summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--meson.build16
-rw-r--r--src/http.cc295
-rw-r--r--src/http.hh28
-rw-r--r--src/logger.cc34
-rw-r--r--src/logger.hh3
-rw-r--r--src/main.cc127
-rw-r--r--src/websocket.cc676
-rw-r--r--src/websocket.hh88
8 files changed, 1229 insertions, 38 deletions
diff --git a/meson.build b/meson.build
index 7b751db..93e6534 100644
--- a/meson.build
+++ b/meson.build
@@ -230,6 +230,20 @@ sha1_dep = declare_dependency(
dependencies: [sha1_inner_dep],
)
+websocket_lib = library(
+ 'websocket',
+ sources: [
+ 'src/websocket.cc',
+ 'src/websocket.hh',
+ ],
+ include_directories: inc,
+ dependencies: [base64_dep, looper_dep, logger_dep, str_dep, sha1_dep],
+)
+websocket_dep = declare_dependency(
+ link_with: websocket_lib,
+ dependencies: [base64_dep, looper_dep, logger_dep, str_dep, sha1_dep],
+)
+
bluetooth_jukebox = executable(
'bluetooth-jukebox',
sources: [
@@ -242,8 +256,10 @@ bluetooth_jukebox = executable(
bt_dep,
cfg_dep,
http_dep,
+ json_dep,
looper_dep,
signals_dep,
+ websocket_dep,
],
)
diff --git a/src/http.cc b/src/http.cc
index 8a0a1e6..21de28d 100644
--- a/src/http.cc
+++ b/src/http.cc
@@ -20,10 +20,12 @@
#include <memory>
#include <netdb.h>
#include <optional>
+#include <span>
#include <string>
#include <string_view>
#include <sys/socket.h>
#include <sys/types.h>
+#include <system_error>
#include <unistd.h>
#include <utility>
#include <vector>
@@ -54,6 +56,25 @@ std::string_view ascii_lowercase(std::string_view input, std::string& tmp) {
return input;
}
+bool eq_lowercase(std::string_view input, std::string_view cmp) {
+ if (input.size() != cmp.size())
+ return false;
+
+ auto a = input.begin();
+ auto b = cmp.begin();
+ for (; a != input.end(); ++a, ++b) {
+ if (*a >= 'A' && *a <= 'Z') {
+ // NOLINTNEXTLINE(bugprone-narrowing-conversions)
+ if ((*a | 0x20) != *b)
+ return false;
+ } else {
+ if (*a != *b)
+ return false;
+ }
+ }
+ return true;
+}
+
std::string standard_header(std::string_view input) {
bool capitalize = true;
std::string ret;
@@ -215,7 +236,8 @@ class OpenPortImpl : public OpenPort {
class ResponseImpl : public Response {
public:
- explicit ResponseImpl(std::string data) : data_(std::move(data)) {}
+ ResponseImpl(std::string data, SocketReceiver* socket_receiver)
+ : data_(std::move(data)), socket_receiver_(socket_receiver) {}
bool write(Buffer& buffer) override {
if (offset_ >= data_.size())
@@ -235,9 +257,15 @@ class ResponseImpl : public Response {
return offset_ < data_.size();
}
+ [[nodiscard]]
+ SocketReceiver* socket_receiver() const override {
+ return socket_receiver_;
+ }
+
private:
std::string const data_;
size_t offset_{0};
+ SocketReceiver* socket_receiver_;
};
class ResponseBuilderImpl : public Response::Builder {
@@ -256,6 +284,11 @@ class ResponseBuilderImpl : public Response::Builder {
return *this;
}
+ Builder& take_over(SocketReceiver& receiver) override {
+ socket_receiver_ = &receiver;
+ return *this;
+ }
+
[[nodiscard]]
std::unique_ptr<Response> build() const override {
std::string ret;
@@ -269,12 +302,18 @@ class ResponseBuilderImpl : public Response::Builder {
}
ret.append(" ");
switch (status_) {
+ case StatusCode::kSwitchingProtocols:
+ ret.append("Switching Protocols");
+ break;
case StatusCode::kOK:
ret.append("OK");
break;
case StatusCode::kNoContent:
ret.append("No Content");
break;
+ case StatusCode::kBadRequest:
+ ret.append("Bad Request");
+ break;
case StatusCode::kNotFound:
ret.append("Not Found");
break;
@@ -292,7 +331,20 @@ class ResponseBuilderImpl : public Response::Builder {
ret.append(pair.second);
ret.append("\r\n");
}
- if (!have_content_len && status_ != StatusCode::kNoContent) {
+ bool need_content_len;
+ switch (status_) {
+ case StatusCode::kSwitchingProtocols:
+ case StatusCode::kNoContent:
+ need_content_len = false;
+ break;
+ case StatusCode::kOK:
+ case StatusCode::kBadRequest:
+ case StatusCode::kNotFound:
+ case StatusCode::kMethodNotAllowed:
+ need_content_len = true;
+ break;
+ }
+ if (!have_content_len && need_content_len) {
char tmp[20];
auto [ptr, ec] = std::to_chars(tmp, tmp + sizeof(tmp), content_.size());
ret.append("Content-Length");
@@ -303,19 +355,93 @@ class ResponseBuilderImpl : public Response::Builder {
ret.append("\r\n");
if (status_ != StatusCode::kNoContent)
ret.append(content_);
- return std::make_unique<ResponseImpl>(std::move(ret));
+ return std::make_unique<ResponseImpl>(std::move(ret), socket_receiver_);
}
private:
StatusCode const status_;
std::string content_;
std::vector<std::pair<std::string, std::string>> headers_;
+ SocketReceiver* socket_receiver_{nullptr};
+};
+
+class HeaderIterator {
+ public:
+ explicit HeaderIterator(std::span<std::string_view> headers)
+ : headers_(headers) {
+ next();
+ }
+
+ [[nodiscard]]
+ bool has_name() const {
+ return offset_ < headers_.size();
+ }
+
+ [[nodiscard]]
+ bool has_value() const {
+ return has_name() && continue_value_;
+ }
+
+ bool next_name() {
+ do {
+ next();
+ } while (has_value());
+ return has_name();
+ }
+
+ bool next_value() {
+ next();
+ return has_value();
+ }
+
+ [[nodiscard]]
+ std::string_view name() const {
+ return name_;
+ }
+
+ [[nodiscard]]
+ std::string_view value() const {
+ return value_;
+ }
+
+ private:
+ void next() {
+ if (offset_ >= headers_.size())
+ return;
+
+ auto line = headers_[offset_];
+ if (offset_ > 0 &&
+ (line.empty() || line.front() == ' ' || line.front() == '\t')) {
+ continue_value_ = true;
+ value_ = str::ltrim(line);
+ ++offset_;
+ return;
+ }
+ auto colon = line.find(':');
+ if (colon == std::string::npos) {
+ assert(false);
+ ++offset_;
+ next();
+ return;
+ }
+ continue_value_ = false;
+ name_ = str::trim(line.substr(0, colon));
+ value_ = str::ltrim(line.substr(colon + 1));
+ ++offset_;
+ }
+
+ std::span<std::string_view> headers_;
+ size_t offset_{0};
+ bool continue_value_{false};
+ std::string_view name_;
+ std::string_view value_;
};
class RequestImpl : public Request {
public:
- RequestImpl(std::string_view method, std::string_view path)
- : method_(method), path_(path) {}
+ RequestImpl(std::string_view method, std::string_view path,
+ std::span<std::string_view> headers)
+ : method_(method), path_(path), headers_(headers) {}
[[nodiscard]]
std::string_view method() const override {
@@ -327,9 +453,62 @@ class RequestImpl : public Request {
return path_;
}
+ [[nodiscard]]
+ std::string_view body() const override {
+ return body_;
+ }
+
+ void set_body(std::string_view body) { body_ = body; }
+
+ [[nodiscard]]
+ bool header_contains(std::string_view name,
+ std::string_view value) const override {
+ for (HeaderIterator it{headers_}; it.has_name(); it.next_name()) {
+ if (eq_lowercase(it.name(), name)) {
+ std::string_view tmp = get_value(it);
+ auto values = str::split(tmp, ',', /* keep_empty */ true);
+ for (auto v : values) {
+ if (eq_lowercase(str::trim(v), value)) {
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+ }
+
+ [[nodiscard]]
+ std::optional<std::string_view> header_value(
+ std::string_view name) const override {
+ for (HeaderIterator it{headers_}; it.has_name(); it.next_name()) {
+ if (eq_lowercase(it.name(), name)) {
+ // Do not look for more entries of the header, this method is used for header
+ // values that cannot be combined.
+ return get_value(it);
+ }
+ }
+ return std::nullopt;
+ }
+
private:
+ std::string_view get_value(HeaderIterator& it) const {
+ std::string_view ret = it.value();
+ if (it.next_value()) {
+ tmp_.clear();
+ tmp_.append(ret);
+ do {
+ tmp_.append(it.value());
+ } while (it.next_value());
+ return tmp_;
+ }
+ return ret;
+ }
+
std::string_view method_;
std::string_view path_;
+ std::string_view body_;
+ std::span<std::string_view> headers_;
+ mutable std::string tmp_;
};
// TODO: What is a good value?
@@ -340,7 +519,7 @@ class ServerImpl : public Server {
ServerImpl(logger::Logger& logger, cfg::Config const& cfg,
looper::Looper& looper, std::unique_ptr<OpenPort> open_port,
Server::Delegate& delegate)
- : logger_(logger),
+ : logger_(logger::prefix(logger, "http")),
cfg_(cfg),
looper_(looper),
delegate_(delegate),
@@ -364,6 +543,7 @@ class ServerImpl : public Server {
std::chrono::steady_clock::time_point last;
uint32_t timeout{0};
bool read_closed_{false};
+ SocketReceiver* take_over_{nullptr};
};
void start_to_listen() {
@@ -372,7 +552,7 @@ class ServerImpl : public Server {
if (listen(it->get(),
static_cast<int>(cfg_.get_uint64("http.listen.backlog")
.value_or(kListenBacklog)))) {
- logger_.warn(
+ logger_->warn(
std::format("Error listening to socket: {}", strerror(errno)));
it = listens_.erase(it);
} else {
@@ -386,21 +566,21 @@ class ServerImpl : public Server {
}
if (listens_.empty()) {
- logger_.err("No ports left to listen to, exit.");
+ logger_->err("No ports left to listen to, exit.");
looper_.quit();
}
}
void accept_client(int fd, uint8_t event) {
if (event & looper::EVENT_ERROR) {
- logger_.warn("Listening port returned error, closing.");
+ logger_->warn("Listening port returned error, closing.");
looper_.remove(fd);
auto it = std::ranges::find_if(
listens_, [&fd](auto& ufd) { return ufd.get() == fd; });
if (it != listens_.end()) {
listens_.erase(it);
if (listens_.empty()) {
- logger_.err("No ports left to listen to, exit.");
+ logger_->err("No ports left to listen to, exit.");
looper_.quit();
}
}
@@ -409,12 +589,12 @@ class ServerImpl : public Server {
unique_fd client_fd(accept4(fd, nullptr, nullptr, SOCK_NONBLOCK));
if (!client_fd) {
- logger_.info(std::format("Accept returned error: {}", strerror(errno)));
+ logger_->info(std::format("Accept returned error: {}", strerror(errno)));
return;
}
if (active_clients_ == client_.size()) {
- logger_.warn("Max number of clients already.");
+ logger_->warn("Max number of clients already.");
return;
}
@@ -423,7 +603,7 @@ class ServerImpl : public Server {
next_client_ = 0;
}
- logger_.dbg(std::format("New client: {}", next_client_));
+ logger_->dbg(std::format("New client: {}", next_client_));
auto& client = client_[next_client_];
client.fd = std::move(client_fd);
@@ -443,7 +623,7 @@ class ServerImpl : public Server {
void client_event(size_t client_id, uint8_t event) {
if (event & looper::EVENT_ERROR) {
- logger_.info(std::format("Client socket error: {}", client_id));
+ logger_->info(std::format("Client socket error: {}", client_id));
close_client(client_id);
return;
}
@@ -453,7 +633,7 @@ class ServerImpl : public Server {
size_t avail;
auto* ptr = client.in->wptr(avail);
if (avail == 0) {
- logger_.info(std::format("Client too large request: {}", client_id));
+ logger_->info(std::format("Client too large request: {}", client_id));
close_client(client_id);
return;
}
@@ -467,8 +647,8 @@ class ServerImpl : public Server {
}
if (got < 0) {
if (errno != EWOULDBLOCK && errno != EAGAIN) {
- logger_.info(std::format("Client read error: {}: {}", client_id,
- strerror(errno)));
+ logger_->info(std::format("Client read error: {}: {}", client_id,
+ strerror(errno)));
close_client(client_id);
return;
}
@@ -488,10 +668,15 @@ class ServerImpl : public Server {
if (avail == 0) {
if (client.resp) {
if (!client.resp->write(*client.out)) {
+ client.take_over_ = client.resp->socket_receiver();
client.resp.reset();
- if (!client_read(client_id))
- return;
+ if (client.take_over_) {
+ client.read_closed_ = true;
+ } else {
+ if (!client_read(client_id))
+ return;
+ }
}
continue;
}
@@ -506,8 +691,8 @@ class ServerImpl : public Server {
}
if (got < 0) {
if (errno != EWOULDBLOCK && errno != EAGAIN) {
- logger_.info(std::format("Client write error: {}: {}", client_id,
- strerror(errno)));
+ logger_->info(std::format("Client write error: {}: {}", client_id,
+ strerror(errno)));
close_client(client_id);
return;
}
@@ -521,7 +706,13 @@ class ServerImpl : public Server {
if (client.read_closed_ && client.in->empty() && client.out->empty() &&
!client.resp) {
- close_client(client_id);
+ if (client.take_over_) {
+ looper_.remove(client.fd.get());
+ client.take_over_->receive(std::move(client.fd));
+ reset_client(client_id);
+ } else {
+ close_client(client_id);
+ }
return;
}
@@ -550,19 +741,57 @@ class ServerImpl : public Server {
{
auto parts = str::split(lines[0], ' ', false);
if (parts.size() != 3) {
- logger_.info(std::format("Client invalid request: {}", client_id));
+ logger_->info(std::format("Client invalid request: {}", client_id));
close_client(client_id);
return false;
}
method = parts[0];
path = parts[1];
if (!parts[2].starts_with("HTTP/")) {
- logger_.info(std::format("Client invalid request: {}", client_id));
+ logger_->info(std::format("Client invalid request: {}", client_id));
close_client(client_id);
return false;
}
}
- RequestImpl request(method, path);
+ for (size_t i = 1; i < lines.size(); ++i) {
+ auto colon = lines[i].find(':');
+ if (colon == std::string::npos &&
+ // Only continue header value lines are allowed to not have a colon.
+ (i == 1 || lines[i].empty() ||
+ (lines[i].front() != ' ' && lines[i].front() != '\t'))) {
+ logger_->info(std::format("Client invalid request: {}", client_id));
+ close_client(client_id);
+ return false;
+ }
+ }
+ RequestImpl request(method, path, std::span{lines}.subspan(1));
+ auto maybe_content_length = request.header_value("content-length");
+ if (maybe_content_length.has_value()) {
+ uint64_t content_length;
+ auto [ptr, ec] = std::from_chars(
+ maybe_content_length->data(),
+ maybe_content_length->data() + maybe_content_length->size(),
+ content_length);
+ if (ec != std::errc() ||
+ ptr != maybe_content_length->data() + maybe_content_length->size()) {
+ logger_->info(std::format(
+ "Client invalid request, bad content-length: {}", client_id));
+ close_client(client_id);
+ return false;
+ }
+ if (avail - end < content_length) {
+ // Need more data.
+ if (client.in->full()) {
+ logger_->info(std::format("Client invalid request, too much data: {}",
+ client_id));
+ close_client(client_id);
+ return false;
+ }
+ return true;
+ }
+ request.set_body(data.substr(end + 4, content_length));
+ end += content_length;
+ }
client.resp = delegate_.handle(request);
client.in->consume(end + 4);
return true;
@@ -583,7 +812,7 @@ class ServerImpl : public Server {
return;
}
- logger_.dbg(std::format("Client timeout: {}", client_id));
+ logger_->dbg(std::format("Client timeout: {}", client_id));
client.timeout = 0;
close_client(client_id);
@@ -596,23 +825,31 @@ class ServerImpl : public Server {
return;
}
- logger_.dbg(std::format("Drop client: {}", client_id));
+ logger_->dbg(std::format("Drop client: {}", client_id));
looper_.remove(client.fd.get());
+ client.fd.reset();
+
+ reset_client(client_id);
+ }
+
+ void reset_client(size_t client_id) {
+ auto& client = client_[client_id];
+
if (client.timeout)
looper_.cancel(client.timeout);
- client.fd.reset();
clear(*client.in);
clear(*client.out);
client.resp.reset();
client.read_closed_ = false;
+ client.take_over_ = nullptr;
assert(active_clients_ > 0);
--active_clients_;
if (next_client_ == client_id + 1)
next_client_ = client_id;
}
- logger::Logger& logger_;
+ std::unique_ptr<logger::Logger> logger_;
cfg::Config const& cfg_;
looper::Looper& looper_;
Server::Delegate& delegate_;
diff --git a/src/http.hh b/src/http.hh
index ca3f7d4..d85420e 100644
--- a/src/http.hh
+++ b/src/http.hh
@@ -1,6 +1,8 @@
#ifndef HTTP_HH
#define HTTP_HH
+#include "unique_fd.hh"
+
#include <cstdint>
#include <memory>
#include <optional>
@@ -23,8 +25,10 @@ class Config;
namespace http {
enum class StatusCode : uint16_t {
+ kSwitchingProtocols = 101,
kOK = 200,
kNoContent = 204,
+ kBadRequest = 400,
kNotFound = 404,
kMethodNotAllowed = 405,
};
@@ -96,6 +100,16 @@ class Request {
virtual std::string_view method() const = 0;
[[nodiscard]]
virtual std::string_view path() const = 0;
+ [[nodiscard]]
+ virtual std::string_view body() const = 0;
+
+ [[nodiscard]]
+ virtual bool header_contains(std::string_view name,
+ std::string_view value) const = 0;
+
+ [[nodiscard]]
+ virtual std::optional<std::string_view> header_value(
+ std::string_view name) const = 0;
protected:
Request() = default;
@@ -103,6 +117,16 @@ class Request {
Request& operator=(Request const&) = delete;
};
+class SocketReceiver {
+ public:
+ virtual ~SocketReceiver() = default;
+
+ virtual void receive(unique_fd&& fd) = 0;
+
+ protected:
+ SocketReceiver() = default;
+};
+
class Response {
public:
virtual ~Response() = default;
@@ -115,6 +139,8 @@ class Response {
virtual Builder& add_header(std::string_view name,
std::string_view value) = 0;
+ virtual Builder& take_over(SocketReceiver& receiver) = 0;
+
[[nodiscard]]
virtual std::unique_ptr<Response> build() const = 0;
@@ -140,6 +166,8 @@ class Response {
// Returns true while there is more data to write.
virtual bool write(Buffer& buffer) = 0;
+ virtual SocketReceiver* socket_receiver() const = 0;
+
protected:
Response() = default;
Response(Response const&) = delete;
diff --git a/src/logger.cc b/src/logger.cc
index 21effff..104797f 100644
--- a/src/logger.cc
+++ b/src/logger.cc
@@ -1,6 +1,7 @@
#include "logger.hh"
#include <cstdint>
+#include <format>
#include <iostream>
#include <memory>
#include <string>
@@ -125,6 +126,34 @@ class StderrLogger : public BaseLogger {
bool const verbose_;
};
+class PrefixLogger : public Logger {
+ public:
+ PrefixLogger(Logger& logger, std::string prefix)
+ : logger_(logger), prefix_(std::move(prefix)) {}
+
+ void err(std::string_view message) override {
+ logger_.err(std::format("{}: {}", prefix_, message));
+ }
+
+ void warn(std::string_view message) override {
+ logger_.warn(std::format("{}: {}", prefix_, message));
+ }
+
+ void info(std::string_view message) override {
+ logger_.info(std::format("{}: {}", prefix_, message));
+ }
+
+#if !defined(NDEBUG)
+ void dbg(std::string_view message) override {
+ logger_.dbg(std::format("{}: {}", prefix_, message));
+ }
+#endif
+
+ private:
+ Logger& logger_;
+ std::string prefix_;
+};
+
} // namespace
[[nodiscard]]
@@ -142,4 +171,9 @@ std::unique_ptr<Logger> stderr(bool verbose) {
return std::make_unique<StderrLogger>(verbose);
}
+[[nodiscard]]
+std::unique_ptr<Logger> prefix(Logger& logger, std::string prefix) {
+ return std::make_unique<PrefixLogger>(logger, std::move(prefix));
+}
+
} // namespace logger
diff --git a/src/logger.hh b/src/logger.hh
index 5c1e599..88edcf3 100644
--- a/src/logger.hh
+++ b/src/logger.hh
@@ -36,6 +36,9 @@ std::unique_ptr<Logger> syslog(std::string ident, bool verbose = false);
[[nodiscard]]
std::unique_ptr<Logger> stderr(bool verbose = false);
+[[nodiscard]]
+std::unique_ptr<Logger> prefix(Logger& logger, std::string prefix);
+
} // namespace logger
#endif // LOGGER_HH
diff --git a/src/main.cc b/src/main.cc
index 6883f30..494c49c 100644
--- a/src/main.cc
+++ b/src/main.cc
@@ -3,9 +3,11 @@
#include "cfg.hh"
#include "config.h"
#include "http.hh"
+#include "json.hh"
#include "logger.hh"
#include "looper.hh"
#include "signals.hh"
+#include "websocket.hh"
#include <cerrno>
#include <cstdint>
@@ -16,6 +18,7 @@
#include <memory>
#include <optional>
#include <string>
+#include <string_view>
#include <unistd.h>
#include <utility>
#include <vector>
@@ -26,9 +29,38 @@
namespace {
+class Api {
+ public:
+ virtual ~Api() = default;
+
+ [[nodiscard]]
+ virtual bt::Adapter* adapter() const = 0;
+
+ protected:
+ Api() = default;
+};
+
+const std::string_view kSignalUpdateAdapter("controller/update");
+
+class Signaler {
+ public:
+ virtual ~Signaler() = default;
+
+ virtual void send(std::string_view signal) = 0;
+ virtual std::unique_ptr<http::Response> handle(
+ http::Request const& request) = 0;
+
+ protected:
+ Signaler() = default;
+};
+
class HttpServerDelegate : public http::Server::Delegate {
public:
- HttpServerDelegate() = default;
+ HttpServerDelegate(Api& api, Signaler& signaler)
+ : api_(api),
+ signaler_(signaler),
+ json_writer_(json::writer(json_tmp_)),
+ json_mimetype_(http::MimeType::create("application", "json")) {}
std::unique_ptr<http::Response> handle(
http::Request const& request) override {
@@ -36,21 +68,66 @@ class HttpServerDelegate : public http::Server::Delegate {
return http::Response::status(http::StatusCode::kMethodNotAllowed);
}
- if (request.path() == "/api/v1/status") {
- return http::Response::content(
- R"({ status: "OK" })",
- *http::MimeType::create("application", "json"));
+ if (request.path().starts_with("/api/v1/")) {
+ auto path = request.path().substr(8);
+ if (path == "status") {
+ json_writer_->clear();
+ json_writer_->start_object();
+ json_writer_->key("status");
+ json_writer_->value("OK");
+ json_writer_->end_object();
+ return http::Response::content(json_tmp_, *json_mimetype_);
+ }
+
+ if (path == "controller") {
+ auto* adapter = api_.adapter();
+ json_writer_->clear();
+ json_writer_->start_object();
+ json_writer_->key("name");
+ json_writer_->value(adapter ? adapter->name() : "unknown");
+ json_writer_->key("pairable");
+ json_writer_->value(adapter ? adapter->pairable() : false);
+ json_writer_->key("pairing");
+ json_writer_->value(adapter ? adapter->pairing() : false);
+ json_writer_->end_object();
+
+ return http::Response::content(json_tmp_, *json_mimetype_);
+ }
+
+ if (path == "events") {
+ auto resp = signaler_.handle(request);
+ if (resp)
+ return resp;
+ return http::Response::status(http::StatusCode::kBadRequest);
+ }
}
return http::Response::status(http::StatusCode::kNotFound);
}
+
+ private:
+ Api& api_;
+ Signaler& signaler_;
+ std::unique_ptr<json::Writer> json_writer_;
+ std::string json_tmp_;
+ std::unique_ptr<http::MimeType> json_mimetype_;
};
-class BluetoothManagerDelegate : public bt::Manager::Delegate {
+class BluetoothManagerDelegate : public bt::Manager::Delegate, public Api {
public:
- explicit BluetoothManagerDelegate(logger::Logger& logger) : logger_(logger) {}
+ BluetoothManagerDelegate(logger::Logger& logger, Signaler& signaler)
+ : logger_(logger), signaler_(signaler) {}
+
+ [[nodiscard]]
+ bt::Adapter* adapter() const override {
+ return adapter_;
+ }
void new_adapter(bt::Adapter* adapter) override {
+ adapter_ = adapter;
+
+ signaler_.send(kSignalUpdateAdapter);
+
if (adapter) {
logger_.info(std::format("New adapter: {} [{}]", adapter->name(),
adapter->address()));
@@ -59,6 +136,11 @@ class BluetoothManagerDelegate : public bt::Manager::Delegate {
}
}
+ void updated_adapter(bt::Adapter& adapter) override {
+ if (adapter_ == &adapter)
+ signaler_.send(kSignalUpdateAdapter);
+ }
+
void added_device(bt::Device& device) override {
logger_.info(
std::format("New device: {} [{}]", device.name(), device.address()));
@@ -117,15 +199,42 @@ class BluetoothManagerDelegate : public bt::Manager::Delegate {
private:
logger::Logger& logger_;
+ Signaler& signaler_;
+ bt::Adapter* adapter_{nullptr};
+};
+
+class SignalerImpl : public Signaler, ws::Server::Delegate {
+ public:
+ SignalerImpl(logger::Logger& logger, cfg::Config const& cfg,
+ looper::Looper& looper)
+ : server_(ws::create_server(logger, cfg, looper, *this)) {}
+
+ void send(std::string_view signal) override {
+ server_->send_text_to_all(signal);
+ }
+
+ std::unique_ptr<http::Response> handle(
+ http::Request const& request) override {
+ return server_->handle(request);
+ }
+
+ std::unique_ptr<ws::Message> handle(ws::Message const& /* msg */) override {
+ // Ignore anything sent by clients
+ return nullptr;
+ }
+
+ private:
+ std::unique_ptr<ws::Server> server_;
};
bool run(logger::Logger& logger, cfg::Config const& cfg,
std::unique_ptr<http::OpenPort> port) {
auto looper = looper::create();
- HttpServerDelegate http_delegate;
+ SignalerImpl signaler(logger, cfg, *looper);
+ BluetoothManagerDelegate bt_delegate(logger, signaler);
+ HttpServerDelegate http_delegate(bt_delegate, signaler);
auto server =
http::create_server(logger, cfg, *looper, std::move(port), http_delegate);
- BluetoothManagerDelegate bt_delegate(logger);
auto manager = bt::create_manager(logger, cfg, *looper, bt_delegate);
auto sigint_handler = signals::Handler::create(
*looper, signals::Signal::INT, [&looper, &logger]() {
diff --git a/src/websocket.cc b/src/websocket.cc
new file mode 100644
index 0000000..391727d
--- /dev/null
+++ b/src/websocket.cc
@@ -0,0 +1,676 @@
+#include "websocket.hh"
+
+#include "base64.hh"
+#include "buffer.hh"
+#include "cfg.hh"
+#include "http.hh"
+#include "logger.hh"
+#include "looper.hh"
+#include "sha1.hh"
+#include "unique_fd.hh"
+
+#include <algorithm>
+#include <cassert>
+#include <cerrno>
+#include <chrono>
+#include <cstddef>
+#include <cstring>
+#include <format>
+#include <iterator>
+#include <memory>
+#include <optional>
+#include <string_view>
+#include <sys/types.h>
+#include <unistd.h>
+#include <utility>
+#include <vector>
+
+namespace ws {
+
+namespace {
+
+const std::string_view kWebsocketGuid("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
+
+enum OpCode : uint8_t {
+ kContinuation = 0,
+ kText = 1,
+ kBinary = 2,
+ kClose = 8,
+ kPing = 9,
+ kPong = 10,
+};
+
+void clear(Buffer& buffer) {
+ if (buffer.empty())
+ return;
+ while (true) {
+ size_t avail;
+ buffer.rptr(avail);
+ if (avail == 0)
+ break;
+ buffer.consume(avail);
+ }
+}
+
+class TextMessageImpl : public Message {
+ public:
+ explicit TextMessageImpl(std::string_view text) : text_(text) {}
+
+ [[nodiscard]]
+ bool is_text() const override {
+ return true;
+ }
+
+ [[nodiscard]]
+ bool is_binary() const override {
+ return false;
+ }
+
+ [[nodiscard]]
+ std::string_view text() const override {
+ return text_;
+ }
+
+ [[nodiscard]]
+ std::span<uint8_t const> binary() const override {
+ return {};
+ }
+
+ private:
+ std::string_view text_;
+};
+
+class BinaryMessageImpl : public Message {
+ public:
+ explicit BinaryMessageImpl(std::span<uint8_t const> data) : data_(data) {}
+
+ [[nodiscard]]
+ bool is_text() const override {
+ return false;
+ }
+
+ [[nodiscard]]
+ bool is_binary() const override {
+ return true;
+ }
+
+ [[nodiscard]]
+ std::string_view text() const override {
+ return {};
+ }
+
+ [[nodiscard]]
+ std::span<uint8_t const> binary() const override {
+ return data_;
+ }
+
+ private:
+ std::span<uint8_t const> data_;
+};
+
+class ServerImpl : public Server, http::SocketReceiver {
+ public:
+ ServerImpl(logger::Logger& logger, cfg::Config const& cfg,
+ looper::Looper& looper, Delegate& delegate)
+ : logger_(logger::prefix(logger, "ws")),
+ cfg_(cfg),
+ looper_(looper),
+ delegate_(delegate) {
+ client_half_timeout_ =
+ std::chrono::duration<double>(
+ cfg_.get_uint64("websocket.client.timeout.seconds").value_or(120)) /
+ 2;
+ client_.resize(cfg_.get_uint64("websocket.max.clients").value_or(100));
+ }
+
+ std::unique_ptr<http::Response> handle(
+ http::Request const& request) override {
+ if (request.method() != "GET")
+ return nullptr;
+
+ if (!request.header_contains("upgrade", "websocket"))
+ return nullptr;
+
+ if (!request.header_contains("connection", "upgrade"))
+ return nullptr;
+
+ auto maybe_key = request.header_value("sec-websocket-key");
+ if (!maybe_key.has_value())
+ return nullptr;
+ auto maybe_nonce = base64::decode(maybe_key.value());
+ if (!maybe_nonce.has_value() || maybe_nonce->size() != 16)
+ return nullptr;
+
+ auto maybe_version = request.header_value("sec-websocket-version");
+ if (!maybe_version.has_value())
+ return nullptr;
+ if (maybe_version.value() != "13") {
+ auto builder =
+ http::Response::Builder::create(http::StatusCode::kBadRequest);
+ builder->add_header("Sec-WebSocket-Version", "13");
+ return builder->build();
+ }
+
+ std::string accept{maybe_key.value()};
+ accept.append(kWebsocketGuid);
+ accept = base64::encode(sha1::hash(accept));
+
+ auto builder =
+ http::Response::Builder::create(http::StatusCode::kSwitchingProtocols);
+ builder->add_header("Upgrade", "websocket");
+ builder->add_header("Connection", "Upgrade");
+ builder->add_header("Sec-WebSocket-Accept", accept);
+ builder->take_over(*this);
+
+ return builder->build();
+ }
+
+ void receive(unique_fd&& fd) override {
+ if (active_clients_ == client_.size()) {
+ logger_->warn("Max number of clients already.");
+ return;
+ }
+
+ for (; client_[next_client_].fd; next_client_++) {
+ if (next_client_ == client_.size())
+ next_client_ = 0;
+ }
+
+ logger_->dbg(std::format("New client: {}", next_client_));
+
+ auto& client = client_[next_client_];
+ client.fd = std::move(fd);
+ client.last = std::chrono::steady_clock::now();
+
+ looper_.add(
+ client.fd.get(), looper::EVENT_READ,
+ [this, id = next_client_](auto events) { client_event(id, events); });
+ client.timeout = looper_.schedule(
+ client_half_timeout_.count(),
+ [this, id = next_client_](auto handle) { client_timeout(id, handle); });
+
+ ++active_clients_;
+ ++next_client_;
+ }
+
+ void send_to_all(Message const& msg) override {
+ auto active = active_clients_;
+ for (size_t i = 0; i < client_.size() && active; ++i) {
+ if (client_[i].fd) {
+ client_send(i, msg);
+
+ if (--active == 0)
+ break;
+ }
+ }
+ }
+
+ private:
+ static const size_t kInputSize = static_cast<size_t>(256) * 1024;
+ static const size_t kOutputSize = static_cast<size_t>(256) * 1024;
+
+ struct Client {
+ unique_fd fd;
+ std::unique_ptr<Buffer> in{Buffer::fixed(kInputSize)};
+ std::unique_ptr<Buffer> out{Buffer::fixed(kOutputSize)};
+ std::chrono::steady_clock::time_point last;
+ uint32_t timeout{0};
+ bool read_closed_{false};
+
+ uint8_t ping_{0};
+ std::optional<uint8_t> expect_ping_;
+ bool expect_close_{false};
+
+ OpCode msg_{OpCode::kContinuation};
+ std::vector<uint8_t> payload_;
+ };
+
+ void client_event(size_t client_id, uint8_t event) {
+ if (event & looper::EVENT_ERROR) {
+ logger_->info(std::format("Client socket error: {}", client_id));
+ close_client(client_id);
+ return;
+ }
+
+ auto& client = client_[client_id];
+ if (event & looper::EVENT_READ) {
+ size_t avail;
+ auto* ptr = client.in->wptr(avail);
+ if (avail == 0) {
+ logger_->info(std::format("Client too large request: {}", client_id));
+ close_client(client_id);
+ return;
+ }
+
+ ssize_t got;
+ while (true) {
+ got = read(client.fd.get(), ptr, avail);
+ if (got < 0 && errno == EINTR)
+ continue;
+ break;
+ }
+ if (got < 0) {
+ if (errno != EWOULDBLOCK && errno != EAGAIN) {
+ logger_->info(std::format("Client read error: {}: {}", client_id,
+ strerror(errno)));
+ close_client(client_id);
+ return;
+ }
+ } else if (got == 0) {
+ client.read_closed_ = true;
+ } else {
+ client.in->commit(got);
+
+ if (!client_read(client_id))
+ return;
+ }
+ }
+
+ if (!client_write(client_id))
+ return;
+
+ if (client.read_closed_ && client.in->empty() && client.out->empty()) {
+ close_client(client_id);
+ return;
+ }
+
+ bool want_read = !client.read_closed_;
+ bool want_write = !client.out->empty();
+
+ looper_.update(client.fd.get(), (want_read ? looper::EVENT_READ : 0) |
+ (want_write ? looper::EVENT_WRITE : 0));
+ }
+
+ bool client_write(size_t client_id) {
+ auto& client = client_[client_id];
+ while (true) {
+ size_t avail;
+ auto* ptr = client.out->rptr(avail);
+ if (avail == 0)
+ break;
+ ssize_t got;
+ while (true) {
+ got = write(client.fd.get(), ptr, avail);
+ if (got < 0 && errno == EINTR)
+ continue;
+ break;
+ }
+ if (got < 0) {
+ if (errno != EWOULDBLOCK && errno != EAGAIN) {
+ logger_->info(std::format("Client write error: {}: {}", client_id,
+ strerror(errno)));
+ close_client(client_id);
+ return false;
+ }
+ break;
+ }
+ client.out->consume(got);
+
+ if (std::cmp_less(got, avail))
+ break;
+ }
+ return true;
+ }
+
+ bool client_read(size_t client_id) {
+ auto& client = client_[client_id];
+ size_t avail;
+ auto* ptr = client.in->rptr(avail, /* need */ 2);
+ if (avail < 2)
+ return true;
+
+ std::span<uint8_t const> data(reinterpret_cast<uint8_t const*>(ptr), avail);
+ bool fin = data[0] & 0x80;
+ if (data[1] & 0x70) {
+ logger_->info(std::format(
+ "Client invalid request, reserved bits are not zero: {}", client_id));
+ close_client(client_id);
+ return false;
+ }
+ uint8_t opcode = data[0] & 0x0f;
+ switch (opcode) {
+ case OpCode::kContinuation:
+ case OpCode::kText:
+ case OpCode::kBinary:
+ case OpCode::kClose:
+ case OpCode::kPing:
+ case OpCode::kPong:
+ break;
+ default:
+ logger_->info(std::format(
+ "Client invalid request, reserved opcode used: {}", client_id));
+ close_client(client_id);
+ return false;
+ }
+ bool mask = data[1] & 0x80;
+ if (!mask) {
+ logger_->info(
+ std::format("Client invalid request, not masked: {}", client_id));
+ close_client(client_id);
+ return false;
+ }
+ size_t payload_offset = 2;
+ uint64_t payload_len = data[1] & 0x7f;
+ if (payload_len == 126) {
+ payload_offset += 2;
+ } else if (payload_len == 127) {
+ payload_offset += 8;
+ }
+ /* if (mask) */ payload_offset += 4;
+
+ if (payload_offset > data.size()) {
+ ptr = client.in->rptr(avail, /* need */ payload_offset);
+ if (avail < payload_offset)
+ return true;
+ data = std::span(reinterpret_cast<uint8_t const*>(ptr), avail);
+ }
+
+ if (payload_len == 126) {
+ payload_len = (static_cast<uint16_t>(data[2]) << 8) | data[3];
+ } else if (payload_len == 127) {
+ payload_len = (static_cast<uint64_t>(data[2]) << 56) |
+ (static_cast<uint64_t>(data[3]) << 48) |
+ (static_cast<uint64_t>(data[4]) << 40) |
+ (static_cast<uint64_t>(data[5]) << 32) |
+ (static_cast<uint64_t>(data[6]) << 24) |
+ (static_cast<uint64_t>(data[7]) << 16) |
+ (static_cast<uint64_t>(data[8]) << 8) | data[9];
+ }
+ /* if (mask) */
+ auto mask_key = data.subspan(payload_offset - 4, 4);
+
+ if (payload_offset + payload_len > data.size()) {
+ ptr = client.in->rptr(avail, /* need */ payload_offset + payload_len);
+ if (avail < payload_offset + payload_len)
+ return true;
+ data = std::span(reinterpret_cast<uint8_t const*>(ptr), avail);
+ }
+
+ auto payload = data.subspan(payload_offset, payload_len);
+ /* if (mask) { */
+ // Unmask the data
+ std::span<uint8_t> tmp{const_cast<uint8_t*>(payload.data()),
+ payload.size()};
+ for (size_t i = 0; i < tmp.size(); ++i) {
+ tmp[i] ^= mask_key[i % 4];
+ }
+
+ if (client_handle(client_id, fin, static_cast<OpCode>(opcode), payload)) {
+ client.in->consume(payload_offset + payload_len);
+ return true;
+ }
+ return false;
+ }
+
+ bool client_handle(size_t client_id, bool fin, OpCode opcode,
+ std::span<uint8_t const> payload) {
+ auto& client = client_[client_id];
+
+ // Validate continuation and fin frames.
+ switch (opcode) {
+ case OpCode::kContinuation:
+ if (client.msg_ == OpCode::kContinuation) {
+ logger_->info(std::format(
+ "Client invalid frame, unexpected continuation: {}", client_id));
+ close_client(client_id);
+ return false;
+ }
+ break;
+ case OpCode::kBinary:
+ case OpCode::kText:
+ if (client.msg_ != OpCode::kContinuation) {
+ logger_->info(std::format(
+ "Client invalid frame, unexpected non-continuation: {}",
+ client_id));
+ close_client(client_id);
+ return false;
+ }
+ break;
+ case OpCode::kClose:
+ case OpCode::kPing:
+ case OpCode::kPong:
+ if (!fin) {
+ logger_->info(
+ std::format("Client invalid control frame: {}", client_id));
+ close_client(client_id);
+ return false;
+ }
+ break;
+ }
+
+ switch (opcode) {
+ case OpCode::kContinuation:
+ std::ranges::copy(payload, std::back_inserter(client.payload_));
+ if (fin) {
+ std::unique_ptr<Message> reply;
+ switch (client.msg_) {
+ case OpCode::kText:
+ reply = delegate_.handle(TextMessageImpl{std::string_view(
+ reinterpret_cast<char const*>(client.payload_.data()),
+ client.payload_.size())});
+ break;
+ case OpCode::kBinary:
+ reply = delegate_.handle(BinaryMessageImpl{client.payload_});
+ break;
+ default:
+ std::unreachable();
+ }
+ client.msg_ = OpCode::kContinuation;
+ client.payload_.clear();
+ if (reply) {
+ if (!client_send(client_id, *reply))
+ return false;
+ }
+ }
+ break;
+ case OpCode::kBinary:
+ case OpCode::kText:
+ if (fin) {
+ std::unique_ptr<Message> reply;
+ if (opcode == OpCode::kText) {
+ reply = delegate_.handle(TextMessageImpl{
+ std::string_view(reinterpret_cast<char const*>(payload.data()),
+ payload.size())});
+ } else {
+ reply = delegate_.handle(BinaryMessageImpl{payload});
+ }
+ if (reply) {
+ if (!client_send(client_id, *reply))
+ return false;
+ }
+ } else {
+ client.msg_ = opcode;
+ client.payload_.assign(payload.begin(), payload.end());
+ }
+ break;
+ case OpCode::kClose:
+ if (client.expect_close_) {
+ close_client(client_id);
+ return false;
+ }
+ if (!client_send(client_id, OpCode::kClose, payload))
+ return false;
+ client.read_closed_ = true;
+ break;
+ case OpCode::kPing:
+ if (!client_send(client_id, OpCode::kPong, payload))
+ return false;
+ break;
+ case OpCode::kPong:
+ if (client.expect_ping_.has_value()) {
+ if (payload.size() != 1 ||
+ payload[0] != client.expect_ping_.value()) {
+ logger_->info(
+ std::format("Client closed, mismatched pong: {}", client_id));
+ close_client(client_id);
+ return false;
+ }
+ client.expect_ping_.reset();
+ } else {
+ // A Pong frame MAY be sent unsolicited. This serves as a unidirectional heartbeat.
+ }
+ break;
+ }
+ return true;
+ }
+
+ bool client_send(size_t client_id, Message const& message) {
+ if (message.is_text()) {
+ auto payload = message.text();
+ return client_send(
+ client_id, OpCode::kText,
+ std::span{reinterpret_cast<uint8_t const*>(payload.data()),
+ payload.size()});
+ }
+ return client_send(client_id, OpCode::kBinary, message.binary());
+ }
+
+ bool client_send(size_t client_id, OpCode opcode,
+ std::span<uint8_t const> payload) {
+ auto& client = client_[client_id];
+ if (client.read_closed_)
+ return true;
+
+ uint8_t header[10];
+ size_t header_size = 0;
+ // Always send FIN frames.
+ header[header_size++] = 0x80 | std::to_underlying(opcode);
+ // Server never sends masked messages.
+ if (payload.size() < 126) {
+ header[header_size++] = payload.size();
+ } else if (payload.size() <= 0xffff) {
+ header[header_size++] = 126;
+ header[header_size++] = payload.size() >> 8;
+ header[header_size++] = payload.size() & 0xff;
+ } else {
+ header[header_size++] = 127;
+ header[header_size++] = payload.size() >> 56;
+ header[header_size++] = (payload.size() & 0xff000000000000) >> 48;
+ header[header_size++] = (payload.size() & 0xff0000000000) >> 40;
+ header[header_size++] = (payload.size() & 0xff00000000) >> 32;
+ header[header_size++] = (payload.size() & 0xff000000) >> 24;
+ header[header_size++] = (payload.size() & 0xff0000) >> 16;
+ header[header_size++] = (payload.size() & 0xff00) >> 8;
+ header[header_size++] = payload.size() & 0xff;
+ }
+ size_t avail;
+ bool const was_empty = client.out->empty();
+ auto* wptr =
+ client.out->wptr(avail, /* need */ header_size + payload.size());
+ if (avail < header_size + payload.size()) {
+ logger_->info(
+ std::format("Client closed, too much output: {}", client_id));
+ close_client(client_id);
+ return false;
+ }
+ auto* data = reinterpret_cast<uint8_t*>(wptr);
+ std::copy_n(header, header_size, data);
+ std::ranges::copy(payload, data + header_size);
+ client.out->commit(header_size + payload.size());
+
+ if (was_empty) {
+ if (!client_write(client_id))
+ return false;
+ }
+
+ return true;
+ }
+
+ void client_timeout(size_t client_id, uint32_t id) {
+ auto now = std::chrono::steady_clock::now();
+ auto& client = client_[client_id];
+ assert(client.timeout == id);
+ if (now - client.last <
+ std::chrono::duration_cast<std::chrono::steady_clock::duration>(
+ client_half_timeout_)) {
+ // TODO: Reschedule for delay left, not the full one
+ client.timeout = looper_.schedule(client_half_timeout_.count(),
+ [this, client_id](auto handle) {
+ client_timeout(client_id, handle);
+ });
+ return;
+ }
+
+ client.timeout = 0;
+
+ if (client.expect_ping_.has_value()) {
+ logger_->dbg(std::format("Client timeout: {}", client_id));
+
+ close_client(client_id);
+ } else {
+ client.expect_ping_ = ++client.ping_;
+ if (client_send(client_id, OpCode::kPing, std::span{&client.ping_, 1})) {
+ client.timeout = looper_.schedule(client_half_timeout_.count(),
+ [this, client_id](auto handle) {
+ client_timeout(client_id, handle);
+ });
+ }
+ }
+ }
+
+ void close_client(size_t client_id) {
+ auto& client = client_[client_id];
+ if (!client.fd) {
+ assert(false);
+ return;
+ }
+
+ logger_->dbg(std::format("Drop client: {}", client_id));
+
+ looper_.remove(client.fd.get());
+ if (client.timeout)
+ looper_.cancel(client.timeout);
+ client.fd.reset();
+ clear(*client.in);
+ clear(*client.out);
+ client.read_closed_ = false;
+ client.expect_ping_.reset();
+ client.ping_ = 0;
+ client.expect_close_ = false;
+ client.msg_ = OpCode::kContinuation;
+ client.payload_.clear();
+
+ assert(active_clients_ > 0);
+ --active_clients_;
+ if (next_client_ == client_id + 1)
+ next_client_ = client_id;
+ }
+
+ std::unique_ptr<logger::Logger> logger_;
+ cfg::Config const& cfg_;
+ looper::Looper& looper_;
+ Delegate& delegate_;
+
+ // Timeout first sends a PING and then if it times out again, then closes connection.
+ std::chrono::duration<double> client_half_timeout_;
+ std::vector<Client> client_;
+ size_t next_client_{0};
+ size_t active_clients_{0};
+};
+
+} // namespace
+
+void Server::send_text_to_all(std::string_view text) {
+ send_to_all(TextMessageImpl{text});
+}
+
+void Server::send_binary_to_all(std::span<uint8_t const> data) {
+ send_to_all(BinaryMessageImpl{data});
+}
+
+std::unique_ptr<Server> create_server(logger::Logger& logger,
+ cfg::Config const& cfg,
+ looper::Looper& looper,
+ Server::Delegate& delegate) {
+ return std::make_unique<ServerImpl>(logger, cfg, looper, delegate);
+}
+
+std::unique_ptr<Message> Message::create(std::string_view text) {
+ return std::make_unique<TextMessageImpl>(text);
+}
+
+std::unique_ptr<Message> Message::create(std::span<uint8_t const> data) {
+ return std::make_unique<BinaryMessageImpl>(data);
+}
+
+} // namespace ws
diff --git a/src/websocket.hh b/src/websocket.hh
new file mode 100644
index 0000000..f7a99f1
--- /dev/null
+++ b/src/websocket.hh
@@ -0,0 +1,88 @@
+#ifndef WEBSOCKET_HH
+#define WEBSOCKET_HH
+
+#include <memory>
+#include <span>
+#include <string_view>
+
+namespace logger {
+class Logger;
+} // namespace logger
+
+namespace looper {
+class Looper;
+} // namespace looper
+
+namespace cfg {
+class Config;
+} // namespace cfg
+
+namespace http {
+class Response;
+class Request;
+} // namespace http
+
+namespace ws {
+
+class Message {
+ public:
+ virtual ~Message() = default;
+
+ [[nodiscard]]
+ virtual bool is_text() const = 0;
+
+ [[nodiscard]]
+ virtual bool is_binary() const = 0;
+
+ [[nodiscard]]
+ virtual std::string_view text() const = 0;
+
+ [[nodiscard]]
+ virtual std::span<uint8_t const> binary() const = 0;
+
+ static std::unique_ptr<Message> create(std::string_view text);
+ static std::unique_ptr<Message> create(std::span<uint8_t const> data);
+
+ protected:
+ Message() = default;
+ Message(Message const&) = delete;
+ Message& operator=(Message const&) = delete;
+};
+
+class Server {
+ public:
+ virtual ~Server() = default;
+
+ class Delegate {
+ public:
+ virtual ~Delegate() = default;
+
+ virtual std::unique_ptr<Message> handle(Message const& msg) = 0;
+
+ protected:
+ Delegate() = default;
+ };
+
+ virtual std::unique_ptr<http::Response> handle(
+ http::Request const& request) = 0;
+
+ virtual void send_text_to_all(std::string_view text);
+ virtual void send_binary_to_all(std::span<uint8_t const> data);
+ virtual void send_to_all(Message const& message) = 0;
+
+ protected:
+ Server() = default;
+
+ private:
+ Server(Server const&) = delete;
+ Server& operator=(Server const&) = delete;
+};
+
+std::unique_ptr<Server> create_server(logger::Logger& logger,
+ cfg::Config const& cfg,
+ looper::Looper& looper,
+ Server::Delegate& delegate);
+
+} // namespace ws
+
+#endif // WEBSOCKET_HH