diff --git a/cmd/server/main.go b/cmd/server/main.go index 23dad0f..7b5a371 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -1,6 +1,7 @@ package main import ( + "bufio" "fmt" "io" "log" @@ -91,9 +92,13 @@ func startHttpProxy(port string) { } slug := parts[0] + log.Printf("Request Host: %s, Extracted Slug: %s", host, slug) + tunnel, ok := manager.Get(slug) if !ok { + log.Printf("Tunnel not found for slug: %s", slug) // Serve 404 page + w.WriteHeader(http.StatusNotFound) http.ServeFile(w, r, "./cmd/server/static/404.html") return } @@ -125,21 +130,15 @@ func startHttpProxy(port string) { // But typically `gliderlabs/ssh` is for allowing the server to be a jump host. // We want to be an HTTP Gateway. - // 1. Hijack the connection to handle bidirectional traffic (WebSockets) - hijacker, ok := w.(http.Hijacker) - if !ok { - http.Error(w, "Hijacking not supported", http.StatusInternalServerError) - return - } - clientConn, bufrw, err := hijacker.Hijack() - if err != nil { - http.Error(w, err.Error(), http.StatusServiceUnavailable) - return - } - defer clientConn.Close() + // Panic recovery + defer func() { + if r := recover(); r != nil { + log.Printf("Recovered from panic in handler: %v", r) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + }() // 2. Open channel to client - // "forwarded-tcpip" arguments: destHost := "0.0.0.0" destPort := tunnel.LocalPort srcHost := "127.0.0.1" @@ -161,36 +160,82 @@ func startHttpProxy(port string) { defer ch.Close() go gossh.DiscardRequests(reqs) - var wg sync.WaitGroup - wg.Add(2) + // Check if it is a WebSocket Upgrade + isWebSocket := false + if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" { + isWebSocket = true + } - // 3. Browser -> Backend (Write request + Copy raw stream) - go func() { - defer wg.Done() - // Write the initial request (Method, Path, Headers) - // This sets up the handshake or request. - // Note: We use r.Write to reconstruct the request line and headers. - // For WebSockets, the Body is empty, so this writes headers and returns. - // For POSTs, it writes headers and tries to copy Body. - if err := r.Write(ch); err != nil { - log.Printf("Error writing request to backend: %v", err) + if isWebSocket { + // WEBSOCKET STRATEGY: Hijack and bidirectional copy + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "Hijacking not supported", http.StatusInternalServerError) return } - // Important: Continue copying any subsequent data (like WebSocket frames) - // from the hijacked buffer/connection to the channel. - io.Copy(ch, bufrw) - // e.g. when browser closes or stops sending, we are done here. - }() + clientConn, bufrw, err := hijacker.Hijack() + if err != nil { + http.Error(w, err.Error(), http.StatusServiceUnavailable) + return + } + defer clientConn.Close() - // 4. Backend -> Browser (Copy raw stream) - go func() { - defer wg.Done() - io.Copy(clientConn, ch) - // When backend closes connection, close browser connection - clientConn.Close() - }() + // Reconstruct request line and headers to send to backend + // We can use r.Write but it writes to the channel + if err := r.Write(ch); err != nil { + log.Printf("Error writing websocket request to backend: %v", err) + return + } - wg.Wait() + var wg sync.WaitGroup + wg.Add(2) + + // Copy existing buffer from hijack + future reads -> backend + go func() { + defer wg.Done() + if bufrw.Reader.Buffered() > 0 { + io.CopyN(ch, bufrw, int64(bufrw.Reader.Buffered())) + } + io.Copy(ch, clientConn) + }() + + // Backend -> Browser + go func() { + defer wg.Done() + io.Copy(clientConn, ch) + }() + + wg.Wait() + return + } + + // STANDARD HTTP STRATEGY: Request/Response Tunneling + // 1. Write Request to Channel (simulating wire) + if err := r.Write(ch); err != nil { + log.Printf("Error writing request to backend: %v", err) + http.Error(w, "Gateway Error", http.StatusBadGateway) + return + } + + // 2. Read Response from Channel (parsing wire) + resp, err := http.ReadResponse(bufio.NewReader(ch), r) + if err != nil { + log.Printf("Error reading response from backend: %v", err) + http.Error(w, "Bad Gateway", http.StatusBadGateway) + return + } + defer resp.Body.Close() + + // 3. Copy Headers to `w` + for k, vv := range resp.Header { + for _, v := range vv { + w.Header().Add(k, v) + } + } + w.WriteHeader(resp.StatusCode) + + // 4. Copy Body to `w` + io.Copy(w, resp.Body) })