Refactor OpenSSLSessionCache out of OpenSSLAdapterFactory.

This changeset refactors the OpenSSLSessionCache out of the Factory. Instead of
directly injecting a pointer to the factory to each OpenSSLAdapter instead just
a pointer to the OpenSSLSessionCache is submitted which the Factory is the sole
owner of. This provides a cleaner dependency injection interface and allows the
OpenSSLSessionCache to be tested independently of the factory that uses it. It
also allows for the factories role to be more clearly defined allowing for
additional dependency injection in future updates.

This change also removes the habit of having OpenSSL typedefs around certain
functions and instead uses the standardised ossl_typ.h header which contains
these typedefs. This makes the headers more directly tied to just what they are
responsible for doing.

Bug: webrtc:9085
Change-Id: I7938178b70acc613856139d387a1b46928dca6ad
Reviewed-on: https://webrtc-review.googlesource.com/66941
Commit-Queue: Benjamin Wright <benwright@webrtc.org>
Reviewed-by: Taylor Brandstetter <deadbeef@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#22758}
This commit is contained in:
Benjamin Wright 2018-04-05 15:39:06 -07:00 committed by Commit Bot
parent fd350d74ee
commit 19aab2ee7c
8 changed files with 258 additions and 75 deletions

View file

@ -861,6 +861,8 @@ rtc_static_library("rtc_base_generic") {
"openssldigest.h",
"opensslidentity.cc",
"opensslidentity.h",
"opensslsessioncache.cc",
"opensslsessioncache.h",
"opensslstreamadapter.cc",
"opensslstreamadapter.h",
"physicalsocketserver.cc",
@ -1359,6 +1361,7 @@ if (rtc_include_tests) {
sources += [
"openssladapter_unittest.cc",
"opensslcommon_unittest.cc",
"opensslsessioncache_unittest.cc",
"ssladapter_unittest.cc",
"sslidentity_unittest.cc",
"sslstreamadapter_unittest.cc",

View file

@ -11,6 +11,11 @@
#ifndef RTC_BASE_OPENSSL_H_
#define RTC_BASE_OPENSSL_H_
#if defined(WEBRTC_WIN)
// Must be included first before openssl headers.
#include "rtc_base/win32.h" // NOLINT
#endif // WEBRTC_WIN
#include <openssl/ssl.h>
#if (OPENSSL_VERSION_NUMBER < 0x10100000L)

View file

@ -14,11 +14,6 @@
#include <unistd.h>
#endif
#if defined(WEBRTC_WIN)
// Must be included first before openssl headers.
#include "rtc_base/win32.h" // NOLINT
#endif // WEBRTC_WIN
#include <openssl/bio.h>
#include <openssl/crypto.h>
#include <openssl/err.h>
@ -26,13 +21,14 @@
#include <openssl/rand.h>
#include <openssl/x509.h>
#include <openssl/x509v3.h>
#include "rtc_base/openssl.h"
#include "rtc_base/arraysize.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_conversions.h"
#include "rtc_base/openssl.h"
#include "rtc_base/opensslcommon.h"
#include "rtc_base/ptr_util.h"
#include "rtc_base/sslroots.h"
#include "rtc_base/stringencode.h"
#include "rtc_base/stringutils.h"
@ -206,9 +202,9 @@ bool OpenSSLAdapter::CleanupSSL() {
}
OpenSSLAdapter::OpenSSLAdapter(AsyncSocket* socket,
OpenSSLAdapterFactory* factory)
OpenSSLSessionCache* ssl_session_cache)
: SSLAdapter(socket),
factory_(factory),
ssl_session_cache_(ssl_session_cache),
state_(SSL_NONE),
role_(SSL_CLIENT),
ssl_read_needs_write_(false),
@ -222,8 +218,8 @@ OpenSSLAdapter::OpenSSLAdapter(AsyncSocket* socket,
// If a factory is used, take a reference on the factory's SSL_CTX.
// Otherwise, we'll create our own later.
// Either way, we'll release our reference via SSL_CTX_free() in Cleanup().
if (factory_) {
ssl_ctx_ = factory_->ssl_ctx();
if (ssl_session_cache_ != nullptr) {
ssl_ctx_ = ssl_session_cache_->GetSSLContext();
RTC_DCHECK(ssl_ctx_);
// Note: if using OpenSSL, requires version 1.1.0 or later.
SSL_CTX_up_ref(ssl_ctx_);
@ -307,7 +303,7 @@ int OpenSSLAdapter::BeginSSL() {
// First set up the context. We should either have a factory, with its own
// pre-existing context, or be running standalone, in which case we will
// need to create one, and specify |false| to disable session caching.
if (!factory_) {
if (ssl_session_cache_ == nullptr) {
RTC_DCHECK(!ssl_ctx_);
ssl_ctx_ = CreateContext(ssl_mode_, false);
}
@ -352,8 +348,8 @@ int OpenSSLAdapter::BeginSSL() {
SSL_set_tlsext_host_name(ssl_, ssl_host_name_.c_str());
// Enable session caching, if configured and a hostname is supplied.
if (factory_) {
SSL_SESSION* cached = factory_->LookupSession(ssl_host_name_);
if (ssl_session_cache_ != nullptr) {
SSL_SESSION* cached = ssl_session_cache_->LookupSession(ssl_host_name_);
if (cached) {
if (SSL_set_session(ssl_, cached) == 0) {
RTC_LOG(LS_WARNING) << "Failed to apply SSL session from cache";
@ -615,7 +611,6 @@ int OpenSSLAdapter::SendTo(const void* pv,
int OpenSSLAdapter::Recv(void* pv, size_t cb, int64_t* timestamp) {
switch (state_) {
case SSL_NONE:
return AsyncSocketAdapter::Recv(pv, cb, timestamp);
@ -883,9 +878,9 @@ int OpenSSLAdapter::SSLVerifyCallback(int ok, X509_STORE_CTX* store) {
int OpenSSLAdapter::NewSSLSessionCallback(SSL* ssl, SSL_SESSION* session) {
OpenSSLAdapter* stream =
reinterpret_cast<OpenSSLAdapter*>(SSL_get_app_data(ssl));
RTC_DCHECK(stream->factory_);
RTC_DCHECK(stream->ssl_session_cache_);
RTC_LOG(LS_INFO) << "Caching SSL session for " << stream->ssl_host_name_;
stream->factory_->AddSession(stream->ssl_host_name_, session);
stream->ssl_session_cache_->AddSession(stream->ssl_host_name_, session);
return 1; // We've taken ownership of the session; OpenSSL shouldn't free it.
}
@ -984,43 +979,26 @@ std::string TransformAlpnProtocols(
// OpenSSLAdapterFactory
//////////////////////////////////////////////////////////////////////
OpenSSLAdapterFactory::OpenSSLAdapterFactory()
: ssl_mode_(SSL_MODE_TLS), ssl_ctx_(nullptr) {}
OpenSSLAdapterFactory::~OpenSSLAdapterFactory() {
for (auto it : sessions_) {
SSL_SESSION_free(it.second);
}
SSL_CTX_free(ssl_ctx_);
}
OpenSSLAdapterFactory::OpenSSLAdapterFactory() = default;
OpenSSLAdapterFactory::~OpenSSLAdapterFactory() = default;
void OpenSSLAdapterFactory::SetMode(SSLMode mode) {
RTC_DCHECK(!ssl_ctx_);
RTC_DCHECK(!ssl_session_cache_);
ssl_mode_ = mode;
}
OpenSSLAdapter* OpenSSLAdapterFactory::CreateAdapter(AsyncSocket* socket) {
if (!ssl_ctx_) {
bool enable_cache = true;
ssl_ctx_ = OpenSSLAdapter::CreateContext(ssl_mode_, enable_cache);
if (!ssl_ctx_) {
if (ssl_session_cache_ == nullptr) {
SSL_CTX* ssl_ctx =
OpenSSLAdapter::CreateContext(ssl_mode_, /* enable_cache = */ true);
if (ssl_ctx == nullptr) {
return nullptr;
}
// The OpenSSLSessionCache will upref the ssl_ctx.
ssl_session_cache_ = MakeUnique<OpenSSLSessionCache>(ssl_mode_, ssl_ctx);
SSL_CTX_free(ssl_ctx);
}
return new OpenSSLAdapter(socket, this);
return new OpenSSLAdapter(socket, ssl_session_cache_.get());
}
SSL_SESSION* OpenSSLAdapterFactory::LookupSession(const std::string& hostname) {
auto it = sessions_.find(hostname);
return (it != sessions_.end()) ? it->second : nullptr;
}
void OpenSSLAdapterFactory::AddSession(const std::string& hostname,
SSL_SESSION* new_session) {
SSL_SESSION* old_session = LookupSession(hostname);
SSL_SESSION_free(old_session);
sessions_[hostname] = new_session;
}
} // namespace rtc
} // namespace rtc

View file

@ -11,30 +11,29 @@
#ifndef RTC_BASE_OPENSSLADAPTER_H_
#define RTC_BASE_OPENSSLADAPTER_H_
#include <openssl/ossl_typ.h>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "rtc_base/buffer.h"
#include "rtc_base/messagehandler.h"
#include "rtc_base/messagequeue.h"
#include "rtc_base/opensslidentity.h"
#include "rtc_base/opensslsessioncache.h"
#include "rtc_base/ssladapter.h"
typedef struct ssl_st SSL;
typedef struct ssl_ctx_st SSL_CTX;
typedef struct x509_store_ctx_st X509_STORE_CTX;
typedef struct ssl_session_st SSL_SESSION;
namespace rtc {
class OpenSSLAdapterFactory;
class OpenSSLAdapter : public SSLAdapter, public MessageHandler {
public:
static bool InitializeSSL(VerificationCallback callback);
static bool CleanupSSL();
explicit OpenSSLAdapter(AsyncSocket* socket,
OpenSSLAdapterFactory* factory = nullptr);
OpenSSLSessionCache* ssl_session_cache = nullptr);
~OpenSSLAdapter() override;
void SetIgnoreBadCert(bool ignore) override;
@ -110,7 +109,7 @@ class OpenSSLAdapter : public SSLAdapter, public MessageHandler {
// Parent object that maintains shared state.
// Can be null if state sharing is not needed.
OpenSSLAdapterFactory* factory_;
OpenSSLSessionCache* ssl_session_cache_ = nullptr;
SSLState state_;
std::unique_ptr<OpenSSLIdentity> identity_;
@ -144,31 +143,32 @@ class OpenSSLAdapter : public SSLAdapter, public MessageHandler {
std::string TransformAlpnProtocols(const std::vector<std::string>& protos);
/////////////////////////////////////////////////////////////////////////////
// The OpenSSLAdapterFactory is responsbile for creating multiple new
// OpenSSLAdapters with a shared SSL_CTX and a shared SSL_SESSION cache. The
// SSL_SESSION cache allows existing SSL_SESSIONS to be reused instead of
// recreating them leading to a significant performance improvement.
class OpenSSLAdapterFactory : public SSLAdapterFactory {
public:
OpenSSLAdapterFactory();
~OpenSSLAdapterFactory() override;
// Set the SSL Mode to use with this factory. This should only be set before
// the first adapter is created with the factory. If it is called after it
// will DCHECK.
void SetMode(SSLMode mode) override;
// Constructs a new socket using the shared OpenSSLSessionCache. This means
// existing SSLSessions already in the cache will be reused instead of
// re-created for improved performance.
OpenSSLAdapter* CreateAdapter(AsyncSocket* socket) override;
static OpenSSLAdapterFactory* Create();
private:
SSL_CTX* ssl_ctx() { return ssl_ctx_; }
// Looks up a session by hostname. The returned SSL_SESSION is not up_refed.
SSL_SESSION* LookupSession(const std::string& hostname);
// Adds a session to the cache, and up_refs it. Any existing session with the
// same hostname is replaced.
void AddSession(const std::string& hostname, SSL_SESSION* session);
// Holds the SSLMode (DTLS,TLS) that will be used to set the session cache.
SSLMode ssl_mode_ = SSL_MODE_TLS;
// Holds a cache of existing SSL Sessions.
std::unique_ptr<OpenSSLSessionCache> ssl_session_cache_;
// TODO(benwright): Remove this when context is moved to OpenSSLCommon.
// Hold a friend class to the OpenSSLAdapter to retrieve the context.
friend class OpenSSLAdapter;
SSLMode ssl_mode_;
// Holds the shared SSL_CTX for all created adapters.
SSL_CTX* ssl_ctx_;
// Map of hostnames to SSL_SESSIONs; holds references to the SSL_SESSIONs,
// which are cleaned up when the factory is destroyed.
// TODO(juberti): Add LRU eviction to keep the cache from growing forever.
std::map<std::string, SSL_SESSION*> sessions_;
};
} // namespace rtc

View file

@ -0,0 +1,52 @@
/*
* Copyright 2018 The WebRTC Project Authors. All rights reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "rtc_base/opensslsessioncache.h"
#include "rtc_base/checks.h"
#include "rtc_base/openssl.h"
namespace rtc {
OpenSSLSessionCache::OpenSSLSessionCache(SSLMode ssl_mode, SSL_CTX* ssl_ctx)
: ssl_mode_(ssl_mode), ssl_ctx_(ssl_ctx) {
// It is invalid to pass in a null context.
RTC_DCHECK(ssl_ctx != nullptr);
SSL_CTX_up_ref(ssl_ctx);
}
OpenSSLSessionCache::~OpenSSLSessionCache() {
for (auto it : sessions_) {
SSL_SESSION_free(it.second);
}
SSL_CTX_free(ssl_ctx_);
}
SSL_SESSION* OpenSSLSessionCache::LookupSession(
const std::string& hostname) const {
auto it = sessions_.find(hostname);
return (it != sessions_.end()) ? it->second : nullptr;
}
void OpenSSLSessionCache::AddSession(const std::string& hostname,
SSL_SESSION* new_session) {
SSL_SESSION* old_session = LookupSession(hostname);
SSL_SESSION_free(old_session);
sessions_[hostname] = new_session;
}
SSL_CTX* OpenSSLSessionCache::GetSSLContext() const {
return ssl_ctx_;
}
SSLMode OpenSSLSessionCache::GetSSLMode() const {
return ssl_mode_;
}
} // namespace rtc

View file

@ -0,0 +1,63 @@
/*
* Copyright 2018 The WebRTC Project Authors. All rights reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef RTC_BASE_OPENSSLSESSIONCACHE_H_
#define RTC_BASE_OPENSSLSESSIONCACHE_H_
#include <openssl/ossl_typ.h>
#include <map>
#include <string>
#include "rtc_base/constructormagic.h"
#include "rtc_base/sslstreamadapter.h"
namespace rtc {
// The OpenSSLSessionCache maps hostnames to SSL_SESSIONS. This cache is
// owned by the OpenSSLAdapterFactory and is passed down to each OpenSSLAdapter
// created with the factory.
class OpenSSLSessionCache final {
public:
// Creates a new OpenSSLSessionCache using the provided the SSL_CTX and
// the ssl_mode. The SSL_CTX will be up_refed. ssl_ctx cannot be nullptr,
// the constructor immediately dchecks this.
OpenSSLSessionCache(SSLMode ssl_mode, SSL_CTX* ssl_ctx);
// Frees the cached SSL_SESSIONS and then frees the SSL_CTX.
~OpenSSLSessionCache();
// Looks up a session by hostname. The returned SSL_SESSION is not up_refed.
SSL_SESSION* LookupSession(const std::string& hostname) const;
// Adds a session to the cache, and up_refs it. Any existing session with the
// same hostname is replaced.
void AddSession(const std::string& hostname, SSL_SESSION* session);
// Returns the true underlying SSL Context that holds these cached sessions.
SSL_CTX* GetSSLContext() const;
// The SSL Mode tht the OpenSSLSessionCache was constructed with. This cannot
// be changed after launch.
SSLMode GetSSLMode() const;
private:
// Holds the SSL Mode that the OpenSSLCache was initialized with. This is
// immutable after creation and cannot change.
const SSLMode ssl_mode_;
/// SSL Context for all shared cached sessions. This SSL_CTX is initialized
// with SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_CLIENT); Meaning
// all client sessions will be added to the cache internal to the context.
SSL_CTX* ssl_ctx_ = nullptr;
// Map of hostnames to SSL_SESSIONs; holds references to the SSL_SESSIONs,
// which are cleaned up when the factory is destroyed.
// TODO(juberti): Add LRU eviction to keep the cache from growing forever.
std::map<std::string, SSL_SESSION*> sessions_;
// The cache should never be copied or assigned directly.
RTC_DISALLOW_COPY_AND_ASSIGN(OpenSSLSessionCache);
};
} // namespace rtc
#endif // RTC_BASE_OPENSSLSESSIONCACHE_H_

View file

@ -0,0 +1,85 @@
/*
* Copyright 2018 The WebRTC Project Authors. All rights reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include <openssl/ssl.h>
#include <stdlib.h>
#include <map>
#include <memory>
#include "rtc_base/gunit.h"
#include "rtc_base/openssl.h"
#include "rtc_base/opensslsessioncache.h"
namespace rtc {
TEST(OpenSSLSessionCache, DTLSModeSetCorrectly) {
SSL_CTX* ssl_ctx = SSL_CTX_new(DTLSv1_2_client_method());
OpenSSLSessionCache session_cache(SSL_MODE_DTLS, ssl_ctx);
EXPECT_EQ(session_cache.GetSSLMode(), SSL_MODE_DTLS);
SSL_CTX_free(ssl_ctx);
}
TEST(OpenSSLSessionCache, TLSModeSetCorrectly) {
SSL_CTX* ssl_ctx = SSL_CTX_new(TLSv1_2_client_method());
OpenSSLSessionCache session_cache(SSL_MODE_TLS, ssl_ctx);
EXPECT_EQ(session_cache.GetSSLMode(), SSL_MODE_TLS);
SSL_CTX_free(ssl_ctx);
}
TEST(OpenSSLSessionCache, SSLContextSetCorrectly) {
SSL_CTX* ssl_ctx = SSL_CTX_new(DTLSv1_2_client_method());
OpenSSLSessionCache session_cache(SSL_MODE_DTLS, ssl_ctx);
EXPECT_EQ(session_cache.GetSSLContext(), ssl_ctx);
SSL_CTX_free(ssl_ctx);
}
TEST(OpenSSLSessionCache, InvalidLookupReturnsNullptr) {
SSL_CTX* ssl_ctx = SSL_CTX_new(DTLSv1_2_client_method());
OpenSSLSessionCache session_cache(SSL_MODE_DTLS, ssl_ctx);
EXPECT_EQ(session_cache.LookupSession("Invalid"), nullptr);
EXPECT_EQ(session_cache.LookupSession(""), nullptr);
EXPECT_EQ(session_cache.LookupSession("."), nullptr);
SSL_CTX_free(ssl_ctx);
}
TEST(OpenSSLSessionCache, SimpleValidSessionLookup) {
SSL_CTX* ssl_ctx = SSL_CTX_new(DTLSv1_2_client_method());
SSL_SESSION* ssl_session = SSL_SESSION_new(ssl_ctx);
OpenSSLSessionCache session_cache(SSL_MODE_DTLS, ssl_ctx);
session_cache.AddSession("webrtc.org", ssl_session);
EXPECT_EQ(session_cache.LookupSession("webrtc.org"), ssl_session);
SSL_CTX_free(ssl_ctx);
}
TEST(OpenSSLSessionCache, AddToExistingReplacesPrevious) {
SSL_CTX* ssl_ctx = SSL_CTX_new(DTLSv1_2_client_method());
SSL_SESSION* ssl_session_1 = SSL_SESSION_new(ssl_ctx);
SSL_SESSION* ssl_session_2 = SSL_SESSION_new(ssl_ctx);
OpenSSLSessionCache session_cache(SSL_MODE_DTLS, ssl_ctx);
session_cache.AddSession("webrtc.org", ssl_session_1);
session_cache.AddSession("webrtc.org", ssl_session_2);
EXPECT_EQ(session_cache.LookupSession("webrtc.org"), ssl_session_2);
SSL_CTX_free(ssl_ctx);
}
} // namespace rtc

View file

@ -11,6 +11,8 @@
#ifndef RTC_BASE_OPENSSLSTREAMADAPTER_H_
#define RTC_BASE_OPENSSLSTREAMADAPTER_H_
#include <openssl/ossl_typ.h>
#include <string>
#include <memory>
#include <vector>
@ -19,11 +21,6 @@
#include "rtc_base/opensslidentity.h"
#include "rtc_base/sslstreamadapter.h"
typedef struct ssl_st SSL;
typedef struct ssl_ctx_st SSL_CTX;
typedef struct ssl_cipher_st SSL_CIPHER;
typedef struct x509_store_ctx_st X509_STORE_CTX;
namespace rtc {
// This class was written with OpenSSLAdapter (a socket adapter) as a