diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Makefile.am | 3 | ||||
| -rw-r--r-- | src/ssl_openssl.cc | 682 |
2 files changed, 685 insertions, 0 deletions
diff --git a/src/Makefile.am b/src/Makefile.am index 7b1034c..c050770 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -31,6 +31,9 @@ libmitm_a_SOURCES = mitm.cc if HAVE_MBEDTLS libmitm_a_SOURCES += ssl_mbedtls.cc endif +if HAVE_OPENSSL +libmitm_a_SOURCES += ssl_openssl.cc +endif libmitm_a_CXXFLAGS = $(AM_CXXFLAGS) @SSL_CFLAGS@ tp_genca_SOURCES = genca.cc logger.cc diff --git a/src/ssl_openssl.cc b/src/ssl_openssl.cc new file mode 100644 index 0000000..0c3eed0 --- /dev/null +++ b/src/ssl_openssl.cc @@ -0,0 +1,682 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#define SSL oSSL + +#include <openssl/err.h> +#include <openssl/pem.h> +#include <openssl/ssl.h> +#include <openssl/rsa.h> +#include <openssl/x509.h> +#include <openssl/x509v3.h> + +#undef SSL + +#include <memory> + +#include "buffer.hh" +#include "logger.hh" +#include "ssl.hh" + +namespace { + +void check_init() { +#if OPENSSL_VERSION_NUMBER < 0x10100000 + static bool initialized; + if (initialized) return; + initialized = true; + SSL_load_error_strings(); + SSL_library_init(); +#endif +} + +class SSLEntropyImpl : public SSLEntropy { +public: + SSLEntropyImpl() { + } +}; + +class SSLCertStoreImpl : public SSLCertStore { +public: + SSLCertStoreImpl(std::string const& filename) + : file_(filename) { + } + ~SSLCertStoreImpl() override { + } + + std::string const& file() const { + return file_; + } +private: + std::string file_; +}; + +class SSLKeyImpl : public SSLKey { +public: + SSLKeyImpl(EVP_PKEY* key) + : key_(key) { + } + ~SSLKeyImpl() override { + EVP_PKEY_free(key_); + } + + EVP_PKEY* key() const { + return key_; + } + +private: + EVP_PKEY* const key_; +}; + +class SSLCertImpl : public SSLCert { +public: + SSLCertImpl(X509* x509) + : x509_(x509) { + } + ~SSLCertImpl() override { + X509_free(x509_); + } + + X509* x509() const { + return x509_; + } + +private: + X509* const x509_; +}; + +void logerr(Logger* logger, const char* message) { + bool any = false; + while (true) { + auto err = ERR_get_error(); + if (!err) { + if (!any) logger->out(Logger::ERR, "%s: Unknown error", message); + break; + } + any = true; + char buffer[512]; + ERR_error_string_n(err, buffer, sizeof(buffer)); + logger->out(Logger::ERR, "%s: %s", message, buffer); + } +} + +class SSLImpl : public SSL { +public: + SSLImpl(Logger* logger, uint16_t flags) + : logger_(logger), flags_(flags), + bio_method({ + (99 | BIO_TYPE_SOURCE_SINK), + "SSLImpl", + bio_write, + bio_read, + bio_puts, + bio_gets, + bio_ctrl, + bio_create, + bio_destroy, + nullptr + }), + ctx_(nullptr), ssl_(nullptr), + bio_(nullptr), rbuf_(nullptr), wbuf_(nullptr) { + } + + ~SSLImpl() override { + if (ssl_) SSL_free(ssl_); + if (ctx_) SSL_CTX_free(ctx_); + if (bio_) BIO_free(bio_); + } + + bool unsecure() const { + return flags_ & UNSECURE; + } + + TransferResult transfer(Buffer* ssl_in, Buffer* ssl_out, + Buffer* data_in, Buffer* data_out) override { + bool want_read = false, want_write = false; + rbuf_ = ssl_in; + wbuf_ = ssl_out; + while (!want_read || !want_write) { + switch (state_) { + case HANDSHAKE: { + auto ret = handshake(); + if (ret == 1) { + state_ = TRANSFER; + continue; + } + ret = SSL_get_error(ssl_, ret); + if (ret == SSL_ERROR_WANT_READ || ret == SSL_ERROR_WANT_WRITE) { + return SSL::NO_ERR; + } + logsslerr("Handshake", ret); + return SSL::ERR; + } + case CLOSED: { + auto ret = SSL_shutdown(ssl_); + if (ret == 1) { + return SSL::CLOSED; + } else if (ret == 0) { + continue; + } + ret = SSL_get_error(ssl_, ret); + if (ret == SSL_ERROR_WANT_READ || ret == SSL_ERROR_WANT_WRITE) { + return SSL::NO_ERR; + } + logsslerr("Close", ret); + return SSL::ERR; + } + case ERROR: + return SSL::ERR; + case TRANSFER: { + size_t avail; + auto wptr = data_out->write_ptr(&avail); + if (avail > 0) { + auto ret = SSL_read(ssl_, wptr, avail); + if (ret > 0) { + data_out->commit(ret); + } else { + ret = SSL_get_error(ssl_, ret); + if (ret == SSL_ERROR_WANT_READ || ret == SSL_ERROR_WANT_WRITE) { + want_read = true; + } else if (ret == SSL_ERROR_ZERO_RETURN) { + return SSL::CLOSED; + } else { + logsslerr("SSL_read", ret); + return SSL::ERR; + } + } + } else { + assert(false); + want_read = true; + } + auto rptr = data_in->read_ptr(&avail); + if (avail > 0) { + auto ret = SSL_write(ssl_, rptr, avail); + if (ret > 0) { + data_in->consume(ret); + } else { + ret = SSL_get_error(ssl_, ret); + if (ret == SSL_ERROR_WANT_READ || ret == SSL_ERROR_WANT_WRITE) { + want_write = true; + } else if (ret == SSL_ERROR_ZERO_RETURN) { + return SSL::CLOSED; + } else { + logsslerr("SSL_write", ret); + return SSL::ERR; + } + } + } else { + want_write = true; + } + break; + } + } + } + rbuf_ = nullptr; + wbuf_ = nullptr; + return NO_ERR; + } + + void close() override { + switch (state_) { + case CLOSED: + case ERROR: + return; + default: + state_ = CLOSED; + } + } + +protected: + Logger* const logger_; + uint16_t const flags_; + + void logsslerr(std::string const& message, int err) { + state_ = ERROR; + switch (err) { + case SSL_ERROR_NONE: + logger_->out(Logger::ERR, "%s: No error (?)", message.c_str()); + return; + case SSL_ERROR_ZERO_RETURN: + logger_->out(Logger::ERR, "%s: Connection closed", message.c_str()); + return; + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + case SSL_ERROR_WANT_CONNECT: + case SSL_ERROR_WANT_ACCEPT: + assert(false); + logger_->out(Logger::ERR, "%s: Non-blocking error", message.c_str()); + return; + case SSL_ERROR_WANT_X509_LOOKUP: + assert(false); + logger_->out(Logger::ERR, "%s: Unknown error", message.c_str()); + return; + case SSL_ERROR_SYSCALL: + case SSL_ERROR_SSL: + default: + logerr(logger_, message.c_str()); + return; + } + } + + bool setup(SSL_CTX* ctx) { + if (!ctx) return false; + ctx_ = ctx; + SSL_CTX_set_mode(ctx_, SSL_MODE_ENABLE_PARTIAL_WRITE); + ssl_ = SSL_new(ctx_); + if (!ssl_) { + logerr(logger_, "Unable to create SSL"); + return false; + } + bio_ = BIO_new(&bio_method); + bio_->ptr = this; + SSL_set_bio(ssl_, bio_, bio_); + return true; + } + + oSSL* ssl() { + return ssl_; + } + + virtual int handshake() = 0; + +private: + enum State { + HANDSHAKE, + CLOSED, + ERROR, + TRANSFER + }; + BIO_METHOD bio_method; + + static int bio_write(BIO* bio, const char* buf, int len) { + BIO_clear_retry_flags(bio); + if (len <= 0) return 0; + auto impl = reinterpret_cast<SSLImpl*>(bio->ptr); + if (impl->wbuf_) { + impl->wbuf_->write(buf, len); + return len; + } + BIO_set_retry_write(bio); + return -1; + } + static int bio_read(BIO* bio, char* buf, int len) { + BIO_clear_retry_flags(bio); + if (len <= 0) return 0; + auto impl = reinterpret_cast<SSLImpl*>(bio->ptr); + if (impl->rbuf_) { + auto ret = impl->rbuf_->read(buf, len); + if (ret > 0) return ret; + } + BIO_set_retry_read(bio); + return -1; + } + static int bio_puts(BIO* bio, const char* str) { + return bio_write(bio, str, strlen(str)); + } + static int bio_gets(BIO* UNUSED(bio), char* UNUSED(str), int UNUSED(size)) { + return -2; + } + static long bio_ctrl(BIO* UNUSED(bio), int cmd, + long UNUSED(num), void* UNUSED(ptr)) { + if (cmd == BIO_CTRL_FLUSH) { + return 1; + } + return 0; + } + static int bio_create(BIO* bio) { + bio->shutdown = 0; + bio->init = 1; + bio->ptr = nullptr; + return 1; + } + static int bio_destroy(BIO* bio) { + if (!bio) return 0; + bio->init = 0; + bio->ptr = nullptr; + return 1; + } + SSL_CTX* ctx_; + oSSL* ssl_; + State state_; + BIO* bio_; + Buffer* rbuf_; + Buffer* wbuf_; +}; + +class SSLServerImpl : public SSLImpl { +public: + SSLServerImpl(Logger* logger, uint16_t flags) + : SSLImpl(logger, flags) { + } + + bool setup(SSLCert* cert, SSLKey* key) { + auto ctx = SSL_CTX_new(SSLv23_server_method()); + if (!ctx) { + logerr(logger_, "Unable to create server context"); + return false; + } + if (!unsecure()) { + SSL_CTX_set_options(ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); + } + if (cert) { + if (SSL_CTX_use_certificate( + ctx, static_cast<SSLCertImpl*>(cert)->x509()) != 1) { + logerr(logger_, "Unable to set certificate"); + return false; + } + } + if (key) { + if (SSL_CTX_use_PrivateKey( + ctx, static_cast<SSLKeyImpl*>(key)->key()) != 1) { + logerr(logger_, "Unable to set private key"); + return false; + } + } + return SSLImpl::setup(ctx); + } + + int handshake() { + return SSL_accept(ssl()); + } +}; + +class SSLClientImpl : public SSLImpl { +public: + SSLClientImpl(Logger* logger, uint16_t flags) + : SSLImpl(logger, flags) { + } + + bool setup(SSLCertStore* store, std::string const& host) { + auto ctx = SSL_CTX_new(SSLv23_client_method()); + if (!ctx) { + logerr(logger_, "Unable to create client context"); + return false; + } + if (!unsecure()) { + SSL_CTX_set_options(ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); + } + if (SSL_CTX_load_verify_locations( + ctx, static_cast<SSLCertStoreImpl*>(store)->file().c_str(), + nullptr) != 1) { + logerr(logger_, "Unable to load certificate store"); + return false; + } + if (!SSLImpl::setup(ctx)) return false; + + // Setup peer verification + auto param = SSL_get0_param(ssl()); + X509_VERIFY_PARAM_set_hostflags( + param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS); + X509_VERIFY_PARAM_set1_host(param, host.data(), host.size()); + SSL_set_verify(ssl(), unsecure() ? SSL_VERIFY_NONE : SSL_VERIFY_PEER, + nullptr); + return true; + } + + int handshake() { + return SSL_connect(ssl()); + } + +protected: + std::string host_; +}; + +} // namespace + +// static +SSLEntropy* SSLEntropy::create(Logger* UNUSED(logger)) { + return new SSLEntropyImpl(); +} + +// static +SSLCertStore* SSLCertStore::create(Logger* UNUSED(logger), + std::string const& bundle) { + // TODO(the_jk): Read certificates here and not later when store is used + return new SSLCertStoreImpl(bundle); +} + +// static +bool SSLKey::generate(Logger* logger, SSLEntropy* UNUSED(entropy), + std::string* key) { + check_init(); + + RSA* rsa = RSA_new(); + EVP_PKEY* pk = EVP_PKEY_new(); + BIGNUM* exp = BN_new(); + BIO* bio = BIO_new(BIO_s_mem()); + bool ok = false; + int len; + + if (BN_set_word(exp, 65537) != 1) { + logerr(logger, "Unable to set exponent"); + goto error; + } + if (RSA_generate_key_ex(rsa, 4096, exp, nullptr) != 1) { + logerr(logger, "Unable to generate key"); + goto error; + } + if (!EVP_PKEY_assign_RSA(pk, rsa)) { + logerr(logger, "Unable to copy key"); + goto error; + } + rsa = nullptr; + + if (PEM_write_bio_PKCS8PrivateKey(bio, pk, nullptr, nullptr, 0, + nullptr, nullptr) != 1) { + logerr(logger, "Error writing key"); + goto error; + } + + len = BIO_pending(bio); + key->resize(len); + BIO_read(bio, &(*key)[0], len); + + ok = true; + + error: + BIO_free(bio); + EVP_PKEY_free(pk); + if (rsa) RSA_free(rsa); + BN_free(exp); + return ok; +} + +// static +SSLKey* SSLKey::load(Logger* logger, std::string const& data) { + check_init(); + + EVP_PKEY* key = EVP_PKEY_new(); + BIO* bio = BIO_new_mem_buf(data.data(), data.size()); + SSLKey* ret = nullptr; + + if (!PEM_read_bio_PrivateKey(bio, &key, nullptr, nullptr)) { + logerr(logger, "Error reading key"); + goto error; + } + + ret = new SSLKeyImpl(key); + key = nullptr; + + error: + BIO_free(bio); + if (key) EVP_PKEY_free(key); + return ret; +} + +// static +bool SSLCert::generate(Logger* logger, SSLEntropy* UNUSED(entropy), + SSLCert* issuer_cert, SSLKey* issuer_key, + std::string const& host, SSLKey* key, + std::string* cert) { + check_init(); + + X509* x509 = X509_new(); + BIO* bio = BIO_new(BIO_s_mem()); + BIGNUM* bn = BN_new(); + ASN1_INTEGER* serial = ASN1_INTEGER_new(); + X509_NAME* name; + X509_EXTENSION* ext; + X509V3_CTX ctx; + EVP_PKEY* sign_key = nullptr; + bool ok = false; + int len; + char const* constraints; + std::string tmp; + int ret; + + if (X509_set_version(x509, 2) != 1) { + logerr(logger, "Unable to set cert version"); + goto error; + } + if (BN_pseudo_rand(bn, 32, 0, 0) != 1) { + logerr(logger, "Unable to generate random serial"); + goto error; + } + if (!BN_to_ASN1_INTEGER(bn, serial)) { + logerr(logger, "Unable to convert serial"); + goto error; + } + if (X509_set_serialNumber(x509, serial) != 1) { + logerr(logger, "Unable to set serial"); + goto error; + } + if (!X509_gmtime_adj(X509_get_notBefore(x509), - (24 * 60 * 60 - 1))) { + logerr(logger, "Unable to not before time"); + goto error; + } + if (!X509_gmtime_adj(X509_get_notAfter(x509), 30 * 24 * 60 * 60)) { + logerr(logger, "Unable to not after time"); + goto error; + } + + if (key) { + if (X509_set_pubkey(x509, static_cast<SSLKeyImpl*>(key)->key()) != 1) { + logerr(logger, "Unable to set public key"); + goto error; + } + } + + name = X509_get_subject_name(x509); + if (X509_NAME_add_entry_by_txt( + name, "CN", MBSTRING_ASC, + reinterpret_cast<unsigned char const*>(host.data()), host.size(), + -1, 0) != 1) { + logerr(logger, "Unable to set common name"); + goto error; + } + X509V3_set_ctx_nodb(&ctx); + if (issuer_cert) { + auto issuer = static_cast<SSLCertImpl*>(issuer_cert)->x509(); + X509_set_issuer_name(x509, X509_get_subject_name(issuer)); + X509V3_set_ctx(&ctx, issuer, x509, nullptr, nullptr, 0); + } else { + X509_set_issuer_name(x509, name); + X509V3_set_ctx(&ctx, x509, x509, nullptr, nullptr, 0); + } + + if (issuer_cert) { + constraints = "CA:FALSE"; + } else { + constraints = "CA:TRUE,pathlen:1"; + } + ext = X509V3_EXT_conf_nid(nullptr, &ctx, NID_basic_constraints, + const_cast<char*>(constraints)); + if (!ext) { + logerr(logger, "Unable to create basic constraints extension"); + goto error; + } + ret = X509_add_ext(x509, ext, -1); + X509_EXTENSION_free(ext); + if (ret != 1) { + logerr(logger, "Unable to add basic constraints extension"); + goto error; + } + + if (issuer_cert) { + tmp = "DNS:" + host; + ext = X509V3_EXT_conf_nid(nullptr, &ctx, NID_subject_alt_name, + const_cast<char*>(tmp.c_str())); + if (!ext) { + logerr(logger, "Unable to create subject alt name extension"); + goto error; + } + ret = X509_add_ext(x509, ext, -1); + X509_EXTENSION_free(ext); + if (ret != 1) { + logerr(logger, "Unable to add subject alt name extension"); + goto error; + } + } + + if (issuer_key) { + sign_key = static_cast<SSLKeyImpl*>(issuer_key)->key(); + } else if (key) { + sign_key = static_cast<SSLKeyImpl*>(key)->key(); + } + + if (!X509_sign(x509, sign_key, EVP_sha256())) { + logerr(logger, "Error signing cert"); + goto error; + } + + if (PEM_write_bio_X509(bio, x509) != 1) { + logerr(logger, "Error writing cert"); + goto error; + } + + len = BIO_pending(bio); + cert->resize(len); + BIO_read(bio, &(*cert)[0], len); + + ok = true; + + error: + BN_free(bn); + ASN1_INTEGER_free(serial); + BIO_free(bio); + X509_free(x509); + return ok; +} + +// static +SSLCert* SSLCert::load(Logger* logger, std::string const& data) { + check_init(); + + X509* x509 = X509_new(); + BIO* bio = BIO_new_mem_buf(data.data(), data.size()); + SSLCert* ret = nullptr; + + if (!PEM_read_bio_X509(bio, &x509, nullptr, nullptr)) { + logerr(logger, "Error reading cert"); + goto error; + } + + ret = new SSLCertImpl(x509); + x509 = nullptr; + + error: + BIO_free(bio); + if (x509) X509_free(x509); + return ret; +} + +// static +const uint16_t SSL::UNSECURE = 0x01; + +// static +SSL* SSL::server(Logger* logger, SSLEntropy* entropy, + SSLCert* cert, SSLKey* key, + uint16_t flags) { + std::unique_ptr<SSLServerImpl> ret(new SSLServerImpl(logger, flags)); + if (!ret->setup(cert, key)) return nullptr; + return ret.release(); +} + +// static +SSL* SSL::client(Logger* logger, SSLEntropy* entropy, SSLCertStore* store, + std::string const& host, uint16_t flags) { + std::unique_ptr<SSLClientImpl> ret(new SSLClientImpl(logger, flags)); + if (!ret->setup(store, host)) return nullptr; + return ret.release(); +} |
