mirror of
https://github.com/fosrl/newt.git
synced 2025-05-15 06:30:38 +01:00
feat/mtls-support
This commit is contained in:
parent
2ff8df9a8d
commit
9b3c82648b
5 changed files with 127 additions and 14 deletions
35
README.md
35
README.md
|
@ -37,8 +37,9 @@ When Newt receives WireGuard control messages, it will use the information encod
|
||||||
- `dns`: DNS server to use to resolve the endpoint
|
- `dns`: DNS server to use to resolve the endpoint
|
||||||
- `log-level` (optional): The log level to use. Default: INFO
|
- `log-level` (optional): The log level to use. Default: INFO
|
||||||
- `updown` (optional): A script to be called when targets are added or removed.
|
- `updown` (optional): A script to be called when targets are added or removed.
|
||||||
|
- `tls-client-cert` (optional): Client certificate (p12 or pfx) for mTLS. See [mTLS](#mtls)
|
||||||
|
|
||||||
Example:
|
- Example:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./newt \
|
./newt \
|
||||||
|
@ -107,6 +108,38 @@ Returning a string from the script in the format of a target (`ip:dst` so `10.0.
|
||||||
|
|
||||||
You can look at updown.py as a reference script to get started!
|
You can look at updown.py as a reference script to get started!
|
||||||
|
|
||||||
|
### mTLS
|
||||||
|
Newt supports mutual TLS (mTLS) authentication, if the server has been configured to request a client certificate.
|
||||||
|
* Only PKCS12 (.p12 or .pfx) file format is accepted
|
||||||
|
* The PKCS12 file must contain:
|
||||||
|
* Private key
|
||||||
|
* Public certificate
|
||||||
|
* CA certificate
|
||||||
|
* Encrypted PKCS12 files are currently not supported
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./newt \
|
||||||
|
--id 31frd0uzbjvp721 \
|
||||||
|
--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
|
||||||
|
--endpoint https://example.com \
|
||||||
|
--tls-client-cert /client.p12
|
||||||
|
```
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
services:
|
||||||
|
newt:
|
||||||
|
image: fosrl/newt
|
||||||
|
container_name: newt
|
||||||
|
restart: unless-stopped
|
||||||
|
environment:
|
||||||
|
- PANGOLIN_ENDPOINT=https://example.com
|
||||||
|
- NEWT_ID=2ix2t8xk22ubpfy
|
||||||
|
- NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2
|
||||||
|
- TLS_CLIENT_CERT=/client.p12
|
||||||
|
```
|
||||||
|
|
||||||
## Build
|
## Build
|
||||||
|
|
||||||
### Container
|
### Container
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -10,6 +10,7 @@ require (
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259
|
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259
|
||||||
|
software.sslmate.com/src/go-pkcs12 v0.5.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -20,3 +20,5 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvY
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
|
||||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
|
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
|
||||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
|
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
|
||||||
|
software.sslmate.com/src/go-pkcs12 v0.5.0 h1:EC6R394xgENTpZ4RltKydeDUjtlM5drOYIG9c6TVj2M=
|
||||||
|
software.sslmate.com/src/go-pkcs12 v0.5.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
|
||||||
|
|
14
main.go
14
main.go
|
@ -256,6 +256,7 @@ var (
|
||||||
err error
|
err error
|
||||||
logLevel string
|
logLevel string
|
||||||
updownScript string
|
updownScript string
|
||||||
|
tlsPrivateKey string
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
@ -267,6 +268,7 @@ func main() {
|
||||||
dns = os.Getenv("DNS")
|
dns = os.Getenv("DNS")
|
||||||
logLevel = os.Getenv("LOG_LEVEL")
|
logLevel = os.Getenv("LOG_LEVEL")
|
||||||
updownScript = os.Getenv("UPDOWN_SCRIPT")
|
updownScript = os.Getenv("UPDOWN_SCRIPT")
|
||||||
|
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
|
||||||
|
|
||||||
if endpoint == "" {
|
if endpoint == "" {
|
||||||
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
|
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
|
||||||
|
@ -289,6 +291,9 @@ func main() {
|
||||||
if updownScript == "" {
|
if updownScript == "" {
|
||||||
flag.StringVar(&updownScript, "updown", "", "Path to updown script to be called when targets are added or removed")
|
flag.StringVar(&updownScript, "updown", "", "Path to updown script to be called when targets are added or removed")
|
||||||
}
|
}
|
||||||
|
if tlsPrivateKey == "" {
|
||||||
|
flag.StringVar(&tlsPrivateKey, "tls-client-cert", "", "Path to client certificate used for mTLS")
|
||||||
|
}
|
||||||
|
|
||||||
// do a --version check
|
// do a --version check
|
||||||
version := flag.Bool("version", false, "Print the version")
|
version := flag.Bool("version", false, "Print the version")
|
||||||
|
@ -314,12 +319,21 @@ func main() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to generate private key: %v", err)
|
logger.Fatal("Failed to generate private key: %v", err)
|
||||||
}
|
}
|
||||||
|
var opt websocket.ClientOption
|
||||||
|
if tlsPrivateKey != "" {
|
||||||
|
tlsConfig, err := websocket.LoadClientCertificate(tlsPrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal("Failed to load client certificate: %v", err)
|
||||||
|
}
|
||||||
|
opt = websocket.WithTLSConfig(tlsConfig)
|
||||||
|
}
|
||||||
|
|
||||||
// Create a new client
|
// Create a new client
|
||||||
client, err := websocket.NewClient(
|
client, err := websocket.NewClient(
|
||||||
id, // CLI arg takes precedence
|
id, // CLI arg takes precedence
|
||||||
secret, // CLI arg takes precedence
|
secret, // CLI arg takes precedence
|
||||||
endpoint,
|
endpoint,
|
||||||
|
opt,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to create client: %v", err)
|
logger.Fatal("Failed to create client: %v", err)
|
||||||
|
|
|
@ -2,16 +2,19 @@ package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"software.sslmate.com/src/go-pkcs12"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -22,6 +25,7 @@ type Client struct {
|
||||||
handlers map[string]MessageHandler
|
handlers map[string]MessageHandler
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
handlersMux sync.RWMutex
|
handlersMux sync.RWMutex
|
||||||
|
tlsConfig *tls.Config
|
||||||
|
|
||||||
reconnectInterval time.Duration
|
reconnectInterval time.Duration
|
||||||
isConnected bool
|
isConnected bool
|
||||||
|
@ -41,6 +45,12 @@ func WithBaseURL(url string) ClientOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithTLSConfig(tlsConfig *tls.Config) ClientOption {
|
||||||
|
return func(c *Client) {
|
||||||
|
c.tlsConfig = tlsConfig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) OnConnect(callback func() error) {
|
func (c *Client) OnConnect(callback func() error) {
|
||||||
c.onConnect = callback
|
c.onConnect = callback
|
||||||
}
|
}
|
||||||
|
@ -177,6 +187,12 @@ func (c *Client) getToken() (string, error) {
|
||||||
|
|
||||||
// Make the request
|
// Make the request
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
|
if c.tlsConfig != nil {
|
||||||
|
logger.Info("Adding tls to req")
|
||||||
|
client.Transport = &http.Transport{
|
||||||
|
TLSClientConfig: c.tlsConfig,
|
||||||
|
}
|
||||||
|
}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to check token validity: %w", err)
|
return "", fmt.Errorf("failed to check token validity: %w", err)
|
||||||
|
@ -220,6 +236,11 @@ func (c *Client) getToken() (string, error) {
|
||||||
|
|
||||||
// Make the request
|
// Make the request
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
|
if c.tlsConfig != nil {
|
||||||
|
client.Transport = &http.Transport{
|
||||||
|
TLSClientConfig: c.tlsConfig,
|
||||||
|
}
|
||||||
|
}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to request new token: %w", err)
|
return "", fmt.Errorf("failed to request new token: %w", err)
|
||||||
|
@ -295,7 +316,11 @@ func (c *Client) establishConnection() error {
|
||||||
u.RawQuery = q.Encode()
|
u.RawQuery = q.Encode()
|
||||||
|
|
||||||
// Connect to WebSocket
|
// Connect to WebSocket
|
||||||
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
dialer := websocket.DefaultDialer
|
||||||
|
if c.tlsConfig != nil {
|
||||||
|
dialer.TLSClientConfig = c.tlsConfig
|
||||||
|
}
|
||||||
|
conn, _, err := dialer.Dial(u.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to connect to WebSocket: %w", err)
|
return fmt.Errorf("failed to connect to WebSocket: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -353,3 +378,41 @@ func (c *Client) setConnected(status bool) {
|
||||||
defer c.reconnectMux.Unlock()
|
defer c.reconnectMux.Unlock()
|
||||||
c.isConnected = status
|
c.isConnected = status
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LoadClientCertificate Helper method to load client certificates
|
||||||
|
func LoadClientCertificate(p12Path string) (*tls.Config, error) {
|
||||||
|
// Read the PKCS12 file
|
||||||
|
p12Data, err := os.ReadFile(p12Path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read PKCS12 file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse PKCS12 with empty password for non-encrypted files
|
||||||
|
privateKey, certificate, caCerts, err := pkcs12.DecodeChain(p12Data, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode PKCS12: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create certificate
|
||||||
|
cert := tls.Certificate{
|
||||||
|
Certificate: [][]byte{certificate.Raw},
|
||||||
|
PrivateKey: privateKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional: Add CA certificates if present
|
||||||
|
rootCAs, err := x509.SystemCertPool()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load system cert pool: %w", err)
|
||||||
|
}
|
||||||
|
if len(caCerts) > 0 {
|
||||||
|
for _, caCert := range caCerts {
|
||||||
|
rootCAs.AddCert(caCert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create TLS configuration
|
||||||
|
return &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
RootCAs: rootCAs,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue