mirror of
https://github.com/fosrl/newt.git
synced 2025-05-13 13:40:39 +01:00
Merge branch 'dev' of https://github.com/fosrl/newt into dev
This commit is contained in:
commit
a1a3dd9ba2
7 changed files with 291 additions and 299 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,2 +1,3 @@
|
||||||
newt
|
newt
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
bin/
|
1
go.mod
1
go.mod
|
@ -10,6 +10,7 @@ require (
|
||||||
github.com/google/btree v1.1.2 // indirect
|
github.com/google/btree v1.1.2 // indirect
|
||||||
github.com/gorilla/websocket v1.5.3 // indirect
|
github.com/gorilla/websocket v1.5.3 // indirect
|
||||||
golang.org/x/crypto v0.28.0 // indirect
|
golang.org/x/crypto v0.28.0 // indirect
|
||||||
|
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 // indirect
|
||||||
golang.org/x/net v0.30.0 // indirect
|
golang.org/x/net v0.30.0 // indirect
|
||||||
golang.org/x/sys v0.26.0 // indirect
|
golang.org/x/sys v0.26.0 // indirect
|
||||||
golang.org/x/time v0.7.0 // indirect
|
golang.org/x/time v0.7.0 // indirect
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -4,6 +4,8 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN
|
||||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
|
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
|
||||||
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
|
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
|
||||||
|
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA=
|
||||||
|
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU=
|
||||||
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
|
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
|
||||||
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
|
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
|
||||||
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
|
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
|
||||||
|
|
24
main.go
24
main.go
|
@ -283,17 +283,21 @@ func main() {
|
||||||
if logLevel == "" {
|
if logLevel == "" {
|
||||||
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// do a --version check
|
||||||
|
version := flag.Bool("version", false, "Print the version")
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
if *version {
|
||||||
|
fmt.Println("Newt version replaceme")
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
logger.Init()
|
logger.Init()
|
||||||
loggerLevel := parseLogLevel(logLevel)
|
loggerLevel := parseLogLevel(logLevel)
|
||||||
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
||||||
|
|
||||||
// Validate required fields
|
|
||||||
if endpoint == "" || id == "" || secret == "" {
|
|
||||||
logger.Fatal("endpoint, id, and secret are required either via CLI flags or environment variables")
|
|
||||||
}
|
|
||||||
|
|
||||||
// parse the mtu string into an int
|
// parse the mtu string into an int
|
||||||
mtuInt, err = strconv.Atoi(mtu)
|
mtuInt, err = strconv.Atoi(mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -455,11 +459,6 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||||
if len(targetData.Targets) > 0 {
|
if len(targetData.Targets) > 0 {
|
||||||
updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData)
|
updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = pm.Start()
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to start proxy manager: %v", err)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) {
|
client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) {
|
||||||
|
@ -480,11 +479,6 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||||
if len(targetData.Targets) > 0 {
|
if len(targetData.Targets) > 0 {
|
||||||
updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData)
|
updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = pm.Start()
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to start proxy manager: %v", err)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
|
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
|
||||||
|
|
530
proxy/manager.go
530
proxy/manager.go
|
@ -9,326 +9,344 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Target represents a proxy target with its address and port
|
||||||
|
type Target struct {
|
||||||
|
Address string
|
||||||
|
Port int
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProxyManager handles the creation and management of proxy connections
|
||||||
|
type ProxyManager struct {
|
||||||
|
tnet *netstack.Net
|
||||||
|
tcpTargets map[string]map[int]string // map[listenIP]map[port]targetAddress
|
||||||
|
udpTargets map[string]map[int]string
|
||||||
|
listeners []*gonet.TCPListener
|
||||||
|
udpConns []*gonet.UDPConn
|
||||||
|
running bool
|
||||||
|
mutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProxyManager creates a new proxy manager instance
|
||||||
func NewProxyManager(tnet *netstack.Net) *ProxyManager {
|
func NewProxyManager(tnet *netstack.Net) *ProxyManager {
|
||||||
return &ProxyManager{
|
return &ProxyManager{
|
||||||
tnet: tnet,
|
tnet: tnet,
|
||||||
|
tcpTargets: make(map[string]map[int]string),
|
||||||
|
udpTargets: make(map[string]map[int]string),
|
||||||
|
listeners: make([]*gonet.TCPListener, 0),
|
||||||
|
udpConns: make([]*gonet.UDPConn, 0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) AddTarget(protocol, listen string, port int, target string) {
|
// AddTarget adds a new target for proxying
|
||||||
pm.Lock()
|
func (pm *ProxyManager) AddTarget(proto, listenIP string, port int, targetAddr string) error {
|
||||||
defer pm.Unlock()
|
pm.mutex.Lock()
|
||||||
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
logger.Info("Adding target: %s://%s:%d -> %s", protocol, listen, port, target)
|
switch proto {
|
||||||
|
case "tcp":
|
||||||
newTarget := ProxyTarget{
|
if pm.tcpTargets[listenIP] == nil {
|
||||||
Protocol: protocol,
|
pm.tcpTargets[listenIP] = make(map[int]string)
|
||||||
Listen: listen,
|
}
|
||||||
Port: port,
|
pm.tcpTargets[listenIP][port] = targetAddr
|
||||||
Target: target,
|
case "udp":
|
||||||
cancel: make(chan struct{}),
|
if pm.udpTargets[listenIP] == nil {
|
||||||
done: make(chan struct{}),
|
pm.udpTargets[listenIP] = make(map[int]string)
|
||||||
|
}
|
||||||
|
pm.udpTargets[listenIP][port] = targetAddr
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported protocol: %s", proto)
|
||||||
}
|
}
|
||||||
|
|
||||||
pm.targets = append(pm.targets, newTarget)
|
if pm.running {
|
||||||
|
return pm.startTarget(proto, listenIP, port, targetAddr)
|
||||||
|
} else {
|
||||||
|
logger.Info("Not adding target because not running")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) RemoveTarget(protocol, listen string, port int) error {
|
func (pm *ProxyManager) RemoveTarget(proto, listenIP string, port int) error {
|
||||||
pm.Lock()
|
pm.mutex.Lock()
|
||||||
defer pm.Unlock()
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
protocol = strings.ToLower(protocol)
|
switch proto {
|
||||||
if protocol != "tcp" && protocol != "udp" {
|
case "tcp":
|
||||||
return fmt.Errorf("unsupported protocol: %s", protocol)
|
if targets, ok := pm.tcpTargets[listenIP]; ok {
|
||||||
}
|
delete(targets, port)
|
||||||
|
// Remove and close the corresponding TCP listener
|
||||||
for i, target := range pm.targets {
|
for i, listener := range pm.listeners {
|
||||||
if target.Listen == listen &&
|
if addr, ok := listener.Addr().(*net.TCPAddr); ok && addr.Port == port {
|
||||||
target.Port == port &&
|
listener.Close()
|
||||||
strings.ToLower(target.Protocol) == protocol {
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
// Remove from slice
|
||||||
// Signal the serving goroutine to stop
|
pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...)
|
||||||
select {
|
break
|
||||||
case <-target.cancel:
|
|
||||||
// Channel is already closed, no need to close it again
|
|
||||||
default:
|
|
||||||
close(target.cancel)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close the appropriate listener/connection based on protocol
|
|
||||||
target.Lock()
|
|
||||||
switch protocol {
|
|
||||||
case "tcp":
|
|
||||||
if target.listener != nil {
|
|
||||||
select {
|
|
||||||
case <-target.cancel:
|
|
||||||
// Listener was already closed by Stop()
|
|
||||||
default:
|
|
||||||
target.listener.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case "udp":
|
|
||||||
if target.udpConn != nil {
|
|
||||||
select {
|
|
||||||
case <-target.cancel:
|
|
||||||
// Connection was already closed by Stop()
|
|
||||||
default:
|
|
||||||
target.udpConn.Close()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
target.Unlock()
|
|
||||||
|
|
||||||
// Wait for the target to fully stop
|
|
||||||
<-target.done
|
|
||||||
|
|
||||||
// Remove the target from the slice
|
|
||||||
pm.targets = append(pm.targets[:i], pm.targets[i+1:]...)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("target not found for %s %s:%d", protocol, listen, port)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) Start() error {
|
|
||||||
pm.RLock()
|
|
||||||
defer pm.RUnlock()
|
|
||||||
|
|
||||||
for i := range pm.targets {
|
|
||||||
target := &pm.targets[i]
|
|
||||||
|
|
||||||
target.Lock()
|
|
||||||
// If target is already running, skip it
|
|
||||||
if target.listener != nil || target.udpConn != nil {
|
|
||||||
target.Unlock()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mark the target as starting by creating a nil listener/connection
|
|
||||||
// This prevents other goroutines from trying to start it
|
|
||||||
if strings.ToLower(target.Protocol) == "tcp" {
|
|
||||||
target.listener = nil
|
|
||||||
} else {
|
} else {
|
||||||
target.udpConn = nil
|
return fmt.Errorf("target not found: %s:%d", listenIP, port)
|
||||||
}
|
}
|
||||||
target.Unlock()
|
case "udp":
|
||||||
|
if targets, ok := pm.udpTargets[listenIP]; ok {
|
||||||
|
delete(targets, port)
|
||||||
|
// Remove and close the corresponding UDP connection
|
||||||
|
for i, conn := range pm.udpConns {
|
||||||
|
if addr, ok := conn.LocalAddr().(*net.UDPAddr); ok && addr.Port == port {
|
||||||
|
conn.Close()
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
// Remove from slice
|
||||||
|
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("target not found: %s:%d", listenIP, port)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported protocol: %s", proto)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
switch strings.ToLower(target.Protocol) {
|
// Start begins listening for all configured proxy targets
|
||||||
case "tcp":
|
func (pm *ProxyManager) Start() error {
|
||||||
go pm.serveTCP(target)
|
pm.mutex.Lock()
|
||||||
case "udp":
|
defer pm.mutex.Unlock()
|
||||||
go pm.serveUDP(target)
|
|
||||||
default:
|
if pm.running {
|
||||||
return fmt.Errorf("unsupported protocol: %s", target.Protocol)
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start TCP targets
|
||||||
|
for listenIP, targets := range pm.tcpTargets {
|
||||||
|
for port, targetAddr := range targets {
|
||||||
|
if err := pm.startTarget("tcp", listenIP, port, targetAddr); err != nil {
|
||||||
|
return fmt.Errorf("failed to start TCP target: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Start UDP targets
|
||||||
|
for listenIP, targets := range pm.udpTargets {
|
||||||
|
for port, targetAddr := range targets {
|
||||||
|
if err := pm.startTarget("udp", listenIP, port, targetAddr); err != nil {
|
||||||
|
return fmt.Errorf("failed to start UDP target: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.running = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) Stop() error {
|
func (pm *ProxyManager) Stop() error {
|
||||||
pm.Lock()
|
pm.mutex.Lock()
|
||||||
defer pm.Unlock()
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
if !pm.running {
|
||||||
for i := range pm.targets {
|
return nil
|
||||||
target := &pm.targets[i]
|
|
||||||
wg.Add(1)
|
|
||||||
go func(t *ProxyTarget) {
|
|
||||||
defer wg.Done()
|
|
||||||
close(t.cancel)
|
|
||||||
t.Lock()
|
|
||||||
if t.listener != nil {
|
|
||||||
t.listener.Close()
|
|
||||||
}
|
|
||||||
if t.udpConn != nil {
|
|
||||||
t.udpConn.Close()
|
|
||||||
}
|
|
||||||
t.Unlock()
|
|
||||||
// Wait for the target to fully stop
|
|
||||||
<-t.done
|
|
||||||
}(target)
|
|
||||||
}
|
}
|
||||||
wg.Wait()
|
|
||||||
|
// Set running to false first to signal handlers to stop
|
||||||
|
pm.running = false
|
||||||
|
|
||||||
|
// Close TCP listeners
|
||||||
|
for i := len(pm.listeners) - 1; i >= 0; i-- {
|
||||||
|
listener := pm.listeners[i]
|
||||||
|
if err := listener.Close(); err != nil {
|
||||||
|
logger.Error("Error closing TCP listener: %v", err)
|
||||||
|
}
|
||||||
|
// Remove from slice
|
||||||
|
pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close UDP connections
|
||||||
|
for i := len(pm.udpConns) - 1; i >= 0; i-- {
|
||||||
|
conn := pm.udpConns[i]
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
logger.Error("Error closing UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
// Remove from slice
|
||||||
|
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear the target maps
|
||||||
|
for k := range pm.tcpTargets {
|
||||||
|
delete(pm.tcpTargets, k)
|
||||||
|
}
|
||||||
|
for k := range pm.udpTargets {
|
||||||
|
delete(pm.udpTargets, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Give active connections a chance to close gracefully
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) serveTCP(target *ProxyTarget) {
|
func (pm *ProxyManager) startTarget(proto, listenIP string, port int, targetAddr string) error {
|
||||||
defer close(target.done) // Signal that this target is fully stopped
|
switch proto {
|
||||||
|
case "tcp":
|
||||||
|
listener, err := pm.tnet.ListenTCP(&net.TCPAddr{Port: port})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create TCP listener: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
listener, err := pm.tnet.ListenTCP(&net.TCPAddr{
|
pm.listeners = append(pm.listeners, listener)
|
||||||
IP: net.ParseIP(target.Listen),
|
go pm.handleTCPProxy(listener, targetAddr)
|
||||||
Port: target.Port,
|
|
||||||
})
|
case "udp":
|
||||||
if err != nil {
|
addr := &net.UDPAddr{Port: port}
|
||||||
logger.Info("Failed to start TCP listener for %s:%d: %v", target.Listen, target.Port, err)
|
conn, err := pm.tnet.ListenUDP(addr)
|
||||||
return
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create UDP listener: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.udpConns = append(pm.udpConns, conn)
|
||||||
|
go pm.handleUDPProxy(conn, targetAddr)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported protocol: %s", proto)
|
||||||
}
|
}
|
||||||
|
|
||||||
target.Lock()
|
logger.Info("Started %s proxy from %s:%d to %s", proto, listenIP, port, targetAddr)
|
||||||
target.listener = listener
|
|
||||||
target.Unlock()
|
|
||||||
|
|
||||||
defer listener.Close()
|
return nil
|
||||||
logger.Info("TCP proxy listening on %s", listener.Addr())
|
}
|
||||||
|
|
||||||
var activeConns sync.WaitGroup
|
|
||||||
acceptDone := make(chan struct{})
|
|
||||||
|
|
||||||
// Goroutine to handle shutdown signal
|
|
||||||
go func() {
|
|
||||||
<-target.cancel
|
|
||||||
close(acceptDone)
|
|
||||||
listener.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
|
func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string) {
|
||||||
for {
|
for {
|
||||||
conn, err := listener.Accept()
|
conn, err := listener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
select {
|
// Check if we're shutting down or the listener was closed
|
||||||
case <-target.cancel:
|
if !pm.running {
|
||||||
// Wait for active connections to finish
|
|
||||||
activeConns.Wait()
|
|
||||||
return
|
return
|
||||||
default:
|
|
||||||
logger.Info("Failed to accept TCP connection: %v", err)
|
|
||||||
// Don't return here, try to accept new connections
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for specific network errors that indicate the listener is closed
|
||||||
|
if ne, ok := err.(net.Error); ok && !ne.Temporary() {
|
||||||
|
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Error("Error accepting TCP connection: %v", err)
|
||||||
|
// Don't hammer the CPU if we hit a temporary error
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
activeConns.Add(1)
|
|
||||||
go func() {
|
go func() {
|
||||||
defer activeConns.Done()
|
target, err := net.Dial("tcp", targetAddr)
|
||||||
pm.handleTCPConnection(conn, target.Target, acceptDone)
|
if err != nil {
|
||||||
|
logger.Error("Error connecting to target: %v", err)
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a WaitGroup to ensure both copy operations complete
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
io.Copy(target, conn)
|
||||||
|
target.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
io.Copy(conn, target)
|
||||||
|
conn.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for both copies to complete
|
||||||
|
wg.Wait()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) handleTCPConnection(clientConn net.Conn, target string, done chan struct{}) {
|
func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
||||||
defer clientConn.Close()
|
buffer := make([]byte, 65507) // Max UDP packet size
|
||||||
|
clientConns := make(map[string]*net.UDPConn)
|
||||||
serverConn, err := net.Dial("tcp", target)
|
var clientsMutex sync.RWMutex
|
||||||
if err != nil {
|
|
||||||
logger.Info("Failed to connect to target %s: %v", target, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer serverConn.Close()
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(2)
|
|
||||||
|
|
||||||
// Client -> Server
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
io.Copy(serverConn, clientConn)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Server -> Client
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
io.Copy(clientConn, serverConn)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) serveUDP(target *ProxyTarget) {
|
|
||||||
defer close(target.done) // Signal that this target is fully stopped
|
|
||||||
|
|
||||||
addr := &net.UDPAddr{
|
|
||||||
IP: net.ParseIP(target.Listen),
|
|
||||||
Port: target.Port,
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := pm.tnet.ListenUDP(addr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Failed to start UDP listener for %s:%d: %v", target.Listen, target.Port, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
target.Lock()
|
|
||||||
target.udpConn = conn
|
|
||||||
target.Unlock()
|
|
||||||
|
|
||||||
defer conn.Close()
|
|
||||||
logger.Info("UDP proxy listening on %s", conn.LocalAddr())
|
|
||||||
|
|
||||||
buffer := make([]byte, 65535)
|
|
||||||
var activeConns sync.WaitGroup
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
n, remoteAddr, err := conn.ReadFrom(buffer)
|
||||||
case <-target.cancel:
|
if err != nil {
|
||||||
activeConns.Wait() // Wait for all active UDP handlers to complete
|
if !pm.running {
|
||||||
return
|
return
|
||||||
default:
|
|
||||||
n, remoteAddr, err := conn.ReadFrom(buffer)
|
|
||||||
if err != nil {
|
|
||||||
select {
|
|
||||||
case <-target.cancel:
|
|
||||||
activeConns.Wait()
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
logger.Info("Failed to read UDP packet: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
targetAddr, err := net.ResolveUDPAddr("udp", target.Target)
|
// Check for connection closed conditions
|
||||||
|
if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") {
|
||||||
|
logger.Info("UDP connection closed, stopping proxy handler")
|
||||||
|
|
||||||
|
// Clean up existing client connections
|
||||||
|
clientsMutex.Lock()
|
||||||
|
for _, targetConn := range clientConns {
|
||||||
|
targetConn.Close()
|
||||||
|
}
|
||||||
|
clientConns = nil
|
||||||
|
clientsMutex.Unlock()
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Error("Error reading UDP packet: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
clientKey := remoteAddr.String()
|
||||||
|
clientsMutex.RLock()
|
||||||
|
targetConn, exists := clientConns[clientKey]
|
||||||
|
clientsMutex.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
targetUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Info("Failed to resolve target address %s: %v", target.Target, err)
|
logger.Error("Error resolving target address: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
activeConns.Add(1)
|
targetConn, err = net.DialUDP("udp", nil, targetUDPAddr)
|
||||||
go func(data []byte, remote net.Addr) {
|
if err != nil {
|
||||||
defer activeConns.Done()
|
logger.Error("Error connecting to target: %v", err)
|
||||||
targetConn, err := net.DialUDP("udp", nil, targetAddr)
|
continue
|
||||||
if err != nil {
|
}
|
||||||
logger.Info("Failed to connect to target %s: %v", target.Target, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer targetConn.Close()
|
|
||||||
|
|
||||||
select {
|
clientsMutex.Lock()
|
||||||
case <-target.cancel:
|
clientConns[clientKey] = targetConn
|
||||||
return
|
clientsMutex.Unlock()
|
||||||
default:
|
|
||||||
_, err = targetConn.Write(data)
|
go func() {
|
||||||
|
buffer := make([]byte, 65507)
|
||||||
|
for {
|
||||||
|
n, _, err := targetConn.ReadFromUDP(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Info("Failed to write to target: %v", err)
|
logger.Error("Error reading from target: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response := make([]byte, 65535)
|
_, err = conn.WriteTo(buffer[:n], remoteAddr)
|
||||||
n, err := targetConn.Read(response)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Info("Failed to read response from target: %v", err)
|
logger.Error("Error writing to client: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = conn.WriteTo(response[:n], remote)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Failed to write response to client: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}(buffer[:n], remoteAddr)
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = targetConn.Write(buffer[:n])
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error writing to target: %v", err)
|
||||||
|
targetConn.Close()
|
||||||
|
clientsMutex.Lock()
|
||||||
|
delete(clientConns, clientKey)
|
||||||
|
clientsMutex.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,28 +0,0 @@
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ProxyTarget struct {
|
|
||||||
Protocol string
|
|
||||||
Listen string
|
|
||||||
Port int
|
|
||||||
Target string
|
|
||||||
cancel chan struct{} // Channel to signal shutdown
|
|
||||||
done chan struct{} // Channel to signal completion
|
|
||||||
listener net.Listener // For TCP
|
|
||||||
udpConn net.PacketConn // For UDP
|
|
||||||
sync.Mutex // Protect access to connection
|
|
||||||
}
|
|
||||||
|
|
||||||
type ProxyManager struct {
|
|
||||||
targets []ProxyTarget
|
|
||||||
tnet *netstack.Net
|
|
||||||
log *log.Logger
|
|
||||||
sync.RWMutex // Protect access to targets slice
|
|
||||||
}
|
|
|
@ -305,6 +305,10 @@ func (c *Client) establishConnection() error {
|
||||||
go c.readPump()
|
go c.readPump()
|
||||||
|
|
||||||
if c.onConnect != nil {
|
if c.onConnect != nil {
|
||||||
|
err := c.saveConfig()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to save config: %v", err)
|
||||||
|
}
|
||||||
if err := c.onConnect(); err != nil {
|
if err := c.onConnect(); err != nil {
|
||||||
logger.Error("OnConnect callback failed: %v", err)
|
logger.Error("OnConnect callback failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue