Replace RTCCertificateGeneratorCallback interface with an AnyInvocable

follow up of the https://webrtc-review.googlesource.com/c/src/+/272402

Bug: None
Change-Id: Ie47aff9fccdb4037c1f560801c780dd549b373ae
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/272553
Commit-Queue: Danil Chapovalov <danilchap@webrtc.org>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#37870}
This commit is contained in:
Danil Chapovalov 2022-08-22 16:39:34 +02:00 committed by WebRTC LUCI CQ
parent 2020767ddd
commit b7da81621c
9 changed files with 63 additions and 103 deletions

View file

@ -12,6 +12,7 @@
#define P2P_BASE_TRANSPORT_DESCRIPTION_FACTORY_H_
#include <memory>
#include <utility>
#include "api/field_trials_view.h"
#include "p2p/base/ice_credentials_iterator.h"
@ -51,9 +52,8 @@ class TransportDescriptionFactory {
// Specifies the transport security policy to use.
void set_secure(SecurePolicy s) { secure_ = s; }
// Specifies the certificate to use (only used when secure != SEC_DISABLED).
void set_certificate(
const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) {
certificate_ = certificate;
void set_certificate(rtc::scoped_refptr<rtc::RTCCertificate> certificate) {
certificate_ = std::move(certificate);
}
// Creates a transport description suitable for use in an offer.

View file

@ -1221,7 +1221,7 @@ void SdpOfferAnswerHandler::Initialize(
webrtc_session_desc_factory_ =
std::make_unique<WebRtcSessionDescriptionFactory>(
context, this, pc_->session_id(), pc_->dtls_enabled(),
std::move(dependencies.cert_generator), certificate,
std::move(dependencies.cert_generator), std::move(certificate),
[this](const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) {
RTC_DCHECK_RUN_ON(signaling_thread());
transport_controller_s()->SetLocalCertificate(certificate);

View file

@ -135,11 +135,9 @@ class FakeRTCCertificateGenerator
int generated_certificates() { return generated_certificates_; }
int generated_failures() { return generated_failures_; }
void GenerateCertificateAsync(
const rtc::KeyParams& key_params,
const absl::optional<uint64_t>& expires_ms,
const rtc::scoped_refptr<rtc::RTCCertificateGeneratorCallback>& callback)
override {
void GenerateCertificateAsync(const rtc::KeyParams& key_params,
const absl::optional<uint64_t>& expires_ms,
Callback callback) override {
// The certificates are created from constant PEM strings and use its coded
// expiration time, we do not support modifying it.
RTC_DCHECK(!expires_ms);
@ -154,7 +152,7 @@ class FakeRTCCertificateGenerator
}
rtc::KeyType key_type = key_params.type();
webrtc::TaskQueueBase::Current()->PostTask(
[this, key_type, callback]() mutable {
[this, key_type, callback = std::move(callback)]() mutable {
GenerateCertificate(key_type, std::move(callback));
});
}
@ -190,9 +188,7 @@ class FakeRTCCertificateGenerator
return get_pem(key_type).certificate();
}
void GenerateCertificate(
rtc::KeyType key_type,
rtc::scoped_refptr<rtc::RTCCertificateGeneratorCallback> callback) {
void GenerateCertificate(rtc::KeyType key_type, Callback callback) {
// If the certificate generation should be stalled, re-post this same
// message to the queue with a small delay so as to wait in a loop until
// set_should_wait(false) is called.
@ -206,13 +202,13 @@ class FakeRTCCertificateGenerator
}
if (should_fail_) {
++generated_failures_;
callback->OnFailure();
std::move(callback)(nullptr);
} else {
rtc::scoped_refptr<rtc::RTCCertificate> certificate =
rtc::RTCCertificate::FromPEM(get_pem(key_type));
RTC_DCHECK(certificate);
++generated_certificates_;
callback->OnSuccess(certificate);
std::move(callback)(std::move(certificate));
}
}

View file

@ -85,29 +85,6 @@ struct CreateSessionDescriptionMsg : public rtc::MessageData {
};
} // namespace
class WebRtcSessionDescriptionFactory::WebRtcCertificateGeneratorCallback
: public rtc::RTCCertificateGeneratorCallback {
public:
explicit WebRtcCertificateGeneratorCallback(
rtc::WeakPtr<WebRtcSessionDescriptionFactory> ptr)
: weak_ptr_(std::move(ptr)) {}
// `rtc::RTCCertificateGeneratorCallback` overrides.
void OnSuccess(
const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) override {
if (weak_ptr_) {
weak_ptr_->SetCertificate(certificate);
}
}
void OnFailure() override {
if (weak_ptr_) {
weak_ptr_->OnCertificateRequestFailed();
}
}
private:
rtc::WeakPtr<WebRtcSessionDescriptionFactory> weak_ptr_;
};
// static
void WebRtcSessionDescriptionFactory::CopyCandidatesFromSessionDescription(
const SessionDescriptionInterface* source_desc,
@ -145,7 +122,7 @@ WebRtcSessionDescriptionFactory::WebRtcSessionDescriptionFactory(
const std::string& session_id,
bool dtls_enabled,
std::unique_ptr<rtc::RTCCertificateGeneratorInterface> cert_generator,
const rtc::scoped_refptr<rtc::RTCCertificate>& certificate,
rtc::scoped_refptr<rtc::RTCCertificate> certificate,
std::function<void(const rtc::scoped_refptr<rtc::RTCCertificate>&)>
on_certificate_ready,
const FieldTrialsView& field_trials)
@ -186,18 +163,26 @@ WebRtcSessionDescriptionFactory::WebRtcSessionDescriptionFactory(
// Generate certificate.
certificate_request_state_ = CERTIFICATE_WAITING;
auto callback = rtc::make_ref_counted<WebRtcCertificateGeneratorCallback>(
weak_factory_.GetWeakPtr());
auto callback = [weak_ptr = weak_factory_.GetWeakPtr()](
rtc::scoped_refptr<rtc::RTCCertificate> certificate) {
if (!weak_ptr) {
return;
}
if (certificate) {
weak_ptr->SetCertificate(std::move(certificate));
} else {
weak_ptr->OnCertificateRequestFailed();
}
};
rtc::KeyParams key_params = rtc::KeyParams();
RTC_LOG(LS_VERBOSE)
<< "DTLS-SRTP enabled; sending DTLS identity request (key type: "
<< key_params.type() << ").";
// Request certificate. This happens asynchronously, so that the caller gets
// a chance to connect to `SignalCertificateReady`.
// Request certificate. This happens asynchronously on a different thread.
cert_generator_->GenerateCertificateAsync(key_params, absl::nullopt,
callback);
std::move(callback));
}
}
@ -477,7 +462,7 @@ void WebRtcSessionDescriptionFactory::OnCertificateRequestFailed() {
}
void WebRtcSessionDescriptionFactory::SetCertificate(
const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) {
rtc::scoped_refptr<rtc::RTCCertificate> certificate) {
RTC_DCHECK(certificate);
RTC_LOG(LS_VERBOSE) << "Setting new certificate.";
@ -485,7 +470,7 @@ void WebRtcSessionDescriptionFactory::SetCertificate(
on_certificate_ready_(certificate);
transport_desc_factory_.set_certificate(certificate);
transport_desc_factory_.set_certificate(std::move(certificate));
transport_desc_factory_.set_secure(cricket::SEC_ENABLED);
while (!create_session_description_requests_.empty()) {

View file

@ -51,7 +51,7 @@ class WebRtcSessionDescriptionFactory : public rtc::MessageHandler {
const std::string& session_id,
bool dtls_enabled,
std::unique_ptr<rtc::RTCCertificateGeneratorInterface> cert_generator,
const rtc::scoped_refptr<rtc::RTCCertificate>& certificate,
rtc::scoped_refptr<rtc::RTCCertificate> certificate,
std::function<void(const rtc::scoped_refptr<rtc::RTCCertificate>&)>
on_certificate_ready,
const FieldTrialsView& field_trials);
@ -98,9 +98,6 @@ class WebRtcSessionDescriptionFactory : public rtc::MessageHandler {
CERTIFICATE_FAILED,
};
// DTLS certificate request callback class.
class WebRtcCertificateGeneratorCallback;
struct CreateSessionDescriptionRequest {
enum Type {
kOffer,
@ -132,8 +129,7 @@ class WebRtcSessionDescriptionFactory : public rtc::MessageHandler {
std::unique_ptr<SessionDescriptionInterface> description);
void OnCertificateRequestFailed();
void SetCertificate(
const rtc::scoped_refptr<rtc::RTCCertificate>& certificate);
void SetCertificate(rtc::scoped_refptr<rtc::RTCCertificate> certificate);
std::queue<CreateSessionDescriptionRequest>
create_session_description_requests_;

View file

@ -1128,6 +1128,7 @@ rtc_library("rtc_base") {
absl_deps = [
"//third_party/abseil-cpp/absl/algorithm:container",
"//third_party/abseil-cpp/absl/container:flat_hash_map",
"//third_party/abseil-cpp/absl/functional:any_invocable",
"//third_party/abseil-cpp/absl/memory",
"//third_party/abseil-cpp/absl/strings",
"//third_party/abseil-cpp/absl/types:optional",

View file

@ -71,21 +71,18 @@ RTCCertificateGenerator::RTCCertificateGenerator(Thread* signaling_thread,
void RTCCertificateGenerator::GenerateCertificateAsync(
const KeyParams& key_params,
const absl::optional<uint64_t>& expires_ms,
const scoped_refptr<RTCCertificateGeneratorCallback>& callback) {
RTCCertificateGenerator::Callback callback) {
RTC_DCHECK(signaling_thread_->IsCurrent());
RTC_DCHECK(callback);
// Create a new `RTCCertificateGenerationTask` for this generation request. It
// is reference counted and referenced by the message data, ensuring it lives
// until the task has completed (independent of `RTCCertificateGenerator`).
worker_thread_->PostTask([key_params, expires_ms,
signaling_thread = signaling_thread_,
cb = callback]() {
cb = std::move(callback)]() mutable {
scoped_refptr<RTCCertificate> certificate =
RTCCertificateGenerator::GenerateCertificate(key_params, expires_ms);
signaling_thread->PostTask(
[cert = std::move(certificate), cb = std::move(cb)]() {
cert ? cb->OnSuccess(cert) : cb->OnFailure();
[cert = std::move(certificate), cb = std::move(cb)]() mutable {
std::move(cb)(std::move(cert));
});
});
}

View file

@ -13,9 +13,9 @@
#include <stdint.h>
#include "absl/functional/any_invocable.h"
#include "absl/types/optional.h"
#include "api/scoped_refptr.h"
#include "rtc_base/ref_count.h"
#include "rtc_base/rtc_certificate.h"
#include "rtc_base/ssl_identity.h"
#include "rtc_base/system/rtc_export.h"
@ -23,21 +23,15 @@
namespace rtc {
// See `RTCCertificateGeneratorInterface::GenerateCertificateAsync`.
class RTCCertificateGeneratorCallback : public RefCountInterface {
public:
virtual void OnSuccess(const scoped_refptr<RTCCertificate>& certificate) = 0;
virtual void OnFailure() = 0;
protected:
~RTCCertificateGeneratorCallback() override {}
};
// Generates `RTCCertificate`s.
// See `RTCCertificateGenerator` for the WebRTC repo's implementation.
class RTCCertificateGeneratorInterface {
public:
virtual ~RTCCertificateGeneratorInterface() {}
// Functor that will be called when certificate is generated asynchroniosly.
// Called with nullptr as the parameter on failure.
using Callback = absl::AnyInvocable<void(scoped_refptr<RTCCertificate>) &&>;
virtual ~RTCCertificateGeneratorInterface() = default;
// Generates a certificate asynchronously on the worker thread.
// Must be called on the signaling thread. The `callback` is invoked with the
@ -47,7 +41,7 @@ class RTCCertificateGeneratorInterface {
virtual void GenerateCertificateAsync(
const KeyParams& key_params,
const absl::optional<uint64_t>& expires_ms,
const scoped_refptr<RTCCertificateGeneratorCallback>& callback) = 0;
Callback callback) = 0;
};
// Standard implementation of `RTCCertificateGeneratorInterface`.
@ -74,10 +68,9 @@ class RTC_EXPORT RTCCertificateGenerator
// that many milliseconds from now. `expires_ms` is limited to a year, a
// larger value than that is clamped down to a year. If `expires_ms` is not
// specified, a default expiration time is used.
void GenerateCertificateAsync(
const KeyParams& key_params,
const absl::optional<uint64_t>& expires_ms,
const scoped_refptr<RTCCertificateGeneratorCallback>& callback) override;
void GenerateCertificateAsync(const KeyParams& key_params,
const absl::optional<uint64_t>& expires_ms,
Callback callback) override;
private:
Thread* const signaling_thread_;

View file

@ -21,7 +21,7 @@
namespace rtc {
class RTCCertificateGeneratorFixture : public RTCCertificateGeneratorCallback {
class RTCCertificateGeneratorFixture {
public:
RTCCertificateGeneratorFixture()
: signaling_thread_(Thread::Current()),
@ -32,21 +32,16 @@ class RTCCertificateGeneratorFixture : public RTCCertificateGeneratorCallback {
generator_.reset(
new RTCCertificateGenerator(signaling_thread_, worker_thread_.get()));
}
~RTCCertificateGeneratorFixture() override {}
RTCCertificateGenerator* generator() const { return generator_.get(); }
RTCCertificate* certificate() const { return certificate_.get(); }
void OnSuccess(const scoped_refptr<RTCCertificate>& certificate) override {
RTC_CHECK(signaling_thread_->IsCurrent());
RTC_CHECK(certificate);
certificate_ = certificate;
generate_async_completed_ = true;
}
void OnFailure() override {
RTC_CHECK(signaling_thread_->IsCurrent());
certificate_ = nullptr;
generate_async_completed_ = true;
RTCCertificateGeneratorInterface::Callback OnGenerated() {
return [this](scoped_refptr<RTCCertificate> certificate) mutable {
RTC_CHECK(signaling_thread_->IsCurrent());
certificate_ = std::move(certificate);
generate_async_completed_ = true;
};
}
bool GenerateAsyncCompleted() {
@ -69,14 +64,11 @@ class RTCCertificateGeneratorFixture : public RTCCertificateGeneratorCallback {
class RTCCertificateGeneratorTest : public ::testing::Test {
public:
RTCCertificateGeneratorTest()
: fixture_(make_ref_counted<RTCCertificateGeneratorFixture>()) {}
protected:
static constexpr int kGenerationTimeoutMs = 10000;
rtc::AutoThread main_thread_;
scoped_refptr<RTCCertificateGeneratorFixture> fixture_;
RTCCertificateGeneratorFixture fixture_;
};
TEST_F(RTCCertificateGeneratorTest, GenerateECDSA) {
@ -90,16 +82,16 @@ TEST_F(RTCCertificateGeneratorTest, GenerateRSA) {
}
TEST_F(RTCCertificateGeneratorTest, GenerateAsyncECDSA) {
EXPECT_FALSE(fixture_->certificate());
fixture_->generator()->GenerateCertificateAsync(KeyParams::ECDSA(),
absl::nullopt, fixture_);
EXPECT_FALSE(fixture_.certificate());
fixture_.generator()->GenerateCertificateAsync(
KeyParams::ECDSA(), absl::nullopt, fixture_.OnGenerated());
// Until generation has completed, the certificate is null. Since this is an
// async call, generation must not have completed until we process messages
// posted to this thread (which is done by `EXPECT_TRUE_WAIT`).
EXPECT_FALSE(fixture_->GenerateAsyncCompleted());
EXPECT_FALSE(fixture_->certificate());
EXPECT_TRUE_WAIT(fixture_->GenerateAsyncCompleted(), kGenerationTimeoutMs);
EXPECT_TRUE(fixture_->certificate());
EXPECT_FALSE(fixture_.GenerateAsyncCompleted());
EXPECT_FALSE(fixture_.certificate());
EXPECT_TRUE_WAIT(fixture_.GenerateAsyncCompleted(), kGenerationTimeoutMs);
EXPECT_TRUE(fixture_.certificate());
}
TEST_F(RTCCertificateGeneratorTest, GenerateWithExpires) {
@ -136,10 +128,10 @@ TEST_F(RTCCertificateGeneratorTest, GenerateWithInvalidParamsShouldFail) {
EXPECT_FALSE(RTCCertificateGenerator::GenerateCertificate(invalid_params,
absl::nullopt));
fixture_->generator()->GenerateCertificateAsync(invalid_params, absl::nullopt,
fixture_);
EXPECT_TRUE_WAIT(fixture_->GenerateAsyncCompleted(), kGenerationTimeoutMs);
EXPECT_FALSE(fixture_->certificate());
fixture_.generator()->GenerateCertificateAsync(invalid_params, absl::nullopt,
fixture_.OnGenerated());
EXPECT_TRUE_WAIT(fixture_.GenerateAsyncCompleted(), kGenerationTimeoutMs);
EXPECT_FALSE(fixture_.certificate());
}
} // namespace rtc