newt/main.go
2024-11-03 16:34:23 -05:00

323 lines
7.1 KiB
Go

package main
import (
"bytes"
"encoding/base64"
"encoding/hex"
"flag"
"fmt"
"io"
"log"
"math/rand"
"net"
"net/netip"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
)
type ProxyTarget struct {
Protocol string
Listen string
Targets []string
}
type ProxyManager struct {
targets []ProxyTarget
tnet *netstack.Net
}
func NewProxyManager(tnet *netstack.Net) *ProxyManager {
return &ProxyManager{
tnet: tnet,
}
}
func (pm *ProxyManager) AddTarget(protocol, listen string, targets []string) {
pm.targets = append(pm.targets, ProxyTarget{
Protocol: protocol,
Listen: listen,
Targets: targets,
})
}
func (pm *ProxyManager) Start() error {
for _, target := range pm.targets {
switch strings.ToLower(target.Protocol) {
case "tcp":
go pm.serveTCP(target)
case "udp":
go pm.serveUDP(target)
default:
return fmt.Errorf("unsupported protocol: %s", target.Protocol)
}
}
return nil
}
func (pm *ProxyManager) serveTCP(target ProxyTarget) {
listener, err := pm.tnet.ListenTCP(&net.TCPAddr{
IP: net.ParseIP(target.Listen),
Port: 0,
})
if err != nil {
log.Printf("Failed to start TCP listener for %s: %v", target.Listen, err)
return
}
defer listener.Close()
log.Printf("TCP proxy listening on %s", listener.Addr())
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("Failed to accept TCP connection: %v", err)
continue
}
go pm.handleTCPConnection(conn, target.Targets)
}
}
func (pm *ProxyManager) handleTCPConnection(clientConn net.Conn, targets []string) {
defer clientConn.Close()
// Round-robin through targets
targetIndex := 0
target := targets[targetIndex]
targetIndex = (targetIndex + 1) % len(targets)
serverConn, err := net.Dial("tcp", target)
if err != nil {
log.Printf("Failed to connect to target %s: %v", target, err)
return
}
defer serverConn.Close()
var wg sync.WaitGroup
wg.Add(2)
// Client -> Server
go func() {
defer wg.Done()
io.Copy(serverConn, clientConn)
}()
// Server -> Client
go func() {
defer wg.Done()
io.Copy(clientConn, serverConn)
}()
wg.Wait()
}
func (pm *ProxyManager) serveUDP(target ProxyTarget) {
addr := &net.UDPAddr{
IP: net.ParseIP(target.Listen),
Port: 0,
}
conn, err := pm.tnet.ListenUDP(addr)
if err != nil {
log.Printf("Failed to start UDP listener for %s: %v", target.Listen, err)
return
}
defer conn.Close()
log.Printf("UDP proxy listening on %s", conn.LocalAddr())
buffer := make([]byte, 65535)
targetIndex := 0
for {
// Read from the UDP connection
n, remoteAddr, err := conn.ReadFrom(buffer)
if err != nil {
log.Printf("Failed to read UDP packet: %v", err)
continue
}
t := target.Targets[targetIndex]
targetIndex = (targetIndex + 1) % len(target.Targets)
targetAddr, err := net.ResolveUDPAddr("udp", t)
if err != nil {
log.Printf("Failed to resolve target address %s: %v", target, err)
continue
}
go func(data []byte, remote net.Addr) {
targetConn, err := net.DialUDP("udp", nil, targetAddr)
if err != nil {
log.Printf("Failed to connect to target %s: %v", target, err)
return
}
defer targetConn.Close()
_, err = targetConn.Write(data)
if err != nil {
log.Printf("Failed to write to target: %v", err)
return
}
response := make([]byte, 65535)
n, err := targetConn.Read(response)
if err != nil {
log.Printf("Failed to read response from target: %v", err)
return
}
_, err = conn.WriteTo(response[:n], remote)
if err != nil {
log.Printf("Failed to write response to client: %v", err)
}
}(buffer[:n], remoteAddr)
}
}
func fixKey(key string) string {
// Remove any whitespace
key = strings.TrimSpace(key)
// Decode from base64
decoded, err := base64.StdEncoding.DecodeString(key)
if err != nil {
log.Fatal("Error decoding base64:", err)
}
// Convert to hex
return hex.EncodeToString(decoded)
}
func ping(tnet *netstack.Net, dst string) {
socket, err := tnet.Dial("ping4", dst)
if err != nil {
log.Panic(err)
}
requestPing := icmp.Echo{
Seq: rand.Intn(1 << 16),
Data: []byte("gopher burrow"),
}
icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
socket.SetReadDeadline(time.Now().Add(time.Second * 10))
start := time.Now()
_, err = socket.Write(icmpBytes)
if err != nil {
log.Panic(err)
}
n, err := socket.Read(icmpBytes[:])
if err != nil {
log.Panic(err)
}
replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
if err != nil {
log.Panic(err)
}
replyPing, ok := replyPacket.Body.(*icmp.Echo)
if !ok {
log.Panicf("invalid reply type: %v", replyPacket)
}
if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
log.Panicf("invalid ping reply: %v", replyPing)
}
log.Printf("Ping latency: %v", time.Since(start))
}
func main() {
var (
tunnelIP string
privateKey string
publicKey string
endpoint string
tcpTargets string
udpTargets string
listenIP string
serverIP string
dns string
)
flag.StringVar(&tunnelIP, "tunnel-ip", "", "Tunnel IP address")
flag.StringVar(&privateKey, "private-key", "", "WireGuard private key")
flag.StringVar(&publicKey, "public-key", "", "WireGuard public key")
flag.StringVar(&endpoint, "endpoint", "", "WireGuard endpoint (host:port)")
flag.StringVar(&tcpTargets, "tcp-targets", "", "Comma-separated list of TCP targets (host:port)")
flag.StringVar(&udpTargets, "udp-targets", "", "Comma-separated list of UDP targets (host:port)")
flag.StringVar(&listenIP, "listen-ip", "", "IP to listen for incoming connections")
flag.StringVar(&serverIP, "server-ip", "", "IP to filter and ping on the server side. Inside tunnel...")
flag.StringVar(&dns, "dns", "8.8.8.8", "DNS server to use")
flag.Parse()
// Create TUN device and network stack
tun, tnet, err := netstack.CreateNetTUN(
[]netip.Addr{netip.MustParseAddr(tunnelIP)},
[]netip.Addr{netip.MustParseAddr(dns)},
1420)
if err != nil {
log.Panic(err)
}
// Create WireGuard device
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
// Configure WireGuard
config := fmt.Sprintf(`private_key=%s
public_key=%s
allowed_ip=%s/32
endpoint=%s
persistent_keepalive_interval=5
`, fixKey(privateKey), fixKey(publicKey), serverIP, endpoint)
err = dev.IpcSet(config)
if err != nil {
log.Panic(err)
}
// Bring up the device
err = dev.Up()
if err != nil {
log.Panic(err)
}
// Ping to bring the tunnel up on the server side quickly
ping(tnet, serverIP)
// Create proxy manager
pm := NewProxyManager(tnet)
// Add TCP targets
if tcpTargets != "" {
targets := strings.Split(tcpTargets, ",")
pm.AddTarget("tcp", listenIP, targets)
}
// Add UDP targets
if udpTargets != "" {
targets := strings.Split(udpTargets, ",")
pm.AddTarget("udp", listenIP, targets)
}
// Start proxies
err = pm.Start()
if err != nil {
log.Panic(err)
}
// Wait for interrupt signal
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
// Cleanup
dev.Close()
}