diff --git a/main.go b/main.go index b65acaf..6787ad5 100644 --- a/main.go +++ b/main.go @@ -379,24 +379,28 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( } }) + client.OnConnect(func() error { + publicKey := privateKey.PublicKey() + logger.Debug("Public key: %s", publicKey) + + err := client.SendMessage("newt/wg/register", map[string]interface{}{ + "publicKey": fmt.Sprintf("%s", publicKey), + }) + if err != nil { + logger.Error("Failed to send registration message: %v", err) + return err + } + + logger.Info("Sent registration message") + return nil + }) + // Connect to the WebSocket server if err := client.Connect(); err != nil { logger.Fatal("Failed to connect to server: %v", err) } defer client.Close() - publicKey := privateKey.PublicKey() - logger.Debug("Public key: %s", publicKey) - // TODO: how to retry? - err = client.SendMessage("newt/wg/register", map[string]interface{}{ - "publicKey": fmt.Sprintf("%s", publicKey), - }) - if err != nil { - logger.Info("Failed to send message: %v", err) - } - - logger.Info("Sent registration message") - // Wait for interrupt signal sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) diff --git a/websocket/client.go b/websocket/client.go index 94351ac..c89e88c 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -9,6 +9,7 @@ import ( "newt/logger" "strings" "sync" + "time" "github.com/gorilla/websocket" ) @@ -20,6 +21,12 @@ type Client struct { handlers map[string]MessageHandler done chan struct{} handlersMux sync.RWMutex + + reconnectInterval time.Duration + isConnected bool + reconnectMux sync.RWMutex + + onConnect func() error } type ClientOption func(*Client) @@ -33,6 +40,10 @@ func WithBaseURL(url string) ClientOption { } } +func (c *Client) OnConnect(callback func() error) { + c.onConnect = callback +} + // NewClient creates a new Newt client func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*Client, error) { config := &Config{ @@ -42,10 +53,12 @@ func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*C } client := &Client{ - config: config, - baseURL: endpoint, // default value - handlers: make(map[string]MessageHandler), - done: make(chan struct{}), + config: config, + baseURL: endpoint, // default value + handlers: make(map[string]MessageHandler), + done: make(chan struct{}), + reconnectInterval: 10 * time.Second, + isConnected: false, } // Apply options before loading config @@ -63,54 +76,7 @@ func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*C // Connect establishes the WebSocket connection func (c *Client) Connect() error { - // Get token for authentication - token, err := c.getToken() - if err != nil { - return fmt.Errorf("failed to get token: %w", err) - } - - logger.Info("Using token: %s", token) - - // Update config with new token and save - c.config.Token = token - if err := c.saveConfig(); err != nil { - return fmt.Errorf("failed to save config: %w", err) - } - - // Parse the base URL to determine protocol and hostname - baseURL, err := url.Parse(c.baseURL) - if err != nil { - return fmt.Errorf("failed to parse base URL: %w", err) - } - - // Determine WebSocket protocol based on HTTP protocol - wsProtocol := "wss" - if baseURL.Scheme == "http" { - wsProtocol = "ws" - } - - // Create WebSocket URL using the hostname without path - wsURL := fmt.Sprintf("%s://%s/api/v1/ws", wsProtocol, baseURL.Host) - u, err := url.Parse(wsURL) - if err != nil { - return fmt.Errorf("failed to parse WebSocket URL: %w", err) - } - - // Add token to query parameters - q := u.Query() - q.Set("token", token) - u.RawQuery = q.Encode() - - // Connect to WebSocket - conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) - if err != nil { - return fmt.Errorf("failed to connect to WebSocket: %w", err) - } - - logger.Info("Connected to WebSocket") - - c.conn = conn - go c.readPump() + go c.connectWithRetry() return nil } @@ -248,3 +214,107 @@ func (c *Client) getToken() (string, error) { return tokenResp.Data.Token, nil } + +func (c *Client) connectWithRetry() { + for { + select { + case <-c.done: + return + default: + err := c.establishConnection() + if err != nil { + logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) + time.Sleep(c.reconnectInterval) + continue + } + return + } + } +} + +func (c *Client) establishConnection() error { + // Get token for authentication + token, err := c.getToken() + if err != nil { + return fmt.Errorf("failed to get token: %w", err) + } + + // Parse the base URL to determine protocol and hostname + baseURL, err := url.Parse(c.baseURL) + if err != nil { + return fmt.Errorf("failed to parse base URL: %w", err) + } + + // Determine WebSocket protocol based on HTTP protocol + wsProtocol := "wss" + if baseURL.Scheme == "http" { + wsProtocol = "ws" + } + + // Create WebSocket URL + wsURL := fmt.Sprintf("%s://%s/api/v1/ws", wsProtocol, baseURL.Host) + u, err := url.Parse(wsURL) + if err != nil { + return fmt.Errorf("failed to parse WebSocket URL: %w", err) + } + + // Add token to query parameters + q := u.Query() + q.Set("token", token) + u.RawQuery = q.Encode() + + // Connect to WebSocket + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + return fmt.Errorf("failed to connect to WebSocket: %w", err) + } + + c.conn = conn + c.setConnected(true) + + // Start the ping monitor + go c.pingMonitor() + // Start the read pump + go c.readPump() + + if c.onConnect != nil { + if err := c.onConnect(); err != nil { + logger.Error("OnConnect callback failed: %v", err) + } + } + + return nil +} + +func (c *Client) pingMonitor() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-c.done: + return + case <-ticker.C: + if err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { + logger.Error("Ping failed: %v", err) + c.reconnect() + return + } + } + } +} + +func (c *Client) reconnect() { + c.setConnected(false) + if c.conn != nil { + c.conn.Close() + } + + go c.connectWithRetry() +} + +func (c *Client) setConnected(status bool) { + c.reconnectMux.Lock() + defer c.reconnectMux.Unlock() + c.isConnected = status +}