#include "args.hh" #include "errors.hh" #include "grammar.hh" #include "io.hh" #include "prefix_tree.hh" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace { enum class CharacterClass : uint8_t { kIdentifier = 0, kLiteral = 1, }; std::vector const kCharacterClassNames( {"Identifier", "Literal"}); std::string make_define(std::string_view filename) { std::string ret; ret.reserve(filename.size()); for (char c : filename) { if ((c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') { ret.push_back(c); } else if (c >= 'a' && c <= 'z') { ret.push_back(static_cast(c & ~0x20)); } else { ret.push_back('_'); } } return ret; } std::string field_name(std::string_view name) { std::string ret; ret.reserve(name.size()); for (char c : name) { if (c >= 'A' && c <= 'Z') { if (!ret.empty()) { ret.push_back('_'); } ret.push_back(c | 0x20); } else { ret.push_back(c); } } return ret; } grammar::Symbol::Optional merge(grammar::Symbol::Optional parent, grammar::Symbol::Optional child) { switch (parent) { case grammar::Symbol::Optional::kRequired: break; case grammar::Symbol::Optional::kZeroOrOne: if (child == grammar::Symbol::Optional::kRequired) return parent; break; case grammar::Symbol::Optional::kZeroOrMore: case grammar::Symbol::Optional::kExcluded: return parent; } return child; } void visit_named_symbols( grammar::Definition const& definition, grammar::Symbol::Optional optional, std::function visitor) { for (auto const& symbol : definition.symbols) { switch (symbol.type) { case grammar::Symbol::Type::kNonTerminal: if (symbol.element->name.empty()) { for (auto const& element_definition : symbol.element->definitions) { visit_named_symbols(element_definition, merge(optional, symbol.optional), visitor); } } else { visitor(symbol, merge(optional, symbol.optional)); } break; case grammar::Symbol::Type::kTerminal: case grammar::Symbol::Type::kCharacterClass: visitor(symbol, merge(optional, symbol.optional)); break; } } } void visit_named_symbols( grammar::Definition const& definition, std::function visitor) { visit_named_symbols(definition, grammar::Symbol::Optional::kRequired, std::move(visitor)); } enum class NodeChildType : uint8_t { kElement, kIdentifier, }; enum class Optional : uint8_t { kRequired = 0, kZeroOrOne, kZeroOrMore, }; struct NodeChild { NodeChildType child_type; std::string name; std::string type_name; Optional optional; NodeChild(NodeChildType child_type, std::string name, std::string type_name, Optional optional) : child_type(child_type), name(std::move(name)), type_name(std::move(type_name)), optional(optional) {} }; struct Node { std::vector children; }; class Generator { public: bool generate(std::string_view header_name, std::string_view source_name, grammar::Grammar& grammar); private: void declare_nodes(std::ostream& out); void declare_node_types(std::ostream& out); void implement_nodes(std::ostream& out); void find_nodes(grammar::Element const& element); std::map nodes_; }; bool Generator::generate(std::string_view header_name, std::string_view source_name, grammar::Grammar& grammar) { std::fstream header{std::string(header_name), std::fstream::trunc | std::fstream::out}; std::fstream source{std::string(source_name), std::fstream::trunc | std::fstream::out}; find_nodes(grammar.root()); auto header_guard = make_define(header_name); header << "#ifndef " << header_guard << "\n" << "#define " << header_guard << "\n" << "\n" << "#include \n" << "#include \n" << "#include \n" << "#include \n" << "#include \n" << "\n" << "namespace java {\n" << "namespace ast {\n" << "\n" << "class Node {\n" << " public:\n" << " virtual ~Node() = default;\n" << " Node(Node const&) = delete;\n" << " Node& operator=(Node const&) = delete;\n" << "\n" << " enum class Type : uint16_t {\n" << " kIdentifier,\n"; declare_node_types(header); header << " };\n" << "\n" << " Type const type;\n" << "\n" << " virtual std::vector> const& children() const;\n" << "\n" << " protected:\n" << " Node(Type type, std::vector> children);\n" << " std::vector> children_;\n" << "};\n" << "\n" << "class Identifier : public Node {\n" << " public:\n" << " explicit Identifier(std::string_view value);\n" << "\n" << " std::string_view value;\n" << "};\n" << "\n"; declare_nodes(header); header << "\n" << "} // namespace ast\n" << "} // namespace java\n" << "\n" << "#endif // " << header_guard << "\n"; source << "#include \"" << header_name << "\"\n" << "\n" << "#include \n" << "\n" << "namespace java {\n" << "namespace ast {\n" << "\n" << "Node::Node(Type type, std::vector> children)\n" << " : type(type), children_(std::move(children)) {}\n" << "std::vector> const& Node::children() const {\n" << " return children_;\n" << "}\n" << "\n" << "Identifier::Identifier(std::string_view value)\n" << " : Node(Type::kIdentifier, {}), value(value) {}\n" << "\n"; implement_nodes(source); source << "\n" << "} // namespace ast\n" << "} // namespace java\n"; return true; } void Generator::find_nodes(grammar::Element const& element) { if (element.name.empty()) { for (auto const& definition : element.definitions) { for (auto const& symbol : definition.symbols) { switch (symbol.type) { case grammar::Symbol::Type::kNonTerminal: find_nodes(*symbol.element); break; case grammar::Symbol::Type::kTerminal: case grammar::Symbol::Type::kCharacterClass: break; } } } return; } auto insert_pair = nodes_.emplace(element.name, Node()); if (!insert_pair.second) { // Already queued return; } std::map> elements; std::map> lines; std::map> optional; for (auto const& definition : element.definitions) { size_t identifiers = 0; visit_named_symbols( definition, [&elements, &lines, &optional, &identifiers](auto const& symbol, auto symbol_optional) { std::string name; grammar::Element const* element; switch (symbol.type) { case grammar::Symbol::Type::kNonTerminal: element = symbol.element; name = element->name; break; case grammar::Symbol::Type::kTerminal: return; case grammar::Symbol::Type::kCharacterClass: if (symbol.char_class == static_cast(CharacterClass::kIdentifier)) { name = "identifier"; if (identifiers > 0) { char tmp[20]; auto ret = std::to_chars(tmp, tmp + sizeof(tmp), identifiers); name.append({tmp, static_cast(ret.ptr - tmp)}); } identifiers++; element = nullptr; } else { return; } break; } switch (symbol_optional) { case grammar::Symbol::Optional::kRequired: break; case grammar::Symbol::Optional::kZeroOrOne: if (optional[name] == Optional::kRequired) { optional[name] = Optional::kZeroOrOne; } break; case grammar::Symbol::Optional::kZeroOrMore: optional[name] = Optional::kZeroOrMore; break; case grammar::Symbol::Optional::kExcluded: // Don't include in elements return; } elements[name] = element; lines[name]++; }); } for (auto const& pair : elements) { auto& node = insert_pair.first->second; auto child_optional = optional[pair.first]; if (child_optional == Optional::kRequired && lines[pair.first] < element.definitions.size()) { child_optional = Optional::kZeroOrOne; } if (pair.second) { node.children.emplace_back( NodeChildType::kElement, field_name(pair.first), std::string(pair.first), child_optional); find_nodes(*pair.second); } else { node.children.emplace_back( NodeChildType::kIdentifier, pair.first, "std::string_view", child_optional); } } } void Generator::declare_nodes(std::ostream& out) { std::set> declared; for (auto const& pair : nodes_) { declared.insert(pair.first); out << "class " << pair.first << " : public Node {\n" << " public:\n" << " " << pair.first << "(" << pair.first << "&&);\n"; for (size_t i = 0; i < pair.second.children.size(); ++i) { auto const& child = pair.second.children[i]; switch (child.child_type) { case NodeChildType::kElement: { std::string type_name = child.type_name; if (!declared.contains(type_name)) { type_name = "class " + type_name; } switch (child.optional) { case Optional::kRequired: out << " " << type_name << "& " << child.name << "() const;\n"; out << " void set_" << child.name << "(" << type_name << "&& " << child.name << ");\n"; break; case Optional::kZeroOrOne: out << " " << type_name << "* " << child.name << "() const;\n"; out << " void set_" << child.name << "(std::unique_ptr<" << type_name << "> " << child.name << ");\n"; break; case Optional::kZeroOrMore: out << " std::vector<" << type_name << "&> " << child.name << "() const;\n"; break; } break; } case NodeChildType::kIdentifier: switch (child.optional) { case Optional::kRequired: out << " std::string_view " << child.name << "() const {\n" << " return static_cast(" << "children_[" << i << "].get())->value;\n" << " };\n"; out << " void set_" << child.name << "(std::string_view " << child.name << ");\n"; break; case Optional::kZeroOrOne: out << " std::optional " << child.name << "() const {\n" << " auto* ptr = static_cast(" << "children_[" << i << "].get());\n" << " if (ptr) return ptr->value;\n" << " return std::nullopt;\n" << " };\n"; out << " void set_" << child.name << "(std::optional " << child.name << ");\n"; break; case Optional::kZeroOrMore: out << " std::vector " << child.name << "() const;\n"; break; } } } out << "};\n" << "\n"; } for (auto const& pair : nodes_) { for (size_t i = 0; i < pair.second.children.size(); ++i) { auto const& child = pair.second.children[i]; switch (child.child_type) { case NodeChildType::kElement: switch (child.optional) { case Optional::kRequired: out << "inline " << child.type_name << "& " << pair.first << "::" << child.name << "() const {\n" << " return *static_cast<" << child.type_name << "*>(" << "children_[" << i << "].get());\n" << "}\n" << "\n"; break; case Optional::kZeroOrOne: out << "inline " << child.type_name << "* " << pair.first << "::" << child.name << "() const {\n" << " return static_cast<" << child.type_name << "*>(" << "children_[" << i << "].get());\n" << "}\n" << "\n"; break; case Optional::kZeroOrMore: break; } break; case NodeChildType::kIdentifier: break; } } } } void Generator::declare_node_types(std::ostream& out) { for (auto const& pair : nodes_) { out << " k" << pair.first << ",\n"; } } void Generator::implement_nodes(std::ostream& out) { for (auto const& pair : nodes_) { for (size_t i = 0; i < pair.second.children.size(); ++i) { auto const& child = pair.second.children[i]; switch (child.child_type) { case NodeChildType::kElement: { switch (child.optional) { case Optional::kRequired: out << "void " << pair.first << "::set_" << child.name << "(" << child.type_name << "&& " << child.name << ") {\n" << " children_[" << i << "] = std::make_unique<" << child.type_name << ">(std::move(" << child.name << "));\n" << "}\n"; break; case Optional::kZeroOrOne: out << "void " << pair.first << "::set_" << child.name << "(std::unique_ptr<" << child.type_name << "> " << child.name << ") {\n" << " children_[" << i << "] = std::move(" << child.name << ");\n" << "}\n"; break; case Optional::kZeroOrMore: break; } break; } case NodeChildType::kIdentifier: switch (child.optional) { case Optional::kRequired: out << "void " << pair.first << "::set_" << child.name << "(std::string_view " << child.name << ") {\n" << " children_[" << i << "] = std::make_unique(" << child.name << ");\n" << "}\n"; break; case Optional::kZeroOrOne: out << "void " << pair.first << "::set_" << child.name << "(std::optional " << child.name << ") {\n" << " if (" << child.name << ".has_value()) {\n" << " children_[" << i << "] = std::make_unique(*" << child.name << ");\n" << " } else {\n" << " children_[" << i << "].reset();\n" << " }\n" << "}\n"; break; case Optional::kZeroOrMore: break; } } } out << "\n"; } } } // namespace int main(int argc, char** argv) { auto args = Args::create(); auto opt_help = args->option('h', "help", "display this text and exit"); std::vector arguments; if (!args->run(argc, argv, &arguments)) { args->print_error(std::cerr); std::cerr << "Try `gen_ast --help` for usage\n"; return 1; } if (opt_help->is_set()) { std::cout << "Usage: `gen_ast [OPTIONS...] syntax.grammar" << " OUTPUT.hh OUTPUT.cc`\n" << "Generates an AST for grammar.\n" << "\n"; args->print_help(std::cout); return 0; } if (arguments.size() != 3) { std::cerr << "Expecting three arguments. No more, no less.\n" << "Try `gen_ast --help` for usage\n"; return 1; } auto filename = std::string(arguments[0]); auto reader = io::open(filename); if (!reader.has_value()) { std::cerr << "Unable to open " << filename << '\n'; return 1; } auto errors = src::file_errors(std::move(filename)); auto grammar = grammar::load(std::move(reader.value()), kCharacterClassNames, *errors); if (!grammar || errors->errors() > 0) return 1; Generator generator; if (!generator.generate(arguments[1], arguments[2], *grammar)) return 1; return 0; }