summaryrefslogtreecommitdiff
path: root/src/prefix_tree.cc
diff options
context:
space:
mode:
authorJoel Klinghed <the_jk@spawned.biz>2025-09-27 18:25:10 +0200
committerJoel Klinghed <the_jk@spawned.biz>2025-09-27 18:49:23 +0200
commit2f13baa843bd1fb5db6630a2823681ffaff9fb11 (patch)
treea8c619cfa52ceb3b31b125b11e6bb15f7e268ed1 /src/prefix_tree.cc
parentce271f82f16ee89a18e7bfc9ed8eab7cbd6f37bc (diff)
Add simple prefix_tree
Will be used by tokenizer for short lists of strings
Diffstat (limited to 'src/prefix_tree.cc')
-rw-r--r--src/prefix_tree.cc173
1 files changed, 173 insertions, 0 deletions
diff --git a/src/prefix_tree.cc b/src/prefix_tree.cc
new file mode 100644
index 0000000..f16df22
--- /dev/null
+++ b/src/prefix_tree.cc
@@ -0,0 +1,173 @@
+#include "prefix_tree.hh"
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <map>
+#include <memory>
+#include <optional>
+#include <set>
+#include <string>
+#include <string_view>
+
+namespace prefix_tree {
+
+namespace {
+
+[[nodiscard]]
+inline uint8_t get_u8(std::string_view str, size_t index) {
+ return static_cast<uint8_t>(str[index] & 0xff);
+}
+
+[[nodiscard]]
+inline uint16_t get_u16(std::string_view str, size_t index) {
+ return (static_cast<uint16_t>(get_u8(str, index)) << 8) |
+ get_u8(str, index + 1);
+}
+
+inline void write_u8(std::string& str, uint8_t value) {
+ str.push_back(static_cast<char>(value));
+}
+
+inline void write_u16(std::string& str, uint16_t value) {
+ write_u8(str, value >> 8);
+ write_u8(str, value & 0xff);
+}
+
+class BuilderImpl : public Builder {
+ public:
+ BuilderImpl() = default;
+
+ void add(std::string_view str) override { strings_.emplace(str); }
+
+ [[nodiscard]]
+ std::optional<std::string> build() const override {
+ std::string tree;
+ std::string strings;
+
+ std::set<std::string_view> tmp;
+ for (auto const& str : strings_)
+ tmp.emplace(str);
+
+ if (!write_tree(tmp, tree, strings))
+ return std::nullopt;
+
+ std::string header;
+ if (tree.size() > 0xffff)
+ return std::nullopt;
+ write_u16(header, tree.size());
+
+ return header + tree + strings;
+ }
+
+ private:
+ bool write_tree(std::set<std::string_view> const& input, std::string& tree,
+ std::string& strings) const {
+ std::map<char, std::set<std::string_view>> buckets;
+ bool match = false;
+ for (auto& str : input) {
+ if (str.empty()) {
+ match = true;
+ continue;
+ }
+ buckets[str.front()].emplace(str.substr(1));
+ }
+
+ write_u8(tree, buckets.size() + (match ? 1 : 0));
+ if (match) {
+ write_u8(tree, 0);
+ write_u16(tree, 0);
+ write_u16(tree, 0);
+ }
+ std::string extra;
+ for (auto& pair : buckets) {
+ auto it = pair.second.begin();
+ auto str = *it;
+ for (++it; it != pair.second.end(); ++it) {
+ size_t i = 0;
+ while (i < str.size() && i < it->size() && str[i] == it->at(i))
+ ++i;
+ if (i == 0) {
+ str = "";
+ break;
+ }
+ str = str.substr(0, i);
+ }
+
+ write_u8(tree, 1 + str.size());
+ if (strings.size() > 0xffff)
+ return false;
+ write_u16(tree, strings.size());
+ strings.push_back(pair.first);
+ strings.append(str);
+ if (extra.size() > 0xffff)
+ return false;
+ write_u16(tree, extra.size());
+ if (str.empty()) {
+ if (!write_tree(pair.second, extra, strings))
+ return false;
+ } else {
+ std::set<std::string_view> tmp;
+ for (auto& str2 : pair.second) {
+ tmp.emplace(str2.substr(str.size()));
+ }
+ if (!write_tree(tmp, extra, strings))
+ return false;
+ }
+ }
+
+ tree.append(extra);
+ return true;
+ }
+
+ std::set<std::string> strings_;
+};
+
+} // namespace
+
+std::optional<size_t> lookup(std::string_view tree, std::string_view str) {
+ size_t base_str = 2 + get_u16(tree, 0);
+ size_t node = 2;
+ std::optional<size_t> match;
+ std::optional<size_t> earlier_match;
+
+ while (node < base_str && !str.empty()) {
+ auto children = get_u8(tree, node);
+
+ if (children == 0) {
+ // Leaf
+ return match;
+ }
+
+ size_t child_node = node + 1;
+ size_t child_end = child_node + (static_cast<size_t>(children) * 5);
+ for (; child_node < child_end; child_node += 5) {
+ uint8_t len = get_u8(tree, child_node);
+ uint16_t offset = get_u16(tree, child_node + 1);
+
+ if (str.starts_with(tree.substr(base_str + offset, len))) {
+ // Match but not a leaf, always first in the list of children
+ if (len == 0) {
+ earlier_match = match;
+ continue;
+ }
+ match = match.value_or(0) + len;
+ str = str.substr(len);
+ auto jump = get_u16(tree, child_node + 3);
+ node = child_end + jump;
+ break;
+ }
+ }
+
+ if (child_node == child_end)
+ return earlier_match;
+ }
+
+ if (node == base_str)
+ return earlier_match;
+ return match;
+}
+
+std::unique_ptr<Builder> builder() { return std::make_unique<BuilderImpl>(); }
+
+} // namespace prefix_tree