mirror of
https://github.com/mollyim/webrtc.git
synced 2025-05-12 21:30:45 +01:00
Replace RTCCertificateGeneratorCallback interface with an AnyInvocable
follow up of the https://webrtc-review.googlesource.com/c/src/+/272402 Bug: None Change-Id: Ie47aff9fccdb4037c1f560801c780dd549b373ae Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/272553 Commit-Queue: Danil Chapovalov <danilchap@webrtc.org> Reviewed-by: Harald Alvestrand <hta@webrtc.org> Cr-Commit-Position: refs/heads/main@{#37870}
This commit is contained in:
parent
2020767ddd
commit
b7da81621c
9 changed files with 63 additions and 103 deletions
|
@ -12,6 +12,7 @@
|
|||
#define P2P_BASE_TRANSPORT_DESCRIPTION_FACTORY_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "api/field_trials_view.h"
|
||||
#include "p2p/base/ice_credentials_iterator.h"
|
||||
|
@ -51,9 +52,8 @@ class TransportDescriptionFactory {
|
|||
// Specifies the transport security policy to use.
|
||||
void set_secure(SecurePolicy s) { secure_ = s; }
|
||||
// Specifies the certificate to use (only used when secure != SEC_DISABLED).
|
||||
void set_certificate(
|
||||
const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) {
|
||||
certificate_ = certificate;
|
||||
void set_certificate(rtc::scoped_refptr<rtc::RTCCertificate> certificate) {
|
||||
certificate_ = std::move(certificate);
|
||||
}
|
||||
|
||||
// Creates a transport description suitable for use in an offer.
|
||||
|
|
|
@ -1221,7 +1221,7 @@ void SdpOfferAnswerHandler::Initialize(
|
|||
webrtc_session_desc_factory_ =
|
||||
std::make_unique<WebRtcSessionDescriptionFactory>(
|
||||
context, this, pc_->session_id(), pc_->dtls_enabled(),
|
||||
std::move(dependencies.cert_generator), certificate,
|
||||
std::move(dependencies.cert_generator), std::move(certificate),
|
||||
[this](const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) {
|
||||
RTC_DCHECK_RUN_ON(signaling_thread());
|
||||
transport_controller_s()->SetLocalCertificate(certificate);
|
||||
|
|
|
@ -135,11 +135,9 @@ class FakeRTCCertificateGenerator
|
|||
int generated_certificates() { return generated_certificates_; }
|
||||
int generated_failures() { return generated_failures_; }
|
||||
|
||||
void GenerateCertificateAsync(
|
||||
const rtc::KeyParams& key_params,
|
||||
const absl::optional<uint64_t>& expires_ms,
|
||||
const rtc::scoped_refptr<rtc::RTCCertificateGeneratorCallback>& callback)
|
||||
override {
|
||||
void GenerateCertificateAsync(const rtc::KeyParams& key_params,
|
||||
const absl::optional<uint64_t>& expires_ms,
|
||||
Callback callback) override {
|
||||
// The certificates are created from constant PEM strings and use its coded
|
||||
// expiration time, we do not support modifying it.
|
||||
RTC_DCHECK(!expires_ms);
|
||||
|
@ -154,7 +152,7 @@ class FakeRTCCertificateGenerator
|
|||
}
|
||||
rtc::KeyType key_type = key_params.type();
|
||||
webrtc::TaskQueueBase::Current()->PostTask(
|
||||
[this, key_type, callback]() mutable {
|
||||
[this, key_type, callback = std::move(callback)]() mutable {
|
||||
GenerateCertificate(key_type, std::move(callback));
|
||||
});
|
||||
}
|
||||
|
@ -190,9 +188,7 @@ class FakeRTCCertificateGenerator
|
|||
return get_pem(key_type).certificate();
|
||||
}
|
||||
|
||||
void GenerateCertificate(
|
||||
rtc::KeyType key_type,
|
||||
rtc::scoped_refptr<rtc::RTCCertificateGeneratorCallback> callback) {
|
||||
void GenerateCertificate(rtc::KeyType key_type, Callback callback) {
|
||||
// If the certificate generation should be stalled, re-post this same
|
||||
// message to the queue with a small delay so as to wait in a loop until
|
||||
// set_should_wait(false) is called.
|
||||
|
@ -206,13 +202,13 @@ class FakeRTCCertificateGenerator
|
|||
}
|
||||
if (should_fail_) {
|
||||
++generated_failures_;
|
||||
callback->OnFailure();
|
||||
std::move(callback)(nullptr);
|
||||
} else {
|
||||
rtc::scoped_refptr<rtc::RTCCertificate> certificate =
|
||||
rtc::RTCCertificate::FromPEM(get_pem(key_type));
|
||||
RTC_DCHECK(certificate);
|
||||
++generated_certificates_;
|
||||
callback->OnSuccess(certificate);
|
||||
std::move(callback)(std::move(certificate));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -85,29 +85,6 @@ struct CreateSessionDescriptionMsg : public rtc::MessageData {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
class WebRtcSessionDescriptionFactory::WebRtcCertificateGeneratorCallback
|
||||
: public rtc::RTCCertificateGeneratorCallback {
|
||||
public:
|
||||
explicit WebRtcCertificateGeneratorCallback(
|
||||
rtc::WeakPtr<WebRtcSessionDescriptionFactory> ptr)
|
||||
: weak_ptr_(std::move(ptr)) {}
|
||||
// `rtc::RTCCertificateGeneratorCallback` overrides.
|
||||
void OnSuccess(
|
||||
const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) override {
|
||||
if (weak_ptr_) {
|
||||
weak_ptr_->SetCertificate(certificate);
|
||||
}
|
||||
}
|
||||
void OnFailure() override {
|
||||
if (weak_ptr_) {
|
||||
weak_ptr_->OnCertificateRequestFailed();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
rtc::WeakPtr<WebRtcSessionDescriptionFactory> weak_ptr_;
|
||||
};
|
||||
|
||||
// static
|
||||
void WebRtcSessionDescriptionFactory::CopyCandidatesFromSessionDescription(
|
||||
const SessionDescriptionInterface* source_desc,
|
||||
|
@ -145,7 +122,7 @@ WebRtcSessionDescriptionFactory::WebRtcSessionDescriptionFactory(
|
|||
const std::string& session_id,
|
||||
bool dtls_enabled,
|
||||
std::unique_ptr<rtc::RTCCertificateGeneratorInterface> cert_generator,
|
||||
const rtc::scoped_refptr<rtc::RTCCertificate>& certificate,
|
||||
rtc::scoped_refptr<rtc::RTCCertificate> certificate,
|
||||
std::function<void(const rtc::scoped_refptr<rtc::RTCCertificate>&)>
|
||||
on_certificate_ready,
|
||||
const FieldTrialsView& field_trials)
|
||||
|
@ -186,18 +163,26 @@ WebRtcSessionDescriptionFactory::WebRtcSessionDescriptionFactory(
|
|||
// Generate certificate.
|
||||
certificate_request_state_ = CERTIFICATE_WAITING;
|
||||
|
||||
auto callback = rtc::make_ref_counted<WebRtcCertificateGeneratorCallback>(
|
||||
weak_factory_.GetWeakPtr());
|
||||
auto callback = [weak_ptr = weak_factory_.GetWeakPtr()](
|
||||
rtc::scoped_refptr<rtc::RTCCertificate> certificate) {
|
||||
if (!weak_ptr) {
|
||||
return;
|
||||
}
|
||||
if (certificate) {
|
||||
weak_ptr->SetCertificate(std::move(certificate));
|
||||
} else {
|
||||
weak_ptr->OnCertificateRequestFailed();
|
||||
}
|
||||
};
|
||||
|
||||
rtc::KeyParams key_params = rtc::KeyParams();
|
||||
RTC_LOG(LS_VERBOSE)
|
||||
<< "DTLS-SRTP enabled; sending DTLS identity request (key type: "
|
||||
<< key_params.type() << ").";
|
||||
|
||||
// Request certificate. This happens asynchronously, so that the caller gets
|
||||
// a chance to connect to `SignalCertificateReady`.
|
||||
// Request certificate. This happens asynchronously on a different thread.
|
||||
cert_generator_->GenerateCertificateAsync(key_params, absl::nullopt,
|
||||
callback);
|
||||
std::move(callback));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -477,7 +462,7 @@ void WebRtcSessionDescriptionFactory::OnCertificateRequestFailed() {
|
|||
}
|
||||
|
||||
void WebRtcSessionDescriptionFactory::SetCertificate(
|
||||
const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) {
|
||||
rtc::scoped_refptr<rtc::RTCCertificate> certificate) {
|
||||
RTC_DCHECK(certificate);
|
||||
RTC_LOG(LS_VERBOSE) << "Setting new certificate.";
|
||||
|
||||
|
@ -485,7 +470,7 @@ void WebRtcSessionDescriptionFactory::SetCertificate(
|
|||
|
||||
on_certificate_ready_(certificate);
|
||||
|
||||
transport_desc_factory_.set_certificate(certificate);
|
||||
transport_desc_factory_.set_certificate(std::move(certificate));
|
||||
transport_desc_factory_.set_secure(cricket::SEC_ENABLED);
|
||||
|
||||
while (!create_session_description_requests_.empty()) {
|
||||
|
|
|
@ -51,7 +51,7 @@ class WebRtcSessionDescriptionFactory : public rtc::MessageHandler {
|
|||
const std::string& session_id,
|
||||
bool dtls_enabled,
|
||||
std::unique_ptr<rtc::RTCCertificateGeneratorInterface> cert_generator,
|
||||
const rtc::scoped_refptr<rtc::RTCCertificate>& certificate,
|
||||
rtc::scoped_refptr<rtc::RTCCertificate> certificate,
|
||||
std::function<void(const rtc::scoped_refptr<rtc::RTCCertificate>&)>
|
||||
on_certificate_ready,
|
||||
const FieldTrialsView& field_trials);
|
||||
|
@ -98,9 +98,6 @@ class WebRtcSessionDescriptionFactory : public rtc::MessageHandler {
|
|||
CERTIFICATE_FAILED,
|
||||
};
|
||||
|
||||
// DTLS certificate request callback class.
|
||||
class WebRtcCertificateGeneratorCallback;
|
||||
|
||||
struct CreateSessionDescriptionRequest {
|
||||
enum Type {
|
||||
kOffer,
|
||||
|
@ -132,8 +129,7 @@ class WebRtcSessionDescriptionFactory : public rtc::MessageHandler {
|
|||
std::unique_ptr<SessionDescriptionInterface> description);
|
||||
|
||||
void OnCertificateRequestFailed();
|
||||
void SetCertificate(
|
||||
const rtc::scoped_refptr<rtc::RTCCertificate>& certificate);
|
||||
void SetCertificate(rtc::scoped_refptr<rtc::RTCCertificate> certificate);
|
||||
|
||||
std::queue<CreateSessionDescriptionRequest>
|
||||
create_session_description_requests_;
|
||||
|
|
|
@ -1128,6 +1128,7 @@ rtc_library("rtc_base") {
|
|||
absl_deps = [
|
||||
"//third_party/abseil-cpp/absl/algorithm:container",
|
||||
"//third_party/abseil-cpp/absl/container:flat_hash_map",
|
||||
"//third_party/abseil-cpp/absl/functional:any_invocable",
|
||||
"//third_party/abseil-cpp/absl/memory",
|
||||
"//third_party/abseil-cpp/absl/strings",
|
||||
"//third_party/abseil-cpp/absl/types:optional",
|
||||
|
|
|
@ -71,21 +71,18 @@ RTCCertificateGenerator::RTCCertificateGenerator(Thread* signaling_thread,
|
|||
void RTCCertificateGenerator::GenerateCertificateAsync(
|
||||
const KeyParams& key_params,
|
||||
const absl::optional<uint64_t>& expires_ms,
|
||||
const scoped_refptr<RTCCertificateGeneratorCallback>& callback) {
|
||||
RTCCertificateGenerator::Callback callback) {
|
||||
RTC_DCHECK(signaling_thread_->IsCurrent());
|
||||
RTC_DCHECK(callback);
|
||||
|
||||
// Create a new `RTCCertificateGenerationTask` for this generation request. It
|
||||
// is reference counted and referenced by the message data, ensuring it lives
|
||||
// until the task has completed (independent of `RTCCertificateGenerator`).
|
||||
worker_thread_->PostTask([key_params, expires_ms,
|
||||
signaling_thread = signaling_thread_,
|
||||
cb = callback]() {
|
||||
cb = std::move(callback)]() mutable {
|
||||
scoped_refptr<RTCCertificate> certificate =
|
||||
RTCCertificateGenerator::GenerateCertificate(key_params, expires_ms);
|
||||
signaling_thread->PostTask(
|
||||
[cert = std::move(certificate), cb = std::move(cb)]() {
|
||||
cert ? cb->OnSuccess(cert) : cb->OnFailure();
|
||||
[cert = std::move(certificate), cb = std::move(cb)]() mutable {
|
||||
std::move(cb)(std::move(cert));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
|
@ -13,9 +13,9 @@
|
|||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "absl/functional/any_invocable.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "api/scoped_refptr.h"
|
||||
#include "rtc_base/ref_count.h"
|
||||
#include "rtc_base/rtc_certificate.h"
|
||||
#include "rtc_base/ssl_identity.h"
|
||||
#include "rtc_base/system/rtc_export.h"
|
||||
|
@ -23,21 +23,15 @@
|
|||
|
||||
namespace rtc {
|
||||
|
||||
// See `RTCCertificateGeneratorInterface::GenerateCertificateAsync`.
|
||||
class RTCCertificateGeneratorCallback : public RefCountInterface {
|
||||
public:
|
||||
virtual void OnSuccess(const scoped_refptr<RTCCertificate>& certificate) = 0;
|
||||
virtual void OnFailure() = 0;
|
||||
|
||||
protected:
|
||||
~RTCCertificateGeneratorCallback() override {}
|
||||
};
|
||||
|
||||
// Generates `RTCCertificate`s.
|
||||
// See `RTCCertificateGenerator` for the WebRTC repo's implementation.
|
||||
class RTCCertificateGeneratorInterface {
|
||||
public:
|
||||
virtual ~RTCCertificateGeneratorInterface() {}
|
||||
// Functor that will be called when certificate is generated asynchroniosly.
|
||||
// Called with nullptr as the parameter on failure.
|
||||
using Callback = absl::AnyInvocable<void(scoped_refptr<RTCCertificate>) &&>;
|
||||
|
||||
virtual ~RTCCertificateGeneratorInterface() = default;
|
||||
|
||||
// Generates a certificate asynchronously on the worker thread.
|
||||
// Must be called on the signaling thread. The `callback` is invoked with the
|
||||
|
@ -47,7 +41,7 @@ class RTCCertificateGeneratorInterface {
|
|||
virtual void GenerateCertificateAsync(
|
||||
const KeyParams& key_params,
|
||||
const absl::optional<uint64_t>& expires_ms,
|
||||
const scoped_refptr<RTCCertificateGeneratorCallback>& callback) = 0;
|
||||
Callback callback) = 0;
|
||||
};
|
||||
|
||||
// Standard implementation of `RTCCertificateGeneratorInterface`.
|
||||
|
@ -74,10 +68,9 @@ class RTC_EXPORT RTCCertificateGenerator
|
|||
// that many milliseconds from now. `expires_ms` is limited to a year, a
|
||||
// larger value than that is clamped down to a year. If `expires_ms` is not
|
||||
// specified, a default expiration time is used.
|
||||
void GenerateCertificateAsync(
|
||||
const KeyParams& key_params,
|
||||
const absl::optional<uint64_t>& expires_ms,
|
||||
const scoped_refptr<RTCCertificateGeneratorCallback>& callback) override;
|
||||
void GenerateCertificateAsync(const KeyParams& key_params,
|
||||
const absl::optional<uint64_t>& expires_ms,
|
||||
Callback callback) override;
|
||||
|
||||
private:
|
||||
Thread* const signaling_thread_;
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
|
||||
namespace rtc {
|
||||
|
||||
class RTCCertificateGeneratorFixture : public RTCCertificateGeneratorCallback {
|
||||
class RTCCertificateGeneratorFixture {
|
||||
public:
|
||||
RTCCertificateGeneratorFixture()
|
||||
: signaling_thread_(Thread::Current()),
|
||||
|
@ -32,21 +32,16 @@ class RTCCertificateGeneratorFixture : public RTCCertificateGeneratorCallback {
|
|||
generator_.reset(
|
||||
new RTCCertificateGenerator(signaling_thread_, worker_thread_.get()));
|
||||
}
|
||||
~RTCCertificateGeneratorFixture() override {}
|
||||
|
||||
RTCCertificateGenerator* generator() const { return generator_.get(); }
|
||||
RTCCertificate* certificate() const { return certificate_.get(); }
|
||||
|
||||
void OnSuccess(const scoped_refptr<RTCCertificate>& certificate) override {
|
||||
RTC_CHECK(signaling_thread_->IsCurrent());
|
||||
RTC_CHECK(certificate);
|
||||
certificate_ = certificate;
|
||||
generate_async_completed_ = true;
|
||||
}
|
||||
void OnFailure() override {
|
||||
RTC_CHECK(signaling_thread_->IsCurrent());
|
||||
certificate_ = nullptr;
|
||||
generate_async_completed_ = true;
|
||||
RTCCertificateGeneratorInterface::Callback OnGenerated() {
|
||||
return [this](scoped_refptr<RTCCertificate> certificate) mutable {
|
||||
RTC_CHECK(signaling_thread_->IsCurrent());
|
||||
certificate_ = std::move(certificate);
|
||||
generate_async_completed_ = true;
|
||||
};
|
||||
}
|
||||
|
||||
bool GenerateAsyncCompleted() {
|
||||
|
@ -69,14 +64,11 @@ class RTCCertificateGeneratorFixture : public RTCCertificateGeneratorCallback {
|
|||
|
||||
class RTCCertificateGeneratorTest : public ::testing::Test {
|
||||
public:
|
||||
RTCCertificateGeneratorTest()
|
||||
: fixture_(make_ref_counted<RTCCertificateGeneratorFixture>()) {}
|
||||
|
||||
protected:
|
||||
static constexpr int kGenerationTimeoutMs = 10000;
|
||||
|
||||
rtc::AutoThread main_thread_;
|
||||
scoped_refptr<RTCCertificateGeneratorFixture> fixture_;
|
||||
RTCCertificateGeneratorFixture fixture_;
|
||||
};
|
||||
|
||||
TEST_F(RTCCertificateGeneratorTest, GenerateECDSA) {
|
||||
|
@ -90,16 +82,16 @@ TEST_F(RTCCertificateGeneratorTest, GenerateRSA) {
|
|||
}
|
||||
|
||||
TEST_F(RTCCertificateGeneratorTest, GenerateAsyncECDSA) {
|
||||
EXPECT_FALSE(fixture_->certificate());
|
||||
fixture_->generator()->GenerateCertificateAsync(KeyParams::ECDSA(),
|
||||
absl::nullopt, fixture_);
|
||||
EXPECT_FALSE(fixture_.certificate());
|
||||
fixture_.generator()->GenerateCertificateAsync(
|
||||
KeyParams::ECDSA(), absl::nullopt, fixture_.OnGenerated());
|
||||
// Until generation has completed, the certificate is null. Since this is an
|
||||
// async call, generation must not have completed until we process messages
|
||||
// posted to this thread (which is done by `EXPECT_TRUE_WAIT`).
|
||||
EXPECT_FALSE(fixture_->GenerateAsyncCompleted());
|
||||
EXPECT_FALSE(fixture_->certificate());
|
||||
EXPECT_TRUE_WAIT(fixture_->GenerateAsyncCompleted(), kGenerationTimeoutMs);
|
||||
EXPECT_TRUE(fixture_->certificate());
|
||||
EXPECT_FALSE(fixture_.GenerateAsyncCompleted());
|
||||
EXPECT_FALSE(fixture_.certificate());
|
||||
EXPECT_TRUE_WAIT(fixture_.GenerateAsyncCompleted(), kGenerationTimeoutMs);
|
||||
EXPECT_TRUE(fixture_.certificate());
|
||||
}
|
||||
|
||||
TEST_F(RTCCertificateGeneratorTest, GenerateWithExpires) {
|
||||
|
@ -136,10 +128,10 @@ TEST_F(RTCCertificateGeneratorTest, GenerateWithInvalidParamsShouldFail) {
|
|||
EXPECT_FALSE(RTCCertificateGenerator::GenerateCertificate(invalid_params,
|
||||
absl::nullopt));
|
||||
|
||||
fixture_->generator()->GenerateCertificateAsync(invalid_params, absl::nullopt,
|
||||
fixture_);
|
||||
EXPECT_TRUE_WAIT(fixture_->GenerateAsyncCompleted(), kGenerationTimeoutMs);
|
||||
EXPECT_FALSE(fixture_->certificate());
|
||||
fixture_.generator()->GenerateCertificateAsync(invalid_params, absl::nullopt,
|
||||
fixture_.OnGenerated());
|
||||
EXPECT_TRUE_WAIT(fixture_.GenerateAsyncCompleted(), kGenerationTimeoutMs);
|
||||
EXPECT_FALSE(fixture_.certificate());
|
||||
}
|
||||
|
||||
} // namespace rtc
|
||||
|
|
Loading…
Reference in a new issue