summaryrefslogtreecommitdiff
path: root/src/proxy.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/proxy.cc')
-rw-r--r--src/proxy.cc1373
1 files changed, 1373 insertions, 0 deletions
diff --git a/src/proxy.cc b/src/proxy.cc
new file mode 100644
index 0000000..fcd02f6
--- /dev/null
+++ b/src/proxy.cc
@@ -0,0 +1,1373 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#ifndef _GNU_SOURCE
+#define _GNU_SOURCE
+#endif
+
+#include <chrono>
+#include <cstring>
+#include <fcntl.h>
+#include <memory>
+#include <netdb.h>
+#include <signal.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <vector>
+
+#include "buffer.hh"
+#include "chunked.hh"
+#include "config.hh"
+#include "http.hh"
+#include "io.hh"
+#include "logger.hh"
+#include "looper.hh"
+#include "resolver.hh"
+#include "paths.hh"
+#include "proxy.hh"
+#include "url.hh"
+
+namespace {
+
+auto const NEW_CONNECTION_TIMEOUT = std::chrono::duration<float>(5.0f);
+auto const CONNECTION_TIMEOUT = std::chrono::duration<float>(30.0f);
+
+template<typename T>
+class clients {
+public:
+ class iterator {
+ public:
+ iterator(clients* clients, size_t active)
+ : clients_(clients), index_(0), active_(active) {
+ if (active_) {
+ while (clients_->is_free(index_)) {
+ ++index_;
+ }
+ }
+ }
+ iterator(iterator const& i)
+ : clients_(i.clients_), index_(i.index_), active_(i.active_) {
+ }
+
+ iterator& operator=(iterator const& i) {
+ clients_ = i.clients_;
+ index_ = i.index_;
+ active_ = i.active_;
+ return *this;
+ }
+
+ bool operator==(iterator const& i) const {
+ return active_ == i.active_;
+ }
+ bool operator!=(iterator const& i) const {
+ return !(*this == i);
+ }
+ bool operator<(iterator const& i) const {
+ return active_ < i.active_;
+ }
+ bool operator<=(iterator const& i) const {
+ return (*this < i) || (*this == i);
+ }
+ bool operator>=(iterator const& i) const {
+ return !(*this < i);
+ }
+ bool operator>(iterator const& i) const {
+ return !(*this <= i);
+ }
+
+ iterator& operator++() {
+ next();
+ return *this;
+ }
+
+ iterator operator++(int UNUSED(dummy)) {
+ iterator ret(*this);
+ ++(*this);
+ return ret;
+ }
+
+ T& operator*() {
+ return (*clients_)[index_];
+ }
+ T* operator->() {
+ return &(*clients_)[index_];
+ }
+
+ size_t index() const {
+ return index_;
+ }
+
+ private:
+ void next() {
+ if (active_ == 0) return;
+ --active_;
+ if (active_ == 0) return;
+ ++index_;
+ while (clients_->is_free(index_)) {
+ ++index_;
+ }
+ }
+
+ clients* clients_;
+ size_t index_;
+ size_t active_;
+ };
+
+ clients()
+ : active_(0), max_(0) {
+ }
+
+ void resize(size_t max) {
+ max_ = max;
+ if (max > data_.size()) {
+ data_.resize(max);
+ }
+ }
+
+ bool full() const {
+ return active_ >= max_;
+ }
+
+ iterator begin() {
+ return iterator(this, active_);
+ }
+
+ iterator end() {
+ return iterator(this, 0);
+ }
+
+ T& operator[] (size_t index) {
+ return data_[index];
+ }
+
+ size_t new_client() {
+ size_t index;
+ if (active_ == data_.size()) {
+ index = data_.size();
+ data_.emplace_back();
+ } else {
+ index = rand() % data_.size();
+ while (!is_free(index)) {
+ if (++index == data_.size()) {
+ index = 0;
+ }
+ }
+ }
+ ++active_;
+ return index;
+ }
+
+ void erase(size_t index) {
+ assert(active_ > 0);
+ --active_;
+ assert(is_free(index));
+ if (data_.size() > max_ && index >= max_) {
+ size_t size = data_.size();
+ while (size > max_ && is_free(size - 1)) --size;
+ data_.resize(size);
+ }
+ }
+
+ void clear() {
+ if (active_ > 0) {
+ data_.clear();
+ active_ = 0;
+ }
+ data_.resize(max_);
+ }
+
+private:
+ bool is_free(size_t index) const {
+ return !data_[index].fd;
+ }
+
+ size_t active_;
+ size_t max_;
+ std::vector<T> data_;
+};
+
+struct BaseClient {
+ io::auto_fd fd;
+ bool new_connection;
+ Looper::clock::time_point last;
+ std::unique_ptr<Buffer> in;
+ std::unique_ptr<Buffer> out;
+ uint8_t read_flag;
+ uint8_t write_flag;
+};
+
+enum ContentType {
+ CONTENT_NONE,
+ CONTENT_LEN,
+ CONTENT_CHUNKED,
+ CONTENT_CLOSE
+};
+
+struct Content {
+ ContentType type;
+ uint64_t len;
+ std::unique_ptr<Chunked> chunked;
+};
+
+enum RemoteState {
+ CLOSED,
+ RESOLVING,
+ CONNECTING,
+ CONNECTED,
+ WAITING,
+};
+
+struct RemoteClient : public BaseClient {
+ Content content;
+};
+
+struct Client : public BaseClient {
+ Client()
+ : resolve(nullptr) {
+ }
+ std::unique_ptr<HttpRequest> request;
+ std::unique_ptr<Url> url;
+ Content content;
+ RemoteState remote_state;
+ void* resolve;
+ RemoteClient remote;
+};
+
+struct Monitor : public BaseClient {
+};
+
+class ProxyImpl : public Proxy {
+public:
+ ProxyImpl(Config* config, std::string const& cwd, char const* configfile,
+ char const* logfile, Logger* logger, int accept_fd, int monitor_fd)
+ : config_(config), cwd_(cwd), configfile_(configfile), logfile_(logfile),
+ logger_(logger), accept_socket_(accept_fd), monitor_socket_(monitor_fd),
+ looper_(Looper::create()), resolver_(Resolver::create(looper_.get())),
+ new_timeout_(nullptr), timeout_(nullptr) {
+ setup();
+ }
+ ~ProxyImpl() override {
+ if (accept_socket_) {
+ looper_->remove(accept_socket_.get());
+ }
+ if (monitor_socket_) {
+ looper_->remove(monitor_socket_.get());
+ }
+ if (signal_pipe_) {
+ looper_->remove(signal_pipe_.read());
+ }
+ }
+
+ void quit() override {
+ char cmd = 'q';
+ io::write_all(signal_pipe_.write(), &cmd, 1);
+ }
+
+ void reload() override {
+ char cmd = 'r';
+ io::write_all(signal_pipe_.write(), &cmd, 1);
+ }
+
+ bool run() override;
+
+private:
+ void setup();
+ bool reload_config();
+ void fatal_error();
+ void new_client(int fd, uint8_t events);
+ void new_monitor(int fd, uint8_t events);
+ void new_base(BaseClient* client, int fd);
+ void signal_event(int fd, uint8_t events);
+ bool base_event(BaseClient* client, uint8_t events,
+ size_t index, char const* name);
+ void client_event(size_t index, int fd, uint8_t events);
+ void client_remote_event(size_t index, int fd, uint8_t events);
+ void client_empty_input(size_t index);
+ void monitor_event(size_t index, int fd, uint8_t events);
+ void close_client(size_t index);
+ void close_monitor(size_t index);
+ void close_base(BaseClient* client);
+ void new_timeout();
+ void timeout();
+ float handle_timeout(bool new_conn,
+ std::chrono::duration<float> const& timeout);
+ bool base_send(BaseClient* client, void const* data, size_t size,
+ size_t index, char const* name);
+ void client_error(size_t index,
+ uint16_t status_code, std::string const& status);
+ bool client_request(size_t index);
+ bool client_send(size_t index, void const* ptr, size_t avail);
+ void client_remote_error(size_t index, uint16_t error);
+ void close_client_when_done(size_t index);
+ void client_remote_resolved(size_t index, int fd, bool connected,
+ char const* error);
+
+ Config* const config_;
+ std::string cwd_;
+ char const* const configfile_;
+ char const* const logfile_;
+ Logger* logger_;
+ std::unique_ptr<Logger> priv_logger_;
+ io::auto_fd accept_socket_;
+ io::auto_fd monitor_socket_;
+ io::auto_pipe signal_pipe_;
+ std::unique_ptr<Looper> looper_;
+ std::unique_ptr<Resolver> resolver_;
+ bool good_;
+ void* new_timeout_;
+ void* timeout_;
+
+ clients<Client> clients_;
+ clients<Monitor> monitors_;
+};
+
+size_t get_size(Config* config, Logger* logger, std::string const& name,
+ size_t fallback) {
+ auto value = config->get(name, nullptr);
+ if (!value) return fallback;
+ char* end = nullptr;
+ errno = 0;
+ auto tmp = strtoul(value, &end, 10);
+ if (errno || !end || *end) {
+ logger->out(Logger::WARN,
+ "Invalid value given for %s: %s, using fallback %zu instead",
+ name.c_str(), value, fallback);
+ return fallback;
+ }
+ return static_cast<size_t>(tmp);
+}
+
+void ProxyImpl::setup() {
+ clients_.resize(get_size(config_, logger_, "max_clients", 1024));
+ monitors_.resize(get_size(config_, logger_, "max_monitors", 2));
+ looper_->add(accept_socket_.get(),
+ clients_.full() ? 0 : Looper::EVENT_READ,
+ std::bind(&ProxyImpl::new_client, this,
+ std::placeholders::_1,
+ std::placeholders::_2));
+ if (monitor_socket_) {
+ looper_->add(monitor_socket_.get(),
+ monitors_.full() ? 0 :Looper::EVENT_READ,
+ std::bind(&ProxyImpl::new_monitor, this,
+ std::placeholders::_1,
+ std::placeholders::_2));
+ } else {
+ monitors_.clear();
+ }
+}
+
+bool ProxyImpl::reload_config() {
+ if (configfile_) {
+ auto file = Paths::join(cwd_, configfile_);
+ logger_->out(Logger::INFO, "Reloading config file %s", file.c_str());
+ config_->load_file(file);
+ } else {
+ logger_->out(Logger::INFO, "Reloading config");
+ config_->load_name("tp");
+ }
+ if (!config_->good()) {
+ logger_->out(Logger::WARN, "New config invalid, ignored.");
+ return true;
+ }
+ if (!logfile_) {
+ auto logfile = config_->get("logfile", nullptr);
+ if (logfile) {
+ if (logfile[0] != '/') {
+ logger_->out(Logger::WARN,
+ "Logfile need to be an absolute path, not: %s",
+ logfile);
+ } else {
+ std::unique_ptr<Logger> tmp(Logger::create_file(logfile_));
+ if (tmp) {
+ priv_logger_.swap(tmp);
+ logger_ = priv_logger_.get();
+ } else {
+ logger_->out(Logger::WARN,
+ "Unable to open %s for logging", logfile);
+ }
+ }
+ } else {
+ priv_logger_.reset(Logger::create_syslog("tp"));
+ logger_ = priv_logger_.get();
+ }
+ }
+ auto const old_accept = accept_socket_.get();
+ auto const old_monitor = monitor_socket_.get();
+ looper_->remove(old_accept);
+ if (monitor_socket_) looper_->remove(old_monitor);
+ accept_socket_.reset(setup_accept_socket(config_, logger_));
+ monitor_socket_.reset();
+ if (!accept_socket_) {
+ logger_->out(Logger::ERR, "Unable to bind to new configuration, abort");
+ return false;
+ }
+ if (config_->get("monitor", false)) {
+ monitor_socket_.reset(setup_monitor_socket(config_, logger_));
+ if (!monitor_socket_) {
+ logger_->out(Logger::ERR, "Unable to bind to new configuration, abort");
+ return false;
+ }
+ }
+ setup();
+ return true;
+}
+
+void ProxyImpl::signal_event(int fd, uint8_t events) {
+ assert(fd == signal_pipe_.read());
+ if (events == Looper::EVENT_READ) {
+ char cmd;
+ auto ret = io::read(fd, &cmd, 1);
+ if (ret == 1) {
+ switch (cmd) {
+ case 'q':
+ logger_->out(Logger::INFO, "Exiting ...");
+ looper_->quit();
+ return;
+ case 'r':
+ if (!reload_config()) {
+ fatal_error();
+ return;
+ }
+ break;
+ }
+ }
+ } else {
+ logger_->out(Logger::WARN, "Signal pipes have crashed");
+ signal_pipe_.reset();
+ looper_->remove(fd);
+ }
+}
+
+void ProxyImpl::close_base(BaseClient* client) {
+ if (client->fd) {
+ looper_->remove(client->fd.get());
+ client->fd.reset();
+ }
+ client->in.reset();
+ client->out.reset();
+}
+
+void ProxyImpl::close_client(size_t index) {
+ bool was_full = clients_.full();
+ auto& client = clients_[index];
+ client.request.reset();
+ client.url.reset();
+ client.content.type = CONTENT_NONE;
+ client.content.chunked.reset();
+ client.remote_state = CLOSED;
+ if (client.resolve) {
+ resolver_->cancel(client.resolve);
+ client.resolve = nullptr;
+ }
+ close_base(&client.remote);
+ close_base(&client);
+ clients_.erase(index);
+ if (was_full && !clients_.full() && accept_socket_) {
+ looper_->modify(accept_socket_.get(), Looper::EVENT_READ);
+ }
+}
+
+void ProxyImpl::close_monitor(size_t index) {
+ bool was_full = monitors_.full();
+ auto& monitor = monitors_[index];
+ close_base(&monitor);
+ monitors_.erase(index);
+ if (was_full && !monitors_.full() && monitor_socket_) {
+ looper_->modify(monitor_socket_.get(), Looper::EVENT_READ);
+ }
+}
+
+float ProxyImpl::handle_timeout(bool new_conn,
+ std::chrono::duration<float> const& timeout) {
+ auto now = looper_->now();
+ std::vector<size_t> close;
+ std::vector<bool> remote;
+ float next = -1.0f;
+ for (auto i = clients_.begin(); i != clients_.end(); ++i) {
+ if (i->new_connection != new_conn) continue;
+ auto diff = std::chrono::duration_cast<std::chrono::duration<float>>(
+ ((i->last + timeout) - now)).count();
+ if (diff < 0.0f) {
+ close.push_back(i.index());
+ remote.push_back(false);
+ } else {
+ if (!new_conn && i->remote_state > CLOSED) {
+ auto diff2 = std::chrono::duration_cast<std::chrono::duration<float>>(
+ ((i->remote.last + timeout) - now)).count();
+ if (diff2 < 0.0f) {
+ close.push_back(i.index());
+ remote.push_back(true);
+ } else if (diff2 < diff) {
+ diff = diff2;
+ }
+ }
+ if (next < 0.0f || diff < next) {
+ next = diff;
+ }
+ }
+ }
+ assert(close.size() == remote.size());
+ auto j = remote.rbegin();
+ for (auto i = close.rbegin(); i != close.rend(); ++i, ++j) {
+ if (*j) {
+ client_remote_error(*i, 504);
+ } else {
+ close_client(*i);
+ }
+ }
+ close.clear();
+ for (auto i = monitors_.begin(); i != monitors_.end(); ++i) {
+ if (i->new_connection != new_conn) continue;
+ auto diff = std::chrono::duration_cast<std::chrono::duration<float>>(
+ ((i->last + timeout) - now)).count();
+ if (diff < 0.0f) {
+ close.push_back(i.index());
+ } else if (next < 0.0f || diff < next) {
+ next = diff;
+ }
+ }
+ for (auto i = close.rbegin(); i != close.rend(); ++i) {
+ close_monitor(*i);
+ }
+ return next;
+}
+
+void ProxyImpl::timeout() {
+ assert(timeout_);
+ timeout_ = nullptr;
+ float next = handle_timeout(false, CONNECTION_TIMEOUT);
+ if (next < 0.0f) return;
+ timeout_ = looper_->schedule(next, std::bind(&ProxyImpl::timeout, this));
+}
+
+void ProxyImpl::new_timeout() {
+ assert(new_timeout_);
+ new_timeout_ = nullptr;
+ float next = handle_timeout(true, NEW_CONNECTION_TIMEOUT);
+ if (next < 0.0f) return;
+ new_timeout_ =
+ looper_->schedule(next, std::bind(&ProxyImpl::new_timeout, this));
+}
+
+bool ProxyImpl::base_send(BaseClient* client, void const* data, size_t size,
+ size_t index, char const* name) {
+ if (size == 0) return true;
+ if (!client->out->empty()) {
+ // Already waiting for write event
+ client->out->write(data, size);
+ return true;
+ }
+ auto ret = io::write(client->fd.get(), data, size);
+ if (ret == -1) {
+ if (errno != EAGAIN && errno != EWOULDBLOCK) {
+ logger_->out(Logger::INFO, "%zu: %s write error: %s", index, name,
+ strerror(errno));
+ return false;
+ }
+ client->out->write(data, size);
+ } else {
+ if (static_cast<size_t>(ret) == size) {
+ if (!client->in) {
+ // If input is closed, close after sending all data
+ return false;
+ }
+ return true;
+ }
+ client->out->write(reinterpret_cast<char const*>(data) + ret, size - ret);
+ }
+ client->write_flag = Looper::EVENT_WRITE;
+ looper_->modify(client->fd.get(), client->read_flag | client->write_flag);
+ return true;
+}
+
+bool ProxyImpl::base_event(BaseClient* client, uint8_t events,
+ size_t index, char const* name) {
+ if (events & Looper::EVENT_READ) {
+ if (client->new_connection) {
+ char tmp[1];
+ auto ret = io::read(client->fd.get(), &tmp, 1);
+ if (ret < 0) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK) return true;
+ logger_->out(Logger::INFO, "%zu: %s read error: %s", index, name,
+ strerror(errno));
+ return false;
+ }
+ if (ret == 0) {
+ return false;
+ }
+ client->last = looper_->now();
+ client->new_connection = false;
+ client->in.reset(Buffer::create(8192, 1024));
+ client->out.reset(Buffer::create(8192, 1024));
+ client->in->write(tmp, 1);
+ if (!timeout_) {
+ timeout_ = looper_->schedule(CONNECTION_TIMEOUT.count(),
+ std::bind(&ProxyImpl::timeout, this));
+ }
+ }
+ if (!client->in) {
+ assert(false);
+ return false;
+ }
+ size_t avail;
+ auto ptr = client->in->write_ptr(&avail);
+ assert(avail > 0);
+ auto ret = io::read(client->fd.get(), ptr, avail);
+ if (ret < 0) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK) return true;
+ logger_->out(Logger::INFO, "%zu: %s read error: %s", index, name,
+ strerror(errno));
+ return false;
+ }
+ if (ret == 0) {
+ return false;
+ }
+ client->last = looper_->now();
+ client->in->commit(ret);
+ }
+ if (events & Looper::EVENT_WRITE) {
+ if (client->new_connection) {
+ assert(false);
+ return true;
+ }
+ size_t avail;
+ auto ptr = client->out->read_ptr(&avail);
+ if (avail == 0) {
+ assert(false);
+ if (!client->in) return false;
+ client->write_flag = 0;
+ looper_->modify(client->fd.get(), client->read_flag);
+ return true;
+ }
+ auto ret = io::write(client->fd.get(), ptr, avail);
+ if (ret < 0) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK) return true;
+ logger_->out(Logger::INFO, "%zu: %s write error: %s", index, name,
+ strerror(errno));
+ return false;
+ }
+ if (ret == 0) {
+ return false;
+ }
+ client->last = looper_->now();
+ client->out->consume(ret);
+ if (client->out->empty()) {
+ if (!client->in) return false;
+ client->write_flag = 0;
+ looper_->modify(client->fd.get(), client->read_flag);
+ }
+ }
+ return true;
+}
+
+inline char lower_ascii(char c) {
+ return (c >= 'A' && c <= 'Z') ? (c - 'A' + 'a') : c;
+}
+
+bool lower_equal(char const* data, size_t start, size_t end,
+ std::string const& str) {
+ assert(start <= end);
+ if (str.size() != end - start) return false;
+ for (auto i = str.begin(); start < end; ++start, ++i) {
+ if (lower_ascii(*i) != lower_ascii(data[start])) return false;
+ }
+ return true;
+}
+
+bool header_token_eq(std::string const& value, std::string const& token) {
+ if (value.empty()) return false;
+ auto pos = value.find(';');
+ if (pos == std::string::npos) pos = value.size();
+ return lower_equal(value.data(), 0, pos, token);
+}
+
+void ProxyImpl::client_remote_error(size_t index, uint16_t error) {
+ assert(false);
+ auto& client = clients_[index];
+ if (client.remote_state > CONNECTED) {
+ // Already started to return response, too late to do anything
+ close_client(index);
+ return;
+ }
+ char const* status;
+ switch (error) {
+ case 502:
+ status = "Bad Gateway";
+ break;
+ case 504:
+ status = "Gateway Timeout";
+ break;
+ default:
+ assert(false);
+ error = 500;
+ status = "Internal Server Error";
+ break;
+ }
+
+ client_error(index, error, status);
+}
+
+void ProxyImpl::client_error(size_t index, uint16_t status_code,
+ std::string const& status_message) {
+ auto& client = clients_[index];
+ // No more input
+ client.read_flag = 0;
+ looper_->modify(client.fd.get(), client.write_flag);
+ client.url.reset();
+
+ std::string proto;
+ Version version;
+
+ if (!client.request) {
+ proto = "HTTP";
+ version.major = 1;
+ version.minor = 1;
+ } else {
+ proto = client.request->proto();
+ version = client.request->proto_version();
+ }
+
+ auto resp = std::unique_ptr<HttpResponseBuilder>(
+ HttpResponseBuilder::create(
+ proto, version, status_code, status_message));
+ resp->add_header("Content-Length", "0");
+ resp->add_header("Connection", "close");
+
+ client.in.reset();
+ client.request.reset();
+
+ auto data = resp->build();
+
+ if (!base_send(&client, data.data(), data.size(), index, "Client")) {
+ close_client(index);
+ return;
+ }
+}
+
+bool ProxyImpl::client_request(size_t index) {
+ auto& client = clients_[index];
+ auto version = client.request->proto_version();
+ if (version.major != 1 || version.minor > 1) {
+ client_error(index, 505, "HTTP Version Not Supported");
+ return false;
+ }
+ if (client.url->scheme() != "http") {
+ client_error(index, 501, "Not Implemented");
+ return false;
+ }
+ if (client.url->userinfo()) {
+ client_error(index, 400, "Bad request");
+ return false;
+ }
+ if (client.remote_state == WAITING) {
+ client.remote_state = CONNECTING;
+ client.remote.read_flag = Looper::EVENT_READ;
+ client.remote.write_flag = 0;
+ looper_->modify(client.remote.fd.get(),
+ client.remote.read_flag | client.remote.write_flag);
+ client_remote_event(index, client.remote.fd.get(), Looper::EVENT_WRITE);
+ } else {
+ assert(client.remote_state == CLOSED);
+ client.remote.last = looper_->now();
+ client.remote.new_connection = true;
+ client.remote_state = RESOLVING;
+
+ auto port = client.url->port();
+ if (port == 0) port = 80;
+ client.resolve = resolver_->request(
+ client.url->host(), port,
+ std::bind(&ProxyImpl::client_remote_resolved, this, index,
+ std::placeholders::_1,
+ std::placeholders::_2,
+ std::placeholders::_3));
+ }
+ return true;
+}
+
+void ProxyImpl::client_remote_resolved(size_t index, int fd, bool connected,
+ char const* error) {
+ auto& client = clients_[index];
+ assert(client.resolve);
+ client.resolve = nullptr;
+ assert(client.remote_state == RESOLVING);
+ if (fd < 0) {
+ logger_->out(Logger::INFO, "%zu: Client unable to resolve remote: %s",
+ index, error);
+ client_remote_error(index, 502);
+ return;
+ }
+ client.remote_state = CONNECTING;
+ client.remote.fd.reset(fd);
+ client.remote.last = looper_->now();
+ client.remote.new_connection = false;
+ client.remote.in.reset(Buffer::create(8192, 1024));
+ client.remote.out.reset(Buffer::create(8192, 1024));
+ if (connected) {
+ client.remote.read_flag = Looper::EVENT_READ;
+ client.remote.write_flag = 0;
+ } else {
+ client.remote.read_flag = 0;
+ client.remote.write_flag = Looper::EVENT_WRITE;
+ }
+ looper_->add(client.remote.fd.get(),
+ client.remote.read_flag | client.remote.write_flag,
+ std::bind(&ProxyImpl::client_remote_event, this, index,
+ std::placeholders::_1,
+ std::placeholders::_2));
+ assert(timeout_);
+ if (connected) {
+ client_remote_event(index, fd, Looper::EVENT_WRITE);
+ }
+}
+
+void ProxyImpl::client_event(size_t index, int fd, uint8_t events) {
+ auto& client = clients_[index];
+ assert(client.fd.get() == fd);
+ if (events & Looper::EVENT_HUP) {
+ close_client(index);
+ return;
+ }
+ if (events & Looper::EVENT_ERROR) {
+ logger_->out(Logger::INFO, "%zu: Client connection error", index);
+ close_client(index);
+ return;
+ }
+ if (!base_event(&client, events, index, "Client")) {
+ close_client(index);
+ return;
+ }
+ if (client.new_connection) return;
+ client_empty_input(index);
+}
+
+bool setup_content(Http const* http, Content* content) {
+ assert(content->type == CONTENT_NONE);
+ std::string te = http->first_header("transfer-encoding");
+ if (te.empty() || header_token_eq(te, "identity")) {
+ std::string len = http->first_header("content-length");
+ if (len.empty()) {
+ content->type = CONTENT_CLOSE;
+ return true;
+ }
+ char* end = nullptr;
+ errno = 0;
+ auto tmp = strtoull(len.c_str(), &end, 10);
+ if (errno || !end || *end) {
+ return false;
+ }
+ if (tmp == 0) {
+ content->type = CONTENT_NONE;
+ return true;
+ }
+ content->len = tmp;
+ content->type = CONTENT_LEN;
+ } else {
+ content->chunked.reset(Chunked::create());
+ content->type = CONTENT_CHUNKED;
+ }
+ return true;
+}
+
+void ProxyImpl::client_empty_input(size_t index) {
+ auto& client = clients_[index];
+ while (true) {
+ size_t avail;
+ auto ptr = client.in->read_ptr(&avail);
+ if (avail == 0) return;
+ switch (client.content.type) {
+ case CONTENT_CLOSE:
+ assert(false);
+ // falltrough
+ case CONTENT_NONE: {
+ if (client.remote_state != CLOSED && client.remote_state != WAITING) {
+ // Still working on the last request, wait
+ return;
+ }
+ client.request.reset(
+ HttpRequest::parse(
+ reinterpret_cast<char const*>(ptr), avail, false));
+ if (!client.request) {
+ if (avail >= 1024 * 1024) {
+ logger_->out(Logger::INFO, "%zu: Client too large request %zu",
+ index, avail);
+ close_client(index);
+ }
+ return;
+ }
+ if (!client.request->good()) {
+ client_error(index, 400, "Bad request");
+ return;
+ }
+ if (client.request->method_equal("CONNECT")) {
+ client_error(index, 501, "Not Implemented");
+ return;
+ }
+ client.url.reset(Url::parse(client.request->url()));
+ if (!client.url) {
+ client_error(index, 400, "Bad request");
+ return;
+ }
+ if (!setup_content(client.request.get(), &client.content)) {
+ logger_->out(Logger::INFO, "%zu: Client bad content-length", index);
+ client_error(index, 400, "Bad request");
+ return;
+ }
+ if (client.content.type == CONTENT_CLOSE) {
+ client.content.type = CONTENT_NONE;
+ }
+ if (!client_request(index)) {
+ client.content.type = CONTENT_NONE;
+ return;
+ }
+ break;
+ }
+ case CONTENT_LEN:
+ if (client.remote_state < CONNECTED) {
+ // Request hasn't been sent yet, still collecting data
+ return;
+ }
+ if (avail < client.content.len) {
+ if (!client_send(index, ptr, avail)) {
+ return;
+ }
+ client.in->consume(avail);
+ client.content.len -= avail;
+ return;
+ }
+ if (!client_send(index, ptr, client.content.len)) {
+ return;
+ }
+ client.in->consume(client.content.len);
+ client.content.len = 0;
+ client.content.type = CONTENT_NONE;
+ break;
+ case CONTENT_CHUNKED:
+ if (client.remote_state < CONNECTED) {
+ // Request hasn't been sent yet, still collecting data
+ return;
+ }
+ auto used = client.content.chunked->add(ptr, avail);
+ if (!client.content.chunked->good()) {
+ client_error(index, 400, "Bad request");
+ return;
+ }
+ if (!client_send(index, ptr, used)) {
+ return;
+ }
+ client.in->consume(used);
+ if (client.content.chunked->eof()) {
+ client.content.chunked.reset();
+ client.content.type = CONTENT_NONE;
+ }
+ break;
+ }
+ }
+}
+
+void ProxyImpl::close_client_when_done(size_t index) {
+ auto& client = clients_[index];
+ if (client.out->empty()) {
+ close_client(index);
+ return;
+ }
+ client.remote_state = CLOSED;
+ close_base(&client.remote);
+ client.in.reset();
+ client.read_flag = 0;
+ client.request.reset();
+ client.content.type = CONTENT_NONE;
+ looper_->modify(client.fd.get(), client.write_flag);
+}
+
+void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) {
+ auto& client = clients_[index];
+ assert(client.remote.fd.get() == fd);
+ if (events & Looper::EVENT_HUP) {
+ logger_->out(Logger::INFO, "%zu: Client remote connection closed", index);
+ client_remote_error(index, 502);
+ return;
+ }
+ if (events & Looper::EVENT_ERROR) {
+ logger_->out(Logger::INFO, "%zu: Client remote connection error", index);
+ client_remote_error(index, 502);
+ return;
+ }
+ if (client.remote_state == CONNECTING) {
+ if (events & Looper::EVENT_WRITE) {
+ std::string url(client.url->path_escaped());
+ if (url.empty()) url.push_back('/');
+ auto query = client.url->full_query_escaped();
+ if (query) {
+ url.push_back('?');
+ url.append(query);
+ }
+ auto req = std::unique_ptr<HttpRequestBuilder>(
+ HttpRequestBuilder::create(
+ client.request->method(),
+ url,
+ client.request->proto(),
+ client.request->proto_version()));
+ auto iter = client.request->header();
+ bool have_host = false;
+ for (; iter->valid(); iter->next()) {
+ if (!have_host && iter->name_equal("host")) have_host = true;
+ if (iter->name_equal("proxy-connection") ||
+ iter->name_equal("proxy-authenticate") ||
+ iter->name_equal("proxy-authorization")) {
+ continue;
+ }
+ req->add_header(iter->name(), iter->value());
+ }
+ if (!have_host &&
+ (client.request->proto_version().major == 1 &&
+ client.request->proto_version().minor == 1)) {
+ req->add_header("host", client.url->host());
+ }
+ auto data = req->build();
+ client.in->consume(client.request->size());
+ client.request.reset();
+ client.url.reset();
+ client.remote.out->write(data.data(), data.size());
+ client.remote_state = CONNECTED;
+ client.remote.read_flag = Looper::EVENT_READ;
+ client.remote.write_flag = Looper::EVENT_WRITE;
+ looper_->modify(client.remote.fd.get(),
+ client.remote.read_flag | client.remote.write_flag);
+ client_empty_input(index);
+ } else {
+ return;
+ }
+ }
+ if (!base_event(&client.remote, events, index, "Client remote")) {
+ switch (client.remote_state) {
+ case CONNECTED:
+ switch (client.remote.content.type) {
+ case CONTENT_CLOSE:
+ close_client_when_done(index);
+ break;
+ case CONTENT_CHUNKED:
+ case CONTENT_LEN:
+ case CONTENT_NONE:
+ client_remote_error(index, 502);
+ break;
+ }
+ break;
+ case CONNECTING:
+ case RESOLVING:
+ case CLOSED:
+ assert(false);
+ client_remote_error(index, 502);
+ break;
+ case WAITING:
+ close_base(&client.remote);
+ client.remote_state = CLOSED;
+ break;
+ }
+ return;
+ }
+ while (true) {
+ size_t avail;
+ auto ptr = client.remote.in->read_ptr(&avail);
+ if (avail == 0) return;
+ switch (client.remote_state) {
+ case CONNECTED:
+ switch (client.remote.content.type) {
+ case CONTENT_NONE: {
+ auto response = std::unique_ptr<HttpResponse>(
+ HttpResponse::parse(
+ reinterpret_cast<char const*>(ptr), avail, false));
+ if (!response) {
+ if (avail >= 1024 * 1024) {
+ logger_->out(Logger::INFO,
+ "%zu: Client remote too large request %zu",
+ index, avail);
+ client_remote_error(index, 502);
+ }
+ return;
+ }
+ if (!response->good()) {
+ client_remote_error(index, 502);
+ return;
+ }
+ if (!setup_content(response.get(), &client.remote.content)) {
+ logger_->out(Logger::INFO, "%zu: Client remote bad content-length",
+ index);
+ client.remote.content.type = CONTENT_CLOSE;
+ } else {
+ if (client.remote.content.type == CONTENT_NONE) {
+ client.remote_state = WAITING;
+ client.remote.read_flag = 0;
+ client.remote.write_flag = 0;
+ looper_->modify(client.remote.fd.get(), 0);
+ }
+ }
+ if (!base_send(&client, ptr, response->size(), index, "Client")) {
+ return;
+ }
+ client.remote.in->consume(response->size());
+ break;
+ }
+ case CONTENT_LEN:
+ if (avail < client.remote.content.len) {
+ if (!base_send(&client, ptr, avail, index, "Client")) {
+ return;
+ }
+ client.remote.in->consume(avail);
+ client.remote.content.len -= avail;
+ return;
+ }
+ if (!base_send(&client, ptr, client.remote.content.len,
+ index, "Client")) {
+ return;
+ }
+ client.remote.in->consume(client.remote.content.len);
+ client.remote.content.len = 0;
+ client.remote.content.type = CONTENT_NONE;
+ client.remote_state = WAITING;
+ client.remote.read_flag = 0;
+ client.remote.write_flag = 0;
+ looper_->modify(client.remote.fd.get(), 0);
+ return;
+ case CONTENT_CHUNKED: {
+ auto used = client.remote.content.chunked->add(ptr, avail);
+ if (!client.remote.content.chunked->good()) {
+ logger_->out(Logger::INFO, "%zu: Client remote bad chunked",
+ index);
+ client.remote.content.type = CONTENT_CLOSE;
+ break;
+ }
+ if (!base_send(&client, ptr, used, index, "Client")) {
+ return;
+ }
+ client.remote.in->consume(used);
+ if (client.remote.content.chunked->eof()) {
+ client.remote.content.chunked.reset();
+ client.remote.content.type = CONTENT_NONE;
+ logger_->out(Logger::INFO, "%zu: chunked -> waiting", index);
+ client.remote_state = WAITING;
+ client.remote.read_flag = 0;
+ client.remote.write_flag = 0;
+ looper_->modify(client.remote.fd.get(), 0);
+ }
+ break;
+ }
+ case CONTENT_CLOSE:
+ if (!base_send(&client, ptr, avail, index, "Client")) {
+ return;
+ }
+ client.remote.in->consume(avail);
+ break;
+ }
+ break;
+ case CONNECTING:
+ case RESOLVING:
+ case CLOSED:
+ assert(false);
+ return;
+ case WAITING:
+ return;
+ }
+ }
+}
+
+bool ProxyImpl::client_send(size_t index, void const* ptr, size_t size) {
+ auto& client = clients_[index];
+ assert(!client.request);
+ assert(client.remote_state >= CONNECTED);
+
+ return base_send(&client.remote, ptr, size, index, "Client remote");
+}
+
+void ProxyImpl::monitor_event(size_t index, int fd, uint8_t events) {
+ auto& monitor = monitors_[index];
+ assert(monitor.fd.get() == fd);
+ if (events & Looper::EVENT_HUP) {
+ close_monitor(index);
+ return;
+ }
+ if (events & Looper::EVENT_ERROR) {
+ logger_->out(Logger::INFO, "%zu: Monitor connection error", index);
+ close_monitor(index);
+ return;
+ }
+ if (!base_event(&monitor, events, index, "Monitor")) {
+ close_monitor(index);
+ return;
+ }
+}
+
+void ProxyImpl::new_base(BaseClient* client, int fd) {
+ client->fd.reset(fd);
+ client->new_connection = true;
+ client->read_flag = Looper::EVENT_READ;
+ client->write_flag = 0;
+ client->last = looper_->now();
+ if (!new_timeout_) {
+ new_timeout_ = looper_->schedule(
+ NEW_CONNECTION_TIMEOUT.count(),
+ std::bind(&ProxyImpl::new_timeout, this));
+ }
+}
+
+void ProxyImpl::new_client(int fd, uint8_t events) {
+ assert(fd == accept_socket_.get());
+ if (events == Looper::EVENT_READ) {
+ assert(!clients_.full());
+ while (true) {
+ int ret = accept4(fd, nullptr, nullptr, SOCK_NONBLOCK);
+ if (ret < 0) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK) return;
+ if (errno == EINTR) continue;
+ logger_->out(Logger::WARN, "Accept failed: %s", strerror(errno));
+ return;
+ }
+ auto index = clients_.new_client();
+ new_base(&clients_[index], ret);
+ clients_[index].content.type = CONTENT_NONE;
+ clients_[index].remote_state = CLOSED;
+ clients_[index].remote.content.type = CONTENT_NONE;
+ looper_->add(ret, clients_[index].read_flag | clients_[index].write_flag,
+ std::bind(&ProxyImpl::client_event, this,
+ index,
+ std::placeholders::_1,
+ std::placeholders::_2));
+ break;
+ }
+ if (clients_.full()) looper_->modify(fd, 0);
+ } else {
+ logger_->out(Logger::ERR, "Accept socket died");
+ fatal_error();
+ }
+}
+
+void ProxyImpl::new_monitor(int fd, uint8_t events) {
+ assert(fd == monitor_socket_.get());
+ if (events == Looper::EVENT_READ) {
+ assert(!monitors_.full());
+ while (true) {
+ int ret = accept4(fd, nullptr, nullptr, SOCK_NONBLOCK);
+ if (ret < 0) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK) return;
+ if (errno == EINTR) continue;
+ logger_->out(Logger::WARN, "Accept failed: %s", strerror(errno));
+ return;
+ }
+ auto index = monitors_.new_client();
+ new_base(&monitors_[index], ret);
+ looper_->add(ret, clients_[index].read_flag | clients_[index].write_flag,
+ std::bind(&ProxyImpl::monitor_event, this,
+ index,
+ std::placeholders::_1,
+ std::placeholders::_2));
+ break;
+ }
+ if (monitors_.full()) {
+ looper_->modify(fd, 0);
+ }
+ } else {
+ logger_->out(Logger::ERR, "Monitor socket died");
+ fatal_error();
+ }
+}
+
+void ProxyImpl::fatal_error() {
+ looper_->quit();
+ good_ = false;
+}
+
+bool ProxyImpl::run() {
+ good_ = true;
+ if (!logger_) {
+ priv_logger_.reset(Logger::create_syslog("tp"));
+ logger_ = priv_logger_.get();
+ }
+ {
+ struct sigaction action;
+ memset(&action, 0, sizeof(action));
+ action.sa_handler = SIG_IGN;
+ action.sa_flags = SA_RESTART;
+ sigaction(SIGPIPE, &action, nullptr);
+ }
+ if (!signal_pipe_.open()) {
+ logger_->out(Logger::WARN,
+ "Failed to create pipes, signals wont work: %s",
+ strerror(errno));
+ }
+ if (signal_pipe_) {
+ looper_->add(signal_pipe_.read(), Looper::EVENT_READ,
+ std::bind(&ProxyImpl::signal_event, this,
+ std::placeholders::_1,
+ std::placeholders::_2));
+ }
+ if (!looper_->run()) {
+ logger_->out(Logger::ERR, "poll() failed: %s", strerror(errno));
+ return false;
+ }
+ return good_;
+}
+
+int setup_socket(char const* host, std::string const& port, Logger* logger) {
+ io::auto_fd ret;
+ struct addrinfo hints;
+ memset(&hints, 0, sizeof(hints));
+ hints.ai_family = AF_UNSPEC;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_flags = AI_V4MAPPED | AI_ADDRCONFIG | AI_PASSIVE;
+ struct addrinfo* result;
+ auto retval = getaddrinfo(host, port.c_str(), &hints, &result);
+ if (retval) {
+ logger->out(Logger::ERR, "getaddrinfo failed for %s:%s: %s",
+ host ? host : "*", port.c_str(), gai_strerror(retval));
+ return -1;
+ }
+ auto retp = result;
+ for (; retp; retp = retp->ai_next) {
+ ret.reset(socket(retp->ai_family, retp->ai_socktype,
+ retp->ai_protocol));
+ if (!ret) continue;
+ if (bind(ret.get(), retp->ai_addr, retp->ai_addrlen) == 0)
+ break;
+ }
+ freeaddrinfo(result);
+ if (!retp) {
+ logger->out(Logger::ERR, "Failed to bind %s:%s: %s",
+ host ? host : "*", port.c_str(), strerror(errno));
+ return -1;
+ }
+ if (listen(ret.get(), SOMAXCONN)) {
+ logger->out(Logger::ERR, "Failed to listen: %s", strerror(errno));
+ return -1;
+ }
+ if (fcntl(ret.get(), F_SETFL, O_NONBLOCK)) {
+ logger->out(Logger::ERR, "fcntl(O_NONBLOCK) failed: %s", strerror(errno));
+ return -1;
+ }
+ return ret.release();
+}
+
+} // namespace
+
+// static
+Proxy* Proxy::create(Config* config, std::string const& cwd,
+ char const* configfile,
+ char const* logfile,
+ Logger* logger,
+ int accept_fd,
+ int monitor_fd) {
+ return new ProxyImpl(config, cwd, configfile, logfile, logger,
+ accept_fd, monitor_fd);
+}
+
+// static
+int Proxy::setup_accept_socket(Config* config, Logger* logger) {
+ return setup_socket(config->get("proxy_bind", nullptr),
+ config->get("proxy_port", "8080"), logger);
+}
+
+// static
+int Proxy::setup_monitor_socket(Config* config, Logger* logger) {
+ assert(config->get("monitor", false));
+ return setup_socket(config->get("monitor_bind", "localhost"),
+ config->get("monitor_port", "9000"), logger);
+}