Files
grokway/internal/tunnel/client.go

225 lines
5.9 KiB
Go

package tunnel
import (
"fmt"
"io"
"net"
"strings"
"time"
"os"
"golang.org/x/crypto/ssh"
)
func logToFile(msg string) {
f, _ := os.OpenFile("client_debug.log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
defer f.Close()
f.WriteString(time.Now().Format(time.RFC3339) + " " + msg + "\n")
}
// Client handles the SSH connection and forwarding
type Client struct {
ServerAddr string
LocalPort string
AuthToken string
HostHeader string // New field for custom Host header
SSHClient *ssh.Client
Listener net.Listener
Events chan string // Channel to send logs/events to TUI
Metrics chan int64 // Channel to send bytes transferred
PublicURL string // PublicURL is the URL accessible from the internet
}
func NewClient(serverAddr, localPort, authToken, hostHeader string) *Client {
return &Client{
ServerAddr: serverAddr,
LocalPort: localPort,
AuthToken: authToken,
HostHeader: hostHeader,
Events: make(chan string, 10),
Metrics: make(chan int64, 10),
}
}
func (c *Client) Start() error {
config := &ssh.ClientConfig{
User: "grokway",
Auth: []ssh.AuthMethod{
ssh.Password(c.AuthToken),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // Dev only
Timeout: 5 * time.Second,
}
c.Events <- fmt.Sprintf("Connecting to %s...", c.ServerAddr)
client, err := ssh.Dial("tcp", c.ServerAddr, config)
if err != nil {
return err
}
c.SSHClient = client
c.Events <- "SSH Connected!"
// Request remote listening (Reverse Forwarding)
// Bind to 0.0.0.0 on server, random port (0)
listener, err := client.Listen("tcp", "0.0.0.0:0")
if err != nil {
return fmt.Errorf("failed to request port forwarding: %w", err)
}
c.Listener = listener
// Query server for assigned slug
ok, slugBytes, err := client.SendRequest("grokway-whoami", true, nil)
slug := "test-slug" // Fallback
if err == nil && ok {
slug = string(slugBytes)
c.Events <- fmt.Sprintf("Server assigned domain: %s", slug)
} else {
c.Events <- "Failed to query domain from server, using fallback"
}
hostname := "localhost" // This should match what the server is running on actually
// Assuming HTTP proxy is on port 8080 of the same host as SSH server (but different port)
// We extract host from c.ServerAddr
host, _, _ := net.SplitHostPort(c.ServerAddr)
if host == "" {
host = hostname
}
c.PublicURL = fmt.Sprintf("https://%s.%s", slug, host)
c.Events <- fmt.Sprintf("Tunnel established! Public URL: %s", c.PublicURL)
go c.acceptLoop()
return nil
}
func (c *Client) acceptLoop() {
for {
remoteConn, err := c.Listener.Accept()
if err != nil {
c.Events <- fmt.Sprintf("Accept error: %s", err)
break
}
c.Events <- "New Request received"
go c.handleConnection(remoteConn)
}
}
func (c *Client) handleConnection(remoteConn net.Conn) {
defer remoteConn.Close()
// Dial local service
localConn, err := net.Dial("tcp", "localhost:"+c.LocalPort)
if err != nil {
errMsg := fmt.Sprintf("Failed to dial local: %s", err)
c.Events <- errMsg
logToFile(errMsg)
return
}
defer localConn.Close()
logToFile("Dialed local service successfully")
// We need to peek at the connection to see if it's HTTP
// Wrap the connection to peek without consuming
logToFile("Handling new connection")
// Check if we can peek
// Since net.Conn doesn't support Peek, we read specific bytes and reconstruct
// or use a bufio reader if we weren't doing a raw Copy.
// But io.Copy needs a reader.
// Read first chunk to inspect headers
buf := make([]byte, 8192) // Increased buffer for headers
remoteConn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
n, err := remoteConn.Read(buf)
remoteConn.SetReadDeadline(time.Time{}) // Reset deadline
if err != nil && err != io.EOF && !os.IsTimeout(err) {
logToFile(fmt.Sprintf("Error reading from remote: %v", err))
return
}
if n == 0 {
return
}
payload := string(buf[:n])
// If it looks like HTTP, rewrite Host and add Proto headers
if strings.Contains(payload, "HTTP/") {
lines := strings.Split(payload, "\r\n")
var newLines []string
for _, line := range lines {
if strings.HasPrefix(strings.ToLower(line), "host:") {
if c.HostHeader != "" {
newLines = append(newLines, "Host: "+c.HostHeader)
} else {
// Default to localhost handling: keep original or set to localhost?
// Usually localhost:port is safest if user didn't specify.
newLines = append(newLines, "Host: localhost:"+c.LocalPort)
}
continue
}
newLines = append(newLines, line)
}
// Insert X-Forwarded-Proto if missing (prevents redirects loop)
if !strings.Contains(strings.ToLower(payload), "x-forwarded-proto:") {
// Find end of headers (empty line)
for i, line := range newLines {
if line == "" { // End of headers
// Insert before the empty line
finalHeaders := append(newLines[:i], "X-Forwarded-Proto: https")
finalHeaders = append(finalHeaders, newLines[i:]...)
newLines = finalHeaders
break
}
}
}
modifiedPayload := strings.Join(newLines, "\r\n")
// Log for debug
logToFile("Rewritten Headers:\n" + modifiedPayload)
// Send modified headers
localConn.Write([]byte(modifiedPayload))
} else {
// Not HTTP or couldn't parse, just forward as comes
localConn.Write(buf[:n])
}
// Try to parse rudimentary HTTP for TUI logs
// Format: METHOD PATH PROTOCOL
// We use the payload we just processed (or the original if not modified? No, logs should show original request usually,
// but here we have the payload string variable already.)
lines := strings.Split(payload, "\n")
if len(lines) > 0 {
firstLine := lines[0]
parts := strings.Fields(firstLine)
if len(parts) >= 3 {
method := parts[0]
path := parts[1]
// Send structured log
c.Events <- fmt.Sprintf("HTTP|%s|%s|%d", method, path, 200)
} else {
c.Events <- "TCP|||Connection"
}
}
// Bidirectional copy
// Calculate bytes?
go func() {
n, _ := io.Copy(remoteConn, localConn)
c.Metrics <- n
}()
n2, _ := io.Copy(localConn, remoteConn)
c.Metrics <- n2
}