summaryrefslogtreecommitdiff
path: root/src/ssl.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/ssl.cc')
-rw-r--r--src/ssl.cc602
1 files changed, 602 insertions, 0 deletions
diff --git a/src/ssl.cc b/src/ssl.cc
new file mode 100644
index 0000000..3395d83
--- /dev/null
+++ b/src/ssl.cc
@@ -0,0 +1,602 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include <mbedtls/certs.h>
+#include <mbedtls/ctr_drbg.h>
+#include <mbedtls/entropy.h>
+#include <mbedtls/error.h>
+#include <mbedtls/pk.h>
+#include <mbedtls/rsa.h>
+#include <mbedtls/ssl.h>
+#include <mbedtls/x509.h>
+#include <memory>
+
+#include "buffer.hh"
+#include "logger.hh"
+#include "ssl.hh"
+
+namespace {
+
+void logout(Logger* logger, Logger::Level lvl, int errnum,
+ const std::string& msg) {
+ char tmp[256];
+ mbedtls_strerror(errnum, tmp, sizeof(tmp));
+ logger->out(lvl, "%s (%d): %s", msg.c_str(), errnum, tmp);
+}
+
+inline void loginfo(Logger* logger, int errnum, const std::string& msg) {
+ logout(logger, Logger::INFO, errnum, msg);
+}
+
+inline void logerr(Logger* logger, int errnum, const std::string& msg) {
+ logout(logger, Logger::ERR, errnum, msg);
+}
+
+class SSLEntropyImpl : public SSLEntropy {
+public:
+ SSLEntropyImpl() {
+ mbedtls_entropy_init(&entropy_);
+ mbedtls_ctr_drbg_init(&ctr_drbg_);
+ }
+
+ ~SSLEntropyImpl() override {
+ mbedtls_ctr_drbg_free(&ctr_drbg_);
+ mbedtls_entropy_free(&entropy_);
+ }
+
+ bool init(Logger* logger) {
+ auto ret = mbedtls_ctr_drbg_seed(&ctr_drbg_, mbedtls_entropy_func,
+ &entropy_, nullptr, 0);
+ if (ret) {
+ logerr(logger, ret, "Error seeding entropy");
+ return false;
+ }
+ return true;
+ }
+
+ void setup(mbedtls_ssl_config* conf) {
+ mbedtls_ssl_conf_rng(conf, mbedtls_ctr_drbg_random, &ctr_drbg_);
+ }
+
+ mbedtls_ctr_drbg_context* random() {
+ return &ctr_drbg_;
+ }
+
+private:
+ mbedtls_entropy_context entropy_;
+ mbedtls_ctr_drbg_context ctr_drbg_;
+};
+
+class SSLCertStoreImpl : public SSLCertStore {
+public:
+ SSLCertStoreImpl() {
+ mbedtls_x509_crt_init(&chain_);
+ }
+
+ ~SSLCertStoreImpl() override {
+ mbedtls_x509_crt_free(&chain_);
+ }
+
+ bool init(Logger* logger, std::string const& path) {
+ auto ret = mbedtls_x509_crt_parse_file(&chain_, path.c_str());
+ if (ret) {
+ logerr(logger, ret, "Error loeading cert store bundle");
+ return false;
+ }
+ return true;
+ }
+
+ mbedtls_x509_crt* chain() {
+ return &chain_;
+ }
+
+private:
+ mbedtls_x509_crt chain_;
+};
+
+class SSLKeyImpl : public SSLKey {
+public:
+ SSLKeyImpl() {
+ mbedtls_pk_init(&key_);
+ }
+
+ ~SSLKeyImpl() override {
+ mbedtls_pk_free(&key_);
+ }
+
+ bool load(Logger* logger, std::string const& data) {
+ auto ret = mbedtls_pk_parse_key(
+ &key_,
+ reinterpret_cast<const unsigned char*>(data.c_str()), data.size() + 1,
+ nullptr, 0);
+ if (ret) {
+ logerr(logger, ret, "Error parsing key");
+ return false;
+ }
+ return true;
+ }
+
+ mbedtls_pk_context* key() {
+ return &key_;
+ }
+
+private:
+ mbedtls_pk_context key_;
+};
+
+class SSLCertImpl : public SSLCert {
+public:
+ SSLCertImpl() {
+ mbedtls_x509_crt_init(&cert_);
+ }
+
+ ~SSLCertImpl() override {
+ mbedtls_x509_crt_free(&cert_);
+ }
+
+ bool load(Logger* logger, std::string const& data) {
+ auto ret = mbedtls_x509_crt_parse(
+ &cert_,
+ reinterpret_cast<const unsigned char*>(data.c_str()), data.size() + 1);
+ if (ret < 0) {
+ logerr(logger, ret, "Error parsing cert");
+ return false;
+ }
+ return true;
+ }
+
+ mbedtls_x509_crt* cert() {
+ return &cert_;
+ }
+
+private:
+ mbedtls_x509_crt cert_;
+};
+
+template<typename SetupData>
+class SSLImpl : public SSL {
+public:
+ SSLImpl(Logger* logger, uint16_t flags)
+ : logger_(logger), flags_(flags), in_(nullptr), out_(nullptr),
+ result_(NO_ERR), state_(HANDSHAKE) {
+ }
+
+ ~SSLImpl() override {
+ mbedtls_ssl_free(&ssl_);
+ mbedtls_ssl_config_free(&conf_);
+ }
+
+ bool unsecure() const {
+ return flags_ & UNSECURE;
+ }
+
+ TransferResult transfer(Buffer* ssl_in, Buffer* ssl_out,
+ Buffer* data_in, Buffer* data_out) override {
+ if (result_) return result_;
+
+ in_ = ssl_in;
+ out_ = ssl_out;
+
+ result_ = transfer(data_in, data_out);
+
+ in_ = nullptr;
+ out_ = nullptr;
+
+ return result_;
+ }
+
+ void close() override {
+ state_ = CLOSE;
+ }
+
+protected:
+ void init(SSLEntropy* entropy, int endpoint, SetupData const& data) {
+ mbedtls_ssl_init(&ssl_);
+ mbedtls_ssl_config_init(&conf_);
+
+ auto ret = mbedtls_ssl_config_defaults(&conf_, endpoint,
+ MBEDTLS_SSL_TRANSPORT_STREAM,
+ MBEDTLS_SSL_PRESET_DEFAULT);
+ if (ret) {
+ logerr(logger_, ret, "Error configuring SSL");
+ result_ = ERR;
+ return;
+ }
+
+ static_cast<SSLEntropyImpl*>(entropy)->setup(&conf_);
+
+ if (!before_setup(data)) {
+ result_ = ERR;
+ return;
+ }
+
+ ret = mbedtls_ssl_setup(&ssl_, &conf_);
+ if (ret) {
+ logerr(logger_, ret, "Failed to setup SSL");
+ result_ = ERR;
+ return;
+ }
+
+ mbedtls_ssl_set_bio(&ssl_, this, &ssl_send, &ssl_recv, nullptr);
+
+ if (!after_setup(data)) {
+ result_ = ERR;
+ return;
+ }
+
+ result_ = NO_ERR;
+ }
+
+ virtual bool before_setup(SetupData const& UNUSED(data)) {
+ return true;
+ }
+
+ virtual bool after_setup(SetupData const& UNUSED(data)) {
+ return true;
+ }
+
+ Logger* const logger_;
+ mbedtls_ssl_context ssl_;
+ mbedtls_ssl_config conf_;
+
+private:
+ enum State {
+ HANDSHAKE,
+ TRANSFER,
+ CLOSE
+ };
+
+ TransferResult transfer(Buffer* in, Buffer* out) {
+ bool want_read = false, want_write = false;
+ while (!want_read || !want_write) {
+ switch (state_) {
+ case HANDSHAKE: {
+ auto ret = mbedtls_ssl_handshake(&ssl_);
+ if (ret) {
+ if (ret == MBEDTLS_ERR_SSL_WANT_READ
+ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
+ return NO_ERR;
+ }
+ loginfo(logger_, ret, "Handshake failed");
+ return ERR;
+ }
+ state_ = TRANSFER;
+ break;
+ }
+ case CLOSE: {
+ auto ret = mbedtls_ssl_close_notify(&ssl_);
+ if (ret) {
+ if (ret == MBEDTLS_ERR_SSL_WANT_READ
+ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
+ return NO_ERR;
+ }
+ loginfo(logger_, ret, "Close notify failed");
+ return ERR;
+ }
+ return CLOSED;
+ }
+ case TRANSFER: {
+ size_t avail;
+ auto wptr = out->write_ptr(&avail);
+ if (avail > 0) {
+ auto ret = mbedtls_ssl_read(&ssl_,
+ reinterpret_cast<unsigned char*>(wptr),
+ avail);
+ if (ret > 0) {
+ out->commit(ret);
+ } else if (ret == MBEDTLS_ERR_SSL_WANT_READ
+ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
+ want_read = true;
+ } else if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
+ return CLOSED;
+ } else {
+ loginfo(logger_, ret, "SSL read failed");
+ return ERR;
+ }
+ } else {
+ assert(false);
+ want_read = true;
+ }
+ auto rptr = in->read_ptr(&avail);
+ if (avail > 0) {
+ auto ret = mbedtls_ssl_write(
+ &ssl_,
+ reinterpret_cast<unsigned char const*>(rptr),
+ avail);
+ if (ret > 0) {
+ in->consume(ret);
+ } else if (ret == MBEDTLS_ERR_SSL_WANT_READ
+ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
+ want_write = true;
+ } else {
+ loginfo(logger_, ret, "SSL write failed");
+ return ERR;
+ }
+ } else {
+ want_write = true;
+ }
+ break;
+ }
+ }
+ }
+ return NO_ERR;
+ }
+
+ int send(unsigned char const* buf, size_t len) {
+ if (!out_) {
+ assert(false);
+ return MBEDTLS_ERR_SSL_WANT_WRITE;
+ }
+ out_->write(buf, len);
+ return len;
+ }
+ int recv(unsigned char* buf, size_t len) {
+ if (!in_) {
+ assert(false);
+ return MBEDTLS_ERR_SSL_WANT_READ;
+ }
+ auto ret = in_->read(buf, len);
+ if (ret == 0) {
+ return MBEDTLS_ERR_SSL_WANT_READ;
+ }
+ return ret;
+ }
+ static int ssl_send(void* ctx, unsigned char const* buf, size_t len) {
+ return reinterpret_cast<SSLImpl*>(ctx)->send(buf, len);
+ }
+ static int ssl_recv(void* ctx, unsigned char* buf, size_t len) {
+ return reinterpret_cast<SSLImpl*>(ctx)->recv(buf, len);
+ }
+
+ uint16_t const flags_;
+ Buffer* in_;
+ Buffer* out_;
+ TransferResult result_;
+ State state_;
+};
+
+struct CertKey {
+ SSLCert* cert;
+ SSLKey* key;
+
+ CertKey(SSLCert* cert, SSLKey* key)
+ : cert(cert), key(key) {
+ }
+};
+
+class SSLServerImpl : public SSLImpl<CertKey> {
+public:
+ SSLServerImpl(Logger* logger, SSLEntropy* entropy, SSLCert* cert,
+ SSLKey* key, uint16_t flags)
+ : SSLImpl(logger, flags) {
+ SSLImpl::init(entropy, MBEDTLS_SSL_IS_SERVER, CertKey(cert, key));
+ }
+
+ ~SSLServerImpl() override {
+ }
+
+private:
+ bool before_setup(CertKey const& data) override {
+ mbedtls_ssl_conf_ca_chain(
+ &conf_, static_cast<SSLCertImpl*>(data.cert)->cert()->next, nullptr);
+ auto ret = mbedtls_ssl_conf_own_cert(
+ &conf_,
+ static_cast<SSLCertImpl*>(data.cert)->cert(),
+ static_cast<SSLKeyImpl*>(data.key)->key());
+ if (ret) {
+ logerr(logger_, ret, "Server: Error setting certificate");
+ return false;
+ }
+
+ mbedtls_ssl_conf_min_version(&conf_,
+ MBEDTLS_SSL_MAJOR_VERSION_3,
+ unsecure() ? MBEDTLS_SSL_MINOR_VERSION_0 :
+ MBEDTLS_SSL_MINOR_VERSION_1);
+ return true;
+ }
+};
+
+struct HostStore {
+ std::string const& host;
+ SSLCertStore* const store;
+
+ HostStore(std::string const& host, SSLCertStore* store)
+ : host(host), store(store) {
+ }
+};
+
+class SSLClientImpl : public SSLImpl<HostStore> {
+public:
+ SSLClientImpl(Logger* logger, SSLEntropy* entropy, SSLCertStore* store,
+ std::string const& host, uint16_t flags)
+ : SSLImpl(logger, flags) {
+ SSLImpl::init(entropy, MBEDTLS_SSL_IS_CLIENT, HostStore(host, store));
+ }
+
+private:
+ bool before_setup(HostStore const& data) override {
+ mbedtls_ssl_conf_authmode(&conf_, unsecure() ? MBEDTLS_SSL_VERIFY_OPTIONAL
+ : MBEDTLS_SSL_VERIFY_REQUIRED);
+ mbedtls_ssl_conf_ca_chain(
+ &conf_, static_cast<SSLCertStoreImpl*>(data.store)->chain(), nullptr);
+ return true;
+ }
+
+ bool after_setup(HostStore const& data) override {
+ auto ret = mbedtls_ssl_set_hostname(&ssl_, data.host.c_str());
+ if (ret) {
+ logerr(logger_, ret, "Failed to set hostname");
+ return false;
+ }
+ return true;
+ }
+};
+
+} // namespace
+
+// static
+SSLEntropy* SSLEntropy::create(Logger* logger) {
+ std::unique_ptr<SSLEntropyImpl> entropy(new SSLEntropyImpl());
+ if (!entropy->init(logger)) return nullptr;
+ return entropy.release();
+}
+
+// static
+SSLCertStore* SSLCertStore::create(Logger* logger, std::string const& path) {
+ std::unique_ptr<SSLCertStoreImpl> store(new SSLCertStoreImpl());
+ if (!store->init(logger, path)) return nullptr;
+ return store.release();
+}
+
+// static
+SSLKey* SSLKey::load(Logger* logger, std::string const& data) {
+ std::unique_ptr<SSLKeyImpl> key(new SSLKeyImpl());
+ if (!key->load(logger, data)) return nullptr;
+ return key.release();
+}
+
+// static
+bool SSLKey::generate(Logger* logger, SSLEntropy* entropy, std::string* key) {
+ mbedtls_pk_context pk;
+ bool ok = false;
+ unsigned char buffer[16000];
+ mbedtls_pk_init(&pk);
+
+ auto ret = mbedtls_pk_setup(&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA));
+ if (ret) {
+ logerr(logger, ret, "Error setting up key");
+ goto error;
+ }
+
+ ret = mbedtls_rsa_gen_key(mbedtls_pk_rsa(pk),
+ mbedtls_ctr_drbg_random,
+ static_cast<SSLEntropyImpl*>(entropy)->random(),
+ 4096, 65537);
+ if (ret) {
+ logerr(logger, ret, "Error generating key");
+ goto error;
+ }
+
+ ret = mbedtls_pk_write_key_pem(&pk, buffer, sizeof(buffer));
+ if (ret) {
+ logerr(logger, ret, "Error writing key");
+ goto error;
+ }
+ key->assign(reinterpret_cast<char*>(buffer));
+
+ ok = true;
+ error:
+ mbedtls_pk_free(&pk);
+ return ok;
+}
+
+// static
+SSLCert* SSLCert::load(Logger* logger, std::string const& data) {
+ std::unique_ptr<SSLCertImpl> key(new SSLCertImpl());
+ if (!key->load(logger, data)) return nullptr;
+ return key.release();
+}
+
+// static
+bool SSLCert::generate(Logger* logger, SSLEntropy* entropy,
+ SSLCert* issuer_cert, SSLKey* issuer_key,
+ std::string const& host, SSLKey* key,
+ std::string* cert) {
+ mbedtls_x509write_cert crt;
+ char issuer_name[256];
+ std::string subject;
+ bool ok = false;
+ time_t now, tmp;
+ struct tm tm;
+ char not_before[20];
+ char not_after[20];
+ unsigned char buffer[16000];
+ mbedtls_x509write_crt_init(&crt);
+ mbedtls_x509write_crt_set_md_alg(&crt, MBEDTLS_MD_SHA256);
+
+ if (key) {
+ mbedtls_x509write_crt_set_subject_key(
+ &crt, static_cast<SSLKeyImpl*>(key)->key());
+ }
+ if (issuer_key) {
+ mbedtls_x509write_crt_set_issuer_key(
+ &crt, static_cast<SSLKeyImpl*>(issuer_key)->key());
+ }
+
+ subject = "CN=" + host;
+ auto ret = mbedtls_x509write_crt_set_subject_name(&crt, subject.c_str());
+ if (ret) {
+ logerr(logger, ret, "Invalid subject name");
+ goto error;
+ }
+
+ if (issuer_cert) {
+ ret = mbedtls_x509_dn_gets(
+ issuer_name, sizeof(issuer_name),
+ &static_cast<SSLCertImpl*>(issuer_cert)->cert()->subject);
+ if (ret < 0) {
+ logerr(logger, ret, "Unable to get issuer name");
+ goto error;
+ }
+ ret = mbedtls_x509write_crt_set_issuer_name(&crt, issuer_name);
+ } else {
+ ret = mbedtls_x509write_crt_set_issuer_name(&crt, subject.c_str());
+ }
+ if (ret) {
+ logerr(logger, ret, "Invalid issuer name");
+ goto error;
+ }
+ now = time(nullptr);
+ tmp = now - (24 * 60 * 60 - 1);
+ gmtime_r(&tmp, &tm);
+ strftime(not_before, sizeof(not_before), "%Y%m%d000000", &tm);
+ tmp = now + (30 * 24 * 60 * 60);
+ gmtime_r(&tmp, &tm);
+ strftime(not_after, sizeof(not_after), "%Y%m%d000000", &tm);
+ ret = mbedtls_x509write_crt_set_validity(&crt, not_before, not_after);
+ if (ret) {
+ logerr(logger, ret, "Unable to set validity");
+ goto error;
+ }
+ if (issuer_cert) {
+ ret = mbedtls_x509write_crt_set_basic_constraints(&crt, 0, -1);
+ } else {
+ ret = mbedtls_x509write_crt_set_basic_constraints(&crt, 1, 1);
+ }
+ if (ret) {
+ logerr(logger, ret, "Unable to set basic constraints");
+ goto error;
+ }
+
+ ret = mbedtls_x509write_crt_pem(
+ &crt, buffer, sizeof(buffer),
+ mbedtls_ctr_drbg_random,
+ static_cast<SSLEntropyImpl*>(entropy)->random());
+ if (ret) {
+ logerr(logger, ret, "Error writing cert");
+ goto error;
+ }
+ cert->assign(reinterpret_cast<char*>(buffer));
+
+ ok = true;
+ error:
+ mbedtls_x509write_crt_free(&crt);
+ return ok;
+}
+
+// static
+SSL* SSL::server(Logger* logger, SSLEntropy* entropy, SSLCert* cert,
+ SSLKey* key, uint16_t flags) {
+ return new SSLServerImpl(logger, entropy, cert, key, flags);
+}
+
+// static
+SSL* SSL::client(Logger* logger, SSLEntropy* entropy, SSLCertStore* store,
+ std::string const& host, uint16_t flags) {
+ return new SSLClientImpl(logger, entropy, store, host, flags);
+}
+
+// static
+const uint16_t SSL::UNSECURE = 0x01;
+