summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/prefix_tree.cc112
-rw-r--r--test/prefix_tree.cc26
2 files changed, 123 insertions, 15 deletions
diff --git a/src/prefix_tree.cc b/src/prefix_tree.cc
index f16df22..56466e8 100644
--- a/src/prefix_tree.cc
+++ b/src/prefix_tree.cc
@@ -42,6 +42,17 @@ class BuilderImpl : public Builder {
[[nodiscard]]
std::optional<std::string> build() const override {
+ auto ret = build(1);
+ if (!ret.has_value())
+ ret = build(2);
+ return ret;
+ }
+
+ private:
+ [[nodiscard]]
+ std::optional<std::string> build(uint8_t size) const {
+ assert(size > 0);
+
std::string tree;
std::string strings;
@@ -49,20 +60,19 @@ class BuilderImpl : public Builder {
for (auto const& str : strings_)
tmp.emplace(str);
- if (!write_tree(tmp, tree, strings))
+ if (!write_tree(tmp, tree, strings, size))
return std::nullopt;
std::string header;
- if (tree.size() > 0xffff)
+ write_u8(header, size);
+ if (!write_size(header, size, tree.size()))
return std::nullopt;
- write_u16(header, tree.size());
return header + tree + strings;
}
- private:
bool write_tree(std::set<std::string_view> const& input, std::string& tree,
- std::string& strings) const {
+ std::string& strings, uint8_t size) const {
std::map<char, std::set<std::string_view>> buckets;
bool match = false;
for (auto& str : input) {
@@ -76,8 +86,8 @@ class BuilderImpl : public Builder {
write_u8(tree, buckets.size() + (match ? 1 : 0));
if (match) {
write_u8(tree, 0);
- write_u16(tree, 0);
- write_u16(tree, 0);
+ write_size(tree, size, 0);
+ write_size(tree, size, 0);
}
std::string extra;
for (auto& pair : buckets) {
@@ -95,23 +105,23 @@ class BuilderImpl : public Builder {
}
write_u8(tree, 1 + str.size());
- if (strings.size() > 0xffff)
+ if (!write_size(tree, size, strings.size()))
return false;
- write_u16(tree, strings.size());
strings.push_back(pair.first);
strings.append(str);
if (extra.size() > 0xffff)
return false;
- write_u16(tree, extra.size());
+ if (!write_size(tree, size, extra.size()))
+ return false;
if (str.empty()) {
- if (!write_tree(pair.second, extra, strings))
+ if (!write_tree(pair.second, extra, strings, size))
return false;
} else {
std::set<std::string_view> tmp;
for (auto& str2 : pair.second) {
tmp.emplace(str2.substr(str.size()));
}
- if (!write_tree(tmp, extra, strings))
+ if (!write_tree(tmp, extra, strings, size))
return false;
}
}
@@ -120,12 +130,27 @@ class BuilderImpl : public Builder {
return true;
}
+ static bool write_size(std::string& str, uint8_t size, size_t value) {
+ if (size == 1) {
+ if (value > 0xff)
+ return false;
+ write_u8(str, value);
+ return true;
+ }
+ if (size == 2) {
+ if (value > 0xffff)
+ return false;
+ write_u16(str, value);
+ return true;
+ }
+ assert(false);
+ return false;
+ }
+
std::set<std::string> strings_;
};
-} // namespace
-
-std::optional<size_t> lookup(std::string_view tree, std::string_view str) {
+std::optional<size_t> lookup16(std::string_view tree, std::string_view str) {
size_t base_str = 2 + get_u16(tree, 0);
size_t node = 2;
std::optional<size_t> match;
@@ -168,6 +193,63 @@ std::optional<size_t> lookup(std::string_view tree, std::string_view str) {
return match;
}
+std::optional<size_t> lookup8(std::string_view tree, std::string_view str) {
+ size_t base_str = 1 + get_u8(tree, 0);
+ size_t node = 1;
+ std::optional<size_t> match;
+ std::optional<size_t> earlier_match;
+
+ while (node < base_str && !str.empty()) {
+ auto children = get_u8(tree, node);
+
+ if (children == 0) {
+ // Leaf
+ return match;
+ }
+
+ size_t child_node = node + 1;
+ size_t child_end = child_node + (static_cast<size_t>(children) * 3);
+ for (; child_node < child_end; child_node += 3) {
+ uint8_t len = get_u8(tree, child_node);
+ uint8_t offset = get_u8(tree, child_node + 1);
+
+ if (str.starts_with(tree.substr(base_str + offset, len))) {
+ // Match but not a leaf, always first in the list of children
+ if (len == 0) {
+ earlier_match = match;
+ continue;
+ }
+ match = match.value_or(0) + len;
+ str = str.substr(len);
+ auto jump = get_u8(tree, child_node + 2);
+ node = child_end + jump;
+ break;
+ }
+ }
+
+ if (child_node == child_end)
+ return earlier_match;
+ }
+
+ if (node == base_str)
+ return earlier_match;
+ return match;
+}
+
+} // namespace
+
+std::optional<size_t> lookup(std::string_view tree, std::string_view str) {
+ auto size = get_u8(tree, 0);
+ if (size == 1) {
+ return lookup8(tree.substr(1), str);
+ }
+ if (size == 2) {
+ return lookup16(tree.substr(1), str);
+ }
+ assert(false);
+ return std::nullopt;
+}
+
std::unique_ptr<Builder> builder() { return std::make_unique<BuilderImpl>(); }
} // namespace prefix_tree
diff --git a/test/prefix_tree.cc b/test/prefix_tree.cc
index 6c00adb..86c8990 100644
--- a/test/prefix_tree.cc
+++ b/test/prefix_tree.cc
@@ -1,5 +1,7 @@
#include "prefix_tree.hh"
+#include "str.hh"
+
#include <gtest/gtest.h>
TEST(prefix_tree, empty) {
@@ -45,3 +47,27 @@ TEST(prefix_tree, sanity) {
ASSERT_TRUE(ret.has_value());
EXPECT_EQ(3, ret.value());
}
+
+TEST(prefix_tree, many_and_long) {
+ auto builder = prefix_tree::builder();
+ for (auto str : str::split(
+ "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do "
+ "eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut "
+ "enim ad minim veniam, quis nostrud exercitation ullamco laboris "
+ "nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in "
+ "reprehenderit in voluptate velit esse cillum dolore eu fugiat "
+ "nulla pariatur. Excepteur sint occaecat cupidatat non proident, "
+ "sunt in culpa qui officia deserunt mollit anim id est laborum.")) {
+ builder->add(str);
+ }
+ auto tree = builder->build();
+ ASSERT_TRUE(tree.has_value());
+ auto ret = prefix_tree::lookup(tree.value(), "");
+ EXPECT_FALSE(ret.has_value());
+ ret = prefix_tree::lookup(tree.value(), "Lorem");
+ ASSERT_TRUE(ret.has_value());
+ EXPECT_EQ(5, ret.value());
+ ret = prefix_tree::lookup(tree.value(), "cillum");
+ ASSERT_TRUE(ret.has_value());
+ EXPECT_EQ(6, ret.value());
+}