mirror of
https://github.com/fosrl/newt.git
synced 2025-05-13 13:40:39 +01:00
Add retry logic to newt
This commit is contained in:
parent
082ebae0bb
commit
e99853422c
2 changed files with 138 additions and 64 deletions
28
main.go
28
main.go
|
@ -379,24 +379,28 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
client.OnConnect(func() error {
|
||||||
|
publicKey := privateKey.PublicKey()
|
||||||
|
logger.Debug("Public key: %s", publicKey)
|
||||||
|
|
||||||
|
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
||||||
|
"publicKey": fmt.Sprintf("%s", publicKey),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to send registration message: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Sent registration message")
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
// Connect to the WebSocket server
|
// Connect to the WebSocket server
|
||||||
if err := client.Connect(); err != nil {
|
if err := client.Connect(); err != nil {
|
||||||
logger.Fatal("Failed to connect to server: %v", err)
|
logger.Fatal("Failed to connect to server: %v", err)
|
||||||
}
|
}
|
||||||
defer client.Close()
|
defer client.Close()
|
||||||
|
|
||||||
publicKey := privateKey.PublicKey()
|
|
||||||
logger.Debug("Public key: %s", publicKey)
|
|
||||||
// TODO: how to retry?
|
|
||||||
err = client.SendMessage("newt/wg/register", map[string]interface{}{
|
|
||||||
"publicKey": fmt.Sprintf("%s", publicKey),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Failed to send message: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Sent registration message")
|
|
||||||
|
|
||||||
// Wait for interrupt signal
|
// Wait for interrupt signal
|
||||||
sigCh := make(chan os.Signal, 1)
|
sigCh := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"newt/logger"
|
"newt/logger"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
@ -20,6 +21,12 @@ type Client struct {
|
||||||
handlers map[string]MessageHandler
|
handlers map[string]MessageHandler
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
handlersMux sync.RWMutex
|
handlersMux sync.RWMutex
|
||||||
|
|
||||||
|
reconnectInterval time.Duration
|
||||||
|
isConnected bool
|
||||||
|
reconnectMux sync.RWMutex
|
||||||
|
|
||||||
|
onConnect func() error
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientOption func(*Client)
|
type ClientOption func(*Client)
|
||||||
|
@ -33,6 +40,10 @@ func WithBaseURL(url string) ClientOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) OnConnect(callback func() error) {
|
||||||
|
c.onConnect = callback
|
||||||
|
}
|
||||||
|
|
||||||
// NewClient creates a new Newt client
|
// NewClient creates a new Newt client
|
||||||
func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*Client, error) {
|
func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*Client, error) {
|
||||||
config := &Config{
|
config := &Config{
|
||||||
|
@ -46,6 +57,8 @@ func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*C
|
||||||
baseURL: endpoint, // default value
|
baseURL: endpoint, // default value
|
||||||
handlers: make(map[string]MessageHandler),
|
handlers: make(map[string]MessageHandler),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
|
reconnectInterval: 10 * time.Second,
|
||||||
|
isConnected: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply options before loading config
|
// Apply options before loading config
|
||||||
|
@ -63,54 +76,7 @@ func NewClient(newtID, secret string, endpoint string, opts ...ClientOption) (*C
|
||||||
|
|
||||||
// Connect establishes the WebSocket connection
|
// Connect establishes the WebSocket connection
|
||||||
func (c *Client) Connect() error {
|
func (c *Client) Connect() error {
|
||||||
// Get token for authentication
|
go c.connectWithRetry()
|
||||||
token, err := c.getToken()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Using token: %s", token)
|
|
||||||
|
|
||||||
// Update config with new token and save
|
|
||||||
c.config.Token = token
|
|
||||||
if err := c.saveConfig(); err != nil {
|
|
||||||
return fmt.Errorf("failed to save config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the base URL to determine protocol and hostname
|
|
||||||
baseURL, err := url.Parse(c.baseURL)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse base URL: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine WebSocket protocol based on HTTP protocol
|
|
||||||
wsProtocol := "wss"
|
|
||||||
if baseURL.Scheme == "http" {
|
|
||||||
wsProtocol = "ws"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create WebSocket URL using the hostname without path
|
|
||||||
wsURL := fmt.Sprintf("%s://%s/api/v1/ws", wsProtocol, baseURL.Host)
|
|
||||||
u, err := url.Parse(wsURL)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse WebSocket URL: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add token to query parameters
|
|
||||||
q := u.Query()
|
|
||||||
q.Set("token", token)
|
|
||||||
u.RawQuery = q.Encode()
|
|
||||||
|
|
||||||
// Connect to WebSocket
|
|
||||||
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to connect to WebSocket: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Connected to WebSocket")
|
|
||||||
|
|
||||||
c.conn = conn
|
|
||||||
go c.readPump()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -248,3 +214,107 @@ func (c *Client) getToken() (string, error) {
|
||||||
|
|
||||||
return tokenResp.Data.Token, nil
|
return tokenResp.Data.Token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) connectWithRetry() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
err := c.establishConnection()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
|
||||||
|
time.Sleep(c.reconnectInterval)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) establishConnection() error {
|
||||||
|
// Get token for authentication
|
||||||
|
token, err := c.getToken()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the base URL to determine protocol and hostname
|
||||||
|
baseURL, err := url.Parse(c.baseURL)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse base URL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine WebSocket protocol based on HTTP protocol
|
||||||
|
wsProtocol := "wss"
|
||||||
|
if baseURL.Scheme == "http" {
|
||||||
|
wsProtocol = "ws"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create WebSocket URL
|
||||||
|
wsURL := fmt.Sprintf("%s://%s/api/v1/ws", wsProtocol, baseURL.Host)
|
||||||
|
u, err := url.Parse(wsURL)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse WebSocket URL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add token to query parameters
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("token", token)
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
// Connect to WebSocket
|
||||||
|
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to WebSocket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.conn = conn
|
||||||
|
c.setConnected(true)
|
||||||
|
|
||||||
|
// Start the ping monitor
|
||||||
|
go c.pingMonitor()
|
||||||
|
// Start the read pump
|
||||||
|
go c.readPump()
|
||||||
|
|
||||||
|
if c.onConnect != nil {
|
||||||
|
if err := c.onConnect(); err != nil {
|
||||||
|
logger.Error("OnConnect callback failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) pingMonitor() {
|
||||||
|
ticker := time.NewTicker(30 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil {
|
||||||
|
logger.Error("Ping failed: %v", err)
|
||||||
|
c.reconnect()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) reconnect() {
|
||||||
|
c.setConnected(false)
|
||||||
|
if c.conn != nil {
|
||||||
|
c.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
go c.connectWithRetry()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) setConnected(status bool) {
|
||||||
|
c.reconnectMux.Lock()
|
||||||
|
defer c.reconnectMux.Unlock()
|
||||||
|
c.isConnected = status
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue