All checks were successful
CI / Test (push) Successful in 3m2s
CI / Lint (push) Successful in 3m7s
CI / Release (push) Successful in 1m55s
CI / Notify Downstream (stt-module) (push) Successful in 1s
CI / Notify Downstream (voice-assistant) (push) Successful in 1s
CI / Notify (push) Successful in 2s
CI / Notify Downstream (chat-handler) (push) Successful in 1s
CI / Notify Downstream (pipeline-bridge) (push) Successful in 1s
CI / Notify Downstream (tts-module) (push) Successful in 1s
795 lines
25 KiB
Go
795 lines
25 KiB
Go
package clients
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
// Shared infrastructure tests
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
|
|
func TestSharedTransport(t *testing.T) {
|
|
// All clients created via newHTTPClient should share the same transport.
|
|
c1 := newHTTPClient("http://a:8000", 10*time.Second)
|
|
c2 := newHTTPClient("http://b:9000", 30*time.Second)
|
|
|
|
if c1.client.Transport != c2.client.Transport {
|
|
t.Error("clients should share the same http.Transport")
|
|
}
|
|
if c1.client.Transport != SharedTransport {
|
|
t.Error("transport should be the package-level SharedTransport")
|
|
}
|
|
}
|
|
|
|
func TestBufferPoolGetPut(t *testing.T) {
|
|
buf := getBuf()
|
|
if buf == nil {
|
|
t.Fatal("getBuf returned nil")
|
|
}
|
|
if buf.Len() != 0 {
|
|
t.Error("getBuf should return a reset buffer")
|
|
}
|
|
buf.WriteString("hello")
|
|
putBuf(buf)
|
|
|
|
// On re-get, buffer should be reset.
|
|
buf2 := getBuf()
|
|
if buf2.Len() != 0 {
|
|
t.Error("re-acquired buffer should be reset")
|
|
}
|
|
putBuf(buf2)
|
|
}
|
|
|
|
func TestBufferPoolOversizedDiscarded(t *testing.T) {
|
|
buf := getBuf()
|
|
// Grow beyond 1 MB threshold.
|
|
buf.Write(make([]byte, 2<<20))
|
|
putBuf(buf) // should silently discard
|
|
|
|
// Pool should still work — we get a fresh one.
|
|
buf2 := getBuf()
|
|
if buf2.Len() != 0 {
|
|
t.Error("should get a fresh buffer")
|
|
}
|
|
putBuf(buf2)
|
|
}
|
|
|
|
func TestBufferPoolConcurrency(t *testing.T) {
|
|
var wg sync.WaitGroup
|
|
for i := range 100 {
|
|
wg.Add(1)
|
|
go func(n int) {
|
|
defer wg.Done()
|
|
buf := getBuf()
|
|
buf.WriteString(strings.Repeat("x", n))
|
|
putBuf(buf)
|
|
}(i)
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
// Embeddings client
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
|
|
func TestEmbeddingsClient_Embed(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/embeddings" {
|
|
t.Errorf("path = %q, want /embeddings", r.URL.Path)
|
|
}
|
|
if r.Method != http.MethodPost {
|
|
t.Errorf("method = %s, want POST", r.Method)
|
|
}
|
|
var req map[string]any
|
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
|
input, _ := req["input"].([]any)
|
|
if len(input) != 2 {
|
|
t.Errorf("input len = %d, want 2", len(input))
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
|
"data": []map[string]any{
|
|
{"embedding": []float64{0.1, 0.2, 0.3}},
|
|
{"embedding": []float64{0.4, 0.5, 0.6}},
|
|
},
|
|
})
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "bge")
|
|
results, err := c.Embed(context.Background(), []string{"hello", "world"})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(results) != 2 {
|
|
t.Fatalf("len(results) = %d, want 2", len(results))
|
|
}
|
|
if results[0][0] != 0.1 {
|
|
t.Errorf("results[0][0] = %f, want 0.1", results[0][0])
|
|
}
|
|
}
|
|
|
|
func TestEmbeddingsClient_EmbedSingle(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
|
"data": []map[string]any{
|
|
{"embedding": []float64{1.0, 2.0}},
|
|
},
|
|
})
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
|
|
vec, err := c.EmbedSingle(context.Background(), "test")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(vec) != 2 || vec[0] != 1.0 {
|
|
t.Errorf("vec = %v", vec)
|
|
}
|
|
}
|
|
|
|
func TestEmbeddingsClient_EmbedEmpty(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_ = json.NewEncoder(w).Encode(map[string]any{"data": []any{}})
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
|
|
_, err := c.EmbedSingle(context.Background(), "test")
|
|
if err == nil {
|
|
t.Error("expected error for empty embedding")
|
|
}
|
|
}
|
|
|
|
func TestEmbeddingsClient_Health(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/health" {
|
|
w.WriteHeader(200)
|
|
return
|
|
}
|
|
w.WriteHeader(404)
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
|
|
if !c.Health(context.Background()) {
|
|
t.Error("expected healthy")
|
|
}
|
|
}
|
|
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
// Reranker client
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
|
|
func TestRerankerClient_Rerank(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var req map[string]any
|
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
|
if req["query"] != "test query" {
|
|
t.Errorf("query = %v", req["query"])
|
|
}
|
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
|
"results": []map[string]any{
|
|
{"index": 1, "relevance_score": 0.95},
|
|
{"index": 0, "relevance_score": 0.80},
|
|
},
|
|
})
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewRerankerClient(ts.URL, 5*time.Second)
|
|
docs := []string{"Paris is great", "France is in Europe"}
|
|
results, err := c.Rerank(context.Background(), "test query", docs, 2)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(results) != 2 {
|
|
t.Fatalf("len = %d", len(results))
|
|
}
|
|
if results[0].Score != 0.95 {
|
|
t.Errorf("score = %f, want 0.95", results[0].Score)
|
|
}
|
|
if results[0].Document != "France is in Europe" {
|
|
t.Errorf("document = %q", results[0].Document)
|
|
}
|
|
}
|
|
|
|
func TestRerankerClient_RerankFallbackScore(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
|
"results": []map[string]any{
|
|
{"index": 0, "score": 0.77, "relevance_score": 0}, // some APIs only set score
|
|
},
|
|
})
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewRerankerClient(ts.URL, 5*time.Second)
|
|
results, err := c.Rerank(context.Background(), "q", []string{"doc1"}, 0)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if results[0].Score != 0.77 {
|
|
t.Errorf("fallback score = %f, want 0.77", results[0].Score)
|
|
}
|
|
}
|
|
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
// LLM client
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
|
|
func TestLLMClient_Generate(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/v1/chat/completions" {
|
|
t.Errorf("path = %q", r.URL.Path)
|
|
}
|
|
var req map[string]any
|
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
|
msgs, _ := req["messages"].([]any)
|
|
if len(msgs) == 0 {
|
|
t.Error("no messages in request")
|
|
}
|
|
|
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
|
"choices": []map[string]any{
|
|
{"message": map[string]any{"content": "Paris is the capital of France."}},
|
|
},
|
|
})
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
|
result, err := c.Generate(context.Background(), "capital of France?", "", "")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if result != "Paris is the capital of France." {
|
|
t.Errorf("result = %q", result)
|
|
}
|
|
}
|
|
|
|
func TestLLMClient_GenerateWithContext(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var req map[string]any
|
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
|
msgs, _ := req["messages"].([]any)
|
|
// Should have system + user message
|
|
if len(msgs) != 2 {
|
|
t.Errorf("expected 2 messages, got %d", len(msgs))
|
|
}
|
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
|
"choices": []map[string]any{
|
|
{"message": map[string]any{"content": "answer with context"}},
|
|
},
|
|
})
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
|
result, err := c.Generate(context.Background(), "question", "some context", "")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if result != "answer with context" {
|
|
t.Errorf("result = %q", result)
|
|
}
|
|
}
|
|
|
|
func TestLLMClient_GenerateNoChoices(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_ = json.NewEncoder(w).Encode(map[string]any{"choices": []any{}})
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
|
_, err := c.Generate(context.Background(), "q", "", "")
|
|
if err == nil {
|
|
t.Error("expected error for empty choices")
|
|
}
|
|
}
|
|
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
// LLM client — StreamGenerate
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
|
|
// sseChunk builds an OpenAI-compatible SSE chat.completion.chunk line.
|
|
func sseChunk(content string) string {
|
|
chunk := map[string]any{
|
|
"choices": []map[string]any{
|
|
{"delta": map[string]any{"content": content}},
|
|
},
|
|
}
|
|
b, _ := json.Marshal(chunk)
|
|
return "data: " + string(b) + "\n\n"
|
|
}
|
|
|
|
func TestLLMClient_StreamGenerate(t *testing.T) {
|
|
tokens := []string{"Hello", " world", "!"}
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/v1/chat/completions" {
|
|
t.Errorf("path = %q", r.URL.Path)
|
|
}
|
|
var req map[string]any
|
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
|
if req["stream"] != true {
|
|
t.Errorf("stream = %v, want true", req["stream"])
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
flusher, _ := w.(http.Flusher)
|
|
for _, tok := range tokens {
|
|
_, _ = w.Write([]byte(sseChunk(tok)))
|
|
flusher.Flush()
|
|
}
|
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
|
flusher.Flush()
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
|
var received []string
|
|
result, err := c.StreamGenerate(context.Background(), "hi", "", "", func(tok string) {
|
|
received = append(received, tok)
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if result != "Hello world!" {
|
|
t.Errorf("result = %q, want %q", result, "Hello world!")
|
|
}
|
|
if len(received) != 3 {
|
|
t.Fatalf("callback count = %d, want 3", len(received))
|
|
}
|
|
if received[0] != "Hello" || received[1] != " world" || received[2] != "!" {
|
|
t.Errorf("received = %v", received)
|
|
}
|
|
}
|
|
|
|
func TestLLMClient_StreamGenerateWithSystemPrompt(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var req map[string]any
|
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
|
msgs, _ := req["messages"].([]any)
|
|
if len(msgs) != 2 {
|
|
t.Errorf("expected system+user, got %d messages", len(msgs))
|
|
}
|
|
first, _ := msgs[0].(map[string]any)
|
|
if first["role"] != "system" || first["content"] != "You are a DM" {
|
|
t.Errorf("system msg = %v", first)
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
flusher, _ := w.(http.Flusher)
|
|
_, _ = w.Write([]byte(sseChunk("ok")))
|
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
|
flusher.Flush()
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
|
result, err := c.StreamGenerate(context.Background(), "roll dice", "", "You are a DM", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if result != "ok" {
|
|
t.Errorf("result = %q", result)
|
|
}
|
|
}
|
|
|
|
func TestLLMClient_StreamGenerateNilCallback(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
flusher, _ := w.(http.Flusher)
|
|
_, _ = w.Write([]byte(sseChunk("token")))
|
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
|
flusher.Flush()
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
|
// nil callback should not panic
|
|
result, err := c.StreamGenerate(context.Background(), "hi", "", "", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if result != "token" {
|
|
t.Errorf("result = %q", result)
|
|
}
|
|
}
|
|
|
|
func TestLLMClient_StreamGenerateEmptyDelta(t *testing.T) {
|
|
// SSE chunks with empty content should be silently skipped.
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
flusher, _ := w.(http.Flusher)
|
|
// role-only chunk (no content) — common for first chunk from vLLM
|
|
_, _ = w.Write([]byte("data: {\"choices\":[{\"delta\":{\"role\":\"assistant\"}}]}\n\n"))
|
|
// empty content string
|
|
_, _ = w.Write([]byte(sseChunk("")))
|
|
// real token
|
|
_, _ = w.Write([]byte(sseChunk("hello")))
|
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
|
flusher.Flush()
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
|
var count int
|
|
result, err := c.StreamGenerate(context.Background(), "q", "", "", func(tok string) {
|
|
count++
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if result != "hello" {
|
|
t.Errorf("result = %q", result)
|
|
}
|
|
if count != 1 {
|
|
t.Errorf("callback count = %d, want 1 (empty deltas should be skipped)", count)
|
|
}
|
|
}
|
|
|
|
func TestLLMClient_StreamGenerateMalformedChunks(t *testing.T) {
|
|
// Malformed JSON should be skipped without error.
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
flusher, _ := w.(http.Flusher)
|
|
_, _ = w.Write([]byte("data: {invalid json}\n\n"))
|
|
_, _ = w.Write([]byte("data: {\"choices\":[]}\n\n")) // empty choices
|
|
_, _ = w.Write([]byte(sseChunk("good")))
|
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
|
flusher.Flush()
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
|
result, err := c.StreamGenerate(context.Background(), "q", "", "", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if result != "good" {
|
|
t.Errorf("result = %q, want %q", result, "good")
|
|
}
|
|
}
|
|
|
|
func TestLLMClient_StreamGenerateHTTPError(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(500)
|
|
_, _ = w.Write([]byte("internal server error"))
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
|
_, err := c.StreamGenerate(context.Background(), "q", "", "", nil)
|
|
if err == nil {
|
|
t.Fatal("expected error for 500")
|
|
}
|
|
if !strings.Contains(err.Error(), "500") {
|
|
t.Errorf("error should contain 500: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestLLMClient_StreamGenerateContextCanceled(t *testing.T) {
|
|
started := make(chan struct{})
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
flusher, _ := w.(http.Flusher)
|
|
// Send several tokens so the client receives some before cancel.
|
|
for i := range 20 {
|
|
_, _ = w.Write([]byte(sseChunk(fmt.Sprintf("tok%d ", i))))
|
|
flusher.Flush()
|
|
}
|
|
close(started)
|
|
// Block until client cancels
|
|
<-r.Context().Done()
|
|
}))
|
|
defer ts.Close()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
c := NewLLMClient(ts.URL, 10*time.Second)
|
|
|
|
var streamErr error
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
_, streamErr = c.StreamGenerate(ctx, "q", "", "", nil)
|
|
}()
|
|
|
|
<-started
|
|
cancel()
|
|
<-done
|
|
|
|
// After cancel the stream should return an error (context canceled or
|
|
// stream interrupted). The exact partial text depends on timing.
|
|
if streamErr == nil {
|
|
t.Error("expected error after context cancel")
|
|
}
|
|
}
|
|
|
|
func TestLLMClient_StreamGenerateNoSSEPrefix(t *testing.T) {
|
|
// Lines without "data: " prefix should be silently ignored (comments, blank lines, event IDs).
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
flusher, _ := w.(http.Flusher)
|
|
_, _ = w.Write([]byte(": this is an SSE comment\n\n"))
|
|
_, _ = w.Write([]byte("event: message\n"))
|
|
_, _ = w.Write([]byte(sseChunk("word")))
|
|
_, _ = w.Write([]byte("\n")) // blank line
|
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
|
flusher.Flush()
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
|
result, err := c.StreamGenerate(context.Background(), "q", "", "", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if result != "word" {
|
|
t.Errorf("result = %q, want %q", result, "word")
|
|
}
|
|
}
|
|
|
|
func TestLLMClient_StreamGenerateManyTokens(t *testing.T) {
|
|
// Verify token ordering and full assembly with many chunks.
|
|
n := 100
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
flusher, _ := w.(http.Flusher)
|
|
for i := range n {
|
|
tok := fmt.Sprintf("t%d ", i)
|
|
_, _ = w.Write([]byte(sseChunk(tok)))
|
|
flusher.Flush()
|
|
}
|
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
|
flusher.Flush()
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
|
var mu sync.Mutex
|
|
var order []int
|
|
result, err := c.StreamGenerate(context.Background(), "q", "", "", func(tok string) {
|
|
var idx int
|
|
_, _ = fmt.Sscanf(tok, "t%d ", &idx)
|
|
mu.Lock()
|
|
order = append(order, idx)
|
|
mu.Unlock()
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
// Verify all tokens arrived in order
|
|
if len(order) != n {
|
|
t.Fatalf("got %d tokens, want %d", len(order), n)
|
|
}
|
|
for i, v := range order {
|
|
if v != i {
|
|
t.Errorf("order[%d] = %d", i, v)
|
|
break
|
|
}
|
|
}
|
|
// Quick sanity: result should start with "t0 " and end with last token
|
|
if !strings.HasPrefix(result, "t0 ") {
|
|
t.Errorf("result prefix = %q", result[:10])
|
|
}
|
|
}
|
|
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
// TTS client
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
|
|
func TestTTSClient_Synthesize(t *testing.T) {
|
|
expected := []byte{0xDE, 0xAD, 0xBE, 0xEF}
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/api/tts" {
|
|
t.Errorf("path = %q", r.URL.Path)
|
|
}
|
|
if r.URL.Query().Get("text") != "hello world" {
|
|
t.Errorf("text = %q", r.URL.Query().Get("text"))
|
|
}
|
|
_, _ = w.Write(expected)
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewTTSClient(ts.URL, 5*time.Second, "en")
|
|
audio, err := c.Synthesize(context.Background(), "hello world", "", "")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !bytes.Equal(audio, expected) {
|
|
t.Errorf("audio = %x, want %x", audio, expected)
|
|
}
|
|
}
|
|
|
|
func TestTTSClient_SynthesizeWithSpeaker(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Query().Get("speaker_id") != "alice" {
|
|
t.Errorf("speaker_id = %q", r.URL.Query().Get("speaker_id"))
|
|
}
|
|
_, _ = w.Write([]byte{0x01})
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewTTSClient(ts.URL, 5*time.Second, "en")
|
|
_, err := c.Synthesize(context.Background(), "hi", "en", "alice")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
// STT client
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
|
|
func TestSTTClient_Transcribe(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/v1/audio/transcriptions" {
|
|
t.Errorf("path = %q", r.URL.Path)
|
|
}
|
|
ct := r.Header.Get("Content-Type")
|
|
if !strings.Contains(ct, "multipart/form-data") {
|
|
t.Errorf("content-type = %q", ct)
|
|
}
|
|
// Verify the audio file is present.
|
|
file, _, err := r.FormFile("file")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
data, _ := io.ReadAll(file)
|
|
if len(data) != 100 {
|
|
t.Errorf("file size = %d, want 100", len(data))
|
|
}
|
|
|
|
_ = json.NewEncoder(w).Encode(map[string]string{"text": "hello world"})
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewSTTClient(ts.URL, 5*time.Second)
|
|
result, err := c.Transcribe(context.Background(), make([]byte, 100), "en")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if result.Text != "hello world" {
|
|
t.Errorf("text = %q", result.Text)
|
|
}
|
|
}
|
|
|
|
func TestSTTClient_TranscribeTranslate(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/v1/audio/translations" {
|
|
t.Errorf("path = %q, want /v1/audio/translations", r.URL.Path)
|
|
}
|
|
_ = json.NewEncoder(w).Encode(map[string]string{"text": "translated"})
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewSTTClient(ts.URL, 5*time.Second)
|
|
c.Task = "translate"
|
|
result, err := c.Transcribe(context.Background(), []byte{0x01}, "")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if result.Text != "translated" {
|
|
t.Errorf("text = %q", result.Text)
|
|
}
|
|
}
|
|
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
// HTTP error handling
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
|
|
func TestHTTPError4xx(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(422)
|
|
_, _ = w.Write([]byte(`{"error": "bad input"}`))
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
|
|
_, err := c.Embed(context.Background(), []string{"test"})
|
|
if err == nil {
|
|
t.Fatal("expected error for 422")
|
|
}
|
|
if !strings.Contains(err.Error(), "422") {
|
|
t.Errorf("error should contain status code: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHTTPError5xx(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(500)
|
|
_, _ = w.Write([]byte("internal server error"))
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
|
_, err := c.Generate(context.Background(), "q", "", "")
|
|
if err == nil {
|
|
t.Fatal("expected error for 500")
|
|
}
|
|
}
|
|
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
// buildMessages helper
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
|
|
func TestBuildMessages(t *testing.T) {
|
|
// No context, no system prompt → just user message
|
|
msgs := buildMessages("hello", "", "")
|
|
if len(msgs) != 1 || msgs[0].Role != "user" {
|
|
t.Errorf("expected 1 user msg, got %+v", msgs)
|
|
}
|
|
|
|
// With system prompt
|
|
msgs = buildMessages("hello", "", "You are helpful")
|
|
if len(msgs) != 2 || msgs[0].Role != "system" || msgs[0].Content != "You are helpful" {
|
|
t.Errorf("expected system+user, got %+v", msgs)
|
|
}
|
|
|
|
// With context, no system prompt → auto system prompt
|
|
msgs = buildMessages("question", "some context", "")
|
|
if len(msgs) != 2 || msgs[0].Role != "system" {
|
|
t.Errorf("expected auto system+user, got %+v", msgs)
|
|
}
|
|
if !strings.Contains(msgs[1].Content, "Context:") {
|
|
t.Error("user message should contain context")
|
|
}
|
|
}
|
|
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
// Benchmarks: pooled buffer vs direct allocation
|
|
// ────────────────────────────────────────────────────────────────────────────
|
|
|
|
func BenchmarkPostJSON(b *testing.B) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, _ = io.Copy(io.Discard, r.Body)
|
|
_, _ = w.Write([]byte(`{"ok":true}`))
|
|
}))
|
|
defer ts.Close()
|
|
|
|
c := newHTTPClient(ts.URL, 10*time.Second)
|
|
ctx := context.Background()
|
|
payload := map[string]any{
|
|
"text": strings.Repeat("x", 1024),
|
|
"count": 42,
|
|
"enabled": true,
|
|
}
|
|
|
|
b.ResetTimer()
|
|
for b.Loop() {
|
|
_, _ = c.postJSON(ctx, "/test", payload)
|
|
}
|
|
}
|
|
|
|
func BenchmarkBufferPool(b *testing.B) {
|
|
b.ResetTimer()
|
|
for b.Loop() {
|
|
buf := getBuf()
|
|
buf.WriteString(strings.Repeat("x", 4096))
|
|
putBuf(buf)
|
|
}
|
|
}
|
|
|
|
func BenchmarkBufferPoolParallel(b *testing.B) {
|
|
b.RunParallel(func(pb *testing.PB) {
|
|
for pb.Next() {
|
|
buf := getBuf()
|
|
buf.WriteString(strings.Repeat("x", 4096))
|
|
putBuf(buf)
|
|
}
|
|
})
|
|
}
|