Handle order of opertions of hole punch better

This commit is contained in:
Owen 2025-04-22 22:11:37 -04:00
parent 6a146ed371
commit 175718a48e
No known key found for this signature in database
GPG key ID: 8271FDFFD9E0CCBD
2 changed files with 74 additions and 31 deletions

View file

@ -403,7 +403,7 @@ func main() {
if generateAndSaveKeyTo == "" {
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "/tmp/newtkey", "Path to save generated private key")
}
flag.BoolVar(&rm, "rm", true, "Remove the WireGuard interface")
flag.BoolVar(&rm, "rm", false, "Remove the WireGuard interface")
flag.BoolVar(&acceptClients, "accept-clients", false, "Accept clients on the WireGuard interface")
// do a --version check

103
wg/wg.go
View file

@ -60,6 +60,7 @@ type WireGuardService struct {
host string
serverPubKey string
token string
stopGetConfig chan struct{}
}
// Add this type definition
@ -168,12 +169,6 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
}
}
port, err := FindAvailableUDPPort(49152, 65535)
if err != nil {
fmt.Printf("Error finding available port: %v\n", err)
return nil, err
}
service := &WireGuardService{
interfaceName: interfaceName,
mtu: mtu,
@ -181,10 +176,23 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
wgClient: wgClient,
key: key,
newtId: newtId,
lastReadings: make(map[string]PeerReading),
Port: port,
stopHolepunch: make(chan struct{}),
host: host,
lastReadings: make(map[string]PeerReading),
stopHolepunch: make(chan struct{}),
stopGetConfig: make(chan struct{}),
}
// Get the existing wireguard port (keep this part)
device, err := service.wgClient.Device(service.interfaceName)
if err == nil {
service.Port = uint16(device.ListenPort)
logger.Info("WireGuard interface %s already exists with port %d\n", service.interfaceName, service.Port)
} else {
service.Port, err = FindAvailableUDPPort(49152, 65535)
if err != nil {
fmt.Printf("Error finding available port: %v\n", err)
return nil, err
}
}
// Register websocket handlers
@ -193,16 +201,35 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer)
wsClient.RegisterHandler("newt/wg/peer/update", service.handleUpdatePeer)
if err := service.sendUDPHolePunch(service.host + ":21820"); err != nil {
logger.Error("Failed to send UDP hole punch: %v", err)
}
// start the UDP holepunch
go service.keepSendingUDPHolePunch(service.host)
return service, nil
}
func (s *WireGuardService) Close(rm bool) {
select {
case <-s.stopGetConfig:
// Already closed, do nothing
default:
close(s.stopGetConfig)
}
s.wgClient.Close()
// Remove the WireGuard interface
if rm {
if err := s.removeInterface(); err != nil {
logger.Error("Failed to remove WireGuard interface: %v", err)
}
// Remove the private key file
if err := os.Remove(s.key.String()); err != nil {
logger.Error("Failed to remove private key file: %v", err)
}
}
}
@ -215,24 +242,15 @@ func (s *WireGuardService) SetToken(token string) {
}
func (s *WireGuardService) LoadRemoteConfig() error {
// get the exising wireguard port
device, err := s.wgClient.Device(s.interfaceName)
if err == nil {
s.Port = uint16(device.ListenPort)
logger.Info("WireGuard interface %s already exists with port %d\n", s.interfaceName, s.Port)
}
err = s.client.SendMessage("newt/wg/get-config", map[string]interface{}{
"publicKey": fmt.Sprintf("%s", s.key.PublicKey().String()),
"port": s.Port,
})
// Send the initial message
err := s.sendGetConfigMessage()
if err != nil {
logger.Error("Failed to send registration message: %v", err)
logger.Error("Failed to send initial get-config message: %v", err)
return err
}
logger.Info("Requesting WireGuard configuration from remote server")
// Start goroutine to periodically send the message until config is received
go s.keepSendingGetConfig()
go s.periodicBandwidthCheck()
@ -256,6 +274,8 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
}
s.config = config
close(s.stopGetConfig)
// Ensure the WireGuard interface and peers are configured
if err := s.ensureWireguardInterface(config); err != nil {
logger.Error("Failed to ensure WireGuard interface: %v", err)
@ -264,13 +284,6 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
if err := s.ensureWireguardPeers(config.Peers); err != nil {
logger.Error("Failed to ensure WireGuard peers: %v", err)
}
if err := s.sendUDPHolePunch(s.host + ":21820"); err != nil {
logger.Error("Failed to send UDP hole punch: %v", err)
}
// start the UDP holepunch
go s.keepSendingUDPHolePunch(s.host)
}
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
@ -932,3 +945,33 @@ func (s *WireGuardService) removeInterface() error {
return nil
}
func (s *WireGuardService) sendGetConfigMessage() error {
err := s.client.SendMessage("newt/wg/get-config", map[string]interface{}{
"publicKey": fmt.Sprintf("%s", s.key.PublicKey().String()),
"port": s.Port,
})
if err != nil {
logger.Error("Failed to send get-config message: %v", err)
return err
}
logger.Info("Requesting WireGuard configuration from remote server")
return nil
}
func (s *WireGuardService) keepSendingGetConfig() {
ticker := time.NewTicker(3 * time.Second)
defer ticker.Stop()
for {
select {
case <-s.stopGetConfig:
logger.Info("Stopping get-config messages")
return
case <-ticker.C:
if err := s.sendGetConfigMessage(); err != nil {
logger.Error("Failed to send periodic get-config: %v", err)
}
}
}
}