From 9b3c82648b1daa452423de3bb67e04e1f1612ba2 Mon Sep 17 00:00:00 2001 From: progressive-kiwi Date: Mon, 31 Mar 2025 00:06:40 +0200 Subject: [PATCH] feat/mtls-support --- README.md | 37 +++++++++++++++++++++++-- go.mod | 1 + go.sum | 2 ++ main.go | 34 ++++++++++++++++------- websocket/client.go | 67 +++++++++++++++++++++++++++++++++++++++++++-- 5 files changed, 127 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 8f0d1c3..127a77c 100644 --- a/README.md +++ b/README.md @@ -37,8 +37,9 @@ When Newt receives WireGuard control messages, it will use the information encod - `dns`: DNS server to use to resolve the endpoint - `log-level` (optional): The log level to use. Default: INFO - `updown` (optional): A script to be called when targets are added or removed. - -Example: +- `tls-client-cert` (optional): Client certificate (p12 or pfx) for mTLS. See [mTLS](#mtls) + +- Example: ```bash ./newt \ @@ -107,6 +108,38 @@ Returning a string from the script in the format of a target (`ip:dst` so `10.0. You can look at updown.py as a reference script to get started! +### mTLS +Newt supports mutual TLS (mTLS) authentication, if the server has been configured to request a client certificate. +* Only PKCS12 (.p12 or .pfx) file format is accepted +* The PKCS12 file must contain: + * Private key + * Public certificate + * CA certificate +* Encrypted PKCS12 files are currently not supported + +Examples: + +```bash +./newt \ +--id 31frd0uzbjvp721 \ +--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \ +--endpoint https://example.com \ +--tls-client-cert /client.p12 +``` + +```yaml +services: + newt: + image: fosrl/newt + container_name: newt + restart: unless-stopped + environment: + - PANGOLIN_ENDPOINT=https://example.com + - NEWT_ID=2ix2t8xk22ubpfy + - NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2 + - TLS_CLIENT_CERT=/client.p12 +``` + ## Build ### Container diff --git a/go.mod b/go.mod index 33c593f..af10435 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 + software.sslmate.com/src/go-pkcs12 v0.5.0 ) require ( diff --git a/go.sum b/go.sum index 2328634..35a4b31 100644 --- a/go.sum +++ b/go.sum @@ -20,3 +20,5 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvY golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= +software.sslmate.com/src/go-pkcs12 v0.5.0 h1:EC6R394xgENTpZ4RltKydeDUjtlM5drOYIG9c6TVj2M= +software.sslmate.com/src/go-pkcs12 v0.5.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI= diff --git a/main.go b/main.go index cf3f062..bec5554 100644 --- a/main.go +++ b/main.go @@ -246,16 +246,17 @@ func resolveDomain(domain string) (string, error) { } var ( - endpoint string - id string - secret string - mtu string - mtuInt int - dns string - privateKey wgtypes.Key - err error - logLevel string - updownScript string + endpoint string + id string + secret string + mtu string + mtuInt int + dns string + privateKey wgtypes.Key + err error + logLevel string + updownScript string + tlsPrivateKey string ) func main() { @@ -267,6 +268,7 @@ func main() { dns = os.Getenv("DNS") logLevel = os.Getenv("LOG_LEVEL") updownScript = os.Getenv("UPDOWN_SCRIPT") + tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT") if endpoint == "" { flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") @@ -289,6 +291,9 @@ func main() { if updownScript == "" { flag.StringVar(&updownScript, "updown", "", "Path to updown script to be called when targets are added or removed") } + if tlsPrivateKey == "" { + flag.StringVar(&tlsPrivateKey, "tls-client-cert", "", "Path to client certificate used for mTLS") + } // do a --version check version := flag.Bool("version", false, "Print the version") @@ -314,12 +319,21 @@ func main() { if err != nil { logger.Fatal("Failed to generate private key: %v", err) } + var opt websocket.ClientOption + if tlsPrivateKey != "" { + tlsConfig, err := websocket.LoadClientCertificate(tlsPrivateKey) + if err != nil { + logger.Fatal("Failed to load client certificate: %v", err) + } + opt = websocket.WithTLSConfig(tlsConfig) + } // Create a new client client, err := websocket.NewClient( id, // CLI arg takes precedence secret, // CLI arg takes precedence endpoint, + opt, ) if err != nil { logger.Fatal("Failed to create client: %v", err) diff --git a/websocket/client.go b/websocket/client.go index 022a489..0f491d3 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -2,16 +2,19 @@ package websocket import ( "bytes" + "crypto/tls" + "crypto/x509" "encoding/json" "fmt" "net/http" "net/url" + "os" + "software.sslmate.com/src/go-pkcs12" "strings" "sync" "time" "github.com/fosrl/newt/logger" - "github.com/gorilla/websocket" ) @@ -22,6 +25,7 @@ type Client struct { handlers map[string]MessageHandler done chan struct{} handlersMux sync.RWMutex + tlsConfig *tls.Config reconnectInterval time.Duration isConnected bool @@ -41,6 +45,12 @@ func WithBaseURL(url string) ClientOption { } } +func WithTLSConfig(tlsConfig *tls.Config) ClientOption { + return func(c *Client) { + c.tlsConfig = tlsConfig + } +} + func (c *Client) OnConnect(callback func() error) { c.onConnect = callback } @@ -177,6 +187,12 @@ func (c *Client) getToken() (string, error) { // Make the request client := &http.Client{} + if c.tlsConfig != nil { + logger.Info("Adding tls to req") + client.Transport = &http.Transport{ + TLSClientConfig: c.tlsConfig, + } + } resp, err := client.Do(req) if err != nil { return "", fmt.Errorf("failed to check token validity: %w", err) @@ -220,6 +236,11 @@ func (c *Client) getToken() (string, error) { // Make the request client := &http.Client{} + if c.tlsConfig != nil { + client.Transport = &http.Transport{ + TLSClientConfig: c.tlsConfig, + } + } resp, err := client.Do(req) if err != nil { return "", fmt.Errorf("failed to request new token: %w", err) @@ -295,7 +316,11 @@ func (c *Client) establishConnection() error { u.RawQuery = q.Encode() // Connect to WebSocket - conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + dialer := websocket.DefaultDialer + if c.tlsConfig != nil { + dialer.TLSClientConfig = c.tlsConfig + } + conn, _, err := dialer.Dial(u.String(), nil) if err != nil { return fmt.Errorf("failed to connect to WebSocket: %w", err) } @@ -353,3 +378,41 @@ func (c *Client) setConnected(status bool) { defer c.reconnectMux.Unlock() c.isConnected = status } + +// LoadClientCertificate Helper method to load client certificates +func LoadClientCertificate(p12Path string) (*tls.Config, error) { + // Read the PKCS12 file + p12Data, err := os.ReadFile(p12Path) + if err != nil { + return nil, fmt.Errorf("failed to read PKCS12 file: %w", err) + } + + // Parse PKCS12 with empty password for non-encrypted files + privateKey, certificate, caCerts, err := pkcs12.DecodeChain(p12Data, "") + if err != nil { + return nil, fmt.Errorf("failed to decode PKCS12: %w", err) + } + + // Create certificate + cert := tls.Certificate{ + Certificate: [][]byte{certificate.Raw}, + PrivateKey: privateKey, + } + + // Optional: Add CA certificates if present + rootCAs, err := x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("failed to load system cert pool: %w", err) + } + if len(caCerts) > 0 { + for _, caCert := range caCerts { + rootCAs.AddCert(caCert) + } + } + + // Create TLS configuration + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: rootCAs, + }, nil +}