diff --git a/internal/tunnel/client.go b/internal/tunnel/client.go index 0d7b045..2a66a06 100644 --- a/internal/tunnel/client.go +++ b/internal/tunnel/client.go @@ -31,6 +31,7 @@ type Client struct { Events chan string // Channel to send logs/events to TUI Metrics chan int64 // Channel to send bytes transferred PublicURL string // PublicURL is the URL accessible from the internet + stopKeepAlive chan struct{} // Signal to stop keepalive goroutine } func NewClient(serverAddr, localPort, authToken, hostHeader string, localHTTPS bool) *Client { @@ -45,7 +46,7 @@ func NewClient(serverAddr, localPort, authToken, hostHeader string, localHTTPS b } } -func (c *Client) Start() error { +func (c *Client) connect() error { config := &ssh.ClientConfig{ User: "grokway", Auth: []ssh.AuthMethod{ @@ -93,17 +94,85 @@ func (c *Client) Start() error { c.PublicURL = fmt.Sprintf("https://%s.%s", slug, host) c.Events <- fmt.Sprintf("Tunnel established! Public URL: %s", c.PublicURL) - go c.acceptLoop() + // Start SSH keepalive to prevent idle disconnection + c.stopKeepAlive = make(chan struct{}) + go c.keepAlive() return nil } +func (c *Client) Start() error { + if err := c.connect(); err != nil { + return err + } + go c.acceptLoop() + return nil +} + +// keepAlive sends periodic SSH keepalive requests to prevent +// firewalls/NATs from dropping idle connections. +func (c *Client) keepAlive() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-c.stopKeepAlive: + return + case <-ticker.C: + if c.SSHClient != nil { + _, _, err := c.SSHClient.SendRequest("keepalive@openssh.com", true, nil) + if err != nil { + logToFile(fmt.Sprintf("Keepalive failed: %v", err)) + return + } + } + } + } +} + +func (c *Client) closeConnection() { + if c.stopKeepAlive != nil { + select { + case <-c.stopKeepAlive: + // Already closed + default: + close(c.stopKeepAlive) + } + } + if c.Listener != nil { + c.Listener.Close() + } + if c.SSHClient != nil { + c.SSHClient.Close() + } +} + func (c *Client) acceptLoop() { for { remoteConn, err := c.Listener.Accept() if err != nil { - c.Events <- fmt.Sprintf("Accept error: %s", err) - break + c.Events <- fmt.Sprintf("Connection lost: %s", err) + c.closeConnection() + + // Reconnect with backoff + delay := 2 * time.Second + maxDelay := 60 * time.Second + for { + c.Events <- fmt.Sprintf("Reconnecting in %s...", delay) + time.Sleep(delay) + + if err := c.connect(); err != nil { + c.Events <- fmt.Sprintf("Reconnection failed: %s", err) + delay *= 2 + if delay > maxDelay { + delay = maxDelay + } + continue + } + c.Events <- "Reconnected successfully!" + break + } + continue } c.Events <- "New Request received"