package clients import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "strings" "sync" "testing" "time" ) // ──────────────────────────────────────────────────────────────────────────── // Shared infrastructure tests // ──────────────────────────────────────────────────────────────────────────── func TestSharedTransport(t *testing.T) { // All clients created via newHTTPClient should share the same transport. c1 := newHTTPClient("http://a:8000", 10*time.Second) c2 := newHTTPClient("http://b:9000", 30*time.Second) if c1.client.Transport != c2.client.Transport { t.Error("clients should share the same http.Transport") } if c1.client.Transport != SharedTransport { t.Error("transport should be the package-level SharedTransport") } } func TestBufferPoolGetPut(t *testing.T) { buf := getBuf() if buf == nil { t.Fatal("getBuf returned nil") } if buf.Len() != 0 { t.Error("getBuf should return a reset buffer") } buf.WriteString("hello") putBuf(buf) // On re-get, buffer should be reset. buf2 := getBuf() if buf2.Len() != 0 { t.Error("re-acquired buffer should be reset") } putBuf(buf2) } func TestBufferPoolOversizedDiscarded(t *testing.T) { buf := getBuf() // Grow beyond 1 MB threshold. buf.Write(make([]byte, 2<<20)) putBuf(buf) // should silently discard // Pool should still work — we get a fresh one. buf2 := getBuf() if buf2.Len() != 0 { t.Error("should get a fresh buffer") } putBuf(buf2) } func TestBufferPoolConcurrency(t *testing.T) { var wg sync.WaitGroup for i := range 100 { wg.Add(1) go func(n int) { defer wg.Done() buf := getBuf() buf.WriteString(strings.Repeat("x", n)) putBuf(buf) }(i) } wg.Wait() } // ──────────────────────────────────────────────────────────────────────────── // Embeddings client // ──────────────────────────────────────────────────────────────────────────── func TestEmbeddingsClient_Embed(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/embeddings" { t.Errorf("path = %q, want /embeddings", r.URL.Path) } if r.Method != http.MethodPost { t.Errorf("method = %s, want POST", r.Method) } var req map[string]any _ = json.NewDecoder(r.Body).Decode(&req) input, _ := req["input"].([]any) if len(input) != 2 { t.Errorf("input len = %d, want 2", len(input)) } w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]any{ "data": []map[string]any{ {"embedding": []float64{0.1, 0.2, 0.3}}, {"embedding": []float64{0.4, 0.5, 0.6}}, }, }) })) defer ts.Close() c := NewEmbeddingsClient(ts.URL, 5*time.Second, "bge") results, err := c.Embed(context.Background(), []string{"hello", "world"}) if err != nil { t.Fatal(err) } if len(results) != 2 { t.Fatalf("len(results) = %d, want 2", len(results)) } if results[0][0] != 0.1 { t.Errorf("results[0][0] = %f, want 0.1", results[0][0]) } } func TestEmbeddingsClient_EmbedSingle(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _ = json.NewEncoder(w).Encode(map[string]any{ "data": []map[string]any{ {"embedding": []float64{1.0, 2.0}}, }, }) })) defer ts.Close() c := NewEmbeddingsClient(ts.URL, 5*time.Second, "") vec, err := c.EmbedSingle(context.Background(), "test") if err != nil { t.Fatal(err) } if len(vec) != 2 || vec[0] != 1.0 { t.Errorf("vec = %v", vec) } } func TestEmbeddingsClient_EmbedEmpty(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _ = json.NewEncoder(w).Encode(map[string]any{"data": []any{}}) })) defer ts.Close() c := NewEmbeddingsClient(ts.URL, 5*time.Second, "") _, err := c.EmbedSingle(context.Background(), "test") if err == nil { t.Error("expected error for empty embedding") } } func TestEmbeddingsClient_Health(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/health" { w.WriteHeader(200) return } w.WriteHeader(404) })) defer ts.Close() c := NewEmbeddingsClient(ts.URL, 5*time.Second, "") if !c.Health(context.Background()) { t.Error("expected healthy") } } // ──────────────────────────────────────────────────────────────────────────── // Reranker client // ──────────────────────────────────────────────────────────────────────────── func TestRerankerClient_Rerank(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var req map[string]any _ = json.NewDecoder(r.Body).Decode(&req) if req["query"] != "test query" { t.Errorf("query = %v", req["query"]) } _ = json.NewEncoder(w).Encode(map[string]any{ "results": []map[string]any{ {"index": 1, "relevance_score": 0.95}, {"index": 0, "relevance_score": 0.80}, }, }) })) defer ts.Close() c := NewRerankerClient(ts.URL, 5*time.Second) docs := []string{"Paris is great", "France is in Europe"} results, err := c.Rerank(context.Background(), "test query", docs, 2) if err != nil { t.Fatal(err) } if len(results) != 2 { t.Fatalf("len = %d", len(results)) } if results[0].Score != 0.95 { t.Errorf("score = %f, want 0.95", results[0].Score) } if results[0].Document != "France is in Europe" { t.Errorf("document = %q", results[0].Document) } } func TestRerankerClient_RerankFallbackScore(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _ = json.NewEncoder(w).Encode(map[string]any{ "results": []map[string]any{ {"index": 0, "score": 0.77, "relevance_score": 0}, // some APIs only set score }, }) })) defer ts.Close() c := NewRerankerClient(ts.URL, 5*time.Second) results, err := c.Rerank(context.Background(), "q", []string{"doc1"}, 0) if err != nil { t.Fatal(err) } if results[0].Score != 0.77 { t.Errorf("fallback score = %f, want 0.77", results[0].Score) } } // ──────────────────────────────────────────────────────────────────────────── // LLM client // ──────────────────────────────────────────────────────────────────────────── func TestLLMClient_Generate(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/chat/completions" { t.Errorf("path = %q", r.URL.Path) } var req map[string]any _ = json.NewDecoder(r.Body).Decode(&req) msgs, _ := req["messages"].([]any) if len(msgs) == 0 { t.Error("no messages in request") } _ = json.NewEncoder(w).Encode(map[string]any{ "choices": []map[string]any{ {"message": map[string]any{"content": "Paris is the capital of France."}}, }, }) })) defer ts.Close() c := NewLLMClient(ts.URL, 5*time.Second) result, err := c.Generate(context.Background(), "capital of France?", "", "") if err != nil { t.Fatal(err) } if result != "Paris is the capital of France." { t.Errorf("result = %q", result) } } func TestLLMClient_GenerateWithContext(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var req map[string]any _ = json.NewDecoder(r.Body).Decode(&req) msgs, _ := req["messages"].([]any) // Should have system + user message if len(msgs) != 2 { t.Errorf("expected 2 messages, got %d", len(msgs)) } _ = json.NewEncoder(w).Encode(map[string]any{ "choices": []map[string]any{ {"message": map[string]any{"content": "answer with context"}}, }, }) })) defer ts.Close() c := NewLLMClient(ts.URL, 5*time.Second) result, err := c.Generate(context.Background(), "question", "some context", "") if err != nil { t.Fatal(err) } if result != "answer with context" { t.Errorf("result = %q", result) } } func TestLLMClient_GenerateNoChoices(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _ = json.NewEncoder(w).Encode(map[string]any{"choices": []any{}}) })) defer ts.Close() c := NewLLMClient(ts.URL, 5*time.Second) _, err := c.Generate(context.Background(), "q", "", "") if err == nil { t.Error("expected error for empty choices") } } // ──────────────────────────────────────────────────────────────────────────── // LLM client — StreamGenerate // ──────────────────────────────────────────────────────────────────────────── // sseChunk builds an OpenAI-compatible SSE chat.completion.chunk line. func sseChunk(content string) string { chunk := map[string]any{ "choices": []map[string]any{ {"delta": map[string]any{"content": content}}, }, } b, _ := json.Marshal(chunk) return "data: " + string(b) + "\n\n" } func TestLLMClient_StreamGenerate(t *testing.T) { tokens := []string{"Hello", " world", "!"} ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/chat/completions" { t.Errorf("path = %q", r.URL.Path) } var req map[string]any _ = json.NewDecoder(r.Body).Decode(&req) if req["stream"] != true { t.Errorf("stream = %v, want true", req["stream"]) } w.Header().Set("Content-Type", "text/event-stream") flusher, _ := w.(http.Flusher) for _, tok := range tokens { _, _ = w.Write([]byte(sseChunk(tok))) flusher.Flush() } _, _ = w.Write([]byte("data: [DONE]\n\n")) flusher.Flush() })) defer ts.Close() c := NewLLMClient(ts.URL, 5*time.Second) var received []string result, err := c.StreamGenerate(context.Background(), "hi", "", "", func(tok string) { received = append(received, tok) }) if err != nil { t.Fatal(err) } if result != "Hello world!" { t.Errorf("result = %q, want %q", result, "Hello world!") } if len(received) != 3 { t.Fatalf("callback count = %d, want 3", len(received)) } if received[0] != "Hello" || received[1] != " world" || received[2] != "!" { t.Errorf("received = %v", received) } } func TestLLMClient_StreamGenerateWithSystemPrompt(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var req map[string]any _ = json.NewDecoder(r.Body).Decode(&req) msgs, _ := req["messages"].([]any) if len(msgs) != 2 { t.Errorf("expected system+user, got %d messages", len(msgs)) } first, _ := msgs[0].(map[string]any) if first["role"] != "system" || first["content"] != "You are a DM" { t.Errorf("system msg = %v", first) } w.Header().Set("Content-Type", "text/event-stream") flusher, _ := w.(http.Flusher) _, _ = w.Write([]byte(sseChunk("ok"))) _, _ = w.Write([]byte("data: [DONE]\n\n")) flusher.Flush() })) defer ts.Close() c := NewLLMClient(ts.URL, 5*time.Second) result, err := c.StreamGenerate(context.Background(), "roll dice", "", "You are a DM", nil) if err != nil { t.Fatal(err) } if result != "ok" { t.Errorf("result = %q", result) } } func TestLLMClient_StreamGenerateNilCallback(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") flusher, _ := w.(http.Flusher) _, _ = w.Write([]byte(sseChunk("token"))) _, _ = w.Write([]byte("data: [DONE]\n\n")) flusher.Flush() })) defer ts.Close() c := NewLLMClient(ts.URL, 5*time.Second) // nil callback should not panic result, err := c.StreamGenerate(context.Background(), "hi", "", "", nil) if err != nil { t.Fatal(err) } if result != "token" { t.Errorf("result = %q", result) } } func TestLLMClient_StreamGenerateEmptyDelta(t *testing.T) { // SSE chunks with empty content should be silently skipped. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") flusher, _ := w.(http.Flusher) // role-only chunk (no content) — common for first chunk from vLLM _, _ = w.Write([]byte("data: {\"choices\":[{\"delta\":{\"role\":\"assistant\"}}]}\n\n")) // empty content string _, _ = w.Write([]byte(sseChunk(""))) // real token _, _ = w.Write([]byte(sseChunk("hello"))) _, _ = w.Write([]byte("data: [DONE]\n\n")) flusher.Flush() })) defer ts.Close() c := NewLLMClient(ts.URL, 5*time.Second) var count int result, err := c.StreamGenerate(context.Background(), "q", "", "", func(tok string) { count++ }) if err != nil { t.Fatal(err) } if result != "hello" { t.Errorf("result = %q", result) } if count != 1 { t.Errorf("callback count = %d, want 1 (empty deltas should be skipped)", count) } } func TestLLMClient_StreamGenerateMalformedChunks(t *testing.T) { // Malformed JSON should be skipped without error. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") flusher, _ := w.(http.Flusher) _, _ = w.Write([]byte("data: {invalid json}\n\n")) _, _ = w.Write([]byte("data: {\"choices\":[]}\n\n")) // empty choices _, _ = w.Write([]byte(sseChunk("good"))) _, _ = w.Write([]byte("data: [DONE]\n\n")) flusher.Flush() })) defer ts.Close() c := NewLLMClient(ts.URL, 5*time.Second) result, err := c.StreamGenerate(context.Background(), "q", "", "", nil) if err != nil { t.Fatal(err) } if result != "good" { t.Errorf("result = %q, want %q", result, "good") } } func TestLLMClient_StreamGenerateHTTPError(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(500) _, _ = w.Write([]byte("internal server error")) })) defer ts.Close() c := NewLLMClient(ts.URL, 5*time.Second) _, err := c.StreamGenerate(context.Background(), "q", "", "", nil) if err == nil { t.Fatal("expected error for 500") } if !strings.Contains(err.Error(), "500") { t.Errorf("error should contain 500: %v", err) } } func TestLLMClient_StreamGenerateContextCanceled(t *testing.T) { started := make(chan struct{}) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") flusher, _ := w.(http.Flusher) // Send several tokens so the client receives some before cancel. for i := range 20 { _, _ = w.Write([]byte(sseChunk(fmt.Sprintf("tok%d ", i)))) flusher.Flush() } close(started) // Block until client cancels <-r.Context().Done() })) defer ts.Close() ctx, cancel := context.WithCancel(context.Background()) c := NewLLMClient(ts.URL, 10*time.Second) var streamErr error done := make(chan struct{}) go func() { defer close(done) _, streamErr = c.StreamGenerate(ctx, "q", "", "", nil) }() <-started cancel() <-done // After cancel the stream should return an error (context canceled or // stream interrupted). The exact partial text depends on timing. if streamErr == nil { t.Error("expected error after context cancel") } } func TestLLMClient_StreamGenerateNoSSEPrefix(t *testing.T) { // Lines without "data: " prefix should be silently ignored (comments, blank lines, event IDs). ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") flusher, _ := w.(http.Flusher) _, _ = w.Write([]byte(": this is an SSE comment\n\n")) _, _ = w.Write([]byte("event: message\n")) _, _ = w.Write([]byte(sseChunk("word"))) _, _ = w.Write([]byte("\n")) // blank line _, _ = w.Write([]byte("data: [DONE]\n\n")) flusher.Flush() })) defer ts.Close() c := NewLLMClient(ts.URL, 5*time.Second) result, err := c.StreamGenerate(context.Background(), "q", "", "", nil) if err != nil { t.Fatal(err) } if result != "word" { t.Errorf("result = %q, want %q", result, "word") } } func TestLLMClient_StreamGenerateManyTokens(t *testing.T) { // Verify token ordering and full assembly with many chunks. n := 100 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") flusher, _ := w.(http.Flusher) for i := range n { tok := fmt.Sprintf("t%d ", i) _, _ = w.Write([]byte(sseChunk(tok))) flusher.Flush() } _, _ = w.Write([]byte("data: [DONE]\n\n")) flusher.Flush() })) defer ts.Close() c := NewLLMClient(ts.URL, 5*time.Second) var mu sync.Mutex var order []int result, err := c.StreamGenerate(context.Background(), "q", "", "", func(tok string) { var idx int _, _ = fmt.Sscanf(tok, "t%d ", &idx) mu.Lock() order = append(order, idx) mu.Unlock() }) if err != nil { t.Fatal(err) } // Verify all tokens arrived in order if len(order) != n { t.Fatalf("got %d tokens, want %d", len(order), n) } for i, v := range order { if v != i { t.Errorf("order[%d] = %d", i, v) break } } // Quick sanity: result should start with "t0 " and end with last token if !strings.HasPrefix(result, "t0 ") { t.Errorf("result prefix = %q", result[:10]) } } // ──────────────────────────────────────────────────────────────────────────── // TTS client // ──────────────────────────────────────────────────────────────────────────── func TestTTSClient_Synthesize(t *testing.T) { expected := []byte{0xDE, 0xAD, 0xBE, 0xEF} ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/api/tts" { t.Errorf("path = %q", r.URL.Path) } if r.URL.Query().Get("text") != "hello world" { t.Errorf("text = %q", r.URL.Query().Get("text")) } _, _ = w.Write(expected) })) defer ts.Close() c := NewTTSClient(ts.URL, 5*time.Second, "en") audio, err := c.Synthesize(context.Background(), "hello world", "", "") if err != nil { t.Fatal(err) } if !bytes.Equal(audio, expected) { t.Errorf("audio = %x, want %x", audio, expected) } } func TestTTSClient_SynthesizeWithSpeaker(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Query().Get("speaker_id") != "alice" { t.Errorf("speaker_id = %q", r.URL.Query().Get("speaker_id")) } _, _ = w.Write([]byte{0x01}) })) defer ts.Close() c := NewTTSClient(ts.URL, 5*time.Second, "en") _, err := c.Synthesize(context.Background(), "hi", "en", "alice") if err != nil { t.Fatal(err) } } // ──────────────────────────────────────────────────────────────────────────── // STT client // ──────────────────────────────────────────────────────────────────────────── func TestSTTClient_Transcribe(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/audio/transcriptions" { t.Errorf("path = %q", r.URL.Path) } ct := r.Header.Get("Content-Type") if !strings.Contains(ct, "multipart/form-data") { t.Errorf("content-type = %q", ct) } // Verify the audio file is present. file, _, err := r.FormFile("file") if err != nil { t.Fatal(err) } data, _ := io.ReadAll(file) if len(data) != 100 { t.Errorf("file size = %d, want 100", len(data)) } _ = json.NewEncoder(w).Encode(map[string]string{"text": "hello world"}) })) defer ts.Close() c := NewSTTClient(ts.URL, 5*time.Second) result, err := c.Transcribe(context.Background(), make([]byte, 100), "en") if err != nil { t.Fatal(err) } if result.Text != "hello world" { t.Errorf("text = %q", result.Text) } } func TestSTTClient_TranscribeTranslate(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/audio/translations" { t.Errorf("path = %q, want /v1/audio/translations", r.URL.Path) } _ = json.NewEncoder(w).Encode(map[string]string{"text": "translated"}) })) defer ts.Close() c := NewSTTClient(ts.URL, 5*time.Second) c.Task = "translate" result, err := c.Transcribe(context.Background(), []byte{0x01}, "") if err != nil { t.Fatal(err) } if result.Text != "translated" { t.Errorf("text = %q", result.Text) } } // ──────────────────────────────────────────────────────────────────────────── // HTTP error handling // ──────────────────────────────────────────────────────────────────────────── func TestHTTPError4xx(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(422) _, _ = w.Write([]byte(`{"error": "bad input"}`)) })) defer ts.Close() c := NewEmbeddingsClient(ts.URL, 5*time.Second, "") _, err := c.Embed(context.Background(), []string{"test"}) if err == nil { t.Fatal("expected error for 422") } if !strings.Contains(err.Error(), "422") { t.Errorf("error should contain status code: %v", err) } } func TestHTTPError5xx(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(500) _, _ = w.Write([]byte("internal server error")) })) defer ts.Close() c := NewLLMClient(ts.URL, 5*time.Second) _, err := c.Generate(context.Background(), "q", "", "") if err == nil { t.Fatal("expected error for 500") } } // ──────────────────────────────────────────────────────────────────────────── // buildMessages helper // ──────────────────────────────────────────────────────────────────────────── func TestBuildMessages(t *testing.T) { // No context, no system prompt → just user message msgs := buildMessages("hello", "", "") if len(msgs) != 1 || msgs[0].Role != "user" { t.Errorf("expected 1 user msg, got %+v", msgs) } // With system prompt msgs = buildMessages("hello", "", "You are helpful") if len(msgs) != 2 || msgs[0].Role != "system" || msgs[0].Content != "You are helpful" { t.Errorf("expected system+user, got %+v", msgs) } // With context, no system prompt → auto system prompt msgs = buildMessages("question", "some context", "") if len(msgs) != 2 || msgs[0].Role != "system" { t.Errorf("expected auto system+user, got %+v", msgs) } if !strings.Contains(msgs[1].Content, "Context:") { t.Error("user message should contain context") } } // ──────────────────────────────────────────────────────────────────────────── // Benchmarks: pooled buffer vs direct allocation // ──────────────────────────────────────────────────────────────────────────── func BenchmarkPostJSON(b *testing.B) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = io.Copy(io.Discard, r.Body) _, _ = w.Write([]byte(`{"ok":true}`)) })) defer ts.Close() c := newHTTPClient(ts.URL, 10*time.Second) ctx := context.Background() payload := map[string]any{ "text": strings.Repeat("x", 1024), "count": 42, "enabled": true, } b.ResetTimer() for b.Loop() { _, _ = c.postJSON(ctx, "/test", payload) } } func BenchmarkBufferPool(b *testing.B) { b.ResetTimer() for b.Loop() { buf := getBuf() buf.WriteString(strings.Repeat("x", 4096)) putBuf(buf) } } func BenchmarkBufferPoolParallel(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { buf := getBuf() buf.WriteString(strings.Repeat("x", 4096)) putBuf(buf) } }) }