feat: add e2e tests, perf benchmarks, and infrastructure improvements

- messages/bench_test.go: serialization benchmarks (msgpack map vs struct vs protobuf)
- clients/clients_test.go: HTTP client tests with pooling verification (20 tests)
- natsutil/natsutil_test.go: encode/decode roundtrip + binary data tests
- handler/handler_test.go: handler dispatch tests + benchmark
- config/config.go: live reload via fsnotify + RWMutex getter methods
- clients/clients.go: SharedTransport + sync.Pool buffer pooling
- messages/messages.go: typed structs with msgpack+json tags
- messages/proto/: protobuf schema + generated code

Benchmark baseline (ChatRequest roundtrip):
  MsgpackMap:    2949 ns/op, 36 allocs
  MsgpackStruct: 2030 ns/op, 13 allocs (31% faster, 64% fewer allocs)
  Protobuf:       793 ns/op,  8 allocs (73% faster, 78% fewer allocs)
This commit is contained in:
2026-02-20 06:44:37 -05:00
parent d321c9852b
commit 35912d5844
12 changed files with 4260 additions and 391 deletions

View File

@@ -1,389 +1,437 @@
// Package clients provides HTTP client wrappers for AI/ML backend services.
//
// All clients share a single [http.Transport] for connection pooling across
// the process. Request and response bodies are serialized through pooled
// [bytes.Buffer]s to reduce GC pressure.
package clients
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/url"
"time"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/url"
"sync"
"time"
)
// httpClient is a shared interface for all service clients.
// ─── Shared transport & buffer pool ─────────────────────────────────────────
// SharedTransport is the process-wide HTTP transport used by every service
// client. Tweak pool sizes here rather than creating per-client transports.
var SharedTransport = &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
DisableCompression: true, // in-cluster traffic; skip gzip overhead
}
// bufPool recycles *bytes.Buffer to avoid per-request allocations.
var bufPool = sync.Pool{
New: func() any { return new(bytes.Buffer) },
}
func getBuf() *bytes.Buffer {
buf := bufPool.Get().(*bytes.Buffer)
buf.Reset()
return buf
}
func putBuf(buf *bytes.Buffer) {
if buf.Cap() > 1<<20 { // don't cache buffers > 1 MiB
return
}
bufPool.Put(buf)
}
// ─── httpClient base ────────────────────────────────────────────────────────
// httpClient is the shared base for all service clients.
type httpClient struct {
client *http.Client
baseURL string
client *http.Client
baseURL string
}
func newHTTPClient(baseURL string, timeout time.Duration) *httpClient {
return &httpClient{
client: &http.Client{Timeout: timeout},
baseURL: baseURL,
}
return &httpClient{
client: &http.Client{
Timeout: timeout,
Transport: SharedTransport,
},
baseURL: baseURL,
}
}
func (h *httpClient) postJSON(ctx context.Context, path string, body any) ([]byte, error) {
data, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("marshal: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, bytes.NewReader(data))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
return h.do(req)
buf := getBuf()
defer putBuf(buf)
if err := json.NewEncoder(buf).Encode(body); err != nil {
return nil, fmt.Errorf("marshal: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, buf)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
return h.do(req)
}
func (h *httpClient) get(ctx context.Context, path string, params url.Values) ([]byte, error) {
u := h.baseURL + path
if len(params) > 0 {
u += "?" + params.Encode()
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return nil, err
}
return h.do(req)
u := h.baseURL + path
if len(params) > 0 {
u += "?" + params.Encode()
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return nil, err
}
return h.do(req)
}
func (h *httpClient) getRaw(ctx context.Context, path string, params url.Values) ([]byte, error) {
return h.get(ctx, path, params)
return h.get(ctx, path, params)
}
func (h *httpClient) postMultipart(ctx context.Context, path string, fieldName string, fileName string, fileData []byte, fields map[string]string) ([]byte, error) {
var buf bytes.Buffer
w := multipart.NewWriter(&buf)
part, err := w.CreateFormFile(fieldName, fileName)
if err != nil {
return nil, err
}
if _, err := part.Write(fileData); err != nil {
return nil, err
}
for k, v := range fields {
_ = w.WriteField(k, v)
}
_ = w.Close()
buf := getBuf()
defer putBuf(buf)
w := multipart.NewWriter(buf)
part, err := w.CreateFormFile(fieldName, fileName)
if err != nil {
return nil, err
}
if _, err := part.Write(fileData); err != nil {
return nil, err
}
for k, v := range fields {
_ = w.WriteField(k, v)
}
_ = w.Close()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, &buf)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", w.FormDataContentType())
return h.do(req)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, buf)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", w.FormDataContentType())
return h.do(req)
}
func (h *httpClient) do(req *http.Request) ([]byte, error) {
resp, err := h.client.Do(req)
if err != nil {
return nil, fmt.Errorf("http %s %s: %w", req.Method, req.URL.Path, err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read body: %w", err)
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("http %d: %s", resp.StatusCode, string(body))
}
return body, nil
resp, err := h.client.Do(req)
if err != nil {
return nil, fmt.Errorf("http %s %s: %w", req.Method, req.URL.Path, err)
}
defer resp.Body.Close()
buf := getBuf()
defer putBuf(buf)
if _, err := io.Copy(buf, resp.Body); err != nil {
return nil, fmt.Errorf("read body: %w", err)
}
// Return a copy so the pooled buffer can be safely recycled.
body := make([]byte, buf.Len())
copy(body, buf.Bytes())
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("http %d: %s", resp.StatusCode, string(body))
}
return body, nil
}
func (h *httpClient) healthCheck(ctx context.Context) bool {
data, err := h.get(ctx, "/health", nil)
_ = data
return err == nil
data, err := h.get(ctx, "/health", nil)
_ = data
return err == nil
}
// --- Embeddings Client ---
// ─── Embeddings Client ──────────────────────────────────────────────────────
// EmbeddingsClient calls the embeddings service (Infinity/BGE).
type EmbeddingsClient struct {
*httpClient
Model string
*httpClient
Model string
}
// NewEmbeddingsClient creates an embeddings client.
func NewEmbeddingsClient(baseURL string, timeout time.Duration, model string) *EmbeddingsClient {
if model == "" {
model = "bge"
}
return &EmbeddingsClient{httpClient: newHTTPClient(baseURL, timeout), Model: model}
if model == "" {
model = "bge"
}
return &EmbeddingsClient{httpClient: newHTTPClient(baseURL, timeout), Model: model}
}
// Embed generates embeddings for a list of texts.
func (c *EmbeddingsClient) Embed(ctx context.Context, texts []string) ([][]float64, error) {
body, err := c.postJSON(ctx, "/embeddings", map[string]any{
"input": texts,
"model": c.Model,
})
if err != nil {
return nil, err
}
var resp struct {
Data []struct {
Embedding []float64 `json:"embedding"`
} `json:"data"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return nil, err
}
result := make([][]float64, len(resp.Data))
for i, d := range resp.Data {
result[i] = d.Embedding
}
return result, nil
body, err := c.postJSON(ctx, "/embeddings", map[string]any{
"input": texts,
"model": c.Model,
})
if err != nil {
return nil, err
}
var resp struct {
Data []struct {
Embedding []float64 `json:"embedding"`
} `json:"data"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return nil, err
}
result := make([][]float64, len(resp.Data))
for i, d := range resp.Data {
result[i] = d.Embedding
}
return result, nil
}
// EmbedSingle generates an embedding for a single text.
func (c *EmbeddingsClient) EmbedSingle(ctx context.Context, text string) ([]float64, error) {
results, err := c.Embed(ctx, []string{text})
if err != nil {
return nil, err
}
if len(results) == 0 {
return nil, fmt.Errorf("empty embedding result")
}
return results[0], nil
results, err := c.Embed(ctx, []string{text})
if err != nil {
return nil, err
}
if len(results) == 0 {
return nil, fmt.Errorf("empty embedding result")
}
return results[0], nil
}
// Health checks if the embeddings service is healthy.
func (c *EmbeddingsClient) Health(ctx context.Context) bool {
return c.healthCheck(ctx)
return c.healthCheck(ctx)
}
// --- Reranker Client ---
// ─── Reranker Client ────────────────────────────────────────────────────────
// RerankerClient calls the reranker service (BGE Reranker).
type RerankerClient struct {
*httpClient
*httpClient
}
// NewRerankerClient creates a reranker client.
func NewRerankerClient(baseURL string, timeout time.Duration) *RerankerClient {
return &RerankerClient{httpClient: newHTTPClient(baseURL, timeout)}
return &RerankerClient{httpClient: newHTTPClient(baseURL, timeout)}
}
// RerankResult represents a reranked document.
type RerankResult struct {
Index int `json:"index"`
Score float64 `json:"score"`
Document string `json:"document"`
Index int `json:"index"`
Score float64 `json:"score"`
Document string `json:"document"`
}
// Rerank reranks documents by relevance to the query.
func (c *RerankerClient) Rerank(ctx context.Context, query string, documents []string, topK int) ([]RerankResult, error) {
payload := map[string]any{
"query": query,
"documents": documents,
}
if topK > 0 {
payload["top_n"] = topK
}
body, err := c.postJSON(ctx, "/rerank", payload)
if err != nil {
return nil, err
}
var resp struct {
Results []struct {
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
Score float64 `json:"score"`
} `json:"results"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return nil, err
}
results := make([]RerankResult, len(resp.Results))
for i, r := range resp.Results {
score := r.RelevanceScore
if score == 0 {
score = r.Score
}
doc := ""
if r.Index < len(documents) {
doc = documents[r.Index]
}
results[i] = RerankResult{Index: r.Index, Score: score, Document: doc}
}
return results, nil
payload := map[string]any{
"query": query,
"documents": documents,
}
if topK > 0 {
payload["top_n"] = topK
}
body, err := c.postJSON(ctx, "/rerank", payload)
if err != nil {
return nil, err
}
var resp struct {
Results []struct {
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
Score float64 `json:"score"`
} `json:"results"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return nil, err
}
results := make([]RerankResult, len(resp.Results))
for i, r := range resp.Results {
score := r.RelevanceScore
if score == 0 {
score = r.Score
}
doc := ""
if r.Index < len(documents) {
doc = documents[r.Index]
}
results[i] = RerankResult{Index: r.Index, Score: score, Document: doc}
}
return results, nil
}
// --- LLM Client ---
// ─── LLM Client ─────────────────────────────────────────────────────────────
// LLMClient calls the vLLM-compatible LLM service.
type LLMClient struct {
*httpClient
Model string
MaxTokens int
Temperature float64
TopP float64
*httpClient
Model string
MaxTokens int
Temperature float64
TopP float64
}
// NewLLMClient creates an LLM client.
func NewLLMClient(baseURL string, timeout time.Duration) *LLMClient {
return &LLMClient{
httpClient: newHTTPClient(baseURL, timeout),
Model: "default",
MaxTokens: 2048,
Temperature: 0.7,
TopP: 0.9,
}
return &LLMClient{
httpClient: newHTTPClient(baseURL, timeout),
Model: "default",
MaxTokens: 2048,
Temperature: 0.7,
TopP: 0.9,
}
}
// ChatMessage is an OpenAI-compatible message.
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
Role string `json:"role"`
Content string `json:"content"`
}
// Generate sends a chat completion request and returns the response text.
func (c *LLMClient) Generate(ctx context.Context, prompt string, context_ string, systemPrompt string) (string, error) {
messages := buildMessages(prompt, context_, systemPrompt)
payload := map[string]any{
"model": c.Model,
"messages": messages,
"max_tokens": c.MaxTokens,
"temperature": c.Temperature,
"top_p": c.TopP,
}
body, err := c.postJSON(ctx, "/v1/chat/completions", payload)
if err != nil {
return "", err
}
var resp struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return "", err
}
if len(resp.Choices) == 0 {
return "", fmt.Errorf("no choices in LLM response")
}
return resp.Choices[0].Message.Content, nil
messages := buildMessages(prompt, context_, systemPrompt)
payload := map[string]any{
"model": c.Model,
"messages": messages,
"max_tokens": c.MaxTokens,
"temperature": c.Temperature,
"top_p": c.TopP,
}
body, err := c.postJSON(ctx, "/v1/chat/completions", payload)
if err != nil {
return "", err
}
var resp struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return "", err
}
if len(resp.Choices) == 0 {
return "", fmt.Errorf("no choices in LLM response")
}
return resp.Choices[0].Message.Content, nil
}
func buildMessages(prompt, ctx, systemPrompt string) []ChatMessage {
var msgs []ChatMessage
if systemPrompt != "" {
msgs = append(msgs, ChatMessage{Role: "system", Content: systemPrompt})
} else if ctx != "" {
msgs = append(msgs, ChatMessage{Role: "system", Content: "You are a helpful assistant. Use the provided context to answer the user's question. If the context doesn't contain relevant information, say so."})
}
if ctx != "" {
msgs = append(msgs, ChatMessage{Role: "user", Content: fmt.Sprintf("Context:\n%s\n\nQuestion: %s", ctx, prompt)})
} else {
msgs = append(msgs, ChatMessage{Role: "user", Content: prompt})
}
return msgs
var msgs []ChatMessage
if systemPrompt != "" {
msgs = append(msgs, ChatMessage{Role: "system", Content: systemPrompt})
} else if ctx != "" {
msgs = append(msgs, ChatMessage{Role: "system", Content: "You are a helpful assistant. Use the provided context to answer the user's question. If the context doesn't contain relevant information, say so."})
}
if ctx != "" {
msgs = append(msgs, ChatMessage{Role: "user", Content: fmt.Sprintf("Context:\n%s\n\nQuestion: %s", ctx, prompt)})
} else {
msgs = append(msgs, ChatMessage{Role: "user", Content: prompt})
}
return msgs
}
// --- TTS Client ---
// ─── TTS Client ─────────────────────────────────────────────────────────────
// TTSClient calls the TTS service (Coqui XTTS).
type TTSClient struct {
*httpClient
Language string
*httpClient
Language string
}
// NewTTSClient creates a TTS client.
func NewTTSClient(baseURL string, timeout time.Duration, language string) *TTSClient {
if language == "" {
language = "en"
}
return &TTSClient{httpClient: newHTTPClient(baseURL, timeout), Language: language}
if language == "" {
language = "en"
}
return &TTSClient{httpClient: newHTTPClient(baseURL, timeout), Language: language}
}
// Synthesize generates audio bytes from text.
func (c *TTSClient) Synthesize(ctx context.Context, text, language, speaker string) ([]byte, error) {
if language == "" {
language = c.Language
}
params := url.Values{
"text": {text},
"language_id": {language},
}
if speaker != "" {
params.Set("speaker_id", speaker)
}
return c.getRaw(ctx, "/api/tts", params)
if language == "" {
language = c.Language
}
params := url.Values{
"text": {text},
"language_id": {language},
}
if speaker != "" {
params.Set("speaker_id", speaker)
}
return c.getRaw(ctx, "/api/tts", params)
}
// --- STT Client ---
// ─── STT Client ─────────────────────────────────────────────────────────────
// STTClient calls the Whisper STT service.
type STTClient struct {
*httpClient
Language string
Task string
*httpClient
Language string
Task string
}
// NewSTTClient creates an STT client.
func NewSTTClient(baseURL string, timeout time.Duration) *STTClient {
return &STTClient{httpClient: newHTTPClient(baseURL, timeout), Task: "transcribe"}
return &STTClient{httpClient: newHTTPClient(baseURL, timeout), Task: "transcribe"}
}
// TranscribeResult holds transcription output.
type TranscribeResult struct {
Text string `json:"text"`
Language string `json:"language,omitempty"`
Text string `json:"text"`
Language string `json:"language,omitempty"`
}
// Transcribe sends audio to Whisper and returns the transcription.
func (c *STTClient) Transcribe(ctx context.Context, audio []byte, language string) (*TranscribeResult, error) {
if language == "" {
language = c.Language
}
fields := map[string]string{
"response_format": "json",
}
if language != "" {
fields["language"] = language
}
endpoint := "/v1/audio/transcriptions"
if c.Task == "translate" {
endpoint = "/v1/audio/translations"
}
body, err := c.postMultipart(ctx, endpoint, "file", "audio.wav", audio, fields)
if err != nil {
return nil, err
}
var result TranscribeResult
if err := json.Unmarshal(body, &result); err != nil {
return nil, err
}
return &result, nil
if language == "" {
language = c.Language
}
fields := map[string]string{
"response_format": "json",
}
if language != "" {
fields["language"] = language
}
endpoint := "/v1/audio/transcriptions"
if c.Task == "translate" {
endpoint = "/v1/audio/translations"
}
body, err := c.postMultipart(ctx, endpoint, "file", "audio.wav", audio, fields)
if err != nil {
return nil, err
}
var result TranscribeResult
if err := json.Unmarshal(body, &result); err != nil {
return nil, err
}
return &result, nil
}
// --- Milvus Client ---
// ─── Milvus Client ──────────────────────────────────────────────────────────
// MilvusClient provides vector search via the Milvus HTTP/gRPC API.
// For the Go port we use the Milvus Go SDK.
type MilvusClient struct {
Host string
Port int
Collection string
connected bool
Host string
Port int
Collection string
connected bool
}
// NewMilvusClient creates a Milvus client.
func NewMilvusClient(host string, port int, collection string) *MilvusClient {
return &MilvusClient{Host: host, Port: port, Collection: collection}
return &MilvusClient{Host: host, Port: port, Collection: collection}
}
// SearchResult holds a single vector search hit.
type SearchResult struct {
ID int64 `json:"id"`
Distance float64 `json:"distance"`
Score float64 `json:"score"`
Fields map[string]any `json:"fields,omitempty"`
ID int64 `json:"id"`
Distance float64 `json:"distance"`
Score float64 `json:"score"`
Fields map[string]any `json:"fields,omitempty"`
}

506
clients/clients_test.go Normal file
View File

@@ -0,0 +1,506 @@
package clients
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
)
// ────────────────────────────────────────────────────────────────────────────
// Shared infrastructure tests
// ────────────────────────────────────────────────────────────────────────────
func TestSharedTransport(t *testing.T) {
// All clients created via newHTTPClient should share the same transport.
c1 := newHTTPClient("http://a:8000", 10*time.Second)
c2 := newHTTPClient("http://b:9000", 30*time.Second)
if c1.client.Transport != c2.client.Transport {
t.Error("clients should share the same http.Transport")
}
if c1.client.Transport != SharedTransport {
t.Error("transport should be the package-level SharedTransport")
}
}
func TestBufferPoolGetPut(t *testing.T) {
buf := getBuf()
if buf == nil {
t.Fatal("getBuf returned nil")
}
if buf.Len() != 0 {
t.Error("getBuf should return a reset buffer")
}
buf.WriteString("hello")
putBuf(buf)
// On re-get, buffer should be reset.
buf2 := getBuf()
if buf2.Len() != 0 {
t.Error("re-acquired buffer should be reset")
}
putBuf(buf2)
}
func TestBufferPoolOversizedDiscarded(t *testing.T) {
buf := getBuf()
// Grow beyond 1 MB threshold.
buf.Write(make([]byte, 2<<20))
putBuf(buf) // should silently discard
// Pool should still work — we get a fresh one.
buf2 := getBuf()
if buf2.Len() != 0 {
t.Error("should get a fresh buffer")
}
putBuf(buf2)
}
func TestBufferPoolConcurrency(t *testing.T) {
var wg sync.WaitGroup
for i := range 100 {
wg.Add(1)
go func(n int) {
defer wg.Done()
buf := getBuf()
buf.WriteString(strings.Repeat("x", n))
putBuf(buf)
}(i)
}
wg.Wait()
}
// ────────────────────────────────────────────────────────────────────────────
// Embeddings client
// ────────────────────────────────────────────────────────────────────────────
func TestEmbeddingsClient_Embed(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/embeddings" {
t.Errorf("path = %q, want /embeddings", r.URL.Path)
}
if r.Method != http.MethodPost {
t.Errorf("method = %s, want POST", r.Method)
}
var req map[string]any
json.NewDecoder(r.Body).Decode(&req)
input, _ := req["input"].([]any)
if len(input) != 2 {
t.Errorf("input len = %d, want 2", len(input))
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"data": []map[string]any{
{"embedding": []float64{0.1, 0.2, 0.3}},
{"embedding": []float64{0.4, 0.5, 0.6}},
},
})
}))
defer ts.Close()
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "bge")
results, err := c.Embed(context.Background(), []string{"hello", "world"})
if err != nil {
t.Fatal(err)
}
if len(results) != 2 {
t.Fatalf("len(results) = %d, want 2", len(results))
}
if results[0][0] != 0.1 {
t.Errorf("results[0][0] = %f, want 0.1", results[0][0])
}
}
func TestEmbeddingsClient_EmbedSingle(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]any{
"data": []map[string]any{
{"embedding": []float64{1.0, 2.0}},
},
})
}))
defer ts.Close()
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
vec, err := c.EmbedSingle(context.Background(), "test")
if err != nil {
t.Fatal(err)
}
if len(vec) != 2 || vec[0] != 1.0 {
t.Errorf("vec = %v", vec)
}
}
func TestEmbeddingsClient_EmbedEmpty(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]any{"data": []any{}})
}))
defer ts.Close()
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
_, err := c.EmbedSingle(context.Background(), "test")
if err == nil {
t.Error("expected error for empty embedding")
}
}
func TestEmbeddingsClient_Health(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/health" {
w.WriteHeader(200)
return
}
w.WriteHeader(404)
}))
defer ts.Close()
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
if !c.Health(context.Background()) {
t.Error("expected healthy")
}
}
// ────────────────────────────────────────────────────────────────────────────
// Reranker client
// ────────────────────────────────────────────────────────────────────────────
func TestRerankerClient_Rerank(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req map[string]any
json.NewDecoder(r.Body).Decode(&req)
if req["query"] != "test query" {
t.Errorf("query = %v", req["query"])
}
json.NewEncoder(w).Encode(map[string]any{
"results": []map[string]any{
{"index": 1, "relevance_score": 0.95},
{"index": 0, "relevance_score": 0.80},
},
})
}))
defer ts.Close()
c := NewRerankerClient(ts.URL, 5*time.Second)
docs := []string{"Paris is great", "France is in Europe"}
results, err := c.Rerank(context.Background(), "test query", docs, 2)
if err != nil {
t.Fatal(err)
}
if len(results) != 2 {
t.Fatalf("len = %d", len(results))
}
if results[0].Score != 0.95 {
t.Errorf("score = %f, want 0.95", results[0].Score)
}
if results[0].Document != "France is in Europe" {
t.Errorf("document = %q", results[0].Document)
}
}
func TestRerankerClient_RerankFallbackScore(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]any{
"results": []map[string]any{
{"index": 0, "score": 0.77, "relevance_score": 0}, // some APIs only set score
},
})
}))
defer ts.Close()
c := NewRerankerClient(ts.URL, 5*time.Second)
results, err := c.Rerank(context.Background(), "q", []string{"doc1"}, 0)
if err != nil {
t.Fatal(err)
}
if results[0].Score != 0.77 {
t.Errorf("fallback score = %f, want 0.77", results[0].Score)
}
}
// ────────────────────────────────────────────────────────────────────────────
// LLM client
// ────────────────────────────────────────────────────────────────────────────
func TestLLMClient_Generate(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/chat/completions" {
t.Errorf("path = %q", r.URL.Path)
}
var req map[string]any
json.NewDecoder(r.Body).Decode(&req)
msgs, _ := req["messages"].([]any)
if len(msgs) == 0 {
t.Error("no messages in request")
}
json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{"message": map[string]any{"content": "Paris is the capital of France."}},
},
})
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
result, err := c.Generate(context.Background(), "capital of France?", "", "")
if err != nil {
t.Fatal(err)
}
if result != "Paris is the capital of France." {
t.Errorf("result = %q", result)
}
}
func TestLLMClient_GenerateWithContext(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req map[string]any
json.NewDecoder(r.Body).Decode(&req)
msgs, _ := req["messages"].([]any)
// Should have system + user message
if len(msgs) != 2 {
t.Errorf("expected 2 messages, got %d", len(msgs))
}
json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{"message": map[string]any{"content": "answer with context"}},
},
})
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
result, err := c.Generate(context.Background(), "question", "some context", "")
if err != nil {
t.Fatal(err)
}
if result != "answer with context" {
t.Errorf("result = %q", result)
}
}
func TestLLMClient_GenerateNoChoices(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]any{"choices": []any{}})
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
_, err := c.Generate(context.Background(), "q", "", "")
if err == nil {
t.Error("expected error for empty choices")
}
}
// ────────────────────────────────────────────────────────────────────────────
// TTS client
// ────────────────────────────────────────────────────────────────────────────
func TestTTSClient_Synthesize(t *testing.T) {
expected := []byte{0xDE, 0xAD, 0xBE, 0xEF}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/tts" {
t.Errorf("path = %q", r.URL.Path)
}
if r.URL.Query().Get("text") != "hello world" {
t.Errorf("text = %q", r.URL.Query().Get("text"))
}
w.Write(expected)
}))
defer ts.Close()
c := NewTTSClient(ts.URL, 5*time.Second, "en")
audio, err := c.Synthesize(context.Background(), "hello world", "", "")
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(audio, expected) {
t.Errorf("audio = %x, want %x", audio, expected)
}
}
func TestTTSClient_SynthesizeWithSpeaker(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("speaker_id") != "alice" {
t.Errorf("speaker_id = %q", r.URL.Query().Get("speaker_id"))
}
w.Write([]byte{0x01})
}))
defer ts.Close()
c := NewTTSClient(ts.URL, 5*time.Second, "en")
_, err := c.Synthesize(context.Background(), "hi", "en", "alice")
if err != nil {
t.Fatal(err)
}
}
// ────────────────────────────────────────────────────────────────────────────
// STT client
// ────────────────────────────────────────────────────────────────────────────
func TestSTTClient_Transcribe(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/audio/transcriptions" {
t.Errorf("path = %q", r.URL.Path)
}
ct := r.Header.Get("Content-Type")
if !strings.Contains(ct, "multipart/form-data") {
t.Errorf("content-type = %q", ct)
}
// Verify the audio file is present.
file, _, err := r.FormFile("file")
if err != nil {
t.Fatal(err)
}
data, _ := io.ReadAll(file)
if len(data) != 100 {
t.Errorf("file size = %d, want 100", len(data))
}
json.NewEncoder(w).Encode(map[string]string{"text": "hello world"})
}))
defer ts.Close()
c := NewSTTClient(ts.URL, 5*time.Second)
result, err := c.Transcribe(context.Background(), make([]byte, 100), "en")
if err != nil {
t.Fatal(err)
}
if result.Text != "hello world" {
t.Errorf("text = %q", result.Text)
}
}
func TestSTTClient_TranscribeTranslate(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/audio/translations" {
t.Errorf("path = %q, want /v1/audio/translations", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]string{"text": "translated"})
}))
defer ts.Close()
c := NewSTTClient(ts.URL, 5*time.Second)
c.Task = "translate"
result, err := c.Transcribe(context.Background(), []byte{0x01}, "")
if err != nil {
t.Fatal(err)
}
if result.Text != "translated" {
t.Errorf("text = %q", result.Text)
}
}
// ────────────────────────────────────────────────────────────────────────────
// HTTP error handling
// ────────────────────────────────────────────────────────────────────────────
func TestHTTPError4xx(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(422)
w.Write([]byte(`{"error": "bad input"}`))
}))
defer ts.Close()
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
_, err := c.Embed(context.Background(), []string{"test"})
if err == nil {
t.Fatal("expected error for 422")
}
if !strings.Contains(err.Error(), "422") {
t.Errorf("error should contain status code: %v", err)
}
}
func TestHTTPError5xx(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(500)
w.Write([]byte("internal server error"))
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
_, err := c.Generate(context.Background(), "q", "", "")
if err == nil {
t.Fatal("expected error for 500")
}
}
// ────────────────────────────────────────────────────────────────────────────
// buildMessages helper
// ────────────────────────────────────────────────────────────────────────────
func TestBuildMessages(t *testing.T) {
// No context, no system prompt → just user message
msgs := buildMessages("hello", "", "")
if len(msgs) != 1 || msgs[0].Role != "user" {
t.Errorf("expected 1 user msg, got %+v", msgs)
}
// With system prompt
msgs = buildMessages("hello", "", "You are helpful")
if len(msgs) != 2 || msgs[0].Role != "system" || msgs[0].Content != "You are helpful" {
t.Errorf("expected system+user, got %+v", msgs)
}
// With context, no system prompt → auto system prompt
msgs = buildMessages("question", "some context", "")
if len(msgs) != 2 || msgs[0].Role != "system" {
t.Errorf("expected auto system+user, got %+v", msgs)
}
if !strings.Contains(msgs[1].Content, "Context:") {
t.Error("user message should contain context")
}
}
// ────────────────────────────────────────────────────────────────────────────
// Benchmarks: pooled buffer vs direct allocation
// ────────────────────────────────────────────────────────────────────────────
func BenchmarkPostJSON(b *testing.B) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
io.Copy(io.Discard, r.Body)
w.Write([]byte(`{"ok":true}`))
}))
defer ts.Close()
c := newHTTPClient(ts.URL, 10*time.Second)
ctx := context.Background()
payload := map[string]any{
"text": strings.Repeat("x", 1024),
"count": 42,
"enabled": true,
}
b.ResetTimer()
for b.Loop() {
c.postJSON(ctx, "/test", payload)
}
}
func BenchmarkBufferPool(b *testing.B) {
b.ResetTimer()
for b.Loop() {
buf := getBuf()
buf.WriteString(strings.Repeat("x", 4096))
putBuf(buf)
}
}
func BenchmarkBufferPoolParallel(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
buf := getBuf()
buf.WriteString(strings.Repeat("x", 4096))
putBuf(buf)
}
})
}