diff options
Diffstat (limited to 'src/http.cc')
| -rw-r--r-- | src/http.cc | 724 |
1 files changed, 724 insertions, 0 deletions
diff --git a/src/http.cc b/src/http.cc new file mode 100644 index 0000000..8a0a1e6 --- /dev/null +++ b/src/http.cc @@ -0,0 +1,724 @@ +#include "http.hh" + +#include "buffer.hh" +#include "cfg.hh" +#include "logger.hh" +#include "looper.hh" +#include "str.hh" +#include "unique_fd.hh" + +#include <algorithm> +#include <cassert> +#include <cerrno> +#include <charconv> +#include <chrono> +#include <cstdint> +#include <cstring> +#include <fcntl.h> +#include <format> +#include <map> +#include <memory> +#include <netdb.h> +#include <optional> +#include <string> +#include <string_view> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> +#include <utility> +#include <vector> + +namespace http { + +namespace { + +std::string_view ascii_lowercase(std::string_view input, std::string& tmp) { + auto it = input.begin(); + for (; it != input.end(); ++it) { + if (*it >= 'A' && *it <= 'Z') { + tmp.resize(input.size()); + auto out = std::copy(input.begin(), it, tmp.begin()); + // NOLINTNEXTLINE(bugprone-narrowing-conversions) + *out++ = *it++ | 0x20; + for (; it != input.end(); ++it) { + if (*it >= 'A' && *it <= 'Z') { + // NOLINTNEXTLINE(bugprone-narrowing-conversions) + *out++ = *it | 0x20; + } else { + *out++ = *it; + } + } + return tmp; + } + } + return input; +} + +std::string standard_header(std::string_view input) { + bool capitalize = true; + std::string ret; + for (auto c : input) { + if (c >= 'A' && c <= 'Z') { + if (capitalize) { + ret.push_back(c); + capitalize = false; + } else { + // NOLINTNEXTLINE(bugprone-narrowing-conversions) + ret.push_back(c | 0x20); + } + } else if (c >= 'a' && c <= 'z') { + if (capitalize) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions) + ret.push_back(c & 0x20); + capitalize = false; + } else { + ret.push_back(c); + } + } else if (c == '-') { + ret.push_back(c); + capitalize = true; + } else { + ret.push_back(c); + } + } + return ret; +} + +void make_nonblock(int fd) { + int flags = fcntl(fd, F_GETFL); + if (flags & O_NONBLOCK) + return; + fcntl(fd, F_SETFL, flags | O_NONBLOCK); +} + +void clear(Buffer& buffer) { + if (buffer.empty()) + return; + while (true) { + size_t avail; + buffer.rptr(avail); + if (avail == 0) + break; + buffer.consume(avail); + } +} + +class MimeTypeImpl : public MimeType { + public: + MimeTypeImpl(std::string str, size_t slash, size_t subtype_len, + std::map<std::string_view, std::string_view> params = {}) + : str_(std::move(str)), + slash_(slash), + subtype_len_(subtype_len), + params_(std::move(params)) { + assert(slash_ < str_.size()); + assert(slash_ + 1 + subtype_len_ <= str_.size()); + } + + [[nodiscard]] + std::string_view type() const override { + return std::string_view(str_).substr(0, slash_); + } + + [[nodiscard]] + std::string_view subtype() const override { + return std::string_view(str_).substr(slash_ + 1, subtype_len_); + } + + [[nodiscard]] + std::optional<std::string_view> parameter( + std::string_view name) const override { + std::optional<std::string_view> ret; + std::string tmp; + auto it = params_.find(ascii_lowercase(name, tmp)); + if (it != params_.end()) + ret = it->second; + return ret; + } + + [[nodiscard]] + std::string const& string() const override { + return str_; + } + + private: + std::string const str_; + size_t const slash_; + size_t const subtype_len_; + std::map<std::string_view, std::string_view> const params_; +}; + +class MimeTypeBuilderImpl : public MimeType::Builder { + public: + MimeTypeBuilderImpl(std::string_view type, std::string_view subtype) + : slash_(type.size()), subtype_len_(subtype.size()) { + str_.reserve(type.size() + 1 + subtype.size()); + str_.append(type); + str_.push_back('/'); + str_.append(subtype); + } + + MimeType::Builder& parameter(std::string_view name, + std::string_view value) override { + // TODO: make value a quoted string if it contains ; + std::string tmp; + auto ret = params_.emplace(ascii_lowercase(name, tmp), value); + if (!ret.second) { + ret.first->second = value; + } + return *this; + } + + [[nodiscard]] + std::unique_ptr<MimeType> build() const override { + if (params_.empty()) { + return std::make_unique<MimeTypeImpl>(str_, slash_, subtype_len_); + } + + std::string tmp(str_); + std::map<std::pair<size_t, size_t>, size_t> params1; + for (auto const& pair : params_) { + tmp.append("; "); + params1.emplace(std::make_pair(tmp.size(), pair.first.size()), + pair.second.size()); + tmp.append(pair.first); + tmp.push_back('='); + tmp.append(pair.second); + } + std::map<std::string_view, std::string_view> params2; + std::string_view tmp_view(tmp); + for (auto const& pair : params1) { + params2.emplace(tmp_view.substr(pair.first.first, pair.first.second), + tmp_view.substr(pair.first.first + pair.first.second + 1, + pair.second)); + } + return std::make_unique<MimeTypeImpl>(std::move(tmp), slash_, subtype_len_, + std::move(params2)); + } + + private: + std::string str_; + size_t const slash_; + size_t const subtype_len_; + std::map<std::string, std::string> params_; +}; + +class OpenPortImpl : public OpenPort { + public: + explicit OpenPortImpl(std::vector<unique_fd> fd) : fd_(std::move(fd)) {} + + std::vector<unique_fd> release() { return std::move(fd_); } + + private: + std::vector<unique_fd> fd_; +}; + +class ResponseImpl : public Response { + public: + explicit ResponseImpl(std::string data) : data_(std::move(data)) {} + + bool write(Buffer& buffer) override { + if (offset_ >= data_.size()) + return false; + + size_t avail; + auto* ptr = buffer.wptr(avail); + if (avail == 0) + return true; + + avail = std::min(data_.size() - offset_, avail); + + std::copy_n(data_.data() + offset_, avail, reinterpret_cast<char*>(ptr)); + offset_ += avail; + buffer.commit(avail); + + return offset_ < data_.size(); + } + + private: + std::string const data_; + size_t offset_{0}; +}; + +class ResponseBuilderImpl : public Response::Builder { + public: + explicit ResponseBuilderImpl(StatusCode status) : status_(status) {} + + Response::Builder& content(std::string_view content) override { + content_ = content; + return *this; + } + + Response::Builder& add_header(std::string_view name, + std::string_view value) override { + // TODO: Make sure name or value doesn't contain invalid chars + headers_.emplace_back(standard_header(name), value); + return *this; + } + + [[nodiscard]] + std::unique_ptr<Response> build() const override { + std::string ret; + ret.reserve(100 + content_.size()); + ret.append("HTTP/1.1 "); + { + char tmp[4]; + auto [ptr, ec] = + std::to_chars(tmp, tmp + sizeof(tmp), std::to_underlying(status_)); + ret.append(tmp, ptr - tmp); + } + ret.append(" "); + switch (status_) { + case StatusCode::kOK: + ret.append("OK"); + break; + case StatusCode::kNoContent: + ret.append("No Content"); + break; + case StatusCode::kNotFound: + ret.append("Not Found"); + break; + case StatusCode::kMethodNotAllowed: + ret.append("Method Not Allowed"); + break; + } + ret.append("\r\n"); + bool have_content_len = false; + for (auto const& pair : headers_) { + if (!have_content_len && pair.first == "Content-Length") + have_content_len = true; + ret.append(pair.first); + ret.append(": "); + ret.append(pair.second); + ret.append("\r\n"); + } + if (!have_content_len && status_ != StatusCode::kNoContent) { + char tmp[20]; + auto [ptr, ec] = std::to_chars(tmp, tmp + sizeof(tmp), content_.size()); + ret.append("Content-Length"); + ret.append(": "); + ret.append(tmp, ptr - tmp); + ret.append("\r\n"); + } + ret.append("\r\n"); + if (status_ != StatusCode::kNoContent) + ret.append(content_); + return std::make_unique<ResponseImpl>(std::move(ret)); + } + + private: + StatusCode const status_; + std::string content_; + std::vector<std::pair<std::string, std::string>> headers_; +}; + +class RequestImpl : public Request { + public: + RequestImpl(std::string_view method, std::string_view path) + : method_(method), path_(path) {} + + [[nodiscard]] + std::string_view method() const override { + return method_; + } + + [[nodiscard]] + std::string_view path() const override { + return path_; + } + + private: + std::string_view method_; + std::string_view path_; +}; + +// TODO: What is a good value? +const int kListenBacklog = 10; + +class ServerImpl : public Server { + public: + ServerImpl(logger::Logger& logger, cfg::Config const& cfg, + looper::Looper& looper, std::unique_ptr<OpenPort> open_port, + Server::Delegate& delegate) + : logger_(logger), + cfg_(cfg), + looper_(looper), + delegate_(delegate), + listens_(static_cast<OpenPortImpl*>(open_port.get())->release()) { + client_timeout_ = std::chrono::duration<double>( + cfg_.get_uint64("http.client.timeout.seconds").value_or(60)); + client_.resize(cfg_.get_uint64("http.max.clients").value_or(100)); + start_to_listen(); + } + + private: + static const size_t kInputMinSize = 1024; + static const size_t kInputMaxSize = static_cast<size_t>(1) * 1024 * 1024; + static const size_t kOutputSize = static_cast<size_t>(512) * 1024; + + struct Client { + unique_fd fd; + std::unique_ptr<Buffer> in{Buffer::dynamic(kInputMinSize, kInputMaxSize)}; + std::unique_ptr<Buffer> out{Buffer::fixed(kOutputSize)}; + std::unique_ptr<Response> resp; + std::chrono::steady_clock::time_point last; + uint32_t timeout{0}; + bool read_closed_{false}; + }; + + void start_to_listen() { + auto it = listens_.begin(); + while (it != listens_.end()) { + if (listen(it->get(), + static_cast<int>(cfg_.get_uint64("http.listen.backlog") + .value_or(kListenBacklog)))) { + logger_.warn( + std::format("Error listening to socket: {}", strerror(errno))); + it = listens_.erase(it); + } else { + make_nonblock(it->get()); + + looper_.add( + it->get(), looper::EVENT_READ, + [this, fd = it->get()](auto events) { accept_client(fd, events); }); + ++it; + } + } + + if (listens_.empty()) { + logger_.err("No ports left to listen to, exit."); + looper_.quit(); + } + } + + void accept_client(int fd, uint8_t event) { + if (event & looper::EVENT_ERROR) { + logger_.warn("Listening port returned error, closing."); + looper_.remove(fd); + auto it = std::ranges::find_if( + listens_, [&fd](auto& ufd) { return ufd.get() == fd; }); + if (it != listens_.end()) { + listens_.erase(it); + if (listens_.empty()) { + logger_.err("No ports left to listen to, exit."); + looper_.quit(); + } + } + return; + } + + unique_fd client_fd(accept4(fd, nullptr, nullptr, SOCK_NONBLOCK)); + if (!client_fd) { + logger_.info(std::format("Accept returned error: {}", strerror(errno))); + return; + } + + if (active_clients_ == client_.size()) { + logger_.warn("Max number of clients already."); + return; + } + + for (; client_[next_client_].fd; next_client_++) { + if (next_client_ == client_.size()) + next_client_ = 0; + } + + logger_.dbg(std::format("New client: {}", next_client_)); + + auto& client = client_[next_client_]; + client.fd = std::move(client_fd); + client.last = std::chrono::steady_clock::now(); + + looper_.add( + client.fd.get(), looper::EVENT_READ, + [this, id = next_client_](auto events) { client_event(id, events); }); + client.timeout = looper_.schedule( + client_timeout_.count(), + [this, id = next_client_](auto handle) { client_timeout(id, handle); }); + + ++active_clients_; + ++next_client_; + // TODO: Stopping listening if active_clients_ == client_.size() + } + + void client_event(size_t client_id, uint8_t event) { + if (event & looper::EVENT_ERROR) { + logger_.info(std::format("Client socket error: {}", client_id)); + close_client(client_id); + return; + } + + auto& client = client_[client_id]; + if (event & looper::EVENT_READ) { + size_t avail; + auto* ptr = client.in->wptr(avail); + if (avail == 0) { + logger_.info(std::format("Client too large request: {}", client_id)); + close_client(client_id); + return; + } + + ssize_t got; + while (true) { + got = read(client.fd.get(), ptr, avail); + if (got < 0 && errno == EINTR) + continue; + break; + } + if (got < 0) { + if (errno != EWOULDBLOCK && errno != EAGAIN) { + logger_.info(std::format("Client read error: {}: {}", client_id, + strerror(errno))); + close_client(client_id); + return; + } + } else if (got == 0) { + client.read_closed_ = true; + } else { + client.in->commit(got); + + if (!client_read(client_id)) + return; + } + } + + while (true) { + size_t avail; + auto* ptr = client.out->rptr(avail); + if (avail == 0) { + if (client.resp) { + if (!client.resp->write(*client.out)) { + client.resp.reset(); + + if (!client_read(client_id)) + return; + } + continue; + } + break; + } + ssize_t got; + while (true) { + got = write(client.fd.get(), ptr, avail); + if (got < 0 && errno == EINTR) + continue; + break; + } + if (got < 0) { + if (errno != EWOULDBLOCK && errno != EAGAIN) { + logger_.info(std::format("Client write error: {}: {}", client_id, + strerror(errno))); + close_client(client_id); + return; + } + break; + } + client.out->consume(got); + + if (std::cmp_less(got, avail)) + break; + } + + if (client.read_closed_ && client.in->empty() && client.out->empty() && + !client.resp) { + close_client(client_id); + return; + } + + bool want_read = !client.read_closed_ && !client.resp; + bool want_write = !client.out->empty(); + + looper_.update(client.fd.get(), (want_read ? looper::EVENT_READ : 0) | + (want_write ? looper::EVENT_WRITE : 0)); + } + + bool client_read(size_t client_id) { + auto& client = client_[client_id]; + size_t avail; + auto* ptr = client.in->rptr(avail); + if (avail == 0) + return true; + std::string_view data(reinterpret_cast<char const*>(ptr), avail); + // TODO: Cache last search to not have to restart all the time + auto end = data.find("\r\n\r\n"); + if (end == std::string_view::npos) + return true; + + auto lines = str::split(data.substr(0, end), "\r\n", /* keep_empty */ true); + std::string_view method; + std::string_view path; + { + auto parts = str::split(lines[0], ' ', false); + if (parts.size() != 3) { + logger_.info(std::format("Client invalid request: {}", client_id)); + close_client(client_id); + return false; + } + method = parts[0]; + path = parts[1]; + if (!parts[2].starts_with("HTTP/")) { + logger_.info(std::format("Client invalid request: {}", client_id)); + close_client(client_id); + return false; + } + } + RequestImpl request(method, path); + client.resp = delegate_.handle(request); + client.in->consume(end + 4); + return true; + } + + void client_timeout(size_t client_id, uint32_t id) { + auto now = std::chrono::steady_clock::now(); + auto& client = client_[client_id]; + assert(client.timeout == id); + if (now - client.last < + std::chrono::duration_cast<std::chrono::steady_clock::duration>( + client_timeout_)) { + // TODO: Reschedule for delay left, not the full one + client.timeout = looper_.schedule(client_timeout_.count(), + [this, client_id](auto handle) { + client_timeout(client_id, handle); + }); + return; + } + + logger_.dbg(std::format("Client timeout: {}", client_id)); + + client.timeout = 0; + close_client(client_id); + } + + void close_client(size_t client_id) { + auto& client = client_[client_id]; + if (!client.fd) { + assert(false); + return; + } + + logger_.dbg(std::format("Drop client: {}", client_id)); + + looper_.remove(client.fd.get()); + if (client.timeout) + looper_.cancel(client.timeout); + client.fd.reset(); + clear(*client.in); + clear(*client.out); + client.resp.reset(); + client.read_closed_ = false; + assert(active_clients_ > 0); + --active_clients_; + if (next_client_ == client_id + 1) + next_client_ = client_id; + } + + logger::Logger& logger_; + cfg::Config const& cfg_; + looper::Looper& looper_; + Server::Delegate& delegate_; + std::chrono::duration<double> client_timeout_; + std::vector<unique_fd> listens_; + std::vector<Client> client_; + size_t next_client_{0}; + size_t active_clients_{0}; +}; + +} // namespace + +std::unique_ptr<MimeType> MimeType::create(std::string_view type, + std::string_view subtype) { + std::string tmp; + tmp.reserve(type.size() + 1 + subtype.size()); + tmp.append(type); + tmp.push_back('/'); + tmp.append(subtype); + return std::make_unique<MimeTypeImpl>(std::move(tmp), type.size(), + subtype.size()); +} + +std::unique_ptr<MimeType::Builder> MimeType::Builder::create( + std::string_view type, std::string_view subtype) { + return std::make_unique<MimeTypeBuilderImpl>(type, subtype); +} + +std::unique_ptr<OpenPort> open_port(std::string_view host_port, + logger::Logger& logger) { + auto colon = host_port.find(':'); + std::string host; + std::string port; + if (colon == std::string_view::npos) { + host = "localhost"; + port = host_port; + } else { + host = host_port.substr(0, colon); + port = host_port.substr(colon + 1); + } + + if (host == "*") + host = ""; + + struct addrinfo hints = {}; + hints.ai_flags = + AI_PASSIVE // Use wildcard if host is nullptr + | + AI_ADDRCONFIG; // Only use IPv4 or IPv6 if there is at least one interface of that type + hints.ai_family = AF_UNSPEC; // Allow IPv4 or IPv6 + hints.ai_socktype = SOCK_STREAM; // TCP + struct addrinfo* res; + auto ret = getaddrinfo(host.empty() ? nullptr : host.c_str(), port.c_str(), + &hints, &res); + if (ret) { + logger.err( + std::format("Unable to bind to {}: {}", host_port, gai_strerror(ret))); + return nullptr; + } + + struct addrinfo* addr; + std::vector<unique_fd> fds; + for (addr = res; addr != nullptr; addr = addr->ai_next) { + unique_fd fd{socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol)}; + if (!fd) + continue; + if (bind(fd.get(), addr->ai_addr, addr->ai_addrlen) == 0) { + fds.push_back(std::move(fd)); + } + } + freeaddrinfo(res); + + if (fds.empty()) { + // Assume that the last errno will be bind or socket failing + logger.err( + std::format("Unable to bind to {}: {}", host_port, strerror(errno))); + return nullptr; + } + + return std::make_unique<OpenPortImpl>(std::move(fds)); +} + +std::unique_ptr<Response::Builder> Response::Builder::create( + StatusCode status_code) { + return std::make_unique<ResponseBuilderImpl>(status_code); +} + +std::unique_ptr<Response> Response::status(StatusCode status_code) { + return Response::Builder::create(status_code)->build(); +} + +std::unique_ptr<Response> Response::content(std::string_view content, + MimeType const& mime_type) { + return Response::Builder::create(StatusCode::kOK) + ->content(content) + .add_header("Content-Type", mime_type.string()) + .build(); +} + +std::unique_ptr<Server> create_server(logger::Logger& logger, + cfg::Config const& cfg, + looper::Looper& looper, + std::unique_ptr<OpenPort> open_port, + Server::Delegate& delegate) { + return std::make_unique<ServerImpl>(logger, cfg, looper, std::move(open_port), + delegate); +} + +} // namespace http |
