Merge pull request #20 from fosrl/dev

Cleanup & Updown Script
This commit is contained in:
Owen Schwartz 2025-03-09 23:28:44 -04:00 committed by GitHub
commit 623be5ea0d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 187 additions and 22 deletions

View file

@ -36,7 +36,8 @@ When Newt receives WireGuard control messages, it will use the information encod
- `secret`: A unique secret (not shared and kept private) used to authenticate the client ID with the websocket in order to receive commands. - `secret`: A unique secret (not shared and kept private) used to authenticate the client ID with the websocket in order to receive commands.
- `dns`: DNS server to use to resolve the endpoint - `dns`: DNS server to use to resolve the endpoint
- `log-level` (optional): The log level to use. Default: INFO - `log-level` (optional): The log level to use. Default: INFO
- `updown` (optional): A script to be called when targets are added or removed.
Example: Example:
```bash ```bash
@ -92,6 +93,20 @@ WantedBy=multi-user.target
Make sure to `mv ./newt /usr/local/bin/newt`! Make sure to `mv ./newt /usr/local/bin/newt`!
### Updown
You can pass in a updown script for Newt to call when it is adding or removing a target:
`--updown "python3 test.py"`
It will get called with args when a target is added:
`python3 test.py add tcp localhost:8556`
`python3 test.py remove tcp localhost:8556`
Returning a string from the script in the format of a target (`ip:dst` so `10.0.0.1:8080`) it will override the target and use this value instead to proxy.
You can look at updown.py as a reference script to get started!
## Build ## Build
### Container ### Container

14
go.mod
View file

@ -4,17 +4,19 @@ go 1.23.1
toolchain go1.23.2 toolchain go1.23.2
require golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 require (
github.com/gorilla/websocket v1.5.3
golang.org/x/net v0.30.0
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259
)
require ( require (
github.com/google/btree v1.1.2 // indirect github.com/google/btree v1.1.2 // indirect
github.com/gorilla/websocket v1.5.3 // indirect github.com/google/go-cmp v0.6.0 // indirect
golang.org/x/crypto v0.28.0 // indirect golang.org/x/crypto v0.28.0 // indirect
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 // indirect
golang.org/x/net v0.30.0 // indirect
golang.org/x/sys v0.26.0 // indirect golang.org/x/sys v0.26.0 // indirect
golang.org/x/time v0.7.0 // indirect golang.org/x/time v0.7.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 // indirect
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect
) )

4
go.sum
View file

@ -1,11 +1,11 @@
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA=
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU=
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=

97
main.go
View file

@ -11,6 +11,7 @@ import (
"net" "net"
"net/netip" "net/netip"
"os" "os"
"os/exec"
"os/signal" "os/signal"
"strconv" "strconv"
"strings" "strings"
@ -244,19 +245,20 @@ func resolveDomain(domain string) (string, error) {
return ipAddr, nil return ipAddr, nil
} }
func main() { var (
var ( endpoint string
endpoint string id string
id string secret string
secret string mtu string
mtu string mtuInt int
mtuInt int dns string
dns string privateKey wgtypes.Key
privateKey wgtypes.Key err error
err error logLevel string
logLevel string updownScript string
) )
func main() {
// if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values // if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values
endpoint = os.Getenv("PANGOLIN_ENDPOINT") endpoint = os.Getenv("PANGOLIN_ENDPOINT")
id = os.Getenv("NEWT_ID") id = os.Getenv("NEWT_ID")
@ -264,6 +266,7 @@ func main() {
mtu = os.Getenv("MTU") mtu = os.Getenv("MTU")
dns = os.Getenv("DNS") dns = os.Getenv("DNS")
logLevel = os.Getenv("LOG_LEVEL") logLevel = os.Getenv("LOG_LEVEL")
updownScript = os.Getenv("UPDOWN_SCRIPT")
if endpoint == "" { if endpoint == "" {
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
@ -283,6 +286,9 @@ func main() {
if logLevel == "" { if logLevel == "" {
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
} }
if updownScript == "" {
flag.StringVar(&updownScript, "updown", "", "Path to updown script to be called when targets are added or removed")
}
// do a --version check // do a --version check
version := flag.Bool("version", false, "Print the version") version := flag.Bool("version", false, "Print the version")
@ -586,6 +592,18 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
if action == "add" { if action == "add" {
target := parts[1] + ":" + parts[2] target := parts[1] + ":" + parts[2]
// Call updown script if provided
processedTarget := target
if updownScript != "" {
newTarget, err := executeUpdownScript(action, proto, target)
if err != nil {
logger.Warn("Updown script error: %v", err)
} else if newTarget != "" {
processedTarget = newTarget
}
}
// Only remove the specific target if it exists // Only remove the specific target if it exists
err := pm.RemoveTarget(proto, tunnelIP, port) err := pm.RemoveTarget(proto, tunnelIP, port)
if err != nil { if err != nil {
@ -596,10 +614,21 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
} }
// Add the new target // Add the new target
pm.AddTarget(proto, tunnelIP, port, target) pm.AddTarget(proto, tunnelIP, port, processedTarget)
} else if action == "remove" { } else if action == "remove" {
logger.Info("Removing target with port %d", port) logger.Info("Removing target with port %d", port)
target := parts[1] + ":" + parts[2]
// Call updown script if provided
if updownScript != "" {
_, err := executeUpdownScript(action, proto, target)
if err != nil {
logger.Warn("Updown script error: %v", err)
}
}
err := pm.RemoveTarget(proto, tunnelIP, port) err := pm.RemoveTarget(proto, tunnelIP, port)
if err != nil { if err != nil {
logger.Error("Failed to remove target: %v", err) logger.Error("Failed to remove target: %v", err)
@ -610,3 +639,45 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
return nil return nil
} }
func executeUpdownScript(action, proto, target string) (string, error) {
if updownScript == "" {
return target, nil
}
// Split the updownScript in case it contains spaces (like "/usr/bin/python3 script.py")
parts := strings.Fields(updownScript)
if len(parts) == 0 {
return target, fmt.Errorf("invalid updown script command")
}
var cmd *exec.Cmd
if len(parts) == 1 {
// If it's a single executable
logger.Info("Executing updown script: %s %s %s %s", updownScript, action, proto, target)
cmd = exec.Command(parts[0], action, proto, target)
} else {
// If it includes interpreter and script
args := append(parts[1:], action, proto, target)
logger.Info("Executing updown script: %s %s %s %s %s", parts[0], strings.Join(parts[1:], " "), action, proto, target)
cmd = exec.Command(parts[0], args...)
}
output, err := cmd.Output()
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
return "", fmt.Errorf("updown script execution failed (exit code %d): %s",
exitErr.ExitCode(), string(exitErr.Stderr))
}
return "", fmt.Errorf("updown script execution failed: %v", err)
}
// If the script returns a new target, use it
newTarget := strings.TrimSpace(string(output))
if newTarget != "" {
logger.Info("Updown script returned new target: %s", newTarget)
return newTarget, nil
}
return target, nil
}

77
updown.py Normal file
View file

@ -0,0 +1,77 @@
"""
Sample updown script for Newt proxy
Usage: update.py <action> <protocol> <target>
Parameters:
- action: 'add' or 'remove'
- protocol: 'tcp' or 'udp'
- target: the target address in format 'host:port'
If the action is 'add', the script can return a modified target that
will be used instead of the original.
"""
import sys
import logging
import json
from datetime import datetime
# Configure logging
LOG_FILE = "/tmp/newt-updown.log"
logging.basicConfig(
filename=LOG_FILE,
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
def log_event(action, protocol, target):
"""Log each event to a file for auditing purposes"""
timestamp = datetime.now().isoformat()
event = {
"timestamp": timestamp,
"action": action,
"protocol": protocol,
"target": target
}
logging.info(json.dumps(event))
def handle_add(protocol, target):
"""Handle 'add' action"""
logging.info(f"Adding {protocol} target: {target}")
def handle_remove(protocol, target):
"""Handle 'remove' action"""
logging.info(f"Removing {protocol} target: {target}")
# For remove action, no return value is expected or used
def main():
# Check arguments
if len(sys.argv) != 4:
logging.error(f"Invalid arguments: {sys.argv}")
sys.exit(1)
action = sys.argv[1]
protocol = sys.argv[2]
target = sys.argv[3]
# Log the event
log_event(action, protocol, target)
# Handle the action
if action == "add":
new_target = handle_add(protocol, target)
# Print the new target to stdout (if empty, no change will be made)
if new_target and new_target != target:
print(new_target)
elif action == "remove":
handle_remove(protocol, target)
else:
logging.error(f"Unknown action: {action}")
sys.exit(1)
if __name__ == "__main__":
try:
main()
except Exception as e:
logging.error(f"Unhandled exception: {e}")
sys.exit(1)