diff options
| author | Joel Klinghed <the_jk@yahoo.com> | 2017-02-28 21:50:44 +0100 |
|---|---|---|
| committer | Joel Klinghed <the_jk@yahoo.com> | 2017-02-28 21:50:44 +0100 |
| commit | c029d90d1975e124d237605f1edb2be16bd05b5d (patch) | |
| tree | 9df87ffb365354bdb74a969440b32c8304bdbcb7 /src/proxy.cc | |
Initial commit
Diffstat (limited to 'src/proxy.cc')
| -rw-r--r-- | src/proxy.cc | 1373 |
1 files changed, 1373 insertions, 0 deletions
diff --git a/src/proxy.cc b/src/proxy.cc new file mode 100644 index 0000000..fcd02f6 --- /dev/null +++ b/src/proxy.cc @@ -0,0 +1,1373 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include <chrono> +#include <cstring> +#include <fcntl.h> +#include <memory> +#include <netdb.h> +#include <signal.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <vector> + +#include "buffer.hh" +#include "chunked.hh" +#include "config.hh" +#include "http.hh" +#include "io.hh" +#include "logger.hh" +#include "looper.hh" +#include "resolver.hh" +#include "paths.hh" +#include "proxy.hh" +#include "url.hh" + +namespace { + +auto const NEW_CONNECTION_TIMEOUT = std::chrono::duration<float>(5.0f); +auto const CONNECTION_TIMEOUT = std::chrono::duration<float>(30.0f); + +template<typename T> +class clients { +public: + class iterator { + public: + iterator(clients* clients, size_t active) + : clients_(clients), index_(0), active_(active) { + if (active_) { + while (clients_->is_free(index_)) { + ++index_; + } + } + } + iterator(iterator const& i) + : clients_(i.clients_), index_(i.index_), active_(i.active_) { + } + + iterator& operator=(iterator const& i) { + clients_ = i.clients_; + index_ = i.index_; + active_ = i.active_; + return *this; + } + + bool operator==(iterator const& i) const { + return active_ == i.active_; + } + bool operator!=(iterator const& i) const { + return !(*this == i); + } + bool operator<(iterator const& i) const { + return active_ < i.active_; + } + bool operator<=(iterator const& i) const { + return (*this < i) || (*this == i); + } + bool operator>=(iterator const& i) const { + return !(*this < i); + } + bool operator>(iterator const& i) const { + return !(*this <= i); + } + + iterator& operator++() { + next(); + return *this; + } + + iterator operator++(int UNUSED(dummy)) { + iterator ret(*this); + ++(*this); + return ret; + } + + T& operator*() { + return (*clients_)[index_]; + } + T* operator->() { + return &(*clients_)[index_]; + } + + size_t index() const { + return index_; + } + + private: + void next() { + if (active_ == 0) return; + --active_; + if (active_ == 0) return; + ++index_; + while (clients_->is_free(index_)) { + ++index_; + } + } + + clients* clients_; + size_t index_; + size_t active_; + }; + + clients() + : active_(0), max_(0) { + } + + void resize(size_t max) { + max_ = max; + if (max > data_.size()) { + data_.resize(max); + } + } + + bool full() const { + return active_ >= max_; + } + + iterator begin() { + return iterator(this, active_); + } + + iterator end() { + return iterator(this, 0); + } + + T& operator[] (size_t index) { + return data_[index]; + } + + size_t new_client() { + size_t index; + if (active_ == data_.size()) { + index = data_.size(); + data_.emplace_back(); + } else { + index = rand() % data_.size(); + while (!is_free(index)) { + if (++index == data_.size()) { + index = 0; + } + } + } + ++active_; + return index; + } + + void erase(size_t index) { + assert(active_ > 0); + --active_; + assert(is_free(index)); + if (data_.size() > max_ && index >= max_) { + size_t size = data_.size(); + while (size > max_ && is_free(size - 1)) --size; + data_.resize(size); + } + } + + void clear() { + if (active_ > 0) { + data_.clear(); + active_ = 0; + } + data_.resize(max_); + } + +private: + bool is_free(size_t index) const { + return !data_[index].fd; + } + + size_t active_; + size_t max_; + std::vector<T> data_; +}; + +struct BaseClient { + io::auto_fd fd; + bool new_connection; + Looper::clock::time_point last; + std::unique_ptr<Buffer> in; + std::unique_ptr<Buffer> out; + uint8_t read_flag; + uint8_t write_flag; +}; + +enum ContentType { + CONTENT_NONE, + CONTENT_LEN, + CONTENT_CHUNKED, + CONTENT_CLOSE +}; + +struct Content { + ContentType type; + uint64_t len; + std::unique_ptr<Chunked> chunked; +}; + +enum RemoteState { + CLOSED, + RESOLVING, + CONNECTING, + CONNECTED, + WAITING, +}; + +struct RemoteClient : public BaseClient { + Content content; +}; + +struct Client : public BaseClient { + Client() + : resolve(nullptr) { + } + std::unique_ptr<HttpRequest> request; + std::unique_ptr<Url> url; + Content content; + RemoteState remote_state; + void* resolve; + RemoteClient remote; +}; + +struct Monitor : public BaseClient { +}; + +class ProxyImpl : public Proxy { +public: + ProxyImpl(Config* config, std::string const& cwd, char const* configfile, + char const* logfile, Logger* logger, int accept_fd, int monitor_fd) + : config_(config), cwd_(cwd), configfile_(configfile), logfile_(logfile), + logger_(logger), accept_socket_(accept_fd), monitor_socket_(monitor_fd), + looper_(Looper::create()), resolver_(Resolver::create(looper_.get())), + new_timeout_(nullptr), timeout_(nullptr) { + setup(); + } + ~ProxyImpl() override { + if (accept_socket_) { + looper_->remove(accept_socket_.get()); + } + if (monitor_socket_) { + looper_->remove(monitor_socket_.get()); + } + if (signal_pipe_) { + looper_->remove(signal_pipe_.read()); + } + } + + void quit() override { + char cmd = 'q'; + io::write_all(signal_pipe_.write(), &cmd, 1); + } + + void reload() override { + char cmd = 'r'; + io::write_all(signal_pipe_.write(), &cmd, 1); + } + + bool run() override; + +private: + void setup(); + bool reload_config(); + void fatal_error(); + void new_client(int fd, uint8_t events); + void new_monitor(int fd, uint8_t events); + void new_base(BaseClient* client, int fd); + void signal_event(int fd, uint8_t events); + bool base_event(BaseClient* client, uint8_t events, + size_t index, char const* name); + void client_event(size_t index, int fd, uint8_t events); + void client_remote_event(size_t index, int fd, uint8_t events); + void client_empty_input(size_t index); + void monitor_event(size_t index, int fd, uint8_t events); + void close_client(size_t index); + void close_monitor(size_t index); + void close_base(BaseClient* client); + void new_timeout(); + void timeout(); + float handle_timeout(bool new_conn, + std::chrono::duration<float> const& timeout); + bool base_send(BaseClient* client, void const* data, size_t size, + size_t index, char const* name); + void client_error(size_t index, + uint16_t status_code, std::string const& status); + bool client_request(size_t index); + bool client_send(size_t index, void const* ptr, size_t avail); + void client_remote_error(size_t index, uint16_t error); + void close_client_when_done(size_t index); + void client_remote_resolved(size_t index, int fd, bool connected, + char const* error); + + Config* const config_; + std::string cwd_; + char const* const configfile_; + char const* const logfile_; + Logger* logger_; + std::unique_ptr<Logger> priv_logger_; + io::auto_fd accept_socket_; + io::auto_fd monitor_socket_; + io::auto_pipe signal_pipe_; + std::unique_ptr<Looper> looper_; + std::unique_ptr<Resolver> resolver_; + bool good_; + void* new_timeout_; + void* timeout_; + + clients<Client> clients_; + clients<Monitor> monitors_; +}; + +size_t get_size(Config* config, Logger* logger, std::string const& name, + size_t fallback) { + auto value = config->get(name, nullptr); + if (!value) return fallback; + char* end = nullptr; + errno = 0; + auto tmp = strtoul(value, &end, 10); + if (errno || !end || *end) { + logger->out(Logger::WARN, + "Invalid value given for %s: %s, using fallback %zu instead", + name.c_str(), value, fallback); + return fallback; + } + return static_cast<size_t>(tmp); +} + +void ProxyImpl::setup() { + clients_.resize(get_size(config_, logger_, "max_clients", 1024)); + monitors_.resize(get_size(config_, logger_, "max_monitors", 2)); + looper_->add(accept_socket_.get(), + clients_.full() ? 0 : Looper::EVENT_READ, + std::bind(&ProxyImpl::new_client, this, + std::placeholders::_1, + std::placeholders::_2)); + if (monitor_socket_) { + looper_->add(monitor_socket_.get(), + monitors_.full() ? 0 :Looper::EVENT_READ, + std::bind(&ProxyImpl::new_monitor, this, + std::placeholders::_1, + std::placeholders::_2)); + } else { + monitors_.clear(); + } +} + +bool ProxyImpl::reload_config() { + if (configfile_) { + auto file = Paths::join(cwd_, configfile_); + logger_->out(Logger::INFO, "Reloading config file %s", file.c_str()); + config_->load_file(file); + } else { + logger_->out(Logger::INFO, "Reloading config"); + config_->load_name("tp"); + } + if (!config_->good()) { + logger_->out(Logger::WARN, "New config invalid, ignored."); + return true; + } + if (!logfile_) { + auto logfile = config_->get("logfile", nullptr); + if (logfile) { + if (logfile[0] != '/') { + logger_->out(Logger::WARN, + "Logfile need to be an absolute path, not: %s", + logfile); + } else { + std::unique_ptr<Logger> tmp(Logger::create_file(logfile_)); + if (tmp) { + priv_logger_.swap(tmp); + logger_ = priv_logger_.get(); + } else { + logger_->out(Logger::WARN, + "Unable to open %s for logging", logfile); + } + } + } else { + priv_logger_.reset(Logger::create_syslog("tp")); + logger_ = priv_logger_.get(); + } + } + auto const old_accept = accept_socket_.get(); + auto const old_monitor = monitor_socket_.get(); + looper_->remove(old_accept); + if (monitor_socket_) looper_->remove(old_monitor); + accept_socket_.reset(setup_accept_socket(config_, logger_)); + monitor_socket_.reset(); + if (!accept_socket_) { + logger_->out(Logger::ERR, "Unable to bind to new configuration, abort"); + return false; + } + if (config_->get("monitor", false)) { + monitor_socket_.reset(setup_monitor_socket(config_, logger_)); + if (!monitor_socket_) { + logger_->out(Logger::ERR, "Unable to bind to new configuration, abort"); + return false; + } + } + setup(); + return true; +} + +void ProxyImpl::signal_event(int fd, uint8_t events) { + assert(fd == signal_pipe_.read()); + if (events == Looper::EVENT_READ) { + char cmd; + auto ret = io::read(fd, &cmd, 1); + if (ret == 1) { + switch (cmd) { + case 'q': + logger_->out(Logger::INFO, "Exiting ..."); + looper_->quit(); + return; + case 'r': + if (!reload_config()) { + fatal_error(); + return; + } + break; + } + } + } else { + logger_->out(Logger::WARN, "Signal pipes have crashed"); + signal_pipe_.reset(); + looper_->remove(fd); + } +} + +void ProxyImpl::close_base(BaseClient* client) { + if (client->fd) { + looper_->remove(client->fd.get()); + client->fd.reset(); + } + client->in.reset(); + client->out.reset(); +} + +void ProxyImpl::close_client(size_t index) { + bool was_full = clients_.full(); + auto& client = clients_[index]; + client.request.reset(); + client.url.reset(); + client.content.type = CONTENT_NONE; + client.content.chunked.reset(); + client.remote_state = CLOSED; + if (client.resolve) { + resolver_->cancel(client.resolve); + client.resolve = nullptr; + } + close_base(&client.remote); + close_base(&client); + clients_.erase(index); + if (was_full && !clients_.full() && accept_socket_) { + looper_->modify(accept_socket_.get(), Looper::EVENT_READ); + } +} + +void ProxyImpl::close_monitor(size_t index) { + bool was_full = monitors_.full(); + auto& monitor = monitors_[index]; + close_base(&monitor); + monitors_.erase(index); + if (was_full && !monitors_.full() && monitor_socket_) { + looper_->modify(monitor_socket_.get(), Looper::EVENT_READ); + } +} + +float ProxyImpl::handle_timeout(bool new_conn, + std::chrono::duration<float> const& timeout) { + auto now = looper_->now(); + std::vector<size_t> close; + std::vector<bool> remote; + float next = -1.0f; + for (auto i = clients_.begin(); i != clients_.end(); ++i) { + if (i->new_connection != new_conn) continue; + auto diff = std::chrono::duration_cast<std::chrono::duration<float>>( + ((i->last + timeout) - now)).count(); + if (diff < 0.0f) { + close.push_back(i.index()); + remote.push_back(false); + } else { + if (!new_conn && i->remote_state > CLOSED) { + auto diff2 = std::chrono::duration_cast<std::chrono::duration<float>>( + ((i->remote.last + timeout) - now)).count(); + if (diff2 < 0.0f) { + close.push_back(i.index()); + remote.push_back(true); + } else if (diff2 < diff) { + diff = diff2; + } + } + if (next < 0.0f || diff < next) { + next = diff; + } + } + } + assert(close.size() == remote.size()); + auto j = remote.rbegin(); + for (auto i = close.rbegin(); i != close.rend(); ++i, ++j) { + if (*j) { + client_remote_error(*i, 504); + } else { + close_client(*i); + } + } + close.clear(); + for (auto i = monitors_.begin(); i != monitors_.end(); ++i) { + if (i->new_connection != new_conn) continue; + auto diff = std::chrono::duration_cast<std::chrono::duration<float>>( + ((i->last + timeout) - now)).count(); + if (diff < 0.0f) { + close.push_back(i.index()); + } else if (next < 0.0f || diff < next) { + next = diff; + } + } + for (auto i = close.rbegin(); i != close.rend(); ++i) { + close_monitor(*i); + } + return next; +} + +void ProxyImpl::timeout() { + assert(timeout_); + timeout_ = nullptr; + float next = handle_timeout(false, CONNECTION_TIMEOUT); + if (next < 0.0f) return; + timeout_ = looper_->schedule(next, std::bind(&ProxyImpl::timeout, this)); +} + +void ProxyImpl::new_timeout() { + assert(new_timeout_); + new_timeout_ = nullptr; + float next = handle_timeout(true, NEW_CONNECTION_TIMEOUT); + if (next < 0.0f) return; + new_timeout_ = + looper_->schedule(next, std::bind(&ProxyImpl::new_timeout, this)); +} + +bool ProxyImpl::base_send(BaseClient* client, void const* data, size_t size, + size_t index, char const* name) { + if (size == 0) return true; + if (!client->out->empty()) { + // Already waiting for write event + client->out->write(data, size); + return true; + } + auto ret = io::write(client->fd.get(), data, size); + if (ret == -1) { + if (errno != EAGAIN && errno != EWOULDBLOCK) { + logger_->out(Logger::INFO, "%zu: %s write error: %s", index, name, + strerror(errno)); + return false; + } + client->out->write(data, size); + } else { + if (static_cast<size_t>(ret) == size) { + if (!client->in) { + // If input is closed, close after sending all data + return false; + } + return true; + } + client->out->write(reinterpret_cast<char const*>(data) + ret, size - ret); + } + client->write_flag = Looper::EVENT_WRITE; + looper_->modify(client->fd.get(), client->read_flag | client->write_flag); + return true; +} + +bool ProxyImpl::base_event(BaseClient* client, uint8_t events, + size_t index, char const* name) { + if (events & Looper::EVENT_READ) { + if (client->new_connection) { + char tmp[1]; + auto ret = io::read(client->fd.get(), &tmp, 1); + if (ret < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) return true; + logger_->out(Logger::INFO, "%zu: %s read error: %s", index, name, + strerror(errno)); + return false; + } + if (ret == 0) { + return false; + } + client->last = looper_->now(); + client->new_connection = false; + client->in.reset(Buffer::create(8192, 1024)); + client->out.reset(Buffer::create(8192, 1024)); + client->in->write(tmp, 1); + if (!timeout_) { + timeout_ = looper_->schedule(CONNECTION_TIMEOUT.count(), + std::bind(&ProxyImpl::timeout, this)); + } + } + if (!client->in) { + assert(false); + return false; + } + size_t avail; + auto ptr = client->in->write_ptr(&avail); + assert(avail > 0); + auto ret = io::read(client->fd.get(), ptr, avail); + if (ret < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) return true; + logger_->out(Logger::INFO, "%zu: %s read error: %s", index, name, + strerror(errno)); + return false; + } + if (ret == 0) { + return false; + } + client->last = looper_->now(); + client->in->commit(ret); + } + if (events & Looper::EVENT_WRITE) { + if (client->new_connection) { + assert(false); + return true; + } + size_t avail; + auto ptr = client->out->read_ptr(&avail); + if (avail == 0) { + assert(false); + if (!client->in) return false; + client->write_flag = 0; + looper_->modify(client->fd.get(), client->read_flag); + return true; + } + auto ret = io::write(client->fd.get(), ptr, avail); + if (ret < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) return true; + logger_->out(Logger::INFO, "%zu: %s write error: %s", index, name, + strerror(errno)); + return false; + } + if (ret == 0) { + return false; + } + client->last = looper_->now(); + client->out->consume(ret); + if (client->out->empty()) { + if (!client->in) return false; + client->write_flag = 0; + looper_->modify(client->fd.get(), client->read_flag); + } + } + return true; +} + +inline char lower_ascii(char c) { + return (c >= 'A' && c <= 'Z') ? (c - 'A' + 'a') : c; +} + +bool lower_equal(char const* data, size_t start, size_t end, + std::string const& str) { + assert(start <= end); + if (str.size() != end - start) return false; + for (auto i = str.begin(); start < end; ++start, ++i) { + if (lower_ascii(*i) != lower_ascii(data[start])) return false; + } + return true; +} + +bool header_token_eq(std::string const& value, std::string const& token) { + if (value.empty()) return false; + auto pos = value.find(';'); + if (pos == std::string::npos) pos = value.size(); + return lower_equal(value.data(), 0, pos, token); +} + +void ProxyImpl::client_remote_error(size_t index, uint16_t error) { + assert(false); + auto& client = clients_[index]; + if (client.remote_state > CONNECTED) { + // Already started to return response, too late to do anything + close_client(index); + return; + } + char const* status; + switch (error) { + case 502: + status = "Bad Gateway"; + break; + case 504: + status = "Gateway Timeout"; + break; + default: + assert(false); + error = 500; + status = "Internal Server Error"; + break; + } + + client_error(index, error, status); +} + +void ProxyImpl::client_error(size_t index, uint16_t status_code, + std::string const& status_message) { + auto& client = clients_[index]; + // No more input + client.read_flag = 0; + looper_->modify(client.fd.get(), client.write_flag); + client.url.reset(); + + std::string proto; + Version version; + + if (!client.request) { + proto = "HTTP"; + version.major = 1; + version.minor = 1; + } else { + proto = client.request->proto(); + version = client.request->proto_version(); + } + + auto resp = std::unique_ptr<HttpResponseBuilder>( + HttpResponseBuilder::create( + proto, version, status_code, status_message)); + resp->add_header("Content-Length", "0"); + resp->add_header("Connection", "close"); + + client.in.reset(); + client.request.reset(); + + auto data = resp->build(); + + if (!base_send(&client, data.data(), data.size(), index, "Client")) { + close_client(index); + return; + } +} + +bool ProxyImpl::client_request(size_t index) { + auto& client = clients_[index]; + auto version = client.request->proto_version(); + if (version.major != 1 || version.minor > 1) { + client_error(index, 505, "HTTP Version Not Supported"); + return false; + } + if (client.url->scheme() != "http") { + client_error(index, 501, "Not Implemented"); + return false; + } + if (client.url->userinfo()) { + client_error(index, 400, "Bad request"); + return false; + } + if (client.remote_state == WAITING) { + client.remote_state = CONNECTING; + client.remote.read_flag = Looper::EVENT_READ; + client.remote.write_flag = 0; + looper_->modify(client.remote.fd.get(), + client.remote.read_flag | client.remote.write_flag); + client_remote_event(index, client.remote.fd.get(), Looper::EVENT_WRITE); + } else { + assert(client.remote_state == CLOSED); + client.remote.last = looper_->now(); + client.remote.new_connection = true; + client.remote_state = RESOLVING; + + auto port = client.url->port(); + if (port == 0) port = 80; + client.resolve = resolver_->request( + client.url->host(), port, + std::bind(&ProxyImpl::client_remote_resolved, this, index, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3)); + } + return true; +} + +void ProxyImpl::client_remote_resolved(size_t index, int fd, bool connected, + char const* error) { + auto& client = clients_[index]; + assert(client.resolve); + client.resolve = nullptr; + assert(client.remote_state == RESOLVING); + if (fd < 0) { + logger_->out(Logger::INFO, "%zu: Client unable to resolve remote: %s", + index, error); + client_remote_error(index, 502); + return; + } + client.remote_state = CONNECTING; + client.remote.fd.reset(fd); + client.remote.last = looper_->now(); + client.remote.new_connection = false; + client.remote.in.reset(Buffer::create(8192, 1024)); + client.remote.out.reset(Buffer::create(8192, 1024)); + if (connected) { + client.remote.read_flag = Looper::EVENT_READ; + client.remote.write_flag = 0; + } else { + client.remote.read_flag = 0; + client.remote.write_flag = Looper::EVENT_WRITE; + } + looper_->add(client.remote.fd.get(), + client.remote.read_flag | client.remote.write_flag, + std::bind(&ProxyImpl::client_remote_event, this, index, + std::placeholders::_1, + std::placeholders::_2)); + assert(timeout_); + if (connected) { + client_remote_event(index, fd, Looper::EVENT_WRITE); + } +} + +void ProxyImpl::client_event(size_t index, int fd, uint8_t events) { + auto& client = clients_[index]; + assert(client.fd.get() == fd); + if (events & Looper::EVENT_HUP) { + close_client(index); + return; + } + if (events & Looper::EVENT_ERROR) { + logger_->out(Logger::INFO, "%zu: Client connection error", index); + close_client(index); + return; + } + if (!base_event(&client, events, index, "Client")) { + close_client(index); + return; + } + if (client.new_connection) return; + client_empty_input(index); +} + +bool setup_content(Http const* http, Content* content) { + assert(content->type == CONTENT_NONE); + std::string te = http->first_header("transfer-encoding"); + if (te.empty() || header_token_eq(te, "identity")) { + std::string len = http->first_header("content-length"); + if (len.empty()) { + content->type = CONTENT_CLOSE; + return true; + } + char* end = nullptr; + errno = 0; + auto tmp = strtoull(len.c_str(), &end, 10); + if (errno || !end || *end) { + return false; + } + if (tmp == 0) { + content->type = CONTENT_NONE; + return true; + } + content->len = tmp; + content->type = CONTENT_LEN; + } else { + content->chunked.reset(Chunked::create()); + content->type = CONTENT_CHUNKED; + } + return true; +} + +void ProxyImpl::client_empty_input(size_t index) { + auto& client = clients_[index]; + while (true) { + size_t avail; + auto ptr = client.in->read_ptr(&avail); + if (avail == 0) return; + switch (client.content.type) { + case CONTENT_CLOSE: + assert(false); + // falltrough + case CONTENT_NONE: { + if (client.remote_state != CLOSED && client.remote_state != WAITING) { + // Still working on the last request, wait + return; + } + client.request.reset( + HttpRequest::parse( + reinterpret_cast<char const*>(ptr), avail, false)); + if (!client.request) { + if (avail >= 1024 * 1024) { + logger_->out(Logger::INFO, "%zu: Client too large request %zu", + index, avail); + close_client(index); + } + return; + } + if (!client.request->good()) { + client_error(index, 400, "Bad request"); + return; + } + if (client.request->method_equal("CONNECT")) { + client_error(index, 501, "Not Implemented"); + return; + } + client.url.reset(Url::parse(client.request->url())); + if (!client.url) { + client_error(index, 400, "Bad request"); + return; + } + if (!setup_content(client.request.get(), &client.content)) { + logger_->out(Logger::INFO, "%zu: Client bad content-length", index); + client_error(index, 400, "Bad request"); + return; + } + if (client.content.type == CONTENT_CLOSE) { + client.content.type = CONTENT_NONE; + } + if (!client_request(index)) { + client.content.type = CONTENT_NONE; + return; + } + break; + } + case CONTENT_LEN: + if (client.remote_state < CONNECTED) { + // Request hasn't been sent yet, still collecting data + return; + } + if (avail < client.content.len) { + if (!client_send(index, ptr, avail)) { + return; + } + client.in->consume(avail); + client.content.len -= avail; + return; + } + if (!client_send(index, ptr, client.content.len)) { + return; + } + client.in->consume(client.content.len); + client.content.len = 0; + client.content.type = CONTENT_NONE; + break; + case CONTENT_CHUNKED: + if (client.remote_state < CONNECTED) { + // Request hasn't been sent yet, still collecting data + return; + } + auto used = client.content.chunked->add(ptr, avail); + if (!client.content.chunked->good()) { + client_error(index, 400, "Bad request"); + return; + } + if (!client_send(index, ptr, used)) { + return; + } + client.in->consume(used); + if (client.content.chunked->eof()) { + client.content.chunked.reset(); + client.content.type = CONTENT_NONE; + } + break; + } + } +} + +void ProxyImpl::close_client_when_done(size_t index) { + auto& client = clients_[index]; + if (client.out->empty()) { + close_client(index); + return; + } + client.remote_state = CLOSED; + close_base(&client.remote); + client.in.reset(); + client.read_flag = 0; + client.request.reset(); + client.content.type = CONTENT_NONE; + looper_->modify(client.fd.get(), client.write_flag); +} + +void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) { + auto& client = clients_[index]; + assert(client.remote.fd.get() == fd); + if (events & Looper::EVENT_HUP) { + logger_->out(Logger::INFO, "%zu: Client remote connection closed", index); + client_remote_error(index, 502); + return; + } + if (events & Looper::EVENT_ERROR) { + logger_->out(Logger::INFO, "%zu: Client remote connection error", index); + client_remote_error(index, 502); + return; + } + if (client.remote_state == CONNECTING) { + if (events & Looper::EVENT_WRITE) { + std::string url(client.url->path_escaped()); + if (url.empty()) url.push_back('/'); + auto query = client.url->full_query_escaped(); + if (query) { + url.push_back('?'); + url.append(query); + } + auto req = std::unique_ptr<HttpRequestBuilder>( + HttpRequestBuilder::create( + client.request->method(), + url, + client.request->proto(), + client.request->proto_version())); + auto iter = client.request->header(); + bool have_host = false; + for (; iter->valid(); iter->next()) { + if (!have_host && iter->name_equal("host")) have_host = true; + if (iter->name_equal("proxy-connection") || + iter->name_equal("proxy-authenticate") || + iter->name_equal("proxy-authorization")) { + continue; + } + req->add_header(iter->name(), iter->value()); + } + if (!have_host && + (client.request->proto_version().major == 1 && + client.request->proto_version().minor == 1)) { + req->add_header("host", client.url->host()); + } + auto data = req->build(); + client.in->consume(client.request->size()); + client.request.reset(); + client.url.reset(); + client.remote.out->write(data.data(), data.size()); + client.remote_state = CONNECTED; + client.remote.read_flag = Looper::EVENT_READ; + client.remote.write_flag = Looper::EVENT_WRITE; + looper_->modify(client.remote.fd.get(), + client.remote.read_flag | client.remote.write_flag); + client_empty_input(index); + } else { + return; + } + } + if (!base_event(&client.remote, events, index, "Client remote")) { + switch (client.remote_state) { + case CONNECTED: + switch (client.remote.content.type) { + case CONTENT_CLOSE: + close_client_when_done(index); + break; + case CONTENT_CHUNKED: + case CONTENT_LEN: + case CONTENT_NONE: + client_remote_error(index, 502); + break; + } + break; + case CONNECTING: + case RESOLVING: + case CLOSED: + assert(false); + client_remote_error(index, 502); + break; + case WAITING: + close_base(&client.remote); + client.remote_state = CLOSED; + break; + } + return; + } + while (true) { + size_t avail; + auto ptr = client.remote.in->read_ptr(&avail); + if (avail == 0) return; + switch (client.remote_state) { + case CONNECTED: + switch (client.remote.content.type) { + case CONTENT_NONE: { + auto response = std::unique_ptr<HttpResponse>( + HttpResponse::parse( + reinterpret_cast<char const*>(ptr), avail, false)); + if (!response) { + if (avail >= 1024 * 1024) { + logger_->out(Logger::INFO, + "%zu: Client remote too large request %zu", + index, avail); + client_remote_error(index, 502); + } + return; + } + if (!response->good()) { + client_remote_error(index, 502); + return; + } + if (!setup_content(response.get(), &client.remote.content)) { + logger_->out(Logger::INFO, "%zu: Client remote bad content-length", + index); + client.remote.content.type = CONTENT_CLOSE; + } else { + if (client.remote.content.type == CONTENT_NONE) { + client.remote_state = WAITING; + client.remote.read_flag = 0; + client.remote.write_flag = 0; + looper_->modify(client.remote.fd.get(), 0); + } + } + if (!base_send(&client, ptr, response->size(), index, "Client")) { + return; + } + client.remote.in->consume(response->size()); + break; + } + case CONTENT_LEN: + if (avail < client.remote.content.len) { + if (!base_send(&client, ptr, avail, index, "Client")) { + return; + } + client.remote.in->consume(avail); + client.remote.content.len -= avail; + return; + } + if (!base_send(&client, ptr, client.remote.content.len, + index, "Client")) { + return; + } + client.remote.in->consume(client.remote.content.len); + client.remote.content.len = 0; + client.remote.content.type = CONTENT_NONE; + client.remote_state = WAITING; + client.remote.read_flag = 0; + client.remote.write_flag = 0; + looper_->modify(client.remote.fd.get(), 0); + return; + case CONTENT_CHUNKED: { + auto used = client.remote.content.chunked->add(ptr, avail); + if (!client.remote.content.chunked->good()) { + logger_->out(Logger::INFO, "%zu: Client remote bad chunked", + index); + client.remote.content.type = CONTENT_CLOSE; + break; + } + if (!base_send(&client, ptr, used, index, "Client")) { + return; + } + client.remote.in->consume(used); + if (client.remote.content.chunked->eof()) { + client.remote.content.chunked.reset(); + client.remote.content.type = CONTENT_NONE; + logger_->out(Logger::INFO, "%zu: chunked -> waiting", index); + client.remote_state = WAITING; + client.remote.read_flag = 0; + client.remote.write_flag = 0; + looper_->modify(client.remote.fd.get(), 0); + } + break; + } + case CONTENT_CLOSE: + if (!base_send(&client, ptr, avail, index, "Client")) { + return; + } + client.remote.in->consume(avail); + break; + } + break; + case CONNECTING: + case RESOLVING: + case CLOSED: + assert(false); + return; + case WAITING: + return; + } + } +} + +bool ProxyImpl::client_send(size_t index, void const* ptr, size_t size) { + auto& client = clients_[index]; + assert(!client.request); + assert(client.remote_state >= CONNECTED); + + return base_send(&client.remote, ptr, size, index, "Client remote"); +} + +void ProxyImpl::monitor_event(size_t index, int fd, uint8_t events) { + auto& monitor = monitors_[index]; + assert(monitor.fd.get() == fd); + if (events & Looper::EVENT_HUP) { + close_monitor(index); + return; + } + if (events & Looper::EVENT_ERROR) { + logger_->out(Logger::INFO, "%zu: Monitor connection error", index); + close_monitor(index); + return; + } + if (!base_event(&monitor, events, index, "Monitor")) { + close_monitor(index); + return; + } +} + +void ProxyImpl::new_base(BaseClient* client, int fd) { + client->fd.reset(fd); + client->new_connection = true; + client->read_flag = Looper::EVENT_READ; + client->write_flag = 0; + client->last = looper_->now(); + if (!new_timeout_) { + new_timeout_ = looper_->schedule( + NEW_CONNECTION_TIMEOUT.count(), + std::bind(&ProxyImpl::new_timeout, this)); + } +} + +void ProxyImpl::new_client(int fd, uint8_t events) { + assert(fd == accept_socket_.get()); + if (events == Looper::EVENT_READ) { + assert(!clients_.full()); + while (true) { + int ret = accept4(fd, nullptr, nullptr, SOCK_NONBLOCK); + if (ret < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) return; + if (errno == EINTR) continue; + logger_->out(Logger::WARN, "Accept failed: %s", strerror(errno)); + return; + } + auto index = clients_.new_client(); + new_base(&clients_[index], ret); + clients_[index].content.type = CONTENT_NONE; + clients_[index].remote_state = CLOSED; + clients_[index].remote.content.type = CONTENT_NONE; + looper_->add(ret, clients_[index].read_flag | clients_[index].write_flag, + std::bind(&ProxyImpl::client_event, this, + index, + std::placeholders::_1, + std::placeholders::_2)); + break; + } + if (clients_.full()) looper_->modify(fd, 0); + } else { + logger_->out(Logger::ERR, "Accept socket died"); + fatal_error(); + } +} + +void ProxyImpl::new_monitor(int fd, uint8_t events) { + assert(fd == monitor_socket_.get()); + if (events == Looper::EVENT_READ) { + assert(!monitors_.full()); + while (true) { + int ret = accept4(fd, nullptr, nullptr, SOCK_NONBLOCK); + if (ret < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) return; + if (errno == EINTR) continue; + logger_->out(Logger::WARN, "Accept failed: %s", strerror(errno)); + return; + } + auto index = monitors_.new_client(); + new_base(&monitors_[index], ret); + looper_->add(ret, clients_[index].read_flag | clients_[index].write_flag, + std::bind(&ProxyImpl::monitor_event, this, + index, + std::placeholders::_1, + std::placeholders::_2)); + break; + } + if (monitors_.full()) { + looper_->modify(fd, 0); + } + } else { + logger_->out(Logger::ERR, "Monitor socket died"); + fatal_error(); + } +} + +void ProxyImpl::fatal_error() { + looper_->quit(); + good_ = false; +} + +bool ProxyImpl::run() { + good_ = true; + if (!logger_) { + priv_logger_.reset(Logger::create_syslog("tp")); + logger_ = priv_logger_.get(); + } + { + struct sigaction action; + memset(&action, 0, sizeof(action)); + action.sa_handler = SIG_IGN; + action.sa_flags = SA_RESTART; + sigaction(SIGPIPE, &action, nullptr); + } + if (!signal_pipe_.open()) { + logger_->out(Logger::WARN, + "Failed to create pipes, signals wont work: %s", + strerror(errno)); + } + if (signal_pipe_) { + looper_->add(signal_pipe_.read(), Looper::EVENT_READ, + std::bind(&ProxyImpl::signal_event, this, + std::placeholders::_1, + std::placeholders::_2)); + } + if (!looper_->run()) { + logger_->out(Logger::ERR, "poll() failed: %s", strerror(errno)); + return false; + } + return good_; +} + +int setup_socket(char const* host, std::string const& port, Logger* logger) { + io::auto_fd ret; + struct addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_V4MAPPED | AI_ADDRCONFIG | AI_PASSIVE; + struct addrinfo* result; + auto retval = getaddrinfo(host, port.c_str(), &hints, &result); + if (retval) { + logger->out(Logger::ERR, "getaddrinfo failed for %s:%s: %s", + host ? host : "*", port.c_str(), gai_strerror(retval)); + return -1; + } + auto retp = result; + for (; retp; retp = retp->ai_next) { + ret.reset(socket(retp->ai_family, retp->ai_socktype, + retp->ai_protocol)); + if (!ret) continue; + if (bind(ret.get(), retp->ai_addr, retp->ai_addrlen) == 0) + break; + } + freeaddrinfo(result); + if (!retp) { + logger->out(Logger::ERR, "Failed to bind %s:%s: %s", + host ? host : "*", port.c_str(), strerror(errno)); + return -1; + } + if (listen(ret.get(), SOMAXCONN)) { + logger->out(Logger::ERR, "Failed to listen: %s", strerror(errno)); + return -1; + } + if (fcntl(ret.get(), F_SETFL, O_NONBLOCK)) { + logger->out(Logger::ERR, "fcntl(O_NONBLOCK) failed: %s", strerror(errno)); + return -1; + } + return ret.release(); +} + +} // namespace + +// static +Proxy* Proxy::create(Config* config, std::string const& cwd, + char const* configfile, + char const* logfile, + Logger* logger, + int accept_fd, + int monitor_fd) { + return new ProxyImpl(config, cwd, configfile, logfile, logger, + accept_fd, monitor_fd); +} + +// static +int Proxy::setup_accept_socket(Config* config, Logger* logger) { + return setup_socket(config->get("proxy_bind", nullptr), + config->get("proxy_port", "8080"), logger); +} + +// static +int Proxy::setup_monitor_socket(Config* config, Logger* logger) { + assert(config->get("monitor", false)); + return setup_socket(config->get("monitor_bind", "localhost"), + config->get("monitor_port", "9000"), logger); +} |
