feat: add StreamGenerate for real SSE streaming from LLM
Some checks failed
CI / Lint (push) Failing after 2m44s
CI / Test (push) Successful in 3m7s
CI / Release (push) Has been skipped
CI / Notify Downstream (chat-handler) (push) Has been skipped
CI / Notify Downstream (pipeline-bridge) (push) Has been skipped
CI / Notify Downstream (stt-module) (push) Has been skipped
CI / Notify Downstream (tts-module) (push) Has been skipped
CI / Notify Downstream (voice-assistant) (push) Has been skipped
CI / Notify (push) Successful in 2s

- Add postJSONStream() for incremental response body reading
- Add LLMClient.StreamGenerate() with SSE parsing and onToken callback
- Supports stream:true, parses data: lines, handles [DONE] sentinel
- Graceful partial-text return on stream interruption
- 9 new tests covering happy path, edge cases, cancellation
This commit is contained in:
2026-02-20 17:55:01 -05:00
parent fba7b62573
commit 3585d81ff5
2 changed files with 387 additions and 0 deletions

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
@@ -299,6 +300,293 @@ func TestLLMClient_GenerateNoChoices(t *testing.T) {
}
}
// ────────────────────────────────────────────────────────────────────────────
// LLM client — StreamGenerate
// ────────────────────────────────────────────────────────────────────────────
// sseChunk builds an OpenAI-compatible SSE chat.completion.chunk line.
func sseChunk(content string) string {
chunk := map[string]any{
"choices": []map[string]any{
{"delta": map[string]any{"content": content}},
},
}
b, _ := json.Marshal(chunk)
return "data: " + string(b) + "\n\n"
}
func TestLLMClient_StreamGenerate(t *testing.T) {
tokens := []string{"Hello", " world", "!"}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/chat/completions" {
t.Errorf("path = %q", r.URL.Path)
}
var req map[string]any
_ = json.NewDecoder(r.Body).Decode(&req)
if req["stream"] != true {
t.Errorf("stream = %v, want true", req["stream"])
}
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
for _, tok := range tokens {
_, _ = w.Write([]byte(sseChunk(tok)))
flusher.Flush()
}
_, _ = w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
var received []string
result, err := c.StreamGenerate(context.Background(), "hi", "", "", func(tok string) {
received = append(received, tok)
})
if err != nil {
t.Fatal(err)
}
if result != "Hello world!" {
t.Errorf("result = %q, want %q", result, "Hello world!")
}
if len(received) != 3 {
t.Fatalf("callback count = %d, want 3", len(received))
}
if received[0] != "Hello" || received[1] != " world" || received[2] != "!" {
t.Errorf("received = %v", received)
}
}
func TestLLMClient_StreamGenerateWithSystemPrompt(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req map[string]any
_ = json.NewDecoder(r.Body).Decode(&req)
msgs, _ := req["messages"].([]any)
if len(msgs) != 2 {
t.Errorf("expected system+user, got %d messages", len(msgs))
}
first, _ := msgs[0].(map[string]any)
if first["role"] != "system" || first["content"] != "You are a DM" {
t.Errorf("system msg = %v", first)
}
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
_, _ = w.Write([]byte(sseChunk("ok")))
_, _ = w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
result, err := c.StreamGenerate(context.Background(), "roll dice", "", "You are a DM", nil)
if err != nil {
t.Fatal(err)
}
if result != "ok" {
t.Errorf("result = %q", result)
}
}
func TestLLMClient_StreamGenerateNilCallback(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
_, _ = w.Write([]byte(sseChunk("token")))
_, _ = w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
// nil callback should not panic
result, err := c.StreamGenerate(context.Background(), "hi", "", "", nil)
if err != nil {
t.Fatal(err)
}
if result != "token" {
t.Errorf("result = %q", result)
}
}
func TestLLMClient_StreamGenerateEmptyDelta(t *testing.T) {
// SSE chunks with empty content should be silently skipped.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
// role-only chunk (no content) — common for first chunk from vLLM
_, _ = w.Write([]byte("data: {\"choices\":[{\"delta\":{\"role\":\"assistant\"}}]}\n\n"))
// empty content string
_, _ = w.Write([]byte(sseChunk("")))
// real token
_, _ = w.Write([]byte(sseChunk("hello")))
_, _ = w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
var count int
result, err := c.StreamGenerate(context.Background(), "q", "", "", func(tok string) {
count++
})
if err != nil {
t.Fatal(err)
}
if result != "hello" {
t.Errorf("result = %q", result)
}
if count != 1 {
t.Errorf("callback count = %d, want 1 (empty deltas should be skipped)", count)
}
}
func TestLLMClient_StreamGenerateMalformedChunks(t *testing.T) {
// Malformed JSON should be skipped without error.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
_, _ = w.Write([]byte("data: {invalid json}\n\n"))
_, _ = w.Write([]byte("data: {\"choices\":[]}\n\n")) // empty choices
_, _ = w.Write([]byte(sseChunk("good")))
_, _ = w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
result, err := c.StreamGenerate(context.Background(), "q", "", "", nil)
if err != nil {
t.Fatal(err)
}
if result != "good" {
t.Errorf("result = %q, want %q", result, "good")
}
}
func TestLLMClient_StreamGenerateHTTPError(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(500)
_, _ = w.Write([]byte("internal server error"))
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
_, err := c.StreamGenerate(context.Background(), "q", "", "", nil)
if err == nil {
t.Fatal("expected error for 500")
}
if !strings.Contains(err.Error(), "500") {
t.Errorf("error should contain 500: %v", err)
}
}
func TestLLMClient_StreamGenerateContextCanceled(t *testing.T) {
started := make(chan struct{})
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
// Send several tokens so the client receives some before cancel.
for i := range 20 {
_, _ = w.Write([]byte(sseChunk(fmt.Sprintf("tok%d ", i))))
flusher.Flush()
}
close(started)
// Block until client cancels
<-r.Context().Done()
}))
defer ts.Close()
ctx, cancel := context.WithCancel(context.Background())
c := NewLLMClient(ts.URL, 10*time.Second)
var streamErr error
done := make(chan struct{})
go func() {
defer close(done)
_, streamErr = c.StreamGenerate(ctx, "q", "", "", nil)
}()
<-started
cancel()
<-done
// After cancel the stream should return an error (context canceled or
// stream interrupted). The exact partial text depends on timing.
if streamErr == nil {
t.Error("expected error after context cancel")
}
}
func TestLLMClient_StreamGenerateNoSSEPrefix(t *testing.T) {
// Lines without "data: " prefix should be silently ignored (comments, blank lines, event IDs).
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
_, _ = w.Write([]byte(": this is an SSE comment\n\n"))
_, _ = w.Write([]byte("event: message\n"))
_, _ = w.Write([]byte(sseChunk("word")))
_, _ = w.Write([]byte("\n")) // blank line
_, _ = w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
result, err := c.StreamGenerate(context.Background(), "q", "", "", nil)
if err != nil {
t.Fatal(err)
}
if result != "word" {
t.Errorf("result = %q, want %q", result, "word")
}
}
func TestLLMClient_StreamGenerateManyTokens(t *testing.T) {
// Verify token ordering and full assembly with many chunks.
n := 100
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
for i := range n {
tok := fmt.Sprintf("t%d ", i)
_, _ = w.Write([]byte(sseChunk(tok)))
flusher.Flush()
}
_, _ = w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
var mu sync.Mutex
var order []int
result, err := c.StreamGenerate(context.Background(), "q", "", "", func(tok string) {
var idx int
fmt.Sscanf(tok, "t%d ", &idx)
mu.Lock()
order = append(order, idx)
mu.Unlock()
})
if err != nil {
t.Fatal(err)
}
// Verify all tokens arrived in order
if len(order) != n {
t.Fatalf("got %d tokens, want %d", len(order), n)
}
for i, v := range order {
if v != i {
t.Errorf("order[%d] = %d", i, v)
break
}
}
// Quick sanity: result should start with "t0 " and end with last token
if !strings.HasPrefix(result, "t0 ") {
t.Errorf("result prefix = %q", result[:10])
}
}
// ────────────────────────────────────────────────────────────────────────────
// TTS client
// ────────────────────────────────────────────────────────────────────────────