Add retry logic to newt

This commit is contained in:
Owen Schwartz 2024-12-08 16:24:17 -05:00
parent 082ebae0bb
commit e99853422c
No known key found for this signature in database
GPG key ID: 8271FDFFD9E0CCBD
2 changed files with 138 additions and 64 deletions

28
main.go
View file

@ -379,24 +379,28 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
}
})
client.OnConnect(func() error {
publicKey := privateKey.PublicKey()
logger.Debug("Public key: %s", publicKey)
err := client.SendMessage("newt/wg/register", map[string]interface{}{
"publicKey": fmt.Sprintf("%s", publicKey),
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent registration message")
return nil
})
// Connect to the WebSocket server
if err := client.Connect(); err != nil {
logger.Fatal("Failed to connect to server: %v", err)
}
defer client.Close()
publicKey := privateKey.PublicKey()
logger.Debug("Public key: %s", publicKey)
// TODO: how to retry?
err = client.SendMessage("newt/wg/register", map[string]interface{}{
"publicKey": fmt.Sprintf("%s", publicKey),
})
if err != nil {
logger.Info("Failed to send message: %v", err)
}
logger.Info("Sent registration message")
// Wait for interrupt signal
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)

View file

@ -9,6 +9,7 @@ import (
"newt/logger"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
)
@ -20,6 +21,12 @@ type Client struct {
handlers map[string]MessageHandler
done chan struct{}
handlersMux sync.RWMutex
reconnectInterval time.Duration
isConnected bool
reconnectMux sync.RWMutex
onConnect func() error
}
type ClientOption func(*Client)
@ -33,6 +40,10 @@ func WithBaseURL(url string) ClientOption {
}
}
func (c *Client) OnConnect(callback func() error) {
c.onConnect = callback
}
// NewClient creates a new Newt client
func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*Client, error) {
config := &Config{
@ -42,10 +53,12 @@ func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*C
}
client := &Client{
config: config,
baseURL: endpoint, // default value
handlers: make(map[string]MessageHandler),
done: make(chan struct{}),
config: config,
baseURL: endpoint, // default value
handlers: make(map[string]MessageHandler),
done: make(chan struct{}),
reconnectInterval: 10 * time.Second,
isConnected: false,
}
// Apply options before loading config
@ -63,54 +76,7 @@ func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*C
// Connect establishes the WebSocket connection
func (c *Client) Connect() error {
// Get token for authentication
token, err := c.getToken()
if err != nil {
return fmt.Errorf("failed to get token: %w", err)
}
logger.Info("Using token: %s", token)
// Update config with new token and save
c.config.Token = token
if err := c.saveConfig(); err != nil {
return fmt.Errorf("failed to save config: %w", err)
}
// Parse the base URL to determine protocol and hostname
baseURL, err := url.Parse(c.baseURL)
if err != nil {
return fmt.Errorf("failed to parse base URL: %w", err)
}
// Determine WebSocket protocol based on HTTP protocol
wsProtocol := "wss"
if baseURL.Scheme == "http" {
wsProtocol = "ws"
}
// Create WebSocket URL using the hostname without path
wsURL := fmt.Sprintf("%s://%s/api/v1/ws", wsProtocol, baseURL.Host)
u, err := url.Parse(wsURL)
if err != nil {
return fmt.Errorf("failed to parse WebSocket URL: %w", err)
}
// Add token to query parameters
q := u.Query()
q.Set("token", token)
u.RawQuery = q.Encode()
// Connect to WebSocket
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
return fmt.Errorf("failed to connect to WebSocket: %w", err)
}
logger.Info("Connected to WebSocket")
c.conn = conn
go c.readPump()
go c.connectWithRetry()
return nil
}
@ -248,3 +214,107 @@ func (c *Client) getToken() (string, error) {
return tokenResp.Data.Token, nil
}
func (c *Client) connectWithRetry() {
for {
select {
case <-c.done:
return
default:
err := c.establishConnection()
if err != nil {
logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
time.Sleep(c.reconnectInterval)
continue
}
return
}
}
}
func (c *Client) establishConnection() error {
// Get token for authentication
token, err := c.getToken()
if err != nil {
return fmt.Errorf("failed to get token: %w", err)
}
// Parse the base URL to determine protocol and hostname
baseURL, err := url.Parse(c.baseURL)
if err != nil {
return fmt.Errorf("failed to parse base URL: %w", err)
}
// Determine WebSocket protocol based on HTTP protocol
wsProtocol := "wss"
if baseURL.Scheme == "http" {
wsProtocol = "ws"
}
// Create WebSocket URL
wsURL := fmt.Sprintf("%s://%s/api/v1/ws", wsProtocol, baseURL.Host)
u, err := url.Parse(wsURL)
if err != nil {
return fmt.Errorf("failed to parse WebSocket URL: %w", err)
}
// Add token to query parameters
q := u.Query()
q.Set("token", token)
u.RawQuery = q.Encode()
// Connect to WebSocket
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
return fmt.Errorf("failed to connect to WebSocket: %w", err)
}
c.conn = conn
c.setConnected(true)
// Start the ping monitor
go c.pingMonitor()
// Start the read pump
go c.readPump()
if c.onConnect != nil {
if err := c.onConnect(); err != nil {
logger.Error("OnConnect callback failed: %v", err)
}
}
return nil
}
func (c *Client) pingMonitor() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-c.done:
return
case <-ticker.C:
if err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil {
logger.Error("Ping failed: %v", err)
c.reconnect()
return
}
}
}
}
func (c *Client) reconnect() {
c.setConnected(false)
if c.conn != nil {
c.conn.Close()
}
go c.connectWithRetry()
}
func (c *Client) setConnected(status bool) {
c.reconnectMux.Lock()
defer c.reconnectMux.Unlock()
c.isConnected = status
}