sctp: Fix data channel closing sequence

When an SCTP stream is closing, a stream reset needs
to be sent from both ends.
The remote was not sending a stream reset and quickly
opening another stream with the same StreamID could
cause SCTP errors.

Bug: webrtc:13994
Change-Id: I3abc74ddc88b3fcf7e6495d76e7d77f52280b5d1
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/260922
Commit-Queue: Florent Castelli <orphis@webrtc.org>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Reviewed-by: Mirko Bonadei <mbonadei@webrtc.org>
Reviewed-by: Victor Boivie <boivie@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#36773}
This commit is contained in:
Florent Castelli 2022-05-03 00:24:15 +02:00 committed by WebRTC LUCI CQ
parent 95b1a3497c
commit e3b74f8e61
9 changed files with 223 additions and 10 deletions

View file

@ -423,6 +423,7 @@ if (rtc_build_dcsctp) {
"../rtc_base:socket", "../rtc_base:socket",
"../rtc_base:stringutils", "../rtc_base:stringutils",
"../rtc_base:threading", "../rtc_base:threading",
"../rtc_base/containers:flat_set",
"../rtc_base/task_utils:pending_task_safety_flag", "../rtc_base/task_utils:pending_task_safety_flag",
"../rtc_base/task_utils:to_queued_task", "../rtc_base/task_utils:to_queued_task",
"../rtc_base/third_party/sigslot:sigslot", "../rtc_base/third_party/sigslot:sigslot",
@ -692,6 +693,16 @@ if (rtc_include_tests) {
if (is_ios) { if (is_ios) {
deps += [ ":rtc_media_unittests_bundle_data" ] deps += [ ":rtc_media_unittests_bundle_data" ]
} }
if (rtc_build_dcsctp) {
sources += [ "sctp/dcsctp_transport_unittest.cc" ]
deps += [
":rtc_data_dcsctp_transport",
"../net/dcsctp/public:factory",
"../net/dcsctp/public:mocks",
"../net/dcsctp/public:socket",
]
}
} }
} }
} }

View file

@ -116,10 +116,21 @@ bool IsEmptyPPID(dcsctp::PPID ppid) {
DcSctpTransport::DcSctpTransport(rtc::Thread* network_thread, DcSctpTransport::DcSctpTransport(rtc::Thread* network_thread,
rtc::PacketTransportInternal* transport, rtc::PacketTransportInternal* transport,
Clock* clock) Clock* clock)
: DcSctpTransport(network_thread,
transport,
clock,
std::make_unique<dcsctp::DcSctpSocketFactory>()) {}
DcSctpTransport::DcSctpTransport(
rtc::Thread* network_thread,
rtc::PacketTransportInternal* transport,
Clock* clock,
std::unique_ptr<dcsctp::DcSctpSocketFactory> socket_factory)
: network_thread_(network_thread), : network_thread_(network_thread),
transport_(transport), transport_(transport),
clock_(clock), clock_(clock),
random_(clock_->TimeInMicroseconds()), random_(clock_->TimeInMicroseconds()),
socket_factory_(std::move(socket_factory)),
task_queue_timeout_factory_( task_queue_timeout_factory_(
*network_thread, *network_thread,
[this]() { return TimeMillis(); }, [this]() { return TimeMillis(); },
@ -175,9 +186,8 @@ bool DcSctpTransport::Start(int local_sctp_port,
std::make_unique<dcsctp::TextPcapPacketObserver>(debug_name_); std::make_unique<dcsctp::TextPcapPacketObserver>(debug_name_);
} }
dcsctp::DcSctpSocketFactory factory; socket_ = socket_factory_->Create(debug_name_, *this,
socket_ = std::move(packet_observer), options);
factory.Create(debug_name_, *this, std::move(packet_observer), options);
} else { } else {
if (local_sctp_port != socket_->options().local_port || if (local_sctp_port != socket_->options().local_port ||
remote_sctp_port != socket_->options().remote_port) { remote_sctp_port != socket_->options().remote_port) {
@ -202,6 +212,7 @@ bool DcSctpTransport::OpenStream(int sid) {
<< "): Transport is not started."; << "): Transport is not started.";
return false; return false;
} }
local_close_.erase(dcsctp::StreamID(static_cast<uint16_t>(sid)));
return true; return true;
} }
@ -213,6 +224,7 @@ bool DcSctpTransport::ResetStream(int sid) {
return false; return false;
} }
dcsctp::StreamID streams[1] = {dcsctp::StreamID(static_cast<uint16_t>(sid))}; dcsctp::StreamID streams[1] = {dcsctp::StreamID(static_cast<uint16_t>(sid))};
local_close_.insert(streams[0]);
socket_->ResetStreams(streams); socket_->ResetStreams(streams);
return true; return true;
} }
@ -472,7 +484,11 @@ void DcSctpTransport::OnStreamsResetPerformed(
RTC_LOG(LS_INFO) << debug_name_ RTC_LOG(LS_INFO) << debug_name_
<< "->OnStreamsResetPerformed(...): Outgoing stream reset" << "->OnStreamsResetPerformed(...): Outgoing stream reset"
<< ", sid=" << stream_id.value(); << ", sid=" << stream_id.value();
SignalClosingProcedureComplete(stream_id.value()); if (!local_close_.contains(stream_id)) {
// When the close was not initiated locally, we can signal the end of the
// data channel close procedure when the remote ACKs the reset.
SignalClosingProcedureComplete(stream_id.value());
}
} }
} }
@ -482,8 +498,18 @@ void DcSctpTransport::OnIncomingStreamsReset(
RTC_LOG(LS_INFO) << debug_name_ RTC_LOG(LS_INFO) << debug_name_
<< "->OnIncomingStreamsReset(...): Incoming stream reset" << "->OnIncomingStreamsReset(...): Incoming stream reset"
<< ", sid=" << stream_id.value(); << ", sid=" << stream_id.value();
SignalClosingProcedureStartedRemotely(stream_id.value()); if (!local_close_.contains(stream_id)) {
SignalClosingProcedureComplete(stream_id.value()); // When receiving an incoming stream reset event for a non local close
// procedure, the transport needs to reset the stream in the other
// direction too.
dcsctp::StreamID streams[1] = {stream_id};
socket_->ResetStreams(streams);
SignalClosingProcedureStartedRemotely(stream_id.value());
} else {
// The close procedure that was initiated locally is complete when we
// receive and incoming reset event.
SignalClosingProcedureComplete(stream_id.value());
}
} }
} }

View file

@ -21,9 +21,11 @@
#include "media/sctp/sctp_transport_internal.h" #include "media/sctp/sctp_transport_internal.h"
#include "net/dcsctp/public/dcsctp_options.h" #include "net/dcsctp/public/dcsctp_options.h"
#include "net/dcsctp/public/dcsctp_socket.h" #include "net/dcsctp/public/dcsctp_socket.h"
#include "net/dcsctp/public/dcsctp_socket_factory.h"
#include "net/dcsctp/public/types.h" #include "net/dcsctp/public/types.h"
#include "net/dcsctp/timer/task_queue_timeout.h" #include "net/dcsctp/timer/task_queue_timeout.h"
#include "p2p/base/packet_transport_internal.h" #include "p2p/base/packet_transport_internal.h"
#include "rtc_base/containers/flat_set.h"
#include "rtc_base/copy_on_write_buffer.h" #include "rtc_base/copy_on_write_buffer.h"
#include "rtc_base/random.h" #include "rtc_base/random.h"
#include "rtc_base/third_party/sigslot/sigslot.h" #include "rtc_base/third_party/sigslot/sigslot.h"
@ -39,6 +41,10 @@ class DcSctpTransport : public cricket::SctpTransportInternal,
DcSctpTransport(rtc::Thread* network_thread, DcSctpTransport(rtc::Thread* network_thread,
rtc::PacketTransportInternal* transport, rtc::PacketTransportInternal* transport,
Clock* clock); Clock* clock);
DcSctpTransport(rtc::Thread* network_thread,
rtc::PacketTransportInternal* transport,
Clock* clock,
std::unique_ptr<dcsctp::DcSctpSocketFactory> socket_factory);
~DcSctpTransport() override; ~DcSctpTransport() override;
// cricket::SctpTransportInternal // cricket::SctpTransportInternal
@ -99,11 +105,13 @@ class DcSctpTransport : public cricket::SctpTransportInternal,
Clock* clock_; Clock* clock_;
Random random_; Random random_;
std::unique_ptr<dcsctp::DcSctpSocketFactory> socket_factory_;
dcsctp::TaskQueueTimeoutFactory task_queue_timeout_factory_; dcsctp::TaskQueueTimeoutFactory task_queue_timeout_factory_;
std::unique_ptr<dcsctp::DcSctpSocketInterface> socket_; std::unique_ptr<dcsctp::DcSctpSocketInterface> socket_;
std::string debug_name_ = "DcSctpTransport"; std::string debug_name_ = "DcSctpTransport";
rtc::CopyOnWriteBuffer receive_buffer_; rtc::CopyOnWriteBuffer receive_buffer_;
flat_set<dcsctp::StreamID> local_close_;
bool ready_to_send_data_ = false; bool ready_to_send_data_ = false;
}; };

View file

@ -0,0 +1,129 @@
/*
* Copyright 2022 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "media/sctp/dcsctp_transport.h"
#include <memory>
#include <utility>
#include "net/dcsctp/public/mock_dcsctp_socket.h"
#include "net/dcsctp/public/mock_dcsctp_socket_factory.h"
#include "p2p/base/fake_packet_transport.h"
#include "test/gtest.h"
using ::testing::ByMove;
using ::testing::DoAll;
using ::testing::ElementsAre;
using ::testing::InSequence;
using ::testing::Invoke;
using ::testing::NiceMock;
using ::testing::Return;
namespace webrtc {
namespace {
class SctpInternalTransportObserver : public sigslot::has_slots<> {
public:
MOCK_METHOD(void, OnSignalReadyToSendData, ());
MOCK_METHOD(void, OnSignalAssociationChangeCommunicationUp, ());
MOCK_METHOD(void, OnSignalClosingProcedureStartedRemotely, (int));
MOCK_METHOD(void, OnSignalClosingProcedureComplete, (int));
};
class Peer {
public:
Peer() : fake_packet_transport_("transport"), simulated_clock_(1000) {
auto socket_ptr = std::make_unique<dcsctp::MockDcSctpSocket>();
socket_ = socket_ptr.get();
auto mock_dcsctp_socket_factory =
std::make_unique<dcsctp::MockDcSctpSocketFactory>();
EXPECT_CALL(*mock_dcsctp_socket_factory, Create)
.Times(1)
.WillOnce(Return(ByMove(std::move(socket_ptr))));
sctp_transport_ = std::make_unique<webrtc::DcSctpTransport>(
rtc::Thread::Current(), &fake_packet_transport_, &simulated_clock_,
std::move(mock_dcsctp_socket_factory));
sctp_transport_->SignalAssociationChangeCommunicationUp.connect(
static_cast<SctpInternalTransportObserver*>(&observer_),
&SctpInternalTransportObserver::OnSignalReadyToSendData);
sctp_transport_->SignalAssociationChangeCommunicationUp.connect(
static_cast<SctpInternalTransportObserver*>(&observer_),
&SctpInternalTransportObserver::
OnSignalAssociationChangeCommunicationUp);
sctp_transport_->SignalClosingProcedureStartedRemotely.connect(
static_cast<SctpInternalTransportObserver*>(&observer_),
&SctpInternalTransportObserver::
OnSignalClosingProcedureStartedRemotely);
sctp_transport_->SignalClosingProcedureComplete.connect(
static_cast<SctpInternalTransportObserver*>(&observer_),
&SctpInternalTransportObserver::OnSignalClosingProcedureComplete);
}
rtc::FakePacketTransport fake_packet_transport_;
webrtc::SimulatedClock simulated_clock_;
dcsctp::MockDcSctpSocket* socket_;
std::unique_ptr<webrtc::DcSctpTransport> sctp_transport_;
NiceMock<SctpInternalTransportObserver> observer_;
};
} // namespace
TEST(DcSctpTransportTest, OpenSequence) {
Peer peer_a;
peer_a.fake_packet_transport_.SetWritable(true);
EXPECT_CALL(*peer_a.socket_, Connect)
.Times(1)
.WillOnce(Invoke(peer_a.sctp_transport_.get(),
&dcsctp::DcSctpSocketCallbacks::OnConnected));
EXPECT_CALL(peer_a.observer_, OnSignalReadyToSendData);
EXPECT_CALL(peer_a.observer_, OnSignalAssociationChangeCommunicationUp);
peer_a.sctp_transport_->Start(5000, 5000, 256 * 1024);
}
TEST(DcSctpTransportTest, CloseSequence) {
Peer peer_a;
Peer peer_b;
peer_a.fake_packet_transport_.SetDestination(&peer_b.fake_packet_transport_,
false);
{
InSequence sequence;
EXPECT_CALL(*peer_a.socket_, ResetStreams(ElementsAre(dcsctp::StreamID(1))))
.WillOnce(DoAll(
Invoke(peer_b.sctp_transport_.get(),
&dcsctp::DcSctpSocketCallbacks::OnIncomingStreamsReset),
Invoke(peer_a.sctp_transport_.get(),
&dcsctp::DcSctpSocketCallbacks::OnStreamsResetPerformed),
Return(dcsctp::ResetStreamsStatus::kPerformed)));
EXPECT_CALL(*peer_b.socket_, ResetStreams(ElementsAre(dcsctp::StreamID(1))))
.WillOnce(DoAll(
Invoke(peer_a.sctp_transport_.get(),
&dcsctp::DcSctpSocketCallbacks::OnIncomingStreamsReset),
Invoke(peer_b.sctp_transport_.get(),
&dcsctp::DcSctpSocketCallbacks::OnStreamsResetPerformed),
Return(dcsctp::ResetStreamsStatus::kPerformed)));
EXPECT_CALL(peer_a.observer_, OnSignalClosingProcedureComplete(1));
EXPECT_CALL(peer_b.observer_, OnSignalClosingProcedureComplete(1));
EXPECT_CALL(peer_b.observer_, OnSignalClosingProcedureStartedRemotely(1));
}
peer_a.sctp_transport_->Start(5000, 5000, 256 * 1024);
peer_b.sctp_transport_->Start(5000, 5000, 256 * 1024);
peer_a.sctp_transport_->OpenStream(1);
peer_a.sctp_transport_->ResetStream(1);
}
} // namespace webrtc

View file

@ -57,8 +57,12 @@ rtc_source_set("factory") {
rtc_source_set("mocks") { rtc_source_set("mocks") {
testonly = true testonly = true
sources = [ "mock_dcsctp_socket.h" ] sources = [
"mock_dcsctp_socket.h",
"mock_dcsctp_socket_factory.h",
]
deps = [ deps = [
":factory",
":socket", ":socket",
"../../../test:test_support", "../../../test:test_support",
] ]

View file

@ -20,6 +20,9 @@
#include "net/dcsctp/socket/dcsctp_socket.h" #include "net/dcsctp/socket/dcsctp_socket.h"
namespace dcsctp { namespace dcsctp {
DcSctpSocketFactory::~DcSctpSocketFactory() = default;
std::unique_ptr<DcSctpSocketInterface> DcSctpSocketFactory::Create( std::unique_ptr<DcSctpSocketInterface> DcSctpSocketFactory::Create(
absl::string_view log_prefix, absl::string_view log_prefix,
DcSctpSocketCallbacks& callbacks, DcSctpSocketCallbacks& callbacks,

View file

@ -20,7 +20,8 @@
namespace dcsctp { namespace dcsctp {
class DcSctpSocketFactory { class DcSctpSocketFactory {
public: public:
std::unique_ptr<DcSctpSocketInterface> Create( virtual ~DcSctpSocketFactory();
virtual std::unique_ptr<DcSctpSocketInterface> Create(
absl::string_view log_prefix, absl::string_view log_prefix,
DcSctpSocketCallbacks& callbacks, DcSctpSocketCallbacks& callbacks,
std::unique_ptr<PacketObserver> packet_observer, std::unique_ptr<PacketObserver> packet_observer,

View file

@ -0,0 +1,33 @@
/*
* Copyright (c) 2022 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef NET_DCSCTP_PUBLIC_MOCK_DCSCTP_SOCKET_FACTORY_H_
#define NET_DCSCTP_PUBLIC_MOCK_DCSCTP_SOCKET_FACTORY_H_
#include <memory>
#include "net/dcsctp/public/dcsctp_socket_factory.h"
#include "test/gmock.h"
namespace dcsctp {
class MockDcSctpSocketFactory : public DcSctpSocketFactory {
public:
MOCK_METHOD(std::unique_ptr<DcSctpSocketInterface>,
Create,
(absl::string_view log_prefix,
DcSctpSocketCallbacks& callbacks,
std::unique_ptr<PacketObserver> packet_observer,
const DcSctpOptions& options),
(override));
};
} // namespace dcsctp
#endif // NET_DCSCTP_PUBLIC_MOCK_DCSCTP_SOCKET_FACTORY_H_

View file

@ -923,7 +923,6 @@ public class PeerConnectionEndToEndTest {
answeringExpectations.expectStateChange(DataChannel.State.CLOSING); answeringExpectations.expectStateChange(DataChannel.State.CLOSING);
offeringExpectations.expectStateChange(DataChannel.State.CLOSED); offeringExpectations.expectStateChange(DataChannel.State.CLOSED);
answeringExpectations.expectStateChange(DataChannel.State.CLOSED); answeringExpectations.expectStateChange(DataChannel.State.CLOSED);
answeringExpectations.dataChannel.close();
offeringExpectations.dataChannel.close(); offeringExpectations.dataChannel.close();
assertTrue(offeringExpectations.waitForAllExpectationsToBeSatisfied(DEFAULT_TIMEOUT_SECONDS)); assertTrue(offeringExpectations.waitForAllExpectationsToBeSatisfied(DEFAULT_TIMEOUT_SECONDS));
assertTrue(answeringExpectations.waitForAllExpectationsToBeSatisfied(DEFAULT_TIMEOUT_SECONDS)); assertTrue(answeringExpectations.waitForAllExpectationsToBeSatisfied(DEFAULT_TIMEOUT_SECONDS));
@ -1094,7 +1093,6 @@ public class PeerConnectionEndToEndTest {
answeringExpectations.expectStateChange(DataChannel.State.CLOSING); answeringExpectations.expectStateChange(DataChannel.State.CLOSING);
offeringExpectations.expectStateChange(DataChannel.State.CLOSED); offeringExpectations.expectStateChange(DataChannel.State.CLOSED);
answeringExpectations.expectStateChange(DataChannel.State.CLOSED); answeringExpectations.expectStateChange(DataChannel.State.CLOSED);
answeringExpectations.dataChannel.close();
offeringExpectations.dataChannel.close(); offeringExpectations.dataChannel.close();
assertTrue(offeringExpectations.waitForAllExpectationsToBeSatisfied(DEFAULT_TIMEOUT_SECONDS)); assertTrue(offeringExpectations.waitForAllExpectationsToBeSatisfied(DEFAULT_TIMEOUT_SECONDS));
assertTrue(answeringExpectations.waitForAllExpectationsToBeSatisfied(DEFAULT_TIMEOUT_SECONDS)); assertTrue(answeringExpectations.waitForAllExpectationsToBeSatisfied(DEFAULT_TIMEOUT_SECONDS));