From 7e9b90fb692b80df159992f62458c70c9fe36781 Mon Sep 17 00:00:00 2001 From: Joel Klinghed Date: Tue, 28 Mar 2017 22:53:04 +0200 Subject: Support compiling without SSL And prepare for other SSL implementations than mbedtls --- src/Makefile.am | 8 +- src/mitm_stub.cc | 11 + src/ssl.cc | 602 ----------------------------------------------------- src/ssl_mbedtls.cc | 602 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 620 insertions(+), 603 deletions(-) create mode 100644 src/mitm_stub.cc delete mode 100644 src/ssl.cc create mode 100644 src/ssl_mbedtls.cc (limited to 'src') diff --git a/src/Makefile.am b/src/Makefile.am index fbcdc67..7b1034c 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -22,9 +22,15 @@ tp_CXXFLAGS = $(AM_CXXFLAGS) -DVERSION='"@VERSION@"' @THREAD_CFLAGS@ libtp_a_SOURCES = args.cc xdg.cc terminal.cc http.cc url.cc paths.cc \ character.cc config.cc strings.cc io.cc looper.cc \ buffer.cc chunked.cc +if !HAVE_SSL +libtp_a_SOURCES += mitm_stub.cc +endif libtp_a_CXXFLAGS = $(AM_CXXFLAGS) -DSYSCONFDIR='"@SYSCONFDIR@"' -libmitm_a_SOURCES = ssl.cc mitm.cc +libmitm_a_SOURCES = mitm.cc +if HAVE_MBEDTLS +libmitm_a_SOURCES += ssl_mbedtls.cc +endif libmitm_a_CXXFLAGS = $(AM_CXXFLAGS) @SSL_CFLAGS@ tp_genca_SOURCES = genca.cc logger.cc diff --git a/src/mitm_stub.cc b/src/mitm_stub.cc new file mode 100644 index 0000000..6f4dcb0 --- /dev/null +++ b/src/mitm_stub.cc @@ -0,0 +1,11 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include "mitm.hh" + +// static +Mitm* Mitm::create(Logger* UNUSED(logger), Config* UNUSED(config), + std::string const& UNUSED(cwd)) { + return nullptr; +} diff --git a/src/ssl.cc b/src/ssl.cc deleted file mode 100644 index 3395d83..0000000 --- a/src/ssl.cc +++ /dev/null @@ -1,602 +0,0 @@ -// -*- mode: c++; c-basic-offset: 2; -*- - -#include "common.hh" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "buffer.hh" -#include "logger.hh" -#include "ssl.hh" - -namespace { - -void logout(Logger* logger, Logger::Level lvl, int errnum, - const std::string& msg) { - char tmp[256]; - mbedtls_strerror(errnum, tmp, sizeof(tmp)); - logger->out(lvl, "%s (%d): %s", msg.c_str(), errnum, tmp); -} - -inline void loginfo(Logger* logger, int errnum, const std::string& msg) { - logout(logger, Logger::INFO, errnum, msg); -} - -inline void logerr(Logger* logger, int errnum, const std::string& msg) { - logout(logger, Logger::ERR, errnum, msg); -} - -class SSLEntropyImpl : public SSLEntropy { -public: - SSLEntropyImpl() { - mbedtls_entropy_init(&entropy_); - mbedtls_ctr_drbg_init(&ctr_drbg_); - } - - ~SSLEntropyImpl() override { - mbedtls_ctr_drbg_free(&ctr_drbg_); - mbedtls_entropy_free(&entropy_); - } - - bool init(Logger* logger) { - auto ret = mbedtls_ctr_drbg_seed(&ctr_drbg_, mbedtls_entropy_func, - &entropy_, nullptr, 0); - if (ret) { - logerr(logger, ret, "Error seeding entropy"); - return false; - } - return true; - } - - void setup(mbedtls_ssl_config* conf) { - mbedtls_ssl_conf_rng(conf, mbedtls_ctr_drbg_random, &ctr_drbg_); - } - - mbedtls_ctr_drbg_context* random() { - return &ctr_drbg_; - } - -private: - mbedtls_entropy_context entropy_; - mbedtls_ctr_drbg_context ctr_drbg_; -}; - -class SSLCertStoreImpl : public SSLCertStore { -public: - SSLCertStoreImpl() { - mbedtls_x509_crt_init(&chain_); - } - - ~SSLCertStoreImpl() override { - mbedtls_x509_crt_free(&chain_); - } - - bool init(Logger* logger, std::string const& path) { - auto ret = mbedtls_x509_crt_parse_file(&chain_, path.c_str()); - if (ret) { - logerr(logger, ret, "Error loeading cert store bundle"); - return false; - } - return true; - } - - mbedtls_x509_crt* chain() { - return &chain_; - } - -private: - mbedtls_x509_crt chain_; -}; - -class SSLKeyImpl : public SSLKey { -public: - SSLKeyImpl() { - mbedtls_pk_init(&key_); - } - - ~SSLKeyImpl() override { - mbedtls_pk_free(&key_); - } - - bool load(Logger* logger, std::string const& data) { - auto ret = mbedtls_pk_parse_key( - &key_, - reinterpret_cast(data.c_str()), data.size() + 1, - nullptr, 0); - if (ret) { - logerr(logger, ret, "Error parsing key"); - return false; - } - return true; - } - - mbedtls_pk_context* key() { - return &key_; - } - -private: - mbedtls_pk_context key_; -}; - -class SSLCertImpl : public SSLCert { -public: - SSLCertImpl() { - mbedtls_x509_crt_init(&cert_); - } - - ~SSLCertImpl() override { - mbedtls_x509_crt_free(&cert_); - } - - bool load(Logger* logger, std::string const& data) { - auto ret = mbedtls_x509_crt_parse( - &cert_, - reinterpret_cast(data.c_str()), data.size() + 1); - if (ret < 0) { - logerr(logger, ret, "Error parsing cert"); - return false; - } - return true; - } - - mbedtls_x509_crt* cert() { - return &cert_; - } - -private: - mbedtls_x509_crt cert_; -}; - -template -class SSLImpl : public SSL { -public: - SSLImpl(Logger* logger, uint16_t flags) - : logger_(logger), flags_(flags), in_(nullptr), out_(nullptr), - result_(NO_ERR), state_(HANDSHAKE) { - } - - ~SSLImpl() override { - mbedtls_ssl_free(&ssl_); - mbedtls_ssl_config_free(&conf_); - } - - bool unsecure() const { - return flags_ & UNSECURE; - } - - TransferResult transfer(Buffer* ssl_in, Buffer* ssl_out, - Buffer* data_in, Buffer* data_out) override { - if (result_) return result_; - - in_ = ssl_in; - out_ = ssl_out; - - result_ = transfer(data_in, data_out); - - in_ = nullptr; - out_ = nullptr; - - return result_; - } - - void close() override { - state_ = CLOSE; - } - -protected: - void init(SSLEntropy* entropy, int endpoint, SetupData const& data) { - mbedtls_ssl_init(&ssl_); - mbedtls_ssl_config_init(&conf_); - - auto ret = mbedtls_ssl_config_defaults(&conf_, endpoint, - MBEDTLS_SSL_TRANSPORT_STREAM, - MBEDTLS_SSL_PRESET_DEFAULT); - if (ret) { - logerr(logger_, ret, "Error configuring SSL"); - result_ = ERR; - return; - } - - static_cast(entropy)->setup(&conf_); - - if (!before_setup(data)) { - result_ = ERR; - return; - } - - ret = mbedtls_ssl_setup(&ssl_, &conf_); - if (ret) { - logerr(logger_, ret, "Failed to setup SSL"); - result_ = ERR; - return; - } - - mbedtls_ssl_set_bio(&ssl_, this, &ssl_send, &ssl_recv, nullptr); - - if (!after_setup(data)) { - result_ = ERR; - return; - } - - result_ = NO_ERR; - } - - virtual bool before_setup(SetupData const& UNUSED(data)) { - return true; - } - - virtual bool after_setup(SetupData const& UNUSED(data)) { - return true; - } - - Logger* const logger_; - mbedtls_ssl_context ssl_; - mbedtls_ssl_config conf_; - -private: - enum State { - HANDSHAKE, - TRANSFER, - CLOSE - }; - - TransferResult transfer(Buffer* in, Buffer* out) { - bool want_read = false, want_write = false; - while (!want_read || !want_write) { - switch (state_) { - case HANDSHAKE: { - auto ret = mbedtls_ssl_handshake(&ssl_); - if (ret) { - if (ret == MBEDTLS_ERR_SSL_WANT_READ - || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { - return NO_ERR; - } - loginfo(logger_, ret, "Handshake failed"); - return ERR; - } - state_ = TRANSFER; - break; - } - case CLOSE: { - auto ret = mbedtls_ssl_close_notify(&ssl_); - if (ret) { - if (ret == MBEDTLS_ERR_SSL_WANT_READ - || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { - return NO_ERR; - } - loginfo(logger_, ret, "Close notify failed"); - return ERR; - } - return CLOSED; - } - case TRANSFER: { - size_t avail; - auto wptr = out->write_ptr(&avail); - if (avail > 0) { - auto ret = mbedtls_ssl_read(&ssl_, - reinterpret_cast(wptr), - avail); - if (ret > 0) { - out->commit(ret); - } else if (ret == MBEDTLS_ERR_SSL_WANT_READ - || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { - want_read = true; - } else if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { - return CLOSED; - } else { - loginfo(logger_, ret, "SSL read failed"); - return ERR; - } - } else { - assert(false); - want_read = true; - } - auto rptr = in->read_ptr(&avail); - if (avail > 0) { - auto ret = mbedtls_ssl_write( - &ssl_, - reinterpret_cast(rptr), - avail); - if (ret > 0) { - in->consume(ret); - } else if (ret == MBEDTLS_ERR_SSL_WANT_READ - || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { - want_write = true; - } else { - loginfo(logger_, ret, "SSL write failed"); - return ERR; - } - } else { - want_write = true; - } - break; - } - } - } - return NO_ERR; - } - - int send(unsigned char const* buf, size_t len) { - if (!out_) { - assert(false); - return MBEDTLS_ERR_SSL_WANT_WRITE; - } - out_->write(buf, len); - return len; - } - int recv(unsigned char* buf, size_t len) { - if (!in_) { - assert(false); - return MBEDTLS_ERR_SSL_WANT_READ; - } - auto ret = in_->read(buf, len); - if (ret == 0) { - return MBEDTLS_ERR_SSL_WANT_READ; - } - return ret; - } - static int ssl_send(void* ctx, unsigned char const* buf, size_t len) { - return reinterpret_cast(ctx)->send(buf, len); - } - static int ssl_recv(void* ctx, unsigned char* buf, size_t len) { - return reinterpret_cast(ctx)->recv(buf, len); - } - - uint16_t const flags_; - Buffer* in_; - Buffer* out_; - TransferResult result_; - State state_; -}; - -struct CertKey { - SSLCert* cert; - SSLKey* key; - - CertKey(SSLCert* cert, SSLKey* key) - : cert(cert), key(key) { - } -}; - -class SSLServerImpl : public SSLImpl { -public: - SSLServerImpl(Logger* logger, SSLEntropy* entropy, SSLCert* cert, - SSLKey* key, uint16_t flags) - : SSLImpl(logger, flags) { - SSLImpl::init(entropy, MBEDTLS_SSL_IS_SERVER, CertKey(cert, key)); - } - - ~SSLServerImpl() override { - } - -private: - bool before_setup(CertKey const& data) override { - mbedtls_ssl_conf_ca_chain( - &conf_, static_cast(data.cert)->cert()->next, nullptr); - auto ret = mbedtls_ssl_conf_own_cert( - &conf_, - static_cast(data.cert)->cert(), - static_cast(data.key)->key()); - if (ret) { - logerr(logger_, ret, "Server: Error setting certificate"); - return false; - } - - mbedtls_ssl_conf_min_version(&conf_, - MBEDTLS_SSL_MAJOR_VERSION_3, - unsecure() ? MBEDTLS_SSL_MINOR_VERSION_0 : - MBEDTLS_SSL_MINOR_VERSION_1); - return true; - } -}; - -struct HostStore { - std::string const& host; - SSLCertStore* const store; - - HostStore(std::string const& host, SSLCertStore* store) - : host(host), store(store) { - } -}; - -class SSLClientImpl : public SSLImpl { -public: - SSLClientImpl(Logger* logger, SSLEntropy* entropy, SSLCertStore* store, - std::string const& host, uint16_t flags) - : SSLImpl(logger, flags) { - SSLImpl::init(entropy, MBEDTLS_SSL_IS_CLIENT, HostStore(host, store)); - } - -private: - bool before_setup(HostStore const& data) override { - mbedtls_ssl_conf_authmode(&conf_, unsecure() ? MBEDTLS_SSL_VERIFY_OPTIONAL - : MBEDTLS_SSL_VERIFY_REQUIRED); - mbedtls_ssl_conf_ca_chain( - &conf_, static_cast(data.store)->chain(), nullptr); - return true; - } - - bool after_setup(HostStore const& data) override { - auto ret = mbedtls_ssl_set_hostname(&ssl_, data.host.c_str()); - if (ret) { - logerr(logger_, ret, "Failed to set hostname"); - return false; - } - return true; - } -}; - -} // namespace - -// static -SSLEntropy* SSLEntropy::create(Logger* logger) { - std::unique_ptr entropy(new SSLEntropyImpl()); - if (!entropy->init(logger)) return nullptr; - return entropy.release(); -} - -// static -SSLCertStore* SSLCertStore::create(Logger* logger, std::string const& path) { - std::unique_ptr store(new SSLCertStoreImpl()); - if (!store->init(logger, path)) return nullptr; - return store.release(); -} - -// static -SSLKey* SSLKey::load(Logger* logger, std::string const& data) { - std::unique_ptr key(new SSLKeyImpl()); - if (!key->load(logger, data)) return nullptr; - return key.release(); -} - -// static -bool SSLKey::generate(Logger* logger, SSLEntropy* entropy, std::string* key) { - mbedtls_pk_context pk; - bool ok = false; - unsigned char buffer[16000]; - mbedtls_pk_init(&pk); - - auto ret = mbedtls_pk_setup(&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA)); - if (ret) { - logerr(logger, ret, "Error setting up key"); - goto error; - } - - ret = mbedtls_rsa_gen_key(mbedtls_pk_rsa(pk), - mbedtls_ctr_drbg_random, - static_cast(entropy)->random(), - 4096, 65537); - if (ret) { - logerr(logger, ret, "Error generating key"); - goto error; - } - - ret = mbedtls_pk_write_key_pem(&pk, buffer, sizeof(buffer)); - if (ret) { - logerr(logger, ret, "Error writing key"); - goto error; - } - key->assign(reinterpret_cast(buffer)); - - ok = true; - error: - mbedtls_pk_free(&pk); - return ok; -} - -// static -SSLCert* SSLCert::load(Logger* logger, std::string const& data) { - std::unique_ptr key(new SSLCertImpl()); - if (!key->load(logger, data)) return nullptr; - return key.release(); -} - -// static -bool SSLCert::generate(Logger* logger, SSLEntropy* entropy, - SSLCert* issuer_cert, SSLKey* issuer_key, - std::string const& host, SSLKey* key, - std::string* cert) { - mbedtls_x509write_cert crt; - char issuer_name[256]; - std::string subject; - bool ok = false; - time_t now, tmp; - struct tm tm; - char not_before[20]; - char not_after[20]; - unsigned char buffer[16000]; - mbedtls_x509write_crt_init(&crt); - mbedtls_x509write_crt_set_md_alg(&crt, MBEDTLS_MD_SHA256); - - if (key) { - mbedtls_x509write_crt_set_subject_key( - &crt, static_cast(key)->key()); - } - if (issuer_key) { - mbedtls_x509write_crt_set_issuer_key( - &crt, static_cast(issuer_key)->key()); - } - - subject = "CN=" + host; - auto ret = mbedtls_x509write_crt_set_subject_name(&crt, subject.c_str()); - if (ret) { - logerr(logger, ret, "Invalid subject name"); - goto error; - } - - if (issuer_cert) { - ret = mbedtls_x509_dn_gets( - issuer_name, sizeof(issuer_name), - &static_cast(issuer_cert)->cert()->subject); - if (ret < 0) { - logerr(logger, ret, "Unable to get issuer name"); - goto error; - } - ret = mbedtls_x509write_crt_set_issuer_name(&crt, issuer_name); - } else { - ret = mbedtls_x509write_crt_set_issuer_name(&crt, subject.c_str()); - } - if (ret) { - logerr(logger, ret, "Invalid issuer name"); - goto error; - } - now = time(nullptr); - tmp = now - (24 * 60 * 60 - 1); - gmtime_r(&tmp, &tm); - strftime(not_before, sizeof(not_before), "%Y%m%d000000", &tm); - tmp = now + (30 * 24 * 60 * 60); - gmtime_r(&tmp, &tm); - strftime(not_after, sizeof(not_after), "%Y%m%d000000", &tm); - ret = mbedtls_x509write_crt_set_validity(&crt, not_before, not_after); - if (ret) { - logerr(logger, ret, "Unable to set validity"); - goto error; - } - if (issuer_cert) { - ret = mbedtls_x509write_crt_set_basic_constraints(&crt, 0, -1); - } else { - ret = mbedtls_x509write_crt_set_basic_constraints(&crt, 1, 1); - } - if (ret) { - logerr(logger, ret, "Unable to set basic constraints"); - goto error; - } - - ret = mbedtls_x509write_crt_pem( - &crt, buffer, sizeof(buffer), - mbedtls_ctr_drbg_random, - static_cast(entropy)->random()); - if (ret) { - logerr(logger, ret, "Error writing cert"); - goto error; - } - cert->assign(reinterpret_cast(buffer)); - - ok = true; - error: - mbedtls_x509write_crt_free(&crt); - return ok; -} - -// static -SSL* SSL::server(Logger* logger, SSLEntropy* entropy, SSLCert* cert, - SSLKey* key, uint16_t flags) { - return new SSLServerImpl(logger, entropy, cert, key, flags); -} - -// static -SSL* SSL::client(Logger* logger, SSLEntropy* entropy, SSLCertStore* store, - std::string const& host, uint16_t flags) { - return new SSLClientImpl(logger, entropy, store, host, flags); -} - -// static -const uint16_t SSL::UNSECURE = 0x01; - diff --git a/src/ssl_mbedtls.cc b/src/ssl_mbedtls.cc new file mode 100644 index 0000000..3395d83 --- /dev/null +++ b/src/ssl_mbedtls.cc @@ -0,0 +1,602 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "buffer.hh" +#include "logger.hh" +#include "ssl.hh" + +namespace { + +void logout(Logger* logger, Logger::Level lvl, int errnum, + const std::string& msg) { + char tmp[256]; + mbedtls_strerror(errnum, tmp, sizeof(tmp)); + logger->out(lvl, "%s (%d): %s", msg.c_str(), errnum, tmp); +} + +inline void loginfo(Logger* logger, int errnum, const std::string& msg) { + logout(logger, Logger::INFO, errnum, msg); +} + +inline void logerr(Logger* logger, int errnum, const std::string& msg) { + logout(logger, Logger::ERR, errnum, msg); +} + +class SSLEntropyImpl : public SSLEntropy { +public: + SSLEntropyImpl() { + mbedtls_entropy_init(&entropy_); + mbedtls_ctr_drbg_init(&ctr_drbg_); + } + + ~SSLEntropyImpl() override { + mbedtls_ctr_drbg_free(&ctr_drbg_); + mbedtls_entropy_free(&entropy_); + } + + bool init(Logger* logger) { + auto ret = mbedtls_ctr_drbg_seed(&ctr_drbg_, mbedtls_entropy_func, + &entropy_, nullptr, 0); + if (ret) { + logerr(logger, ret, "Error seeding entropy"); + return false; + } + return true; + } + + void setup(mbedtls_ssl_config* conf) { + mbedtls_ssl_conf_rng(conf, mbedtls_ctr_drbg_random, &ctr_drbg_); + } + + mbedtls_ctr_drbg_context* random() { + return &ctr_drbg_; + } + +private: + mbedtls_entropy_context entropy_; + mbedtls_ctr_drbg_context ctr_drbg_; +}; + +class SSLCertStoreImpl : public SSLCertStore { +public: + SSLCertStoreImpl() { + mbedtls_x509_crt_init(&chain_); + } + + ~SSLCertStoreImpl() override { + mbedtls_x509_crt_free(&chain_); + } + + bool init(Logger* logger, std::string const& path) { + auto ret = mbedtls_x509_crt_parse_file(&chain_, path.c_str()); + if (ret) { + logerr(logger, ret, "Error loeading cert store bundle"); + return false; + } + return true; + } + + mbedtls_x509_crt* chain() { + return &chain_; + } + +private: + mbedtls_x509_crt chain_; +}; + +class SSLKeyImpl : public SSLKey { +public: + SSLKeyImpl() { + mbedtls_pk_init(&key_); + } + + ~SSLKeyImpl() override { + mbedtls_pk_free(&key_); + } + + bool load(Logger* logger, std::string const& data) { + auto ret = mbedtls_pk_parse_key( + &key_, + reinterpret_cast(data.c_str()), data.size() + 1, + nullptr, 0); + if (ret) { + logerr(logger, ret, "Error parsing key"); + return false; + } + return true; + } + + mbedtls_pk_context* key() { + return &key_; + } + +private: + mbedtls_pk_context key_; +}; + +class SSLCertImpl : public SSLCert { +public: + SSLCertImpl() { + mbedtls_x509_crt_init(&cert_); + } + + ~SSLCertImpl() override { + mbedtls_x509_crt_free(&cert_); + } + + bool load(Logger* logger, std::string const& data) { + auto ret = mbedtls_x509_crt_parse( + &cert_, + reinterpret_cast(data.c_str()), data.size() + 1); + if (ret < 0) { + logerr(logger, ret, "Error parsing cert"); + return false; + } + return true; + } + + mbedtls_x509_crt* cert() { + return &cert_; + } + +private: + mbedtls_x509_crt cert_; +}; + +template +class SSLImpl : public SSL { +public: + SSLImpl(Logger* logger, uint16_t flags) + : logger_(logger), flags_(flags), in_(nullptr), out_(nullptr), + result_(NO_ERR), state_(HANDSHAKE) { + } + + ~SSLImpl() override { + mbedtls_ssl_free(&ssl_); + mbedtls_ssl_config_free(&conf_); + } + + bool unsecure() const { + return flags_ & UNSECURE; + } + + TransferResult transfer(Buffer* ssl_in, Buffer* ssl_out, + Buffer* data_in, Buffer* data_out) override { + if (result_) return result_; + + in_ = ssl_in; + out_ = ssl_out; + + result_ = transfer(data_in, data_out); + + in_ = nullptr; + out_ = nullptr; + + return result_; + } + + void close() override { + state_ = CLOSE; + } + +protected: + void init(SSLEntropy* entropy, int endpoint, SetupData const& data) { + mbedtls_ssl_init(&ssl_); + mbedtls_ssl_config_init(&conf_); + + auto ret = mbedtls_ssl_config_defaults(&conf_, endpoint, + MBEDTLS_SSL_TRANSPORT_STREAM, + MBEDTLS_SSL_PRESET_DEFAULT); + if (ret) { + logerr(logger_, ret, "Error configuring SSL"); + result_ = ERR; + return; + } + + static_cast(entropy)->setup(&conf_); + + if (!before_setup(data)) { + result_ = ERR; + return; + } + + ret = mbedtls_ssl_setup(&ssl_, &conf_); + if (ret) { + logerr(logger_, ret, "Failed to setup SSL"); + result_ = ERR; + return; + } + + mbedtls_ssl_set_bio(&ssl_, this, &ssl_send, &ssl_recv, nullptr); + + if (!after_setup(data)) { + result_ = ERR; + return; + } + + result_ = NO_ERR; + } + + virtual bool before_setup(SetupData const& UNUSED(data)) { + return true; + } + + virtual bool after_setup(SetupData const& UNUSED(data)) { + return true; + } + + Logger* const logger_; + mbedtls_ssl_context ssl_; + mbedtls_ssl_config conf_; + +private: + enum State { + HANDSHAKE, + TRANSFER, + CLOSE + }; + + TransferResult transfer(Buffer* in, Buffer* out) { + bool want_read = false, want_write = false; + while (!want_read || !want_write) { + switch (state_) { + case HANDSHAKE: { + auto ret = mbedtls_ssl_handshake(&ssl_); + if (ret) { + if (ret == MBEDTLS_ERR_SSL_WANT_READ + || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + return NO_ERR; + } + loginfo(logger_, ret, "Handshake failed"); + return ERR; + } + state_ = TRANSFER; + break; + } + case CLOSE: { + auto ret = mbedtls_ssl_close_notify(&ssl_); + if (ret) { + if (ret == MBEDTLS_ERR_SSL_WANT_READ + || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + return NO_ERR; + } + loginfo(logger_, ret, "Close notify failed"); + return ERR; + } + return CLOSED; + } + case TRANSFER: { + size_t avail; + auto wptr = out->write_ptr(&avail); + if (avail > 0) { + auto ret = mbedtls_ssl_read(&ssl_, + reinterpret_cast(wptr), + avail); + if (ret > 0) { + out->commit(ret); + } else if (ret == MBEDTLS_ERR_SSL_WANT_READ + || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + want_read = true; + } else if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { + return CLOSED; + } else { + loginfo(logger_, ret, "SSL read failed"); + return ERR; + } + } else { + assert(false); + want_read = true; + } + auto rptr = in->read_ptr(&avail); + if (avail > 0) { + auto ret = mbedtls_ssl_write( + &ssl_, + reinterpret_cast(rptr), + avail); + if (ret > 0) { + in->consume(ret); + } else if (ret == MBEDTLS_ERR_SSL_WANT_READ + || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + want_write = true; + } else { + loginfo(logger_, ret, "SSL write failed"); + return ERR; + } + } else { + want_write = true; + } + break; + } + } + } + return NO_ERR; + } + + int send(unsigned char const* buf, size_t len) { + if (!out_) { + assert(false); + return MBEDTLS_ERR_SSL_WANT_WRITE; + } + out_->write(buf, len); + return len; + } + int recv(unsigned char* buf, size_t len) { + if (!in_) { + assert(false); + return MBEDTLS_ERR_SSL_WANT_READ; + } + auto ret = in_->read(buf, len); + if (ret == 0) { + return MBEDTLS_ERR_SSL_WANT_READ; + } + return ret; + } + static int ssl_send(void* ctx, unsigned char const* buf, size_t len) { + return reinterpret_cast(ctx)->send(buf, len); + } + static int ssl_recv(void* ctx, unsigned char* buf, size_t len) { + return reinterpret_cast(ctx)->recv(buf, len); + } + + uint16_t const flags_; + Buffer* in_; + Buffer* out_; + TransferResult result_; + State state_; +}; + +struct CertKey { + SSLCert* cert; + SSLKey* key; + + CertKey(SSLCert* cert, SSLKey* key) + : cert(cert), key(key) { + } +}; + +class SSLServerImpl : public SSLImpl { +public: + SSLServerImpl(Logger* logger, SSLEntropy* entropy, SSLCert* cert, + SSLKey* key, uint16_t flags) + : SSLImpl(logger, flags) { + SSLImpl::init(entropy, MBEDTLS_SSL_IS_SERVER, CertKey(cert, key)); + } + + ~SSLServerImpl() override { + } + +private: + bool before_setup(CertKey const& data) override { + mbedtls_ssl_conf_ca_chain( + &conf_, static_cast(data.cert)->cert()->next, nullptr); + auto ret = mbedtls_ssl_conf_own_cert( + &conf_, + static_cast(data.cert)->cert(), + static_cast(data.key)->key()); + if (ret) { + logerr(logger_, ret, "Server: Error setting certificate"); + return false; + } + + mbedtls_ssl_conf_min_version(&conf_, + MBEDTLS_SSL_MAJOR_VERSION_3, + unsecure() ? MBEDTLS_SSL_MINOR_VERSION_0 : + MBEDTLS_SSL_MINOR_VERSION_1); + return true; + } +}; + +struct HostStore { + std::string const& host; + SSLCertStore* const store; + + HostStore(std::string const& host, SSLCertStore* store) + : host(host), store(store) { + } +}; + +class SSLClientImpl : public SSLImpl { +public: + SSLClientImpl(Logger* logger, SSLEntropy* entropy, SSLCertStore* store, + std::string const& host, uint16_t flags) + : SSLImpl(logger, flags) { + SSLImpl::init(entropy, MBEDTLS_SSL_IS_CLIENT, HostStore(host, store)); + } + +private: + bool before_setup(HostStore const& data) override { + mbedtls_ssl_conf_authmode(&conf_, unsecure() ? MBEDTLS_SSL_VERIFY_OPTIONAL + : MBEDTLS_SSL_VERIFY_REQUIRED); + mbedtls_ssl_conf_ca_chain( + &conf_, static_cast(data.store)->chain(), nullptr); + return true; + } + + bool after_setup(HostStore const& data) override { + auto ret = mbedtls_ssl_set_hostname(&ssl_, data.host.c_str()); + if (ret) { + logerr(logger_, ret, "Failed to set hostname"); + return false; + } + return true; + } +}; + +} // namespace + +// static +SSLEntropy* SSLEntropy::create(Logger* logger) { + std::unique_ptr entropy(new SSLEntropyImpl()); + if (!entropy->init(logger)) return nullptr; + return entropy.release(); +} + +// static +SSLCertStore* SSLCertStore::create(Logger* logger, std::string const& path) { + std::unique_ptr store(new SSLCertStoreImpl()); + if (!store->init(logger, path)) return nullptr; + return store.release(); +} + +// static +SSLKey* SSLKey::load(Logger* logger, std::string const& data) { + std::unique_ptr key(new SSLKeyImpl()); + if (!key->load(logger, data)) return nullptr; + return key.release(); +} + +// static +bool SSLKey::generate(Logger* logger, SSLEntropy* entropy, std::string* key) { + mbedtls_pk_context pk; + bool ok = false; + unsigned char buffer[16000]; + mbedtls_pk_init(&pk); + + auto ret = mbedtls_pk_setup(&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA)); + if (ret) { + logerr(logger, ret, "Error setting up key"); + goto error; + } + + ret = mbedtls_rsa_gen_key(mbedtls_pk_rsa(pk), + mbedtls_ctr_drbg_random, + static_cast(entropy)->random(), + 4096, 65537); + if (ret) { + logerr(logger, ret, "Error generating key"); + goto error; + } + + ret = mbedtls_pk_write_key_pem(&pk, buffer, sizeof(buffer)); + if (ret) { + logerr(logger, ret, "Error writing key"); + goto error; + } + key->assign(reinterpret_cast(buffer)); + + ok = true; + error: + mbedtls_pk_free(&pk); + return ok; +} + +// static +SSLCert* SSLCert::load(Logger* logger, std::string const& data) { + std::unique_ptr key(new SSLCertImpl()); + if (!key->load(logger, data)) return nullptr; + return key.release(); +} + +// static +bool SSLCert::generate(Logger* logger, SSLEntropy* entropy, + SSLCert* issuer_cert, SSLKey* issuer_key, + std::string const& host, SSLKey* key, + std::string* cert) { + mbedtls_x509write_cert crt; + char issuer_name[256]; + std::string subject; + bool ok = false; + time_t now, tmp; + struct tm tm; + char not_before[20]; + char not_after[20]; + unsigned char buffer[16000]; + mbedtls_x509write_crt_init(&crt); + mbedtls_x509write_crt_set_md_alg(&crt, MBEDTLS_MD_SHA256); + + if (key) { + mbedtls_x509write_crt_set_subject_key( + &crt, static_cast(key)->key()); + } + if (issuer_key) { + mbedtls_x509write_crt_set_issuer_key( + &crt, static_cast(issuer_key)->key()); + } + + subject = "CN=" + host; + auto ret = mbedtls_x509write_crt_set_subject_name(&crt, subject.c_str()); + if (ret) { + logerr(logger, ret, "Invalid subject name"); + goto error; + } + + if (issuer_cert) { + ret = mbedtls_x509_dn_gets( + issuer_name, sizeof(issuer_name), + &static_cast(issuer_cert)->cert()->subject); + if (ret < 0) { + logerr(logger, ret, "Unable to get issuer name"); + goto error; + } + ret = mbedtls_x509write_crt_set_issuer_name(&crt, issuer_name); + } else { + ret = mbedtls_x509write_crt_set_issuer_name(&crt, subject.c_str()); + } + if (ret) { + logerr(logger, ret, "Invalid issuer name"); + goto error; + } + now = time(nullptr); + tmp = now - (24 * 60 * 60 - 1); + gmtime_r(&tmp, &tm); + strftime(not_before, sizeof(not_before), "%Y%m%d000000", &tm); + tmp = now + (30 * 24 * 60 * 60); + gmtime_r(&tmp, &tm); + strftime(not_after, sizeof(not_after), "%Y%m%d000000", &tm); + ret = mbedtls_x509write_crt_set_validity(&crt, not_before, not_after); + if (ret) { + logerr(logger, ret, "Unable to set validity"); + goto error; + } + if (issuer_cert) { + ret = mbedtls_x509write_crt_set_basic_constraints(&crt, 0, -1); + } else { + ret = mbedtls_x509write_crt_set_basic_constraints(&crt, 1, 1); + } + if (ret) { + logerr(logger, ret, "Unable to set basic constraints"); + goto error; + } + + ret = mbedtls_x509write_crt_pem( + &crt, buffer, sizeof(buffer), + mbedtls_ctr_drbg_random, + static_cast(entropy)->random()); + if (ret) { + logerr(logger, ret, "Error writing cert"); + goto error; + } + cert->assign(reinterpret_cast(buffer)); + + ok = true; + error: + mbedtls_x509write_crt_free(&crt); + return ok; +} + +// static +SSL* SSL::server(Logger* logger, SSLEntropy* entropy, SSLCert* cert, + SSLKey* key, uint16_t flags) { + return new SSLServerImpl(logger, entropy, cert, key, flags); +} + +// static +SSL* SSL::client(Logger* logger, SSLEntropy* entropy, SSLCertStore* store, + std::string const& host, uint16_t flags) { + return new SSLClientImpl(logger, entropy, store, host, flags); +} + +// static +const uint16_t SSL::UNSECURE = 0x01; + -- cgit v1.2.3-70-g09d2