diff --git a/p2p/base/dtls_transport.cc b/p2p/base/dtls_transport.cc index f6f6847a96..056b5a56dd 100644 --- a/p2p/base/dtls_transport.cc +++ b/p2p/base/dtls_transport.cc @@ -379,7 +379,8 @@ bool DtlsTransport::SetupDtls() { dtls_->SetMode(rtc::SSL_MODE_DTLS); dtls_->SetMaxProtocolVersion(ssl_max_version_); dtls_->SetServerRole(*dtls_role_); - dtls_->SignalEvent.connect(this, &DtlsTransport::OnDtlsEvent); + dtls_->SetEventCallback( + [this](int events, int err) { OnDtlsEvent(events, err); }); if (remote_fingerprint_value_.size() && !dtls_->SetPeerCertificateDigest( remote_fingerprint_algorithm_, @@ -698,9 +699,8 @@ void DtlsTransport::OnReadyToSend(rtc::PacketTransportInternal* transport) { } } -void DtlsTransport::OnDtlsEvent(rtc::StreamInterface* dtls, int sig, int err) { +void DtlsTransport::OnDtlsEvent(int sig, int err) { RTC_DCHECK_RUN_ON(&thread_checker_); - RTC_DCHECK(dtls == dtls_.get()); if (sig & rtc::SE_OPEN) { // This is the first time. RTC_LOG(LS_INFO) << ToString() << ": DTLS handshake complete."; diff --git a/p2p/base/dtls_transport.h b/p2p/base/dtls_transport.h index 109dbf58c9..ceeaa8430d 100644 --- a/p2p/base/dtls_transport.h +++ b/p2p/base/dtls_transport.h @@ -221,7 +221,7 @@ class DtlsTransport : public DtlsTransportInternal { const rtc::SentPacket& sent_packet); void OnReadyToSend(rtc::PacketTransportInternal* transport); void OnReceivingState(rtc::PacketTransportInternal* transport); - void OnDtlsEvent(rtc::StreamInterface* stream_, int sig, int err); + void OnDtlsEvent(int sig, int err); void OnNetworkRouteChanged(absl::optional network_route); bool SetupDtls(); void MaybeStartDtls(); diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn index a7a2adadfc..f761c36400 100644 --- a/rtc_base/BUILD.gn +++ b/rtc_base/BUILD.gn @@ -1464,6 +1464,8 @@ rtc_library("stream") { "system:rtc_export", "third_party/sigslot", ] + + absl_deps = [ "//third_party/abseil-cpp/absl/functional:any_invocable" ] } rtc_library("rtc_certificate_generator") { @@ -2131,6 +2133,7 @@ if (rtc_include_tests) { "ssl_identity_unittest.cc", "ssl_stream_adapter_unittest.cc", ] + deps += [ ":callback_list" ] } absl_deps = [ "//third_party/abseil-cpp/absl/algorithm:container", diff --git a/rtc_base/openssl_stream_adapter.cc b/rtc_base/openssl_stream_adapter.cc index 357510ba10..28da5d4530 100644 --- a/rtc_base/openssl_stream_adapter.cc +++ b/rtc_base/openssl_stream_adapter.cc @@ -296,7 +296,8 @@ OpenSSLStreamAdapter::OpenSSLStreamAdapter( #endif ssl_mode_(SSL_MODE_TLS), ssl_max_version_(SSL_PROTOCOL_TLS_12) { - stream_->SignalEvent.connect(this, &OpenSSLStreamAdapter::OnEvent); + stream_->SetEventCallback( + [this](int events, int err) { OnEvent(events, err); }); } OpenSSLStreamAdapter::~OpenSSLStreamAdapter() { @@ -741,13 +742,10 @@ StreamState OpenSSLStreamAdapter::GetState() const { // not reached } -void OpenSSLStreamAdapter::OnEvent(StreamInterface* stream, - int events, - int err) { +void OpenSSLStreamAdapter::OnEvent(int events, int err) { RTC_DCHECK_RUN_ON(&callback_sequence_); int events_to_signal = 0; int signal_error = 0; - RTC_DCHECK(stream == stream_.get()); if ((events & SE_OPEN)) { RTC_DLOG(LS_VERBOSE) << "OpenSSLStreamAdapter::OnEvent SE_OPEN"; diff --git a/rtc_base/openssl_stream_adapter.h b/rtc_base/openssl_stream_adapter.h index 3ef1363ed5..c3558b35fd 100644 --- a/rtc_base/openssl_stream_adapter.h +++ b/rtc_base/openssl_stream_adapter.h @@ -148,7 +148,7 @@ class OpenSSLStreamAdapter final : public SSLStreamAdapter, SSL_CLOSED // Clean close }; - void OnEvent(StreamInterface* stream, int events, int err); + void OnEvent(int events, int err); void PostEvent(int events, int err); void SetTimeout(int delay_ms); diff --git a/rtc_base/ssl_adapter_unittest.cc b/rtc_base/ssl_adapter_unittest.cc index 084594f6b9..ec407c531e 100644 --- a/rtc_base/ssl_adapter_unittest.cc +++ b/rtc_base/ssl_adapter_unittest.cc @@ -321,7 +321,7 @@ class SSLAdapterTestDummyServer : public sigslot::has_slots<> { DoHandshake(server_socket_->Accept(nullptr)); } - void OnSSLStreamAdapterEvent(rtc::StreamInterface* stream, int sig, int err) { + void OnSSLStreamAdapterEvent(int sig, int err) { if (sig & rtc::SE_READ) { uint8_t buffer[4096] = ""; size_t read; @@ -329,7 +329,7 @@ class SSLAdapterTestDummyServer : public sigslot::has_slots<> { // Read data received from the client and store it in our internal // buffer. - rtc::StreamResult r = stream->Read(buffer, read, error); + rtc::StreamResult r = ssl_stream_adapter_->Read(buffer, read, error); if (r == rtc::SR_SUCCESS) { buffer[read] = '\0'; // Here we assume that the buffer is interpretable as string. @@ -365,8 +365,8 @@ class SSLAdapterTestDummyServer : public sigslot::has_slots<> { ssl_stream_adapter_->StartSSL(); - ssl_stream_adapter_->SignalEvent.connect( - this, &SSLAdapterTestDummyServer::OnSSLStreamAdapterEvent); + ssl_stream_adapter_->SetEventCallback( + [this](int events, int err) { OnSSLStreamAdapterEvent(events, err); }); } const rtc::SSLMode ssl_mode_; diff --git a/rtc_base/ssl_stream_adapter_unittest.cc b/rtc_base/ssl_stream_adapter_unittest.cc index fc6532c1f0..6d76c7b0d2 100644 --- a/rtc_base/ssl_stream_adapter_unittest.cc +++ b/rtc_base/ssl_stream_adapter_unittest.cc @@ -20,6 +20,7 @@ #include "api/array_view.h" #include "api/task_queue/pending_task_safety_flag.h" #include "rtc_base/buffer_queue.h" +#include "rtc_base/callback_list.h" #include "rtc_base/checks.h" #include "rtc_base/gunit.h" #include "rtc_base/helpers.h" @@ -149,17 +150,75 @@ static const char kCACert[] = class SSLStreamAdapterTestBase; -class SSLDummyStreamBase : public rtc::StreamInterface, - public sigslot::has_slots<> { +// StreamWrapper is a middle layer between `stream`, which supports a single +// event callback, and test classes in this file that need that event forwarded +// to them. I.e. this class wraps a `stream` object that it delegates all calls +// to, but for the event callback, `StreamWrapper` additionally provides support +// for forwarding event notifications to test classes that call +// `SubscribeStreamEvent()`. +// +// This is needed because in this file, tests connect both client and server +// streams (SSLDummyStream) to the same underlying `stream` objects +// (see CreateClientStream() and CreateServerStream()). +class StreamWrapper : public rtc::StreamInterface { public: - SSLDummyStreamBase(SSLStreamAdapterTestBase* test, - absl::string_view side, - rtc::StreamInterface* in, - rtc::StreamInterface* out) + explicit StreamWrapper(std::unique_ptr stream) + : stream_(std::move(stream)) { + stream_->SetEventCallback([this](int events, int err) { + RTC_DCHECK_RUN_ON(&callback_sequence_); + callbacks_.Send(events, err); + FireEvent(events, err); + }); + } + + template + void SubscribeStreamEvent(const void* removal_tag, F&& callback) { + callbacks_.AddReceiver(removal_tag, std::forward(callback)); + } + + void UnsubscribeStreamEvent(const void* removal_tag) { + callbacks_.RemoveReceivers(removal_tag); + } + + rtc::StreamState GetState() const override { return stream_->GetState(); } + + void Close() override { stream_->Close(); } + + rtc::StreamResult Read(rtc::ArrayView buffer, + size_t& read, + int& error) override { + return stream_->Read(buffer, read, error); + } + + rtc::StreamResult Write(rtc::ArrayView data, + size_t& written, + int& error) override { + return stream_->Write(data, written, error); + } + + private: + const std::unique_ptr stream_; + webrtc::CallbackList callbacks_; +}; + +class SSLDummyStream final : public rtc::StreamInterface { + public: + SSLDummyStream(SSLStreamAdapterTestBase* test, + absl::string_view side, + StreamWrapper* in, + StreamWrapper* out) : test_base_(test), side_(side), in_(in), out_(out), first_packet_(true) { - RTC_DCHECK_NE(in, out); - in_->SignalEvent.connect(this, &SSLDummyStreamBase::OnEventIn); - out_->SignalEvent.connect(this, &SSLDummyStreamBase::OnEventOut); + RTC_CHECK(thread_); + RTC_CHECK_NE(in, out); + in_->SubscribeStreamEvent( + this, [this](int events, int err) { OnEventIn(events, err); }); + out_->SubscribeStreamEvent( + this, [this](int events, int err) { OnEventOut(events, err); }); + } + + ~SSLDummyStream() override { + in_->UnsubscribeStreamEvent(this); + out_->UnsubscribeStreamEvent(this); } rtc::StreamState GetState() const override { return rtc::SS_OPEN; } @@ -184,20 +243,20 @@ class SSLDummyStreamBase : public rtc::StreamInterface, } // Catch readability events on in and pass them up. - void OnEventIn(rtc::StreamInterface* stream, int sig, int err) { + void OnEventIn(int sig, int err) { int mask = (rtc::SE_READ | rtc::SE_CLOSE); if (sig & mask) { - RTC_LOG(LS_VERBOSE) << "SSLDummyStreamBase::OnEventIn side=" << side_ + RTC_LOG(LS_VERBOSE) << "SSLDummyStream::OnEventIn side=" << side_ << " sig=" << sig << " forwarding upward"; PostEvent(sig & mask, 0); } } // Catch writeability events on out and pass them up. - void OnEventOut(rtc::StreamInterface* stream, int sig, int err) { + void OnEventOut(int sig, int err) { if (sig & rtc::SE_WRITE) { - RTC_LOG(LS_VERBOSE) << "SSLDummyStreamBase::OnEventOut side=" << side_ + RTC_LOG(LS_VERBOSE) << "SSLDummyStream::OnEventOut side=" << side_ << " sig=" << sig << " forwarding upward"; PostEvent(sig & rtc::SE_WRITE, 0); @@ -232,20 +291,11 @@ class SSLDummyStreamBase : public rtc::StreamInterface, rtc::Thread* const thread_ = rtc::Thread::Current(); SSLStreamAdapterTestBase* test_base_; const std::string side_; - rtc::StreamInterface* in_; - rtc::StreamInterface* out_; + StreamWrapper* const in_; + StreamWrapper* const out_; bool first_packet_; }; -class SSLDummyStreamTLS : public SSLDummyStreamBase { - public: - SSLDummyStreamTLS(SSLStreamAdapterTestBase* test, - absl::string_view side, - rtc::FifoBuffer* in, - rtc::FifoBuffer* out) - : SSLDummyStreamBase(test, side, in, out) {} -}; - class BufferQueueStream : public rtc::StreamInterface { public: BufferQueueStream(size_t capacity, size_t default_size) @@ -304,15 +354,6 @@ class BufferQueueStream : public rtc::StreamInterface { rtc::BufferQueue buffer_; }; -class SSLDummyStreamDTLS : public SSLDummyStreamBase { - public: - SSLDummyStreamDTLS(SSLStreamAdapterTestBase* test, - absl::string_view side, - BufferQueueStream* in, - BufferQueueStream* out) - : SSLDummyStreamBase(test, side, in, out) {} -}; - static const int kFifoBufferSize = 4096; static const int kBufferCapacity = 1; static const size_t kDefaultBufferSize = 2048; @@ -391,11 +432,10 @@ class SSLStreamAdapterTestBase : public ::testing::Test, : new ScopedFieldTrials(server_experiment)); server_ssl_ = rtc::SSLStreamAdapter::Create(CreateServerStream()); } - - client_ssl_->SignalEvent.connect(this, - &SSLStreamAdapterTestBase::OnClientEvent); - server_ssl_->SignalEvent.connect(this, - &SSLStreamAdapterTestBase::OnServerEvent); + client_ssl_->SetEventCallback( + [this](int events, int err) { OnClientEvent(events, err); }); + server_ssl_->SetEventCallback( + [this](int events, int err) { OnServerEvent(events, err); }); } // Recreate the client/server identities with the specified validity period. @@ -648,7 +688,7 @@ class SSLStreamAdapterTestBase : public ::testing::Test, } } - rtc::StreamResult DataWritten(SSLDummyStreamBase* from, + rtc::StreamResult DataWritten(SSLDummyStream* from, const void* data, size_t data_len, size_t& written, @@ -756,13 +796,12 @@ class SSLStreamAdapterTestBase : public ::testing::Test, virtual void TestTransfer(int size) = 0; private: - void OnClientEvent(rtc::StreamInterface* stream, int sig, int err) { - RTC_DCHECK_EQ(stream, client_ssl_.get()); + void OnClientEvent(int sig, int err) { RTC_LOG(LS_VERBOSE) << "SSLStreamAdapterTestBase::OnClientEvent sig=" << sig; if (sig & rtc::SE_READ) { - ReadData(stream); + ReadData(client_ssl_.get()); } if (sig & rtc::SE_WRITE) { @@ -770,12 +809,11 @@ class SSLStreamAdapterTestBase : public ::testing::Test, } } - void OnServerEvent(rtc::StreamInterface* stream, int sig, int err) { - RTC_DCHECK_EQ(stream, server_ssl_.get()); + void OnServerEvent(int sig, int err) { RTC_LOG(LS_VERBOSE) << "SSLStreamAdapterTestBase::OnServerEvent sig=" << sig; if (sig & rtc::SE_READ) { - ReadData(stream); + ReadData(server_ssl_.get()); } } @@ -819,18 +857,16 @@ class SSLStreamAdapterTestTLS "", false, ::testing::get<0>(GetParam()), - ::testing::get<1>(GetParam())), - client_buffer_(kFifoBufferSize), - server_buffer_(kFifoBufferSize) {} + ::testing::get<1>(GetParam())) {} std::unique_ptr CreateClientStream() override final { return absl::WrapUnique( - new SSLDummyStreamTLS(this, "c2s", &client_buffer_, &server_buffer_)); + new SSLDummyStream(this, "c2s", &client_buffer_, &server_buffer_)); } std::unique_ptr CreateServerStream() override final { return absl::WrapUnique( - new SSLDummyStreamTLS(this, "s2c", &server_buffer_, &client_buffer_)); + new SSLDummyStream(this, "s2c", &server_buffer_, &client_buffer_)); } // Test data transfer for TLS @@ -930,8 +966,10 @@ class SSLStreamAdapterTestTLS } private: - rtc::FifoBuffer client_buffer_; - rtc::FifoBuffer server_buffer_; + StreamWrapper client_buffer_{ + std::make_unique(kFifoBufferSize)}; + StreamWrapper server_buffer_{ + std::make_unique(kFifoBufferSize)}; rtc::MemoryStream send_stream_; rtc::MemoryStream recv_stream_; }; @@ -940,8 +978,6 @@ class SSLStreamAdapterTestDTLSBase : public SSLStreamAdapterTestBase { public: SSLStreamAdapterTestDTLSBase(rtc::KeyParams param1, rtc::KeyParams param2) : SSLStreamAdapterTestBase("", "", true, param1, param2), - client_buffer_(kBufferCapacity, kDefaultBufferSize), - server_buffer_(kBufferCapacity, kDefaultBufferSize), packet_size_(1000), count_(0), sent_(0) {} @@ -949,20 +985,18 @@ class SSLStreamAdapterTestDTLSBase : public SSLStreamAdapterTestBase { SSLStreamAdapterTestDTLSBase(absl::string_view cert_pem, absl::string_view private_key_pem) : SSLStreamAdapterTestBase(cert_pem, private_key_pem, true), - client_buffer_(kBufferCapacity, kDefaultBufferSize), - server_buffer_(kBufferCapacity, kDefaultBufferSize), packet_size_(1000), count_(0), sent_(0) {} std::unique_ptr CreateClientStream() override final { return absl::WrapUnique( - new SSLDummyStreamDTLS(this, "c2s", &client_buffer_, &server_buffer_)); + new SSLDummyStream(this, "c2s", &client_buffer_, &server_buffer_)); } std::unique_ptr CreateServerStream() override final { return absl::WrapUnique( - new SSLDummyStreamDTLS(this, "s2c", &server_buffer_, &client_buffer_)); + new SSLDummyStream(this, "s2c", &server_buffer_, &client_buffer_)); } void WriteData() override { @@ -1052,8 +1086,10 @@ class SSLStreamAdapterTestDTLSBase : public SSLStreamAdapterTestBase { } protected: - BufferQueueStream client_buffer_; - BufferQueueStream server_buffer_; + StreamWrapper client_buffer_{ + std::make_unique(kBufferCapacity, kDefaultBufferSize)}; + StreamWrapper server_buffer_{ + std::make_unique(kBufferCapacity, kDefaultBufferSize)}; private: size_t packet_size_; @@ -1075,9 +1111,9 @@ class SSLStreamAdapterTestDTLS : SSLStreamAdapterTestDTLSBase(cert_pem, private_key_pem) {} }; -rtc::StreamResult SSLDummyStreamBase::Write(rtc::ArrayView data, - size_t& written, - int& error) { +rtc::StreamResult SSLDummyStream::Write(rtc::ArrayView data, + size_t& written, + int& error) { RTC_LOG(LS_VERBOSE) << "Writing to loopback " << data.size(); if (first_packet_) { diff --git a/rtc_base/stream.h b/rtc_base/stream.h index 4b2236a86e..8eb800c4f9 100644 --- a/rtc_base/stream.h +++ b/rtc_base/stream.h @@ -13,9 +13,11 @@ #include +#include "absl/functional/any_invocable.h" #include "api/array_view.h" #include "api/sequence_checker.h" #include "rtc_base/buffer.h" +#include "rtc_base/logging.h" #include "rtc_base/system/no_unique_address.h" #include "rtc_base/system/rtc_export.h" #include "rtc_base/third_party/sigslot/sigslot.h" @@ -83,15 +85,24 @@ class RTC_EXPORT StreamInterface { // signalled as a result of this call. virtual void Close() = 0; - // Streams may signal one or more StreamEvents to indicate state changes. - // The first argument identifies the stream on which the state change occured. - // The second argument is a bit-wise combination of StreamEvents. - // If SE_CLOSE is signalled, then the third argument is the associated error - // code. Otherwise, the value is undefined. - // Note: Not all streams will support asynchronous event signalling. However, - // SS_OPENING and SR_BLOCK returned from stream member functions imply that - // certain events will be raised in the future. - sigslot::signal3 SignalEvent; + // Streams may issue one or more events to indicate state changes to a + // provided callback. + // The first argument is a bit-wise combination of `StreamEvent` flags. + // If SE_CLOSE is set, then the second argument is the associated error code. + // Otherwise, the value of the second parameter is undefined and should be + // set to 0. + // Note: Not all streams support callbacks. However, SS_OPENING and + // SR_BLOCK returned from member functions imply that certain callbacks will + // be made in the future. + void SetEventCallback(absl::AnyInvocable callback) { + RTC_DCHECK_RUN_ON(&callback_sequence_); + RTC_DCHECK(!callback_ || !callback); + callback_ = std::move(callback); + } + + // TODO(bugs.webrtc.org/11943): Remove after updating downstream code. + sigslot::signal3 SignalEvent + [[deprecated("Use SetEventCallback instead")]]; // Return true if flush is successful. virtual bool Flush(); @@ -126,13 +137,23 @@ class RTC_EXPORT StreamInterface { // Utility function for derived classes. void FireEvent(int stream_events, int err) RTC_RUN_ON(&callback_sequence_) { + if (callback_) { + callback_(stream_events, err); + } +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" // TODO(tommi): This is for backwards compatibility only while `SignalEvent` - // is being replaced by `SetEventHandler`. + // is being replaced by `SetEventCallback`. SignalEvent(this, stream_events, err); +#pragma clang diagnostic pop } RTC_NO_UNIQUE_ADDRESS webrtc::SequenceChecker callback_sequence_{ webrtc::SequenceChecker::kDetached}; + + private: + absl::AnyInvocable callback_ + RTC_GUARDED_BY(&callback_sequence_) = nullptr; }; } // namespace rtc