feat: add StreamGenerate for real SSE streaming from LLM
- Add postJSONStream() for incremental response body reading - Add LLMClient.StreamGenerate() with SSE parsing and onToken callback - Supports stream:true, parses data: lines, handles [DONE] sentinel - Graceful partial-text return on stream interruption - 9 new tests covering happy path, edge cases, cancellation
This commit is contained in:
@@ -6,6 +6,7 @@
|
|||||||
package clients
|
package clients
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -14,6 +15,7 @@ import (
|
|||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -142,6 +144,36 @@ func (h *httpClient) do(req *http.Request) ([]byte, error) {
|
|||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// postJSONStream sends a JSON POST and returns the raw *http.Response so the
|
||||||
|
// caller can read the body incrementally (e.g. for SSE streaming). The caller
|
||||||
|
// is responsible for closing resp.Body.
|
||||||
|
func (h *httpClient) postJSONStream(ctx context.Context, path string, body any) (*http.Response, error) {
|
||||||
|
buf := getBuf()
|
||||||
|
defer putBuf(buf)
|
||||||
|
if err := json.NewEncoder(buf).Encode(body); err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal: %w", err)
|
||||||
|
}
|
||||||
|
// Copy to a non-pooled buffer so we can safely return the pool buffer.
|
||||||
|
payload := make([]byte, buf.Len())
|
||||||
|
copy(payload, buf.Bytes())
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, bytes.NewReader(payload))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
resp, err := h.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("http %s %s: %w", req.Method, req.URL.Path, err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
return nil, fmt.Errorf("http %d: %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (h *httpClient) healthCheck(ctx context.Context) bool {
|
func (h *httpClient) healthCheck(ctx context.Context) bool {
|
||||||
data, err := h.get(ctx, "/health", nil)
|
data, err := h.get(ctx, "/health", nil)
|
||||||
_ = data
|
_ = data
|
||||||
@@ -320,6 +352,73 @@ func (c *LLMClient) Generate(ctx context.Context, prompt string, context_ string
|
|||||||
return resp.Choices[0].Message.Content, nil
|
return resp.Choices[0].Message.Content, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StreamGenerate sends a streaming chat completion request and calls onToken
|
||||||
|
// for each content delta received via SSE. Returns the fully assembled text.
|
||||||
|
// The onToken callback is invoked synchronously on the calling goroutine; it
|
||||||
|
// should be fast (e.g. publish a NATS message).
|
||||||
|
func (c *LLMClient) StreamGenerate(ctx context.Context, prompt string, context_ string, systemPrompt string, onToken func(token string)) (string, error) {
|
||||||
|
msgs := buildMessages(prompt, context_, systemPrompt)
|
||||||
|
payload := map[string]any{
|
||||||
|
"model": c.Model,
|
||||||
|
"messages": msgs,
|
||||||
|
"max_tokens": c.MaxTokens,
|
||||||
|
"temperature": c.Temperature,
|
||||||
|
"top_p": c.TopP,
|
||||||
|
"stream": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.postJSONStream(ctx, "/v1/chat/completions", payload)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
var full strings.Builder
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
// SSE lines can be up to 64 KiB for large token batches.
|
||||||
|
scanner.Buffer(make([]byte, 0, 64*1024), 64*1024)
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if !strings.HasPrefix(line, "data: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := strings.TrimPrefix(line, "data: ")
|
||||||
|
if data == "[DONE]" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
var chunk struct {
|
||||||
|
Choices []struct {
|
||||||
|
Delta struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
} `json:"delta"`
|
||||||
|
} `json:"choices"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||||
|
continue // skip malformed chunks
|
||||||
|
}
|
||||||
|
if len(chunk.Choices) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
token := chunk.Choices[0].Delta.Content
|
||||||
|
if token == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
full.WriteString(token)
|
||||||
|
if onToken != nil {
|
||||||
|
onToken(token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
// If we already collected some text, return it with the error.
|
||||||
|
if full.Len() > 0 {
|
||||||
|
return full.String(), fmt.Errorf("stream interrupted: %w", err)
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("stream read: %w", err)
|
||||||
|
}
|
||||||
|
return full.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
func buildMessages(prompt, ctx, systemPrompt string) []ChatMessage {
|
func buildMessages(prompt, ctx, systemPrompt string) []ChatMessage {
|
||||||
var msgs []ChatMessage
|
var msgs []ChatMessage
|
||||||
if systemPrompt != "" {
|
if systemPrompt != "" {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -299,6 +300,293 @@ func TestLLMClient_GenerateNoChoices(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// LLM client — StreamGenerate
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// sseChunk builds an OpenAI-compatible SSE chat.completion.chunk line.
|
||||||
|
func sseChunk(content string) string {
|
||||||
|
chunk := map[string]any{
|
||||||
|
"choices": []map[string]any{
|
||||||
|
{"delta": map[string]any{"content": content}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b, _ := json.Marshal(chunk)
|
||||||
|
return "data: " + string(b) + "\n\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_StreamGenerate(t *testing.T) {
|
||||||
|
tokens := []string{"Hello", " world", "!"}
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/v1/chat/completions" {
|
||||||
|
t.Errorf("path = %q", r.URL.Path)
|
||||||
|
}
|
||||||
|
var req map[string]any
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
if req["stream"] != true {
|
||||||
|
t.Errorf("stream = %v, want true", req["stream"])
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
flusher, _ := w.(http.Flusher)
|
||||||
|
for _, tok := range tokens {
|
||||||
|
_, _ = w.Write([]byte(sseChunk(tok)))
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
var received []string
|
||||||
|
result, err := c.StreamGenerate(context.Background(), "hi", "", "", func(tok string) {
|
||||||
|
received = append(received, tok)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result != "Hello world!" {
|
||||||
|
t.Errorf("result = %q, want %q", result, "Hello world!")
|
||||||
|
}
|
||||||
|
if len(received) != 3 {
|
||||||
|
t.Fatalf("callback count = %d, want 3", len(received))
|
||||||
|
}
|
||||||
|
if received[0] != "Hello" || received[1] != " world" || received[2] != "!" {
|
||||||
|
t.Errorf("received = %v", received)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_StreamGenerateWithSystemPrompt(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req map[string]any
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
msgs, _ := req["messages"].([]any)
|
||||||
|
if len(msgs) != 2 {
|
||||||
|
t.Errorf("expected system+user, got %d messages", len(msgs))
|
||||||
|
}
|
||||||
|
first, _ := msgs[0].(map[string]any)
|
||||||
|
if first["role"] != "system" || first["content"] != "You are a DM" {
|
||||||
|
t.Errorf("system msg = %v", first)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
flusher, _ := w.(http.Flusher)
|
||||||
|
_, _ = w.Write([]byte(sseChunk("ok")))
|
||||||
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
result, err := c.StreamGenerate(context.Background(), "roll dice", "", "You are a DM", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result != "ok" {
|
||||||
|
t.Errorf("result = %q", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_StreamGenerateNilCallback(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
flusher, _ := w.(http.Flusher)
|
||||||
|
_, _ = w.Write([]byte(sseChunk("token")))
|
||||||
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
// nil callback should not panic
|
||||||
|
result, err := c.StreamGenerate(context.Background(), "hi", "", "", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result != "token" {
|
||||||
|
t.Errorf("result = %q", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_StreamGenerateEmptyDelta(t *testing.T) {
|
||||||
|
// SSE chunks with empty content should be silently skipped.
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
flusher, _ := w.(http.Flusher)
|
||||||
|
// role-only chunk (no content) — common for first chunk from vLLM
|
||||||
|
_, _ = w.Write([]byte("data: {\"choices\":[{\"delta\":{\"role\":\"assistant\"}}]}\n\n"))
|
||||||
|
// empty content string
|
||||||
|
_, _ = w.Write([]byte(sseChunk("")))
|
||||||
|
// real token
|
||||||
|
_, _ = w.Write([]byte(sseChunk("hello")))
|
||||||
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
var count int
|
||||||
|
result, err := c.StreamGenerate(context.Background(), "q", "", "", func(tok string) {
|
||||||
|
count++
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result != "hello" {
|
||||||
|
t.Errorf("result = %q", result)
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
t.Errorf("callback count = %d, want 1 (empty deltas should be skipped)", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_StreamGenerateMalformedChunks(t *testing.T) {
|
||||||
|
// Malformed JSON should be skipped without error.
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
flusher, _ := w.(http.Flusher)
|
||||||
|
_, _ = w.Write([]byte("data: {invalid json}\n\n"))
|
||||||
|
_, _ = w.Write([]byte("data: {\"choices\":[]}\n\n")) // empty choices
|
||||||
|
_, _ = w.Write([]byte(sseChunk("good")))
|
||||||
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
result, err := c.StreamGenerate(context.Background(), "q", "", "", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result != "good" {
|
||||||
|
t.Errorf("result = %q, want %q", result, "good")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_StreamGenerateHTTPError(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(500)
|
||||||
|
_, _ = w.Write([]byte("internal server error"))
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
_, err := c.StreamGenerate(context.Background(), "q", "", "", nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 500")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "500") {
|
||||||
|
t.Errorf("error should contain 500: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_StreamGenerateContextCanceled(t *testing.T) {
|
||||||
|
started := make(chan struct{})
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
flusher, _ := w.(http.Flusher)
|
||||||
|
// Send several tokens so the client receives some before cancel.
|
||||||
|
for i := range 20 {
|
||||||
|
_, _ = w.Write([]byte(sseChunk(fmt.Sprintf("tok%d ", i))))
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
close(started)
|
||||||
|
// Block until client cancels
|
||||||
|
<-r.Context().Done()
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
c := NewLLMClient(ts.URL, 10*time.Second)
|
||||||
|
|
||||||
|
var streamErr error
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
_, streamErr = c.StreamGenerate(ctx, "q", "", "", nil)
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-started
|
||||||
|
cancel()
|
||||||
|
<-done
|
||||||
|
|
||||||
|
// After cancel the stream should return an error (context canceled or
|
||||||
|
// stream interrupted). The exact partial text depends on timing.
|
||||||
|
if streamErr == nil {
|
||||||
|
t.Error("expected error after context cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_StreamGenerateNoSSEPrefix(t *testing.T) {
|
||||||
|
// Lines without "data: " prefix should be silently ignored (comments, blank lines, event IDs).
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
flusher, _ := w.(http.Flusher)
|
||||||
|
_, _ = w.Write([]byte(": this is an SSE comment\n\n"))
|
||||||
|
_, _ = w.Write([]byte("event: message\n"))
|
||||||
|
_, _ = w.Write([]byte(sseChunk("word")))
|
||||||
|
_, _ = w.Write([]byte("\n")) // blank line
|
||||||
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
result, err := c.StreamGenerate(context.Background(), "q", "", "", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result != "word" {
|
||||||
|
t.Errorf("result = %q, want %q", result, "word")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_StreamGenerateManyTokens(t *testing.T) {
|
||||||
|
// Verify token ordering and full assembly with many chunks.
|
||||||
|
n := 100
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
flusher, _ := w.(http.Flusher)
|
||||||
|
for i := range n {
|
||||||
|
tok := fmt.Sprintf("t%d ", i)
|
||||||
|
_, _ = w.Write([]byte(sseChunk(tok)))
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
var mu sync.Mutex
|
||||||
|
var order []int
|
||||||
|
result, err := c.StreamGenerate(context.Background(), "q", "", "", func(tok string) {
|
||||||
|
var idx int
|
||||||
|
fmt.Sscanf(tok, "t%d ", &idx)
|
||||||
|
mu.Lock()
|
||||||
|
order = append(order, idx)
|
||||||
|
mu.Unlock()
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// Verify all tokens arrived in order
|
||||||
|
if len(order) != n {
|
||||||
|
t.Fatalf("got %d tokens, want %d", len(order), n)
|
||||||
|
}
|
||||||
|
for i, v := range order {
|
||||||
|
if v != i {
|
||||||
|
t.Errorf("order[%d] = %d", i, v)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Quick sanity: result should start with "t0 " and end with last token
|
||||||
|
if !strings.HasPrefix(result, "t0 ") {
|
||||||
|
t.Errorf("result prefix = %q", result[:10])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
// TTS client
|
// TTS client
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|||||||
Reference in New Issue
Block a user