dcsctp: Add OnTotalBufferedAmountLow in Send Queue

This is similar to Change-Id: I12a16f44f775da3711f3aa52a68a0bf24f70d2f8
but with the entire send buffer as scope, not a single stream.

This can be used by clients to take alternate action (such as delaying
transmission or using other buffering) if the send buffer ever becomes
full, as they can now be notified when the send buffer is no longer
full.

Bug: webrtc:12794
Change-Id: Icf3be3b118888ffb5ced955fd7ba4826a37140f9
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/220360
Commit-Queue: Victor Boivie <boivie@webrtc.org>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#34143}
This commit is contained in:
Victor Boivie 2021-05-26 19:48:55 +02:00 committed by WebRTC LUCI CQ
parent 791adafa09
commit bd9031bf22
6 changed files with 107 additions and 35 deletions

View file

@ -167,9 +167,12 @@ DcSctpSocket::DcSctpSocket(absl::string_view log_prefix,
TimerOptions(options.t2_shutdown_timeout, TimerOptions(options.t2_shutdown_timeout,
TimerBackoffAlgorithm::kExponential, TimerBackoffAlgorithm::kExponential,
options.max_retransmissions))), options.max_retransmissions))),
send_queue_(log_prefix_, send_queue_(
options_.max_send_buffer_size, log_prefix_,
[](StreamID stream_id) {}) {} options_.max_send_buffer_size,
[](StreamID stream_id) {},
/*total_buffered_amount_low_threshold=*/0,
[]() {}) {}
std::string DcSctpSocket::log_prefix() const { std::string DcSctpSocket::log_prefix() const {
return log_prefix_ + "[" + std::string(ToString(state_)) + "] "; return log_prefix_ + "[" + std::string(ToString(state_)) + "] ";

View file

@ -44,6 +44,7 @@ class MockSendQueue : public SendQueue {
MOCK_METHOD(void, RollbackResetStreams, (), (override)); MOCK_METHOD(void, RollbackResetStreams, (), (override));
MOCK_METHOD(void, Reset, (), (override)); MOCK_METHOD(void, Reset, (), (override));
MOCK_METHOD(size_t, buffered_amount, (StreamID stream_id), (const, override)); MOCK_METHOD(size_t, buffered_amount, (StreamID stream_id), (const, override));
MOCK_METHOD(size_t, total_buffered_amount, (), (const, override));
MOCK_METHOD(size_t, MOCK_METHOD(size_t,
buffered_amount_low_threshold, buffered_amount_low_threshold,
(StreamID stream_id), (StreamID stream_id),

View file

@ -39,6 +39,7 @@ RRSendQueue::OutgoingStream::GetFirstNonExpiredMessage(TimeMs now) {
*item.expires_at <= now) { *item.expires_at <= now) {
// TODO(boivie): This should be reported to the client. // TODO(boivie): This should be reported to the client.
buffered_amount_.Decrease(item.remaining_size); buffered_amount_.Decrease(item.remaining_size);
total_buffered_amount_.Decrease(item.remaining_size);
items_.pop_front(); items_.pop_front();
continue; continue;
} }
@ -50,6 +51,14 @@ RRSendQueue::OutgoingStream::GetFirstNonExpiredMessage(TimeMs now) {
return nullptr; return nullptr;
} }
bool RRSendQueue::IsConsistent() const {
size_t total_buffered_amount = 0;
for (const auto& stream_entry : streams_) {
total_buffered_amount += stream_entry.second.buffered_amount().value();
}
return total_buffered_amount == total_buffered_amount_.value();
}
bool RRSendQueue::OutgoingStream::IsConsistent() const { bool RRSendQueue::OutgoingStream::IsConsistent() const {
size_t bytes = 0; size_t bytes = 0;
for (const auto& item : items_) { for (const auto& item : items_) {
@ -80,6 +89,7 @@ void RRSendQueue::OutgoingStream::Add(DcSctpMessage message,
absl::optional<TimeMs> expires_at, absl::optional<TimeMs> expires_at,
const SendOptions& send_options) { const SendOptions& send_options) {
buffered_amount_.Increase(message.payload().size()); buffered_amount_.Increase(message.payload().size());
total_buffered_amount_.Increase(message.payload().size());
items_.emplace_back(std::move(message), expires_at, send_options); items_.emplace_back(std::move(message), expires_at, send_options);
RTC_DCHECK(IsConsistent()); RTC_DCHECK(IsConsistent());
@ -141,6 +151,7 @@ absl::optional<SendQueue::DataToSend> RRSendQueue::OutgoingStream::Produce(
FSN fsn(item->current_fsn); FSN fsn(item->current_fsn);
item->current_fsn = FSN(*item->current_fsn + 1); item->current_fsn = FSN(*item->current_fsn + 1);
buffered_amount_.Decrease(payload.size()); buffered_amount_.Decrease(payload.size());
total_buffered_amount_.Decrease(payload.size());
SendQueue::DataToSend chunk(Data(stream_id, item->ssn.value_or(SSN(0)), SendQueue::DataToSend chunk(Data(stream_id, item->ssn.value_or(SSN(0)),
item->message_id.value(), fsn, ppid, item->message_id.value(), fsn, ppid,
@ -172,6 +183,7 @@ bool RRSendQueue::OutgoingStream::Discard(IsUnordered unordered,
if (item.send_options.unordered == unordered && if (item.send_options.unordered == unordered &&
item.message_id.has_value() && *item.message_id == message_id) { item.message_id.has_value() && *item.message_id == message_id) {
buffered_amount_.Decrease(item.remaining_size); buffered_amount_.Decrease(item.remaining_size);
total_buffered_amount_.Decrease(item.remaining_size);
items_.pop_front(); items_.pop_front();
// As the item still existed, it had unsent data. // As the item still existed, it had unsent data.
result = true; result = true;
@ -193,6 +205,7 @@ void RRSendQueue::OutgoingStream::Pause() {
for (auto it = items_.begin(); it != items_.end();) { for (auto it = items_.begin(); it != items_.end();) {
if (it->remaining_offset == 0) { if (it->remaining_offset == 0) {
buffered_amount_.Decrease(it->remaining_size); buffered_amount_.Decrease(it->remaining_size);
total_buffered_amount_.Decrease(it->remaining_size);
it = items_.erase(it); it = items_.erase(it);
} else { } else {
++it; ++it;
@ -208,6 +221,8 @@ void RRSendQueue::OutgoingStream::Reset() {
auto& item = items_.front(); auto& item = items_.front();
buffered_amount_.Increase(item.message.payload().size() - buffered_amount_.Increase(item.message.payload().size() -
item.remaining_size); item.remaining_size);
total_buffered_amount_.Increase(item.message.payload().size() -
item.remaining_size);
item.remaining_offset = 0; item.remaining_offset = 0;
item.remaining_size = item.message.payload().size(); item.remaining_size = item.message.payload().size();
item.message_id = absl::nullopt; item.message_id = absl::nullopt;
@ -243,25 +258,15 @@ void RRSendQueue::Add(TimeMs now,
} }
GetOrCreateStreamInfo(message.stream_id()) GetOrCreateStreamInfo(message.stream_id())
.Add(std::move(message), expires_at, send_options); .Add(std::move(message), expires_at, send_options);
} RTC_DCHECK(IsConsistent());
size_t RRSendQueue::total_bytes() const {
// TODO(boivie): Have the current size as a member variable, so that it's not
// calculated for every operation.
size_t bytes = 0;
for (const auto& stream : streams_) {
bytes += stream.second.buffered_amount().value();
}
return bytes;
} }
bool RRSendQueue::IsFull() const { bool RRSendQueue::IsFull() const {
return total_bytes() >= buffer_size_; return total_buffered_amount() >= buffer_size_;
} }
bool RRSendQueue::IsEmpty() const { bool RRSendQueue::IsEmpty() const {
return total_bytes() == 0; return total_buffered_amount() == 0;
} }
absl::optional<SendQueue::DataToSend> RRSendQueue::Produce( absl::optional<SendQueue::DataToSend> RRSendQueue::Produce(
@ -279,7 +284,7 @@ absl::optional<SendQueue::DataToSend> RRSendQueue::Produce(
next_stream_id_ = StreamID(*it->first + 1); next_stream_id_ = StreamID(*it->first + 1);
} }
} }
RTC_DCHECK(IsConsistent());
return data; return data;
} }
@ -312,6 +317,7 @@ void RRSendQueue::PrepareResetStreams(rtc::ArrayView<const StreamID> streams) {
for (StreamID stream_id : streams) { for (StreamID stream_id : streams) {
GetOrCreateStreamInfo(stream_id).Pause(); GetOrCreateStreamInfo(stream_id).Pause();
} }
RTC_DCHECK(IsConsistent());
} }
bool RRSendQueue::CanResetStreams() const { bool RRSendQueue::CanResetStreams() const {
@ -328,15 +334,19 @@ bool RRSendQueue::CanResetStreams() const {
void RRSendQueue::CommitResetStreams() { void RRSendQueue::CommitResetStreams() {
Reset(); Reset();
RTC_DCHECK(IsConsistent());
} }
void RRSendQueue::RollbackResetStreams() { void RRSendQueue::RollbackResetStreams() {
for (auto& stream_entry : streams_) { for (auto& stream_entry : streams_) {
stream_entry.second.Resume(); stream_entry.second.Resume();
} }
RTC_DCHECK(IsConsistent());
} }
void RRSendQueue::Reset() { void RRSendQueue::Reset() {
// Recalculate buffered amount, as partially sent messages may have been put
// fully back in the queue.
for (auto& stream_entry : streams_) { for (auto& stream_entry : streams_) {
OutgoingStream& stream = stream_entry.second; OutgoingStream& stream = stream_entry.second;
stream.Reset(); stream.Reset();
@ -373,7 +383,9 @@ RRSendQueue::OutgoingStream& RRSendQueue::GetOrCreateStreamInfo(
return streams_ return streams_
.emplace(stream_id, .emplace(stream_id,
[this, stream_id]() { on_buffered_amount_low_(stream_id); }) OutgoingStream(
[this, stream_id]() { on_buffered_amount_low_(stream_id); },
total_buffered_amount_))
.first->second; .first->second;
} }
} // namespace dcsctp } // namespace dcsctp

View file

@ -46,10 +46,15 @@ class RRSendQueue : public SendQueue {
RRSendQueue(absl::string_view log_prefix, RRSendQueue(absl::string_view log_prefix,
size_t buffer_size, size_t buffer_size,
std::function<void(StreamID)> on_buffered_amount_low) std::function<void(StreamID)> on_buffered_amount_low,
size_t total_buffered_amount_low_threshold,
std::function<void()> on_total_buffered_amount_low)
: log_prefix_(std::string(log_prefix) + "fcfs: "), : log_prefix_(std::string(log_prefix) + "fcfs: "),
buffer_size_(buffer_size), buffer_size_(buffer_size),
on_buffered_amount_low_(std::move(on_buffered_amount_low)) {} on_buffered_amount_low_(std::move(on_buffered_amount_low)),
total_buffered_amount_(std::move(on_total_buffered_amount_low)) {
total_buffered_amount_.SetLowThreshold(total_buffered_amount_low_threshold);
}
// Indicates if the buffer is full. Note that it's up to the caller to ensure // Indicates if the buffer is full. Note that it's up to the caller to ensure
// that the buffer is not full prior to adding new items to it. // that the buffer is not full prior to adding new items to it.
@ -76,12 +81,12 @@ class RRSendQueue : public SendQueue {
void RollbackResetStreams() override; void RollbackResetStreams() override;
void Reset() override; void Reset() override;
size_t buffered_amount(StreamID stream_id) const override; size_t buffered_amount(StreamID stream_id) const override;
size_t total_buffered_amount() const override {
return total_buffered_amount_.value();
}
size_t buffered_amount_low_threshold(StreamID stream_id) const override; size_t buffered_amount_low_threshold(StreamID stream_id) const override;
void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override; void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override;
// The size of the buffer, in "payload bytes".
size_t total_bytes() const;
private: private:
// Represents a value and a "low threshold" that when the value reaches or // Represents a value and a "low threshold" that when the value reaches or
// goes under the "low threshold", will trigger `on_threshold_reached` // goes under the "low threshold", will trigger `on_threshold_reached`
@ -109,8 +114,10 @@ class RRSendQueue : public SendQueue {
// Per-stream information. // Per-stream information.
class OutgoingStream { class OutgoingStream {
public: public:
explicit OutgoingStream(std::function<void()> on_buffered_amount_low) explicit OutgoingStream(std::function<void()> on_buffered_amount_low,
: buffered_amount_(std::move(on_buffered_amount_low)) {} ThresholdWatcher& total_buffered_amount)
: buffered_amount_(std::move(on_buffered_amount_low)),
total_buffered_amount_(total_buffered_amount) {}
// Enqueues a message to this stream. // Enqueues a message to this stream.
void Add(DcSctpMessage message, void Add(DcSctpMessage message,
@ -182,8 +189,13 @@ class RRSendQueue : public SendQueue {
// The current amount of buffered data. // The current amount of buffered data.
ThresholdWatcher buffered_amount_; ThresholdWatcher buffered_amount_;
// Reference to the total buffered amount, which is updated directly by each
// stream.
ThresholdWatcher& total_buffered_amount_;
}; };
bool IsConsistent() const;
OutgoingStream& GetOrCreateStreamInfo(StreamID stream_id); OutgoingStream& GetOrCreateStreamInfo(StreamID stream_id);
absl::optional<DataToSend> Produce( absl::optional<DataToSend> Produce(
std::map<StreamID, OutgoingStream>::iterator it, std::map<StreamID, OutgoingStream>::iterator it,
@ -192,10 +204,18 @@ class RRSendQueue : public SendQueue {
const std::string log_prefix_; const std::string log_prefix_;
const size_t buffer_size_; const size_t buffer_size_;
// Called when the buffered amount is below what has been set using // Called when the buffered amount is below what has been set using
// `SetBufferedAmountLowThreshold`. // `SetBufferedAmountLowThreshold`.
const std::function<void(StreamID)> on_buffered_amount_low_; const std::function<void(StreamID)> on_buffered_amount_low_;
// Called when the total buffered amount is below what has been set using
// `SetTotalBufferedAmountLowThreshold`.
const std::function<void()> on_total_buffered_amount_low_;
// The total amount of buffer data, for all streams.
ThresholdWatcher total_buffered_amount_;
// The next stream to send chunks from. // The next stream to send chunks from.
StreamID next_stream_id_ = StreamID(0); StreamID next_stream_id_ = StreamID(0);

View file

@ -31,17 +31,24 @@ constexpr TimeMs kNow = TimeMs(0);
constexpr StreamID kStreamID(1); constexpr StreamID kStreamID(1);
constexpr PPID kPPID(53); constexpr PPID kPPID(53);
constexpr size_t kMaxQueueSize = 1000; constexpr size_t kMaxQueueSize = 1000;
constexpr size_t kBufferedAmountLowThreshold = 500;
constexpr size_t kOneFragmentPacketSize = 100; constexpr size_t kOneFragmentPacketSize = 100;
constexpr size_t kTwoFragmentPacketSize = 101; constexpr size_t kTwoFragmentPacketSize = 101;
class RRSendQueueTest : public testing::Test { class RRSendQueueTest : public testing::Test {
protected: protected:
RRSendQueueTest() RRSendQueueTest()
: buf_("log: ", kMaxQueueSize, on_buffered_amount_low_.AsStdFunction()) {} : buf_("log: ",
kMaxQueueSize,
on_buffered_amount_low_.AsStdFunction(),
kBufferedAmountLowThreshold,
on_total_buffered_amount_low_.AsStdFunction()) {}
const DcSctpOptions options_; const DcSctpOptions options_;
testing::NiceMock<testing::MockFunction<void(StreamID)>> testing::NiceMock<testing::MockFunction<void(StreamID)>>
on_buffered_amount_low_; on_buffered_amount_low_;
testing::NiceMock<testing::MockFunction<void()>>
on_total_buffered_amount_low_;
RRSendQueue buf_; RRSendQueue buf_;
}; };
@ -272,13 +279,13 @@ TEST_F(RRSendQueueTest, DiscardPartialPackets) {
TEST_F(RRSendQueueTest, PrepareResetStreamsDiscardsStream) { TEST_F(RRSendQueueTest, PrepareResetStreamsDiscardsStream) {
buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, {1, 2, 3})); buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, {1, 2, 3}));
buf_.Add(kNow, DcSctpMessage(StreamID(2), PPID(54), {1, 2, 3, 4, 5})); buf_.Add(kNow, DcSctpMessage(StreamID(2), PPID(54), {1, 2, 3, 4, 5}));
EXPECT_EQ(buf_.total_bytes(), 8u); EXPECT_EQ(buf_.total_buffered_amount(), 8u);
buf_.PrepareResetStreams(std::vector<StreamID>({StreamID(1)})); buf_.PrepareResetStreams(std::vector<StreamID>({StreamID(1)}));
EXPECT_EQ(buf_.total_bytes(), 5u); EXPECT_EQ(buf_.total_buffered_amount(), 5u);
buf_.CommitResetStreams(); buf_.CommitResetStreams();
buf_.PrepareResetStreams(std::vector<StreamID>({StreamID(2)})); buf_.PrepareResetStreams(std::vector<StreamID>({StreamID(2)}));
EXPECT_EQ(buf_.total_bytes(), 0u); EXPECT_EQ(buf_.total_buffered_amount(), 0u);
} }
TEST_F(RRSendQueueTest, PrepareResetStreamsNotPartialPackets) { TEST_F(RRSendQueueTest, PrepareResetStreamsNotPartialPackets) {
@ -290,30 +297,30 @@ TEST_F(RRSendQueueTest, PrepareResetStreamsNotPartialPackets) {
absl::optional<SendQueue::DataToSend> chunk_one = buf_.Produce(kNow, 50); absl::optional<SendQueue::DataToSend> chunk_one = buf_.Produce(kNow, 50);
ASSERT_TRUE(chunk_one.has_value()); ASSERT_TRUE(chunk_one.has_value());
EXPECT_EQ(chunk_one->data.stream_id, kStreamID); EXPECT_EQ(chunk_one->data.stream_id, kStreamID);
EXPECT_EQ(buf_.total_bytes(), 2 * payload.size() - 50); EXPECT_EQ(buf_.total_buffered_amount(), 2 * payload.size() - 50);
StreamID stream_ids[] = {StreamID(1)}; StreamID stream_ids[] = {StreamID(1)};
buf_.PrepareResetStreams(stream_ids); buf_.PrepareResetStreams(stream_ids);
EXPECT_EQ(buf_.total_bytes(), payload.size() - 50); EXPECT_EQ(buf_.total_buffered_amount(), payload.size() - 50);
} }
TEST_F(RRSendQueueTest, EnqueuedItemsArePausedDuringStreamReset) { TEST_F(RRSendQueueTest, EnqueuedItemsArePausedDuringStreamReset) {
std::vector<uint8_t> payload(50); std::vector<uint8_t> payload(50);
buf_.PrepareResetStreams(std::vector<StreamID>({StreamID(1)})); buf_.PrepareResetStreams(std::vector<StreamID>({StreamID(1)}));
EXPECT_EQ(buf_.total_bytes(), 0u); EXPECT_EQ(buf_.total_buffered_amount(), 0u);
buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload));
EXPECT_EQ(buf_.total_bytes(), payload.size()); EXPECT_EQ(buf_.total_buffered_amount(), payload.size());
EXPECT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize).has_value()); EXPECT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize).has_value());
buf_.CommitResetStreams(); buf_.CommitResetStreams();
EXPECT_EQ(buf_.total_bytes(), payload.size()); EXPECT_EQ(buf_.total_buffered_amount(), payload.size());
absl::optional<SendQueue::DataToSend> chunk_one = buf_.Produce(kNow, 50); absl::optional<SendQueue::DataToSend> chunk_one = buf_.Produce(kNow, 50);
ASSERT_TRUE(chunk_one.has_value()); ASSERT_TRUE(chunk_one.has_value());
EXPECT_EQ(chunk_one->data.stream_id, kStreamID); EXPECT_EQ(chunk_one->data.stream_id, kStreamID);
EXPECT_EQ(buf_.total_bytes(), 0u); EXPECT_EQ(buf_.total_buffered_amount(), 0u);
} }
TEST_F(RRSendQueueTest, CommittingResetsSSN) { TEST_F(RRSendQueueTest, CommittingResetsSSN) {
@ -633,5 +640,31 @@ TEST_F(RRSendQueueTest, TriggersOnBufferedAmountLowOnThresholdChanged) {
buf_.SetBufferedAmountLowThreshold(StreamID(1), 0); buf_.SetBufferedAmountLowThreshold(StreamID(1), 0);
} }
TEST_F(RRSendQueueTest,
OnTotalBufferedAmountLowDoesNotTriggerOnBufferFillingUp) {
EXPECT_CALL(on_total_buffered_amount_low_, Call).Times(0);
std::vector<uint8_t> payload(kBufferedAmountLowThreshold - 1);
buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload));
EXPECT_EQ(buf_.total_buffered_amount(), payload.size());
// Will not trigger if going above but never below.
buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID,
std::vector<uint8_t>(kOneFragmentPacketSize)));
}
TEST_F(RRSendQueueTest, TriggersOnTotalBufferedAmountLowWhenCrossing) {
EXPECT_CALL(on_total_buffered_amount_low_, Call).Times(0);
std::vector<uint8_t> payload(kBufferedAmountLowThreshold);
buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload));
EXPECT_EQ(buf_.total_buffered_amount(), payload.size());
// Reaches it.
buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, std::vector<uint8_t>(1)));
// Drain it a bit - will trigger.
EXPECT_CALL(on_total_buffered_amount_low_, Call).Times(1);
absl::optional<SendQueue::DataToSend> chunk_two =
buf_.Produce(kNow, kOneFragmentPacketSize);
}
} // namespace } // namespace
} // namespace dcsctp } // namespace dcsctp

View file

@ -113,6 +113,9 @@ class SendQueue {
// e.g. inflight. // e.g. inflight.
virtual size_t buffered_amount(StreamID stream_id) const = 0; virtual size_t buffered_amount(StreamID stream_id) const = 0;
// Returns the total amount of buffer data, for all streams.
virtual size_t total_buffered_amount() const = 0;
// Returns the limit for the `OnBufferedAmountLow` event. Default value is 0. // Returns the limit for the `OnBufferedAmountLow` event. Default value is 0.
virtual size_t buffered_amount_low_threshold(StreamID stream_id) const = 0; virtual size_t buffered_amount_low_threshold(StreamID stream_id) const = 0;