#include "http.hh" #include "buffer.hh" #include "cfg.hh" #include "logger.hh" #include "looper.hh" #include "str.hh" #include "unique_fd.hh" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace http { namespace { std::string_view ascii_lowercase(std::string_view input, std::string& tmp) { auto it = input.begin(); for (; it != input.end(); ++it) { if (*it >= 'A' && *it <= 'Z') { tmp.resize(input.size()); auto out = std::copy(input.begin(), it, tmp.begin()); // NOLINTNEXTLINE(bugprone-narrowing-conversions) *out++ = *it++ | 0x20; for (; it != input.end(); ++it) { if (*it >= 'A' && *it <= 'Z') { // NOLINTNEXTLINE(bugprone-narrowing-conversions) *out++ = *it | 0x20; } else { *out++ = *it; } } return tmp; } } return input; } bool eq_lowercase(std::string_view input, std::string_view cmp) { if (input.size() != cmp.size()) return false; auto a = input.begin(); auto b = cmp.begin(); for (; a != input.end(); ++a, ++b) { if (*a >= 'A' && *a <= 'Z') { // NOLINTNEXTLINE(bugprone-narrowing-conversions) if ((*a | 0x20) != *b) return false; } else { if (*a != *b) return false; } } return true; } std::string standard_header(std::string_view input) { bool capitalize = true; std::string ret; for (auto c : input) { if (c >= 'A' && c <= 'Z') { if (capitalize) { ret.push_back(c); capitalize = false; } else { // NOLINTNEXTLINE(bugprone-narrowing-conversions) ret.push_back(c | 0x20); } } else if (c >= 'a' && c <= 'z') { if (capitalize) { // NOLINTNEXTLINE(bugprone-narrowing-conversions) ret.push_back(c & 0x20); capitalize = false; } else { ret.push_back(c); } } else if (c == '-') { ret.push_back(c); capitalize = true; } else { ret.push_back(c); } } return ret; } void make_nonblock(int fd) { int flags = fcntl(fd, F_GETFL); if (flags & O_NONBLOCK) return; fcntl(fd, F_SETFL, flags | O_NONBLOCK); } void clear(Buffer& buffer) { if (buffer.empty()) return; while (true) { size_t avail; buffer.rptr(avail); if (avail == 0) break; buffer.consume(avail); } } class MimeTypeImpl : public MimeType { public: MimeTypeImpl(std::string str, size_t slash, size_t subtype_len, std::map params = {}) : str_(std::move(str)), slash_(slash), subtype_len_(subtype_len), params_(std::move(params)) { assert(slash_ < str_.size()); assert(slash_ + 1 + subtype_len_ <= str_.size()); } [[nodiscard]] std::string_view type() const override { return std::string_view(str_).substr(0, slash_); } [[nodiscard]] std::string_view subtype() const override { return std::string_view(str_).substr(slash_ + 1, subtype_len_); } [[nodiscard]] std::optional parameter( std::string_view name) const override { std::optional ret; std::string tmp; auto it = params_.find(ascii_lowercase(name, tmp)); if (it != params_.end()) ret = it->second; return ret; } [[nodiscard]] std::string const& string() const override { return str_; } private: std::string const str_; size_t const slash_; size_t const subtype_len_; std::map const params_; }; class MimeTypeBuilderImpl : public MimeType::Builder { public: MimeTypeBuilderImpl(std::string_view type, std::string_view subtype) : slash_(type.size()), subtype_len_(subtype.size()) { str_.reserve(type.size() + 1 + subtype.size()); str_.append(type); str_.push_back('/'); str_.append(subtype); } MimeType::Builder& parameter(std::string_view name, std::string_view value) override { // TODO: make value a quoted string if it contains ; std::string tmp; auto ret = params_.emplace(ascii_lowercase(name, tmp), value); if (!ret.second) { ret.first->second = value; } return *this; } [[nodiscard]] std::unique_ptr build() const override { if (params_.empty()) { return std::make_unique(str_, slash_, subtype_len_); } std::string tmp(str_); std::map, size_t> params1; for (auto const& pair : params_) { tmp.append("; "); params1.emplace(std::make_pair(tmp.size(), pair.first.size()), pair.second.size()); tmp.append(pair.first); tmp.push_back('='); tmp.append(pair.second); } std::map params2; std::string_view tmp_view(tmp); for (auto const& pair : params1) { params2.emplace(tmp_view.substr(pair.first.first, pair.first.second), tmp_view.substr(pair.first.first + pair.first.second + 1, pair.second)); } return std::make_unique(std::move(tmp), slash_, subtype_len_, std::move(params2)); } private: std::string str_; size_t const slash_; size_t const subtype_len_; std::map params_; }; class OpenPortImpl : public OpenPort { public: explicit OpenPortImpl(std::vector fd) : fd_(std::move(fd)) {} std::vector release() { return std::move(fd_); } private: std::vector fd_; }; class ResponseImpl : public Response { public: ResponseImpl(std::string data, SocketReceiver* socket_receiver) : data_(std::move(data)), socket_receiver_(socket_receiver) {} bool write(Buffer& buffer) override { if (offset_ >= data_.size()) return false; size_t avail; auto* ptr = buffer.wptr(avail); if (avail == 0) return true; avail = std::min(data_.size() - offset_, avail); std::copy_n(data_.data() + offset_, avail, reinterpret_cast(ptr)); offset_ += avail; buffer.commit(avail); return offset_ < data_.size(); } [[nodiscard]] SocketReceiver* socket_receiver() const override { return socket_receiver_; } private: std::string const data_; size_t offset_{0}; SocketReceiver* socket_receiver_; }; class ResponseBuilderImpl : public Response::Builder { public: explicit ResponseBuilderImpl(StatusCode status) : status_(status) {} Response::Builder& content(std::string_view content) override { content_ = content; return *this; } Response::Builder& add_header(std::string_view name, std::string_view value) override { // TODO: Make sure name or value doesn't contain invalid chars headers_.emplace_back(standard_header(name), value); return *this; } Builder& take_over(SocketReceiver& receiver) override { socket_receiver_ = &receiver; return *this; } [[nodiscard]] std::unique_ptr build() const override { std::string ret; ret.reserve(100 + content_.size()); ret.append("HTTP/1.1 "); { char tmp[4]; auto [ptr, ec] = std::to_chars(tmp, tmp + sizeof(tmp), std::to_underlying(status_)); ret.append(tmp, ptr - tmp); } ret.append(" "); switch (status_) { case StatusCode::kSwitchingProtocols: ret.append("Switching Protocols"); break; case StatusCode::kOK: ret.append("OK"); break; case StatusCode::kNoContent: ret.append("No Content"); break; case StatusCode::kBadRequest: ret.append("Bad Request"); break; case StatusCode::kNotFound: ret.append("Not Found"); break; case StatusCode::kMethodNotAllowed: ret.append("Method Not Allowed"); break; } ret.append("\r\n"); bool have_content_len = false; for (auto const& pair : headers_) { if (!have_content_len && pair.first == "Content-Length") have_content_len = true; ret.append(pair.first); ret.append(": "); ret.append(pair.second); ret.append("\r\n"); } bool need_content_len; switch (status_) { case StatusCode::kSwitchingProtocols: case StatusCode::kNoContent: need_content_len = false; break; case StatusCode::kOK: case StatusCode::kBadRequest: case StatusCode::kNotFound: case StatusCode::kMethodNotAllowed: need_content_len = true; break; } if (!have_content_len && need_content_len) { char tmp[20]; auto [ptr, ec] = std::to_chars(tmp, tmp + sizeof(tmp), content_.size()); ret.append("Content-Length"); ret.append(": "); ret.append(tmp, ptr - tmp); ret.append("\r\n"); } ret.append("\r\n"); if (status_ != StatusCode::kNoContent) ret.append(content_); return std::make_unique(std::move(ret), socket_receiver_); } private: StatusCode const status_; std::string content_; std::vector> headers_; SocketReceiver* socket_receiver_{nullptr}; }; class HeaderIterator { public: explicit HeaderIterator(std::span headers) : headers_(headers) { next(); } [[nodiscard]] bool has_name() const { return offset_ < headers_.size(); } [[nodiscard]] bool has_value() const { return has_name() && continue_value_; } bool next_name() { do { next(); } while (has_value()); return has_name(); } bool next_value() { next(); return has_value(); } [[nodiscard]] std::string_view name() const { return name_; } [[nodiscard]] std::string_view value() const { return value_; } private: void next() { if (offset_ >= headers_.size()) return; auto line = headers_[offset_]; if (offset_ > 0 && (line.empty() || line.front() == ' ' || line.front() == '\t')) { continue_value_ = true; value_ = str::ltrim(line); ++offset_; return; } auto colon = line.find(':'); if (colon == std::string::npos) { assert(false); ++offset_; next(); return; } continue_value_ = false; name_ = str::trim(line.substr(0, colon)); value_ = str::ltrim(line.substr(colon + 1)); ++offset_; } std::span headers_; size_t offset_{0}; bool continue_value_{false}; std::string_view name_; std::string_view value_; }; class RequestImpl : public Request { public: RequestImpl(std::string_view method, std::string_view path, std::span headers) : method_(method), path_(path), headers_(headers) {} [[nodiscard]] std::string_view method() const override { return method_; } [[nodiscard]] std::string_view path() const override { return path_; } [[nodiscard]] std::string_view body() const override { return body_; } void set_body(std::string_view body) { body_ = body; } [[nodiscard]] bool header_contains(std::string_view name, std::string_view value) const override { for (HeaderIterator it{headers_}; it.has_name(); it.next_name()) { if (eq_lowercase(it.name(), name)) { std::string_view tmp = get_value(it); auto values = str::split(tmp, ',', /* keep_empty */ true); for (auto v : values) { if (eq_lowercase(str::trim(v), value)) { return true; } } } } return false; } [[nodiscard]] std::optional header_value( std::string_view name) const override { for (HeaderIterator it{headers_}; it.has_name(); it.next_name()) { if (eq_lowercase(it.name(), name)) { // Do not look for more entries of the header, this method is used for header // values that cannot be combined. return get_value(it); } } return std::nullopt; } private: std::string_view get_value(HeaderIterator& it) const { std::string_view ret = it.value(); if (it.next_value()) { tmp_.clear(); tmp_.append(ret); do { tmp_.append(it.value()); } while (it.next_value()); return tmp_; } return ret; } std::string_view method_; std::string_view path_; std::string_view body_; std::span headers_; mutable std::string tmp_; }; // TODO: What is a good value? const int kListenBacklog = 10; class ServerImpl : public Server { public: ServerImpl(logger::Logger& logger, cfg::Config const& cfg, looper::Looper& looper, std::unique_ptr open_port, Server::Delegate& delegate) : logger_(logger::prefix(logger, "http")), cfg_(cfg), looper_(looper), delegate_(delegate), listens_(static_cast(open_port.get())->release()) { client_timeout_ = std::chrono::duration( cfg_.get_uint64("http.client.timeout.seconds").value_or(60)); client_.resize(cfg_.get_uint64("http.max.clients").value_or(100)); start_to_listen(); } private: static const size_t kInputMinSize = 1024; static const size_t kInputMaxSize = static_cast(1) * 1024 * 1024; static const size_t kOutputSize = static_cast(512) * 1024; struct Client { unique_fd fd; std::unique_ptr in{Buffer::dynamic(kInputMinSize, kInputMaxSize)}; std::unique_ptr out{Buffer::fixed(kOutputSize)}; std::unique_ptr resp; std::chrono::steady_clock::time_point last; uint32_t timeout{0}; bool read_closed_{false}; SocketReceiver* take_over_{nullptr}; }; void start_to_listen() { auto it = listens_.begin(); while (it != listens_.end()) { if (listen(it->get(), static_cast(cfg_.get_uint64("http.listen.backlog") .value_or(kListenBacklog)))) { logger_->warn( std::format("Error listening to socket: {}", strerror(errno))); it = listens_.erase(it); } else { make_nonblock(it->get()); looper_.add( it->get(), looper::EVENT_READ, [this, fd = it->get()](auto events) { accept_client(fd, events); }); ++it; } } if (listens_.empty()) { logger_->err("No ports left to listen to, exit."); looper_.quit(); } } void accept_client(int fd, uint8_t event) { if (event & looper::EVENT_ERROR) { logger_->warn("Listening port returned error, closing."); looper_.remove(fd); auto it = std::ranges::find_if( listens_, [&fd](auto& ufd) { return ufd.get() == fd; }); if (it != listens_.end()) { listens_.erase(it); if (listens_.empty()) { logger_->err("No ports left to listen to, exit."); looper_.quit(); } } return; } unique_fd client_fd(accept4(fd, nullptr, nullptr, SOCK_NONBLOCK)); if (!client_fd) { logger_->info(std::format("Accept returned error: {}", strerror(errno))); return; } if (active_clients_ == client_.size()) { logger_->warn("Max number of clients already."); return; } for (; client_[next_client_].fd; next_client_++) { if (next_client_ == client_.size()) next_client_ = 0; } logger_->dbg(std::format("New client: {}", next_client_)); auto& client = client_[next_client_]; client.fd = std::move(client_fd); client.last = std::chrono::steady_clock::now(); looper_.add( client.fd.get(), looper::EVENT_READ, [this, id = next_client_](auto events) { client_event(id, events); }); client.timeout = looper_.schedule( client_timeout_.count(), [this, id = next_client_](auto handle) { client_timeout(id, handle); }); ++active_clients_; ++next_client_; // TODO: Stopping listening if active_clients_ == client_.size() } void client_event(size_t client_id, uint8_t event) { if (event & looper::EVENT_ERROR) { logger_->info(std::format("Client socket error: {}", client_id)); close_client(client_id); return; } auto& client = client_[client_id]; if (event & looper::EVENT_READ) { size_t avail; auto* ptr = client.in->wptr(avail); if (avail == 0) { logger_->info(std::format("Client too large request: {}", client_id)); close_client(client_id); return; } ssize_t got; while (true) { got = read(client.fd.get(), ptr, avail); if (got < 0 && errno == EINTR) continue; break; } if (got < 0) { if (errno != EWOULDBLOCK && errno != EAGAIN) { logger_->info(std::format("Client read error: {}: {}", client_id, strerror(errno))); close_client(client_id); return; } } else if (got == 0) { client.read_closed_ = true; } else { client.in->commit(got); if (!client_read(client_id)) return; } } while (true) { size_t avail; auto* ptr = client.out->rptr(avail); if (avail == 0) { if (client.resp) { if (!client.resp->write(*client.out)) { client.take_over_ = client.resp->socket_receiver(); client.resp.reset(); if (client.take_over_) { client.read_closed_ = true; } else { if (!client_read(client_id)) return; } } continue; } break; } ssize_t got; while (true) { got = write(client.fd.get(), ptr, avail); if (got < 0 && errno == EINTR) continue; break; } if (got < 0) { if (errno != EWOULDBLOCK && errno != EAGAIN) { logger_->info(std::format("Client write error: {}: {}", client_id, strerror(errno))); close_client(client_id); return; } break; } client.out->consume(got); if (std::cmp_less(got, avail)) break; } if (client.read_closed_ && client.in->empty() && client.out->empty() && !client.resp) { if (client.take_over_) { looper_.remove(client.fd.get()); client.take_over_->receive(std::move(client.fd)); reset_client(client_id); } else { close_client(client_id); } return; } bool want_read = !client.read_closed_ && !client.resp; bool want_write = !client.out->empty(); looper_.update(client.fd.get(), (want_read ? looper::EVENT_READ : 0) | (want_write ? looper::EVENT_WRITE : 0)); } bool client_read(size_t client_id) { auto& client = client_[client_id]; size_t avail; auto* ptr = client.in->rptr(avail); if (avail == 0) return true; std::string_view data(reinterpret_cast(ptr), avail); // TODO: Cache last search to not have to restart all the time auto end = data.find("\r\n\r\n"); if (end == std::string_view::npos) return true; auto lines = str::split(data.substr(0, end), "\r\n", /* keep_empty */ true); std::string_view method; std::string_view path; { auto parts = str::split(lines[0], ' ', false); if (parts.size() != 3) { logger_->info(std::format("Client invalid request: {}", client_id)); close_client(client_id); return false; } method = parts[0]; path = parts[1]; if (!parts[2].starts_with("HTTP/")) { logger_->info(std::format("Client invalid request: {}", client_id)); close_client(client_id); return false; } size_t query = path.find('?'); if (query != std::string_view::npos) path = path.substr(0, query); } for (size_t i = 1; i < lines.size(); ++i) { auto colon = lines[i].find(':'); if (colon == std::string::npos && // Only continue header value lines are allowed to not have a colon. (i == 1 || lines[i].empty() || (lines[i].front() != ' ' && lines[i].front() != '\t'))) { logger_->info(std::format("Client invalid request: {}", client_id)); close_client(client_id); return false; } } RequestImpl request(method, path, std::span{lines}.subspan(1)); auto maybe_content_length = request.header_value("content-length"); if (maybe_content_length.has_value()) { uint64_t content_length; auto [ptr, ec] = std::from_chars( maybe_content_length->data(), maybe_content_length->data() + maybe_content_length->size(), content_length); if (ec != std::errc() || ptr != maybe_content_length->data() + maybe_content_length->size()) { logger_->info(std::format( "Client invalid request, bad content-length: {}", client_id)); close_client(client_id); return false; } if (avail - end < content_length) { // Need more data. if (client.in->full()) { logger_->info(std::format("Client invalid request, too much data: {}", client_id)); close_client(client_id); return false; } return true; } request.set_body(data.substr(end + 4, content_length)); end += content_length; } client.resp = delegate_.handle(request); client.in->consume(end + 4); // If client requests close, do so. if (request.header_contains("connection", "close")) { if (shutdown(client.fd.get(), SHUT_RD) == 0) { client.read_closed_ = true; } } return true; } void client_timeout(size_t client_id, uint32_t id) { auto now = std::chrono::steady_clock::now(); auto& client = client_[client_id]; assert(client.timeout == id); if (now - client.last < std::chrono::duration_cast( client_timeout_)) { // TODO: Reschedule for delay left, not the full one client.timeout = looper_.schedule(client_timeout_.count(), [this, client_id](auto handle) { client_timeout(client_id, handle); }); return; } logger_->dbg(std::format("Client timeout: {}", client_id)); client.timeout = 0; close_client(client_id); } void close_client(size_t client_id) { auto& client = client_[client_id]; if (!client.fd) { assert(false); return; } logger_->dbg(std::format("Drop client: {}", client_id)); looper_.remove(client.fd.get()); client.fd.reset(); reset_client(client_id); } void reset_client(size_t client_id) { auto& client = client_[client_id]; if (client.timeout) looper_.cancel(client.timeout); clear(*client.in); clear(*client.out); client.resp.reset(); client.read_closed_ = false; client.take_over_ = nullptr; assert(active_clients_ > 0); --active_clients_; if (next_client_ == client_id + 1) next_client_ = client_id; } std::unique_ptr logger_; cfg::Config const& cfg_; looper::Looper& looper_; Server::Delegate& delegate_; std::chrono::duration client_timeout_; std::vector listens_; std::vector client_; size_t next_client_{0}; size_t active_clients_{0}; }; } // namespace std::unique_ptr MimeType::create(std::string_view type, std::string_view subtype) { std::string tmp; tmp.reserve(type.size() + 1 + subtype.size()); tmp.append(type); tmp.push_back('/'); tmp.append(subtype); return std::make_unique(std::move(tmp), type.size(), subtype.size()); } std::unique_ptr MimeType::Builder::create( std::string_view type, std::string_view subtype) { return std::make_unique(type, subtype); } std::unique_ptr open_port(std::string_view host_port, logger::Logger& logger) { auto colon = host_port.find(':'); std::string host; std::string port; if (colon == std::string_view::npos) { host = "localhost"; port = host_port; } else { host = host_port.substr(0, colon); port = host_port.substr(colon + 1); } if (host == "*") host = ""; struct addrinfo hints = {}; hints.ai_flags = AI_PASSIVE // Use wildcard if host is nullptr | AI_ADDRCONFIG; // Only use IPv4 or IPv6 if there is at least one interface of that type hints.ai_family = AF_UNSPEC; // Allow IPv4 or IPv6 hints.ai_socktype = SOCK_STREAM; // TCP struct addrinfo* res; auto ret = getaddrinfo(host.empty() ? nullptr : host.c_str(), port.c_str(), &hints, &res); if (ret) { logger.err( std::format("Unable to bind to {}: {}", host_port, gai_strerror(ret))); return nullptr; } struct addrinfo* addr; std::vector fds; for (addr = res; addr != nullptr; addr = addr->ai_next) { unique_fd fd{socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol)}; if (!fd) continue; if (bind(fd.get(), addr->ai_addr, addr->ai_addrlen) == 0) { fds.push_back(std::move(fd)); } } freeaddrinfo(res); if (fds.empty()) { // Assume that the last errno will be bind or socket failing logger.err( std::format("Unable to bind to {}: {}", host_port, strerror(errno))); return nullptr; } return std::make_unique(std::move(fds)); } std::unique_ptr Response::Builder::create( StatusCode status_code) { return std::make_unique(status_code); } std::unique_ptr Response::status(StatusCode status_code) { return Response::Builder::create(status_code)->build(); } std::unique_ptr Response::content(std::string_view content, MimeType const& mime_type) { return Response::Builder::create(StatusCode::kOK) ->content(content) .add_header("Content-Type", mime_type.string()) .build(); } std::unique_ptr create_server(logger::Logger& logger, cfg::Config const& cfg, looper::Looper& looper, std::unique_ptr open_port, Server::Delegate& delegate) { return std::make_unique(logger, cfg, looper, std::move(open_port), delegate); } } // namespace http