feat: add StreamGenerate for real SSE streaming from LLM
Some checks failed
CI / Lint (push) Failing after 2m44s
CI / Test (push) Successful in 3m7s
CI / Release (push) Has been skipped
CI / Notify Downstream (chat-handler) (push) Has been skipped
CI / Notify Downstream (pipeline-bridge) (push) Has been skipped
CI / Notify Downstream (stt-module) (push) Has been skipped
CI / Notify Downstream (tts-module) (push) Has been skipped
CI / Notify Downstream (voice-assistant) (push) Has been skipped
CI / Notify (push) Successful in 2s

- Add postJSONStream() for incremental response body reading
- Add LLMClient.StreamGenerate() with SSE parsing and onToken callback
- Supports stream:true, parses data: lines, handles [DONE] sentinel
- Graceful partial-text return on stream interruption
- 9 new tests covering happy path, edge cases, cancellation
This commit is contained in:
2026-02-20 17:55:01 -05:00
parent fba7b62573
commit 3585d81ff5
2 changed files with 387 additions and 0 deletions

View File

@@ -6,6 +6,7 @@
package clients
import (
"bufio"
"bytes"
"context"
"encoding/json"
@@ -14,6 +15,7 @@ import (
"mime/multipart"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
@@ -142,6 +144,36 @@ func (h *httpClient) do(req *http.Request) ([]byte, error) {
return body, nil
}
// postJSONStream sends a JSON POST and returns the raw *http.Response so the
// caller can read the body incrementally (e.g. for SSE streaming). The caller
// is responsible for closing resp.Body.
func (h *httpClient) postJSONStream(ctx context.Context, path string, body any) (*http.Response, error) {
buf := getBuf()
defer putBuf(buf)
if err := json.NewEncoder(buf).Encode(body); err != nil {
return nil, fmt.Errorf("marshal: %w", err)
}
// Copy to a non-pooled buffer so we can safely return the pool buffer.
payload := make([]byte, buf.Len())
copy(payload, buf.Bytes())
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, bytes.NewReader(payload))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := h.client.Do(req)
if err != nil {
return nil, fmt.Errorf("http %s %s: %w", req.Method, req.URL.Path, err)
}
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
return nil, fmt.Errorf("http %d: %s", resp.StatusCode, string(respBody))
}
return resp, nil
}
func (h *httpClient) healthCheck(ctx context.Context) bool {
data, err := h.get(ctx, "/health", nil)
_ = data
@@ -320,6 +352,73 @@ func (c *LLMClient) Generate(ctx context.Context, prompt string, context_ string
return resp.Choices[0].Message.Content, nil
}
// StreamGenerate sends a streaming chat completion request and calls onToken
// for each content delta received via SSE. Returns the fully assembled text.
// The onToken callback is invoked synchronously on the calling goroutine; it
// should be fast (e.g. publish a NATS message).
func (c *LLMClient) StreamGenerate(ctx context.Context, prompt string, context_ string, systemPrompt string, onToken func(token string)) (string, error) {
msgs := buildMessages(prompt, context_, systemPrompt)
payload := map[string]any{
"model": c.Model,
"messages": msgs,
"max_tokens": c.MaxTokens,
"temperature": c.Temperature,
"top_p": c.TopP,
"stream": true,
}
resp, err := c.postJSONStream(ctx, "/v1/chat/completions", payload)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
var full strings.Builder
scanner := bufio.NewScanner(resp.Body)
// SSE lines can be up to 64 KiB for large token batches.
scanner.Buffer(make([]byte, 0, 64*1024), 64*1024)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
break
}
var chunk struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
} `json:"choices"`
}
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue // skip malformed chunks
}
if len(chunk.Choices) == 0 {
continue
}
token := chunk.Choices[0].Delta.Content
if token == "" {
continue
}
full.WriteString(token)
if onToken != nil {
onToken(token)
}
}
if err := scanner.Err(); err != nil {
// If we already collected some text, return it with the error.
if full.Len() > 0 {
return full.String(), fmt.Errorf("stream interrupted: %w", err)
}
return "", fmt.Errorf("stream read: %w", err)
}
return full.String(), nil
}
func buildMessages(prompt, ctx, systemPrompt string) []ChatMessage {
var msgs []ChatMessage
if systemPrompt != "" {