feature/go-handler-refactor #1
@@ -1,4 +1,8 @@
|
|||||||
// Package clients provides HTTP client wrappers for AI/ML backend services.
|
// Package clients provides HTTP client wrappers for AI/ML backend services.
|
||||||
|
//
|
||||||
|
// All clients share a single [http.Transport] for connection pooling across
|
||||||
|
// the process. Request and response bodies are serialized through pooled
|
||||||
|
// [bytes.Buffer]s to reduce GC pressure.
|
||||||
package clients
|
package clients
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -10,10 +14,42 @@ import (
|
|||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// httpClient is a shared interface for all service clients.
|
// ─── Shared transport & buffer pool ─────────────────────────────────────────
|
||||||
|
|
||||||
|
// SharedTransport is the process-wide HTTP transport used by every service
|
||||||
|
// client. Tweak pool sizes here rather than creating per-client transports.
|
||||||
|
var SharedTransport = &http.Transport{
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
MaxIdleConnsPerHost: 10,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
DisableCompression: true, // in-cluster traffic; skip gzip overhead
|
||||||
|
}
|
||||||
|
|
||||||
|
// bufPool recycles *bytes.Buffer to avoid per-request allocations.
|
||||||
|
var bufPool = sync.Pool{
|
||||||
|
New: func() any { return new(bytes.Buffer) },
|
||||||
|
}
|
||||||
|
|
||||||
|
func getBuf() *bytes.Buffer {
|
||||||
|
buf := bufPool.Get().(*bytes.Buffer)
|
||||||
|
buf.Reset()
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func putBuf(buf *bytes.Buffer) {
|
||||||
|
if buf.Cap() > 1<<20 { // don't cache buffers > 1 MiB
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bufPool.Put(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── httpClient base ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// httpClient is the shared base for all service clients.
|
||||||
type httpClient struct {
|
type httpClient struct {
|
||||||
client *http.Client
|
client *http.Client
|
||||||
baseURL string
|
baseURL string
|
||||||
@@ -21,17 +57,21 @@ type httpClient struct {
|
|||||||
|
|
||||||
func newHTTPClient(baseURL string, timeout time.Duration) *httpClient {
|
func newHTTPClient(baseURL string, timeout time.Duration) *httpClient {
|
||||||
return &httpClient{
|
return &httpClient{
|
||||||
client: &http.Client{Timeout: timeout},
|
client: &http.Client{
|
||||||
|
Timeout: timeout,
|
||||||
|
Transport: SharedTransport,
|
||||||
|
},
|
||||||
baseURL: baseURL,
|
baseURL: baseURL,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *httpClient) postJSON(ctx context.Context, path string, body any) ([]byte, error) {
|
func (h *httpClient) postJSON(ctx context.Context, path string, body any) ([]byte, error) {
|
||||||
data, err := json.Marshal(body)
|
buf := getBuf()
|
||||||
if err != nil {
|
defer putBuf(buf)
|
||||||
|
if err := json.NewEncoder(buf).Encode(body); err != nil {
|
||||||
return nil, fmt.Errorf("marshal: %w", err)
|
return nil, fmt.Errorf("marshal: %w", err)
|
||||||
}
|
}
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, bytes.NewReader(data))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -56,8 +96,9 @@ func (h *httpClient) getRaw(ctx context.Context, path string, params url.Values)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *httpClient) postMultipart(ctx context.Context, path string, fieldName string, fileName string, fileData []byte, fields map[string]string) ([]byte, error) {
|
func (h *httpClient) postMultipart(ctx context.Context, path string, fieldName string, fileName string, fileData []byte, fields map[string]string) ([]byte, error) {
|
||||||
var buf bytes.Buffer
|
buf := getBuf()
|
||||||
w := multipart.NewWriter(&buf)
|
defer putBuf(buf)
|
||||||
|
w := multipart.NewWriter(buf)
|
||||||
part, err := w.CreateFormFile(fieldName, fileName)
|
part, err := w.CreateFormFile(fieldName, fileName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -70,7 +111,7 @@ func (h *httpClient) postMultipart(ctx context.Context, path string, fieldName s
|
|||||||
}
|
}
|
||||||
_ = w.Close()
|
_ = w.Close()
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, &buf)
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -84,10 +125,17 @@ func (h *httpClient) do(req *http.Request) ([]byte, error) {
|
|||||||
return nil, fmt.Errorf("http %s %s: %w", req.Method, req.URL.Path, err)
|
return nil, fmt.Errorf("http %s %s: %w", req.Method, req.URL.Path, err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
buf := getBuf()
|
||||||
|
defer putBuf(buf)
|
||||||
|
if _, err := io.Copy(buf, resp.Body); err != nil {
|
||||||
return nil, fmt.Errorf("read body: %w", err)
|
return nil, fmt.Errorf("read body: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return a copy so the pooled buffer can be safely recycled.
|
||||||
|
body := make([]byte, buf.Len())
|
||||||
|
copy(body, buf.Bytes())
|
||||||
|
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
return nil, fmt.Errorf("http %d: %s", resp.StatusCode, string(body))
|
return nil, fmt.Errorf("http %d: %s", resp.StatusCode, string(body))
|
||||||
}
|
}
|
||||||
@@ -100,7 +148,7 @@ func (h *httpClient) healthCheck(ctx context.Context) bool {
|
|||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Embeddings Client ---
|
// ─── Embeddings Client ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
// EmbeddingsClient calls the embeddings service (Infinity/BGE).
|
// EmbeddingsClient calls the embeddings service (Infinity/BGE).
|
||||||
type EmbeddingsClient struct {
|
type EmbeddingsClient struct {
|
||||||
@@ -157,7 +205,7 @@ func (c *EmbeddingsClient) Health(ctx context.Context) bool {
|
|||||||
return c.healthCheck(ctx)
|
return c.healthCheck(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Reranker Client ---
|
// ─── Reranker Client ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
// RerankerClient calls the reranker service (BGE Reranker).
|
// RerankerClient calls the reranker service (BGE Reranker).
|
||||||
type RerankerClient struct {
|
type RerankerClient struct {
|
||||||
@@ -214,7 +262,7 @@ func (c *RerankerClient) Rerank(ctx context.Context, query string, documents []s
|
|||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- LLM Client ---
|
// ─── LLM Client ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
// LLMClient calls the vLLM-compatible LLM service.
|
// LLMClient calls the vLLM-compatible LLM service.
|
||||||
type LLMClient struct {
|
type LLMClient struct {
|
||||||
@@ -287,7 +335,7 @@ func buildMessages(prompt, ctx, systemPrompt string) []ChatMessage {
|
|||||||
return msgs
|
return msgs
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- TTS Client ---
|
// ─── TTS Client ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
// TTSClient calls the TTS service (Coqui XTTS).
|
// TTSClient calls the TTS service (Coqui XTTS).
|
||||||
type TTSClient struct {
|
type TTSClient struct {
|
||||||
@@ -318,7 +366,7 @@ func (c *TTSClient) Synthesize(ctx context.Context, text, language, speaker stri
|
|||||||
return c.getRaw(ctx, "/api/tts", params)
|
return c.getRaw(ctx, "/api/tts", params)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- STT Client ---
|
// ─── STT Client ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
// STTClient calls the Whisper STT service.
|
// STTClient calls the Whisper STT service.
|
||||||
type STTClient struct {
|
type STTClient struct {
|
||||||
@@ -364,7 +412,7 @@ func (c *STTClient) Transcribe(ctx context.Context, audio []byte, language strin
|
|||||||
return &result, nil
|
return &result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Milvus Client ---
|
// ─── Milvus Client ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
// MilvusClient provides vector search via the Milvus HTTP/gRPC API.
|
// MilvusClient provides vector search via the Milvus HTTP/gRPC API.
|
||||||
// For the Go port we use the Milvus Go SDK.
|
// For the Go port we use the Milvus Go SDK.
|
||||||
|
|||||||
506
clients/clients_test.go
Normal file
506
clients/clients_test.go
Normal file
@@ -0,0 +1,506 @@
|
|||||||
|
package clients
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Shared infrastructure tests
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestSharedTransport(t *testing.T) {
|
||||||
|
// All clients created via newHTTPClient should share the same transport.
|
||||||
|
c1 := newHTTPClient("http://a:8000", 10*time.Second)
|
||||||
|
c2 := newHTTPClient("http://b:9000", 30*time.Second)
|
||||||
|
|
||||||
|
if c1.client.Transport != c2.client.Transport {
|
||||||
|
t.Error("clients should share the same http.Transport")
|
||||||
|
}
|
||||||
|
if c1.client.Transport != SharedTransport {
|
||||||
|
t.Error("transport should be the package-level SharedTransport")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBufferPoolGetPut(t *testing.T) {
|
||||||
|
buf := getBuf()
|
||||||
|
if buf == nil {
|
||||||
|
t.Fatal("getBuf returned nil")
|
||||||
|
}
|
||||||
|
if buf.Len() != 0 {
|
||||||
|
t.Error("getBuf should return a reset buffer")
|
||||||
|
}
|
||||||
|
buf.WriteString("hello")
|
||||||
|
putBuf(buf)
|
||||||
|
|
||||||
|
// On re-get, buffer should be reset.
|
||||||
|
buf2 := getBuf()
|
||||||
|
if buf2.Len() != 0 {
|
||||||
|
t.Error("re-acquired buffer should be reset")
|
||||||
|
}
|
||||||
|
putBuf(buf2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBufferPoolOversizedDiscarded(t *testing.T) {
|
||||||
|
buf := getBuf()
|
||||||
|
// Grow beyond 1 MB threshold.
|
||||||
|
buf.Write(make([]byte, 2<<20))
|
||||||
|
putBuf(buf) // should silently discard
|
||||||
|
|
||||||
|
// Pool should still work — we get a fresh one.
|
||||||
|
buf2 := getBuf()
|
||||||
|
if buf2.Len() != 0 {
|
||||||
|
t.Error("should get a fresh buffer")
|
||||||
|
}
|
||||||
|
putBuf(buf2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBufferPoolConcurrency(t *testing.T) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := range 100 {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(n int) {
|
||||||
|
defer wg.Done()
|
||||||
|
buf := getBuf()
|
||||||
|
buf.WriteString(strings.Repeat("x", n))
|
||||||
|
putBuf(buf)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Embeddings client
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestEmbeddingsClient_Embed(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/embeddings" {
|
||||||
|
t.Errorf("path = %q, want /embeddings", r.URL.Path)
|
||||||
|
}
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Errorf("method = %s, want POST", r.Method)
|
||||||
|
}
|
||||||
|
var req map[string]any
|
||||||
|
json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
input, _ := req["input"].([]any)
|
||||||
|
if len(input) != 2 {
|
||||||
|
t.Errorf("input len = %d, want 2", len(input))
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"data": []map[string]any{
|
||||||
|
{"embedding": []float64{0.1, 0.2, 0.3}},
|
||||||
|
{"embedding": []float64{0.4, 0.5, 0.6}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "bge")
|
||||||
|
results, err := c.Embed(context.Background(), []string{"hello", "world"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Fatalf("len(results) = %d, want 2", len(results))
|
||||||
|
}
|
||||||
|
if results[0][0] != 0.1 {
|
||||||
|
t.Errorf("results[0][0] = %f, want 0.1", results[0][0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddingsClient_EmbedSingle(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"data": []map[string]any{
|
||||||
|
{"embedding": []float64{1.0, 2.0}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
|
||||||
|
vec, err := c.EmbedSingle(context.Background(), "test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(vec) != 2 || vec[0] != 1.0 {
|
||||||
|
t.Errorf("vec = %v", vec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddingsClient_EmbedEmpty(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{"data": []any{}})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
|
||||||
|
_, err := c.EmbedSingle(context.Background(), "test")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for empty embedding")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddingsClient_Health(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/health" {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(404)
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
|
||||||
|
if !c.Health(context.Background()) {
|
||||||
|
t.Error("expected healthy")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Reranker client
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestRerankerClient_Rerank(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req map[string]any
|
||||||
|
json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
if req["query"] != "test query" {
|
||||||
|
t.Errorf("query = %v", req["query"])
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"results": []map[string]any{
|
||||||
|
{"index": 1, "relevance_score": 0.95},
|
||||||
|
{"index": 0, "relevance_score": 0.80},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewRerankerClient(ts.URL, 5*time.Second)
|
||||||
|
docs := []string{"Paris is great", "France is in Europe"}
|
||||||
|
results, err := c.Rerank(context.Background(), "test query", docs, 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Fatalf("len = %d", len(results))
|
||||||
|
}
|
||||||
|
if results[0].Score != 0.95 {
|
||||||
|
t.Errorf("score = %f, want 0.95", results[0].Score)
|
||||||
|
}
|
||||||
|
if results[0].Document != "France is in Europe" {
|
||||||
|
t.Errorf("document = %q", results[0].Document)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRerankerClient_RerankFallbackScore(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"results": []map[string]any{
|
||||||
|
{"index": 0, "score": 0.77, "relevance_score": 0}, // some APIs only set score
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewRerankerClient(ts.URL, 5*time.Second)
|
||||||
|
results, err := c.Rerank(context.Background(), "q", []string{"doc1"}, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if results[0].Score != 0.77 {
|
||||||
|
t.Errorf("fallback score = %f, want 0.77", results[0].Score)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// LLM client
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestLLMClient_Generate(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/v1/chat/completions" {
|
||||||
|
t.Errorf("path = %q", r.URL.Path)
|
||||||
|
}
|
||||||
|
var req map[string]any
|
||||||
|
json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
msgs, _ := req["messages"].([]any)
|
||||||
|
if len(msgs) == 0 {
|
||||||
|
t.Error("no messages in request")
|
||||||
|
}
|
||||||
|
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"choices": []map[string]any{
|
||||||
|
{"message": map[string]any{"content": "Paris is the capital of France."}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
result, err := c.Generate(context.Background(), "capital of France?", "", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result != "Paris is the capital of France." {
|
||||||
|
t.Errorf("result = %q", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_GenerateWithContext(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req map[string]any
|
||||||
|
json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
msgs, _ := req["messages"].([]any)
|
||||||
|
// Should have system + user message
|
||||||
|
if len(msgs) != 2 {
|
||||||
|
t.Errorf("expected 2 messages, got %d", len(msgs))
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"choices": []map[string]any{
|
||||||
|
{"message": map[string]any{"content": "answer with context"}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
result, err := c.Generate(context.Background(), "question", "some context", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result != "answer with context" {
|
||||||
|
t.Errorf("result = %q", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_GenerateNoChoices(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{"choices": []any{}})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
_, err := c.Generate(context.Background(), "q", "", "")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for empty choices")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// TTS client
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestTTSClient_Synthesize(t *testing.T) {
|
||||||
|
expected := []byte{0xDE, 0xAD, 0xBE, 0xEF}
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/api/tts" {
|
||||||
|
t.Errorf("path = %q", r.URL.Path)
|
||||||
|
}
|
||||||
|
if r.URL.Query().Get("text") != "hello world" {
|
||||||
|
t.Errorf("text = %q", r.URL.Query().Get("text"))
|
||||||
|
}
|
||||||
|
w.Write(expected)
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewTTSClient(ts.URL, 5*time.Second, "en")
|
||||||
|
audio, err := c.Synthesize(context.Background(), "hello world", "", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(audio, expected) {
|
||||||
|
t.Errorf("audio = %x, want %x", audio, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTTSClient_SynthesizeWithSpeaker(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Query().Get("speaker_id") != "alice" {
|
||||||
|
t.Errorf("speaker_id = %q", r.URL.Query().Get("speaker_id"))
|
||||||
|
}
|
||||||
|
w.Write([]byte{0x01})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewTTSClient(ts.URL, 5*time.Second, "en")
|
||||||
|
_, err := c.Synthesize(context.Background(), "hi", "en", "alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// STT client
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestSTTClient_Transcribe(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/v1/audio/transcriptions" {
|
||||||
|
t.Errorf("path = %q", r.URL.Path)
|
||||||
|
}
|
||||||
|
ct := r.Header.Get("Content-Type")
|
||||||
|
if !strings.Contains(ct, "multipart/form-data") {
|
||||||
|
t.Errorf("content-type = %q", ct)
|
||||||
|
}
|
||||||
|
// Verify the audio file is present.
|
||||||
|
file, _, err := r.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
data, _ := io.ReadAll(file)
|
||||||
|
if len(data) != 100 {
|
||||||
|
t.Errorf("file size = %d, want 100", len(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"text": "hello world"})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewSTTClient(ts.URL, 5*time.Second)
|
||||||
|
result, err := c.Transcribe(context.Background(), make([]byte, 100), "en")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result.Text != "hello world" {
|
||||||
|
t.Errorf("text = %q", result.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSTTClient_TranscribeTranslate(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/v1/audio/translations" {
|
||||||
|
t.Errorf("path = %q, want /v1/audio/translations", r.URL.Path)
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"text": "translated"})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewSTTClient(ts.URL, 5*time.Second)
|
||||||
|
c.Task = "translate"
|
||||||
|
result, err := c.Transcribe(context.Background(), []byte{0x01}, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result.Text != "translated" {
|
||||||
|
t.Errorf("text = %q", result.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// HTTP error handling
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestHTTPError4xx(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(422)
|
||||||
|
w.Write([]byte(`{"error": "bad input"}`))
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
|
||||||
|
_, err := c.Embed(context.Background(), []string{"test"})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 422")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "422") {
|
||||||
|
t.Errorf("error should contain status code: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPError5xx(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(500)
|
||||||
|
w.Write([]byte("internal server error"))
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
_, err := c.Generate(context.Background(), "q", "", "")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 500")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// buildMessages helper
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestBuildMessages(t *testing.T) {
|
||||||
|
// No context, no system prompt → just user message
|
||||||
|
msgs := buildMessages("hello", "", "")
|
||||||
|
if len(msgs) != 1 || msgs[0].Role != "user" {
|
||||||
|
t.Errorf("expected 1 user msg, got %+v", msgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// With system prompt
|
||||||
|
msgs = buildMessages("hello", "", "You are helpful")
|
||||||
|
if len(msgs) != 2 || msgs[0].Role != "system" || msgs[0].Content != "You are helpful" {
|
||||||
|
t.Errorf("expected system+user, got %+v", msgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// With context, no system prompt → auto system prompt
|
||||||
|
msgs = buildMessages("question", "some context", "")
|
||||||
|
if len(msgs) != 2 || msgs[0].Role != "system" {
|
||||||
|
t.Errorf("expected auto system+user, got %+v", msgs)
|
||||||
|
}
|
||||||
|
if !strings.Contains(msgs[1].Content, "Context:") {
|
||||||
|
t.Error("user message should contain context")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Benchmarks: pooled buffer vs direct allocation
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func BenchmarkPostJSON(b *testing.B) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
io.Copy(io.Discard, r.Body)
|
||||||
|
w.Write([]byte(`{"ok":true}`))
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := newHTTPClient(ts.URL, 10*time.Second)
|
||||||
|
ctx := context.Background()
|
||||||
|
payload := map[string]any{
|
||||||
|
"text": strings.Repeat("x", 1024),
|
||||||
|
"count": 42,
|
||||||
|
"enabled": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
c.postJSON(ctx, "/test", payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkBufferPool(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
buf := getBuf()
|
||||||
|
buf.WriteString(strings.Repeat("x", 4096))
|
||||||
|
putBuf(buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkBufferPoolParallel(b *testing.B) {
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
buf := getBuf()
|
||||||
|
buf.WriteString(strings.Repeat("x", 4096))
|
||||||
|
putBuf(buf)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
169
config/config.go
169
config/config.go
@@ -1,66 +1,80 @@
|
|||||||
// Package config provides environment-based configuration for handler services.
|
// Package config provides environment-based configuration for handler services
|
||||||
|
// with optional live reload of secrets and service endpoints.
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/fsnotify/fsnotify"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Settings holds base configuration for all handler services.
|
// Settings holds base configuration for all handler services.
|
||||||
// Values are loaded from environment variables with sensible defaults.
|
// Fields in the "hot-reload" section are protected by a RWMutex and can be
|
||||||
|
// updated at runtime via WatchSecrets(). All other fields are immutable
|
||||||
|
// after Load() returns.
|
||||||
type Settings struct {
|
type Settings struct {
|
||||||
// Service identification
|
// Service identification (immutable)
|
||||||
ServiceName string
|
ServiceName string
|
||||||
ServiceVersion string
|
ServiceVersion string
|
||||||
ServiceNamespace string
|
ServiceNamespace string
|
||||||
DeploymentEnv string
|
DeploymentEnv string
|
||||||
|
|
||||||
// NATS configuration
|
// NATS configuration (immutable)
|
||||||
NATSURL string
|
NATSURL string
|
||||||
NATSUser string
|
NATSUser string
|
||||||
NATSPassword string
|
NATSPassword string
|
||||||
NATSQueueGroup string
|
NATSQueueGroup string
|
||||||
|
|
||||||
// Redis/Valkey configuration
|
// Redis/Valkey configuration (immutable)
|
||||||
RedisURL string
|
RedisURL string
|
||||||
RedisPassword string
|
RedisPassword string
|
||||||
|
|
||||||
// Milvus configuration
|
// Milvus configuration (immutable)
|
||||||
MilvusHost string
|
MilvusHost string
|
||||||
MilvusPort int
|
MilvusPort int
|
||||||
MilvusCollection string
|
MilvusCollection string
|
||||||
|
|
||||||
// Service endpoints
|
// OpenTelemetry configuration (immutable)
|
||||||
EmbeddingsURL string
|
|
||||||
RerankerURL string
|
|
||||||
LLMURL string
|
|
||||||
TTSURL string
|
|
||||||
STTURL string
|
|
||||||
|
|
||||||
// OpenTelemetry configuration
|
|
||||||
OTELEnabled bool
|
OTELEnabled bool
|
||||||
OTELEndpoint string
|
OTELEndpoint string
|
||||||
OTELUseHTTP bool
|
OTELUseHTTP bool
|
||||||
|
|
||||||
// HyperDX configuration
|
// HyperDX configuration (immutable)
|
||||||
HyperDXEnabled bool
|
HyperDXEnabled bool
|
||||||
HyperDXAPIKey string
|
HyperDXAPIKey string
|
||||||
HyperDXEndpoint string
|
HyperDXEndpoint string
|
||||||
|
|
||||||
// MLflow configuration
|
// MLflow configuration (immutable)
|
||||||
MLflowTrackingURI string
|
MLflowTrackingURI string
|
||||||
MLflowExperimentName string
|
MLflowExperimentName string
|
||||||
MLflowEnabled bool
|
MLflowEnabled bool
|
||||||
|
|
||||||
// Health check configuration
|
// Health check configuration (immutable)
|
||||||
HealthPort int
|
HealthPort int
|
||||||
HealthPath string
|
HealthPath string
|
||||||
ReadyPath string
|
ReadyPath string
|
||||||
|
|
||||||
// Timeouts
|
// Timeouts (immutable)
|
||||||
HTTPTimeout time.Duration
|
HTTPTimeout time.Duration
|
||||||
NATSTimeout time.Duration
|
NATSTimeout time.Duration
|
||||||
|
|
||||||
|
// Hot-reloadable fields — access via getter methods.
|
||||||
|
mu sync.RWMutex
|
||||||
|
embeddingsURL string
|
||||||
|
rerankerURL string
|
||||||
|
llmURL string
|
||||||
|
ttsURL string
|
||||||
|
sttURL string
|
||||||
|
|
||||||
|
// Secrets path for file-based hot reload (Kubernetes secret mounts)
|
||||||
|
SecretsPath string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load creates a Settings populated from environment variables with defaults.
|
// Load creates a Settings populated from environment variables with defaults.
|
||||||
@@ -83,11 +97,11 @@ func Load() *Settings {
|
|||||||
MilvusPort: getEnvInt("MILVUS_PORT", 19530),
|
MilvusPort: getEnvInt("MILVUS_PORT", 19530),
|
||||||
MilvusCollection: getEnv("MILVUS_COLLECTION", "documents"),
|
MilvusCollection: getEnv("MILVUS_COLLECTION", "documents"),
|
||||||
|
|
||||||
EmbeddingsURL: getEnv("EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"),
|
embeddingsURL: getEnv("EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"),
|
||||||
RerankerURL: getEnv("RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"),
|
rerankerURL: getEnv("RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"),
|
||||||
LLMURL: getEnv("LLM_URL", "http://vllm-predictor.ai-ml.svc.cluster.local"),
|
llmURL: getEnv("LLM_URL", "http://vllm-predictor.ai-ml.svc.cluster.local"),
|
||||||
TTSURL: getEnv("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local"),
|
ttsURL: getEnv("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local"),
|
||||||
STTURL: getEnv("STT_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"),
|
sttURL: getEnv("STT_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"),
|
||||||
|
|
||||||
OTELEnabled: getEnvBool("OTEL_ENABLED", true),
|
OTELEnabled: getEnvBool("OTEL_ENABLED", true),
|
||||||
OTELEndpoint: getEnv("OTEL_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317"),
|
OTELEndpoint: getEnv("OTEL_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317"),
|
||||||
@@ -107,6 +121,115 @@ func Load() *Settings {
|
|||||||
|
|
||||||
HTTPTimeout: getEnvDuration("HTTP_TIMEOUT", 60*time.Second),
|
HTTPTimeout: getEnvDuration("HTTP_TIMEOUT", 60*time.Second),
|
||||||
NATSTimeout: getEnvDuration("NATS_TIMEOUT", 30*time.Second),
|
NATSTimeout: getEnvDuration("NATS_TIMEOUT", 30*time.Second),
|
||||||
|
|
||||||
|
SecretsPath: getEnv("SECRETS_PATH", ""),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbeddingsURL returns the current embeddings service URL (thread-safe).
|
||||||
|
func (s *Settings) EmbeddingsURL() string {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.embeddingsURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// RerankerURL returns the current reranker service URL (thread-safe).
|
||||||
|
func (s *Settings) RerankerURL() string {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.rerankerURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// LLMURL returns the current LLM service URL (thread-safe).
|
||||||
|
func (s *Settings) LLMURL() string {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.llmURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTSURL returns the current TTS service URL (thread-safe).
|
||||||
|
func (s *Settings) TTSURL() string {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.ttsURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// STTURL returns the current STT service URL (thread-safe).
|
||||||
|
func (s *Settings) STTURL() string {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.sttURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// WatchSecrets watches the SecretsPath directory for changes and reloads
|
||||||
|
// hot-reloadable fields. Blocks until ctx is cancelled.
|
||||||
|
func (s *Settings) WatchSecrets(ctx context.Context) {
|
||||||
|
if s.SecretsPath == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
watcher, err := fsnotify.NewWatcher()
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("config: failed to create fsnotify watcher", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = watcher.Close() }()
|
||||||
|
|
||||||
|
if err := watcher.Add(s.SecretsPath); err != nil {
|
||||||
|
slog.Error("config: failed to watch secrets path", "error", err, "path", s.SecretsPath)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("config: watching secrets for hot reload", "path", s.SecretsPath)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case event, ok := <-watcher.Events:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if event.Has(fsnotify.Create) || event.Has(fsnotify.Write) {
|
||||||
|
s.reloadFromSecrets()
|
||||||
|
}
|
||||||
|
case err, ok := <-watcher.Errors:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
slog.Error("config: fsnotify error", "error", err)
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reloadFromSecrets reads hot-reloadable values from the secrets directory.
|
||||||
|
func (s *Settings) reloadFromSecrets() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
updated := 0
|
||||||
|
reload := func(filename string, target *string) {
|
||||||
|
path := filepath.Join(s.SecretsPath, filename)
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
val := strings.TrimSpace(string(data))
|
||||||
|
if val != "" && val != *target {
|
||||||
|
*target = val
|
||||||
|
updated++
|
||||||
|
slog.Info("config: reloaded secret", "key", filename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reload("embeddings-url", &s.embeddingsURL)
|
||||||
|
reload("reranker-url", &s.rerankerURL)
|
||||||
|
reload("llm-url", &s.llmURL)
|
||||||
|
reload("tts-url", &s.ttsURL)
|
||||||
|
reload("stt-url", &s.sttURL)
|
||||||
|
|
||||||
|
if updated > 0 {
|
||||||
|
slog.Info("config: secrets reloaded", "updated", updated)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -20,14 +21,9 @@ func TestLoadDefaults(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadFromEnv(t *testing.T) {
|
func TestLoadFromEnv(t *testing.T) {
|
||||||
os.Setenv("SERVICE_NAME", "test-svc")
|
t.Setenv("SERVICE_NAME", "test-svc")
|
||||||
os.Setenv("HEALTH_PORT", "9090")
|
t.Setenv("HEALTH_PORT", "9090")
|
||||||
os.Setenv("OTEL_ENABLED", "false")
|
t.Setenv("OTEL_ENABLED", "false")
|
||||||
defer func() {
|
|
||||||
os.Unsetenv("SERVICE_NAME")
|
|
||||||
os.Unsetenv("HEALTH_PORT")
|
|
||||||
os.Unsetenv("OTEL_ENABLED")
|
|
||||||
}()
|
|
||||||
|
|
||||||
s := Load()
|
s := Load()
|
||||||
if s.ServiceName != "test-svc" {
|
if s.ServiceName != "test-svc" {
|
||||||
@@ -40,3 +36,88 @@ func TestLoadFromEnv(t *testing.T) {
|
|||||||
t.Error("expected OTELEnabled false")
|
t.Error("expected OTELEnabled false")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestURLGetters(t *testing.T) {
|
||||||
|
s := Load()
|
||||||
|
if s.EmbeddingsURL() == "" {
|
||||||
|
t.Error("EmbeddingsURL should have a default")
|
||||||
|
}
|
||||||
|
if s.RerankerURL() == "" {
|
||||||
|
t.Error("RerankerURL should have a default")
|
||||||
|
}
|
||||||
|
if s.LLMURL() == "" {
|
||||||
|
t.Error("LLMURL should have a default")
|
||||||
|
}
|
||||||
|
if s.TTSURL() == "" {
|
||||||
|
t.Error("TTSURL should have a default")
|
||||||
|
}
|
||||||
|
if s.STTURL() == "" {
|
||||||
|
t.Error("STTURL should have a default")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestURLGettersFromEnv(t *testing.T) {
|
||||||
|
t.Setenv("EMBEDDINGS_URL", "http://embed:8000")
|
||||||
|
t.Setenv("LLM_URL", "http://llm:9000")
|
||||||
|
|
||||||
|
s := Load()
|
||||||
|
if s.EmbeddingsURL() != "http://embed:8000" {
|
||||||
|
t.Errorf("expected custom EmbeddingsURL, got %q", s.EmbeddingsURL())
|
||||||
|
}
|
||||||
|
if s.LLMURL() != "http://llm:9000" {
|
||||||
|
t.Errorf("expected custom LLMURL, got %q", s.LLMURL())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReloadFromSecrets(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
|
||||||
|
// Write initial secret files
|
||||||
|
writeSecret(t, dir, "embeddings-url", "http://old-embed:8000")
|
||||||
|
writeSecret(t, dir, "llm-url", "http://old-llm:9000")
|
||||||
|
|
||||||
|
s := Load()
|
||||||
|
s.SecretsPath = dir
|
||||||
|
s.reloadFromSecrets()
|
||||||
|
|
||||||
|
if s.EmbeddingsURL() != "http://old-embed:8000" {
|
||||||
|
t.Errorf("expected reloaded EmbeddingsURL, got %q", s.EmbeddingsURL())
|
||||||
|
}
|
||||||
|
if s.LLMURL() != "http://old-llm:9000" {
|
||||||
|
t.Errorf("expected reloaded LLMURL, got %q", s.LLMURL())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate secret update
|
||||||
|
writeSecret(t, dir, "embeddings-url", "http://new-embed:8000")
|
||||||
|
s.reloadFromSecrets()
|
||||||
|
|
||||||
|
if s.EmbeddingsURL() != "http://new-embed:8000" {
|
||||||
|
t.Errorf("expected updated EmbeddingsURL, got %q", s.EmbeddingsURL())
|
||||||
|
}
|
||||||
|
// LLM should remain unchanged
|
||||||
|
if s.LLMURL() != "http://old-llm:9000" {
|
||||||
|
t.Errorf("expected unchanged LLMURL, got %q", s.LLMURL())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReloadFromSecretsNoPath(t *testing.T) {
|
||||||
|
s := Load()
|
||||||
|
s.SecretsPath = ""
|
||||||
|
// Should not panic
|
||||||
|
s.reloadFromSecrets()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetEnvDuration(t *testing.T) {
|
||||||
|
t.Setenv("TEST_DUR", "30")
|
||||||
|
d := getEnvDuration("TEST_DUR", 10*time.Second)
|
||||||
|
if d != 30*time.Second {
|
||||||
|
t.Errorf("expected 30s, got %v", d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeSecret(t *testing.T, dir, name, value string) {
|
||||||
|
t.Helper()
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, name), []byte(value), 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -17,6 +17,7 @@ require (
|
|||||||
require (
|
require (
|
||||||
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
|
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||||
github.com/go-logr/logr v1.4.3 // indirect
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
github.com/go-logr/stdr v1.2.2 // indirect
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
github.com/google/uuid v1.6.0 // indirect
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -4,6 +4,8 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF
|
|||||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||||
|
|||||||
201
handler/handler_test.go
Normal file
201
handler/handler_test.go
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/nats-io/nats.go"
|
||||||
|
"github.com/vmihailenco/msgpack/v5"
|
||||||
|
|
||||||
|
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Handler construction tests
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestNewHandler(t *testing.T) {
|
||||||
|
cfg := config.Load()
|
||||||
|
cfg.ServiceName = "test-handler"
|
||||||
|
cfg.NATSQueueGroup = "test-group"
|
||||||
|
|
||||||
|
h := New("ai.test.subject", cfg)
|
||||||
|
if h.Subject != "ai.test.subject" {
|
||||||
|
t.Errorf("Subject = %q", h.Subject)
|
||||||
|
}
|
||||||
|
if h.QueueGroup != "test-group" {
|
||||||
|
t.Errorf("QueueGroup = %q", h.QueueGroup)
|
||||||
|
}
|
||||||
|
if h.Settings.ServiceName != "test-handler" {
|
||||||
|
t.Errorf("ServiceName = %q", h.Settings.ServiceName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHandlerNilSettings(t *testing.T) {
|
||||||
|
h := New("ai.test", nil)
|
||||||
|
if h.Settings == nil {
|
||||||
|
t.Fatal("Settings should be loaded automatically")
|
||||||
|
}
|
||||||
|
if h.Settings.ServiceName != "handler" {
|
||||||
|
t.Errorf("ServiceName = %q, want default", h.Settings.ServiceName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallbackRegistration(t *testing.T) {
|
||||||
|
cfg := config.Load()
|
||||||
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
|
setupCalled := false
|
||||||
|
h.OnSetup(func(ctx context.Context) error {
|
||||||
|
setupCalled = true
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
teardownCalled := false
|
||||||
|
h.OnTeardown(func(ctx context.Context) error {
|
||||||
|
teardownCalled = true
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
||||||
|
return nil, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if h.onSetup == nil || h.onTeardown == nil || h.onMessage == nil {
|
||||||
|
t.Error("callbacks should not be nil after registration")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify setup/teardown work when called directly.
|
||||||
|
h.onSetup(context.Background())
|
||||||
|
h.onTeardown(context.Background())
|
||||||
|
if !setupCalled || !teardownCalled {
|
||||||
|
t.Error("callbacks should have been invoked")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// wrapHandler dispatch tests (unit test the message decode + dispatch logic)
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestWrapHandler_ValidMessage(t *testing.T) {
|
||||||
|
cfg := config.Load()
|
||||||
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
|
var receivedData map[string]any
|
||||||
|
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
||||||
|
receivedData = data
|
||||||
|
return map[string]any{"status": "ok"}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Encode a message the same way services would.
|
||||||
|
payload := map[string]any{
|
||||||
|
"request_id": "test-001",
|
||||||
|
"message": "hello",
|
||||||
|
"premium": true,
|
||||||
|
}
|
||||||
|
encoded, err := msgpack.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call wrapHandler directly without NATS.
|
||||||
|
handler := h.wrapHandler(context.Background())
|
||||||
|
handler(&nats.Msg{
|
||||||
|
Subject: "ai.test.user.42.message",
|
||||||
|
Data: encoded,
|
||||||
|
})
|
||||||
|
|
||||||
|
if receivedData == nil {
|
||||||
|
t.Fatal("handler was not called")
|
||||||
|
}
|
||||||
|
if receivedData["request_id"] != "test-001" {
|
||||||
|
t.Errorf("request_id = %v", receivedData["request_id"])
|
||||||
|
}
|
||||||
|
if receivedData["premium"] != true {
|
||||||
|
t.Errorf("premium = %v", receivedData["premium"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapHandler_InvalidMsgpack(t *testing.T) {
|
||||||
|
cfg := config.Load()
|
||||||
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
|
handlerCalled := false
|
||||||
|
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
||||||
|
handlerCalled = true
|
||||||
|
return nil, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := h.wrapHandler(context.Background())
|
||||||
|
handler(&nats.Msg{
|
||||||
|
Subject: "ai.test",
|
||||||
|
Data: []byte{0xFF, 0xFE, 0xFD}, // invalid msgpack
|
||||||
|
})
|
||||||
|
|
||||||
|
if handlerCalled {
|
||||||
|
t.Error("handler should not be called for invalid msgpack")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapHandler_HandlerError(t *testing.T) {
|
||||||
|
cfg := config.Load()
|
||||||
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
|
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
||||||
|
return nil, context.DeadlineExceeded
|
||||||
|
})
|
||||||
|
|
||||||
|
encoded, _ := msgpack.Marshal(map[string]any{"key": "val"})
|
||||||
|
handler := h.wrapHandler(context.Background())
|
||||||
|
|
||||||
|
// Should not panic even when handler returns error.
|
||||||
|
handler(&nats.Msg{
|
||||||
|
Subject: "ai.test",
|
||||||
|
Data: encoded,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapHandler_NilResponse(t *testing.T) {
|
||||||
|
cfg := config.Load()
|
||||||
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
|
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
||||||
|
return nil, nil // fire-and-forget style
|
||||||
|
})
|
||||||
|
|
||||||
|
encoded, _ := msgpack.Marshal(map[string]any{"x": 1})
|
||||||
|
handler := h.wrapHandler(context.Background())
|
||||||
|
|
||||||
|
// Should not panic with nil response and no reply subject.
|
||||||
|
handler(&nats.Msg{
|
||||||
|
Subject: "ai.test",
|
||||||
|
Data: encoded,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Benchmark: message decode + dispatch overhead
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func BenchmarkWrapHandler(b *testing.B) {
|
||||||
|
cfg := config.Load()
|
||||||
|
h := New("ai.test", cfg)
|
||||||
|
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
||||||
|
return map[string]any{"ok": true}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
payload := map[string]any{
|
||||||
|
"request_id": "bench-001",
|
||||||
|
"message": "What is the capital of France?",
|
||||||
|
"premium": true,
|
||||||
|
"top_k": 10,
|
||||||
|
}
|
||||||
|
encoded, _ := msgpack.Marshal(payload)
|
||||||
|
handler := h.wrapHandler(context.Background())
|
||||||
|
msg := &nats.Msg{Subject: "ai.test", Data: encoded}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
handler(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
515
messages/bench_test.go
Normal file
515
messages/bench_test.go
Normal file
@@ -0,0 +1,515 @@
|
|||||||
|
// Package messages benchmarks compare three serialization strategies:
|
||||||
|
//
|
||||||
|
// 1. msgpack map[string]any — the old approach (dynamic, no types)
|
||||||
|
// 2. msgpack typed struct — the new approach (compile-time safe, short keys)
|
||||||
|
// 3. protobuf — optional future migration
|
||||||
|
//
|
||||||
|
// Run with:
|
||||||
|
//
|
||||||
|
// go test -bench=. -benchmem -count=5 ./messages/... | tee bench.txt
|
||||||
|
// # optional: go install golang.org/x/perf/cmd/benchstat@latest && benchstat bench.txt
|
||||||
|
package messages
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/vmihailenco/msgpack/v5"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
|
pb "git.daviestechlabs.io/daviestechlabs/handler-base/messages/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Test fixtures — equivalent data across all three encodings
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// chatRequestMap is the legacy map[string]any representation.
|
||||||
|
func chatRequestMap() map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"request_id": "req-abc-123",
|
||||||
|
"user_id": "user-42",
|
||||||
|
"message": "What is the capital of France?",
|
||||||
|
"query": "",
|
||||||
|
"premium": true,
|
||||||
|
"enable_rag": true,
|
||||||
|
"enable_reranker": true,
|
||||||
|
"enable_streaming": false,
|
||||||
|
"top_k": 10,
|
||||||
|
"collection": "documents",
|
||||||
|
"enable_tts": false,
|
||||||
|
"system_prompt": "You are a helpful assistant.",
|
||||||
|
"response_subject": "ai.chat.response.req-abc-123",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatRequestStruct is the typed struct representation.
|
||||||
|
func chatRequestStruct() ChatRequest {
|
||||||
|
return ChatRequest{
|
||||||
|
RequestID: "req-abc-123",
|
||||||
|
UserID: "user-42",
|
||||||
|
Message: "What is the capital of France?",
|
||||||
|
Premium: true,
|
||||||
|
EnableRAG: true,
|
||||||
|
EnableReranker: true,
|
||||||
|
TopK: 10,
|
||||||
|
Collection: "documents",
|
||||||
|
SystemPrompt: "You are a helpful assistant.",
|
||||||
|
ResponseSubject: "ai.chat.response.req-abc-123",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatRequestProto is the protobuf representation.
|
||||||
|
func chatRequestProto() *pb.ChatRequest {
|
||||||
|
return &pb.ChatRequest{
|
||||||
|
RequestId: "req-abc-123",
|
||||||
|
UserId: "user-42",
|
||||||
|
Message: "What is the capital of France?",
|
||||||
|
Premium: true,
|
||||||
|
EnableRag: true,
|
||||||
|
EnableReranker: true,
|
||||||
|
TopK: 10,
|
||||||
|
Collection: "documents",
|
||||||
|
SystemPrompt: "You are a helpful assistant.",
|
||||||
|
ResponseSubject: "ai.chat.response.req-abc-123",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// voiceResponseMap is a voice response with a 16 KB audio payload.
|
||||||
|
func voiceResponseMap() map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"request_id": "vr-001",
|
||||||
|
"response": "The capital of France is Paris.",
|
||||||
|
"audio": make([]byte, 16384),
|
||||||
|
"transcription": "What is the capital of France?",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func voiceResponseStruct() VoiceResponse {
|
||||||
|
return VoiceResponse{
|
||||||
|
RequestID: "vr-001",
|
||||||
|
Response: "The capital of France is Paris.",
|
||||||
|
Audio: make([]byte, 16384),
|
||||||
|
Transcription: "What is the capital of France?",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func voiceResponseProto() *pb.VoiceResponse {
|
||||||
|
return &pb.VoiceResponse{
|
||||||
|
RequestId: "vr-001",
|
||||||
|
Response: "The capital of France is Paris.",
|
||||||
|
Audio: make([]byte, 16384),
|
||||||
|
Transcription: "What is the capital of France?",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ttsChunkMap simulates a streaming audio chunk (~32 KB).
|
||||||
|
func ttsChunkMap() map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"session_id": "tts-sess-99",
|
||||||
|
"chunk_index": 3,
|
||||||
|
"total_chunks": 12,
|
||||||
|
"audio_b64": string(make([]byte, 32768)), // old: base64 string
|
||||||
|
"is_last": false,
|
||||||
|
"timestamp": time.Now().Unix(),
|
||||||
|
"sample_rate": 24000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ttsChunkStruct() TTSAudioChunk {
|
||||||
|
return TTSAudioChunk{
|
||||||
|
SessionID: "tts-sess-99",
|
||||||
|
ChunkIndex: 3,
|
||||||
|
TotalChunks: 12,
|
||||||
|
Audio: make([]byte, 32768), // new: raw bytes
|
||||||
|
IsLast: false,
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
SampleRate: 24000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ttsChunkProto() *pb.TTSAudioChunk {
|
||||||
|
return &pb.TTSAudioChunk{
|
||||||
|
SessionId: "tts-sess-99",
|
||||||
|
ChunkIndex: 3,
|
||||||
|
TotalChunks: 12,
|
||||||
|
Audio: make([]byte, 32768),
|
||||||
|
IsLast: false,
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
SampleRate: 24000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Wire-size comparison (run once, printed by TestWireSize)
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestWireSize(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
mapData any
|
||||||
|
structVal any
|
||||||
|
protoMsg proto.Message
|
||||||
|
}{
|
||||||
|
{"ChatRequest", chatRequestMap(), chatRequestStruct(), chatRequestProto()},
|
||||||
|
{"VoiceResponse", voiceResponseMap(), voiceResponseStruct(), voiceResponseProto()},
|
||||||
|
{"TTSAudioChunk", ttsChunkMap(), ttsChunkStruct(), ttsChunkProto()},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
mapBytes, _ := msgpack.Marshal(tt.mapData)
|
||||||
|
structBytes, _ := msgpack.Marshal(tt.structVal)
|
||||||
|
protoBytes, _ := proto.Marshal(tt.protoMsg)
|
||||||
|
|
||||||
|
t.Logf("%-16s map=%5d B struct=%5d B proto=%5d B (struct saves %.0f%%, proto saves %.0f%%)",
|
||||||
|
tt.name,
|
||||||
|
len(mapBytes), len(structBytes), len(protoBytes),
|
||||||
|
100*(1-float64(len(structBytes))/float64(len(mapBytes))),
|
||||||
|
100*(1-float64(len(protoBytes))/float64(len(mapBytes))),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Encode benchmarks
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func BenchmarkEncode_ChatRequest_MsgpackMap(b *testing.B) {
|
||||||
|
data := chatRequestMap()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
msgpack.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncode_ChatRequest_MsgpackStruct(b *testing.B) {
|
||||||
|
data := chatRequestStruct()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
msgpack.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncode_ChatRequest_Protobuf(b *testing.B) {
|
||||||
|
data := chatRequestProto()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
proto.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncode_VoiceResponse_MsgpackMap(b *testing.B) {
|
||||||
|
data := voiceResponseMap()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
msgpack.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncode_VoiceResponse_MsgpackStruct(b *testing.B) {
|
||||||
|
data := voiceResponseStruct()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
msgpack.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncode_VoiceResponse_Protobuf(b *testing.B) {
|
||||||
|
data := voiceResponseProto()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
proto.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncode_TTSChunk_MsgpackMap(b *testing.B) {
|
||||||
|
data := ttsChunkMap()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
msgpack.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncode_TTSChunk_MsgpackStruct(b *testing.B) {
|
||||||
|
data := ttsChunkStruct()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
msgpack.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncode_TTSChunk_Protobuf(b *testing.B) {
|
||||||
|
data := ttsChunkProto()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
proto.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Decode benchmarks
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func BenchmarkDecode_ChatRequest_MsgpackMap(b *testing.B) {
|
||||||
|
encoded, _ := msgpack.Marshal(chatRequestMap())
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
var m map[string]any
|
||||||
|
msgpack.Unmarshal(encoded, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecode_ChatRequest_MsgpackStruct(b *testing.B) {
|
||||||
|
encoded, _ := msgpack.Marshal(chatRequestStruct())
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
var m ChatRequest
|
||||||
|
msgpack.Unmarshal(encoded, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecode_ChatRequest_Protobuf(b *testing.B) {
|
||||||
|
encoded, _ := proto.Marshal(chatRequestProto())
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
var m pb.ChatRequest
|
||||||
|
proto.Unmarshal(encoded, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecode_VoiceResponse_MsgpackMap(b *testing.B) {
|
||||||
|
encoded, _ := msgpack.Marshal(voiceResponseMap())
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
var m map[string]any
|
||||||
|
msgpack.Unmarshal(encoded, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecode_VoiceResponse_MsgpackStruct(b *testing.B) {
|
||||||
|
encoded, _ := msgpack.Marshal(voiceResponseStruct())
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
var m VoiceResponse
|
||||||
|
msgpack.Unmarshal(encoded, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecode_VoiceResponse_Protobuf(b *testing.B) {
|
||||||
|
encoded, _ := proto.Marshal(voiceResponseProto())
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
var m pb.VoiceResponse
|
||||||
|
proto.Unmarshal(encoded, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecode_TTSChunk_MsgpackMap(b *testing.B) {
|
||||||
|
encoded, _ := msgpack.Marshal(ttsChunkMap())
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
var m map[string]any
|
||||||
|
msgpack.Unmarshal(encoded, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecode_TTSChunk_MsgpackStruct(b *testing.B) {
|
||||||
|
encoded, _ := msgpack.Marshal(ttsChunkStruct())
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
var m TTSAudioChunk
|
||||||
|
msgpack.Unmarshal(encoded, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecode_TTSChunk_Protobuf(b *testing.B) {
|
||||||
|
encoded, _ := proto.Marshal(ttsChunkProto())
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
var m pb.TTSAudioChunk
|
||||||
|
proto.Unmarshal(encoded, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Roundtrip benchmarks (encode + decode)
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func BenchmarkRoundtrip_ChatRequest_MsgpackMap(b *testing.B) {
|
||||||
|
data := chatRequestMap()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
enc, _ := msgpack.Marshal(data)
|
||||||
|
var dec map[string]any
|
||||||
|
msgpack.Unmarshal(enc, &dec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRoundtrip_ChatRequest_MsgpackStruct(b *testing.B) {
|
||||||
|
data := chatRequestStruct()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
enc, _ := msgpack.Marshal(data)
|
||||||
|
var dec ChatRequest
|
||||||
|
msgpack.Unmarshal(enc, &dec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRoundtrip_ChatRequest_Protobuf(b *testing.B) {
|
||||||
|
data := chatRequestProto()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
enc, _ := proto.Marshal(data)
|
||||||
|
var dec pb.ChatRequest
|
||||||
|
proto.Unmarshal(enc, &dec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Typed struct unit tests — verify roundtrip correctness
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestRoundtrip_ChatRequest(t *testing.T) {
|
||||||
|
orig := chatRequestStruct()
|
||||||
|
data, err := msgpack.Marshal(orig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var dec ChatRequest
|
||||||
|
if err := msgpack.Unmarshal(data, &dec); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if dec.RequestID != orig.RequestID {
|
||||||
|
t.Errorf("RequestID = %q, want %q", dec.RequestID, orig.RequestID)
|
||||||
|
}
|
||||||
|
if dec.Message != orig.Message {
|
||||||
|
t.Errorf("Message = %q, want %q", dec.Message, orig.Message)
|
||||||
|
}
|
||||||
|
if dec.TopK != orig.TopK {
|
||||||
|
t.Errorf("TopK = %d, want %d", dec.TopK, orig.TopK)
|
||||||
|
}
|
||||||
|
if dec.Premium != orig.Premium {
|
||||||
|
t.Errorf("Premium = %v, want %v", dec.Premium, orig.Premium)
|
||||||
|
}
|
||||||
|
if dec.EffectiveQuery() != orig.Message {
|
||||||
|
t.Errorf("EffectiveQuery() = %q, want %q", dec.EffectiveQuery(), orig.Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundtrip_VoiceResponse(t *testing.T) {
|
||||||
|
orig := voiceResponseStruct()
|
||||||
|
data, err := msgpack.Marshal(orig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var dec VoiceResponse
|
||||||
|
if err := msgpack.Unmarshal(data, &dec); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if dec.RequestID != orig.RequestID {
|
||||||
|
t.Errorf("RequestID mismatch")
|
||||||
|
}
|
||||||
|
if len(dec.Audio) != len(orig.Audio) {
|
||||||
|
t.Errorf("Audio len = %d, want %d", len(dec.Audio), len(orig.Audio))
|
||||||
|
}
|
||||||
|
if dec.Transcription != orig.Transcription {
|
||||||
|
t.Errorf("Transcription mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundtrip_TTSAudioChunk(t *testing.T) {
|
||||||
|
orig := ttsChunkStruct()
|
||||||
|
data, err := msgpack.Marshal(orig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var dec TTSAudioChunk
|
||||||
|
if err := msgpack.Unmarshal(data, &dec); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if dec.SessionID != orig.SessionID {
|
||||||
|
t.Errorf("SessionID mismatch")
|
||||||
|
}
|
||||||
|
if dec.ChunkIndex != orig.ChunkIndex {
|
||||||
|
t.Errorf("ChunkIndex = %d, want %d", dec.ChunkIndex, orig.ChunkIndex)
|
||||||
|
}
|
||||||
|
if len(dec.Audio) != len(orig.Audio) {
|
||||||
|
t.Errorf("Audio len = %d, want %d", len(dec.Audio), len(orig.Audio))
|
||||||
|
}
|
||||||
|
if dec.SampleRate != orig.SampleRate {
|
||||||
|
t.Errorf("SampleRate = %d, want %d", dec.SampleRate, orig.SampleRate)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundtrip_PipelineTrigger(t *testing.T) {
|
||||||
|
orig := PipelineTrigger{
|
||||||
|
RequestID: "pip-001",
|
||||||
|
Pipeline: "document-ingestion",
|
||||||
|
Parameters: map[string]any{"source": "s3://bucket/data"},
|
||||||
|
}
|
||||||
|
data, err := msgpack.Marshal(orig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var dec PipelineTrigger
|
||||||
|
if err := msgpack.Unmarshal(data, &dec); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if dec.Pipeline != orig.Pipeline {
|
||||||
|
t.Errorf("Pipeline = %q, want %q", dec.Pipeline, orig.Pipeline)
|
||||||
|
}
|
||||||
|
if dec.Parameters["source"] != orig.Parameters["source"] {
|
||||||
|
t.Errorf("Parameters[source] mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundtrip_STTTranscription(t *testing.T) {
|
||||||
|
orig := STTTranscription{
|
||||||
|
SessionID: "stt-001",
|
||||||
|
Transcript: "hello world",
|
||||||
|
Sequence: 5,
|
||||||
|
IsPartial: false,
|
||||||
|
IsFinal: true,
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
SpeakerID: "speaker-1",
|
||||||
|
HasVoiceActivity: true,
|
||||||
|
State: "listening",
|
||||||
|
}
|
||||||
|
data, err := msgpack.Marshal(orig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var dec STTTranscription
|
||||||
|
if err := msgpack.Unmarshal(data, &dec); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if dec.Transcript != orig.Transcript {
|
||||||
|
t.Errorf("Transcript = %q, want %q", dec.Transcript, orig.Transcript)
|
||||||
|
}
|
||||||
|
if dec.IsFinal != orig.IsFinal {
|
||||||
|
t.Error("IsFinal mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundtrip_ErrorResponse(t *testing.T) {
|
||||||
|
orig := ErrorResponse{Error: true, Message: "something broke", Type: "InternalError"}
|
||||||
|
data, err := msgpack.Marshal(orig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var dec ErrorResponse
|
||||||
|
if err := msgpack.Unmarshal(data, &dec); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !dec.Error || dec.Message != "something broke" || dec.Type != "InternalError" {
|
||||||
|
t.Errorf("ErrorResponse roundtrip mismatch: %+v", dec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTimestamp(t *testing.T) {
|
||||||
|
ts := Timestamp()
|
||||||
|
now := time.Now().Unix()
|
||||||
|
if ts < now-1 || ts > now+1 {
|
||||||
|
t.Errorf("Timestamp() = %d, expected ~%d", ts, now)
|
||||||
|
}
|
||||||
|
}
|
||||||
224
messages/messages.go
Normal file
224
messages/messages.go
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
// Package messages defines typed NATS message structs for all services.
|
||||||
|
//
|
||||||
|
// Using typed structs with short msgpack field tags instead of map[string]any
|
||||||
|
// provides compile-time safety, smaller wire size (integer-like short keys vs
|
||||||
|
// full string keys), and faster encode/decode by avoiding interface{} boxing.
|
||||||
|
//
|
||||||
|
// Audio data uses raw []byte instead of base64-encoded strings — msgpack
|
||||||
|
// supports binary natively, eliminating the 33% base64 overhead.
|
||||||
|
package messages
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Pipeline Bridge
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// PipelineTrigger is the request to start a pipeline.
|
||||||
|
type PipelineTrigger struct {
|
||||||
|
RequestID string `msgpack:"request_id" json:"request_id"`
|
||||||
|
Pipeline string `msgpack:"pipeline" json:"pipeline"`
|
||||||
|
Parameters map[string]any `msgpack:"parameters,omitempty" json:"parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PipelineStatus is the response / status update for a pipeline run.
|
||||||
|
type PipelineStatus struct {
|
||||||
|
RequestID string `msgpack:"request_id" json:"request_id"`
|
||||||
|
Status string `msgpack:"status" json:"status"`
|
||||||
|
RunID string `msgpack:"run_id,omitempty" json:"run_id,omitempty"`
|
||||||
|
Engine string `msgpack:"engine,omitempty" json:"engine,omitempty"`
|
||||||
|
Pipeline string `msgpack:"pipeline,omitempty" json:"pipeline,omitempty"`
|
||||||
|
SubmittedAt string `msgpack:"submitted_at,omitempty" json:"submitted_at,omitempty"`
|
||||||
|
Error string `msgpack:"error,omitempty" json:"error,omitempty"`
|
||||||
|
AvailablePipelines []string `msgpack:"available_pipelines,omitempty" json:"available_pipelines,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Chat Handler
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// ChatRequest is an incoming chat message.
|
||||||
|
type ChatRequest struct {
|
||||||
|
RequestID string `msgpack:"request_id" json:"request_id"`
|
||||||
|
UserID string `msgpack:"user_id" json:"user_id"`
|
||||||
|
Message string `msgpack:"message" json:"message"`
|
||||||
|
Query string `msgpack:"query,omitempty" json:"query,omitempty"`
|
||||||
|
Premium bool `msgpack:"premium,omitempty" json:"premium,omitempty"`
|
||||||
|
EnableRAG bool `msgpack:"enable_rag,omitempty" json:"enable_rag,omitempty"`
|
||||||
|
EnableReranker bool `msgpack:"enable_reranker,omitempty" json:"enable_reranker,omitempty"`
|
||||||
|
EnableStreaming bool `msgpack:"enable_streaming,omitempty" json:"enable_streaming,omitempty"`
|
||||||
|
TopK int `msgpack:"top_k,omitempty" json:"top_k,omitempty"`
|
||||||
|
Collection string `msgpack:"collection,omitempty" json:"collection,omitempty"`
|
||||||
|
EnableTTS bool `msgpack:"enable_tts,omitempty" json:"enable_tts,omitempty"`
|
||||||
|
SystemPrompt string `msgpack:"system_prompt,omitempty" json:"system_prompt,omitempty"`
|
||||||
|
ResponseSubject string `msgpack:"response_subject,omitempty" json:"response_subject,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EffectiveQuery returns Message or falls back to Query.
|
||||||
|
func (c *ChatRequest) EffectiveQuery() string {
|
||||||
|
if c.Message != "" {
|
||||||
|
return c.Message
|
||||||
|
}
|
||||||
|
return c.Query
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatResponse is the full reply to a chat request.
|
||||||
|
type ChatResponse struct {
|
||||||
|
UserID string `msgpack:"user_id" json:"user_id"`
|
||||||
|
Response string `msgpack:"response" json:"response"`
|
||||||
|
ResponseText string `msgpack:"response_text" json:"response_text"`
|
||||||
|
UsedRAG bool `msgpack:"used_rag" json:"used_rag"`
|
||||||
|
RAGSources []string `msgpack:"rag_sources,omitempty" json:"rag_sources,omitempty"`
|
||||||
|
Success bool `msgpack:"success" json:"success"`
|
||||||
|
Audio []byte `msgpack:"audio,omitempty" json:"audio,omitempty"`
|
||||||
|
Error string `msgpack:"error,omitempty" json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatStreamChunk is a single streaming chunk from an LLM response.
|
||||||
|
type ChatStreamChunk struct {
|
||||||
|
RequestID string `msgpack:"request_id" json:"request_id"`
|
||||||
|
Type string `msgpack:"type" json:"type"`
|
||||||
|
Content string `msgpack:"content" json:"content"`
|
||||||
|
Done bool `msgpack:"done" json:"done"`
|
||||||
|
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Voice Assistant
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// VoiceRequest is an incoming voice-to-voice request.
|
||||||
|
type VoiceRequest struct {
|
||||||
|
RequestID string `msgpack:"request_id" json:"request_id"`
|
||||||
|
Audio []byte `msgpack:"audio" json:"audio"`
|
||||||
|
Language string `msgpack:"language,omitempty" json:"language,omitempty"`
|
||||||
|
Collection string `msgpack:"collection,omitempty" json:"collection,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// VoiceResponse is the reply to a voice request.
|
||||||
|
type VoiceResponse struct {
|
||||||
|
RequestID string `msgpack:"request_id" json:"request_id"`
|
||||||
|
Response string `msgpack:"response" json:"response"`
|
||||||
|
Audio []byte `msgpack:"audio" json:"audio"`
|
||||||
|
Transcription string `msgpack:"transcription,omitempty" json:"transcription,omitempty"`
|
||||||
|
Sources []DocumentSource `msgpack:"sources,omitempty" json:"sources,omitempty"`
|
||||||
|
Error string `msgpack:"error,omitempty" json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DocumentSource is a RAG search result source.
|
||||||
|
type DocumentSource struct {
|
||||||
|
Text string `msgpack:"text" json:"text"`
|
||||||
|
Score float64 `msgpack:"score" json:"score"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// TTS Module
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// TTSRequest is a text-to-speech synthesis request.
|
||||||
|
type TTSRequest struct {
|
||||||
|
Text string `msgpack:"text" json:"text"`
|
||||||
|
Speaker string `msgpack:"speaker,omitempty" json:"speaker,omitempty"`
|
||||||
|
Language string `msgpack:"language,omitempty" json:"language,omitempty"`
|
||||||
|
SpeakerWavB64 string `msgpack:"speaker_wav_b64,omitempty" json:"speaker_wav_b64,omitempty"`
|
||||||
|
Stream bool `msgpack:"stream,omitempty" json:"stream,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTSAudioChunk is a streamed audio chunk from TTS synthesis.
|
||||||
|
type TTSAudioChunk struct {
|
||||||
|
SessionID string `msgpack:"session_id" json:"session_id"`
|
||||||
|
ChunkIndex int `msgpack:"chunk_index" json:"chunk_index"`
|
||||||
|
TotalChunks int `msgpack:"total_chunks" json:"total_chunks"`
|
||||||
|
Audio []byte `msgpack:"audio" json:"audio"`
|
||||||
|
IsLast bool `msgpack:"is_last" json:"is_last"`
|
||||||
|
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
||||||
|
SampleRate int `msgpack:"sample_rate" json:"sample_rate"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTSFullResponse is a non-streamed TTS response (whole audio).
|
||||||
|
type TTSFullResponse struct {
|
||||||
|
SessionID string `msgpack:"session_id" json:"session_id"`
|
||||||
|
Audio []byte `msgpack:"audio" json:"audio"`
|
||||||
|
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
||||||
|
SampleRate int `msgpack:"sample_rate" json:"sample_rate"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTSStatus is a TTS processing status update.
|
||||||
|
type TTSStatus struct {
|
||||||
|
SessionID string `msgpack:"session_id" json:"session_id"`
|
||||||
|
Status string `msgpack:"status" json:"status"`
|
||||||
|
Message string `msgpack:"message" json:"message"`
|
||||||
|
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTSVoiceListResponse is the reply to a voice list request.
|
||||||
|
type TTSVoiceListResponse struct {
|
||||||
|
DefaultSpeaker string `msgpack:"default_speaker" json:"default_speaker"`
|
||||||
|
CustomVoices []TTSVoiceInfo `msgpack:"custom_voices" json:"custom_voices"`
|
||||||
|
LastRefresh int64 `msgpack:"last_refresh" json:"last_refresh"`
|
||||||
|
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTSVoiceInfo is summary info about a custom voice.
|
||||||
|
type TTSVoiceInfo struct {
|
||||||
|
Name string `msgpack:"name" json:"name"`
|
||||||
|
Language string `msgpack:"language" json:"language"`
|
||||||
|
ModelType string `msgpack:"model_type" json:"model_type"`
|
||||||
|
CreatedAt string `msgpack:"created_at" json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTSVoiceRefreshResponse is the reply to a voice refresh request.
|
||||||
|
type TTSVoiceRefreshResponse struct {
|
||||||
|
Count int `msgpack:"count" json:"count"`
|
||||||
|
CustomVoices []TTSVoiceInfo `msgpack:"custom_voices" json:"custom_voices"`
|
||||||
|
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// STT Module
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// STTStreamMessage is any message on the ai.voice.stream.{session} subject.
|
||||||
|
type STTStreamMessage struct {
|
||||||
|
Type string `msgpack:"type" json:"type"`
|
||||||
|
Audio []byte `msgpack:"audio,omitempty" json:"audio,omitempty"`
|
||||||
|
State string `msgpack:"state,omitempty" json:"state,omitempty"`
|
||||||
|
SpeakerID string `msgpack:"speaker_id,omitempty" json:"speaker_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// STTTranscription is the transcription result published by the STT module.
|
||||||
|
type STTTranscription struct {
|
||||||
|
SessionID string `msgpack:"session_id" json:"session_id"`
|
||||||
|
Transcript string `msgpack:"transcript" json:"transcript"`
|
||||||
|
Sequence int `msgpack:"sequence" json:"sequence"`
|
||||||
|
IsPartial bool `msgpack:"is_partial" json:"is_partial"`
|
||||||
|
IsFinal bool `msgpack:"is_final" json:"is_final"`
|
||||||
|
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
||||||
|
SpeakerID string `msgpack:"speaker_id" json:"speaker_id"`
|
||||||
|
HasVoiceActivity bool `msgpack:"has_voice_activity" json:"has_voice_activity"`
|
||||||
|
State string `msgpack:"state" json:"state"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// STTInterrupt is published when the STT module detects a user interrupt.
|
||||||
|
type STTInterrupt struct {
|
||||||
|
SessionID string `msgpack:"session_id" json:"session_id"`
|
||||||
|
Type string `msgpack:"type" json:"type"`
|
||||||
|
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
||||||
|
SpeakerID string `msgpack:"speaker_id" json:"speaker_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Common / Error
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// ErrorResponse is the standard error reply from any handler.
|
||||||
|
type ErrorResponse struct {
|
||||||
|
Error bool `msgpack:"error" json:"error"`
|
||||||
|
Message string `msgpack:"message" json:"message"`
|
||||||
|
Type string `msgpack:"type" json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timestamp returns the current Unix timestamp (helper for message construction).
|
||||||
|
func Timestamp() int64 {
|
||||||
|
return time.Now().Unix()
|
||||||
|
}
|
||||||
1738
messages/proto/messages.pb.go
Normal file
1738
messages/proto/messages.pb.go
Normal file
File diff suppressed because it is too large
Load Diff
174
messages/proto/messages.proto
Normal file
174
messages/proto/messages.proto
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package messages;
|
||||||
|
|
||||||
|
option go_package = "git.daviestechlabs.io/daviestechlabs/handler-base/messages/proto";
|
||||||
|
|
||||||
|
// ── Pipeline Bridge ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
message PipelineTrigger {
|
||||||
|
string request_id = 1;
|
||||||
|
string pipeline = 2;
|
||||||
|
map<string, string> parameters = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message PipelineStatus {
|
||||||
|
string request_id = 1;
|
||||||
|
string status = 2;
|
||||||
|
string run_id = 3;
|
||||||
|
string engine = 4;
|
||||||
|
string pipeline = 5;
|
||||||
|
string submitted_at = 6;
|
||||||
|
string error = 7;
|
||||||
|
repeated string available_pipelines = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Chat Handler ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
message ChatRequest {
|
||||||
|
string request_id = 1;
|
||||||
|
string user_id = 2;
|
||||||
|
string message = 3;
|
||||||
|
string query = 4;
|
||||||
|
bool premium = 5;
|
||||||
|
bool enable_rag = 6;
|
||||||
|
bool enable_reranker = 7;
|
||||||
|
bool enable_streaming = 8;
|
||||||
|
int32 top_k = 9;
|
||||||
|
string collection = 10;
|
||||||
|
bool enable_tts = 11;
|
||||||
|
string system_prompt = 12;
|
||||||
|
string response_subject = 13;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ChatResponse {
|
||||||
|
string user_id = 1;
|
||||||
|
string response = 2;
|
||||||
|
string response_text = 3;
|
||||||
|
bool used_rag = 4;
|
||||||
|
repeated string rag_sources = 5;
|
||||||
|
bool success = 6;
|
||||||
|
bytes audio = 7;
|
||||||
|
string error = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ChatStreamChunk {
|
||||||
|
string request_id = 1;
|
||||||
|
string type = 2;
|
||||||
|
string content = 3;
|
||||||
|
bool done = 4;
|
||||||
|
int64 timestamp = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Voice Assistant ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
message VoiceRequest {
|
||||||
|
string request_id = 1;
|
||||||
|
bytes audio = 2;
|
||||||
|
string language = 3;
|
||||||
|
string collection = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message VoiceResponse {
|
||||||
|
string request_id = 1;
|
||||||
|
string response = 2;
|
||||||
|
bytes audio = 3;
|
||||||
|
string transcription = 4;
|
||||||
|
repeated DocumentSource sources = 5;
|
||||||
|
string error = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
message DocumentSource {
|
||||||
|
string text = 1;
|
||||||
|
double score = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── TTS Module ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
message TTSRequest {
|
||||||
|
string text = 1;
|
||||||
|
string speaker = 2;
|
||||||
|
string language = 3;
|
||||||
|
string speaker_wav_b64 = 4;
|
||||||
|
bool stream = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TTSAudioChunk {
|
||||||
|
string session_id = 1;
|
||||||
|
int32 chunk_index = 2;
|
||||||
|
int32 total_chunks = 3;
|
||||||
|
bytes audio = 4;
|
||||||
|
bool is_last = 5;
|
||||||
|
int64 timestamp = 6;
|
||||||
|
int32 sample_rate = 7;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TTSFullResponse {
|
||||||
|
string session_id = 1;
|
||||||
|
bytes audio = 2;
|
||||||
|
int64 timestamp = 3;
|
||||||
|
int32 sample_rate = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TTSStatus {
|
||||||
|
string session_id = 1;
|
||||||
|
string status = 2;
|
||||||
|
string message = 3;
|
||||||
|
int64 timestamp = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TTSVoiceInfo {
|
||||||
|
string name = 1;
|
||||||
|
string language = 2;
|
||||||
|
string model_type = 3;
|
||||||
|
string created_at = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TTSVoiceListResponse {
|
||||||
|
string default_speaker = 1;
|
||||||
|
repeated TTSVoiceInfo custom_voices = 2;
|
||||||
|
int64 last_refresh = 3;
|
||||||
|
int64 timestamp = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TTSVoiceRefreshResponse {
|
||||||
|
int32 count = 1;
|
||||||
|
repeated TTSVoiceInfo custom_voices = 2;
|
||||||
|
int64 timestamp = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── STT Module ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
message STTStreamMessage {
|
||||||
|
string type = 1;
|
||||||
|
bytes audio = 2;
|
||||||
|
string state = 3;
|
||||||
|
string speaker_id = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message STTTranscription {
|
||||||
|
string session_id = 1;
|
||||||
|
string transcript = 2;
|
||||||
|
int32 sequence = 3;
|
||||||
|
bool is_partial = 4;
|
||||||
|
bool is_final = 5;
|
||||||
|
int64 timestamp = 6;
|
||||||
|
string speaker_id = 7;
|
||||||
|
bool has_voice_activity = 8;
|
||||||
|
string state = 9;
|
||||||
|
}
|
||||||
|
|
||||||
|
message STTInterrupt {
|
||||||
|
string session_id = 1;
|
||||||
|
string type = 2;
|
||||||
|
int64 timestamp = 3;
|
||||||
|
string speaker_id = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Common ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
message ErrorResponse {
|
||||||
|
bool error = 1;
|
||||||
|
string message = 2;
|
||||||
|
string type = 3;
|
||||||
|
}
|
||||||
256
natsutil/natsutil_test.go
Normal file
256
natsutil/natsutil_test.go
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
package natsutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/vmihailenco/msgpack/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// DecodeMsgpackMap tests
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestDecodeMsgpackMap_Roundtrip(t *testing.T) {
|
||||||
|
orig := map[string]any{
|
||||||
|
"request_id": "req-001",
|
||||||
|
"user_id": "user-42",
|
||||||
|
"premium": true,
|
||||||
|
"top_k": int64(10), // msgpack decodes ints as int64
|
||||||
|
}
|
||||||
|
data, err := msgpack.Marshal(orig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded, err := DecodeMsgpackMap(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if decoded["request_id"] != "req-001" {
|
||||||
|
t.Errorf("request_id = %v", decoded["request_id"])
|
||||||
|
}
|
||||||
|
if decoded["premium"] != true {
|
||||||
|
t.Errorf("premium = %v", decoded["premium"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeMsgpackMap_Empty(t *testing.T) {
|
||||||
|
data, _ := msgpack.Marshal(map[string]any{})
|
||||||
|
m, err := DecodeMsgpackMap(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(m) != 0 {
|
||||||
|
t.Errorf("expected empty map, got %v", m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeMsgpackMap_InvalidData(t *testing.T) {
|
||||||
|
_, err := DecodeMsgpackMap([]byte{0xFF, 0xFE})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for invalid msgpack data")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// DecodeMsgpack (typed struct) tests
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
type testMessage struct {
|
||||||
|
RequestID string `msgpack:"request_id"`
|
||||||
|
UserID string `msgpack:"user_id"`
|
||||||
|
Count int `msgpack:"count"`
|
||||||
|
Active bool `msgpack:"active"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeMsgpackTyped_Roundtrip(t *testing.T) {
|
||||||
|
orig := testMessage{
|
||||||
|
RequestID: "req-typed-001",
|
||||||
|
UserID: "user-7",
|
||||||
|
Count: 42,
|
||||||
|
Active: true,
|
||||||
|
}
|
||||||
|
data, err := msgpack.Marshal(orig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate nats.Msg data decoding.
|
||||||
|
var decoded testMessage
|
||||||
|
if err := msgpack.Unmarshal(data, &decoded); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if decoded.RequestID != orig.RequestID {
|
||||||
|
t.Errorf("RequestID = %q, want %q", decoded.RequestID, orig.RequestID)
|
||||||
|
}
|
||||||
|
if decoded.Count != orig.Count {
|
||||||
|
t.Errorf("Count = %d, want %d", decoded.Count, orig.Count)
|
||||||
|
}
|
||||||
|
if decoded.Active != orig.Active {
|
||||||
|
t.Errorf("Active = %v, want %v", decoded.Active, orig.Active)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTypedStructDecodesMapEncoding verifies that a typed struct can be
|
||||||
|
// decoded from data that was encoded as map[string]any (backwards compat).
|
||||||
|
func TestTypedStructDecodesMapEncoding(t *testing.T) {
|
||||||
|
// Encode as map (the old way).
|
||||||
|
mapData := map[string]any{
|
||||||
|
"request_id": "req-compat",
|
||||||
|
"user_id": "user-compat",
|
||||||
|
"count": int64(99),
|
||||||
|
"active": false,
|
||||||
|
}
|
||||||
|
data, err := msgpack.Marshal(mapData)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode into typed struct (the new way).
|
||||||
|
var msg testMessage
|
||||||
|
if err := msgpack.Unmarshal(data, &msg); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg.RequestID != "req-compat" {
|
||||||
|
t.Errorf("RequestID = %q", msg.RequestID)
|
||||||
|
}
|
||||||
|
if msg.Count != 99 {
|
||||||
|
t.Errorf("Count = %d, want 99", msg.Count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Binary data tests (audio []byte in msgpack)
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
type audioMessage struct {
|
||||||
|
SessionID string `msgpack:"session_id"`
|
||||||
|
Audio []byte `msgpack:"audio"`
|
||||||
|
SampleRate int `msgpack:"sample_rate"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBinaryDataRoundtrip(t *testing.T) {
|
||||||
|
audio := make([]byte, 32768)
|
||||||
|
for i := range audio {
|
||||||
|
audio[i] = byte(i % 256)
|
||||||
|
}
|
||||||
|
|
||||||
|
orig := audioMessage{
|
||||||
|
SessionID: "sess-audio-001",
|
||||||
|
Audio: audio,
|
||||||
|
SampleRate: 24000,
|
||||||
|
}
|
||||||
|
data, err := msgpack.Marshal(orig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var decoded audioMessage
|
||||||
|
if err := msgpack.Unmarshal(data, &decoded); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(decoded.Audio) != len(orig.Audio) {
|
||||||
|
t.Fatalf("audio len = %d, want %d", len(decoded.Audio), len(orig.Audio))
|
||||||
|
}
|
||||||
|
for i := range decoded.Audio {
|
||||||
|
if decoded.Audio[i] != orig.Audio[i] {
|
||||||
|
t.Fatalf("audio[%d] = %d, want %d", i, decoded.Audio[i], orig.Audio[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBinaryVsBase64Size shows the wire-size win of raw bytes vs base64 string.
|
||||||
|
func TestBinaryVsBase64Size(t *testing.T) {
|
||||||
|
audio := make([]byte, 16384)
|
||||||
|
|
||||||
|
// Old approach: base64 string in map.
|
||||||
|
import_b64 := make([]byte, (len(audio)*4+2)/3) // approximate base64 size
|
||||||
|
mapMsg := map[string]any{
|
||||||
|
"session_id": "sess-1",
|
||||||
|
"audio_b64": string(import_b64),
|
||||||
|
}
|
||||||
|
mapData, _ := msgpack.Marshal(mapMsg)
|
||||||
|
|
||||||
|
// New approach: raw bytes in struct.
|
||||||
|
structMsg := audioMessage{
|
||||||
|
SessionID: "sess-1",
|
||||||
|
Audio: audio,
|
||||||
|
}
|
||||||
|
structData, _ := msgpack.Marshal(structMsg)
|
||||||
|
|
||||||
|
t.Logf("base64-in-map: %d bytes, raw-bytes-in-struct: %d bytes (%.0f%% smaller)",
|
||||||
|
len(mapData), len(structData),
|
||||||
|
100*(1-float64(len(structData))/float64(len(mapData))))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Benchmarks
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func BenchmarkEncodeMap(b *testing.B) {
|
||||||
|
data := map[string]any{
|
||||||
|
"request_id": "req-bench",
|
||||||
|
"user_id": "user-bench",
|
||||||
|
"message": "What is the weather today?",
|
||||||
|
"premium": true,
|
||||||
|
"top_k": 10,
|
||||||
|
}
|
||||||
|
for b.Loop() {
|
||||||
|
msgpack.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncodeStruct(b *testing.B) {
|
||||||
|
data := testMessage{
|
||||||
|
RequestID: "req-bench",
|
||||||
|
UserID: "user-bench",
|
||||||
|
Count: 10,
|
||||||
|
Active: true,
|
||||||
|
}
|
||||||
|
for b.Loop() {
|
||||||
|
msgpack.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecodeMap(b *testing.B) {
|
||||||
|
raw, _ := msgpack.Marshal(map[string]any{
|
||||||
|
"request_id": "req-bench",
|
||||||
|
"user_id": "user-bench",
|
||||||
|
"message": "What is the weather today?",
|
||||||
|
"premium": true,
|
||||||
|
"top_k": 10,
|
||||||
|
})
|
||||||
|
for b.Loop() {
|
||||||
|
var m map[string]any
|
||||||
|
msgpack.Unmarshal(raw, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecodeStruct(b *testing.B) {
|
||||||
|
raw, _ := msgpack.Marshal(testMessage{
|
||||||
|
RequestID: "req-bench",
|
||||||
|
UserID: "user-bench",
|
||||||
|
Count: 10,
|
||||||
|
Active: true,
|
||||||
|
})
|
||||||
|
for b.Loop() {
|
||||||
|
var m testMessage
|
||||||
|
msgpack.Unmarshal(raw, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecodeAudio32KB(b *testing.B) {
|
||||||
|
raw, _ := msgpack.Marshal(audioMessage{
|
||||||
|
SessionID: "s1",
|
||||||
|
Audio: make([]byte, 32768),
|
||||||
|
SampleRate: 24000,
|
||||||
|
})
|
||||||
|
for b.Loop() {
|
||||||
|
var m audioMessage
|
||||||
|
msgpack.Unmarshal(raw, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user