From b41570eb2ca5ac28ae30c929d885d556e7992744 Mon Sep 17 00:00:00 2001 From: progressive-kiwi Date: Tue, 1 Apr 2025 20:43:42 +0200 Subject: [PATCH] feat/mtls-support-cert: config support --- main.go | 7 +----- websocket/client.go | 56 +++++++++++++++++++++++++++++---------------- websocket/config.go | 7 ++++++ websocket/types.go | 9 ++++---- 4 files changed, 49 insertions(+), 30 deletions(-) diff --git a/main.go b/main.go index bec5554..4feb325 100644 --- a/main.go +++ b/main.go @@ -321,13 +321,8 @@ func main() { } var opt websocket.ClientOption if tlsPrivateKey != "" { - tlsConfig, err := websocket.LoadClientCertificate(tlsPrivateKey) - if err != nil { - logger.Fatal("Failed to load client certificate: %v", err) - } - opt = websocket.WithTLSConfig(tlsConfig) + opt = websocket.WithTLSConfig(tlsPrivateKey) } - // Create a new client client, err := websocket.NewClient( id, // CLI arg takes precedence diff --git a/websocket/client.go b/websocket/client.go index 0f491d3..894e3bc 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -19,14 +19,12 @@ import ( ) type Client struct { - conn *websocket.Conn - config *Config - baseURL string - handlers map[string]MessageHandler - done chan struct{} - handlersMux sync.RWMutex - tlsConfig *tls.Config - + conn *websocket.Conn + config *Config + baseURL string + handlers map[string]MessageHandler + done chan struct{} + handlersMux sync.RWMutex reconnectInterval time.Duration isConnected bool reconnectMux sync.RWMutex @@ -45,9 +43,9 @@ func WithBaseURL(url string) ClientOption { } } -func WithTLSConfig(tlsConfig *tls.Config) ClientOption { +func WithTLSConfig(tlsClientCertPath string) ClientOption { return func(c *Client) { - c.tlsConfig = tlsConfig + c.config.TlsClientCert = tlsClientCertPath } } @@ -73,8 +71,13 @@ func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*C } // Apply options before loading config - for _, opt := range opts { - opt(client) + if opts != nil { + for _, opt := range opts { + if opt == nil { + continue + } + opt(client) + } } // Load existing config if available @@ -187,10 +190,13 @@ func (c *Client) getToken() (string, error) { // Make the request client := &http.Client{} - if c.tlsConfig != nil { - logger.Info("Adding tls to req") + if c.config.TlsClientCert != "" { + tlsConfig, err := LoadClientCertificate(c.config.TlsClientCert) + if err != nil { + return "", fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err) + } client.Transport = &http.Transport{ - TLSClientConfig: c.tlsConfig, + TLSClientConfig: tlsConfig, } } resp, err := client.Do(req) @@ -236,9 +242,13 @@ func (c *Client) getToken() (string, error) { // Make the request client := &http.Client{} - if c.tlsConfig != nil { + if c.config.TlsClientCert != "" { + tlsConfig, err := LoadClientCertificate(c.config.TlsClientCert) + if err != nil { + return "", fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err) + } client.Transport = &http.Transport{ - TLSClientConfig: c.tlsConfig, + TLSClientConfig: tlsConfig, } } resp, err := client.Do(req) @@ -317,8 +327,13 @@ func (c *Client) establishConnection() error { // Connect to WebSocket dialer := websocket.DefaultDialer - if c.tlsConfig != nil { - dialer.TLSClientConfig = c.tlsConfig + if c.config.TlsClientCert != "" { + logger.Info("Adding tls to req") + tlsConfig, err := LoadClientCertificate(c.config.TlsClientCert) + if err != nil { + return fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err) + } + dialer.TLSClientConfig = tlsConfig } conn, _, err := dialer.Dial(u.String(), nil) if err != nil { @@ -381,6 +396,7 @@ func (c *Client) setConnected(status bool) { // LoadClientCertificate Helper method to load client certificates func LoadClientCertificate(p12Path string) (*tls.Config, error) { + logger.Info("Loading tls-client-cert %s", p12Path) // Read the PKCS12 file p12Data, err := os.ReadFile(p12Path) if err != nil { @@ -392,7 +408,7 @@ func LoadClientCertificate(p12Path string) (*tls.Config, error) { if err != nil { return nil, fmt.Errorf("failed to decode PKCS12: %w", err) } - + // Create certificate cert := tls.Certificate{ Certificate: [][]byte{certificate.Raw}, diff --git a/websocket/config.go b/websocket/config.go index 794ff1e..b8dac85 100644 --- a/websocket/config.go +++ b/websocket/config.go @@ -54,6 +54,13 @@ func (c *Client) loadConfig() error { if c.config.Secret == "" { c.config.Secret = config.Secret } + if c.config.TlsClientCert == "" { + c.config.TlsClientCert = config.TlsClientCert + } + if c.config.Endpoint == "" { + c.config.Endpoint = config.Endpoint + c.baseURL = config.Endpoint + } if c.config.Endpoint == "" { c.config.Endpoint = config.Endpoint c.baseURL = config.Endpoint diff --git a/websocket/types.go b/websocket/types.go index 084465a..0ea24fc 100644 --- a/websocket/types.go +++ b/websocket/types.go @@ -1,10 +1,11 @@ package websocket type Config struct { - NewtID string `json:"newtId"` - Secret string `json:"secret"` - Token string `json:"token"` - Endpoint string `json:"endpoint"` + NewtID string `json:"newtId"` + Secret string `json:"secret"` + Token string `json:"token"` + Endpoint string `json:"endpoint"` + TlsClientCert string `json:"tlsClientCert"` } type TokenResponse struct {