diff --git a/main.go b/main.go index f9c0fe4..f912014 100644 --- a/main.go +++ b/main.go @@ -364,13 +364,6 @@ func parseTargetData(data interface{}) (TargetData, error) { } func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error { - - // Stop the proxy manager before adding new targets - err := pm.Stop() - if err != nil { - logger.Error("Failed to stop proxy manager: %v", err) - } - for _, t := range targetData.Targets { // Split the first number off of the target with : separator and use as the port parts := strings.Split(t, ":") @@ -389,18 +382,34 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto if action == "add" { target := parts[1] + ":" + parts[2] - pm.RemoveTarget(proto, tunnelIP, port) // remove it first in case this is an update. we are kind of using the internal port as the "targetId" in the proxy + // Only remove the specific target if it exists + err := pm.RemoveTarget(proto, tunnelIP, port) + if err != nil { + // Ignore "target not found" errors as this is expected for new targets + if !strings.Contains(err.Error(), "target not found") { + logger.Error("Failed to remove existing target: %v", err) + } + } + + // Add the new target pm.AddTarget(proto, tunnelIP, port, target) + + // Start just this target by calling Start() on the proxy manager + // The Start() function is idempotent and will only start new targets + err = pm.Start() + if err != nil { + logger.Error("Failed to start proxy manager after adding target: %v", err) + return err + } } else if action == "remove" { logger.Info("Removing target with port %d", port) - pm.RemoveTarget(proto, tunnelIP, port) + err := pm.RemoveTarget(proto, tunnelIP, port) + if err != nil { + logger.Error("Failed to remove target: %v", err) + return err + } } } - err = pm.Start() - if err != nil { - logger.Error("Failed to start proxy manager: %v", err) - } - return nil } diff --git a/proxy/manager.go b/proxy/manager.go index 2f4975a..0077be9 100644 --- a/proxy/manager.go +++ b/proxy/manager.go @@ -7,6 +7,7 @@ import ( "net" "strings" "sync" + "time" "golang.zx2c4.com/wireguard/tun/netstack" ) @@ -27,6 +28,7 @@ func (pm *ProxyManager) AddTarget(protocol, listen string, port int, target stri Port: port, Target: target, cancel: make(chan struct{}), + done: make(chan struct{}), } pm.targets = append(pm.targets, newTarget) @@ -45,23 +47,42 @@ func (pm *ProxyManager) RemoveTarget(protocol, listen string, port int) error { if target.Listen == listen && target.Port == port && strings.ToLower(target.Protocol) == protocol { + // Signal the serving goroutine to stop - // close(target.cancel) + select { + case <-target.cancel: + // Channel is already closed, no need to close it again + default: + close(target.cancel) + } // Close the appropriate listener/connection based on protocol target.Lock() switch protocol { case "tcp": if target.listener != nil { - target.listener.Close() + select { + case <-target.cancel: + // Listener was already closed by Stop() + default: + target.listener.Close() + } } case "udp": if target.udpConn != nil { - target.udpConn.Close() + select { + case <-target.cancel: + // Connection was already closed by Stop() + default: + target.udpConn.Close() + } } } target.Unlock() + // Wait for the target to fully stop + <-target.done + // Remove the target from the slice pm.targets = append(pm.targets[:i], pm.targets[i+1:]...) return nil @@ -76,7 +97,16 @@ func (pm *ProxyManager) Start() error { defer pm.RUnlock() for i := range pm.targets { - target := &pm.targets[i] // Use pointer to modify the target in the slice + target := &pm.targets[i] + + // Skip already running targets + target.Lock() + if target.listener != nil || target.udpConn != nil { + target.Unlock() + continue + } + target.Unlock() + switch strings.ToLower(target.Protocol) { case "tcp": go pm.serveTCP(target) @@ -93,27 +123,36 @@ func (pm *ProxyManager) Stop() error { pm.Lock() defer pm.Unlock() + var wg sync.WaitGroup for i := range pm.targets { target := &pm.targets[i] - close(target.cancel) - target.Lock() - if target.listener != nil { - target.listener.Close() - } - if target.udpConn != nil { - target.udpConn.Close() - } - target.Unlock() + wg.Add(1) + go func(t *ProxyTarget) { + defer wg.Done() + close(t.cancel) + t.Lock() + if t.listener != nil { + t.listener.Close() + } + if t.udpConn != nil { + t.udpConn.Close() + } + t.Unlock() + // Wait for the target to fully stop + <-t.done + }(target) } + wg.Wait() return nil } func (pm *ProxyManager) serveTCP(target *ProxyTarget) { + defer close(target.done) // Signal that this target is fully stopped + listener, err := pm.tnet.ListenTCP(&net.TCPAddr{ IP: net.ParseIP(target.Listen), Port: target.Port, }) - log.Printf("Listening on %s:%d", target.Listen, target.Port) if err != nil { log.Printf("Failed to start TCP listener for %s:%d: %v", target.Listen, target.Port, err) return @@ -126,14 +165,13 @@ func (pm *ProxyManager) serveTCP(target *ProxyTarget) { defer listener.Close() log.Printf("TCP proxy listening on %s", listener.Addr()) - // Channel to signal active connections to close - done := make(chan struct{}) var activeConns sync.WaitGroup + acceptDone := make(chan struct{}) // Goroutine to handle shutdown signal go func() { <-target.cancel - close(done) + close(acceptDone) listener.Close() }() @@ -147,6 +185,8 @@ func (pm *ProxyManager) serveTCP(target *ProxyTarget) { return default: log.Printf("Failed to accept TCP connection: %v", err) + // Don't return here, try to accept new connections + time.Sleep(time.Second) continue } } @@ -154,7 +194,7 @@ func (pm *ProxyManager) serveTCP(target *ProxyTarget) { activeConns.Add(1) go func() { defer activeConns.Done() - pm.handleTCPConnection(conn, target.Target, done) + pm.handleTCPConnection(conn, target.Target, acceptDone) }() } } @@ -198,6 +238,8 @@ func (pm *ProxyManager) handleTCPConnection(clientConn net.Conn, target string, } func (pm *ProxyManager) serveUDP(target *ProxyTarget) { + defer close(target.done) // Signal that this target is fully stopped + addr := &net.UDPAddr{ IP: net.ParseIP(target.Listen), Port: target.Port, @@ -217,16 +259,19 @@ func (pm *ProxyManager) serveUDP(target *ProxyTarget) { log.Printf("UDP proxy listening on %s", conn.LocalAddr()) buffer := make([]byte, 65535) + var activeConns sync.WaitGroup for { select { case <-target.cancel: + activeConns.Wait() // Wait for all active UDP handlers to complete return default: n, remoteAddr, err := conn.ReadFrom(buffer) if err != nil { select { case <-target.cancel: + activeConns.Wait() return default: log.Printf("Failed to read UDP packet: %v", err) @@ -240,7 +285,9 @@ func (pm *ProxyManager) serveUDP(target *ProxyTarget) { continue } + activeConns.Add(1) go func(data []byte, remote net.Addr) { + defer activeConns.Done() targetConn, err := net.DialUDP("udp", nil, targetAddr) if err != nil { log.Printf("Failed to connect to target %s: %v", target.Target, err) diff --git a/proxy/types.go b/proxy/types.go index 2886431..10bd046 100644 --- a/proxy/types.go +++ b/proxy/types.go @@ -14,9 +14,10 @@ type ProxyTarget struct { Port int Target string cancel chan struct{} // Channel to signal shutdown + done chan struct{} // Channel to signal completion listener net.Listener // For TCP udpConn net.PacketConn // For UDP - sync.Mutex // Protect access to connections + sync.Mutex // Protect access to connection } type ProxyManager struct {