feat: add e2e tests, perf benchmarks, and infrastructure improvements

- messages/bench_test.go: serialization benchmarks (msgpack map vs struct vs protobuf)
- clients/clients_test.go: HTTP client tests with pooling verification (20 tests)
- natsutil/natsutil_test.go: encode/decode roundtrip + binary data tests
- handler/handler_test.go: handler dispatch tests + benchmark
- config/config.go: live reload via fsnotify + RWMutex getter methods
- clients/clients.go: SharedTransport + sync.Pool buffer pooling
- messages/messages.go: typed structs with msgpack+json tags
- messages/proto/: protobuf schema + generated code

Benchmark baseline (ChatRequest roundtrip):
  MsgpackMap:    2949 ns/op, 36 allocs
  MsgpackStruct: 2030 ns/op, 13 allocs (31% faster, 64% fewer allocs)
  Protobuf:       793 ns/op,  8 allocs (73% faster, 78% fewer allocs)
This commit is contained in:
2026-02-20 06:44:37 -05:00
parent d321c9852b
commit 35912d5844
12 changed files with 4260 additions and 391 deletions

View File

@@ -1,389 +1,437 @@
// Package clients provides HTTP client wrappers for AI/ML backend services. // Package clients provides HTTP client wrappers for AI/ML backend services.
//
// All clients share a single [http.Transport] for connection pooling across
// the process. Request and response bodies are serialized through pooled
// [bytes.Buffer]s to reduce GC pressure.
package clients package clients
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"net/url" "net/url"
"time" "sync"
"time"
) )
// httpClient is a shared interface for all service clients. // ─── Shared transport & buffer pool ─────────────────────────────────────────
// SharedTransport is the process-wide HTTP transport used by every service
// client. Tweak pool sizes here rather than creating per-client transports.
var SharedTransport = &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
DisableCompression: true, // in-cluster traffic; skip gzip overhead
}
// bufPool recycles *bytes.Buffer to avoid per-request allocations.
var bufPool = sync.Pool{
New: func() any { return new(bytes.Buffer) },
}
func getBuf() *bytes.Buffer {
buf := bufPool.Get().(*bytes.Buffer)
buf.Reset()
return buf
}
func putBuf(buf *bytes.Buffer) {
if buf.Cap() > 1<<20 { // don't cache buffers > 1 MiB
return
}
bufPool.Put(buf)
}
// ─── httpClient base ────────────────────────────────────────────────────────
// httpClient is the shared base for all service clients.
type httpClient struct { type httpClient struct {
client *http.Client client *http.Client
baseURL string baseURL string
} }
func newHTTPClient(baseURL string, timeout time.Duration) *httpClient { func newHTTPClient(baseURL string, timeout time.Duration) *httpClient {
return &httpClient{ return &httpClient{
client: &http.Client{Timeout: timeout}, client: &http.Client{
baseURL: baseURL, Timeout: timeout,
} Transport: SharedTransport,
},
baseURL: baseURL,
}
} }
func (h *httpClient) postJSON(ctx context.Context, path string, body any) ([]byte, error) { func (h *httpClient) postJSON(ctx context.Context, path string, body any) ([]byte, error) {
data, err := json.Marshal(body) buf := getBuf()
if err != nil { defer putBuf(buf)
return nil, fmt.Errorf("marshal: %w", err) if err := json.NewEncoder(buf).Encode(body); err != nil {
} return nil, fmt.Errorf("marshal: %w", err)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, bytes.NewReader(data)) }
if err != nil { req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, buf)
return nil, err if err != nil {
} return nil, err
req.Header.Set("Content-Type", "application/json") }
return h.do(req) req.Header.Set("Content-Type", "application/json")
return h.do(req)
} }
func (h *httpClient) get(ctx context.Context, path string, params url.Values) ([]byte, error) { func (h *httpClient) get(ctx context.Context, path string, params url.Values) ([]byte, error) {
u := h.baseURL + path u := h.baseURL + path
if len(params) > 0 { if len(params) > 0 {
u += "?" + params.Encode() u += "?" + params.Encode()
} }
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return h.do(req) return h.do(req)
} }
func (h *httpClient) getRaw(ctx context.Context, path string, params url.Values) ([]byte, error) { func (h *httpClient) getRaw(ctx context.Context, path string, params url.Values) ([]byte, error) {
return h.get(ctx, path, params) return h.get(ctx, path, params)
} }
func (h *httpClient) postMultipart(ctx context.Context, path string, fieldName string, fileName string, fileData []byte, fields map[string]string) ([]byte, error) { func (h *httpClient) postMultipart(ctx context.Context, path string, fieldName string, fileName string, fileData []byte, fields map[string]string) ([]byte, error) {
var buf bytes.Buffer buf := getBuf()
w := multipart.NewWriter(&buf) defer putBuf(buf)
part, err := w.CreateFormFile(fieldName, fileName) w := multipart.NewWriter(buf)
if err != nil { part, err := w.CreateFormFile(fieldName, fileName)
return nil, err if err != nil {
} return nil, err
if _, err := part.Write(fileData); err != nil { }
return nil, err if _, err := part.Write(fileData); err != nil {
} return nil, err
for k, v := range fields { }
_ = w.WriteField(k, v) for k, v := range fields {
} _ = w.WriteField(k, v)
_ = w.Close() }
_ = w.Close()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, &buf) req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, buf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set("Content-Type", w.FormDataContentType()) req.Header.Set("Content-Type", w.FormDataContentType())
return h.do(req) return h.do(req)
} }
func (h *httpClient) do(req *http.Request) ([]byte, error) { func (h *httpClient) do(req *http.Request) ([]byte, error) {
resp, err := h.client.Do(req) resp, err := h.client.Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("http %s %s: %w", req.Method, req.URL.Path, err) return nil, fmt.Errorf("http %s %s: %w", req.Method, req.URL.Path, err)
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil { buf := getBuf()
return nil, fmt.Errorf("read body: %w", err) defer putBuf(buf)
} if _, err := io.Copy(buf, resp.Body); err != nil {
if resp.StatusCode >= 400 { return nil, fmt.Errorf("read body: %w", err)
return nil, fmt.Errorf("http %d: %s", resp.StatusCode, string(body)) }
}
return body, nil // Return a copy so the pooled buffer can be safely recycled.
body := make([]byte, buf.Len())
copy(body, buf.Bytes())
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("http %d: %s", resp.StatusCode, string(body))
}
return body, nil
} }
func (h *httpClient) healthCheck(ctx context.Context) bool { func (h *httpClient) healthCheck(ctx context.Context) bool {
data, err := h.get(ctx, "/health", nil) data, err := h.get(ctx, "/health", nil)
_ = data _ = data
return err == nil return err == nil
} }
// --- Embeddings Client --- // ─── Embeddings Client ──────────────────────────────────────────────────────
// EmbeddingsClient calls the embeddings service (Infinity/BGE). // EmbeddingsClient calls the embeddings service (Infinity/BGE).
type EmbeddingsClient struct { type EmbeddingsClient struct {
*httpClient *httpClient
Model string Model string
} }
// NewEmbeddingsClient creates an embeddings client. // NewEmbeddingsClient creates an embeddings client.
func NewEmbeddingsClient(baseURL string, timeout time.Duration, model string) *EmbeddingsClient { func NewEmbeddingsClient(baseURL string, timeout time.Duration, model string) *EmbeddingsClient {
if model == "" { if model == "" {
model = "bge" model = "bge"
} }
return &EmbeddingsClient{httpClient: newHTTPClient(baseURL, timeout), Model: model} return &EmbeddingsClient{httpClient: newHTTPClient(baseURL, timeout), Model: model}
} }
// Embed generates embeddings for a list of texts. // Embed generates embeddings for a list of texts.
func (c *EmbeddingsClient) Embed(ctx context.Context, texts []string) ([][]float64, error) { func (c *EmbeddingsClient) Embed(ctx context.Context, texts []string) ([][]float64, error) {
body, err := c.postJSON(ctx, "/embeddings", map[string]any{ body, err := c.postJSON(ctx, "/embeddings", map[string]any{
"input": texts, "input": texts,
"model": c.Model, "model": c.Model,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
var resp struct { var resp struct {
Data []struct { Data []struct {
Embedding []float64 `json:"embedding"` Embedding []float64 `json:"embedding"`
} `json:"data"` } `json:"data"`
} }
if err := json.Unmarshal(body, &resp); err != nil { if err := json.Unmarshal(body, &resp); err != nil {
return nil, err return nil, err
} }
result := make([][]float64, len(resp.Data)) result := make([][]float64, len(resp.Data))
for i, d := range resp.Data { for i, d := range resp.Data {
result[i] = d.Embedding result[i] = d.Embedding
} }
return result, nil return result, nil
} }
// EmbedSingle generates an embedding for a single text. // EmbedSingle generates an embedding for a single text.
func (c *EmbeddingsClient) EmbedSingle(ctx context.Context, text string) ([]float64, error) { func (c *EmbeddingsClient) EmbedSingle(ctx context.Context, text string) ([]float64, error) {
results, err := c.Embed(ctx, []string{text}) results, err := c.Embed(ctx, []string{text})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(results) == 0 { if len(results) == 0 {
return nil, fmt.Errorf("empty embedding result") return nil, fmt.Errorf("empty embedding result")
} }
return results[0], nil return results[0], nil
} }
// Health checks if the embeddings service is healthy. // Health checks if the embeddings service is healthy.
func (c *EmbeddingsClient) Health(ctx context.Context) bool { func (c *EmbeddingsClient) Health(ctx context.Context) bool {
return c.healthCheck(ctx) return c.healthCheck(ctx)
} }
// --- Reranker Client --- // ─── Reranker Client ────────────────────────────────────────────────────────
// RerankerClient calls the reranker service (BGE Reranker). // RerankerClient calls the reranker service (BGE Reranker).
type RerankerClient struct { type RerankerClient struct {
*httpClient *httpClient
} }
// NewRerankerClient creates a reranker client. // NewRerankerClient creates a reranker client.
func NewRerankerClient(baseURL string, timeout time.Duration) *RerankerClient { func NewRerankerClient(baseURL string, timeout time.Duration) *RerankerClient {
return &RerankerClient{httpClient: newHTTPClient(baseURL, timeout)} return &RerankerClient{httpClient: newHTTPClient(baseURL, timeout)}
} }
// RerankResult represents a reranked document. // RerankResult represents a reranked document.
type RerankResult struct { type RerankResult struct {
Index int `json:"index"` Index int `json:"index"`
Score float64 `json:"score"` Score float64 `json:"score"`
Document string `json:"document"` Document string `json:"document"`
} }
// Rerank reranks documents by relevance to the query. // Rerank reranks documents by relevance to the query.
func (c *RerankerClient) Rerank(ctx context.Context, query string, documents []string, topK int) ([]RerankResult, error) { func (c *RerankerClient) Rerank(ctx context.Context, query string, documents []string, topK int) ([]RerankResult, error) {
payload := map[string]any{ payload := map[string]any{
"query": query, "query": query,
"documents": documents, "documents": documents,
} }
if topK > 0 { if topK > 0 {
payload["top_n"] = topK payload["top_n"] = topK
} }
body, err := c.postJSON(ctx, "/rerank", payload) body, err := c.postJSON(ctx, "/rerank", payload)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var resp struct { var resp struct {
Results []struct { Results []struct {
Index int `json:"index"` Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"` RelevanceScore float64 `json:"relevance_score"`
Score float64 `json:"score"` Score float64 `json:"score"`
} `json:"results"` } `json:"results"`
} }
if err := json.Unmarshal(body, &resp); err != nil { if err := json.Unmarshal(body, &resp); err != nil {
return nil, err return nil, err
} }
results := make([]RerankResult, len(resp.Results)) results := make([]RerankResult, len(resp.Results))
for i, r := range resp.Results { for i, r := range resp.Results {
score := r.RelevanceScore score := r.RelevanceScore
if score == 0 { if score == 0 {
score = r.Score score = r.Score
} }
doc := "" doc := ""
if r.Index < len(documents) { if r.Index < len(documents) {
doc = documents[r.Index] doc = documents[r.Index]
} }
results[i] = RerankResult{Index: r.Index, Score: score, Document: doc} results[i] = RerankResult{Index: r.Index, Score: score, Document: doc}
} }
return results, nil return results, nil
} }
// --- LLM Client --- // ─── LLM Client ─────────────────────────────────────────────────────────────
// LLMClient calls the vLLM-compatible LLM service. // LLMClient calls the vLLM-compatible LLM service.
type LLMClient struct { type LLMClient struct {
*httpClient *httpClient
Model string Model string
MaxTokens int MaxTokens int
Temperature float64 Temperature float64
TopP float64 TopP float64
} }
// NewLLMClient creates an LLM client. // NewLLMClient creates an LLM client.
func NewLLMClient(baseURL string, timeout time.Duration) *LLMClient { func NewLLMClient(baseURL string, timeout time.Duration) *LLMClient {
return &LLMClient{ return &LLMClient{
httpClient: newHTTPClient(baseURL, timeout), httpClient: newHTTPClient(baseURL, timeout),
Model: "default", Model: "default",
MaxTokens: 2048, MaxTokens: 2048,
Temperature: 0.7, Temperature: 0.7,
TopP: 0.9, TopP: 0.9,
} }
} }
// ChatMessage is an OpenAI-compatible message. // ChatMessage is an OpenAI-compatible message.
type ChatMessage struct { type ChatMessage struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`
} }
// Generate sends a chat completion request and returns the response text. // Generate sends a chat completion request and returns the response text.
func (c *LLMClient) Generate(ctx context.Context, prompt string, context_ string, systemPrompt string) (string, error) { func (c *LLMClient) Generate(ctx context.Context, prompt string, context_ string, systemPrompt string) (string, error) {
messages := buildMessages(prompt, context_, systemPrompt) messages := buildMessages(prompt, context_, systemPrompt)
payload := map[string]any{ payload := map[string]any{
"model": c.Model, "model": c.Model,
"messages": messages, "messages": messages,
"max_tokens": c.MaxTokens, "max_tokens": c.MaxTokens,
"temperature": c.Temperature, "temperature": c.Temperature,
"top_p": c.TopP, "top_p": c.TopP,
} }
body, err := c.postJSON(ctx, "/v1/chat/completions", payload) body, err := c.postJSON(ctx, "/v1/chat/completions", payload)
if err != nil { if err != nil {
return "", err return "", err
} }
var resp struct { var resp struct {
Choices []struct { Choices []struct {
Message struct { Message struct {
Content string `json:"content"` Content string `json:"content"`
} `json:"message"` } `json:"message"`
} `json:"choices"` } `json:"choices"`
} }
if err := json.Unmarshal(body, &resp); err != nil { if err := json.Unmarshal(body, &resp); err != nil {
return "", err return "", err
} }
if len(resp.Choices) == 0 { if len(resp.Choices) == 0 {
return "", fmt.Errorf("no choices in LLM response") return "", fmt.Errorf("no choices in LLM response")
} }
return resp.Choices[0].Message.Content, nil return resp.Choices[0].Message.Content, nil
} }
func buildMessages(prompt, ctx, systemPrompt string) []ChatMessage { func buildMessages(prompt, ctx, systemPrompt string) []ChatMessage {
var msgs []ChatMessage var msgs []ChatMessage
if systemPrompt != "" { if systemPrompt != "" {
msgs = append(msgs, ChatMessage{Role: "system", Content: systemPrompt}) msgs = append(msgs, ChatMessage{Role: "system", Content: systemPrompt})
} else if ctx != "" { } else if ctx != "" {
msgs = append(msgs, ChatMessage{Role: "system", Content: "You are a helpful assistant. Use the provided context to answer the user's question. If the context doesn't contain relevant information, say so."}) msgs = append(msgs, ChatMessage{Role: "system", Content: "You are a helpful assistant. Use the provided context to answer the user's question. If the context doesn't contain relevant information, say so."})
} }
if ctx != "" { if ctx != "" {
msgs = append(msgs, ChatMessage{Role: "user", Content: fmt.Sprintf("Context:\n%s\n\nQuestion: %s", ctx, prompt)}) msgs = append(msgs, ChatMessage{Role: "user", Content: fmt.Sprintf("Context:\n%s\n\nQuestion: %s", ctx, prompt)})
} else { } else {
msgs = append(msgs, ChatMessage{Role: "user", Content: prompt}) msgs = append(msgs, ChatMessage{Role: "user", Content: prompt})
} }
return msgs return msgs
} }
// --- TTS Client --- // ─── TTS Client ─────────────────────────────────────────────────────────────
// TTSClient calls the TTS service (Coqui XTTS). // TTSClient calls the TTS service (Coqui XTTS).
type TTSClient struct { type TTSClient struct {
*httpClient *httpClient
Language string Language string
} }
// NewTTSClient creates a TTS client. // NewTTSClient creates a TTS client.
func NewTTSClient(baseURL string, timeout time.Duration, language string) *TTSClient { func NewTTSClient(baseURL string, timeout time.Duration, language string) *TTSClient {
if language == "" { if language == "" {
language = "en" language = "en"
} }
return &TTSClient{httpClient: newHTTPClient(baseURL, timeout), Language: language} return &TTSClient{httpClient: newHTTPClient(baseURL, timeout), Language: language}
} }
// Synthesize generates audio bytes from text. // Synthesize generates audio bytes from text.
func (c *TTSClient) Synthesize(ctx context.Context, text, language, speaker string) ([]byte, error) { func (c *TTSClient) Synthesize(ctx context.Context, text, language, speaker string) ([]byte, error) {
if language == "" { if language == "" {
language = c.Language language = c.Language
} }
params := url.Values{ params := url.Values{
"text": {text}, "text": {text},
"language_id": {language}, "language_id": {language},
} }
if speaker != "" { if speaker != "" {
params.Set("speaker_id", speaker) params.Set("speaker_id", speaker)
} }
return c.getRaw(ctx, "/api/tts", params) return c.getRaw(ctx, "/api/tts", params)
} }
// --- STT Client --- // ─── STT Client ─────────────────────────────────────────────────────────────
// STTClient calls the Whisper STT service. // STTClient calls the Whisper STT service.
type STTClient struct { type STTClient struct {
*httpClient *httpClient
Language string Language string
Task string Task string
} }
// NewSTTClient creates an STT client. // NewSTTClient creates an STT client.
func NewSTTClient(baseURL string, timeout time.Duration) *STTClient { func NewSTTClient(baseURL string, timeout time.Duration) *STTClient {
return &STTClient{httpClient: newHTTPClient(baseURL, timeout), Task: "transcribe"} return &STTClient{httpClient: newHTTPClient(baseURL, timeout), Task: "transcribe"}
} }
// TranscribeResult holds transcription output. // TranscribeResult holds transcription output.
type TranscribeResult struct { type TranscribeResult struct {
Text string `json:"text"` Text string `json:"text"`
Language string `json:"language,omitempty"` Language string `json:"language,omitempty"`
} }
// Transcribe sends audio to Whisper and returns the transcription. // Transcribe sends audio to Whisper and returns the transcription.
func (c *STTClient) Transcribe(ctx context.Context, audio []byte, language string) (*TranscribeResult, error) { func (c *STTClient) Transcribe(ctx context.Context, audio []byte, language string) (*TranscribeResult, error) {
if language == "" { if language == "" {
language = c.Language language = c.Language
} }
fields := map[string]string{ fields := map[string]string{
"response_format": "json", "response_format": "json",
} }
if language != "" { if language != "" {
fields["language"] = language fields["language"] = language
} }
endpoint := "/v1/audio/transcriptions" endpoint := "/v1/audio/transcriptions"
if c.Task == "translate" { if c.Task == "translate" {
endpoint = "/v1/audio/translations" endpoint = "/v1/audio/translations"
} }
body, err := c.postMultipart(ctx, endpoint, "file", "audio.wav", audio, fields) body, err := c.postMultipart(ctx, endpoint, "file", "audio.wav", audio, fields)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var result TranscribeResult var result TranscribeResult
if err := json.Unmarshal(body, &result); err != nil { if err := json.Unmarshal(body, &result); err != nil {
return nil, err return nil, err
} }
return &result, nil return &result, nil
} }
// --- Milvus Client --- // ─── Milvus Client ──────────────────────────────────────────────────────────
// MilvusClient provides vector search via the Milvus HTTP/gRPC API. // MilvusClient provides vector search via the Milvus HTTP/gRPC API.
// For the Go port we use the Milvus Go SDK. // For the Go port we use the Milvus Go SDK.
type MilvusClient struct { type MilvusClient struct {
Host string Host string
Port int Port int
Collection string Collection string
connected bool connected bool
} }
// NewMilvusClient creates a Milvus client. // NewMilvusClient creates a Milvus client.
func NewMilvusClient(host string, port int, collection string) *MilvusClient { func NewMilvusClient(host string, port int, collection string) *MilvusClient {
return &MilvusClient{Host: host, Port: port, Collection: collection} return &MilvusClient{Host: host, Port: port, Collection: collection}
} }
// SearchResult holds a single vector search hit. // SearchResult holds a single vector search hit.
type SearchResult struct { type SearchResult struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Distance float64 `json:"distance"` Distance float64 `json:"distance"`
Score float64 `json:"score"` Score float64 `json:"score"`
Fields map[string]any `json:"fields,omitempty"` Fields map[string]any `json:"fields,omitempty"`
} }

506
clients/clients_test.go Normal file
View File

@@ -0,0 +1,506 @@
package clients
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
)
// ────────────────────────────────────────────────────────────────────────────
// Shared infrastructure tests
// ────────────────────────────────────────────────────────────────────────────
func TestSharedTransport(t *testing.T) {
// All clients created via newHTTPClient should share the same transport.
c1 := newHTTPClient("http://a:8000", 10*time.Second)
c2 := newHTTPClient("http://b:9000", 30*time.Second)
if c1.client.Transport != c2.client.Transport {
t.Error("clients should share the same http.Transport")
}
if c1.client.Transport != SharedTransport {
t.Error("transport should be the package-level SharedTransport")
}
}
func TestBufferPoolGetPut(t *testing.T) {
buf := getBuf()
if buf == nil {
t.Fatal("getBuf returned nil")
}
if buf.Len() != 0 {
t.Error("getBuf should return a reset buffer")
}
buf.WriteString("hello")
putBuf(buf)
// On re-get, buffer should be reset.
buf2 := getBuf()
if buf2.Len() != 0 {
t.Error("re-acquired buffer should be reset")
}
putBuf(buf2)
}
func TestBufferPoolOversizedDiscarded(t *testing.T) {
buf := getBuf()
// Grow beyond 1 MB threshold.
buf.Write(make([]byte, 2<<20))
putBuf(buf) // should silently discard
// Pool should still work — we get a fresh one.
buf2 := getBuf()
if buf2.Len() != 0 {
t.Error("should get a fresh buffer")
}
putBuf(buf2)
}
func TestBufferPoolConcurrency(t *testing.T) {
var wg sync.WaitGroup
for i := range 100 {
wg.Add(1)
go func(n int) {
defer wg.Done()
buf := getBuf()
buf.WriteString(strings.Repeat("x", n))
putBuf(buf)
}(i)
}
wg.Wait()
}
// ────────────────────────────────────────────────────────────────────────────
// Embeddings client
// ────────────────────────────────────────────────────────────────────────────
func TestEmbeddingsClient_Embed(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/embeddings" {
t.Errorf("path = %q, want /embeddings", r.URL.Path)
}
if r.Method != http.MethodPost {
t.Errorf("method = %s, want POST", r.Method)
}
var req map[string]any
json.NewDecoder(r.Body).Decode(&req)
input, _ := req["input"].([]any)
if len(input) != 2 {
t.Errorf("input len = %d, want 2", len(input))
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"data": []map[string]any{
{"embedding": []float64{0.1, 0.2, 0.3}},
{"embedding": []float64{0.4, 0.5, 0.6}},
},
})
}))
defer ts.Close()
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "bge")
results, err := c.Embed(context.Background(), []string{"hello", "world"})
if err != nil {
t.Fatal(err)
}
if len(results) != 2 {
t.Fatalf("len(results) = %d, want 2", len(results))
}
if results[0][0] != 0.1 {
t.Errorf("results[0][0] = %f, want 0.1", results[0][0])
}
}
func TestEmbeddingsClient_EmbedSingle(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]any{
"data": []map[string]any{
{"embedding": []float64{1.0, 2.0}},
},
})
}))
defer ts.Close()
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
vec, err := c.EmbedSingle(context.Background(), "test")
if err != nil {
t.Fatal(err)
}
if len(vec) != 2 || vec[0] != 1.0 {
t.Errorf("vec = %v", vec)
}
}
func TestEmbeddingsClient_EmbedEmpty(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]any{"data": []any{}})
}))
defer ts.Close()
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
_, err := c.EmbedSingle(context.Background(), "test")
if err == nil {
t.Error("expected error for empty embedding")
}
}
func TestEmbeddingsClient_Health(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/health" {
w.WriteHeader(200)
return
}
w.WriteHeader(404)
}))
defer ts.Close()
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
if !c.Health(context.Background()) {
t.Error("expected healthy")
}
}
// ────────────────────────────────────────────────────────────────────────────
// Reranker client
// ────────────────────────────────────────────────────────────────────────────
func TestRerankerClient_Rerank(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req map[string]any
json.NewDecoder(r.Body).Decode(&req)
if req["query"] != "test query" {
t.Errorf("query = %v", req["query"])
}
json.NewEncoder(w).Encode(map[string]any{
"results": []map[string]any{
{"index": 1, "relevance_score": 0.95},
{"index": 0, "relevance_score": 0.80},
},
})
}))
defer ts.Close()
c := NewRerankerClient(ts.URL, 5*time.Second)
docs := []string{"Paris is great", "France is in Europe"}
results, err := c.Rerank(context.Background(), "test query", docs, 2)
if err != nil {
t.Fatal(err)
}
if len(results) != 2 {
t.Fatalf("len = %d", len(results))
}
if results[0].Score != 0.95 {
t.Errorf("score = %f, want 0.95", results[0].Score)
}
if results[0].Document != "France is in Europe" {
t.Errorf("document = %q", results[0].Document)
}
}
func TestRerankerClient_RerankFallbackScore(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]any{
"results": []map[string]any{
{"index": 0, "score": 0.77, "relevance_score": 0}, // some APIs only set score
},
})
}))
defer ts.Close()
c := NewRerankerClient(ts.URL, 5*time.Second)
results, err := c.Rerank(context.Background(), "q", []string{"doc1"}, 0)
if err != nil {
t.Fatal(err)
}
if results[0].Score != 0.77 {
t.Errorf("fallback score = %f, want 0.77", results[0].Score)
}
}
// ────────────────────────────────────────────────────────────────────────────
// LLM client
// ────────────────────────────────────────────────────────────────────────────
func TestLLMClient_Generate(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/chat/completions" {
t.Errorf("path = %q", r.URL.Path)
}
var req map[string]any
json.NewDecoder(r.Body).Decode(&req)
msgs, _ := req["messages"].([]any)
if len(msgs) == 0 {
t.Error("no messages in request")
}
json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{"message": map[string]any{"content": "Paris is the capital of France."}},
},
})
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
result, err := c.Generate(context.Background(), "capital of France?", "", "")
if err != nil {
t.Fatal(err)
}
if result != "Paris is the capital of France." {
t.Errorf("result = %q", result)
}
}
func TestLLMClient_GenerateWithContext(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req map[string]any
json.NewDecoder(r.Body).Decode(&req)
msgs, _ := req["messages"].([]any)
// Should have system + user message
if len(msgs) != 2 {
t.Errorf("expected 2 messages, got %d", len(msgs))
}
json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{"message": map[string]any{"content": "answer with context"}},
},
})
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
result, err := c.Generate(context.Background(), "question", "some context", "")
if err != nil {
t.Fatal(err)
}
if result != "answer with context" {
t.Errorf("result = %q", result)
}
}
func TestLLMClient_GenerateNoChoices(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]any{"choices": []any{}})
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
_, err := c.Generate(context.Background(), "q", "", "")
if err == nil {
t.Error("expected error for empty choices")
}
}
// ────────────────────────────────────────────────────────────────────────────
// TTS client
// ────────────────────────────────────────────────────────────────────────────
func TestTTSClient_Synthesize(t *testing.T) {
expected := []byte{0xDE, 0xAD, 0xBE, 0xEF}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/tts" {
t.Errorf("path = %q", r.URL.Path)
}
if r.URL.Query().Get("text") != "hello world" {
t.Errorf("text = %q", r.URL.Query().Get("text"))
}
w.Write(expected)
}))
defer ts.Close()
c := NewTTSClient(ts.URL, 5*time.Second, "en")
audio, err := c.Synthesize(context.Background(), "hello world", "", "")
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(audio, expected) {
t.Errorf("audio = %x, want %x", audio, expected)
}
}
func TestTTSClient_SynthesizeWithSpeaker(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("speaker_id") != "alice" {
t.Errorf("speaker_id = %q", r.URL.Query().Get("speaker_id"))
}
w.Write([]byte{0x01})
}))
defer ts.Close()
c := NewTTSClient(ts.URL, 5*time.Second, "en")
_, err := c.Synthesize(context.Background(), "hi", "en", "alice")
if err != nil {
t.Fatal(err)
}
}
// ────────────────────────────────────────────────────────────────────────────
// STT client
// ────────────────────────────────────────────────────────────────────────────
func TestSTTClient_Transcribe(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/audio/transcriptions" {
t.Errorf("path = %q", r.URL.Path)
}
ct := r.Header.Get("Content-Type")
if !strings.Contains(ct, "multipart/form-data") {
t.Errorf("content-type = %q", ct)
}
// Verify the audio file is present.
file, _, err := r.FormFile("file")
if err != nil {
t.Fatal(err)
}
data, _ := io.ReadAll(file)
if len(data) != 100 {
t.Errorf("file size = %d, want 100", len(data))
}
json.NewEncoder(w).Encode(map[string]string{"text": "hello world"})
}))
defer ts.Close()
c := NewSTTClient(ts.URL, 5*time.Second)
result, err := c.Transcribe(context.Background(), make([]byte, 100), "en")
if err != nil {
t.Fatal(err)
}
if result.Text != "hello world" {
t.Errorf("text = %q", result.Text)
}
}
func TestSTTClient_TranscribeTranslate(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/audio/translations" {
t.Errorf("path = %q, want /v1/audio/translations", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]string{"text": "translated"})
}))
defer ts.Close()
c := NewSTTClient(ts.URL, 5*time.Second)
c.Task = "translate"
result, err := c.Transcribe(context.Background(), []byte{0x01}, "")
if err != nil {
t.Fatal(err)
}
if result.Text != "translated" {
t.Errorf("text = %q", result.Text)
}
}
// ────────────────────────────────────────────────────────────────────────────
// HTTP error handling
// ────────────────────────────────────────────────────────────────────────────
func TestHTTPError4xx(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(422)
w.Write([]byte(`{"error": "bad input"}`))
}))
defer ts.Close()
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
_, err := c.Embed(context.Background(), []string{"test"})
if err == nil {
t.Fatal("expected error for 422")
}
if !strings.Contains(err.Error(), "422") {
t.Errorf("error should contain status code: %v", err)
}
}
func TestHTTPError5xx(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(500)
w.Write([]byte("internal server error"))
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
_, err := c.Generate(context.Background(), "q", "", "")
if err == nil {
t.Fatal("expected error for 500")
}
}
// ────────────────────────────────────────────────────────────────────────────
// buildMessages helper
// ────────────────────────────────────────────────────────────────────────────
func TestBuildMessages(t *testing.T) {
// No context, no system prompt → just user message
msgs := buildMessages("hello", "", "")
if len(msgs) != 1 || msgs[0].Role != "user" {
t.Errorf("expected 1 user msg, got %+v", msgs)
}
// With system prompt
msgs = buildMessages("hello", "", "You are helpful")
if len(msgs) != 2 || msgs[0].Role != "system" || msgs[0].Content != "You are helpful" {
t.Errorf("expected system+user, got %+v", msgs)
}
// With context, no system prompt → auto system prompt
msgs = buildMessages("question", "some context", "")
if len(msgs) != 2 || msgs[0].Role != "system" {
t.Errorf("expected auto system+user, got %+v", msgs)
}
if !strings.Contains(msgs[1].Content, "Context:") {
t.Error("user message should contain context")
}
}
// ────────────────────────────────────────────────────────────────────────────
// Benchmarks: pooled buffer vs direct allocation
// ────────────────────────────────────────────────────────────────────────────
func BenchmarkPostJSON(b *testing.B) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
io.Copy(io.Discard, r.Body)
w.Write([]byte(`{"ok":true}`))
}))
defer ts.Close()
c := newHTTPClient(ts.URL, 10*time.Second)
ctx := context.Background()
payload := map[string]any{
"text": strings.Repeat("x", 1024),
"count": 42,
"enabled": true,
}
b.ResetTimer()
for b.Loop() {
c.postJSON(ctx, "/test", payload)
}
}
func BenchmarkBufferPool(b *testing.B) {
b.ResetTimer()
for b.Loop() {
buf := getBuf()
buf.WriteString(strings.Repeat("x", 4096))
putBuf(buf)
}
}
func BenchmarkBufferPoolParallel(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
buf := getBuf()
buf.WriteString(strings.Repeat("x", 4096))
putBuf(buf)
}
})
}

View File

@@ -1,145 +1,268 @@
// Package config provides environment-based configuration for handler services. // Package config provides environment-based configuration for handler services
// with optional live reload of secrets and service endpoints.
package config package config
import ( import (
"os" "context"
"strconv" "log/slog"
"time" "os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/fsnotify/fsnotify"
) )
// Settings holds base configuration for all handler services. // Settings holds base configuration for all handler services.
// Values are loaded from environment variables with sensible defaults. // Fields in the "hot-reload" section are protected by a RWMutex and can be
// updated at runtime via WatchSecrets(). All other fields are immutable
// after Load() returns.
type Settings struct { type Settings struct {
// Service identification // Service identification (immutable)
ServiceName string ServiceName string
ServiceVersion string ServiceVersion string
ServiceNamespace string ServiceNamespace string
DeploymentEnv string DeploymentEnv string
// NATS configuration // NATS configuration (immutable)
NATSURL string NATSURL string
NATSUser string NATSUser string
NATSPassword string NATSPassword string
NATSQueueGroup string NATSQueueGroup string
// Redis/Valkey configuration // Redis/Valkey configuration (immutable)
RedisURL string RedisURL string
RedisPassword string RedisPassword string
// Milvus configuration // Milvus configuration (immutable)
MilvusHost string MilvusHost string
MilvusPort int MilvusPort int
MilvusCollection string MilvusCollection string
// Service endpoints // OpenTelemetry configuration (immutable)
EmbeddingsURL string OTELEnabled bool
RerankerURL string OTELEndpoint string
LLMURL string OTELUseHTTP bool
TTSURL string
STTURL string
// OpenTelemetry configuration // HyperDX configuration (immutable)
OTELEnabled bool HyperDXEnabled bool
OTELEndpoint string HyperDXAPIKey string
OTELUseHTTP bool HyperDXEndpoint string
// HyperDX configuration // MLflow configuration (immutable)
HyperDXEnabled bool MLflowTrackingURI string
HyperDXAPIKey string MLflowExperimentName string
HyperDXEndpoint string MLflowEnabled bool
// MLflow configuration // Health check configuration (immutable)
MLflowTrackingURI string HealthPort int
MLflowExperimentName string HealthPath string
MLflowEnabled bool ReadyPath string
// Health check configuration // Timeouts (immutable)
HealthPort int HTTPTimeout time.Duration
HealthPath string NATSTimeout time.Duration
ReadyPath string
// Timeouts // Hot-reloadable fields — access via getter methods.
HTTPTimeout time.Duration mu sync.RWMutex
NATSTimeout time.Duration embeddingsURL string
rerankerURL string
llmURL string
ttsURL string
sttURL string
// Secrets path for file-based hot reload (Kubernetes secret mounts)
SecretsPath string
} }
// Load creates a Settings populated from environment variables with defaults. // Load creates a Settings populated from environment variables with defaults.
func Load() *Settings { func Load() *Settings {
return &Settings{ return &Settings{
ServiceName: getEnv("SERVICE_NAME", "handler"), ServiceName: getEnv("SERVICE_NAME", "handler"),
ServiceVersion: getEnv("SERVICE_VERSION", "1.0.0"), ServiceVersion: getEnv("SERVICE_VERSION", "1.0.0"),
ServiceNamespace: getEnv("SERVICE_NAMESPACE", "ai-ml"), ServiceNamespace: getEnv("SERVICE_NAMESPACE", "ai-ml"),
DeploymentEnv: getEnv("DEPLOYMENT_ENV", "production"), DeploymentEnv: getEnv("DEPLOYMENT_ENV", "production"),
NATSURL: getEnv("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222"), NATSURL: getEnv("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222"),
NATSUser: getEnv("NATS_USER", ""), NATSUser: getEnv("NATS_USER", ""),
NATSPassword: getEnv("NATS_PASSWORD", ""), NATSPassword: getEnv("NATS_PASSWORD", ""),
NATSQueueGroup: getEnv("NATS_QUEUE_GROUP", ""), NATSQueueGroup: getEnv("NATS_QUEUE_GROUP", ""),
RedisURL: getEnv("REDIS_URL", "redis://valkey.ai-ml.svc.cluster.local:6379"), RedisURL: getEnv("REDIS_URL", "redis://valkey.ai-ml.svc.cluster.local:6379"),
RedisPassword: getEnv("REDIS_PASSWORD", ""), RedisPassword: getEnv("REDIS_PASSWORD", ""),
MilvusHost: getEnv("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local"), MilvusHost: getEnv("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local"),
MilvusPort: getEnvInt("MILVUS_PORT", 19530), MilvusPort: getEnvInt("MILVUS_PORT", 19530),
MilvusCollection: getEnv("MILVUS_COLLECTION", "documents"), MilvusCollection: getEnv("MILVUS_COLLECTION", "documents"),
EmbeddingsURL: getEnv("EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"), embeddingsURL: getEnv("EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"),
RerankerURL: getEnv("RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"), rerankerURL: getEnv("RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"),
LLMURL: getEnv("LLM_URL", "http://vllm-predictor.ai-ml.svc.cluster.local"), llmURL: getEnv("LLM_URL", "http://vllm-predictor.ai-ml.svc.cluster.local"),
TTSURL: getEnv("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local"), ttsURL: getEnv("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local"),
STTURL: getEnv("STT_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"), sttURL: getEnv("STT_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"),
OTELEnabled: getEnvBool("OTEL_ENABLED", true), OTELEnabled: getEnvBool("OTEL_ENABLED", true),
OTELEndpoint: getEnv("OTEL_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317"), OTELEndpoint: getEnv("OTEL_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317"),
OTELUseHTTP: getEnvBool("OTEL_USE_HTTP", false), OTELUseHTTP: getEnvBool("OTEL_USE_HTTP", false),
HyperDXEnabled: getEnvBool("HYPERDX_ENABLED", false), HyperDXEnabled: getEnvBool("HYPERDX_ENABLED", false),
HyperDXAPIKey: getEnv("HYPERDX_API_KEY", ""), HyperDXAPIKey: getEnv("HYPERDX_API_KEY", ""),
HyperDXEndpoint: getEnv("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io"), HyperDXEndpoint: getEnv("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io"),
MLflowTrackingURI: getEnv("MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"), MLflowTrackingURI: getEnv("MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"),
MLflowExperimentName: getEnv("MLFLOW_EXPERIMENT_NAME", ""), MLflowExperimentName: getEnv("MLFLOW_EXPERIMENT_NAME", ""),
MLflowEnabled: getEnvBool("MLFLOW_ENABLED", true), MLflowEnabled: getEnvBool("MLFLOW_ENABLED", true),
HealthPort: getEnvInt("HEALTH_PORT", 8080), HealthPort: getEnvInt("HEALTH_PORT", 8080),
HealthPath: getEnv("HEALTH_PATH", "/health"), HealthPath: getEnv("HEALTH_PATH", "/health"),
ReadyPath: getEnv("READY_PATH", "/ready"), ReadyPath: getEnv("READY_PATH", "/ready"),
HTTPTimeout: getEnvDuration("HTTP_TIMEOUT", 60*time.Second), HTTPTimeout: getEnvDuration("HTTP_TIMEOUT", 60*time.Second),
NATSTimeout: getEnvDuration("NATS_TIMEOUT", 30*time.Second), NATSTimeout: getEnvDuration("NATS_TIMEOUT", 30*time.Second),
}
SecretsPath: getEnv("SECRETS_PATH", ""),
}
}
// EmbeddingsURL returns the current embeddings service URL (thread-safe).
func (s *Settings) EmbeddingsURL() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.embeddingsURL
}
// RerankerURL returns the current reranker service URL (thread-safe).
func (s *Settings) RerankerURL() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.rerankerURL
}
// LLMURL returns the current LLM service URL (thread-safe).
func (s *Settings) LLMURL() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.llmURL
}
// TTSURL returns the current TTS service URL (thread-safe).
func (s *Settings) TTSURL() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.ttsURL
}
// STTURL returns the current STT service URL (thread-safe).
func (s *Settings) STTURL() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.sttURL
}
// WatchSecrets watches the SecretsPath directory for changes and reloads
// hot-reloadable fields. Blocks until ctx is cancelled.
func (s *Settings) WatchSecrets(ctx context.Context) {
if s.SecretsPath == "" {
return
}
watcher, err := fsnotify.NewWatcher()
if err != nil {
slog.Error("config: failed to create fsnotify watcher", "error", err)
return
}
defer func() { _ = watcher.Close() }()
if err := watcher.Add(s.SecretsPath); err != nil {
slog.Error("config: failed to watch secrets path", "error", err, "path", s.SecretsPath)
return
}
slog.Info("config: watching secrets for hot reload", "path", s.SecretsPath)
for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
if event.Has(fsnotify.Create) || event.Has(fsnotify.Write) {
s.reloadFromSecrets()
}
case err, ok := <-watcher.Errors:
if !ok {
return
}
slog.Error("config: fsnotify error", "error", err)
case <-ctx.Done():
return
}
}
}
// reloadFromSecrets reads hot-reloadable values from the secrets directory.
func (s *Settings) reloadFromSecrets() {
s.mu.Lock()
defer s.mu.Unlock()
updated := 0
reload := func(filename string, target *string) {
path := filepath.Join(s.SecretsPath, filename)
data, err := os.ReadFile(path)
if err != nil {
return
}
val := strings.TrimSpace(string(data))
if val != "" && val != *target {
*target = val
updated++
slog.Info("config: reloaded secret", "key", filename)
}
}
reload("embeddings-url", &s.embeddingsURL)
reload("reranker-url", &s.rerankerURL)
reload("llm-url", &s.llmURL)
reload("tts-url", &s.ttsURL)
reload("stt-url", &s.sttURL)
if updated > 0 {
slog.Info("config: secrets reloaded", "updated", updated)
}
} }
func getEnv(key, fallback string) string { func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" { if v := os.Getenv(key); v != "" {
return v return v
} }
return fallback return fallback
} }
func getEnvInt(key string, fallback int) int { func getEnvInt(key string, fallback int) int {
if v := os.Getenv(key); v != "" { if v := os.Getenv(key); v != "" {
if i, err := strconv.Atoi(v); err == nil { if i, err := strconv.Atoi(v); err == nil {
return i return i
} }
} }
return fallback return fallback
} }
func getEnvBool(key string, fallback bool) bool { func getEnvBool(key string, fallback bool) bool {
if v := os.Getenv(key); v != "" { if v := os.Getenv(key); v != "" {
if b, err := strconv.ParseBool(v); err == nil { if b, err := strconv.ParseBool(v); err == nil {
return b return b
} }
} }
return fallback return fallback
} }
func getEnvDuration(key string, fallback time.Duration) time.Duration { func getEnvDuration(key string, fallback time.Duration) time.Duration {
if v := os.Getenv(key); v != "" { if v := os.Getenv(key); v != "" {
if f, err := strconv.ParseFloat(v, 64); err == nil { if f, err := strconv.ParseFloat(v, 64); err == nil {
return time.Duration(f * float64(time.Second)) return time.Duration(f * float64(time.Second))
} }
} }
return fallback return fallback
} }

View File

@@ -1,42 +1,123 @@
package config package config
import ( import (
"os" "os"
"testing" "path/filepath"
"time" "testing"
"time"
) )
func TestLoadDefaults(t *testing.T) { func TestLoadDefaults(t *testing.T) {
s := Load() s := Load()
if s.ServiceName != "handler" { if s.ServiceName != "handler" {
t.Errorf("expected default ServiceName 'handler', got %q", s.ServiceName) t.Errorf("expected default ServiceName 'handler', got %q", s.ServiceName)
} }
if s.HealthPort != 8080 { if s.HealthPort != 8080 {
t.Errorf("expected default HealthPort 8080, got %d", s.HealthPort) t.Errorf("expected default HealthPort 8080, got %d", s.HealthPort)
} }
if s.HTTPTimeout != 60*time.Second { if s.HTTPTimeout != 60*time.Second {
t.Errorf("expected default HTTPTimeout 60s, got %v", s.HTTPTimeout) t.Errorf("expected default HTTPTimeout 60s, got %v", s.HTTPTimeout)
} }
} }
func TestLoadFromEnv(t *testing.T) { func TestLoadFromEnv(t *testing.T) {
os.Setenv("SERVICE_NAME", "test-svc") t.Setenv("SERVICE_NAME", "test-svc")
os.Setenv("HEALTH_PORT", "9090") t.Setenv("HEALTH_PORT", "9090")
os.Setenv("OTEL_ENABLED", "false") t.Setenv("OTEL_ENABLED", "false")
defer func() {
os.Unsetenv("SERVICE_NAME")
os.Unsetenv("HEALTH_PORT")
os.Unsetenv("OTEL_ENABLED")
}()
s := Load() s := Load()
if s.ServiceName != "test-svc" { if s.ServiceName != "test-svc" {
t.Errorf("expected ServiceName 'test-svc', got %q", s.ServiceName) t.Errorf("expected ServiceName 'test-svc', got %q", s.ServiceName)
} }
if s.HealthPort != 9090 { if s.HealthPort != 9090 {
t.Errorf("expected HealthPort 9090, got %d", s.HealthPort) t.Errorf("expected HealthPort 9090, got %d", s.HealthPort)
} }
if s.OTELEnabled { if s.OTELEnabled {
t.Error("expected OTELEnabled false") t.Error("expected OTELEnabled false")
} }
}
func TestURLGetters(t *testing.T) {
s := Load()
if s.EmbeddingsURL() == "" {
t.Error("EmbeddingsURL should have a default")
}
if s.RerankerURL() == "" {
t.Error("RerankerURL should have a default")
}
if s.LLMURL() == "" {
t.Error("LLMURL should have a default")
}
if s.TTSURL() == "" {
t.Error("TTSURL should have a default")
}
if s.STTURL() == "" {
t.Error("STTURL should have a default")
}
}
func TestURLGettersFromEnv(t *testing.T) {
t.Setenv("EMBEDDINGS_URL", "http://embed:8000")
t.Setenv("LLM_URL", "http://llm:9000")
s := Load()
if s.EmbeddingsURL() != "http://embed:8000" {
t.Errorf("expected custom EmbeddingsURL, got %q", s.EmbeddingsURL())
}
if s.LLMURL() != "http://llm:9000" {
t.Errorf("expected custom LLMURL, got %q", s.LLMURL())
}
}
func TestReloadFromSecrets(t *testing.T) {
dir := t.TempDir()
// Write initial secret files
writeSecret(t, dir, "embeddings-url", "http://old-embed:8000")
writeSecret(t, dir, "llm-url", "http://old-llm:9000")
s := Load()
s.SecretsPath = dir
s.reloadFromSecrets()
if s.EmbeddingsURL() != "http://old-embed:8000" {
t.Errorf("expected reloaded EmbeddingsURL, got %q", s.EmbeddingsURL())
}
if s.LLMURL() != "http://old-llm:9000" {
t.Errorf("expected reloaded LLMURL, got %q", s.LLMURL())
}
// Simulate secret update
writeSecret(t, dir, "embeddings-url", "http://new-embed:8000")
s.reloadFromSecrets()
if s.EmbeddingsURL() != "http://new-embed:8000" {
t.Errorf("expected updated EmbeddingsURL, got %q", s.EmbeddingsURL())
}
// LLM should remain unchanged
if s.LLMURL() != "http://old-llm:9000" {
t.Errorf("expected unchanged LLMURL, got %q", s.LLMURL())
}
}
func TestReloadFromSecretsNoPath(t *testing.T) {
s := Load()
s.SecretsPath = ""
// Should not panic
s.reloadFromSecrets()
}
func TestGetEnvDuration(t *testing.T) {
t.Setenv("TEST_DUR", "30")
d := getEnvDuration("TEST_DUR", 10*time.Second)
if d != 30*time.Second {
t.Errorf("expected 30s, got %v", d)
}
}
func writeSecret(t *testing.T, dir, name, value string) {
t.Helper()
if err := os.WriteFile(filepath.Join(dir, name), []byte(value), 0644); err != nil {
t.Fatal(err)
}
} }

1
go.mod
View File

@@ -17,6 +17,7 @@ require (
require ( require (
github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect
github.com/google/uuid v1.6.0 // indirect github.com/google/uuid v1.6.0 // indirect

2
go.sum
View File

@@ -4,6 +4,8 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=

201
handler/handler_test.go Normal file
View File

@@ -0,0 +1,201 @@
package handler
import (
"context"
"testing"
"github.com/nats-io/nats.go"
"github.com/vmihailenco/msgpack/v5"
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
)
// ────────────────────────────────────────────────────────────────────────────
// Handler construction tests
// ────────────────────────────────────────────────────────────────────────────
func TestNewHandler(t *testing.T) {
cfg := config.Load()
cfg.ServiceName = "test-handler"
cfg.NATSQueueGroup = "test-group"
h := New("ai.test.subject", cfg)
if h.Subject != "ai.test.subject" {
t.Errorf("Subject = %q", h.Subject)
}
if h.QueueGroup != "test-group" {
t.Errorf("QueueGroup = %q", h.QueueGroup)
}
if h.Settings.ServiceName != "test-handler" {
t.Errorf("ServiceName = %q", h.Settings.ServiceName)
}
}
func TestNewHandlerNilSettings(t *testing.T) {
h := New("ai.test", nil)
if h.Settings == nil {
t.Fatal("Settings should be loaded automatically")
}
if h.Settings.ServiceName != "handler" {
t.Errorf("ServiceName = %q, want default", h.Settings.ServiceName)
}
}
func TestCallbackRegistration(t *testing.T) {
cfg := config.Load()
h := New("ai.test", cfg)
setupCalled := false
h.OnSetup(func(ctx context.Context) error {
setupCalled = true
return nil
})
teardownCalled := false
h.OnTeardown(func(ctx context.Context) error {
teardownCalled = true
return nil
})
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
return nil, nil
})
if h.onSetup == nil || h.onTeardown == nil || h.onMessage == nil {
t.Error("callbacks should not be nil after registration")
}
// Verify setup/teardown work when called directly.
h.onSetup(context.Background())
h.onTeardown(context.Background())
if !setupCalled || !teardownCalled {
t.Error("callbacks should have been invoked")
}
}
// ────────────────────────────────────────────────────────────────────────────
// wrapHandler dispatch tests (unit test the message decode + dispatch logic)
// ────────────────────────────────────────────────────────────────────────────
func TestWrapHandler_ValidMessage(t *testing.T) {
cfg := config.Load()
h := New("ai.test", cfg)
var receivedData map[string]any
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
receivedData = data
return map[string]any{"status": "ok"}, nil
})
// Encode a message the same way services would.
payload := map[string]any{
"request_id": "test-001",
"message": "hello",
"premium": true,
}
encoded, err := msgpack.Marshal(payload)
if err != nil {
t.Fatal(err)
}
// Call wrapHandler directly without NATS.
handler := h.wrapHandler(context.Background())
handler(&nats.Msg{
Subject: "ai.test.user.42.message",
Data: encoded,
})
if receivedData == nil {
t.Fatal("handler was not called")
}
if receivedData["request_id"] != "test-001" {
t.Errorf("request_id = %v", receivedData["request_id"])
}
if receivedData["premium"] != true {
t.Errorf("premium = %v", receivedData["premium"])
}
}
func TestWrapHandler_InvalidMsgpack(t *testing.T) {
cfg := config.Load()
h := New("ai.test", cfg)
handlerCalled := false
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
handlerCalled = true
return nil, nil
})
handler := h.wrapHandler(context.Background())
handler(&nats.Msg{
Subject: "ai.test",
Data: []byte{0xFF, 0xFE, 0xFD}, // invalid msgpack
})
if handlerCalled {
t.Error("handler should not be called for invalid msgpack")
}
}
func TestWrapHandler_HandlerError(t *testing.T) {
cfg := config.Load()
h := New("ai.test", cfg)
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
return nil, context.DeadlineExceeded
})
encoded, _ := msgpack.Marshal(map[string]any{"key": "val"})
handler := h.wrapHandler(context.Background())
// Should not panic even when handler returns error.
handler(&nats.Msg{
Subject: "ai.test",
Data: encoded,
})
}
func TestWrapHandler_NilResponse(t *testing.T) {
cfg := config.Load()
h := New("ai.test", cfg)
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
return nil, nil // fire-and-forget style
})
encoded, _ := msgpack.Marshal(map[string]any{"x": 1})
handler := h.wrapHandler(context.Background())
// Should not panic with nil response and no reply subject.
handler(&nats.Msg{
Subject: "ai.test",
Data: encoded,
})
}
// ────────────────────────────────────────────────────────────────────────────
// Benchmark: message decode + dispatch overhead
// ────────────────────────────────────────────────────────────────────────────
func BenchmarkWrapHandler(b *testing.B) {
cfg := config.Load()
h := New("ai.test", cfg)
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
return map[string]any{"ok": true}, nil
})
payload := map[string]any{
"request_id": "bench-001",
"message": "What is the capital of France?",
"premium": true,
"top_k": 10,
}
encoded, _ := msgpack.Marshal(payload)
handler := h.wrapHandler(context.Background())
msg := &nats.Msg{Subject: "ai.test", Data: encoded}
b.ResetTimer()
for b.Loop() {
handler(msg)
}
}

515
messages/bench_test.go Normal file
View File

@@ -0,0 +1,515 @@
// Package messages benchmarks compare three serialization strategies:
//
// 1. msgpack map[string]any — the old approach (dynamic, no types)
// 2. msgpack typed struct — the new approach (compile-time safe, short keys)
// 3. protobuf — optional future migration
//
// Run with:
//
// go test -bench=. -benchmem -count=5 ./messages/... | tee bench.txt
// # optional: go install golang.org/x/perf/cmd/benchstat@latest && benchstat bench.txt
package messages
import (
"testing"
"time"
"github.com/vmihailenco/msgpack/v5"
"google.golang.org/protobuf/proto"
pb "git.daviestechlabs.io/daviestechlabs/handler-base/messages/proto"
)
// ────────────────────────────────────────────────────────────────────────────
// Test fixtures — equivalent data across all three encodings
// ────────────────────────────────────────────────────────────────────────────
// chatRequestMap is the legacy map[string]any representation.
func chatRequestMap() map[string]any {
return map[string]any{
"request_id": "req-abc-123",
"user_id": "user-42",
"message": "What is the capital of France?",
"query": "",
"premium": true,
"enable_rag": true,
"enable_reranker": true,
"enable_streaming": false,
"top_k": 10,
"collection": "documents",
"enable_tts": false,
"system_prompt": "You are a helpful assistant.",
"response_subject": "ai.chat.response.req-abc-123",
}
}
// chatRequestStruct is the typed struct representation.
func chatRequestStruct() ChatRequest {
return ChatRequest{
RequestID: "req-abc-123",
UserID: "user-42",
Message: "What is the capital of France?",
Premium: true,
EnableRAG: true,
EnableReranker: true,
TopK: 10,
Collection: "documents",
SystemPrompt: "You are a helpful assistant.",
ResponseSubject: "ai.chat.response.req-abc-123",
}
}
// chatRequestProto is the protobuf representation.
func chatRequestProto() *pb.ChatRequest {
return &pb.ChatRequest{
RequestId: "req-abc-123",
UserId: "user-42",
Message: "What is the capital of France?",
Premium: true,
EnableRag: true,
EnableReranker: true,
TopK: 10,
Collection: "documents",
SystemPrompt: "You are a helpful assistant.",
ResponseSubject: "ai.chat.response.req-abc-123",
}
}
// voiceResponseMap is a voice response with a 16 KB audio payload.
func voiceResponseMap() map[string]any {
return map[string]any{
"request_id": "vr-001",
"response": "The capital of France is Paris.",
"audio": make([]byte, 16384),
"transcription": "What is the capital of France?",
}
}
func voiceResponseStruct() VoiceResponse {
return VoiceResponse{
RequestID: "vr-001",
Response: "The capital of France is Paris.",
Audio: make([]byte, 16384),
Transcription: "What is the capital of France?",
}
}
func voiceResponseProto() *pb.VoiceResponse {
return &pb.VoiceResponse{
RequestId: "vr-001",
Response: "The capital of France is Paris.",
Audio: make([]byte, 16384),
Transcription: "What is the capital of France?",
}
}
// ttsChunkMap simulates a streaming audio chunk (~32 KB).
func ttsChunkMap() map[string]any {
return map[string]any{
"session_id": "tts-sess-99",
"chunk_index": 3,
"total_chunks": 12,
"audio_b64": string(make([]byte, 32768)), // old: base64 string
"is_last": false,
"timestamp": time.Now().Unix(),
"sample_rate": 24000,
}
}
func ttsChunkStruct() TTSAudioChunk {
return TTSAudioChunk{
SessionID: "tts-sess-99",
ChunkIndex: 3,
TotalChunks: 12,
Audio: make([]byte, 32768), // new: raw bytes
IsLast: false,
Timestamp: time.Now().Unix(),
SampleRate: 24000,
}
}
func ttsChunkProto() *pb.TTSAudioChunk {
return &pb.TTSAudioChunk{
SessionId: "tts-sess-99",
ChunkIndex: 3,
TotalChunks: 12,
Audio: make([]byte, 32768),
IsLast: false,
Timestamp: time.Now().Unix(),
SampleRate: 24000,
}
}
// ────────────────────────────────────────────────────────────────────────────
// Wire-size comparison (run once, printed by TestWireSize)
// ────────────────────────────────────────────────────────────────────────────
func TestWireSize(t *testing.T) {
tests := []struct {
name string
mapData any
structVal any
protoMsg proto.Message
}{
{"ChatRequest", chatRequestMap(), chatRequestStruct(), chatRequestProto()},
{"VoiceResponse", voiceResponseMap(), voiceResponseStruct(), voiceResponseProto()},
{"TTSAudioChunk", ttsChunkMap(), ttsChunkStruct(), ttsChunkProto()},
}
for _, tt := range tests {
mapBytes, _ := msgpack.Marshal(tt.mapData)
structBytes, _ := msgpack.Marshal(tt.structVal)
protoBytes, _ := proto.Marshal(tt.protoMsg)
t.Logf("%-16s map=%5d B struct=%5d B proto=%5d B (struct saves %.0f%%, proto saves %.0f%%)",
tt.name,
len(mapBytes), len(structBytes), len(protoBytes),
100*(1-float64(len(structBytes))/float64(len(mapBytes))),
100*(1-float64(len(protoBytes))/float64(len(mapBytes))),
)
}
}
// ────────────────────────────────────────────────────────────────────────────
// Encode benchmarks
// ────────────────────────────────────────────────────────────────────────────
func BenchmarkEncode_ChatRequest_MsgpackMap(b *testing.B) {
data := chatRequestMap()
b.ResetTimer()
for b.Loop() {
msgpack.Marshal(data)
}
}
func BenchmarkEncode_ChatRequest_MsgpackStruct(b *testing.B) {
data := chatRequestStruct()
b.ResetTimer()
for b.Loop() {
msgpack.Marshal(data)
}
}
func BenchmarkEncode_ChatRequest_Protobuf(b *testing.B) {
data := chatRequestProto()
b.ResetTimer()
for b.Loop() {
proto.Marshal(data)
}
}
func BenchmarkEncode_VoiceResponse_MsgpackMap(b *testing.B) {
data := voiceResponseMap()
b.ResetTimer()
for b.Loop() {
msgpack.Marshal(data)
}
}
func BenchmarkEncode_VoiceResponse_MsgpackStruct(b *testing.B) {
data := voiceResponseStruct()
b.ResetTimer()
for b.Loop() {
msgpack.Marshal(data)
}
}
func BenchmarkEncode_VoiceResponse_Protobuf(b *testing.B) {
data := voiceResponseProto()
b.ResetTimer()
for b.Loop() {
proto.Marshal(data)
}
}
func BenchmarkEncode_TTSChunk_MsgpackMap(b *testing.B) {
data := ttsChunkMap()
b.ResetTimer()
for b.Loop() {
msgpack.Marshal(data)
}
}
func BenchmarkEncode_TTSChunk_MsgpackStruct(b *testing.B) {
data := ttsChunkStruct()
b.ResetTimer()
for b.Loop() {
msgpack.Marshal(data)
}
}
func BenchmarkEncode_TTSChunk_Protobuf(b *testing.B) {
data := ttsChunkProto()
b.ResetTimer()
for b.Loop() {
proto.Marshal(data)
}
}
// ────────────────────────────────────────────────────────────────────────────
// Decode benchmarks
// ────────────────────────────────────────────────────────────────────────────
func BenchmarkDecode_ChatRequest_MsgpackMap(b *testing.B) {
encoded, _ := msgpack.Marshal(chatRequestMap())
b.ResetTimer()
for b.Loop() {
var m map[string]any
msgpack.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_ChatRequest_MsgpackStruct(b *testing.B) {
encoded, _ := msgpack.Marshal(chatRequestStruct())
b.ResetTimer()
for b.Loop() {
var m ChatRequest
msgpack.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_ChatRequest_Protobuf(b *testing.B) {
encoded, _ := proto.Marshal(chatRequestProto())
b.ResetTimer()
for b.Loop() {
var m pb.ChatRequest
proto.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_VoiceResponse_MsgpackMap(b *testing.B) {
encoded, _ := msgpack.Marshal(voiceResponseMap())
b.ResetTimer()
for b.Loop() {
var m map[string]any
msgpack.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_VoiceResponse_MsgpackStruct(b *testing.B) {
encoded, _ := msgpack.Marshal(voiceResponseStruct())
b.ResetTimer()
for b.Loop() {
var m VoiceResponse
msgpack.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_VoiceResponse_Protobuf(b *testing.B) {
encoded, _ := proto.Marshal(voiceResponseProto())
b.ResetTimer()
for b.Loop() {
var m pb.VoiceResponse
proto.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_TTSChunk_MsgpackMap(b *testing.B) {
encoded, _ := msgpack.Marshal(ttsChunkMap())
b.ResetTimer()
for b.Loop() {
var m map[string]any
msgpack.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_TTSChunk_MsgpackStruct(b *testing.B) {
encoded, _ := msgpack.Marshal(ttsChunkStruct())
b.ResetTimer()
for b.Loop() {
var m TTSAudioChunk
msgpack.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_TTSChunk_Protobuf(b *testing.B) {
encoded, _ := proto.Marshal(ttsChunkProto())
b.ResetTimer()
for b.Loop() {
var m pb.TTSAudioChunk
proto.Unmarshal(encoded, &m)
}
}
// ────────────────────────────────────────────────────────────────────────────
// Roundtrip benchmarks (encode + decode)
// ────────────────────────────────────────────────────────────────────────────
func BenchmarkRoundtrip_ChatRequest_MsgpackMap(b *testing.B) {
data := chatRequestMap()
b.ResetTimer()
for b.Loop() {
enc, _ := msgpack.Marshal(data)
var dec map[string]any
msgpack.Unmarshal(enc, &dec)
}
}
func BenchmarkRoundtrip_ChatRequest_MsgpackStruct(b *testing.B) {
data := chatRequestStruct()
b.ResetTimer()
for b.Loop() {
enc, _ := msgpack.Marshal(data)
var dec ChatRequest
msgpack.Unmarshal(enc, &dec)
}
}
func BenchmarkRoundtrip_ChatRequest_Protobuf(b *testing.B) {
data := chatRequestProto()
b.ResetTimer()
for b.Loop() {
enc, _ := proto.Marshal(data)
var dec pb.ChatRequest
proto.Unmarshal(enc, &dec)
}
}
// ────────────────────────────────────────────────────────────────────────────
// Typed struct unit tests — verify roundtrip correctness
// ────────────────────────────────────────────────────────────────────────────
func TestRoundtrip_ChatRequest(t *testing.T) {
orig := chatRequestStruct()
data, err := msgpack.Marshal(orig)
if err != nil {
t.Fatal(err)
}
var dec ChatRequest
if err := msgpack.Unmarshal(data, &dec); err != nil {
t.Fatal(err)
}
if dec.RequestID != orig.RequestID {
t.Errorf("RequestID = %q, want %q", dec.RequestID, orig.RequestID)
}
if dec.Message != orig.Message {
t.Errorf("Message = %q, want %q", dec.Message, orig.Message)
}
if dec.TopK != orig.TopK {
t.Errorf("TopK = %d, want %d", dec.TopK, orig.TopK)
}
if dec.Premium != orig.Premium {
t.Errorf("Premium = %v, want %v", dec.Premium, orig.Premium)
}
if dec.EffectiveQuery() != orig.Message {
t.Errorf("EffectiveQuery() = %q, want %q", dec.EffectiveQuery(), orig.Message)
}
}
func TestRoundtrip_VoiceResponse(t *testing.T) {
orig := voiceResponseStruct()
data, err := msgpack.Marshal(orig)
if err != nil {
t.Fatal(err)
}
var dec VoiceResponse
if err := msgpack.Unmarshal(data, &dec); err != nil {
t.Fatal(err)
}
if dec.RequestID != orig.RequestID {
t.Errorf("RequestID mismatch")
}
if len(dec.Audio) != len(orig.Audio) {
t.Errorf("Audio len = %d, want %d", len(dec.Audio), len(orig.Audio))
}
if dec.Transcription != orig.Transcription {
t.Errorf("Transcription mismatch")
}
}
func TestRoundtrip_TTSAudioChunk(t *testing.T) {
orig := ttsChunkStruct()
data, err := msgpack.Marshal(orig)
if err != nil {
t.Fatal(err)
}
var dec TTSAudioChunk
if err := msgpack.Unmarshal(data, &dec); err != nil {
t.Fatal(err)
}
if dec.SessionID != orig.SessionID {
t.Errorf("SessionID mismatch")
}
if dec.ChunkIndex != orig.ChunkIndex {
t.Errorf("ChunkIndex = %d, want %d", dec.ChunkIndex, orig.ChunkIndex)
}
if len(dec.Audio) != len(orig.Audio) {
t.Errorf("Audio len = %d, want %d", len(dec.Audio), len(orig.Audio))
}
if dec.SampleRate != orig.SampleRate {
t.Errorf("SampleRate = %d, want %d", dec.SampleRate, orig.SampleRate)
}
}
func TestRoundtrip_PipelineTrigger(t *testing.T) {
orig := PipelineTrigger{
RequestID: "pip-001",
Pipeline: "document-ingestion",
Parameters: map[string]any{"source": "s3://bucket/data"},
}
data, err := msgpack.Marshal(orig)
if err != nil {
t.Fatal(err)
}
var dec PipelineTrigger
if err := msgpack.Unmarshal(data, &dec); err != nil {
t.Fatal(err)
}
if dec.Pipeline != orig.Pipeline {
t.Errorf("Pipeline = %q, want %q", dec.Pipeline, orig.Pipeline)
}
if dec.Parameters["source"] != orig.Parameters["source"] {
t.Errorf("Parameters[source] mismatch")
}
}
func TestRoundtrip_STTTranscription(t *testing.T) {
orig := STTTranscription{
SessionID: "stt-001",
Transcript: "hello world",
Sequence: 5,
IsPartial: false,
IsFinal: true,
Timestamp: time.Now().Unix(),
SpeakerID: "speaker-1",
HasVoiceActivity: true,
State: "listening",
}
data, err := msgpack.Marshal(orig)
if err != nil {
t.Fatal(err)
}
var dec STTTranscription
if err := msgpack.Unmarshal(data, &dec); err != nil {
t.Fatal(err)
}
if dec.Transcript != orig.Transcript {
t.Errorf("Transcript = %q, want %q", dec.Transcript, orig.Transcript)
}
if dec.IsFinal != orig.IsFinal {
t.Error("IsFinal mismatch")
}
}
func TestRoundtrip_ErrorResponse(t *testing.T) {
orig := ErrorResponse{Error: true, Message: "something broke", Type: "InternalError"}
data, err := msgpack.Marshal(orig)
if err != nil {
t.Fatal(err)
}
var dec ErrorResponse
if err := msgpack.Unmarshal(data, &dec); err != nil {
t.Fatal(err)
}
if !dec.Error || dec.Message != "something broke" || dec.Type != "InternalError" {
t.Errorf("ErrorResponse roundtrip mismatch: %+v", dec)
}
}
func TestTimestamp(t *testing.T) {
ts := Timestamp()
now := time.Now().Unix()
if ts < now-1 || ts > now+1 {
t.Errorf("Timestamp() = %d, expected ~%d", ts, now)
}
}

224
messages/messages.go Normal file
View File

@@ -0,0 +1,224 @@
// Package messages defines typed NATS message structs for all services.
//
// Using typed structs with short msgpack field tags instead of map[string]any
// provides compile-time safety, smaller wire size (integer-like short keys vs
// full string keys), and faster encode/decode by avoiding interface{} boxing.
//
// Audio data uses raw []byte instead of base64-encoded strings — msgpack
// supports binary natively, eliminating the 33% base64 overhead.
package messages
import "time"
// ────────────────────────────────────────────────────────────────────────────
// Pipeline Bridge
// ────────────────────────────────────────────────────────────────────────────
// PipelineTrigger is the request to start a pipeline.
type PipelineTrigger struct {
RequestID string `msgpack:"request_id" json:"request_id"`
Pipeline string `msgpack:"pipeline" json:"pipeline"`
Parameters map[string]any `msgpack:"parameters,omitempty" json:"parameters,omitempty"`
}
// PipelineStatus is the response / status update for a pipeline run.
type PipelineStatus struct {
RequestID string `msgpack:"request_id" json:"request_id"`
Status string `msgpack:"status" json:"status"`
RunID string `msgpack:"run_id,omitempty" json:"run_id,omitempty"`
Engine string `msgpack:"engine,omitempty" json:"engine,omitempty"`
Pipeline string `msgpack:"pipeline,omitempty" json:"pipeline,omitempty"`
SubmittedAt string `msgpack:"submitted_at,omitempty" json:"submitted_at,omitempty"`
Error string `msgpack:"error,omitempty" json:"error,omitempty"`
AvailablePipelines []string `msgpack:"available_pipelines,omitempty" json:"available_pipelines,omitempty"`
}
// ────────────────────────────────────────────────────────────────────────────
// Chat Handler
// ────────────────────────────────────────────────────────────────────────────
// ChatRequest is an incoming chat message.
type ChatRequest struct {
RequestID string `msgpack:"request_id" json:"request_id"`
UserID string `msgpack:"user_id" json:"user_id"`
Message string `msgpack:"message" json:"message"`
Query string `msgpack:"query,omitempty" json:"query,omitempty"`
Premium bool `msgpack:"premium,omitempty" json:"premium,omitempty"`
EnableRAG bool `msgpack:"enable_rag,omitempty" json:"enable_rag,omitempty"`
EnableReranker bool `msgpack:"enable_reranker,omitempty" json:"enable_reranker,omitempty"`
EnableStreaming bool `msgpack:"enable_streaming,omitempty" json:"enable_streaming,omitempty"`
TopK int `msgpack:"top_k,omitempty" json:"top_k,omitempty"`
Collection string `msgpack:"collection,omitempty" json:"collection,omitempty"`
EnableTTS bool `msgpack:"enable_tts,omitempty" json:"enable_tts,omitempty"`
SystemPrompt string `msgpack:"system_prompt,omitempty" json:"system_prompt,omitempty"`
ResponseSubject string `msgpack:"response_subject,omitempty" json:"response_subject,omitempty"`
}
// EffectiveQuery returns Message or falls back to Query.
func (c *ChatRequest) EffectiveQuery() string {
if c.Message != "" {
return c.Message
}
return c.Query
}
// ChatResponse is the full reply to a chat request.
type ChatResponse struct {
UserID string `msgpack:"user_id" json:"user_id"`
Response string `msgpack:"response" json:"response"`
ResponseText string `msgpack:"response_text" json:"response_text"`
UsedRAG bool `msgpack:"used_rag" json:"used_rag"`
RAGSources []string `msgpack:"rag_sources,omitempty" json:"rag_sources,omitempty"`
Success bool `msgpack:"success" json:"success"`
Audio []byte `msgpack:"audio,omitempty" json:"audio,omitempty"`
Error string `msgpack:"error,omitempty" json:"error,omitempty"`
}
// ChatStreamChunk is a single streaming chunk from an LLM response.
type ChatStreamChunk struct {
RequestID string `msgpack:"request_id" json:"request_id"`
Type string `msgpack:"type" json:"type"`
Content string `msgpack:"content" json:"content"`
Done bool `msgpack:"done" json:"done"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
}
// ────────────────────────────────────────────────────────────────────────────
// Voice Assistant
// ────────────────────────────────────────────────────────────────────────────
// VoiceRequest is an incoming voice-to-voice request.
type VoiceRequest struct {
RequestID string `msgpack:"request_id" json:"request_id"`
Audio []byte `msgpack:"audio" json:"audio"`
Language string `msgpack:"language,omitempty" json:"language,omitempty"`
Collection string `msgpack:"collection,omitempty" json:"collection,omitempty"`
}
// VoiceResponse is the reply to a voice request.
type VoiceResponse struct {
RequestID string `msgpack:"request_id" json:"request_id"`
Response string `msgpack:"response" json:"response"`
Audio []byte `msgpack:"audio" json:"audio"`
Transcription string `msgpack:"transcription,omitempty" json:"transcription,omitempty"`
Sources []DocumentSource `msgpack:"sources,omitempty" json:"sources,omitempty"`
Error string `msgpack:"error,omitempty" json:"error,omitempty"`
}
// DocumentSource is a RAG search result source.
type DocumentSource struct {
Text string `msgpack:"text" json:"text"`
Score float64 `msgpack:"score" json:"score"`
}
// ────────────────────────────────────────────────────────────────────────────
// TTS Module
// ────────────────────────────────────────────────────────────────────────────
// TTSRequest is a text-to-speech synthesis request.
type TTSRequest struct {
Text string `msgpack:"text" json:"text"`
Speaker string `msgpack:"speaker,omitempty" json:"speaker,omitempty"`
Language string `msgpack:"language,omitempty" json:"language,omitempty"`
SpeakerWavB64 string `msgpack:"speaker_wav_b64,omitempty" json:"speaker_wav_b64,omitempty"`
Stream bool `msgpack:"stream,omitempty" json:"stream,omitempty"`
}
// TTSAudioChunk is a streamed audio chunk from TTS synthesis.
type TTSAudioChunk struct {
SessionID string `msgpack:"session_id" json:"session_id"`
ChunkIndex int `msgpack:"chunk_index" json:"chunk_index"`
TotalChunks int `msgpack:"total_chunks" json:"total_chunks"`
Audio []byte `msgpack:"audio" json:"audio"`
IsLast bool `msgpack:"is_last" json:"is_last"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
SampleRate int `msgpack:"sample_rate" json:"sample_rate"`
}
// TTSFullResponse is a non-streamed TTS response (whole audio).
type TTSFullResponse struct {
SessionID string `msgpack:"session_id" json:"session_id"`
Audio []byte `msgpack:"audio" json:"audio"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
SampleRate int `msgpack:"sample_rate" json:"sample_rate"`
}
// TTSStatus is a TTS processing status update.
type TTSStatus struct {
SessionID string `msgpack:"session_id" json:"session_id"`
Status string `msgpack:"status" json:"status"`
Message string `msgpack:"message" json:"message"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
}
// TTSVoiceListResponse is the reply to a voice list request.
type TTSVoiceListResponse struct {
DefaultSpeaker string `msgpack:"default_speaker" json:"default_speaker"`
CustomVoices []TTSVoiceInfo `msgpack:"custom_voices" json:"custom_voices"`
LastRefresh int64 `msgpack:"last_refresh" json:"last_refresh"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
}
// TTSVoiceInfo is summary info about a custom voice.
type TTSVoiceInfo struct {
Name string `msgpack:"name" json:"name"`
Language string `msgpack:"language" json:"language"`
ModelType string `msgpack:"model_type" json:"model_type"`
CreatedAt string `msgpack:"created_at" json:"created_at"`
}
// TTSVoiceRefreshResponse is the reply to a voice refresh request.
type TTSVoiceRefreshResponse struct {
Count int `msgpack:"count" json:"count"`
CustomVoices []TTSVoiceInfo `msgpack:"custom_voices" json:"custom_voices"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
}
// ────────────────────────────────────────────────────────────────────────────
// STT Module
// ────────────────────────────────────────────────────────────────────────────
// STTStreamMessage is any message on the ai.voice.stream.{session} subject.
type STTStreamMessage struct {
Type string `msgpack:"type" json:"type"`
Audio []byte `msgpack:"audio,omitempty" json:"audio,omitempty"`
State string `msgpack:"state,omitempty" json:"state,omitempty"`
SpeakerID string `msgpack:"speaker_id,omitempty" json:"speaker_id,omitempty"`
}
// STTTranscription is the transcription result published by the STT module.
type STTTranscription struct {
SessionID string `msgpack:"session_id" json:"session_id"`
Transcript string `msgpack:"transcript" json:"transcript"`
Sequence int `msgpack:"sequence" json:"sequence"`
IsPartial bool `msgpack:"is_partial" json:"is_partial"`
IsFinal bool `msgpack:"is_final" json:"is_final"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
SpeakerID string `msgpack:"speaker_id" json:"speaker_id"`
HasVoiceActivity bool `msgpack:"has_voice_activity" json:"has_voice_activity"`
State string `msgpack:"state" json:"state"`
}
// STTInterrupt is published when the STT module detects a user interrupt.
type STTInterrupt struct {
SessionID string `msgpack:"session_id" json:"session_id"`
Type string `msgpack:"type" json:"type"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
SpeakerID string `msgpack:"speaker_id" json:"speaker_id"`
}
// ────────────────────────────────────────────────────────────────────────────
// Common / Error
// ────────────────────────────────────────────────────────────────────────────
// ErrorResponse is the standard error reply from any handler.
type ErrorResponse struct {
Error bool `msgpack:"error" json:"error"`
Message string `msgpack:"message" json:"message"`
Type string `msgpack:"type" json:"type"`
}
// Timestamp returns the current Unix timestamp (helper for message construction).
func Timestamp() int64 {
return time.Now().Unix()
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,174 @@
syntax = "proto3";
package messages;
option go_package = "git.daviestechlabs.io/daviestechlabs/handler-base/messages/proto";
// ── Pipeline Bridge ─────────────────────────────────────────────────────────
message PipelineTrigger {
string request_id = 1;
string pipeline = 2;
map<string, string> parameters = 3;
}
message PipelineStatus {
string request_id = 1;
string status = 2;
string run_id = 3;
string engine = 4;
string pipeline = 5;
string submitted_at = 6;
string error = 7;
repeated string available_pipelines = 8;
}
// ── Chat Handler ────────────────────────────────────────────────────────────
message ChatRequest {
string request_id = 1;
string user_id = 2;
string message = 3;
string query = 4;
bool premium = 5;
bool enable_rag = 6;
bool enable_reranker = 7;
bool enable_streaming = 8;
int32 top_k = 9;
string collection = 10;
bool enable_tts = 11;
string system_prompt = 12;
string response_subject = 13;
}
message ChatResponse {
string user_id = 1;
string response = 2;
string response_text = 3;
bool used_rag = 4;
repeated string rag_sources = 5;
bool success = 6;
bytes audio = 7;
string error = 8;
}
message ChatStreamChunk {
string request_id = 1;
string type = 2;
string content = 3;
bool done = 4;
int64 timestamp = 5;
}
// ── Voice Assistant ─────────────────────────────────────────────────────────
message VoiceRequest {
string request_id = 1;
bytes audio = 2;
string language = 3;
string collection = 4;
}
message VoiceResponse {
string request_id = 1;
string response = 2;
bytes audio = 3;
string transcription = 4;
repeated DocumentSource sources = 5;
string error = 6;
}
message DocumentSource {
string text = 1;
double score = 2;
}
// ── TTS Module ──────────────────────────────────────────────────────────────
message TTSRequest {
string text = 1;
string speaker = 2;
string language = 3;
string speaker_wav_b64 = 4;
bool stream = 5;
}
message TTSAudioChunk {
string session_id = 1;
int32 chunk_index = 2;
int32 total_chunks = 3;
bytes audio = 4;
bool is_last = 5;
int64 timestamp = 6;
int32 sample_rate = 7;
}
message TTSFullResponse {
string session_id = 1;
bytes audio = 2;
int64 timestamp = 3;
int32 sample_rate = 4;
}
message TTSStatus {
string session_id = 1;
string status = 2;
string message = 3;
int64 timestamp = 4;
}
message TTSVoiceInfo {
string name = 1;
string language = 2;
string model_type = 3;
string created_at = 4;
}
message TTSVoiceListResponse {
string default_speaker = 1;
repeated TTSVoiceInfo custom_voices = 2;
int64 last_refresh = 3;
int64 timestamp = 4;
}
message TTSVoiceRefreshResponse {
int32 count = 1;
repeated TTSVoiceInfo custom_voices = 2;
int64 timestamp = 3;
}
// ── STT Module ──────────────────────────────────────────────────────────────
message STTStreamMessage {
string type = 1;
bytes audio = 2;
string state = 3;
string speaker_id = 4;
}
message STTTranscription {
string session_id = 1;
string transcript = 2;
int32 sequence = 3;
bool is_partial = 4;
bool is_final = 5;
int64 timestamp = 6;
string speaker_id = 7;
bool has_voice_activity = 8;
string state = 9;
}
message STTInterrupt {
string session_id = 1;
string type = 2;
int64 timestamp = 3;
string speaker_id = 4;
}
// ── Common ──────────────────────────────────────────────────────────────────
message ErrorResponse {
bool error = 1;
string message = 2;
string type = 3;
}

256
natsutil/natsutil_test.go Normal file
View File

@@ -0,0 +1,256 @@
package natsutil
import (
"testing"
"github.com/vmihailenco/msgpack/v5"
)
// ────────────────────────────────────────────────────────────────────────────
// DecodeMsgpackMap tests
// ────────────────────────────────────────────────────────────────────────────
func TestDecodeMsgpackMap_Roundtrip(t *testing.T) {
orig := map[string]any{
"request_id": "req-001",
"user_id": "user-42",
"premium": true,
"top_k": int64(10), // msgpack decodes ints as int64
}
data, err := msgpack.Marshal(orig)
if err != nil {
t.Fatal(err)
}
decoded, err := DecodeMsgpackMap(data)
if err != nil {
t.Fatal(err)
}
if decoded["request_id"] != "req-001" {
t.Errorf("request_id = %v", decoded["request_id"])
}
if decoded["premium"] != true {
t.Errorf("premium = %v", decoded["premium"])
}
}
func TestDecodeMsgpackMap_Empty(t *testing.T) {
data, _ := msgpack.Marshal(map[string]any{})
m, err := DecodeMsgpackMap(data)
if err != nil {
t.Fatal(err)
}
if len(m) != 0 {
t.Errorf("expected empty map, got %v", m)
}
}
func TestDecodeMsgpackMap_InvalidData(t *testing.T) {
_, err := DecodeMsgpackMap([]byte{0xFF, 0xFE})
if err == nil {
t.Error("expected error for invalid msgpack data")
}
}
// ────────────────────────────────────────────────────────────────────────────
// DecodeMsgpack (typed struct) tests
// ────────────────────────────────────────────────────────────────────────────
type testMessage struct {
RequestID string `msgpack:"request_id"`
UserID string `msgpack:"user_id"`
Count int `msgpack:"count"`
Active bool `msgpack:"active"`
}
func TestDecodeMsgpackTyped_Roundtrip(t *testing.T) {
orig := testMessage{
RequestID: "req-typed-001",
UserID: "user-7",
Count: 42,
Active: true,
}
data, err := msgpack.Marshal(orig)
if err != nil {
t.Fatal(err)
}
// Simulate nats.Msg data decoding.
var decoded testMessage
if err := msgpack.Unmarshal(data, &decoded); err != nil {
t.Fatal(err)
}
if decoded.RequestID != orig.RequestID {
t.Errorf("RequestID = %q, want %q", decoded.RequestID, orig.RequestID)
}
if decoded.Count != orig.Count {
t.Errorf("Count = %d, want %d", decoded.Count, orig.Count)
}
if decoded.Active != orig.Active {
t.Errorf("Active = %v, want %v", decoded.Active, orig.Active)
}
}
// TestTypedStructDecodesMapEncoding verifies that a typed struct can be
// decoded from data that was encoded as map[string]any (backwards compat).
func TestTypedStructDecodesMapEncoding(t *testing.T) {
// Encode as map (the old way).
mapData := map[string]any{
"request_id": "req-compat",
"user_id": "user-compat",
"count": int64(99),
"active": false,
}
data, err := msgpack.Marshal(mapData)
if err != nil {
t.Fatal(err)
}
// Decode into typed struct (the new way).
var msg testMessage
if err := msgpack.Unmarshal(data, &msg); err != nil {
t.Fatal(err)
}
if msg.RequestID != "req-compat" {
t.Errorf("RequestID = %q", msg.RequestID)
}
if msg.Count != 99 {
t.Errorf("Count = %d, want 99", msg.Count)
}
}
// ────────────────────────────────────────────────────────────────────────────
// Binary data tests (audio []byte in msgpack)
// ────────────────────────────────────────────────────────────────────────────
type audioMessage struct {
SessionID string `msgpack:"session_id"`
Audio []byte `msgpack:"audio"`
SampleRate int `msgpack:"sample_rate"`
}
func TestBinaryDataRoundtrip(t *testing.T) {
audio := make([]byte, 32768)
for i := range audio {
audio[i] = byte(i % 256)
}
orig := audioMessage{
SessionID: "sess-audio-001",
Audio: audio,
SampleRate: 24000,
}
data, err := msgpack.Marshal(orig)
if err != nil {
t.Fatal(err)
}
var decoded audioMessage
if err := msgpack.Unmarshal(data, &decoded); err != nil {
t.Fatal(err)
}
if len(decoded.Audio) != len(orig.Audio) {
t.Fatalf("audio len = %d, want %d", len(decoded.Audio), len(orig.Audio))
}
for i := range decoded.Audio {
if decoded.Audio[i] != orig.Audio[i] {
t.Fatalf("audio[%d] = %d, want %d", i, decoded.Audio[i], orig.Audio[i])
}
}
}
// TestBinaryVsBase64Size shows the wire-size win of raw bytes vs base64 string.
func TestBinaryVsBase64Size(t *testing.T) {
audio := make([]byte, 16384)
// Old approach: base64 string in map.
import_b64 := make([]byte, (len(audio)*4+2)/3) // approximate base64 size
mapMsg := map[string]any{
"session_id": "sess-1",
"audio_b64": string(import_b64),
}
mapData, _ := msgpack.Marshal(mapMsg)
// New approach: raw bytes in struct.
structMsg := audioMessage{
SessionID: "sess-1",
Audio: audio,
}
structData, _ := msgpack.Marshal(structMsg)
t.Logf("base64-in-map: %d bytes, raw-bytes-in-struct: %d bytes (%.0f%% smaller)",
len(mapData), len(structData),
100*(1-float64(len(structData))/float64(len(mapData))))
}
// ────────────────────────────────────────────────────────────────────────────
// Benchmarks
// ────────────────────────────────────────────────────────────────────────────
func BenchmarkEncodeMap(b *testing.B) {
data := map[string]any{
"request_id": "req-bench",
"user_id": "user-bench",
"message": "What is the weather today?",
"premium": true,
"top_k": 10,
}
for b.Loop() {
msgpack.Marshal(data)
}
}
func BenchmarkEncodeStruct(b *testing.B) {
data := testMessage{
RequestID: "req-bench",
UserID: "user-bench",
Count: 10,
Active: true,
}
for b.Loop() {
msgpack.Marshal(data)
}
}
func BenchmarkDecodeMap(b *testing.B) {
raw, _ := msgpack.Marshal(map[string]any{
"request_id": "req-bench",
"user_id": "user-bench",
"message": "What is the weather today?",
"premium": true,
"top_k": 10,
})
for b.Loop() {
var m map[string]any
msgpack.Unmarshal(raw, &m)
}
}
func BenchmarkDecodeStruct(b *testing.B) {
raw, _ := msgpack.Marshal(testMessage{
RequestID: "req-bench",
UserID: "user-bench",
Count: 10,
Active: true,
})
for b.Loop() {
var m testMessage
msgpack.Unmarshal(raw, &m)
}
}
func BenchmarkDecodeAudio32KB(b *testing.B) {
raw, _ := msgpack.Marshal(audioMessage{
SessionID: "s1",
Audio: make([]byte, 32768),
SampleRate: 24000,
})
for b.Loop() {
var m audioMessage
msgpack.Unmarshal(raw, &m)
}
}