Compare commits

..

3 Commits

Author SHA1 Message Date
Jose Luis Montañes Ojados
c83fa3530d Add SSH keepalive and auto-reconnection to prevent idle disconnections
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 01:11:45 +01:00
Jose Luis Montañes Ojados
34af75aaa1 Fix panic in WebSocket handling by manually writing headers instead of using r.Write 2026-01-27 13:32:45 +01:00
Jose Luis Montañes Ojados
bec63c3283 Fix POST requests by using standard HTTP tunneling instead of unconditional Hijack 2026-01-27 03:38:22 +01:00
2 changed files with 177 additions and 42 deletions

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"bufio"
"fmt" "fmt"
"io" "io"
"log" "log"
@@ -91,9 +92,13 @@ func startHttpProxy(port string) {
} }
slug := parts[0] slug := parts[0]
log.Printf("Request Host: %s, Extracted Slug: %s", host, slug)
tunnel, ok := manager.Get(slug) tunnel, ok := manager.Get(slug)
if !ok { if !ok {
log.Printf("Tunnel not found for slug: %s", slug)
// Serve 404 page // Serve 404 page
w.WriteHeader(http.StatusNotFound)
http.ServeFile(w, r, "./cmd/server/static/404.html") http.ServeFile(w, r, "./cmd/server/static/404.html")
return return
} }
@@ -125,21 +130,15 @@ func startHttpProxy(port string) {
// But typically `gliderlabs/ssh` is for allowing the server to be a jump host. // But typically `gliderlabs/ssh` is for allowing the server to be a jump host.
// We want to be an HTTP Gateway. // We want to be an HTTP Gateway.
// 1. Hijack the connection to handle bidirectional traffic (WebSockets) // Panic recovery
hijacker, ok := w.(http.Hijacker) defer func() {
if !ok { if r := recover(); r != nil {
http.Error(w, "Hijacking not supported", http.StatusInternalServerError) log.Printf("Recovered from panic in handler: %v", r)
return http.Error(w, "Internal Server Error", http.StatusInternalServerError)
} }
clientConn, bufrw, err := hijacker.Hijack() }()
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
defer clientConn.Close()
// 2. Open channel to client // 2. Open channel to client
// "forwarded-tcpip" arguments:
destHost := "0.0.0.0" destHost := "0.0.0.0"
destPort := tunnel.LocalPort destPort := tunnel.LocalPort
srcHost := "127.0.0.1" srcHost := "127.0.0.1"
@@ -161,36 +160,103 @@ func startHttpProxy(port string) {
defer ch.Close() defer ch.Close()
go gossh.DiscardRequests(reqs) go gossh.DiscardRequests(reqs)
var wg sync.WaitGroup // Check if it is a WebSocket Upgrade
wg.Add(2) isWebSocket := false
if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" {
isWebSocket = true
}
// 3. Browser -> Backend (Write request + Copy raw stream) if isWebSocket {
go func() { // WEBSOCKET STRATEGY: Hijack and bidirectional copy
defer wg.Done() hijacker, ok := w.(http.Hijacker)
// Write the initial request (Method, Path, Headers) if !ok {
// This sets up the handshake or request. http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
// Note: We use r.Write to reconstruct the request line and headers.
// For WebSockets, the Body is empty, so this writes headers and returns.
// For POSTs, it writes headers and tries to copy Body.
if err := r.Write(ch); err != nil {
log.Printf("Error writing request to backend: %v", err)
return return
} }
// Important: Continue copying any subsequent data (like WebSocket frames) clientConn, bufrw, err := hijacker.Hijack()
// from the hijacked buffer/connection to the channel. if err != nil {
io.Copy(ch, bufrw) http.Error(w, err.Error(), http.StatusServiceUnavailable)
// e.g. when browser closes or stops sending, we are done here. return
}() }
defer clientConn.Close()
// 4. Backend -> Browser (Copy raw stream) // Manual Request writing to avoid touching Body after Hijack/Panic
go func() { // Request Line
defer wg.Done() reqLine := fmt.Sprintf("%s %s %s\r\n", r.Method, r.RequestURI, r.Proto)
io.Copy(clientConn, ch) if _, err := io.WriteString(ch, reqLine); err != nil {
// When backend closes connection, close browser connection log.Printf("Error writing websocket request line: %v", err)
clientConn.Close() return
}() }
// Headers
if err := r.Header.Write(ch); err != nil {
log.Printf("Error writing websocket headers: %v", err)
return
}
// End of headers
if _, err := io.WriteString(ch, "\r\n"); err != nil {
log.Printf("Error writing websocket header terminator: %v", err)
return
}
wg.Wait() var wg sync.WaitGroup
wg.Add(2)
// Copy existing buffer from hijack + future reads -> backend
go func() {
defer func() {
if r := recover(); r != nil {
log.Printf("Recovered from panic in WS writer: %v", r)
}
wg.Done()
}()
if bufrw.Reader.Buffered() > 0 {
io.CopyN(ch, bufrw, int64(bufrw.Reader.Buffered()))
}
io.Copy(ch, clientConn)
}()
// Backend -> Browser
go func() {
defer func() {
if r := recover(); r != nil {
log.Printf("Recovered from panic in WS reader: %v", r)
}
wg.Done()
}()
io.Copy(clientConn, ch)
}()
wg.Wait()
return
}
// STANDARD HTTP STRATEGY: Request/Response Tunneling
// 1. Write Request to Channel (simulating wire)
if err := r.Write(ch); err != nil {
log.Printf("Error writing request to backend: %v", err)
http.Error(w, "Gateway Error", http.StatusBadGateway)
return
}
// 2. Read Response from Channel (parsing wire)
resp, err := http.ReadResponse(bufio.NewReader(ch), r)
if err != nil {
log.Printf("Error reading response from backend: %v", err)
http.Error(w, "Bad Gateway", http.StatusBadGateway)
return
}
defer resp.Body.Close()
// 3. Copy Headers to `w`
for k, vv := range resp.Header {
for _, v := range vv {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
// 4. Copy Body to `w`
io.Copy(w, resp.Body)
}) })

View File

@@ -31,6 +31,7 @@ type Client struct {
Events chan string // Channel to send logs/events to TUI Events chan string // Channel to send logs/events to TUI
Metrics chan int64 // Channel to send bytes transferred Metrics chan int64 // Channel to send bytes transferred
PublicURL string // PublicURL is the URL accessible from the internet PublicURL string // PublicURL is the URL accessible from the internet
stopKeepAlive chan struct{} // Signal to stop keepalive goroutine
} }
func NewClient(serverAddr, localPort, authToken, hostHeader string, localHTTPS bool) *Client { func NewClient(serverAddr, localPort, authToken, hostHeader string, localHTTPS bool) *Client {
@@ -45,7 +46,7 @@ func NewClient(serverAddr, localPort, authToken, hostHeader string, localHTTPS b
} }
} }
func (c *Client) Start() error { func (c *Client) connect() error {
config := &ssh.ClientConfig{ config := &ssh.ClientConfig{
User: "grokway", User: "grokway",
Auth: []ssh.AuthMethod{ Auth: []ssh.AuthMethod{
@@ -93,17 +94,85 @@ func (c *Client) Start() error {
c.PublicURL = fmt.Sprintf("https://%s.%s", slug, host) c.PublicURL = fmt.Sprintf("https://%s.%s", slug, host)
c.Events <- fmt.Sprintf("Tunnel established! Public URL: %s", c.PublicURL) c.Events <- fmt.Sprintf("Tunnel established! Public URL: %s", c.PublicURL)
go c.acceptLoop() // Start SSH keepalive to prevent idle disconnection
c.stopKeepAlive = make(chan struct{})
go c.keepAlive()
return nil return nil
} }
func (c *Client) Start() error {
if err := c.connect(); err != nil {
return err
}
go c.acceptLoop()
return nil
}
// keepAlive sends periodic SSH keepalive requests to prevent
// firewalls/NATs from dropping idle connections.
func (c *Client) keepAlive() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-c.stopKeepAlive:
return
case <-ticker.C:
if c.SSHClient != nil {
_, _, err := c.SSHClient.SendRequest("keepalive@openssh.com", true, nil)
if err != nil {
logToFile(fmt.Sprintf("Keepalive failed: %v", err))
return
}
}
}
}
}
func (c *Client) closeConnection() {
if c.stopKeepAlive != nil {
select {
case <-c.stopKeepAlive:
// Already closed
default:
close(c.stopKeepAlive)
}
}
if c.Listener != nil {
c.Listener.Close()
}
if c.SSHClient != nil {
c.SSHClient.Close()
}
}
func (c *Client) acceptLoop() { func (c *Client) acceptLoop() {
for { for {
remoteConn, err := c.Listener.Accept() remoteConn, err := c.Listener.Accept()
if err != nil { if err != nil {
c.Events <- fmt.Sprintf("Accept error: %s", err) c.Events <- fmt.Sprintf("Connection lost: %s", err)
break c.closeConnection()
// Reconnect with backoff
delay := 2 * time.Second
maxDelay := 60 * time.Second
for {
c.Events <- fmt.Sprintf("Reconnecting in %s...", delay)
time.Sleep(delay)
if err := c.connect(); err != nil {
c.Events <- fmt.Sprintf("Reconnection failed: %s", err)
delay *= 2
if delay > maxDelay {
delay = maxDelay
}
continue
}
c.Events <- "Reconnected successfully!"
break
}
continue
} }
c.Events <- "New Request received" c.Events <- "New Request received"