Compare commits
4 Commits
736922ca77
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
803b4049a6 | ||
|
|
c83fa3530d | ||
|
|
34af75aaa1 | ||
|
|
bec63c3283 |
@@ -37,6 +37,7 @@ func main() {
|
||||
tokenFlag := flag.String("token", "", "Authentication token (overrides config)")
|
||||
hostHeaderFlag := flag.String("host-header", "", "Custom Host header to send to local service")
|
||||
localHttpsFlag := flag.Bool("local-https", false, "Use HTTPS to connect to local service (implied if port is 443)")
|
||||
slugFlag := flag.String("slug", "", "Request a specific subdomain slug (e.g., myapp)")
|
||||
flag.Parse()
|
||||
|
||||
// Load config
|
||||
@@ -58,7 +59,7 @@ func main() {
|
||||
serverAddr = "localhost:2222"
|
||||
}
|
||||
|
||||
m := tui.InitialModel(*localPort, serverAddr, authToken, *hostHeaderFlag, *localHttpsFlag)
|
||||
m := tui.InitialModel(*localPort, serverAddr, authToken, *hostHeaderFlag, *localHttpsFlag, *slugFlag)
|
||||
p := tea.NewProgram(m, tea.WithAltScreen())
|
||||
|
||||
if _, err := p.Run(); err != nil {
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -43,6 +45,13 @@ var manager = &TunnelManager{
|
||||
tunnels: make(map[string]*Tunnel),
|
||||
}
|
||||
|
||||
// requestedSlugs stores slugs requested by clients before tcpip-forward
|
||||
var (
|
||||
requestedSlugs = make(map[*gossh.ServerConn]string)
|
||||
requestedSlugsMu sync.Mutex
|
||||
slugPattern = regexp.MustCompile(`^[a-z0-9][a-z0-9-]{1,30}[a-z0-9]$`)
|
||||
)
|
||||
|
||||
func (tm *TunnelManager) Register(id string, t *Tunnel) {
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
@@ -91,9 +100,13 @@ func startHttpProxy(port string) {
|
||||
}
|
||||
slug := parts[0]
|
||||
|
||||
log.Printf("Request Host: %s, Extracted Slug: %s", host, slug)
|
||||
|
||||
tunnel, ok := manager.Get(slug)
|
||||
if !ok {
|
||||
log.Printf("Tunnel not found for slug: %s", slug)
|
||||
// Serve 404 page
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
http.ServeFile(w, r, "./cmd/server/static/404.html")
|
||||
return
|
||||
}
|
||||
@@ -125,21 +138,15 @@ func startHttpProxy(port string) {
|
||||
// But typically `gliderlabs/ssh` is for allowing the server to be a jump host.
|
||||
// We want to be an HTTP Gateway.
|
||||
|
||||
// 1. Hijack the connection to handle bidirectional traffic (WebSockets)
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
|
||||
return
|
||||
// Panic recovery
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("Recovered from panic in handler: %v", r)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}
|
||||
clientConn, bufrw, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
defer clientConn.Close()
|
||||
}()
|
||||
|
||||
// 2. Open channel to client
|
||||
// "forwarded-tcpip" arguments:
|
||||
destHost := "0.0.0.0"
|
||||
destPort := tunnel.LocalPort
|
||||
srcHost := "127.0.0.1"
|
||||
@@ -161,36 +168,103 @@ func startHttpProxy(port string) {
|
||||
defer ch.Close()
|
||||
go gossh.DiscardRequests(reqs)
|
||||
|
||||
// Check if it is a WebSocket Upgrade
|
||||
isWebSocket := false
|
||||
if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" {
|
||||
isWebSocket = true
|
||||
}
|
||||
|
||||
if isWebSocket {
|
||||
// WEBSOCKET STRATEGY: Hijack and bidirectional copy
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
clientConn, bufrw, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
defer clientConn.Close()
|
||||
|
||||
// Manual Request writing to avoid touching Body after Hijack/Panic
|
||||
// Request Line
|
||||
reqLine := fmt.Sprintf("%s %s %s\r\n", r.Method, r.RequestURI, r.Proto)
|
||||
if _, err := io.WriteString(ch, reqLine); err != nil {
|
||||
log.Printf("Error writing websocket request line: %v", err)
|
||||
return
|
||||
}
|
||||
// Headers
|
||||
if err := r.Header.Write(ch); err != nil {
|
||||
log.Printf("Error writing websocket headers: %v", err)
|
||||
return
|
||||
}
|
||||
// End of headers
|
||||
if _, err := io.WriteString(ch, "\r\n"); err != nil {
|
||||
log.Printf("Error writing websocket header terminator: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
// 3. Browser -> Backend (Write request + Copy raw stream)
|
||||
// Copy existing buffer from hijack + future reads -> backend
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
// Write the initial request (Method, Path, Headers)
|
||||
// This sets up the handshake or request.
|
||||
// Note: We use r.Write to reconstruct the request line and headers.
|
||||
// For WebSockets, the Body is empty, so this writes headers and returns.
|
||||
// For POSTs, it writes headers and tries to copy Body.
|
||||
if err := r.Write(ch); err != nil {
|
||||
log.Printf("Error writing request to backend: %v", err)
|
||||
return
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("Recovered from panic in WS writer: %v", r)
|
||||
}
|
||||
// Important: Continue copying any subsequent data (like WebSocket frames)
|
||||
// from the hijacked buffer/connection to the channel.
|
||||
io.Copy(ch, bufrw)
|
||||
// e.g. when browser closes or stops sending, we are done here.
|
||||
wg.Done()
|
||||
}()
|
||||
if bufrw.Reader.Buffered() > 0 {
|
||||
io.CopyN(ch, bufrw, int64(bufrw.Reader.Buffered()))
|
||||
}
|
||||
io.Copy(ch, clientConn)
|
||||
}()
|
||||
|
||||
// 4. Backend -> Browser (Copy raw stream)
|
||||
// Backend -> Browser
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("Recovered from panic in WS reader: %v", r)
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
io.Copy(clientConn, ch)
|
||||
// When backend closes connection, close browser connection
|
||||
clientConn.Close()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
return
|
||||
}
|
||||
|
||||
// STANDARD HTTP STRATEGY: Request/Response Tunneling
|
||||
// 1. Write Request to Channel (simulating wire)
|
||||
if err := r.Write(ch); err != nil {
|
||||
log.Printf("Error writing request to backend: %v", err)
|
||||
http.Error(w, "Gateway Error", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
// 2. Read Response from Channel (parsing wire)
|
||||
resp, err := http.ReadResponse(bufio.NewReader(ch), r)
|
||||
if err != nil {
|
||||
log.Printf("Error reading response from backend: %v", err)
|
||||
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 3. Copy Headers to `w`
|
||||
for k, vv := range resp.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
// 4. Copy Body to `w`
|
||||
io.Copy(w, resp.Body)
|
||||
|
||||
})
|
||||
|
||||
@@ -239,28 +313,22 @@ func main() {
|
||||
// This is where clients say "Please listen on port X"
|
||||
sshServer.RequestHandlers = map[string]ssh.RequestHandler{
|
||||
"tcpip-forward": func(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) {
|
||||
// Parse payload
|
||||
// string address to bind (usually empty or 0.0.0.0 or 127.0.0.1)
|
||||
// uint32 port number to bind
|
||||
conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
|
||||
|
||||
// For Grokway, we ignore the requested port and assign a random subdomain/slug
|
||||
// Or we use the requested port if valid?
|
||||
// Let's assume we ignore it and generate a slug,
|
||||
// OR we use the port as the "slug" if it's special.
|
||||
// Check if client requested a specific slug
|
||||
requestedSlugsMu.Lock()
|
||||
slug, hasRequested := requestedSlugs[conn]
|
||||
if hasRequested {
|
||||
delete(requestedSlugs, conn)
|
||||
}
|
||||
requestedSlugsMu.Unlock()
|
||||
|
||||
// But wait, the client needs to know what we assigned!
|
||||
// Standard SSH response to tcpip-forward contains the BOUND PORT.
|
||||
// If we return a port, the client knows.
|
||||
// If we want to return a URL, standard SSH doesn't have a field for that.
|
||||
// But we can write it to the Session (stdout).
|
||||
|
||||
// Let's accept any request and map it to a random slug.
|
||||
slug := generateSlug(8)
|
||||
if !hasRequested {
|
||||
slug = generateSlug(8)
|
||||
}
|
||||
|
||||
log.Printf("Client requested forwarding. Assigning slug: %s", slug)
|
||||
|
||||
conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
|
||||
|
||||
manager.Register(slug, &Tunnel{
|
||||
ID: slug,
|
||||
LocalPort: 80, // We assume client forwards to port 80 locally? No,
|
||||
@@ -285,6 +353,24 @@ func main() {
|
||||
"cancel-tcpip-forward": func(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) {
|
||||
return true, nil
|
||||
},
|
||||
"grokway-request-slug": func(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) {
|
||||
slug := string(req.Payload)
|
||||
if !slugPattern.MatchString(slug) {
|
||||
log.Printf("Invalid slug format requested: %q", slug)
|
||||
return false, []byte("invalid slug format (lowercase alphanumeric and hyphens, 3-32 chars)")
|
||||
}
|
||||
// Check if slug is already in use
|
||||
if _, taken := manager.Get(slug); taken {
|
||||
log.Printf("Slug %q already in use", slug)
|
||||
return false, []byte("slug already in use")
|
||||
}
|
||||
conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
|
||||
requestedSlugsMu.Lock()
|
||||
requestedSlugs[conn] = slug
|
||||
requestedSlugsMu.Unlock()
|
||||
log.Printf("Client requested slug: %s", slug)
|
||||
return true, nil
|
||||
},
|
||||
"grokway-whoami": func(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) {
|
||||
conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
|
||||
slug, ok := manager.FindSlugByConn(conn)
|
||||
|
||||
@@ -68,8 +68,8 @@ type LogMsg string
|
||||
type MetricMsg int64
|
||||
type ClearCopiedMsg struct{}
|
||||
|
||||
func InitialModel(localPort, serverAddr, authToken, hostHeader string, localHTTPS bool) Model {
|
||||
c := tunnel.NewClient(serverAddr, localPort, authToken, hostHeader, localHTTPS)
|
||||
func InitialModel(localPort, serverAddr, authToken, hostHeader string, localHTTPS bool, slug string) Model {
|
||||
c := tunnel.NewClient(serverAddr, localPort, authToken, hostHeader, localHTTPS, slug)
|
||||
return Model{
|
||||
Client: c,
|
||||
LogLines: []string{},
|
||||
|
||||
@@ -24,28 +24,31 @@ type Client struct {
|
||||
ServerAddr string
|
||||
LocalPort string
|
||||
AuthToken string
|
||||
HostHeader string // New field for custom Host header
|
||||
LocalHTTPS bool // New field: connect to local service with HTTPS
|
||||
HostHeader string // Custom Host header
|
||||
LocalHTTPS bool // Connect to local service with HTTPS
|
||||
Slug string // Requested slug (empty = server assigns random)
|
||||
SSHClient *ssh.Client
|
||||
Listener net.Listener
|
||||
Events chan string // Channel to send logs/events to TUI
|
||||
Metrics chan int64 // Channel to send bytes transferred
|
||||
PublicURL string // PublicURL is the URL accessible from the internet
|
||||
stopKeepAlive chan struct{} // Signal to stop keepalive goroutine
|
||||
}
|
||||
|
||||
func NewClient(serverAddr, localPort, authToken, hostHeader string, localHTTPS bool) *Client {
|
||||
func NewClient(serverAddr, localPort, authToken, hostHeader string, localHTTPS bool, slug string) *Client {
|
||||
return &Client{
|
||||
ServerAddr: serverAddr,
|
||||
LocalPort: localPort,
|
||||
AuthToken: authToken,
|
||||
HostHeader: hostHeader,
|
||||
LocalHTTPS: localHTTPS,
|
||||
Slug: slug,
|
||||
Events: make(chan string, 10),
|
||||
Metrics: make(chan int64, 10),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Start() error {
|
||||
func (c *Client) connect() error {
|
||||
config := &ssh.ClientConfig{
|
||||
User: "grokway",
|
||||
Auth: []ssh.AuthMethod{
|
||||
@@ -63,6 +66,20 @@ func (c *Client) Start() error {
|
||||
c.SSHClient = client
|
||||
c.Events <- "SSH Connected!"
|
||||
|
||||
// Request a specific slug if provided
|
||||
if c.Slug != "" {
|
||||
ok, reply, err := client.SendRequest("grokway-request-slug", true, []byte(c.Slug))
|
||||
if err != nil || !ok {
|
||||
reason := string(reply)
|
||||
if reason == "" && err != nil {
|
||||
reason = err.Error()
|
||||
}
|
||||
c.Events <- fmt.Sprintf("Slug %q not available (%s), server will assign one", c.Slug, reason)
|
||||
} else {
|
||||
c.Events <- fmt.Sprintf("Slug %q reserved", c.Slug)
|
||||
}
|
||||
}
|
||||
|
||||
// Request remote listening (Reverse Forwarding)
|
||||
// Bind to 0.0.0.0 on server, random port (0)
|
||||
listener, err := client.Listen("tcp", "0.0.0.0:0")
|
||||
@@ -93,18 +110,86 @@ func (c *Client) Start() error {
|
||||
c.PublicURL = fmt.Sprintf("https://%s.%s", slug, host)
|
||||
c.Events <- fmt.Sprintf("Tunnel established! Public URL: %s", c.PublicURL)
|
||||
|
||||
go c.acceptLoop()
|
||||
// Start SSH keepalive to prevent idle disconnection
|
||||
c.stopKeepAlive = make(chan struct{})
|
||||
go c.keepAlive()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Start() error {
|
||||
if err := c.connect(); err != nil {
|
||||
return err
|
||||
}
|
||||
go c.acceptLoop()
|
||||
return nil
|
||||
}
|
||||
|
||||
// keepAlive sends periodic SSH keepalive requests to prevent
|
||||
// firewalls/NATs from dropping idle connections.
|
||||
func (c *Client) keepAlive() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-c.stopKeepAlive:
|
||||
return
|
||||
case <-ticker.C:
|
||||
if c.SSHClient != nil {
|
||||
_, _, err := c.SSHClient.SendRequest("keepalive@openssh.com", true, nil)
|
||||
if err != nil {
|
||||
logToFile(fmt.Sprintf("Keepalive failed: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) closeConnection() {
|
||||
if c.stopKeepAlive != nil {
|
||||
select {
|
||||
case <-c.stopKeepAlive:
|
||||
// Already closed
|
||||
default:
|
||||
close(c.stopKeepAlive)
|
||||
}
|
||||
}
|
||||
if c.Listener != nil {
|
||||
c.Listener.Close()
|
||||
}
|
||||
if c.SSHClient != nil {
|
||||
c.SSHClient.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) acceptLoop() {
|
||||
for {
|
||||
remoteConn, err := c.Listener.Accept()
|
||||
if err != nil {
|
||||
c.Events <- fmt.Sprintf("Accept error: %s", err)
|
||||
c.Events <- fmt.Sprintf("Connection lost: %s", err)
|
||||
c.closeConnection()
|
||||
|
||||
// Reconnect with backoff
|
||||
delay := 2 * time.Second
|
||||
maxDelay := 60 * time.Second
|
||||
for {
|
||||
c.Events <- fmt.Sprintf("Reconnecting in %s...", delay)
|
||||
time.Sleep(delay)
|
||||
|
||||
if err := c.connect(); err != nil {
|
||||
c.Events <- fmt.Sprintf("Reconnection failed: %s", err)
|
||||
delay *= 2
|
||||
if delay > maxDelay {
|
||||
delay = maxDelay
|
||||
}
|
||||
continue
|
||||
}
|
||||
c.Events <- "Reconnected successfully!"
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
c.Events <- "New Request received"
|
||||
go c.handleConnection(remoteConn)
|
||||
|
||||
Reference in New Issue
Block a user