diff --git a/p2p/base/transport_description_factory.h b/p2p/base/transport_description_factory.h index b4d8822cd2..11352f88b4 100644 --- a/p2p/base/transport_description_factory.h +++ b/p2p/base/transport_description_factory.h @@ -12,6 +12,7 @@ #define P2P_BASE_TRANSPORT_DESCRIPTION_FACTORY_H_ #include +#include #include "api/field_trials_view.h" #include "p2p/base/ice_credentials_iterator.h" @@ -51,9 +52,8 @@ class TransportDescriptionFactory { // Specifies the transport security policy to use. void set_secure(SecurePolicy s) { secure_ = s; } // Specifies the certificate to use (only used when secure != SEC_DISABLED). - void set_certificate( - const rtc::scoped_refptr& certificate) { - certificate_ = certificate; + void set_certificate(rtc::scoped_refptr certificate) { + certificate_ = std::move(certificate); } // Creates a transport description suitable for use in an offer. diff --git a/pc/sdp_offer_answer.cc b/pc/sdp_offer_answer.cc index 3ea8240350..3c8c5062ce 100644 --- a/pc/sdp_offer_answer.cc +++ b/pc/sdp_offer_answer.cc @@ -1221,7 +1221,7 @@ void SdpOfferAnswerHandler::Initialize( webrtc_session_desc_factory_ = std::make_unique( context, this, pc_->session_id(), pc_->dtls_enabled(), - std::move(dependencies.cert_generator), certificate, + std::move(dependencies.cert_generator), std::move(certificate), [this](const rtc::scoped_refptr& certificate) { RTC_DCHECK_RUN_ON(signaling_thread()); transport_controller_s()->SetLocalCertificate(certificate); diff --git a/pc/test/fake_rtc_certificate_generator.h b/pc/test/fake_rtc_certificate_generator.h index 5f0667b26f..61da26a12f 100644 --- a/pc/test/fake_rtc_certificate_generator.h +++ b/pc/test/fake_rtc_certificate_generator.h @@ -135,11 +135,9 @@ class FakeRTCCertificateGenerator int generated_certificates() { return generated_certificates_; } int generated_failures() { return generated_failures_; } - void GenerateCertificateAsync( - const rtc::KeyParams& key_params, - const absl::optional& expires_ms, - const rtc::scoped_refptr& callback) - override { + void GenerateCertificateAsync(const rtc::KeyParams& key_params, + const absl::optional& expires_ms, + Callback callback) override { // The certificates are created from constant PEM strings and use its coded // expiration time, we do not support modifying it. RTC_DCHECK(!expires_ms); @@ -154,7 +152,7 @@ class FakeRTCCertificateGenerator } rtc::KeyType key_type = key_params.type(); webrtc::TaskQueueBase::Current()->PostTask( - [this, key_type, callback]() mutable { + [this, key_type, callback = std::move(callback)]() mutable { GenerateCertificate(key_type, std::move(callback)); }); } @@ -190,9 +188,7 @@ class FakeRTCCertificateGenerator return get_pem(key_type).certificate(); } - void GenerateCertificate( - rtc::KeyType key_type, - rtc::scoped_refptr callback) { + void GenerateCertificate(rtc::KeyType key_type, Callback callback) { // If the certificate generation should be stalled, re-post this same // message to the queue with a small delay so as to wait in a loop until // set_should_wait(false) is called. @@ -206,13 +202,13 @@ class FakeRTCCertificateGenerator } if (should_fail_) { ++generated_failures_; - callback->OnFailure(); + std::move(callback)(nullptr); } else { rtc::scoped_refptr certificate = rtc::RTCCertificate::FromPEM(get_pem(key_type)); RTC_DCHECK(certificate); ++generated_certificates_; - callback->OnSuccess(certificate); + std::move(callback)(std::move(certificate)); } } diff --git a/pc/webrtc_session_description_factory.cc b/pc/webrtc_session_description_factory.cc index eb7607ec85..363a7f71a0 100644 --- a/pc/webrtc_session_description_factory.cc +++ b/pc/webrtc_session_description_factory.cc @@ -85,29 +85,6 @@ struct CreateSessionDescriptionMsg : public rtc::MessageData { }; } // namespace -class WebRtcSessionDescriptionFactory::WebRtcCertificateGeneratorCallback - : public rtc::RTCCertificateGeneratorCallback { - public: - explicit WebRtcCertificateGeneratorCallback( - rtc::WeakPtr ptr) - : weak_ptr_(std::move(ptr)) {} - // `rtc::RTCCertificateGeneratorCallback` overrides. - void OnSuccess( - const rtc::scoped_refptr& certificate) override { - if (weak_ptr_) { - weak_ptr_->SetCertificate(certificate); - } - } - void OnFailure() override { - if (weak_ptr_) { - weak_ptr_->OnCertificateRequestFailed(); - } - } - - private: - rtc::WeakPtr weak_ptr_; -}; - // static void WebRtcSessionDescriptionFactory::CopyCandidatesFromSessionDescription( const SessionDescriptionInterface* source_desc, @@ -145,7 +122,7 @@ WebRtcSessionDescriptionFactory::WebRtcSessionDescriptionFactory( const std::string& session_id, bool dtls_enabled, std::unique_ptr cert_generator, - const rtc::scoped_refptr& certificate, + rtc::scoped_refptr certificate, std::function&)> on_certificate_ready, const FieldTrialsView& field_trials) @@ -186,18 +163,26 @@ WebRtcSessionDescriptionFactory::WebRtcSessionDescriptionFactory( // Generate certificate. certificate_request_state_ = CERTIFICATE_WAITING; - auto callback = rtc::make_ref_counted( - weak_factory_.GetWeakPtr()); + auto callback = [weak_ptr = weak_factory_.GetWeakPtr()]( + rtc::scoped_refptr certificate) { + if (!weak_ptr) { + return; + } + if (certificate) { + weak_ptr->SetCertificate(std::move(certificate)); + } else { + weak_ptr->OnCertificateRequestFailed(); + } + }; rtc::KeyParams key_params = rtc::KeyParams(); RTC_LOG(LS_VERBOSE) << "DTLS-SRTP enabled; sending DTLS identity request (key type: " << key_params.type() << ")."; - // Request certificate. This happens asynchronously, so that the caller gets - // a chance to connect to `SignalCertificateReady`. + // Request certificate. This happens asynchronously on a different thread. cert_generator_->GenerateCertificateAsync(key_params, absl::nullopt, - callback); + std::move(callback)); } } @@ -477,7 +462,7 @@ void WebRtcSessionDescriptionFactory::OnCertificateRequestFailed() { } void WebRtcSessionDescriptionFactory::SetCertificate( - const rtc::scoped_refptr& certificate) { + rtc::scoped_refptr certificate) { RTC_DCHECK(certificate); RTC_LOG(LS_VERBOSE) << "Setting new certificate."; @@ -485,7 +470,7 @@ void WebRtcSessionDescriptionFactory::SetCertificate( on_certificate_ready_(certificate); - transport_desc_factory_.set_certificate(certificate); + transport_desc_factory_.set_certificate(std::move(certificate)); transport_desc_factory_.set_secure(cricket::SEC_ENABLED); while (!create_session_description_requests_.empty()) { diff --git a/pc/webrtc_session_description_factory.h b/pc/webrtc_session_description_factory.h index e105f1666a..f85f712e59 100644 --- a/pc/webrtc_session_description_factory.h +++ b/pc/webrtc_session_description_factory.h @@ -51,7 +51,7 @@ class WebRtcSessionDescriptionFactory : public rtc::MessageHandler { const std::string& session_id, bool dtls_enabled, std::unique_ptr cert_generator, - const rtc::scoped_refptr& certificate, + rtc::scoped_refptr certificate, std::function&)> on_certificate_ready, const FieldTrialsView& field_trials); @@ -98,9 +98,6 @@ class WebRtcSessionDescriptionFactory : public rtc::MessageHandler { CERTIFICATE_FAILED, }; - // DTLS certificate request callback class. - class WebRtcCertificateGeneratorCallback; - struct CreateSessionDescriptionRequest { enum Type { kOffer, @@ -132,8 +129,7 @@ class WebRtcSessionDescriptionFactory : public rtc::MessageHandler { std::unique_ptr description); void OnCertificateRequestFailed(); - void SetCertificate( - const rtc::scoped_refptr& certificate); + void SetCertificate(rtc::scoped_refptr certificate); std::queue create_session_description_requests_; diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn index 27a790ed80..078c9ea151 100644 --- a/rtc_base/BUILD.gn +++ b/rtc_base/BUILD.gn @@ -1128,6 +1128,7 @@ rtc_library("rtc_base") { absl_deps = [ "//third_party/abseil-cpp/absl/algorithm:container", "//third_party/abseil-cpp/absl/container:flat_hash_map", + "//third_party/abseil-cpp/absl/functional:any_invocable", "//third_party/abseil-cpp/absl/memory", "//third_party/abseil-cpp/absl/strings", "//third_party/abseil-cpp/absl/types:optional", diff --git a/rtc_base/rtc_certificate_generator.cc b/rtc_base/rtc_certificate_generator.cc index d2856f7e3c..739890e2b8 100644 --- a/rtc_base/rtc_certificate_generator.cc +++ b/rtc_base/rtc_certificate_generator.cc @@ -71,21 +71,18 @@ RTCCertificateGenerator::RTCCertificateGenerator(Thread* signaling_thread, void RTCCertificateGenerator::GenerateCertificateAsync( const KeyParams& key_params, const absl::optional& expires_ms, - const scoped_refptr& callback) { + RTCCertificateGenerator::Callback callback) { RTC_DCHECK(signaling_thread_->IsCurrent()); RTC_DCHECK(callback); - // Create a new `RTCCertificateGenerationTask` for this generation request. It - // is reference counted and referenced by the message data, ensuring it lives - // until the task has completed (independent of `RTCCertificateGenerator`). worker_thread_->PostTask([key_params, expires_ms, signaling_thread = signaling_thread_, - cb = callback]() { + cb = std::move(callback)]() mutable { scoped_refptr certificate = RTCCertificateGenerator::GenerateCertificate(key_params, expires_ms); signaling_thread->PostTask( - [cert = std::move(certificate), cb = std::move(cb)]() { - cert ? cb->OnSuccess(cert) : cb->OnFailure(); + [cert = std::move(certificate), cb = std::move(cb)]() mutable { + std::move(cb)(std::move(cert)); }); }); } diff --git a/rtc_base/rtc_certificate_generator.h b/rtc_base/rtc_certificate_generator.h index 065b8b5002..a881f1a369 100644 --- a/rtc_base/rtc_certificate_generator.h +++ b/rtc_base/rtc_certificate_generator.h @@ -13,9 +13,9 @@ #include +#include "absl/functional/any_invocable.h" #include "absl/types/optional.h" #include "api/scoped_refptr.h" -#include "rtc_base/ref_count.h" #include "rtc_base/rtc_certificate.h" #include "rtc_base/ssl_identity.h" #include "rtc_base/system/rtc_export.h" @@ -23,21 +23,15 @@ namespace rtc { -// See `RTCCertificateGeneratorInterface::GenerateCertificateAsync`. -class RTCCertificateGeneratorCallback : public RefCountInterface { - public: - virtual void OnSuccess(const scoped_refptr& certificate) = 0; - virtual void OnFailure() = 0; - - protected: - ~RTCCertificateGeneratorCallback() override {} -}; - // Generates `RTCCertificate`s. // See `RTCCertificateGenerator` for the WebRTC repo's implementation. class RTCCertificateGeneratorInterface { public: - virtual ~RTCCertificateGeneratorInterface() {} + // Functor that will be called when certificate is generated asynchroniosly. + // Called with nullptr as the parameter on failure. + using Callback = absl::AnyInvocable) &&>; + + virtual ~RTCCertificateGeneratorInterface() = default; // Generates a certificate asynchronously on the worker thread. // Must be called on the signaling thread. The `callback` is invoked with the @@ -47,7 +41,7 @@ class RTCCertificateGeneratorInterface { virtual void GenerateCertificateAsync( const KeyParams& key_params, const absl::optional& expires_ms, - const scoped_refptr& callback) = 0; + Callback callback) = 0; }; // Standard implementation of `RTCCertificateGeneratorInterface`. @@ -74,10 +68,9 @@ class RTC_EXPORT RTCCertificateGenerator // that many milliseconds from now. `expires_ms` is limited to a year, a // larger value than that is clamped down to a year. If `expires_ms` is not // specified, a default expiration time is used. - void GenerateCertificateAsync( - const KeyParams& key_params, - const absl::optional& expires_ms, - const scoped_refptr& callback) override; + void GenerateCertificateAsync(const KeyParams& key_params, + const absl::optional& expires_ms, + Callback callback) override; private: Thread* const signaling_thread_; diff --git a/rtc_base/rtc_certificate_generator_unittest.cc b/rtc_base/rtc_certificate_generator_unittest.cc index 3d9df5875b..fb7ec913e5 100644 --- a/rtc_base/rtc_certificate_generator_unittest.cc +++ b/rtc_base/rtc_certificate_generator_unittest.cc @@ -21,7 +21,7 @@ namespace rtc { -class RTCCertificateGeneratorFixture : public RTCCertificateGeneratorCallback { +class RTCCertificateGeneratorFixture { public: RTCCertificateGeneratorFixture() : signaling_thread_(Thread::Current()), @@ -32,21 +32,16 @@ class RTCCertificateGeneratorFixture : public RTCCertificateGeneratorCallback { generator_.reset( new RTCCertificateGenerator(signaling_thread_, worker_thread_.get())); } - ~RTCCertificateGeneratorFixture() override {} RTCCertificateGenerator* generator() const { return generator_.get(); } RTCCertificate* certificate() const { return certificate_.get(); } - void OnSuccess(const scoped_refptr& certificate) override { - RTC_CHECK(signaling_thread_->IsCurrent()); - RTC_CHECK(certificate); - certificate_ = certificate; - generate_async_completed_ = true; - } - void OnFailure() override { - RTC_CHECK(signaling_thread_->IsCurrent()); - certificate_ = nullptr; - generate_async_completed_ = true; + RTCCertificateGeneratorInterface::Callback OnGenerated() { + return [this](scoped_refptr certificate) mutable { + RTC_CHECK(signaling_thread_->IsCurrent()); + certificate_ = std::move(certificate); + generate_async_completed_ = true; + }; } bool GenerateAsyncCompleted() { @@ -69,14 +64,11 @@ class RTCCertificateGeneratorFixture : public RTCCertificateGeneratorCallback { class RTCCertificateGeneratorTest : public ::testing::Test { public: - RTCCertificateGeneratorTest() - : fixture_(make_ref_counted()) {} - protected: static constexpr int kGenerationTimeoutMs = 10000; rtc::AutoThread main_thread_; - scoped_refptr fixture_; + RTCCertificateGeneratorFixture fixture_; }; TEST_F(RTCCertificateGeneratorTest, GenerateECDSA) { @@ -90,16 +82,16 @@ TEST_F(RTCCertificateGeneratorTest, GenerateRSA) { } TEST_F(RTCCertificateGeneratorTest, GenerateAsyncECDSA) { - EXPECT_FALSE(fixture_->certificate()); - fixture_->generator()->GenerateCertificateAsync(KeyParams::ECDSA(), - absl::nullopt, fixture_); + EXPECT_FALSE(fixture_.certificate()); + fixture_.generator()->GenerateCertificateAsync( + KeyParams::ECDSA(), absl::nullopt, fixture_.OnGenerated()); // Until generation has completed, the certificate is null. Since this is an // async call, generation must not have completed until we process messages // posted to this thread (which is done by `EXPECT_TRUE_WAIT`). - EXPECT_FALSE(fixture_->GenerateAsyncCompleted()); - EXPECT_FALSE(fixture_->certificate()); - EXPECT_TRUE_WAIT(fixture_->GenerateAsyncCompleted(), kGenerationTimeoutMs); - EXPECT_TRUE(fixture_->certificate()); + EXPECT_FALSE(fixture_.GenerateAsyncCompleted()); + EXPECT_FALSE(fixture_.certificate()); + EXPECT_TRUE_WAIT(fixture_.GenerateAsyncCompleted(), kGenerationTimeoutMs); + EXPECT_TRUE(fixture_.certificate()); } TEST_F(RTCCertificateGeneratorTest, GenerateWithExpires) { @@ -136,10 +128,10 @@ TEST_F(RTCCertificateGeneratorTest, GenerateWithInvalidParamsShouldFail) { EXPECT_FALSE(RTCCertificateGenerator::GenerateCertificate(invalid_params, absl::nullopt)); - fixture_->generator()->GenerateCertificateAsync(invalid_params, absl::nullopt, - fixture_); - EXPECT_TRUE_WAIT(fixture_->GenerateAsyncCompleted(), kGenerationTimeoutMs); - EXPECT_FALSE(fixture_->certificate()); + fixture_.generator()->GenerateCertificateAsync(invalid_params, absl::nullopt, + fixture_.OnGenerated()); + EXPECT_TRUE_WAIT(fixture_.GenerateAsyncCompleted(), kGenerationTimeoutMs); + EXPECT_FALSE(fixture_.certificate()); } } // namespace rtc