#include "common.hh" #include "http_protocol.hh" #include "logger.hh" #include "strutil.hh" #include "transport_base.hh" #include "transport_http.hh" namespace { class HttpTransport : public TransportBase { public: HttpTransport(std::shared_ptr logger, std::shared_ptr looper, std::shared_ptr runner, Transport::Handler* handler) : TransportBase(logger, looper, runner, handler) { } bool setup(Logger* logger, Config const* config) override { if (!TransportBase::setup(logger, config)) return false; extra_.resize(clients()); return true; } bool client_handle(Client* client) override { auto& extra = extra_[client->index_]; auto req = HttpRequest::parse(client->in_.get()); if (!req) { client->expect_in_ = 1; // Don't know how big the request will be if (client->in_closed_ || extra.close_connection_) { if (client->out_->empty()) { client_abort(client); return false; } // Wait for output to be sent (or timeout, whichever is first). } return true; } if (req->good()) { if (supported_request(req.get())) { extra.close_connection_ = close_connection(req.get()); extra.version_ = req->proto_version(); // Stop reading in_ buffer until request is handled. client->expect_in_ = 0; return TransportBase::client_request( client, 0, std::make_unique(std::move(req))); } else { return client_fatal_response(client, 505); } } else { return client_fatal_response(client, 400); } } static bool supported_request(HttpRequest const* request) { if (request->proto() != "HTTP") return false; auto version = request->proto_version(); return version.major == 1; } static bool close_connection(HttpRequest const* request) { if (request->proto_version().major == 1 && request->proto_version().minor == 1) { return request->first_header("connection") == "close"; } return true; } void client_new(Client* client) override { TransportBase::client_new(client); auto& extra = extra_[client->index_]; extra.close_connection_ = true; extra.version_.major = 1; extra.version_.minor = 0; } bool client_fatal_response(Client* client, uint16_t status_code) { auto& extra = extra_[client->index_]; extra.close_connection_ = true; client->expect_in_ = 0; client->in_->clear(); return client_response(client, 0, create_data(status_code, "")); } bool client_response_header(Client* client, uint32_t id) override { auto& extra = extra_[client->index_]; assert(client->responses_.count(id)); auto& cli_response = client->responses_[id]; auto status_code = cli_response.response_->code(); auto builder = HttpResponseBuilder::create( "HTTP", extra.version_, status_code, std::string(http_standard_message(status_code))); bool have_content_length = false; for (auto const& pair : cli_response.response_->headers()) { if (!have_content_length && pair.first == "Content-Length") have_content_length = true; builder->add_header(pair.first, pair.second); } if (!have_content_length) extra.close_connection_ = true; if (extra.close_connection_) builder->add_header("Connection", "close"); if (!builder->build(client->out_.get())) { logger_->warn("Output buffer full for client: %zu", client->index_); client_abort(client); return false; } return client_flush(client); } bool client_response_footer(Client* client, uint32_t) override { auto const& extra = extra_[client->index_]; if (extra.close_connection_ && client->out_->empty()) { client_abort(client); return false; } return true; } private: class WrapRequest : public UrlRequest { public: explicit WrapRequest(std::unique_ptr req) : req_(std::move(req)) {} std::string_view method() const override { return req_->method(); } std::string_view url() const override { return req_->url(); } std::optional header_one( std::string_view name) const override { auto it = req_->header(name); if (it->valid()) return it->value(); return std::nullopt; } std::vector header_all( std::string_view name) const override { std::vector ret; for (auto it = req_->header(name); it->valid(); it->next()) { auto tmp = str::split(it->value(), ','); for (auto str : tmp) ret.push_back(std::string(str::trim(str))); } return ret; } private: std::unique_ptr req_; }; struct Extra { bool close_connection_; Version version_; }; std::vector extra_; }; class HttpFactory : public Transport::Factory { public: std::unique_ptr create( std::shared_ptr logger, std::shared_ptr looper, std::shared_ptr runner, Logger* config_logger, Config const* config, Transport::Handler* handler) { auto transport = std::make_unique( logger, looper, runner, handler); if (transport->setup(config_logger, config)) return transport; return nullptr; } }; } // namespace std::unique_ptr create_transport_factory_http() { return std::make_unique(); }