package tunnel import ( "crypto/tls" "fmt" "io" "net" "strings" "time" "os" "golang.org/x/crypto/ssh" ) func logToFile(msg string) { f, _ := os.OpenFile("client_debug.log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) defer f.Close() f.WriteString(time.Now().Format(time.RFC3339) + " " + msg + "\n") } // Client handles the SSH connection and forwarding type Client struct { ServerAddr string LocalPort string AuthToken string HostHeader string // New field for custom Host header LocalHTTPS bool // New field: connect to local service with HTTPS SSHClient *ssh.Client Listener net.Listener Events chan string // Channel to send logs/events to TUI Metrics chan int64 // Channel to send bytes transferred PublicURL string // PublicURL is the URL accessible from the internet } func NewClient(serverAddr, localPort, authToken, hostHeader string, localHTTPS bool) *Client { return &Client{ ServerAddr: serverAddr, LocalPort: localPort, AuthToken: authToken, HostHeader: hostHeader, LocalHTTPS: localHTTPS, Events: make(chan string, 10), Metrics: make(chan int64, 10), } } // ... Start() code ... // ... acceptLoop() code ... func (c *Client) handleConnection(remoteConn net.Conn) { defer remoteConn.Close() // Dial local service var localConn net.Conn var err error // Check if we should use TLS (explicit flag or port 443) // You might want to strip "443" from "localhost:443" if localPort includes hostname, but user input is just port usually? // User input is --local :443 or 443. useTLS := c.LocalHTTPS if c.LocalPort == "443" || strings.HasSuffix(c.LocalPort, ":443") { useTLS = true } if useTLS { conf := &tls.Config{InsecureSkipVerify: true} localConn, err = tls.Dial("tcp", "localhost:"+c.LocalPort, conf) } else { localConn, err = net.Dial("tcp", "localhost:"+c.LocalPort) } if err != nil { errMsg := fmt.Sprintf("Failed to dial local (TLS=%v): %s", useTLS, err) c.Events <- errMsg logToFile(errMsg) return } defer localConn.Close() logToFile(fmt.Sprintf("Dialed local service successfully (TLS=%v)", useTLS)) // We need to peek at the connection to see if it's HTTP // Wrap the connection to peek without consuming logToFile("Handling new connection") // Read first chunk to inspect headers buf := make([]byte, 8192) // Increased buffer for headers remoteConn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) n, err := remoteConn.Read(buf) remoteConn.SetReadDeadline(time.Time{}) // Reset deadline if err != nil && err != io.EOF && !os.IsTimeout(err) { logToFile(fmt.Sprintf("Error reading from remote: %v", err)) return } if n == 0 { return } payload := string(buf[:n]) // If it looks like HTTP, rewrite Host and add Proto headers if strings.Contains(payload, "HTTP/") { lines := strings.Split(payload, "\r\n") var newLines []string for _, line := range lines { if strings.HasPrefix(strings.ToLower(line), "host:") { if c.HostHeader != "" { newLines = append(newLines, "Host: "+c.HostHeader) } else { // Default to localhost handling: keep original or set to localhost? // Usually localhost:port is safest if user didn't specify. newLines = append(newLines, "Host: localhost:"+c.LocalPort) } continue } newLines = append(newLines, line) } // Insert X-Forwarded-Proto if missing (prevents redirects loop) if !strings.Contains(strings.ToLower(payload), "x-forwarded-proto:") { // Find end of headers (empty line) for i, line := range newLines { if line == "" { // End of headers // Insert before the empty line finalHeaders := append(newLines[:i], "X-Forwarded-Proto: https") finalHeaders = append(finalHeaders, newLines[i:]...) newLines = finalHeaders break } } } modifiedPayload := strings.Join(newLines, "\r\n") // Log for debug logToFile("Rewritten Headers:\n" + modifiedPayload) // Send modified headers localConn.Write([]byte(modifiedPayload)) } else { // Not HTTP or couldn't parse, just forward as comes localConn.Write(buf[:n]) } // Try to parse rudimentary HTTP for TUI logs // Format: METHOD PATH PROTOCOL // We use the payload we just processed (or the original if not modified? No, logs should show original request usually, // but here we have the payload string variable already.) lines := strings.Split(payload, "\n") if len(lines) > 0 { firstLine := lines[0] parts := strings.Fields(firstLine) if len(parts) >= 3 { method := parts[0] path := parts[1] // Send structured log c.Events <- fmt.Sprintf("HTTP|%s|%s|%d", method, path, 200) } else { c.Events <- "TCP|||Connection" } } // Bidirectional copy // Calculate bytes? go func() { n, _ := io.Copy(remoteConn, localConn) c.Metrics <- n }() n2, _ := io.Copy(localConn, remoteConn) c.Metrics <- n2 }