From f48502dd089bc9ebe68adf7e061a66327fa60a9c Mon Sep 17 00:00:00 2001 From: Owen Schwartz Date: Sat, 23 Nov 2024 17:34:58 -0500 Subject: [PATCH] Standardize logs --- logger/level.go | 27 ++++++++++ logger/logger.go | 106 ++++++++++++++++++++++++++++++++++++ main.go | 127 +++++++++++++++++++++++++++++--------------- proxy/types.go | 2 + websocket/client.go | 2 - 5 files changed, 219 insertions(+), 45 deletions(-) create mode 100644 logger/level.go create mode 100644 logger/logger.go diff --git a/logger/level.go b/logger/level.go new file mode 100644 index 0000000..175995f --- /dev/null +++ b/logger/level.go @@ -0,0 +1,27 @@ +package logger + +type LogLevel int + +const ( + DEBUG LogLevel = iota + INFO + WARN + ERROR + FATAL +) + +var levelStrings = map[LogLevel]string{ + DEBUG: "DEBUG", + INFO: "INFO", + WARN: "WARN", + ERROR: "ERROR", + FATAL: "FATAL", +} + +// String returns the string representation of the log level +func (l LogLevel) String() string { + if s, ok := levelStrings[l]; ok { + return s + } + return "UNKNOWN" +} diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 0000000..9ef486d --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,106 @@ +package logger + +import ( + "fmt" + "log" + "os" + "sync" + "time" +) + +// Logger struct holds the logger instance +type Logger struct { + logger *log.Logger + level LogLevel +} + +var ( + defaultLogger *Logger + once sync.Once +) + +// NewLogger creates a new logger instance +func NewLogger() *Logger { + return &Logger{ + logger: log.New(os.Stdout, "", 0), + level: DEBUG, + } +} + +// Init initializes the default logger +func Init() *Logger { + once.Do(func() { + defaultLogger = NewLogger() + }) + return defaultLogger +} + +// GetLogger returns the default logger instance +func GetLogger() *Logger { + if defaultLogger == nil { + Init() + } + return defaultLogger +} + +// SetLevel sets the minimum logging level +func (l *Logger) SetLevel(level LogLevel) { + l.level = level +} + +// log handles the actual logging +func (l *Logger) log(level LogLevel, format string, args ...interface{}) { + if level < l.level { + return + } + timestamp := time.Now().Format("2006/01/02 15:04:05") + message := fmt.Sprintf(format, args...) + l.logger.Printf("%s: %s %s", level.String(), timestamp, message) +} + +// Debug logs debug level messages +func (l *Logger) Debug(format string, args ...interface{}) { + l.log(DEBUG, format, args...) +} + +// Info logs info level messages +func (l *Logger) Info(format string, args ...interface{}) { + l.log(INFO, format, args...) +} + +// Warn logs warning level messages +func (l *Logger) Warn(format string, args ...interface{}) { + l.log(WARN, format, args...) +} + +// Error logs error level messages +func (l *Logger) Error(format string, args ...interface{}) { + l.log(ERROR, format, args...) +} + +// Fatal logs fatal level messages and exits +func (l *Logger) Fatal(format string, args ...interface{}) { + l.log(FATAL, format, args...) + os.Exit(1) +} + +// Global helper functions +func Debug(format string, args ...interface{}) { + GetLogger().Debug(format, args...) +} + +func Info(format string, args ...interface{}) { + GetLogger().Info(format, args...) +} + +func Warn(format string, args ...interface{}) { + GetLogger().Warn(format, args...) +} + +func Error(format string, args ...interface{}) { + GetLogger().Error(format, args...) +} + +func Fatal(format string, args ...interface{}) { + GetLogger().Fatal(format, args...) +} diff --git a/main.go b/main.go index 8129d4a..f9c0fe4 100644 --- a/main.go +++ b/main.go @@ -7,9 +7,9 @@ import ( "encoding/json" "flag" "fmt" - "log" "math/rand" "net/netip" + "newt/logger" "newt/proxy" "newt/websocket" "os" @@ -51,7 +51,7 @@ func fixKey(key string) string { // Decode from base64 decoded, err := base64.StdEncoding.DecodeString(key) if err != nil { - log.Fatal("Error decoding base64:", err) + logger.Fatal("Error decoding base64:", err) } // Convert to hex @@ -59,10 +59,10 @@ func fixKey(key string) string { } func ping(tnet *netstack.Net, dst string) { - log.Printf("Pinging %s", dst) + logger.Info("Pinging %s", dst) socket, err := tnet.Dial("ping4", dst) if err != nil { - log.Panic(err) + logger.Error("Failed to create ICMP socket: %v", err) } requestPing := icmp.Echo{ Seq: rand.Intn(1 << 16), @@ -73,24 +73,56 @@ func ping(tnet *netstack.Net, dst string) { start := time.Now() _, err = socket.Write(icmpBytes) if err != nil { - log.Panic(err) + logger.Error("Failed to write ICMP packet: %v", err) } n, err := socket.Read(icmpBytes[:]) if err != nil { - log.Panic(err) + logger.Error("Failed to read ICMP packet: %v", err) } replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n]) if err != nil { - log.Panic(err) + logger.Error("Failed to parse ICMP packet: %v", err) } replyPing, ok := replyPacket.Body.(*icmp.Echo) if !ok { - log.Panicf("invalid reply type: %v", replyPacket) + logger.Error("invalid reply type: %v", replyPacket) } if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq { - log.Panicf("invalid ping reply: %v", replyPing) + logger.Error("invalid ping reply: %v", replyPing) + } + logger.Info("Ping latency: %v", time.Since(start)) +} + +func parseLogLevel(level string) logger.LogLevel { + switch strings.ToUpper(level) { + case "DEBUG": + return logger.DEBUG + case "INFO": + return logger.INFO + case "WARN": + return logger.WARN + case "ERROR": + return logger.ERROR + case "FATAL": + return logger.FATAL + default: + return logger.INFO // default to INFO if invalid level provided + } +} + +func mapToWireGuardLogLevel(level logger.LogLevel) int { + switch level { + case logger.DEBUG: + return device.LogLevelVerbose + // case logger.INFO: + // return device.LogLevel + case logger.WARN: + return device.LogLevelError + case logger.ERROR, logger.FATAL: + return device.LogLevelSilent + default: + return device.LogLevelSilent } - log.Printf("Ping latency: %v", time.Since(start)) } func main() { @@ -101,18 +133,24 @@ func main() { dns string privateKey wgtypes.Key err error + logLevel string ) flag.StringVar(&endpoint, "endpoint", "http://localhost:3000/api/v1", "Endpoint of your pangolin server") flag.StringVar(&id, "id", "", "Newt ID") flag.StringVar(&secret, "secret", "", "Newt secret") flag.StringVar(&dns, "dns", "8.8.8.8", "DNS server to use") + flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") flag.Parse() + logger.Init() + loggerLevel := parseLogLevel(logLevel) + logger.GetLogger().SetLevel(parseLogLevel(logLevel)) + privateKey, err = wgtypes.GeneratePrivateKey() if err != nil { - log.Fatalf("Failed to generate private key: %v", err) + logger.Fatal("Failed to generate private key: %v", err) } // Create a new client @@ -123,7 +161,7 @@ func main() { websocket.WithBaseURL(endpoint), // TODO: save the endpoint in the config file so we dont have to pass it in every time ) if err != nil { - log.Fatal(err) + logger.Fatal("Failed to create client: %v", err) } // Create TUN device and network stack @@ -137,33 +175,36 @@ func main() { // Register handlers for different message types client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) { if connected { - log.Printf("Already connected! Put I will send a ping anyway...") + logger.Info("Already connected! Put I will send a ping anyway...") ping(tnet, wgData.ServerIP) return } jsonData, err := json.Marshal(msg.Data) if err != nil { - log.Printf("Error marshaling data: %v", err) + logger.Info("Error marshaling data: %v", err) return } if err := json.Unmarshal(jsonData, &wgData); err != nil { - log.Printf("Error unmarshaling target data: %v", err) + logger.Info("Error unmarshaling target data: %v", err) return } - log.Printf("Received: %+v", msg) + logger.Info("Received: %+v", msg) tun, tnet, err = netstack.CreateNetTUN( []netip.Addr{netip.MustParseAddr(wgData.TunnelIP)}, []netip.Addr{netip.MustParseAddr(dns)}, 1420) if err != nil { - log.Panic(err) + logger.Error("Failed to create TUN device: %v", err) } // Create WireGuard device - dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) + dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger( + mapToWireGuardLogLevel(loggerLevel), + "wireguard: ", + )) // Configure WireGuard config := fmt.Sprintf(`private_key=%s @@ -174,16 +215,16 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( err = dev.IpcSet(config) if err != nil { - log.Panic(err) + logger.Error("Failed to configure WireGuard device: %v", err) } // Bring up the device err = dev.Up() if err != nil { - log.Panic(err) + logger.Error("Failed to bring up WireGuard device: %v", err) } - log.Printf("WireGuard device created. Lets ping the server now...") + logger.Info("WireGuard device created. Lets ping the server now...") // Ping to bring the tunnel up on the server side quickly ping(tnet, wgData.ServerIP) @@ -203,17 +244,17 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( }) client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) { - log.Printf("Received: %+v", msg) + logger.Info("Received: %+v", msg) // if there is no wgData or pm, we can't add targets if wgData.TunnelIP == "" || pm == nil { - log.Printf("No tunnel IP or proxy manager available") + logger.Info("No tunnel IP or proxy manager available") return } targetData, err := parseTargetData(msg.Data) if err != nil { - log.Printf("Error parsing target data: %v", err) + logger.Info("Error parsing target data: %v", err) return } @@ -223,17 +264,17 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( }) client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) { - log.Printf("Received: %+v", msg) + logger.Info("Received: %+v", msg) // if there is no wgData or pm, we can't add targets if wgData.TunnelIP == "" || pm == nil { - log.Printf("No tunnel IP or proxy manager available") + logger.Info("No tunnel IP or proxy manager available") return } targetData, err := parseTargetData(msg.Data) if err != nil { - log.Printf("Error parsing target data: %v", err) + logger.Info("Error parsing target data: %v", err) return } @@ -243,17 +284,17 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( }) client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) { - log.Printf("Received: %+v", msg) + logger.Info("Received: %+v", msg) // if there is no wgData or pm, we can't add targets if wgData.TunnelIP == "" || pm == nil { - log.Printf("No tunnel IP or proxy manager available") + logger.Info("No tunnel IP or proxy manager available") return } targetData, err := parseTargetData(msg.Data) if err != nil { - log.Printf("Error parsing target data: %v", err) + logger.Info("Error parsing target data: %v", err) return } @@ -263,17 +304,17 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( }) client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) { - log.Printf("Received: %+v", msg) + logger.Info("Received: %+v", msg) // if there is no wgData or pm, we can't add targets if wgData.TunnelIP == "" || pm == nil { - log.Printf("No tunnel IP or proxy manager available") + logger.Info("No tunnel IP or proxy manager available") return } targetData, err := parseTargetData(msg.Data) if err != nil { - log.Printf("Error parsing target data: %v", err) + logger.Info("Error parsing target data: %v", err) return } @@ -284,18 +325,18 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( // Connect to the WebSocket server if err := client.Connect(); err != nil { - log.Fatal(err) + logger.Fatal("Failed to connect to server: %v", err) } defer client.Close() publicKey := privateKey.PublicKey() - log.Printf("Public key: %s", publicKey) + logger.Info("Public key: %s", publicKey) // TODO: how to retry? err = client.SendMessage("newt/wg/register", map[string]interface{}{ "publicKey": fmt.Sprintf("%s", publicKey), }) if err != nil { - log.Printf("Failed to send message: %v", err) + logger.Info("Failed to send message: %v", err) } // Wait for interrupt signal @@ -311,12 +352,12 @@ func parseTargetData(data interface{}) (TargetData, error) { var targetData TargetData jsonData, err := json.Marshal(data) if err != nil { - log.Printf("Error marshaling data: %v", err) + logger.Info("Error marshaling data: %v", err) return targetData, err } if err := json.Unmarshal(jsonData, &targetData); err != nil { - log.Printf("Error unmarshaling target data: %v", err) + logger.Info("Error unmarshaling target data: %v", err) return targetData, err } return targetData, nil @@ -327,14 +368,14 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto // Stop the proxy manager before adding new targets err := pm.Stop() if err != nil { - log.Panic(err) + logger.Error("Failed to stop proxy manager: %v", err) } for _, t := range targetData.Targets { // Split the first number off of the target with : separator and use as the port parts := strings.Split(t, ":") if len(parts) != 3 { - log.Printf("Invalid target format: %s", t) + logger.Info("Invalid target format: %s", t) continue } @@ -342,7 +383,7 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto port := 0 _, err := fmt.Sscanf(parts[0], "%d", &port) if err != nil { - log.Printf("Invalid port: %s", parts[0]) + logger.Info("Invalid port: %s", parts[0]) continue } @@ -351,14 +392,14 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto pm.RemoveTarget(proto, tunnelIP, port) // remove it first in case this is an update. we are kind of using the internal port as the "targetId" in the proxy pm.AddTarget(proto, tunnelIP, port, target) } else if action == "remove" { - log.Printf("Removing target with port %d", port) + logger.Info("Removing target with port %d", port) pm.RemoveTarget(proto, tunnelIP, port) } } err = pm.Start() if err != nil { - log.Panic(err) + logger.Error("Failed to start proxy manager: %v", err) } return nil diff --git a/proxy/types.go b/proxy/types.go index f1e334f..2886431 100644 --- a/proxy/types.go +++ b/proxy/types.go @@ -1,6 +1,7 @@ package proxy import ( + "log" "net" "sync" @@ -21,5 +22,6 @@ type ProxyTarget struct { type ProxyManager struct { targets []ProxyTarget tnet *netstack.Net + log *log.Logger sync.RWMutex // Protect access to targets slice } diff --git a/websocket/client.go b/websocket/client.go index 34c0665..749cb93 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/json" "fmt" - "log" "net/http" "net/url" "sync" @@ -137,7 +136,6 @@ func (c *Client) readPump() { var msg WSMessage err := c.conn.ReadJSON(&msg) if err != nil { - log.Printf("read error: %v", err) return }