diff --git a/clients/clients.go b/clients/clients.go index 2cb14c7..59cab82 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -6,6 +6,7 @@ package clients import ( + "bufio" "bytes" "context" "encoding/json" @@ -14,6 +15,7 @@ import ( "mime/multipart" "net/http" "net/url" + "strings" "sync" "time" ) @@ -142,6 +144,36 @@ func (h *httpClient) do(req *http.Request) ([]byte, error) { return body, nil } +// postJSONStream sends a JSON POST and returns the raw *http.Response so the +// caller can read the body incrementally (e.g. for SSE streaming). The caller +// is responsible for closing resp.Body. +func (h *httpClient) postJSONStream(ctx context.Context, path string, body any) (*http.Response, error) { + buf := getBuf() + defer putBuf(buf) + if err := json.NewEncoder(buf).Encode(body); err != nil { + return nil, fmt.Errorf("marshal: %w", err) + } + // Copy to a non-pooled buffer so we can safely return the pool buffer. + payload := make([]byte, buf.Len()) + copy(payload, buf.Bytes()) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, bytes.NewReader(payload)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + resp, err := h.client.Do(req) + if err != nil { + return nil, fmt.Errorf("http %s %s: %w", req.Method, req.URL.Path, err) + } + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + return nil, fmt.Errorf("http %d: %s", resp.StatusCode, string(respBody)) + } + return resp, nil +} + func (h *httpClient) healthCheck(ctx context.Context) bool { data, err := h.get(ctx, "/health", nil) _ = data @@ -320,6 +352,73 @@ func (c *LLMClient) Generate(ctx context.Context, prompt string, context_ string return resp.Choices[0].Message.Content, nil } +// StreamGenerate sends a streaming chat completion request and calls onToken +// for each content delta received via SSE. Returns the fully assembled text. +// The onToken callback is invoked synchronously on the calling goroutine; it +// should be fast (e.g. publish a NATS message). +func (c *LLMClient) StreamGenerate(ctx context.Context, prompt string, context_ string, systemPrompt string, onToken func(token string)) (string, error) { + msgs := buildMessages(prompt, context_, systemPrompt) + payload := map[string]any{ + "model": c.Model, + "messages": msgs, + "max_tokens": c.MaxTokens, + "temperature": c.Temperature, + "top_p": c.TopP, + "stream": true, + } + + resp, err := c.postJSONStream(ctx, "/v1/chat/completions", payload) + if err != nil { + return "", err + } + defer func() { _ = resp.Body.Close() }() + + var full strings.Builder + scanner := bufio.NewScanner(resp.Body) + // SSE lines can be up to 64 KiB for large token batches. + scanner.Buffer(make([]byte, 0, 64*1024), 64*1024) + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + break + } + var chunk struct { + Choices []struct { + Delta struct { + Content string `json:"content"` + } `json:"delta"` + } `json:"choices"` + } + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + continue // skip malformed chunks + } + if len(chunk.Choices) == 0 { + continue + } + token := chunk.Choices[0].Delta.Content + if token == "" { + continue + } + full.WriteString(token) + if onToken != nil { + onToken(token) + } + } + if err := scanner.Err(); err != nil { + // If we already collected some text, return it with the error. + if full.Len() > 0 { + return full.String(), fmt.Errorf("stream interrupted: %w", err) + } + return "", fmt.Errorf("stream read: %w", err) + } + return full.String(), nil +} + func buildMessages(prompt, ctx, systemPrompt string) []ChatMessage { var msgs []ChatMessage if systemPrompt != "" { diff --git a/clients/clients_test.go b/clients/clients_test.go index 0406552..9c24887 100644 --- a/clients/clients_test.go +++ b/clients/clients_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io" "net/http" "net/http/httptest" @@ -299,6 +300,293 @@ func TestLLMClient_GenerateNoChoices(t *testing.T) { } } +// ──────────────────────────────────────────────────────────────────────────── +// LLM client — StreamGenerate +// ──────────────────────────────────────────────────────────────────────────── + +// sseChunk builds an OpenAI-compatible SSE chat.completion.chunk line. +func sseChunk(content string) string { + chunk := map[string]any{ + "choices": []map[string]any{ + {"delta": map[string]any{"content": content}}, + }, + } + b, _ := json.Marshal(chunk) + return "data: " + string(b) + "\n\n" +} + +func TestLLMClient_StreamGenerate(t *testing.T) { + tokens := []string{"Hello", " world", "!"} + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/chat/completions" { + t.Errorf("path = %q", r.URL.Path) + } + var req map[string]any + _ = json.NewDecoder(r.Body).Decode(&req) + if req["stream"] != true { + t.Errorf("stream = %v, want true", req["stream"]) + } + + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + for _, tok := range tokens { + _, _ = w.Write([]byte(sseChunk(tok))) + flusher.Flush() + } + _, _ = w.Write([]byte("data: [DONE]\n\n")) + flusher.Flush() + })) + defer ts.Close() + + c := NewLLMClient(ts.URL, 5*time.Second) + var received []string + result, err := c.StreamGenerate(context.Background(), "hi", "", "", func(tok string) { + received = append(received, tok) + }) + if err != nil { + t.Fatal(err) + } + if result != "Hello world!" { + t.Errorf("result = %q, want %q", result, "Hello world!") + } + if len(received) != 3 { + t.Fatalf("callback count = %d, want 3", len(received)) + } + if received[0] != "Hello" || received[1] != " world" || received[2] != "!" { + t.Errorf("received = %v", received) + } +} + +func TestLLMClient_StreamGenerateWithSystemPrompt(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req map[string]any + _ = json.NewDecoder(r.Body).Decode(&req) + msgs, _ := req["messages"].([]any) + if len(msgs) != 2 { + t.Errorf("expected system+user, got %d messages", len(msgs)) + } + first, _ := msgs[0].(map[string]any) + if first["role"] != "system" || first["content"] != "You are a DM" { + t.Errorf("system msg = %v", first) + } + + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + _, _ = w.Write([]byte(sseChunk("ok"))) + _, _ = w.Write([]byte("data: [DONE]\n\n")) + flusher.Flush() + })) + defer ts.Close() + + c := NewLLMClient(ts.URL, 5*time.Second) + result, err := c.StreamGenerate(context.Background(), "roll dice", "", "You are a DM", nil) + if err != nil { + t.Fatal(err) + } + if result != "ok" { + t.Errorf("result = %q", result) + } +} + +func TestLLMClient_StreamGenerateNilCallback(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + _, _ = w.Write([]byte(sseChunk("token"))) + _, _ = w.Write([]byte("data: [DONE]\n\n")) + flusher.Flush() + })) + defer ts.Close() + + c := NewLLMClient(ts.URL, 5*time.Second) + // nil callback should not panic + result, err := c.StreamGenerate(context.Background(), "hi", "", "", nil) + if err != nil { + t.Fatal(err) + } + if result != "token" { + t.Errorf("result = %q", result) + } +} + +func TestLLMClient_StreamGenerateEmptyDelta(t *testing.T) { + // SSE chunks with empty content should be silently skipped. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + // role-only chunk (no content) — common for first chunk from vLLM + _, _ = w.Write([]byte("data: {\"choices\":[{\"delta\":{\"role\":\"assistant\"}}]}\n\n")) + // empty content string + _, _ = w.Write([]byte(sseChunk(""))) + // real token + _, _ = w.Write([]byte(sseChunk("hello"))) + _, _ = w.Write([]byte("data: [DONE]\n\n")) + flusher.Flush() + })) + defer ts.Close() + + c := NewLLMClient(ts.URL, 5*time.Second) + var count int + result, err := c.StreamGenerate(context.Background(), "q", "", "", func(tok string) { + count++ + }) + if err != nil { + t.Fatal(err) + } + if result != "hello" { + t.Errorf("result = %q", result) + } + if count != 1 { + t.Errorf("callback count = %d, want 1 (empty deltas should be skipped)", count) + } +} + +func TestLLMClient_StreamGenerateMalformedChunks(t *testing.T) { + // Malformed JSON should be skipped without error. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + _, _ = w.Write([]byte("data: {invalid json}\n\n")) + _, _ = w.Write([]byte("data: {\"choices\":[]}\n\n")) // empty choices + _, _ = w.Write([]byte(sseChunk("good"))) + _, _ = w.Write([]byte("data: [DONE]\n\n")) + flusher.Flush() + })) + defer ts.Close() + + c := NewLLMClient(ts.URL, 5*time.Second) + result, err := c.StreamGenerate(context.Background(), "q", "", "", nil) + if err != nil { + t.Fatal(err) + } + if result != "good" { + t.Errorf("result = %q, want %q", result, "good") + } +} + +func TestLLMClient_StreamGenerateHTTPError(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + _, _ = w.Write([]byte("internal server error")) + })) + defer ts.Close() + + c := NewLLMClient(ts.URL, 5*time.Second) + _, err := c.StreamGenerate(context.Background(), "q", "", "", nil) + if err == nil { + t.Fatal("expected error for 500") + } + if !strings.Contains(err.Error(), "500") { + t.Errorf("error should contain 500: %v", err) + } +} + +func TestLLMClient_StreamGenerateContextCanceled(t *testing.T) { + started := make(chan struct{}) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + // Send several tokens so the client receives some before cancel. + for i := range 20 { + _, _ = w.Write([]byte(sseChunk(fmt.Sprintf("tok%d ", i)))) + flusher.Flush() + } + close(started) + // Block until client cancels + <-r.Context().Done() + })) + defer ts.Close() + + ctx, cancel := context.WithCancel(context.Background()) + c := NewLLMClient(ts.URL, 10*time.Second) + + var streamErr error + done := make(chan struct{}) + go func() { + defer close(done) + _, streamErr = c.StreamGenerate(ctx, "q", "", "", nil) + }() + + <-started + cancel() + <-done + + // After cancel the stream should return an error (context canceled or + // stream interrupted). The exact partial text depends on timing. + if streamErr == nil { + t.Error("expected error after context cancel") + } +} + +func TestLLMClient_StreamGenerateNoSSEPrefix(t *testing.T) { + // Lines without "data: " prefix should be silently ignored (comments, blank lines, event IDs). + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + _, _ = w.Write([]byte(": this is an SSE comment\n\n")) + _, _ = w.Write([]byte("event: message\n")) + _, _ = w.Write([]byte(sseChunk("word"))) + _, _ = w.Write([]byte("\n")) // blank line + _, _ = w.Write([]byte("data: [DONE]\n\n")) + flusher.Flush() + })) + defer ts.Close() + + c := NewLLMClient(ts.URL, 5*time.Second) + result, err := c.StreamGenerate(context.Background(), "q", "", "", nil) + if err != nil { + t.Fatal(err) + } + if result != "word" { + t.Errorf("result = %q, want %q", result, "word") + } +} + +func TestLLMClient_StreamGenerateManyTokens(t *testing.T) { + // Verify token ordering and full assembly with many chunks. + n := 100 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + for i := range n { + tok := fmt.Sprintf("t%d ", i) + _, _ = w.Write([]byte(sseChunk(tok))) + flusher.Flush() + } + _, _ = w.Write([]byte("data: [DONE]\n\n")) + flusher.Flush() + })) + defer ts.Close() + + c := NewLLMClient(ts.URL, 5*time.Second) + var mu sync.Mutex + var order []int + result, err := c.StreamGenerate(context.Background(), "q", "", "", func(tok string) { + var idx int + fmt.Sscanf(tok, "t%d ", &idx) + mu.Lock() + order = append(order, idx) + mu.Unlock() + }) + if err != nil { + t.Fatal(err) + } + // Verify all tokens arrived in order + if len(order) != n { + t.Fatalf("got %d tokens, want %d", len(order), n) + } + for i, v := range order { + if v != i { + t.Errorf("order[%d] = %d", i, v) + break + } + } + // Quick sanity: result should start with "t0 " and end with last token + if !strings.HasPrefix(result, "t0 ") { + t.Errorf("result prefix = %q", result[:10]) + } +} + // ──────────────────────────────────────────────────────────────────────────── // TTS client // ────────────────────────────────────────────────────────────────────────────