// -*- mode: c++; c-basic-offset: 2; -*- #include "common.hh" #define SSL oSSL #include #include #include #include #include #include #undef SSL #include #include "buffer.hh" #include "logger.hh" #include "ssl.hh" namespace { void check_init() { #if OPENSSL_VERSION_NUMBER < 0x10100000 static bool initialized; if (initialized) return; initialized = true; SSL_load_error_strings(); SSL_library_init(); #endif } class SSLEntropyImpl : public SSLEntropy { public: SSLEntropyImpl() { } }; class SSLCertStoreImpl : public SSLCertStore { public: SSLCertStoreImpl(std::string const& filename) : file_(filename) { } ~SSLCertStoreImpl() override { } std::string const& file() const { return file_; } private: std::string file_; }; class SSLKeyImpl : public SSLKey { public: SSLKeyImpl(EVP_PKEY* key) : key_(key) { } ~SSLKeyImpl() override { EVP_PKEY_free(key_); } EVP_PKEY* key() const { return key_; } private: EVP_PKEY* const key_; }; class SSLCertImpl : public SSLCert { public: SSLCertImpl(X509* x509) : x509_(x509) { } ~SSLCertImpl() override { X509_free(x509_); } X509* x509() const { return x509_; } private: X509* const x509_; }; void logerr(Logger* logger, const char* message) { bool any = false; while (true) { auto err = ERR_get_error(); if (!err) { if (!any) logger->out(Logger::ERR, "%s: Unknown error", message); break; } any = true; char buffer[512]; ERR_error_string_n(err, buffer, sizeof(buffer)); logger->out(Logger::ERR, "%s: %s", message, buffer); } } class SSLImpl : public SSL { public: SSLImpl(Logger* logger, uint16_t flags) : logger_(logger), flags_(flags), bio_method({ (99 | BIO_TYPE_SOURCE_SINK), "SSLImpl", bio_write, bio_read, bio_puts, bio_gets, bio_ctrl, bio_create, bio_destroy, nullptr }), ctx_(nullptr), ssl_(nullptr), bio_(nullptr), rbuf_(nullptr), wbuf_(nullptr) { } ~SSLImpl() override { if (ssl_) SSL_free(ssl_); if (ctx_) SSL_CTX_free(ctx_); if (bio_) BIO_free(bio_); } bool unsecure() const { return flags_ & UNSECURE; } TransferResult transfer(Buffer* ssl_in, Buffer* ssl_out, Buffer* data_in, Buffer* data_out) override { bool want_read = false, want_write = false; rbuf_ = ssl_in; wbuf_ = ssl_out; while (!want_read || !want_write) { switch (state_) { case HANDSHAKE: { auto ret = handshake(); if (ret == 1) { state_ = TRANSFER; continue; } ret = SSL_get_error(ssl_, ret); if (ret == SSL_ERROR_WANT_READ || ret == SSL_ERROR_WANT_WRITE) { return SSL::NO_ERR; } logsslerr("Handshake", ret); return SSL::ERR; } case CLOSED: { auto ret = SSL_shutdown(ssl_); if (ret == 1) { return SSL::CLOSED; } else if (ret == 0) { continue; } ret = SSL_get_error(ssl_, ret); if (ret == SSL_ERROR_WANT_READ || ret == SSL_ERROR_WANT_WRITE) { return SSL::NO_ERR; } logsslerr("Close", ret); return SSL::ERR; } case ERROR: return SSL::ERR; case TRANSFER: { size_t avail; auto wptr = data_out->write_ptr(&avail); if (avail > 0) { auto ret = SSL_read(ssl_, wptr, avail); if (ret > 0) { data_out->commit(ret); } else { ret = SSL_get_error(ssl_, ret); if (ret == SSL_ERROR_WANT_READ || ret == SSL_ERROR_WANT_WRITE) { want_read = true; } else if (ret == SSL_ERROR_ZERO_RETURN) { return SSL::CLOSED; } else { logsslerr("SSL_read", ret); return SSL::ERR; } } } else { assert(false); want_read = true; } auto rptr = data_in->read_ptr(&avail); if (avail > 0) { auto ret = SSL_write(ssl_, rptr, avail); if (ret > 0) { data_in->consume(ret); } else { ret = SSL_get_error(ssl_, ret); if (ret == SSL_ERROR_WANT_READ || ret == SSL_ERROR_WANT_WRITE) { want_write = true; } else if (ret == SSL_ERROR_ZERO_RETURN) { return SSL::CLOSED; } else { logsslerr("SSL_write", ret); return SSL::ERR; } } } else { want_write = true; } break; } } } rbuf_ = nullptr; wbuf_ = nullptr; return NO_ERR; } void close() override { switch (state_) { case CLOSED: case ERROR: return; default: state_ = CLOSED; } } protected: Logger* const logger_; uint16_t const flags_; void logsslerr(std::string const& message, int err) { state_ = ERROR; switch (err) { case SSL_ERROR_NONE: logger_->out(Logger::ERR, "%s: No error (?)", message.c_str()); return; case SSL_ERROR_ZERO_RETURN: logger_->out(Logger::ERR, "%s: Connection closed", message.c_str()); return; case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_CONNECT: case SSL_ERROR_WANT_ACCEPT: assert(false); logger_->out(Logger::ERR, "%s: Non-blocking error", message.c_str()); return; case SSL_ERROR_WANT_X509_LOOKUP: assert(false); logger_->out(Logger::ERR, "%s: Unknown error", message.c_str()); return; case SSL_ERROR_SYSCALL: case SSL_ERROR_SSL: default: logerr(logger_, message.c_str()); return; } } bool setup(SSL_CTX* ctx) { if (!ctx) return false; ctx_ = ctx; SSL_CTX_set_mode(ctx_, SSL_MODE_ENABLE_PARTIAL_WRITE); ssl_ = SSL_new(ctx_); if (!ssl_) { logerr(logger_, "Unable to create SSL"); return false; } bio_ = BIO_new(&bio_method); bio_->ptr = this; SSL_set_bio(ssl_, bio_, bio_); return true; } oSSL* ssl() { return ssl_; } virtual int handshake() = 0; private: enum State { HANDSHAKE, CLOSED, ERROR, TRANSFER }; BIO_METHOD bio_method; static int bio_write(BIO* bio, const char* buf, int len) { BIO_clear_retry_flags(bio); if (len <= 0) return 0; auto impl = reinterpret_cast(bio->ptr); if (impl->wbuf_) { impl->wbuf_->write(buf, len); return len; } BIO_set_retry_write(bio); return -1; } static int bio_read(BIO* bio, char* buf, int len) { BIO_clear_retry_flags(bio); if (len <= 0) return 0; auto impl = reinterpret_cast(bio->ptr); if (impl->rbuf_) { auto ret = impl->rbuf_->read(buf, len); if (ret > 0) return ret; } BIO_set_retry_read(bio); return -1; } static int bio_puts(BIO* bio, const char* str) { return bio_write(bio, str, strlen(str)); } static int bio_gets(BIO* UNUSED(bio), char* UNUSED(str), int UNUSED(size)) { return -2; } static long bio_ctrl(BIO* UNUSED(bio), int cmd, long UNUSED(num), void* UNUSED(ptr)) { if (cmd == BIO_CTRL_FLUSH) { return 1; } return 0; } static int bio_create(BIO* bio) { bio->shutdown = 0; bio->init = 1; bio->ptr = nullptr; return 1; } static int bio_destroy(BIO* bio) { if (!bio) return 0; bio->init = 0; bio->ptr = nullptr; return 1; } SSL_CTX* ctx_; oSSL* ssl_; State state_; BIO* bio_; Buffer* rbuf_; Buffer* wbuf_; }; class SSLServerImpl : public SSLImpl { public: SSLServerImpl(Logger* logger, uint16_t flags) : SSLImpl(logger, flags) { } bool setup(SSLCert* cert, SSLKey* key) { auto ctx = SSL_CTX_new(SSLv23_server_method()); if (!ctx) { logerr(logger_, "Unable to create server context"); return false; } if (!unsecure()) { SSL_CTX_set_options(ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); } if (cert) { if (SSL_CTX_use_certificate( ctx, static_cast(cert)->x509()) != 1) { logerr(logger_, "Unable to set certificate"); return false; } } if (key) { if (SSL_CTX_use_PrivateKey( ctx, static_cast(key)->key()) != 1) { logerr(logger_, "Unable to set private key"); return false; } } return SSLImpl::setup(ctx); } int handshake() { return SSL_accept(ssl()); } }; class SSLClientImpl : public SSLImpl { public: SSLClientImpl(Logger* logger, uint16_t flags) : SSLImpl(logger, flags) { } bool setup(SSLCertStore* store, std::string const& host) { auto ctx = SSL_CTX_new(SSLv23_client_method()); if (!ctx) { logerr(logger_, "Unable to create client context"); return false; } if (!unsecure()) { SSL_CTX_set_options(ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); } if (SSL_CTX_load_verify_locations( ctx, static_cast(store)->file().c_str(), nullptr) != 1) { logerr(logger_, "Unable to load certificate store"); return false; } if (!SSLImpl::setup(ctx)) return false; // Setup peer verification auto param = SSL_get0_param(ssl()); X509_VERIFY_PARAM_set_hostflags( param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS); X509_VERIFY_PARAM_set1_host(param, host.data(), host.size()); SSL_set_verify(ssl(), unsecure() ? SSL_VERIFY_NONE : SSL_VERIFY_PEER, nullptr); return true; } int handshake() { return SSL_connect(ssl()); } protected: std::string host_; }; } // namespace // static SSLEntropy* SSLEntropy::create(Logger* UNUSED(logger)) { return new SSLEntropyImpl(); } // static SSLCertStore* SSLCertStore::create(Logger* UNUSED(logger), std::string const& bundle) { // TODO(the_jk): Read certificates here and not later when store is used return new SSLCertStoreImpl(bundle); } // static bool SSLKey::generate(Logger* logger, SSLEntropy* UNUSED(entropy), std::string* key) { check_init(); RSA* rsa = RSA_new(); EVP_PKEY* pk = EVP_PKEY_new(); BIGNUM* exp = BN_new(); BIO* bio = BIO_new(BIO_s_mem()); bool ok = false; int len; if (BN_set_word(exp, 65537) != 1) { logerr(logger, "Unable to set exponent"); goto error; } if (RSA_generate_key_ex(rsa, 4096, exp, nullptr) != 1) { logerr(logger, "Unable to generate key"); goto error; } if (!EVP_PKEY_assign_RSA(pk, rsa)) { logerr(logger, "Unable to copy key"); goto error; } rsa = nullptr; if (PEM_write_bio_PKCS8PrivateKey(bio, pk, nullptr, nullptr, 0, nullptr, nullptr) != 1) { logerr(logger, "Error writing key"); goto error; } len = BIO_pending(bio); key->resize(len); BIO_read(bio, &(*key)[0], len); ok = true; error: BIO_free(bio); EVP_PKEY_free(pk); if (rsa) RSA_free(rsa); BN_free(exp); return ok; } // static SSLKey* SSLKey::load(Logger* logger, std::string const& data) { check_init(); EVP_PKEY* key = EVP_PKEY_new(); BIO* bio = BIO_new_mem_buf(data.data(), data.size()); SSLKey* ret = nullptr; if (!PEM_read_bio_PrivateKey(bio, &key, nullptr, nullptr)) { logerr(logger, "Error reading key"); goto error; } ret = new SSLKeyImpl(key); key = nullptr; error: BIO_free(bio); if (key) EVP_PKEY_free(key); return ret; } // static bool SSLCert::generate(Logger* logger, SSLEntropy* UNUSED(entropy), SSLCert* issuer_cert, SSLKey* issuer_key, std::string const& host, SSLKey* key, std::string* cert) { check_init(); X509* x509 = X509_new(); BIO* bio = BIO_new(BIO_s_mem()); BIGNUM* bn = BN_new(); ASN1_INTEGER* serial = ASN1_INTEGER_new(); X509_NAME* name; X509_EXTENSION* ext; X509V3_CTX ctx; EVP_PKEY* sign_key = nullptr; bool ok = false; int len; char const* constraints; std::string tmp; int ret; if (X509_set_version(x509, 2) != 1) { logerr(logger, "Unable to set cert version"); goto error; } if (BN_pseudo_rand(bn, 32, 0, 0) != 1) { logerr(logger, "Unable to generate random serial"); goto error; } if (!BN_to_ASN1_INTEGER(bn, serial)) { logerr(logger, "Unable to convert serial"); goto error; } if (X509_set_serialNumber(x509, serial) != 1) { logerr(logger, "Unable to set serial"); goto error; } if (!X509_gmtime_adj(X509_get_notBefore(x509), - (24 * 60 * 60 - 1))) { logerr(logger, "Unable to not before time"); goto error; } if (!X509_gmtime_adj(X509_get_notAfter(x509), 30 * 24 * 60 * 60)) { logerr(logger, "Unable to not after time"); goto error; } if (key) { if (X509_set_pubkey(x509, static_cast(key)->key()) != 1) { logerr(logger, "Unable to set public key"); goto error; } } name = X509_get_subject_name(x509); if (X509_NAME_add_entry_by_txt( name, "CN", MBSTRING_ASC, reinterpret_cast(host.data()), host.size(), -1, 0) != 1) { logerr(logger, "Unable to set common name"); goto error; } X509V3_set_ctx_nodb(&ctx); if (issuer_cert) { auto issuer = static_cast(issuer_cert)->x509(); X509_set_issuer_name(x509, X509_get_subject_name(issuer)); X509V3_set_ctx(&ctx, issuer, x509, nullptr, nullptr, 0); } else { X509_set_issuer_name(x509, name); X509V3_set_ctx(&ctx, x509, x509, nullptr, nullptr, 0); } if (issuer_cert) { constraints = "CA:FALSE"; } else { constraints = "CA:TRUE,pathlen:1"; } ext = X509V3_EXT_conf_nid(nullptr, &ctx, NID_basic_constraints, const_cast(constraints)); if (!ext) { logerr(logger, "Unable to create basic constraints extension"); goto error; } ret = X509_add_ext(x509, ext, -1); X509_EXTENSION_free(ext); if (ret != 1) { logerr(logger, "Unable to add basic constraints extension"); goto error; } if (issuer_cert) { tmp = "DNS:" + host; ext = X509V3_EXT_conf_nid(nullptr, &ctx, NID_subject_alt_name, const_cast(tmp.c_str())); if (!ext) { logerr(logger, "Unable to create subject alt name extension"); goto error; } ret = X509_add_ext(x509, ext, -1); X509_EXTENSION_free(ext); if (ret != 1) { logerr(logger, "Unable to add subject alt name extension"); goto error; } } if (issuer_key) { sign_key = static_cast(issuer_key)->key(); } else if (key) { sign_key = static_cast(key)->key(); } if (!X509_sign(x509, sign_key, EVP_sha256())) { logerr(logger, "Error signing cert"); goto error; } if (PEM_write_bio_X509(bio, x509) != 1) { logerr(logger, "Error writing cert"); goto error; } len = BIO_pending(bio); cert->resize(len); BIO_read(bio, &(*cert)[0], len); ok = true; error: BN_free(bn); ASN1_INTEGER_free(serial); BIO_free(bio); X509_free(x509); return ok; } // static SSLCert* SSLCert::load(Logger* logger, std::string const& data) { check_init(); X509* x509 = X509_new(); BIO* bio = BIO_new_mem_buf(data.data(), data.size()); SSLCert* ret = nullptr; if (!PEM_read_bio_X509(bio, &x509, nullptr, nullptr)) { logerr(logger, "Error reading cert"); goto error; } ret = new SSLCertImpl(x509); x509 = nullptr; error: BIO_free(bio); if (x509) X509_free(x509); return ret; } // static const uint16_t SSL::UNSECURE = 0x01; // static SSL* SSL::server(Logger* logger, SSLEntropy* entropy, SSLCert* cert, SSLKey* key, uint16_t flags) { std::unique_ptr ret(new SSLServerImpl(logger, flags)); if (!ret->setup(cert, key)) return nullptr; return ret.release(); } // static SSL* SSL::client(Logger* logger, SSLEntropy* entropy, SSLCertStore* store, std::string const& host, uint16_t flags) { std::unique_ptr ret(new SSLClientImpl(logger, flags)); if (!ret->setup(store, host)) return nullptr; return ret.release(); }