diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/http.cc | 295 | ||||
| -rw-r--r-- | src/http.hh | 28 | ||||
| -rw-r--r-- | src/logger.cc | 34 | ||||
| -rw-r--r-- | src/logger.hh | 3 | ||||
| -rw-r--r-- | src/main.cc | 127 | ||||
| -rw-r--r-- | src/websocket.cc | 676 | ||||
| -rw-r--r-- | src/websocket.hh | 88 |
7 files changed, 1213 insertions, 38 deletions
diff --git a/src/http.cc b/src/http.cc index 8a0a1e6..21de28d 100644 --- a/src/http.cc +++ b/src/http.cc @@ -20,10 +20,12 @@ #include <memory> #include <netdb.h> #include <optional> +#include <span> #include <string> #include <string_view> #include <sys/socket.h> #include <sys/types.h> +#include <system_error> #include <unistd.h> #include <utility> #include <vector> @@ -54,6 +56,25 @@ std::string_view ascii_lowercase(std::string_view input, std::string& tmp) { return input; } +bool eq_lowercase(std::string_view input, std::string_view cmp) { + if (input.size() != cmp.size()) + return false; + + auto a = input.begin(); + auto b = cmp.begin(); + for (; a != input.end(); ++a, ++b) { + if (*a >= 'A' && *a <= 'Z') { + // NOLINTNEXTLINE(bugprone-narrowing-conversions) + if ((*a | 0x20) != *b) + return false; + } else { + if (*a != *b) + return false; + } + } + return true; +} + std::string standard_header(std::string_view input) { bool capitalize = true; std::string ret; @@ -215,7 +236,8 @@ class OpenPortImpl : public OpenPort { class ResponseImpl : public Response { public: - explicit ResponseImpl(std::string data) : data_(std::move(data)) {} + ResponseImpl(std::string data, SocketReceiver* socket_receiver) + : data_(std::move(data)), socket_receiver_(socket_receiver) {} bool write(Buffer& buffer) override { if (offset_ >= data_.size()) @@ -235,9 +257,15 @@ class ResponseImpl : public Response { return offset_ < data_.size(); } + [[nodiscard]] + SocketReceiver* socket_receiver() const override { + return socket_receiver_; + } + private: std::string const data_; size_t offset_{0}; + SocketReceiver* socket_receiver_; }; class ResponseBuilderImpl : public Response::Builder { @@ -256,6 +284,11 @@ class ResponseBuilderImpl : public Response::Builder { return *this; } + Builder& take_over(SocketReceiver& receiver) override { + socket_receiver_ = &receiver; + return *this; + } + [[nodiscard]] std::unique_ptr<Response> build() const override { std::string ret; @@ -269,12 +302,18 @@ class ResponseBuilderImpl : public Response::Builder { } ret.append(" "); switch (status_) { + case StatusCode::kSwitchingProtocols: + ret.append("Switching Protocols"); + break; case StatusCode::kOK: ret.append("OK"); break; case StatusCode::kNoContent: ret.append("No Content"); break; + case StatusCode::kBadRequest: + ret.append("Bad Request"); + break; case StatusCode::kNotFound: ret.append("Not Found"); break; @@ -292,7 +331,20 @@ class ResponseBuilderImpl : public Response::Builder { ret.append(pair.second); ret.append("\r\n"); } - if (!have_content_len && status_ != StatusCode::kNoContent) { + bool need_content_len; + switch (status_) { + case StatusCode::kSwitchingProtocols: + case StatusCode::kNoContent: + need_content_len = false; + break; + case StatusCode::kOK: + case StatusCode::kBadRequest: + case StatusCode::kNotFound: + case StatusCode::kMethodNotAllowed: + need_content_len = true; + break; + } + if (!have_content_len && need_content_len) { char tmp[20]; auto [ptr, ec] = std::to_chars(tmp, tmp + sizeof(tmp), content_.size()); ret.append("Content-Length"); @@ -303,19 +355,93 @@ class ResponseBuilderImpl : public Response::Builder { ret.append("\r\n"); if (status_ != StatusCode::kNoContent) ret.append(content_); - return std::make_unique<ResponseImpl>(std::move(ret)); + return std::make_unique<ResponseImpl>(std::move(ret), socket_receiver_); } private: StatusCode const status_; std::string content_; std::vector<std::pair<std::string, std::string>> headers_; + SocketReceiver* socket_receiver_{nullptr}; +}; + +class HeaderIterator { + public: + explicit HeaderIterator(std::span<std::string_view> headers) + : headers_(headers) { + next(); + } + + [[nodiscard]] + bool has_name() const { + return offset_ < headers_.size(); + } + + [[nodiscard]] + bool has_value() const { + return has_name() && continue_value_; + } + + bool next_name() { + do { + next(); + } while (has_value()); + return has_name(); + } + + bool next_value() { + next(); + return has_value(); + } + + [[nodiscard]] + std::string_view name() const { + return name_; + } + + [[nodiscard]] + std::string_view value() const { + return value_; + } + + private: + void next() { + if (offset_ >= headers_.size()) + return; + + auto line = headers_[offset_]; + if (offset_ > 0 && + (line.empty() || line.front() == ' ' || line.front() == '\t')) { + continue_value_ = true; + value_ = str::ltrim(line); + ++offset_; + return; + } + auto colon = line.find(':'); + if (colon == std::string::npos) { + assert(false); + ++offset_; + next(); + return; + } + continue_value_ = false; + name_ = str::trim(line.substr(0, colon)); + value_ = str::ltrim(line.substr(colon + 1)); + ++offset_; + } + + std::span<std::string_view> headers_; + size_t offset_{0}; + bool continue_value_{false}; + std::string_view name_; + std::string_view value_; }; class RequestImpl : public Request { public: - RequestImpl(std::string_view method, std::string_view path) - : method_(method), path_(path) {} + RequestImpl(std::string_view method, std::string_view path, + std::span<std::string_view> headers) + : method_(method), path_(path), headers_(headers) {} [[nodiscard]] std::string_view method() const override { @@ -327,9 +453,62 @@ class RequestImpl : public Request { return path_; } + [[nodiscard]] + std::string_view body() const override { + return body_; + } + + void set_body(std::string_view body) { body_ = body; } + + [[nodiscard]] + bool header_contains(std::string_view name, + std::string_view value) const override { + for (HeaderIterator it{headers_}; it.has_name(); it.next_name()) { + if (eq_lowercase(it.name(), name)) { + std::string_view tmp = get_value(it); + auto values = str::split(tmp, ',', /* keep_empty */ true); + for (auto v : values) { + if (eq_lowercase(str::trim(v), value)) { + return true; + } + } + } + } + return false; + } + + [[nodiscard]] + std::optional<std::string_view> header_value( + std::string_view name) const override { + for (HeaderIterator it{headers_}; it.has_name(); it.next_name()) { + if (eq_lowercase(it.name(), name)) { + // Do not look for more entries of the header, this method is used for header + // values that cannot be combined. + return get_value(it); + } + } + return std::nullopt; + } + private: + std::string_view get_value(HeaderIterator& it) const { + std::string_view ret = it.value(); + if (it.next_value()) { + tmp_.clear(); + tmp_.append(ret); + do { + tmp_.append(it.value()); + } while (it.next_value()); + return tmp_; + } + return ret; + } + std::string_view method_; std::string_view path_; + std::string_view body_; + std::span<std::string_view> headers_; + mutable std::string tmp_; }; // TODO: What is a good value? @@ -340,7 +519,7 @@ class ServerImpl : public Server { ServerImpl(logger::Logger& logger, cfg::Config const& cfg, looper::Looper& looper, std::unique_ptr<OpenPort> open_port, Server::Delegate& delegate) - : logger_(logger), + : logger_(logger::prefix(logger, "http")), cfg_(cfg), looper_(looper), delegate_(delegate), @@ -364,6 +543,7 @@ class ServerImpl : public Server { std::chrono::steady_clock::time_point last; uint32_t timeout{0}; bool read_closed_{false}; + SocketReceiver* take_over_{nullptr}; }; void start_to_listen() { @@ -372,7 +552,7 @@ class ServerImpl : public Server { if (listen(it->get(), static_cast<int>(cfg_.get_uint64("http.listen.backlog") .value_or(kListenBacklog)))) { - logger_.warn( + logger_->warn( std::format("Error listening to socket: {}", strerror(errno))); it = listens_.erase(it); } else { @@ -386,21 +566,21 @@ class ServerImpl : public Server { } if (listens_.empty()) { - logger_.err("No ports left to listen to, exit."); + logger_->err("No ports left to listen to, exit."); looper_.quit(); } } void accept_client(int fd, uint8_t event) { if (event & looper::EVENT_ERROR) { - logger_.warn("Listening port returned error, closing."); + logger_->warn("Listening port returned error, closing."); looper_.remove(fd); auto it = std::ranges::find_if( listens_, [&fd](auto& ufd) { return ufd.get() == fd; }); if (it != listens_.end()) { listens_.erase(it); if (listens_.empty()) { - logger_.err("No ports left to listen to, exit."); + logger_->err("No ports left to listen to, exit."); looper_.quit(); } } @@ -409,12 +589,12 @@ class ServerImpl : public Server { unique_fd client_fd(accept4(fd, nullptr, nullptr, SOCK_NONBLOCK)); if (!client_fd) { - logger_.info(std::format("Accept returned error: {}", strerror(errno))); + logger_->info(std::format("Accept returned error: {}", strerror(errno))); return; } if (active_clients_ == client_.size()) { - logger_.warn("Max number of clients already."); + logger_->warn("Max number of clients already."); return; } @@ -423,7 +603,7 @@ class ServerImpl : public Server { next_client_ = 0; } - logger_.dbg(std::format("New client: {}", next_client_)); + logger_->dbg(std::format("New client: {}", next_client_)); auto& client = client_[next_client_]; client.fd = std::move(client_fd); @@ -443,7 +623,7 @@ class ServerImpl : public Server { void client_event(size_t client_id, uint8_t event) { if (event & looper::EVENT_ERROR) { - logger_.info(std::format("Client socket error: {}", client_id)); + logger_->info(std::format("Client socket error: {}", client_id)); close_client(client_id); return; } @@ -453,7 +633,7 @@ class ServerImpl : public Server { size_t avail; auto* ptr = client.in->wptr(avail); if (avail == 0) { - logger_.info(std::format("Client too large request: {}", client_id)); + logger_->info(std::format("Client too large request: {}", client_id)); close_client(client_id); return; } @@ -467,8 +647,8 @@ class ServerImpl : public Server { } if (got < 0) { if (errno != EWOULDBLOCK && errno != EAGAIN) { - logger_.info(std::format("Client read error: {}: {}", client_id, - strerror(errno))); + logger_->info(std::format("Client read error: {}: {}", client_id, + strerror(errno))); close_client(client_id); return; } @@ -488,10 +668,15 @@ class ServerImpl : public Server { if (avail == 0) { if (client.resp) { if (!client.resp->write(*client.out)) { + client.take_over_ = client.resp->socket_receiver(); client.resp.reset(); - if (!client_read(client_id)) - return; + if (client.take_over_) { + client.read_closed_ = true; + } else { + if (!client_read(client_id)) + return; + } } continue; } @@ -506,8 +691,8 @@ class ServerImpl : public Server { } if (got < 0) { if (errno != EWOULDBLOCK && errno != EAGAIN) { - logger_.info(std::format("Client write error: {}: {}", client_id, - strerror(errno))); + logger_->info(std::format("Client write error: {}: {}", client_id, + strerror(errno))); close_client(client_id); return; } @@ -521,7 +706,13 @@ class ServerImpl : public Server { if (client.read_closed_ && client.in->empty() && client.out->empty() && !client.resp) { - close_client(client_id); + if (client.take_over_) { + looper_.remove(client.fd.get()); + client.take_over_->receive(std::move(client.fd)); + reset_client(client_id); + } else { + close_client(client_id); + } return; } @@ -550,19 +741,57 @@ class ServerImpl : public Server { { auto parts = str::split(lines[0], ' ', false); if (parts.size() != 3) { - logger_.info(std::format("Client invalid request: {}", client_id)); + logger_->info(std::format("Client invalid request: {}", client_id)); close_client(client_id); return false; } method = parts[0]; path = parts[1]; if (!parts[2].starts_with("HTTP/")) { - logger_.info(std::format("Client invalid request: {}", client_id)); + logger_->info(std::format("Client invalid request: {}", client_id)); close_client(client_id); return false; } } - RequestImpl request(method, path); + for (size_t i = 1; i < lines.size(); ++i) { + auto colon = lines[i].find(':'); + if (colon == std::string::npos && + // Only continue header value lines are allowed to not have a colon. + (i == 1 || lines[i].empty() || + (lines[i].front() != ' ' && lines[i].front() != '\t'))) { + logger_->info(std::format("Client invalid request: {}", client_id)); + close_client(client_id); + return false; + } + } + RequestImpl request(method, path, std::span{lines}.subspan(1)); + auto maybe_content_length = request.header_value("content-length"); + if (maybe_content_length.has_value()) { + uint64_t content_length; + auto [ptr, ec] = std::from_chars( + maybe_content_length->data(), + maybe_content_length->data() + maybe_content_length->size(), + content_length); + if (ec != std::errc() || + ptr != maybe_content_length->data() + maybe_content_length->size()) { + logger_->info(std::format( + "Client invalid request, bad content-length: {}", client_id)); + close_client(client_id); + return false; + } + if (avail - end < content_length) { + // Need more data. + if (client.in->full()) { + logger_->info(std::format("Client invalid request, too much data: {}", + client_id)); + close_client(client_id); + return false; + } + return true; + } + request.set_body(data.substr(end + 4, content_length)); + end += content_length; + } client.resp = delegate_.handle(request); client.in->consume(end + 4); return true; @@ -583,7 +812,7 @@ class ServerImpl : public Server { return; } - logger_.dbg(std::format("Client timeout: {}", client_id)); + logger_->dbg(std::format("Client timeout: {}", client_id)); client.timeout = 0; close_client(client_id); @@ -596,23 +825,31 @@ class ServerImpl : public Server { return; } - logger_.dbg(std::format("Drop client: {}", client_id)); + logger_->dbg(std::format("Drop client: {}", client_id)); looper_.remove(client.fd.get()); + client.fd.reset(); + + reset_client(client_id); + } + + void reset_client(size_t client_id) { + auto& client = client_[client_id]; + if (client.timeout) looper_.cancel(client.timeout); - client.fd.reset(); clear(*client.in); clear(*client.out); client.resp.reset(); client.read_closed_ = false; + client.take_over_ = nullptr; assert(active_clients_ > 0); --active_clients_; if (next_client_ == client_id + 1) next_client_ = client_id; } - logger::Logger& logger_; + std::unique_ptr<logger::Logger> logger_; cfg::Config const& cfg_; looper::Looper& looper_; Server::Delegate& delegate_; diff --git a/src/http.hh b/src/http.hh index ca3f7d4..d85420e 100644 --- a/src/http.hh +++ b/src/http.hh @@ -1,6 +1,8 @@ #ifndef HTTP_HH #define HTTP_HH +#include "unique_fd.hh" + #include <cstdint> #include <memory> #include <optional> @@ -23,8 +25,10 @@ class Config; namespace http { enum class StatusCode : uint16_t { + kSwitchingProtocols = 101, kOK = 200, kNoContent = 204, + kBadRequest = 400, kNotFound = 404, kMethodNotAllowed = 405, }; @@ -96,6 +100,16 @@ class Request { virtual std::string_view method() const = 0; [[nodiscard]] virtual std::string_view path() const = 0; + [[nodiscard]] + virtual std::string_view body() const = 0; + + [[nodiscard]] + virtual bool header_contains(std::string_view name, + std::string_view value) const = 0; + + [[nodiscard]] + virtual std::optional<std::string_view> header_value( + std::string_view name) const = 0; protected: Request() = default; @@ -103,6 +117,16 @@ class Request { Request& operator=(Request const&) = delete; }; +class SocketReceiver { + public: + virtual ~SocketReceiver() = default; + + virtual void receive(unique_fd&& fd) = 0; + + protected: + SocketReceiver() = default; +}; + class Response { public: virtual ~Response() = default; @@ -115,6 +139,8 @@ class Response { virtual Builder& add_header(std::string_view name, std::string_view value) = 0; + virtual Builder& take_over(SocketReceiver& receiver) = 0; + [[nodiscard]] virtual std::unique_ptr<Response> build() const = 0; @@ -140,6 +166,8 @@ class Response { // Returns true while there is more data to write. virtual bool write(Buffer& buffer) = 0; + virtual SocketReceiver* socket_receiver() const = 0; + protected: Response() = default; Response(Response const&) = delete; diff --git a/src/logger.cc b/src/logger.cc index 21effff..104797f 100644 --- a/src/logger.cc +++ b/src/logger.cc @@ -1,6 +1,7 @@ #include "logger.hh" #include <cstdint> +#include <format> #include <iostream> #include <memory> #include <string> @@ -125,6 +126,34 @@ class StderrLogger : public BaseLogger { bool const verbose_; }; +class PrefixLogger : public Logger { + public: + PrefixLogger(Logger& logger, std::string prefix) + : logger_(logger), prefix_(std::move(prefix)) {} + + void err(std::string_view message) override { + logger_.err(std::format("{}: {}", prefix_, message)); + } + + void warn(std::string_view message) override { + logger_.warn(std::format("{}: {}", prefix_, message)); + } + + void info(std::string_view message) override { + logger_.info(std::format("{}: {}", prefix_, message)); + } + +#if !defined(NDEBUG) + void dbg(std::string_view message) override { + logger_.dbg(std::format("{}: {}", prefix_, message)); + } +#endif + + private: + Logger& logger_; + std::string prefix_; +}; + } // namespace [[nodiscard]] @@ -142,4 +171,9 @@ std::unique_ptr<Logger> stderr(bool verbose) { return std::make_unique<StderrLogger>(verbose); } +[[nodiscard]] +std::unique_ptr<Logger> prefix(Logger& logger, std::string prefix) { + return std::make_unique<PrefixLogger>(logger, std::move(prefix)); +} + } // namespace logger diff --git a/src/logger.hh b/src/logger.hh index 5c1e599..88edcf3 100644 --- a/src/logger.hh +++ b/src/logger.hh @@ -36,6 +36,9 @@ std::unique_ptr<Logger> syslog(std::string ident, bool verbose = false); [[nodiscard]] std::unique_ptr<Logger> stderr(bool verbose = false); +[[nodiscard]] +std::unique_ptr<Logger> prefix(Logger& logger, std::string prefix); + } // namespace logger #endif // LOGGER_HH diff --git a/src/main.cc b/src/main.cc index 6883f30..494c49c 100644 --- a/src/main.cc +++ b/src/main.cc @@ -3,9 +3,11 @@ #include "cfg.hh" #include "config.h" #include "http.hh" +#include "json.hh" #include "logger.hh" #include "looper.hh" #include "signals.hh" +#include "websocket.hh" #include <cerrno> #include <cstdint> @@ -16,6 +18,7 @@ #include <memory> #include <optional> #include <string> +#include <string_view> #include <unistd.h> #include <utility> #include <vector> @@ -26,9 +29,38 @@ namespace { +class Api { + public: + virtual ~Api() = default; + + [[nodiscard]] + virtual bt::Adapter* adapter() const = 0; + + protected: + Api() = default; +}; + +const std::string_view kSignalUpdateAdapter("controller/update"); + +class Signaler { + public: + virtual ~Signaler() = default; + + virtual void send(std::string_view signal) = 0; + virtual std::unique_ptr<http::Response> handle( + http::Request const& request) = 0; + + protected: + Signaler() = default; +}; + class HttpServerDelegate : public http::Server::Delegate { public: - HttpServerDelegate() = default; + HttpServerDelegate(Api& api, Signaler& signaler) + : api_(api), + signaler_(signaler), + json_writer_(json::writer(json_tmp_)), + json_mimetype_(http::MimeType::create("application", "json")) {} std::unique_ptr<http::Response> handle( http::Request const& request) override { @@ -36,21 +68,66 @@ class HttpServerDelegate : public http::Server::Delegate { return http::Response::status(http::StatusCode::kMethodNotAllowed); } - if (request.path() == "/api/v1/status") { - return http::Response::content( - R"({ status: "OK" })", - *http::MimeType::create("application", "json")); + if (request.path().starts_with("/api/v1/")) { + auto path = request.path().substr(8); + if (path == "status") { + json_writer_->clear(); + json_writer_->start_object(); + json_writer_->key("status"); + json_writer_->value("OK"); + json_writer_->end_object(); + return http::Response::content(json_tmp_, *json_mimetype_); + } + + if (path == "controller") { + auto* adapter = api_.adapter(); + json_writer_->clear(); + json_writer_->start_object(); + json_writer_->key("name"); + json_writer_->value(adapter ? adapter->name() : "unknown"); + json_writer_->key("pairable"); + json_writer_->value(adapter ? adapter->pairable() : false); + json_writer_->key("pairing"); + json_writer_->value(adapter ? adapter->pairing() : false); + json_writer_->end_object(); + + return http::Response::content(json_tmp_, *json_mimetype_); + } + + if (path == "events") { + auto resp = signaler_.handle(request); + if (resp) + return resp; + return http::Response::status(http::StatusCode::kBadRequest); + } } return http::Response::status(http::StatusCode::kNotFound); } + + private: + Api& api_; + Signaler& signaler_; + std::unique_ptr<json::Writer> json_writer_; + std::string json_tmp_; + std::unique_ptr<http::MimeType> json_mimetype_; }; -class BluetoothManagerDelegate : public bt::Manager::Delegate { +class BluetoothManagerDelegate : public bt::Manager::Delegate, public Api { public: - explicit BluetoothManagerDelegate(logger::Logger& logger) : logger_(logger) {} + BluetoothManagerDelegate(logger::Logger& logger, Signaler& signaler) + : logger_(logger), signaler_(signaler) {} + + [[nodiscard]] + bt::Adapter* adapter() const override { + return adapter_; + } void new_adapter(bt::Adapter* adapter) override { + adapter_ = adapter; + + signaler_.send(kSignalUpdateAdapter); + if (adapter) { logger_.info(std::format("New adapter: {} [{}]", adapter->name(), adapter->address())); @@ -59,6 +136,11 @@ class BluetoothManagerDelegate : public bt::Manager::Delegate { } } + void updated_adapter(bt::Adapter& adapter) override { + if (adapter_ == &adapter) + signaler_.send(kSignalUpdateAdapter); + } + void added_device(bt::Device& device) override { logger_.info( std::format("New device: {} [{}]", device.name(), device.address())); @@ -117,15 +199,42 @@ class BluetoothManagerDelegate : public bt::Manager::Delegate { private: logger::Logger& logger_; + Signaler& signaler_; + bt::Adapter* adapter_{nullptr}; +}; + +class SignalerImpl : public Signaler, ws::Server::Delegate { + public: + SignalerImpl(logger::Logger& logger, cfg::Config const& cfg, + looper::Looper& looper) + : server_(ws::create_server(logger, cfg, looper, *this)) {} + + void send(std::string_view signal) override { + server_->send_text_to_all(signal); + } + + std::unique_ptr<http::Response> handle( + http::Request const& request) override { + return server_->handle(request); + } + + std::unique_ptr<ws::Message> handle(ws::Message const& /* msg */) override { + // Ignore anything sent by clients + return nullptr; + } + + private: + std::unique_ptr<ws::Server> server_; }; bool run(logger::Logger& logger, cfg::Config const& cfg, std::unique_ptr<http::OpenPort> port) { auto looper = looper::create(); - HttpServerDelegate http_delegate; + SignalerImpl signaler(logger, cfg, *looper); + BluetoothManagerDelegate bt_delegate(logger, signaler); + HttpServerDelegate http_delegate(bt_delegate, signaler); auto server = http::create_server(logger, cfg, *looper, std::move(port), http_delegate); - BluetoothManagerDelegate bt_delegate(logger); auto manager = bt::create_manager(logger, cfg, *looper, bt_delegate); auto sigint_handler = signals::Handler::create( *looper, signals::Signal::INT, [&looper, &logger]() { diff --git a/src/websocket.cc b/src/websocket.cc new file mode 100644 index 0000000..391727d --- /dev/null +++ b/src/websocket.cc @@ -0,0 +1,676 @@ +#include "websocket.hh" + +#include "base64.hh" +#include "buffer.hh" +#include "cfg.hh" +#include "http.hh" +#include "logger.hh" +#include "looper.hh" +#include "sha1.hh" +#include "unique_fd.hh" + +#include <algorithm> +#include <cassert> +#include <cerrno> +#include <chrono> +#include <cstddef> +#include <cstring> +#include <format> +#include <iterator> +#include <memory> +#include <optional> +#include <string_view> +#include <sys/types.h> +#include <unistd.h> +#include <utility> +#include <vector> + +namespace ws { + +namespace { + +const std::string_view kWebsocketGuid("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + +enum OpCode : uint8_t { + kContinuation = 0, + kText = 1, + kBinary = 2, + kClose = 8, + kPing = 9, + kPong = 10, +}; + +void clear(Buffer& buffer) { + if (buffer.empty()) + return; + while (true) { + size_t avail; + buffer.rptr(avail); + if (avail == 0) + break; + buffer.consume(avail); + } +} + +class TextMessageImpl : public Message { + public: + explicit TextMessageImpl(std::string_view text) : text_(text) {} + + [[nodiscard]] + bool is_text() const override { + return true; + } + + [[nodiscard]] + bool is_binary() const override { + return false; + } + + [[nodiscard]] + std::string_view text() const override { + return text_; + } + + [[nodiscard]] + std::span<uint8_t const> binary() const override { + return {}; + } + + private: + std::string_view text_; +}; + +class BinaryMessageImpl : public Message { + public: + explicit BinaryMessageImpl(std::span<uint8_t const> data) : data_(data) {} + + [[nodiscard]] + bool is_text() const override { + return false; + } + + [[nodiscard]] + bool is_binary() const override { + return true; + } + + [[nodiscard]] + std::string_view text() const override { + return {}; + } + + [[nodiscard]] + std::span<uint8_t const> binary() const override { + return data_; + } + + private: + std::span<uint8_t const> data_; +}; + +class ServerImpl : public Server, http::SocketReceiver { + public: + ServerImpl(logger::Logger& logger, cfg::Config const& cfg, + looper::Looper& looper, Delegate& delegate) + : logger_(logger::prefix(logger, "ws")), + cfg_(cfg), + looper_(looper), + delegate_(delegate) { + client_half_timeout_ = + std::chrono::duration<double>( + cfg_.get_uint64("websocket.client.timeout.seconds").value_or(120)) / + 2; + client_.resize(cfg_.get_uint64("websocket.max.clients").value_or(100)); + } + + std::unique_ptr<http::Response> handle( + http::Request const& request) override { + if (request.method() != "GET") + return nullptr; + + if (!request.header_contains("upgrade", "websocket")) + return nullptr; + + if (!request.header_contains("connection", "upgrade")) + return nullptr; + + auto maybe_key = request.header_value("sec-websocket-key"); + if (!maybe_key.has_value()) + return nullptr; + auto maybe_nonce = base64::decode(maybe_key.value()); + if (!maybe_nonce.has_value() || maybe_nonce->size() != 16) + return nullptr; + + auto maybe_version = request.header_value("sec-websocket-version"); + if (!maybe_version.has_value()) + return nullptr; + if (maybe_version.value() != "13") { + auto builder = + http::Response::Builder::create(http::StatusCode::kBadRequest); + builder->add_header("Sec-WebSocket-Version", "13"); + return builder->build(); + } + + std::string accept{maybe_key.value()}; + accept.append(kWebsocketGuid); + accept = base64::encode(sha1::hash(accept)); + + auto builder = + http::Response::Builder::create(http::StatusCode::kSwitchingProtocols); + builder->add_header("Upgrade", "websocket"); + builder->add_header("Connection", "Upgrade"); + builder->add_header("Sec-WebSocket-Accept", accept); + builder->take_over(*this); + + return builder->build(); + } + + void receive(unique_fd&& fd) override { + if (active_clients_ == client_.size()) { + logger_->warn("Max number of clients already."); + return; + } + + for (; client_[next_client_].fd; next_client_++) { + if (next_client_ == client_.size()) + next_client_ = 0; + } + + logger_->dbg(std::format("New client: {}", next_client_)); + + auto& client = client_[next_client_]; + client.fd = std::move(fd); + client.last = std::chrono::steady_clock::now(); + + looper_.add( + client.fd.get(), looper::EVENT_READ, + [this, id = next_client_](auto events) { client_event(id, events); }); + client.timeout = looper_.schedule( + client_half_timeout_.count(), + [this, id = next_client_](auto handle) { client_timeout(id, handle); }); + + ++active_clients_; + ++next_client_; + } + + void send_to_all(Message const& msg) override { + auto active = active_clients_; + for (size_t i = 0; i < client_.size() && active; ++i) { + if (client_[i].fd) { + client_send(i, msg); + + if (--active == 0) + break; + } + } + } + + private: + static const size_t kInputSize = static_cast<size_t>(256) * 1024; + static const size_t kOutputSize = static_cast<size_t>(256) * 1024; + + struct Client { + unique_fd fd; + std::unique_ptr<Buffer> in{Buffer::fixed(kInputSize)}; + std::unique_ptr<Buffer> out{Buffer::fixed(kOutputSize)}; + std::chrono::steady_clock::time_point last; + uint32_t timeout{0}; + bool read_closed_{false}; + + uint8_t ping_{0}; + std::optional<uint8_t> expect_ping_; + bool expect_close_{false}; + + OpCode msg_{OpCode::kContinuation}; + std::vector<uint8_t> payload_; + }; + + void client_event(size_t client_id, uint8_t event) { + if (event & looper::EVENT_ERROR) { + logger_->info(std::format("Client socket error: {}", client_id)); + close_client(client_id); + return; + } + + auto& client = client_[client_id]; + if (event & looper::EVENT_READ) { + size_t avail; + auto* ptr = client.in->wptr(avail); + if (avail == 0) { + logger_->info(std::format("Client too large request: {}", client_id)); + close_client(client_id); + return; + } + + ssize_t got; + while (true) { + got = read(client.fd.get(), ptr, avail); + if (got < 0 && errno == EINTR) + continue; + break; + } + if (got < 0) { + if (errno != EWOULDBLOCK && errno != EAGAIN) { + logger_->info(std::format("Client read error: {}: {}", client_id, + strerror(errno))); + close_client(client_id); + return; + } + } else if (got == 0) { + client.read_closed_ = true; + } else { + client.in->commit(got); + + if (!client_read(client_id)) + return; + } + } + + if (!client_write(client_id)) + return; + + if (client.read_closed_ && client.in->empty() && client.out->empty()) { + close_client(client_id); + return; + } + + bool want_read = !client.read_closed_; + bool want_write = !client.out->empty(); + + looper_.update(client.fd.get(), (want_read ? looper::EVENT_READ : 0) | + (want_write ? looper::EVENT_WRITE : 0)); + } + + bool client_write(size_t client_id) { + auto& client = client_[client_id]; + while (true) { + size_t avail; + auto* ptr = client.out->rptr(avail); + if (avail == 0) + break; + ssize_t got; + while (true) { + got = write(client.fd.get(), ptr, avail); + if (got < 0 && errno == EINTR) + continue; + break; + } + if (got < 0) { + if (errno != EWOULDBLOCK && errno != EAGAIN) { + logger_->info(std::format("Client write error: {}: {}", client_id, + strerror(errno))); + close_client(client_id); + return false; + } + break; + } + client.out->consume(got); + + if (std::cmp_less(got, avail)) + break; + } + return true; + } + + bool client_read(size_t client_id) { + auto& client = client_[client_id]; + size_t avail; + auto* ptr = client.in->rptr(avail, /* need */ 2); + if (avail < 2) + return true; + + std::span<uint8_t const> data(reinterpret_cast<uint8_t const*>(ptr), avail); + bool fin = data[0] & 0x80; + if (data[1] & 0x70) { + logger_->info(std::format( + "Client invalid request, reserved bits are not zero: {}", client_id)); + close_client(client_id); + return false; + } + uint8_t opcode = data[0] & 0x0f; + switch (opcode) { + case OpCode::kContinuation: + case OpCode::kText: + case OpCode::kBinary: + case OpCode::kClose: + case OpCode::kPing: + case OpCode::kPong: + break; + default: + logger_->info(std::format( + "Client invalid request, reserved opcode used: {}", client_id)); + close_client(client_id); + return false; + } + bool mask = data[1] & 0x80; + if (!mask) { + logger_->info( + std::format("Client invalid request, not masked: {}", client_id)); + close_client(client_id); + return false; + } + size_t payload_offset = 2; + uint64_t payload_len = data[1] & 0x7f; + if (payload_len == 126) { + payload_offset += 2; + } else if (payload_len == 127) { + payload_offset += 8; + } + /* if (mask) */ payload_offset += 4; + + if (payload_offset > data.size()) { + ptr = client.in->rptr(avail, /* need */ payload_offset); + if (avail < payload_offset) + return true; + data = std::span(reinterpret_cast<uint8_t const*>(ptr), avail); + } + + if (payload_len == 126) { + payload_len = (static_cast<uint16_t>(data[2]) << 8) | data[3]; + } else if (payload_len == 127) { + payload_len = (static_cast<uint64_t>(data[2]) << 56) | + (static_cast<uint64_t>(data[3]) << 48) | + (static_cast<uint64_t>(data[4]) << 40) | + (static_cast<uint64_t>(data[5]) << 32) | + (static_cast<uint64_t>(data[6]) << 24) | + (static_cast<uint64_t>(data[7]) << 16) | + (static_cast<uint64_t>(data[8]) << 8) | data[9]; + } + /* if (mask) */ + auto mask_key = data.subspan(payload_offset - 4, 4); + + if (payload_offset + payload_len > data.size()) { + ptr = client.in->rptr(avail, /* need */ payload_offset + payload_len); + if (avail < payload_offset + payload_len) + return true; + data = std::span(reinterpret_cast<uint8_t const*>(ptr), avail); + } + + auto payload = data.subspan(payload_offset, payload_len); + /* if (mask) { */ + // Unmask the data + std::span<uint8_t> tmp{const_cast<uint8_t*>(payload.data()), + payload.size()}; + for (size_t i = 0; i < tmp.size(); ++i) { + tmp[i] ^= mask_key[i % 4]; + } + + if (client_handle(client_id, fin, static_cast<OpCode>(opcode), payload)) { + client.in->consume(payload_offset + payload_len); + return true; + } + return false; + } + + bool client_handle(size_t client_id, bool fin, OpCode opcode, + std::span<uint8_t const> payload) { + auto& client = client_[client_id]; + + // Validate continuation and fin frames. + switch (opcode) { + case OpCode::kContinuation: + if (client.msg_ == OpCode::kContinuation) { + logger_->info(std::format( + "Client invalid frame, unexpected continuation: {}", client_id)); + close_client(client_id); + return false; + } + break; + case OpCode::kBinary: + case OpCode::kText: + if (client.msg_ != OpCode::kContinuation) { + logger_->info(std::format( + "Client invalid frame, unexpected non-continuation: {}", + client_id)); + close_client(client_id); + return false; + } + break; + case OpCode::kClose: + case OpCode::kPing: + case OpCode::kPong: + if (!fin) { + logger_->info( + std::format("Client invalid control frame: {}", client_id)); + close_client(client_id); + return false; + } + break; + } + + switch (opcode) { + case OpCode::kContinuation: + std::ranges::copy(payload, std::back_inserter(client.payload_)); + if (fin) { + std::unique_ptr<Message> reply; + switch (client.msg_) { + case OpCode::kText: + reply = delegate_.handle(TextMessageImpl{std::string_view( + reinterpret_cast<char const*>(client.payload_.data()), + client.payload_.size())}); + break; + case OpCode::kBinary: + reply = delegate_.handle(BinaryMessageImpl{client.payload_}); + break; + default: + std::unreachable(); + } + client.msg_ = OpCode::kContinuation; + client.payload_.clear(); + if (reply) { + if (!client_send(client_id, *reply)) + return false; + } + } + break; + case OpCode::kBinary: + case OpCode::kText: + if (fin) { + std::unique_ptr<Message> reply; + if (opcode == OpCode::kText) { + reply = delegate_.handle(TextMessageImpl{ + std::string_view(reinterpret_cast<char const*>(payload.data()), + payload.size())}); + } else { + reply = delegate_.handle(BinaryMessageImpl{payload}); + } + if (reply) { + if (!client_send(client_id, *reply)) + return false; + } + } else { + client.msg_ = opcode; + client.payload_.assign(payload.begin(), payload.end()); + } + break; + case OpCode::kClose: + if (client.expect_close_) { + close_client(client_id); + return false; + } + if (!client_send(client_id, OpCode::kClose, payload)) + return false; + client.read_closed_ = true; + break; + case OpCode::kPing: + if (!client_send(client_id, OpCode::kPong, payload)) + return false; + break; + case OpCode::kPong: + if (client.expect_ping_.has_value()) { + if (payload.size() != 1 || + payload[0] != client.expect_ping_.value()) { + logger_->info( + std::format("Client closed, mismatched pong: {}", client_id)); + close_client(client_id); + return false; + } + client.expect_ping_.reset(); + } else { + // A Pong frame MAY be sent unsolicited. This serves as a unidirectional heartbeat. + } + break; + } + return true; + } + + bool client_send(size_t client_id, Message const& message) { + if (message.is_text()) { + auto payload = message.text(); + return client_send( + client_id, OpCode::kText, + std::span{reinterpret_cast<uint8_t const*>(payload.data()), + payload.size()}); + } + return client_send(client_id, OpCode::kBinary, message.binary()); + } + + bool client_send(size_t client_id, OpCode opcode, + std::span<uint8_t const> payload) { + auto& client = client_[client_id]; + if (client.read_closed_) + return true; + + uint8_t header[10]; + size_t header_size = 0; + // Always send FIN frames. + header[header_size++] = 0x80 | std::to_underlying(opcode); + // Server never sends masked messages. + if (payload.size() < 126) { + header[header_size++] = payload.size(); + } else if (payload.size() <= 0xffff) { + header[header_size++] = 126; + header[header_size++] = payload.size() >> 8; + header[header_size++] = payload.size() & 0xff; + } else { + header[header_size++] = 127; + header[header_size++] = payload.size() >> 56; + header[header_size++] = (payload.size() & 0xff000000000000) >> 48; + header[header_size++] = (payload.size() & 0xff0000000000) >> 40; + header[header_size++] = (payload.size() & 0xff00000000) >> 32; + header[header_size++] = (payload.size() & 0xff000000) >> 24; + header[header_size++] = (payload.size() & 0xff0000) >> 16; + header[header_size++] = (payload.size() & 0xff00) >> 8; + header[header_size++] = payload.size() & 0xff; + } + size_t avail; + bool const was_empty = client.out->empty(); + auto* wptr = + client.out->wptr(avail, /* need */ header_size + payload.size()); + if (avail < header_size + payload.size()) { + logger_->info( + std::format("Client closed, too much output: {}", client_id)); + close_client(client_id); + return false; + } + auto* data = reinterpret_cast<uint8_t*>(wptr); + std::copy_n(header, header_size, data); + std::ranges::copy(payload, data + header_size); + client.out->commit(header_size + payload.size()); + + if (was_empty) { + if (!client_write(client_id)) + return false; + } + + return true; + } + + void client_timeout(size_t client_id, uint32_t id) { + auto now = std::chrono::steady_clock::now(); + auto& client = client_[client_id]; + assert(client.timeout == id); + if (now - client.last < + std::chrono::duration_cast<std::chrono::steady_clock::duration>( + client_half_timeout_)) { + // TODO: Reschedule for delay left, not the full one + client.timeout = looper_.schedule(client_half_timeout_.count(), + [this, client_id](auto handle) { + client_timeout(client_id, handle); + }); + return; + } + + client.timeout = 0; + + if (client.expect_ping_.has_value()) { + logger_->dbg(std::format("Client timeout: {}", client_id)); + + close_client(client_id); + } else { + client.expect_ping_ = ++client.ping_; + if (client_send(client_id, OpCode::kPing, std::span{&client.ping_, 1})) { + client.timeout = looper_.schedule(client_half_timeout_.count(), + [this, client_id](auto handle) { + client_timeout(client_id, handle); + }); + } + } + } + + void close_client(size_t client_id) { + auto& client = client_[client_id]; + if (!client.fd) { + assert(false); + return; + } + + logger_->dbg(std::format("Drop client: {}", client_id)); + + looper_.remove(client.fd.get()); + if (client.timeout) + looper_.cancel(client.timeout); + client.fd.reset(); + clear(*client.in); + clear(*client.out); + client.read_closed_ = false; + client.expect_ping_.reset(); + client.ping_ = 0; + client.expect_close_ = false; + client.msg_ = OpCode::kContinuation; + client.payload_.clear(); + + assert(active_clients_ > 0); + --active_clients_; + if (next_client_ == client_id + 1) + next_client_ = client_id; + } + + std::unique_ptr<logger::Logger> logger_; + cfg::Config const& cfg_; + looper::Looper& looper_; + Delegate& delegate_; + + // Timeout first sends a PING and then if it times out again, then closes connection. + std::chrono::duration<double> client_half_timeout_; + std::vector<Client> client_; + size_t next_client_{0}; + size_t active_clients_{0}; +}; + +} // namespace + +void Server::send_text_to_all(std::string_view text) { + send_to_all(TextMessageImpl{text}); +} + +void Server::send_binary_to_all(std::span<uint8_t const> data) { + send_to_all(BinaryMessageImpl{data}); +} + +std::unique_ptr<Server> create_server(logger::Logger& logger, + cfg::Config const& cfg, + looper::Looper& looper, + Server::Delegate& delegate) { + return std::make_unique<ServerImpl>(logger, cfg, looper, delegate); +} + +std::unique_ptr<Message> Message::create(std::string_view text) { + return std::make_unique<TextMessageImpl>(text); +} + +std::unique_ptr<Message> Message::create(std::span<uint8_t const> data) { + return std::make_unique<BinaryMessageImpl>(data); +} + +} // namespace ws diff --git a/src/websocket.hh b/src/websocket.hh new file mode 100644 index 0000000..f7a99f1 --- /dev/null +++ b/src/websocket.hh @@ -0,0 +1,88 @@ +#ifndef WEBSOCKET_HH +#define WEBSOCKET_HH + +#include <memory> +#include <span> +#include <string_view> + +namespace logger { +class Logger; +} // namespace logger + +namespace looper { +class Looper; +} // namespace looper + +namespace cfg { +class Config; +} // namespace cfg + +namespace http { +class Response; +class Request; +} // namespace http + +namespace ws { + +class Message { + public: + virtual ~Message() = default; + + [[nodiscard]] + virtual bool is_text() const = 0; + + [[nodiscard]] + virtual bool is_binary() const = 0; + + [[nodiscard]] + virtual std::string_view text() const = 0; + + [[nodiscard]] + virtual std::span<uint8_t const> binary() const = 0; + + static std::unique_ptr<Message> create(std::string_view text); + static std::unique_ptr<Message> create(std::span<uint8_t const> data); + + protected: + Message() = default; + Message(Message const&) = delete; + Message& operator=(Message const&) = delete; +}; + +class Server { + public: + virtual ~Server() = default; + + class Delegate { + public: + virtual ~Delegate() = default; + + virtual std::unique_ptr<Message> handle(Message const& msg) = 0; + + protected: + Delegate() = default; + }; + + virtual std::unique_ptr<http::Response> handle( + http::Request const& request) = 0; + + virtual void send_text_to_all(std::string_view text); + virtual void send_binary_to_all(std::span<uint8_t const> data); + virtual void send_to_all(Message const& message) = 0; + + protected: + Server() = default; + + private: + Server(Server const&) = delete; + Server& operator=(Server const&) = delete; +}; + +std::unique_ptr<Server> create_server(logger::Logger& logger, + cfg::Config const& cfg, + looper::Looper& looper, + Server::Delegate& delegate); + +} // namespace ws + +#endif // WEBSOCKET_HH |
