Refactor NatServer to use rtc::ReceivedPackets

Instead of using raw pointers.
Also, ensure callbacks are registered on the correct thread.
Nat servers are test only code.

Bug: webrtc:11943
Change-Id: Ib70a5966acb512f1a07212a07aaedab70aa20f9b
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/331260
Commit-Queue: Per Kjellander <perkj@webrtc.org>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Reviewed-by: Jonas Oreland <jonaso@webrtc.org>
Reviewed-by: Jonas Oreland <jonaso@google.com>
Cr-Commit-Position: refs/heads/main@{#41372}
This commit is contained in:
Per K 2023-12-13 08:10:06 +01:00 committed by WebRTC LUCI CQ
parent e6df126b79
commit a8cd2babcd
7 changed files with 84 additions and 57 deletions

View file

@ -620,8 +620,8 @@ class PortTest : public ::testing::Test, public sigslot::has_slots<> {
std::unique_ptr<rtc::NATServer> CreateNatServer(const SocketAddress& addr,
rtc::NATType type) {
return std::make_unique<rtc::NATServer>(type, ss_.get(), addr, addr,
ss_.get(), addr);
return std::make_unique<rtc::NATServer>(type, main_, ss_.get(), addr, addr,
main_, ss_.get(), addr);
}
static const char* StunName(NATType type) {
switch (type) {

View file

@ -496,8 +496,8 @@ class BasicPortAllocatorTestBase : public ::testing::Test,
bool with_nat) {
if (with_nat) {
nat_server_.reset(new rtc::NATServer(
rtc::NAT_OPEN_CONE, vss_.get(), kNatUdpAddr, kNatTcpAddr, vss_.get(),
rtc::SocketAddress(kNatUdpAddr.ipaddr(), 0)));
rtc::NAT_OPEN_CONE, thread_, vss_.get(), kNatUdpAddr, kNatTcpAddr,
thread_, vss_.get(), rtc::SocketAddress(kNatUdpAddr.ipaddr(), 0)));
} else {
nat_socket_factory_ =
std::make_unique<rtc::BasicPacketSocketFactory>(fss_.get());

View file

@ -10,12 +10,15 @@
#include "rtc_base/nat_server.h"
#include <cstddef>
#include <memory>
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/nat_socket_factory.h"
#include "rtc_base/network/received_packet.h"
#include "rtc_base/socket_adapters.h"
#include "rtc_base/socket_address.h"
namespace rtc {
@ -125,17 +128,27 @@ class NATProxyServer : public ProxyServer {
};
NATServer::NATServer(NATType type,
rtc::Thread& internal_socket_thread,
SocketFactory* internal,
const SocketAddress& internal_udp_addr,
const SocketAddress& internal_tcp_addr,
rtc::Thread& external_socket_thread,
SocketFactory* external,
const SocketAddress& external_ip)
: external_(external), external_ip_(external_ip.ipaddr(), 0) {
: internal_socket_thread_(internal_socket_thread),
external_socket_thread_(external_socket_thread),
external_(external),
external_ip_(external_ip.ipaddr(), 0) {
nat_ = NAT::Create(type);
udp_server_socket_ = AsyncUDPSocket::Create(internal, internal_udp_addr);
udp_server_socket_->SignalReadPacket.connect(this,
&NATServer::OnInternalUDPPacket);
internal_socket_thread_.BlockingCall([&] {
udp_server_socket_ = AsyncUDPSocket::Create(internal, internal_udp_addr);
udp_server_socket_->RegisterReceivedPacketCallback(
[&](rtc::AsyncPacketSocket* socket, const rtc::ReceivedPacket& packet) {
OnInternalUDPPacket(socket, packet);
});
});
tcp_proxy_server_ =
new NATProxyServer(internal, internal_tcp_addr, external, external_ip);
@ -156,10 +169,11 @@ NATServer::~NATServer() {
}
void NATServer::OnInternalUDPPacket(AsyncPacketSocket* socket,
const char* buf,
size_t size,
const SocketAddress& addr,
const int64_t& /* packet_time_us */) {
const rtc::ReceivedPacket& packet) {
RTC_DCHECK(internal_socket_thread_.IsCurrent());
const char* buf = reinterpret_cast<const char*>(packet.payload().data());
size_t size = packet.payload().size();
const SocketAddress& addr = packet.source_address();
// Read the intended destination from the wire.
SocketAddress dest_addr;
size_t length = UnpackAddressFromNAT(buf, size, &dest_addr);
@ -182,10 +196,8 @@ void NATServer::OnInternalUDPPacket(AsyncPacketSocket* socket,
}
void NATServer::OnExternalUDPPacket(AsyncPacketSocket* socket,
const char* buf,
size_t size,
const SocketAddress& remote_addr,
const int64_t& /* packet_time_us */) {
const rtc::ReceivedPacket& packet) {
RTC_DCHECK(external_socket_thread_.IsCurrent());
SocketAddress local_addr = socket->GetLocalAddress();
// Find the translation for this addresses.
@ -193,36 +205,46 @@ void NATServer::OnExternalUDPPacket(AsyncPacketSocket* socket,
RTC_DCHECK(iter != ext_map_->end());
// Allow the NAT to reject this packet.
if (ShouldFilterOut(iter->second, remote_addr)) {
RTC_LOG(LS_INFO) << "Packet from " << remote_addr.ToSensitiveString()
if (ShouldFilterOut(iter->second, packet.source_address())) {
RTC_LOG(LS_INFO) << "Packet from "
<< packet.source_address().ToSensitiveString()
<< " was filtered out by the NAT.";
return;
}
// Forward this packet to the internal address.
// First prepend the address in a quasi-STUN format.
std::unique_ptr<char[]> real_buf(new char[size + kNATEncodedIPv6AddressSize]);
std::unique_ptr<char[]> real_buf(
new char[packet.payload().size() + kNATEncodedIPv6AddressSize]);
size_t addrlength = PackAddressForNAT(
real_buf.get(), size + kNATEncodedIPv6AddressSize, remote_addr);
real_buf.get(), packet.payload().size() + kNATEncodedIPv6AddressSize,
packet.source_address());
// Copy the data part after the address.
rtc::PacketOptions options;
memcpy(real_buf.get() + addrlength, buf, size);
udp_server_socket_->SendTo(real_buf.get(), size + addrlength,
memcpy(real_buf.get() + addrlength, packet.payload().data(),
packet.payload().size());
udp_server_socket_->SendTo(real_buf.get(),
packet.payload().size() + addrlength,
iter->second->route.source(), options);
}
void NATServer::Translate(const SocketAddressPair& route) {
AsyncUDPSocket* socket = AsyncUDPSocket::Create(external_, external_ip_);
external_socket_thread_.BlockingCall([&] {
AsyncUDPSocket* socket = AsyncUDPSocket::Create(external_, external_ip_);
if (!socket) {
RTC_LOG(LS_ERROR) << "Couldn't find a free port!";
return;
}
if (!socket) {
RTC_LOG(LS_ERROR) << "Couldn't find a free port!";
return;
}
TransEntry* entry = new TransEntry(route, socket, nat_);
(*int_map_)[route] = entry;
(*ext_map_)[socket->GetLocalAddress()] = entry;
socket->SignalReadPacket.connect(this, &NATServer::OnExternalUDPPacket);
TransEntry* entry = new TransEntry(route, socket, nat_);
(*int_map_)[route] = entry;
(*ext_map_)[socket->GetLocalAddress()] = entry;
socket->RegisterReceivedPacketCallback(
[&](rtc::AsyncPacketSocket* socket, const rtc::ReceivedPacket& packet) {
OnExternalUDPPacket(socket, packet);
});
});
}
bool NATServer::ShouldFilterOut(TransEntry* entry,

View file

@ -58,15 +58,17 @@ struct AddrCmp {
const int NAT_SERVER_UDP_PORT = 4237;
const int NAT_SERVER_TCP_PORT = 4238;
class NATServer : public sigslot::has_slots<> {
class NATServer {
public:
NATServer(NATType type,
rtc::Thread& internal_socket_thread,
SocketFactory* internal,
const SocketAddress& internal_udp_addr,
const SocketAddress& internal_tcp_addr,
rtc::Thread& external_socket_thread,
SocketFactory* external,
const SocketAddress& external_ip);
~NATServer() override;
~NATServer();
NATServer(const NATServer&) = delete;
NATServer& operator=(const NATServer&) = delete;
@ -81,15 +83,9 @@ class NATServer : public sigslot::has_slots<> {
// Packets received on one of the networks.
void OnInternalUDPPacket(AsyncPacketSocket* socket,
const char* buf,
size_t size,
const SocketAddress& addr,
const int64_t& packet_time_us);
const rtc::ReceivedPacket& packet);
void OnExternalUDPPacket(AsyncPacketSocket* socket,
const char* buf,
size_t size,
const SocketAddress& remote_addr,
const int64_t& packet_time_us);
const rtc::ReceivedPacket& packet);
private:
typedef std::set<SocketAddress, AddrCmp> AddressSet;
@ -118,6 +114,8 @@ class NATServer : public sigslot::has_slots<> {
bool ShouldFilterOut(TransEntry* entry, const SocketAddress& ext_addr);
NAT* nat_;
rtc::Thread& internal_socket_thread_;
rtc::Thread& external_socket_thread_;
SocketFactory* external_;
SocketAddress external_ip_;
AsyncUDPSocket* udp_server_socket_;

View file

@ -368,7 +368,8 @@ NATSocketServer::Translator* NATSocketServer::AddTranslator(
if (nats_.Get(ext_ip))
return nullptr;
return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip));
return nats_.Add(
ext_ip, new Translator(this, type, int_ip, *msg_queue_, server_, ext_ip));
}
void NATSocketServer::RemoveTranslator(const SocketAddress& ext_ip) {
@ -413,6 +414,7 @@ Socket* NATSocketServer::CreateInternalSocket(int family,
NATSocketServer::Translator::Translator(NATSocketServer* server,
NATType type,
const SocketAddress& int_ip,
Thread& external_socket_thread,
SocketFactory* ext_factory,
const SocketAddress& ext_ip)
: server_(server) {
@ -422,7 +424,8 @@ NATSocketServer::Translator::Translator(NATSocketServer* server,
internal_server_ = std::make_unique<VirtualSocketServer>();
internal_server_->SetMessageQueue(server_->queue());
nat_server_ = std::make_unique<NATServer>(
type, internal_server_.get(), int_ip, int_ip, ext_factory, ext_ip);
type, *server->queue(), internal_server_.get(), int_ip, int_ip,
external_socket_thread, ext_factory, ext_ip);
}
NATSocketServer::Translator::~Translator() {
@ -443,8 +446,8 @@ NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator(
return nullptr;
AddClient(ext_ip);
return nats_.Add(ext_ip,
new Translator(server_, type, int_ip, server_, ext_ip));
return nats_.Add(ext_ip, new Translator(server_, type, int_ip,
*server_->queue(), server_, ext_ip));
}
void NATSocketServer::Translator::RemoveTranslator(
const SocketAddress& ext_ip) {

View file

@ -102,6 +102,7 @@ class NATSocketServer : public SocketServer, public NATInternalSocketFactory {
Translator(NATSocketServer* server,
NATType type,
const SocketAddress& int_addr,
Thread& external_socket_thread,
SocketFactory* ext_factory,
const SocketAddress& ext_addr);
~Translator();

View file

@ -76,16 +76,17 @@ void TestSend(SocketServer* internal,
Thread th_int(internal);
Thread th_ext(external);
SocketAddress server_addr = internal_addr;
server_addr.SetPort(0); // Auto-select a port
NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr,
external, external_addrs[0]);
NATSocketFactory* natsf = new NATSocketFactory(
internal, nat->internal_udp_address(), nat->internal_tcp_address());
th_int.Start();
th_ext.Start();
SocketAddress server_addr = internal_addr;
server_addr.SetPort(0); // Auto-select a port
NATServer* nat =
new NATServer(nat_type, th_int, internal, server_addr, server_addr,
th_ext, external, external_addrs[0]);
NATSocketFactory* natsf = new NATSocketFactory(
internal, nat->internal_udp_address(), nat->internal_tcp_address());
TestClient* in;
th_int.BlockingCall([&] { in = CreateTestClient(natsf, internal_addr); });
@ -139,13 +140,13 @@ void TestRecv(SocketServer* internal,
SocketAddress server_addr = internal_addr;
server_addr.SetPort(0); // Auto-select a port
NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr,
external, external_addrs[0]);
NATSocketFactory* natsf = new NATSocketFactory(
internal, nat->internal_udp_address(), nat->internal_tcp_address());
th_int.Start();
th_ext.Start();
NATServer* nat =
new NATServer(nat_type, th_int, internal, server_addr, server_addr,
th_ext, external, external_addrs[0]);
NATSocketFactory* natsf = new NATSocketFactory(
internal, nat->internal_udp_address(), nat->internal_tcp_address());
TestClient* in = nullptr;
th_int.BlockingCall([&] { in = CreateTestClient(natsf, internal_addr); });
@ -355,9 +356,11 @@ class NatTcpTest : public ::testing::Test, public sigslot::has_slots<> {
int_thread_(new Thread(int_vss_.get())),
ext_thread_(new Thread(ext_vss_.get())),
nat_(new NATServer(NAT_OPEN_CONE,
*int_thread_,
int_vss_.get(),
int_addr_,
int_addr_,
*ext_thread_,
ext_vss_.get(),
ext_addr_)),
natsf_(new NATSocketFactory(int_vss_.get(),