From d28e3ca5e8ae3e15f67902c1dd0eda9d56670fb1 Mon Sep 17 00:00:00 2001 From: progressive-kiwi Date: Wed, 2 Apr 2025 21:00:09 +0200 Subject: [PATCH] feat/mtls-support-cert: doc update, removing config.Endpoint loading duplicates, handling null-pointer case and some logging --- README.md | 4 ++-- main.go | 7 +++++-- websocket/client.go | 26 +++++++++++++------------- websocket/config.go | 4 ---- 4 files changed, 20 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 127a77c..7512476 100644 --- a/README.md +++ b/README.md @@ -124,7 +124,7 @@ Examples: --id 31frd0uzbjvp721 \ --secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \ --endpoint https://example.com \ ---tls-client-cert /client.p12 +--tls-client-cert ./client.p12 ``` ```yaml @@ -137,7 +137,7 @@ services: - PANGOLIN_ENDPOINT=https://example.com - NEWT_ID=2ix2t8xk22ubpfy - NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2 - - TLS_CLIENT_CERT=/client.p12 + - TLS_CLIENT_CERT=./client.p12 ``` ## Build diff --git a/main.go b/main.go index 4feb325..f0f65c7 100644 --- a/main.go +++ b/main.go @@ -561,10 +561,13 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( // Wait for interrupt signal sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - <-sigCh + sigReceived := <-sigCh // Cleanup - dev.Close() + logger.Info("Received %s signal, stopping", sigReceived.String()) + if dev != nil { + dev.Close() + } } func parseTargetData(data interface{}) (TargetData, error) { diff --git a/websocket/client.go b/websocket/client.go index 894e3bc..3d221e1 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -162,6 +162,14 @@ func (c *Client) getToken() (string, error) { // Ensure we have the base URL without trailing slashes baseEndpoint := strings.TrimRight(baseURL.String(), "/") + var tlsConfig *tls.Config = nil + if c.config.TlsClientCert != "" { + tlsConfig, err = loadClientCertificate(c.config.TlsClientCert) + if err != nil { + return "", fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err) + } + } + // If we already have a token, try to use it if c.config.Token != "" { tokenCheckData := map[string]interface{}{ @@ -190,11 +198,7 @@ func (c *Client) getToken() (string, error) { // Make the request client := &http.Client{} - if c.config.TlsClientCert != "" { - tlsConfig, err := LoadClientCertificate(c.config.TlsClientCert) - if err != nil { - return "", fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err) - } + if tlsConfig != nil { client.Transport = &http.Transport{ TLSClientConfig: tlsConfig, } @@ -242,11 +246,7 @@ func (c *Client) getToken() (string, error) { // Make the request client := &http.Client{} - if c.config.TlsClientCert != "" { - tlsConfig, err := LoadClientCertificate(c.config.TlsClientCert) - if err != nil { - return "", fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err) - } + if tlsConfig != nil { client.Transport = &http.Transport{ TLSClientConfig: tlsConfig, } @@ -329,7 +329,7 @@ func (c *Client) establishConnection() error { dialer := websocket.DefaultDialer if c.config.TlsClientCert != "" { logger.Info("Adding tls to req") - tlsConfig, err := LoadClientCertificate(c.config.TlsClientCert) + tlsConfig, err := loadClientCertificate(c.config.TlsClientCert) if err != nil { return fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err) } @@ -395,7 +395,7 @@ func (c *Client) setConnected(status bool) { } // LoadClientCertificate Helper method to load client certificates -func LoadClientCertificate(p12Path string) (*tls.Config, error) { +func loadClientCertificate(p12Path string) (*tls.Config, error) { logger.Info("Loading tls-client-cert %s", p12Path) // Read the PKCS12 file p12Data, err := os.ReadFile(p12Path) @@ -408,7 +408,7 @@ func LoadClientCertificate(p12Path string) (*tls.Config, error) { if err != nil { return nil, fmt.Errorf("failed to decode PKCS12: %w", err) } - + // Create certificate cert := tls.Certificate{ Certificate: [][]byte{certificate.Raw}, diff --git a/websocket/config.go b/websocket/config.go index b8dac85..e2b0055 100644 --- a/websocket/config.go +++ b/websocket/config.go @@ -61,10 +61,6 @@ func (c *Client) loadConfig() error { c.config.Endpoint = config.Endpoint c.baseURL = config.Endpoint } - if c.config.Endpoint == "" { - c.config.Endpoint = config.Endpoint - c.baseURL = config.Endpoint - } return nil }