diff options
Diffstat (limited to 'src/http.cc')
| -rw-r--r-- | src/http.cc | 295 |
1 files changed, 266 insertions, 29 deletions
diff --git a/src/http.cc b/src/http.cc index 8a0a1e6..21de28d 100644 --- a/src/http.cc +++ b/src/http.cc @@ -20,10 +20,12 @@ #include <memory> #include <netdb.h> #include <optional> +#include <span> #include <string> #include <string_view> #include <sys/socket.h> #include <sys/types.h> +#include <system_error> #include <unistd.h> #include <utility> #include <vector> @@ -54,6 +56,25 @@ std::string_view ascii_lowercase(std::string_view input, std::string& tmp) { return input; } +bool eq_lowercase(std::string_view input, std::string_view cmp) { + if (input.size() != cmp.size()) + return false; + + auto a = input.begin(); + auto b = cmp.begin(); + for (; a != input.end(); ++a, ++b) { + if (*a >= 'A' && *a <= 'Z') { + // NOLINTNEXTLINE(bugprone-narrowing-conversions) + if ((*a | 0x20) != *b) + return false; + } else { + if (*a != *b) + return false; + } + } + return true; +} + std::string standard_header(std::string_view input) { bool capitalize = true; std::string ret; @@ -215,7 +236,8 @@ class OpenPortImpl : public OpenPort { class ResponseImpl : public Response { public: - explicit ResponseImpl(std::string data) : data_(std::move(data)) {} + ResponseImpl(std::string data, SocketReceiver* socket_receiver) + : data_(std::move(data)), socket_receiver_(socket_receiver) {} bool write(Buffer& buffer) override { if (offset_ >= data_.size()) @@ -235,9 +257,15 @@ class ResponseImpl : public Response { return offset_ < data_.size(); } + [[nodiscard]] + SocketReceiver* socket_receiver() const override { + return socket_receiver_; + } + private: std::string const data_; size_t offset_{0}; + SocketReceiver* socket_receiver_; }; class ResponseBuilderImpl : public Response::Builder { @@ -256,6 +284,11 @@ class ResponseBuilderImpl : public Response::Builder { return *this; } + Builder& take_over(SocketReceiver& receiver) override { + socket_receiver_ = &receiver; + return *this; + } + [[nodiscard]] std::unique_ptr<Response> build() const override { std::string ret; @@ -269,12 +302,18 @@ class ResponseBuilderImpl : public Response::Builder { } ret.append(" "); switch (status_) { + case StatusCode::kSwitchingProtocols: + ret.append("Switching Protocols"); + break; case StatusCode::kOK: ret.append("OK"); break; case StatusCode::kNoContent: ret.append("No Content"); break; + case StatusCode::kBadRequest: + ret.append("Bad Request"); + break; case StatusCode::kNotFound: ret.append("Not Found"); break; @@ -292,7 +331,20 @@ class ResponseBuilderImpl : public Response::Builder { ret.append(pair.second); ret.append("\r\n"); } - if (!have_content_len && status_ != StatusCode::kNoContent) { + bool need_content_len; + switch (status_) { + case StatusCode::kSwitchingProtocols: + case StatusCode::kNoContent: + need_content_len = false; + break; + case StatusCode::kOK: + case StatusCode::kBadRequest: + case StatusCode::kNotFound: + case StatusCode::kMethodNotAllowed: + need_content_len = true; + break; + } + if (!have_content_len && need_content_len) { char tmp[20]; auto [ptr, ec] = std::to_chars(tmp, tmp + sizeof(tmp), content_.size()); ret.append("Content-Length"); @@ -303,19 +355,93 @@ class ResponseBuilderImpl : public Response::Builder { ret.append("\r\n"); if (status_ != StatusCode::kNoContent) ret.append(content_); - return std::make_unique<ResponseImpl>(std::move(ret)); + return std::make_unique<ResponseImpl>(std::move(ret), socket_receiver_); } private: StatusCode const status_; std::string content_; std::vector<std::pair<std::string, std::string>> headers_; + SocketReceiver* socket_receiver_{nullptr}; +}; + +class HeaderIterator { + public: + explicit HeaderIterator(std::span<std::string_view> headers) + : headers_(headers) { + next(); + } + + [[nodiscard]] + bool has_name() const { + return offset_ < headers_.size(); + } + + [[nodiscard]] + bool has_value() const { + return has_name() && continue_value_; + } + + bool next_name() { + do { + next(); + } while (has_value()); + return has_name(); + } + + bool next_value() { + next(); + return has_value(); + } + + [[nodiscard]] + std::string_view name() const { + return name_; + } + + [[nodiscard]] + std::string_view value() const { + return value_; + } + + private: + void next() { + if (offset_ >= headers_.size()) + return; + + auto line = headers_[offset_]; + if (offset_ > 0 && + (line.empty() || line.front() == ' ' || line.front() == '\t')) { + continue_value_ = true; + value_ = str::ltrim(line); + ++offset_; + return; + } + auto colon = line.find(':'); + if (colon == std::string::npos) { + assert(false); + ++offset_; + next(); + return; + } + continue_value_ = false; + name_ = str::trim(line.substr(0, colon)); + value_ = str::ltrim(line.substr(colon + 1)); + ++offset_; + } + + std::span<std::string_view> headers_; + size_t offset_{0}; + bool continue_value_{false}; + std::string_view name_; + std::string_view value_; }; class RequestImpl : public Request { public: - RequestImpl(std::string_view method, std::string_view path) - : method_(method), path_(path) {} + RequestImpl(std::string_view method, std::string_view path, + std::span<std::string_view> headers) + : method_(method), path_(path), headers_(headers) {} [[nodiscard]] std::string_view method() const override { @@ -327,9 +453,62 @@ class RequestImpl : public Request { return path_; } + [[nodiscard]] + std::string_view body() const override { + return body_; + } + + void set_body(std::string_view body) { body_ = body; } + + [[nodiscard]] + bool header_contains(std::string_view name, + std::string_view value) const override { + for (HeaderIterator it{headers_}; it.has_name(); it.next_name()) { + if (eq_lowercase(it.name(), name)) { + std::string_view tmp = get_value(it); + auto values = str::split(tmp, ',', /* keep_empty */ true); + for (auto v : values) { + if (eq_lowercase(str::trim(v), value)) { + return true; + } + } + } + } + return false; + } + + [[nodiscard]] + std::optional<std::string_view> header_value( + std::string_view name) const override { + for (HeaderIterator it{headers_}; it.has_name(); it.next_name()) { + if (eq_lowercase(it.name(), name)) { + // Do not look for more entries of the header, this method is used for header + // values that cannot be combined. + return get_value(it); + } + } + return std::nullopt; + } + private: + std::string_view get_value(HeaderIterator& it) const { + std::string_view ret = it.value(); + if (it.next_value()) { + tmp_.clear(); + tmp_.append(ret); + do { + tmp_.append(it.value()); + } while (it.next_value()); + return tmp_; + } + return ret; + } + std::string_view method_; std::string_view path_; + std::string_view body_; + std::span<std::string_view> headers_; + mutable std::string tmp_; }; // TODO: What is a good value? @@ -340,7 +519,7 @@ class ServerImpl : public Server { ServerImpl(logger::Logger& logger, cfg::Config const& cfg, looper::Looper& looper, std::unique_ptr<OpenPort> open_port, Server::Delegate& delegate) - : logger_(logger), + : logger_(logger::prefix(logger, "http")), cfg_(cfg), looper_(looper), delegate_(delegate), @@ -364,6 +543,7 @@ class ServerImpl : public Server { std::chrono::steady_clock::time_point last; uint32_t timeout{0}; bool read_closed_{false}; + SocketReceiver* take_over_{nullptr}; }; void start_to_listen() { @@ -372,7 +552,7 @@ class ServerImpl : public Server { if (listen(it->get(), static_cast<int>(cfg_.get_uint64("http.listen.backlog") .value_or(kListenBacklog)))) { - logger_.warn( + logger_->warn( std::format("Error listening to socket: {}", strerror(errno))); it = listens_.erase(it); } else { @@ -386,21 +566,21 @@ class ServerImpl : public Server { } if (listens_.empty()) { - logger_.err("No ports left to listen to, exit."); + logger_->err("No ports left to listen to, exit."); looper_.quit(); } } void accept_client(int fd, uint8_t event) { if (event & looper::EVENT_ERROR) { - logger_.warn("Listening port returned error, closing."); + logger_->warn("Listening port returned error, closing."); looper_.remove(fd); auto it = std::ranges::find_if( listens_, [&fd](auto& ufd) { return ufd.get() == fd; }); if (it != listens_.end()) { listens_.erase(it); if (listens_.empty()) { - logger_.err("No ports left to listen to, exit."); + logger_->err("No ports left to listen to, exit."); looper_.quit(); } } @@ -409,12 +589,12 @@ class ServerImpl : public Server { unique_fd client_fd(accept4(fd, nullptr, nullptr, SOCK_NONBLOCK)); if (!client_fd) { - logger_.info(std::format("Accept returned error: {}", strerror(errno))); + logger_->info(std::format("Accept returned error: {}", strerror(errno))); return; } if (active_clients_ == client_.size()) { - logger_.warn("Max number of clients already."); + logger_->warn("Max number of clients already."); return; } @@ -423,7 +603,7 @@ class ServerImpl : public Server { next_client_ = 0; } - logger_.dbg(std::format("New client: {}", next_client_)); + logger_->dbg(std::format("New client: {}", next_client_)); auto& client = client_[next_client_]; client.fd = std::move(client_fd); @@ -443,7 +623,7 @@ class ServerImpl : public Server { void client_event(size_t client_id, uint8_t event) { if (event & looper::EVENT_ERROR) { - logger_.info(std::format("Client socket error: {}", client_id)); + logger_->info(std::format("Client socket error: {}", client_id)); close_client(client_id); return; } @@ -453,7 +633,7 @@ class ServerImpl : public Server { size_t avail; auto* ptr = client.in->wptr(avail); if (avail == 0) { - logger_.info(std::format("Client too large request: {}", client_id)); + logger_->info(std::format("Client too large request: {}", client_id)); close_client(client_id); return; } @@ -467,8 +647,8 @@ class ServerImpl : public Server { } if (got < 0) { if (errno != EWOULDBLOCK && errno != EAGAIN) { - logger_.info(std::format("Client read error: {}: {}", client_id, - strerror(errno))); + logger_->info(std::format("Client read error: {}: {}", client_id, + strerror(errno))); close_client(client_id); return; } @@ -488,10 +668,15 @@ class ServerImpl : public Server { if (avail == 0) { if (client.resp) { if (!client.resp->write(*client.out)) { + client.take_over_ = client.resp->socket_receiver(); client.resp.reset(); - if (!client_read(client_id)) - return; + if (client.take_over_) { + client.read_closed_ = true; + } else { + if (!client_read(client_id)) + return; + } } continue; } @@ -506,8 +691,8 @@ class ServerImpl : public Server { } if (got < 0) { if (errno != EWOULDBLOCK && errno != EAGAIN) { - logger_.info(std::format("Client write error: {}: {}", client_id, - strerror(errno))); + logger_->info(std::format("Client write error: {}: {}", client_id, + strerror(errno))); close_client(client_id); return; } @@ -521,7 +706,13 @@ class ServerImpl : public Server { if (client.read_closed_ && client.in->empty() && client.out->empty() && !client.resp) { - close_client(client_id); + if (client.take_over_) { + looper_.remove(client.fd.get()); + client.take_over_->receive(std::move(client.fd)); + reset_client(client_id); + } else { + close_client(client_id); + } return; } @@ -550,19 +741,57 @@ class ServerImpl : public Server { { auto parts = str::split(lines[0], ' ', false); if (parts.size() != 3) { - logger_.info(std::format("Client invalid request: {}", client_id)); + logger_->info(std::format("Client invalid request: {}", client_id)); close_client(client_id); return false; } method = parts[0]; path = parts[1]; if (!parts[2].starts_with("HTTP/")) { - logger_.info(std::format("Client invalid request: {}", client_id)); + logger_->info(std::format("Client invalid request: {}", client_id)); close_client(client_id); return false; } } - RequestImpl request(method, path); + for (size_t i = 1; i < lines.size(); ++i) { + auto colon = lines[i].find(':'); + if (colon == std::string::npos && + // Only continue header value lines are allowed to not have a colon. + (i == 1 || lines[i].empty() || + (lines[i].front() != ' ' && lines[i].front() != '\t'))) { + logger_->info(std::format("Client invalid request: {}", client_id)); + close_client(client_id); + return false; + } + } + RequestImpl request(method, path, std::span{lines}.subspan(1)); + auto maybe_content_length = request.header_value("content-length"); + if (maybe_content_length.has_value()) { + uint64_t content_length; + auto [ptr, ec] = std::from_chars( + maybe_content_length->data(), + maybe_content_length->data() + maybe_content_length->size(), + content_length); + if (ec != std::errc() || + ptr != maybe_content_length->data() + maybe_content_length->size()) { + logger_->info(std::format( + "Client invalid request, bad content-length: {}", client_id)); + close_client(client_id); + return false; + } + if (avail - end < content_length) { + // Need more data. + if (client.in->full()) { + logger_->info(std::format("Client invalid request, too much data: {}", + client_id)); + close_client(client_id); + return false; + } + return true; + } + request.set_body(data.substr(end + 4, content_length)); + end += content_length; + } client.resp = delegate_.handle(request); client.in->consume(end + 4); return true; @@ -583,7 +812,7 @@ class ServerImpl : public Server { return; } - logger_.dbg(std::format("Client timeout: {}", client_id)); + logger_->dbg(std::format("Client timeout: {}", client_id)); client.timeout = 0; close_client(client_id); @@ -596,23 +825,31 @@ class ServerImpl : public Server { return; } - logger_.dbg(std::format("Drop client: {}", client_id)); + logger_->dbg(std::format("Drop client: {}", client_id)); looper_.remove(client.fd.get()); + client.fd.reset(); + + reset_client(client_id); + } + + void reset_client(size_t client_id) { + auto& client = client_[client_id]; + if (client.timeout) looper_.cancel(client.timeout); - client.fd.reset(); clear(*client.in); clear(*client.out); client.resp.reset(); client.read_closed_ = false; + client.take_over_ = nullptr; assert(active_clients_ > 0); --active_clients_; if (next_client_ == client_id + 1) next_client_ = client_id; } - logger::Logger& logger_; + std::unique_ptr<logger::Logger> logger_; cfg::Config const& cfg_; looper::Looper& looper_; Server::Delegate& delegate_; |
