From 175718a48e4bbb58b9499d0b872d4fb0c63c8f2b Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 22 Apr 2025 22:11:37 -0400 Subject: [PATCH] Handle order of opertions of hole punch better --- main.go | 2 +- wg/wg.go | 103 +++++++++++++++++++++++++++++++++++++++---------------- 2 files changed, 74 insertions(+), 31 deletions(-) diff --git a/main.go b/main.go index 97c9c33..d9a3cef 100644 --- a/main.go +++ b/main.go @@ -403,7 +403,7 @@ func main() { if generateAndSaveKeyTo == "" { flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "/tmp/newtkey", "Path to save generated private key") } - flag.BoolVar(&rm, "rm", true, "Remove the WireGuard interface") + flag.BoolVar(&rm, "rm", false, "Remove the WireGuard interface") flag.BoolVar(&acceptClients, "accept-clients", false, "Accept clients on the WireGuard interface") // do a --version check diff --git a/wg/wg.go b/wg/wg.go index 20cb9cd..6e17880 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -60,6 +60,7 @@ type WireGuardService struct { host string serverPubKey string token string + stopGetConfig chan struct{} } // Add this type definition @@ -168,12 +169,6 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str } } - port, err := FindAvailableUDPPort(49152, 65535) - if err != nil { - fmt.Printf("Error finding available port: %v\n", err) - return nil, err - } - service := &WireGuardService{ interfaceName: interfaceName, mtu: mtu, @@ -181,10 +176,23 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str wgClient: wgClient, key: key, newtId: newtId, - lastReadings: make(map[string]PeerReading), - Port: port, - stopHolepunch: make(chan struct{}), host: host, + lastReadings: make(map[string]PeerReading), + stopHolepunch: make(chan struct{}), + stopGetConfig: make(chan struct{}), + } + + // Get the existing wireguard port (keep this part) + device, err := service.wgClient.Device(service.interfaceName) + if err == nil { + service.Port = uint16(device.ListenPort) + logger.Info("WireGuard interface %s already exists with port %d\n", service.interfaceName, service.Port) + } else { + service.Port, err = FindAvailableUDPPort(49152, 65535) + if err != nil { + fmt.Printf("Error finding available port: %v\n", err) + return nil, err + } } // Register websocket handlers @@ -193,16 +201,35 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer) wsClient.RegisterHandler("newt/wg/peer/update", service.handleUpdatePeer) + if err := service.sendUDPHolePunch(service.host + ":21820"); err != nil { + logger.Error("Failed to send UDP hole punch: %v", err) + } + + // start the UDP holepunch + go service.keepSendingUDPHolePunch(service.host) + return service, nil } func (s *WireGuardService) Close(rm bool) { + select { + case <-s.stopGetConfig: + // Already closed, do nothing + default: + close(s.stopGetConfig) + } + s.wgClient.Close() // Remove the WireGuard interface if rm { if err := s.removeInterface(); err != nil { logger.Error("Failed to remove WireGuard interface: %v", err) } + + // Remove the private key file + if err := os.Remove(s.key.String()); err != nil { + logger.Error("Failed to remove private key file: %v", err) + } } } @@ -215,24 +242,15 @@ func (s *WireGuardService) SetToken(token string) { } func (s *WireGuardService) LoadRemoteConfig() error { - - // get the exising wireguard port - device, err := s.wgClient.Device(s.interfaceName) - if err == nil { - s.Port = uint16(device.ListenPort) - logger.Info("WireGuard interface %s already exists with port %d\n", s.interfaceName, s.Port) - } - - err = s.client.SendMessage("newt/wg/get-config", map[string]interface{}{ - "publicKey": fmt.Sprintf("%s", s.key.PublicKey().String()), - "port": s.Port, - }) + // Send the initial message + err := s.sendGetConfigMessage() if err != nil { - logger.Error("Failed to send registration message: %v", err) + logger.Error("Failed to send initial get-config message: %v", err) return err } - logger.Info("Requesting WireGuard configuration from remote server") + // Start goroutine to periodically send the message until config is received + go s.keepSendingGetConfig() go s.periodicBandwidthCheck() @@ -256,6 +274,8 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { } s.config = config + close(s.stopGetConfig) + // Ensure the WireGuard interface and peers are configured if err := s.ensureWireguardInterface(config); err != nil { logger.Error("Failed to ensure WireGuard interface: %v", err) @@ -264,13 +284,6 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { if err := s.ensureWireguardPeers(config.Peers); err != nil { logger.Error("Failed to ensure WireGuard peers: %v", err) } - - if err := s.sendUDPHolePunch(s.host + ":21820"); err != nil { - logger.Error("Failed to send UDP hole punch: %v", err) - } - - // start the UDP holepunch - go s.keepSendingUDPHolePunch(s.host) } func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { @@ -932,3 +945,33 @@ func (s *WireGuardService) removeInterface() error { return nil } + +func (s *WireGuardService) sendGetConfigMessage() error { + err := s.client.SendMessage("newt/wg/get-config", map[string]interface{}{ + "publicKey": fmt.Sprintf("%s", s.key.PublicKey().String()), + "port": s.Port, + }) + if err != nil { + logger.Error("Failed to send get-config message: %v", err) + return err + } + logger.Info("Requesting WireGuard configuration from remote server") + return nil +} + +func (s *WireGuardService) keepSendingGetConfig() { + ticker := time.NewTicker(3 * time.Second) + defer ticker.Stop() + + for { + select { + case <-s.stopGetConfig: + logger.Info("Stopping get-config messages") + return + case <-ticker.C: + if err := s.sendGetConfigMessage(); err != nil { + logger.Error("Failed to send periodic get-config: %v", err) + } + } + } +}