Merge pull request #26 from progressive-kiwi/feat-mtls-support

Feat: mTLS support
This commit is contained in:
Owen Schwartz 2025-04-02 21:23:17 -04:00 committed by GitHub
commit e7c8dbc1c8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 290 additions and 31 deletions

3
.gitignore vendored
View file

@ -1,3 +1,6 @@
newt newt
.DS_Store .DS_Store
bin/ bin/
.idea
*.iml
certs/

View file

@ -37,8 +37,9 @@ When Newt receives WireGuard control messages, it will use the information encod
- `dns`: DNS server to use to resolve the endpoint - `dns`: DNS server to use to resolve the endpoint
- `log-level` (optional): The log level to use. Default: INFO - `log-level` (optional): The log level to use. Default: INFO
- `updown` (optional): A script to be called when targets are added or removed. - `updown` (optional): A script to be called when targets are added or removed.
- `tls-client-cert` (optional): Client certificate (p12 or pfx) for mTLS. See [mTLS](#mtls)
Example: - Example:
```bash ```bash
./newt \ ./newt \
@ -107,6 +108,38 @@ Returning a string from the script in the format of a target (`ip:dst` so `10.0.
You can look at updown.py as a reference script to get started! You can look at updown.py as a reference script to get started!
### mTLS
Newt supports mutual TLS (mTLS) authentication, if the server has been configured to request a client certificate.
* Only PKCS12 (.p12 or .pfx) file format is accepted
* The PKCS12 file must contain:
* Private key
* Public certificate
* CA certificate
* Encrypted PKCS12 files are currently not supported
Examples:
```bash
./newt \
--id 31frd0uzbjvp721 \
--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
--endpoint https://example.com \
--tls-client-cert ./client.p12
```
```yaml
services:
newt:
image: fosrl/newt
container_name: newt
restart: unless-stopped
environment:
- PANGOLIN_ENDPOINT=https://example.com
- NEWT_ID=2ix2t8xk22ubpfy
- NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2
- TLS_CLIENT_CERT=./client.p12
```
## Build ## Build
### Container ### Container

1
go.mod
View file

@ -10,6 +10,7 @@ require (
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259
software.sslmate.com/src/go-pkcs12 v0.5.0
) )
require ( require (

2
go.sum
View file

@ -20,3 +20,5 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvY
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
software.sslmate.com/src/go-pkcs12 v0.5.0 h1:EC6R394xgENTpZ4RltKydeDUjtlM5drOYIG9c6TVj2M=
software.sslmate.com/src/go-pkcs12 v0.5.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=

16
main.go
View file

@ -350,6 +350,7 @@ var (
err error err error
logLevel string logLevel string
updownScript string updownScript string
tlsPrivateKey string
) )
func main() { func main() {
@ -361,6 +362,7 @@ func main() {
dns = os.Getenv("DNS") dns = os.Getenv("DNS")
logLevel = os.Getenv("LOG_LEVEL") logLevel = os.Getenv("LOG_LEVEL")
updownScript = os.Getenv("UPDOWN_SCRIPT") updownScript = os.Getenv("UPDOWN_SCRIPT")
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
if endpoint == "" { if endpoint == "" {
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
@ -383,6 +385,9 @@ func main() {
if updownScript == "" { if updownScript == "" {
flag.StringVar(&updownScript, "updown", "", "Path to updown script to be called when targets are added or removed") flag.StringVar(&updownScript, "updown", "", "Path to updown script to be called when targets are added or removed")
} }
if tlsPrivateKey == "" {
flag.StringVar(&tlsPrivateKey, "tls-client-cert", "", "Path to client certificate used for mTLS")
}
// do a --version check // do a --version check
version := flag.Bool("version", false, "Print the version") version := flag.Bool("version", false, "Print the version")
@ -408,12 +413,16 @@ func main() {
if err != nil { if err != nil {
logger.Fatal("Failed to generate private key: %v", err) logger.Fatal("Failed to generate private key: %v", err)
} }
var opt websocket.ClientOption
if tlsPrivateKey != "" {
opt = websocket.WithTLSConfig(tlsPrivateKey)
}
// Create a new client // Create a new client
client, err := websocket.NewClient( client, err := websocket.NewClient(
id, // CLI arg takes precedence id, // CLI arg takes precedence
secret, // CLI arg takes precedence secret, // CLI arg takes precedence
endpoint, endpoint,
opt,
) )
if err != nil { if err != nil {
logger.Fatal("Failed to create client: %v", err) logger.Fatal("Failed to create client: %v", err)
@ -642,11 +651,14 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
// Wait for interrupt signal // Wait for interrupt signal
sigCh := make(chan os.Signal, 1) sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh sigReceived := <-sigCh
// Cleanup // Cleanup
logger.Info("Received %s signal, stopping", sigReceived.String())
if dev != nil {
dev.Close() dev.Close()
} }
}
func parseTargetData(data interface{}) (TargetData, error) { func parseTargetData(data interface{}) (TargetData, error) {
var targetData TargetData var targetData TargetData

125
self-signed-certs-for-mtls.sh Executable file
View file

@ -0,0 +1,125 @@
#!/usr/bin/env bash
set -eu
echo -n "Enter username for certs (eg alice): "
read CERT_USERNAME
echo
echo -n "Enter domain of user (eg example.com): "
read DOMAIN
echo
# Prompt for password at the start
echo -n "Enter password for certificate: "
read -s PASSWORD
echo
echo -n "Confirm password: "
read -s PASSWORD2
echo
if [ "$PASSWORD" != "$PASSWORD2" ]; then
echo "Passwords don't match!"
exit 1
fi
CA_DIR="./certs/ca"
CLIENT_DIR="./certs/clients"
FILE_PREFIX=$(echo "$CERT_USERNAME-at-$DOMAIN" | sed 's/\./-/')
mkdir -p "$CA_DIR"
mkdir -p "$CLIENT_DIR"
if [ ! -f "$CA_DIR/ca.crt" ]; then
# Generate CA private key
openssl genrsa -out "$CA_DIR/ca.key" 4096
echo "CA key ✅"
# Generate CA root certificate
openssl req -x509 -new -nodes \
-key "$CA_DIR/ca.key" \
-sha256 \
-days 3650 \
-out "$CA_DIR/ca.crt" \
-subj "/C=US/ST=State/L=City/O=Organization/OU=Unit/CN=ca.$DOMAIN"
echo "CA cert ✅"
fi
# Generate client private key
openssl genrsa -aes256 -passout pass:"$PASSWORD" -out "$CLIENT_DIR/$FILE_PREFIX.key" 2048
echo "Client key ✅"
# Generate client Certificate Signing Request (CSR)
openssl req -new \
-key "$CLIENT_DIR/$FILE_PREFIX.key" \
-out "$CLIENT_DIR/$FILE_PREFIX.csr" \
-passin pass:"$PASSWORD" \
-subj "/C=US/ST=State/L=City/O=Organization/OU=Unit/CN=$CERT_USERNAME@$DOMAIN"
echo "Client cert ✅"
echo -n "Signing client cert..."
# Create client certificate configuration file
cat > "$CLIENT_DIR/$FILE_PREFIX.ext" << EOF
authorityKeyIdentifier=keyid,issuer
basicConstraints=CA:FALSE
keyUsage = digitalSignature, nonRepudiation, keyEncipherment, dataEncipherment
subjectAltName = @alt_names
[alt_names]
DNS.1 = $DOMAIN
EOF
# Generate client certificate signed by CA
openssl x509 -req \
-in "$CLIENT_DIR/$FILE_PREFIX.csr" \
-CA "$CA_DIR/ca.crt" \
-CAkey "$CA_DIR/ca.key" \
-CAcreateserial \
-out "$CLIENT_DIR/$FILE_PREFIX.crt" \
-days 365 \
-sha256 \
-extfile "$CLIENT_DIR/$FILE_PREFIX.ext"
# Verify the client certificate
openssl verify -CAfile "$CA_DIR/ca.crt" "$CLIENT_DIR/$FILE_PREFIX.crt"
echo "Signed ✅"
# Create encrypted PEM bundle
openssl rsa -in "$CLIENT_DIR/$FILE_PREFIX.key" -passin pass:"$PASSWORD" \
| cat "$CLIENT_DIR/$FILE_PREFIX.crt" - > "$CLIENT_DIR/$FILE_PREFIX-bundle.enc.pem"
# Convert to PKCS12
echo "Converting to PKCS12 format..."
openssl pkcs12 -export \
-out "$CLIENT_DIR/$FILE_PREFIX.enc.p12" \
-inkey "$CLIENT_DIR/$FILE_PREFIX.key" \
-in "$CLIENT_DIR/$FILE_PREFIX.crt" \
-certfile "$CA_DIR/ca.crt" \
-name "$CERT_USERNAME@$DOMAIN" \
-passin pass:"$PASSWORD" \
-passout pass:"$PASSWORD"
echo "Converted to encrypted p12 for macOS ✅"
# Convert to PKCS12 format without encryption
echo "Converting to non-encrypted PKCS12 format..."
openssl pkcs12 -export \
-out "$CLIENT_DIR/$FILE_PREFIX.p12" \
-inkey "$CLIENT_DIR/$FILE_PREFIX.key" \
-in "$CLIENT_DIR/$FILE_PREFIX.crt" \
-certfile "$CA_DIR/ca.crt" \
-name "$CERT_USERNAME@$DOMAIN" \
-passin pass:"$PASSWORD" \
-passout pass:""
echo "Converted to non-encrypted p12 ✅"
# Clean up intermediate files
rm "$CLIENT_DIR/$FILE_PREFIX.csr" "$CLIENT_DIR/$FILE_PREFIX.ext" "$CA_DIR/ca.srl"
echo
echo
echo "CA certificate: $CA_DIR/ca.crt"
echo "CA private key: $CA_DIR/ca.key"
echo "Client certificate: $CLIENT_DIR/$FILE_PREFIX.crt"
echo "Client private key: $CLIENT_DIR/$FILE_PREFIX.key"
echo "Client cert bundle: $CLIENT_DIR/$FILE_PREFIX.p12"
echo "Client cert bundle (encrypted): $CLIENT_DIR/$FILE_PREFIX.enc.p12"

View file

@ -2,16 +2,19 @@ package websocket
import ( import (
"bytes" "bytes"
"crypto/tls"
"crypto/x509"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"os"
"software.sslmate.com/src/go-pkcs12"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
@ -22,7 +25,6 @@ type Client struct {
handlers map[string]MessageHandler handlers map[string]MessageHandler
done chan struct{} done chan struct{}
handlersMux sync.RWMutex handlersMux sync.RWMutex
reconnectInterval time.Duration reconnectInterval time.Duration
isConnected bool isConnected bool
reconnectMux sync.RWMutex reconnectMux sync.RWMutex
@ -41,6 +43,12 @@ func WithBaseURL(url string) ClientOption {
} }
} }
func WithTLSConfig(tlsClientCertPath string) ClientOption {
return func(c *Client) {
c.config.TlsClientCert = tlsClientCertPath
}
}
func (c *Client) OnConnect(callback func() error) { func (c *Client) OnConnect(callback func() error) {
c.onConnect = callback c.onConnect = callback
} }
@ -63,9 +71,14 @@ func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*C
} }
// Apply options before loading config // Apply options before loading config
if opts != nil {
for _, opt := range opts { for _, opt := range opts {
if opt == nil {
continue
}
opt(client) opt(client)
} }
}
// Load existing config if available // Load existing config if available
if err := client.loadConfig(); err != nil { if err := client.loadConfig(); err != nil {
@ -149,6 +162,14 @@ func (c *Client) getToken() (string, error) {
// Ensure we have the base URL without trailing slashes // Ensure we have the base URL without trailing slashes
baseEndpoint := strings.TrimRight(baseURL.String(), "/") baseEndpoint := strings.TrimRight(baseURL.String(), "/")
var tlsConfig *tls.Config = nil
if c.config.TlsClientCert != "" {
tlsConfig, err = loadClientCertificate(c.config.TlsClientCert)
if err != nil {
return "", fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
}
}
// If we already have a token, try to use it // If we already have a token, try to use it
if c.config.Token != "" { if c.config.Token != "" {
tokenCheckData := map[string]interface{}{ tokenCheckData := map[string]interface{}{
@ -177,6 +198,11 @@ func (c *Client) getToken() (string, error) {
// Make the request // Make the request
client := &http.Client{} client := &http.Client{}
if tlsConfig != nil {
client.Transport = &http.Transport{
TLSClientConfig: tlsConfig,
}
}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to check token validity: %w", err) return "", fmt.Errorf("failed to check token validity: %w", err)
@ -220,6 +246,11 @@ func (c *Client) getToken() (string, error) {
// Make the request // Make the request
client := &http.Client{} client := &http.Client{}
if tlsConfig != nil {
client.Transport = &http.Transport{
TLSClientConfig: tlsConfig,
}
}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to request new token: %w", err) return "", fmt.Errorf("failed to request new token: %w", err)
@ -295,7 +326,16 @@ func (c *Client) establishConnection() error {
u.RawQuery = q.Encode() u.RawQuery = q.Encode()
// Connect to WebSocket // Connect to WebSocket
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) dialer := websocket.DefaultDialer
if c.config.TlsClientCert != "" {
logger.Info("Adding tls to req")
tlsConfig, err := loadClientCertificate(c.config.TlsClientCert)
if err != nil {
return fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
}
dialer.TLSClientConfig = tlsConfig
}
conn, _, err := dialer.Dial(u.String(), nil)
if err != nil { if err != nil {
return fmt.Errorf("failed to connect to WebSocket: %w", err) return fmt.Errorf("failed to connect to WebSocket: %w", err)
} }
@ -353,3 +393,42 @@ func (c *Client) setConnected(status bool) {
defer c.reconnectMux.Unlock() defer c.reconnectMux.Unlock()
c.isConnected = status c.isConnected = status
} }
// LoadClientCertificate Helper method to load client certificates
func loadClientCertificate(p12Path string) (*tls.Config, error) {
logger.Info("Loading tls-client-cert %s", p12Path)
// Read the PKCS12 file
p12Data, err := os.ReadFile(p12Path)
if err != nil {
return nil, fmt.Errorf("failed to read PKCS12 file: %w", err)
}
// Parse PKCS12 with empty password for non-encrypted files
privateKey, certificate, caCerts, err := pkcs12.DecodeChain(p12Data, "")
if err != nil {
return nil, fmt.Errorf("failed to decode PKCS12: %w", err)
}
// Create certificate
cert := tls.Certificate{
Certificate: [][]byte{certificate.Raw},
PrivateKey: privateKey,
}
// Optional: Add CA certificates if present
rootCAs, err := x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("failed to load system cert pool: %w", err)
}
if len(caCerts) > 0 {
for _, caCert := range caCerts {
rootCAs.AddCert(caCert)
}
}
// Create TLS configuration
return &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: rootCAs,
}, nil
}

View file

@ -54,6 +54,9 @@ func (c *Client) loadConfig() error {
if c.config.Secret == "" { if c.config.Secret == "" {
c.config.Secret = config.Secret c.config.Secret = config.Secret
} }
if c.config.TlsClientCert == "" {
c.config.TlsClientCert = config.TlsClientCert
}
if c.config.Endpoint == "" { if c.config.Endpoint == "" {
c.config.Endpoint = config.Endpoint c.config.Endpoint = config.Endpoint
c.baseURL = config.Endpoint c.baseURL = config.Endpoint

View file

@ -5,6 +5,7 @@ type Config struct {
Secret string `json:"secret"` Secret string `json:"secret"`
Token string `json:"token"` Token string `json:"token"`
Endpoint string `json:"endpoint"` Endpoint string `json:"endpoint"`
TlsClientCert string `json:"tlsClientCert"`
} }
type TokenResponse struct { type TokenResponse struct {