Replace RTCCertificateGeneratorCallback interface with an AnyInvocable

follow up of the https://webrtc-review.googlesource.com/c/src/+/272402

Bug: None
Change-Id: Ie47aff9fccdb4037c1f560801c780dd549b373ae
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/272553
Commit-Queue: Danil Chapovalov <danilchap@webrtc.org>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#37870}
This commit is contained in:
Danil Chapovalov 2022-08-22 16:39:34 +02:00 committed by WebRTC LUCI CQ
parent 2020767ddd
commit b7da81621c
9 changed files with 63 additions and 103 deletions

View file

@ -12,6 +12,7 @@
#define P2P_BASE_TRANSPORT_DESCRIPTION_FACTORY_H_ #define P2P_BASE_TRANSPORT_DESCRIPTION_FACTORY_H_
#include <memory> #include <memory>
#include <utility>
#include "api/field_trials_view.h" #include "api/field_trials_view.h"
#include "p2p/base/ice_credentials_iterator.h" #include "p2p/base/ice_credentials_iterator.h"
@ -51,9 +52,8 @@ class TransportDescriptionFactory {
// Specifies the transport security policy to use. // Specifies the transport security policy to use.
void set_secure(SecurePolicy s) { secure_ = s; } void set_secure(SecurePolicy s) { secure_ = s; }
// Specifies the certificate to use (only used when secure != SEC_DISABLED). // Specifies the certificate to use (only used when secure != SEC_DISABLED).
void set_certificate( void set_certificate(rtc::scoped_refptr<rtc::RTCCertificate> certificate) {
const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) { certificate_ = std::move(certificate);
certificate_ = certificate;
} }
// Creates a transport description suitable for use in an offer. // Creates a transport description suitable for use in an offer.

View file

@ -1221,7 +1221,7 @@ void SdpOfferAnswerHandler::Initialize(
webrtc_session_desc_factory_ = webrtc_session_desc_factory_ =
std::make_unique<WebRtcSessionDescriptionFactory>( std::make_unique<WebRtcSessionDescriptionFactory>(
context, this, pc_->session_id(), pc_->dtls_enabled(), context, this, pc_->session_id(), pc_->dtls_enabled(),
std::move(dependencies.cert_generator), certificate, std::move(dependencies.cert_generator), std::move(certificate),
[this](const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) { [this](const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) {
RTC_DCHECK_RUN_ON(signaling_thread()); RTC_DCHECK_RUN_ON(signaling_thread());
transport_controller_s()->SetLocalCertificate(certificate); transport_controller_s()->SetLocalCertificate(certificate);

View file

@ -135,11 +135,9 @@ class FakeRTCCertificateGenerator
int generated_certificates() { return generated_certificates_; } int generated_certificates() { return generated_certificates_; }
int generated_failures() { return generated_failures_; } int generated_failures() { return generated_failures_; }
void GenerateCertificateAsync( void GenerateCertificateAsync(const rtc::KeyParams& key_params,
const rtc::KeyParams& key_params,
const absl::optional<uint64_t>& expires_ms, const absl::optional<uint64_t>& expires_ms,
const rtc::scoped_refptr<rtc::RTCCertificateGeneratorCallback>& callback) Callback callback) override {
override {
// The certificates are created from constant PEM strings and use its coded // The certificates are created from constant PEM strings and use its coded
// expiration time, we do not support modifying it. // expiration time, we do not support modifying it.
RTC_DCHECK(!expires_ms); RTC_DCHECK(!expires_ms);
@ -154,7 +152,7 @@ class FakeRTCCertificateGenerator
} }
rtc::KeyType key_type = key_params.type(); rtc::KeyType key_type = key_params.type();
webrtc::TaskQueueBase::Current()->PostTask( webrtc::TaskQueueBase::Current()->PostTask(
[this, key_type, callback]() mutable { [this, key_type, callback = std::move(callback)]() mutable {
GenerateCertificate(key_type, std::move(callback)); GenerateCertificate(key_type, std::move(callback));
}); });
} }
@ -190,9 +188,7 @@ class FakeRTCCertificateGenerator
return get_pem(key_type).certificate(); return get_pem(key_type).certificate();
} }
void GenerateCertificate( void GenerateCertificate(rtc::KeyType key_type, Callback callback) {
rtc::KeyType key_type,
rtc::scoped_refptr<rtc::RTCCertificateGeneratorCallback> callback) {
// If the certificate generation should be stalled, re-post this same // If the certificate generation should be stalled, re-post this same
// message to the queue with a small delay so as to wait in a loop until // message to the queue with a small delay so as to wait in a loop until
// set_should_wait(false) is called. // set_should_wait(false) is called.
@ -206,13 +202,13 @@ class FakeRTCCertificateGenerator
} }
if (should_fail_) { if (should_fail_) {
++generated_failures_; ++generated_failures_;
callback->OnFailure(); std::move(callback)(nullptr);
} else { } else {
rtc::scoped_refptr<rtc::RTCCertificate> certificate = rtc::scoped_refptr<rtc::RTCCertificate> certificate =
rtc::RTCCertificate::FromPEM(get_pem(key_type)); rtc::RTCCertificate::FromPEM(get_pem(key_type));
RTC_DCHECK(certificate); RTC_DCHECK(certificate);
++generated_certificates_; ++generated_certificates_;
callback->OnSuccess(certificate); std::move(callback)(std::move(certificate));
} }
} }

View file

@ -85,29 +85,6 @@ struct CreateSessionDescriptionMsg : public rtc::MessageData {
}; };
} // namespace } // namespace
class WebRtcSessionDescriptionFactory::WebRtcCertificateGeneratorCallback
: public rtc::RTCCertificateGeneratorCallback {
public:
explicit WebRtcCertificateGeneratorCallback(
rtc::WeakPtr<WebRtcSessionDescriptionFactory> ptr)
: weak_ptr_(std::move(ptr)) {}
// `rtc::RTCCertificateGeneratorCallback` overrides.
void OnSuccess(
const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) override {
if (weak_ptr_) {
weak_ptr_->SetCertificate(certificate);
}
}
void OnFailure() override {
if (weak_ptr_) {
weak_ptr_->OnCertificateRequestFailed();
}
}
private:
rtc::WeakPtr<WebRtcSessionDescriptionFactory> weak_ptr_;
};
// static // static
void WebRtcSessionDescriptionFactory::CopyCandidatesFromSessionDescription( void WebRtcSessionDescriptionFactory::CopyCandidatesFromSessionDescription(
const SessionDescriptionInterface* source_desc, const SessionDescriptionInterface* source_desc,
@ -145,7 +122,7 @@ WebRtcSessionDescriptionFactory::WebRtcSessionDescriptionFactory(
const std::string& session_id, const std::string& session_id,
bool dtls_enabled, bool dtls_enabled,
std::unique_ptr<rtc::RTCCertificateGeneratorInterface> cert_generator, std::unique_ptr<rtc::RTCCertificateGeneratorInterface> cert_generator,
const rtc::scoped_refptr<rtc::RTCCertificate>& certificate, rtc::scoped_refptr<rtc::RTCCertificate> certificate,
std::function<void(const rtc::scoped_refptr<rtc::RTCCertificate>&)> std::function<void(const rtc::scoped_refptr<rtc::RTCCertificate>&)>
on_certificate_ready, on_certificate_ready,
const FieldTrialsView& field_trials) const FieldTrialsView& field_trials)
@ -186,18 +163,26 @@ WebRtcSessionDescriptionFactory::WebRtcSessionDescriptionFactory(
// Generate certificate. // Generate certificate.
certificate_request_state_ = CERTIFICATE_WAITING; certificate_request_state_ = CERTIFICATE_WAITING;
auto callback = rtc::make_ref_counted<WebRtcCertificateGeneratorCallback>( auto callback = [weak_ptr = weak_factory_.GetWeakPtr()](
weak_factory_.GetWeakPtr()); rtc::scoped_refptr<rtc::RTCCertificate> certificate) {
if (!weak_ptr) {
return;
}
if (certificate) {
weak_ptr->SetCertificate(std::move(certificate));
} else {
weak_ptr->OnCertificateRequestFailed();
}
};
rtc::KeyParams key_params = rtc::KeyParams(); rtc::KeyParams key_params = rtc::KeyParams();
RTC_LOG(LS_VERBOSE) RTC_LOG(LS_VERBOSE)
<< "DTLS-SRTP enabled; sending DTLS identity request (key type: " << "DTLS-SRTP enabled; sending DTLS identity request (key type: "
<< key_params.type() << ")."; << key_params.type() << ").";
// Request certificate. This happens asynchronously, so that the caller gets // Request certificate. This happens asynchronously on a different thread.
// a chance to connect to `SignalCertificateReady`.
cert_generator_->GenerateCertificateAsync(key_params, absl::nullopt, cert_generator_->GenerateCertificateAsync(key_params, absl::nullopt,
callback); std::move(callback));
} }
} }
@ -477,7 +462,7 @@ void WebRtcSessionDescriptionFactory::OnCertificateRequestFailed() {
} }
void WebRtcSessionDescriptionFactory::SetCertificate( void WebRtcSessionDescriptionFactory::SetCertificate(
const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) { rtc::scoped_refptr<rtc::RTCCertificate> certificate) {
RTC_DCHECK(certificate); RTC_DCHECK(certificate);
RTC_LOG(LS_VERBOSE) << "Setting new certificate."; RTC_LOG(LS_VERBOSE) << "Setting new certificate.";
@ -485,7 +470,7 @@ void WebRtcSessionDescriptionFactory::SetCertificate(
on_certificate_ready_(certificate); on_certificate_ready_(certificate);
transport_desc_factory_.set_certificate(certificate); transport_desc_factory_.set_certificate(std::move(certificate));
transport_desc_factory_.set_secure(cricket::SEC_ENABLED); transport_desc_factory_.set_secure(cricket::SEC_ENABLED);
while (!create_session_description_requests_.empty()) { while (!create_session_description_requests_.empty()) {

View file

@ -51,7 +51,7 @@ class WebRtcSessionDescriptionFactory : public rtc::MessageHandler {
const std::string& session_id, const std::string& session_id,
bool dtls_enabled, bool dtls_enabled,
std::unique_ptr<rtc::RTCCertificateGeneratorInterface> cert_generator, std::unique_ptr<rtc::RTCCertificateGeneratorInterface> cert_generator,
const rtc::scoped_refptr<rtc::RTCCertificate>& certificate, rtc::scoped_refptr<rtc::RTCCertificate> certificate,
std::function<void(const rtc::scoped_refptr<rtc::RTCCertificate>&)> std::function<void(const rtc::scoped_refptr<rtc::RTCCertificate>&)>
on_certificate_ready, on_certificate_ready,
const FieldTrialsView& field_trials); const FieldTrialsView& field_trials);
@ -98,9 +98,6 @@ class WebRtcSessionDescriptionFactory : public rtc::MessageHandler {
CERTIFICATE_FAILED, CERTIFICATE_FAILED,
}; };
// DTLS certificate request callback class.
class WebRtcCertificateGeneratorCallback;
struct CreateSessionDescriptionRequest { struct CreateSessionDescriptionRequest {
enum Type { enum Type {
kOffer, kOffer,
@ -132,8 +129,7 @@ class WebRtcSessionDescriptionFactory : public rtc::MessageHandler {
std::unique_ptr<SessionDescriptionInterface> description); std::unique_ptr<SessionDescriptionInterface> description);
void OnCertificateRequestFailed(); void OnCertificateRequestFailed();
void SetCertificate( void SetCertificate(rtc::scoped_refptr<rtc::RTCCertificate> certificate);
const rtc::scoped_refptr<rtc::RTCCertificate>& certificate);
std::queue<CreateSessionDescriptionRequest> std::queue<CreateSessionDescriptionRequest>
create_session_description_requests_; create_session_description_requests_;

View file

@ -1128,6 +1128,7 @@ rtc_library("rtc_base") {
absl_deps = [ absl_deps = [
"//third_party/abseil-cpp/absl/algorithm:container", "//third_party/abseil-cpp/absl/algorithm:container",
"//third_party/abseil-cpp/absl/container:flat_hash_map", "//third_party/abseil-cpp/absl/container:flat_hash_map",
"//third_party/abseil-cpp/absl/functional:any_invocable",
"//third_party/abseil-cpp/absl/memory", "//third_party/abseil-cpp/absl/memory",
"//third_party/abseil-cpp/absl/strings", "//third_party/abseil-cpp/absl/strings",
"//third_party/abseil-cpp/absl/types:optional", "//third_party/abseil-cpp/absl/types:optional",

View file

@ -71,21 +71,18 @@ RTCCertificateGenerator::RTCCertificateGenerator(Thread* signaling_thread,
void RTCCertificateGenerator::GenerateCertificateAsync( void RTCCertificateGenerator::GenerateCertificateAsync(
const KeyParams& key_params, const KeyParams& key_params,
const absl::optional<uint64_t>& expires_ms, const absl::optional<uint64_t>& expires_ms,
const scoped_refptr<RTCCertificateGeneratorCallback>& callback) { RTCCertificateGenerator::Callback callback) {
RTC_DCHECK(signaling_thread_->IsCurrent()); RTC_DCHECK(signaling_thread_->IsCurrent());
RTC_DCHECK(callback); RTC_DCHECK(callback);
// Create a new `RTCCertificateGenerationTask` for this generation request. It
// is reference counted and referenced by the message data, ensuring it lives
// until the task has completed (independent of `RTCCertificateGenerator`).
worker_thread_->PostTask([key_params, expires_ms, worker_thread_->PostTask([key_params, expires_ms,
signaling_thread = signaling_thread_, signaling_thread = signaling_thread_,
cb = callback]() { cb = std::move(callback)]() mutable {
scoped_refptr<RTCCertificate> certificate = scoped_refptr<RTCCertificate> certificate =
RTCCertificateGenerator::GenerateCertificate(key_params, expires_ms); RTCCertificateGenerator::GenerateCertificate(key_params, expires_ms);
signaling_thread->PostTask( signaling_thread->PostTask(
[cert = std::move(certificate), cb = std::move(cb)]() { [cert = std::move(certificate), cb = std::move(cb)]() mutable {
cert ? cb->OnSuccess(cert) : cb->OnFailure(); std::move(cb)(std::move(cert));
}); });
}); });
} }

View file

@ -13,9 +13,9 @@
#include <stdint.h> #include <stdint.h>
#include "absl/functional/any_invocable.h"
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "api/scoped_refptr.h" #include "api/scoped_refptr.h"
#include "rtc_base/ref_count.h"
#include "rtc_base/rtc_certificate.h" #include "rtc_base/rtc_certificate.h"
#include "rtc_base/ssl_identity.h" #include "rtc_base/ssl_identity.h"
#include "rtc_base/system/rtc_export.h" #include "rtc_base/system/rtc_export.h"
@ -23,21 +23,15 @@
namespace rtc { namespace rtc {
// See `RTCCertificateGeneratorInterface::GenerateCertificateAsync`.
class RTCCertificateGeneratorCallback : public RefCountInterface {
public:
virtual void OnSuccess(const scoped_refptr<RTCCertificate>& certificate) = 0;
virtual void OnFailure() = 0;
protected:
~RTCCertificateGeneratorCallback() override {}
};
// Generates `RTCCertificate`s. // Generates `RTCCertificate`s.
// See `RTCCertificateGenerator` for the WebRTC repo's implementation. // See `RTCCertificateGenerator` for the WebRTC repo's implementation.
class RTCCertificateGeneratorInterface { class RTCCertificateGeneratorInterface {
public: public:
virtual ~RTCCertificateGeneratorInterface() {} // Functor that will be called when certificate is generated asynchroniosly.
// Called with nullptr as the parameter on failure.
using Callback = absl::AnyInvocable<void(scoped_refptr<RTCCertificate>) &&>;
virtual ~RTCCertificateGeneratorInterface() = default;
// Generates a certificate asynchronously on the worker thread. // Generates a certificate asynchronously on the worker thread.
// Must be called on the signaling thread. The `callback` is invoked with the // Must be called on the signaling thread. The `callback` is invoked with the
@ -47,7 +41,7 @@ class RTCCertificateGeneratorInterface {
virtual void GenerateCertificateAsync( virtual void GenerateCertificateAsync(
const KeyParams& key_params, const KeyParams& key_params,
const absl::optional<uint64_t>& expires_ms, const absl::optional<uint64_t>& expires_ms,
const scoped_refptr<RTCCertificateGeneratorCallback>& callback) = 0; Callback callback) = 0;
}; };
// Standard implementation of `RTCCertificateGeneratorInterface`. // Standard implementation of `RTCCertificateGeneratorInterface`.
@ -74,10 +68,9 @@ class RTC_EXPORT RTCCertificateGenerator
// that many milliseconds from now. `expires_ms` is limited to a year, a // that many milliseconds from now. `expires_ms` is limited to a year, a
// larger value than that is clamped down to a year. If `expires_ms` is not // larger value than that is clamped down to a year. If `expires_ms` is not
// specified, a default expiration time is used. // specified, a default expiration time is used.
void GenerateCertificateAsync( void GenerateCertificateAsync(const KeyParams& key_params,
const KeyParams& key_params,
const absl::optional<uint64_t>& expires_ms, const absl::optional<uint64_t>& expires_ms,
const scoped_refptr<RTCCertificateGeneratorCallback>& callback) override; Callback callback) override;
private: private:
Thread* const signaling_thread_; Thread* const signaling_thread_;

View file

@ -21,7 +21,7 @@
namespace rtc { namespace rtc {
class RTCCertificateGeneratorFixture : public RTCCertificateGeneratorCallback { class RTCCertificateGeneratorFixture {
public: public:
RTCCertificateGeneratorFixture() RTCCertificateGeneratorFixture()
: signaling_thread_(Thread::Current()), : signaling_thread_(Thread::Current()),
@ -32,21 +32,16 @@ class RTCCertificateGeneratorFixture : public RTCCertificateGeneratorCallback {
generator_.reset( generator_.reset(
new RTCCertificateGenerator(signaling_thread_, worker_thread_.get())); new RTCCertificateGenerator(signaling_thread_, worker_thread_.get()));
} }
~RTCCertificateGeneratorFixture() override {}
RTCCertificateGenerator* generator() const { return generator_.get(); } RTCCertificateGenerator* generator() const { return generator_.get(); }
RTCCertificate* certificate() const { return certificate_.get(); } RTCCertificate* certificate() const { return certificate_.get(); }
void OnSuccess(const scoped_refptr<RTCCertificate>& certificate) override { RTCCertificateGeneratorInterface::Callback OnGenerated() {
return [this](scoped_refptr<RTCCertificate> certificate) mutable {
RTC_CHECK(signaling_thread_->IsCurrent()); RTC_CHECK(signaling_thread_->IsCurrent());
RTC_CHECK(certificate); certificate_ = std::move(certificate);
certificate_ = certificate;
generate_async_completed_ = true;
}
void OnFailure() override {
RTC_CHECK(signaling_thread_->IsCurrent());
certificate_ = nullptr;
generate_async_completed_ = true; generate_async_completed_ = true;
};
} }
bool GenerateAsyncCompleted() { bool GenerateAsyncCompleted() {
@ -69,14 +64,11 @@ class RTCCertificateGeneratorFixture : public RTCCertificateGeneratorCallback {
class RTCCertificateGeneratorTest : public ::testing::Test { class RTCCertificateGeneratorTest : public ::testing::Test {
public: public:
RTCCertificateGeneratorTest()
: fixture_(make_ref_counted<RTCCertificateGeneratorFixture>()) {}
protected: protected:
static constexpr int kGenerationTimeoutMs = 10000; static constexpr int kGenerationTimeoutMs = 10000;
rtc::AutoThread main_thread_; rtc::AutoThread main_thread_;
scoped_refptr<RTCCertificateGeneratorFixture> fixture_; RTCCertificateGeneratorFixture fixture_;
}; };
TEST_F(RTCCertificateGeneratorTest, GenerateECDSA) { TEST_F(RTCCertificateGeneratorTest, GenerateECDSA) {
@ -90,16 +82,16 @@ TEST_F(RTCCertificateGeneratorTest, GenerateRSA) {
} }
TEST_F(RTCCertificateGeneratorTest, GenerateAsyncECDSA) { TEST_F(RTCCertificateGeneratorTest, GenerateAsyncECDSA) {
EXPECT_FALSE(fixture_->certificate()); EXPECT_FALSE(fixture_.certificate());
fixture_->generator()->GenerateCertificateAsync(KeyParams::ECDSA(), fixture_.generator()->GenerateCertificateAsync(
absl::nullopt, fixture_); KeyParams::ECDSA(), absl::nullopt, fixture_.OnGenerated());
// Until generation has completed, the certificate is null. Since this is an // Until generation has completed, the certificate is null. Since this is an
// async call, generation must not have completed until we process messages // async call, generation must not have completed until we process messages
// posted to this thread (which is done by `EXPECT_TRUE_WAIT`). // posted to this thread (which is done by `EXPECT_TRUE_WAIT`).
EXPECT_FALSE(fixture_->GenerateAsyncCompleted()); EXPECT_FALSE(fixture_.GenerateAsyncCompleted());
EXPECT_FALSE(fixture_->certificate()); EXPECT_FALSE(fixture_.certificate());
EXPECT_TRUE_WAIT(fixture_->GenerateAsyncCompleted(), kGenerationTimeoutMs); EXPECT_TRUE_WAIT(fixture_.GenerateAsyncCompleted(), kGenerationTimeoutMs);
EXPECT_TRUE(fixture_->certificate()); EXPECT_TRUE(fixture_.certificate());
} }
TEST_F(RTCCertificateGeneratorTest, GenerateWithExpires) { TEST_F(RTCCertificateGeneratorTest, GenerateWithExpires) {
@ -136,10 +128,10 @@ TEST_F(RTCCertificateGeneratorTest, GenerateWithInvalidParamsShouldFail) {
EXPECT_FALSE(RTCCertificateGenerator::GenerateCertificate(invalid_params, EXPECT_FALSE(RTCCertificateGenerator::GenerateCertificate(invalid_params,
absl::nullopt)); absl::nullopt));
fixture_->generator()->GenerateCertificateAsync(invalid_params, absl::nullopt, fixture_.generator()->GenerateCertificateAsync(invalid_params, absl::nullopt,
fixture_); fixture_.OnGenerated());
EXPECT_TRUE_WAIT(fixture_->GenerateAsyncCompleted(), kGenerationTimeoutMs); EXPECT_TRUE_WAIT(fixture_.GenerateAsyncCompleted(), kGenerationTimeoutMs);
EXPECT_FALSE(fixture_->certificate()); EXPECT_FALSE(fixture_.certificate());
} }
} // namespace rtc } // namespace rtc