2026-01-15 22:38:39 +01:00
|
|
|
package xai
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"encoding/base64"
|
|
|
|
|
"encoding/binary"
|
|
|
|
|
"encoding/json"
|
|
|
|
|
"fmt"
|
|
|
|
|
"log"
|
|
|
|
|
"sync"
|
|
|
|
|
|
|
|
|
|
"github.com/gorilla/websocket"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
const (
|
|
|
|
|
// WebSocket endpoint for xAI realtime API
|
|
|
|
|
RealtimeURL = "wss://api.x.ai/v1/realtime"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// AudioHandler is called when audio is received from xAI
|
|
|
|
|
type AudioHandler func(pcm []int16)
|
|
|
|
|
|
|
|
|
|
// TranscriptHandler is called when transcript text is received
|
|
|
|
|
type TranscriptHandler func(text string)
|
|
|
|
|
|
|
|
|
|
// SpeechHandler is called when speech is detected (for interruptions)
|
|
|
|
|
type SpeechHandler func()
|
|
|
|
|
|
|
|
|
|
// Client manages a WebSocket connection to xAI Voice Agent API
|
|
|
|
|
type Client struct {
|
|
|
|
|
apiKey string
|
|
|
|
|
conn *websocket.Conn
|
|
|
|
|
mu sync.Mutex
|
|
|
|
|
|
|
|
|
|
// Callbacks
|
|
|
|
|
onAudio AudioHandler
|
|
|
|
|
onTranscript TranscriptHandler
|
|
|
|
|
onSpeechStarted SpeechHandler
|
|
|
|
|
|
|
|
|
|
// State
|
|
|
|
|
connected bool
|
|
|
|
|
done chan struct{}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// New creates a new xAI client
|
|
|
|
|
func New(apiKey string) *Client {
|
|
|
|
|
return &Client{
|
|
|
|
|
apiKey: apiKey,
|
|
|
|
|
done: make(chan struct{}),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// OnAudio sets the callback for received audio
|
|
|
|
|
func (c *Client) OnAudio(handler AudioHandler) {
|
|
|
|
|
c.onAudio = handler
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// OnTranscript sets the callback for received transcripts
|
|
|
|
|
func (c *Client) OnTranscript(handler TranscriptHandler) {
|
|
|
|
|
c.onTranscript = handler
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// OnSpeechStarted sets the callback for when user starts speaking (for interruptions)
|
|
|
|
|
func (c *Client) OnSpeechStarted(handler SpeechHandler) {
|
|
|
|
|
c.onSpeechStarted = handler
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Connect establishes WebSocket connection to xAI
|
|
|
|
|
func (c *Client) Connect() error {
|
|
|
|
|
header := make(map[string][]string)
|
|
|
|
|
header["Authorization"] = []string{"Bearer " + c.apiKey}
|
|
|
|
|
|
|
|
|
|
dialer := websocket.Dialer{}
|
|
|
|
|
conn, _, err := dialer.Dial(RealtimeURL, header)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return fmt.Errorf("failed to connect to xAI: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
c.conn = conn
|
|
|
|
|
c.connected = true
|
|
|
|
|
|
|
|
|
|
// Start message receiver
|
|
|
|
|
go c.receiveLoop()
|
|
|
|
|
|
|
|
|
|
log.Println("[xAI] Connected to Voice Agent API")
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ConfigureSession sets up the voice session
|
|
|
|
|
func (c *Client) ConfigureSession(voice, instructions string) error {
|
|
|
|
|
msg := SessionUpdate{
|
|
|
|
|
Type: "session.update",
|
|
|
|
|
Session: Session{
|
|
|
|
|
Voice: voice,
|
|
|
|
|
Instructions: instructions,
|
|
|
|
|
TurnDetection: &TurnDetection{
|
|
|
|
|
Type: "server_vad",
|
|
|
|
|
},
|
|
|
|
|
Audio: &AudioConfig{
|
|
|
|
|
Input: &AudioFormatConfig{
|
|
|
|
|
Format: AudioFormat{Type: "audio/pcm", Rate: 48000},
|
|
|
|
|
},
|
|
|
|
|
Output: &AudioFormatConfig{
|
|
|
|
|
Format: AudioFormat{Type: "audio/pcm", Rate: 48000},
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return c.sendJSON(msg)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// SendAudio sends PCM audio data to xAI
|
|
|
|
|
// pcm should be int16 samples at 48kHz mono
|
|
|
|
|
func (c *Client) SendAudio(pcm []int16) error {
|
|
|
|
|
// Convert int16 slice to bytes (little endian)
|
|
|
|
|
buf := make([]byte, len(pcm)*2)
|
|
|
|
|
for i, sample := range pcm {
|
|
|
|
|
binary.LittleEndian.PutUint16(buf[i*2:], uint16(sample))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Encode to base64
|
|
|
|
|
encoded := base64.StdEncoding.EncodeToString(buf)
|
|
|
|
|
|
|
|
|
|
msg := InputAudioBufferAppend{
|
|
|
|
|
Type: "input_audio_buffer.append",
|
|
|
|
|
Audio: encoded,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return c.sendJSON(msg)
|
|
|
|
|
}
|
|
|
|
|
|
2026-01-16 10:39:27 +01:00
|
|
|
// SendText sends a text message to trigger a Grok response
|
|
|
|
|
func (c *Client) SendText(text string) error {
|
|
|
|
|
// Create conversation item with text
|
|
|
|
|
createMsg := ConversationItemCreate{
|
|
|
|
|
Type: "conversation.item.create",
|
|
|
|
|
Item: ConversationItem{
|
|
|
|
|
Type: "message",
|
|
|
|
|
Role: "user",
|
|
|
|
|
Content: []ItemContent{
|
|
|
|
|
{Type: "input_text", Text: text},
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err := c.sendJSON(createMsg); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Request response
|
|
|
|
|
responseMsg := ResponseCreate{
|
|
|
|
|
Type: "response.create",
|
|
|
|
|
Response: ResponseSettings{
|
|
|
|
|
Modalities: []string{"text", "audio"},
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return c.sendJSON(responseMsg)
|
|
|
|
|
}
|
|
|
|
|
|
2026-01-15 22:38:39 +01:00
|
|
|
// Close closes the WebSocket connection
|
|
|
|
|
func (c *Client) Close() {
|
|
|
|
|
c.mu.Lock()
|
|
|
|
|
defer c.mu.Unlock()
|
|
|
|
|
|
|
|
|
|
if c.conn != nil {
|
|
|
|
|
close(c.done)
|
|
|
|
|
c.conn.Close()
|
|
|
|
|
c.connected = false
|
|
|
|
|
log.Println("[xAI] Connection closed")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// IsConnected returns connection status
|
|
|
|
|
func (c *Client) IsConnected() bool {
|
|
|
|
|
return c.connected
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// sendJSON sends a JSON message over WebSocket
|
|
|
|
|
func (c *Client) sendJSON(v any) error {
|
|
|
|
|
c.mu.Lock()
|
|
|
|
|
defer c.mu.Unlock()
|
|
|
|
|
|
|
|
|
|
if c.conn == nil {
|
|
|
|
|
return fmt.Errorf("not connected")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
data, err := json.Marshal(v)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return c.conn.WriteMessage(websocket.TextMessage, data)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// receiveLoop handles incoming messages from xAI
|
|
|
|
|
func (c *Client) receiveLoop() {
|
|
|
|
|
defer func() {
|
|
|
|
|
c.connected = false
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
for {
|
|
|
|
|
select {
|
|
|
|
|
case <-c.done:
|
|
|
|
|
return
|
|
|
|
|
default:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_, message, err := c.conn.ReadMessage()
|
|
|
|
|
if err != nil {
|
2026-01-16 10:39:27 +01:00
|
|
|
// Check if closed intentionally
|
|
|
|
|
select {
|
|
|
|
|
case <-c.done:
|
|
|
|
|
return
|
|
|
|
|
default:
|
|
|
|
|
}
|
|
|
|
|
|
2026-01-15 22:38:39 +01:00
|
|
|
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
|
|
|
|
|
log.Println("[xAI] Connection closed normally")
|
|
|
|
|
} else {
|
|
|
|
|
log.Printf("[xAI] Read error: %v", err)
|
|
|
|
|
}
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
c.handleMessage(message)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// handleMessage processes an incoming WebSocket message
|
|
|
|
|
func (c *Client) handleMessage(data []byte) {
|
|
|
|
|
// Parse base message to get type
|
|
|
|
|
var base ServerMessage
|
|
|
|
|
if err := json.Unmarshal(data, &base); err != nil {
|
|
|
|
|
log.Printf("[xAI] Failed to parse message: %v", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch base.Type {
|
|
|
|
|
case "session.updated":
|
|
|
|
|
log.Println("[xAI] Session configured successfully")
|
|
|
|
|
|
|
|
|
|
case "session.created":
|
|
|
|
|
log.Println("[xAI] Session created")
|
|
|
|
|
|
|
|
|
|
case "conversation.created":
|
|
|
|
|
log.Println("[xAI] Conversation created")
|
|
|
|
|
|
|
|
|
|
case "response.output_audio.delta":
|
|
|
|
|
var msg ResponseOutputAudioDelta
|
|
|
|
|
if err := json.Unmarshal(data, &msg); err != nil {
|
|
|
|
|
log.Printf("[xAI] Failed to parse audio delta: %v", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
c.handleAudioDelta(msg.Delta)
|
|
|
|
|
|
|
|
|
|
case "response.output_audio.done":
|
|
|
|
|
// Audio stream complete for this response
|
|
|
|
|
log.Println("[xAI] Audio response complete")
|
|
|
|
|
|
|
|
|
|
case "response.output_audio_transcript.delta":
|
|
|
|
|
// Could extract transcript text here
|
|
|
|
|
var raw map[string]any
|
|
|
|
|
json.Unmarshal(data, &raw)
|
|
|
|
|
if delta, ok := raw["delta"].(string); ok && c.onTranscript != nil {
|
|
|
|
|
c.onTranscript(delta)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
case "response.done":
|
|
|
|
|
log.Println("[xAI] Response complete")
|
|
|
|
|
|
|
|
|
|
case "input_audio_buffer.speech_started":
|
|
|
|
|
log.Println("[xAI] Speech started (VAD)")
|
|
|
|
|
if c.onSpeechStarted != nil {
|
|
|
|
|
c.onSpeechStarted()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
case "input_audio_buffer.speech_stopped":
|
|
|
|
|
log.Println("[xAI] Speech stopped (VAD)")
|
|
|
|
|
|
|
|
|
|
case "error":
|
|
|
|
|
var msg ErrorMessage
|
|
|
|
|
if err := json.Unmarshal(data, &msg); err == nil {
|
|
|
|
|
log.Printf("[xAI] Error: %s - %s", msg.Error.Code, msg.Error.Message)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
default:
|
|
|
|
|
// Log unhandled message types for debugging
|
|
|
|
|
log.Printf("[xAI] Received: %s", base.Type)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// handleAudioDelta processes received audio data
|
|
|
|
|
func (c *Client) handleAudioDelta(base64Audio string) {
|
|
|
|
|
if c.onAudio == nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Decode base64
|
|
|
|
|
audioBytes, err := base64.StdEncoding.DecodeString(base64Audio)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("[xAI] Failed to decode audio: %v", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Convert bytes to int16 (little endian)
|
|
|
|
|
pcm := make([]int16, len(audioBytes)/2)
|
|
|
|
|
for i := 0; i < len(pcm); i++ {
|
|
|
|
|
pcm[i] = int16(binary.LittleEndian.Uint16(audioBytes[i*2:]))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
c.onAudio(pcm)
|
|
|
|
|
}
|