#include "websocket.hh" #include "base64.hh" #include "buffer.hh" #include "cfg.hh" #include "http.hh" #include "logger.hh" #include "looper.hh" #include "sha1.hh" #include "unique_fd.hh" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace ws { namespace { const std::string_view kWebsocketGuid("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); enum OpCode : uint8_t { kContinuation = 0, kText = 1, kBinary = 2, kClose = 8, kPing = 9, kPong = 10, }; void clear(Buffer& buffer) { if (buffer.empty()) return; while (true) { size_t avail; buffer.rptr(avail); if (avail == 0) break; buffer.consume(avail); } } class TextMessageImpl : public Message { public: explicit TextMessageImpl(std::string_view text) : text_(text) {} [[nodiscard]] bool is_text() const override { return true; } [[nodiscard]] bool is_binary() const override { return false; } [[nodiscard]] std::string_view text() const override { return text_; } [[nodiscard]] std::span binary() const override { return {}; } private: std::string_view text_; }; class BinaryMessageImpl : public Message { public: explicit BinaryMessageImpl(std::span data) : data_(data) {} [[nodiscard]] bool is_text() const override { return false; } [[nodiscard]] bool is_binary() const override { return true; } [[nodiscard]] std::string_view text() const override { return {}; } [[nodiscard]] std::span binary() const override { return data_; } private: std::span data_; }; class ServerImpl : public Server, http::SocketReceiver { public: ServerImpl(logger::Logger& logger, cfg::Config const& cfg, looper::Looper& looper, Delegate& delegate) : logger_(logger::prefix(logger, "ws")), cfg_(cfg), looper_(looper), delegate_(delegate) { client_half_timeout_ = std::chrono::duration( cfg_.get_uint64("websocket.client.timeout.seconds").value_or(120)) / 2; client_.resize(cfg_.get_uint64("websocket.max.clients").value_or(100)); } std::unique_ptr handle( http::Request const& request) override { if (request.method() != "GET") return nullptr; if (!request.header_contains("upgrade", "websocket")) return nullptr; if (!request.header_contains("connection", "upgrade")) return nullptr; auto maybe_key = request.header_value("sec-websocket-key"); if (!maybe_key.has_value()) return nullptr; auto maybe_nonce = base64::decode(maybe_key.value()); if (!maybe_nonce.has_value() || maybe_nonce->size() != 16) return nullptr; auto maybe_version = request.header_value("sec-websocket-version"); if (!maybe_version.has_value()) return nullptr; if (maybe_version.value() != "13") { auto builder = http::Response::Builder::create(http::StatusCode::kBadRequest); builder->add_header("Sec-WebSocket-Version", "13"); return builder->build(); } std::string accept{maybe_key.value()}; accept.append(kWebsocketGuid); accept = base64::encode(sha1::hash(accept)); auto builder = http::Response::Builder::create(http::StatusCode::kSwitchingProtocols); builder->add_header("Upgrade", "websocket"); builder->add_header("Connection", "Upgrade"); builder->add_header("Sec-WebSocket-Accept", accept); builder->take_over(*this); return builder->build(); } void receive(unique_fd&& fd) override { if (active_clients_ == client_.size()) { logger_->warn("Max number of clients already."); return; } for (; client_[next_client_].fd; next_client_++) { if (next_client_ == client_.size()) next_client_ = 0; } logger_->dbg(std::format("New client: {}", next_client_)); auto& client = client_[next_client_]; client.fd = std::move(fd); client.last = std::chrono::steady_clock::now(); looper_.add( client.fd.get(), looper::EVENT_READ, [this, id = next_client_](auto events) { client_event(id, events); }); client.timeout = looper_.schedule( client_half_timeout_.count(), [this, id = next_client_](auto handle) { client_timeout(id, handle); }); ++active_clients_; ++next_client_; } void send_to_all(Message const& msg) override { auto active = active_clients_; for (size_t i = 0; i < client_.size() && active; ++i) { if (client_[i].fd) { client_send(i, msg); if (--active == 0) break; } } } private: static const size_t kInputSize = static_cast(256) * 1024; static const size_t kOutputSize = static_cast(256) * 1024; struct Client { unique_fd fd; std::unique_ptr in{Buffer::fixed(kInputSize)}; std::unique_ptr out{Buffer::fixed(kOutputSize)}; std::chrono::steady_clock::time_point last; uint32_t timeout{0}; bool read_closed_{false}; uint8_t ping_{0}; std::optional expect_ping_; bool expect_close_{false}; OpCode msg_{OpCode::kContinuation}; std::vector payload_; }; void client_event(size_t client_id, uint8_t event) { if (event & looper::EVENT_ERROR) { logger_->info(std::format("Client socket error: {}", client_id)); close_client(client_id); return; } auto& client = client_[client_id]; if (event & looper::EVENT_READ) { size_t avail; auto* ptr = client.in->wptr(avail); if (avail == 0) { logger_->info(std::format("Client too large request: {}", client_id)); close_client(client_id); return; } ssize_t got; while (true) { got = read(client.fd.get(), ptr, avail); if (got < 0 && errno == EINTR) continue; break; } if (got < 0) { if (errno != EWOULDBLOCK && errno != EAGAIN) { logger_->info(std::format("Client read error: {}: {}", client_id, strerror(errno))); close_client(client_id); return; } } else if (got == 0) { client.read_closed_ = true; } else { client.in->commit(got); if (!client_read(client_id)) return; } } if (!client_write(client_id)) return; if (client.read_closed_ && client.in->empty() && client.out->empty()) { close_client(client_id); return; } bool want_read = !client.read_closed_; bool want_write = !client.out->empty(); looper_.update(client.fd.get(), (want_read ? looper::EVENT_READ : 0) | (want_write ? looper::EVENT_WRITE : 0)); } bool client_write(size_t client_id) { auto& client = client_[client_id]; while (true) { size_t avail; auto* ptr = client.out->rptr(avail); if (avail == 0) break; ssize_t got; while (true) { got = write(client.fd.get(), ptr, avail); if (got < 0 && errno == EINTR) continue; break; } if (got < 0) { if (errno != EWOULDBLOCK && errno != EAGAIN) { logger_->info(std::format("Client write error: {}: {}", client_id, strerror(errno))); close_client(client_id); return false; } break; } client.out->consume(got); if (std::cmp_less(got, avail)) break; } return true; } bool client_read(size_t client_id) { auto& client = client_[client_id]; size_t avail; auto* ptr = client.in->rptr(avail, /* need */ 2); if (avail < 2) return true; std::span data(reinterpret_cast(ptr), avail); bool fin = data[0] & 0x80; if (data[1] & 0x70) { logger_->info(std::format( "Client invalid request, reserved bits are not zero: {}", client_id)); close_client(client_id); return false; } uint8_t opcode = data[0] & 0x0f; switch (opcode) { case OpCode::kContinuation: case OpCode::kText: case OpCode::kBinary: case OpCode::kClose: case OpCode::kPing: case OpCode::kPong: break; default: logger_->info(std::format( "Client invalid request, reserved opcode used: {}", client_id)); close_client(client_id); return false; } bool mask = data[1] & 0x80; if (!mask) { logger_->info( std::format("Client invalid request, not masked: {}", client_id)); close_client(client_id); return false; } size_t payload_offset = 2; uint64_t payload_len = data[1] & 0x7f; if (payload_len == 126) { payload_offset += 2; } else if (payload_len == 127) { payload_offset += 8; } /* if (mask) */ payload_offset += 4; if (payload_offset > data.size()) { ptr = client.in->rptr(avail, /* need */ payload_offset); if (avail < payload_offset) return true; data = std::span(reinterpret_cast(ptr), avail); } if (payload_len == 126) { payload_len = (static_cast(data[2]) << 8) | data[3]; } else if (payload_len == 127) { payload_len = (static_cast(data[2]) << 56) | (static_cast(data[3]) << 48) | (static_cast(data[4]) << 40) | (static_cast(data[5]) << 32) | (static_cast(data[6]) << 24) | (static_cast(data[7]) << 16) | (static_cast(data[8]) << 8) | data[9]; } /* if (mask) */ auto mask_key = data.subspan(payload_offset - 4, 4); if (payload_offset + payload_len > data.size()) { ptr = client.in->rptr(avail, /* need */ payload_offset + payload_len); if (avail < payload_offset + payload_len) return true; data = std::span(reinterpret_cast(ptr), avail); } auto payload = data.subspan(payload_offset, payload_len); /* if (mask) { */ // Unmask the data std::span tmp{const_cast(payload.data()), payload.size()}; for (size_t i = 0; i < tmp.size(); ++i) { tmp[i] ^= mask_key[i % 4]; } if (client_handle(client_id, fin, static_cast(opcode), payload)) { client.in->consume(payload_offset + payload_len); return true; } return false; } bool client_handle(size_t client_id, bool fin, OpCode opcode, std::span payload) { auto& client = client_[client_id]; // Validate continuation and fin frames. switch (opcode) { case OpCode::kContinuation: if (client.msg_ == OpCode::kContinuation) { logger_->info(std::format( "Client invalid frame, unexpected continuation: {}", client_id)); close_client(client_id); return false; } break; case OpCode::kBinary: case OpCode::kText: if (client.msg_ != OpCode::kContinuation) { logger_->info(std::format( "Client invalid frame, unexpected non-continuation: {}", client_id)); close_client(client_id); return false; } break; case OpCode::kClose: case OpCode::kPing: case OpCode::kPong: if (!fin) { logger_->info( std::format("Client invalid control frame: {}", client_id)); close_client(client_id); return false; } break; } switch (opcode) { case OpCode::kContinuation: std::ranges::copy(payload, std::back_inserter(client.payload_)); if (fin) { std::unique_ptr reply; switch (client.msg_) { case OpCode::kText: reply = delegate_.handle(TextMessageImpl{std::string_view( reinterpret_cast(client.payload_.data()), client.payload_.size())}); break; case OpCode::kBinary: reply = delegate_.handle(BinaryMessageImpl{client.payload_}); break; default: std::unreachable(); } client.msg_ = OpCode::kContinuation; client.payload_.clear(); if (reply) { if (!client_send(client_id, *reply)) return false; } } break; case OpCode::kBinary: case OpCode::kText: if (fin) { std::unique_ptr reply; if (opcode == OpCode::kText) { reply = delegate_.handle(TextMessageImpl{ std::string_view(reinterpret_cast(payload.data()), payload.size())}); } else { reply = delegate_.handle(BinaryMessageImpl{payload}); } if (reply) { if (!client_send(client_id, *reply)) return false; } } else { client.msg_ = opcode; client.payload_.assign(payload.begin(), payload.end()); } break; case OpCode::kClose: if (client.expect_close_) { close_client(client_id); return false; } if (!client_send(client_id, OpCode::kClose, payload)) return false; client.read_closed_ = true; break; case OpCode::kPing: if (!client_send(client_id, OpCode::kPong, payload)) return false; break; case OpCode::kPong: if (client.expect_ping_.has_value()) { if (payload.size() != 1 || payload[0] != client.expect_ping_.value()) { logger_->info( std::format("Client closed, mismatched pong: {}", client_id)); close_client(client_id); return false; } client.expect_ping_.reset(); } else { // A Pong frame MAY be sent unsolicited. This serves as a unidirectional heartbeat. } break; } return true; } bool client_send(size_t client_id, Message const& message) { if (message.is_text()) { auto payload = message.text(); return client_send( client_id, OpCode::kText, std::span{reinterpret_cast(payload.data()), payload.size()}); } return client_send(client_id, OpCode::kBinary, message.binary()); } bool client_send(size_t client_id, OpCode opcode, std::span payload) { auto& client = client_[client_id]; if (client.read_closed_) return true; uint8_t header[10]; size_t header_size = 0; // Always send FIN frames. header[header_size++] = 0x80 | std::to_underlying(opcode); // Server never sends masked messages. if (payload.size() < 126) { header[header_size++] = payload.size(); } else if (payload.size() <= 0xffff) { header[header_size++] = 126; header[header_size++] = payload.size() >> 8; header[header_size++] = payload.size() & 0xff; } else { header[header_size++] = 127; header[header_size++] = payload.size() >> 56; header[header_size++] = (payload.size() & 0xff000000000000) >> 48; header[header_size++] = (payload.size() & 0xff0000000000) >> 40; header[header_size++] = (payload.size() & 0xff00000000) >> 32; header[header_size++] = (payload.size() & 0xff000000) >> 24; header[header_size++] = (payload.size() & 0xff0000) >> 16; header[header_size++] = (payload.size() & 0xff00) >> 8; header[header_size++] = payload.size() & 0xff; } size_t avail; bool const was_empty = client.out->empty(); auto* wptr = client.out->wptr(avail, /* need */ header_size + payload.size()); if (avail < header_size + payload.size()) { logger_->info( std::format("Client closed, too much output: {}", client_id)); close_client(client_id); return false; } auto* data = reinterpret_cast(wptr); std::copy_n(header, header_size, data); std::ranges::copy(payload, data + header_size); client.out->commit(header_size + payload.size()); if (was_empty) { if (!client_write(client_id)) return false; } return true; } void client_timeout(size_t client_id, uint32_t id) { auto now = std::chrono::steady_clock::now(); auto& client = client_[client_id]; assert(client.timeout == id); if (now - client.last < std::chrono::duration_cast( client_half_timeout_)) { // TODO: Reschedule for delay left, not the full one client.timeout = looper_.schedule(client_half_timeout_.count(), [this, client_id](auto handle) { client_timeout(client_id, handle); }); return; } client.timeout = 0; if (client.expect_ping_.has_value()) { logger_->dbg(std::format("Client timeout: {}", client_id)); close_client(client_id); } else { client.expect_ping_ = ++client.ping_; if (client_send(client_id, OpCode::kPing, std::span{&client.ping_, 1})) { client.timeout = looper_.schedule(client_half_timeout_.count(), [this, client_id](auto handle) { client_timeout(client_id, handle); }); } } } void close_client(size_t client_id) { auto& client = client_[client_id]; if (!client.fd) { assert(false); return; } logger_->dbg(std::format("Drop client: {}", client_id)); looper_.remove(client.fd.get()); if (client.timeout) looper_.cancel(client.timeout); client.fd.reset(); clear(*client.in); clear(*client.out); client.read_closed_ = false; client.expect_ping_.reset(); client.ping_ = 0; client.expect_close_ = false; client.msg_ = OpCode::kContinuation; client.payload_.clear(); assert(active_clients_ > 0); --active_clients_; if (next_client_ == client_id + 1) next_client_ = client_id; } std::unique_ptr logger_; cfg::Config const& cfg_; looper::Looper& looper_; Delegate& delegate_; // Timeout first sends a PING and then if it times out again, then closes connection. std::chrono::duration client_half_timeout_; std::vector client_; size_t next_client_{0}; size_t active_clients_{0}; }; } // namespace void Server::send_text_to_all(std::string_view text) { send_to_all(TextMessageImpl{text}); } void Server::send_binary_to_all(std::span data) { send_to_all(BinaryMessageImpl{data}); } std::unique_ptr create_server(logger::Logger& logger, cfg::Config const& cfg, looper::Looper& looper, Server::Delegate& delegate) { return std::make_unique(logger, cfg, looper, delegate); } std::unique_ptr Message::create(std::string_view text) { return std::make_unique(text); } std::unique_ptr Message::create(std::span data) { return std::make_unique(data); } } // namespace ws