newt/main.go
2024-11-18 22:08:42 -05:00

357 lines
8.3 KiB
Go

package main
import (
"bytes"
"encoding/base64"
"encoding/hex"
"encoding/json"
"flag"
"fmt"
"log"
"math/rand"
"net/netip"
"newt/proxy"
"newt/websocket"
"os"
"os/signal"
"strings"
"syscall"
"time"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type WgData struct {
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
ServerIP string `json:"serverIP"`
TunnelIP string `json:"tunnelIP"`
Targets TargetsByType `json:"targets"`
}
type TargetsByType struct {
UDP []string `json:"udp"`
TCP []string `json:"tcp"`
}
type TargetData struct {
Targets []string `json:"targets"`
}
func fixKey(key string) string {
// Remove any whitespace
key = strings.TrimSpace(key)
// Decode from base64
decoded, err := base64.StdEncoding.DecodeString(key)
if err != nil {
log.Fatal("Error decoding base64:", err)
}
// Convert to hex
return hex.EncodeToString(decoded)
}
func ping(tnet *netstack.Net, dst string) {
socket, err := tnet.Dial("ping4", dst)
if err != nil {
log.Panic(err)
}
requestPing := icmp.Echo{
Seq: rand.Intn(1 << 16),
Data: []byte("gopher burrow"),
}
icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
socket.SetReadDeadline(time.Now().Add(time.Second * 10))
start := time.Now()
_, err = socket.Write(icmpBytes)
if err != nil {
log.Panic(err)
}
n, err := socket.Read(icmpBytes[:])
if err != nil {
log.Panic(err)
}
replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
if err != nil {
log.Panic(err)
}
replyPing, ok := replyPacket.Body.(*icmp.Echo)
if !ok {
log.Panicf("invalid reply type: %v", replyPacket)
}
if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
log.Panicf("invalid ping reply: %v", replyPing)
}
log.Printf("Ping latency: %v", time.Since(start))
}
func main() {
var (
dns string
id string
secret string
privateKey wgtypes.Key
err error
)
flag.StringVar(&dns, "dns", "8.8.8.8", "DNS server to use")
flag.StringVar(&id, "id", "", "Newt ID")
flag.StringVar(&secret, "secret", "", "Newt secret")
flag.Parse()
privateKey, err = wgtypes.GeneratePrivateKey()
if err != nil {
log.Fatalf("Failed to generate private key: %v", err)
}
// Create a new client
client, err := websocket.NewClient(
// the id and secret from the params
id,
secret,
websocket.WithBaseURL("http://localhost:3000/api/v1"),
)
if err != nil {
log.Fatal(err)
}
// Create TUN device and network stack
var tun tun.Device
var tnet *netstack.Net
var dev *device.Device
var pm *proxy.ProxyManager
var connected bool
var wgData WgData
// Register handlers for different message types
client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) {
if connected {
log.Printf("Already connected! Put I will send a ping anyway...")
ping(tnet, wgData.ServerIP)
return
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
log.Printf("Error marshaling data: %v", err)
return
}
if err := json.Unmarshal(jsonData, &wgData); err != nil {
log.Printf("Error unmarshaling target data: %v", err)
return
}
log.Printf("Received: %+v", msg)
tun, tnet, err = netstack.CreateNetTUN(
[]netip.Addr{netip.MustParseAddr(wgData.TunnelIP)},
[]netip.Addr{netip.MustParseAddr(dns)},
1420)
if err != nil {
log.Panic(err)
}
// Create WireGuard device
dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
// Configure WireGuard
config := fmt.Sprintf(`private_key=%s
public_key=%s
allowed_ip=%s/32
endpoint=%s
persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.PublicKey), wgData.ServerIP, wgData.Endpoint)
err = dev.IpcSet(config)
if err != nil {
log.Panic(err)
}
// Bring up the device
err = dev.Up()
if err != nil {
log.Panic(err)
}
// Ping to bring the tunnel up on the server side quickly
ping(tnet, wgData.ServerIP)
// Create proxy manager
pm = proxy.NewProxyManager(tnet)
connected = true
// add the targets if there are any
if len(wgData.Targets.TCP) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP})
}
if len(wgData.Targets.UDP) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP})
}
})
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
log.Printf("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
log.Printf("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
log.Printf("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData)
}
})
client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) {
log.Printf("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
log.Printf("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
log.Printf("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData)
}
})
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
log.Printf("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
log.Printf("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
log.Printf("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "remove", wgData.TunnelIP, "udp", targetData)
}
})
client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) {
log.Printf("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
log.Printf("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
log.Printf("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "remove", wgData.TunnelIP, "tcp", targetData)
}
})
// Connect to the WebSocket server
if err := client.Connect(); err != nil {
log.Fatal(err)
}
defer client.Close()
// TODO: how to retry?
err = client.SendMessage("newt/wg/register", map[string]interface{}{
"publicKey": fmt.Sprintf("%s", privateKey),
})
if err != nil {
log.Printf("Failed to send message: %v", err)
}
// Wait for interrupt signal
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
// Cleanup
dev.Close()
}
func parseTargetData(data interface{}) (TargetData, error) {
var targetData TargetData
jsonData, err := json.Marshal(data)
if err != nil {
log.Printf("Error marshaling data: %v", err)
return targetData, err
}
if err := json.Unmarshal(jsonData, &targetData); err != nil {
log.Printf("Error unmarshaling target data: %v", err)
return targetData, err
}
return targetData, nil
}
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
// Stop the proxy manager before adding new targets
err := pm.Stop()
if err != nil {
log.Panic(err)
}
for _, t := range targetData.Targets {
// Split the first number off of the target with : separator and use as the port
parts := strings.Split(t, ":")
if len(parts) != 2 {
log.Printf("Invalid target format: %s", t)
continue
}
// Get the port as an int
port := 0
_, err := fmt.Sscanf(parts[0], "%d", &port)
if err != nil {
log.Printf("Invalid port: %s", parts[0])
continue
}
if action == "add" {
target := parts[1]
pm.AddTarget(proto, tunnelIP, port, target)
} else if action == "remove" {
pm.RemoveTarget(proto, tunnelIP, port)
}
}
err = pm.Start()
if err != nil {
log.Panic(err)
}
return nil
}