diff options
| author | Joel Klinghed <the_jk@yahoo.com> | 2017-03-28 22:36:44 +0200 |
|---|---|---|
| committer | Joel Klinghed <the_jk@yahoo.com> | 2017-03-28 22:36:44 +0200 |
| commit | d01e13c9dee53c3ab4faf70a215f4d1dcfed9e87 (patch) | |
| tree | 90975d8502a6c610a58f5d3cd8014bcf8443c0e9 /src | |
| parent | 87774d8981ae7a079492d8949e205065ba72a8e4 (diff) | |
MITM SSL Interception support using mbedtls
Diffstat (limited to 'src')
| -rw-r--r-- | src/.gitignore | 2 | ||||
| -rw-r--r-- | src/Makefile.am | 14 | ||||
| -rw-r--r-- | src/buffer.cc | 2 | ||||
| -rw-r--r-- | src/genca.cc | 69 | ||||
| -rw-r--r-- | src/logger.cc | 21 | ||||
| -rw-r--r-- | src/logger.hh | 3 | ||||
| -rw-r--r-- | src/lru.hh | 70 | ||||
| -rw-r--r-- | src/mitm.cc | 253 | ||||
| -rw-r--r-- | src/mitm.hh | 66 | ||||
| -rw-r--r-- | src/proxy.cc | 160 | ||||
| -rw-r--r-- | src/ssl.cc | 602 | ||||
| -rw-r--r-- | src/ssl.hh | 89 |
12 files changed, 1342 insertions, 9 deletions
diff --git a/src/.gitignore b/src/.gitignore index e4fbae2..7045943 100644 --- a/src/.gitignore +++ b/src/.gitignore @@ -1,6 +1,8 @@ /config.h /config.h.in~ /libmonitor.a +/libmitm.a /libtp.a /tp +/tp-genca /tp-monitor diff --git a/src/Makefile.am b/src/Makefile.am index 84c4401..fbcdc67 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -7,9 +7,16 @@ ARFLAGS = cr bin_PROGRAMS = tp tp-monitor noinst_LIBRARIES = libtp.a libmonitor.a +if HAVE_SSL +bin_PROGRAMS += tp-genca +noinst_LIBRARIES += libmitm.a +endif tp_SOURCES = main.cc proxy.cc logger.cc resolver.cc tp_LDADD = libtp.a @THREAD_LIBS@ +if HAVE_SSL +tp_LDADD += libmitm.a @SSL_LIBS@ +endif tp_CXXFLAGS = $(AM_CXXFLAGS) -DVERSION='"@VERSION@"' @THREAD_CFLAGS@ libtp_a_SOURCES = args.cc xdg.cc terminal.cc http.cc url.cc paths.cc \ @@ -17,6 +24,13 @@ libtp_a_SOURCES = args.cc xdg.cc terminal.cc http.cc url.cc paths.cc \ buffer.cc chunked.cc libtp_a_CXXFLAGS = $(AM_CXXFLAGS) -DSYSCONFDIR='"@SYSCONFDIR@"' +libmitm_a_SOURCES = ssl.cc mitm.cc +libmitm_a_CXXFLAGS = $(AM_CXXFLAGS) @SSL_CFLAGS@ + +tp_genca_SOURCES = genca.cc logger.cc +tp_genca_LDADD = libmitm.a libtp.a @SSL_LIBS@ +tp_genca_CXXFLAGS = $(AM_CXXFLAGS) -DVERSION='"@VERSION@"' + libmonitor_a_SOURCES = monitor.cc resolver.cc libmonitor_CXXFLAGS = $(AM_CXXFLAGS) -DVERSION='"@VERSION@"' @THREAD_CFLAGS@ diff --git a/src/buffer.cc b/src/buffer.cc index 4bde195..d220514 100644 --- a/src/buffer.cc +++ b/src/buffer.cc @@ -95,7 +95,7 @@ size_t Buffer::read(void* data, size_t max) { if (avail == 0) return 0; avail = std::min(avail, max); memcpy(data, ptr, avail); - commit(avail); + consume(avail); return avail; } diff --git a/src/genca.cc b/src/genca.cc new file mode 100644 index 0000000..ee0f580 --- /dev/null +++ b/src/genca.cc @@ -0,0 +1,69 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include <fstream> +#include <iostream> +#include <memory> +#include <string> + +#include "args.hh" +#include "logger.hh" +#include "ssl.hh" + +namespace { + +bool genca(std::ostream& out, std::string const& name) { + std::unique_ptr<Logger> logger(Logger::create_stderr()); + std::unique_ptr<SSLEntropy> entropy(SSLEntropy::create(logger.get())); + if (!entropy) return false; + std::string key; + if (!SSLKey::generate(logger.get(), entropy.get(), &key)) return false; + std::string cert; + std::unique_ptr<SSLKey> pkey(SSLKey::load(logger.get(), key)); + if (!SSLCert::generate(logger.get(), entropy.get(), nullptr, nullptr, name, + pkey.get(), &cert)) return false; + out << cert << '\n' << key << std::endl; + return true; +} + +} // namespace + +int main(int argc, char** argv) { + std::unique_ptr<Args> args(Args::create()); + args->add('o', "output", "FILE", + "output certificate and key to FILE instead of stdout."); + args->add('n', "name", "NAME", + "Issuer name to use instead of TransparentProxy"); + args->add('h', "help", "display this text and exit."); + args->add('V', "version", "display version and exit."); + if (!args->run(argc, argv)) { + std::cerr << "Try `tp-genca --help` for usage." << std::endl; + return EXIT_FAILURE; + } + if (args->is_set('h')) { + std::cout << "Usage: `tp-genca [OPTIONS...]`\n" + << "Generate a self-signed CA to use for MITM SSL interception.\n" + << '\n'; + args->print_help(); + return EXIT_SUCCESS; + } + if (args->is_set('V')) { + std::cout << "TransparentProxy version " VERSION + << " written by Joel Klinghed <the_jk@yahoo.com>" << std::endl; + return EXIT_SUCCESS; + } + if (!args->arguments().empty()) { + std::cerr << "Too many arguments.\n" + << "Try `tp-genca --help` for usage." << std::endl; + return EXIT_FAILURE; + } + auto name = args->arg('n', "TransparentProxy"); + auto output = args->arg('o', nullptr); + if (output) { + std::ofstream out(output); + return genca(out, name) ? EXIT_SUCCESS : EXIT_FAILURE; + } else { + return genca(std::cout, name) ? EXIT_SUCCESS : EXIT_FAILURE; + } +} diff --git a/src/logger.cc b/src/logger.cc index 29dcb38..343bb4d 100644 --- a/src/logger.cc +++ b/src/logger.cc @@ -21,7 +21,12 @@ namespace { class LoggerStdErr : public Logger { public: - void out(Level UNUSED(lvl), char const* format, ...) override { + void out(Level lvl, char const* format, ...) override { + if (lvl == DBG) { +#ifdef NDEBUG + return; +#endif + } char* tmp; va_list args; va_start(args, format); @@ -44,6 +49,11 @@ public: } void out(Level lvl, char const* format, ...) override { + if (lvl == DBG) { +#ifdef NDEBUG + return; +#endif + } va_list args; va_start(args, format); vsyslog(lvl2prio(lvl), format, args); @@ -59,6 +69,8 @@ private: return LOG_WARNING; case INFO: return LOG_INFO; + case DBG: + return LOG_DEBUG; } assert(false); return LOG_INFO; @@ -82,6 +94,11 @@ public: } void out(Level lvl, char const* format, ...) override { + if (lvl == DBG) { +#ifdef NDEBUG + return; +#endif + } fputs(lvl2str(lvl), fh_); fwrite(": ", 1, 2, fh_); va_list args; @@ -100,6 +117,8 @@ private: return "Warning"; case INFO: return "Info"; + case DBG: + return "Debug"; } assert(false); return "Info"; diff --git a/src/logger.hh b/src/logger.hh index 8b0db05..3700635 100644 --- a/src/logger.hh +++ b/src/logger.hh @@ -12,7 +12,8 @@ public: enum Level { ERR, WARN, - INFO + INFO, + DBG }; static Logger* create_stderr(); diff --git a/src/lru.hh b/src/lru.hh new file mode 100644 index 0000000..3279b35 --- /dev/null +++ b/src/lru.hh @@ -0,0 +1,70 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef LRU_HH +#define LRU_HH + +#include <algorithm> +#include <deque> +#include <unordered_map> + +template<typename K, typename V> +class lru { +public: + lru(size_t size) + : size_(size) { + } + + bool empty() const { + return usage_.empty(); + } + + size_t size() const { + return usage_.size(); + } + + V* get(K const& key) { + auto it = data_.find(key); + if (it == data_.end()) return nullptr; + auto it2 = std::find(usage_.begin(), usage_.end(), key); + if (it2 != usage_.begin()) { + usage_.erase(it2); + usage_.push_front(key); + } + return &it->second; + } + + void insert(K const& key, V const& data) { + auto pair = data_.insert(std::make_pair(key, data)); + if (pair.second) { + usage_.push_front(key); + shrink_to_size(); + } else { + pair.first->second = data; + auto it = std::find(usage_.begin(), usage_.end(), key); + if (it != usage_.begin()) { + usage_.erase(it); + usage_.push_front(key); + } + } + } + + void erase(K const& key) { + if (data_.erase(key)) { + usage_.erase(std::find(usage_.begin(), usage_.end(), key)); + } + } + +private: + void shrink_to_size() { + if (usage_.size() > size_) { + data_.erase(usage_.back()); + usage_.pop_back(); + } + } + + size_t const size_; + std::deque<K> usage_; + std::unordered_map<K, V> data_; +}; + +#endif // LRU_HH diff --git a/src/mitm.cc b/src/mitm.cc new file mode 100644 index 0000000..7e0fd19 --- /dev/null +++ b/src/mitm.cc @@ -0,0 +1,253 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include <fstream> +#include <memory> + +#include "config.hh" +#include "buffer.hh" +#include "paths.hh" +#include "logger.hh" +#include "lru.hh" +#include "mitm.hh" +#include "ssl.hh" + +namespace { + +class ConnectionImpl : public Mitm::Connection { +public: + ConnectionImpl(Logger* logger) + : logger_(logger){ + } + + bool connect(SSLEntropy* entropy, SSLCertStore* store, + std::string const& cert, SSLKey* key, + bool unsecure, std::string const& host) { + uint16_t flags = unsecure ? SSL::UNSECURE : 0; + cert_.reset(SSLCert::load(logger_, cert)); + if (!cert_) return false; + local_.reset(SSL::server(logger_, entropy, cert_.get(), key, flags)); + if (!local_) return false; + remote_.reset(SSL::client(logger_, entropy, store, host, flags)); + if (!remote_) return false; + + in_.reset(Buffer::create(8192, 1024)); + out_.reset(Buffer::create(8192, 1024)); + + return true; + } + + bool transfer(Buffer* local_in, Buffer* local_out, + Buffer* remote_in, Buffer* remote_out, + Mitm::Monitor* monitor) override { + bool local_active = true, remote_active = true; + while (local_active || remote_active) { + local_active = false; + if (local_) { + size_t avail; + out_->read_ptr(&avail); + switch (local_->transfer(local_in, local_out, in_.get(), out_.get())) { + case SSL::NO_ERR: { + size_t avail2; + auto ptr = out_->read_ptr(&avail2); + if (avail < avail2) { + if (monitor) { + monitor->local2remote( + static_cast<char const*>(ptr) + avail, avail2 - avail); + } + local_active = true; + } + break; + } + case SSL::ERR: + return false; + case SSL::CLOSED: + local_.reset(); + break; + } + } + remote_active = false; + if (remote_) { + size_t avail; + in_->read_ptr(&avail); + switch (remote_->transfer(remote_in, remote_out, + out_.get(), in_.get())) { + case SSL::NO_ERR: { + size_t avail2; + auto ptr = in_->read_ptr(&avail2); + if (avail < avail2) { + if (monitor) { + monitor->remote2local( + static_cast<char const*>(ptr) + avail, avail2 - avail); + } + remote_active = true; + } + break; + } + case SSL::ERR: + return false; + case SSL::CLOSED: + remote_.reset(); + break; + } + } + } + return true; + } + + bool local_eof() const override { + return !local_; + } + + bool remote_eof() const override { + return !remote_; + } + + void close_local() override { + if (local_) { + local_->close(); + } + } + void close_remote() override { + if (remote_) { + remote_->close(); + } + } + +private: + Logger* const logger_; + std::unique_ptr<SSLCert> cert_; + std::unique_ptr<SSL> local_; + std::unique_ptr<SSL> remote_; + std::unique_ptr<Buffer> in_; + std::unique_ptr<Buffer> out_; +}; + +class MitmImpl : public Mitm { +public: + MitmImpl(Logger* logger) + : logger_(logger), unsecure_(false), cache_(42) { + } + + ~MitmImpl() override { + } + + bool load(Config* config, std::string const& cwd) { + entropy_.reset(SSLEntropy::create(logger_)); + if (!entropy_) return false; + store_.reset(SSLCertStore::create( + logger_, config->get("ssl_cert_bundle", ""))); + if (!store_) return false; + std::string ca_cert, ca_key; + if (!load_file(config->get("ssl_ca_cert", ""), cwd, &ca_cert)) return false; + if (!load_file(config->get("ssl_ca_key", ""), cwd, &ca_key)) return false; + unsecure_ = config->get("ssl_unsecure", false); + issuer_cert_.reset(SSLCert::load(logger_, ca_cert)); + if (!issuer_cert_) return false; + issuer_key_.reset(SSLKey::load(logger_, ca_key)); + if (!issuer_key_) return false; + return true; + } + + bool reload(Config* config, std::string const& cwd) override { + std::unique_ptr<SSLCertStore> store( + SSLCertStore::create(logger_, config->get("ssl_cert_bundle", ""))); + if (!store) return false; + std::string ca_cert, ca_key; + if (!load_file(config->get("ssl_ca_cert", ""), cwd, &ca_cert)) return false; + if (!load_file(config->get("ssl_ca_key", ""), cwd, &ca_key)) return false; + unsecure_ = config->get("ssl_unsecure", false); + std::unique_ptr<SSLCert> issuer_cert(SSLCert::load(logger_, ca_cert)); + if (!issuer_cert) return false; + std::unique_ptr<SSLKey> issuer_key(SSLKey::load(logger_, ca_key)); + if (!issuer_key) return false; + store_.swap(store); + issuer_cert_.swap(issuer_cert); + issuer_key_.swap(issuer_key); + return true; + } + + DetectResult detect(void const* data, size_t avail) override { + if (avail == 0) return NEED_MORE; + auto d = static_cast<uint8_t const*>(data); + // SSL 3.0 + if (d[0] == 0x16) { + if (avail < 2) return NEED_MORE; + if (d[1] != 0x03) return OTHER; // Need fixing when SSL 4.0 shows up + if (avail < 5) return NEED_MORE; + if (((d[3] << 8) | d[4]) < 9) // Min size of client hello + return OTHER; + if (avail < 6) return NEED_MORE; + return d[5] == 0x01 ? SSL : OTHER; + } + // SSL 2.0 + if (!(d[0] & 0x80)) return OTHER; + if (avail < 2) return NEED_MORE; + if (((d[0] & 0x7f) << 8 | d[1]) < 9) // Min size of client hello + return OTHER; + if (avail < 3) return NEED_MORE; + return d[2] == 0x01 ? SSL : OTHER; + } + + Connection* open(std::string const& host) override { + CacheEntry* entry; + CacheEntry _entry; + entry = cache_.get(host); + if (!entry) { + entry = &_entry; + if (!SSLCert::generate(logger_, entropy_.get(), + issuer_cert_.get(), issuer_key_.get(), host, + issuer_key_.get(), &entry->cert)) { + return nullptr; + } + cache_.insert(host, _entry); + } + std::unique_ptr<ConnectionImpl> conn(new ConnectionImpl(logger_)); + if (!conn->connect(entropy_.get(), store_.get(), entry->cert, + issuer_key_.get(), unsecure_, host)) return nullptr; + return conn.release(); + } + +private: + struct CacheEntry { + std::string cert; + }; + + bool load_file(std::string const& path, std::string const& cwd, + std::string* out) const { + std::ifstream in(Paths::join(cwd, path)); + if (!in.good()) { + logger_->out(Logger::ERR, "Unable to open: %s", path.c_str()); + return false; + } + out->clear(); + while (in.good()) { + char buffer[8192]; + in.read(buffer, sizeof(buffer)); + out->append(buffer, in.gcount()); + } + if (!in.eof()) { + logger_->out(Logger::ERR, "Unable to open: %s", path.c_str()); + return false; + } + return true; + } + + Logger* const logger_; + std::unique_ptr<SSLEntropy> entropy_; + std::unique_ptr<SSLCertStore> store_; + std::unique_ptr<SSLCert> issuer_cert_; + std::unique_ptr<SSLKey> issuer_key_; + bool unsecure_; + lru<std::string, CacheEntry> cache_; +}; + +} // namespace + +// static +Mitm* Mitm::create(Logger* logger, Config* config, std::string const& cwd) { + std::unique_ptr<MitmImpl> mitm(new MitmImpl(logger)); + if (!mitm->load(config, cwd)) return nullptr; + return mitm.release(); +} diff --git a/src/mitm.hh b/src/mitm.hh new file mode 100644 index 0000000..6d79e8f --- /dev/null +++ b/src/mitm.hh @@ -0,0 +1,66 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef MITM_HH +#define MITM_HH + +#include <cstddef> +#include <string> + +class Buffer; +class Config; +class Logger; + +class Mitm { +public: + virtual ~Mitm() {} + + static Mitm* create(Logger* logger, Config* config, std::string const& cwd); + + virtual bool reload(Config* config, std::string const& cwd) = 0; + + enum DetectResult { + SSL, + OTHER, + NEED_MORE, + }; + virtual DetectResult detect(void const* data, size_t avail) = 0; + + class Monitor { + public: + virtual ~Monitor() {} + + virtual void local2remote(void const* data, size_t size) = 0; + virtual void remote2local(void const* data, size_t size) = 0; + + protected: + Monitor() {} + }; + + class Connection { + public: + virtual ~Connection() {} + + virtual bool transfer( + Buffer* local_in, Buffer* local_out, + Buffer* remote_in, Buffer* remote_out, + Monitor* monitor) = 0; + + virtual bool local_eof() const = 0; + virtual bool remote_eof() const = 0; + + virtual void close_local() = 0; + virtual void close_remote() = 0; + + protected: + Connection() {} + Connection(Connection const&) = delete; + }; + + virtual Connection* open(std::string const& host) = 0; + +protected: + Mitm() {} + Mitm(Mitm const&) = delete; +}; + +#endif // MITM_HH diff --git a/src/proxy.cc b/src/proxy.cc index 5abe257..fdb6df0 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -27,6 +27,7 @@ #include "io.hh" #include "logger.hh" #include "looper.hh" +#include "mitm.hh" #include "resolver.hh" #include "paths.hh" #include "proxy.hh" @@ -217,6 +218,9 @@ struct Content { struct Connect { std::string host; uint16_t port; + Mitm::DetectResult mitm_detect; + std::unique_ptr<Mitm::Connection> mitm; + std::unique_ptr<Mitm::Monitor> mitm_monitor; }; enum RemoteState { @@ -290,6 +294,27 @@ public: bool run() override; private: + class MitmMonitor : public Mitm::Monitor { + public: + MitmMonitor(ProxyImpl* proxy, uint32_t local_pkg_id, uint32_t remote_pkg_id) + : proxy_(proxy), local_pkg_id_(local_pkg_id), + remote_pkg_id_(remote_pkg_id) { + assert(local_pkg_id_ && remote_pkg_id_); + } + + void local2remote(void const* data, size_t size) override { + proxy_->send_attached_data(local_pkg_id_, data, size, false); + } + void remote2local(void const* data, size_t size) override { + proxy_->send_attached_data(remote_pkg_id_, data, size, false); + } + + private: + ProxyImpl* const proxy_; + uint32_t const local_pkg_id_; + uint32_t const remote_pkg_id_; + }; + void setup(); bool reload_config(); void fatal_error(); @@ -312,6 +337,8 @@ private: std::chrono::duration<float> const& timeout); bool base_send(BaseClient* client, void const* data, size_t size, size_t index, char const* name); + bool base_continue_send(BaseClient* client, bool was_empty, + size_t index, char const* name); void client_error(size_t index, uint16_t status_code, std::string const& status); bool client_request(size_t index); @@ -354,6 +381,7 @@ private: io::auto_pipe signal_pipe_; std::unique_ptr<Looper> looper_; std::unique_ptr<Resolver> resolver_; + std::unique_ptr<Mitm> mitm_; bool good_; void* new_timeout_; void* timeout_; @@ -385,6 +413,10 @@ void ProxyImpl::setup() { monitor_send_proxied_ = !config_->get("monitor_proxy_request", false); clients_.resize(get_size(config_, logger_, "max_clients", 1024)); monitors_.resize(get_size(config_, logger_, "max_monitors", 2)); + mitm_.reset(Mitm::create(logger_, config_, cwd_)); + if (mitm_) { + logger_->out(Logger::INFO, "MITM SSL Interception active"); + } looper_->add(accept_socket_.get(), clients_.full() ? 0 : Looper::EVENT_READ, std::bind(&ProxyImpl::new_client, this, @@ -437,6 +469,13 @@ bool ProxyImpl::reload_config() { logger_ = priv_logger_.get(); } } + if (mitm_) { + if (!mitm_->reload(config_, cwd_)) { + logger_->out(Logger::WARN, "Invalid mitm config, ignored"); + } + } else { + mitm_.reset(Mitm::create(logger_, config_, cwd_)); + } auto const old_accept = accept_socket_.get(); auto const old_monitor = monitor_socket_.get(); looper_->remove(old_accept); @@ -653,6 +692,37 @@ bool ProxyImpl::base_send(BaseClient* client, void const* data, size_t size, return true; } +bool ProxyImpl::base_continue_send(BaseClient* client, bool was_empty, + size_t index, char const* name) { + if (client->out->empty()) return true; + if (!was_empty) { + // Already waiting for write event + return true; + } + size_t avail; + auto ptr = client->out->read_ptr(&avail); + auto ret = io::write(client->fd.get(), ptr, avail); + if (ret == -1) { + if (errno != EAGAIN && errno != EWOULDBLOCK) { + logger_->out(Logger::INFO, "%zu: %s write error: %s", index, name, + strerror(errno)); + return false; + } + } else { + client->out->consume(ret); + if (static_cast<size_t>(ret) == avail) { + if (!client->in) { + // If input is closed, close after sending all data + return false; + } + return true; + } + } + client->write_flag = Looper::EVENT_WRITE; + looper_->modify(client->fd.get(), client->read_flag | client->write_flag); + return true; +} + bool ProxyImpl::base_event(BaseClient* client, uint8_t events, size_t index, char const* name) { if (events & Looper::EVENT_READ) { @@ -964,14 +1034,54 @@ void ProxyImpl::client_empty_input(size_t index) { // falltrough case CONTENT_NONE: { if (client.connect && client.remote_state == CONNECTED) { - if (client.pkg_id != 0) { - send_attached_data(client.pkg_id, ptr, avail, false); + if (client.connect->mitm_detect == Mitm::NEED_MORE) { + client.connect->mitm_detect = mitm_->detect(ptr, avail); + if (client.connect->mitm_detect == Mitm::NEED_MORE) { + // Need more data + return; + } + if (client.connect->mitm_detect == Mitm::SSL) { + client.connect->mitm.reset(mitm_->open(client.connect->host)); + if (!client.connect->mitm) { + client.connect->mitm_detect = Mitm::OTHER; + } else if (client.pkg_id != 0) { + client.connect->mitm_monitor.reset( + new MitmMonitor(this, client.pkg_id, client.remote.pkg_id)); + } + } } - if (!base_send(&client.remote, ptr, avail, index, "Client remote")) { - return; + if (client.connect->mitm) { + bool local_empty = client.out->empty(); + bool remote_empty = client.remote.out->empty(); + // TODO(the_jk): Monitor + if (!client.connect->mitm->transfer( + client.in.get(), client.out.get(), + client.remote.in.get(), client.remote.out.get(), + client.connect->mitm_monitor.get())) { + close_client(index); + return; + } + if (!base_continue_send(&client, local_empty, index, "Client")) { + close_client(index); + return; + } + if (!base_continue_send(&client.remote, remote_empty, index, + "Client remote")) { + client_remote_error(index, 502); + return; + } + break; + } else { + if (client.pkg_id != 0) { + send_attached_data(client.pkg_id, ptr, avail, false); + } + if (!base_send(&client.remote, ptr, avail, index, "Client remote")) { + client_remote_error(index, 502); + return; + } + client.in->consume(avail); + break; } - client.in->consume(avail); - break; } if (client.remote_state != CLOSED && client.remote_state != WAITING) { // Still working on the last request, wait @@ -1003,6 +1113,11 @@ void ProxyImpl::client_empty_input(size_t index) { client_error(index, 400, "Bad request"); return; } + if (mitm_) { + client.connect->mitm_detect = Mitm::NEED_MORE; + } else { + client.connect->mitm_detect = Mitm::OTHER; + } } else { client.url.reset(Url::parse(client.request->url())); if (!client.url) { @@ -1132,6 +1247,7 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) { client.source_host, client.source_port, false); } if (!base_send(&client, data.data(), data.size(), index, "Client")) { + close_client(index); return; } events &= ~Looper::EVENT_WRITE; @@ -1280,6 +1396,7 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) { } } if (!base_send(&client, ptr, response->size(), index, "Client")) { + close_client(index); return; } client.remote.in->consume(response->size()); @@ -1291,6 +1408,7 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) { send_attached_data(client.remote.pkg_id, ptr, avail, false); } if (!base_send(&client, ptr, avail, index, "Client")) { + close_client(index); return; } client.remote.in->consume(avail); @@ -1304,6 +1422,7 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) { } if (!base_send(&client, ptr, client.remote.content.len, index, "Client")) { + close_client(index); return; } client.remote.in->consume(client.remote.content.len); @@ -1326,6 +1445,7 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) { send_attached_data(client.remote.pkg_id, ptr, used, false); } if (!base_send(&client, ptr, used, index, "Client")) { + close_client(index); return; } client.remote.in->consume(used); @@ -1344,10 +1464,38 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) { break; } case CONTENT_CLOSE: + if (client.connect) { + if (client.connect->mitm_detect == Mitm::NEED_MORE) { + // Need more data (on client side) + return; + } + if (client.connect->mitm) { + bool local_empty = client.out->empty(); + bool remote_empty = client.remote.out->empty(); + // TODO(the_jk): Monitor + if (!client.connect->mitm->transfer( + client.in.get(), client.out.get(), + client.remote.in.get(), client.remote.out.get(), + client.connect->mitm_monitor.get())) { + close_client(index); + return; + } + if (!base_continue_send(&client, local_empty, index, "Client")) { + return; + } + if (!base_continue_send(&client.remote, remote_empty, index, + "Client remote")) { + client_remote_error(index, 502); + return; + } + break; + } + } if (client.remote.pkg_id != 0) { send_attached_data(client.remote.pkg_id, ptr, avail, false); } if (!base_send(&client, ptr, avail, index, "Client")) { + close_client(index); return; } client.remote.in->consume(avail); diff --git a/src/ssl.cc b/src/ssl.cc new file mode 100644 index 0000000..3395d83 --- /dev/null +++ b/src/ssl.cc @@ -0,0 +1,602 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include <mbedtls/certs.h> +#include <mbedtls/ctr_drbg.h> +#include <mbedtls/entropy.h> +#include <mbedtls/error.h> +#include <mbedtls/pk.h> +#include <mbedtls/rsa.h> +#include <mbedtls/ssl.h> +#include <mbedtls/x509.h> +#include <memory> + +#include "buffer.hh" +#include "logger.hh" +#include "ssl.hh" + +namespace { + +void logout(Logger* logger, Logger::Level lvl, int errnum, + const std::string& msg) { + char tmp[256]; + mbedtls_strerror(errnum, tmp, sizeof(tmp)); + logger->out(lvl, "%s (%d): %s", msg.c_str(), errnum, tmp); +} + +inline void loginfo(Logger* logger, int errnum, const std::string& msg) { + logout(logger, Logger::INFO, errnum, msg); +} + +inline void logerr(Logger* logger, int errnum, const std::string& msg) { + logout(logger, Logger::ERR, errnum, msg); +} + +class SSLEntropyImpl : public SSLEntropy { +public: + SSLEntropyImpl() { + mbedtls_entropy_init(&entropy_); + mbedtls_ctr_drbg_init(&ctr_drbg_); + } + + ~SSLEntropyImpl() override { + mbedtls_ctr_drbg_free(&ctr_drbg_); + mbedtls_entropy_free(&entropy_); + } + + bool init(Logger* logger) { + auto ret = mbedtls_ctr_drbg_seed(&ctr_drbg_, mbedtls_entropy_func, + &entropy_, nullptr, 0); + if (ret) { + logerr(logger, ret, "Error seeding entropy"); + return false; + } + return true; + } + + void setup(mbedtls_ssl_config* conf) { + mbedtls_ssl_conf_rng(conf, mbedtls_ctr_drbg_random, &ctr_drbg_); + } + + mbedtls_ctr_drbg_context* random() { + return &ctr_drbg_; + } + +private: + mbedtls_entropy_context entropy_; + mbedtls_ctr_drbg_context ctr_drbg_; +}; + +class SSLCertStoreImpl : public SSLCertStore { +public: + SSLCertStoreImpl() { + mbedtls_x509_crt_init(&chain_); + } + + ~SSLCertStoreImpl() override { + mbedtls_x509_crt_free(&chain_); + } + + bool init(Logger* logger, std::string const& path) { + auto ret = mbedtls_x509_crt_parse_file(&chain_, path.c_str()); + if (ret) { + logerr(logger, ret, "Error loeading cert store bundle"); + return false; + } + return true; + } + + mbedtls_x509_crt* chain() { + return &chain_; + } + +private: + mbedtls_x509_crt chain_; +}; + +class SSLKeyImpl : public SSLKey { +public: + SSLKeyImpl() { + mbedtls_pk_init(&key_); + } + + ~SSLKeyImpl() override { + mbedtls_pk_free(&key_); + } + + bool load(Logger* logger, std::string const& data) { + auto ret = mbedtls_pk_parse_key( + &key_, + reinterpret_cast<const unsigned char*>(data.c_str()), data.size() + 1, + nullptr, 0); + if (ret) { + logerr(logger, ret, "Error parsing key"); + return false; + } + return true; + } + + mbedtls_pk_context* key() { + return &key_; + } + +private: + mbedtls_pk_context key_; +}; + +class SSLCertImpl : public SSLCert { +public: + SSLCertImpl() { + mbedtls_x509_crt_init(&cert_); + } + + ~SSLCertImpl() override { + mbedtls_x509_crt_free(&cert_); + } + + bool load(Logger* logger, std::string const& data) { + auto ret = mbedtls_x509_crt_parse( + &cert_, + reinterpret_cast<const unsigned char*>(data.c_str()), data.size() + 1); + if (ret < 0) { + logerr(logger, ret, "Error parsing cert"); + return false; + } + return true; + } + + mbedtls_x509_crt* cert() { + return &cert_; + } + +private: + mbedtls_x509_crt cert_; +}; + +template<typename SetupData> +class SSLImpl : public SSL { +public: + SSLImpl(Logger* logger, uint16_t flags) + : logger_(logger), flags_(flags), in_(nullptr), out_(nullptr), + result_(NO_ERR), state_(HANDSHAKE) { + } + + ~SSLImpl() override { + mbedtls_ssl_free(&ssl_); + mbedtls_ssl_config_free(&conf_); + } + + bool unsecure() const { + return flags_ & UNSECURE; + } + + TransferResult transfer(Buffer* ssl_in, Buffer* ssl_out, + Buffer* data_in, Buffer* data_out) override { + if (result_) return result_; + + in_ = ssl_in; + out_ = ssl_out; + + result_ = transfer(data_in, data_out); + + in_ = nullptr; + out_ = nullptr; + + return result_; + } + + void close() override { + state_ = CLOSE; + } + +protected: + void init(SSLEntropy* entropy, int endpoint, SetupData const& data) { + mbedtls_ssl_init(&ssl_); + mbedtls_ssl_config_init(&conf_); + + auto ret = mbedtls_ssl_config_defaults(&conf_, endpoint, + MBEDTLS_SSL_TRANSPORT_STREAM, + MBEDTLS_SSL_PRESET_DEFAULT); + if (ret) { + logerr(logger_, ret, "Error configuring SSL"); + result_ = ERR; + return; + } + + static_cast<SSLEntropyImpl*>(entropy)->setup(&conf_); + + if (!before_setup(data)) { + result_ = ERR; + return; + } + + ret = mbedtls_ssl_setup(&ssl_, &conf_); + if (ret) { + logerr(logger_, ret, "Failed to setup SSL"); + result_ = ERR; + return; + } + + mbedtls_ssl_set_bio(&ssl_, this, &ssl_send, &ssl_recv, nullptr); + + if (!after_setup(data)) { + result_ = ERR; + return; + } + + result_ = NO_ERR; + } + + virtual bool before_setup(SetupData const& UNUSED(data)) { + return true; + } + + virtual bool after_setup(SetupData const& UNUSED(data)) { + return true; + } + + Logger* const logger_; + mbedtls_ssl_context ssl_; + mbedtls_ssl_config conf_; + +private: + enum State { + HANDSHAKE, + TRANSFER, + CLOSE + }; + + TransferResult transfer(Buffer* in, Buffer* out) { + bool want_read = false, want_write = false; + while (!want_read || !want_write) { + switch (state_) { + case HANDSHAKE: { + auto ret = mbedtls_ssl_handshake(&ssl_); + if (ret) { + if (ret == MBEDTLS_ERR_SSL_WANT_READ + || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + return NO_ERR; + } + loginfo(logger_, ret, "Handshake failed"); + return ERR; + } + state_ = TRANSFER; + break; + } + case CLOSE: { + auto ret = mbedtls_ssl_close_notify(&ssl_); + if (ret) { + if (ret == MBEDTLS_ERR_SSL_WANT_READ + || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + return NO_ERR; + } + loginfo(logger_, ret, "Close notify failed"); + return ERR; + } + return CLOSED; + } + case TRANSFER: { + size_t avail; + auto wptr = out->write_ptr(&avail); + if (avail > 0) { + auto ret = mbedtls_ssl_read(&ssl_, + reinterpret_cast<unsigned char*>(wptr), + avail); + if (ret > 0) { + out->commit(ret); + } else if (ret == MBEDTLS_ERR_SSL_WANT_READ + || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + want_read = true; + } else if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { + return CLOSED; + } else { + loginfo(logger_, ret, "SSL read failed"); + return ERR; + } + } else { + assert(false); + want_read = true; + } + auto rptr = in->read_ptr(&avail); + if (avail > 0) { + auto ret = mbedtls_ssl_write( + &ssl_, + reinterpret_cast<unsigned char const*>(rptr), + avail); + if (ret > 0) { + in->consume(ret); + } else if (ret == MBEDTLS_ERR_SSL_WANT_READ + || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + want_write = true; + } else { + loginfo(logger_, ret, "SSL write failed"); + return ERR; + } + } else { + want_write = true; + } + break; + } + } + } + return NO_ERR; + } + + int send(unsigned char const* buf, size_t len) { + if (!out_) { + assert(false); + return MBEDTLS_ERR_SSL_WANT_WRITE; + } + out_->write(buf, len); + return len; + } + int recv(unsigned char* buf, size_t len) { + if (!in_) { + assert(false); + return MBEDTLS_ERR_SSL_WANT_READ; + } + auto ret = in_->read(buf, len); + if (ret == 0) { + return MBEDTLS_ERR_SSL_WANT_READ; + } + return ret; + } + static int ssl_send(void* ctx, unsigned char const* buf, size_t len) { + return reinterpret_cast<SSLImpl*>(ctx)->send(buf, len); + } + static int ssl_recv(void* ctx, unsigned char* buf, size_t len) { + return reinterpret_cast<SSLImpl*>(ctx)->recv(buf, len); + } + + uint16_t const flags_; + Buffer* in_; + Buffer* out_; + TransferResult result_; + State state_; +}; + +struct CertKey { + SSLCert* cert; + SSLKey* key; + + CertKey(SSLCert* cert, SSLKey* key) + : cert(cert), key(key) { + } +}; + +class SSLServerImpl : public SSLImpl<CertKey> { +public: + SSLServerImpl(Logger* logger, SSLEntropy* entropy, SSLCert* cert, + SSLKey* key, uint16_t flags) + : SSLImpl(logger, flags) { + SSLImpl::init(entropy, MBEDTLS_SSL_IS_SERVER, CertKey(cert, key)); + } + + ~SSLServerImpl() override { + } + +private: + bool before_setup(CertKey const& data) override { + mbedtls_ssl_conf_ca_chain( + &conf_, static_cast<SSLCertImpl*>(data.cert)->cert()->next, nullptr); + auto ret = mbedtls_ssl_conf_own_cert( + &conf_, + static_cast<SSLCertImpl*>(data.cert)->cert(), + static_cast<SSLKeyImpl*>(data.key)->key()); + if (ret) { + logerr(logger_, ret, "Server: Error setting certificate"); + return false; + } + + mbedtls_ssl_conf_min_version(&conf_, + MBEDTLS_SSL_MAJOR_VERSION_3, + unsecure() ? MBEDTLS_SSL_MINOR_VERSION_0 : + MBEDTLS_SSL_MINOR_VERSION_1); + return true; + } +}; + +struct HostStore { + std::string const& host; + SSLCertStore* const store; + + HostStore(std::string const& host, SSLCertStore* store) + : host(host), store(store) { + } +}; + +class SSLClientImpl : public SSLImpl<HostStore> { +public: + SSLClientImpl(Logger* logger, SSLEntropy* entropy, SSLCertStore* store, + std::string const& host, uint16_t flags) + : SSLImpl(logger, flags) { + SSLImpl::init(entropy, MBEDTLS_SSL_IS_CLIENT, HostStore(host, store)); + } + +private: + bool before_setup(HostStore const& data) override { + mbedtls_ssl_conf_authmode(&conf_, unsecure() ? MBEDTLS_SSL_VERIFY_OPTIONAL + : MBEDTLS_SSL_VERIFY_REQUIRED); + mbedtls_ssl_conf_ca_chain( + &conf_, static_cast<SSLCertStoreImpl*>(data.store)->chain(), nullptr); + return true; + } + + bool after_setup(HostStore const& data) override { + auto ret = mbedtls_ssl_set_hostname(&ssl_, data.host.c_str()); + if (ret) { + logerr(logger_, ret, "Failed to set hostname"); + return false; + } + return true; + } +}; + +} // namespace + +// static +SSLEntropy* SSLEntropy::create(Logger* logger) { + std::unique_ptr<SSLEntropyImpl> entropy(new SSLEntropyImpl()); + if (!entropy->init(logger)) return nullptr; + return entropy.release(); +} + +// static +SSLCertStore* SSLCertStore::create(Logger* logger, std::string const& path) { + std::unique_ptr<SSLCertStoreImpl> store(new SSLCertStoreImpl()); + if (!store->init(logger, path)) return nullptr; + return store.release(); +} + +// static +SSLKey* SSLKey::load(Logger* logger, std::string const& data) { + std::unique_ptr<SSLKeyImpl> key(new SSLKeyImpl()); + if (!key->load(logger, data)) return nullptr; + return key.release(); +} + +// static +bool SSLKey::generate(Logger* logger, SSLEntropy* entropy, std::string* key) { + mbedtls_pk_context pk; + bool ok = false; + unsigned char buffer[16000]; + mbedtls_pk_init(&pk); + + auto ret = mbedtls_pk_setup(&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA)); + if (ret) { + logerr(logger, ret, "Error setting up key"); + goto error; + } + + ret = mbedtls_rsa_gen_key(mbedtls_pk_rsa(pk), + mbedtls_ctr_drbg_random, + static_cast<SSLEntropyImpl*>(entropy)->random(), + 4096, 65537); + if (ret) { + logerr(logger, ret, "Error generating key"); + goto error; + } + + ret = mbedtls_pk_write_key_pem(&pk, buffer, sizeof(buffer)); + if (ret) { + logerr(logger, ret, "Error writing key"); + goto error; + } + key->assign(reinterpret_cast<char*>(buffer)); + + ok = true; + error: + mbedtls_pk_free(&pk); + return ok; +} + +// static +SSLCert* SSLCert::load(Logger* logger, std::string const& data) { + std::unique_ptr<SSLCertImpl> key(new SSLCertImpl()); + if (!key->load(logger, data)) return nullptr; + return key.release(); +} + +// static +bool SSLCert::generate(Logger* logger, SSLEntropy* entropy, + SSLCert* issuer_cert, SSLKey* issuer_key, + std::string const& host, SSLKey* key, + std::string* cert) { + mbedtls_x509write_cert crt; + char issuer_name[256]; + std::string subject; + bool ok = false; + time_t now, tmp; + struct tm tm; + char not_before[20]; + char not_after[20]; + unsigned char buffer[16000]; + mbedtls_x509write_crt_init(&crt); + mbedtls_x509write_crt_set_md_alg(&crt, MBEDTLS_MD_SHA256); + + if (key) { + mbedtls_x509write_crt_set_subject_key( + &crt, static_cast<SSLKeyImpl*>(key)->key()); + } + if (issuer_key) { + mbedtls_x509write_crt_set_issuer_key( + &crt, static_cast<SSLKeyImpl*>(issuer_key)->key()); + } + + subject = "CN=" + host; + auto ret = mbedtls_x509write_crt_set_subject_name(&crt, subject.c_str()); + if (ret) { + logerr(logger, ret, "Invalid subject name"); + goto error; + } + + if (issuer_cert) { + ret = mbedtls_x509_dn_gets( + issuer_name, sizeof(issuer_name), + &static_cast<SSLCertImpl*>(issuer_cert)->cert()->subject); + if (ret < 0) { + logerr(logger, ret, "Unable to get issuer name"); + goto error; + } + ret = mbedtls_x509write_crt_set_issuer_name(&crt, issuer_name); + } else { + ret = mbedtls_x509write_crt_set_issuer_name(&crt, subject.c_str()); + } + if (ret) { + logerr(logger, ret, "Invalid issuer name"); + goto error; + } + now = time(nullptr); + tmp = now - (24 * 60 * 60 - 1); + gmtime_r(&tmp, &tm); + strftime(not_before, sizeof(not_before), "%Y%m%d000000", &tm); + tmp = now + (30 * 24 * 60 * 60); + gmtime_r(&tmp, &tm); + strftime(not_after, sizeof(not_after), "%Y%m%d000000", &tm); + ret = mbedtls_x509write_crt_set_validity(&crt, not_before, not_after); + if (ret) { + logerr(logger, ret, "Unable to set validity"); + goto error; + } + if (issuer_cert) { + ret = mbedtls_x509write_crt_set_basic_constraints(&crt, 0, -1); + } else { + ret = mbedtls_x509write_crt_set_basic_constraints(&crt, 1, 1); + } + if (ret) { + logerr(logger, ret, "Unable to set basic constraints"); + goto error; + } + + ret = mbedtls_x509write_crt_pem( + &crt, buffer, sizeof(buffer), + mbedtls_ctr_drbg_random, + static_cast<SSLEntropyImpl*>(entropy)->random()); + if (ret) { + logerr(logger, ret, "Error writing cert"); + goto error; + } + cert->assign(reinterpret_cast<char*>(buffer)); + + ok = true; + error: + mbedtls_x509write_crt_free(&crt); + return ok; +} + +// static +SSL* SSL::server(Logger* logger, SSLEntropy* entropy, SSLCert* cert, + SSLKey* key, uint16_t flags) { + return new SSLServerImpl(logger, entropy, cert, key, flags); +} + +// static +SSL* SSL::client(Logger* logger, SSLEntropy* entropy, SSLCertStore* store, + std::string const& host, uint16_t flags) { + return new SSLClientImpl(logger, entropy, store, host, flags); +} + +// static +const uint16_t SSL::UNSECURE = 0x01; + diff --git a/src/ssl.hh b/src/ssl.hh new file mode 100644 index 0000000..1cd6aea --- /dev/null +++ b/src/ssl.hh @@ -0,0 +1,89 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef SSL_HH +#define SSL_HH + +#include <string> + +class Buffer; +class Logger; + +class SSLEntropy { +public: + virtual ~SSLEntropy() {} + + static SSLEntropy* create(Logger* logger); + +protected: + SSLEntropy() {} + SSLEntropy(SSLEntropy const&) = delete; +}; + +class SSLCertStore { +public: + virtual ~SSLCertStore() {} + + static SSLCertStore* create(Logger* logger, std::string const& bundle); + +protected: + SSLCertStore() {} + SSLCertStore(SSLCertStore const&) = delete; +}; + +class SSLKey { +public: + virtual ~SSLKey() {} + static bool generate(Logger* logger, SSLEntropy* entropy, std::string* key); + static SSLKey* load(Logger* logger, std::string const& data); + +protected: + SSLKey() {} + SSLKey(SSLKey const&) = delete; +}; + +class SSLCert { +public: + virtual ~SSLCert() {} + static bool generate(Logger* logger, SSLEntropy* entropy, + SSLCert* issuer_cert, SSLKey* issuer_key, + std::string const& host, SSLKey* key, + std::string* cert); + static SSLCert* load(Logger* logger, std::string const& data); + +protected: + SSLCert() {} + SSLCert(SSLCert const&) = delete; +}; + +class SSL { +public: + virtual ~SSL() {} + + // For server: allow SSLv3 and old unsecure certs like RC4 + // For client: allow self signed certs and missmatched hostname + static const uint16_t UNSECURE; + + static SSL* server(Logger* logger, SSLEntropy* entropy, + SSLCert* cert, SSLKey* key, + uint16_t flags); + static SSL* client(Logger* logger, SSLEntropy* entropy, SSLCertStore* store, + std::string const& host, uint16_t flags); + + enum TransferResult { + NO_ERR, + ERR, + CLOSED, + }; + + // reads SSL from ssl_in and writes SSL to ssl_out + // reads data from data_in and writes data to data_out + virtual TransferResult transfer(Buffer* ssl_in, Buffer* ssl_out, + Buffer* data_in, Buffer* data_out) = 0; + virtual void close() = 0; + +protected: + SSL() {} + SSL(SSL const&) = delete; +}; + +#endif // SSL_HH |
