diff --git a/p2p/base/port_unittest.cc b/p2p/base/port_unittest.cc index 96c1bd5ee1..f5f3ee07dc 100644 --- a/p2p/base/port_unittest.cc +++ b/p2p/base/port_unittest.cc @@ -620,8 +620,8 @@ class PortTest : public ::testing::Test, public sigslot::has_slots<> { std::unique_ptr CreateNatServer(const SocketAddress& addr, rtc::NATType type) { - return std::make_unique(type, ss_.get(), addr, addr, - ss_.get(), addr); + return std::make_unique(type, main_, ss_.get(), addr, addr, + main_, ss_.get(), addr); } static const char* StunName(NATType type) { switch (type) { diff --git a/p2p/client/basic_port_allocator_unittest.cc b/p2p/client/basic_port_allocator_unittest.cc index defcab01c9..65f8e43895 100644 --- a/p2p/client/basic_port_allocator_unittest.cc +++ b/p2p/client/basic_port_allocator_unittest.cc @@ -496,8 +496,8 @@ class BasicPortAllocatorTestBase : public ::testing::Test, bool with_nat) { if (with_nat) { nat_server_.reset(new rtc::NATServer( - rtc::NAT_OPEN_CONE, vss_.get(), kNatUdpAddr, kNatTcpAddr, vss_.get(), - rtc::SocketAddress(kNatUdpAddr.ipaddr(), 0))); + rtc::NAT_OPEN_CONE, thread_, vss_.get(), kNatUdpAddr, kNatTcpAddr, + thread_, vss_.get(), rtc::SocketAddress(kNatUdpAddr.ipaddr(), 0))); } else { nat_socket_factory_ = std::make_unique(fss_.get()); diff --git a/rtc_base/nat_server.cc b/rtc_base/nat_server.cc index b818685efb..c274cedf18 100644 --- a/rtc_base/nat_server.cc +++ b/rtc_base/nat_server.cc @@ -10,12 +10,15 @@ #include "rtc_base/nat_server.h" +#include #include #include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "rtc_base/nat_socket_factory.h" +#include "rtc_base/network/received_packet.h" #include "rtc_base/socket_adapters.h" +#include "rtc_base/socket_address.h" namespace rtc { @@ -125,17 +128,27 @@ class NATProxyServer : public ProxyServer { }; NATServer::NATServer(NATType type, + rtc::Thread& internal_socket_thread, SocketFactory* internal, const SocketAddress& internal_udp_addr, const SocketAddress& internal_tcp_addr, + rtc::Thread& external_socket_thread, SocketFactory* external, const SocketAddress& external_ip) - : external_(external), external_ip_(external_ip.ipaddr(), 0) { + : internal_socket_thread_(internal_socket_thread), + external_socket_thread_(external_socket_thread), + external_(external), + external_ip_(external_ip.ipaddr(), 0) { nat_ = NAT::Create(type); - udp_server_socket_ = AsyncUDPSocket::Create(internal, internal_udp_addr); - udp_server_socket_->SignalReadPacket.connect(this, - &NATServer::OnInternalUDPPacket); + internal_socket_thread_.BlockingCall([&] { + udp_server_socket_ = AsyncUDPSocket::Create(internal, internal_udp_addr); + udp_server_socket_->RegisterReceivedPacketCallback( + [&](rtc::AsyncPacketSocket* socket, const rtc::ReceivedPacket& packet) { + OnInternalUDPPacket(socket, packet); + }); + }); + tcp_proxy_server_ = new NATProxyServer(internal, internal_tcp_addr, external, external_ip); @@ -156,10 +169,11 @@ NATServer::~NATServer() { } void NATServer::OnInternalUDPPacket(AsyncPacketSocket* socket, - const char* buf, - size_t size, - const SocketAddress& addr, - const int64_t& /* packet_time_us */) { + const rtc::ReceivedPacket& packet) { + RTC_DCHECK(internal_socket_thread_.IsCurrent()); + const char* buf = reinterpret_cast(packet.payload().data()); + size_t size = packet.payload().size(); + const SocketAddress& addr = packet.source_address(); // Read the intended destination from the wire. SocketAddress dest_addr; size_t length = UnpackAddressFromNAT(buf, size, &dest_addr); @@ -182,10 +196,8 @@ void NATServer::OnInternalUDPPacket(AsyncPacketSocket* socket, } void NATServer::OnExternalUDPPacket(AsyncPacketSocket* socket, - const char* buf, - size_t size, - const SocketAddress& remote_addr, - const int64_t& /* packet_time_us */) { + const rtc::ReceivedPacket& packet) { + RTC_DCHECK(external_socket_thread_.IsCurrent()); SocketAddress local_addr = socket->GetLocalAddress(); // Find the translation for this addresses. @@ -193,36 +205,46 @@ void NATServer::OnExternalUDPPacket(AsyncPacketSocket* socket, RTC_DCHECK(iter != ext_map_->end()); // Allow the NAT to reject this packet. - if (ShouldFilterOut(iter->second, remote_addr)) { - RTC_LOG(LS_INFO) << "Packet from " << remote_addr.ToSensitiveString() + if (ShouldFilterOut(iter->second, packet.source_address())) { + RTC_LOG(LS_INFO) << "Packet from " + << packet.source_address().ToSensitiveString() << " was filtered out by the NAT."; return; } // Forward this packet to the internal address. // First prepend the address in a quasi-STUN format. - std::unique_ptr real_buf(new char[size + kNATEncodedIPv6AddressSize]); + std::unique_ptr real_buf( + new char[packet.payload().size() + kNATEncodedIPv6AddressSize]); size_t addrlength = PackAddressForNAT( - real_buf.get(), size + kNATEncodedIPv6AddressSize, remote_addr); + real_buf.get(), packet.payload().size() + kNATEncodedIPv6AddressSize, + packet.source_address()); // Copy the data part after the address. rtc::PacketOptions options; - memcpy(real_buf.get() + addrlength, buf, size); - udp_server_socket_->SendTo(real_buf.get(), size + addrlength, + memcpy(real_buf.get() + addrlength, packet.payload().data(), + packet.payload().size()); + udp_server_socket_->SendTo(real_buf.get(), + packet.payload().size() + addrlength, iter->second->route.source(), options); } void NATServer::Translate(const SocketAddressPair& route) { - AsyncUDPSocket* socket = AsyncUDPSocket::Create(external_, external_ip_); + external_socket_thread_.BlockingCall([&] { + AsyncUDPSocket* socket = AsyncUDPSocket::Create(external_, external_ip_); - if (!socket) { - RTC_LOG(LS_ERROR) << "Couldn't find a free port!"; - return; - } + if (!socket) { + RTC_LOG(LS_ERROR) << "Couldn't find a free port!"; + return; + } - TransEntry* entry = new TransEntry(route, socket, nat_); - (*int_map_)[route] = entry; - (*ext_map_)[socket->GetLocalAddress()] = entry; - socket->SignalReadPacket.connect(this, &NATServer::OnExternalUDPPacket); + TransEntry* entry = new TransEntry(route, socket, nat_); + (*int_map_)[route] = entry; + (*ext_map_)[socket->GetLocalAddress()] = entry; + socket->RegisterReceivedPacketCallback( + [&](rtc::AsyncPacketSocket* socket, const rtc::ReceivedPacket& packet) { + OnExternalUDPPacket(socket, packet); + }); + }); } bool NATServer::ShouldFilterOut(TransEntry* entry, diff --git a/rtc_base/nat_server.h b/rtc_base/nat_server.h index acbd62a092..d179efa567 100644 --- a/rtc_base/nat_server.h +++ b/rtc_base/nat_server.h @@ -58,15 +58,17 @@ struct AddrCmp { const int NAT_SERVER_UDP_PORT = 4237; const int NAT_SERVER_TCP_PORT = 4238; -class NATServer : public sigslot::has_slots<> { +class NATServer { public: NATServer(NATType type, + rtc::Thread& internal_socket_thread, SocketFactory* internal, const SocketAddress& internal_udp_addr, const SocketAddress& internal_tcp_addr, + rtc::Thread& external_socket_thread, SocketFactory* external, const SocketAddress& external_ip); - ~NATServer() override; + ~NATServer(); NATServer(const NATServer&) = delete; NATServer& operator=(const NATServer&) = delete; @@ -81,15 +83,9 @@ class NATServer : public sigslot::has_slots<> { // Packets received on one of the networks. void OnInternalUDPPacket(AsyncPacketSocket* socket, - const char* buf, - size_t size, - const SocketAddress& addr, - const int64_t& packet_time_us); + const rtc::ReceivedPacket& packet); void OnExternalUDPPacket(AsyncPacketSocket* socket, - const char* buf, - size_t size, - const SocketAddress& remote_addr, - const int64_t& packet_time_us); + const rtc::ReceivedPacket& packet); private: typedef std::set AddressSet; @@ -118,6 +114,8 @@ class NATServer : public sigslot::has_slots<> { bool ShouldFilterOut(TransEntry* entry, const SocketAddress& ext_addr); NAT* nat_; + rtc::Thread& internal_socket_thread_; + rtc::Thread& external_socket_thread_; SocketFactory* external_; SocketAddress external_ip_; AsyncUDPSocket* udp_server_socket_; diff --git a/rtc_base/nat_socket_factory.cc b/rtc_base/nat_socket_factory.cc index fe021b95ff..83ec2bc327 100644 --- a/rtc_base/nat_socket_factory.cc +++ b/rtc_base/nat_socket_factory.cc @@ -368,7 +368,8 @@ NATSocketServer::Translator* NATSocketServer::AddTranslator( if (nats_.Get(ext_ip)) return nullptr; - return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip)); + return nats_.Add( + ext_ip, new Translator(this, type, int_ip, *msg_queue_, server_, ext_ip)); } void NATSocketServer::RemoveTranslator(const SocketAddress& ext_ip) { @@ -413,6 +414,7 @@ Socket* NATSocketServer::CreateInternalSocket(int family, NATSocketServer::Translator::Translator(NATSocketServer* server, NATType type, const SocketAddress& int_ip, + Thread& external_socket_thread, SocketFactory* ext_factory, const SocketAddress& ext_ip) : server_(server) { @@ -422,7 +424,8 @@ NATSocketServer::Translator::Translator(NATSocketServer* server, internal_server_ = std::make_unique(); internal_server_->SetMessageQueue(server_->queue()); nat_server_ = std::make_unique( - type, internal_server_.get(), int_ip, int_ip, ext_factory, ext_ip); + type, *server->queue(), internal_server_.get(), int_ip, int_ip, + external_socket_thread, ext_factory, ext_ip); } NATSocketServer::Translator::~Translator() { @@ -443,8 +446,8 @@ NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator( return nullptr; AddClient(ext_ip); - return nats_.Add(ext_ip, - new Translator(server_, type, int_ip, server_, ext_ip)); + return nats_.Add(ext_ip, new Translator(server_, type, int_ip, + *server_->queue(), server_, ext_ip)); } void NATSocketServer::Translator::RemoveTranslator( const SocketAddress& ext_ip) { diff --git a/rtc_base/nat_socket_factory.h b/rtc_base/nat_socket_factory.h index 0b301b5844..f803496b05 100644 --- a/rtc_base/nat_socket_factory.h +++ b/rtc_base/nat_socket_factory.h @@ -102,6 +102,7 @@ class NATSocketServer : public SocketServer, public NATInternalSocketFactory { Translator(NATSocketServer* server, NATType type, const SocketAddress& int_addr, + Thread& external_socket_thread, SocketFactory* ext_factory, const SocketAddress& ext_addr); ~Translator(); diff --git a/rtc_base/nat_unittest.cc b/rtc_base/nat_unittest.cc index 432985d283..742e0d6ee7 100644 --- a/rtc_base/nat_unittest.cc +++ b/rtc_base/nat_unittest.cc @@ -76,16 +76,17 @@ void TestSend(SocketServer* internal, Thread th_int(internal); Thread th_ext(external); - SocketAddress server_addr = internal_addr; - server_addr.SetPort(0); // Auto-select a port - NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr, - external, external_addrs[0]); - NATSocketFactory* natsf = new NATSocketFactory( - internal, nat->internal_udp_address(), nat->internal_tcp_address()); - th_int.Start(); th_ext.Start(); + SocketAddress server_addr = internal_addr; + server_addr.SetPort(0); // Auto-select a port + NATServer* nat = + new NATServer(nat_type, th_int, internal, server_addr, server_addr, + th_ext, external, external_addrs[0]); + NATSocketFactory* natsf = new NATSocketFactory( + internal, nat->internal_udp_address(), nat->internal_tcp_address()); + TestClient* in; th_int.BlockingCall([&] { in = CreateTestClient(natsf, internal_addr); }); @@ -139,13 +140,13 @@ void TestRecv(SocketServer* internal, SocketAddress server_addr = internal_addr; server_addr.SetPort(0); // Auto-select a port - NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr, - external, external_addrs[0]); - NATSocketFactory* natsf = new NATSocketFactory( - internal, nat->internal_udp_address(), nat->internal_tcp_address()); - th_int.Start(); th_ext.Start(); + NATServer* nat = + new NATServer(nat_type, th_int, internal, server_addr, server_addr, + th_ext, external, external_addrs[0]); + NATSocketFactory* natsf = new NATSocketFactory( + internal, nat->internal_udp_address(), nat->internal_tcp_address()); TestClient* in = nullptr; th_int.BlockingCall([&] { in = CreateTestClient(natsf, internal_addr); }); @@ -355,9 +356,11 @@ class NatTcpTest : public ::testing::Test, public sigslot::has_slots<> { int_thread_(new Thread(int_vss_.get())), ext_thread_(new Thread(ext_vss_.get())), nat_(new NATServer(NAT_OPEN_CONE, + *int_thread_, int_vss_.get(), int_addr_, int_addr_, + *ext_thread_, ext_vss_.get(), ext_addr_)), natsf_(new NATSocketFactory(int_vss_.get(),