diff --git a/go.mod b/go.mod index 2e3b644..cc28d40 100644 --- a/go.mod +++ b/go.mod @@ -14,5 +14,6 @@ require ( golang.org/x/sys v0.26.0 // indirect golang.org/x/time v0.7.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect + golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 // indirect gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect ) diff --git a/go.sum b/go.sum index 682c547..571b2d6 100644 --- a/go.sum +++ b/go.sum @@ -14,5 +14,7 @@ golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeu golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4= golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= +golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= +golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= diff --git a/main.go b/main.go index 3ec52ca..95d9585 100644 --- a/main.go +++ b/main.go @@ -28,10 +28,20 @@ import ( ) type WgData struct { - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` - ServerIP string `json:"serverIP"` - TunnelIP string `json:"tunnelIP"` + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` + ServerIP string `json:"serverIP"` + TunnelIP string `json:"tunnelIP"` + Targets TargetsByType `json:"targets"` +} + +type TargetsByType struct { + UDP []string `json:"udp"` + TCP []string `json:"tcp"` +} + +type TargetData struct { + Targets []string `json:"targets"` } func fixKey(key string) string { @@ -177,6 +187,15 @@ persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.P pm = proxy.NewProxyManager(tnet) connected = true + + // add the targets if there are any + if len(wgData.Targets.TCP) > 0 { + updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP}) + } + + if len(wgData.Targets.UDP) > 0 { + updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP}) + } }) client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) { @@ -188,55 +207,14 @@ persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.P return } - type TargetData struct { - Targets []string `json:"targets"` - } - // Define a struct for the expected data structure - jsonData, err := json.Marshal(msg.Data) + targetData, err := parseTargetData(msg.Data) if err != nil { - log.Printf("Error marshaling data: %v", err) - return - } - - // Parse into our target structure - var targetData TargetData - if err := json.Unmarshal(jsonData, &targetData); err != nil { - log.Printf("Error unmarshaling target data: %v", err) + log.Printf("Error parsing target data: %v", err) return } if len(targetData.Targets) > 0 { - - // Stop the proxy manager before adding new targets - err = pm.Stop() - if err != nil { - log.Panic(err) - } - - for _, t := range targetData.Targets { - // Split the first number off of the target with : separator and use as the port - parts := strings.Split(t, ":") - if len(parts) != 2 { - log.Printf("Invalid target format: %s", t) - continue - } - - // Get the port as an int - port := 0 - _, err := fmt.Sscanf(parts[0], "%d", &port) - if err != nil { - log.Printf("Invalid port: %s", parts[0]) - continue - } - - target := parts[1] - pm.AddTarget("tcp", wgData.TunnelIP, port, target) - } - - err = pm.Start() - if err != nil { - log.Panic(err) - } + updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData) } }) @@ -249,51 +227,54 @@ persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.P return } - type TargetData struct { - Targets []string `json:"targets"` - } - jsonData, err := json.Marshal(msg.Data) + targetData, err := parseTargetData(msg.Data) if err != nil { - log.Printf("Error marshaling data: %v", err) - return - } - - var targetData TargetData - if err := json.Unmarshal(jsonData, &targetData); err != nil { - log.Printf("Error unmarshaling target data: %v", err) + log.Printf("Error parsing target data: %v", err) return } if len(targetData.Targets) > 0 { - err = pm.Stop() - if err != nil { - log.Panic(err) - } + updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData) + } + }) - for _, t := range targetData.Targets { - // Split the first number off of the target with : separator and use as the port - parts := strings.Split(t, ":") - if len(parts) != 2 { - log.Printf("Invalid target format: %s", t) - continue - } + client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) { + log.Printf("Received: %+v", msg) - // Get the port as an int - port := 0 - _, err := fmt.Sscanf(parts[0], "%d", &port) - if err != nil { - log.Printf("Invalid port: %s", parts[0]) - continue - } + // if there is no wgData or pm, we can't add targets + if wgData.TunnelIP == "" || pm == nil { + log.Printf("No tunnel IP or proxy manager available") + return + } - target := parts[1] - pm.AddTarget("udp", wgData.TunnelIP, port, target) - } + targetData, err := parseTargetData(msg.Data) + if err != nil { + log.Printf("Error parsing target data: %v", err) + return + } - err = pm.Start() - if err != nil { - log.Panic(err) - } + if len(targetData.Targets) > 0 { + updateTargets(pm, "remove", wgData.TunnelIP, "udp", targetData) + } + }) + + client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) { + log.Printf("Received: %+v", msg) + + // if there is no wgData or pm, we can't add targets + if wgData.TunnelIP == "" || pm == nil { + log.Printf("No tunnel IP or proxy manager available") + return + } + + targetData, err := parseTargetData(msg.Data) + if err != nil { + log.Printf("Error parsing target data: %v", err) + return + } + + if len(targetData.Targets) > 0 { + updateTargets(pm, "remove", wgData.TunnelIP, "tcp", targetData) } }) @@ -303,10 +284,9 @@ persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.P } defer client.Close() - // TODO: we need to send the public key to the server to trigger it to respond to create the tunnel // TODO: how to retry? err = client.SendMessage("newt/wg/register", map[string]interface{}{ - "content": "Hello, World!", + "publicKey": fmt.Sprintf("%s", privateKey), }) if err != nil { log.Printf("Failed to send message: %v", err) @@ -320,3 +300,58 @@ persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.P // Cleanup dev.Close() } + +func parseTargetData(data interface{}) (TargetData, error) { + var targetData TargetData + jsonData, err := json.Marshal(data) + if err != nil { + log.Printf("Error marshaling data: %v", err) + return targetData, err + } + + if err := json.Unmarshal(jsonData, &targetData); err != nil { + log.Printf("Error unmarshaling target data: %v", err) + return targetData, err + } + return targetData, nil +} + +func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error { + + // Stop the proxy manager before adding new targets + err := pm.Stop() + if err != nil { + log.Panic(err) + } + + for _, t := range targetData.Targets { + // Split the first number off of the target with : separator and use as the port + parts := strings.Split(t, ":") + if len(parts) != 2 { + log.Printf("Invalid target format: %s", t) + continue + } + + // Get the port as an int + port := 0 + _, err := fmt.Sscanf(parts[0], "%d", &port) + if err != nil { + log.Printf("Invalid port: %s", parts[0]) + continue + } + + if action == "add" { + target := parts[1] + pm.AddTarget(proto, tunnelIP, port, target) + } else if action == "remove" { + pm.RemoveTarget(proto, tunnelIP, port) + } + } + + err = pm.Start() + if err != nil { + log.Panic(err) + } + + return nil +} diff --git a/proxy/manager.go b/proxy/manager.go index 45d667a..e54c312 100644 --- a/proxy/manager.go +++ b/proxy/manager.go @@ -32,22 +32,33 @@ func (pm *ProxyManager) AddTarget(protocol, listen string, port int, target stri pm.targets = append(pm.targets, newTarget) } -func (pm *ProxyManager) RemoveTarget(listen string, port int) error { +func (pm *ProxyManager) RemoveTarget(protocol, listen string, port int) error { pm.Lock() defer pm.Unlock() + protocol = strings.ToLower(protocol) + if protocol != "tcp" && protocol != "udp" { + return fmt.Errorf("unsupported protocol: %s", protocol) + } + for i, target := range pm.targets { - if target.Listen == listen && target.Port == port { + if target.Listen == listen && + target.Port == port && + strings.ToLower(target.Protocol) == protocol { // Signal the serving goroutine to stop close(target.cancel) - // Close the listener/connection + // Close the appropriate listener/connection based on protocol target.Lock() - if target.listener != nil { - target.listener.Close() - } - if target.udpConn != nil { - target.udpConn.Close() + switch protocol { + case "tcp": + if target.listener != nil { + target.listener.Close() + } + case "udp": + if target.udpConn != nil { + target.udpConn.Close() + } } target.Unlock() @@ -57,7 +68,7 @@ func (pm *ProxyManager) RemoveTarget(listen string, port int) error { } } - return fmt.Errorf("target not found for %s:%d", listen, port) + return fmt.Errorf("target not found for %s %s:%d", protocol, listen, port) } func (pm *ProxyManager) Start() error { diff --git a/websocket/config.go b/websocket/config.go index b47e7be..efc6e22 100644 --- a/websocket/config.go +++ b/websocket/config.go @@ -2,7 +2,6 @@ package websocket import ( "encoding/json" - "io/ioutil" "log" "os" "path/filepath" @@ -29,7 +28,7 @@ func getConfigPath() string { func (c *Client) loadConfig() error { configPath := getConfigPath() - data, err := ioutil.ReadFile(configPath) + data, err := os.ReadFile(configPath) if err != nil { if os.IsNotExist(err) { return nil @@ -53,5 +52,5 @@ func (c *Client) saveConfig() error { if err != nil { return err } - return ioutil.WriteFile(configPath, data, 0644) + return os.WriteFile(configPath, data, 0644) }