Disable SSLAdapter methods Listen and Accept

Only affects turn server. Refactored to wrap sockets with SSLAdapter
after Accept, using the SSLAdapterFactory to hold needed configuration.

Bug: webrtc:13065
Change-Id: I5df65aad5728d8d40d95b22db6398a573ec7a36f
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/235823
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Commit-Queue: Niels Moller <nisse@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#35258}
This commit is contained in:
Niels Möller 2021-10-20 15:25:09 +02:00 committed by WebRTC LUCI CQ
parent 5e67b6a90d
commit ac9a288274
7 changed files with 96 additions and 36 deletions

View file

@ -11,7 +11,9 @@
#ifndef P2P_BASE_TEST_TURN_SERVER_H_
#define P2P_BASE_TEST_TURN_SERVER_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "api/sequence_checker.h"
@ -104,21 +106,24 @@ class TestTurnServer : public TurnAuthInterface {
// new connections.
rtc::Socket* socket =
thread_->socketserver()->CreateSocket(AF_INET, SOCK_STREAM);
socket->Bind(int_addr);
socket->Listen(5);
if (proto == cricket::PROTO_TLS) {
// For TLS, wrap the TCP socket with an SSL adapter. The adapter must
// be configured with a self-signed certificate for testing.
// Additionally, the client will not present a valid certificate, so we
// must not fail when checking the peer's identity.
rtc::SSLAdapter* adapter = rtc::SSLAdapter::Create(socket);
adapter->SetRole(rtc::SSL_SERVER);
adapter->SetIdentity(
std::unique_ptr<rtc::SSLAdapterFactory> ssl_adapter_factory =
rtc::SSLAdapterFactory::Create();
ssl_adapter_factory->SetRole(rtc::SSL_SERVER);
ssl_adapter_factory->SetIdentity(
rtc::SSLIdentity::Create(common_name, rtc::KeyParams()));
adapter->SetIgnoreBadCert(ignore_bad_cert);
socket = adapter;
ssl_adapter_factory->SetIgnoreBadCert(ignore_bad_cert);
server_.AddInternalServerSocket(socket, proto,
std::move(ssl_adapter_factory));
} else {
server_.AddInternalServerSocket(socket, proto);
}
socket->Bind(int_addr);
socket->Listen(5);
server_.AddInternalServerSocket(socket, proto);
} else {
RTC_NOTREACHED() << "Unknown protocol type: " << proto;
}

View file

@ -152,12 +152,15 @@ void TurnServer::AddInternalSocket(rtc::AsyncPacketSocket* socket,
socket->SignalReadPacket.connect(this, &TurnServer::OnInternalPacket);
}
void TurnServer::AddInternalServerSocket(rtc::Socket* socket,
ProtocolType proto) {
void TurnServer::AddInternalServerSocket(
rtc::Socket* socket,
ProtocolType proto,
std::unique_ptr<rtc::SSLAdapterFactory> ssl_adapter_factory) {
RTC_DCHECK_RUN_ON(thread_);
RTC_DCHECK(server_listen_sockets_.end() ==
server_listen_sockets_.find(socket));
server_listen_sockets_[socket] = proto;
server_listen_sockets_[socket] = {proto, std::move(ssl_adapter_factory)};
socket->SignalReadEvent.connect(this, &TurnServer::OnNewInternalConnection);
}
@ -181,13 +184,19 @@ void TurnServer::AcceptConnection(rtc::Socket* server_socket) {
rtc::SocketAddress accept_addr;
rtc::Socket* accepted_socket = server_socket->Accept(&accept_addr);
if (accepted_socket != NULL) {
ProtocolType proto = server_listen_sockets_[server_socket];
const ServerSocketInfo& info = server_listen_sockets_[server_socket];
if (info.ssl_adapter_factory) {
rtc::SSLAdapter* ssl_adapter =
info.ssl_adapter_factory->CreateAdapter(accepted_socket);
ssl_adapter->StartSSL("");
accepted_socket = ssl_adapter;
}
cricket::AsyncStunTCPSocket* tcp_socket =
new cricket::AsyncStunTCPSocket(accepted_socket);
tcp_socket->SignalClose.connect(this, &TurnServer::OnInternalSocketClose);
// Finally add the socket so it can start communicating with the client.
AddInternalSocket(tcp_socket, proto);
AddInternalSocket(tcp_socket, info.proto);
}
}

View file

@ -23,6 +23,7 @@
#include "p2p/base/port_interface.h"
#include "rtc_base/async_packet_socket.h"
#include "rtc_base/socket_address.h"
#include "rtc_base/ssl_adapter.h"
#include "rtc_base/third_party/sigslot/sigslot.h"
#include "rtc_base/thread.h"
@ -237,7 +238,10 @@ class TurnServer : public sigslot::has_slots<> {
// Starts listening for the connections on this socket. When someone tries
// to connect, the connection will be accepted and a new internal socket
// will be added.
void AddInternalServerSocket(rtc::Socket* socket, ProtocolType proto);
void AddInternalServerSocket(
rtc::Socket* socket,
ProtocolType proto,
std::unique_ptr<rtc::SSLAdapterFactory> ssl_adapter_factory = nullptr);
// Specifies the factory to use for creating external sockets.
void SetExternalSocketFactory(rtc::PacketSocketFactory* factory,
const rtc::SocketAddress& address);
@ -320,7 +324,12 @@ class TurnServer : public sigslot::has_slots<> {
RTC_RUN_ON(thread_);
typedef std::map<rtc::AsyncPacketSocket*, ProtocolType> InternalSocketMap;
typedef std::map<rtc::Socket*, ProtocolType> ServerSocketMap;
struct ServerSocketInfo {
ProtocolType proto;
// If non-null, used to wrap accepted sockets.
std::unique_ptr<rtc::SSLAdapterFactory> ssl_adapter_factory;
};
typedef std::map<rtc::Socket*, ServerSocketInfo> ServerSocketMap;
rtc::Thread* const thread_;
const std::string nonce_key_;

View file

@ -250,21 +250,6 @@ void OpenSSLAdapter::SetRole(SSLRole role) {
role_ = role;
}
Socket* OpenSSLAdapter::Accept(SocketAddress* paddr) {
RTC_DCHECK(role_ == SSL_SERVER);
Socket* socket = SSLAdapter::Accept(paddr);
if (!socket) {
return nullptr;
}
SSLAdapter* adapter = SSLAdapter::Create(socket);
adapter->SetIdentity(identity_->Clone());
adapter->SetRole(rtc::SSL_SERVER);
adapter->SetIgnoreBadCert(ignore_bad_cert_);
adapter->StartSSL("");
return adapter;
}
int OpenSSLAdapter::StartSSL(const char* hostname) {
if (state_ != SSL_NONE)
return -1;
@ -1038,6 +1023,21 @@ void OpenSSLAdapterFactory::SetCertVerifier(
ssl_cert_verifier_ = ssl_cert_verifier;
}
void OpenSSLAdapterFactory::SetIdentity(std::unique_ptr<SSLIdentity> identity) {
RTC_DCHECK(!ssl_session_cache_);
identity_ = std::move(identity);
}
void OpenSSLAdapterFactory::SetRole(SSLRole role) {
RTC_DCHECK(!ssl_session_cache_);
ssl_role_ = role;
}
void OpenSSLAdapterFactory::SetIgnoreBadCert(bool ignore) {
RTC_DCHECK(!ssl_session_cache_);
ignore_bad_cert_ = ignore;
}
OpenSSLAdapter* OpenSSLAdapterFactory::CreateAdapter(Socket* socket) {
if (ssl_session_cache_ == nullptr) {
SSL_CTX* ssl_ctx = OpenSSLAdapter::CreateContext(ssl_mode_, true);
@ -1049,8 +1049,14 @@ OpenSSLAdapter* OpenSSLAdapterFactory::CreateAdapter(Socket* socket) {
std::make_unique<OpenSSLSessionCache>(ssl_mode_, ssl_ctx);
SSL_CTX_free(ssl_ctx);
}
return new OpenSSLAdapter(socket, ssl_session_cache_.get(),
ssl_cert_verifier_);
OpenSSLAdapter* ssl_adapter =
new OpenSSLAdapter(socket, ssl_session_cache_.get(), ssl_cert_verifier_);
ssl_adapter->SetRole(ssl_role_);
ssl_adapter->SetIgnoreBadCert(ignore_bad_cert_);
if (identity_) {
ssl_adapter->SetIdentity(identity_->Clone());
}
return ssl_adapter;
}
OpenSSLAdapter::EarlyExitCatcher::EarlyExitCatcher(OpenSSLAdapter& adapter_ptr)

View file

@ -60,7 +60,6 @@ class OpenSSLAdapter final : public SSLAdapter,
void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) override;
void SetIdentity(std::unique_ptr<SSLIdentity> identity) override;
void SetRole(SSLRole role) override;
Socket* Accept(SocketAddress* paddr) override;
int StartSSL(const char* hostname) override;
int Send(const void* pv, size_t cb) override;
int SendTo(const void* pv, size_t cb, const SocketAddress& addr) override;
@ -191,10 +190,21 @@ class OpenSSLAdapterFactory : public SSLAdapterFactory {
// the first adapter is created with the factory. If it is called after it
// will DCHECK.
void SetMode(SSLMode mode) override;
// Set a custom certificate verifier to be passed down to each instance
// created with this factory. This should only ever be set before the first
// call to the factory and cannot be changed after the fact.
void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) override;
void SetIdentity(std::unique_ptr<SSLIdentity> identity) override;
// Choose whether the socket acts as a server socket or client socket.
void SetRole(SSLRole role) override;
// Methods that control server certificate verification, used in unit tests.
// Do not call these methods in production code.
void SetIgnoreBadCert(bool ignore) override;
// Constructs a new socket using the shared OpenSSLSessionCache. This means
// existing SSLSessions already in the cache will be reused instead of
// re-created for improved performance.
@ -203,6 +213,11 @@ class OpenSSLAdapterFactory : public SSLAdapterFactory {
private:
// Holds the SSLMode (DTLS,TLS) that will be used to set the session cache.
SSLMode ssl_mode_ = SSL_MODE_TLS;
SSLRole ssl_role_ = SSL_CLIENT;
bool ignore_bad_cert_ = false;
std::unique_ptr<SSLIdentity> identity_;
// Holds a cache of existing SSL Sessions.
std::unique_ptr<OpenSSLSessionCache> ssl_session_cache_;
// Provides an optional custom callback for verifying SSL certificates, this

View file

@ -16,8 +16,8 @@
namespace rtc {
SSLAdapterFactory* SSLAdapterFactory::Create() {
return new OpenSSLAdapterFactory();
std::unique_ptr<SSLAdapterFactory> SSLAdapterFactory::Create() {
return std::make_unique<OpenSSLAdapterFactory>();
}
SSLAdapter* SSLAdapter::Create(Socket* socket) {

View file

@ -39,10 +39,21 @@ class SSLAdapterFactory {
// Specify a custom certificate verifier for SSL.
virtual void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) = 0;
// Set the certificate this socket will present to incoming clients.
// Takes ownership of `identity`.
virtual void SetIdentity(std::unique_ptr<SSLIdentity> identity) = 0;
// Choose whether the socket acts as a server socket or client socket.
virtual void SetRole(SSLRole role) = 0;
// Methods that control server certificate verification, used in unit tests.
// Do not call these methods in production code.
virtual void SetIgnoreBadCert(bool ignore) = 0;
// Creates a new SSL adapter, but from a shared context.
virtual SSLAdapter* CreateAdapter(Socket* socket) = 0;
static SSLAdapterFactory* Create();
static std::unique_ptr<SSLAdapterFactory> Create();
};
// Class that abstracts a client-to-server SSL session. It can be created
@ -91,6 +102,11 @@ class SSLAdapter : public AsyncSocketAdapter {
// and deletes `socket`. Otherwise, the returned SSLAdapter takes ownership
// of `socket`.
static SSLAdapter* Create(Socket* socket);
private:
// Not supported.
int Listen(int backlog) override { RTC_CHECK(false); }
Socket* Accept(SocketAddress* paddr) override { RTC_CHECK(false); }
};
///////////////////////////////////////////////////////////////////////////////