mirror of
https://github.com/fosrl/newt.git
synced 2025-05-13 05:30:39 +01:00
Polish; add remove
This commit is contained in:
parent
2e5531b4a5
commit
055d50d1d3
5 changed files with 146 additions and 98 deletions
1
go.mod
1
go.mod
|
@ -14,5 +14,6 @@ require (
|
||||||
golang.org/x/sys v0.26.0 // indirect
|
golang.org/x/sys v0.26.0 // indirect
|
||||||
golang.org/x/time v0.7.0 // indirect
|
golang.org/x/time v0.7.0 // indirect
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||||
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 // indirect
|
||||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect
|
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect
|
||||||
)
|
)
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -14,5 +14,7 @@ golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeu
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
|
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||||
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
|
||||||
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
|
||||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
|
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
|
||||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
|
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
|
||||||
|
|
207
main.go
207
main.go
|
@ -28,10 +28,20 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type WgData struct {
|
type WgData struct {
|
||||||
Endpoint string `json:"endpoint"`
|
Endpoint string `json:"endpoint"`
|
||||||
PublicKey string `json:"publicKey"`
|
PublicKey string `json:"publicKey"`
|
||||||
ServerIP string `json:"serverIP"`
|
ServerIP string `json:"serverIP"`
|
||||||
TunnelIP string `json:"tunnelIP"`
|
TunnelIP string `json:"tunnelIP"`
|
||||||
|
Targets TargetsByType `json:"targets"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TargetsByType struct {
|
||||||
|
UDP []string `json:"udp"`
|
||||||
|
TCP []string `json:"tcp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TargetData struct {
|
||||||
|
Targets []string `json:"targets"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func fixKey(key string) string {
|
func fixKey(key string) string {
|
||||||
|
@ -177,6 +187,15 @@ persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.P
|
||||||
pm = proxy.NewProxyManager(tnet)
|
pm = proxy.NewProxyManager(tnet)
|
||||||
|
|
||||||
connected = true
|
connected = true
|
||||||
|
|
||||||
|
// add the targets if there are any
|
||||||
|
if len(wgData.Targets.TCP) > 0 {
|
||||||
|
updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP})
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(wgData.Targets.UDP) > 0 {
|
||||||
|
updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP})
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
|
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
|
||||||
|
@ -188,55 +207,14 @@ persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.P
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
type TargetData struct {
|
targetData, err := parseTargetData(msg.Data)
|
||||||
Targets []string `json:"targets"`
|
|
||||||
}
|
|
||||||
// Define a struct for the expected data structure
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error marshaling data: %v", err)
|
log.Printf("Error parsing target data: %v", err)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse into our target structure
|
|
||||||
var targetData TargetData
|
|
||||||
if err := json.Unmarshal(jsonData, &targetData); err != nil {
|
|
||||||
log.Printf("Error unmarshaling target data: %v", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
if len(targetData.Targets) > 0 {
|
||||||
|
updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData)
|
||||||
// Stop the proxy manager before adding new targets
|
|
||||||
err = pm.Stop()
|
|
||||||
if err != nil {
|
|
||||||
log.Panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, t := range targetData.Targets {
|
|
||||||
// Split the first number off of the target with : separator and use as the port
|
|
||||||
parts := strings.Split(t, ":")
|
|
||||||
if len(parts) != 2 {
|
|
||||||
log.Printf("Invalid target format: %s", t)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the port as an int
|
|
||||||
port := 0
|
|
||||||
_, err := fmt.Sscanf(parts[0], "%d", &port)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Invalid port: %s", parts[0])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
target := parts[1]
|
|
||||||
pm.AddTarget("tcp", wgData.TunnelIP, port, target)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = pm.Start()
|
|
||||||
if err != nil {
|
|
||||||
log.Panic(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -249,51 +227,54 @@ persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.P
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
type TargetData struct {
|
targetData, err := parseTargetData(msg.Data)
|
||||||
Targets []string `json:"targets"`
|
|
||||||
}
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error marshaling data: %v", err)
|
log.Printf("Error parsing target data: %v", err)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var targetData TargetData
|
|
||||||
if err := json.Unmarshal(jsonData, &targetData); err != nil {
|
|
||||||
log.Printf("Error unmarshaling target data: %v", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
if len(targetData.Targets) > 0 {
|
||||||
err = pm.Stop()
|
updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData)
|
||||||
if err != nil {
|
}
|
||||||
log.Panic(err)
|
})
|
||||||
}
|
|
||||||
|
|
||||||
for _, t := range targetData.Targets {
|
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
|
||||||
// Split the first number off of the target with : separator and use as the port
|
log.Printf("Received: %+v", msg)
|
||||||
parts := strings.Split(t, ":")
|
|
||||||
if len(parts) != 2 {
|
|
||||||
log.Printf("Invalid target format: %s", t)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the port as an int
|
// if there is no wgData or pm, we can't add targets
|
||||||
port := 0
|
if wgData.TunnelIP == "" || pm == nil {
|
||||||
_, err := fmt.Sscanf(parts[0], "%d", &port)
|
log.Printf("No tunnel IP or proxy manager available")
|
||||||
if err != nil {
|
return
|
||||||
log.Printf("Invalid port: %s", parts[0])
|
}
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
target := parts[1]
|
targetData, err := parseTargetData(msg.Data)
|
||||||
pm.AddTarget("udp", wgData.TunnelIP, port, target)
|
if err != nil {
|
||||||
}
|
log.Printf("Error parsing target data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
err = pm.Start()
|
if len(targetData.Targets) > 0 {
|
||||||
if err != nil {
|
updateTargets(pm, "remove", wgData.TunnelIP, "udp", targetData)
|
||||||
log.Panic(err)
|
}
|
||||||
}
|
})
|
||||||
|
|
||||||
|
client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) {
|
||||||
|
log.Printf("Received: %+v", msg)
|
||||||
|
|
||||||
|
// if there is no wgData or pm, we can't add targets
|
||||||
|
if wgData.TunnelIP == "" || pm == nil {
|
||||||
|
log.Printf("No tunnel IP or proxy manager available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
targetData, err := parseTargetData(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error parsing target data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(targetData.Targets) > 0 {
|
||||||
|
updateTargets(pm, "remove", wgData.TunnelIP, "tcp", targetData)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -303,10 +284,9 @@ persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.P
|
||||||
}
|
}
|
||||||
defer client.Close()
|
defer client.Close()
|
||||||
|
|
||||||
// TODO: we need to send the public key to the server to trigger it to respond to create the tunnel
|
|
||||||
// TODO: how to retry?
|
// TODO: how to retry?
|
||||||
err = client.SendMessage("newt/wg/register", map[string]interface{}{
|
err = client.SendMessage("newt/wg/register", map[string]interface{}{
|
||||||
"content": "Hello, World!",
|
"publicKey": fmt.Sprintf("%s", privateKey),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to send message: %v", err)
|
log.Printf("Failed to send message: %v", err)
|
||||||
|
@ -320,3 +300,58 @@ persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.P
|
||||||
// Cleanup
|
// Cleanup
|
||||||
dev.Close()
|
dev.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseTargetData(data interface{}) (TargetData, error) {
|
||||||
|
var targetData TargetData
|
||||||
|
jsonData, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error marshaling data: %v", err)
|
||||||
|
return targetData, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(jsonData, &targetData); err != nil {
|
||||||
|
log.Printf("Error unmarshaling target data: %v", err)
|
||||||
|
return targetData, err
|
||||||
|
}
|
||||||
|
return targetData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
|
||||||
|
|
||||||
|
// Stop the proxy manager before adding new targets
|
||||||
|
err := pm.Stop()
|
||||||
|
if err != nil {
|
||||||
|
log.Panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range targetData.Targets {
|
||||||
|
// Split the first number off of the target with : separator and use as the port
|
||||||
|
parts := strings.Split(t, ":")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
log.Printf("Invalid target format: %s", t)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the port as an int
|
||||||
|
port := 0
|
||||||
|
_, err := fmt.Sscanf(parts[0], "%d", &port)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Invalid port: %s", parts[0])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if action == "add" {
|
||||||
|
target := parts[1]
|
||||||
|
pm.AddTarget(proto, tunnelIP, port, target)
|
||||||
|
} else if action == "remove" {
|
||||||
|
pm.RemoveTarget(proto, tunnelIP, port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pm.Start()
|
||||||
|
if err != nil {
|
||||||
|
log.Panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -32,22 +32,33 @@ func (pm *ProxyManager) AddTarget(protocol, listen string, port int, target stri
|
||||||
pm.targets = append(pm.targets, newTarget)
|
pm.targets = append(pm.targets, newTarget)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) RemoveTarget(listen string, port int) error {
|
func (pm *ProxyManager) RemoveTarget(protocol, listen string, port int) error {
|
||||||
pm.Lock()
|
pm.Lock()
|
||||||
defer pm.Unlock()
|
defer pm.Unlock()
|
||||||
|
|
||||||
|
protocol = strings.ToLower(protocol)
|
||||||
|
if protocol != "tcp" && protocol != "udp" {
|
||||||
|
return fmt.Errorf("unsupported protocol: %s", protocol)
|
||||||
|
}
|
||||||
|
|
||||||
for i, target := range pm.targets {
|
for i, target := range pm.targets {
|
||||||
if target.Listen == listen && target.Port == port {
|
if target.Listen == listen &&
|
||||||
|
target.Port == port &&
|
||||||
|
strings.ToLower(target.Protocol) == protocol {
|
||||||
// Signal the serving goroutine to stop
|
// Signal the serving goroutine to stop
|
||||||
close(target.cancel)
|
close(target.cancel)
|
||||||
|
|
||||||
// Close the listener/connection
|
// Close the appropriate listener/connection based on protocol
|
||||||
target.Lock()
|
target.Lock()
|
||||||
if target.listener != nil {
|
switch protocol {
|
||||||
target.listener.Close()
|
case "tcp":
|
||||||
}
|
if target.listener != nil {
|
||||||
if target.udpConn != nil {
|
target.listener.Close()
|
||||||
target.udpConn.Close()
|
}
|
||||||
|
case "udp":
|
||||||
|
if target.udpConn != nil {
|
||||||
|
target.udpConn.Close()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
target.Unlock()
|
target.Unlock()
|
||||||
|
|
||||||
|
@ -57,7 +68,7 @@ func (pm *ProxyManager) RemoveTarget(listen string, port int) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("target not found for %s:%d", listen, port)
|
return fmt.Errorf("target not found for %s %s:%d", protocol, listen, port)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) Start() error {
|
func (pm *ProxyManager) Start() error {
|
||||||
|
|
|
@ -2,7 +2,6 @@ package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io/ioutil"
|
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -29,7 +28,7 @@ func getConfigPath() string {
|
||||||
|
|
||||||
func (c *Client) loadConfig() error {
|
func (c *Client) loadConfig() error {
|
||||||
configPath := getConfigPath()
|
configPath := getConfigPath()
|
||||||
data, err := ioutil.ReadFile(configPath)
|
data, err := os.ReadFile(configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
return nil
|
return nil
|
||||||
|
@ -53,5 +52,5 @@ func (c *Client) saveConfig() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return ioutil.WriteFile(configPath, data, 0644)
|
return os.WriteFile(configPath, data, 0644)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue