Handle freeing ports correctly

This commit is contained in:
Owen Schwartz 2024-11-23 18:20:56 -05:00
parent 99f7d13efe
commit 7c6c4237cf
No known key found for this signature in database
GPG key ID: 8271FDFFD9E0CCBD
3 changed files with 90 additions and 33 deletions

35
main.go
View file

@ -364,13 +364,6 @@ func parseTargetData(data interface{}) (TargetData, error) {
} }
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error { func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
// Stop the proxy manager before adding new targets
err := pm.Stop()
if err != nil {
logger.Error("Failed to stop proxy manager: %v", err)
}
for _, t := range targetData.Targets { for _, t := range targetData.Targets {
// Split the first number off of the target with : separator and use as the port // Split the first number off of the target with : separator and use as the port
parts := strings.Split(t, ":") parts := strings.Split(t, ":")
@ -389,17 +382,33 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
if action == "add" { if action == "add" {
target := parts[1] + ":" + parts[2] target := parts[1] + ":" + parts[2]
pm.RemoveTarget(proto, tunnelIP, port) // remove it first in case this is an update. we are kind of using the internal port as the "targetId" in the proxy // Only remove the specific target if it exists
pm.AddTarget(proto, tunnelIP, port, target) err := pm.RemoveTarget(proto, tunnelIP, port)
} else if action == "remove" { if err != nil {
logger.Info("Removing target with port %d", port) // Ignore "target not found" errors as this is expected for new targets
pm.RemoveTarget(proto, tunnelIP, port) if !strings.Contains(err.Error(), "target not found") {
logger.Error("Failed to remove existing target: %v", err)
} }
} }
// Add the new target
pm.AddTarget(proto, tunnelIP, port, target)
// Start just this target by calling Start() on the proxy manager
// The Start() function is idempotent and will only start new targets
err = pm.Start() err = pm.Start()
if err != nil { if err != nil {
logger.Error("Failed to start proxy manager: %v", err) logger.Error("Failed to start proxy manager after adding target: %v", err)
return err
}
} else if action == "remove" {
logger.Info("Removing target with port %d", port)
err := pm.RemoveTarget(proto, tunnelIP, port)
if err != nil {
logger.Error("Failed to remove target: %v", err)
return err
}
}
} }
return nil return nil

View file

@ -7,6 +7,7 @@ import (
"net" "net"
"strings" "strings"
"sync" "sync"
"time"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
) )
@ -27,6 +28,7 @@ func (pm *ProxyManager) AddTarget(protocol, listen string, port int, target stri
Port: port, Port: port,
Target: target, Target: target,
cancel: make(chan struct{}), cancel: make(chan struct{}),
done: make(chan struct{}),
} }
pm.targets = append(pm.targets, newTarget) pm.targets = append(pm.targets, newTarget)
@ -45,23 +47,42 @@ func (pm *ProxyManager) RemoveTarget(protocol, listen string, port int) error {
if target.Listen == listen && if target.Listen == listen &&
target.Port == port && target.Port == port &&
strings.ToLower(target.Protocol) == protocol { strings.ToLower(target.Protocol) == protocol {
// Signal the serving goroutine to stop // Signal the serving goroutine to stop
// close(target.cancel) select {
case <-target.cancel:
// Channel is already closed, no need to close it again
default:
close(target.cancel)
}
// Close the appropriate listener/connection based on protocol // Close the appropriate listener/connection based on protocol
target.Lock() target.Lock()
switch protocol { switch protocol {
case "tcp": case "tcp":
if target.listener != nil { if target.listener != nil {
select {
case <-target.cancel:
// Listener was already closed by Stop()
default:
target.listener.Close() target.listener.Close()
} }
}
case "udp": case "udp":
if target.udpConn != nil { if target.udpConn != nil {
select {
case <-target.cancel:
// Connection was already closed by Stop()
default:
target.udpConn.Close() target.udpConn.Close()
} }
} }
}
target.Unlock() target.Unlock()
// Wait for the target to fully stop
<-target.done
// Remove the target from the slice // Remove the target from the slice
pm.targets = append(pm.targets[:i], pm.targets[i+1:]...) pm.targets = append(pm.targets[:i], pm.targets[i+1:]...)
return nil return nil
@ -76,7 +97,16 @@ func (pm *ProxyManager) Start() error {
defer pm.RUnlock() defer pm.RUnlock()
for i := range pm.targets { for i := range pm.targets {
target := &pm.targets[i] // Use pointer to modify the target in the slice target := &pm.targets[i]
// Skip already running targets
target.Lock()
if target.listener != nil || target.udpConn != nil {
target.Unlock()
continue
}
target.Unlock()
switch strings.ToLower(target.Protocol) { switch strings.ToLower(target.Protocol) {
case "tcp": case "tcp":
go pm.serveTCP(target) go pm.serveTCP(target)
@ -93,27 +123,36 @@ func (pm *ProxyManager) Stop() error {
pm.Lock() pm.Lock()
defer pm.Unlock() defer pm.Unlock()
var wg sync.WaitGroup
for i := range pm.targets { for i := range pm.targets {
target := &pm.targets[i] target := &pm.targets[i]
close(target.cancel) wg.Add(1)
target.Lock() go func(t *ProxyTarget) {
if target.listener != nil { defer wg.Done()
target.listener.Close() close(t.cancel)
t.Lock()
if t.listener != nil {
t.listener.Close()
} }
if target.udpConn != nil { if t.udpConn != nil {
target.udpConn.Close() t.udpConn.Close()
} }
target.Unlock() t.Unlock()
// Wait for the target to fully stop
<-t.done
}(target)
} }
wg.Wait()
return nil return nil
} }
func (pm *ProxyManager) serveTCP(target *ProxyTarget) { func (pm *ProxyManager) serveTCP(target *ProxyTarget) {
defer close(target.done) // Signal that this target is fully stopped
listener, err := pm.tnet.ListenTCP(&net.TCPAddr{ listener, err := pm.tnet.ListenTCP(&net.TCPAddr{
IP: net.ParseIP(target.Listen), IP: net.ParseIP(target.Listen),
Port: target.Port, Port: target.Port,
}) })
log.Printf("Listening on %s:%d", target.Listen, target.Port)
if err != nil { if err != nil {
log.Printf("Failed to start TCP listener for %s:%d: %v", target.Listen, target.Port, err) log.Printf("Failed to start TCP listener for %s:%d: %v", target.Listen, target.Port, err)
return return
@ -126,14 +165,13 @@ func (pm *ProxyManager) serveTCP(target *ProxyTarget) {
defer listener.Close() defer listener.Close()
log.Printf("TCP proxy listening on %s", listener.Addr()) log.Printf("TCP proxy listening on %s", listener.Addr())
// Channel to signal active connections to close
done := make(chan struct{})
var activeConns sync.WaitGroup var activeConns sync.WaitGroup
acceptDone := make(chan struct{})
// Goroutine to handle shutdown signal // Goroutine to handle shutdown signal
go func() { go func() {
<-target.cancel <-target.cancel
close(done) close(acceptDone)
listener.Close() listener.Close()
}() }()
@ -147,6 +185,8 @@ func (pm *ProxyManager) serveTCP(target *ProxyTarget) {
return return
default: default:
log.Printf("Failed to accept TCP connection: %v", err) log.Printf("Failed to accept TCP connection: %v", err)
// Don't return here, try to accept new connections
time.Sleep(time.Second)
continue continue
} }
} }
@ -154,7 +194,7 @@ func (pm *ProxyManager) serveTCP(target *ProxyTarget) {
activeConns.Add(1) activeConns.Add(1)
go func() { go func() {
defer activeConns.Done() defer activeConns.Done()
pm.handleTCPConnection(conn, target.Target, done) pm.handleTCPConnection(conn, target.Target, acceptDone)
}() }()
} }
} }
@ -198,6 +238,8 @@ func (pm *ProxyManager) handleTCPConnection(clientConn net.Conn, target string,
} }
func (pm *ProxyManager) serveUDP(target *ProxyTarget) { func (pm *ProxyManager) serveUDP(target *ProxyTarget) {
defer close(target.done) // Signal that this target is fully stopped
addr := &net.UDPAddr{ addr := &net.UDPAddr{
IP: net.ParseIP(target.Listen), IP: net.ParseIP(target.Listen),
Port: target.Port, Port: target.Port,
@ -217,16 +259,19 @@ func (pm *ProxyManager) serveUDP(target *ProxyTarget) {
log.Printf("UDP proxy listening on %s", conn.LocalAddr()) log.Printf("UDP proxy listening on %s", conn.LocalAddr())
buffer := make([]byte, 65535) buffer := make([]byte, 65535)
var activeConns sync.WaitGroup
for { for {
select { select {
case <-target.cancel: case <-target.cancel:
activeConns.Wait() // Wait for all active UDP handlers to complete
return return
default: default:
n, remoteAddr, err := conn.ReadFrom(buffer) n, remoteAddr, err := conn.ReadFrom(buffer)
if err != nil { if err != nil {
select { select {
case <-target.cancel: case <-target.cancel:
activeConns.Wait()
return return
default: default:
log.Printf("Failed to read UDP packet: %v", err) log.Printf("Failed to read UDP packet: %v", err)
@ -240,7 +285,9 @@ func (pm *ProxyManager) serveUDP(target *ProxyTarget) {
continue continue
} }
activeConns.Add(1)
go func(data []byte, remote net.Addr) { go func(data []byte, remote net.Addr) {
defer activeConns.Done()
targetConn, err := net.DialUDP("udp", nil, targetAddr) targetConn, err := net.DialUDP("udp", nil, targetAddr)
if err != nil { if err != nil {
log.Printf("Failed to connect to target %s: %v", target.Target, err) log.Printf("Failed to connect to target %s: %v", target.Target, err)

View file

@ -14,9 +14,10 @@ type ProxyTarget struct {
Port int Port int
Target string Target string
cancel chan struct{} // Channel to signal shutdown cancel chan struct{} // Channel to signal shutdown
done chan struct{} // Channel to signal completion
listener net.Listener // For TCP listener net.Listener // For TCP
udpConn net.PacketConn // For UDP udpConn net.PacketConn // For UDP
sync.Mutex // Protect access to connections sync.Mutex // Protect access to connection
} }
type ProxyManager struct { type ProxyManager struct {