diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn index 32047c06d8..f2836b9092 100644 --- a/rtc_base/BUILD.gn +++ b/rtc_base/BUILD.gn @@ -861,6 +861,8 @@ rtc_static_library("rtc_base_generic") { "openssldigest.h", "opensslidentity.cc", "opensslidentity.h", + "opensslsessioncache.cc", + "opensslsessioncache.h", "opensslstreamadapter.cc", "opensslstreamadapter.h", "physicalsocketserver.cc", @@ -1359,6 +1361,7 @@ if (rtc_include_tests) { sources += [ "openssladapter_unittest.cc", "opensslcommon_unittest.cc", + "opensslsessioncache_unittest.cc", "ssladapter_unittest.cc", "sslidentity_unittest.cc", "sslstreamadapter_unittest.cc", diff --git a/rtc_base/openssl.h b/rtc_base/openssl.h index dbbae05319..eeed373c43 100644 --- a/rtc_base/openssl.h +++ b/rtc_base/openssl.h @@ -11,6 +11,11 @@ #ifndef RTC_BASE_OPENSSL_H_ #define RTC_BASE_OPENSSL_H_ +#if defined(WEBRTC_WIN) +// Must be included first before openssl headers. +#include "rtc_base/win32.h" // NOLINT +#endif // WEBRTC_WIN + #include #if (OPENSSL_VERSION_NUMBER < 0x10100000L) diff --git a/rtc_base/openssladapter.cc b/rtc_base/openssladapter.cc index 4c2276bf77..03b3ca8c62 100644 --- a/rtc_base/openssladapter.cc +++ b/rtc_base/openssladapter.cc @@ -14,11 +14,6 @@ #include #endif -#if defined(WEBRTC_WIN) -// Must be included first before openssl headers. -#include "rtc_base/win32.h" // NOLINT -#endif // WEBRTC_WIN - #include #include #include @@ -26,13 +21,14 @@ #include #include #include +#include "rtc_base/openssl.h" #include "rtc_base/arraysize.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "rtc_base/numerics/safe_conversions.h" -#include "rtc_base/openssl.h" #include "rtc_base/opensslcommon.h" +#include "rtc_base/ptr_util.h" #include "rtc_base/sslroots.h" #include "rtc_base/stringencode.h" #include "rtc_base/stringutils.h" @@ -206,9 +202,9 @@ bool OpenSSLAdapter::CleanupSSL() { } OpenSSLAdapter::OpenSSLAdapter(AsyncSocket* socket, - OpenSSLAdapterFactory* factory) + OpenSSLSessionCache* ssl_session_cache) : SSLAdapter(socket), - factory_(factory), + ssl_session_cache_(ssl_session_cache), state_(SSL_NONE), role_(SSL_CLIENT), ssl_read_needs_write_(false), @@ -222,8 +218,8 @@ OpenSSLAdapter::OpenSSLAdapter(AsyncSocket* socket, // If a factory is used, take a reference on the factory's SSL_CTX. // Otherwise, we'll create our own later. // Either way, we'll release our reference via SSL_CTX_free() in Cleanup(). - if (factory_) { - ssl_ctx_ = factory_->ssl_ctx(); + if (ssl_session_cache_ != nullptr) { + ssl_ctx_ = ssl_session_cache_->GetSSLContext(); RTC_DCHECK(ssl_ctx_); // Note: if using OpenSSL, requires version 1.1.0 or later. SSL_CTX_up_ref(ssl_ctx_); @@ -307,7 +303,7 @@ int OpenSSLAdapter::BeginSSL() { // First set up the context. We should either have a factory, with its own // pre-existing context, or be running standalone, in which case we will // need to create one, and specify |false| to disable session caching. - if (!factory_) { + if (ssl_session_cache_ == nullptr) { RTC_DCHECK(!ssl_ctx_); ssl_ctx_ = CreateContext(ssl_mode_, false); } @@ -352,8 +348,8 @@ int OpenSSLAdapter::BeginSSL() { SSL_set_tlsext_host_name(ssl_, ssl_host_name_.c_str()); // Enable session caching, if configured and a hostname is supplied. - if (factory_) { - SSL_SESSION* cached = factory_->LookupSession(ssl_host_name_); + if (ssl_session_cache_ != nullptr) { + SSL_SESSION* cached = ssl_session_cache_->LookupSession(ssl_host_name_); if (cached) { if (SSL_set_session(ssl_, cached) == 0) { RTC_LOG(LS_WARNING) << "Failed to apply SSL session from cache"; @@ -615,7 +611,6 @@ int OpenSSLAdapter::SendTo(const void* pv, int OpenSSLAdapter::Recv(void* pv, size_t cb, int64_t* timestamp) { switch (state_) { - case SSL_NONE: return AsyncSocketAdapter::Recv(pv, cb, timestamp); @@ -883,9 +878,9 @@ int OpenSSLAdapter::SSLVerifyCallback(int ok, X509_STORE_CTX* store) { int OpenSSLAdapter::NewSSLSessionCallback(SSL* ssl, SSL_SESSION* session) { OpenSSLAdapter* stream = reinterpret_cast(SSL_get_app_data(ssl)); - RTC_DCHECK(stream->factory_); + RTC_DCHECK(stream->ssl_session_cache_); RTC_LOG(LS_INFO) << "Caching SSL session for " << stream->ssl_host_name_; - stream->factory_->AddSession(stream->ssl_host_name_, session); + stream->ssl_session_cache_->AddSession(stream->ssl_host_name_, session); return 1; // We've taken ownership of the session; OpenSSL shouldn't free it. } @@ -984,43 +979,26 @@ std::string TransformAlpnProtocols( // OpenSSLAdapterFactory ////////////////////////////////////////////////////////////////////// -OpenSSLAdapterFactory::OpenSSLAdapterFactory() - : ssl_mode_(SSL_MODE_TLS), ssl_ctx_(nullptr) {} - -OpenSSLAdapterFactory::~OpenSSLAdapterFactory() { - for (auto it : sessions_) { - SSL_SESSION_free(it.second); - } - SSL_CTX_free(ssl_ctx_); -} +OpenSSLAdapterFactory::OpenSSLAdapterFactory() = default; +OpenSSLAdapterFactory::~OpenSSLAdapterFactory() = default; void OpenSSLAdapterFactory::SetMode(SSLMode mode) { - RTC_DCHECK(!ssl_ctx_); + RTC_DCHECK(!ssl_session_cache_); ssl_mode_ = mode; } OpenSSLAdapter* OpenSSLAdapterFactory::CreateAdapter(AsyncSocket* socket) { - if (!ssl_ctx_) { - bool enable_cache = true; - ssl_ctx_ = OpenSSLAdapter::CreateContext(ssl_mode_, enable_cache); - if (!ssl_ctx_) { + if (ssl_session_cache_ == nullptr) { + SSL_CTX* ssl_ctx = + OpenSSLAdapter::CreateContext(ssl_mode_, /* enable_cache = */ true); + if (ssl_ctx == nullptr) { return nullptr; } + // The OpenSSLSessionCache will upref the ssl_ctx. + ssl_session_cache_ = MakeUnique(ssl_mode_, ssl_ctx); + SSL_CTX_free(ssl_ctx); } - - return new OpenSSLAdapter(socket, this); + return new OpenSSLAdapter(socket, ssl_session_cache_.get()); } -SSL_SESSION* OpenSSLAdapterFactory::LookupSession(const std::string& hostname) { - auto it = sessions_.find(hostname); - return (it != sessions_.end()) ? it->second : nullptr; -} - -void OpenSSLAdapterFactory::AddSession(const std::string& hostname, - SSL_SESSION* new_session) { - SSL_SESSION* old_session = LookupSession(hostname); - SSL_SESSION_free(old_session); - sessions_[hostname] = new_session; -} - -} // namespace rtc +} // namespace rtc diff --git a/rtc_base/openssladapter.h b/rtc_base/openssladapter.h index fbbd88c981..5f5eb80c6e 100644 --- a/rtc_base/openssladapter.h +++ b/rtc_base/openssladapter.h @@ -11,30 +11,29 @@ #ifndef RTC_BASE_OPENSSLADAPTER_H_ #define RTC_BASE_OPENSSLADAPTER_H_ +#include + #include +#include #include +#include + #include "rtc_base/buffer.h" #include "rtc_base/messagehandler.h" #include "rtc_base/messagequeue.h" #include "rtc_base/opensslidentity.h" +#include "rtc_base/opensslsessioncache.h" #include "rtc_base/ssladapter.h" -typedef struct ssl_st SSL; -typedef struct ssl_ctx_st SSL_CTX; -typedef struct x509_store_ctx_st X509_STORE_CTX; -typedef struct ssl_session_st SSL_SESSION; - namespace rtc { -class OpenSSLAdapterFactory; - class OpenSSLAdapter : public SSLAdapter, public MessageHandler { public: static bool InitializeSSL(VerificationCallback callback); static bool CleanupSSL(); explicit OpenSSLAdapter(AsyncSocket* socket, - OpenSSLAdapterFactory* factory = nullptr); + OpenSSLSessionCache* ssl_session_cache = nullptr); ~OpenSSLAdapter() override; void SetIgnoreBadCert(bool ignore) override; @@ -110,7 +109,7 @@ class OpenSSLAdapter : public SSLAdapter, public MessageHandler { // Parent object that maintains shared state. // Can be null if state sharing is not needed. - OpenSSLAdapterFactory* factory_; + OpenSSLSessionCache* ssl_session_cache_ = nullptr; SSLState state_; std::unique_ptr identity_; @@ -144,31 +143,32 @@ class OpenSSLAdapter : public SSLAdapter, public MessageHandler { std::string TransformAlpnProtocols(const std::vector& protos); ///////////////////////////////////////////////////////////////////////////// + +// The OpenSSLAdapterFactory is responsbile for creating multiple new +// OpenSSLAdapters with a shared SSL_CTX and a shared SSL_SESSION cache. The +// SSL_SESSION cache allows existing SSL_SESSIONS to be reused instead of +// recreating them leading to a significant performance improvement. class OpenSSLAdapterFactory : public SSLAdapterFactory { public: OpenSSLAdapterFactory(); ~OpenSSLAdapterFactory() override; - + // Set the SSL Mode to use with this factory. This should only be set before + // the first adapter is created with the factory. If it is called after it + // will DCHECK. void SetMode(SSLMode mode) override; + // Constructs a new socket using the shared OpenSSLSessionCache. This means + // existing SSLSessions already in the cache will be reused instead of + // re-created for improved performance. OpenSSLAdapter* CreateAdapter(AsyncSocket* socket) override; - static OpenSSLAdapterFactory* Create(); private: - SSL_CTX* ssl_ctx() { return ssl_ctx_; } - // Looks up a session by hostname. The returned SSL_SESSION is not up_refed. - SSL_SESSION* LookupSession(const std::string& hostname); - // Adds a session to the cache, and up_refs it. Any existing session with the - // same hostname is replaced. - void AddSession(const std::string& hostname, SSL_SESSION* session); + // Holds the SSLMode (DTLS,TLS) that will be used to set the session cache. + SSLMode ssl_mode_ = SSL_MODE_TLS; + // Holds a cache of existing SSL Sessions. + std::unique_ptr ssl_session_cache_; + // TODO(benwright): Remove this when context is moved to OpenSSLCommon. + // Hold a friend class to the OpenSSLAdapter to retrieve the context. friend class OpenSSLAdapter; - - SSLMode ssl_mode_; - // Holds the shared SSL_CTX for all created adapters. - SSL_CTX* ssl_ctx_; - // Map of hostnames to SSL_SESSIONs; holds references to the SSL_SESSIONs, - // which are cleaned up when the factory is destroyed. - // TODO(juberti): Add LRU eviction to keep the cache from growing forever. - std::map sessions_; }; } // namespace rtc diff --git a/rtc_base/opensslsessioncache.cc b/rtc_base/opensslsessioncache.cc new file mode 100644 index 0000000000..2e37d55deb --- /dev/null +++ b/rtc_base/opensslsessioncache.cc @@ -0,0 +1,52 @@ +/* + * Copyright 2018 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "rtc_base/opensslsessioncache.h" +#include "rtc_base/checks.h" +#include "rtc_base/openssl.h" + +namespace rtc { + +OpenSSLSessionCache::OpenSSLSessionCache(SSLMode ssl_mode, SSL_CTX* ssl_ctx) + : ssl_mode_(ssl_mode), ssl_ctx_(ssl_ctx) { + // It is invalid to pass in a null context. + RTC_DCHECK(ssl_ctx != nullptr); + SSL_CTX_up_ref(ssl_ctx); +} + +OpenSSLSessionCache::~OpenSSLSessionCache() { + for (auto it : sessions_) { + SSL_SESSION_free(it.second); + } + SSL_CTX_free(ssl_ctx_); +} + +SSL_SESSION* OpenSSLSessionCache::LookupSession( + const std::string& hostname) const { + auto it = sessions_.find(hostname); + return (it != sessions_.end()) ? it->second : nullptr; +} + +void OpenSSLSessionCache::AddSession(const std::string& hostname, + SSL_SESSION* new_session) { + SSL_SESSION* old_session = LookupSession(hostname); + SSL_SESSION_free(old_session); + sessions_[hostname] = new_session; +} + +SSL_CTX* OpenSSLSessionCache::GetSSLContext() const { + return ssl_ctx_; +} + +SSLMode OpenSSLSessionCache::GetSSLMode() const { + return ssl_mode_; +} + +} // namespace rtc diff --git a/rtc_base/opensslsessioncache.h b/rtc_base/opensslsessioncache.h new file mode 100644 index 0000000000..ee5b525334 --- /dev/null +++ b/rtc_base/opensslsessioncache.h @@ -0,0 +1,63 @@ +/* + * Copyright 2018 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef RTC_BASE_OPENSSLSESSIONCACHE_H_ +#define RTC_BASE_OPENSSLSESSIONCACHE_H_ + +#include +#include +#include + +#include "rtc_base/constructormagic.h" +#include "rtc_base/sslstreamadapter.h" + +namespace rtc { + +// The OpenSSLSessionCache maps hostnames to SSL_SESSIONS. This cache is +// owned by the OpenSSLAdapterFactory and is passed down to each OpenSSLAdapter +// created with the factory. +class OpenSSLSessionCache final { + public: + // Creates a new OpenSSLSessionCache using the provided the SSL_CTX and + // the ssl_mode. The SSL_CTX will be up_refed. ssl_ctx cannot be nullptr, + // the constructor immediately dchecks this. + OpenSSLSessionCache(SSLMode ssl_mode, SSL_CTX* ssl_ctx); + // Frees the cached SSL_SESSIONS and then frees the SSL_CTX. + ~OpenSSLSessionCache(); + // Looks up a session by hostname. The returned SSL_SESSION is not up_refed. + SSL_SESSION* LookupSession(const std::string& hostname) const; + // Adds a session to the cache, and up_refs it. Any existing session with the + // same hostname is replaced. + void AddSession(const std::string& hostname, SSL_SESSION* session); + // Returns the true underlying SSL Context that holds these cached sessions. + SSL_CTX* GetSSLContext() const; + // The SSL Mode tht the OpenSSLSessionCache was constructed with. This cannot + // be changed after launch. + SSLMode GetSSLMode() const; + + private: + // Holds the SSL Mode that the OpenSSLCache was initialized with. This is + // immutable after creation and cannot change. + const SSLMode ssl_mode_; + /// SSL Context for all shared cached sessions. This SSL_CTX is initialized + // with SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_CLIENT); Meaning + // all client sessions will be added to the cache internal to the context. + SSL_CTX* ssl_ctx_ = nullptr; + // Map of hostnames to SSL_SESSIONs; holds references to the SSL_SESSIONs, + // which are cleaned up when the factory is destroyed. + // TODO(juberti): Add LRU eviction to keep the cache from growing forever. + std::map sessions_; + // The cache should never be copied or assigned directly. + RTC_DISALLOW_COPY_AND_ASSIGN(OpenSSLSessionCache); +}; + +} // namespace rtc + +#endif // RTC_BASE_OPENSSLSESSIONCACHE_H_ diff --git a/rtc_base/opensslsessioncache_unittest.cc b/rtc_base/opensslsessioncache_unittest.cc new file mode 100644 index 0000000000..6489b2bc2f --- /dev/null +++ b/rtc_base/opensslsessioncache_unittest.cc @@ -0,0 +1,85 @@ +/* + * Copyright 2018 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include +#include + +#include +#include + +#include "rtc_base/gunit.h" +#include "rtc_base/openssl.h" +#include "rtc_base/opensslsessioncache.h" + +namespace rtc { + +TEST(OpenSSLSessionCache, DTLSModeSetCorrectly) { + SSL_CTX* ssl_ctx = SSL_CTX_new(DTLSv1_2_client_method()); + + OpenSSLSessionCache session_cache(SSL_MODE_DTLS, ssl_ctx); + EXPECT_EQ(session_cache.GetSSLMode(), SSL_MODE_DTLS); + + SSL_CTX_free(ssl_ctx); +} + +TEST(OpenSSLSessionCache, TLSModeSetCorrectly) { + SSL_CTX* ssl_ctx = SSL_CTX_new(TLSv1_2_client_method()); + + OpenSSLSessionCache session_cache(SSL_MODE_TLS, ssl_ctx); + EXPECT_EQ(session_cache.GetSSLMode(), SSL_MODE_TLS); + + SSL_CTX_free(ssl_ctx); +} + +TEST(OpenSSLSessionCache, SSLContextSetCorrectly) { + SSL_CTX* ssl_ctx = SSL_CTX_new(DTLSv1_2_client_method()); + + OpenSSLSessionCache session_cache(SSL_MODE_DTLS, ssl_ctx); + EXPECT_EQ(session_cache.GetSSLContext(), ssl_ctx); + + SSL_CTX_free(ssl_ctx); +} + +TEST(OpenSSLSessionCache, InvalidLookupReturnsNullptr) { + SSL_CTX* ssl_ctx = SSL_CTX_new(DTLSv1_2_client_method()); + + OpenSSLSessionCache session_cache(SSL_MODE_DTLS, ssl_ctx); + EXPECT_EQ(session_cache.LookupSession("Invalid"), nullptr); + EXPECT_EQ(session_cache.LookupSession(""), nullptr); + EXPECT_EQ(session_cache.LookupSession("."), nullptr); + + SSL_CTX_free(ssl_ctx); +} + +TEST(OpenSSLSessionCache, SimpleValidSessionLookup) { + SSL_CTX* ssl_ctx = SSL_CTX_new(DTLSv1_2_client_method()); + SSL_SESSION* ssl_session = SSL_SESSION_new(ssl_ctx); + + OpenSSLSessionCache session_cache(SSL_MODE_DTLS, ssl_ctx); + session_cache.AddSession("webrtc.org", ssl_session); + EXPECT_EQ(session_cache.LookupSession("webrtc.org"), ssl_session); + + SSL_CTX_free(ssl_ctx); +} + +TEST(OpenSSLSessionCache, AddToExistingReplacesPrevious) { + SSL_CTX* ssl_ctx = SSL_CTX_new(DTLSv1_2_client_method()); + SSL_SESSION* ssl_session_1 = SSL_SESSION_new(ssl_ctx); + SSL_SESSION* ssl_session_2 = SSL_SESSION_new(ssl_ctx); + + OpenSSLSessionCache session_cache(SSL_MODE_DTLS, ssl_ctx); + session_cache.AddSession("webrtc.org", ssl_session_1); + session_cache.AddSession("webrtc.org", ssl_session_2); + EXPECT_EQ(session_cache.LookupSession("webrtc.org"), ssl_session_2); + + SSL_CTX_free(ssl_ctx); +} + +} // namespace rtc diff --git a/rtc_base/opensslstreamadapter.h b/rtc_base/opensslstreamadapter.h index 97ab557f41..7a6e099d22 100644 --- a/rtc_base/opensslstreamadapter.h +++ b/rtc_base/opensslstreamadapter.h @@ -11,6 +11,8 @@ #ifndef RTC_BASE_OPENSSLSTREAMADAPTER_H_ #define RTC_BASE_OPENSSLSTREAMADAPTER_H_ +#include + #include #include #include @@ -19,11 +21,6 @@ #include "rtc_base/opensslidentity.h" #include "rtc_base/sslstreamadapter.h" -typedef struct ssl_st SSL; -typedef struct ssl_ctx_st SSL_CTX; -typedef struct ssl_cipher_st SSL_CIPHER; -typedef struct x509_store_ctx_st X509_STORE_CTX; - namespace rtc { // This class was written with OpenSSLAdapter (a socket adapter) as a