diff --git a/net/dcsctp/socket/BUILD.gn b/net/dcsctp/socket/BUILD.gn index 69307ac6b4..e7dfc582ed 100644 --- a/net/dcsctp/socket/BUILD.gn +++ b/net/dcsctp/socket/BUILD.gn @@ -160,6 +160,7 @@ rtc_library("dcsctp_socket") { "../tx:send_queue", ] sources = [ + "callback_deferrer.cc", "callback_deferrer.h", "dcsctp_socket.cc", "dcsctp_socket.h", diff --git a/net/dcsctp/socket/callback_deferrer.cc b/net/dcsctp/socket/callback_deferrer.cc new file mode 100644 index 0000000000..1b7fbacccb --- /dev/null +++ b/net/dcsctp/socket/callback_deferrer.cc @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/callback_deferrer.h" + +namespace dcsctp { +namespace { +// A wrapper around the move-only DcSctpMessage, to let it be captured in a +// lambda. +class MessageDeliverer { + public: + explicit MessageDeliverer(DcSctpMessage&& message) + : state_(rtc::make_ref_counted(std::move(message))) {} + + void Deliver(DcSctpSocketCallbacks& c) { + // Really ensure that it's only called once. + RTC_DCHECK(!state_->has_delivered); + state_->has_delivered = true; + c.OnMessageReceived(std::move(state_->message)); + } + + private: + struct State : public rtc::RefCountInterface { + explicit State(DcSctpMessage&& m) + : has_delivered(false), message(std::move(m)) {} + bool has_delivered; + DcSctpMessage message; + }; + rtc::scoped_refptr state_; +}; +} // namespace + +void CallbackDeferrer::TriggerDeferred() { + // Need to swap here. The client may call into the library from within a + // callback, and that might result in adding new callbacks to this instance, + // and the vector can't be modified while iterated on. + std::vector> deferred; + deferred.swap(deferred_); + + for (auto& cb : deferred) { + cb(underlying_); + } +} + +SendPacketStatus CallbackDeferrer::SendPacketWithStatus( + rtc::ArrayView data) { + // Will not be deferred - call directly. + return underlying_.SendPacketWithStatus(data); +} + +std::unique_ptr CallbackDeferrer::CreateTimeout() { + // Will not be deferred - call directly. + return underlying_.CreateTimeout(); +} + +TimeMs CallbackDeferrer::TimeMillis() { + // Will not be deferred - call directly. + return underlying_.TimeMillis(); +} + +uint32_t CallbackDeferrer::GetRandomInt(uint32_t low, uint32_t high) { + // Will not be deferred - call directly. + return underlying_.GetRandomInt(low, high); +} + +void CallbackDeferrer::OnMessageReceived(DcSctpMessage message) { + deferred_.emplace_back( + [deliverer = MessageDeliverer(std::move(message))]( + DcSctpSocketCallbacks& cb) mutable { deliverer.Deliver(cb); }); +} + +void CallbackDeferrer::OnError(ErrorKind error, absl::string_view message) { + deferred_.emplace_back( + [error, message = std::string(message)](DcSctpSocketCallbacks& cb) { + cb.OnError(error, message); + }); +} + +void CallbackDeferrer::OnAborted(ErrorKind error, absl::string_view message) { + deferred_.emplace_back( + [error, message = std::string(message)](DcSctpSocketCallbacks& cb) { + cb.OnAborted(error, message); + }); +} + +void CallbackDeferrer::OnConnected() { + deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnConnected(); }); +} + +void CallbackDeferrer::OnClosed() { + deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnClosed(); }); +} + +void CallbackDeferrer::OnConnectionRestarted() { + deferred_.emplace_back( + [](DcSctpSocketCallbacks& cb) { cb.OnConnectionRestarted(); }); +} + +void CallbackDeferrer::OnStreamsResetFailed( + rtc::ArrayView outgoing_streams, + absl::string_view reason) { + deferred_.emplace_back( + [streams = std::vector(outgoing_streams.begin(), + outgoing_streams.end()), + reason = std::string(reason)](DcSctpSocketCallbacks& cb) { + cb.OnStreamsResetFailed(streams, reason); + }); +} + +void CallbackDeferrer::OnStreamsResetPerformed( + rtc::ArrayView outgoing_streams) { + deferred_.emplace_back( + [streams = std::vector(outgoing_streams.begin(), + outgoing_streams.end())]( + DcSctpSocketCallbacks& cb) { cb.OnStreamsResetPerformed(streams); }); +} + +void CallbackDeferrer::OnIncomingStreamsReset( + rtc::ArrayView incoming_streams) { + deferred_.emplace_back( + [streams = std::vector(incoming_streams.begin(), + incoming_streams.end())]( + DcSctpSocketCallbacks& cb) { cb.OnIncomingStreamsReset(streams); }); +} + +void CallbackDeferrer::OnBufferedAmountLow(StreamID stream_id) { + deferred_.emplace_back([stream_id](DcSctpSocketCallbacks& cb) { + cb.OnBufferedAmountLow(stream_id); + }); +} + +void CallbackDeferrer::OnTotalBufferedAmountLow() { + deferred_.emplace_back( + [](DcSctpSocketCallbacks& cb) { cb.OnTotalBufferedAmountLow(); }); +} +} // namespace dcsctp diff --git a/net/dcsctp/socket/callback_deferrer.h b/net/dcsctp/socket/callback_deferrer.h index b3251c84d5..ab2739feb1 100644 --- a/net/dcsctp/socket/callback_deferrer.h +++ b/net/dcsctp/socket/callback_deferrer.h @@ -47,136 +47,30 @@ class CallbackDeferrer : public DcSctpSocketCallbacks { explicit CallbackDeferrer(DcSctpSocketCallbacks& underlying) : underlying_(underlying) {} - void TriggerDeferred() { - // Need to swap here. The client may call into the library from within a - // callback, and that might result in adding new callbacks to this instance, - // and the vector can't be modified while iterated on. - std::vector> deferred; - deferred.swap(deferred_); - - for (auto& cb : deferred) { - cb(underlying_); - } - } + void TriggerDeferred(); + // Implementation of DcSctpSocketCallbacks SendPacketStatus SendPacketWithStatus( - rtc::ArrayView data) override { - // Will not be deferred - call directly. - return underlying_.SendPacketWithStatus(data); - } - - std::unique_ptr CreateTimeout() override { - // Will not be deferred - call directly. - return underlying_.CreateTimeout(); - } - - TimeMs TimeMillis() override { - // Will not be deferred - call directly. - return underlying_.TimeMillis(); - } - - uint32_t GetRandomInt(uint32_t low, uint32_t high) override { - // Will not be deferred - call directly. - return underlying_.GetRandomInt(low, high); - } - - void OnMessageReceived(DcSctpMessage message) override { - deferred_.emplace_back( - [deliverer = MessageDeliverer(std::move(message))]( - DcSctpSocketCallbacks& cb) mutable { deliverer.Deliver(cb); }); - } - - void OnError(ErrorKind error, absl::string_view message) override { - deferred_.emplace_back( - [error, message = std::string(message)](DcSctpSocketCallbacks& cb) { - cb.OnError(error, message); - }); - } - - void OnAborted(ErrorKind error, absl::string_view message) override { - deferred_.emplace_back( - [error, message = std::string(message)](DcSctpSocketCallbacks& cb) { - cb.OnAborted(error, message); - }); - } - - void OnConnected() override { - deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnConnected(); }); - } - - void OnClosed() override { - deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnClosed(); }); - } - - void OnConnectionRestarted() override { - deferred_.emplace_back( - [](DcSctpSocketCallbacks& cb) { cb.OnConnectionRestarted(); }); - } - + rtc::ArrayView data) override; + std::unique_ptr CreateTimeout() override; + TimeMs TimeMillis() override; + uint32_t GetRandomInt(uint32_t low, uint32_t high) override; + void OnMessageReceived(DcSctpMessage message) override; + void OnError(ErrorKind error, absl::string_view message) override; + void OnAborted(ErrorKind error, absl::string_view message) override; + void OnConnected() override; + void OnClosed() override; + void OnConnectionRestarted() override; void OnStreamsResetFailed(rtc::ArrayView outgoing_streams, - absl::string_view reason) override { - deferred_.emplace_back( - [streams = std::vector(outgoing_streams.begin(), - outgoing_streams.end()), - reason = std::string(reason)](DcSctpSocketCallbacks& cb) { - cb.OnStreamsResetFailed(streams, reason); - }); - } - + absl::string_view reason) override; void OnStreamsResetPerformed( - rtc::ArrayView outgoing_streams) override { - deferred_.emplace_back( - [streams = std::vector(outgoing_streams.begin(), - outgoing_streams.end())]( - DcSctpSocketCallbacks& cb) { - cb.OnStreamsResetPerformed(streams); - }); - } - + rtc::ArrayView outgoing_streams) override; void OnIncomingStreamsReset( - rtc::ArrayView incoming_streams) override { - deferred_.emplace_back( - [streams = std::vector(incoming_streams.begin(), - incoming_streams.end())]( - DcSctpSocketCallbacks& cb) { cb.OnIncomingStreamsReset(streams); }); - } - - void OnBufferedAmountLow(StreamID stream_id) override { - deferred_.emplace_back([stream_id](DcSctpSocketCallbacks& cb) { - cb.OnBufferedAmountLow(stream_id); - }); - } - - void OnTotalBufferedAmountLow() override { - deferred_.emplace_back( - [](DcSctpSocketCallbacks& cb) { cb.OnTotalBufferedAmountLow(); }); - } + rtc::ArrayView incoming_streams) override; + void OnBufferedAmountLow(StreamID stream_id) override; + void OnTotalBufferedAmountLow() override; private: - // A wrapper around the move-only DcSctpMessage, to let it be captured in a - // lambda. - class MessageDeliverer { - public: - explicit MessageDeliverer(DcSctpMessage&& message) - : state_(rtc::make_ref_counted(std::move(message))) {} - - void Deliver(DcSctpSocketCallbacks& c) { - // Really ensure that it's only called once. - RTC_DCHECK(!state_->has_delivered); - state_->has_delivered = true; - c.OnMessageReceived(std::move(state_->message)); - } - - private: - struct State : public rtc::RefCountInterface { - explicit State(DcSctpMessage&& m) - : has_delivered(false), message(std::move(m)) {} - bool has_delivered; - DcSctpMessage message; - }; - rtc::scoped_refptr state_; - }; - DcSctpSocketCallbacks& underlying_; std::vector> deferred_; };