diff --git a/pc/jsep_transport.cc b/pc/jsep_transport.cc index faff4e8cf4..a6c9ebc10e 100644 --- a/pc/jsep_transport.cc +++ b/pc/jsep_transport.cc @@ -96,7 +96,8 @@ JsepTransport::JsepTransport( : nullptr), sctp_transport_(sctp_transport ? rtc::make_ref_counted( - std::move(sctp_transport)) + std::move(sctp_transport), + rtp_dtls_transport_) : nullptr), rtcp_mux_active_callback_(std::move(rtcp_mux_active_callback)) { TRACE_EVENT0("webrtc", "JsepTransport::JsepTransport"); @@ -118,10 +119,6 @@ JsepTransport::JsepTransport( RTC_DCHECK(!unencrypted_rtp_transport); RTC_DCHECK(!sdes_transport); } - - if (sctp_transport_) { - sctp_transport_->SetDtlsTransport(rtp_dtls_transport_); - } } JsepTransport::~JsepTransport() { diff --git a/pc/sctp_transport.cc b/pc/sctp_transport.cc index eb60930389..6c5e66fe6c 100644 --- a/pc/sctp_transport.cc +++ b/pc/sctp_transport.cc @@ -22,19 +22,27 @@ namespace webrtc { SctpTransport::SctpTransport( - std::unique_ptr internal) + std::unique_ptr internal, + rtc::scoped_refptr dtls_transport) : owner_thread_(rtc::Thread::Current()), - info_(SctpTransportState::kNew), - internal_sctp_transport_(std::move(internal)) { + info_(SctpTransportState::kConnecting, + dtls_transport, + /*max_message_size=*/absl::nullopt, + /*max_channels=*/absl::nullopt), + internal_sctp_transport_(std::move(internal)), + dtls_transport_(dtls_transport) { RTC_DCHECK(internal_sctp_transport_.get()); + RTC_DCHECK(dtls_transport_.get()); + + dtls_transport_->internal()->SubscribeDtlsTransportState( + [this](cricket::DtlsTransportInternal* transport, + DtlsTransportState state) { + OnDtlsStateChange(transport, state); + }); + + internal_sctp_transport_->SetDtlsTransport(dtls_transport->internal()); internal_sctp_transport_->SetOnConnectedCallback( [this]() { OnAssociationChangeCommunicationUp(); }); - - if (dtls_transport_) { - UpdateInformation(SctpTransportState::kConnecting); - } else { - UpdateInformation(SctpTransportState::kNew); - } } SctpTransport::~SctpTransport() { @@ -134,31 +142,6 @@ void SctpTransport::Clear() { UpdateInformation(SctpTransportState::kClosed); } -void SctpTransport::SetDtlsTransport( - rtc::scoped_refptr transport) { - RTC_DCHECK_RUN_ON(owner_thread_); - SctpTransportState next_state = info_.state(); - dtls_transport_ = transport; - if (internal_sctp_transport_) { - if (transport) { - internal_sctp_transport_->SetDtlsTransport(transport->internal()); - - transport->internal()->SubscribeDtlsTransportState( - [this](cricket::DtlsTransportInternal* transport, - DtlsTransportState state) { - OnDtlsStateChange(transport, state); - }); - if (info_.state() == SctpTransportState::kNew) { - next_state = SctpTransportState::kConnecting; - } - } else { - internal_sctp_transport_->SetDtlsTransport(nullptr); - } - } - - UpdateInformation(next_state); -} - void SctpTransport::Start(int local_port, int remote_port, int max_message_size) { diff --git a/pc/sctp_transport.h b/pc/sctp_transport.h index 60434d829f..5508843162 100644 --- a/pc/sctp_transport.h +++ b/pc/sctp_transport.h @@ -35,8 +35,8 @@ namespace webrtc { class SctpTransport : public SctpTransportInterface, public DataChannelTransportInterface { public: - explicit SctpTransport( - std::unique_ptr internal); + SctpTransport(std::unique_ptr internal, + rtc::scoped_refptr dtls_transport); // SctpTransportInterface rtc::scoped_refptr dtls_transport() const override; @@ -58,7 +58,6 @@ class SctpTransport : public SctpTransportInterface, // Internal functions void Clear(); - void SetDtlsTransport(rtc::scoped_refptr); // Initialize the cricket::SctpTransport. This can be called from // the signaling thread. void Start(int local_port, int remote_port, int max_message_size); diff --git a/pc/sctp_transport_unittest.cc b/pc/sctp_transport_unittest.cc index 2849889992..4eb83d375a 100644 --- a/pc/sctp_transport_unittest.cc +++ b/pc/sctp_transport_unittest.cc @@ -117,19 +117,16 @@ class SctpTransportTest : public ::testing::Test { SctpTransportObserverInterface* observer() { return &observer_; } void CreateTransport() { - auto cricket_sctp_transport = - absl::WrapUnique(new FakeCricketSctpTransport()); - transport_ = - rtc::make_ref_counted(std::move(cricket_sctp_transport)); - } - - void AddDtlsTransport() { std::unique_ptr cricket_transport = std::make_unique( "audio", cricket::ICE_CANDIDATE_COMPONENT_RTP); dtls_transport_ = rtc::make_ref_counted(std::move(cricket_transport)); - transport_->SetDtlsTransport(dtls_transport_); + + auto cricket_sctp_transport = + absl::WrapUnique(new FakeCricketSctpTransport()); + transport_ = rtc::make_ref_counted( + std::move(cricket_sctp_transport), dtls_transport_); } void CompleteSctpHandshake() { @@ -152,13 +149,20 @@ class SctpTransportTest : public ::testing::Test { TEST(SctpTransportSimpleTest, CreateClearDelete) { rtc::AutoThread main_thread; + std::unique_ptr cricket_transport = + std::make_unique("audio", + cricket::ICE_CANDIDATE_COMPONENT_RTP); + rtc::scoped_refptr dtls_transport = + rtc::make_ref_counted(std::move(cricket_transport)); + std::unique_ptr fake_cricket_sctp_transport = absl::WrapUnique(new FakeCricketSctpTransport()); rtc::scoped_refptr sctp_transport = rtc::make_ref_counted( - std::move(fake_cricket_sctp_transport)); + std::move(fake_cricket_sctp_transport), dtls_transport); ASSERT_TRUE(sctp_transport->internal()); - ASSERT_EQ(SctpTransportState::kNew, sctp_transport->Information().state()); + ASSERT_EQ(SctpTransportState::kConnecting, + sctp_transport->Information().state()); sctp_transport->Clear(); ASSERT_FALSE(sctp_transport->internal()); ASSERT_EQ(SctpTransportState::kClosed, sctp_transport->Information().state()); @@ -167,18 +171,15 @@ TEST(SctpTransportSimpleTest, CreateClearDelete) { TEST_F(SctpTransportTest, EventsObservedWhenConnecting) { CreateTransport(); transport()->RegisterObserver(observer()); - AddDtlsTransport(); CompleteSctpHandshake(); ASSERT_EQ_WAIT(SctpTransportState::kConnected, observer_.State(), kDefaultTimeout); - EXPECT_THAT(observer_.States(), ElementsAre(SctpTransportState::kConnecting, - SctpTransportState::kConnected)); + EXPECT_THAT(observer_.States(), ElementsAre(SctpTransportState::kConnected)); } TEST_F(SctpTransportTest, CloseWhenClearing) { CreateTransport(); transport()->RegisterObserver(observer()); - AddDtlsTransport(); CompleteSctpHandshake(); ASSERT_EQ_WAIT(SctpTransportState::kConnected, observer_.State(), kDefaultTimeout); @@ -190,7 +191,6 @@ TEST_F(SctpTransportTest, CloseWhenClearing) { TEST_F(SctpTransportTest, MaxChannelsSignalled) { CreateTransport(); transport()->RegisterObserver(observer()); - AddDtlsTransport(); EXPECT_FALSE(transport()->Information().MaxChannels()); EXPECT_FALSE(observer_.LastReceivedInformation().MaxChannels()); CompleteSctpHandshake(); @@ -206,7 +206,6 @@ TEST_F(SctpTransportTest, MaxChannelsSignalled) { TEST_F(SctpTransportTest, CloseWhenTransportCloses) { CreateTransport(); transport()->RegisterObserver(observer()); - AddDtlsTransport(); CompleteSctpHandshake(); ASSERT_EQ_WAIT(SctpTransportState::kConnected, observer_.State(), kDefaultTimeout);