mirror of
https://github.com/fosrl/newt.git
synced 2025-05-12 21:20:39 +01:00
Merge pull request #28 from fosrl/dev
MTLS, Connection Monitoring, time zone logger
This commit is contained in:
commit
a1a439c75c
10 changed files with 432 additions and 67 deletions
5
.gitignore
vendored
5
.gitignore
vendored
|
@ -1,3 +1,6 @@
|
||||||
newt
|
newt
|
||||||
.DS_Store
|
.DS_Store
|
||||||
bin/
|
bin/
|
||||||
|
.idea
|
||||||
|
*.iml
|
||||||
|
certs/
|
37
README.md
37
README.md
|
@ -37,8 +37,9 @@ When Newt receives WireGuard control messages, it will use the information encod
|
||||||
- `dns`: DNS server to use to resolve the endpoint
|
- `dns`: DNS server to use to resolve the endpoint
|
||||||
- `log-level` (optional): The log level to use. Default: INFO
|
- `log-level` (optional): The log level to use. Default: INFO
|
||||||
- `updown` (optional): A script to be called when targets are added or removed.
|
- `updown` (optional): A script to be called when targets are added or removed.
|
||||||
|
- `tls-client-cert` (optional): Client certificate (p12 or pfx) for mTLS. See [mTLS](#mtls)
|
||||||
Example:
|
|
||||||
|
- Example:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./newt \
|
./newt \
|
||||||
|
@ -107,6 +108,38 @@ Returning a string from the script in the format of a target (`ip:dst` so `10.0.
|
||||||
|
|
||||||
You can look at updown.py as a reference script to get started!
|
You can look at updown.py as a reference script to get started!
|
||||||
|
|
||||||
|
### mTLS
|
||||||
|
Newt supports mutual TLS (mTLS) authentication, if the server has been configured to request a client certificate.
|
||||||
|
* Only PKCS12 (.p12 or .pfx) file format is accepted
|
||||||
|
* The PKCS12 file must contain:
|
||||||
|
* Private key
|
||||||
|
* Public certificate
|
||||||
|
* CA certificate
|
||||||
|
* Encrypted PKCS12 files are currently not supported
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./newt \
|
||||||
|
--id 31frd0uzbjvp721 \
|
||||||
|
--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
|
||||||
|
--endpoint https://example.com \
|
||||||
|
--tls-client-cert ./client.p12
|
||||||
|
```
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
services:
|
||||||
|
newt:
|
||||||
|
image: fosrl/newt
|
||||||
|
container_name: newt
|
||||||
|
restart: unless-stopped
|
||||||
|
environment:
|
||||||
|
- PANGOLIN_ENDPOINT=https://example.com
|
||||||
|
- NEWT_ID=2ix2t8xk22ubpfy
|
||||||
|
- NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2
|
||||||
|
- TLS_CLIENT_CERT=./client.p12
|
||||||
|
```
|
||||||
|
|
||||||
## Build
|
## Build
|
||||||
|
|
||||||
### Container
|
### Container
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -10,6 +10,7 @@ require (
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259
|
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259
|
||||||
|
software.sslmate.com/src/go-pkcs12 v0.5.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -20,3 +20,5 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvY
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
|
||||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
|
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
|
||||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
|
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
|
||||||
|
software.sslmate.com/src/go-pkcs12 v0.5.0 h1:EC6R394xgENTpZ4RltKydeDUjtlM5drOYIG9c6TVj2M=
|
||||||
|
software.sslmate.com/src/go-pkcs12 v0.5.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
|
||||||
|
|
|
@ -53,7 +53,23 @@ func (l *Logger) log(level LogLevel, format string, args ...interface{}) {
|
||||||
if level < l.level {
|
if level < l.level {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
timestamp := time.Now().Format("2006/01/02 15:04:05")
|
|
||||||
|
// Get timezone from environment variable or use local timezone
|
||||||
|
timezone := os.Getenv("LOGGER_TIMEZONE")
|
||||||
|
var location *time.Location
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if timezone != "" {
|
||||||
|
location, err = time.LoadLocation(timezone)
|
||||||
|
if err != nil {
|
||||||
|
// If invalid timezone, fall back to local
|
||||||
|
location = time.Local
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
location = time.Local
|
||||||
|
}
|
||||||
|
|
||||||
|
timestamp := time.Now().In(location).Format("2006/01/02 15:04:05")
|
||||||
message := fmt.Sprintf(format, args...)
|
message := fmt.Sprintf(format, args...)
|
||||||
l.logger.Printf("%s: %s %s", level.String(), timestamp, message)
|
l.logger.Printf("%s: %s %s", level.String(), timestamp, message)
|
||||||
}
|
}
|
||||||
|
|
198
main.go
198
main.go
|
@ -115,7 +115,12 @@ func ping(tnet *netstack.Net, dst string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}) {
|
func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}) {
|
||||||
ticker := time.NewTicker(10 * time.Second)
|
initialInterval := 10 * time.Second
|
||||||
|
maxInterval := 60 * time.Second
|
||||||
|
currentInterval := initialInterval
|
||||||
|
consecutiveFailures := 0
|
||||||
|
|
||||||
|
ticker := time.NewTicker(currentInterval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -124,8 +129,34 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{})
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
err := ping(tnet, serverIP)
|
err := ping(tnet, serverIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn("Periodic ping failed: %v", err)
|
consecutiveFailures++
|
||||||
|
logger.Warn("Periodic ping failed (%d consecutive failures): %v",
|
||||||
|
consecutiveFailures, err)
|
||||||
logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?")
|
logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?")
|
||||||
|
|
||||||
|
// Increase interval if we have consistent failures, with a maximum cap
|
||||||
|
if consecutiveFailures >= 3 && currentInterval < maxInterval {
|
||||||
|
// Increase by 50% each time, up to the maximum
|
||||||
|
currentInterval = time.Duration(float64(currentInterval) * 1.5)
|
||||||
|
if currentInterval > maxInterval {
|
||||||
|
currentInterval = maxInterval
|
||||||
|
}
|
||||||
|
ticker.Reset(currentInterval)
|
||||||
|
logger.Info("Increased ping check interval to %v due to consecutive failures",
|
||||||
|
currentInterval)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// On success, if we've backed off, gradually return to normal interval
|
||||||
|
if currentInterval > initialInterval {
|
||||||
|
currentInterval = time.Duration(float64(currentInterval) * 0.8)
|
||||||
|
if currentInterval < initialInterval {
|
||||||
|
currentInterval = initialInterval
|
||||||
|
}
|
||||||
|
ticker.Reset(currentInterval)
|
||||||
|
logger.Info("Decreased ping check interval to %v after successful ping",
|
||||||
|
currentInterval)
|
||||||
|
}
|
||||||
|
consecutiveFailures = 0
|
||||||
}
|
}
|
||||||
case <-stopChan:
|
case <-stopChan:
|
||||||
logger.Info("Stopping ping check")
|
logger.Info("Stopping ping check")
|
||||||
|
@ -135,34 +166,97 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{})
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Function to track connection status and trigger reconnection as needed
|
||||||
|
func monitorConnectionStatus(tnet *netstack.Net, serverIP string, client *websocket.Client) {
|
||||||
|
const checkInterval = 30 * time.Second
|
||||||
|
connectionLost := false
|
||||||
|
ticker := time.NewTicker(checkInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
// Try a ping to see if connection is alive
|
||||||
|
err := ping(tnet, serverIP)
|
||||||
|
|
||||||
|
if err != nil && !connectionLost {
|
||||||
|
// We just lost connection
|
||||||
|
connectionLost = true
|
||||||
|
logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.")
|
||||||
|
|
||||||
|
// Notify the user they might need to check their network
|
||||||
|
logger.Warn("Please check your internet connection and ensure the Pangolin server is online.")
|
||||||
|
logger.Warn("Newt will continue reconnection attempts automatically when connectivity is restored.")
|
||||||
|
} else if err == nil && connectionLost {
|
||||||
|
// Connection has been restored
|
||||||
|
connectionLost = false
|
||||||
|
logger.Info("Connection to server restored!")
|
||||||
|
|
||||||
|
// Tell the server we're back
|
||||||
|
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
||||||
|
"publicKey": fmt.Sprintf("%s", privateKey.PublicKey()),
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to send registration message after reconnection: %v", err)
|
||||||
|
} else {
|
||||||
|
logger.Info("Successfully re-registered with server after reconnection")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func pingWithRetry(tnet *netstack.Net, dst string) error {
|
func pingWithRetry(tnet *netstack.Net, dst string) error {
|
||||||
const (
|
const (
|
||||||
maxAttempts = 15
|
initialMaxAttempts = 15
|
||||||
retryDelay = 2 * time.Second
|
initialRetryDelay = 2 * time.Second
|
||||||
|
maxRetryDelay = 60 * time.Second // Cap the maximum delay
|
||||||
)
|
)
|
||||||
|
|
||||||
var lastErr error
|
attempt := 1
|
||||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
retryDelay := initialRetryDelay
|
||||||
logger.Info("Ping attempt %d of %d", attempt, maxAttempts)
|
|
||||||
|
|
||||||
if err := ping(tnet, dst); err != nil {
|
|
||||||
lastErr = err
|
|
||||||
logger.Warn("Ping attempt %d failed: %v", attempt, err)
|
|
||||||
|
|
||||||
if attempt < maxAttempts {
|
|
||||||
time.Sleep(retryDelay)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return fmt.Errorf("all ping attempts failed after %d tries, last error: %w",
|
|
||||||
maxAttempts, lastErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// First try with the initial parameters
|
||||||
|
logger.Info("Ping attempt %d", attempt)
|
||||||
|
if err := ping(tnet, dst); err == nil {
|
||||||
// Successful ping
|
// Successful ping
|
||||||
return nil
|
return nil
|
||||||
|
} else {
|
||||||
|
logger.Warn("Ping attempt %d failed: %v", attempt, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// This shouldn't be reached due to the return in the loop, but added for completeness
|
// Start a goroutine that will attempt pings indefinitely with increasing delays
|
||||||
return fmt.Errorf("unexpected error: all ping attempts failed")
|
go func() {
|
||||||
|
attempt = 2 // Continue from attempt 2
|
||||||
|
|
||||||
|
for {
|
||||||
|
logger.Info("Ping attempt %d", attempt)
|
||||||
|
|
||||||
|
if err := ping(tnet, dst); err != nil {
|
||||||
|
logger.Warn("Ping attempt %d failed: %v", attempt, err)
|
||||||
|
|
||||||
|
// Increase delay after certain thresholds but cap it
|
||||||
|
if attempt%5 == 0 && retryDelay < maxRetryDelay {
|
||||||
|
retryDelay = time.Duration(float64(retryDelay) * 1.5)
|
||||||
|
if retryDelay > maxRetryDelay {
|
||||||
|
retryDelay = maxRetryDelay
|
||||||
|
}
|
||||||
|
logger.Info("Increasing ping retry delay to %v", retryDelay)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(retryDelay)
|
||||||
|
attempt++
|
||||||
|
} else {
|
||||||
|
// Successful ping
|
||||||
|
logger.Info("Ping succeeded after %d attempts", attempt)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Return an error for the first batch of attempts (to maintain compatibility with existing code)
|
||||||
|
return fmt.Errorf("initial ping attempts failed, continuing in background")
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseLogLevel(level string) logger.LogLevel {
|
func parseLogLevel(level string) logger.LogLevel {
|
||||||
|
@ -246,16 +340,17 @@ func resolveDomain(domain string) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
endpoint string
|
endpoint string
|
||||||
id string
|
id string
|
||||||
secret string
|
secret string
|
||||||
mtu string
|
mtu string
|
||||||
mtuInt int
|
mtuInt int
|
||||||
dns string
|
dns string
|
||||||
privateKey wgtypes.Key
|
privateKey wgtypes.Key
|
||||||
err error
|
err error
|
||||||
logLevel string
|
logLevel string
|
||||||
updownScript string
|
updownScript string
|
||||||
|
tlsPrivateKey string
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
@ -267,6 +362,7 @@ func main() {
|
||||||
dns = os.Getenv("DNS")
|
dns = os.Getenv("DNS")
|
||||||
logLevel = os.Getenv("LOG_LEVEL")
|
logLevel = os.Getenv("LOG_LEVEL")
|
||||||
updownScript = os.Getenv("UPDOWN_SCRIPT")
|
updownScript = os.Getenv("UPDOWN_SCRIPT")
|
||||||
|
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
|
||||||
|
|
||||||
if endpoint == "" {
|
if endpoint == "" {
|
||||||
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
|
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
|
||||||
|
@ -289,6 +385,9 @@ func main() {
|
||||||
if updownScript == "" {
|
if updownScript == "" {
|
||||||
flag.StringVar(&updownScript, "updown", "", "Path to updown script to be called when targets are added or removed")
|
flag.StringVar(&updownScript, "updown", "", "Path to updown script to be called when targets are added or removed")
|
||||||
}
|
}
|
||||||
|
if tlsPrivateKey == "" {
|
||||||
|
flag.StringVar(&tlsPrivateKey, "tls-client-cert", "", "Path to client certificate used for mTLS")
|
||||||
|
}
|
||||||
|
|
||||||
// do a --version check
|
// do a --version check
|
||||||
version := flag.Bool("version", false, "Print the version")
|
version := flag.Bool("version", false, "Print the version")
|
||||||
|
@ -314,12 +413,16 @@ func main() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to generate private key: %v", err)
|
logger.Fatal("Failed to generate private key: %v", err)
|
||||||
}
|
}
|
||||||
|
var opt websocket.ClientOption
|
||||||
|
if tlsPrivateKey != "" {
|
||||||
|
opt = websocket.WithTLSConfig(tlsPrivateKey)
|
||||||
|
}
|
||||||
// Create a new client
|
// Create a new client
|
||||||
client, err := websocket.NewClient(
|
client, err := websocket.NewClient(
|
||||||
id, // CLI arg takes precedence
|
id, // CLI arg takes precedence
|
||||||
secret, // CLI arg takes precedence
|
secret, // CLI arg takes precedence
|
||||||
endpoint,
|
endpoint,
|
||||||
|
opt,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to create client: %v", err)
|
logger.Fatal("Failed to create client: %v", err)
|
||||||
|
@ -353,13 +456,8 @@ func main() {
|
||||||
|
|
||||||
if connected {
|
if connected {
|
||||||
logger.Info("Already connected! But I will send a ping anyway...")
|
logger.Info("Already connected! But I will send a ping anyway...")
|
||||||
// ping(tnet, wgData.ServerIP)
|
// Even if pingWithRetry returns an error, it will continue trying in the background
|
||||||
err = pingWithRetry(tnet, wgData.ServerIP)
|
_ = pingWithRetry(tnet, wgData.ServerIP) // Ignoring initial error as pings will continue
|
||||||
if err != nil {
|
|
||||||
// Handle complete failure after all retries
|
|
||||||
logger.Warn("Failed to ping %s: %v", wgData.ServerIP, err)
|
|
||||||
logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?")
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -414,17 +512,18 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("WireGuard device created. Lets ping the server now...")
|
logger.Info("WireGuard device created. Lets ping the server now...")
|
||||||
// Ping to bring the tunnel up on the server side quickly
|
|
||||||
// ping(tnet, wgData.ServerIP)
|
|
||||||
err = pingWithRetry(tnet, wgData.ServerIP)
|
|
||||||
if err != nil {
|
|
||||||
// Handle complete failure after all retries
|
|
||||||
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// Even if pingWithRetry returns an error, it will continue trying in the background
|
||||||
|
_ = pingWithRetry(tnet, wgData.ServerIP)
|
||||||
|
|
||||||
|
// Always mark as connected and start the proxy manager regardless of initial ping result
|
||||||
|
// as the pings will continue in the background
|
||||||
if !connected {
|
if !connected {
|
||||||
logger.Info("Starting ping check")
|
logger.Info("Starting ping check")
|
||||||
startPingCheck(tnet, wgData.ServerIP, pingStopChan)
|
startPingCheck(tnet, wgData.ServerIP, pingStopChan)
|
||||||
|
|
||||||
|
// Start connection monitoring in a separate goroutine
|
||||||
|
go monitorConnectionStatus(tnet, wgData.ServerIP, client)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create proxy manager
|
// Create proxy manager
|
||||||
|
@ -552,10 +651,13 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||||
// Wait for interrupt signal
|
// Wait for interrupt signal
|
||||||
sigCh := make(chan os.Signal, 1)
|
sigCh := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||||
<-sigCh
|
sigReceived := <-sigCh
|
||||||
|
|
||||||
// Cleanup
|
// Cleanup
|
||||||
dev.Close()
|
logger.Info("Received %s signal, stopping", sigReceived.String())
|
||||||
|
if dev != nil {
|
||||||
|
dev.Close()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseTargetData(data interface{}) (TargetData, error) {
|
func parseTargetData(data interface{}) (TargetData, error) {
|
||||||
|
|
125
self-signed-certs-for-mtls.sh
Executable file
125
self-signed-certs-for-mtls.sh
Executable file
|
@ -0,0 +1,125 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
echo -n "Enter username for certs (eg alice): "
|
||||||
|
read CERT_USERNAME
|
||||||
|
echo
|
||||||
|
|
||||||
|
echo -n "Enter domain of user (eg example.com): "
|
||||||
|
read DOMAIN
|
||||||
|
echo
|
||||||
|
|
||||||
|
# Prompt for password at the start
|
||||||
|
echo -n "Enter password for certificate: "
|
||||||
|
read -s PASSWORD
|
||||||
|
echo
|
||||||
|
echo -n "Confirm password: "
|
||||||
|
read -s PASSWORD2
|
||||||
|
echo
|
||||||
|
|
||||||
|
if [ "$PASSWORD" != "$PASSWORD2" ]; then
|
||||||
|
echo "Passwords don't match!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
CA_DIR="./certs/ca"
|
||||||
|
CLIENT_DIR="./certs/clients"
|
||||||
|
FILE_PREFIX=$(echo "$CERT_USERNAME-at-$DOMAIN" | sed 's/\./-/')
|
||||||
|
|
||||||
|
mkdir -p "$CA_DIR"
|
||||||
|
mkdir -p "$CLIENT_DIR"
|
||||||
|
|
||||||
|
if [ ! -f "$CA_DIR/ca.crt" ]; then
|
||||||
|
# Generate CA private key
|
||||||
|
openssl genrsa -out "$CA_DIR/ca.key" 4096
|
||||||
|
echo "CA key ✅"
|
||||||
|
|
||||||
|
# Generate CA root certificate
|
||||||
|
openssl req -x509 -new -nodes \
|
||||||
|
-key "$CA_DIR/ca.key" \
|
||||||
|
-sha256 \
|
||||||
|
-days 3650 \
|
||||||
|
-out "$CA_DIR/ca.crt" \
|
||||||
|
-subj "/C=US/ST=State/L=City/O=Organization/OU=Unit/CN=ca.$DOMAIN"
|
||||||
|
|
||||||
|
echo "CA cert ✅"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Generate client private key
|
||||||
|
openssl genrsa -aes256 -passout pass:"$PASSWORD" -out "$CLIENT_DIR/$FILE_PREFIX.key" 2048
|
||||||
|
echo "Client key ✅"
|
||||||
|
|
||||||
|
# Generate client Certificate Signing Request (CSR)
|
||||||
|
openssl req -new \
|
||||||
|
-key "$CLIENT_DIR/$FILE_PREFIX.key" \
|
||||||
|
-out "$CLIENT_DIR/$FILE_PREFIX.csr" \
|
||||||
|
-passin pass:"$PASSWORD" \
|
||||||
|
-subj "/C=US/ST=State/L=City/O=Organization/OU=Unit/CN=$CERT_USERNAME@$DOMAIN"
|
||||||
|
echo "Client cert ✅"
|
||||||
|
|
||||||
|
echo -n "Signing client cert..."
|
||||||
|
# Create client certificate configuration file
|
||||||
|
cat > "$CLIENT_DIR/$FILE_PREFIX.ext" << EOF
|
||||||
|
authorityKeyIdentifier=keyid,issuer
|
||||||
|
basicConstraints=CA:FALSE
|
||||||
|
keyUsage = digitalSignature, nonRepudiation, keyEncipherment, dataEncipherment
|
||||||
|
subjectAltName = @alt_names
|
||||||
|
|
||||||
|
[alt_names]
|
||||||
|
DNS.1 = $DOMAIN
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# Generate client certificate signed by CA
|
||||||
|
openssl x509 -req \
|
||||||
|
-in "$CLIENT_DIR/$FILE_PREFIX.csr" \
|
||||||
|
-CA "$CA_DIR/ca.crt" \
|
||||||
|
-CAkey "$CA_DIR/ca.key" \
|
||||||
|
-CAcreateserial \
|
||||||
|
-out "$CLIENT_DIR/$FILE_PREFIX.crt" \
|
||||||
|
-days 365 \
|
||||||
|
-sha256 \
|
||||||
|
-extfile "$CLIENT_DIR/$FILE_PREFIX.ext"
|
||||||
|
|
||||||
|
# Verify the client certificate
|
||||||
|
openssl verify -CAfile "$CA_DIR/ca.crt" "$CLIENT_DIR/$FILE_PREFIX.crt"
|
||||||
|
echo "Signed ✅"
|
||||||
|
|
||||||
|
# Create encrypted PEM bundle
|
||||||
|
openssl rsa -in "$CLIENT_DIR/$FILE_PREFIX.key" -passin pass:"$PASSWORD" \
|
||||||
|
| cat "$CLIENT_DIR/$FILE_PREFIX.crt" - > "$CLIENT_DIR/$FILE_PREFIX-bundle.enc.pem"
|
||||||
|
|
||||||
|
|
||||||
|
# Convert to PKCS12
|
||||||
|
echo "Converting to PKCS12 format..."
|
||||||
|
openssl pkcs12 -export \
|
||||||
|
-out "$CLIENT_DIR/$FILE_PREFIX.enc.p12" \
|
||||||
|
-inkey "$CLIENT_DIR/$FILE_PREFIX.key" \
|
||||||
|
-in "$CLIENT_DIR/$FILE_PREFIX.crt" \
|
||||||
|
-certfile "$CA_DIR/ca.crt" \
|
||||||
|
-name "$CERT_USERNAME@$DOMAIN" \
|
||||||
|
-passin pass:"$PASSWORD" \
|
||||||
|
-passout pass:"$PASSWORD"
|
||||||
|
echo "Converted to encrypted p12 for macOS ✅"
|
||||||
|
|
||||||
|
# Convert to PKCS12 format without encryption
|
||||||
|
echo "Converting to non-encrypted PKCS12 format..."
|
||||||
|
openssl pkcs12 -export \
|
||||||
|
-out "$CLIENT_DIR/$FILE_PREFIX.p12" \
|
||||||
|
-inkey "$CLIENT_DIR/$FILE_PREFIX.key" \
|
||||||
|
-in "$CLIENT_DIR/$FILE_PREFIX.crt" \
|
||||||
|
-certfile "$CA_DIR/ca.crt" \
|
||||||
|
-name "$CERT_USERNAME@$DOMAIN" \
|
||||||
|
-passin pass:"$PASSWORD" \
|
||||||
|
-passout pass:""
|
||||||
|
echo "Converted to non-encrypted p12 ✅"
|
||||||
|
|
||||||
|
# Clean up intermediate files
|
||||||
|
rm "$CLIENT_DIR/$FILE_PREFIX.csr" "$CLIENT_DIR/$FILE_PREFIX.ext" "$CA_DIR/ca.srl"
|
||||||
|
echo
|
||||||
|
echo
|
||||||
|
|
||||||
|
echo "CA certificate: $CA_DIR/ca.crt"
|
||||||
|
echo "CA private key: $CA_DIR/ca.key"
|
||||||
|
echo "Client certificate: $CLIENT_DIR/$FILE_PREFIX.crt"
|
||||||
|
echo "Client private key: $CLIENT_DIR/$FILE_PREFIX.key"
|
||||||
|
echo "Client cert bundle: $CLIENT_DIR/$FILE_PREFIX.p12"
|
||||||
|
echo "Client cert bundle (encrypted): $CLIENT_DIR/$FILE_PREFIX.enc.p12"
|
|
@ -2,27 +2,29 @@ package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"software.sslmate.com/src/go-pkcs12"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
conn *websocket.Conn
|
conn *websocket.Conn
|
||||||
config *Config
|
config *Config
|
||||||
baseURL string
|
baseURL string
|
||||||
handlers map[string]MessageHandler
|
handlers map[string]MessageHandler
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
handlersMux sync.RWMutex
|
handlersMux sync.RWMutex
|
||||||
|
|
||||||
reconnectInterval time.Duration
|
reconnectInterval time.Duration
|
||||||
isConnected bool
|
isConnected bool
|
||||||
reconnectMux sync.RWMutex
|
reconnectMux sync.RWMutex
|
||||||
|
@ -41,6 +43,12 @@ func WithBaseURL(url string) ClientOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithTLSConfig(tlsClientCertPath string) ClientOption {
|
||||||
|
return func(c *Client) {
|
||||||
|
c.config.TlsClientCert = tlsClientCertPath
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) OnConnect(callback func() error) {
|
func (c *Client) OnConnect(callback func() error) {
|
||||||
c.onConnect = callback
|
c.onConnect = callback
|
||||||
}
|
}
|
||||||
|
@ -63,8 +71,13 @@ func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*C
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply options before loading config
|
// Apply options before loading config
|
||||||
for _, opt := range opts {
|
if opts != nil {
|
||||||
opt(client)
|
for _, opt := range opts {
|
||||||
|
if opt == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
opt(client)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load existing config if available
|
// Load existing config if available
|
||||||
|
@ -149,6 +162,14 @@ func (c *Client) getToken() (string, error) {
|
||||||
// Ensure we have the base URL without trailing slashes
|
// Ensure we have the base URL without trailing slashes
|
||||||
baseEndpoint := strings.TrimRight(baseURL.String(), "/")
|
baseEndpoint := strings.TrimRight(baseURL.String(), "/")
|
||||||
|
|
||||||
|
var tlsConfig *tls.Config = nil
|
||||||
|
if c.config.TlsClientCert != "" {
|
||||||
|
tlsConfig, err = loadClientCertificate(c.config.TlsClientCert)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If we already have a token, try to use it
|
// If we already have a token, try to use it
|
||||||
if c.config.Token != "" {
|
if c.config.Token != "" {
|
||||||
tokenCheckData := map[string]interface{}{
|
tokenCheckData := map[string]interface{}{
|
||||||
|
@ -177,6 +198,11 @@ func (c *Client) getToken() (string, error) {
|
||||||
|
|
||||||
// Make the request
|
// Make the request
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
|
if tlsConfig != nil {
|
||||||
|
client.Transport = &http.Transport{
|
||||||
|
TLSClientConfig: tlsConfig,
|
||||||
|
}
|
||||||
|
}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to check token validity: %w", err)
|
return "", fmt.Errorf("failed to check token validity: %w", err)
|
||||||
|
@ -220,6 +246,11 @@ func (c *Client) getToken() (string, error) {
|
||||||
|
|
||||||
// Make the request
|
// Make the request
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
|
if tlsConfig != nil {
|
||||||
|
client.Transport = &http.Transport{
|
||||||
|
TLSClientConfig: tlsConfig,
|
||||||
|
}
|
||||||
|
}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to request new token: %w", err)
|
return "", fmt.Errorf("failed to request new token: %w", err)
|
||||||
|
@ -295,7 +326,16 @@ func (c *Client) establishConnection() error {
|
||||||
u.RawQuery = q.Encode()
|
u.RawQuery = q.Encode()
|
||||||
|
|
||||||
// Connect to WebSocket
|
// Connect to WebSocket
|
||||||
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
dialer := websocket.DefaultDialer
|
||||||
|
if c.config.TlsClientCert != "" {
|
||||||
|
logger.Info("Adding tls to req")
|
||||||
|
tlsConfig, err := loadClientCertificate(c.config.TlsClientCert)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
|
||||||
|
}
|
||||||
|
dialer.TLSClientConfig = tlsConfig
|
||||||
|
}
|
||||||
|
conn, _, err := dialer.Dial(u.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to connect to WebSocket: %w", err)
|
return fmt.Errorf("failed to connect to WebSocket: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -353,3 +393,42 @@ func (c *Client) setConnected(status bool) {
|
||||||
defer c.reconnectMux.Unlock()
|
defer c.reconnectMux.Unlock()
|
||||||
c.isConnected = status
|
c.isConnected = status
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LoadClientCertificate Helper method to load client certificates
|
||||||
|
func loadClientCertificate(p12Path string) (*tls.Config, error) {
|
||||||
|
logger.Info("Loading tls-client-cert %s", p12Path)
|
||||||
|
// Read the PKCS12 file
|
||||||
|
p12Data, err := os.ReadFile(p12Path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read PKCS12 file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse PKCS12 with empty password for non-encrypted files
|
||||||
|
privateKey, certificate, caCerts, err := pkcs12.DecodeChain(p12Data, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode PKCS12: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create certificate
|
||||||
|
cert := tls.Certificate{
|
||||||
|
Certificate: [][]byte{certificate.Raw},
|
||||||
|
PrivateKey: privateKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional: Add CA certificates if present
|
||||||
|
rootCAs, err := x509.SystemCertPool()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load system cert pool: %w", err)
|
||||||
|
}
|
||||||
|
if len(caCerts) > 0 {
|
||||||
|
for _, caCert := range caCerts {
|
||||||
|
rootCAs.AddCert(caCert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create TLS configuration
|
||||||
|
return &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
RootCAs: rootCAs,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
|
@ -54,6 +54,9 @@ func (c *Client) loadConfig() error {
|
||||||
if c.config.Secret == "" {
|
if c.config.Secret == "" {
|
||||||
c.config.Secret = config.Secret
|
c.config.Secret = config.Secret
|
||||||
}
|
}
|
||||||
|
if c.config.TlsClientCert == "" {
|
||||||
|
c.config.TlsClientCert = config.TlsClientCert
|
||||||
|
}
|
||||||
if c.config.Endpoint == "" {
|
if c.config.Endpoint == "" {
|
||||||
c.config.Endpoint = config.Endpoint
|
c.config.Endpoint = config.Endpoint
|
||||||
c.baseURL = config.Endpoint
|
c.baseURL = config.Endpoint
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
NewtID string `json:"newtId"`
|
NewtID string `json:"newtId"`
|
||||||
Secret string `json:"secret"`
|
Secret string `json:"secret"`
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
Endpoint string `json:"endpoint"`
|
Endpoint string `json:"endpoint"`
|
||||||
|
TlsClientCert string `json:"tlsClientCert"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenResponse struct {
|
type TokenResponse struct {
|
||||||
|
|
Loading…
Reference in a new issue