diff options
Diffstat (limited to 'src/ssl_mbedtls.cc')
| -rw-r--r-- | src/ssl_mbedtls.cc | 602 |
1 files changed, 602 insertions, 0 deletions
diff --git a/src/ssl_mbedtls.cc b/src/ssl_mbedtls.cc new file mode 100644 index 0000000..3395d83 --- /dev/null +++ b/src/ssl_mbedtls.cc @@ -0,0 +1,602 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include <mbedtls/certs.h> +#include <mbedtls/ctr_drbg.h> +#include <mbedtls/entropy.h> +#include <mbedtls/error.h> +#include <mbedtls/pk.h> +#include <mbedtls/rsa.h> +#include <mbedtls/ssl.h> +#include <mbedtls/x509.h> +#include <memory> + +#include "buffer.hh" +#include "logger.hh" +#include "ssl.hh" + +namespace { + +void logout(Logger* logger, Logger::Level lvl, int errnum, + const std::string& msg) { + char tmp[256]; + mbedtls_strerror(errnum, tmp, sizeof(tmp)); + logger->out(lvl, "%s (%d): %s", msg.c_str(), errnum, tmp); +} + +inline void loginfo(Logger* logger, int errnum, const std::string& msg) { + logout(logger, Logger::INFO, errnum, msg); +} + +inline void logerr(Logger* logger, int errnum, const std::string& msg) { + logout(logger, Logger::ERR, errnum, msg); +} + +class SSLEntropyImpl : public SSLEntropy { +public: + SSLEntropyImpl() { + mbedtls_entropy_init(&entropy_); + mbedtls_ctr_drbg_init(&ctr_drbg_); + } + + ~SSLEntropyImpl() override { + mbedtls_ctr_drbg_free(&ctr_drbg_); + mbedtls_entropy_free(&entropy_); + } + + bool init(Logger* logger) { + auto ret = mbedtls_ctr_drbg_seed(&ctr_drbg_, mbedtls_entropy_func, + &entropy_, nullptr, 0); + if (ret) { + logerr(logger, ret, "Error seeding entropy"); + return false; + } + return true; + } + + void setup(mbedtls_ssl_config* conf) { + mbedtls_ssl_conf_rng(conf, mbedtls_ctr_drbg_random, &ctr_drbg_); + } + + mbedtls_ctr_drbg_context* random() { + return &ctr_drbg_; + } + +private: + mbedtls_entropy_context entropy_; + mbedtls_ctr_drbg_context ctr_drbg_; +}; + +class SSLCertStoreImpl : public SSLCertStore { +public: + SSLCertStoreImpl() { + mbedtls_x509_crt_init(&chain_); + } + + ~SSLCertStoreImpl() override { + mbedtls_x509_crt_free(&chain_); + } + + bool init(Logger* logger, std::string const& path) { + auto ret = mbedtls_x509_crt_parse_file(&chain_, path.c_str()); + if (ret) { + logerr(logger, ret, "Error loeading cert store bundle"); + return false; + } + return true; + } + + mbedtls_x509_crt* chain() { + return &chain_; + } + +private: + mbedtls_x509_crt chain_; +}; + +class SSLKeyImpl : public SSLKey { +public: + SSLKeyImpl() { + mbedtls_pk_init(&key_); + } + + ~SSLKeyImpl() override { + mbedtls_pk_free(&key_); + } + + bool load(Logger* logger, std::string const& data) { + auto ret = mbedtls_pk_parse_key( + &key_, + reinterpret_cast<const unsigned char*>(data.c_str()), data.size() + 1, + nullptr, 0); + if (ret) { + logerr(logger, ret, "Error parsing key"); + return false; + } + return true; + } + + mbedtls_pk_context* key() { + return &key_; + } + +private: + mbedtls_pk_context key_; +}; + +class SSLCertImpl : public SSLCert { +public: + SSLCertImpl() { + mbedtls_x509_crt_init(&cert_); + } + + ~SSLCertImpl() override { + mbedtls_x509_crt_free(&cert_); + } + + bool load(Logger* logger, std::string const& data) { + auto ret = mbedtls_x509_crt_parse( + &cert_, + reinterpret_cast<const unsigned char*>(data.c_str()), data.size() + 1); + if (ret < 0) { + logerr(logger, ret, "Error parsing cert"); + return false; + } + return true; + } + + mbedtls_x509_crt* cert() { + return &cert_; + } + +private: + mbedtls_x509_crt cert_; +}; + +template<typename SetupData> +class SSLImpl : public SSL { +public: + SSLImpl(Logger* logger, uint16_t flags) + : logger_(logger), flags_(flags), in_(nullptr), out_(nullptr), + result_(NO_ERR), state_(HANDSHAKE) { + } + + ~SSLImpl() override { + mbedtls_ssl_free(&ssl_); + mbedtls_ssl_config_free(&conf_); + } + + bool unsecure() const { + return flags_ & UNSECURE; + } + + TransferResult transfer(Buffer* ssl_in, Buffer* ssl_out, + Buffer* data_in, Buffer* data_out) override { + if (result_) return result_; + + in_ = ssl_in; + out_ = ssl_out; + + result_ = transfer(data_in, data_out); + + in_ = nullptr; + out_ = nullptr; + + return result_; + } + + void close() override { + state_ = CLOSE; + } + +protected: + void init(SSLEntropy* entropy, int endpoint, SetupData const& data) { + mbedtls_ssl_init(&ssl_); + mbedtls_ssl_config_init(&conf_); + + auto ret = mbedtls_ssl_config_defaults(&conf_, endpoint, + MBEDTLS_SSL_TRANSPORT_STREAM, + MBEDTLS_SSL_PRESET_DEFAULT); + if (ret) { + logerr(logger_, ret, "Error configuring SSL"); + result_ = ERR; + return; + } + + static_cast<SSLEntropyImpl*>(entropy)->setup(&conf_); + + if (!before_setup(data)) { + result_ = ERR; + return; + } + + ret = mbedtls_ssl_setup(&ssl_, &conf_); + if (ret) { + logerr(logger_, ret, "Failed to setup SSL"); + result_ = ERR; + return; + } + + mbedtls_ssl_set_bio(&ssl_, this, &ssl_send, &ssl_recv, nullptr); + + if (!after_setup(data)) { + result_ = ERR; + return; + } + + result_ = NO_ERR; + } + + virtual bool before_setup(SetupData const& UNUSED(data)) { + return true; + } + + virtual bool after_setup(SetupData const& UNUSED(data)) { + return true; + } + + Logger* const logger_; + mbedtls_ssl_context ssl_; + mbedtls_ssl_config conf_; + +private: + enum State { + HANDSHAKE, + TRANSFER, + CLOSE + }; + + TransferResult transfer(Buffer* in, Buffer* out) { + bool want_read = false, want_write = false; + while (!want_read || !want_write) { + switch (state_) { + case HANDSHAKE: { + auto ret = mbedtls_ssl_handshake(&ssl_); + if (ret) { + if (ret == MBEDTLS_ERR_SSL_WANT_READ + || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + return NO_ERR; + } + loginfo(logger_, ret, "Handshake failed"); + return ERR; + } + state_ = TRANSFER; + break; + } + case CLOSE: { + auto ret = mbedtls_ssl_close_notify(&ssl_); + if (ret) { + if (ret == MBEDTLS_ERR_SSL_WANT_READ + || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + return NO_ERR; + } + loginfo(logger_, ret, "Close notify failed"); + return ERR; + } + return CLOSED; + } + case TRANSFER: { + size_t avail; + auto wptr = out->write_ptr(&avail); + if (avail > 0) { + auto ret = mbedtls_ssl_read(&ssl_, + reinterpret_cast<unsigned char*>(wptr), + avail); + if (ret > 0) { + out->commit(ret); + } else if (ret == MBEDTLS_ERR_SSL_WANT_READ + || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + want_read = true; + } else if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { + return CLOSED; + } else { + loginfo(logger_, ret, "SSL read failed"); + return ERR; + } + } else { + assert(false); + want_read = true; + } + auto rptr = in->read_ptr(&avail); + if (avail > 0) { + auto ret = mbedtls_ssl_write( + &ssl_, + reinterpret_cast<unsigned char const*>(rptr), + avail); + if (ret > 0) { + in->consume(ret); + } else if (ret == MBEDTLS_ERR_SSL_WANT_READ + || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + want_write = true; + } else { + loginfo(logger_, ret, "SSL write failed"); + return ERR; + } + } else { + want_write = true; + } + break; + } + } + } + return NO_ERR; + } + + int send(unsigned char const* buf, size_t len) { + if (!out_) { + assert(false); + return MBEDTLS_ERR_SSL_WANT_WRITE; + } + out_->write(buf, len); + return len; + } + int recv(unsigned char* buf, size_t len) { + if (!in_) { + assert(false); + return MBEDTLS_ERR_SSL_WANT_READ; + } + auto ret = in_->read(buf, len); + if (ret == 0) { + return MBEDTLS_ERR_SSL_WANT_READ; + } + return ret; + } + static int ssl_send(void* ctx, unsigned char const* buf, size_t len) { + return reinterpret_cast<SSLImpl*>(ctx)->send(buf, len); + } + static int ssl_recv(void* ctx, unsigned char* buf, size_t len) { + return reinterpret_cast<SSLImpl*>(ctx)->recv(buf, len); + } + + uint16_t const flags_; + Buffer* in_; + Buffer* out_; + TransferResult result_; + State state_; +}; + +struct CertKey { + SSLCert* cert; + SSLKey* key; + + CertKey(SSLCert* cert, SSLKey* key) + : cert(cert), key(key) { + } +}; + +class SSLServerImpl : public SSLImpl<CertKey> { +public: + SSLServerImpl(Logger* logger, SSLEntropy* entropy, SSLCert* cert, + SSLKey* key, uint16_t flags) + : SSLImpl(logger, flags) { + SSLImpl::init(entropy, MBEDTLS_SSL_IS_SERVER, CertKey(cert, key)); + } + + ~SSLServerImpl() override { + } + +private: + bool before_setup(CertKey const& data) override { + mbedtls_ssl_conf_ca_chain( + &conf_, static_cast<SSLCertImpl*>(data.cert)->cert()->next, nullptr); + auto ret = mbedtls_ssl_conf_own_cert( + &conf_, + static_cast<SSLCertImpl*>(data.cert)->cert(), + static_cast<SSLKeyImpl*>(data.key)->key()); + if (ret) { + logerr(logger_, ret, "Server: Error setting certificate"); + return false; + } + + mbedtls_ssl_conf_min_version(&conf_, + MBEDTLS_SSL_MAJOR_VERSION_3, + unsecure() ? MBEDTLS_SSL_MINOR_VERSION_0 : + MBEDTLS_SSL_MINOR_VERSION_1); + return true; + } +}; + +struct HostStore { + std::string const& host; + SSLCertStore* const store; + + HostStore(std::string const& host, SSLCertStore* store) + : host(host), store(store) { + } +}; + +class SSLClientImpl : public SSLImpl<HostStore> { +public: + SSLClientImpl(Logger* logger, SSLEntropy* entropy, SSLCertStore* store, + std::string const& host, uint16_t flags) + : SSLImpl(logger, flags) { + SSLImpl::init(entropy, MBEDTLS_SSL_IS_CLIENT, HostStore(host, store)); + } + +private: + bool before_setup(HostStore const& data) override { + mbedtls_ssl_conf_authmode(&conf_, unsecure() ? MBEDTLS_SSL_VERIFY_OPTIONAL + : MBEDTLS_SSL_VERIFY_REQUIRED); + mbedtls_ssl_conf_ca_chain( + &conf_, static_cast<SSLCertStoreImpl*>(data.store)->chain(), nullptr); + return true; + } + + bool after_setup(HostStore const& data) override { + auto ret = mbedtls_ssl_set_hostname(&ssl_, data.host.c_str()); + if (ret) { + logerr(logger_, ret, "Failed to set hostname"); + return false; + } + return true; + } +}; + +} // namespace + +// static +SSLEntropy* SSLEntropy::create(Logger* logger) { + std::unique_ptr<SSLEntropyImpl> entropy(new SSLEntropyImpl()); + if (!entropy->init(logger)) return nullptr; + return entropy.release(); +} + +// static +SSLCertStore* SSLCertStore::create(Logger* logger, std::string const& path) { + std::unique_ptr<SSLCertStoreImpl> store(new SSLCertStoreImpl()); + if (!store->init(logger, path)) return nullptr; + return store.release(); +} + +// static +SSLKey* SSLKey::load(Logger* logger, std::string const& data) { + std::unique_ptr<SSLKeyImpl> key(new SSLKeyImpl()); + if (!key->load(logger, data)) return nullptr; + return key.release(); +} + +// static +bool SSLKey::generate(Logger* logger, SSLEntropy* entropy, std::string* key) { + mbedtls_pk_context pk; + bool ok = false; + unsigned char buffer[16000]; + mbedtls_pk_init(&pk); + + auto ret = mbedtls_pk_setup(&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA)); + if (ret) { + logerr(logger, ret, "Error setting up key"); + goto error; + } + + ret = mbedtls_rsa_gen_key(mbedtls_pk_rsa(pk), + mbedtls_ctr_drbg_random, + static_cast<SSLEntropyImpl*>(entropy)->random(), + 4096, 65537); + if (ret) { + logerr(logger, ret, "Error generating key"); + goto error; + } + + ret = mbedtls_pk_write_key_pem(&pk, buffer, sizeof(buffer)); + if (ret) { + logerr(logger, ret, "Error writing key"); + goto error; + } + key->assign(reinterpret_cast<char*>(buffer)); + + ok = true; + error: + mbedtls_pk_free(&pk); + return ok; +} + +// static +SSLCert* SSLCert::load(Logger* logger, std::string const& data) { + std::unique_ptr<SSLCertImpl> key(new SSLCertImpl()); + if (!key->load(logger, data)) return nullptr; + return key.release(); +} + +// static +bool SSLCert::generate(Logger* logger, SSLEntropy* entropy, + SSLCert* issuer_cert, SSLKey* issuer_key, + std::string const& host, SSLKey* key, + std::string* cert) { + mbedtls_x509write_cert crt; + char issuer_name[256]; + std::string subject; + bool ok = false; + time_t now, tmp; + struct tm tm; + char not_before[20]; + char not_after[20]; + unsigned char buffer[16000]; + mbedtls_x509write_crt_init(&crt); + mbedtls_x509write_crt_set_md_alg(&crt, MBEDTLS_MD_SHA256); + + if (key) { + mbedtls_x509write_crt_set_subject_key( + &crt, static_cast<SSLKeyImpl*>(key)->key()); + } + if (issuer_key) { + mbedtls_x509write_crt_set_issuer_key( + &crt, static_cast<SSLKeyImpl*>(issuer_key)->key()); + } + + subject = "CN=" + host; + auto ret = mbedtls_x509write_crt_set_subject_name(&crt, subject.c_str()); + if (ret) { + logerr(logger, ret, "Invalid subject name"); + goto error; + } + + if (issuer_cert) { + ret = mbedtls_x509_dn_gets( + issuer_name, sizeof(issuer_name), + &static_cast<SSLCertImpl*>(issuer_cert)->cert()->subject); + if (ret < 0) { + logerr(logger, ret, "Unable to get issuer name"); + goto error; + } + ret = mbedtls_x509write_crt_set_issuer_name(&crt, issuer_name); + } else { + ret = mbedtls_x509write_crt_set_issuer_name(&crt, subject.c_str()); + } + if (ret) { + logerr(logger, ret, "Invalid issuer name"); + goto error; + } + now = time(nullptr); + tmp = now - (24 * 60 * 60 - 1); + gmtime_r(&tmp, &tm); + strftime(not_before, sizeof(not_before), "%Y%m%d000000", &tm); + tmp = now + (30 * 24 * 60 * 60); + gmtime_r(&tmp, &tm); + strftime(not_after, sizeof(not_after), "%Y%m%d000000", &tm); + ret = mbedtls_x509write_crt_set_validity(&crt, not_before, not_after); + if (ret) { + logerr(logger, ret, "Unable to set validity"); + goto error; + } + if (issuer_cert) { + ret = mbedtls_x509write_crt_set_basic_constraints(&crt, 0, -1); + } else { + ret = mbedtls_x509write_crt_set_basic_constraints(&crt, 1, 1); + } + if (ret) { + logerr(logger, ret, "Unable to set basic constraints"); + goto error; + } + + ret = mbedtls_x509write_crt_pem( + &crt, buffer, sizeof(buffer), + mbedtls_ctr_drbg_random, + static_cast<SSLEntropyImpl*>(entropy)->random()); + if (ret) { + logerr(logger, ret, "Error writing cert"); + goto error; + } + cert->assign(reinterpret_cast<char*>(buffer)); + + ok = true; + error: + mbedtls_x509write_crt_free(&crt); + return ok; +} + +// static +SSL* SSL::server(Logger* logger, SSLEntropy* entropy, SSLCert* cert, + SSLKey* key, uint16_t flags) { + return new SSLServerImpl(logger, entropy, cert, key, flags); +} + +// static +SSL* SSL::client(Logger* logger, SSLEntropy* entropy, SSLCertStore* store, + std::string const& host, uint16_t flags) { + return new SSLClientImpl(logger, entropy, store, host, flags); +} + +// static +const uint16_t SSL::UNSECURE = 0x01; + |
