package xai import ( "encoding/base64" "encoding/binary" "encoding/json" "fmt" "log" "sync" "github.com/gorilla/websocket" ) const ( // WebSocket endpoint for xAI realtime API RealtimeURL = "wss://api.x.ai/v1/realtime" ) // AudioHandler is called when audio is received from xAI type AudioHandler func(pcm []int16) // TranscriptHandler is called when transcript text is received type TranscriptHandler func(text string) // SpeechHandler is called when speech is detected (for interruptions) type SpeechHandler func() // Client manages a WebSocket connection to xAI Voice Agent API type Client struct { apiKey string conn *websocket.Conn mu sync.Mutex // Callbacks onAudio AudioHandler onTranscript TranscriptHandler onSpeechStarted SpeechHandler // State connected bool done chan struct{} } // New creates a new xAI client func New(apiKey string) *Client { return &Client{ apiKey: apiKey, done: make(chan struct{}), } } // OnAudio sets the callback for received audio func (c *Client) OnAudio(handler AudioHandler) { c.onAudio = handler } // OnTranscript sets the callback for received transcripts func (c *Client) OnTranscript(handler TranscriptHandler) { c.onTranscript = handler } // OnSpeechStarted sets the callback for when user starts speaking (for interruptions) func (c *Client) OnSpeechStarted(handler SpeechHandler) { c.onSpeechStarted = handler } // Connect establishes WebSocket connection to xAI func (c *Client) Connect() error { header := make(map[string][]string) header["Authorization"] = []string{"Bearer " + c.apiKey} dialer := websocket.Dialer{} conn, _, err := dialer.Dial(RealtimeURL, header) if err != nil { return fmt.Errorf("failed to connect to xAI: %w", err) } c.conn = conn c.connected = true // Start message receiver go c.receiveLoop() log.Println("[xAI] Connected to Voice Agent API") return nil } // ConfigureSession sets up the voice session func (c *Client) ConfigureSession(voice, instructions string) error { msg := SessionUpdate{ Type: "session.update", Session: Session{ Voice: voice, Instructions: instructions, TurnDetection: &TurnDetection{ Type: "server_vad", }, Audio: &AudioConfig{ Input: &AudioFormatConfig{ Format: AudioFormat{Type: "audio/pcm", Rate: 48000}, }, Output: &AudioFormatConfig{ Format: AudioFormat{Type: "audio/pcm", Rate: 48000}, }, }, }, } return c.sendJSON(msg) } // SendAudio sends PCM audio data to xAI // pcm should be int16 samples at 48kHz mono func (c *Client) SendAudio(pcm []int16) error { // Convert int16 slice to bytes (little endian) buf := make([]byte, len(pcm)*2) for i, sample := range pcm { binary.LittleEndian.PutUint16(buf[i*2:], uint16(sample)) } // Encode to base64 encoded := base64.StdEncoding.EncodeToString(buf) msg := InputAudioBufferAppend{ Type: "input_audio_buffer.append", Audio: encoded, } return c.sendJSON(msg) } // Close closes the WebSocket connection func (c *Client) Close() { c.mu.Lock() defer c.mu.Unlock() if c.conn != nil { close(c.done) c.conn.Close() c.connected = false log.Println("[xAI] Connection closed") } } // IsConnected returns connection status func (c *Client) IsConnected() bool { return c.connected } // sendJSON sends a JSON message over WebSocket func (c *Client) sendJSON(v any) error { c.mu.Lock() defer c.mu.Unlock() if c.conn == nil { return fmt.Errorf("not connected") } data, err := json.Marshal(v) if err != nil { return err } return c.conn.WriteMessage(websocket.TextMessage, data) } // receiveLoop handles incoming messages from xAI func (c *Client) receiveLoop() { defer func() { c.connected = false }() for { select { case <-c.done: return default: } _, message, err := c.conn.ReadMessage() if err != nil { if websocket.IsCloseError(err, websocket.CloseNormalClosure) { log.Println("[xAI] Connection closed normally") } else { log.Printf("[xAI] Read error: %v", err) } return } c.handleMessage(message) } } // handleMessage processes an incoming WebSocket message func (c *Client) handleMessage(data []byte) { // Parse base message to get type var base ServerMessage if err := json.Unmarshal(data, &base); err != nil { log.Printf("[xAI] Failed to parse message: %v", err) return } switch base.Type { case "session.updated": log.Println("[xAI] Session configured successfully") case "session.created": log.Println("[xAI] Session created") case "conversation.created": log.Println("[xAI] Conversation created") case "response.output_audio.delta": var msg ResponseOutputAudioDelta if err := json.Unmarshal(data, &msg); err != nil { log.Printf("[xAI] Failed to parse audio delta: %v", err) return } c.handleAudioDelta(msg.Delta) case "response.output_audio.done": // Audio stream complete for this response log.Println("[xAI] Audio response complete") case "response.output_audio_transcript.delta": // Could extract transcript text here var raw map[string]any json.Unmarshal(data, &raw) if delta, ok := raw["delta"].(string); ok && c.onTranscript != nil { c.onTranscript(delta) } case "response.done": log.Println("[xAI] Response complete") case "input_audio_buffer.speech_started": log.Println("[xAI] Speech started (VAD)") if c.onSpeechStarted != nil { c.onSpeechStarted() } case "input_audio_buffer.speech_stopped": log.Println("[xAI] Speech stopped (VAD)") case "error": var msg ErrorMessage if err := json.Unmarshal(data, &msg); err == nil { log.Printf("[xAI] Error: %s - %s", msg.Error.Code, msg.Error.Message) } default: // Log unhandled message types for debugging log.Printf("[xAI] Received: %s", base.Type) } } // handleAudioDelta processes received audio data func (c *Client) handleAudioDelta(base64Audio string) { if c.onAudio == nil { return } // Decode base64 audioBytes, err := base64.StdEncoding.DecodeString(base64Audio) if err != nil { log.Printf("[xAI] Failed to decode audio: %v", err) return } // Convert bytes to int16 (little endian) pcm := make([]int16, len(audioBytes)/2) for i := 0; i < len(pcm); i++ { pcm[i] = int16(binary.LittleEndian.Uint16(audioBytes[i*2:])) } c.onAudio(pcm) }