diff --git a/examples/stunserver/stunserver_main.cc b/examples/stunserver/stunserver_main.cc index 8180069bf0..ecf6c81ff1 100644 --- a/examples/stunserver/stunserver_main.cc +++ b/examples/stunserver/stunserver_main.cc @@ -29,7 +29,8 @@ int main(int argc, char* argv[]) { return 1; } - rtc::Thread* pthMain = rtc::Thread::Current(); + rtc::Thread* pthMain = rtc::ThreadManager::Instance()->WrapCurrentThread(); + RTC_DCHECK(pthMain); rtc::AsyncUDPSocket* server_socket = rtc::AsyncUDPSocket::Create(pthMain->socketserver(), server_addr); diff --git a/p2p/BUILD.gn b/p2p/BUILD.gn index bfc0cb5186..444e2ad064 100644 --- a/p2p/BUILD.gn +++ b/p2p/BUILD.gn @@ -378,6 +378,7 @@ rtc_library("p2p_server_utils") { "../rtc_base:socket_address", "../rtc_base:ssl", "../rtc_base:stringutils", + "../rtc_base/network:received_packet", "../rtc_base/third_party/sigslot", ] absl_deps = [ diff --git a/p2p/base/p2p_transport_channel_unittest.cc b/p2p/base/p2p_transport_channel_unittest.cc index 9f8df961f7..61c0064aab 100644 --- a/p2p/base/p2p_transport_channel_unittest.cc +++ b/p2p/base/p2p_transport_channel_unittest.cc @@ -285,7 +285,7 @@ class P2PTransportChannelTestBase : public ::testing::Test, ss_(new rtc::FirewallSocketServer(nss_.get())), socket_factory_(new rtc::BasicPacketSocketFactory(ss_.get())), main_(ss_.get()), - stun_server_(TestStunServer::Create(ss_.get(), kStunAddr)), + stun_server_(TestStunServer::Create(ss_.get(), kStunAddr, main_)), turn_server_(&main_, ss_.get(), kTurnUdpIntAddr, kTurnUdpExtAddr), socks_server1_(ss_.get(), kSocksProxyAddrs[0], @@ -1025,7 +1025,7 @@ class P2PTransportChannelTestBase : public ::testing::Test, rtc::AutoSocketServerThread main_; rtc::scoped_refptr safety_ = PendingTaskSafetyFlag::Create(); - std::unique_ptr stun_server_; + TestStunServer::StunServerPtr stun_server_; TestTurnServer turn_server_; rtc::SocksProxyServer socks_server1_; rtc::SocksProxyServer socks_server2_; diff --git a/p2p/base/port_unittest.cc b/p2p/base/port_unittest.cc index f3a01d08cd..96c1bd5ee1 100644 --- a/p2p/base/port_unittest.cc +++ b/p2p/base/port_unittest.cc @@ -421,7 +421,7 @@ class PortTest : public ::testing::Test, public sigslot::has_slots<> { nat_factory2_(ss_.get(), kNatAddr2, SocketAddress()), nat_socket_factory1_(&nat_factory1_), nat_socket_factory2_(&nat_factory2_), - stun_server_(TestStunServer::Create(ss_.get(), kStunAddr)), + stun_server_(TestStunServer::Create(ss_.get(), kStunAddr, main_)), turn_server_(&main_, ss_.get(), kTurnUdpIntAddr, kTurnUdpExtAddr), username_(rtc::CreateRandomString(ICE_UFRAG_LENGTH)), password_(rtc::CreateRandomString(ICE_PWD_LENGTH)), @@ -873,7 +873,7 @@ class PortTest : public ::testing::Test, public sigslot::has_slots<> { rtc::NATSocketFactory nat_factory2_; rtc::BasicPacketSocketFactory nat_socket_factory1_; rtc::BasicPacketSocketFactory nat_socket_factory2_; - std::unique_ptr stun_server_; + TestStunServer::StunServerPtr stun_server_; TestTurnServer turn_server_; std::string username_; std::string password_; diff --git a/p2p/base/stun_port_unittest.cc b/p2p/base/stun_port_unittest.cc index 5b10618337..04505d26ff 100644 --- a/p2p/base/stun_port_unittest.cc +++ b/p2p/base/stun_port_unittest.cc @@ -96,8 +96,10 @@ class StunPortTestBase : public ::testing::Test, public sigslot::has_slots<> { thread_(ss_.get()), network_(network), socket_factory_(ss_.get()), - stun_server_1_(cricket::TestStunServer::Create(ss_.get(), kStunAddr1)), - stun_server_2_(cricket::TestStunServer::Create(ss_.get(), kStunAddr2)), + stun_server_1_( + cricket::TestStunServer::Create(ss_.get(), kStunAddr1, thread_)), + stun_server_2_( + cricket::TestStunServer::Create(ss_.get(), kStunAddr2, thread_)), mdns_responder_provider_(new FakeMdnsResponderProvider()), done_(false), error_(false), @@ -226,14 +228,16 @@ class StunPortTestBase : public ::testing::Test, public sigslot::has_slots<> { cricket::TestStunServer* stun_server_1() { return stun_server_1_.get(); } cricket::TestStunServer* stun_server_2() { return stun_server_2_.get(); } + rtc::AutoSocketServerThread& thread() { return thread_; } + private: std::unique_ptr ss_; rtc::AutoSocketServerThread thread_; rtc::Network network_; rtc::BasicPacketSocketFactory socket_factory_; std::unique_ptr stun_port_; - std::unique_ptr stun_server_1_; - std::unique_ptr stun_server_2_; + cricket::TestStunServer::StunServerPtr stun_server_1_; + cricket::TestStunServer::StunServerPtr stun_server_2_; std::unique_ptr socket_; std::unique_ptr mdns_responder_provider_; bool done_; @@ -620,12 +624,12 @@ class StunIPv6PortTestBase : public StunPortTestBase { kIPv6LocalAddr.ipaddr(), 128), kIPv6LocalAddr.ipaddr()) { - stun_server_ipv6_1_.reset( - cricket::TestStunServer::Create(ss(), kIPv6StunAddr1)); + stun_server_ipv6_1_ = + cricket::TestStunServer::Create(ss(), kIPv6StunAddr1, thread()); } protected: - std::unique_ptr stun_server_ipv6_1_; + cricket::TestStunServer::StunServerPtr stun_server_ipv6_1_; }; class StunIPv6PortTestWithRealClock : public StunIPv6PortTestBase {}; diff --git a/p2p/base/stun_server.cc b/p2p/base/stun_server.cc index d09ff4bca1..e37c5baf26 100644 --- a/p2p/base/stun_server.cc +++ b/p2p/base/stun_server.cc @@ -14,43 +14,49 @@ #include #include "absl/strings/string_view.h" +#include "api/sequence_checker.h" +#include "rtc_base/async_packet_socket.h" #include "rtc_base/byte_buffer.h" #include "rtc_base/logging.h" +#include "rtc_base/network/received_packet.h" namespace cricket { StunServer::StunServer(rtc::AsyncUDPSocket* socket) : socket_(socket) { - socket_->SignalReadPacket.connect(this, &StunServer::OnPacket); + socket_->RegisterReceivedPacketCallback( + [&](rtc::AsyncPacketSocket* socket, const rtc::ReceivedPacket& packet) { + OnPacket(socket, packet); + }); } StunServer::~StunServer() { - socket_->SignalReadPacket.disconnect(this); + RTC_DCHECK_RUN_ON(&sequence_checker_); + socket_->DeregisterReceivedPacketCallback(); } void StunServer::OnPacket(rtc::AsyncPacketSocket* socket, - const char* buf, - size_t size, - const rtc::SocketAddress& remote_addr, - const int64_t& /* packet_time_us */) { + const rtc::ReceivedPacket& packet) { + RTC_DCHECK_RUN_ON(&sequence_checker_); // Parse the STUN message; eat any messages that fail to parse. - rtc::ByteBufferReader bbuf( - rtc::MakeArrayView(reinterpret_cast(buf), size)); + rtc::ByteBufferReader bbuf(packet.payload()); StunMessage msg; if (!msg.Read(&bbuf)) { return; } - // TODO(?): If unknown non-optional (<= 0x7fff) attributes are found, send a + // TODO(?): If unknown non-optional (<= 0x7fff) attributes are found, + // send a // 420 "Unknown Attribute" response. // Send the message to the appropriate handler function. switch (msg.type()) { case STUN_BINDING_REQUEST: - OnBindingRequest(&msg, remote_addr); + OnBindingRequest(&msg, packet.source_address()); break; default: - SendErrorResponse(msg, remote_addr, 600, "Operation Not Supported"); + SendErrorResponse(msg, packet.source_address(), 600, + "Operation Not Supported"); } } diff --git a/p2p/base/stun_server.h b/p2p/base/stun_server.h index 505773b052..f6a776c5f7 100644 --- a/p2p/base/stun_server.h +++ b/p2p/base/stun_server.h @@ -12,7 +12,6 @@ #define P2P_BASE_STUN_SERVER_H_ #include -#include #include @@ -21,26 +20,22 @@ #include "rtc_base/async_packet_socket.h" #include "rtc_base/async_udp_socket.h" #include "rtc_base/socket_address.h" -#include "rtc_base/third_party/sigslot/sigslot.h" namespace cricket { const int STUN_SERVER_PORT = 3478; -class StunServer : public sigslot::has_slots<> { +class StunServer { public: // Creates a STUN server, which will listen on the given socket. explicit StunServer(rtc::AsyncUDPSocket* socket); // Removes the STUN server from the socket and deletes the socket. - ~StunServer() override; + virtual ~StunServer(); protected: - // Slot for Socket.PacketRead: + // Callback for packets from socket. void OnPacket(rtc::AsyncPacketSocket* socket, - const char* buf, - size_t size, - const rtc::SocketAddress& remote_addr, - const int64_t& packet_time_us); + const rtc::ReceivedPacket& packet); // Handlers for the different types of STUN/TURN requests: virtual void OnBindingRequest(StunMessage* msg, @@ -64,6 +59,7 @@ class StunServer : public sigslot::has_slots<> { StunMessage* response) const; private: + webrtc::SequenceChecker sequence_checker_; std::unique_ptr socket_; }; diff --git a/p2p/base/test_stun_server.cc b/p2p/base/test_stun_server.cc index d4c3b2d851..a8a5c46f8c 100644 --- a/p2p/base/test_stun_server.cc +++ b/p2p/base/test_stun_server.cc @@ -10,21 +10,32 @@ #include "p2p/base/test_stun_server.h" +#include + #include "rtc_base/socket.h" #include "rtc_base/socket_server.h" namespace cricket { -TestStunServer* TestStunServer::Create(rtc::SocketServer* ss, - const rtc::SocketAddress& addr) { +std::unique_ptr> +TestStunServer::Create(rtc::SocketServer* ss, + const rtc::SocketAddress& addr, + rtc::Thread& network_thread) { rtc::Socket* socket = ss->CreateSocket(addr.family(), SOCK_DGRAM); rtc::AsyncUDPSocket* udp_socket = rtc::AsyncUDPSocket::Create(socket, addr); - - return new TestStunServer(udp_socket); + TestStunServer* server = nullptr; + network_thread.BlockingCall( + [&]() { server = new TestStunServer(udp_socket, network_thread); }); + std::unique_ptr> result( + server, [&](TestStunServer* server) { + network_thread.BlockingCall([server]() { delete server; }); + }); + return result; } void TestStunServer::OnBindingRequest(StunMessage* msg, const rtc::SocketAddress& remote_addr) { + RTC_DCHECK_RUN_ON(&network_thread_); if (fake_stun_addr_.IsNil()) { StunServer::OnBindingRequest(msg, remote_addr); } else { diff --git a/p2p/base/test_stun_server.h b/p2p/base/test_stun_server.h index 11ac620bb8..7bf7dc1dba 100644 --- a/p2p/base/test_stun_server.h +++ b/p2p/base/test_stun_server.h @@ -11,19 +11,25 @@ #ifndef P2P_BASE_TEST_STUN_SERVER_H_ #define P2P_BASE_TEST_STUN_SERVER_H_ +#include + #include "api/transport/stun.h" #include "p2p/base/stun_server.h" #include "rtc_base/async_udp_socket.h" #include "rtc_base/socket_address.h" #include "rtc_base/socket_server.h" +#include "rtc_base/thread.h" namespace cricket { // A test STUN server. Useful for unit tests. class TestStunServer : StunServer { public: - static TestStunServer* Create(rtc::SocketServer* ss, - const rtc::SocketAddress& addr); + using StunServerPtr = + std::unique_ptr>; + static StunServerPtr Create(rtc::SocketServer* ss, + const rtc::SocketAddress& addr, + rtc::Thread& network_thread); // Set a fake STUN address to return to the client. void set_fake_stun_addr(const rtc::SocketAddress& addr) { @@ -31,13 +37,17 @@ class TestStunServer : StunServer { } private: - explicit TestStunServer(rtc::AsyncUDPSocket* socket) : StunServer(socket) {} + static void DeleteOnNetworkThread(TestStunServer* server); + + TestStunServer(rtc::AsyncUDPSocket* socket, rtc::Thread& network_thread) + : StunServer(socket), network_thread_(network_thread) {} void OnBindingRequest(StunMessage* msg, const rtc::SocketAddress& remote_addr) override; private: rtc::SocketAddress fake_stun_addr_; + rtc::Thread& network_thread_; }; } // namespace cricket diff --git a/p2p/base/turn_port_unittest.cc b/p2p/base/turn_port_unittest.cc index e626947d88..e7efb5e594 100644 --- a/p2p/base/turn_port_unittest.cc +++ b/p2p/base/turn_port_unittest.cc @@ -932,7 +932,7 @@ class TurnLoggingIdValidator : public StunMessageObserver { } } } - void ReceivedChannelData(const char* data, size_t size) override {} + void ReceivedChannelData(rtc::ArrayView packet) override {} private: const char* expect_val_; @@ -1734,7 +1734,7 @@ class MessageObserver : public StunMessageObserver { } } - void ReceivedChannelData(const char* data, size_t size) override { + void ReceivedChannelData(rtc::ArrayView payload) override { if (channel_data_counter_ != nullptr) { (*channel_data_counter_)++; } diff --git a/p2p/base/turn_server.cc b/p2p/base/turn_server.cc index b0c895e782..3d633110a7 100644 --- a/p2p/base/turn_server.cc +++ b/p2p/base/turn_server.cc @@ -102,7 +102,11 @@ void TurnServer::AddInternalSocket(rtc::AsyncPacketSocket* socket, RTC_DCHECK_RUN_ON(thread_); RTC_DCHECK(server_sockets_.end() == server_sockets_.find(socket)); server_sockets_[socket] = proto; - socket->SignalReadPacket.connect(this, &TurnServer::OnInternalPacket); + socket->RegisterReceivedPacketCallback( + [&](rtc::AsyncPacketSocket* socket, const rtc::ReceivedPacket& packet) { + RTC_DCHECK_RUN_ON(thread_); + OnInternalPacket(socket, packet); + }); } void TurnServer::AddInternalServerSocket( @@ -163,40 +167,35 @@ void TurnServer::OnInternalSocketClose(rtc::AsyncPacketSocket* socket, } void TurnServer::OnInternalPacket(rtc::AsyncPacketSocket* socket, - const char* data, - size_t size, - const rtc::SocketAddress& addr, - const int64_t& /* packet_time_us */) { + const rtc::ReceivedPacket& packet) { RTC_DCHECK_RUN_ON(thread_); // Fail if the packet is too small to even contain a channel header. - if (size < TURN_CHANNEL_HEADER_SIZE) { + if (packet.payload().size() < TURN_CHANNEL_HEADER_SIZE) { return; } InternalSocketMap::iterator iter = server_sockets_.find(socket); RTC_DCHECK(iter != server_sockets_.end()); - TurnServerConnection conn(addr, iter->second, socket); - uint16_t msg_type = rtc::GetBE16(data); + TurnServerConnection conn(packet.source_address(), iter->second, socket); + uint16_t msg_type = rtc::GetBE16(packet.payload().data()); if (!IsTurnChannelData(msg_type)) { // This is a STUN message. - HandleStunMessage(&conn, data, size); + HandleStunMessage(&conn, packet.payload()); } else { // This is a channel message; let the allocation handle it. TurnServerAllocation* allocation = FindAllocation(&conn); if (allocation) { - allocation->HandleChannelData(data, size); + allocation->HandleChannelData(packet.payload()); } if (stun_message_observer_ != nullptr) { - stun_message_observer_->ReceivedChannelData(data, size); + stun_message_observer_->ReceivedChannelData(packet.payload()); } } } void TurnServer::HandleStunMessage(TurnServerConnection* conn, - const char* data, - size_t size) { + rtc::ArrayView payload) { TurnMessage msg; - rtc::ByteBufferReader buf( - rtc::MakeArrayView(reinterpret_cast(data), size)); + rtc::ByteBufferReader buf(payload); if (!msg.Read(&buf) || (buf.Length() > 0)) { RTC_LOG(LS_WARNING) << "Received invalid STUN message"; return; @@ -232,7 +231,7 @@ void TurnServer::HandleStunMessage(TurnServerConnection* conn, // Ensure the message is authorized; only needed for requests. if (IsStunRequestType(msg.type())) { - if (!CheckAuthorization(conn, &msg, data, size, key)) { + if (!CheckAuthorization(conn, &msg, key)) { return; } } @@ -273,8 +272,6 @@ bool TurnServer::GetKey(const StunMessage* msg, std::string* key) { bool TurnServer::CheckAuthorization(TurnServerConnection* conn, StunMessage* msg, - const char* data, - size_t size, absl::string_view key) { // RFC 5389, 10.2.2. RTC_DCHECK(IsStunRequestType(msg->type())); @@ -517,7 +514,7 @@ void TurnServer::DestroyInternalSocket(rtc::AsyncPacketSocket* socket) { if (iter != server_sockets_.end()) { rtc::AsyncPacketSocket* socket = iter->first; socket->UnsubscribeCloseEvent(this); - socket->SignalReadPacket.disconnect(this); + socket->DeregisterReceivedPacketCallback(); server_sockets_.erase(iter); std::unique_ptr socket_to_delete = absl::WrapUnique(socket); @@ -562,8 +559,11 @@ TurnServerAllocation::TurnServerAllocation(TurnServer* server, conn_(conn), external_socket_(socket), key_(key) { - external_socket_->SignalReadPacket.connect( - this, &TurnServerAllocation::OnExternalPacket); + external_socket_->RegisterReceivedPacketCallback( + [&](rtc::AsyncPacketSocket* socket, const rtc::ReceivedPacket& packet) { + RTC_DCHECK_RUN_ON(thread_); + OnExternalPacket(socket, packet); + }); } TurnServerAllocation::~TurnServerAllocation() { @@ -759,14 +759,15 @@ void TurnServerAllocation::HandleChannelBindRequest(const TurnMessage* msg) { SendResponse(&response); } -void TurnServerAllocation::HandleChannelData(const char* data, size_t size) { +void TurnServerAllocation::HandleChannelData( + rtc::ArrayView payload) { // Extract the channel number from the data. - uint16_t channel_id = rtc::GetBE16(data); + uint16_t channel_id = rtc::GetBE16(payload.data()); auto channel = FindChannel(channel_id); if (channel != channels_.end()) { // Send the data to the peer address. - SendExternal(data + TURN_CHANNEL_HEADER_SIZE, - size - TURN_CHANNEL_HEADER_SIZE, channel->peer); + SendExternal(payload.data() + TURN_CHANNEL_HEADER_SIZE, + payload.size() - TURN_CHANNEL_HEADER_SIZE, channel->peer); } else { RTC_LOG(LS_WARNING) << ToString() << ": Received channel data for invalid channel, id=" @@ -774,34 +775,31 @@ void TurnServerAllocation::HandleChannelData(const char* data, size_t size) { } } -void TurnServerAllocation::OnExternalPacket( - rtc::AsyncPacketSocket* socket, - const char* data, - size_t size, - const rtc::SocketAddress& addr, - const int64_t& /* packet_time_us */) { +void TurnServerAllocation::OnExternalPacket(rtc::AsyncPacketSocket* socket, + const rtc::ReceivedPacket& packet) { RTC_DCHECK(external_socket_.get() == socket); - auto channel = FindChannel(addr); + auto channel = FindChannel(packet.source_address()); if (channel != channels_.end()) { // There is a channel bound to this address. Send as a channel message. rtc::ByteBufferWriter buf; buf.WriteUInt16(channel->id); - buf.WriteUInt16(static_cast(size)); - buf.WriteBytes(data, size); + buf.WriteUInt16(static_cast(packet.payload().size())); + buf.WriteBytes(reinterpret_cast(packet.payload().data()), + packet.payload().size()); server_->Send(&conn_, buf); } else if (!server_->enable_permission_checks_ || - HasPermission(addr.ipaddr())) { + HasPermission(packet.source_address().ipaddr())) { // No channel, but a permission exists. Send as a data indication. TurnMessage msg(TURN_DATA_INDICATION); msg.AddAttribute(std::make_unique( - STUN_ATTR_XOR_PEER_ADDRESS, addr)); - msg.AddAttribute( - std::make_unique(STUN_ATTR_DATA, data, size)); + STUN_ATTR_XOR_PEER_ADDRESS, packet.source_address())); + msg.AddAttribute(std::make_unique( + STUN_ATTR_DATA, packet.payload().data(), packet.payload().size())); server_->SendStun(&conn_, &msg); } else { RTC_LOG(LS_WARNING) << ToString() << ": Received external packet without permission, peer=" - << addr.ToSensitiveString(); + << packet.source_address().ToSensitiveString(); } } diff --git a/p2p/base/turn_server.h b/p2p/base/turn_server.h index e951d089af..be42338a3b 100644 --- a/p2p/base/turn_server.h +++ b/p2p/base/turn_server.h @@ -26,6 +26,7 @@ #include "api/units/time_delta.h" #include "p2p/base/port_interface.h" #include "rtc_base/async_packet_socket.h" +#include "rtc_base/network/received_packet.h" #include "rtc_base/socket_address.h" #include "rtc_base/ssl_adapter.h" #include "rtc_base/third_party/sigslot/sigslot.h" @@ -69,14 +70,14 @@ class TurnServerConnection { // handles TURN messages (via HandleTurnMessage) and channel data messages // (via HandleChannelData) for this allocation when received by the server. // The object informs the server when its lifetime timer expires. -class TurnServerAllocation : public sigslot::has_slots<> { +class TurnServerAllocation { public: TurnServerAllocation(TurnServer* server_, webrtc::TaskQueueBase* thread, const TurnServerConnection& conn, rtc::AsyncPacketSocket* server_socket, absl::string_view key); - ~TurnServerAllocation() override; + virtual ~TurnServerAllocation(); TurnServerConnection* conn() { return &conn_; } const std::string& key() const { return key_; } @@ -90,7 +91,7 @@ class TurnServerAllocation : public sigslot::has_slots<> { std::string ToString() const; void HandleTurnMessage(const TurnMessage* msg); - void HandleChannelData(const char* data, size_t size); + void HandleChannelData(rtc::ArrayView payload); private: struct Channel { @@ -114,10 +115,7 @@ class TurnServerAllocation : public sigslot::has_slots<> { void HandleChannelBindRequest(const TurnMessage* msg); void OnExternalPacket(rtc::AsyncPacketSocket* socket, - const char* data, - size_t size, - const rtc::SocketAddress& addr, - const int64_t& packet_time_us); + const rtc::ReceivedPacket& packet); static webrtc::TimeDelta ComputeLifetime(const TurnMessage& msg); bool HasPermission(const rtc::IPAddress& addr); @@ -171,7 +169,7 @@ class TurnRedirectInterface { class StunMessageObserver { public: virtual void ReceivedMessage(const TurnMessage* msg) = 0; - virtual void ReceivedChannelData(const char* data, size_t size) = 0; + virtual void ReceivedChannelData(rtc::ArrayView payload) = 0; virtual ~StunMessageObserver() {} }; @@ -266,14 +264,11 @@ class TurnServer : public sigslot::has_slots<> { private: // All private member functions and variables should have access restricted to // thread_. But compile-time annotations are missing for members access from - // TurnServerAllocation (via friend declaration), and the On* methods, which - // are called via sigslot. + // TurnServerAllocation (via friend declaration). + std::string GenerateNonce(int64_t now) const RTC_RUN_ON(thread_); void OnInternalPacket(rtc::AsyncPacketSocket* socket, - const char* data, - size_t size, - const rtc::SocketAddress& address, - const int64_t& packet_time_us); + const rtc::ReceivedPacket& packet) RTC_RUN_ON(thread_); void OnNewInternalConnection(rtc::Socket* socket); @@ -282,8 +277,8 @@ class TurnServer : public sigslot::has_slots<> { void OnInternalSocketClose(rtc::AsyncPacketSocket* socket, int err); void HandleStunMessage(TurnServerConnection* conn, - const char* data, - size_t size) RTC_RUN_ON(thread_); + rtc::ArrayView payload) + RTC_RUN_ON(thread_); void HandleBindingRequest(TurnServerConnection* conn, const StunMessage* msg) RTC_RUN_ON(thread_); void HandleAllocateRequest(TurnServerConnection* conn, @@ -293,8 +288,6 @@ class TurnServer : public sigslot::has_slots<> { bool GetKey(const StunMessage* msg, std::string* key) RTC_RUN_ON(thread_); bool CheckAuthorization(TurnServerConnection* conn, StunMessage* msg, - const char* data, - size_t size, absl::string_view key) RTC_RUN_ON(thread_); bool ValidateNonce(absl::string_view nonce) const RTC_RUN_ON(thread_); diff --git a/p2p/client/basic_port_allocator_unittest.cc b/p2p/client/basic_port_allocator_unittest.cc index 55222a1be2..defcab01c9 100644 --- a/p2p/client/basic_port_allocator_unittest.cc +++ b/p2p/client/basic_port_allocator_unittest.cc @@ -163,7 +163,7 @@ class BasicPortAllocatorTestBase : public ::testing::Test, // must be called. nat_factory_(vss_.get(), kNatUdpAddr, kNatTcpAddr), nat_socket_factory_(new rtc::BasicPacketSocketFactory(&nat_factory_)), - stun_server_(TestStunServer::Create(fss_.get(), kStunAddr)), + stun_server_(TestStunServer::Create(fss_.get(), kStunAddr, thread_)), turn_server_(rtc::Thread::Current(), fss_.get(), kTurnUdpIntAddr, @@ -521,7 +521,7 @@ class BasicPortAllocatorTestBase : public ::testing::Test, std::unique_ptr nat_server_; rtc::NATSocketFactory nat_factory_; std::unique_ptr nat_socket_factory_; - std::unique_ptr stun_server_; + TestStunServer::StunServerPtr stun_server_; TestTurnServer turn_server_; rtc::FakeNetworkManager network_manager_; std::unique_ptr allocator_; diff --git a/p2p/stunprober/stun_prober_unittest.cc b/p2p/stunprober/stun_prober_unittest.cc index ca20fccb6b..1aa2be2844 100644 --- a/p2p/stunprober/stun_prober_unittest.cc +++ b/p2p/stunprober/stun_prober_unittest.cc @@ -44,8 +44,10 @@ class StunProberTest : public ::testing::Test { : ss_(std::make_unique()), main_(ss_.get()), result_(StunProber::SUCCESS), - stun_server_1_(cricket::TestStunServer::Create(ss_.get(), kStunAddr1)), - stun_server_2_(cricket::TestStunServer::Create(ss_.get(), kStunAddr2)) { + stun_server_1_( + cricket::TestStunServer::Create(ss_.get(), kStunAddr1, main_)), + stun_server_2_( + cricket::TestStunServer::Create(ss_.get(), kStunAddr2, main_)) { stun_server_1_->set_fake_stun_addr(kStunMappedAddr); stun_server_2_->set_fake_stun_addr(kStunMappedAddr); rtc::InitializeSSL(); @@ -57,8 +59,8 @@ class StunProberTest : public ::testing::Test { void CreateProber(rtc::PacketSocketFactory* socket_factory, std::vector networks) { - prober_ = std::make_unique( - socket_factory, rtc::Thread::Current(), std::move(networks)); + prober_ = std::make_unique(socket_factory, &main_, + std::move(networks)); } void StartProbing(rtc::PacketSocketFactory* socket_factory, @@ -137,8 +139,8 @@ class StunProberTest : public ::testing::Test { std::unique_ptr prober_; int result_ = 0; bool stopped_ = false; - std::unique_ptr stun_server_1_; - std::unique_ptr stun_server_2_; + cricket::TestStunServer::StunServerPtr stun_server_1_; + cricket::TestStunServer::StunServerPtr stun_server_2_; StunProber::Stats stats_; }; diff --git a/pc/peer_connection_integrationtest.cc b/pc/peer_connection_integrationtest.cc index 1ea16a6cf9..029046e357 100644 --- a/pc/peer_connection_integrationtest.cc +++ b/pc/peer_connection_integrationtest.cc @@ -1755,8 +1755,8 @@ class PeerConnectionIntegrationIceStatesTest } void StartStunServer(const SocketAddress& server_address) { - stun_server_.reset( - cricket::TestStunServer::Create(firewall(), server_address)); + stun_server_ = cricket::TestStunServer::Create(firewall(), server_address, + *network_thread()); } bool TestIPv6() { @@ -1802,7 +1802,7 @@ class PeerConnectionIntegrationIceStatesTest private: uint32_t port_allocator_flags_; - std::unique_ptr stun_server_; + cricket::TestStunServer::StunServerPtr stun_server_; }; // Ensure FakeClockForTest is constructed first (see class for rationale). diff --git a/pc/slow_peer_connection_integration_test.cc b/pc/slow_peer_connection_integration_test.cc index 4e26283395..9e49291d94 100644 --- a/pc/slow_peer_connection_integration_test.cc +++ b/pc/slow_peer_connection_integration_test.cc @@ -262,8 +262,8 @@ class PeerConnectionIntegrationIceStatesTest } void StartStunServer(const SocketAddress& server_address) { - stun_server_.reset( - cricket::TestStunServer::Create(firewall(), server_address)); + stun_server_ = cricket::TestStunServer::Create(firewall(), server_address, + *network_thread()); } bool TestIPv6() { @@ -309,7 +309,7 @@ class PeerConnectionIntegrationIceStatesTest private: uint32_t port_allocator_flags_; - std::unique_ptr stun_server_; + cricket::TestStunServer::StunServerPtr stun_server_; }; // Ensure FakeClockForTest is constructed first (see class for rationale). diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn index 51e15b57f7..0baeb956c5 100644 --- a/rtc_base/BUILD.gn +++ b/rtc_base/BUILD.gn @@ -1760,6 +1760,7 @@ rtc_library("rtc_base_tests_utils") { "../test:scoped_key_value_config", "memory:always_valid_pointer", "memory:fifo_buffer", + "network:received_packet", "synchronization:mutex", "third_party/sigslot", ] diff --git a/rtc_base/test_echo_server.h b/rtc_base/test_echo_server.h index 82817624a5..d99ed72f00 100644 --- a/rtc_base/test_echo_server.h +++ b/rtc_base/test_echo_server.h @@ -21,6 +21,7 @@ #include "absl/memory/memory.h" #include "rtc_base/async_packet_socket.h" #include "rtc_base/async_tcp_socket.h" +#include "rtc_base/network/received_packet.h" #include "rtc_base/socket.h" #include "rtc_base/socket_address.h" #include "rtc_base/third_party/sigslot/sigslot.h" @@ -45,19 +46,17 @@ class TestEchoServer : public sigslot::has_slots<> { Socket* raw_socket = socket->Accept(nullptr); if (raw_socket) { AsyncTCPSocket* packet_socket = new AsyncTCPSocket(raw_socket); - packet_socket->SignalReadPacket.connect(this, &TestEchoServer::OnPacket); + packet_socket->RegisterReceivedPacketCallback( + [&](rtc::AsyncPacketSocket* socket, + const rtc::ReceivedPacket& packet) { OnPacket(socket, packet); }); packet_socket->SubscribeCloseEvent( this, [this](AsyncPacketSocket* s, int err) { OnClose(s, err); }); client_sockets_.push_back(packet_socket); } } - void OnPacket(AsyncPacketSocket* socket, - const char* buf, - size_t size, - const SocketAddress& remote_addr, - const int64_t& /* packet_time_us */) { + void OnPacket(AsyncPacketSocket* socket, const rtc::ReceivedPacket& packet) { rtc::PacketOptions options; - socket->Send(buf, size, options); + socket->Send(packet.payload().data(), packet.payload().size(), options); } void OnClose(AsyncPacketSocket* socket, int err) { ClientList::iterator it = absl::c_find(client_sockets_, socket); diff --git a/test/network/BUILD.gn b/test/network/BUILD.gn index 2b4c39624f..b8255d35fd 100644 --- a/test/network/BUILD.gn +++ b/test/network/BUILD.gn @@ -76,6 +76,7 @@ rtc_library("emulated_network") { "../../rtc_base:task_queue_for_test", "../../rtc_base:threading", "../../rtc_base/memory:always_valid_pointer", + "../../rtc_base/network:received_packet", "../../rtc_base/synchronization:mutex", "../../rtc_base/system:no_unique_address", "../../rtc_base/task_utils:repeating_task", diff --git a/test/network/emulated_turn_server.cc b/test/network/emulated_turn_server.cc index 0bc7ec6e2a..93724ca8a3 100644 --- a/test/network/emulated_turn_server.cc +++ b/test/network/emulated_turn_server.cc @@ -14,6 +14,7 @@ #include #include "api/packet_socket_factory.h" +#include "rtc_base/network/received_packet.h" #include "rtc_base/strings/string_builder.h" #include "rtc_base/task_queue_for_test.h" @@ -22,55 +23,6 @@ namespace { static const char kTestRealm[] = "example.org"; static const char kTestSoftware[] = "TestTurnServer"; -// A wrapper class for copying data between an AsyncPacketSocket and a -// EmulatedEndpoint. This is used by the cricket::TurnServer when -// sending data back into the emulated network. -class AsyncPacketSocketWrapper : public rtc::AsyncPacketSocket { - public: - AsyncPacketSocketWrapper(webrtc::test::EmulatedTURNServer* turn_server, - webrtc::EmulatedEndpoint* endpoint, - uint16_t port) - : turn_server_(turn_server), - endpoint_(endpoint), - local_address_( - rtc::SocketAddress(endpoint_->GetPeerLocalAddress(), port)) {} - ~AsyncPacketSocketWrapper() { turn_server_->Unbind(local_address_); } - - rtc::SocketAddress GetLocalAddress() const override { return local_address_; } - rtc::SocketAddress GetRemoteAddress() const override { - return rtc::SocketAddress(); - } - int Send(const void* pv, - size_t cb, - const rtc::PacketOptions& options) override { - RTC_CHECK(false) << "TCP not implemented"; - return -1; - } - int SendTo(const void* pv, - size_t cb, - const rtc::SocketAddress& addr, - const rtc::PacketOptions& options) override { - // Copy from rtc::AsyncPacketSocket to EmulatedEndpoint. - rtc::CopyOnWriteBuffer buf(reinterpret_cast(pv), cb); - endpoint_->SendPacket(local_address_, addr, buf); - return cb; - } - int Close() override { return 0; } - - rtc::AsyncPacketSocket::State GetState() const override { - return rtc::AsyncPacketSocket::STATE_BOUND; - } - int GetOption(rtc::Socket::Option opt, int* value) override { return 0; } - int SetOption(rtc::Socket::Option opt, int value) override { return 0; } - int GetError() const override { return 0; } - void SetError(int error) override {} - - private: - webrtc::test::EmulatedTURNServer* const turn_server_; - webrtc::EmulatedEndpoint* const endpoint_; - const rtc::SocketAddress local_address_; -}; - // A wrapper class for cricket::TurnServer to allocate sockets. class PacketSocketFactoryWrapper : public rtc::PacketSocketFactory { public: @@ -116,6 +68,59 @@ class PacketSocketFactoryWrapper : public rtc::PacketSocketFactory { namespace webrtc { namespace test { +// A wrapper class for copying data between an AsyncPacketSocket and a +// EmulatedEndpoint. This is used by the cricket::TurnServer when +// sending data back into the emulated network. +class EmulatedTURNServer::AsyncPacketSocketWrapper + : public rtc::AsyncPacketSocket { + public: + AsyncPacketSocketWrapper(webrtc::test::EmulatedTURNServer* turn_server, + webrtc::EmulatedEndpoint* endpoint, + uint16_t port) + : turn_server_(turn_server), + endpoint_(endpoint), + local_address_( + rtc::SocketAddress(endpoint_->GetPeerLocalAddress(), port)) {} + ~AsyncPacketSocketWrapper() { turn_server_->Unbind(local_address_); } + + rtc::SocketAddress GetLocalAddress() const override { return local_address_; } + rtc::SocketAddress GetRemoteAddress() const override { + return rtc::SocketAddress(); + } + int Send(const void* pv, + size_t cb, + const rtc::PacketOptions& options) override { + RTC_CHECK(false) << "TCP not implemented"; + return -1; + } + int SendTo(const void* pv, + size_t cb, + const rtc::SocketAddress& addr, + const rtc::PacketOptions& options) override { + // Copy from rtc::AsyncPacketSocket to EmulatedEndpoint. + rtc::CopyOnWriteBuffer buf(reinterpret_cast(pv), cb); + endpoint_->SendPacket(local_address_, addr, buf); + return cb; + } + int Close() override { return 0; } + void NotifyPacketReceived(const rtc::ReceivedPacket& packet) { + rtc::AsyncPacketSocket::NotifyPacketReceived(packet); + } + + rtc::AsyncPacketSocket::State GetState() const override { + return rtc::AsyncPacketSocket::STATE_BOUND; + } + int GetOption(rtc::Socket::Option opt, int* value) override { return 0; } + int SetOption(rtc::Socket::Option opt, int value) override { return 0; } + int GetError() const override { return 0; } + void SetError(int error) override {} + + private: + webrtc::test::EmulatedTURNServer* const turn_server_; + webrtc::EmulatedEndpoint* const endpoint_; + const rtc::SocketAddress local_address_; +}; + EmulatedTURNServer::EmulatedTURNServer(std::unique_ptr thread, EmulatedEndpoint* client, EmulatedEndpoint* peer) @@ -170,9 +175,8 @@ void EmulatedTURNServer::OnPacketReceived(webrtc::EmulatedIpPacket packet) { RTC_DCHECK_RUN_ON(thread_.get()); auto it = sockets_.find(packet.to); if (it != sockets_.end()) { - it->second->SignalReadPacket( - it->second, reinterpret_cast(packet.cdata()), - packet.size(), packet.from, packet.arrival_time.ms()); + it->second->NotifyPacketReceived( + rtc::ReceivedPacket(packet.data, packet.from, packet.arrival_time)); } }); } diff --git a/test/network/emulated_turn_server.h b/test/network/emulated_turn_server.h index 9cb0ceabf6..de5d266897 100644 --- a/test/network/emulated_turn_server.h +++ b/test/network/emulated_turn_server.h @@ -84,7 +84,8 @@ class EmulatedTURNServer : public EmulatedTURNServerInterface, EmulatedEndpoint* const client_; EmulatedEndpoint* const peer_; std::unique_ptr turn_server_ RTC_GUARDED_BY(&thread_); - std::map sockets_ + class AsyncPacketSocketWrapper; + std::map sockets_ RTC_GUARDED_BY(&thread_); // Wraps a EmulatedEndpoint in a AsyncPacketSocket to bridge interaction