// -*- mode: c++; c-basic-offset: 2; -*- #include "common.hh" #include #include #include #include #include #include #include #include #include #include #include #include "buffer.hh" #include "logger.hh" #include "ssl.hh" namespace { void logout(Logger* logger, Logger::Level lvl, int errnum, const std::string& msg) { char tmp[256]; mbedtls_strerror(errnum, tmp, sizeof(tmp)); logger->out(lvl, "%s (%d): %s", msg.c_str(), errnum, tmp); } inline void loginfo(Logger* logger, int errnum, const std::string& msg) { logout(logger, Logger::INFO, errnum, msg); } inline void logerr(Logger* logger, int errnum, const std::string& msg) { logout(logger, Logger::ERR, errnum, msg); } class SSLEntropyImpl : public SSLEntropy { public: SSLEntropyImpl() { mbedtls_entropy_init(&entropy_); mbedtls_ctr_drbg_init(&ctr_drbg_); } ~SSLEntropyImpl() override { mbedtls_ctr_drbg_free(&ctr_drbg_); mbedtls_entropy_free(&entropy_); } bool init(Logger* logger) { auto ret = mbedtls_ctr_drbg_seed(&ctr_drbg_, mbedtls_entropy_func, &entropy_, nullptr, 0); if (ret) { logerr(logger, ret, "Error seeding entropy"); return false; } return true; } void setup(mbedtls_ssl_config* conf) { mbedtls_ssl_conf_rng(conf, mbedtls_ctr_drbg_random, &ctr_drbg_); } mbedtls_ctr_drbg_context* random() { return &ctr_drbg_; } private: mbedtls_entropy_context entropy_; mbedtls_ctr_drbg_context ctr_drbg_; }; class SSLCertStoreImpl : public SSLCertStore { public: SSLCertStoreImpl() { mbedtls_x509_crt_init(&chain_); } ~SSLCertStoreImpl() override { mbedtls_x509_crt_free(&chain_); } bool init(Logger* logger, std::string const& path) { auto ret = mbedtls_x509_crt_parse_file(&chain_, path.c_str()); if (ret) { logerr(logger, ret, "Error loading cert store bundle"); return false; } return true; } mbedtls_x509_crt* chain() { return &chain_; } private: mbedtls_x509_crt chain_; }; class SSLKeyImpl : public SSLKey { public: SSLKeyImpl() { mbedtls_pk_init(&key_); } ~SSLKeyImpl() override { mbedtls_pk_free(&key_); } bool load(Logger* logger, std::string const& data, SSLEntropy* entropy) { auto ret = mbedtls_pk_parse_key( &key_, reinterpret_cast(data.c_str()), data.size() + 1, nullptr, 0, mbedtls_ctr_drbg_random, static_cast(entropy)->random()); if (ret) { logerr(logger, ret, "Error parsing key"); return false; } return true; } mbedtls_pk_context* key() { return &key_; } private: mbedtls_pk_context key_; }; class SSLCertImpl : public SSLCert { public: SSLCertImpl() { mbedtls_x509_crt_init(&cert_); } ~SSLCertImpl() override { mbedtls_x509_crt_free(&cert_); } bool load(Logger* logger, std::string const& data) { auto ret = mbedtls_x509_crt_parse( &cert_, reinterpret_cast(data.c_str()), data.size() + 1); if (ret < 0) { logerr(logger, ret, "Error parsing cert"); return false; } return true; } mbedtls_x509_crt* cert() { return &cert_; } private: mbedtls_x509_crt cert_; }; template class SSLImpl : public SSL { public: SSLImpl(Logger* logger, uint16_t flags) : logger_(logger), flags_(flags), in_(nullptr), out_(nullptr), result_(NO_ERR), state_(HANDSHAKE) { } ~SSLImpl() override { mbedtls_ssl_free(&ssl_); mbedtls_ssl_config_free(&conf_); } bool unsecure() const { return flags_ & UNSECURE; } TransferResult transfer(Buffer* ssl_in, Buffer* ssl_out, Buffer* data_in, Buffer* data_out) override { if (result_) return result_; in_ = ssl_in; out_ = ssl_out; result_ = transfer(data_in, data_out); in_ = nullptr; out_ = nullptr; return result_; } void close() override { state_ = CLOSE; } protected: void init(SSLEntropy* entropy, int endpoint, SetupData const& data) { mbedtls_ssl_init(&ssl_); mbedtls_ssl_config_init(&conf_); auto ret = mbedtls_ssl_config_defaults(&conf_, endpoint, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT); if (ret) { logerr(logger_, ret, "Error configuring SSL"); result_ = ERR; return; } static_cast(entropy)->setup(&conf_); if (!before_setup(data)) { result_ = ERR; return; } ret = mbedtls_ssl_setup(&ssl_, &conf_); if (ret) { logerr(logger_, ret, "Failed to setup SSL"); result_ = ERR; return; } mbedtls_ssl_set_bio(&ssl_, this, &ssl_send, &ssl_recv, nullptr); if (!after_setup(data)) { result_ = ERR; return; } result_ = NO_ERR; } virtual bool before_setup(SetupData const&) { return true; } virtual bool after_setup(SetupData const&) { return true; } Logger* const logger_; mbedtls_ssl_context ssl_; mbedtls_ssl_config conf_; private: enum State { HANDSHAKE, TRANSFER, CLOSE }; TransferResult transfer(Buffer* in, Buffer* out) { bool want_read = false, want_write = false; while (!want_read || !want_write) { switch (state_) { case HANDSHAKE: { auto ret = mbedtls_ssl_handshake(&ssl_); if (ret) { if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { return NO_ERR; } loginfo(logger_, ret, "Handshake failed"); return ERR; } state_ = TRANSFER; break; } case CLOSE: { auto ret = mbedtls_ssl_close_notify(&ssl_); if (ret) { if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { return NO_ERR; } loginfo(logger_, ret, "Close notify failed"); return ERR; } return CLOSED; } case TRANSFER: { size_t avail; auto wptr = out->write_ptr(&avail); if (avail > 0) { auto ret = mbedtls_ssl_read(&ssl_, reinterpret_cast(wptr), avail); if (ret > 0) { out->commit(ret); } else if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { want_read = true; } else if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { return CLOSED; } else { loginfo(logger_, ret, "SSL read failed"); return ERR; } } else { assert(false); want_read = true; } auto rptr = in->read_ptr(&avail); if (avail > 0) { auto ret = mbedtls_ssl_write( &ssl_, reinterpret_cast(rptr), avail); if (ret > 0) { in->consume(ret); } else if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { want_write = true; } else { loginfo(logger_, ret, "SSL write failed"); return ERR; } } else { want_write = true; } break; } } } return NO_ERR; } int send(unsigned char const* buf, size_t len) { if (!out_) { assert(false); return MBEDTLS_ERR_SSL_WANT_WRITE; } out_->write(buf, len); return len; } int recv(unsigned char* buf, size_t len) { if (!in_) { assert(false); return MBEDTLS_ERR_SSL_WANT_READ; } auto ret = in_->read(buf, len); if (ret == 0) { return MBEDTLS_ERR_SSL_WANT_READ; } return ret; } static int ssl_send(void* ctx, unsigned char const* buf, size_t len) { return reinterpret_cast(ctx)->send(buf, len); } static int ssl_recv(void* ctx, unsigned char* buf, size_t len) { return reinterpret_cast(ctx)->recv(buf, len); } uint16_t const flags_; Buffer* in_; Buffer* out_; TransferResult result_; State state_; }; struct CertKey { SSLCert* cert; SSLKey* key; CertKey(SSLCert* cert, SSLKey* key) : cert(cert), key(key) { } }; class SSLServerImpl : public SSLImpl { public: SSLServerImpl(Logger* logger, SSLEntropy* entropy, SSLCert* cert, SSLKey* key, uint16_t flags) : SSLImpl(logger, flags) { SSLImpl::init(entropy, MBEDTLS_SSL_IS_SERVER, CertKey(cert, key)); } ~SSLServerImpl() override { } private: bool before_setup(CertKey const& data) override { mbedtls_ssl_conf_ca_chain( &conf_, static_cast(data.cert)->cert()->next, nullptr); auto ret = mbedtls_ssl_conf_own_cert( &conf_, static_cast(data.cert)->cert(), static_cast(data.key)->key()); if (ret) { logerr(logger_, ret, "Server: Error setting certificate"); return false; } mbedtls_ssl_conf_min_version(&conf_, MBEDTLS_SSL_MAJOR_VERSION_3, unsecure() ? MBEDTLS_SSL_MINOR_VERSION_3 : MBEDTLS_SSL_MINOR_VERSION_4); return true; } }; struct HostStore { std::string const& host; SSLCertStore* const store; HostStore(std::string const& host, SSLCertStore* store) : host(host), store(store) { } }; class SSLClientImpl : public SSLImpl { public: SSLClientImpl(Logger* logger, SSLEntropy* entropy, SSLCertStore* store, std::string const& host, uint16_t flags) : SSLImpl(logger, flags) { SSLImpl::init(entropy, MBEDTLS_SSL_IS_CLIENT, HostStore(host, store)); } private: bool before_setup(HostStore const& data) override { mbedtls_ssl_conf_authmode(&conf_, unsecure() ? MBEDTLS_SSL_VERIFY_OPTIONAL : MBEDTLS_SSL_VERIFY_REQUIRED); mbedtls_ssl_conf_ca_chain( &conf_, static_cast(data.store)->chain(), nullptr); return true; } bool after_setup(HostStore const& data) override { auto ret = mbedtls_ssl_set_hostname(&ssl_, data.host.c_str()); if (ret) { logerr(logger_, ret, "Failed to set hostname"); return false; } return true; } }; int mbedtls_x509write_crt_set_subject_alt_name( mbedtls_x509write_cert* ctx, const char* name) { unsigned char buf[256]; unsigned char *c = buf + sizeof(buf); int ret; size_t len = 0; size_t namelen; if (name == NULL) return MBEDTLS_ERR_X509_BAD_INPUT_DATA; namelen = strlen(name); MBEDTLS_ASN1_CHK_ADD(len, mbedtls_asn1_write_raw_buffer(&c, buf, reinterpret_cast(name), namelen)); MBEDTLS_ASN1_CHK_ADD(len, mbedtls_asn1_write_len(&c, buf, namelen)); MBEDTLS_ASN1_CHK_ADD(len, mbedtls_asn1_write_tag(&c, buf, MBEDTLS_ASN1_CONTEXT_SPECIFIC | 2)); MBEDTLS_ASN1_CHK_ADD(len, mbedtls_asn1_write_len(&c, buf, len)); MBEDTLS_ASN1_CHK_ADD(len, mbedtls_asn1_write_tag(&c, buf, MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE)); return mbedtls_x509write_crt_set_extension( ctx, MBEDTLS_OID_SUBJECT_ALT_NAME, MBEDTLS_OID_SIZE(MBEDTLS_OID_SUBJECT_ALT_NAME), 1, c, len); } } // namespace // static SSLEntropy* SSLEntropy::create(Logger* logger) { std::unique_ptr entropy(new SSLEntropyImpl()); if (!entropy->init(logger)) return nullptr; return entropy.release(); } // static SSLCertStore* SSLCertStore::create(Logger* logger, std::string const& path) { std::unique_ptr store(new SSLCertStoreImpl()); if (!store->init(logger, path)) return nullptr; return store.release(); } // static SSLKey* SSLKey::load(Logger* logger, std::string const& data, SSLEntropy* entropy) { std::unique_ptr key(new SSLKeyImpl()); if (!key->load(logger, data, entropy)) return nullptr; return key.release(); } // static bool SSLKey::generate(Logger* logger, SSLEntropy* entropy, std::string* key) { mbedtls_pk_context pk; bool ok = false; unsigned char buffer[16000]; mbedtls_pk_init(&pk); auto ret = mbedtls_pk_setup(&pk, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA)); if (ret) { logerr(logger, ret, "Error setting up key"); goto error; } ret = mbedtls_rsa_gen_key(mbedtls_pk_rsa(pk), mbedtls_ctr_drbg_random, static_cast(entropy)->random(), 4096, 65537); if (ret) { logerr(logger, ret, "Error generating key"); goto error; } ret = mbedtls_pk_write_key_pem(&pk, buffer, sizeof(buffer)); if (ret) { logerr(logger, ret, "Error writing key"); goto error; } key->assign(reinterpret_cast(buffer)); ok = true; error: mbedtls_pk_free(&pk); return ok; } // static SSLCert* SSLCert::load(Logger* logger, std::string const& data) { std::unique_ptr key(new SSLCertImpl()); if (!key->load(logger, data)) return nullptr; return key.release(); } // static bool SSLCert::generate(Logger* logger, SSLEntropy* entropy, SSLCert* issuer_cert, SSLKey* issuer_key, std::string const& host, SSLKey* key, std::string* cert) { mbedtls_x509write_cert crt; mbedtls_mpi serial; char issuer_name[256]; std::string subject; bool ok = false; time_t now, tmp; struct tm tm; char not_before[20]; char not_after[20]; unsigned char buffer[16000]; mbedtls_x509write_crt_init(&crt); mbedtls_mpi_init(&serial); mbedtls_x509write_crt_set_md_alg(&crt, MBEDTLS_MD_SHA256); if (key) { mbedtls_x509write_crt_set_subject_key( &crt, static_cast(key)->key()); } if (issuer_key) { mbedtls_x509write_crt_set_issuer_key( &crt, static_cast(issuer_key)->key()); } else if (key) { // Without an issuer_key mbedtls_x509write_crt_pem always fails because // it uses the type of the issuer_key to figure out signature algo mbedtls_x509write_crt_set_issuer_key( &crt, static_cast(key)->key()); } subject = "CN=" + host; auto ret = mbedtls_x509write_crt_set_subject_name(&crt, subject.c_str()); if (ret) { logerr(logger, ret, "Invalid subject name"); goto error; } if (issuer_cert) { ret = mbedtls_x509_dn_gets( issuer_name, sizeof(issuer_name), &static_cast(issuer_cert)->cert()->subject); if (ret < 0) { logerr(logger, ret, "Unable to get issuer name"); goto error; } ret = mbedtls_x509write_crt_set_issuer_name(&crt, issuer_name); } else { ret = mbedtls_x509write_crt_set_issuer_name(&crt, subject.c_str()); } if (ret) { logerr(logger, ret, "Invalid issuer name"); goto error; } now = time(nullptr); tmp = now - (24 * 60 * 60 - 1); gmtime_r(&tmp, &tm); strftime(not_before, sizeof(not_before), "%Y%m%d000000", &tm); tmp = now + (30 * 24 * 60 * 60); gmtime_r(&tmp, &tm); strftime(not_after, sizeof(not_after), "%Y%m%d000000", &tm); ret = mbedtls_x509write_crt_set_validity(&crt, not_before, not_after); if (ret) { logerr(logger, ret, "Unable to set validity"); goto error; } if (issuer_cert) { ret = mbedtls_x509write_crt_set_basic_constraints(&crt, 0, -1); } else { ret = mbedtls_x509write_crt_set_basic_constraints(&crt, 1, 1); } if (ret) { logerr(logger, ret, "Unable to set basic constraints"); goto error; } if (issuer_cert) { if (mbedtls_x509write_crt_set_subject_alt_name(&crt, host.c_str())) { logerr(logger, ret, "Unable to set subject alt name"); goto error; } } if (mbedtls_mpi_fill_random( &serial, 32, mbedtls_ctr_drbg_random, static_cast(entropy)->random())) { logerr(logger, ret, "Unable generate serial"); goto error; } if (mbedtls_x509write_crt_set_serial(&crt, &serial)) { logerr(logger, ret, "Unable to set serial"); goto error; } ret = mbedtls_x509write_crt_pem( &crt, buffer, sizeof(buffer), mbedtls_ctr_drbg_random, static_cast(entropy)->random()); if (ret) { logerr(logger, ret, "Error writing cert"); goto error; } cert->assign(reinterpret_cast(buffer)); ok = true; error: mbedtls_mpi_free(&serial); mbedtls_x509write_crt_free(&crt); return ok; } // static SSL* SSL::server(Logger* logger, SSLEntropy* entropy, SSLCert* cert, SSLKey* key, uint16_t flags) { return new SSLServerImpl(logger, entropy, cert, key, flags); } // static SSL* SSL::client(Logger* logger, SSLEntropy* entropy, SSLCertStore* store, std::string const& host, uint16_t flags) { return new SSLClientImpl(logger, entropy, store, host, flags); } // static const uint16_t SSL::UNSECURE = 0x01;