From 87d0545d2c58dd337e75b5a9c8eee922f2269c60 Mon Sep 17 00:00:00 2001 From: "Billy D." Date: Sat, 21 Feb 2026 09:23:57 -0500 Subject: [PATCH] feat: replace fake streaming with real SSE StreamGenerate Use handler-base StreamGenerate() to publish real token-by-token ChatStreamChunk messages to NATS as they arrive from Ray Serve, instead of calling Generate() and splitting into 4-word chunks. Add 8 streaming tests: happy path, system prompt, RAG context, nil callback, timeout, HTTP error, context canceled, fallback. --- e2e_test.go | 237 ++++++++++++++++++++++++++++++++++++++++++++++++++++ main.go | 51 +++++------ 2 files changed, 259 insertions(+), 29 deletions(-) diff --git a/e2e_test.go b/e2e_test.go index 98a9a19..02d0bad 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -3,8 +3,11 @@ package main import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" + "strings" + "sync" "testing" "time" @@ -198,6 +201,220 @@ func TestChatPipeline_TypedDecoding(t *testing.T) { } } +// ──────────────────────────────────────────────────────────────────────────── +// Streaming tests: exercise StreamGenerate path (the real SSE pipeline) +// ──────────────────────────────────────────────────────────────────────────── + +// sseChunk builds an OpenAI-compatible SSE data line. +func sseChunk(content string) string { + return fmt.Sprintf("data: {\"choices\":[{\"delta\":{\"content\":%q}}]}\n\n", content) +} + +// newStreamingLLM creates a mock LLM server that responds with SSE-streamed tokens. +func newStreamingLLM(t *testing.T, tokens []string) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req map[string]any + _ = json.NewDecoder(r.Body).Decode(&req) + // Verify stream=true was requested. + if stream, ok := req["stream"].(bool); !ok || !stream { + t.Error("expected stream=true in request body") + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + flusher, _ := w.(http.Flusher) + + // Role-only chunk (should be skipped by StreamGenerate) + _, _ = fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"role\":\"assistant\"}}]}\n\n") + if flusher != nil { + flusher.Flush() + } + + for _, tok := range tokens { + _, _ = fmt.Fprint(w, sseChunk(tok)) + if flusher != nil { + flusher.Flush() + } + } + _, _ = fmt.Fprintf(w, "data: [DONE]\n\n") + if flusher != nil { + flusher.Flush() + } + })) + t.Cleanup(srv.Close) + return srv +} + +func TestChatPipeline_StreamGenerate(t *testing.T) { + tokens := []string{"Paris", " is", " the", " capital", " of", " France", "."} + srv := newStreamingLLM(t, tokens) + llm := clients.NewLLMClient(srv.URL, 5*time.Second) + + var mu sync.Mutex + var received []string + full, err := llm.StreamGenerate(context.Background(), "capital of France?", "", "", func(token string) { + mu.Lock() + defer mu.Unlock() + received = append(received, token) + }) + if err != nil { + t.Fatal(err) + } + if full != "Paris is the capital of France." { + t.Errorf("full = %q", full) + } + if len(received) != len(tokens) { + t.Errorf("callback count = %d, want %d", len(received), len(tokens)) + } + for i, tok := range tokens { + if received[i] != tok { + t.Errorf("token[%d] = %q, want %q", i, received[i], tok) + } + } +} + +func TestChatPipeline_StreamWithSystemPrompt(t *testing.T) { + srv := newStreamingLLM(t, []string{"Hello", "!"}) + llm := clients.NewLLMClient(srv.URL, 5*time.Second) + + full, err := llm.StreamGenerate(context.Background(), "greet me", "", "You are a friendly assistant.", func(token string) {}) + if err != nil { + t.Fatal(err) + } + if full != "Hello!" { + t.Errorf("full = %q", full) + } +} + +func TestChatPipeline_StreamWithRAGContext(t *testing.T) { + m := newMockBackends(t) + srv := newStreamingLLM(t, []string{"The", " answer", " is", " 42"}) + embeddings := clients.NewEmbeddingsClient(m.Embeddings.URL, 5*time.Second, "bge") + llm := clients.NewLLMClient(srv.URL, 5*time.Second) + + ctx := context.Background() + + // 1. Embed + embedding, err := embeddings.EmbedSingle(ctx, "deep thought") + if err != nil { + t.Fatal(err) + } + if len(embedding) == 0 { + t.Fatal("empty embedding") + } + + // 2. Stream with context + var tokens []string + full, err := llm.StreamGenerate(ctx, "deep thought", "The answer to everything is 42.", "", func(tok string) { + tokens = append(tokens, tok) + }) + if err != nil { + t.Fatal(err) + } + if full != "The answer is 42" { + t.Errorf("full = %q", full) + } + if len(tokens) != 4 { + t.Errorf("token count = %d, want 4", len(tokens)) + } +} + +func TestChatPipeline_StreamNilCallback(t *testing.T) { + srv := newStreamingLLM(t, []string{"ok"}) + llm := clients.NewLLMClient(srv.URL, 5*time.Second) + + full, err := llm.StreamGenerate(context.Background(), "test", "", "", nil) + if err != nil { + t.Fatal(err) + } + if full != "ok" { + t.Errorf("full = %q", full) + } +} + +func TestChatPipeline_StreamTimeout(t *testing.T) { + slow := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) + w.Header().Set("Content-Type", "text/event-stream") + _, _ = fmt.Fprint(w, sseChunk("late")) + _, _ = fmt.Fprint(w, "data: [DONE]\n\n") + })) + defer slow.Close() + + llm := clients.NewLLMClient(slow.URL, 100*time.Millisecond) + _, err := llm.StreamGenerate(context.Background(), "hello", "", "", nil) + if err == nil { + t.Error("expected timeout error") + } +} + +func TestChatPipeline_StreamHTTPError(t *testing.T) { + errSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("internal error")) + })) + defer errSrv.Close() + + llm := clients.NewLLMClient(errSrv.URL, 5*time.Second) + _, err := llm.StreamGenerate(context.Background(), "hello", "", "", nil) + if err == nil { + t.Error("expected error for HTTP 500") + } + if !strings.Contains(err.Error(), "500") { + t.Errorf("error = %q, should mention status 500", err) + } +} + +func TestChatPipeline_StreamContextCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + srv := newStreamingLLM(t, []string{"should", "not", "arrive"}) + llm := clients.NewLLMClient(srv.URL, 5*time.Second) + + _, err := llm.StreamGenerate(ctx, "hello", "", "", nil) + if err == nil { + t.Error("expected context canceled error") + } +} + +func TestChatPipeline_StreamFallbackToNonStreaming(t *testing.T) { + // Simulate the branching in main.go: non-streaming uses Generate(), + // streaming uses StreamGenerate(). Verify both paths work from same mock. + m := newMockBackends(t) + streamSrv := newStreamingLLM(t, []string{"streamed", " answer"}) + + nonStreamLLM := clients.NewLLMClient(m.LLM.URL, 5*time.Second) + streamLLM := clients.NewLLMClient(streamSrv.URL, 5*time.Second) + + ctx := context.Background() + + // Non-streaming path + resp1, err := nonStreamLLM.Generate(ctx, "hello", "", "") + if err != nil { + t.Fatal(err) + } + if resp1 != "Paris is the capital of France." { + t.Errorf("non-stream = %q", resp1) + } + + // Streaming path + var tokens []string + resp2, err := streamLLM.StreamGenerate(ctx, "hello", "", "", func(tok string) { + tokens = append(tokens, tok) + }) + if err != nil { + t.Fatal(err) + } + if resp2 != "streamed answer" { + t.Errorf("stream = %q", resp2) + } + if len(tokens) != 2 { + t.Errorf("token count = %d", len(tokens)) + } +} + // ──────────────────────────────────────────────────────────────────────────── // Benchmark: full chat pipeline overhead (mock backends) // ──────────────────────────────────────────────────────────────────────────── @@ -245,3 +462,23 @@ func BenchmarkChatPipeline_RAGFlow(b *testing.B) { _, _ = llm.Generate(ctx, "question", "context", "") } } + +func BenchmarkChatPipeline_StreamGenerate(b *testing.B) { + tokens := []string{"one", " two", " three", " four", " five"} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + for _, tok := range tokens { + _, _ = fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":%q}}]}\n\n", tok) + } + _, _ = fmt.Fprint(w, "data: [DONE]\n\n") + })) + defer srv.Close() + + llm := clients.NewLLMClient(srv.URL, 10*time.Second) + ctx := context.Background() + + b.ResetTimer() + for b.Loop() { + _, _ = llm.StreamGenerate(ctx, "question", "", "", func(string) {}) + } +} diff --git a/main.go b/main.go index e2beb31..cd1b106 100644 --- a/main.go +++ b/main.go @@ -157,8 +157,27 @@ func main() { } } - // 5. Generate LLM response - responseText, err := llm.Generate(ctx, query, contextText, systemPrompt) + // 5. Generate LLM response (streaming when requested) + var responseText string + if req.EnableStreaming { + streamSubject := fmt.Sprintf("ai.chat.response.stream.%s", requestID) + responseText, err = llm.StreamGenerate(ctx, query, contextText, systemPrompt, func(token string) { + _ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{ + RequestID: requestID, + Type: "chunk", + Content: token, + Timestamp: messages.Timestamp(), + }) + }) + _ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{ + RequestID: requestID, + Type: "done", + Done: true, + Timestamp: messages.Timestamp(), + }) + } else { + responseText, err = llm.Generate(ctx, query, contextText, systemPrompt) + } if err != nil { slog.Error("LLM generation failed", "error", err) return &messages.ChatResponse{ @@ -168,33 +187,7 @@ func main() { }, nil } - // 6. Stream chunks if requested - if req.EnableStreaming { - streamSubject := fmt.Sprintf("ai.chat.response.stream.%s", requestID) - words := strings.Fields(responseText) - chunkSize := 4 - for i := 0; i < len(words); i += chunkSize { - end := i + chunkSize - if end > len(words) { - end = len(words) - } - chunk := strings.Join(words[i:end], " ") - _ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{ - RequestID: requestID, - Type: "chunk", - Content: chunk, - Timestamp: messages.Timestamp(), - }) - } - _ = h.NATS.Publish(streamSubject, &messages.ChatStreamChunk{ - RequestID: requestID, - Type: "done", - Done: true, - Timestamp: messages.Timestamp(), - }) - } - - // 7. Optional TTS — audio as raw bytes (no base64) + // 6. Optional TTS — audio as raw bytes (no base64) var audio []byte if reqEnableTTS && tts != nil { audioBytes, err := tts.Synthesize(ctx, responseText, ttsLanguage, "")