diff --git a/main.go b/main.go index 8e81054..139da58 100644 --- a/main.go +++ b/main.go @@ -347,7 +347,7 @@ func main() { if reachableAt != "" { logger.Info("Sending reachableAt to server: %s", reachableAt) // Create WireGuard service - wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client) + wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, endpoint, id, client) if err != nil { logger.Fatal("Failed to create WireGuard service: %v", err) } diff --git a/wg/wg.go b/wg/wg.go index 4f388c1..9b3a137 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -11,6 +11,7 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/websocket" "github.com/vishvananda/netlink" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -46,11 +47,31 @@ type WireGuardService struct { config WgConfig key wgtypes.Key reachableAt string + newtId string lastReadings map[string]PeerReading mu sync.Mutex + port uint16 } -func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) { +// Add this type definition +type fixedPortBind struct { + port uint16 + conn.Bind +} + +func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) { + // Ignore the requested port and use our fixed port + return b.Bind.Open(b.port) +} + +func NewFixedPortBind(port uint16) conn.Bind { + return &fixedPortBind{ + port: port, + Bind: conn.NewDefaultBind(), + } +} + +func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, endpoint string, newtId string, wsClient *websocket.Client) (*WireGuardService, error) { wgClient, err := wgctrl.New() if err != nil { return nil, fmt.Errorf("failed to create WireGuard client: %v", err) @@ -87,7 +108,14 @@ func NewWireGuardService(interfaceName string, mtu int, reachableAt string, gene wgClient: wgClient, key: key, reachableAt: reachableAt, + newtId: newtId, lastReadings: make(map[string]PeerReading), + port: 21821, + } + + if err := service.sendUDPHolePunch(endpoint + ":21820"); err != nil { + logger.Error("Failed to send UDP hole punch: %v", err) + // Continue anyway as this is just for NAT traversal } // Register websocket handlers @@ -185,12 +213,13 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { return fmt.Errorf("failed to parse private key: %v", err) } - // Create a new WireGuard configuration config := wgtypes.Config{ PrivateKey: &key, ListenPort: new(int), } - *config.ListenPort = wgconfig.ListenPort + + // Use the service's fixed port instead of the config port + *config.ListenPort = int(s.port) // Create and configure the WireGuard interface err = s.wgClient.ConfigureDevice(s.interfaceName, config) @@ -591,3 +620,40 @@ func (s *WireGuardService) reportPeerBandwidth() error { return nil } + +func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error { + // Bind to specific local port + localAddr := &net.UDPAddr{ + Port: int(s.port), + IP: net.IPv4zero, + } + + remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) + if err != nil { + return fmt.Errorf("failed to resolve UDP address: %v", err) + } + + conn, err := net.ListenUDP("udp", localAddr) + if err != nil { + return fmt.Errorf("failed to bind UDP socket: %v", err) + } + defer conn.Close() + + payload := struct { + NewtID string `json:"newtId"` + }{ + NewtID: s.newtId, + } + + data, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal payload: %v", err) + } + + _, err = conn.WriteToUDP(data, remoteAddr) + if err != nil { + return fmt.Errorf("failed to send UDP packet: %v", err) + } + + return nil +}