package main import ( "context" "encoding/json" "fmt" "net/http" "net/http/httptest" "strings" "sync" "testing" "time" "git.daviestechlabs.io/daviestechlabs/handler-base/clients" "git.daviestechlabs.io/daviestechlabs/handler-base/messages" "github.com/vmihailenco/msgpack/v5" ) // ──────────────────────────────────────────────────────────────────────────── // E2E tests: exercise the full chat pipeline with mock backends // ──────────────────────────────────────────────────────────────────────────── // mockBackends starts httptest servers simulating all downstream services. type mockBackends struct { Embeddings *httptest.Server Reranker *httptest.Server LLM *httptest.Server TTS *httptest.Server } func newMockBackends(t *testing.T) *mockBackends { t.Helper() m := &mockBackends{} m.Embeddings = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _ = json.NewEncoder(w).Encode(map[string]any{ "data": []map[string]any{ {"embedding": []float64{0.1, 0.2, 0.3, 0.4}}, }, }) })) t.Cleanup(m.Embeddings.Close) m.Reranker = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _ = json.NewEncoder(w).Encode(map[string]any{ "results": []map[string]any{ {"index": 0, "relevance_score": 0.95}, }, }) })) t.Cleanup(m.Reranker.Close) m.LLM = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var req map[string]any _ = json.NewDecoder(r.Body).Decode(&req) _ = json.NewEncoder(w).Encode(map[string]any{ "choices": []map[string]any{ {"message": map[string]any{ "content": "Paris is the capital of France.", }}, }, }) })) t.Cleanup(m.LLM.Close) m.TTS = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte{0xDE, 0xAD, 0xBE, 0xEF}) })) t.Cleanup(m.TTS.Close) return m } func TestChatPipeline_LLMOnly(t *testing.T) { m := newMockBackends(t) llm := clients.NewLLMClient(m.LLM.URL, 5*time.Second) // Simulate what main.go does for a non-RAG request. response, err := llm.Generate(context.Background(), "What is the capital of France?", "", "") if err != nil { t.Fatal(err) } if response != "Paris is the capital of France." { t.Errorf("response = %q", response) } } func TestChatPipeline_WithRAG(t *testing.T) { m := newMockBackends(t) embeddings := clients.NewEmbeddingsClient(m.Embeddings.URL, 5*time.Second, "bge") reranker := clients.NewRerankerClient(m.Reranker.URL, 5*time.Second) llm := clients.NewLLMClient(m.LLM.URL, 5*time.Second) ctx := context.Background() // 1. Embed query embedding, err := embeddings.EmbedSingle(ctx, "What is the capital of France?") if err != nil { t.Fatal(err) } if len(embedding) == 0 { t.Fatal("empty embedding") } // 2. Rerank (with mock documents) docs := []string{"France is a country in Europe", "Paris is its capital"} results, err := reranker.Rerank(ctx, "capital of France", docs, 2) if err != nil { t.Fatal(err) } if len(results) == 0 { t.Fatal("no rerank results") } if results[0].Score == 0 { t.Error("expected non-zero score") } // 3. Generate with context contextText := results[0].Document response, err := llm.Generate(ctx, "capital of France?", contextText, "") if err != nil { t.Fatal(err) } if response == "" { t.Error("empty response") } } func TestChatPipeline_WithTTS(t *testing.T) { m := newMockBackends(t) llm := clients.NewLLMClient(m.LLM.URL, 5*time.Second) tts := clients.NewTTSClient(m.TTS.URL, 5*time.Second, "en") ctx := context.Background() response, err := llm.Generate(ctx, "hello", "", "") if err != nil { t.Fatal(err) } audio, err := tts.Synthesize(ctx, response, "en", "") if err != nil { t.Fatal(err) } if len(audio) == 0 { t.Error("empty audio") } } func TestChatPipeline_LLMTimeout(t *testing.T) { // Simulate slow LLM. slow := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(200 * time.Millisecond) _ = json.NewEncoder(w).Encode(map[string]any{ "choices": []map[string]any{ {"message": map[string]any{"content": "late response"}}, }, }) })) defer slow.Close() llm := clients.NewLLMClient(slow.URL, 100*time.Millisecond) _, err := llm.Generate(context.Background(), "hello", "", "") if err == nil { t.Error("expected timeout error") } } func TestChatPipeline_TypedDecoding(t *testing.T) { // Verify typed struct decoding from msgpack (same path as OnTypedMessage). raw := map[string]any{ "request_id": "req-e2e-001", "user_id": "user-1", "message": "hello", "premium": true, "enable_rag": false, "enable_streaming": false, "system_prompt": "Be brief.", } data, _ := msgpack.Marshal(raw) var req messages.ChatRequest if err := msgpack.Unmarshal(data, &req); err != nil { t.Fatal(err) } if req.RequestID != "req-e2e-001" { t.Errorf("RequestID = %q", req.RequestID) } if req.UserID != "user-1" { t.Errorf("UserID = %q", req.UserID) } if req.EffectiveQuery() != "hello" { t.Errorf("query = %q", req.EffectiveQuery()) } if req.EnableRAG { t.Error("EnableRAG should be false") } if req.SystemPrompt != "Be brief." { t.Errorf("SystemPrompt = %q", req.SystemPrompt) } } // ──────────────────────────────────────────────────────────────────────────── // Streaming tests: exercise StreamGenerate path (the real SSE pipeline) // ──────────────────────────────────────────────────────────────────────────── // sseChunk builds an OpenAI-compatible SSE data line. func sseChunk(content string) string { return fmt.Sprintf("data: {\"choices\":[{\"delta\":{\"content\":%q}}]}\n\n", content) } // newStreamingLLM creates a mock LLM server that responds with SSE-streamed tokens. func newStreamingLLM(t *testing.T, tokens []string) *httptest.Server { t.Helper() srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var req map[string]any _ = json.NewDecoder(r.Body).Decode(&req) // Verify stream=true was requested. if stream, ok := req["stream"].(bool); !ok || !stream { t.Error("expected stream=true in request body") } w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") flusher, _ := w.(http.Flusher) // Role-only chunk (should be skipped by StreamGenerate) _, _ = fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"role\":\"assistant\"}}]}\n\n") if flusher != nil { flusher.Flush() } for _, tok := range tokens { _, _ = fmt.Fprint(w, sseChunk(tok)) if flusher != nil { flusher.Flush() } } _, _ = fmt.Fprintf(w, "data: [DONE]\n\n") if flusher != nil { flusher.Flush() } })) t.Cleanup(srv.Close) return srv } func TestChatPipeline_StreamGenerate(t *testing.T) { tokens := []string{"Paris", " is", " the", " capital", " of", " France", "."} srv := newStreamingLLM(t, tokens) llm := clients.NewLLMClient(srv.URL, 5*time.Second) var mu sync.Mutex var received []string full, err := llm.StreamGenerate(context.Background(), "capital of France?", "", "", func(token string) { mu.Lock() defer mu.Unlock() received = append(received, token) }) if err != nil { t.Fatal(err) } if full != "Paris is the capital of France." { t.Errorf("full = %q", full) } if len(received) != len(tokens) { t.Errorf("callback count = %d, want %d", len(received), len(tokens)) } for i, tok := range tokens { if received[i] != tok { t.Errorf("token[%d] = %q, want %q", i, received[i], tok) } } } func TestChatPipeline_StreamWithSystemPrompt(t *testing.T) { srv := newStreamingLLM(t, []string{"Hello", "!"}) llm := clients.NewLLMClient(srv.URL, 5*time.Second) full, err := llm.StreamGenerate(context.Background(), "greet me", "", "You are a friendly assistant.", func(token string) {}) if err != nil { t.Fatal(err) } if full != "Hello!" { t.Errorf("full = %q", full) } } func TestChatPipeline_StreamWithRAGContext(t *testing.T) { m := newMockBackends(t) srv := newStreamingLLM(t, []string{"The", " answer", " is", " 42"}) embeddings := clients.NewEmbeddingsClient(m.Embeddings.URL, 5*time.Second, "bge") llm := clients.NewLLMClient(srv.URL, 5*time.Second) ctx := context.Background() // 1. Embed embedding, err := embeddings.EmbedSingle(ctx, "deep thought") if err != nil { t.Fatal(err) } if len(embedding) == 0 { t.Fatal("empty embedding") } // 2. Stream with context var tokens []string full, err := llm.StreamGenerate(ctx, "deep thought", "The answer to everything is 42.", "", func(tok string) { tokens = append(tokens, tok) }) if err != nil { t.Fatal(err) } if full != "The answer is 42" { t.Errorf("full = %q", full) } if len(tokens) != 4 { t.Errorf("token count = %d, want 4", len(tokens)) } } func TestChatPipeline_StreamNilCallback(t *testing.T) { srv := newStreamingLLM(t, []string{"ok"}) llm := clients.NewLLMClient(srv.URL, 5*time.Second) full, err := llm.StreamGenerate(context.Background(), "test", "", "", nil) if err != nil { t.Fatal(err) } if full != "ok" { t.Errorf("full = %q", full) } } func TestChatPipeline_StreamTimeout(t *testing.T) { slow := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(200 * time.Millisecond) w.Header().Set("Content-Type", "text/event-stream") _, _ = fmt.Fprint(w, sseChunk("late")) _, _ = fmt.Fprint(w, "data: [DONE]\n\n") })) defer slow.Close() llm := clients.NewLLMClient(slow.URL, 100*time.Millisecond) _, err := llm.StreamGenerate(context.Background(), "hello", "", "", nil) if err == nil { t.Error("expected timeout error") } } func TestChatPipeline_StreamHTTPError(t *testing.T) { errSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) _, _ = w.Write([]byte("internal error")) })) defer errSrv.Close() llm := clients.NewLLMClient(errSrv.URL, 5*time.Second) _, err := llm.StreamGenerate(context.Background(), "hello", "", "", nil) if err == nil { t.Error("expected error for HTTP 500") } if !strings.Contains(err.Error(), "500") { t.Errorf("error = %q, should mention status 500", err) } } func TestChatPipeline_StreamContextCanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // cancel immediately srv := newStreamingLLM(t, []string{"should", "not", "arrive"}) llm := clients.NewLLMClient(srv.URL, 5*time.Second) _, err := llm.StreamGenerate(ctx, "hello", "", "", nil) if err == nil { t.Error("expected context canceled error") } } func TestChatPipeline_StreamFallbackToNonStreaming(t *testing.T) { // Simulate the branching in main.go: non-streaming uses Generate(), // streaming uses StreamGenerate(). Verify both paths work from same mock. m := newMockBackends(t) streamSrv := newStreamingLLM(t, []string{"streamed", " answer"}) nonStreamLLM := clients.NewLLMClient(m.LLM.URL, 5*time.Second) streamLLM := clients.NewLLMClient(streamSrv.URL, 5*time.Second) ctx := context.Background() // Non-streaming path resp1, err := nonStreamLLM.Generate(ctx, "hello", "", "") if err != nil { t.Fatal(err) } if resp1 != "Paris is the capital of France." { t.Errorf("non-stream = %q", resp1) } // Streaming path var tokens []string resp2, err := streamLLM.StreamGenerate(ctx, "hello", "", "", func(tok string) { tokens = append(tokens, tok) }) if err != nil { t.Fatal(err) } if resp2 != "streamed answer" { t.Errorf("stream = %q", resp2) } if len(tokens) != 2 { t.Errorf("token count = %d", len(tokens)) } } // ──────────────────────────────────────────────────────────────────────────── // Benchmark: full chat pipeline overhead (mock backends) // ──────────────────────────────────────────────────────────────────────────── func BenchmarkChatPipeline_LLMOnly(b *testing.B) { llmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte(`{"choices":[{"message":{"content":"answer"}}]}`)) })) defer llmSrv.Close() llm := clients.NewLLMClient(llmSrv.URL, 10*time.Second) ctx := context.Background() b.ResetTimer() for b.Loop() { _, _ = llm.Generate(ctx, "question", "", "") } } func BenchmarkChatPipeline_RAGFlow(b *testing.B) { embedSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte(`{"data":[{"embedding":[0.1,0.2]}]}`)) })) defer embedSrv.Close() rerankSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte(`{"results":[{"index":0,"relevance_score":0.9}]}`)) })) defer rerankSrv.Close() llmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte(`{"choices":[{"message":{"content":"answer"}}]}`)) })) defer llmSrv.Close() embed := clients.NewEmbeddingsClient(embedSrv.URL, 10*time.Second, "bge") rerank := clients.NewRerankerClient(rerankSrv.URL, 10*time.Second) llm := clients.NewLLMClient(llmSrv.URL, 10*time.Second) ctx := context.Background() b.ResetTimer() for b.Loop() { _, _ = embed.EmbedSingle(ctx, "question") _, _ = rerank.Rerank(ctx, "question", []string{"doc1", "doc2"}, 2) _, _ = llm.Generate(ctx, "question", "context", "") } } func BenchmarkChatPipeline_StreamGenerate(b *testing.B) { tokens := []string{"one", " two", " three", " four", " five"} srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") for _, tok := range tokens { _, _ = fmt.Fprintf(w, "data: {\"choices\":[{\"delta\":{\"content\":%q}}]}\n\n", tok) } _, _ = fmt.Fprint(w, "data: [DONE]\n\n") })) defer srv.Close() llm := clients.NewLLMClient(srv.URL, 10*time.Second) ctx := context.Background() b.ResetTimer() for b.Loop() { _, _ = llm.StreamGenerate(ctx, "question", "", "", func(string) {}) } }