Refactor NatServer to use rtc::ReceivedPackets

Instead of using raw pointers.
Also, ensure callbacks are registered on the correct thread.
Nat servers are test only code.

Bug: webrtc:11943
Change-Id: Ib70a5966acb512f1a07212a07aaedab70aa20f9b
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/331260
Commit-Queue: Per Kjellander <perkj@webrtc.org>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Reviewed-by: Jonas Oreland <jonaso@webrtc.org>
Reviewed-by: Jonas Oreland <jonaso@google.com>
Cr-Commit-Position: refs/heads/main@{#41372}
This commit is contained in:
Per K 2023-12-13 08:10:06 +01:00 committed by WebRTC LUCI CQ
parent e6df126b79
commit a8cd2babcd
7 changed files with 84 additions and 57 deletions

View file

@ -620,8 +620,8 @@ class PortTest : public ::testing::Test, public sigslot::has_slots<> {
std::unique_ptr<rtc::NATServer> CreateNatServer(const SocketAddress& addr, std::unique_ptr<rtc::NATServer> CreateNatServer(const SocketAddress& addr,
rtc::NATType type) { rtc::NATType type) {
return std::make_unique<rtc::NATServer>(type, ss_.get(), addr, addr, return std::make_unique<rtc::NATServer>(type, main_, ss_.get(), addr, addr,
ss_.get(), addr); main_, ss_.get(), addr);
} }
static const char* StunName(NATType type) { static const char* StunName(NATType type) {
switch (type) { switch (type) {

View file

@ -496,8 +496,8 @@ class BasicPortAllocatorTestBase : public ::testing::Test,
bool with_nat) { bool with_nat) {
if (with_nat) { if (with_nat) {
nat_server_.reset(new rtc::NATServer( nat_server_.reset(new rtc::NATServer(
rtc::NAT_OPEN_CONE, vss_.get(), kNatUdpAddr, kNatTcpAddr, vss_.get(), rtc::NAT_OPEN_CONE, thread_, vss_.get(), kNatUdpAddr, kNatTcpAddr,
rtc::SocketAddress(kNatUdpAddr.ipaddr(), 0))); thread_, vss_.get(), rtc::SocketAddress(kNatUdpAddr.ipaddr(), 0)));
} else { } else {
nat_socket_factory_ = nat_socket_factory_ =
std::make_unique<rtc::BasicPacketSocketFactory>(fss_.get()); std::make_unique<rtc::BasicPacketSocketFactory>(fss_.get());

View file

@ -10,12 +10,15 @@
#include "rtc_base/nat_server.h" #include "rtc_base/nat_server.h"
#include <cstddef>
#include <memory> #include <memory>
#include "rtc_base/checks.h" #include "rtc_base/checks.h"
#include "rtc_base/logging.h" #include "rtc_base/logging.h"
#include "rtc_base/nat_socket_factory.h" #include "rtc_base/nat_socket_factory.h"
#include "rtc_base/network/received_packet.h"
#include "rtc_base/socket_adapters.h" #include "rtc_base/socket_adapters.h"
#include "rtc_base/socket_address.h"
namespace rtc { namespace rtc {
@ -125,17 +128,27 @@ class NATProxyServer : public ProxyServer {
}; };
NATServer::NATServer(NATType type, NATServer::NATServer(NATType type,
rtc::Thread& internal_socket_thread,
SocketFactory* internal, SocketFactory* internal,
const SocketAddress& internal_udp_addr, const SocketAddress& internal_udp_addr,
const SocketAddress& internal_tcp_addr, const SocketAddress& internal_tcp_addr,
rtc::Thread& external_socket_thread,
SocketFactory* external, SocketFactory* external,
const SocketAddress& external_ip) const SocketAddress& external_ip)
: external_(external), external_ip_(external_ip.ipaddr(), 0) { : internal_socket_thread_(internal_socket_thread),
external_socket_thread_(external_socket_thread),
external_(external),
external_ip_(external_ip.ipaddr(), 0) {
nat_ = NAT::Create(type); nat_ = NAT::Create(type);
internal_socket_thread_.BlockingCall([&] {
udp_server_socket_ = AsyncUDPSocket::Create(internal, internal_udp_addr); udp_server_socket_ = AsyncUDPSocket::Create(internal, internal_udp_addr);
udp_server_socket_->SignalReadPacket.connect(this, udp_server_socket_->RegisterReceivedPacketCallback(
&NATServer::OnInternalUDPPacket); [&](rtc::AsyncPacketSocket* socket, const rtc::ReceivedPacket& packet) {
OnInternalUDPPacket(socket, packet);
});
});
tcp_proxy_server_ = tcp_proxy_server_ =
new NATProxyServer(internal, internal_tcp_addr, external, external_ip); new NATProxyServer(internal, internal_tcp_addr, external, external_ip);
@ -156,10 +169,11 @@ NATServer::~NATServer() {
} }
void NATServer::OnInternalUDPPacket(AsyncPacketSocket* socket, void NATServer::OnInternalUDPPacket(AsyncPacketSocket* socket,
const char* buf, const rtc::ReceivedPacket& packet) {
size_t size, RTC_DCHECK(internal_socket_thread_.IsCurrent());
const SocketAddress& addr, const char* buf = reinterpret_cast<const char*>(packet.payload().data());
const int64_t& /* packet_time_us */) { size_t size = packet.payload().size();
const SocketAddress& addr = packet.source_address();
// Read the intended destination from the wire. // Read the intended destination from the wire.
SocketAddress dest_addr; SocketAddress dest_addr;
size_t length = UnpackAddressFromNAT(buf, size, &dest_addr); size_t length = UnpackAddressFromNAT(buf, size, &dest_addr);
@ -182,10 +196,8 @@ void NATServer::OnInternalUDPPacket(AsyncPacketSocket* socket,
} }
void NATServer::OnExternalUDPPacket(AsyncPacketSocket* socket, void NATServer::OnExternalUDPPacket(AsyncPacketSocket* socket,
const char* buf, const rtc::ReceivedPacket& packet) {
size_t size, RTC_DCHECK(external_socket_thread_.IsCurrent());
const SocketAddress& remote_addr,
const int64_t& /* packet_time_us */) {
SocketAddress local_addr = socket->GetLocalAddress(); SocketAddress local_addr = socket->GetLocalAddress();
// Find the translation for this addresses. // Find the translation for this addresses.
@ -193,25 +205,31 @@ void NATServer::OnExternalUDPPacket(AsyncPacketSocket* socket,
RTC_DCHECK(iter != ext_map_->end()); RTC_DCHECK(iter != ext_map_->end());
// Allow the NAT to reject this packet. // Allow the NAT to reject this packet.
if (ShouldFilterOut(iter->second, remote_addr)) { if (ShouldFilterOut(iter->second, packet.source_address())) {
RTC_LOG(LS_INFO) << "Packet from " << remote_addr.ToSensitiveString() RTC_LOG(LS_INFO) << "Packet from "
<< packet.source_address().ToSensitiveString()
<< " was filtered out by the NAT."; << " was filtered out by the NAT.";
return; return;
} }
// Forward this packet to the internal address. // Forward this packet to the internal address.
// First prepend the address in a quasi-STUN format. // First prepend the address in a quasi-STUN format.
std::unique_ptr<char[]> real_buf(new char[size + kNATEncodedIPv6AddressSize]); std::unique_ptr<char[]> real_buf(
new char[packet.payload().size() + kNATEncodedIPv6AddressSize]);
size_t addrlength = PackAddressForNAT( size_t addrlength = PackAddressForNAT(
real_buf.get(), size + kNATEncodedIPv6AddressSize, remote_addr); real_buf.get(), packet.payload().size() + kNATEncodedIPv6AddressSize,
packet.source_address());
// Copy the data part after the address. // Copy the data part after the address.
rtc::PacketOptions options; rtc::PacketOptions options;
memcpy(real_buf.get() + addrlength, buf, size); memcpy(real_buf.get() + addrlength, packet.payload().data(),
udp_server_socket_->SendTo(real_buf.get(), size + addrlength, packet.payload().size());
udp_server_socket_->SendTo(real_buf.get(),
packet.payload().size() + addrlength,
iter->second->route.source(), options); iter->second->route.source(), options);
} }
void NATServer::Translate(const SocketAddressPair& route) { void NATServer::Translate(const SocketAddressPair& route) {
external_socket_thread_.BlockingCall([&] {
AsyncUDPSocket* socket = AsyncUDPSocket::Create(external_, external_ip_); AsyncUDPSocket* socket = AsyncUDPSocket::Create(external_, external_ip_);
if (!socket) { if (!socket) {
@ -222,7 +240,11 @@ void NATServer::Translate(const SocketAddressPair& route) {
TransEntry* entry = new TransEntry(route, socket, nat_); TransEntry* entry = new TransEntry(route, socket, nat_);
(*int_map_)[route] = entry; (*int_map_)[route] = entry;
(*ext_map_)[socket->GetLocalAddress()] = entry; (*ext_map_)[socket->GetLocalAddress()] = entry;
socket->SignalReadPacket.connect(this, &NATServer::OnExternalUDPPacket); socket->RegisterReceivedPacketCallback(
[&](rtc::AsyncPacketSocket* socket, const rtc::ReceivedPacket& packet) {
OnExternalUDPPacket(socket, packet);
});
});
} }
bool NATServer::ShouldFilterOut(TransEntry* entry, bool NATServer::ShouldFilterOut(TransEntry* entry,

View file

@ -58,15 +58,17 @@ struct AddrCmp {
const int NAT_SERVER_UDP_PORT = 4237; const int NAT_SERVER_UDP_PORT = 4237;
const int NAT_SERVER_TCP_PORT = 4238; const int NAT_SERVER_TCP_PORT = 4238;
class NATServer : public sigslot::has_slots<> { class NATServer {
public: public:
NATServer(NATType type, NATServer(NATType type,
rtc::Thread& internal_socket_thread,
SocketFactory* internal, SocketFactory* internal,
const SocketAddress& internal_udp_addr, const SocketAddress& internal_udp_addr,
const SocketAddress& internal_tcp_addr, const SocketAddress& internal_tcp_addr,
rtc::Thread& external_socket_thread,
SocketFactory* external, SocketFactory* external,
const SocketAddress& external_ip); const SocketAddress& external_ip);
~NATServer() override; ~NATServer();
NATServer(const NATServer&) = delete; NATServer(const NATServer&) = delete;
NATServer& operator=(const NATServer&) = delete; NATServer& operator=(const NATServer&) = delete;
@ -81,15 +83,9 @@ class NATServer : public sigslot::has_slots<> {
// Packets received on one of the networks. // Packets received on one of the networks.
void OnInternalUDPPacket(AsyncPacketSocket* socket, void OnInternalUDPPacket(AsyncPacketSocket* socket,
const char* buf, const rtc::ReceivedPacket& packet);
size_t size,
const SocketAddress& addr,
const int64_t& packet_time_us);
void OnExternalUDPPacket(AsyncPacketSocket* socket, void OnExternalUDPPacket(AsyncPacketSocket* socket,
const char* buf, const rtc::ReceivedPacket& packet);
size_t size,
const SocketAddress& remote_addr,
const int64_t& packet_time_us);
private: private:
typedef std::set<SocketAddress, AddrCmp> AddressSet; typedef std::set<SocketAddress, AddrCmp> AddressSet;
@ -118,6 +114,8 @@ class NATServer : public sigslot::has_slots<> {
bool ShouldFilterOut(TransEntry* entry, const SocketAddress& ext_addr); bool ShouldFilterOut(TransEntry* entry, const SocketAddress& ext_addr);
NAT* nat_; NAT* nat_;
rtc::Thread& internal_socket_thread_;
rtc::Thread& external_socket_thread_;
SocketFactory* external_; SocketFactory* external_;
SocketAddress external_ip_; SocketAddress external_ip_;
AsyncUDPSocket* udp_server_socket_; AsyncUDPSocket* udp_server_socket_;

View file

@ -368,7 +368,8 @@ NATSocketServer::Translator* NATSocketServer::AddTranslator(
if (nats_.Get(ext_ip)) if (nats_.Get(ext_ip))
return nullptr; return nullptr;
return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip)); return nats_.Add(
ext_ip, new Translator(this, type, int_ip, *msg_queue_, server_, ext_ip));
} }
void NATSocketServer::RemoveTranslator(const SocketAddress& ext_ip) { void NATSocketServer::RemoveTranslator(const SocketAddress& ext_ip) {
@ -413,6 +414,7 @@ Socket* NATSocketServer::CreateInternalSocket(int family,
NATSocketServer::Translator::Translator(NATSocketServer* server, NATSocketServer::Translator::Translator(NATSocketServer* server,
NATType type, NATType type,
const SocketAddress& int_ip, const SocketAddress& int_ip,
Thread& external_socket_thread,
SocketFactory* ext_factory, SocketFactory* ext_factory,
const SocketAddress& ext_ip) const SocketAddress& ext_ip)
: server_(server) { : server_(server) {
@ -422,7 +424,8 @@ NATSocketServer::Translator::Translator(NATSocketServer* server,
internal_server_ = std::make_unique<VirtualSocketServer>(); internal_server_ = std::make_unique<VirtualSocketServer>();
internal_server_->SetMessageQueue(server_->queue()); internal_server_->SetMessageQueue(server_->queue());
nat_server_ = std::make_unique<NATServer>( nat_server_ = std::make_unique<NATServer>(
type, internal_server_.get(), int_ip, int_ip, ext_factory, ext_ip); type, *server->queue(), internal_server_.get(), int_ip, int_ip,
external_socket_thread, ext_factory, ext_ip);
} }
NATSocketServer::Translator::~Translator() { NATSocketServer::Translator::~Translator() {
@ -443,8 +446,8 @@ NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator(
return nullptr; return nullptr;
AddClient(ext_ip); AddClient(ext_ip);
return nats_.Add(ext_ip, return nats_.Add(ext_ip, new Translator(server_, type, int_ip,
new Translator(server_, type, int_ip, server_, ext_ip)); *server_->queue(), server_, ext_ip));
} }
void NATSocketServer::Translator::RemoveTranslator( void NATSocketServer::Translator::RemoveTranslator(
const SocketAddress& ext_ip) { const SocketAddress& ext_ip) {

View file

@ -102,6 +102,7 @@ class NATSocketServer : public SocketServer, public NATInternalSocketFactory {
Translator(NATSocketServer* server, Translator(NATSocketServer* server,
NATType type, NATType type,
const SocketAddress& int_addr, const SocketAddress& int_addr,
Thread& external_socket_thread,
SocketFactory* ext_factory, SocketFactory* ext_factory,
const SocketAddress& ext_addr); const SocketAddress& ext_addr);
~Translator(); ~Translator();

View file

@ -76,16 +76,17 @@ void TestSend(SocketServer* internal,
Thread th_int(internal); Thread th_int(internal);
Thread th_ext(external); Thread th_ext(external);
SocketAddress server_addr = internal_addr;
server_addr.SetPort(0); // Auto-select a port
NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr,
external, external_addrs[0]);
NATSocketFactory* natsf = new NATSocketFactory(
internal, nat->internal_udp_address(), nat->internal_tcp_address());
th_int.Start(); th_int.Start();
th_ext.Start(); th_ext.Start();
SocketAddress server_addr = internal_addr;
server_addr.SetPort(0); // Auto-select a port
NATServer* nat =
new NATServer(nat_type, th_int, internal, server_addr, server_addr,
th_ext, external, external_addrs[0]);
NATSocketFactory* natsf = new NATSocketFactory(
internal, nat->internal_udp_address(), nat->internal_tcp_address());
TestClient* in; TestClient* in;
th_int.BlockingCall([&] { in = CreateTestClient(natsf, internal_addr); }); th_int.BlockingCall([&] { in = CreateTestClient(natsf, internal_addr); });
@ -139,13 +140,13 @@ void TestRecv(SocketServer* internal,
SocketAddress server_addr = internal_addr; SocketAddress server_addr = internal_addr;
server_addr.SetPort(0); // Auto-select a port server_addr.SetPort(0); // Auto-select a port
NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr,
external, external_addrs[0]);
NATSocketFactory* natsf = new NATSocketFactory(
internal, nat->internal_udp_address(), nat->internal_tcp_address());
th_int.Start(); th_int.Start();
th_ext.Start(); th_ext.Start();
NATServer* nat =
new NATServer(nat_type, th_int, internal, server_addr, server_addr,
th_ext, external, external_addrs[0]);
NATSocketFactory* natsf = new NATSocketFactory(
internal, nat->internal_udp_address(), nat->internal_tcp_address());
TestClient* in = nullptr; TestClient* in = nullptr;
th_int.BlockingCall([&] { in = CreateTestClient(natsf, internal_addr); }); th_int.BlockingCall([&] { in = CreateTestClient(natsf, internal_addr); });
@ -355,9 +356,11 @@ class NatTcpTest : public ::testing::Test, public sigslot::has_slots<> {
int_thread_(new Thread(int_vss_.get())), int_thread_(new Thread(int_vss_.get())),
ext_thread_(new Thread(ext_vss_.get())), ext_thread_(new Thread(ext_vss_.get())),
nat_(new NATServer(NAT_OPEN_CONE, nat_(new NATServer(NAT_OPEN_CONE,
*int_thread_,
int_vss_.get(), int_vss_.get(),
int_addr_, int_addr_,
int_addr_, int_addr_,
*ext_thread_,
ext_vss_.get(), ext_vss_.get(),
ext_addr_)), ext_addr_)),
natsf_(new NATSocketFactory(int_vss_.get(), natsf_(new NATSocketFactory(int_vss_.get(),