feat/mtls-support-cert: config support

This commit is contained in:
progressive-kiwi 2025-04-01 20:43:42 +02:00
parent 435b638701
commit b41570eb2c
4 changed files with 49 additions and 30 deletions

View file

@ -321,13 +321,8 @@ func main() {
} }
var opt websocket.ClientOption var opt websocket.ClientOption
if tlsPrivateKey != "" { if tlsPrivateKey != "" {
tlsConfig, err := websocket.LoadClientCertificate(tlsPrivateKey) opt = websocket.WithTLSConfig(tlsPrivateKey)
if err != nil {
logger.Fatal("Failed to load client certificate: %v", err)
}
opt = websocket.WithTLSConfig(tlsConfig)
} }
// Create a new client // Create a new client
client, err := websocket.NewClient( client, err := websocket.NewClient(
id, // CLI arg takes precedence id, // CLI arg takes precedence

View file

@ -19,14 +19,12 @@ import (
) )
type Client struct { type Client struct {
conn *websocket.Conn conn *websocket.Conn
config *Config config *Config
baseURL string baseURL string
handlers map[string]MessageHandler handlers map[string]MessageHandler
done chan struct{} done chan struct{}
handlersMux sync.RWMutex handlersMux sync.RWMutex
tlsConfig *tls.Config
reconnectInterval time.Duration reconnectInterval time.Duration
isConnected bool isConnected bool
reconnectMux sync.RWMutex reconnectMux sync.RWMutex
@ -45,9 +43,9 @@ func WithBaseURL(url string) ClientOption {
} }
} }
func WithTLSConfig(tlsConfig *tls.Config) ClientOption { func WithTLSConfig(tlsClientCertPath string) ClientOption {
return func(c *Client) { return func(c *Client) {
c.tlsConfig = tlsConfig c.config.TlsClientCert = tlsClientCertPath
} }
} }
@ -73,8 +71,13 @@ func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*C
} }
// Apply options before loading config // Apply options before loading config
for _, opt := range opts { if opts != nil {
opt(client) for _, opt := range opts {
if opt == nil {
continue
}
opt(client)
}
} }
// Load existing config if available // Load existing config if available
@ -187,10 +190,13 @@ func (c *Client) getToken() (string, error) {
// Make the request // Make the request
client := &http.Client{} client := &http.Client{}
if c.tlsConfig != nil { if c.config.TlsClientCert != "" {
logger.Info("Adding tls to req") tlsConfig, err := LoadClientCertificate(c.config.TlsClientCert)
if err != nil {
return "", fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
}
client.Transport = &http.Transport{ client.Transport = &http.Transport{
TLSClientConfig: c.tlsConfig, TLSClientConfig: tlsConfig,
} }
} }
resp, err := client.Do(req) resp, err := client.Do(req)
@ -236,9 +242,13 @@ func (c *Client) getToken() (string, error) {
// Make the request // Make the request
client := &http.Client{} client := &http.Client{}
if c.tlsConfig != nil { if c.config.TlsClientCert != "" {
tlsConfig, err := LoadClientCertificate(c.config.TlsClientCert)
if err != nil {
return "", fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
}
client.Transport = &http.Transport{ client.Transport = &http.Transport{
TLSClientConfig: c.tlsConfig, TLSClientConfig: tlsConfig,
} }
} }
resp, err := client.Do(req) resp, err := client.Do(req)
@ -317,8 +327,13 @@ func (c *Client) establishConnection() error {
// Connect to WebSocket // Connect to WebSocket
dialer := websocket.DefaultDialer dialer := websocket.DefaultDialer
if c.tlsConfig != nil { if c.config.TlsClientCert != "" {
dialer.TLSClientConfig = c.tlsConfig logger.Info("Adding tls to req")
tlsConfig, err := LoadClientCertificate(c.config.TlsClientCert)
if err != nil {
return fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
}
dialer.TLSClientConfig = tlsConfig
} }
conn, _, err := dialer.Dial(u.String(), nil) conn, _, err := dialer.Dial(u.String(), nil)
if err != nil { if err != nil {
@ -381,6 +396,7 @@ func (c *Client) setConnected(status bool) {
// LoadClientCertificate Helper method to load client certificates // LoadClientCertificate Helper method to load client certificates
func LoadClientCertificate(p12Path string) (*tls.Config, error) { func LoadClientCertificate(p12Path string) (*tls.Config, error) {
logger.Info("Loading tls-client-cert %s", p12Path)
// Read the PKCS12 file // Read the PKCS12 file
p12Data, err := os.ReadFile(p12Path) p12Data, err := os.ReadFile(p12Path)
if err != nil { if err != nil {

View file

@ -54,6 +54,13 @@ func (c *Client) loadConfig() error {
if c.config.Secret == "" { if c.config.Secret == "" {
c.config.Secret = config.Secret c.config.Secret = config.Secret
} }
if c.config.TlsClientCert == "" {
c.config.TlsClientCert = config.TlsClientCert
}
if c.config.Endpoint == "" {
c.config.Endpoint = config.Endpoint
c.baseURL = config.Endpoint
}
if c.config.Endpoint == "" { if c.config.Endpoint == "" {
c.config.Endpoint = config.Endpoint c.config.Endpoint = config.Endpoint
c.baseURL = config.Endpoint c.baseURL = config.Endpoint

View file

@ -1,10 +1,11 @@
package websocket package websocket
type Config struct { type Config struct {
NewtID string `json:"newtId"` NewtID string `json:"newtId"`
Secret string `json:"secret"` Secret string `json:"secret"`
Token string `json:"token"` Token string `json:"token"`
Endpoint string `json:"endpoint"` Endpoint string `json:"endpoint"`
TlsClientCert string `json:"tlsClientCert"`
} }
type TokenResponse struct { type TokenResponse struct {