diff --git a/.gitignore b/.gitignore index 8b1c477..100fc81 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ newt .DS_Store -bin/ \ No newline at end of file +bin/ +nohup.out \ No newline at end of file diff --git a/go.mod b/go.mod index 33c593f..1912c9c 100644 --- a/go.mod +++ b/go.mod @@ -5,18 +5,27 @@ go 1.23.1 toolchain go1.23.2 require ( + github.com/google/gopacket v1.1.19 github.com/gorilla/websocket v1.5.3 - golang.org/x/net v0.30.0 + github.com/vishvananda/netlink v1.3.0 + golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa + golang.org/x/net v0.33.0 golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 - golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 + golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 ) require ( github.com/google/btree v1.1.2 // indirect github.com/google/go-cmp v0.6.0 // indirect - golang.org/x/crypto v0.28.0 // indirect - golang.org/x/sys v0.26.0 // indirect + github.com/josharian/native v1.1.0 // indirect + github.com/mdlayher/genetlink v1.3.2 // indirect + github.com/mdlayher/netlink v1.7.2 // indirect + github.com/mdlayher/socket v0.5.1 // indirect + github.com/vishvananda/netns v0.0.4 // indirect + golang.org/x/crypto v0.31.0 // indirect + golang.org/x/sync v0.11.0 // indirect + golang.org/x/sys v0.28.0 // indirect golang.org/x/time v0.7.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect ) diff --git a/go.sum b/go.sum index 2328634..386e554 100644 --- a/go.sum +++ b/go.sum @@ -2,21 +2,55 @@ github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= -golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= -golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= -golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= +github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= +github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= +github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= +github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= +github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= +github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= +github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= +github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc= +github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= +github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= +github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= +github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4= +golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= +golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4= golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= -golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= -golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= +golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU= +golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= diff --git a/main.go b/main.go index 9a08b8c..afd6412 100644 --- a/main.go +++ b/main.go @@ -13,6 +13,7 @@ import ( "os" "os/exec" "os/signal" + "runtime" "strconv" "strings" "syscall" @@ -21,6 +22,7 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/proxy" "github.com/fosrl/newt/websocket" + "github.com/fosrl/newt/wg" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" @@ -55,7 +57,7 @@ func fixKey(key string) string { // Decode from base64 decoded, err := base64.StdEncoding.DecodeString(key) if err != nil { - logger.Fatal("Error decoding base64:", err) + logger.Fatal("Error decoding base64") } // Convert to hex @@ -213,6 +215,9 @@ func resolveDomain(domain string) (string, error) { host = strings.TrimPrefix(host, "https://") } + // if there are any trailing slashes, remove them + host = strings.TrimSuffix(host, "/") + // Lookup IP addresses ips, err := net.LookupIP(host) if err != nil { @@ -246,16 +251,18 @@ func resolveDomain(domain string) (string, error) { } var ( - endpoint string - id string - secret string - mtu string - mtuInt int - dns string - privateKey wgtypes.Key - err error - logLevel string - updownScript string + endpoint string + id string + secret string + mtu string + mtuInt int + dns string + privateKey wgtypes.Key + err error + logLevel string + updownScript string + interfaceName string + generateAndSaveKeyTo string ) func main() { @@ -267,6 +274,8 @@ func main() { dns = os.Getenv("DNS") logLevel = os.Getenv("LOG_LEVEL") updownScript = os.Getenv("UPDOWN_SCRIPT") + interfaceName = os.Getenv("INTERFACE") + generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO") if endpoint == "" { flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") @@ -289,6 +298,12 @@ func main() { if updownScript == "" { flag.StringVar(&updownScript, "updown", "", "Path to updown script to be called when targets are added or removed") } + if interfaceName == "" { + flag.StringVar(&interfaceName, "interface", "wg1", "Name of the WireGuard interface") + } + if generateAndSaveKeyTo == "" { + flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key") + } // do a --version check version := flag.Bool("version", false, "Print the version") @@ -325,6 +340,7 @@ func main() { logger.Fatal("Failed to create client: %v", err) } + var wgService *wg.WireGuardService // Create TUN device and network stack var tun tun.Device var tnet *netstack.Net @@ -333,6 +349,30 @@ func main() { var connected bool var wgData WgData + if generateAndSaveKeyTo != "" { + // make sure we are running on linux + if runtime.GOOS != "linux" { + logger.Fatal("Tunnel management is only supported on Linux right now!") + os.Exit(1) + } + + var host = endpoint + if strings.HasPrefix(host, "http://") { + host = strings.TrimPrefix(host, "http://") + } else if strings.HasPrefix(host, "https://") { + host = strings.TrimPrefix(host, "https://") + } + + host = strings.TrimSuffix(host, "/") + + // Create WireGuard service + wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client) + if err != nil { + logger.Fatal("Failed to create WireGuard service: %v", err) + } + defer wgService.Close() + } + client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) { logger.Info("Received terminate message") if pm != nil { @@ -420,6 +460,7 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( if err != nil { // Handle complete failure after all retries logger.Error("Failed to ping %s: %v", wgData.ServerIP, err) + fmt.Sprintf("%s", privateKey) } if !connected { @@ -441,6 +482,13 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP}) } + // first make sure the wpgService has a port + if wgService != nil { + // add a udp proxy for localost and the wgService port + // TODO: make sure this port is not used in a target + pm.AddTarget("udp", wgData.TunnelIP, int(wgService.Port), fmt.Sprintf("localhost:%d", wgService.Port)) + } + err = pm.Start() if err != nil { logger.Error("Failed to start proxy manager: %v", err) @@ -539,6 +587,10 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( return err } + if wgService != nil { + wgService.LoadRemoteConfig() + } + logger.Info("Sent registration message") return nil }) diff --git a/network/network.go b/network/network.go new file mode 100644 index 0000000..0703e8b --- /dev/null +++ b/network/network.go @@ -0,0 +1,202 @@ +package network + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "log" + "net" + "time" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/vishvananda/netlink" + "golang.org/x/net/bpf" + "golang.org/x/net/ipv4" +) + +const ( + udpProtocol = 17 + // EmptyUDPSize is the size of an empty UDP packet + EmptyUDPSize = 28 + timeout = time.Second * 10 +) + +// Server stores data relating to the server +type Server struct { + Hostname string + Addr *net.IPAddr + Port uint16 +} + +// PeerNet stores data about a peer's endpoint +type PeerNet struct { + Resolved bool + IP net.IP + Port uint16 + NewtID string +} + +// GetClientIP gets source ip address that will be used when sending data to dstIP +func GetClientIP(dstIP net.IP) net.IP { + routes, err := netlink.RouteGet(dstIP) + if err != nil { + log.Fatalln("Error getting route:", err) + } + return routes[0].Src +} + +// HostToAddr resolves a hostname, whether DNS or IP to a valid net.IPAddr +func HostToAddr(hostStr string) *net.IPAddr { + remoteAddrs, err := net.LookupHost(hostStr) + if err != nil { + log.Fatalln("Error parsing remote address:", err) + } + + for _, addrStr := range remoteAddrs { + if remoteAddr, err := net.ResolveIPAddr("ip4", addrStr); err == nil { + return remoteAddr + } + } + return nil +} + +// SetupRawConn creates an ipv4 and udp only RawConn and applies packet filtering +func SetupRawConn(server *Server, client *PeerNet) *ipv4.RawConn { + packetConn, err := net.ListenPacket("ip4:udp", client.IP.String()) + if err != nil { + log.Fatalln("Error creating packetConn:", err) + } + + rawConn, err := ipv4.NewRawConn(packetConn) + if err != nil { + log.Fatalln("Error creating rawConn:", err) + } + + ApplyBPF(rawConn, server, client) + + return rawConn +} + +// ApplyBPF constructs a BPF program and applies it to the RawConn +func ApplyBPF(rawConn *ipv4.RawConn, server *Server, client *PeerNet) { + const ipv4HeaderLen = 20 + const srcIPOffset = 12 + const srcPortOffset = ipv4HeaderLen + 0 + const dstPortOffset = ipv4HeaderLen + 2 + + ipArr := []byte(server.Addr.IP.To4()) + ipInt := uint32(ipArr[0])<<(3*8) + uint32(ipArr[1])<<(2*8) + uint32(ipArr[2])<<8 + uint32(ipArr[3]) + + bpfRaw, err := bpf.Assemble([]bpf.Instruction{ + bpf.LoadAbsolute{Off: srcIPOffset, Size: 4}, + bpf.JumpIf{Cond: bpf.JumpEqual, Val: ipInt, SkipFalse: 5, SkipTrue: 0}, + + bpf.LoadAbsolute{Off: srcPortOffset, Size: 2}, + bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(server.Port), SkipFalse: 3, SkipTrue: 0}, + + bpf.LoadAbsolute{Off: dstPortOffset, Size: 2}, + bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 1, SkipTrue: 0}, + + bpf.RetConstant{Val: 1<<(8*4) - 1}, + bpf.RetConstant{Val: 0}, + }) + + if err != nil { + log.Fatalln("Error assembling BPF:", err) + } + + err = rawConn.SetBPF(bpfRaw) + if err != nil { + log.Fatalln("Error setting BPF:", err) + } +} + +// MakePacket constructs a request packet to send to the server +func MakePacket(payload []byte, server *Server, client *PeerNet) []byte { + buf := gopacket.NewSerializeBuffer() + + opts := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + + ipHeader := layers.IPv4{ + SrcIP: client.IP, + DstIP: server.Addr.IP, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + } + + udpHeader := layers.UDP{ + SrcPort: layers.UDPPort(client.Port), + DstPort: layers.UDPPort(server.Port), + } + + payloadLayer := gopacket.Payload(payload) + + udpHeader.SetNetworkLayerForChecksum(&ipHeader) + + gopacket.SerializeLayers(buf, opts, &ipHeader, &udpHeader, &payloadLayer) + + return buf.Bytes() +} + +// SendPacket sends packet to the Server +func SendPacket(packet []byte, conn *ipv4.RawConn, server *Server, client *PeerNet) error { + fullPacket := MakePacket(packet, server, client) + _, err := conn.WriteToIP(fullPacket, server.Addr) + return err +} + +// SendDataPacket sends a JSON payload to the Server +func SendDataPacket(data interface{}, conn *ipv4.RawConn, server *Server, client *PeerNet) error { + jsonData, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal payload: %v", err) + } + + return SendPacket(jsonData, conn, server, client) +} + +// RecvPacket receives a UDP packet from server +func RecvPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, int, error) { + err := conn.SetReadDeadline(time.Now().Add(timeout)) + if err != nil { + return nil, 0, err + } + + response := make([]byte, 4096) + n, err := conn.Read(response) + if err != nil { + return nil, n, err + } + return response, n, nil +} + +// RecvDataPacket receives and unmarshals a JSON packet from server +func RecvDataPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, error) { + response, n, err := RecvPacket(conn, server, client) + if err != nil { + return nil, err + } + + // Extract payload from UDP packet + payload := response[EmptyUDPSize:n] + return payload, nil +} + +// ParseResponse takes a response packet and parses it into an IP and port +func ParseResponse(response []byte) (net.IP, uint16) { + ip := net.IP(response[:4]) + port := binary.BigEndian.Uint16(response[4:6]) + return ip, port +} + +func parseForBPF(response []byte) (srcIP net.IP, srcPort uint16, dstPort uint16) { + srcIP = net.IP(response[12:16]) + srcPort = binary.BigEndian.Uint16(response[20:22]) + dstPort = binary.BigEndian.Uint16(response[22:24]) + return +} diff --git a/websocket/client.go b/websocket/client.go index 022a489..2706eee 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -292,6 +292,7 @@ func (c *Client) establishConnection() error { // Add token to query parameters q := u.Query() q.Set("token", token) + q.Set("clientType", "newt") u.RawQuery = q.Encode() // Connect to WebSocket diff --git a/wg/wg.go b/wg/wg.go new file mode 100644 index 0000000..bcb7cda --- /dev/null +++ b/wg/wg.go @@ -0,0 +1,689 @@ +package wg + +import ( + "encoding/json" + "fmt" + "net" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/network" + "github.com/fosrl/newt/websocket" + "github.com/vishvananda/netlink" + "golang.org/x/exp/rand" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +type WgConfig struct { + ListenPort int `json:"listenPort"` + IpAddress string `json:"ipAddress"` + Peers []Peer `json:"peers"` +} + +type Peer struct { + PublicKey string `json:"publicKey"` + AllowedIPs []string `json:"allowedIps"` + Endpoint string `json:"endpoint"` +} + +type PeerBandwidth struct { + PublicKey string `json:"publicKey"` + BytesIn float64 `json:"bytesIn"` + BytesOut float64 `json:"bytesOut"` +} + +type PeerReading struct { + BytesReceived int64 + BytesTransmitted int64 + LastChecked time.Time +} + +type WireGuardService struct { + interfaceName string + mtu int + client *websocket.Client + wgClient *wgctrl.Client + config WgConfig + key wgtypes.Key + newtId string + lastReadings map[string]PeerReading + mu sync.Mutex + Port uint16 + stopHolepunch chan struct{} + host string +} + +// Add this type definition +type fixedPortBind struct { + port uint16 + conn.Bind +} + +func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) { + // Ignore the requested port and use our fixed port + return b.Bind.Open(b.port) +} + +func NewFixedPortBind(port uint16) conn.Bind { + return &fixedPortBind{ + port: port, + Bind: conn.NewDefaultBind(), + } +} + +func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { + if maxPort < minPort { + return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort) + } + + // Create a slice of all ports in the range + portRange := make([]uint16, maxPort-minPort+1) + for i := range portRange { + portRange[i] = minPort + uint16(i) + } + + // Fisher-Yates shuffle to randomize the port order + rand.Seed(uint64(time.Now().UnixNano())) + for i := len(portRange) - 1; i > 0; i-- { + j := rand.Intn(i + 1) + portRange[i], portRange[j] = portRange[j], portRange[i] + } + + // Try each port in the randomized order + for _, port := range portRange { + addr := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: int(port), + } + conn, err := net.ListenUDP("udp", addr) + if err != nil { + continue // Port is in use or there was an error, try next port + } + _ = conn.SetDeadline(time.Now()) + conn.Close() + return port, nil + } + + return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort) +} + +func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client) (*WireGuardService, error) { + wgClient, err := wgctrl.New() + if err != nil { + return nil, fmt.Errorf("failed to create WireGuard client: %v", err) + } + + var key wgtypes.Key + // if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file + if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) { + // generate a new private key + key, err = wgtypes.GeneratePrivateKey() + if err != nil { + logger.Fatal("Failed to generate private key: %v", err) + } + // save the key to the file + err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644) + if err != nil { + logger.Fatal("Failed to save private key: %v", err) + } + } else { + keyData, err := os.ReadFile(generateAndSaveKeyTo) + if err != nil { + logger.Fatal("Failed to read private key: %v", err) + } + key, err = wgtypes.ParseKey(string(keyData)) + if err != nil { + logger.Fatal("Failed to parse private key: %v", err) + } + } + + port, err := FindAvailableUDPPort(49152, 65535) + if err != nil { + fmt.Printf("Error finding available port: %v\n", err) + return nil, err + } + + service := &WireGuardService{ + interfaceName: interfaceName, + mtu: mtu, + client: wsClient, + wgClient: wgClient, + key: key, + newtId: newtId, + lastReadings: make(map[string]PeerReading), + Port: port, + stopHolepunch: make(chan struct{}), + host: host, + } + + // Register websocket handlers + wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig) + wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer) + wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer) + + return service, nil +} + +func (s *WireGuardService) Close() { + s.wgClient.Close() +} + +func (s *WireGuardService) LoadRemoteConfig() error { + + err := s.client.SendMessage("newt/wg/get-config", map[string]interface{}{ + "publicKey": fmt.Sprintf("%s", s.key.PublicKey().String()), + }) + if err != nil { + logger.Error("Failed to send registration message: %v", err) + return err + } + + logger.Info("Requesting WireGuard configuration from remote server") + + go s.periodicBandwidthCheck() + + return nil +} + +func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { + var config WgConfig + + logger.Info("Received message: %v", msg) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &config); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + return + } + s.config = config + + // Ensure the WireGuard interface and peers are configured + if err := s.ensureWireguardInterface(config); err != nil { + logger.Error("Failed to ensure WireGuard interface: %v", err) + } + + if err := s.ensureWireguardPeers(config.Peers); err != nil { + logger.Error("Failed to ensure WireGuard peers: %v", err) + } + + if err := s.sendUDPHolePunch(s.host + ":21820"); err != nil { + logger.Error("Failed to send UDP hole punch: %v", err) + } + + // start the UDP holepunch + go s.keepSendingUDPHolePunch(s.host) +} + +func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { + // Check if the WireGuard interface exists + _, err := netlink.LinkByName(s.interfaceName) + if err != nil { + if _, ok := err.(netlink.LinkNotFoundError); ok { + // Interface doesn't exist, so create it + err = s.createWireGuardInterface() + if err != nil { + logger.Fatal("Failed to create WireGuard interface: %v", err) + } + logger.Info("Created WireGuard interface %s\n", s.interfaceName) + } else { + logger.Fatal("Error checking for WireGuard interface: %v", err) + } + } else { + logger.Info("WireGuard interface %s already exists\n", s.interfaceName) + + // get the exising wireguard port + device, err := s.wgClient.Device(s.interfaceName) + if err != nil { + return fmt.Errorf("failed to get device: %v", err) + } + + // get the existing port + s.Port = uint16(device.ListenPort) + logger.Info("WireGuard interface %s already exists with port %d\n", s.interfaceName, s.Port) + + return nil + } + + logger.Info("Assigning IP address %s to interface %s\n", wgconfig.IpAddress, s.interfaceName) + // Assign IP address to the interface + err = s.assignIPAddress(wgconfig.IpAddress) + if err != nil { + logger.Fatal("Failed to assign IP address: %v", err) + } + + // Check if the interface already exists + _, err = s.wgClient.Device(s.interfaceName) + if err != nil { + return fmt.Errorf("interface %s does not exist", s.interfaceName) + } + + // Parse the private key + key, err := wgtypes.ParseKey(s.key.String()) + if err != nil { + return fmt.Errorf("failed to parse private key: %v", err) + } + + config := wgtypes.Config{ + PrivateKey: &key, + ListenPort: new(int), + } + + // Use the service's fixed port instead of the config port + *config.ListenPort = int(s.Port) + + // Create and configure the WireGuard interface + err = s.wgClient.ConfigureDevice(s.interfaceName, config) + if err != nil { + return fmt.Errorf("failed to configure WireGuard device: %v", err) + } + + // bring up the interface + link, err := netlink.LinkByName(s.interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface: %v", err) + } + + if err := netlink.LinkSetMTU(link, s.mtu); err != nil { + return fmt.Errorf("failed to set MTU: %v", err) + } + + if err := netlink.LinkSetUp(link); err != nil { + return fmt.Errorf("failed to bring up interface: %v", err) + } + + // if err := s.ensureMSSClamping(); err != nil { + // logger.Warn("Failed to ensure MSS clamping: %v", err) + // } + + logger.Info("WireGuard interface %s created and configured", s.interfaceName) + + return nil +} + +func (s *WireGuardService) createWireGuardInterface() error { + wgLink := &netlink.GenericLink{ + LinkAttrs: netlink.LinkAttrs{Name: s.interfaceName}, + LinkType: "wireguard", + } + return netlink.LinkAdd(wgLink) +} + +func (s *WireGuardService) assignIPAddress(ipAddress string) error { + link, err := netlink.LinkByName(s.interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface: %v", err) + } + + addr, err := netlink.ParseAddr(ipAddress) + if err != nil { + return fmt.Errorf("failed to parse IP address: %v", err) + } + + return netlink.AddrAdd(link, addr) +} + +func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error { + // get the current peers + device, err := s.wgClient.Device(s.interfaceName) + if err != nil { + return fmt.Errorf("failed to get device: %v", err) + } + + // get the peer public keys + var currentPeers []string + for _, peer := range device.Peers { + currentPeers = append(currentPeers, peer.PublicKey.String()) + } + + // remove any peers that are not in the config + for _, peer := range currentPeers { + found := false + for _, configPeer := range peers { + if peer == configPeer.PublicKey { + found = true + break + } + } + if !found { + err := s.removePeer(peer) + if err != nil { + return fmt.Errorf("failed to remove peer: %v", err) + } + } + } + + // add any peers that are in the config but not in the current peers + for _, configPeer := range peers { + found := false + for _, peer := range currentPeers { + if configPeer.PublicKey == peer { + found = true + break + } + } + if !found { + err := s.addPeer(configPeer) + if err != nil { + return fmt.Errorf("failed to add peer: %v", err) + } + } + } + + return nil +} + +func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) { + var peer Peer + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + } + + if err := json.Unmarshal(jsonData, &peer); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + } + + err = s.addPeer(peer) + if err != nil { + logger.Info("Error adding peer: %v", err) + return + } +} + +func (s *WireGuardService) addPeer(peer Peer) error { + pubKey, err := wgtypes.ParseKey(peer.PublicKey) + if err != nil { + return fmt.Errorf("failed to parse public key: %v", err) + } + + // parse allowed IPs into array of net.IPNet + var allowedIPs []net.IPNet + for _, ipStr := range peer.AllowedIPs { + _, ipNet, err := net.ParseCIDR(ipStr) + if err != nil { + return fmt.Errorf("failed to parse allowed IP: %v", err) + } + allowedIPs = append(allowedIPs, *ipNet) + } + // add keep alive using *time.Duration of 1 second + keepalive := time.Second + + var peerConfig wgtypes.PeerConfig + if peer.Endpoint != "" { + endpoint, err := net.ResolveUDPAddr("udp", peer.Endpoint) + if err != nil { + return fmt.Errorf("failed to resolve endpoint address: %w", err) + } + + // make the endpoint localhost to test + + peerConfig = wgtypes.PeerConfig{ + PublicKey: pubKey, + AllowedIPs: allowedIPs, + PersistentKeepaliveInterval: &keepalive, + Endpoint: endpoint, + } + } else { + peerConfig = wgtypes.PeerConfig{ + PublicKey: pubKey, + AllowedIPs: allowedIPs, + PersistentKeepaliveInterval: &keepalive, + } + logger.Info("Added peer with no endpoint!") + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peerConfig}, + } + + if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil { + return fmt.Errorf("failed to add peer: %v", err) + } + + logger.Info("Peer %s added successfully", peer.PublicKey) + + return nil +} + +func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) { + // parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" } + type RemoveRequest struct { + PublicKey string `json:"publicKey"` + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + } + + var request RemoveRequest + if err := json.Unmarshal(jsonData, &request); err != nil { + logger.Info("Error unmarshaling data: %v", err) + return + } + + if err := s.removePeer(request.PublicKey); err != nil { + logger.Info("Error removing peer: %v", err) + return + } +} + +func (s *WireGuardService) removePeer(publicKey string) error { + pubKey, err := wgtypes.ParseKey(publicKey) + if err != nil { + return fmt.Errorf("failed to parse public key: %v", err) + } + + peerConfig := wgtypes.PeerConfig{ + PublicKey: pubKey, + Remove: true, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peerConfig}, + } + + if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil { + return fmt.Errorf("failed to remove peer: %v", err) + } + + logger.Info("Peer %s removed successfully", publicKey) + + return nil +} + +func (s *WireGuardService) periodicBandwidthCheck() { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for range ticker.C { + if err := s.reportPeerBandwidth(); err != nil { + logger.Info("Failed to report peer bandwidth: %v", err) + } + } +} + +func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { + device, err := s.wgClient.Device(s.interfaceName) + if err != nil { + return nil, fmt.Errorf("failed to get device: %v", err) + } + + peerBandwidths := []PeerBandwidth{} + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + for _, peer := range device.Peers { + publicKey := peer.PublicKey.String() + currentReading := PeerReading{ + BytesReceived: peer.ReceiveBytes, + BytesTransmitted: peer.TransmitBytes, + LastChecked: now, + } + + var bytesInDiff, bytesOutDiff float64 + lastReading, exists := s.lastReadings[publicKey] + + if exists { + timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds() + if timeDiff > 0 { + // Calculate bytes transferred since last reading + bytesInDiff = float64(currentReading.BytesReceived - lastReading.BytesReceived) + bytesOutDiff = float64(currentReading.BytesTransmitted - lastReading.BytesTransmitted) + + // Handle counter wraparound (if the counter resets or overflows) + if bytesInDiff < 0 { + bytesInDiff = float64(currentReading.BytesReceived) + } + if bytesOutDiff < 0 { + bytesOutDiff = float64(currentReading.BytesTransmitted) + } + + // Convert to MB + bytesInMB := bytesInDiff / (1024 * 1024) + bytesOutMB := bytesOutDiff / (1024 * 1024) + + peerBandwidths = append(peerBandwidths, PeerBandwidth{ + PublicKey: publicKey, + BytesIn: bytesInMB, + BytesOut: bytesOutMB, + }) + } else { + // If readings are too close together or time hasn't passed, report 0 + peerBandwidths = append(peerBandwidths, PeerBandwidth{ + PublicKey: publicKey, + BytesIn: 0, + BytesOut: 0, + }) + } + } else { + // For first reading of a peer, report 0 to establish baseline + peerBandwidths = append(peerBandwidths, PeerBandwidth{ + PublicKey: publicKey, + BytesIn: 0, + BytesOut: 0, + }) + } + + // Update the last reading + s.lastReadings[publicKey] = currentReading + } + + // Clean up old peers + for publicKey := range s.lastReadings { + found := false + for _, peer := range device.Peers { + if peer.PublicKey.String() == publicKey { + found = true + break + } + } + if !found { + delete(s.lastReadings, publicKey) + } + } + + return peerBandwidths, nil +} + +func (s *WireGuardService) reportPeerBandwidth() error { + bandwidths, err := s.calculatePeerBandwidth() + if err != nil { + return fmt.Errorf("failed to calculate peer bandwidth: %v", err) + } + + err = s.client.SendMessage("newt/receive-bandwidth", map[string]interface{}{ + "bandwidthData": bandwidths, + }) + if err != nil { + return fmt.Errorf("failed to send bandwidth data: %v", err) + } + + return nil +} + +func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error { + // Parse server address + serverSplit := strings.Split(serverAddr, ":") + if len(serverSplit) < 2 { + return fmt.Errorf("invalid server address format, expected hostname:port") + } + + serverHostname := serverSplit[0] + serverPort, err := strconv.ParseUint(serverSplit[1], 10, 16) + if err != nil { + return fmt.Errorf("failed to parse server port: %v", err) + } + + // Resolve server hostname to IP + serverIPAddr := network.HostToAddr(serverHostname) + if serverIPAddr == nil { + return fmt.Errorf("failed to resolve server hostname") + } + + // Get client IP based on route to server + clientIP := network.GetClientIP(serverIPAddr.IP) + + // Create server and client configs + server := &network.Server{ + Hostname: serverHostname, + Addr: serverIPAddr, + Port: uint16(serverPort), + } + + client := &network.PeerNet{ + IP: clientIP, + Port: s.Port, + NewtID: s.newtId, + } + + // Setup raw connection with BPF filtering + rawConn := network.SetupRawConn(server, client) + defer rawConn.Close() + + // Create JSON payload + payload := struct { + NewtID string `json:"newtId"` + }{ + NewtID: s.newtId, + } + + // Send the packet using the raw connection + err = network.SendDataPacket(payload, rawConn, server, client) + if err != nil { + return fmt.Errorf("failed to send UDP packet: %v", err) + } + + return nil +} + +func (s *WireGuardService) keepSendingUDPHolePunch(host string) { + ticker := time.NewTicker(3 * time.Second) + defer ticker.Stop() + + for { + select { + case <-s.stopHolepunch: + logger.Info("Stopping UDP holepunch") + return + case <-ticker.C: + if err := s.sendUDPHolePunch(host + ":21820"); err != nil { + logger.Error("Failed to send UDP hole punch: %v", err) + } + } + } +}