feat/mtls-support

This commit is contained in:
progressive-kiwi 2025-03-31 00:06:40 +02:00
parent 2ff8df9a8d
commit 9b3c82648b
5 changed files with 127 additions and 14 deletions

View file

@ -37,8 +37,9 @@ When Newt receives WireGuard control messages, it will use the information encod
- `dns`: DNS server to use to resolve the endpoint
- `log-level` (optional): The log level to use. Default: INFO
- `updown` (optional): A script to be called when targets are added or removed.
Example:
- `tls-client-cert` (optional): Client certificate (p12 or pfx) for mTLS. See [mTLS](#mtls)
- Example:
```bash
./newt \
@ -107,6 +108,38 @@ Returning a string from the script in the format of a target (`ip:dst` so `10.0.
You can look at updown.py as a reference script to get started!
### mTLS
Newt supports mutual TLS (mTLS) authentication, if the server has been configured to request a client certificate.
* Only PKCS12 (.p12 or .pfx) file format is accepted
* The PKCS12 file must contain:
* Private key
* Public certificate
* CA certificate
* Encrypted PKCS12 files are currently not supported
Examples:
```bash
./newt \
--id 31frd0uzbjvp721 \
--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
--endpoint https://example.com \
--tls-client-cert /client.p12
```
```yaml
services:
newt:
image: fosrl/newt
container_name: newt
restart: unless-stopped
environment:
- PANGOLIN_ENDPOINT=https://example.com
- NEWT_ID=2ix2t8xk22ubpfy
- NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2
- TLS_CLIENT_CERT=/client.p12
```
## Build
### Container

1
go.mod
View file

@ -10,6 +10,7 @@ require (
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259
software.sslmate.com/src/go-pkcs12 v0.5.0
)
require (

2
go.sum
View file

@ -20,3 +20,5 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvY
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
software.sslmate.com/src/go-pkcs12 v0.5.0 h1:EC6R394xgENTpZ4RltKydeDUjtlM5drOYIG9c6TVj2M=
software.sslmate.com/src/go-pkcs12 v0.5.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=

34
main.go
View file

@ -246,16 +246,17 @@ func resolveDomain(domain string) (string, error) {
}
var (
endpoint string
id string
secret string
mtu string
mtuInt int
dns string
privateKey wgtypes.Key
err error
logLevel string
updownScript string
endpoint string
id string
secret string
mtu string
mtuInt int
dns string
privateKey wgtypes.Key
err error
logLevel string
updownScript string
tlsPrivateKey string
)
func main() {
@ -267,6 +268,7 @@ func main() {
dns = os.Getenv("DNS")
logLevel = os.Getenv("LOG_LEVEL")
updownScript = os.Getenv("UPDOWN_SCRIPT")
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
if endpoint == "" {
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
@ -289,6 +291,9 @@ func main() {
if updownScript == "" {
flag.StringVar(&updownScript, "updown", "", "Path to updown script to be called when targets are added or removed")
}
if tlsPrivateKey == "" {
flag.StringVar(&tlsPrivateKey, "tls-client-cert", "", "Path to client certificate used for mTLS")
}
// do a --version check
version := flag.Bool("version", false, "Print the version")
@ -314,12 +319,21 @@ func main() {
if err != nil {
logger.Fatal("Failed to generate private key: %v", err)
}
var opt websocket.ClientOption
if tlsPrivateKey != "" {
tlsConfig, err := websocket.LoadClientCertificate(tlsPrivateKey)
if err != nil {
logger.Fatal("Failed to load client certificate: %v", err)
}
opt = websocket.WithTLSConfig(tlsConfig)
}
// Create a new client
client, err := websocket.NewClient(
id, // CLI arg takes precedence
secret, // CLI arg takes precedence
endpoint,
opt,
)
if err != nil {
logger.Fatal("Failed to create client: %v", err)

View file

@ -2,16 +2,19 @@ package websocket
import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"software.sslmate.com/src/go-pkcs12"
"strings"
"sync"
"time"
"github.com/fosrl/newt/logger"
"github.com/gorilla/websocket"
)
@ -22,6 +25,7 @@ type Client struct {
handlers map[string]MessageHandler
done chan struct{}
handlersMux sync.RWMutex
tlsConfig *tls.Config
reconnectInterval time.Duration
isConnected bool
@ -41,6 +45,12 @@ func WithBaseURL(url string) ClientOption {
}
}
func WithTLSConfig(tlsConfig *tls.Config) ClientOption {
return func(c *Client) {
c.tlsConfig = tlsConfig
}
}
func (c *Client) OnConnect(callback func() error) {
c.onConnect = callback
}
@ -177,6 +187,12 @@ func (c *Client) getToken() (string, error) {
// Make the request
client := &http.Client{}
if c.tlsConfig != nil {
logger.Info("Adding tls to req")
client.Transport = &http.Transport{
TLSClientConfig: c.tlsConfig,
}
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to check token validity: %w", err)
@ -220,6 +236,11 @@ func (c *Client) getToken() (string, error) {
// Make the request
client := &http.Client{}
if c.tlsConfig != nil {
client.Transport = &http.Transport{
TLSClientConfig: c.tlsConfig,
}
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to request new token: %w", err)
@ -295,7 +316,11 @@ func (c *Client) establishConnection() error {
u.RawQuery = q.Encode()
// Connect to WebSocket
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
dialer := websocket.DefaultDialer
if c.tlsConfig != nil {
dialer.TLSClientConfig = c.tlsConfig
}
conn, _, err := dialer.Dial(u.String(), nil)
if err != nil {
return fmt.Errorf("failed to connect to WebSocket: %w", err)
}
@ -353,3 +378,41 @@ func (c *Client) setConnected(status bool) {
defer c.reconnectMux.Unlock()
c.isConnected = status
}
// LoadClientCertificate Helper method to load client certificates
func LoadClientCertificate(p12Path string) (*tls.Config, error) {
// Read the PKCS12 file
p12Data, err := os.ReadFile(p12Path)
if err != nil {
return nil, fmt.Errorf("failed to read PKCS12 file: %w", err)
}
// Parse PKCS12 with empty password for non-encrypted files
privateKey, certificate, caCerts, err := pkcs12.DecodeChain(p12Data, "")
if err != nil {
return nil, fmt.Errorf("failed to decode PKCS12: %w", err)
}
// Create certificate
cert := tls.Certificate{
Certificate: [][]byte{certificate.Raw},
PrivateKey: privateKey,
}
// Optional: Add CA certificates if present
rootCAs, err := x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("failed to load system cert pool: %w", err)
}
if len(caCerts) > 0 {
for _, caCert := range caCerts {
rootCAs.AddCert(caCert)
}
}
// Create TLS configuration
return &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: rootCAs,
}, nil
}