newt/proxy/manager.go
2024-11-04 00:26:45 -05:00

246 lines
4.9 KiB
Go

package proxy
import (
"fmt"
"io"
"log"
"net"
"strings"
"sync"
"golang.zx2c4.com/wireguard/tun/netstack"
)
func NewProxyManager(tnet *netstack.Net) *ProxyManager {
return &ProxyManager{
tnet: tnet,
}
}
func (pm *ProxyManager) AddTarget(protocol, listen string, port int, target string) {
pm.Lock()
defer pm.Unlock()
newTarget := ProxyTarget{
Protocol: protocol,
Listen: listen,
Port: port,
Target: target,
cancel: make(chan struct{}),
}
pm.targets = append(pm.targets, newTarget)
}
func (pm *ProxyManager) RemoveTarget(listen string, port int) error {
pm.Lock()
defer pm.Unlock()
for i, target := range pm.targets {
if target.Listen == listen && target.Port == port {
// Signal the serving goroutine to stop
close(target.cancel)
// Close the listener/connection
target.Lock()
if target.listener != nil {
target.listener.Close()
}
if target.udpConn != nil {
target.udpConn.Close()
}
target.Unlock()
// Remove the target from the slice
pm.targets = append(pm.targets[:i], pm.targets[i+1:]...)
return nil
}
}
return fmt.Errorf("target not found for %s:%d", listen, port)
}
func (pm *ProxyManager) Start() error {
pm.RLock()
defer pm.RUnlock()
for i := range pm.targets {
target := &pm.targets[i] // Use pointer to modify the target in the slice
switch strings.ToLower(target.Protocol) {
case "tcp":
go pm.serveTCP(target)
case "udp":
go pm.serveUDP(target)
default:
return fmt.Errorf("unsupported protocol: %s", target.Protocol)
}
}
return nil
}
func (pm *ProxyManager) serveTCP(target *ProxyTarget) {
listener, err := pm.tnet.ListenTCP(&net.TCPAddr{
IP: net.ParseIP(target.Listen),
Port: target.Port,
})
log.Printf("Listening on %s:%d", target.Listen, target.Port)
if err != nil {
log.Printf("Failed to start TCP listener for %s:%d: %v", target.Listen, target.Port, err)
return
}
target.Lock()
target.listener = listener
target.Unlock()
defer listener.Close()
log.Printf("TCP proxy listening on %s", listener.Addr())
// Channel to signal active connections to close
done := make(chan struct{})
var activeConns sync.WaitGroup
// Goroutine to handle shutdown signal
go func() {
<-target.cancel
close(done)
listener.Close()
}()
for {
conn, err := listener.Accept()
if err != nil {
select {
case <-target.cancel:
// Wait for active connections to finish
activeConns.Wait()
return
default:
log.Printf("Failed to accept TCP connection: %v", err)
continue
}
}
activeConns.Add(1)
go func() {
defer activeConns.Done()
pm.handleTCPConnection(conn, target.Target, done)
}()
}
}
func (pm *ProxyManager) handleTCPConnection(clientConn net.Conn, target string, done chan struct{}) {
defer clientConn.Close()
serverConn, err := net.Dial("tcp", target)
if err != nil {
log.Printf("Failed to connect to target %s: %v", target, err)
return
}
defer serverConn.Close()
var wg sync.WaitGroup
wg.Add(2)
// Client -> Server
go func() {
defer wg.Done()
select {
case <-done:
return
default:
io.Copy(serverConn, clientConn)
}
}()
// Server -> Client
go func() {
defer wg.Done()
select {
case <-done:
return
default:
io.Copy(clientConn, serverConn)
}
}()
wg.Wait()
}
func (pm *ProxyManager) serveUDP(target *ProxyTarget) {
addr := &net.UDPAddr{
IP: net.ParseIP(target.Listen),
Port: target.Port,
}
conn, err := pm.tnet.ListenUDP(addr)
if err != nil {
log.Printf("Failed to start UDP listener for %s:%d: %v", target.Listen, target.Port, err)
return
}
target.Lock()
target.udpConn = conn
target.Unlock()
defer conn.Close()
log.Printf("UDP proxy listening on %s", conn.LocalAddr())
buffer := make([]byte, 65535)
for {
select {
case <-target.cancel:
return
default:
n, remoteAddr, err := conn.ReadFrom(buffer)
if err != nil {
select {
case <-target.cancel:
return
default:
log.Printf("Failed to read UDP packet: %v", err)
continue
}
}
targetAddr, err := net.ResolveUDPAddr("udp", target.Target)
if err != nil {
log.Printf("Failed to resolve target address %s: %v", target.Target, err)
continue
}
go func(data []byte, remote net.Addr) {
targetConn, err := net.DialUDP("udp", nil, targetAddr)
if err != nil {
log.Printf("Failed to connect to target %s: %v", target.Target, err)
return
}
defer targetConn.Close()
select {
case <-target.cancel:
return
default:
_, err = targetConn.Write(data)
if err != nil {
log.Printf("Failed to write to target: %v", err)
return
}
response := make([]byte, 65535)
n, err := targetConn.Read(response)
if err != nil {
log.Printf("Failed to read response from target: %v", err)
return
}
_, err = conn.WriteTo(response[:n], remote)
if err != nil {
log.Printf("Failed to write response to client: %v", err)
}
}
}(buffer[:n], remoteAddr)
}
}
}