#include "common.hh" #include "config.hh" #include "file_opener.hh" #include "image.hh" #include "io.hh" #include "logger.hh" #include "looper.hh" #include "task_runner.hh" #include "transport_base.hh" #include "urlutil.hh" namespace { class ResponseImpl : public Transport::Response { public: uint16_t code() const override { return code_; } std::vector> const& headers() const override { return headers_; } void add_header(std::string name, std::string value) override { headers_.emplace_back(std::move(name), std::move(value)); } protected: explicit ResponseImpl(uint16_t code) : code_(code) {} private: uint16_t const code_; std::vector> headers_; }; class DataInput : public Transport::Input { public: explicit DataInput(std::string data) : data_(std::move(data)) {} Return fill(Buffer* buffer, size_t) override { if (offset_ >= data_.size()) return Return::END; auto bytes = Buffer::write(buffer,data_.data() + offset_, data_.size() - offset_); if (bytes > 0) { offset_ += bytes; return Return::OK; } return Return::FULL; } void wait_once(std::shared_ptr, std::function callback) override { assert(false); callback(); } private: std::string data_; size_t offset_{0}; }; class ResponseData : public ResponseImpl { public: ResponseData(uint16_t code, std::string data) : ResponseImpl(code), data_(std::move(data)) { add_header("Content-Length", std::to_string(data_.size())); } std::unique_ptr open_content() override { return std::make_unique(std::move(data_)); } private: std::string data_; }; class FileInput : public Transport::Input { public: explicit FileInput(unique_fd&& fd) : fd_(std::move(fd)) {} ~FileInput() override { if (looper_) looper_->remove(fd_.get()); } Return fill(Buffer* buffer, size_t buf_request_size) override { size_t bytes; switch (io::fill(fd_.get(), buffer, buf_request_size, &bytes)) { case io::Return::OK: break; case io::Return::ERR: return Return::ERR; case io::Return::CLOSED: return Return::END; } if (bytes > 0) return Return::OK; if (buffer->full()) return Return::FULL; return Return::WAIT; } void wait_once(std::shared_ptr looper, std::function callback) override { if (looper_) { assert(false); looper_->remove(fd_.get()); } looper_ = looper; waiting_callback_ = std::move(callback); looper_->add(fd_.get(), Looper::EVENT_READ, std::bind(&FileInput::event, this, std::placeholders::_1)); } private: void event(uint8_t) { looper_->remove(fd_.get()); looper_.reset(); auto callback = std::move(waiting_callback_); callback(); } unique_fd fd_; std::shared_ptr looper_; std::function waiting_callback_; }; class ErrorInput : public Transport::Input { public: Return fill(Buffer*, size_t) override { return Return::ERR; } void wait_once(std::shared_ptr, std::function callback) override { assert(false); callback(); } }; class ResponseFile : public ResponseImpl { public: ResponseFile(uint16_t code, std::shared_ptr file_opener, std::filesystem::path path) : ResponseImpl(code), opener_(std::move(file_opener)), open_id_(opener_->open(std::move(path), std::bind(&ResponseFile::opened, this, std::placeholders::_1, std::placeholders::_2))) { } ~ResponseFile() { if (open_id_) opener_->cancel(open_id_); } std::unique_ptr open_content() override { if (open_id_) return nullptr; if (fd_) return create_input(std::move(fd_)); return std::make_unique(); } void open_content_async( std::shared_ptr runner, std::function)> callback) override { if (open_id_) { waiting_ = std::make_unique(std::move(callback)); waiting_runner_ = std::move(runner); } else { callback(open_content()); } } protected: virtual std::unique_ptr create_input(unique_fd&& fd) { return std::make_unique(std::move(fd)); } private: class WaitingCallback { public: explicit WaitingCallback( std::function)> callback) : callback_(std::move(callback)) {} void input(std::unique_ptr input) { assert(!input_); input_ = std::move(input); } void call() { assert(input_); callback_(std::move(input_)); } private: std::function)> callback_; std::unique_ptr input_; }; void opened(uint32_t id, unique_fd fd) { assert(open_id_ == id); open_id_ = 0; opener_.reset(); fd_ = std::move(fd); if (waiting_) { waiting_->input(open_content()); waiting_runner_->post(std::bind(&WaitingCallback::call, waiting_)); waiting_.reset(); waiting_runner_.reset(); } } std::shared_ptr opener_; uint32_t open_id_; unique_fd fd_; std::shared_ptr waiting_; std::shared_ptr waiting_runner_; }; class ExifThumbnailInput : public FileInput { public: explicit ExifThumbnailInput(unique_fd&& fd) : FileInput(std::move(fd)), reader_(ThumbnailReader::create()) {} Return fill(Buffer* buffer, size_t buf_request_size) override { if (buf_) { auto file_ret = FileInput::fill(buf_.get(), 1); switch (reader_->drain(buf_.get())) { case ThumbnailReader::Return::NEED_MORE: break; case ThumbnailReader::Return::DONE: buf_.reset(); return fill_with_data(buffer, buf_request_size); case ThumbnailReader::Return::ERR: return Return::ERR; } switch (file_ret) { case Return::OK: case Return::ERR: case Return::WAIT: return file_ret; case Return::FULL: // ThumbnailReader should drain more than this assert(false); return Return::ERR; case Return::END: return Return::ERR; } } else { return fill_with_data(buffer, buf_request_size); } } private: Return fill_with_data(Buffer* buffer, size_t buf_request_size) { if (offset_ >= reader_->data().size()) return Return::END; size_t avail; auto* ptr = buffer->wbuf(buf_request_size, avail); if (avail == 0) return Return::FULL; auto got = reader_->data().size() - offset_; if (avail > got) avail = got; std::copy_n(reader_->data().data() + offset_, avail, ptr); buffer->wcommit(avail); offset_ += avail; return Return::OK; } std::unique_ptr reader_; std::unique_ptr buf_{Buffer::fixed(10 * 1024)}; size_t offset_{0}; }; class ResponseExifThumbnail : public ResponseFile { public: ResponseExifThumbnail(uint16_t code, std::shared_ptr file_opener, std::filesystem::path path) : ResponseFile(code, std::move(file_opener), std::move(path)) {} protected: std::unique_ptr create_input(unique_fd&& fd) override { return std::make_unique(std::move(fd)); } }; } // namespace TransportBase::TransportBase(std::shared_ptr logger, std::shared_ptr looper, std::shared_ptr runner, Handler* handler) : logger_(logger), looper_(looper), runner_(runner), handler_(handler) { } TransportBase::~TransportBase() { // Clear these before calling client_abort to not cause any // unnecessary client_new. client_wait_.clear(); for (auto& client : client_) client_abort(&client); } uint64_t TransportBase::default_client_input_buffer_size() const { // No POST/PUT support, really shouldn't be that big. return 100 * 1024; } uint64_t TransportBase::default_client_output_buffer_size() const { // Might return actual files, but that is async so whole file // doesn't need to fit. return 1 * 1024 * 1024; } std::unique_ptr TransportBase::create_data( uint16_t code, std::string data) { return std::make_unique(code, std::move(data)); } std::unique_ptr TransportBase::create_file( uint16_t code, std::filesystem::path path) { return std::make_unique(code, file_opener_, std::move(path)); } std::unique_ptr TransportBase::create_exif_thumbnail( uint16_t code, std::filesystem::path path) { return std::make_unique(code, file_opener_, std::move(path)); } bool TransportBase::setup(Logger* logger, Config const* config) { client_.clear(); auto clients = config->get("transport.max_clients", 10); if (!clients.has_value()) { logger->err("transport.max_clients is unknown value: '%s'", config->get("transport.max_clients", nullptr)); return false; } if (clients.value() < 1) { logger->err("transport.max_clients must be > 0"); return false; } for (size_t i = 0; i < clients.value(); ++i) client_.emplace_back(i); auto in_buffer_size = config->get_size("client.input_buffer_size", default_client_input_buffer_size()); if (!in_buffer_size.has_value()) { logger->err("client.input_buffer_size is unknown size: `%s'", config->get("client.input_buffer_size", nullptr)); return false; } if (in_buffer_size.value() < 1) { logger->err("client.input_buffer_size must be > 0"); return false; } auto out_buffer_size = config->get_size("client.output_buffer_size", default_client_output_buffer_size()); if (!out_buffer_size.has_value()) { logger->err("client.output_buffer_size is unknown size: `%s'", config->get("client.output_buffer_size", nullptr)); return false; } if (out_buffer_size.value() < 1) { logger->err("client.output_buffer_size must be > 0"); return false; } for (auto& client : client_) { client.in_ = Buffer::fixed(in_buffer_size.value()); client.out_ = Buffer::fixed(out_buffer_size.value()); } auto timeout = config->get_duration("client.timeout", 30.0); if (!timeout.has_value()) { logger->err("client.timeout is unknown duration: `%s'", config->get("client.timeout", nullptr)); return false; } if (timeout.value() <= 0.0) { logger->err("client.timeout must be > 0"); return false; } client_timeout_ = timeout.value(); auto file_opener_threads = config->get("transport.workers", 1); if (!file_opener_threads.has_value()) { logger->err("transport.workers is unknown value: '%s'", config->get("transport.workers", nullptr)); return false; } if (file_opener_threads.value() <= 0) { logger->err("transport.workers must be > 0"); return false; } file_opener_ = FileOpener::create(runner_, file_opener_threads.value()); return true; } void TransportBase::add_client(unique_fd&& fd) { if (!client_full_) { auto const start = next_avail_client_; do { auto& client = client_[next_avail_client_++]; if (next_avail_client_ == client_.size()) next_avail_client_ = 0; if (!client.fd_) { client.fd_ = std::move(fd); client_new(&client); // Assume there is data available directly to speed up responses // in the common case. client_event(&client, Looper::EVENT_READ); return; } } while (next_avail_client_ != start); client_full_ = true; } client_wait_.push_back(std::move(fd)); } void TransportBase::client_new(Client* client) { assert(client->fd_); looper_->add(client->fd_.get(), Looper::EVENT_READ, std::bind(&TransportBase::client_event, this, client, std::placeholders::_1)); client->last_event_ = std::chrono::steady_clock::now(); client->timeout_ = looper_->schedule( client_timeout_, std::bind(&TransportBase::client_timeout, this, client, std::placeholders::_1)); } void TransportBase::client_timeout(Client* client, uint32_t id) { assert(client->timeout_ == id); client->timeout_ = 0; std::chrono::duration delay = std::chrono::steady_clock::now() - client->last_event_; if (delay.count() < client_timeout_) { client->timeout_ = looper_->schedule( client_timeout_ - delay.count(), std::bind(&TransportBase::client_timeout, this, client, std::placeholders::_1)); } else { logger_->dbg("Client timeout %zu", client->index_); client_abort(client); } } void TransportBase::client_event(Client* client, uint8_t event) { client->last_event_ = std::chrono::steady_clock::now(); bool call_handle = false; if (event & Looper::EVENT_READ) { size_t bytes = 0; switch (io::fill(client->fd_.get(), client->in_.get(), client->expect_in_, &bytes)) { case io::Return::OK: if (bytes > 0) call_handle = true; break; case io::Return::ERR: logger_->dbg("Error reading from client %zu", client->index_); client_abort(client); return; case io::Return::CLOSED: if (!client->in_closed_) call_handle = true; client->in_closed_ = true; break; } } if (event & Looper::EVENT_WRITE) { size_t bytes = 0; if (!io::drain(client->out_.get(), client->fd_.get(), &bytes)) { logger_->dbg("Error writing to client %zu", client->index_); client_abort(client); return; } if (bytes > 0) call_handle = true; } if (event & Looper::EVENT_ERROR) { logger_->dbg("Looper error on client %zu", client->index_); client_abort(client); return; } if (call_handle) { if (!client_handle(client)) { client_abort(client); return; } } client_update_event(client); } void TransportBase::client_update_event(Client* client) { uint8_t events = 0; if (!client->in_closed_ && client->expect_in_ > 0 && !client->in_->full()) events |= Looper::EVENT_READ; if (!client->out_->empty()) events |= Looper::EVENT_WRITE; looper_->update(client->fd_.get(), events); } bool TransportBase::client_flush(Client* client) { size_t bytes; const bool was_full = client->out_->full(); if (io::drain(client->out_.get(), client->fd_.get(), &bytes)) { if (bytes > 0) { client->last_event_ = std::chrono::steady_clock::now(); } if (!client->out_->empty()) { // Make sure to add EVENT_WRITE client_update_event(client); } if (bytes > 0 && was_full) { for (auto& pair : client->responses_) { if (pair.second.response_ && pair.second.content_) { if (!client_response_content(client, pair.first)) return false; } } } return true; } logger_->dbg("Error writing to client %zu", client->index_); client_abort(client); return false; } void TransportBase::client_abort(Client* client) { if (!client->fd_) return; looper_->remove(client->fd_.get()); if (client->timeout_) { looper_->cancel(client->timeout_); client->timeout_ = 0; } client->fd_.reset(); client->in_->clear(); client->out_->clear(); client->responses_.clear(); client->in_closed_ = false; client->expect_in_ = 1; next_avail_client_ = client->index_; client_full_ = false; } bool TransportBase::client_response(Client* client, uint32_t id, std::unique_ptr response) { auto ret = client->responses_.emplace(id, std::move(response)); assert(ret.second); if (!client_response_header(client, id)) return false; assert(client->responses_.count(id)); auto& cli_response = client->responses_[id]; cli_response.content_ = cli_response.response_->open_content(); if (cli_response.content_) return client_response_content(client, id); cli_response.response_->open_content_async( runner_, std::bind(&TransportBase::client_response_open, this, client, id, std::placeholders::_1)); return true; } void TransportBase::client_response_open(Client* client, uint32_t id, std::unique_ptr input) { assert(client->responses_.count(id)); auto& cli_response = client->responses_[id]; cli_response.content_ = std::move(input); client_response_content(client, id); } bool TransportBase::client_response_header(Client*, uint32_t) { return true; } bool TransportBase::client_response_content(Client* client, uint32_t id) { return client_response_content(client, id, client->out_.get()); } bool TransportBase::client_response_content(Client* client, uint32_t id, Buffer* out) { assert(client->responses_.count(id)); auto& cli_response = client->responses_[id]; switch (cli_response.content_->fill(out)) { case Input::Return::OK: if (!client_flush(client)) return false; return client->responses_.count(id) == 0 || client_response_content(client, id); case Input::Return::FULL: return client_flush(client); case Input::Return::END: return client_response_end(client, id); case Input::Return::ERR: logger_->warn("Input error for client %zu", client->index_); client_abort(client); return false; case Input::Return::WAIT: cli_response.content_->wait_once( looper_, std::bind(&TransportBase::client_response_content_wait, this, client, id)); return true; } assert(false); return true; } void TransportBase::client_response_content_wait(Client* client, uint32_t id) { client_response_content(client, id); } bool TransportBase::client_response_end(Client* client, uint32_t id) { if (!client_response_footer(client, id)) return false; client->responses_.erase(id); if (!client_handle(client)) return false; client_update_event(client); return true; } bool TransportBase::client_response_footer(Client*, uint32_t) { return true; } bool TransportBase::client_request(Client* client, uint32_t id, std::unique_ptr request) { auto response = handler_->request(this, request.get()); if (response) { return client_response(client, id, std::move(response)); } else { return client_response(client, id, create_data(500, "")); } } std::string_view TransportBase::UrlRequest::path() const { split_url_if_needed(); return path_; } std::string_view TransportBase::UrlRequest::query(std::string_view name) const { split_url_if_needed(); auto it = query_.find(std::string(name)); if (it == query_.end()) return std::string_view(); return it->second; } void TransportBase::UrlRequest::split_url_if_needed() const { if (path_.empty()) url::split_and_unescape_path_and_query(url(), path_, query_); }