summaryrefslogtreecommitdiff
path: root/src/http.cc
diff options
context:
space:
mode:
authorJoel Klinghed <the_jk@spawned.biz>2025-10-19 00:12:50 +0200
committerJoel Klinghed <the_jk@spawned.biz>2025-10-19 00:31:19 +0200
commite7c74917191a4953d495295b65732aa3549ba753 (patch)
treefde3fc6080786b3e3e4526b3793bacbb390d2b17 /src/http.cc
parent4f6ead7c2c646b6b866274299c05d08170d2dfb0 (diff)
Add new module websocket and use it
Implement /api/v1/events which will send out messages when things change, such as the main controller.
Diffstat (limited to 'src/http.cc')
-rw-r--r--src/http.cc295
1 files changed, 266 insertions, 29 deletions
diff --git a/src/http.cc b/src/http.cc
index 8a0a1e6..21de28d 100644
--- a/src/http.cc
+++ b/src/http.cc
@@ -20,10 +20,12 @@
#include <memory>
#include <netdb.h>
#include <optional>
+#include <span>
#include <string>
#include <string_view>
#include <sys/socket.h>
#include <sys/types.h>
+#include <system_error>
#include <unistd.h>
#include <utility>
#include <vector>
@@ -54,6 +56,25 @@ std::string_view ascii_lowercase(std::string_view input, std::string& tmp) {
return input;
}
+bool eq_lowercase(std::string_view input, std::string_view cmp) {
+ if (input.size() != cmp.size())
+ return false;
+
+ auto a = input.begin();
+ auto b = cmp.begin();
+ for (; a != input.end(); ++a, ++b) {
+ if (*a >= 'A' && *a <= 'Z') {
+ // NOLINTNEXTLINE(bugprone-narrowing-conversions)
+ if ((*a | 0x20) != *b)
+ return false;
+ } else {
+ if (*a != *b)
+ return false;
+ }
+ }
+ return true;
+}
+
std::string standard_header(std::string_view input) {
bool capitalize = true;
std::string ret;
@@ -215,7 +236,8 @@ class OpenPortImpl : public OpenPort {
class ResponseImpl : public Response {
public:
- explicit ResponseImpl(std::string data) : data_(std::move(data)) {}
+ ResponseImpl(std::string data, SocketReceiver* socket_receiver)
+ : data_(std::move(data)), socket_receiver_(socket_receiver) {}
bool write(Buffer& buffer) override {
if (offset_ >= data_.size())
@@ -235,9 +257,15 @@ class ResponseImpl : public Response {
return offset_ < data_.size();
}
+ [[nodiscard]]
+ SocketReceiver* socket_receiver() const override {
+ return socket_receiver_;
+ }
+
private:
std::string const data_;
size_t offset_{0};
+ SocketReceiver* socket_receiver_;
};
class ResponseBuilderImpl : public Response::Builder {
@@ -256,6 +284,11 @@ class ResponseBuilderImpl : public Response::Builder {
return *this;
}
+ Builder& take_over(SocketReceiver& receiver) override {
+ socket_receiver_ = &receiver;
+ return *this;
+ }
+
[[nodiscard]]
std::unique_ptr<Response> build() const override {
std::string ret;
@@ -269,12 +302,18 @@ class ResponseBuilderImpl : public Response::Builder {
}
ret.append(" ");
switch (status_) {
+ case StatusCode::kSwitchingProtocols:
+ ret.append("Switching Protocols");
+ break;
case StatusCode::kOK:
ret.append("OK");
break;
case StatusCode::kNoContent:
ret.append("No Content");
break;
+ case StatusCode::kBadRequest:
+ ret.append("Bad Request");
+ break;
case StatusCode::kNotFound:
ret.append("Not Found");
break;
@@ -292,7 +331,20 @@ class ResponseBuilderImpl : public Response::Builder {
ret.append(pair.second);
ret.append("\r\n");
}
- if (!have_content_len && status_ != StatusCode::kNoContent) {
+ bool need_content_len;
+ switch (status_) {
+ case StatusCode::kSwitchingProtocols:
+ case StatusCode::kNoContent:
+ need_content_len = false;
+ break;
+ case StatusCode::kOK:
+ case StatusCode::kBadRequest:
+ case StatusCode::kNotFound:
+ case StatusCode::kMethodNotAllowed:
+ need_content_len = true;
+ break;
+ }
+ if (!have_content_len && need_content_len) {
char tmp[20];
auto [ptr, ec] = std::to_chars(tmp, tmp + sizeof(tmp), content_.size());
ret.append("Content-Length");
@@ -303,19 +355,93 @@ class ResponseBuilderImpl : public Response::Builder {
ret.append("\r\n");
if (status_ != StatusCode::kNoContent)
ret.append(content_);
- return std::make_unique<ResponseImpl>(std::move(ret));
+ return std::make_unique<ResponseImpl>(std::move(ret), socket_receiver_);
}
private:
StatusCode const status_;
std::string content_;
std::vector<std::pair<std::string, std::string>> headers_;
+ SocketReceiver* socket_receiver_{nullptr};
+};
+
+class HeaderIterator {
+ public:
+ explicit HeaderIterator(std::span<std::string_view> headers)
+ : headers_(headers) {
+ next();
+ }
+
+ [[nodiscard]]
+ bool has_name() const {
+ return offset_ < headers_.size();
+ }
+
+ [[nodiscard]]
+ bool has_value() const {
+ return has_name() && continue_value_;
+ }
+
+ bool next_name() {
+ do {
+ next();
+ } while (has_value());
+ return has_name();
+ }
+
+ bool next_value() {
+ next();
+ return has_value();
+ }
+
+ [[nodiscard]]
+ std::string_view name() const {
+ return name_;
+ }
+
+ [[nodiscard]]
+ std::string_view value() const {
+ return value_;
+ }
+
+ private:
+ void next() {
+ if (offset_ >= headers_.size())
+ return;
+
+ auto line = headers_[offset_];
+ if (offset_ > 0 &&
+ (line.empty() || line.front() == ' ' || line.front() == '\t')) {
+ continue_value_ = true;
+ value_ = str::ltrim(line);
+ ++offset_;
+ return;
+ }
+ auto colon = line.find(':');
+ if (colon == std::string::npos) {
+ assert(false);
+ ++offset_;
+ next();
+ return;
+ }
+ continue_value_ = false;
+ name_ = str::trim(line.substr(0, colon));
+ value_ = str::ltrim(line.substr(colon + 1));
+ ++offset_;
+ }
+
+ std::span<std::string_view> headers_;
+ size_t offset_{0};
+ bool continue_value_{false};
+ std::string_view name_;
+ std::string_view value_;
};
class RequestImpl : public Request {
public:
- RequestImpl(std::string_view method, std::string_view path)
- : method_(method), path_(path) {}
+ RequestImpl(std::string_view method, std::string_view path,
+ std::span<std::string_view> headers)
+ : method_(method), path_(path), headers_(headers) {}
[[nodiscard]]
std::string_view method() const override {
@@ -327,9 +453,62 @@ class RequestImpl : public Request {
return path_;
}
+ [[nodiscard]]
+ std::string_view body() const override {
+ return body_;
+ }
+
+ void set_body(std::string_view body) { body_ = body; }
+
+ [[nodiscard]]
+ bool header_contains(std::string_view name,
+ std::string_view value) const override {
+ for (HeaderIterator it{headers_}; it.has_name(); it.next_name()) {
+ if (eq_lowercase(it.name(), name)) {
+ std::string_view tmp = get_value(it);
+ auto values = str::split(tmp, ',', /* keep_empty */ true);
+ for (auto v : values) {
+ if (eq_lowercase(str::trim(v), value)) {
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+ }
+
+ [[nodiscard]]
+ std::optional<std::string_view> header_value(
+ std::string_view name) const override {
+ for (HeaderIterator it{headers_}; it.has_name(); it.next_name()) {
+ if (eq_lowercase(it.name(), name)) {
+ // Do not look for more entries of the header, this method is used for header
+ // values that cannot be combined.
+ return get_value(it);
+ }
+ }
+ return std::nullopt;
+ }
+
private:
+ std::string_view get_value(HeaderIterator& it) const {
+ std::string_view ret = it.value();
+ if (it.next_value()) {
+ tmp_.clear();
+ tmp_.append(ret);
+ do {
+ tmp_.append(it.value());
+ } while (it.next_value());
+ return tmp_;
+ }
+ return ret;
+ }
+
std::string_view method_;
std::string_view path_;
+ std::string_view body_;
+ std::span<std::string_view> headers_;
+ mutable std::string tmp_;
};
// TODO: What is a good value?
@@ -340,7 +519,7 @@ class ServerImpl : public Server {
ServerImpl(logger::Logger& logger, cfg::Config const& cfg,
looper::Looper& looper, std::unique_ptr<OpenPort> open_port,
Server::Delegate& delegate)
- : logger_(logger),
+ : logger_(logger::prefix(logger, "http")),
cfg_(cfg),
looper_(looper),
delegate_(delegate),
@@ -364,6 +543,7 @@ class ServerImpl : public Server {
std::chrono::steady_clock::time_point last;
uint32_t timeout{0};
bool read_closed_{false};
+ SocketReceiver* take_over_{nullptr};
};
void start_to_listen() {
@@ -372,7 +552,7 @@ class ServerImpl : public Server {
if (listen(it->get(),
static_cast<int>(cfg_.get_uint64("http.listen.backlog")
.value_or(kListenBacklog)))) {
- logger_.warn(
+ logger_->warn(
std::format("Error listening to socket: {}", strerror(errno)));
it = listens_.erase(it);
} else {
@@ -386,21 +566,21 @@ class ServerImpl : public Server {
}
if (listens_.empty()) {
- logger_.err("No ports left to listen to, exit.");
+ logger_->err("No ports left to listen to, exit.");
looper_.quit();
}
}
void accept_client(int fd, uint8_t event) {
if (event & looper::EVENT_ERROR) {
- logger_.warn("Listening port returned error, closing.");
+ logger_->warn("Listening port returned error, closing.");
looper_.remove(fd);
auto it = std::ranges::find_if(
listens_, [&fd](auto& ufd) { return ufd.get() == fd; });
if (it != listens_.end()) {
listens_.erase(it);
if (listens_.empty()) {
- logger_.err("No ports left to listen to, exit.");
+ logger_->err("No ports left to listen to, exit.");
looper_.quit();
}
}
@@ -409,12 +589,12 @@ class ServerImpl : public Server {
unique_fd client_fd(accept4(fd, nullptr, nullptr, SOCK_NONBLOCK));
if (!client_fd) {
- logger_.info(std::format("Accept returned error: {}", strerror(errno)));
+ logger_->info(std::format("Accept returned error: {}", strerror(errno)));
return;
}
if (active_clients_ == client_.size()) {
- logger_.warn("Max number of clients already.");
+ logger_->warn("Max number of clients already.");
return;
}
@@ -423,7 +603,7 @@ class ServerImpl : public Server {
next_client_ = 0;
}
- logger_.dbg(std::format("New client: {}", next_client_));
+ logger_->dbg(std::format("New client: {}", next_client_));
auto& client = client_[next_client_];
client.fd = std::move(client_fd);
@@ -443,7 +623,7 @@ class ServerImpl : public Server {
void client_event(size_t client_id, uint8_t event) {
if (event & looper::EVENT_ERROR) {
- logger_.info(std::format("Client socket error: {}", client_id));
+ logger_->info(std::format("Client socket error: {}", client_id));
close_client(client_id);
return;
}
@@ -453,7 +633,7 @@ class ServerImpl : public Server {
size_t avail;
auto* ptr = client.in->wptr(avail);
if (avail == 0) {
- logger_.info(std::format("Client too large request: {}", client_id));
+ logger_->info(std::format("Client too large request: {}", client_id));
close_client(client_id);
return;
}
@@ -467,8 +647,8 @@ class ServerImpl : public Server {
}
if (got < 0) {
if (errno != EWOULDBLOCK && errno != EAGAIN) {
- logger_.info(std::format("Client read error: {}: {}", client_id,
- strerror(errno)));
+ logger_->info(std::format("Client read error: {}: {}", client_id,
+ strerror(errno)));
close_client(client_id);
return;
}
@@ -488,10 +668,15 @@ class ServerImpl : public Server {
if (avail == 0) {
if (client.resp) {
if (!client.resp->write(*client.out)) {
+ client.take_over_ = client.resp->socket_receiver();
client.resp.reset();
- if (!client_read(client_id))
- return;
+ if (client.take_over_) {
+ client.read_closed_ = true;
+ } else {
+ if (!client_read(client_id))
+ return;
+ }
}
continue;
}
@@ -506,8 +691,8 @@ class ServerImpl : public Server {
}
if (got < 0) {
if (errno != EWOULDBLOCK && errno != EAGAIN) {
- logger_.info(std::format("Client write error: {}: {}", client_id,
- strerror(errno)));
+ logger_->info(std::format("Client write error: {}: {}", client_id,
+ strerror(errno)));
close_client(client_id);
return;
}
@@ -521,7 +706,13 @@ class ServerImpl : public Server {
if (client.read_closed_ && client.in->empty() && client.out->empty() &&
!client.resp) {
- close_client(client_id);
+ if (client.take_over_) {
+ looper_.remove(client.fd.get());
+ client.take_over_->receive(std::move(client.fd));
+ reset_client(client_id);
+ } else {
+ close_client(client_id);
+ }
return;
}
@@ -550,19 +741,57 @@ class ServerImpl : public Server {
{
auto parts = str::split(lines[0], ' ', false);
if (parts.size() != 3) {
- logger_.info(std::format("Client invalid request: {}", client_id));
+ logger_->info(std::format("Client invalid request: {}", client_id));
close_client(client_id);
return false;
}
method = parts[0];
path = parts[1];
if (!parts[2].starts_with("HTTP/")) {
- logger_.info(std::format("Client invalid request: {}", client_id));
+ logger_->info(std::format("Client invalid request: {}", client_id));
close_client(client_id);
return false;
}
}
- RequestImpl request(method, path);
+ for (size_t i = 1; i < lines.size(); ++i) {
+ auto colon = lines[i].find(':');
+ if (colon == std::string::npos &&
+ // Only continue header value lines are allowed to not have a colon.
+ (i == 1 || lines[i].empty() ||
+ (lines[i].front() != ' ' && lines[i].front() != '\t'))) {
+ logger_->info(std::format("Client invalid request: {}", client_id));
+ close_client(client_id);
+ return false;
+ }
+ }
+ RequestImpl request(method, path, std::span{lines}.subspan(1));
+ auto maybe_content_length = request.header_value("content-length");
+ if (maybe_content_length.has_value()) {
+ uint64_t content_length;
+ auto [ptr, ec] = std::from_chars(
+ maybe_content_length->data(),
+ maybe_content_length->data() + maybe_content_length->size(),
+ content_length);
+ if (ec != std::errc() ||
+ ptr != maybe_content_length->data() + maybe_content_length->size()) {
+ logger_->info(std::format(
+ "Client invalid request, bad content-length: {}", client_id));
+ close_client(client_id);
+ return false;
+ }
+ if (avail - end < content_length) {
+ // Need more data.
+ if (client.in->full()) {
+ logger_->info(std::format("Client invalid request, too much data: {}",
+ client_id));
+ close_client(client_id);
+ return false;
+ }
+ return true;
+ }
+ request.set_body(data.substr(end + 4, content_length));
+ end += content_length;
+ }
client.resp = delegate_.handle(request);
client.in->consume(end + 4);
return true;
@@ -583,7 +812,7 @@ class ServerImpl : public Server {
return;
}
- logger_.dbg(std::format("Client timeout: {}", client_id));
+ logger_->dbg(std::format("Client timeout: {}", client_id));
client.timeout = 0;
close_client(client_id);
@@ -596,23 +825,31 @@ class ServerImpl : public Server {
return;
}
- logger_.dbg(std::format("Drop client: {}", client_id));
+ logger_->dbg(std::format("Drop client: {}", client_id));
looper_.remove(client.fd.get());
+ client.fd.reset();
+
+ reset_client(client_id);
+ }
+
+ void reset_client(size_t client_id) {
+ auto& client = client_[client_id];
+
if (client.timeout)
looper_.cancel(client.timeout);
- client.fd.reset();
clear(*client.in);
clear(*client.out);
client.resp.reset();
client.read_closed_ = false;
+ client.take_over_ = nullptr;
assert(active_clients_ > 0);
--active_clients_;
if (next_client_ == client_id + 1)
next_client_ = client_id;
}
- logger::Logger& logger_;
+ std::unique_ptr<logger::Logger> logger_;
cfg::Config const& cfg_;
looper::Looper& looper_;
Server::Delegate& delegate_;