Polish; add remove

This commit is contained in:
Owen Schwartz 2024-11-18 22:08:42 -05:00
parent 2e5531b4a5
commit 055d50d1d3
No known key found for this signature in database
GPG key ID: 8271FDFFD9E0CCBD
5 changed files with 146 additions and 98 deletions

1
go.mod
View file

@ -14,5 +14,6 @@ require (
golang.org/x/sys v0.26.0 // indirect golang.org/x/sys v0.26.0 // indirect
golang.org/x/time v0.7.0 // indirect golang.org/x/time v0.7.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 // indirect
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect
) )

2
go.sum
View file

@ -14,5 +14,7 @@ golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeu
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4= golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=

191
main.go
View file

@ -32,6 +32,16 @@ type WgData struct {
PublicKey string `json:"publicKey"` PublicKey string `json:"publicKey"`
ServerIP string `json:"serverIP"` ServerIP string `json:"serverIP"`
TunnelIP string `json:"tunnelIP"` TunnelIP string `json:"tunnelIP"`
Targets TargetsByType `json:"targets"`
}
type TargetsByType struct {
UDP []string `json:"udp"`
TCP []string `json:"tcp"`
}
type TargetData struct {
Targets []string `json:"targets"`
} }
func fixKey(key string) string { func fixKey(key string) string {
@ -177,6 +187,15 @@ persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.P
pm = proxy.NewProxyManager(tnet) pm = proxy.NewProxyManager(tnet)
connected = true connected = true
// add the targets if there are any
if len(wgData.Targets.TCP) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP})
}
if len(wgData.Targets.UDP) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP})
}
}) })
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) { client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
@ -188,55 +207,14 @@ persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.P
return return
} }
type TargetData struct { targetData, err := parseTargetData(msg.Data)
Targets []string `json:"targets"`
}
// Define a struct for the expected data structure
jsonData, err := json.Marshal(msg.Data)
if err != nil { if err != nil {
log.Printf("Error marshaling data: %v", err) log.Printf("Error parsing target data: %v", err)
return
}
// Parse into our target structure
var targetData TargetData
if err := json.Unmarshal(jsonData, &targetData); err != nil {
log.Printf("Error unmarshaling target data: %v", err)
return return
} }
if len(targetData.Targets) > 0 { if len(targetData.Targets) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData)
// Stop the proxy manager before adding new targets
err = pm.Stop()
if err != nil {
log.Panic(err)
}
for _, t := range targetData.Targets {
// Split the first number off of the target with : separator and use as the port
parts := strings.Split(t, ":")
if len(parts) != 2 {
log.Printf("Invalid target format: %s", t)
continue
}
// Get the port as an int
port := 0
_, err := fmt.Sscanf(parts[0], "%d", &port)
if err != nil {
log.Printf("Invalid port: %s", parts[0])
continue
}
target := parts[1]
pm.AddTarget("tcp", wgData.TunnelIP, port, target)
}
err = pm.Start()
if err != nil {
log.Panic(err)
}
} }
}) })
@ -249,51 +227,54 @@ persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.P
return return
} }
type TargetData struct { targetData, err := parseTargetData(msg.Data)
Targets []string `json:"targets"`
}
jsonData, err := json.Marshal(msg.Data)
if err != nil { if err != nil {
log.Printf("Error marshaling data: %v", err) log.Printf("Error parsing target data: %v", err)
return
}
var targetData TargetData
if err := json.Unmarshal(jsonData, &targetData); err != nil {
log.Printf("Error unmarshaling target data: %v", err)
return return
} }
if len(targetData.Targets) > 0 { if len(targetData.Targets) > 0 {
err = pm.Stop() updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData)
}
})
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
log.Printf("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
log.Printf("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil { if err != nil {
log.Panic(err) log.Printf("Error parsing target data: %v", err)
return
} }
for _, t := range targetData.Targets { if len(targetData.Targets) > 0 {
// Split the first number off of the target with : separator and use as the port updateTargets(pm, "remove", wgData.TunnelIP, "udp", targetData)
parts := strings.Split(t, ":") }
if len(parts) != 2 { })
log.Printf("Invalid target format: %s", t)
continue client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) {
log.Printf("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
log.Printf("No tunnel IP or proxy manager available")
return
} }
// Get the port as an int targetData, err := parseTargetData(msg.Data)
port := 0
_, err := fmt.Sscanf(parts[0], "%d", &port)
if err != nil { if err != nil {
log.Printf("Invalid port: %s", parts[0]) log.Printf("Error parsing target data: %v", err)
continue return
} }
target := parts[1] if len(targetData.Targets) > 0 {
pm.AddTarget("udp", wgData.TunnelIP, port, target) updateTargets(pm, "remove", wgData.TunnelIP, "tcp", targetData)
}
err = pm.Start()
if err != nil {
log.Panic(err)
}
} }
}) })
@ -303,10 +284,9 @@ persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.P
} }
defer client.Close() defer client.Close()
// TODO: we need to send the public key to the server to trigger it to respond to create the tunnel
// TODO: how to retry? // TODO: how to retry?
err = client.SendMessage("newt/wg/register", map[string]interface{}{ err = client.SendMessage("newt/wg/register", map[string]interface{}{
"content": "Hello, World!", "publicKey": fmt.Sprintf("%s", privateKey),
}) })
if err != nil { if err != nil {
log.Printf("Failed to send message: %v", err) log.Printf("Failed to send message: %v", err)
@ -320,3 +300,58 @@ persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.P
// Cleanup // Cleanup
dev.Close() dev.Close()
} }
func parseTargetData(data interface{}) (TargetData, error) {
var targetData TargetData
jsonData, err := json.Marshal(data)
if err != nil {
log.Printf("Error marshaling data: %v", err)
return targetData, err
}
if err := json.Unmarshal(jsonData, &targetData); err != nil {
log.Printf("Error unmarshaling target data: %v", err)
return targetData, err
}
return targetData, nil
}
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
// Stop the proxy manager before adding new targets
err := pm.Stop()
if err != nil {
log.Panic(err)
}
for _, t := range targetData.Targets {
// Split the first number off of the target with : separator and use as the port
parts := strings.Split(t, ":")
if len(parts) != 2 {
log.Printf("Invalid target format: %s", t)
continue
}
// Get the port as an int
port := 0
_, err := fmt.Sscanf(parts[0], "%d", &port)
if err != nil {
log.Printf("Invalid port: %s", parts[0])
continue
}
if action == "add" {
target := parts[1]
pm.AddTarget(proto, tunnelIP, port, target)
} else if action == "remove" {
pm.RemoveTarget(proto, tunnelIP, port)
}
}
err = pm.Start()
if err != nil {
log.Panic(err)
}
return nil
}

View file

@ -32,23 +32,34 @@ func (pm *ProxyManager) AddTarget(protocol, listen string, port int, target stri
pm.targets = append(pm.targets, newTarget) pm.targets = append(pm.targets, newTarget)
} }
func (pm *ProxyManager) RemoveTarget(listen string, port int) error { func (pm *ProxyManager) RemoveTarget(protocol, listen string, port int) error {
pm.Lock() pm.Lock()
defer pm.Unlock() defer pm.Unlock()
protocol = strings.ToLower(protocol)
if protocol != "tcp" && protocol != "udp" {
return fmt.Errorf("unsupported protocol: %s", protocol)
}
for i, target := range pm.targets { for i, target := range pm.targets {
if target.Listen == listen && target.Port == port { if target.Listen == listen &&
target.Port == port &&
strings.ToLower(target.Protocol) == protocol {
// Signal the serving goroutine to stop // Signal the serving goroutine to stop
close(target.cancel) close(target.cancel)
// Close the listener/connection // Close the appropriate listener/connection based on protocol
target.Lock() target.Lock()
switch protocol {
case "tcp":
if target.listener != nil { if target.listener != nil {
target.listener.Close() target.listener.Close()
} }
case "udp":
if target.udpConn != nil { if target.udpConn != nil {
target.udpConn.Close() target.udpConn.Close()
} }
}
target.Unlock() target.Unlock()
// Remove the target from the slice // Remove the target from the slice
@ -57,7 +68,7 @@ func (pm *ProxyManager) RemoveTarget(listen string, port int) error {
} }
} }
return fmt.Errorf("target not found for %s:%d", listen, port) return fmt.Errorf("target not found for %s %s:%d", protocol, listen, port)
} }
func (pm *ProxyManager) Start() error { func (pm *ProxyManager) Start() error {

View file

@ -2,7 +2,6 @@ package websocket
import ( import (
"encoding/json" "encoding/json"
"io/ioutil"
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
@ -29,7 +28,7 @@ func getConfigPath() string {
func (c *Client) loadConfig() error { func (c *Client) loadConfig() error {
configPath := getConfigPath() configPath := getConfigPath()
data, err := ioutil.ReadFile(configPath) data, err := os.ReadFile(configPath)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
return nil return nil
@ -53,5 +52,5 @@ func (c *Client) saveConfig() error {
if err != nil { if err != nil {
return err return err
} }
return ioutil.WriteFile(configPath, data, 0644) return os.WriteFile(configPath, data, 0644)
} }