summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/.gitignore2
-rw-r--r--src/Makefile.am14
-rw-r--r--src/buffer.cc2
-rw-r--r--src/genca.cc69
-rw-r--r--src/logger.cc21
-rw-r--r--src/logger.hh3
-rw-r--r--src/lru.hh70
-rw-r--r--src/mitm.cc253
-rw-r--r--src/mitm.hh66
-rw-r--r--src/proxy.cc160
-rw-r--r--src/ssl.cc602
-rw-r--r--src/ssl.hh89
12 files changed, 1342 insertions, 9 deletions
diff --git a/src/.gitignore b/src/.gitignore
index e4fbae2..7045943 100644
--- a/src/.gitignore
+++ b/src/.gitignore
@@ -1,6 +1,8 @@
/config.h
/config.h.in~
/libmonitor.a
+/libmitm.a
/libtp.a
/tp
+/tp-genca
/tp-monitor
diff --git a/src/Makefile.am b/src/Makefile.am
index 84c4401..fbcdc67 100644
--- a/src/Makefile.am
+++ b/src/Makefile.am
@@ -7,9 +7,16 @@ ARFLAGS = cr
bin_PROGRAMS = tp tp-monitor
noinst_LIBRARIES = libtp.a libmonitor.a
+if HAVE_SSL
+bin_PROGRAMS += tp-genca
+noinst_LIBRARIES += libmitm.a
+endif
tp_SOURCES = main.cc proxy.cc logger.cc resolver.cc
tp_LDADD = libtp.a @THREAD_LIBS@
+if HAVE_SSL
+tp_LDADD += libmitm.a @SSL_LIBS@
+endif
tp_CXXFLAGS = $(AM_CXXFLAGS) -DVERSION='"@VERSION@"' @THREAD_CFLAGS@
libtp_a_SOURCES = args.cc xdg.cc terminal.cc http.cc url.cc paths.cc \
@@ -17,6 +24,13 @@ libtp_a_SOURCES = args.cc xdg.cc terminal.cc http.cc url.cc paths.cc \
buffer.cc chunked.cc
libtp_a_CXXFLAGS = $(AM_CXXFLAGS) -DSYSCONFDIR='"@SYSCONFDIR@"'
+libmitm_a_SOURCES = ssl.cc mitm.cc
+libmitm_a_CXXFLAGS = $(AM_CXXFLAGS) @SSL_CFLAGS@
+
+tp_genca_SOURCES = genca.cc logger.cc
+tp_genca_LDADD = libmitm.a libtp.a @SSL_LIBS@
+tp_genca_CXXFLAGS = $(AM_CXXFLAGS) -DVERSION='"@VERSION@"'
+
libmonitor_a_SOURCES = monitor.cc resolver.cc
libmonitor_CXXFLAGS = $(AM_CXXFLAGS) -DVERSION='"@VERSION@"' @THREAD_CFLAGS@
diff --git a/src/buffer.cc b/src/buffer.cc
index 4bde195..d220514 100644
--- a/src/buffer.cc
+++ b/src/buffer.cc
@@ -95,7 +95,7 @@ size_t Buffer::read(void* data, size_t max) {
if (avail == 0) return 0;
avail = std::min(avail, max);
memcpy(data, ptr, avail);
- commit(avail);
+ consume(avail);
return avail;
}
diff --git a/src/genca.cc b/src/genca.cc
new file mode 100644
index 0000000..ee0f580
--- /dev/null
+++ b/src/genca.cc
@@ -0,0 +1,69 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <string>
+
+#include "args.hh"
+#include "logger.hh"
+#include "ssl.hh"
+
+namespace {
+
+bool genca(std::ostream& out, std::string const& name) {
+ std::unique_ptr<Logger> logger(Logger::create_stderr());
+ std::unique_ptr<SSLEntropy> entropy(SSLEntropy::create(logger.get()));
+ if (!entropy) return false;
+ std::string key;
+ if (!SSLKey::generate(logger.get(), entropy.get(), &key)) return false;
+ std::string cert;
+ std::unique_ptr<SSLKey> pkey(SSLKey::load(logger.get(), key));
+ if (!SSLCert::generate(logger.get(), entropy.get(), nullptr, nullptr, name,
+ pkey.get(), &cert)) return false;
+ out << cert << '\n' << key << std::endl;
+ return true;
+}
+
+} // namespace
+
+int main(int argc, char** argv) {
+ std::unique_ptr<Args> args(Args::create());
+ args->add('o', "output", "FILE",
+ "output certificate and key to FILE instead of stdout.");
+ args->add('n', "name", "NAME",
+ "Issuer name to use instead of TransparentProxy");
+ args->add('h', "help", "display this text and exit.");
+ args->add('V', "version", "display version and exit.");
+ if (!args->run(argc, argv)) {
+ std::cerr << "Try `tp-genca --help` for usage." << std::endl;
+ return EXIT_FAILURE;
+ }
+ if (args->is_set('h')) {
+ std::cout << "Usage: `tp-genca [OPTIONS...]`\n"
+ << "Generate a self-signed CA to use for MITM SSL interception.\n"
+ << '\n';
+ args->print_help();
+ return EXIT_SUCCESS;
+ }
+ if (args->is_set('V')) {
+ std::cout << "TransparentProxy version " VERSION
+ << " written by Joel Klinghed <the_jk@yahoo.com>" << std::endl;
+ return EXIT_SUCCESS;
+ }
+ if (!args->arguments().empty()) {
+ std::cerr << "Too many arguments.\n"
+ << "Try `tp-genca --help` for usage." << std::endl;
+ return EXIT_FAILURE;
+ }
+ auto name = args->arg('n', "TransparentProxy");
+ auto output = args->arg('o', nullptr);
+ if (output) {
+ std::ofstream out(output);
+ return genca(out, name) ? EXIT_SUCCESS : EXIT_FAILURE;
+ } else {
+ return genca(std::cout, name) ? EXIT_SUCCESS : EXIT_FAILURE;
+ }
+}
diff --git a/src/logger.cc b/src/logger.cc
index 29dcb38..343bb4d 100644
--- a/src/logger.cc
+++ b/src/logger.cc
@@ -21,7 +21,12 @@ namespace {
class LoggerStdErr : public Logger {
public:
- void out(Level UNUSED(lvl), char const* format, ...) override {
+ void out(Level lvl, char const* format, ...) override {
+ if (lvl == DBG) {
+#ifdef NDEBUG
+ return;
+#endif
+ }
char* tmp;
va_list args;
va_start(args, format);
@@ -44,6 +49,11 @@ public:
}
void out(Level lvl, char const* format, ...) override {
+ if (lvl == DBG) {
+#ifdef NDEBUG
+ return;
+#endif
+ }
va_list args;
va_start(args, format);
vsyslog(lvl2prio(lvl), format, args);
@@ -59,6 +69,8 @@ private:
return LOG_WARNING;
case INFO:
return LOG_INFO;
+ case DBG:
+ return LOG_DEBUG;
}
assert(false);
return LOG_INFO;
@@ -82,6 +94,11 @@ public:
}
void out(Level lvl, char const* format, ...) override {
+ if (lvl == DBG) {
+#ifdef NDEBUG
+ return;
+#endif
+ }
fputs(lvl2str(lvl), fh_);
fwrite(": ", 1, 2, fh_);
va_list args;
@@ -100,6 +117,8 @@ private:
return "Warning";
case INFO:
return "Info";
+ case DBG:
+ return "Debug";
}
assert(false);
return "Info";
diff --git a/src/logger.hh b/src/logger.hh
index 8b0db05..3700635 100644
--- a/src/logger.hh
+++ b/src/logger.hh
@@ -12,7 +12,8 @@ public:
enum Level {
ERR,
WARN,
- INFO
+ INFO,
+ DBG
};
static Logger* create_stderr();
diff --git a/src/lru.hh b/src/lru.hh
new file mode 100644
index 0000000..3279b35
--- /dev/null
+++ b/src/lru.hh
@@ -0,0 +1,70 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef LRU_HH
+#define LRU_HH
+
+#include <algorithm>
+#include <deque>
+#include <unordered_map>
+
+template<typename K, typename V>
+class lru {
+public:
+ lru(size_t size)
+ : size_(size) {
+ }
+
+ bool empty() const {
+ return usage_.empty();
+ }
+
+ size_t size() const {
+ return usage_.size();
+ }
+
+ V* get(K const& key) {
+ auto it = data_.find(key);
+ if (it == data_.end()) return nullptr;
+ auto it2 = std::find(usage_.begin(), usage_.end(), key);
+ if (it2 != usage_.begin()) {
+ usage_.erase(it2);
+ usage_.push_front(key);
+ }
+ return &it->second;
+ }
+
+ void insert(K const& key, V const& data) {
+ auto pair = data_.insert(std::make_pair(key, data));
+ if (pair.second) {
+ usage_.push_front(key);
+ shrink_to_size();
+ } else {
+ pair.first->second = data;
+ auto it = std::find(usage_.begin(), usage_.end(), key);
+ if (it != usage_.begin()) {
+ usage_.erase(it);
+ usage_.push_front(key);
+ }
+ }
+ }
+
+ void erase(K const& key) {
+ if (data_.erase(key)) {
+ usage_.erase(std::find(usage_.begin(), usage_.end(), key));
+ }
+ }
+
+private:
+ void shrink_to_size() {
+ if (usage_.size() > size_) {
+ data_.erase(usage_.back());
+ usage_.pop_back();
+ }
+ }
+
+ size_t const size_;
+ std::deque<K> usage_;
+ std::unordered_map<K, V> data_;
+};
+
+#endif // LRU_HH
diff --git a/src/mitm.cc b/src/mitm.cc
new file mode 100644
index 0000000..7e0fd19
--- /dev/null
+++ b/src/mitm.cc
@@ -0,0 +1,253 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include <fstream>
+#include <memory>
+
+#include "config.hh"
+#include "buffer.hh"
+#include "paths.hh"
+#include "logger.hh"
+#include "lru.hh"
+#include "mitm.hh"
+#include "ssl.hh"
+
+namespace {
+
+class ConnectionImpl : public Mitm::Connection {
+public:
+ ConnectionImpl(Logger* logger)
+ : logger_(logger){
+ }
+
+ bool connect(SSLEntropy* entropy, SSLCertStore* store,
+ std::string const& cert, SSLKey* key,
+ bool unsecure, std::string const& host) {
+ uint16_t flags = unsecure ? SSL::UNSECURE : 0;
+ cert_.reset(SSLCert::load(logger_, cert));
+ if (!cert_) return false;
+ local_.reset(SSL::server(logger_, entropy, cert_.get(), key, flags));
+ if (!local_) return false;
+ remote_.reset(SSL::client(logger_, entropy, store, host, flags));
+ if (!remote_) return false;
+
+ in_.reset(Buffer::create(8192, 1024));
+ out_.reset(Buffer::create(8192, 1024));
+
+ return true;
+ }
+
+ bool transfer(Buffer* local_in, Buffer* local_out,
+ Buffer* remote_in, Buffer* remote_out,
+ Mitm::Monitor* monitor) override {
+ bool local_active = true, remote_active = true;
+ while (local_active || remote_active) {
+ local_active = false;
+ if (local_) {
+ size_t avail;
+ out_->read_ptr(&avail);
+ switch (local_->transfer(local_in, local_out, in_.get(), out_.get())) {
+ case SSL::NO_ERR: {
+ size_t avail2;
+ auto ptr = out_->read_ptr(&avail2);
+ if (avail < avail2) {
+ if (monitor) {
+ monitor->local2remote(
+ static_cast<char const*>(ptr) + avail, avail2 - avail);
+ }
+ local_active = true;
+ }
+ break;
+ }
+ case SSL::ERR:
+ return false;
+ case SSL::CLOSED:
+ local_.reset();
+ break;
+ }
+ }
+ remote_active = false;
+ if (remote_) {
+ size_t avail;
+ in_->read_ptr(&avail);
+ switch (remote_->transfer(remote_in, remote_out,
+ out_.get(), in_.get())) {
+ case SSL::NO_ERR: {
+ size_t avail2;
+ auto ptr = in_->read_ptr(&avail2);
+ if (avail < avail2) {
+ if (monitor) {
+ monitor->remote2local(
+ static_cast<char const*>(ptr) + avail, avail2 - avail);
+ }
+ remote_active = true;
+ }
+ break;
+ }
+ case SSL::ERR:
+ return false;
+ case SSL::CLOSED:
+ remote_.reset();
+ break;
+ }
+ }
+ }
+ return true;
+ }
+
+ bool local_eof() const override {
+ return !local_;
+ }
+
+ bool remote_eof() const override {
+ return !remote_;
+ }
+
+ void close_local() override {
+ if (local_) {
+ local_->close();
+ }
+ }
+ void close_remote() override {
+ if (remote_) {
+ remote_->close();
+ }
+ }
+
+private:
+ Logger* const logger_;
+ std::unique_ptr<SSLCert> cert_;
+ std::unique_ptr<SSL> local_;
+ std::unique_ptr<SSL> remote_;
+ std::unique_ptr<Buffer> in_;
+ std::unique_ptr<Buffer> out_;
+};
+
+class MitmImpl : public Mitm {
+public:
+ MitmImpl(Logger* logger)
+ : logger_(logger), unsecure_(false), cache_(42) {
+ }
+
+ ~MitmImpl() override {
+ }
+
+ bool load(Config* config, std::string const& cwd) {
+ entropy_.reset(SSLEntropy::create(logger_));
+ if (!entropy_) return false;
+ store_.reset(SSLCertStore::create(
+ logger_, config->get("ssl_cert_bundle", "")));
+ if (!store_) return false;
+ std::string ca_cert, ca_key;
+ if (!load_file(config->get("ssl_ca_cert", ""), cwd, &ca_cert)) return false;
+ if (!load_file(config->get("ssl_ca_key", ""), cwd, &ca_key)) return false;
+ unsecure_ = config->get("ssl_unsecure", false);
+ issuer_cert_.reset(SSLCert::load(logger_, ca_cert));
+ if (!issuer_cert_) return false;
+ issuer_key_.reset(SSLKey::load(logger_, ca_key));
+ if (!issuer_key_) return false;
+ return true;
+ }
+
+ bool reload(Config* config, std::string const& cwd) override {
+ std::unique_ptr<SSLCertStore> store(
+ SSLCertStore::create(logger_, config->get("ssl_cert_bundle", "")));
+ if (!store) return false;
+ std::string ca_cert, ca_key;
+ if (!load_file(config->get("ssl_ca_cert", ""), cwd, &ca_cert)) return false;
+ if (!load_file(config->get("ssl_ca_key", ""), cwd, &ca_key)) return false;
+ unsecure_ = config->get("ssl_unsecure", false);
+ std::unique_ptr<SSLCert> issuer_cert(SSLCert::load(logger_, ca_cert));
+ if (!issuer_cert) return false;
+ std::unique_ptr<SSLKey> issuer_key(SSLKey::load(logger_, ca_key));
+ if (!issuer_key) return false;
+ store_.swap(store);
+ issuer_cert_.swap(issuer_cert);
+ issuer_key_.swap(issuer_key);
+ return true;
+ }
+
+ DetectResult detect(void const* data, size_t avail) override {
+ if (avail == 0) return NEED_MORE;
+ auto d = static_cast<uint8_t const*>(data);
+ // SSL 3.0
+ if (d[0] == 0x16) {
+ if (avail < 2) return NEED_MORE;
+ if (d[1] != 0x03) return OTHER; // Need fixing when SSL 4.0 shows up
+ if (avail < 5) return NEED_MORE;
+ if (((d[3] << 8) | d[4]) < 9) // Min size of client hello
+ return OTHER;
+ if (avail < 6) return NEED_MORE;
+ return d[5] == 0x01 ? SSL : OTHER;
+ }
+ // SSL 2.0
+ if (!(d[0] & 0x80)) return OTHER;
+ if (avail < 2) return NEED_MORE;
+ if (((d[0] & 0x7f) << 8 | d[1]) < 9) // Min size of client hello
+ return OTHER;
+ if (avail < 3) return NEED_MORE;
+ return d[2] == 0x01 ? SSL : OTHER;
+ }
+
+ Connection* open(std::string const& host) override {
+ CacheEntry* entry;
+ CacheEntry _entry;
+ entry = cache_.get(host);
+ if (!entry) {
+ entry = &_entry;
+ if (!SSLCert::generate(logger_, entropy_.get(),
+ issuer_cert_.get(), issuer_key_.get(), host,
+ issuer_key_.get(), &entry->cert)) {
+ return nullptr;
+ }
+ cache_.insert(host, _entry);
+ }
+ std::unique_ptr<ConnectionImpl> conn(new ConnectionImpl(logger_));
+ if (!conn->connect(entropy_.get(), store_.get(), entry->cert,
+ issuer_key_.get(), unsecure_, host)) return nullptr;
+ return conn.release();
+ }
+
+private:
+ struct CacheEntry {
+ std::string cert;
+ };
+
+ bool load_file(std::string const& path, std::string const& cwd,
+ std::string* out) const {
+ std::ifstream in(Paths::join(cwd, path));
+ if (!in.good()) {
+ logger_->out(Logger::ERR, "Unable to open: %s", path.c_str());
+ return false;
+ }
+ out->clear();
+ while (in.good()) {
+ char buffer[8192];
+ in.read(buffer, sizeof(buffer));
+ out->append(buffer, in.gcount());
+ }
+ if (!in.eof()) {
+ logger_->out(Logger::ERR, "Unable to open: %s", path.c_str());
+ return false;
+ }
+ return true;
+ }
+
+ Logger* const logger_;
+ std::unique_ptr<SSLEntropy> entropy_;
+ std::unique_ptr<SSLCertStore> store_;
+ std::unique_ptr<SSLCert> issuer_cert_;
+ std::unique_ptr<SSLKey> issuer_key_;
+ bool unsecure_;
+ lru<std::string, CacheEntry> cache_;
+};
+
+} // namespace
+
+// static
+Mitm* Mitm::create(Logger* logger, Config* config, std::string const& cwd) {
+ std::unique_ptr<MitmImpl> mitm(new MitmImpl(logger));
+ if (!mitm->load(config, cwd)) return nullptr;
+ return mitm.release();
+}
diff --git a/src/mitm.hh b/src/mitm.hh
new file mode 100644
index 0000000..6d79e8f
--- /dev/null
+++ b/src/mitm.hh
@@ -0,0 +1,66 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef MITM_HH
+#define MITM_HH
+
+#include <cstddef>
+#include <string>
+
+class Buffer;
+class Config;
+class Logger;
+
+class Mitm {
+public:
+ virtual ~Mitm() {}
+
+ static Mitm* create(Logger* logger, Config* config, std::string const& cwd);
+
+ virtual bool reload(Config* config, std::string const& cwd) = 0;
+
+ enum DetectResult {
+ SSL,
+ OTHER,
+ NEED_MORE,
+ };
+ virtual DetectResult detect(void const* data, size_t avail) = 0;
+
+ class Monitor {
+ public:
+ virtual ~Monitor() {}
+
+ virtual void local2remote(void const* data, size_t size) = 0;
+ virtual void remote2local(void const* data, size_t size) = 0;
+
+ protected:
+ Monitor() {}
+ };
+
+ class Connection {
+ public:
+ virtual ~Connection() {}
+
+ virtual bool transfer(
+ Buffer* local_in, Buffer* local_out,
+ Buffer* remote_in, Buffer* remote_out,
+ Monitor* monitor) = 0;
+
+ virtual bool local_eof() const = 0;
+ virtual bool remote_eof() const = 0;
+
+ virtual void close_local() = 0;
+ virtual void close_remote() = 0;
+
+ protected:
+ Connection() {}
+ Connection(Connection const&) = delete;
+ };
+
+ virtual Connection* open(std::string const& host) = 0;
+
+protected:
+ Mitm() {}
+ Mitm(Mitm const&) = delete;
+};
+
+#endif // MITM_HH
diff --git a/src/proxy.cc b/src/proxy.cc
index 5abe257..fdb6df0 100644
--- a/src/proxy.cc
+++ b/src/proxy.cc
@@ -27,6 +27,7 @@
#include "io.hh"
#include "logger.hh"
#include "looper.hh"
+#include "mitm.hh"
#include "resolver.hh"
#include "paths.hh"
#include "proxy.hh"
@@ -217,6 +218,9 @@ struct Content {
struct Connect {
std::string host;
uint16_t port;
+ Mitm::DetectResult mitm_detect;
+ std::unique_ptr<Mitm::Connection> mitm;
+ std::unique_ptr<Mitm::Monitor> mitm_monitor;
};
enum RemoteState {
@@ -290,6 +294,27 @@ public:
bool run() override;
private:
+ class MitmMonitor : public Mitm::Monitor {
+ public:
+ MitmMonitor(ProxyImpl* proxy, uint32_t local_pkg_id, uint32_t remote_pkg_id)
+ : proxy_(proxy), local_pkg_id_(local_pkg_id),
+ remote_pkg_id_(remote_pkg_id) {
+ assert(local_pkg_id_ && remote_pkg_id_);
+ }
+
+ void local2remote(void const* data, size_t size) override {
+ proxy_->send_attached_data(local_pkg_id_, data, size, false);
+ }
+ void remote2local(void const* data, size_t size) override {
+ proxy_->send_attached_data(remote_pkg_id_, data, size, false);
+ }
+
+ private:
+ ProxyImpl* const proxy_;
+ uint32_t const local_pkg_id_;
+ uint32_t const remote_pkg_id_;
+ };
+
void setup();
bool reload_config();
void fatal_error();
@@ -312,6 +337,8 @@ private:
std::chrono::duration<float> const& timeout);
bool base_send(BaseClient* client, void const* data, size_t size,
size_t index, char const* name);
+ bool base_continue_send(BaseClient* client, bool was_empty,
+ size_t index, char const* name);
void client_error(size_t index,
uint16_t status_code, std::string const& status);
bool client_request(size_t index);
@@ -354,6 +381,7 @@ private:
io::auto_pipe signal_pipe_;
std::unique_ptr<Looper> looper_;
std::unique_ptr<Resolver> resolver_;
+ std::unique_ptr<Mitm> mitm_;
bool good_;
void* new_timeout_;
void* timeout_;
@@ -385,6 +413,10 @@ void ProxyImpl::setup() {
monitor_send_proxied_ = !config_->get("monitor_proxy_request", false);
clients_.resize(get_size(config_, logger_, "max_clients", 1024));
monitors_.resize(get_size(config_, logger_, "max_monitors", 2));
+ mitm_.reset(Mitm::create(logger_, config_, cwd_));
+ if (mitm_) {
+ logger_->out(Logger::INFO, "MITM SSL Interception active");
+ }
looper_->add(accept_socket_.get(),
clients_.full() ? 0 : Looper::EVENT_READ,
std::bind(&ProxyImpl::new_client, this,
@@ -437,6 +469,13 @@ bool ProxyImpl::reload_config() {
logger_ = priv_logger_.get();
}
}
+ if (mitm_) {
+ if (!mitm_->reload(config_, cwd_)) {
+ logger_->out(Logger::WARN, "Invalid mitm config, ignored");
+ }
+ } else {
+ mitm_.reset(Mitm::create(logger_, config_, cwd_));
+ }
auto const old_accept = accept_socket_.get();
auto const old_monitor = monitor_socket_.get();
looper_->remove(old_accept);
@@ -653,6 +692,37 @@ bool ProxyImpl::base_send(BaseClient* client, void const* data, size_t size,
return true;
}
+bool ProxyImpl::base_continue_send(BaseClient* client, bool was_empty,
+ size_t index, char const* name) {
+ if (client->out->empty()) return true;
+ if (!was_empty) {
+ // Already waiting for write event
+ return true;
+ }
+ size_t avail;
+ auto ptr = client->out->read_ptr(&avail);
+ auto ret = io::write(client->fd.get(), ptr, avail);
+ if (ret == -1) {
+ if (errno != EAGAIN && errno != EWOULDBLOCK) {
+ logger_->out(Logger::INFO, "%zu: %s write error: %s", index, name,
+ strerror(errno));
+ return false;
+ }
+ } else {
+ client->out->consume(ret);
+ if (static_cast<size_t>(ret) == avail) {
+ if (!client->in) {
+ // If input is closed, close after sending all data
+ return false;
+ }
+ return true;
+ }
+ }
+ client->write_flag = Looper::EVENT_WRITE;
+ looper_->modify(client->fd.get(), client->read_flag | client->write_flag);
+ return true;
+}
+
bool ProxyImpl::base_event(BaseClient* client, uint8_t events,
size_t index, char const* name) {
if (events & Looper::EVENT_READ) {
@@ -964,14 +1034,54 @@ void ProxyImpl::client_empty_input(size_t index) {
// falltrough
case CONTENT_NONE: {
if (client.connect && client.remote_state == CONNECTED) {
- if (client.pkg_id != 0) {
- send_attached_data(client.pkg_id, ptr, avail, false);
+ if (client.connect->mitm_detect == Mitm::NEED_MORE) {
+ client.connect->mitm_detect = mitm_->detect(ptr, avail);
+ if (client.connect->mitm_detect == Mitm::NEED_MORE) {
+ // Need more data
+ return;
+ }
+ if (client.connect->mitm_detect == Mitm::SSL) {
+ client.connect->mitm.reset(mitm_->open(client.connect->host));
+ if (!client.connect->mitm) {
+ client.connect->mitm_detect = Mitm::OTHER;
+ } else if (client.pkg_id != 0) {
+ client.connect->mitm_monitor.reset(
+ new MitmMonitor(this, client.pkg_id, client.remote.pkg_id));
+ }
+ }
}
- if (!base_send(&client.remote, ptr, avail, index, "Client remote")) {
- return;
+ if (client.connect->mitm) {
+ bool local_empty = client.out->empty();
+ bool remote_empty = client.remote.out->empty();
+ // TODO(the_jk): Monitor
+ if (!client.connect->mitm->transfer(
+ client.in.get(), client.out.get(),
+ client.remote.in.get(), client.remote.out.get(),
+ client.connect->mitm_monitor.get())) {
+ close_client(index);
+ return;
+ }
+ if (!base_continue_send(&client, local_empty, index, "Client")) {
+ close_client(index);
+ return;
+ }
+ if (!base_continue_send(&client.remote, remote_empty, index,
+ "Client remote")) {
+ client_remote_error(index, 502);
+ return;
+ }
+ break;
+ } else {
+ if (client.pkg_id != 0) {
+ send_attached_data(client.pkg_id, ptr, avail, false);
+ }
+ if (!base_send(&client.remote, ptr, avail, index, "Client remote")) {
+ client_remote_error(index, 502);
+ return;
+ }
+ client.in->consume(avail);
+ break;
}
- client.in->consume(avail);
- break;
}
if (client.remote_state != CLOSED && client.remote_state != WAITING) {
// Still working on the last request, wait
@@ -1003,6 +1113,11 @@ void ProxyImpl::client_empty_input(size_t index) {
client_error(index, 400, "Bad request");
return;
}
+ if (mitm_) {
+ client.connect->mitm_detect = Mitm::NEED_MORE;
+ } else {
+ client.connect->mitm_detect = Mitm::OTHER;
+ }
} else {
client.url.reset(Url::parse(client.request->url()));
if (!client.url) {
@@ -1132,6 +1247,7 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) {
client.source_host, client.source_port, false);
}
if (!base_send(&client, data.data(), data.size(), index, "Client")) {
+ close_client(index);
return;
}
events &= ~Looper::EVENT_WRITE;
@@ -1280,6 +1396,7 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) {
}
}
if (!base_send(&client, ptr, response->size(), index, "Client")) {
+ close_client(index);
return;
}
client.remote.in->consume(response->size());
@@ -1291,6 +1408,7 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) {
send_attached_data(client.remote.pkg_id, ptr, avail, false);
}
if (!base_send(&client, ptr, avail, index, "Client")) {
+ close_client(index);
return;
}
client.remote.in->consume(avail);
@@ -1304,6 +1422,7 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) {
}
if (!base_send(&client, ptr, client.remote.content.len,
index, "Client")) {
+ close_client(index);
return;
}
client.remote.in->consume(client.remote.content.len);
@@ -1326,6 +1445,7 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) {
send_attached_data(client.remote.pkg_id, ptr, used, false);
}
if (!base_send(&client, ptr, used, index, "Client")) {
+ close_client(index);
return;
}
client.remote.in->consume(used);
@@ -1344,10 +1464,38 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) {
break;
}
case CONTENT_CLOSE:
+ if (client.connect) {
+ if (client.connect->mitm_detect == Mitm::NEED_MORE) {
+ // Need more data (on client side)
+ return;
+ }
+ if (client.connect->mitm) {
+ bool local_empty = client.out->empty();
+ bool remote_empty = client.remote.out->empty();
+ // TODO(the_jk): Monitor
+ if (!client.connect->mitm->transfer(
+ client.in.get(), client.out.get(),
+ client.remote.in.get(), client.remote.out.get(),
+ client.connect->mitm_monitor.get())) {
+ close_client(index);
+ return;
+ }
+ if (!base_continue_send(&client, local_empty, index, "Client")) {
+ return;
+ }
+ if (!base_continue_send(&client.remote, remote_empty, index,
+ "Client remote")) {
+ client_remote_error(index, 502);
+ return;
+ }
+ break;
+ }
+ }
if (client.remote.pkg_id != 0) {
send_attached_data(client.remote.pkg_id, ptr, avail, false);
}
if (!base_send(&client, ptr, avail, index, "Client")) {
+ close_client(index);
return;
}
client.remote.in->consume(avail);
diff --git a/src/ssl.cc b/src/ssl.cc
new file mode 100644
index 0000000..3395d83
--- /dev/null
+++ b/src/ssl.cc
@@ -0,0 +1,602 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include <mbedtls/certs.h>
+#include <mbedtls/ctr_drbg.h>
+#include <mbedtls/entropy.h>
+#include <mbedtls/error.h>
+#include <mbedtls/pk.h>
+#include <mbedtls/rsa.h>
+#include <mbedtls/ssl.h>
+#include <mbedtls/x509.h>
+#include <memory>
+
+#include "buffer.hh"
+#include "logger.hh"
+#include "ssl.hh"
+
+namespace {
+
+void logout(Logger* logger, Logger::Level lvl, int errnum,
+ const std::string& msg) {
+ char tmp[256];
+ mbedtls_strerror(errnum, tmp, sizeof(tmp));
+ logger->out(lvl, "%s (%d): %s", msg.c_str(), errnum, tmp);
+}
+
+inline void loginfo(Logger* logger, int errnum, const std::string& msg) {
+ logout(logger, Logger::INFO, errnum, msg);
+}
+
+inline void logerr(Logger* logger, int errnum, const std::string& msg) {
+ logout(logger, Logger::ERR, errnum, msg);
+}
+
+class SSLEntropyImpl : public SSLEntropy {
+public:
+ SSLEntropyImpl() {
+ mbedtls_entropy_init(&entropy_);
+ mbedtls_ctr_drbg_init(&ctr_drbg_);
+ }
+
+ ~SSLEntropyImpl() override {
+ mbedtls_ctr_drbg_free(&ctr_drbg_);
+ mbedtls_entropy_free(&entropy_);
+ }
+
+ bool init(Logger* logger) {
+ auto ret = mbedtls_ctr_drbg_seed(&ctr_drbg_, mbedtls_entropy_func,
+ &entropy_, nullptr, 0);
+ if (ret) {
+ logerr(logger, ret, "Error seeding entropy");
+ return false;
+ }
+ return true;
+ }
+
+ void setup(mbedtls_ssl_config* conf) {
+ mbedtls_ssl_conf_rng(conf, mbedtls_ctr_drbg_random, &ctr_drbg_);
+ }
+
+ mbedtls_ctr_drbg_context* random() {
+ return &ctr_drbg_;
+ }
+
+private:
+ mbedtls_entropy_context entropy_;
+ mbedtls_ctr_drbg_context ctr_drbg_;
+};
+
+class SSLCertStoreImpl : public SSLCertStore {
+public:
+ SSLCertStoreImpl() {
+ mbedtls_x509_crt_init(&chain_);
+ }
+
+ ~SSLCertStoreImpl() override {
+ mbedtls_x509_crt_free(&chain_);
+ }
+
+ bool init(Logger* logger, std::string const& path) {
+ auto ret = mbedtls_x509_crt_parse_file(&chain_, path.c_str());
+ if (ret) {
+ logerr(logger, ret, "Error loeading cert store bundle");
+ return false;
+ }
+ return true;
+ }
+
+ mbedtls_x509_crt* chain() {
+ return &chain_;
+ }
+
+private:
+ mbedtls_x509_crt chain_;
+};
+
+class SSLKeyImpl : public SSLKey {
+public:
+ SSLKeyImpl() {
+ mbedtls_pk_init(&key_);
+ }
+
+ ~SSLKeyImpl() override {
+ mbedtls_pk_free(&key_);
+ }
+
+ bool load(Logger* logger, std::string const& data) {
+ auto ret = mbedtls_pk_parse_key(
+ &key_,
+ reinterpret_cast<const unsigned char*>(data.c_str()), data.size() + 1,
+ nullptr, 0);
+ if (ret) {
+ logerr(logger, ret, "Error parsing key");
+ return false;
+ }
+ return true;
+ }
+
+ mbedtls_pk_context* key() {
+ return &key_;
+ }
+
+private:
+ mbedtls_pk_context key_;
+};
+
+class SSLCertImpl : public SSLCert {
+public:
+ SSLCertImpl() {
+ mbedtls_x509_crt_init(&cert_);
+ }
+
+ ~SSLCertImpl() override {
+ mbedtls_x509_crt_free(&cert_);
+ }
+
+ bool load(Logger* logger, std::string const& data) {
+ auto ret = mbedtls_x509_crt_parse(
+ &cert_,
+ reinterpret_cast<const unsigned char*>(data.c_str()), data.size() + 1);
+ if (ret < 0) {
+ logerr(logger, ret, "Error parsing cert");
+ return false;
+ }
+ return true;
+ }
+
+ mbedtls_x509_crt* cert() {
+ return &cert_;
+ }
+
+private:
+ mbedtls_x509_crt cert_;
+};
+
+template<typename SetupData>
+class SSLImpl : public SSL {
+public:
+ SSLImpl(Logger* logger, uint16_t flags)
+ : logger_(logger), flags_(flags), in_(nullptr), out_(nullptr),
+ result_(NO_ERR), state_(HANDSHAKE) {
+ }
+
+ ~SSLImpl() override {
+ mbedtls_ssl_free(&ssl_);
+ mbedtls_ssl_config_free(&conf_);
+ }
+
+ bool unsecure() const {
+ return flags_ & UNSECURE;
+ }
+
+ TransferResult transfer(Buffer* ssl_in, Buffer* ssl_out,
+ Buffer* data_in, Buffer* data_out) override {
+ if (result_) return result_;
+
+ in_ = ssl_in;
+ out_ = ssl_out;
+
+ result_ = transfer(data_in, data_out);
+
+ in_ = nullptr;
+ out_ = nullptr;
+
+ return result_;
+ }
+
+ void close() override {
+ state_ = CLOSE;
+ }
+
+protected:
+ void init(SSLEntropy* entropy, int endpoint, SetupData const& data) {
+ mbedtls_ssl_init(&ssl_);
+ mbedtls_ssl_config_init(&conf_);
+
+ auto ret = mbedtls_ssl_config_defaults(&conf_, endpoint,
+ MBEDTLS_SSL_TRANSPORT_STREAM,
+ MBEDTLS_SSL_PRESET_DEFAULT);
+ if (ret) {
+ logerr(logger_, ret, "Error configuring SSL");
+ result_ = ERR;
+ return;
+ }
+
+ static_cast<SSLEntropyImpl*>(entropy)->setup(&conf_);
+
+ if (!before_setup(data)) {
+ result_ = ERR;
+ return;
+ }
+
+ ret = mbedtls_ssl_setup(&ssl_, &conf_);
+ if (ret) {
+ logerr(logger_, ret, "Failed to setup SSL");
+ result_ = ERR;
+ return;
+ }
+
+ mbedtls_ssl_set_bio(&ssl_, this, &ssl_send, &ssl_recv, nullptr);
+
+ if (!after_setup(data)) {
+ result_ = ERR;
+ return;
+ }
+
+ result_ = NO_ERR;
+ }
+
+ virtual bool before_setup(SetupData const& UNUSED(data)) {
+ return true;
+ }
+
+ virtual bool after_setup(SetupData const& UNUSED(data)) {
+ return true;
+ }
+
+ Logger* const logger_;
+ mbedtls_ssl_context ssl_;
+ mbedtls_ssl_config conf_;
+
+private:
+ enum State {
+ HANDSHAKE,
+ TRANSFER,
+ CLOSE
+ };
+
+ TransferResult transfer(Buffer* in, Buffer* out) {
+ bool want_read = false, want_write = false;
+ while (!want_read || !want_write) {
+ switch (state_) {
+ case HANDSHAKE: {
+ auto ret = mbedtls_ssl_handshake(&ssl_);
+ if (ret) {
+ if (ret == MBEDTLS_ERR_SSL_WANT_READ
+ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
+ return NO_ERR;
+ }
+ loginfo(logger_, ret, "Handshake failed");
+ return ERR;
+ }
+ state_ = TRANSFER;
+ break;
+ }
+ case CLOSE: {
+ auto ret = mbedtls_ssl_close_notify(&ssl_);
+ if (ret) {
+ if (ret == MBEDTLS_ERR_SSL_WANT_READ
+ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
+ return NO_ERR;
+ }
+ loginfo(logger_, ret, "Close notify failed");
+ return ERR;
+ }
+ return CLOSED;
+ }
+ case TRANSFER: {
+ size_t avail;
+ auto wptr = out->write_ptr(&avail);
+ if (avail > 0) {
+ auto ret = mbedtls_ssl_read(&ssl_,
+ reinterpret_cast<unsigned char*>(wptr),
+ avail);
+ if (ret > 0) {
+ out->commit(ret);
+ } else if (ret == MBEDTLS_ERR_SSL_WANT_READ
+ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
+ want_read = true;
+ } else if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
+ return CLOSED;
+ } else {
+ loginfo(logger_, ret, "SSL read failed");
+ return ERR;
+ }
+ } else {
+ assert(false);
+ want_read = true;
+ }
+ auto rptr = in->read_ptr(&avail);
+ if (avail > 0) {
+ auto ret = mbedtls_ssl_write(
+ &ssl_,
+ reinterpret_cast<unsigned char const*>(rptr),
+ avail);
+ if (ret > 0) {
+ in->consume(ret);
+ } else if (ret == MBEDTLS_ERR_SSL_WANT_READ
+ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
+ want_write = true;
+ } else {
+ loginfo(logger_, ret, "SSL write failed");
+ return ERR;
+ }
+ } else {
+ want_write = true;
+ }
+ break;
+ }
+ }
+ }
+ return NO_ERR;
+ }
+
+ int send(unsigned char const* buf, size_t len) {
+ if (!out_) {
+ assert(false);
+ return MBEDTLS_ERR_SSL_WANT_WRITE;
+ }
+ out_->write(buf, len);
+ return len;
+ }
+ int recv(unsigned char* buf, size_t len) {
+ if (!in_) {
+ assert(false);
+ return MBEDTLS_ERR_SSL_WANT_READ;
+ }
+ auto ret = in_->read(buf, len);
+ if (ret == 0) {
+ return MBEDTLS_ERR_SSL_WANT_READ;
+ }
+ return ret;
+ }
+ static int ssl_send(void* ctx, unsigned char const* buf, size_t len) {
+ return reinterpret_cast<SSLImpl*>(ctx)->send(buf, len);
+ }
+ static int ssl_recv(void* ctx, unsigned char* buf, size_t len) {
+ return reinterpret_cast<SSLImpl*>(ctx)->recv(buf, len);
+ }
+
+ uint16_t const flags_;
+ Buffer* in_;
+ Buffer* out_;
+ TransferResult result_;
+ State state_;
+};
+
+struct CertKey {
+ SSLCert* cert;
+ SSLKey* key;
+
+ CertKey(SSLCert* cert, SSLKey* key)
+ : cert(cert), key(key) {
+ }
+};
+
+class SSLServerImpl : public SSLImpl<CertKey> {
+public:
+ SSLServerImpl(Logger* logger, SSLEntropy* entropy, SSLCert* cert,
+ SSLKey* key, uint16_t flags)
+ : SSLImpl(logger, flags) {
+ SSLImpl::init(entropy, MBEDTLS_SSL_IS_SERVER, CertKey(cert, key));
+ }
+
+ ~SSLServerImpl() override {
+ }
+
+private:
+ bool before_setup(CertKey const& data) override {
+ mbedtls_ssl_conf_ca_chain(
+ &conf_, static_cast<SSLCertImpl*>(data.cert)->cert()->next, nullptr);
+ auto ret = mbedtls_ssl_conf_own_cert(
+ &conf_,
+ static_cast<SSLCertImpl*>(data.cert)->cert(),
+ static_cast<SSLKeyImpl*>(data.key)->key());
+ if (ret) {
+ logerr(logger_, ret, "Server: Error setting certificate");
+ return false;
+ }
+
+ mbedtls_ssl_conf_min_version(&conf_,
+ MBEDTLS_SSL_MAJOR_VERSION_3,
+ unsecure() ? MBEDTLS_SSL_MINOR_VERSION_0 :
+ MBEDTLS_SSL_MINOR_VERSION_1);
+ return true;
+ }
+};
+
+struct HostStore {
+ std::string const& host;
+ SSLCertStore* const store;
+
+ HostStore(std::string const& host, SSLCertStore* store)
+ : host(host), store(store) {
+ }
+};
+
+class SSLClientImpl : public SSLImpl<HostStore> {
+public:
+ SSLClientImpl(Logger* logger, SSLEntropy* entropy, SSLCertStore* store,
+ std::string const& host, uint16_t flags)
+ : SSLImpl(logger, flags) {
+ SSLImpl::init(entropy, MBEDTLS_SSL_IS_CLIENT, HostStore(host, store));
+ }
+
+private:
+ bool before_setup(HostStore const& data) override {
+ mbedtls_ssl_conf_authmode(&conf_, unsecure() ? MBEDTLS_SSL_VERIFY_OPTIONAL
+ : MBEDTLS_SSL_VERIFY_REQUIRED);
+ mbedtls_ssl_conf_ca_chain(
+ &conf_, static_cast<SSLCertStoreImpl*>(data.store)->chain(), nullptr);
+ return true;
+ }
+
+ bool after_setup(HostStore const& data) override {
+ auto ret = mbedtls_ssl_set_hostname(&ssl_, data.host.c_str());
+ if (ret) {
+ logerr(logger_, ret, "Failed to set hostname");
+ return false;
+ }
+ return true;
+ }
+};
+
+} // namespace
+
+// static
+SSLEntropy* SSLEntropy::create(Logger* logger) {
+ std::unique_ptr<SSLEntropyImpl> entropy(new SSLEntropyImpl());
+ if (!entropy->init(logger)) return nullptr;
+ return entropy.release();
+}
+
+// static
+SSLCertStore* SSLCertStore::create(Logger* logger, std::string const& path) {
+ std::unique_ptr<SSLCertStoreImpl> store(new SSLCertStoreImpl());
+ if (!store->init(logger, path)) return nullptr;
+ return store.release();
+}
+
+// static
+SSLKey* SSLKey::load(Logger* logger, std::string const& data) {
+ std::unique_ptr<SSLKeyImpl> key(new SSLKeyImpl());
+ if (!key->load(logger, data)) return nullptr;
+ return key.release();
+}
+
+// static
+bool SSLKey::generate(Logger* logger, SSLEntropy* entropy, std::string* key) {
+ mbedtls_pk_context pk;
+ bool ok = false;
+ unsigned char buffer[16000];
+ mbedtls_pk_init(&pk);
+
+ auto ret = mbedtls_pk_setup(&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA));
+ if (ret) {
+ logerr(logger, ret, "Error setting up key");
+ goto error;
+ }
+
+ ret = mbedtls_rsa_gen_key(mbedtls_pk_rsa(pk),
+ mbedtls_ctr_drbg_random,
+ static_cast<SSLEntropyImpl*>(entropy)->random(),
+ 4096, 65537);
+ if (ret) {
+ logerr(logger, ret, "Error generating key");
+ goto error;
+ }
+
+ ret = mbedtls_pk_write_key_pem(&pk, buffer, sizeof(buffer));
+ if (ret) {
+ logerr(logger, ret, "Error writing key");
+ goto error;
+ }
+ key->assign(reinterpret_cast<char*>(buffer));
+
+ ok = true;
+ error:
+ mbedtls_pk_free(&pk);
+ return ok;
+}
+
+// static
+SSLCert* SSLCert::load(Logger* logger, std::string const& data) {
+ std::unique_ptr<SSLCertImpl> key(new SSLCertImpl());
+ if (!key->load(logger, data)) return nullptr;
+ return key.release();
+}
+
+// static
+bool SSLCert::generate(Logger* logger, SSLEntropy* entropy,
+ SSLCert* issuer_cert, SSLKey* issuer_key,
+ std::string const& host, SSLKey* key,
+ std::string* cert) {
+ mbedtls_x509write_cert crt;
+ char issuer_name[256];
+ std::string subject;
+ bool ok = false;
+ time_t now, tmp;
+ struct tm tm;
+ char not_before[20];
+ char not_after[20];
+ unsigned char buffer[16000];
+ mbedtls_x509write_crt_init(&crt);
+ mbedtls_x509write_crt_set_md_alg(&crt, MBEDTLS_MD_SHA256);
+
+ if (key) {
+ mbedtls_x509write_crt_set_subject_key(
+ &crt, static_cast<SSLKeyImpl*>(key)->key());
+ }
+ if (issuer_key) {
+ mbedtls_x509write_crt_set_issuer_key(
+ &crt, static_cast<SSLKeyImpl*>(issuer_key)->key());
+ }
+
+ subject = "CN=" + host;
+ auto ret = mbedtls_x509write_crt_set_subject_name(&crt, subject.c_str());
+ if (ret) {
+ logerr(logger, ret, "Invalid subject name");
+ goto error;
+ }
+
+ if (issuer_cert) {
+ ret = mbedtls_x509_dn_gets(
+ issuer_name, sizeof(issuer_name),
+ &static_cast<SSLCertImpl*>(issuer_cert)->cert()->subject);
+ if (ret < 0) {
+ logerr(logger, ret, "Unable to get issuer name");
+ goto error;
+ }
+ ret = mbedtls_x509write_crt_set_issuer_name(&crt, issuer_name);
+ } else {
+ ret = mbedtls_x509write_crt_set_issuer_name(&crt, subject.c_str());
+ }
+ if (ret) {
+ logerr(logger, ret, "Invalid issuer name");
+ goto error;
+ }
+ now = time(nullptr);
+ tmp = now - (24 * 60 * 60 - 1);
+ gmtime_r(&tmp, &tm);
+ strftime(not_before, sizeof(not_before), "%Y%m%d000000", &tm);
+ tmp = now + (30 * 24 * 60 * 60);
+ gmtime_r(&tmp, &tm);
+ strftime(not_after, sizeof(not_after), "%Y%m%d000000", &tm);
+ ret = mbedtls_x509write_crt_set_validity(&crt, not_before, not_after);
+ if (ret) {
+ logerr(logger, ret, "Unable to set validity");
+ goto error;
+ }
+ if (issuer_cert) {
+ ret = mbedtls_x509write_crt_set_basic_constraints(&crt, 0, -1);
+ } else {
+ ret = mbedtls_x509write_crt_set_basic_constraints(&crt, 1, 1);
+ }
+ if (ret) {
+ logerr(logger, ret, "Unable to set basic constraints");
+ goto error;
+ }
+
+ ret = mbedtls_x509write_crt_pem(
+ &crt, buffer, sizeof(buffer),
+ mbedtls_ctr_drbg_random,
+ static_cast<SSLEntropyImpl*>(entropy)->random());
+ if (ret) {
+ logerr(logger, ret, "Error writing cert");
+ goto error;
+ }
+ cert->assign(reinterpret_cast<char*>(buffer));
+
+ ok = true;
+ error:
+ mbedtls_x509write_crt_free(&crt);
+ return ok;
+}
+
+// static
+SSL* SSL::server(Logger* logger, SSLEntropy* entropy, SSLCert* cert,
+ SSLKey* key, uint16_t flags) {
+ return new SSLServerImpl(logger, entropy, cert, key, flags);
+}
+
+// static
+SSL* SSL::client(Logger* logger, SSLEntropy* entropy, SSLCertStore* store,
+ std::string const& host, uint16_t flags) {
+ return new SSLClientImpl(logger, entropy, store, host, flags);
+}
+
+// static
+const uint16_t SSL::UNSECURE = 0x01;
+
diff --git a/src/ssl.hh b/src/ssl.hh
new file mode 100644
index 0000000..1cd6aea
--- /dev/null
+++ b/src/ssl.hh
@@ -0,0 +1,89 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef SSL_HH
+#define SSL_HH
+
+#include <string>
+
+class Buffer;
+class Logger;
+
+class SSLEntropy {
+public:
+ virtual ~SSLEntropy() {}
+
+ static SSLEntropy* create(Logger* logger);
+
+protected:
+ SSLEntropy() {}
+ SSLEntropy(SSLEntropy const&) = delete;
+};
+
+class SSLCertStore {
+public:
+ virtual ~SSLCertStore() {}
+
+ static SSLCertStore* create(Logger* logger, std::string const& bundle);
+
+protected:
+ SSLCertStore() {}
+ SSLCertStore(SSLCertStore const&) = delete;
+};
+
+class SSLKey {
+public:
+ virtual ~SSLKey() {}
+ static bool generate(Logger* logger, SSLEntropy* entropy, std::string* key);
+ static SSLKey* load(Logger* logger, std::string const& data);
+
+protected:
+ SSLKey() {}
+ SSLKey(SSLKey const&) = delete;
+};
+
+class SSLCert {
+public:
+ virtual ~SSLCert() {}
+ static bool generate(Logger* logger, SSLEntropy* entropy,
+ SSLCert* issuer_cert, SSLKey* issuer_key,
+ std::string const& host, SSLKey* key,
+ std::string* cert);
+ static SSLCert* load(Logger* logger, std::string const& data);
+
+protected:
+ SSLCert() {}
+ SSLCert(SSLCert const&) = delete;
+};
+
+class SSL {
+public:
+ virtual ~SSL() {}
+
+ // For server: allow SSLv3 and old unsecure certs like RC4
+ // For client: allow self signed certs and missmatched hostname
+ static const uint16_t UNSECURE;
+
+ static SSL* server(Logger* logger, SSLEntropy* entropy,
+ SSLCert* cert, SSLKey* key,
+ uint16_t flags);
+ static SSL* client(Logger* logger, SSLEntropy* entropy, SSLCertStore* store,
+ std::string const& host, uint16_t flags);
+
+ enum TransferResult {
+ NO_ERR,
+ ERR,
+ CLOSED,
+ };
+
+ // reads SSL from ssl_in and writes SSL to ssl_out
+ // reads data from data_in and writes data to data_out
+ virtual TransferResult transfer(Buffer* ssl_in, Buffer* ssl_out,
+ Buffer* data_in, Buffer* data_out) = 0;
+ virtual void close() = 0;
+
+protected:
+ SSL() {}
+ SSL(SSL const&) = delete;
+};
+
+#endif // SSL_HH