diff --git a/p2p/BUILD.gn b/p2p/BUILD.gn index 1601b49e39..657e73d344 100644 --- a/p2p/BUILD.gn +++ b/p2p/BUILD.gn @@ -628,9 +628,12 @@ rtc_library("packet_transport_internal") { deps = [ ":connection", ":port", + "../api:sequence_checker", "../rtc_base:async_packet_socket", + "../rtc_base:callback_list", "../rtc_base:network_route", "../rtc_base:socket", + "../rtc_base/network:received_packet", "../rtc_base/system:rtc_export", "../rtc_base/third_party/sigslot", ] @@ -1015,6 +1018,8 @@ if (rtc_include_tests) { "../api/units:time_delta", "../rtc_base:copy_on_write_buffer", "../rtc_base:task_queue_for_test", + "../rtc_base:timeutils", + "../rtc_base/network:received_packet", ] absl_deps = [ "//third_party/abseil-cpp/absl/algorithm:container", @@ -1107,6 +1112,7 @@ if (rtc_include_tests) { "base/dtls_transport_unittest.cc", "base/ice_credentials_iterator_unittest.cc", "base/p2p_transport_channel_unittest.cc", + "base/packet_transport_internal_unittest.cc", "base/port_allocator_unittest.cc", "base/port_unittest.cc", "base/pseudo_tcp_unittest.cc", diff --git a/p2p/base/fake_ice_transport.h b/p2p/base/fake_ice_transport.h index 6172ebb15b..285bfff59c 100644 --- a/p2p/base/fake_ice_transport.h +++ b/p2p/base/fake_ice_transport.h @@ -24,7 +24,9 @@ #include "api/units/time_delta.h" #include "p2p/base/ice_transport_internal.h" #include "rtc_base/copy_on_write_buffer.h" +#include "rtc_base/network/received_packet.h" #include "rtc_base/task_queue_for_test.h" +#include "rtc_base/time_utils.h" namespace cricket { using ::webrtc::SafeTask; @@ -391,8 +393,8 @@ class FakeIceTransport : public IceTransportInternal { RTC_EXCLUSIVE_LOCKS_REQUIRED(network_thread_) { if (dest_) { last_sent_packet_ = packet; - dest_->SignalReadPacket(dest_, packet.data(), packet.size(), - rtc::TimeMicros(), 0); + dest_->NotifyPacketReceived(rtc::ReceivedPacket::CreateFromLegacy( + packet.data(), packet.size(), rtc::TimeMicros())); } } diff --git a/p2p/base/fake_packet_transport.h b/p2p/base/fake_packet_transport.h index e80af0e008..29e9bc780e 100644 --- a/p2p/base/fake_packet_transport.h +++ b/p2p/base/fake_packet_transport.h @@ -98,6 +98,8 @@ class FakePacketTransport : public PacketTransportInternal { SignalNetworkRouteChanged(network_route); } + using PacketTransportInternal::NotifyPacketReceived; + private: void set_writable(bool writable) { if (writable_ == writable) { @@ -121,8 +123,8 @@ class FakePacketTransport : public PacketTransportInternal { void SendPacketInternal(const CopyOnWriteBuffer& packet) { last_sent_packet_ = packet; if (dest_) { - dest_->SignalReadPacket(dest_, packet.data(), packet.size(), - TimeMicros(), 0); + dest_->NotifyPacketReceived(rtc::ReceivedPacket::CreateFromLegacy( + packet.data(), packet.size(), rtc::TimeMicros())); } } diff --git a/p2p/base/p2p_transport_channel.cc b/p2p/base/p2p_transport_channel.cc index b0e19f6091..e3ac48c62a 100644 --- a/p2p/base/p2p_transport_channel.cc +++ b/p2p/base/p2p_transport_channel.cc @@ -2211,17 +2211,14 @@ void P2PTransportChannel::OnReadPacket(Connection* connection, last_data_received_ms_ = std::max(last_data_received_ms_, connection->last_data_received()); - SignalReadPacket( - this, reinterpret_cast(packet.payload().data()), - packet.payload().size(), - packet.arrival_time() ? packet.arrival_time()->us() : -1, 0); + NotifyPacketReceived(packet); - // May need to switch the sending connection based on the receiving media - // path if this is the controlled side. - if (ice_role_ == ICEROLE_CONTROLLED && connection != selected_connection_) { - ice_controller_->OnImmediateSwitchRequest(IceSwitchReason::DATA_RECEIVED, - connection); - } + // May need to switch the sending connection based on the receiving media + // path if this is the controlled side. + if (ice_role_ == ICEROLE_CONTROLLED && connection != selected_connection_) { + ice_controller_->OnImmediateSwitchRequest(IceSwitchReason::DATA_RECEIVED, + connection); + } } void P2PTransportChannel::OnSentPacket(const rtc::SentPacket& sent_packet) { diff --git a/p2p/base/packet_transport_internal.cc b/p2p/base/packet_transport_internal.cc index 0904cb2d3e..2e8c00a648 100644 --- a/p2p/base/packet_transport_internal.cc +++ b/p2p/base/packet_transport_internal.cc @@ -10,6 +10,9 @@ #include "p2p/base/packet_transport_internal.h" +#include "api/sequence_checker.h" +#include "rtc_base/network/received_packet.h" + namespace rtc { PacketTransportInternal::PacketTransportInternal() = default; @@ -24,4 +27,23 @@ absl::optional PacketTransportInternal::network_route() const { return absl::optional(); } +void PacketTransportInternal::NotifyPacketReceived( + const rtc::ReceivedPacket& packet) { + RTC_DCHECK_RUN_ON(&network_checker_); + if (!SignalReadPacket.is_empty()) { + // TODO(bugs.webrtc.org:15368): Replace with + // received_packet_callbacklist_. + int flags = 0; + if (packet.decryption_info() == rtc::ReceivedPacket::kSrtpEncrypted) { + flags = 1; + } + SignalReadPacket( + this, reinterpret_cast(packet.payload().data()), + packet.payload().size(), + packet.arrival_time() ? packet.arrival_time()->us() : -1, flags); + } else { + received_packet_callback_list_.Send(this, packet); + } +} + } // namespace rtc diff --git a/p2p/base/packet_transport_internal.h b/p2p/base/packet_transport_internal.h index 2ca47d533d..981554a2bf 100644 --- a/p2p/base/packet_transport_internal.h +++ b/p2p/base/packet_transport_internal.h @@ -12,11 +12,14 @@ #define P2P_BASE_PACKET_TRANSPORT_INTERNAL_H_ #include +#include #include #include "absl/types/optional.h" #include "p2p/base/port.h" #include "rtc_base/async_packet_socket.h" +#include "rtc_base/callback_list.h" +#include "rtc_base/network/received_packet.h" #include "rtc_base/network_route.h" #include "rtc_base/socket.h" #include "rtc_base/system/rtc_export.h" @@ -78,7 +81,19 @@ class RTC_EXPORT PacketTransportInternal : public sigslot::has_slots<> { // Emitted when receiving state changes to true. sigslot::signal1 SignalReceivingState; + template + void RegisterReceivedPacketCallback(void* id, F&& callback) { + RTC_DCHECK_RUN_ON(&network_checker_); + received_packet_callback_list_.AddReceiver(id, std::forward(callback)); + } + void DeregisterReceivedPacketCallback(void* id) { + RTC_DCHECK_RUN_ON(&network_checker_); + received_packet_callback_list_.RemoveReceivers(id); + } + // Signalled each time a packet is received on this channel. + // TODO(bugs.webrtc.org:15368): Deprecate and remove. Replace with + // RegisterReceivedPacketCallback. sigslot::signal5 { protected: PacketTransportInternal(); ~PacketTransportInternal() override; + + void NotifyPacketReceived(const rtc::ReceivedPacket& packet); + + webrtc::SequenceChecker network_checker_{webrtc::SequenceChecker::kDetached}; + + private: + webrtc::CallbackList + received_packet_callback_list_ RTC_GUARDED_BY(&network_checker_); }; } // namespace rtc diff --git a/p2p/base/packet_transport_internal_unittest.cc b/p2p/base/packet_transport_internal_unittest.cc new file mode 100644 index 0000000000..f17e43f62e --- /dev/null +++ b/p2p/base/packet_transport_internal_unittest.cc @@ -0,0 +1,101 @@ +/* + * Copyright 2024 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "p2p/base/packet_transport_internal.h" + +#include "p2p/base/fake_packet_transport.h" +#include "rtc_base/gunit.h" +#include "rtc_base/network/received_packet.h" +#include "rtc_base/third_party/sigslot/sigslot.h" +#include "test/gmock.h" + +namespace { + +using ::testing::MockFunction; + +class SigslotPacketReceiver : public sigslot::has_slots<> { + public: + bool packet_received() const { return packet_received_; } + + void OnPacketReceived(rtc::PacketTransportInternal*, + const char*, + size_t, + const int64_t&, + int flags) { + packet_received_ = true; + flags_ = flags; + } + + bool packet_received_ = false; + int flags_ = -1; +}; + +TEST(PacketTransportInternal, + PacketFlagsCorrectWithUnDecryptedPacketUsingSigslot) { + rtc::FakePacketTransport packet_transport("test"); + SigslotPacketReceiver receiver; + packet_transport.SignalReadPacket.connect( + &receiver, &SigslotPacketReceiver::OnPacketReceived); + + packet_transport.NotifyPacketReceived( + rtc::ReceivedPacket({}, rtc::SocketAddress(), absl::nullopt, + rtc::ReceivedPacket::kNotDecrypted)); + ASSERT_TRUE(receiver.packet_received_); + EXPECT_EQ(receiver.flags_, 0); +} + +TEST(PacketTransportInternal, PacketFlagsCorrectWithSrtpPacketUsingSigslot) { + rtc::FakePacketTransport packet_transport("test"); + SigslotPacketReceiver receiver; + packet_transport.SignalReadPacket.connect( + &receiver, &SigslotPacketReceiver::OnPacketReceived); + + packet_transport.NotifyPacketReceived( + rtc::ReceivedPacket({}, rtc::SocketAddress(), absl::nullopt, + rtc::ReceivedPacket::kSrtpEncrypted)); + ASSERT_TRUE(receiver.packet_received_); + EXPECT_EQ(receiver.flags_, 1); +} + +TEST(PacketTransportInternal, PacketFlagsCorrectWithDtlsDecryptedUsingSigslot) { + rtc::FakePacketTransport packet_transport("test"); + SigslotPacketReceiver receiver; + packet_transport.SignalReadPacket.connect( + &receiver, &SigslotPacketReceiver::OnPacketReceived); + + packet_transport.NotifyPacketReceived( + rtc::ReceivedPacket({}, rtc::SocketAddress(), absl::nullopt, + rtc::ReceivedPacket::kDtlsDecrypted)); + ASSERT_TRUE(receiver.packet_received_); + EXPECT_EQ(receiver.flags_, 0); +} + +TEST(PacketTransportInternal, + NotifyPacketReceivedPassThrougPacketToRegisterListener) { + rtc::FakePacketTransport packet_transport("test"); + MockFunction + receiver; + + packet_transport.RegisterReceivedPacketCallback(&receiver, + receiver.AsStdFunction()); + EXPECT_CALL(receiver, Call) + .WillOnce( + [](rtc::PacketTransportInternal*, const rtc::ReceivedPacket& packet) { + EXPECT_EQ(packet.decryption_info(), + rtc::ReceivedPacket::kDtlsDecrypted); + }); + packet_transport.NotifyPacketReceived( + rtc::ReceivedPacket({}, rtc::SocketAddress(), absl::nullopt, + rtc::ReceivedPacket::kDtlsDecrypted)); + + packet_transport.DeregisterReceivedPacketCallback(&receiver); +} + +} // namespace diff --git a/rtc_base/network/received_packet.cc b/rtc_base/network/received_packet.cc index 95f5e22d3b..bf8a07ca89 100644 --- a/rtc_base/network/received_packet.cc +++ b/rtc_base/network/received_packet.cc @@ -13,15 +13,24 @@ #include #include "absl/types/optional.h" +#include "rtc_base/socket_address.h" namespace rtc { ReceivedPacket::ReceivedPacket(rtc::ArrayView payload, const SocketAddress& source_address, - absl::optional arrival_time) + absl::optional arrival_time, + DecryptionInfo decryption) : payload_(payload), arrival_time_(std::move(arrival_time)), - source_address_(source_address) {} + source_address_(source_address), + decryption_info_(decryption) {} + +ReceivedPacket ReceivedPacket::CopyAndSet( + DecryptionInfo decryption_info) const { + return ReceivedPacket(payload_, source_address_, arrival_time_, + decryption_info); +} // static ReceivedPacket ReceivedPacket::CreateFromLegacy( diff --git a/rtc_base/network/received_packet.h b/rtc_base/network/received_packet.h index d898ccb2e9..b5a6092d89 100644 --- a/rtc_base/network/received_packet.h +++ b/rtc_base/network/received_packet.h @@ -26,12 +26,21 @@ namespace rtc { // example it may contains STUN, SCTP, SRTP, RTP, RTCP.... etc. class RTC_EXPORT ReceivedPacket { public: + enum DecryptionInfo { + kNotDecrypted, // Payload has not yet been decrypted or encryption is not + // used. + kDtlsDecrypted, // Payload has been Dtls decrypted + kSrtpEncrypted // Payload is SRTP encrypted. + }; + // Caller must keep memory pointed to by payload and address valid for the // lifetime of this ReceivedPacket. - ReceivedPacket( - rtc::ArrayView payload, - const SocketAddress& source_address, - absl::optional arrival_time = absl::nullopt); + ReceivedPacket(rtc::ArrayView payload, + const SocketAddress& source_address, + absl::optional arrival_time = absl::nullopt, + DecryptionInfo decryption = kNotDecrypted); + + ReceivedPacket CopyAndSet(DecryptionInfo decryption_info) const; // Address/port of the packet sender. const SocketAddress& source_address() const { return source_address_; } @@ -43,6 +52,8 @@ class RTC_EXPORT ReceivedPacket { return arrival_time_; } + const DecryptionInfo& decryption_info() const { return decryption_info_; } + static ReceivedPacket CreateFromLegacy( const char* data, size_t size, @@ -62,6 +73,7 @@ class RTC_EXPORT ReceivedPacket { rtc::ArrayView payload_; absl::optional arrival_time_; const SocketAddress& source_address_; + DecryptionInfo decryption_info_; }; } // namespace rtc