diff --git a/api/transport/data_channel_transport_interface.h b/api/transport/data_channel_transport_interface.h index 27d7de6364..2476166c7b 100644 --- a/api/transport/data_channel_transport_interface.h +++ b/api/transport/data_channel_transport_interface.h @@ -86,6 +86,10 @@ class DataChannelSink { // TODO(https://crbug.com/webrtc/10360): Make pure virtual when all // consumers updated. virtual void OnTransportClosed(RTCError error) {} + + // The data channel's buffered_amount has fallen to or below the threshold + // set when calling `SetBufferedAmountLowThreshold` + virtual void OnBufferedAmountLow(int channel_id) = 0; }; // Transport for data channels. @@ -120,6 +124,8 @@ class DataChannelTransportInterface { virtual bool IsReadyToSend() const = 0; virtual size_t buffered_amount(int channel_id) const = 0; + virtual size_t buffered_amount_low_threshold(int channel_id) const = 0; + virtual void SetBufferedAmountLowThreshold(int channel_id, size_t bytes) = 0; }; } // namespace webrtc diff --git a/media/sctp/dcsctp_transport.cc b/media/sctp/dcsctp_transport.cc index 52890b3bd2..53a535f245 100644 --- a/media/sctp/dcsctp_transport.cc +++ b/media/sctp/dcsctp_transport.cc @@ -381,6 +381,18 @@ size_t DcSctpTransport::buffered_amount(int sid) const { return socket_->buffered_amount(dcsctp::StreamID(sid)); } +size_t DcSctpTransport::buffered_amount_low_threshold(int sid) const { + if (!socket_) + return 0; + return socket_->buffered_amount_low_threshold(dcsctp::StreamID(sid)); +} + +void DcSctpTransport::SetBufferedAmountLowThreshold(int sid, size_t bytes) { + if (!socket_) + return; + socket_->SetBufferedAmountLowThreshold(dcsctp::StreamID(sid), bytes); +} + void DcSctpTransport::set_debug_name_for_testing(const char* debug_name) { debug_name_ = debug_name; } @@ -446,6 +458,13 @@ void DcSctpTransport::OnTotalBufferedAmountLow() { } } +void DcSctpTransport::OnBufferedAmountLow(dcsctp::StreamID stream_id) { + RTC_DCHECK_RUN_ON(network_thread_); + if (data_channel_sink_) { + data_channel_sink_->OnBufferedAmountLow(*stream_id); + } +} + void DcSctpTransport::OnMessageReceived(dcsctp::DcSctpMessage message) { RTC_DCHECK_RUN_ON(network_thread_); RTC_DLOG(LS_VERBOSE) << debug_name_ << "->OnMessageReceived(sid=" diff --git a/media/sctp/dcsctp_transport.h b/media/sctp/dcsctp_transport.h index 21a6a9513f..958c54bb69 100644 --- a/media/sctp/dcsctp_transport.h +++ b/media/sctp/dcsctp_transport.h @@ -67,6 +67,8 @@ class DcSctpTransport : public cricket::SctpTransportInternal, absl::optional max_outbound_streams() const override; absl::optional max_inbound_streams() const override; size_t buffered_amount(int sid) const override; + size_t buffered_amount_low_threshold(int sid) const override; + void SetBufferedAmountLowThreshold(int sid, size_t bytes) override; void set_debug_name_for_testing(const char* debug_name) override; private: @@ -78,6 +80,7 @@ class DcSctpTransport : public cricket::SctpTransportInternal, dcsctp::TimeMs TimeMillis() override; uint32_t GetRandomInt(uint32_t low, uint32_t high) override; void OnTotalBufferedAmountLow() override; + void OnBufferedAmountLow(dcsctp::StreamID stream_id) override; void OnMessageReceived(dcsctp::DcSctpMessage message) override; void OnError(dcsctp::ErrorKind error, absl::string_view message) override; void OnAborted(dcsctp::ErrorKind error, absl::string_view message) override; diff --git a/media/sctp/dcsctp_transport_unittest.cc b/media/sctp/dcsctp_transport_unittest.cc index 9642cf6205..adc8c125da 100644 --- a/media/sctp/dcsctp_transport_unittest.cc +++ b/media/sctp/dcsctp_transport_unittest.cc @@ -45,6 +45,7 @@ class MockDataChannelSink : public DataChannelSink { MOCK_METHOD(void, OnChannelClosed, (int)); MOCK_METHOD(void, OnReadyToSend, ()); MOCK_METHOD(void, OnTransportClosed, (RTCError)); + MOCK_METHOD(void, OnBufferedAmountLow, (int channel_id), (override)); }; static_assert(!std::is_abstract_v); diff --git a/media/sctp/sctp_transport_internal.h b/media/sctp/sctp_transport_internal.h index 705f5bd3e6..62bb5e5f26 100644 --- a/media/sctp/sctp_transport_internal.h +++ b/media/sctp/sctp_transport_internal.h @@ -142,6 +142,8 @@ class SctpTransportInternal { virtual absl::optional max_inbound_streams() const = 0; // Returns the amount of buffered data in the send queue for a stream. virtual size_t buffered_amount(int sid) const = 0; + virtual size_t buffered_amount_low_threshold(int sid) const = 0; + virtual void SetBufferedAmountLowThreshold(int sid, size_t bytes) = 0; // Helper for debugging. virtual void set_debug_name_for_testing(const char* debug_name) = 0; diff --git a/pc/data_channel_controller.cc b/pc/data_channel_controller.cc index fbe639f96b..b95ee8d4a3 100644 --- a/pc/data_channel_controller.cc +++ b/pc/data_channel_controller.cc @@ -97,6 +97,26 @@ size_t DataChannelController::buffered_amount(StreamId sid) const { return data_channel_transport_->buffered_amount(sid.stream_id_int()); } +size_t DataChannelController::buffered_amount_low_threshold( + StreamId sid) const { + RTC_DCHECK_RUN_ON(network_thread()); + if (!data_channel_transport_) { + return 0; + } + return data_channel_transport_->buffered_amount_low_threshold( + sid.stream_id_int()); +} + +void DataChannelController::SetBufferedAmountLowThreshold(StreamId sid, + size_t bytes) { + RTC_DCHECK_RUN_ON(network_thread()); + if (!data_channel_transport_) { + return; + } + data_channel_transport_->SetBufferedAmountLowThreshold(sid.stream_id_int(), + bytes); +} + void DataChannelController::OnDataReceived( int channel_id, DataMessageType type, @@ -171,6 +191,16 @@ void DataChannelController::OnTransportClosed(RTCError error) { } } +void DataChannelController::OnBufferedAmountLow(int channel_id) { + RTC_DCHECK_RUN_ON(network_thread()); + auto it = absl::c_find_if(sctp_data_channels_n_, [&](const auto& c) { + return c->sid_n().has_value() && c->sid_n()->stream_id_int() == channel_id; + }); + + if (it != sctp_data_channels_n_.end()) + (*it)->OnBufferedAmountLow(); +} + void DataChannelController::SetupDataChannelTransport_n( DataChannelTransportInterface* transport) { RTC_DCHECK_RUN_ON(network_thread()); diff --git a/pc/data_channel_controller.h b/pc/data_channel_controller.h index d2a9a1a135..fe1024db6d 100644 --- a/pc/data_channel_controller.h +++ b/pc/data_channel_controller.h @@ -55,6 +55,8 @@ class DataChannelController : public SctpDataChannelControllerInterface, void OnChannelStateChanged(SctpDataChannel* channel, DataChannelInterface::DataState state) override; size_t buffered_amount(StreamId sid) const override; + size_t buffered_amount_low_threshold(StreamId sid) const override; + void SetBufferedAmountLowThreshold(StreamId sid, size_t bytes) override; // Implements DataChannelSink. void OnDataReceived(int channel_id, @@ -64,6 +66,7 @@ class DataChannelController : public SctpDataChannelControllerInterface, void OnChannelClosed(int channel_id) override; void OnReadyToSend() override; void OnTransportClosed(RTCError error) override; + void OnBufferedAmountLow(int channel_id) override; // Called as part of destroying the owning PeerConnection. void PrepareForShutdown(); diff --git a/pc/data_channel_controller_unittest.cc b/pc/data_channel_controller_unittest.cc index caf9a76c41..f49d2e6db5 100644 --- a/pc/data_channel_controller_unittest.cc +++ b/pc/data_channel_controller_unittest.cc @@ -42,6 +42,14 @@ class MockDataChannelTransport : public DataChannelTransportInterface { MOCK_METHOD(void, SetDataSink, (DataChannelSink * sink), (override)); MOCK_METHOD(bool, IsReadyToSend, (), (const, override)); MOCK_METHOD(size_t, buffered_amount, (int channel_id), (const, override)); + MOCK_METHOD(size_t, + buffered_amount_low_threshold, + (int channel_id), + (const, override)); + MOCK_METHOD(void, + SetBufferedAmountLowThreshold, + (int channel_id, size_t bytes), + (override)); }; // Convenience class for tests to ensure that shutdown methods for DCC diff --git a/pc/sctp_data_channel.cc b/pc/sctp_data_channel.cc index 7ec314d2f7..e496382c35 100644 --- a/pc/sctp_data_channel.cc +++ b/pc/sctp_data_channel.cc @@ -662,6 +662,10 @@ void SctpDataChannel::OnTransportChannelClosed(RTCError error) { CloseAbruptlyWithError(std::move(error)); } +void SctpDataChannel::OnBufferedAmountLow() { + RTC_DCHECK_RUN_ON(network_thread_); +} + DataChannelStats SctpDataChannel::GetStats() const { RTC_DCHECK_RUN_ON(network_thread_); DataChannelStats stats{internal_id_, id(), label(), diff --git a/pc/sctp_data_channel.h b/pc/sctp_data_channel.h index 0be234bd16..fcd088ce04 100644 --- a/pc/sctp_data_channel.h +++ b/pc/sctp_data_channel.h @@ -56,6 +56,8 @@ class SctpDataChannelControllerInterface { virtual void OnChannelStateChanged(SctpDataChannel* data_channel, DataChannelInterface::DataState state) = 0; virtual size_t buffered_amount(StreamId sid) const = 0; + virtual size_t buffered_amount_low_threshold(StreamId sid) const = 0; + virtual void SetBufferedAmountLowThreshold(StreamId sid, size_t bytes) = 0; protected: virtual ~SctpDataChannelControllerInterface() {} @@ -208,6 +210,9 @@ class SctpDataChannel : public DataChannelInterface { // This method makes sure the DataChannel is disconnected and changes state // to kClosed. void OnTransportChannelClosed(RTCError error); + // Called when the amount of data buffered to be sent falls to or below the + // threshold set when calling `SetBufferedAmountLowThreshold`. + void OnBufferedAmountLow(); DataChannelStats GetStats() const; diff --git a/pc/sctp_transport.cc b/pc/sctp_transport.cc index 5f505e0296..eb60930389 100644 --- a/pc/sctp_transport.cc +++ b/pc/sctp_transport.cc @@ -106,6 +106,17 @@ size_t SctpTransport::buffered_amount(int channel_id) const { return internal_sctp_transport_->buffered_amount(channel_id); } +size_t SctpTransport::buffered_amount_low_threshold(int channel_id) const { + RTC_DCHECK_RUN_ON(owner_thread_); + return internal_sctp_transport_->buffered_amount_low_threshold(channel_id); +} + +void SctpTransport::SetBufferedAmountLowThreshold(int channel_id, + size_t bytes) { + RTC_DCHECK_RUN_ON(owner_thread_); + internal_sctp_transport_->SetBufferedAmountLowThreshold(channel_id, bytes); +} + rtc::scoped_refptr SctpTransport::dtls_transport() const { RTC_DCHECK_RUN_ON(owner_thread_); diff --git a/pc/sctp_transport.h b/pc/sctp_transport.h index 79cb3aed2c..60434d829f 100644 --- a/pc/sctp_transport.h +++ b/pc/sctp_transport.h @@ -53,6 +53,8 @@ class SctpTransport : public SctpTransportInterface, void SetDataSink(DataChannelSink* sink) override; bool IsReadyToSend() const override; size_t buffered_amount(int channel_id) const override; + size_t buffered_amount_low_threshold(int channel_id) const override; + void SetBufferedAmountLowThreshold(int channel_id, size_t bytes) override; // Internal functions void Clear(); diff --git a/pc/sctp_transport_unittest.cc b/pc/sctp_transport_unittest.cc index f0401c1b10..2849889992 100644 --- a/pc/sctp_transport_unittest.cc +++ b/pc/sctp_transport_unittest.cc @@ -64,6 +64,8 @@ class FakeCricketSctpTransport : public cricket::SctpTransportInternal { return max_inbound_streams_; } size_t buffered_amount(int sid) const override { return 0; } + size_t buffered_amount_low_threshold(int sid) const override { return 0; } + void SetBufferedAmountLowThreshold(int sid, size_t bytes) override {} void SendSignalAssociationChangeCommunicationUp() { ASSERT_TRUE(on_connected_callback_); diff --git a/pc/test/fake_data_channel_controller.h b/pc/test/fake_data_channel_controller.h index c65449b010..3f62660922 100644 --- a/pc/test/fake_data_channel_controller.h +++ b/pc/test/fake_data_channel_controller.h @@ -129,6 +129,11 @@ class FakeDataChannelController } size_t buffered_amount(webrtc::StreamId sid) const override { return 0; } + size_t buffered_amount_low_threshold(webrtc::StreamId sid) const override { + return 0; + } + void SetBufferedAmountLowThreshold(webrtc::StreamId sid, + size_t bytes) override {} // Set true to emulate the SCTP stream being blocked by congestion control. void set_send_blocked(bool blocked) { diff --git a/test/pc/sctp/fake_sctp_transport.h b/test/pc/sctp/fake_sctp_transport.h index 6aef57a241..41b7a962f1 100644 --- a/test/pc/sctp/fake_sctp_transport.h +++ b/test/pc/sctp/fake_sctp_transport.h @@ -49,6 +49,8 @@ class FakeSctpTransport : public cricket::SctpTransportInternal { return absl::nullopt; } size_t buffered_amount(int sid) const override { return 0; } + size_t buffered_amount_low_threshold(int sid) const override { return 0; } + void SetBufferedAmountLowThreshold(int sid, size_t bytes) override {} int local_port() const { RTC_DCHECK(local_port_); return *local_port_;