mirror of
https://github.com/fosrl/newt.git
synced 2025-05-12 21:20:39 +01:00
feat/mtls-support-cert: doc update, removing config.Endpoint loading duplicates, handling null-pointer case and some logging
This commit is contained in:
parent
b41570eb2c
commit
d28e3ca5e8
4 changed files with 20 additions and 21 deletions
|
@ -124,7 +124,7 @@ Examples:
|
||||||
--id 31frd0uzbjvp721 \
|
--id 31frd0uzbjvp721 \
|
||||||
--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
|
--secret h51mmlknrvrwv8s4r1i210azhumt6isgbpyavxodibx1k2d6 \
|
||||||
--endpoint https://example.com \
|
--endpoint https://example.com \
|
||||||
--tls-client-cert /client.p12
|
--tls-client-cert ./client.p12
|
||||||
```
|
```
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
@ -137,7 +137,7 @@ services:
|
||||||
- PANGOLIN_ENDPOINT=https://example.com
|
- PANGOLIN_ENDPOINT=https://example.com
|
||||||
- NEWT_ID=2ix2t8xk22ubpfy
|
- NEWT_ID=2ix2t8xk22ubpfy
|
||||||
- NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2
|
- NEWT_SECRET=nnisrfsdfc7prqsp9ewo1dvtvci50j5uiqotez00dgap0ii2
|
||||||
- TLS_CLIENT_CERT=/client.p12
|
- TLS_CLIENT_CERT=./client.p12
|
||||||
```
|
```
|
||||||
|
|
||||||
## Build
|
## Build
|
||||||
|
|
7
main.go
7
main.go
|
@ -561,10 +561,13 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||||
// Wait for interrupt signal
|
// Wait for interrupt signal
|
||||||
sigCh := make(chan os.Signal, 1)
|
sigCh := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||||
<-sigCh
|
sigReceived := <-sigCh
|
||||||
|
|
||||||
// Cleanup
|
// Cleanup
|
||||||
dev.Close()
|
logger.Info("Received %s signal, stopping", sigReceived.String())
|
||||||
|
if dev != nil {
|
||||||
|
dev.Close()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseTargetData(data interface{}) (TargetData, error) {
|
func parseTargetData(data interface{}) (TargetData, error) {
|
||||||
|
|
|
@ -162,6 +162,14 @@ func (c *Client) getToken() (string, error) {
|
||||||
// Ensure we have the base URL without trailing slashes
|
// Ensure we have the base URL without trailing slashes
|
||||||
baseEndpoint := strings.TrimRight(baseURL.String(), "/")
|
baseEndpoint := strings.TrimRight(baseURL.String(), "/")
|
||||||
|
|
||||||
|
var tlsConfig *tls.Config = nil
|
||||||
|
if c.config.TlsClientCert != "" {
|
||||||
|
tlsConfig, err = loadClientCertificate(c.config.TlsClientCert)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If we already have a token, try to use it
|
// If we already have a token, try to use it
|
||||||
if c.config.Token != "" {
|
if c.config.Token != "" {
|
||||||
tokenCheckData := map[string]interface{}{
|
tokenCheckData := map[string]interface{}{
|
||||||
|
@ -190,11 +198,7 @@ func (c *Client) getToken() (string, error) {
|
||||||
|
|
||||||
// Make the request
|
// Make the request
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
if c.config.TlsClientCert != "" {
|
if tlsConfig != nil {
|
||||||
tlsConfig, err := LoadClientCertificate(c.config.TlsClientCert)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
|
|
||||||
}
|
|
||||||
client.Transport = &http.Transport{
|
client.Transport = &http.Transport{
|
||||||
TLSClientConfig: tlsConfig,
|
TLSClientConfig: tlsConfig,
|
||||||
}
|
}
|
||||||
|
@ -242,11 +246,7 @@ func (c *Client) getToken() (string, error) {
|
||||||
|
|
||||||
// Make the request
|
// Make the request
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
if c.config.TlsClientCert != "" {
|
if tlsConfig != nil {
|
||||||
tlsConfig, err := LoadClientCertificate(c.config.TlsClientCert)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
|
|
||||||
}
|
|
||||||
client.Transport = &http.Transport{
|
client.Transport = &http.Transport{
|
||||||
TLSClientConfig: tlsConfig,
|
TLSClientConfig: tlsConfig,
|
||||||
}
|
}
|
||||||
|
@ -329,7 +329,7 @@ func (c *Client) establishConnection() error {
|
||||||
dialer := websocket.DefaultDialer
|
dialer := websocket.DefaultDialer
|
||||||
if c.config.TlsClientCert != "" {
|
if c.config.TlsClientCert != "" {
|
||||||
logger.Info("Adding tls to req")
|
logger.Info("Adding tls to req")
|
||||||
tlsConfig, err := LoadClientCertificate(c.config.TlsClientCert)
|
tlsConfig, err := loadClientCertificate(c.config.TlsClientCert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
|
return fmt.Errorf("failed to load certificate %s: %w", c.config.TlsClientCert, err)
|
||||||
}
|
}
|
||||||
|
@ -395,7 +395,7 @@ func (c *Client) setConnected(status bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadClientCertificate Helper method to load client certificates
|
// LoadClientCertificate Helper method to load client certificates
|
||||||
func LoadClientCertificate(p12Path string) (*tls.Config, error) {
|
func loadClientCertificate(p12Path string) (*tls.Config, error) {
|
||||||
logger.Info("Loading tls-client-cert %s", p12Path)
|
logger.Info("Loading tls-client-cert %s", p12Path)
|
||||||
// Read the PKCS12 file
|
// Read the PKCS12 file
|
||||||
p12Data, err := os.ReadFile(p12Path)
|
p12Data, err := os.ReadFile(p12Path)
|
||||||
|
@ -408,7 +408,7 @@ func LoadClientCertificate(p12Path string) (*tls.Config, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to decode PKCS12: %w", err)
|
return nil, fmt.Errorf("failed to decode PKCS12: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create certificate
|
// Create certificate
|
||||||
cert := tls.Certificate{
|
cert := tls.Certificate{
|
||||||
Certificate: [][]byte{certificate.Raw},
|
Certificate: [][]byte{certificate.Raw},
|
||||||
|
|
|
@ -61,10 +61,6 @@ func (c *Client) loadConfig() error {
|
||||||
c.config.Endpoint = config.Endpoint
|
c.config.Endpoint = config.Endpoint
|
||||||
c.baseURL = config.Endpoint
|
c.baseURL = config.Endpoint
|
||||||
}
|
}
|
||||||
if c.config.Endpoint == "" {
|
|
||||||
c.config.Endpoint = config.Endpoint
|
|
||||||
c.baseURL = config.Endpoint
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue