Compare commits

...

4 Commits

Author SHA1 Message Date
Jose Luis Montañes Ojados
803b4049a6 Add --slug flag to reuse a specific subdomain across sessions
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 01:27:17 +01:00
Jose Luis Montañes Ojados
c83fa3530d Add SSH keepalive and auto-reconnection to prevent idle disconnections
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 01:11:45 +01:00
Jose Luis Montañes Ojados
34af75aaa1 Fix panic in WebSocket handling by manually writing headers instead of using r.Write 2026-01-27 13:32:45 +01:00
Jose Luis Montañes Ojados
bec63c3283 Fix POST requests by using standard HTTP tunneling instead of unconditional Hijack 2026-01-27 03:38:22 +01:00
4 changed files with 245 additions and 73 deletions

View File

@@ -37,6 +37,7 @@ func main() {
tokenFlag := flag.String("token", "", "Authentication token (overrides config)") tokenFlag := flag.String("token", "", "Authentication token (overrides config)")
hostHeaderFlag := flag.String("host-header", "", "Custom Host header to send to local service") hostHeaderFlag := flag.String("host-header", "", "Custom Host header to send to local service")
localHttpsFlag := flag.Bool("local-https", false, "Use HTTPS to connect to local service (implied if port is 443)") localHttpsFlag := flag.Bool("local-https", false, "Use HTTPS to connect to local service (implied if port is 443)")
slugFlag := flag.String("slug", "", "Request a specific subdomain slug (e.g., myapp)")
flag.Parse() flag.Parse()
// Load config // Load config
@@ -58,7 +59,7 @@ func main() {
serverAddr = "localhost:2222" serverAddr = "localhost:2222"
} }
m := tui.InitialModel(*localPort, serverAddr, authToken, *hostHeaderFlag, *localHttpsFlag) m := tui.InitialModel(*localPort, serverAddr, authToken, *hostHeaderFlag, *localHttpsFlag, *slugFlag)
p := tea.NewProgram(m, tea.WithAltScreen()) p := tea.NewProgram(m, tea.WithAltScreen())
if _, err := p.Run(); err != nil { if _, err := p.Run(); err != nil {

View File

@@ -1,12 +1,14 @@
package main package main
import ( import (
"bufio"
"fmt" "fmt"
"io" "io"
"log" "log"
"math/rand" "math/rand"
"net/http" "net/http"
"os" "os"
"regexp"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -43,6 +45,13 @@ var manager = &TunnelManager{
tunnels: make(map[string]*Tunnel), tunnels: make(map[string]*Tunnel),
} }
// requestedSlugs stores slugs requested by clients before tcpip-forward
var (
requestedSlugs = make(map[*gossh.ServerConn]string)
requestedSlugsMu sync.Mutex
slugPattern = regexp.MustCompile(`^[a-z0-9][a-z0-9-]{1,30}[a-z0-9]$`)
)
func (tm *TunnelManager) Register(id string, t *Tunnel) { func (tm *TunnelManager) Register(id string, t *Tunnel) {
tm.mu.Lock() tm.mu.Lock()
defer tm.mu.Unlock() defer tm.mu.Unlock()
@@ -91,9 +100,13 @@ func startHttpProxy(port string) {
} }
slug := parts[0] slug := parts[0]
log.Printf("Request Host: %s, Extracted Slug: %s", host, slug)
tunnel, ok := manager.Get(slug) tunnel, ok := manager.Get(slug)
if !ok { if !ok {
log.Printf("Tunnel not found for slug: %s", slug)
// Serve 404 page // Serve 404 page
w.WriteHeader(http.StatusNotFound)
http.ServeFile(w, r, "./cmd/server/static/404.html") http.ServeFile(w, r, "./cmd/server/static/404.html")
return return
} }
@@ -125,21 +138,15 @@ func startHttpProxy(port string) {
// But typically `gliderlabs/ssh` is for allowing the server to be a jump host. // But typically `gliderlabs/ssh` is for allowing the server to be a jump host.
// We want to be an HTTP Gateway. // We want to be an HTTP Gateway.
// 1. Hijack the connection to handle bidirectional traffic (WebSockets) // Panic recovery
hijacker, ok := w.(http.Hijacker) defer func() {
if !ok { if r := recover(); r != nil {
http.Error(w, "Hijacking not supported", http.StatusInternalServerError) log.Printf("Recovered from panic in handler: %v", r)
return http.Error(w, "Internal Server Error", http.StatusInternalServerError)
} }
clientConn, bufrw, err := hijacker.Hijack() }()
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
defer clientConn.Close()
// 2. Open channel to client // 2. Open channel to client
// "forwarded-tcpip" arguments:
destHost := "0.0.0.0" destHost := "0.0.0.0"
destPort := tunnel.LocalPort destPort := tunnel.LocalPort
srcHost := "127.0.0.1" srcHost := "127.0.0.1"
@@ -161,36 +168,103 @@ func startHttpProxy(port string) {
defer ch.Close() defer ch.Close()
go gossh.DiscardRequests(reqs) go gossh.DiscardRequests(reqs)
// Check if it is a WebSocket Upgrade
isWebSocket := false
if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" {
isWebSocket = true
}
if isWebSocket {
// WEBSOCKET STRATEGY: Hijack and bidirectional copy
hijacker, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
return
}
clientConn, bufrw, err := hijacker.Hijack()
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
defer clientConn.Close()
// Manual Request writing to avoid touching Body after Hijack/Panic
// Request Line
reqLine := fmt.Sprintf("%s %s %s\r\n", r.Method, r.RequestURI, r.Proto)
if _, err := io.WriteString(ch, reqLine); err != nil {
log.Printf("Error writing websocket request line: %v", err)
return
}
// Headers
if err := r.Header.Write(ch); err != nil {
log.Printf("Error writing websocket headers: %v", err)
return
}
// End of headers
if _, err := io.WriteString(ch, "\r\n"); err != nil {
log.Printf("Error writing websocket header terminator: %v", err)
return
}
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
// 3. Browser -> Backend (Write request + Copy raw stream) // Copy existing buffer from hijack + future reads -> backend
go func() { go func() {
defer wg.Done() defer func() {
// Write the initial request (Method, Path, Headers) if r := recover(); r != nil {
// This sets up the handshake or request. log.Printf("Recovered from panic in WS writer: %v", r)
// Note: We use r.Write to reconstruct the request line and headers.
// For WebSockets, the Body is empty, so this writes headers and returns.
// For POSTs, it writes headers and tries to copy Body.
if err := r.Write(ch); err != nil {
log.Printf("Error writing request to backend: %v", err)
return
} }
// Important: Continue copying any subsequent data (like WebSocket frames) wg.Done()
// from the hijacked buffer/connection to the channel. }()
io.Copy(ch, bufrw) if bufrw.Reader.Buffered() > 0 {
// e.g. when browser closes or stops sending, we are done here. io.CopyN(ch, bufrw, int64(bufrw.Reader.Buffered()))
}
io.Copy(ch, clientConn)
}() }()
// 4. Backend -> Browser (Copy raw stream) // Backend -> Browser
go func() { go func() {
defer wg.Done() defer func() {
if r := recover(); r != nil {
log.Printf("Recovered from panic in WS reader: %v", r)
}
wg.Done()
}()
io.Copy(clientConn, ch) io.Copy(clientConn, ch)
// When backend closes connection, close browser connection
clientConn.Close()
}() }()
wg.Wait() wg.Wait()
return
}
// STANDARD HTTP STRATEGY: Request/Response Tunneling
// 1. Write Request to Channel (simulating wire)
if err := r.Write(ch); err != nil {
log.Printf("Error writing request to backend: %v", err)
http.Error(w, "Gateway Error", http.StatusBadGateway)
return
}
// 2. Read Response from Channel (parsing wire)
resp, err := http.ReadResponse(bufio.NewReader(ch), r)
if err != nil {
log.Printf("Error reading response from backend: %v", err)
http.Error(w, "Bad Gateway", http.StatusBadGateway)
return
}
defer resp.Body.Close()
// 3. Copy Headers to `w`
for k, vv := range resp.Header {
for _, v := range vv {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
// 4. Copy Body to `w`
io.Copy(w, resp.Body)
}) })
@@ -239,28 +313,22 @@ func main() {
// This is where clients say "Please listen on port X" // This is where clients say "Please listen on port X"
sshServer.RequestHandlers = map[string]ssh.RequestHandler{ sshServer.RequestHandlers = map[string]ssh.RequestHandler{
"tcpip-forward": func(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) { "tcpip-forward": func(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) {
// Parse payload conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
// string address to bind (usually empty or 0.0.0.0 or 127.0.0.1)
// uint32 port number to bind
// For Grokway, we ignore the requested port and assign a random subdomain/slug // Check if client requested a specific slug
// Or we use the requested port if valid? requestedSlugsMu.Lock()
// Let's assume we ignore it and generate a slug, slug, hasRequested := requestedSlugs[conn]
// OR we use the port as the "slug" if it's special. if hasRequested {
delete(requestedSlugs, conn)
}
requestedSlugsMu.Unlock()
// But wait, the client needs to know what we assigned! if !hasRequested {
// Standard SSH response to tcpip-forward contains the BOUND PORT. slug = generateSlug(8)
// If we return a port, the client knows. }
// If we want to return a URL, standard SSH doesn't have a field for that.
// But we can write it to the Session (stdout).
// Let's accept any request and map it to a random slug.
slug := generateSlug(8)
log.Printf("Client requested forwarding. Assigning slug: %s", slug) log.Printf("Client requested forwarding. Assigning slug: %s", slug)
conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
manager.Register(slug, &Tunnel{ manager.Register(slug, &Tunnel{
ID: slug, ID: slug,
LocalPort: 80, // We assume client forwards to port 80 locally? No, LocalPort: 80, // We assume client forwards to port 80 locally? No,
@@ -285,6 +353,24 @@ func main() {
"cancel-tcpip-forward": func(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) { "cancel-tcpip-forward": func(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) {
return true, nil return true, nil
}, },
"grokway-request-slug": func(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) {
slug := string(req.Payload)
if !slugPattern.MatchString(slug) {
log.Printf("Invalid slug format requested: %q", slug)
return false, []byte("invalid slug format (lowercase alphanumeric and hyphens, 3-32 chars)")
}
// Check if slug is already in use
if _, taken := manager.Get(slug); taken {
log.Printf("Slug %q already in use", slug)
return false, []byte("slug already in use")
}
conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
requestedSlugsMu.Lock()
requestedSlugs[conn] = slug
requestedSlugsMu.Unlock()
log.Printf("Client requested slug: %s", slug)
return true, nil
},
"grokway-whoami": func(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) { "grokway-whoami": func(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) {
conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn) conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
slug, ok := manager.FindSlugByConn(conn) slug, ok := manager.FindSlugByConn(conn)

View File

@@ -68,8 +68,8 @@ type LogMsg string
type MetricMsg int64 type MetricMsg int64
type ClearCopiedMsg struct{} type ClearCopiedMsg struct{}
func InitialModel(localPort, serverAddr, authToken, hostHeader string, localHTTPS bool) Model { func InitialModel(localPort, serverAddr, authToken, hostHeader string, localHTTPS bool, slug string) Model {
c := tunnel.NewClient(serverAddr, localPort, authToken, hostHeader, localHTTPS) c := tunnel.NewClient(serverAddr, localPort, authToken, hostHeader, localHTTPS, slug)
return Model{ return Model{
Client: c, Client: c,
LogLines: []string{}, LogLines: []string{},

View File

@@ -24,28 +24,31 @@ type Client struct {
ServerAddr string ServerAddr string
LocalPort string LocalPort string
AuthToken string AuthToken string
HostHeader string // New field for custom Host header HostHeader string // Custom Host header
LocalHTTPS bool // New field: connect to local service with HTTPS LocalHTTPS bool // Connect to local service with HTTPS
Slug string // Requested slug (empty = server assigns random)
SSHClient *ssh.Client SSHClient *ssh.Client
Listener net.Listener Listener net.Listener
Events chan string // Channel to send logs/events to TUI Events chan string // Channel to send logs/events to TUI
Metrics chan int64 // Channel to send bytes transferred Metrics chan int64 // Channel to send bytes transferred
PublicURL string // PublicURL is the URL accessible from the internet PublicURL string // PublicURL is the URL accessible from the internet
stopKeepAlive chan struct{} // Signal to stop keepalive goroutine
} }
func NewClient(serverAddr, localPort, authToken, hostHeader string, localHTTPS bool) *Client { func NewClient(serverAddr, localPort, authToken, hostHeader string, localHTTPS bool, slug string) *Client {
return &Client{ return &Client{
ServerAddr: serverAddr, ServerAddr: serverAddr,
LocalPort: localPort, LocalPort: localPort,
AuthToken: authToken, AuthToken: authToken,
HostHeader: hostHeader, HostHeader: hostHeader,
LocalHTTPS: localHTTPS, LocalHTTPS: localHTTPS,
Slug: slug,
Events: make(chan string, 10), Events: make(chan string, 10),
Metrics: make(chan int64, 10), Metrics: make(chan int64, 10),
} }
} }
func (c *Client) Start() error { func (c *Client) connect() error {
config := &ssh.ClientConfig{ config := &ssh.ClientConfig{
User: "grokway", User: "grokway",
Auth: []ssh.AuthMethod{ Auth: []ssh.AuthMethod{
@@ -63,6 +66,20 @@ func (c *Client) Start() error {
c.SSHClient = client c.SSHClient = client
c.Events <- "SSH Connected!" c.Events <- "SSH Connected!"
// Request a specific slug if provided
if c.Slug != "" {
ok, reply, err := client.SendRequest("grokway-request-slug", true, []byte(c.Slug))
if err != nil || !ok {
reason := string(reply)
if reason == "" && err != nil {
reason = err.Error()
}
c.Events <- fmt.Sprintf("Slug %q not available (%s), server will assign one", c.Slug, reason)
} else {
c.Events <- fmt.Sprintf("Slug %q reserved", c.Slug)
}
}
// Request remote listening (Reverse Forwarding) // Request remote listening (Reverse Forwarding)
// Bind to 0.0.0.0 on server, random port (0) // Bind to 0.0.0.0 on server, random port (0)
listener, err := client.Listen("tcp", "0.0.0.0:0") listener, err := client.Listen("tcp", "0.0.0.0:0")
@@ -93,18 +110,86 @@ func (c *Client) Start() error {
c.PublicURL = fmt.Sprintf("https://%s.%s", slug, host) c.PublicURL = fmt.Sprintf("https://%s.%s", slug, host)
c.Events <- fmt.Sprintf("Tunnel established! Public URL: %s", c.PublicURL) c.Events <- fmt.Sprintf("Tunnel established! Public URL: %s", c.PublicURL)
go c.acceptLoop() // Start SSH keepalive to prevent idle disconnection
c.stopKeepAlive = make(chan struct{})
go c.keepAlive()
return nil return nil
} }
func (c *Client) Start() error {
if err := c.connect(); err != nil {
return err
}
go c.acceptLoop()
return nil
}
// keepAlive sends periodic SSH keepalive requests to prevent
// firewalls/NATs from dropping idle connections.
func (c *Client) keepAlive() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-c.stopKeepAlive:
return
case <-ticker.C:
if c.SSHClient != nil {
_, _, err := c.SSHClient.SendRequest("keepalive@openssh.com", true, nil)
if err != nil {
logToFile(fmt.Sprintf("Keepalive failed: %v", err))
return
}
}
}
}
}
func (c *Client) closeConnection() {
if c.stopKeepAlive != nil {
select {
case <-c.stopKeepAlive:
// Already closed
default:
close(c.stopKeepAlive)
}
}
if c.Listener != nil {
c.Listener.Close()
}
if c.SSHClient != nil {
c.SSHClient.Close()
}
}
func (c *Client) acceptLoop() { func (c *Client) acceptLoop() {
for { for {
remoteConn, err := c.Listener.Accept() remoteConn, err := c.Listener.Accept()
if err != nil { if err != nil {
c.Events <- fmt.Sprintf("Accept error: %s", err) c.Events <- fmt.Sprintf("Connection lost: %s", err)
c.closeConnection()
// Reconnect with backoff
delay := 2 * time.Second
maxDelay := 60 * time.Second
for {
c.Events <- fmt.Sprintf("Reconnecting in %s...", delay)
time.Sleep(delay)
if err := c.connect(); err != nil {
c.Events <- fmt.Sprintf("Reconnection failed: %s", err)
delay *= 2
if delay > maxDelay {
delay = maxDelay
}
continue
}
c.Events <- "Reconnected successfully!"
break break
} }
continue
}
c.Events <- "New Request received" c.Events <- "New Request received"
go c.handleConnection(remoteConn) go c.handleConnection(remoteConn)