diff --git a/rtc_tools/data_channel_benchmark/data_channel_benchmark.cc b/rtc_tools/data_channel_benchmark/data_channel_benchmark.cc index 33776f37aa..fa0b6ca9c4 100644 --- a/rtc_tools/data_channel_benchmark/data_channel_benchmark.cc +++ b/rtc_tools/data_channel_benchmark/data_channel_benchmark.cc @@ -71,15 +71,16 @@ struct SetupMessage { } }; -class DataChannelObserverImpl : public webrtc::DataChannelObserver { +class DataChannelServerObserverImpl : public webrtc::DataChannelObserver { public: - explicit DataChannelObserverImpl(webrtc::DataChannelInterface* dc) - : dc_(dc), bytes_received_(0) {} + explicit DataChannelServerObserverImpl(webrtc::DataChannelInterface* dc, + rtc::Thread* signaling_thread) + : dc_(dc), signaling_thread_(signaling_thread) {} + void OnStateChange() override { - RTC_LOG(LS_INFO) << "State changed to " << dc_->state(); + RTC_LOG(LS_INFO) << "Server state changed to " << dc_->state(); switch (dc_->state()) { case webrtc::DataChannelInterface::DataState::kOpen: - open_event_.Set(); break; case webrtc::DataChannelInterface::DataState::kClosed: closed_event_.Set(); @@ -88,67 +89,136 @@ class DataChannelObserverImpl : public webrtc::DataChannelObserver { break; } } - void OnMessage(const webrtc::DataBuffer& buffer) override { - bytes_received_ += buffer.data.size(); - if (bytes_received_threshold_ && - bytes_received_ >= bytes_received_threshold_) { - bytes_received_event_.Set(); - } - if (setup_message_.empty() && !buffer.binary) { - setup_message_.assign(buffer.data.cdata(), buffer.data.size()); + void OnMessage(const webrtc::DataBuffer& buffer) override { + if (!buffer.binary) { + std::string setup_message(buffer.data.cdata(), buffer.data.size()); + setup_ = SetupMessage::FromString(setup_message); + remaining_data_ = setup_.transfer_size; setup_message_event_.Set(); } } + void OnBufferedAmountChange(uint64_t sent_data_size) override { - if (dc_->buffered_amount() < - webrtc::DataChannelInterface::MaxSendQueueSize() / 2) - low_buffered_threshold_event_.Set(); - else - low_buffered_threshold_event_.Reset(); + remaining_data_ -= sent_data_size; + // Allow the transport buffer to be drained before starting again. + if (buffer_ && dc_->buffered_amount() <= ok_to_resume_sending_threshold_) { + total_queued_up_ += buffer_->size(); + dc_->SendAsync(*buffer_, [this, buffer = buffer_](webrtc::RTCError err) { + OnSendAsyncComplete(err, buffer); + }); + buffer_ = nullptr; + } } - bool WaitForOpenState() { - return dc_->state() == webrtc::DataChannelInterface::DataState::kOpen || - open_event_.Wait(rtc::Event::kForever); - } - bool WaitForClosedState() { - return dc_->state() == webrtc::DataChannelInterface::DataState::kClosed || - closed_event_.Wait(rtc::Event::kForever); - } + bool IsOkToCallOnTheNetworkThread() override { return true; } - // Set how many received bytes are required until - // WaitForBytesReceivedThreshold return true. - void SetBytesReceivedThreshold(uint64_t bytes_received_threshold) { - bytes_received_threshold_ = bytes_received_threshold; - if (bytes_received_ >= bytes_received_threshold_) - bytes_received_event_.Set(); - } - // Wait until the received byte count reaches the desired value. - bool WaitForBytesReceivedThreshold() { - return (bytes_received_threshold_ && - bytes_received_ >= bytes_received_threshold_) || - bytes_received_event_.Wait(rtc::Event::kForever); - } + bool WaitForClosedState() { return closed_event_.Wait(rtc::Event::kForever); } - bool WaitForLowbufferedThreshold() { - return low_buffered_threshold_event_.Wait(rtc::Event::kForever); - } - std::string SetupMessage() { return setup_message_; } bool WaitForSetupMessage() { return setup_message_event_.Wait(rtc::Event::kForever); } + void StartSending() { + RTC_CHECK(remaining_data_) << "Error: no data to send"; + std::string data(std::min(setup_.packet_size, remaining_data_), '0'); + webrtc::DataBuffer* data_buffer = + new webrtc::DataBuffer(rtc::CopyOnWriteBuffer(data), true); + total_queued_up_ = data_buffer->size(); + dc_->SendAsync(*data_buffer, + [this, data_buffer = data_buffer](webrtc::RTCError err) { + OnSendAsyncComplete(err, data_buffer); + }); + } + + const struct SetupMessage& parameters() const { return setup_; } + private: - webrtc::DataChannelInterface* dc_; - rtc::Event open_event_; + void OnSendAsyncComplete(webrtc::RTCError error, webrtc::DataBuffer* buffer) { + total_queued_up_ -= buffer->size(); + if (!error.ok()) { + RTC_CHECK_EQ(error.type(), webrtc::RTCErrorType::RESOURCE_EXHAUSTED); + RTC_CHECK(!buffer_); + // Buffer saturated. Retry when OnBufferedAmountChange() detects we can. + buffer_ = buffer; + return; + } + signaling_thread_->PostTask([this, buffer = buffer, + remaining_data = remaining_data_]() { + fprintf(stderr, "Progress: %zu / %zu (%zu%%)\n", + (setup_.transfer_size - remaining_data), setup_.transfer_size, + (100 - remaining_data * 100 / setup_.transfer_size)); + + if (!remaining_data) { + RTC_CHECK(!total_queued_up_); + // We're done. + delete buffer; + return; + } + + if (remaining_data < buffer->data.size()) { + buffer->data.SetSize(remaining_data); + } + + total_queued_up_ += buffer->size(); + dc_->SendAsync(*buffer, [this, buffer = buffer](webrtc::RTCError err) { + OnSendAsyncComplete(err, buffer); + }); + }); + } + + webrtc::DataChannelInterface* const dc_; + rtc::Thread* const signaling_thread_; rtc::Event closed_event_; - rtc::Event bytes_received_event_; - absl::optional bytes_received_threshold_; - uint64_t bytes_received_; - rtc::Event low_buffered_threshold_event_; - std::string setup_message_; rtc::Event setup_message_event_; + size_t remaining_data_ = 0u; + size_t total_queued_up_ = 0u; + struct SetupMessage setup_; + webrtc::DataBuffer* buffer_ = nullptr; + const uint64_t ok_to_resume_sending_threshold_ = + webrtc::DataChannelInterface::MaxSendQueueSize() / 2; +}; + +class DataChannelClientObserverImpl : public webrtc::DataChannelObserver { + public: + explicit DataChannelClientObserverImpl(webrtc::DataChannelInterface* dc, + uint64_t bytes_received_threshold) + : dc_(dc), bytes_received_threshold_(bytes_received_threshold) {} + + void OnStateChange() override { + RTC_LOG(LS_INFO) << "Client state changed to " << dc_->state(); + switch (dc_->state()) { + case webrtc::DataChannelInterface::DataState::kOpen: + open_event_.Set(); + break; + default: + break; + } + } + + void OnMessage(const webrtc::DataBuffer& buffer) override { + bytes_received_ += buffer.data.size(); + if (bytes_received_ >= bytes_received_threshold_) { + bytes_received_event_.Set(); + } + } + + void OnBufferedAmountChange(uint64_t sent_data_size) override {} + bool IsOkToCallOnTheNetworkThread() override { return true; } + + bool WaitForOpenState() { return open_event_.Wait(rtc::Event::kForever); } + + // Wait until the received byte count reaches the desired value. + bool WaitForBytesReceivedThreshold() { + return bytes_received_event_.Wait(rtc::Event::kForever); + } + + private: + webrtc::DataChannelInterface* const dc_; + rtc::Event open_event_; + rtc::Event bytes_received_event_; + const uint64_t bytes_received_threshold_; + uint64_t bytes_received_ = 0u; }; int RunServer() { @@ -163,7 +233,9 @@ int RunServer() { auto grpc_server = webrtc::GrpcSignalingServerInterface::Create( [factory = rtc::scoped_refptr( - factory)](webrtc::SignalingInterface* signaling) { + factory), + signaling_thread = + signaling_thread.get()](webrtc::SignalingInterface* signaling) { webrtc::PeerConnectionClient client(factory.get(), signaling); client.StartPeerConnection(); auto peer_connection = client.peerConnection(); @@ -171,9 +243,11 @@ int RunServer() { // Set up the data channel auto dc_or_error = peer_connection->CreateDataChannelOrError("benchmark", nullptr); + RTC_CHECK(dc_or_error.ok()); auto data_channel = dc_or_error.MoveValue(); auto data_channel_observer = - std::make_unique(data_channel.get()); + std::make_unique( + data_channel.get(), signaling_thread); data_channel->RegisterObserver(data_channel_observer.get()); absl::Cleanup unregister_observer( [data_channel] { data_channel->UnregisterObserver(); }); @@ -183,36 +257,14 @@ int RunServer() { // should be. // First message is "packet_size,transfer_size". data_channel_observer->WaitForSetupMessage(); - auto parameters = - SetupMessage::FromString(data_channel_observer->SetupMessage()); // Wait for the sender and receiver peers to stabilize (send all ACKs) // This makes it easier to isolate the sending part when profiling. absl::SleepFor(absl::Seconds(1)); - std::string data(parameters.packet_size, '0'); - size_t remaining_data = parameters.transfer_size; - auto begin_time = webrtc::Clock::GetRealTimeClock()->CurrentTime(); - while (remaining_data) { - if (remaining_data < data.size()) - data.resize(remaining_data); - - rtc::CopyOnWriteBuffer buffer(data); - webrtc::DataBuffer data_buffer(buffer, true); - if (!data_channel->Send(data_buffer)) { - // If the send() call failed, the buffers are full. - // We wait until there's more room. - data_channel_observer->WaitForLowbufferedThreshold(); - continue; - } - remaining_data -= buffer.size(); - fprintf(stderr, "Progress: %zu / %zu (%zu%%)\n", - (parameters.transfer_size - remaining_data), - parameters.transfer_size, - (100 - remaining_data * 100 / parameters.transfer_size)); - } + data_channel_observer->StartSending(); // Receiver signals the data channel close event when it has received // all the data it requested. @@ -220,8 +272,10 @@ int RunServer() { auto end_time = webrtc::Clock::GetRealTimeClock()->CurrentTime(); auto duration_ms = (end_time - begin_time).ms(); - double throughput = (parameters.transfer_size / 1024. / 1024.) / - (duration_ms / 1000.); + double throughput = + (data_channel_observer->parameters().transfer_size / 1024. / + 1024.) / + (duration_ms / 1000.); printf("Elapsed time: %zums %gMiB/s\n", duration_ms, throughput); }, port, oneshot); @@ -231,7 +285,7 @@ int RunServer() { grpc_server->Wait(); } - signaling_thread->Quit(); + signaling_thread->Stop(); return 0; } @@ -251,13 +305,18 @@ int RunClient() { webrtc::PeerConnectionClient client(factory.get(), grpc_client->signaling_client()); + std::unique_ptr observer; + // Set up the callback to receive the data channel from the sender. rtc::scoped_refptr data_channel; rtc::Event got_data_channel; client.SetOnDataChannel( - [&data_channel, &got_data_channel]( - rtc::scoped_refptr channel) { - data_channel = channel; + [&](rtc::scoped_refptr channel) { + data_channel = std::move(channel); + // DataChannel needs an observer to drain the read queue. + observer = std::make_unique( + data_channel.get(), transfer_size); + data_channel->RegisterObserver(observer.get()); got_data_channel.Set(); }); @@ -270,16 +329,12 @@ int RunClient() { // Wait for the data channel to be received got_data_channel.Wait(rtc::Event::kForever); - // DataChannel needs an observer to start draining the read queue - DataChannelObserverImpl observer(data_channel.get()); - observer.SetBytesReceivedThreshold(transfer_size); - data_channel->RegisterObserver(&observer); absl::Cleanup unregister_observer( [data_channel] { data_channel->UnregisterObserver(); }); // Send a configuration string to the server to tell it to send // 'packet_size' bytes packets and send a total of 'transfer_size' MB. - observer.WaitForOpenState(); + observer->WaitForOpenState(); SetupMessage setup_message = { .packet_size = packet_size, .transfer_size = transfer_size, @@ -290,14 +345,14 @@ int RunClient() { } // Wait until we have received all the data - observer.WaitForBytesReceivedThreshold(); + observer->WaitForBytesReceivedThreshold(); // Close the data channel, signaling to the server we have received // all the requested data. data_channel->Close(); } - signaling_thread->Quit(); + signaling_thread->Stop(); return 0; }