package clients import ( "bytes" "context" "encoding/json" "io" "net/http" "net/http/httptest" "strings" "sync" "testing" "time" ) // ──────────────────────────────────────────────────────────────────────────── // Shared infrastructure tests // ──────────────────────────────────────────────────────────────────────────── func TestSharedTransport(t *testing.T) { // All clients created via newHTTPClient should share the same transport. c1 := newHTTPClient("http://a:8000", 10*time.Second) c2 := newHTTPClient("http://b:9000", 30*time.Second) if c1.client.Transport != c2.client.Transport { t.Error("clients should share the same http.Transport") } if c1.client.Transport != SharedTransport { t.Error("transport should be the package-level SharedTransport") } } func TestBufferPoolGetPut(t *testing.T) { buf := getBuf() if buf == nil { t.Fatal("getBuf returned nil") } if buf.Len() != 0 { t.Error("getBuf should return a reset buffer") } buf.WriteString("hello") putBuf(buf) // On re-get, buffer should be reset. buf2 := getBuf() if buf2.Len() != 0 { t.Error("re-acquired buffer should be reset") } putBuf(buf2) } func TestBufferPoolOversizedDiscarded(t *testing.T) { buf := getBuf() // Grow beyond 1 MB threshold. buf.Write(make([]byte, 2<<20)) putBuf(buf) // should silently discard // Pool should still work — we get a fresh one. buf2 := getBuf() if buf2.Len() != 0 { t.Error("should get a fresh buffer") } putBuf(buf2) } func TestBufferPoolConcurrency(t *testing.T) { var wg sync.WaitGroup for i := range 100 { wg.Add(1) go func(n int) { defer wg.Done() buf := getBuf() buf.WriteString(strings.Repeat("x", n)) putBuf(buf) }(i) } wg.Wait() } // ──────────────────────────────────────────────────────────────────────────── // Embeddings client // ──────────────────────────────────────────────────────────────────────────── func TestEmbeddingsClient_Embed(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/embeddings" { t.Errorf("path = %q, want /embeddings", r.URL.Path) } if r.Method != http.MethodPost { t.Errorf("method = %s, want POST", r.Method) } var req map[string]any json.NewDecoder(r.Body).Decode(&req) input, _ := req["input"].([]any) if len(input) != 2 { t.Errorf("input len = %d, want 2", len(input)) } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]any{ "data": []map[string]any{ {"embedding": []float64{0.1, 0.2, 0.3}}, {"embedding": []float64{0.4, 0.5, 0.6}}, }, }) })) defer ts.Close() c := NewEmbeddingsClient(ts.URL, 5*time.Second, "bge") results, err := c.Embed(context.Background(), []string{"hello", "world"}) if err != nil { t.Fatal(err) } if len(results) != 2 { t.Fatalf("len(results) = %d, want 2", len(results)) } if results[0][0] != 0.1 { t.Errorf("results[0][0] = %f, want 0.1", results[0][0]) } } func TestEmbeddingsClient_EmbedSingle(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]any{ "data": []map[string]any{ {"embedding": []float64{1.0, 2.0}}, }, }) })) defer ts.Close() c := NewEmbeddingsClient(ts.URL, 5*time.Second, "") vec, err := c.EmbedSingle(context.Background(), "test") if err != nil { t.Fatal(err) } if len(vec) != 2 || vec[0] != 1.0 { t.Errorf("vec = %v", vec) } } func TestEmbeddingsClient_EmbedEmpty(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]any{"data": []any{}}) })) defer ts.Close() c := NewEmbeddingsClient(ts.URL, 5*time.Second, "") _, err := c.EmbedSingle(context.Background(), "test") if err == nil { t.Error("expected error for empty embedding") } } func TestEmbeddingsClient_Health(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/health" { w.WriteHeader(200) return } w.WriteHeader(404) })) defer ts.Close() c := NewEmbeddingsClient(ts.URL, 5*time.Second, "") if !c.Health(context.Background()) { t.Error("expected healthy") } } // ──────────────────────────────────────────────────────────────────────────── // Reranker client // ──────────────────────────────────────────────────────────────────────────── func TestRerankerClient_Rerank(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var req map[string]any json.NewDecoder(r.Body).Decode(&req) if req["query"] != "test query" { t.Errorf("query = %v", req["query"]) } json.NewEncoder(w).Encode(map[string]any{ "results": []map[string]any{ {"index": 1, "relevance_score": 0.95}, {"index": 0, "relevance_score": 0.80}, }, }) })) defer ts.Close() c := NewRerankerClient(ts.URL, 5*time.Second) docs := []string{"Paris is great", "France is in Europe"} results, err := c.Rerank(context.Background(), "test query", docs, 2) if err != nil { t.Fatal(err) } if len(results) != 2 { t.Fatalf("len = %d", len(results)) } if results[0].Score != 0.95 { t.Errorf("score = %f, want 0.95", results[0].Score) } if results[0].Document != "France is in Europe" { t.Errorf("document = %q", results[0].Document) } } func TestRerankerClient_RerankFallbackScore(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]any{ "results": []map[string]any{ {"index": 0, "score": 0.77, "relevance_score": 0}, // some APIs only set score }, }) })) defer ts.Close() c := NewRerankerClient(ts.URL, 5*time.Second) results, err := c.Rerank(context.Background(), "q", []string{"doc1"}, 0) if err != nil { t.Fatal(err) } if results[0].Score != 0.77 { t.Errorf("fallback score = %f, want 0.77", results[0].Score) } } // ──────────────────────────────────────────────────────────────────────────── // LLM client // ──────────────────────────────────────────────────────────────────────────── func TestLLMClient_Generate(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/chat/completions" { t.Errorf("path = %q", r.URL.Path) } var req map[string]any json.NewDecoder(r.Body).Decode(&req) msgs, _ := req["messages"].([]any) if len(msgs) == 0 { t.Error("no messages in request") } json.NewEncoder(w).Encode(map[string]any{ "choices": []map[string]any{ {"message": map[string]any{"content": "Paris is the capital of France."}}, }, }) })) defer ts.Close() c := NewLLMClient(ts.URL, 5*time.Second) result, err := c.Generate(context.Background(), "capital of France?", "", "") if err != nil { t.Fatal(err) } if result != "Paris is the capital of France." { t.Errorf("result = %q", result) } } func TestLLMClient_GenerateWithContext(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var req map[string]any json.NewDecoder(r.Body).Decode(&req) msgs, _ := req["messages"].([]any) // Should have system + user message if len(msgs) != 2 { t.Errorf("expected 2 messages, got %d", len(msgs)) } json.NewEncoder(w).Encode(map[string]any{ "choices": []map[string]any{ {"message": map[string]any{"content": "answer with context"}}, }, }) })) defer ts.Close() c := NewLLMClient(ts.URL, 5*time.Second) result, err := c.Generate(context.Background(), "question", "some context", "") if err != nil { t.Fatal(err) } if result != "answer with context" { t.Errorf("result = %q", result) } } func TestLLMClient_GenerateNoChoices(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]any{"choices": []any{}}) })) defer ts.Close() c := NewLLMClient(ts.URL, 5*time.Second) _, err := c.Generate(context.Background(), "q", "", "") if err == nil { t.Error("expected error for empty choices") } } // ──────────────────────────────────────────────────────────────────────────── // TTS client // ──────────────────────────────────────────────────────────────────────────── func TestTTSClient_Synthesize(t *testing.T) { expected := []byte{0xDE, 0xAD, 0xBE, 0xEF} ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/api/tts" { t.Errorf("path = %q", r.URL.Path) } if r.URL.Query().Get("text") != "hello world" { t.Errorf("text = %q", r.URL.Query().Get("text")) } w.Write(expected) })) defer ts.Close() c := NewTTSClient(ts.URL, 5*time.Second, "en") audio, err := c.Synthesize(context.Background(), "hello world", "", "") if err != nil { t.Fatal(err) } if !bytes.Equal(audio, expected) { t.Errorf("audio = %x, want %x", audio, expected) } } func TestTTSClient_SynthesizeWithSpeaker(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Query().Get("speaker_id") != "alice" { t.Errorf("speaker_id = %q", r.URL.Query().Get("speaker_id")) } w.Write([]byte{0x01}) })) defer ts.Close() c := NewTTSClient(ts.URL, 5*time.Second, "en") _, err := c.Synthesize(context.Background(), "hi", "en", "alice") if err != nil { t.Fatal(err) } } // ──────────────────────────────────────────────────────────────────────────── // STT client // ──────────────────────────────────────────────────────────────────────────── func TestSTTClient_Transcribe(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/audio/transcriptions" { t.Errorf("path = %q", r.URL.Path) } ct := r.Header.Get("Content-Type") if !strings.Contains(ct, "multipart/form-data") { t.Errorf("content-type = %q", ct) } // Verify the audio file is present. file, _, err := r.FormFile("file") if err != nil { t.Fatal(err) } data, _ := io.ReadAll(file) if len(data) != 100 { t.Errorf("file size = %d, want 100", len(data)) } json.NewEncoder(w).Encode(map[string]string{"text": "hello world"}) })) defer ts.Close() c := NewSTTClient(ts.URL, 5*time.Second) result, err := c.Transcribe(context.Background(), make([]byte, 100), "en") if err != nil { t.Fatal(err) } if result.Text != "hello world" { t.Errorf("text = %q", result.Text) } } func TestSTTClient_TranscribeTranslate(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/audio/translations" { t.Errorf("path = %q, want /v1/audio/translations", r.URL.Path) } json.NewEncoder(w).Encode(map[string]string{"text": "translated"}) })) defer ts.Close() c := NewSTTClient(ts.URL, 5*time.Second) c.Task = "translate" result, err := c.Transcribe(context.Background(), []byte{0x01}, "") if err != nil { t.Fatal(err) } if result.Text != "translated" { t.Errorf("text = %q", result.Text) } } // ──────────────────────────────────────────────────────────────────────────── // HTTP error handling // ──────────────────────────────────────────────────────────────────────────── func TestHTTPError4xx(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(422) w.Write([]byte(`{"error": "bad input"}`)) })) defer ts.Close() c := NewEmbeddingsClient(ts.URL, 5*time.Second, "") _, err := c.Embed(context.Background(), []string{"test"}) if err == nil { t.Fatal("expected error for 422") } if !strings.Contains(err.Error(), "422") { t.Errorf("error should contain status code: %v", err) } } func TestHTTPError5xx(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(500) w.Write([]byte("internal server error")) })) defer ts.Close() c := NewLLMClient(ts.URL, 5*time.Second) _, err := c.Generate(context.Background(), "q", "", "") if err == nil { t.Fatal("expected error for 500") } } // ──────────────────────────────────────────────────────────────────────────── // buildMessages helper // ──────────────────────────────────────────────────────────────────────────── func TestBuildMessages(t *testing.T) { // No context, no system prompt → just user message msgs := buildMessages("hello", "", "") if len(msgs) != 1 || msgs[0].Role != "user" { t.Errorf("expected 1 user msg, got %+v", msgs) } // With system prompt msgs = buildMessages("hello", "", "You are helpful") if len(msgs) != 2 || msgs[0].Role != "system" || msgs[0].Content != "You are helpful" { t.Errorf("expected system+user, got %+v", msgs) } // With context, no system prompt → auto system prompt msgs = buildMessages("question", "some context", "") if len(msgs) != 2 || msgs[0].Role != "system" { t.Errorf("expected auto system+user, got %+v", msgs) } if !strings.Contains(msgs[1].Content, "Context:") { t.Error("user message should contain context") } } // ──────────────────────────────────────────────────────────────────────────── // Benchmarks: pooled buffer vs direct allocation // ──────────────────────────────────────────────────────────────────────────── func BenchmarkPostJSON(b *testing.B) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { io.Copy(io.Discard, r.Body) w.Write([]byte(`{"ok":true}`)) })) defer ts.Close() c := newHTTPClient(ts.URL, 10*time.Second) ctx := context.Background() payload := map[string]any{ "text": strings.Repeat("x", 1024), "count": 42, "enabled": true, } b.ResetTimer() for b.Loop() { c.postJSON(ctx, "/test", payload) } } func BenchmarkBufferPool(b *testing.B) { b.ResetTimer() for b.Loop() { buf := getBuf() buf.WriteString(strings.Repeat("x", 4096)) putBuf(buf) } } func BenchmarkBufferPoolParallel(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { buf := getBuf() buf.WriteString(strings.Repeat("x", 4096)) putBuf(buf) } }) }