From e7c74917191a4953d495295b65732aa3549ba753 Mon Sep 17 00:00:00 2001 From: Joel Klinghed Date: Sun, 19 Oct 2025 00:12:50 +0200 Subject: Add new module websocket and use it Implement /api/v1/events which will send out messages when things change, such as the main controller. --- src/websocket.cc | 676 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 676 insertions(+) create mode 100644 src/websocket.cc (limited to 'src/websocket.cc') diff --git a/src/websocket.cc b/src/websocket.cc new file mode 100644 index 0000000..391727d --- /dev/null +++ b/src/websocket.cc @@ -0,0 +1,676 @@ +#include "websocket.hh" + +#include "base64.hh" +#include "buffer.hh" +#include "cfg.hh" +#include "http.hh" +#include "logger.hh" +#include "looper.hh" +#include "sha1.hh" +#include "unique_fd.hh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ws { + +namespace { + +const std::string_view kWebsocketGuid("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + +enum OpCode : uint8_t { + kContinuation = 0, + kText = 1, + kBinary = 2, + kClose = 8, + kPing = 9, + kPong = 10, +}; + +void clear(Buffer& buffer) { + if (buffer.empty()) + return; + while (true) { + size_t avail; + buffer.rptr(avail); + if (avail == 0) + break; + buffer.consume(avail); + } +} + +class TextMessageImpl : public Message { + public: + explicit TextMessageImpl(std::string_view text) : text_(text) {} + + [[nodiscard]] + bool is_text() const override { + return true; + } + + [[nodiscard]] + bool is_binary() const override { + return false; + } + + [[nodiscard]] + std::string_view text() const override { + return text_; + } + + [[nodiscard]] + std::span binary() const override { + return {}; + } + + private: + std::string_view text_; +}; + +class BinaryMessageImpl : public Message { + public: + explicit BinaryMessageImpl(std::span data) : data_(data) {} + + [[nodiscard]] + bool is_text() const override { + return false; + } + + [[nodiscard]] + bool is_binary() const override { + return true; + } + + [[nodiscard]] + std::string_view text() const override { + return {}; + } + + [[nodiscard]] + std::span binary() const override { + return data_; + } + + private: + std::span data_; +}; + +class ServerImpl : public Server, http::SocketReceiver { + public: + ServerImpl(logger::Logger& logger, cfg::Config const& cfg, + looper::Looper& looper, Delegate& delegate) + : logger_(logger::prefix(logger, "ws")), + cfg_(cfg), + looper_(looper), + delegate_(delegate) { + client_half_timeout_ = + std::chrono::duration( + cfg_.get_uint64("websocket.client.timeout.seconds").value_or(120)) / + 2; + client_.resize(cfg_.get_uint64("websocket.max.clients").value_or(100)); + } + + std::unique_ptr handle( + http::Request const& request) override { + if (request.method() != "GET") + return nullptr; + + if (!request.header_contains("upgrade", "websocket")) + return nullptr; + + if (!request.header_contains("connection", "upgrade")) + return nullptr; + + auto maybe_key = request.header_value("sec-websocket-key"); + if (!maybe_key.has_value()) + return nullptr; + auto maybe_nonce = base64::decode(maybe_key.value()); + if (!maybe_nonce.has_value() || maybe_nonce->size() != 16) + return nullptr; + + auto maybe_version = request.header_value("sec-websocket-version"); + if (!maybe_version.has_value()) + return nullptr; + if (maybe_version.value() != "13") { + auto builder = + http::Response::Builder::create(http::StatusCode::kBadRequest); + builder->add_header("Sec-WebSocket-Version", "13"); + return builder->build(); + } + + std::string accept{maybe_key.value()}; + accept.append(kWebsocketGuid); + accept = base64::encode(sha1::hash(accept)); + + auto builder = + http::Response::Builder::create(http::StatusCode::kSwitchingProtocols); + builder->add_header("Upgrade", "websocket"); + builder->add_header("Connection", "Upgrade"); + builder->add_header("Sec-WebSocket-Accept", accept); + builder->take_over(*this); + + return builder->build(); + } + + void receive(unique_fd&& fd) override { + if (active_clients_ == client_.size()) { + logger_->warn("Max number of clients already."); + return; + } + + for (; client_[next_client_].fd; next_client_++) { + if (next_client_ == client_.size()) + next_client_ = 0; + } + + logger_->dbg(std::format("New client: {}", next_client_)); + + auto& client = client_[next_client_]; + client.fd = std::move(fd); + client.last = std::chrono::steady_clock::now(); + + looper_.add( + client.fd.get(), looper::EVENT_READ, + [this, id = next_client_](auto events) { client_event(id, events); }); + client.timeout = looper_.schedule( + client_half_timeout_.count(), + [this, id = next_client_](auto handle) { client_timeout(id, handle); }); + + ++active_clients_; + ++next_client_; + } + + void send_to_all(Message const& msg) override { + auto active = active_clients_; + for (size_t i = 0; i < client_.size() && active; ++i) { + if (client_[i].fd) { + client_send(i, msg); + + if (--active == 0) + break; + } + } + } + + private: + static const size_t kInputSize = static_cast(256) * 1024; + static const size_t kOutputSize = static_cast(256) * 1024; + + struct Client { + unique_fd fd; + std::unique_ptr in{Buffer::fixed(kInputSize)}; + std::unique_ptr out{Buffer::fixed(kOutputSize)}; + std::chrono::steady_clock::time_point last; + uint32_t timeout{0}; + bool read_closed_{false}; + + uint8_t ping_{0}; + std::optional expect_ping_; + bool expect_close_{false}; + + OpCode msg_{OpCode::kContinuation}; + std::vector payload_; + }; + + void client_event(size_t client_id, uint8_t event) { + if (event & looper::EVENT_ERROR) { + logger_->info(std::format("Client socket error: {}", client_id)); + close_client(client_id); + return; + } + + auto& client = client_[client_id]; + if (event & looper::EVENT_READ) { + size_t avail; + auto* ptr = client.in->wptr(avail); + if (avail == 0) { + logger_->info(std::format("Client too large request: {}", client_id)); + close_client(client_id); + return; + } + + ssize_t got; + while (true) { + got = read(client.fd.get(), ptr, avail); + if (got < 0 && errno == EINTR) + continue; + break; + } + if (got < 0) { + if (errno != EWOULDBLOCK && errno != EAGAIN) { + logger_->info(std::format("Client read error: {}: {}", client_id, + strerror(errno))); + close_client(client_id); + return; + } + } else if (got == 0) { + client.read_closed_ = true; + } else { + client.in->commit(got); + + if (!client_read(client_id)) + return; + } + } + + if (!client_write(client_id)) + return; + + if (client.read_closed_ && client.in->empty() && client.out->empty()) { + close_client(client_id); + return; + } + + bool want_read = !client.read_closed_; + bool want_write = !client.out->empty(); + + looper_.update(client.fd.get(), (want_read ? looper::EVENT_READ : 0) | + (want_write ? looper::EVENT_WRITE : 0)); + } + + bool client_write(size_t client_id) { + auto& client = client_[client_id]; + while (true) { + size_t avail; + auto* ptr = client.out->rptr(avail); + if (avail == 0) + break; + ssize_t got; + while (true) { + got = write(client.fd.get(), ptr, avail); + if (got < 0 && errno == EINTR) + continue; + break; + } + if (got < 0) { + if (errno != EWOULDBLOCK && errno != EAGAIN) { + logger_->info(std::format("Client write error: {}: {}", client_id, + strerror(errno))); + close_client(client_id); + return false; + } + break; + } + client.out->consume(got); + + if (std::cmp_less(got, avail)) + break; + } + return true; + } + + bool client_read(size_t client_id) { + auto& client = client_[client_id]; + size_t avail; + auto* ptr = client.in->rptr(avail, /* need */ 2); + if (avail < 2) + return true; + + std::span data(reinterpret_cast(ptr), avail); + bool fin = data[0] & 0x80; + if (data[1] & 0x70) { + logger_->info(std::format( + "Client invalid request, reserved bits are not zero: {}", client_id)); + close_client(client_id); + return false; + } + uint8_t opcode = data[0] & 0x0f; + switch (opcode) { + case OpCode::kContinuation: + case OpCode::kText: + case OpCode::kBinary: + case OpCode::kClose: + case OpCode::kPing: + case OpCode::kPong: + break; + default: + logger_->info(std::format( + "Client invalid request, reserved opcode used: {}", client_id)); + close_client(client_id); + return false; + } + bool mask = data[1] & 0x80; + if (!mask) { + logger_->info( + std::format("Client invalid request, not masked: {}", client_id)); + close_client(client_id); + return false; + } + size_t payload_offset = 2; + uint64_t payload_len = data[1] & 0x7f; + if (payload_len == 126) { + payload_offset += 2; + } else if (payload_len == 127) { + payload_offset += 8; + } + /* if (mask) */ payload_offset += 4; + + if (payload_offset > data.size()) { + ptr = client.in->rptr(avail, /* need */ payload_offset); + if (avail < payload_offset) + return true; + data = std::span(reinterpret_cast(ptr), avail); + } + + if (payload_len == 126) { + payload_len = (static_cast(data[2]) << 8) | data[3]; + } else if (payload_len == 127) { + payload_len = (static_cast(data[2]) << 56) | + (static_cast(data[3]) << 48) | + (static_cast(data[4]) << 40) | + (static_cast(data[5]) << 32) | + (static_cast(data[6]) << 24) | + (static_cast(data[7]) << 16) | + (static_cast(data[8]) << 8) | data[9]; + } + /* if (mask) */ + auto mask_key = data.subspan(payload_offset - 4, 4); + + if (payload_offset + payload_len > data.size()) { + ptr = client.in->rptr(avail, /* need */ payload_offset + payload_len); + if (avail < payload_offset + payload_len) + return true; + data = std::span(reinterpret_cast(ptr), avail); + } + + auto payload = data.subspan(payload_offset, payload_len); + /* if (mask) { */ + // Unmask the data + std::span tmp{const_cast(payload.data()), + payload.size()}; + for (size_t i = 0; i < tmp.size(); ++i) { + tmp[i] ^= mask_key[i % 4]; + } + + if (client_handle(client_id, fin, static_cast(opcode), payload)) { + client.in->consume(payload_offset + payload_len); + return true; + } + return false; + } + + bool client_handle(size_t client_id, bool fin, OpCode opcode, + std::span payload) { + auto& client = client_[client_id]; + + // Validate continuation and fin frames. + switch (opcode) { + case OpCode::kContinuation: + if (client.msg_ == OpCode::kContinuation) { + logger_->info(std::format( + "Client invalid frame, unexpected continuation: {}", client_id)); + close_client(client_id); + return false; + } + break; + case OpCode::kBinary: + case OpCode::kText: + if (client.msg_ != OpCode::kContinuation) { + logger_->info(std::format( + "Client invalid frame, unexpected non-continuation: {}", + client_id)); + close_client(client_id); + return false; + } + break; + case OpCode::kClose: + case OpCode::kPing: + case OpCode::kPong: + if (!fin) { + logger_->info( + std::format("Client invalid control frame: {}", client_id)); + close_client(client_id); + return false; + } + break; + } + + switch (opcode) { + case OpCode::kContinuation: + std::ranges::copy(payload, std::back_inserter(client.payload_)); + if (fin) { + std::unique_ptr reply; + switch (client.msg_) { + case OpCode::kText: + reply = delegate_.handle(TextMessageImpl{std::string_view( + reinterpret_cast(client.payload_.data()), + client.payload_.size())}); + break; + case OpCode::kBinary: + reply = delegate_.handle(BinaryMessageImpl{client.payload_}); + break; + default: + std::unreachable(); + } + client.msg_ = OpCode::kContinuation; + client.payload_.clear(); + if (reply) { + if (!client_send(client_id, *reply)) + return false; + } + } + break; + case OpCode::kBinary: + case OpCode::kText: + if (fin) { + std::unique_ptr reply; + if (opcode == OpCode::kText) { + reply = delegate_.handle(TextMessageImpl{ + std::string_view(reinterpret_cast(payload.data()), + payload.size())}); + } else { + reply = delegate_.handle(BinaryMessageImpl{payload}); + } + if (reply) { + if (!client_send(client_id, *reply)) + return false; + } + } else { + client.msg_ = opcode; + client.payload_.assign(payload.begin(), payload.end()); + } + break; + case OpCode::kClose: + if (client.expect_close_) { + close_client(client_id); + return false; + } + if (!client_send(client_id, OpCode::kClose, payload)) + return false; + client.read_closed_ = true; + break; + case OpCode::kPing: + if (!client_send(client_id, OpCode::kPong, payload)) + return false; + break; + case OpCode::kPong: + if (client.expect_ping_.has_value()) { + if (payload.size() != 1 || + payload[0] != client.expect_ping_.value()) { + logger_->info( + std::format("Client closed, mismatched pong: {}", client_id)); + close_client(client_id); + return false; + } + client.expect_ping_.reset(); + } else { + // A Pong frame MAY be sent unsolicited. This serves as a unidirectional heartbeat. + } + break; + } + return true; + } + + bool client_send(size_t client_id, Message const& message) { + if (message.is_text()) { + auto payload = message.text(); + return client_send( + client_id, OpCode::kText, + std::span{reinterpret_cast(payload.data()), + payload.size()}); + } + return client_send(client_id, OpCode::kBinary, message.binary()); + } + + bool client_send(size_t client_id, OpCode opcode, + std::span payload) { + auto& client = client_[client_id]; + if (client.read_closed_) + return true; + + uint8_t header[10]; + size_t header_size = 0; + // Always send FIN frames. + header[header_size++] = 0x80 | std::to_underlying(opcode); + // Server never sends masked messages. + if (payload.size() < 126) { + header[header_size++] = payload.size(); + } else if (payload.size() <= 0xffff) { + header[header_size++] = 126; + header[header_size++] = payload.size() >> 8; + header[header_size++] = payload.size() & 0xff; + } else { + header[header_size++] = 127; + header[header_size++] = payload.size() >> 56; + header[header_size++] = (payload.size() & 0xff000000000000) >> 48; + header[header_size++] = (payload.size() & 0xff0000000000) >> 40; + header[header_size++] = (payload.size() & 0xff00000000) >> 32; + header[header_size++] = (payload.size() & 0xff000000) >> 24; + header[header_size++] = (payload.size() & 0xff0000) >> 16; + header[header_size++] = (payload.size() & 0xff00) >> 8; + header[header_size++] = payload.size() & 0xff; + } + size_t avail; + bool const was_empty = client.out->empty(); + auto* wptr = + client.out->wptr(avail, /* need */ header_size + payload.size()); + if (avail < header_size + payload.size()) { + logger_->info( + std::format("Client closed, too much output: {}", client_id)); + close_client(client_id); + return false; + } + auto* data = reinterpret_cast(wptr); + std::copy_n(header, header_size, data); + std::ranges::copy(payload, data + header_size); + client.out->commit(header_size + payload.size()); + + if (was_empty) { + if (!client_write(client_id)) + return false; + } + + return true; + } + + void client_timeout(size_t client_id, uint32_t id) { + auto now = std::chrono::steady_clock::now(); + auto& client = client_[client_id]; + assert(client.timeout == id); + if (now - client.last < + std::chrono::duration_cast( + client_half_timeout_)) { + // TODO: Reschedule for delay left, not the full one + client.timeout = looper_.schedule(client_half_timeout_.count(), + [this, client_id](auto handle) { + client_timeout(client_id, handle); + }); + return; + } + + client.timeout = 0; + + if (client.expect_ping_.has_value()) { + logger_->dbg(std::format("Client timeout: {}", client_id)); + + close_client(client_id); + } else { + client.expect_ping_ = ++client.ping_; + if (client_send(client_id, OpCode::kPing, std::span{&client.ping_, 1})) { + client.timeout = looper_.schedule(client_half_timeout_.count(), + [this, client_id](auto handle) { + client_timeout(client_id, handle); + }); + } + } + } + + void close_client(size_t client_id) { + auto& client = client_[client_id]; + if (!client.fd) { + assert(false); + return; + } + + logger_->dbg(std::format("Drop client: {}", client_id)); + + looper_.remove(client.fd.get()); + if (client.timeout) + looper_.cancel(client.timeout); + client.fd.reset(); + clear(*client.in); + clear(*client.out); + client.read_closed_ = false; + client.expect_ping_.reset(); + client.ping_ = 0; + client.expect_close_ = false; + client.msg_ = OpCode::kContinuation; + client.payload_.clear(); + + assert(active_clients_ > 0); + --active_clients_; + if (next_client_ == client_id + 1) + next_client_ = client_id; + } + + std::unique_ptr logger_; + cfg::Config const& cfg_; + looper::Looper& looper_; + Delegate& delegate_; + + // Timeout first sends a PING and then if it times out again, then closes connection. + std::chrono::duration client_half_timeout_; + std::vector client_; + size_t next_client_{0}; + size_t active_clients_{0}; +}; + +} // namespace + +void Server::send_text_to_all(std::string_view text) { + send_to_all(TextMessageImpl{text}); +} + +void Server::send_binary_to_all(std::span data) { + send_to_all(BinaryMessageImpl{data}); +} + +std::unique_ptr create_server(logger::Logger& logger, + cfg::Config const& cfg, + looper::Looper& looper, + Server::Delegate& delegate) { + return std::make_unique(logger, cfg, looper, delegate); +} + +std::unique_ptr Message::create(std::string_view text) { + return std::make_unique(text); +} + +std::unique_ptr Message::create(std::span data) { + return std::make_unique(data); +} + +} // namespace ws -- cgit v1.2.3-70-g09d2