summaryrefslogtreecommitdiff
path: root/src/websocket.cc
diff options
context:
space:
mode:
authorJoel Klinghed <the_jk@spawned.biz>2025-10-19 00:12:50 +0200
committerJoel Klinghed <the_jk@spawned.biz>2025-10-19 00:31:19 +0200
commite7c74917191a4953d495295b65732aa3549ba753 (patch)
treefde3fc6080786b3e3e4526b3793bacbb390d2b17 /src/websocket.cc
parent4f6ead7c2c646b6b866274299c05d08170d2dfb0 (diff)
Add new module websocket and use it
Implement /api/v1/events which will send out messages when things change, such as the main controller.
Diffstat (limited to 'src/websocket.cc')
-rw-r--r--src/websocket.cc676
1 files changed, 676 insertions, 0 deletions
diff --git a/src/websocket.cc b/src/websocket.cc
new file mode 100644
index 0000000..391727d
--- /dev/null
+++ b/src/websocket.cc
@@ -0,0 +1,676 @@
+#include "websocket.hh"
+
+#include "base64.hh"
+#include "buffer.hh"
+#include "cfg.hh"
+#include "http.hh"
+#include "logger.hh"
+#include "looper.hh"
+#include "sha1.hh"
+#include "unique_fd.hh"
+
+#include <algorithm>
+#include <cassert>
+#include <cerrno>
+#include <chrono>
+#include <cstddef>
+#include <cstring>
+#include <format>
+#include <iterator>
+#include <memory>
+#include <optional>
+#include <string_view>
+#include <sys/types.h>
+#include <unistd.h>
+#include <utility>
+#include <vector>
+
+namespace ws {
+
+namespace {
+
+const std::string_view kWebsocketGuid("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
+
+enum OpCode : uint8_t {
+ kContinuation = 0,
+ kText = 1,
+ kBinary = 2,
+ kClose = 8,
+ kPing = 9,
+ kPong = 10,
+};
+
+void clear(Buffer& buffer) {
+ if (buffer.empty())
+ return;
+ while (true) {
+ size_t avail;
+ buffer.rptr(avail);
+ if (avail == 0)
+ break;
+ buffer.consume(avail);
+ }
+}
+
+class TextMessageImpl : public Message {
+ public:
+ explicit TextMessageImpl(std::string_view text) : text_(text) {}
+
+ [[nodiscard]]
+ bool is_text() const override {
+ return true;
+ }
+
+ [[nodiscard]]
+ bool is_binary() const override {
+ return false;
+ }
+
+ [[nodiscard]]
+ std::string_view text() const override {
+ return text_;
+ }
+
+ [[nodiscard]]
+ std::span<uint8_t const> binary() const override {
+ return {};
+ }
+
+ private:
+ std::string_view text_;
+};
+
+class BinaryMessageImpl : public Message {
+ public:
+ explicit BinaryMessageImpl(std::span<uint8_t const> data) : data_(data) {}
+
+ [[nodiscard]]
+ bool is_text() const override {
+ return false;
+ }
+
+ [[nodiscard]]
+ bool is_binary() const override {
+ return true;
+ }
+
+ [[nodiscard]]
+ std::string_view text() const override {
+ return {};
+ }
+
+ [[nodiscard]]
+ std::span<uint8_t const> binary() const override {
+ return data_;
+ }
+
+ private:
+ std::span<uint8_t const> data_;
+};
+
+class ServerImpl : public Server, http::SocketReceiver {
+ public:
+ ServerImpl(logger::Logger& logger, cfg::Config const& cfg,
+ looper::Looper& looper, Delegate& delegate)
+ : logger_(logger::prefix(logger, "ws")),
+ cfg_(cfg),
+ looper_(looper),
+ delegate_(delegate) {
+ client_half_timeout_ =
+ std::chrono::duration<double>(
+ cfg_.get_uint64("websocket.client.timeout.seconds").value_or(120)) /
+ 2;
+ client_.resize(cfg_.get_uint64("websocket.max.clients").value_or(100));
+ }
+
+ std::unique_ptr<http::Response> handle(
+ http::Request const& request) override {
+ if (request.method() != "GET")
+ return nullptr;
+
+ if (!request.header_contains("upgrade", "websocket"))
+ return nullptr;
+
+ if (!request.header_contains("connection", "upgrade"))
+ return nullptr;
+
+ auto maybe_key = request.header_value("sec-websocket-key");
+ if (!maybe_key.has_value())
+ return nullptr;
+ auto maybe_nonce = base64::decode(maybe_key.value());
+ if (!maybe_nonce.has_value() || maybe_nonce->size() != 16)
+ return nullptr;
+
+ auto maybe_version = request.header_value("sec-websocket-version");
+ if (!maybe_version.has_value())
+ return nullptr;
+ if (maybe_version.value() != "13") {
+ auto builder =
+ http::Response::Builder::create(http::StatusCode::kBadRequest);
+ builder->add_header("Sec-WebSocket-Version", "13");
+ return builder->build();
+ }
+
+ std::string accept{maybe_key.value()};
+ accept.append(kWebsocketGuid);
+ accept = base64::encode(sha1::hash(accept));
+
+ auto builder =
+ http::Response::Builder::create(http::StatusCode::kSwitchingProtocols);
+ builder->add_header("Upgrade", "websocket");
+ builder->add_header("Connection", "Upgrade");
+ builder->add_header("Sec-WebSocket-Accept", accept);
+ builder->take_over(*this);
+
+ return builder->build();
+ }
+
+ void receive(unique_fd&& fd) override {
+ if (active_clients_ == client_.size()) {
+ logger_->warn("Max number of clients already.");
+ return;
+ }
+
+ for (; client_[next_client_].fd; next_client_++) {
+ if (next_client_ == client_.size())
+ next_client_ = 0;
+ }
+
+ logger_->dbg(std::format("New client: {}", next_client_));
+
+ auto& client = client_[next_client_];
+ client.fd = std::move(fd);
+ client.last = std::chrono::steady_clock::now();
+
+ looper_.add(
+ client.fd.get(), looper::EVENT_READ,
+ [this, id = next_client_](auto events) { client_event(id, events); });
+ client.timeout = looper_.schedule(
+ client_half_timeout_.count(),
+ [this, id = next_client_](auto handle) { client_timeout(id, handle); });
+
+ ++active_clients_;
+ ++next_client_;
+ }
+
+ void send_to_all(Message const& msg) override {
+ auto active = active_clients_;
+ for (size_t i = 0; i < client_.size() && active; ++i) {
+ if (client_[i].fd) {
+ client_send(i, msg);
+
+ if (--active == 0)
+ break;
+ }
+ }
+ }
+
+ private:
+ static const size_t kInputSize = static_cast<size_t>(256) * 1024;
+ static const size_t kOutputSize = static_cast<size_t>(256) * 1024;
+
+ struct Client {
+ unique_fd fd;
+ std::unique_ptr<Buffer> in{Buffer::fixed(kInputSize)};
+ std::unique_ptr<Buffer> out{Buffer::fixed(kOutputSize)};
+ std::chrono::steady_clock::time_point last;
+ uint32_t timeout{0};
+ bool read_closed_{false};
+
+ uint8_t ping_{0};
+ std::optional<uint8_t> expect_ping_;
+ bool expect_close_{false};
+
+ OpCode msg_{OpCode::kContinuation};
+ std::vector<uint8_t> payload_;
+ };
+
+ void client_event(size_t client_id, uint8_t event) {
+ if (event & looper::EVENT_ERROR) {
+ logger_->info(std::format("Client socket error: {}", client_id));
+ close_client(client_id);
+ return;
+ }
+
+ auto& client = client_[client_id];
+ if (event & looper::EVENT_READ) {
+ size_t avail;
+ auto* ptr = client.in->wptr(avail);
+ if (avail == 0) {
+ logger_->info(std::format("Client too large request: {}", client_id));
+ close_client(client_id);
+ return;
+ }
+
+ ssize_t got;
+ while (true) {
+ got = read(client.fd.get(), ptr, avail);
+ if (got < 0 && errno == EINTR)
+ continue;
+ break;
+ }
+ if (got < 0) {
+ if (errno != EWOULDBLOCK && errno != EAGAIN) {
+ logger_->info(std::format("Client read error: {}: {}", client_id,
+ strerror(errno)));
+ close_client(client_id);
+ return;
+ }
+ } else if (got == 0) {
+ client.read_closed_ = true;
+ } else {
+ client.in->commit(got);
+
+ if (!client_read(client_id))
+ return;
+ }
+ }
+
+ if (!client_write(client_id))
+ return;
+
+ if (client.read_closed_ && client.in->empty() && client.out->empty()) {
+ close_client(client_id);
+ return;
+ }
+
+ bool want_read = !client.read_closed_;
+ bool want_write = !client.out->empty();
+
+ looper_.update(client.fd.get(), (want_read ? looper::EVENT_READ : 0) |
+ (want_write ? looper::EVENT_WRITE : 0));
+ }
+
+ bool client_write(size_t client_id) {
+ auto& client = client_[client_id];
+ while (true) {
+ size_t avail;
+ auto* ptr = client.out->rptr(avail);
+ if (avail == 0)
+ break;
+ ssize_t got;
+ while (true) {
+ got = write(client.fd.get(), ptr, avail);
+ if (got < 0 && errno == EINTR)
+ continue;
+ break;
+ }
+ if (got < 0) {
+ if (errno != EWOULDBLOCK && errno != EAGAIN) {
+ logger_->info(std::format("Client write error: {}: {}", client_id,
+ strerror(errno)));
+ close_client(client_id);
+ return false;
+ }
+ break;
+ }
+ client.out->consume(got);
+
+ if (std::cmp_less(got, avail))
+ break;
+ }
+ return true;
+ }
+
+ bool client_read(size_t client_id) {
+ auto& client = client_[client_id];
+ size_t avail;
+ auto* ptr = client.in->rptr(avail, /* need */ 2);
+ if (avail < 2)
+ return true;
+
+ std::span<uint8_t const> data(reinterpret_cast<uint8_t const*>(ptr), avail);
+ bool fin = data[0] & 0x80;
+ if (data[1] & 0x70) {
+ logger_->info(std::format(
+ "Client invalid request, reserved bits are not zero: {}", client_id));
+ close_client(client_id);
+ return false;
+ }
+ uint8_t opcode = data[0] & 0x0f;
+ switch (opcode) {
+ case OpCode::kContinuation:
+ case OpCode::kText:
+ case OpCode::kBinary:
+ case OpCode::kClose:
+ case OpCode::kPing:
+ case OpCode::kPong:
+ break;
+ default:
+ logger_->info(std::format(
+ "Client invalid request, reserved opcode used: {}", client_id));
+ close_client(client_id);
+ return false;
+ }
+ bool mask = data[1] & 0x80;
+ if (!mask) {
+ logger_->info(
+ std::format("Client invalid request, not masked: {}", client_id));
+ close_client(client_id);
+ return false;
+ }
+ size_t payload_offset = 2;
+ uint64_t payload_len = data[1] & 0x7f;
+ if (payload_len == 126) {
+ payload_offset += 2;
+ } else if (payload_len == 127) {
+ payload_offset += 8;
+ }
+ /* if (mask) */ payload_offset += 4;
+
+ if (payload_offset > data.size()) {
+ ptr = client.in->rptr(avail, /* need */ payload_offset);
+ if (avail < payload_offset)
+ return true;
+ data = std::span(reinterpret_cast<uint8_t const*>(ptr), avail);
+ }
+
+ if (payload_len == 126) {
+ payload_len = (static_cast<uint16_t>(data[2]) << 8) | data[3];
+ } else if (payload_len == 127) {
+ payload_len = (static_cast<uint64_t>(data[2]) << 56) |
+ (static_cast<uint64_t>(data[3]) << 48) |
+ (static_cast<uint64_t>(data[4]) << 40) |
+ (static_cast<uint64_t>(data[5]) << 32) |
+ (static_cast<uint64_t>(data[6]) << 24) |
+ (static_cast<uint64_t>(data[7]) << 16) |
+ (static_cast<uint64_t>(data[8]) << 8) | data[9];
+ }
+ /* if (mask) */
+ auto mask_key = data.subspan(payload_offset - 4, 4);
+
+ if (payload_offset + payload_len > data.size()) {
+ ptr = client.in->rptr(avail, /* need */ payload_offset + payload_len);
+ if (avail < payload_offset + payload_len)
+ return true;
+ data = std::span(reinterpret_cast<uint8_t const*>(ptr), avail);
+ }
+
+ auto payload = data.subspan(payload_offset, payload_len);
+ /* if (mask) { */
+ // Unmask the data
+ std::span<uint8_t> tmp{const_cast<uint8_t*>(payload.data()),
+ payload.size()};
+ for (size_t i = 0; i < tmp.size(); ++i) {
+ tmp[i] ^= mask_key[i % 4];
+ }
+
+ if (client_handle(client_id, fin, static_cast<OpCode>(opcode), payload)) {
+ client.in->consume(payload_offset + payload_len);
+ return true;
+ }
+ return false;
+ }
+
+ bool client_handle(size_t client_id, bool fin, OpCode opcode,
+ std::span<uint8_t const> payload) {
+ auto& client = client_[client_id];
+
+ // Validate continuation and fin frames.
+ switch (opcode) {
+ case OpCode::kContinuation:
+ if (client.msg_ == OpCode::kContinuation) {
+ logger_->info(std::format(
+ "Client invalid frame, unexpected continuation: {}", client_id));
+ close_client(client_id);
+ return false;
+ }
+ break;
+ case OpCode::kBinary:
+ case OpCode::kText:
+ if (client.msg_ != OpCode::kContinuation) {
+ logger_->info(std::format(
+ "Client invalid frame, unexpected non-continuation: {}",
+ client_id));
+ close_client(client_id);
+ return false;
+ }
+ break;
+ case OpCode::kClose:
+ case OpCode::kPing:
+ case OpCode::kPong:
+ if (!fin) {
+ logger_->info(
+ std::format("Client invalid control frame: {}", client_id));
+ close_client(client_id);
+ return false;
+ }
+ break;
+ }
+
+ switch (opcode) {
+ case OpCode::kContinuation:
+ std::ranges::copy(payload, std::back_inserter(client.payload_));
+ if (fin) {
+ std::unique_ptr<Message> reply;
+ switch (client.msg_) {
+ case OpCode::kText:
+ reply = delegate_.handle(TextMessageImpl{std::string_view(
+ reinterpret_cast<char const*>(client.payload_.data()),
+ client.payload_.size())});
+ break;
+ case OpCode::kBinary:
+ reply = delegate_.handle(BinaryMessageImpl{client.payload_});
+ break;
+ default:
+ std::unreachable();
+ }
+ client.msg_ = OpCode::kContinuation;
+ client.payload_.clear();
+ if (reply) {
+ if (!client_send(client_id, *reply))
+ return false;
+ }
+ }
+ break;
+ case OpCode::kBinary:
+ case OpCode::kText:
+ if (fin) {
+ std::unique_ptr<Message> reply;
+ if (opcode == OpCode::kText) {
+ reply = delegate_.handle(TextMessageImpl{
+ std::string_view(reinterpret_cast<char const*>(payload.data()),
+ payload.size())});
+ } else {
+ reply = delegate_.handle(BinaryMessageImpl{payload});
+ }
+ if (reply) {
+ if (!client_send(client_id, *reply))
+ return false;
+ }
+ } else {
+ client.msg_ = opcode;
+ client.payload_.assign(payload.begin(), payload.end());
+ }
+ break;
+ case OpCode::kClose:
+ if (client.expect_close_) {
+ close_client(client_id);
+ return false;
+ }
+ if (!client_send(client_id, OpCode::kClose, payload))
+ return false;
+ client.read_closed_ = true;
+ break;
+ case OpCode::kPing:
+ if (!client_send(client_id, OpCode::kPong, payload))
+ return false;
+ break;
+ case OpCode::kPong:
+ if (client.expect_ping_.has_value()) {
+ if (payload.size() != 1 ||
+ payload[0] != client.expect_ping_.value()) {
+ logger_->info(
+ std::format("Client closed, mismatched pong: {}", client_id));
+ close_client(client_id);
+ return false;
+ }
+ client.expect_ping_.reset();
+ } else {
+ // A Pong frame MAY be sent unsolicited. This serves as a unidirectional heartbeat.
+ }
+ break;
+ }
+ return true;
+ }
+
+ bool client_send(size_t client_id, Message const& message) {
+ if (message.is_text()) {
+ auto payload = message.text();
+ return client_send(
+ client_id, OpCode::kText,
+ std::span{reinterpret_cast<uint8_t const*>(payload.data()),
+ payload.size()});
+ }
+ return client_send(client_id, OpCode::kBinary, message.binary());
+ }
+
+ bool client_send(size_t client_id, OpCode opcode,
+ std::span<uint8_t const> payload) {
+ auto& client = client_[client_id];
+ if (client.read_closed_)
+ return true;
+
+ uint8_t header[10];
+ size_t header_size = 0;
+ // Always send FIN frames.
+ header[header_size++] = 0x80 | std::to_underlying(opcode);
+ // Server never sends masked messages.
+ if (payload.size() < 126) {
+ header[header_size++] = payload.size();
+ } else if (payload.size() <= 0xffff) {
+ header[header_size++] = 126;
+ header[header_size++] = payload.size() >> 8;
+ header[header_size++] = payload.size() & 0xff;
+ } else {
+ header[header_size++] = 127;
+ header[header_size++] = payload.size() >> 56;
+ header[header_size++] = (payload.size() & 0xff000000000000) >> 48;
+ header[header_size++] = (payload.size() & 0xff0000000000) >> 40;
+ header[header_size++] = (payload.size() & 0xff00000000) >> 32;
+ header[header_size++] = (payload.size() & 0xff000000) >> 24;
+ header[header_size++] = (payload.size() & 0xff0000) >> 16;
+ header[header_size++] = (payload.size() & 0xff00) >> 8;
+ header[header_size++] = payload.size() & 0xff;
+ }
+ size_t avail;
+ bool const was_empty = client.out->empty();
+ auto* wptr =
+ client.out->wptr(avail, /* need */ header_size + payload.size());
+ if (avail < header_size + payload.size()) {
+ logger_->info(
+ std::format("Client closed, too much output: {}", client_id));
+ close_client(client_id);
+ return false;
+ }
+ auto* data = reinterpret_cast<uint8_t*>(wptr);
+ std::copy_n(header, header_size, data);
+ std::ranges::copy(payload, data + header_size);
+ client.out->commit(header_size + payload.size());
+
+ if (was_empty) {
+ if (!client_write(client_id))
+ return false;
+ }
+
+ return true;
+ }
+
+ void client_timeout(size_t client_id, uint32_t id) {
+ auto now = std::chrono::steady_clock::now();
+ auto& client = client_[client_id];
+ assert(client.timeout == id);
+ if (now - client.last <
+ std::chrono::duration_cast<std::chrono::steady_clock::duration>(
+ client_half_timeout_)) {
+ // TODO: Reschedule for delay left, not the full one
+ client.timeout = looper_.schedule(client_half_timeout_.count(),
+ [this, client_id](auto handle) {
+ client_timeout(client_id, handle);
+ });
+ return;
+ }
+
+ client.timeout = 0;
+
+ if (client.expect_ping_.has_value()) {
+ logger_->dbg(std::format("Client timeout: {}", client_id));
+
+ close_client(client_id);
+ } else {
+ client.expect_ping_ = ++client.ping_;
+ if (client_send(client_id, OpCode::kPing, std::span{&client.ping_, 1})) {
+ client.timeout = looper_.schedule(client_half_timeout_.count(),
+ [this, client_id](auto handle) {
+ client_timeout(client_id, handle);
+ });
+ }
+ }
+ }
+
+ void close_client(size_t client_id) {
+ auto& client = client_[client_id];
+ if (!client.fd) {
+ assert(false);
+ return;
+ }
+
+ logger_->dbg(std::format("Drop client: {}", client_id));
+
+ looper_.remove(client.fd.get());
+ if (client.timeout)
+ looper_.cancel(client.timeout);
+ client.fd.reset();
+ clear(*client.in);
+ clear(*client.out);
+ client.read_closed_ = false;
+ client.expect_ping_.reset();
+ client.ping_ = 0;
+ client.expect_close_ = false;
+ client.msg_ = OpCode::kContinuation;
+ client.payload_.clear();
+
+ assert(active_clients_ > 0);
+ --active_clients_;
+ if (next_client_ == client_id + 1)
+ next_client_ = client_id;
+ }
+
+ std::unique_ptr<logger::Logger> logger_;
+ cfg::Config const& cfg_;
+ looper::Looper& looper_;
+ Delegate& delegate_;
+
+ // Timeout first sends a PING and then if it times out again, then closes connection.
+ std::chrono::duration<double> client_half_timeout_;
+ std::vector<Client> client_;
+ size_t next_client_{0};
+ size_t active_clients_{0};
+};
+
+} // namespace
+
+void Server::send_text_to_all(std::string_view text) {
+ send_to_all(TextMessageImpl{text});
+}
+
+void Server::send_binary_to_all(std::span<uint8_t const> data) {
+ send_to_all(BinaryMessageImpl{data});
+}
+
+std::unique_ptr<Server> create_server(logger::Logger& logger,
+ cfg::Config const& cfg,
+ looper::Looper& looper,
+ Server::Delegate& delegate) {
+ return std::make_unique<ServerImpl>(logger, cfg, looper, delegate);
+}
+
+std::unique_ptr<Message> Message::create(std::string_view text) {
+ return std::make_unique<TextMessageImpl>(text);
+}
+
+std::unique_ptr<Message> Message::create(std::span<uint8_t const> data) {
+ return std::make_unique<BinaryMessageImpl>(data);
+}
+
+} // namespace ws