diff --git a/clients/clients.go b/clients/clients.go index e3b7afa..50a589c 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -1,389 +1,437 @@ // Package clients provides HTTP client wrappers for AI/ML backend services. +// +// All clients share a single [http.Transport] for connection pooling across +// the process. Request and response bodies are serialized through pooled +// [bytes.Buffer]s to reduce GC pressure. package clients import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "mime/multipart" - "net/http" - "net/url" - "time" +"bytes" +"context" +"encoding/json" +"fmt" +"io" +"mime/multipart" +"net/http" +"net/url" +"sync" +"time" ) -// httpClient is a shared interface for all service clients. +// ─── Shared transport & buffer pool ───────────────────────────────────────── + +// SharedTransport is the process-wide HTTP transport used by every service +// client. Tweak pool sizes here rather than creating per-client transports. +var SharedTransport = &http.Transport{ +MaxIdleConns: 100, +MaxIdleConnsPerHost: 10, +IdleConnTimeout: 90 * time.Second, +DisableCompression: true, // in-cluster traffic; skip gzip overhead +} + +// bufPool recycles *bytes.Buffer to avoid per-request allocations. +var bufPool = sync.Pool{ +New: func() any { return new(bytes.Buffer) }, +} + +func getBuf() *bytes.Buffer { +buf := bufPool.Get().(*bytes.Buffer) +buf.Reset() +return buf +} + +func putBuf(buf *bytes.Buffer) { +if buf.Cap() > 1<<20 { // don't cache buffers > 1 MiB +return +} +bufPool.Put(buf) +} + +// ─── httpClient base ──────────────────────────────────────────────────────── + +// httpClient is the shared base for all service clients. type httpClient struct { - client *http.Client - baseURL string +client *http.Client +baseURL string } func newHTTPClient(baseURL string, timeout time.Duration) *httpClient { - return &httpClient{ - client: &http.Client{Timeout: timeout}, - baseURL: baseURL, - } +return &httpClient{ +client: &http.Client{ +Timeout: timeout, +Transport: SharedTransport, +}, +baseURL: baseURL, +} } func (h *httpClient) postJSON(ctx context.Context, path string, body any) ([]byte, error) { - data, err := json.Marshal(body) - if err != nil { - return nil, fmt.Errorf("marshal: %w", err) - } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, bytes.NewReader(data)) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - return h.do(req) +buf := getBuf() +defer putBuf(buf) +if err := json.NewEncoder(buf).Encode(body); err != nil { +return nil, fmt.Errorf("marshal: %w", err) +} +req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, buf) +if err != nil { +return nil, err +} +req.Header.Set("Content-Type", "application/json") +return h.do(req) } func (h *httpClient) get(ctx context.Context, path string, params url.Values) ([]byte, error) { - u := h.baseURL + path - if len(params) > 0 { - u += "?" + params.Encode() - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) - if err != nil { - return nil, err - } - return h.do(req) +u := h.baseURL + path +if len(params) > 0 { +u += "?" + params.Encode() +} +req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) +if err != nil { +return nil, err +} +return h.do(req) } func (h *httpClient) getRaw(ctx context.Context, path string, params url.Values) ([]byte, error) { - return h.get(ctx, path, params) +return h.get(ctx, path, params) } func (h *httpClient) postMultipart(ctx context.Context, path string, fieldName string, fileName string, fileData []byte, fields map[string]string) ([]byte, error) { - var buf bytes.Buffer - w := multipart.NewWriter(&buf) - part, err := w.CreateFormFile(fieldName, fileName) - if err != nil { - return nil, err - } - if _, err := part.Write(fileData); err != nil { - return nil, err - } - for k, v := range fields { - _ = w.WriteField(k, v) - } - _ = w.Close() +buf := getBuf() +defer putBuf(buf) +w := multipart.NewWriter(buf) +part, err := w.CreateFormFile(fieldName, fileName) +if err != nil { +return nil, err +} +if _, err := part.Write(fileData); err != nil { +return nil, err +} +for k, v := range fields { +_ = w.WriteField(k, v) +} +_ = w.Close() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, &buf) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", w.FormDataContentType()) - return h.do(req) +req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, buf) +if err != nil { +return nil, err +} +req.Header.Set("Content-Type", w.FormDataContentType()) +return h.do(req) } func (h *httpClient) do(req *http.Request) ([]byte, error) { - resp, err := h.client.Do(req) - if err != nil { - return nil, fmt.Errorf("http %s %s: %w", req.Method, req.URL.Path, err) - } - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("read body: %w", err) - } - if resp.StatusCode >= 400 { - return nil, fmt.Errorf("http %d: %s", resp.StatusCode, string(body)) - } - return body, nil +resp, err := h.client.Do(req) +if err != nil { +return nil, fmt.Errorf("http %s %s: %w", req.Method, req.URL.Path, err) +} +defer resp.Body.Close() + +buf := getBuf() +defer putBuf(buf) +if _, err := io.Copy(buf, resp.Body); err != nil { +return nil, fmt.Errorf("read body: %w", err) +} + +// Return a copy so the pooled buffer can be safely recycled. +body := make([]byte, buf.Len()) +copy(body, buf.Bytes()) + +if resp.StatusCode >= 400 { +return nil, fmt.Errorf("http %d: %s", resp.StatusCode, string(body)) +} +return body, nil } func (h *httpClient) healthCheck(ctx context.Context) bool { - data, err := h.get(ctx, "/health", nil) - _ = data - return err == nil +data, err := h.get(ctx, "/health", nil) +_ = data +return err == nil } -// --- Embeddings Client --- +// ─── Embeddings Client ────────────────────────────────────────────────────── // EmbeddingsClient calls the embeddings service (Infinity/BGE). type EmbeddingsClient struct { - *httpClient - Model string +*httpClient +Model string } // NewEmbeddingsClient creates an embeddings client. func NewEmbeddingsClient(baseURL string, timeout time.Duration, model string) *EmbeddingsClient { - if model == "" { - model = "bge" - } - return &EmbeddingsClient{httpClient: newHTTPClient(baseURL, timeout), Model: model} +if model == "" { +model = "bge" +} +return &EmbeddingsClient{httpClient: newHTTPClient(baseURL, timeout), Model: model} } // Embed generates embeddings for a list of texts. func (c *EmbeddingsClient) Embed(ctx context.Context, texts []string) ([][]float64, error) { - body, err := c.postJSON(ctx, "/embeddings", map[string]any{ - "input": texts, - "model": c.Model, - }) - if err != nil { - return nil, err - } - var resp struct { - Data []struct { - Embedding []float64 `json:"embedding"` - } `json:"data"` - } - if err := json.Unmarshal(body, &resp); err != nil { - return nil, err - } - result := make([][]float64, len(resp.Data)) - for i, d := range resp.Data { - result[i] = d.Embedding - } - return result, nil +body, err := c.postJSON(ctx, "/embeddings", map[string]any{ +"input": texts, +"model": c.Model, +}) +if err != nil { +return nil, err +} +var resp struct { +Data []struct { +Embedding []float64 `json:"embedding"` +} `json:"data"` +} +if err := json.Unmarshal(body, &resp); err != nil { +return nil, err +} +result := make([][]float64, len(resp.Data)) +for i, d := range resp.Data { +result[i] = d.Embedding +} +return result, nil } // EmbedSingle generates an embedding for a single text. func (c *EmbeddingsClient) EmbedSingle(ctx context.Context, text string) ([]float64, error) { - results, err := c.Embed(ctx, []string{text}) - if err != nil { - return nil, err - } - if len(results) == 0 { - return nil, fmt.Errorf("empty embedding result") - } - return results[0], nil +results, err := c.Embed(ctx, []string{text}) +if err != nil { +return nil, err +} +if len(results) == 0 { +return nil, fmt.Errorf("empty embedding result") +} +return results[0], nil } // Health checks if the embeddings service is healthy. func (c *EmbeddingsClient) Health(ctx context.Context) bool { - return c.healthCheck(ctx) +return c.healthCheck(ctx) } -// --- Reranker Client --- +// ─── Reranker Client ──────────────────────────────────────────────────────── // RerankerClient calls the reranker service (BGE Reranker). type RerankerClient struct { - *httpClient +*httpClient } // NewRerankerClient creates a reranker client. func NewRerankerClient(baseURL string, timeout time.Duration) *RerankerClient { - return &RerankerClient{httpClient: newHTTPClient(baseURL, timeout)} +return &RerankerClient{httpClient: newHTTPClient(baseURL, timeout)} } // RerankResult represents a reranked document. type RerankResult struct { - Index int `json:"index"` - Score float64 `json:"score"` - Document string `json:"document"` +Index int `json:"index"` +Score float64 `json:"score"` +Document string `json:"document"` } // Rerank reranks documents by relevance to the query. func (c *RerankerClient) Rerank(ctx context.Context, query string, documents []string, topK int) ([]RerankResult, error) { - payload := map[string]any{ - "query": query, - "documents": documents, - } - if topK > 0 { - payload["top_n"] = topK - } - body, err := c.postJSON(ctx, "/rerank", payload) - if err != nil { - return nil, err - } - var resp struct { - Results []struct { - Index int `json:"index"` - RelevanceScore float64 `json:"relevance_score"` - Score float64 `json:"score"` - } `json:"results"` - } - if err := json.Unmarshal(body, &resp); err != nil { - return nil, err - } - results := make([]RerankResult, len(resp.Results)) - for i, r := range resp.Results { - score := r.RelevanceScore - if score == 0 { - score = r.Score - } - doc := "" - if r.Index < len(documents) { - doc = documents[r.Index] - } - results[i] = RerankResult{Index: r.Index, Score: score, Document: doc} - } - return results, nil +payload := map[string]any{ +"query": query, +"documents": documents, +} +if topK > 0 { +payload["top_n"] = topK +} +body, err := c.postJSON(ctx, "/rerank", payload) +if err != nil { +return nil, err +} +var resp struct { +Results []struct { +Index int `json:"index"` +RelevanceScore float64 `json:"relevance_score"` +Score float64 `json:"score"` +} `json:"results"` +} +if err := json.Unmarshal(body, &resp); err != nil { +return nil, err +} +results := make([]RerankResult, len(resp.Results)) +for i, r := range resp.Results { +score := r.RelevanceScore +if score == 0 { +score = r.Score +} +doc := "" +if r.Index < len(documents) { +doc = documents[r.Index] +} +results[i] = RerankResult{Index: r.Index, Score: score, Document: doc} +} +return results, nil } -// --- LLM Client --- +// ─── LLM Client ───────────────────────────────────────────────────────────── // LLMClient calls the vLLM-compatible LLM service. type LLMClient struct { - *httpClient - Model string - MaxTokens int - Temperature float64 - TopP float64 +*httpClient +Model string +MaxTokens int +Temperature float64 +TopP float64 } // NewLLMClient creates an LLM client. func NewLLMClient(baseURL string, timeout time.Duration) *LLMClient { - return &LLMClient{ - httpClient: newHTTPClient(baseURL, timeout), - Model: "default", - MaxTokens: 2048, - Temperature: 0.7, - TopP: 0.9, - } +return &LLMClient{ +httpClient: newHTTPClient(baseURL, timeout), +Model: "default", +MaxTokens: 2048, +Temperature: 0.7, +TopP: 0.9, +} } // ChatMessage is an OpenAI-compatible message. type ChatMessage struct { - Role string `json:"role"` - Content string `json:"content"` +Role string `json:"role"` +Content string `json:"content"` } // Generate sends a chat completion request and returns the response text. func (c *LLMClient) Generate(ctx context.Context, prompt string, context_ string, systemPrompt string) (string, error) { - messages := buildMessages(prompt, context_, systemPrompt) - payload := map[string]any{ - "model": c.Model, - "messages": messages, - "max_tokens": c.MaxTokens, - "temperature": c.Temperature, - "top_p": c.TopP, - } - body, err := c.postJSON(ctx, "/v1/chat/completions", payload) - if err != nil { - return "", err - } - var resp struct { - Choices []struct { - Message struct { - Content string `json:"content"` - } `json:"message"` - } `json:"choices"` - } - if err := json.Unmarshal(body, &resp); err != nil { - return "", err - } - if len(resp.Choices) == 0 { - return "", fmt.Errorf("no choices in LLM response") - } - return resp.Choices[0].Message.Content, nil +messages := buildMessages(prompt, context_, systemPrompt) +payload := map[string]any{ +"model": c.Model, +"messages": messages, +"max_tokens": c.MaxTokens, +"temperature": c.Temperature, +"top_p": c.TopP, +} +body, err := c.postJSON(ctx, "/v1/chat/completions", payload) +if err != nil { +return "", err +} +var resp struct { +Choices []struct { +Message struct { +Content string `json:"content"` +} `json:"message"` +} `json:"choices"` +} +if err := json.Unmarshal(body, &resp); err != nil { +return "", err +} +if len(resp.Choices) == 0 { +return "", fmt.Errorf("no choices in LLM response") +} +return resp.Choices[0].Message.Content, nil } func buildMessages(prompt, ctx, systemPrompt string) []ChatMessage { - var msgs []ChatMessage - if systemPrompt != "" { - msgs = append(msgs, ChatMessage{Role: "system", Content: systemPrompt}) - } else if ctx != "" { - msgs = append(msgs, ChatMessage{Role: "system", Content: "You are a helpful assistant. Use the provided context to answer the user's question. If the context doesn't contain relevant information, say so."}) - } - if ctx != "" { - msgs = append(msgs, ChatMessage{Role: "user", Content: fmt.Sprintf("Context:\n%s\n\nQuestion: %s", ctx, prompt)}) - } else { - msgs = append(msgs, ChatMessage{Role: "user", Content: prompt}) - } - return msgs +var msgs []ChatMessage +if systemPrompt != "" { +msgs = append(msgs, ChatMessage{Role: "system", Content: systemPrompt}) +} else if ctx != "" { +msgs = append(msgs, ChatMessage{Role: "system", Content: "You are a helpful assistant. Use the provided context to answer the user's question. If the context doesn't contain relevant information, say so."}) +} +if ctx != "" { +msgs = append(msgs, ChatMessage{Role: "user", Content: fmt.Sprintf("Context:\n%s\n\nQuestion: %s", ctx, prompt)}) +} else { +msgs = append(msgs, ChatMessage{Role: "user", Content: prompt}) +} +return msgs } -// --- TTS Client --- +// ─── TTS Client ───────────────────────────────────────────────────────────── // TTSClient calls the TTS service (Coqui XTTS). type TTSClient struct { - *httpClient - Language string +*httpClient +Language string } // NewTTSClient creates a TTS client. func NewTTSClient(baseURL string, timeout time.Duration, language string) *TTSClient { - if language == "" { - language = "en" - } - return &TTSClient{httpClient: newHTTPClient(baseURL, timeout), Language: language} +if language == "" { +language = "en" +} +return &TTSClient{httpClient: newHTTPClient(baseURL, timeout), Language: language} } // Synthesize generates audio bytes from text. func (c *TTSClient) Synthesize(ctx context.Context, text, language, speaker string) ([]byte, error) { - if language == "" { - language = c.Language - } - params := url.Values{ - "text": {text}, - "language_id": {language}, - } - if speaker != "" { - params.Set("speaker_id", speaker) - } - return c.getRaw(ctx, "/api/tts", params) +if language == "" { +language = c.Language +} +params := url.Values{ +"text": {text}, +"language_id": {language}, +} +if speaker != "" { +params.Set("speaker_id", speaker) +} +return c.getRaw(ctx, "/api/tts", params) } -// --- STT Client --- +// ─── STT Client ───────────────────────────────────────────────────────────── // STTClient calls the Whisper STT service. type STTClient struct { - *httpClient - Language string - Task string +*httpClient +Language string +Task string } // NewSTTClient creates an STT client. func NewSTTClient(baseURL string, timeout time.Duration) *STTClient { - return &STTClient{httpClient: newHTTPClient(baseURL, timeout), Task: "transcribe"} +return &STTClient{httpClient: newHTTPClient(baseURL, timeout), Task: "transcribe"} } // TranscribeResult holds transcription output. type TranscribeResult struct { - Text string `json:"text"` - Language string `json:"language,omitempty"` +Text string `json:"text"` +Language string `json:"language,omitempty"` } // Transcribe sends audio to Whisper and returns the transcription. func (c *STTClient) Transcribe(ctx context.Context, audio []byte, language string) (*TranscribeResult, error) { - if language == "" { - language = c.Language - } - fields := map[string]string{ - "response_format": "json", - } - if language != "" { - fields["language"] = language - } - endpoint := "/v1/audio/transcriptions" - if c.Task == "translate" { - endpoint = "/v1/audio/translations" - } - body, err := c.postMultipart(ctx, endpoint, "file", "audio.wav", audio, fields) - if err != nil { - return nil, err - } - var result TranscribeResult - if err := json.Unmarshal(body, &result); err != nil { - return nil, err - } - return &result, nil +if language == "" { +language = c.Language +} +fields := map[string]string{ +"response_format": "json", +} +if language != "" { +fields["language"] = language +} +endpoint := "/v1/audio/transcriptions" +if c.Task == "translate" { +endpoint = "/v1/audio/translations" +} +body, err := c.postMultipart(ctx, endpoint, "file", "audio.wav", audio, fields) +if err != nil { +return nil, err +} +var result TranscribeResult +if err := json.Unmarshal(body, &result); err != nil { +return nil, err +} +return &result, nil } -// --- Milvus Client --- +// ─── Milvus Client ────────────────────────────────────────────────────────── // MilvusClient provides vector search via the Milvus HTTP/gRPC API. // For the Go port we use the Milvus Go SDK. type MilvusClient struct { - Host string - Port int - Collection string - connected bool +Host string +Port int +Collection string +connected bool } // NewMilvusClient creates a Milvus client. func NewMilvusClient(host string, port int, collection string) *MilvusClient { - return &MilvusClient{Host: host, Port: port, Collection: collection} +return &MilvusClient{Host: host, Port: port, Collection: collection} } // SearchResult holds a single vector search hit. type SearchResult struct { - ID int64 `json:"id"` - Distance float64 `json:"distance"` - Score float64 `json:"score"` - Fields map[string]any `json:"fields,omitempty"` +ID int64 `json:"id"` +Distance float64 `json:"distance"` +Score float64 `json:"score"` +Fields map[string]any `json:"fields,omitempty"` } diff --git a/clients/clients_test.go b/clients/clients_test.go new file mode 100644 index 0000000..bb87404 --- /dev/null +++ b/clients/clients_test.go @@ -0,0 +1,506 @@ +package clients + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" +) + +// ──────────────────────────────────────────────────────────────────────────── +// Shared infrastructure tests +// ──────────────────────────────────────────────────────────────────────────── + +func TestSharedTransport(t *testing.T) { + // All clients created via newHTTPClient should share the same transport. + c1 := newHTTPClient("http://a:8000", 10*time.Second) + c2 := newHTTPClient("http://b:9000", 30*time.Second) + + if c1.client.Transport != c2.client.Transport { + t.Error("clients should share the same http.Transport") + } + if c1.client.Transport != SharedTransport { + t.Error("transport should be the package-level SharedTransport") + } +} + +func TestBufferPoolGetPut(t *testing.T) { + buf := getBuf() + if buf == nil { + t.Fatal("getBuf returned nil") + } + if buf.Len() != 0 { + t.Error("getBuf should return a reset buffer") + } + buf.WriteString("hello") + putBuf(buf) + + // On re-get, buffer should be reset. + buf2 := getBuf() + if buf2.Len() != 0 { + t.Error("re-acquired buffer should be reset") + } + putBuf(buf2) +} + +func TestBufferPoolOversizedDiscarded(t *testing.T) { + buf := getBuf() + // Grow beyond 1 MB threshold. + buf.Write(make([]byte, 2<<20)) + putBuf(buf) // should silently discard + + // Pool should still work — we get a fresh one. + buf2 := getBuf() + if buf2.Len() != 0 { + t.Error("should get a fresh buffer") + } + putBuf(buf2) +} + +func TestBufferPoolConcurrency(t *testing.T) { + var wg sync.WaitGroup + for i := range 100 { + wg.Add(1) + go func(n int) { + defer wg.Done() + buf := getBuf() + buf.WriteString(strings.Repeat("x", n)) + putBuf(buf) + }(i) + } + wg.Wait() +} + +// ──────────────────────────────────────────────────────────────────────────── +// Embeddings client +// ──────────────────────────────────────────────────────────────────────────── + +func TestEmbeddingsClient_Embed(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/embeddings" { + t.Errorf("path = %q, want /embeddings", r.URL.Path) + } + if r.Method != http.MethodPost { + t.Errorf("method = %s, want POST", r.Method) + } + var req map[string]any + json.NewDecoder(r.Body).Decode(&req) + input, _ := req["input"].([]any) + if len(input) != 2 { + t.Errorf("input len = %d, want 2", len(input)) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "data": []map[string]any{ + {"embedding": []float64{0.1, 0.2, 0.3}}, + {"embedding": []float64{0.4, 0.5, 0.6}}, + }, + }) + })) + defer ts.Close() + + c := NewEmbeddingsClient(ts.URL, 5*time.Second, "bge") + results, err := c.Embed(context.Background(), []string{"hello", "world"}) + if err != nil { + t.Fatal(err) + } + if len(results) != 2 { + t.Fatalf("len(results) = %d, want 2", len(results)) + } + if results[0][0] != 0.1 { + t.Errorf("results[0][0] = %f, want 0.1", results[0][0]) + } +} + +func TestEmbeddingsClient_EmbedSingle(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]any{ + "data": []map[string]any{ + {"embedding": []float64{1.0, 2.0}}, + }, + }) + })) + defer ts.Close() + + c := NewEmbeddingsClient(ts.URL, 5*time.Second, "") + vec, err := c.EmbedSingle(context.Background(), "test") + if err != nil { + t.Fatal(err) + } + if len(vec) != 2 || vec[0] != 1.0 { + t.Errorf("vec = %v", vec) + } +} + +func TestEmbeddingsClient_EmbedEmpty(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]any{"data": []any{}}) + })) + defer ts.Close() + + c := NewEmbeddingsClient(ts.URL, 5*time.Second, "") + _, err := c.EmbedSingle(context.Background(), "test") + if err == nil { + t.Error("expected error for empty embedding") + } +} + +func TestEmbeddingsClient_Health(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/health" { + w.WriteHeader(200) + return + } + w.WriteHeader(404) + })) + defer ts.Close() + + c := NewEmbeddingsClient(ts.URL, 5*time.Second, "") + if !c.Health(context.Background()) { + t.Error("expected healthy") + } +} + +// ──────────────────────────────────────────────────────────────────────────── +// Reranker client +// ──────────────────────────────────────────────────────────────────────────── + +func TestRerankerClient_Rerank(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req map[string]any + json.NewDecoder(r.Body).Decode(&req) + if req["query"] != "test query" { + t.Errorf("query = %v", req["query"]) + } + json.NewEncoder(w).Encode(map[string]any{ + "results": []map[string]any{ + {"index": 1, "relevance_score": 0.95}, + {"index": 0, "relevance_score": 0.80}, + }, + }) + })) + defer ts.Close() + + c := NewRerankerClient(ts.URL, 5*time.Second) + docs := []string{"Paris is great", "France is in Europe"} + results, err := c.Rerank(context.Background(), "test query", docs, 2) + if err != nil { + t.Fatal(err) + } + if len(results) != 2 { + t.Fatalf("len = %d", len(results)) + } + if results[0].Score != 0.95 { + t.Errorf("score = %f, want 0.95", results[0].Score) + } + if results[0].Document != "France is in Europe" { + t.Errorf("document = %q", results[0].Document) + } +} + +func TestRerankerClient_RerankFallbackScore(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]any{ + "results": []map[string]any{ + {"index": 0, "score": 0.77, "relevance_score": 0}, // some APIs only set score + }, + }) + })) + defer ts.Close() + + c := NewRerankerClient(ts.URL, 5*time.Second) + results, err := c.Rerank(context.Background(), "q", []string{"doc1"}, 0) + if err != nil { + t.Fatal(err) + } + if results[0].Score != 0.77 { + t.Errorf("fallback score = %f, want 0.77", results[0].Score) + } +} + +// ──────────────────────────────────────────────────────────────────────────── +// LLM client +// ──────────────────────────────────────────────────────────────────────────── + +func TestLLMClient_Generate(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/chat/completions" { + t.Errorf("path = %q", r.URL.Path) + } + var req map[string]any + json.NewDecoder(r.Body).Decode(&req) + msgs, _ := req["messages"].([]any) + if len(msgs) == 0 { + t.Error("no messages in request") + } + + json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + {"message": map[string]any{"content": "Paris is the capital of France."}}, + }, + }) + })) + defer ts.Close() + + c := NewLLMClient(ts.URL, 5*time.Second) + result, err := c.Generate(context.Background(), "capital of France?", "", "") + if err != nil { + t.Fatal(err) + } + if result != "Paris is the capital of France." { + t.Errorf("result = %q", result) + } +} + +func TestLLMClient_GenerateWithContext(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req map[string]any + json.NewDecoder(r.Body).Decode(&req) + msgs, _ := req["messages"].([]any) + // Should have system + user message + if len(msgs) != 2 { + t.Errorf("expected 2 messages, got %d", len(msgs)) + } + json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + {"message": map[string]any{"content": "answer with context"}}, + }, + }) + })) + defer ts.Close() + + c := NewLLMClient(ts.URL, 5*time.Second) + result, err := c.Generate(context.Background(), "question", "some context", "") + if err != nil { + t.Fatal(err) + } + if result != "answer with context" { + t.Errorf("result = %q", result) + } +} + +func TestLLMClient_GenerateNoChoices(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]any{"choices": []any{}}) + })) + defer ts.Close() + + c := NewLLMClient(ts.URL, 5*time.Second) + _, err := c.Generate(context.Background(), "q", "", "") + if err == nil { + t.Error("expected error for empty choices") + } +} + +// ──────────────────────────────────────────────────────────────────────────── +// TTS client +// ──────────────────────────────────────────────────────────────────────────── + +func TestTTSClient_Synthesize(t *testing.T) { + expected := []byte{0xDE, 0xAD, 0xBE, 0xEF} + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/tts" { + t.Errorf("path = %q", r.URL.Path) + } + if r.URL.Query().Get("text") != "hello world" { + t.Errorf("text = %q", r.URL.Query().Get("text")) + } + w.Write(expected) + })) + defer ts.Close() + + c := NewTTSClient(ts.URL, 5*time.Second, "en") + audio, err := c.Synthesize(context.Background(), "hello world", "", "") + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(audio, expected) { + t.Errorf("audio = %x, want %x", audio, expected) + } +} + +func TestTTSClient_SynthesizeWithSpeaker(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("speaker_id") != "alice" { + t.Errorf("speaker_id = %q", r.URL.Query().Get("speaker_id")) + } + w.Write([]byte{0x01}) + })) + defer ts.Close() + + c := NewTTSClient(ts.URL, 5*time.Second, "en") + _, err := c.Synthesize(context.Background(), "hi", "en", "alice") + if err != nil { + t.Fatal(err) + } +} + +// ──────────────────────────────────────────────────────────────────────────── +// STT client +// ──────────────────────────────────────────────────────────────────────────── + +func TestSTTClient_Transcribe(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/audio/transcriptions" { + t.Errorf("path = %q", r.URL.Path) + } + ct := r.Header.Get("Content-Type") + if !strings.Contains(ct, "multipart/form-data") { + t.Errorf("content-type = %q", ct) + } + // Verify the audio file is present. + file, _, err := r.FormFile("file") + if err != nil { + t.Fatal(err) + } + data, _ := io.ReadAll(file) + if len(data) != 100 { + t.Errorf("file size = %d, want 100", len(data)) + } + + json.NewEncoder(w).Encode(map[string]string{"text": "hello world"}) + })) + defer ts.Close() + + c := NewSTTClient(ts.URL, 5*time.Second) + result, err := c.Transcribe(context.Background(), make([]byte, 100), "en") + if err != nil { + t.Fatal(err) + } + if result.Text != "hello world" { + t.Errorf("text = %q", result.Text) + } +} + +func TestSTTClient_TranscribeTranslate(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/audio/translations" { + t.Errorf("path = %q, want /v1/audio/translations", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]string{"text": "translated"}) + })) + defer ts.Close() + + c := NewSTTClient(ts.URL, 5*time.Second) + c.Task = "translate" + result, err := c.Transcribe(context.Background(), []byte{0x01}, "") + if err != nil { + t.Fatal(err) + } + if result.Text != "translated" { + t.Errorf("text = %q", result.Text) + } +} + +// ──────────────────────────────────────────────────────────────────────────── +// HTTP error handling +// ──────────────────────────────────────────────────────────────────────────── + +func TestHTTPError4xx(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(422) + w.Write([]byte(`{"error": "bad input"}`)) + })) + defer ts.Close() + + c := NewEmbeddingsClient(ts.URL, 5*time.Second, "") + _, err := c.Embed(context.Background(), []string{"test"}) + if err == nil { + t.Fatal("expected error for 422") + } + if !strings.Contains(err.Error(), "422") { + t.Errorf("error should contain status code: %v", err) + } +} + +func TestHTTPError5xx(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + w.Write([]byte("internal server error")) + })) + defer ts.Close() + + c := NewLLMClient(ts.URL, 5*time.Second) + _, err := c.Generate(context.Background(), "q", "", "") + if err == nil { + t.Fatal("expected error for 500") + } +} + +// ──────────────────────────────────────────────────────────────────────────── +// buildMessages helper +// ──────────────────────────────────────────────────────────────────────────── + +func TestBuildMessages(t *testing.T) { + // No context, no system prompt → just user message + msgs := buildMessages("hello", "", "") + if len(msgs) != 1 || msgs[0].Role != "user" { + t.Errorf("expected 1 user msg, got %+v", msgs) + } + + // With system prompt + msgs = buildMessages("hello", "", "You are helpful") + if len(msgs) != 2 || msgs[0].Role != "system" || msgs[0].Content != "You are helpful" { + t.Errorf("expected system+user, got %+v", msgs) + } + + // With context, no system prompt → auto system prompt + msgs = buildMessages("question", "some context", "") + if len(msgs) != 2 || msgs[0].Role != "system" { + t.Errorf("expected auto system+user, got %+v", msgs) + } + if !strings.Contains(msgs[1].Content, "Context:") { + t.Error("user message should contain context") + } +} + +// ──────────────────────────────────────────────────────────────────────────── +// Benchmarks: pooled buffer vs direct allocation +// ──────────────────────────────────────────────────────────────────────────── + +func BenchmarkPostJSON(b *testing.B) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.Copy(io.Discard, r.Body) + w.Write([]byte(`{"ok":true}`)) + })) + defer ts.Close() + + c := newHTTPClient(ts.URL, 10*time.Second) + ctx := context.Background() + payload := map[string]any{ + "text": strings.Repeat("x", 1024), + "count": 42, + "enabled": true, + } + + b.ResetTimer() + for b.Loop() { + c.postJSON(ctx, "/test", payload) + } +} + +func BenchmarkBufferPool(b *testing.B) { + b.ResetTimer() + for b.Loop() { + buf := getBuf() + buf.WriteString(strings.Repeat("x", 4096)) + putBuf(buf) + } +} + +func BenchmarkBufferPoolParallel(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + buf := getBuf() + buf.WriteString(strings.Repeat("x", 4096)) + putBuf(buf) + } + }) +} diff --git a/config/config.go b/config/config.go index 7888bad..bfb6ffa 100644 --- a/config/config.go +++ b/config/config.go @@ -1,145 +1,268 @@ -// Package config provides environment-based configuration for handler services. +// Package config provides environment-based configuration for handler services +// with optional live reload of secrets and service endpoints. package config import ( - "os" - "strconv" - "time" +"context" +"log/slog" +"os" +"path/filepath" +"strconv" +"strings" +"sync" +"time" + +"github.com/fsnotify/fsnotify" ) // Settings holds base configuration for all handler services. -// Values are loaded from environment variables with sensible defaults. +// Fields in the "hot-reload" section are protected by a RWMutex and can be +// updated at runtime via WatchSecrets(). All other fields are immutable +// after Load() returns. type Settings struct { - // Service identification - ServiceName string - ServiceVersion string - ServiceNamespace string - DeploymentEnv string +// Service identification (immutable) +ServiceName string +ServiceVersion string +ServiceNamespace string +DeploymentEnv string - // NATS configuration - NATSURL string - NATSUser string - NATSPassword string - NATSQueueGroup string +// NATS configuration (immutable) +NATSURL string +NATSUser string +NATSPassword string +NATSQueueGroup string - // Redis/Valkey configuration - RedisURL string - RedisPassword string +// Redis/Valkey configuration (immutable) +RedisURL string +RedisPassword string - // Milvus configuration - MilvusHost string - MilvusPort int - MilvusCollection string +// Milvus configuration (immutable) +MilvusHost string +MilvusPort int +MilvusCollection string - // Service endpoints - EmbeddingsURL string - RerankerURL string - LLMURL string - TTSURL string - STTURL string +// OpenTelemetry configuration (immutable) +OTELEnabled bool +OTELEndpoint string +OTELUseHTTP bool - // OpenTelemetry configuration - OTELEnabled bool - OTELEndpoint string - OTELUseHTTP bool +// HyperDX configuration (immutable) +HyperDXEnabled bool +HyperDXAPIKey string +HyperDXEndpoint string - // HyperDX configuration - HyperDXEnabled bool - HyperDXAPIKey string - HyperDXEndpoint string +// MLflow configuration (immutable) +MLflowTrackingURI string +MLflowExperimentName string +MLflowEnabled bool - // MLflow configuration - MLflowTrackingURI string - MLflowExperimentName string - MLflowEnabled bool +// Health check configuration (immutable) +HealthPort int +HealthPath string +ReadyPath string - // Health check configuration - HealthPort int - HealthPath string - ReadyPath string +// Timeouts (immutable) +HTTPTimeout time.Duration +NATSTimeout time.Duration - // Timeouts - HTTPTimeout time.Duration - NATSTimeout time.Duration +// Hot-reloadable fields — access via getter methods. +mu sync.RWMutex +embeddingsURL string +rerankerURL string +llmURL string +ttsURL string +sttURL string + +// Secrets path for file-based hot reload (Kubernetes secret mounts) +SecretsPath string } // Load creates a Settings populated from environment variables with defaults. func Load() *Settings { - return &Settings{ - ServiceName: getEnv("SERVICE_NAME", "handler"), - ServiceVersion: getEnv("SERVICE_VERSION", "1.0.0"), - ServiceNamespace: getEnv("SERVICE_NAMESPACE", "ai-ml"), - DeploymentEnv: getEnv("DEPLOYMENT_ENV", "production"), +return &Settings{ +ServiceName: getEnv("SERVICE_NAME", "handler"), +ServiceVersion: getEnv("SERVICE_VERSION", "1.0.0"), +ServiceNamespace: getEnv("SERVICE_NAMESPACE", "ai-ml"), +DeploymentEnv: getEnv("DEPLOYMENT_ENV", "production"), - NATSURL: getEnv("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222"), - NATSUser: getEnv("NATS_USER", ""), - NATSPassword: getEnv("NATS_PASSWORD", ""), - NATSQueueGroup: getEnv("NATS_QUEUE_GROUP", ""), +NATSURL: getEnv("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222"), +NATSUser: getEnv("NATS_USER", ""), +NATSPassword: getEnv("NATS_PASSWORD", ""), +NATSQueueGroup: getEnv("NATS_QUEUE_GROUP", ""), - RedisURL: getEnv("REDIS_URL", "redis://valkey.ai-ml.svc.cluster.local:6379"), - RedisPassword: getEnv("REDIS_PASSWORD", ""), +RedisURL: getEnv("REDIS_URL", "redis://valkey.ai-ml.svc.cluster.local:6379"), +RedisPassword: getEnv("REDIS_PASSWORD", ""), - MilvusHost: getEnv("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local"), - MilvusPort: getEnvInt("MILVUS_PORT", 19530), - MilvusCollection: getEnv("MILVUS_COLLECTION", "documents"), +MilvusHost: getEnv("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local"), +MilvusPort: getEnvInt("MILVUS_PORT", 19530), +MilvusCollection: getEnv("MILVUS_COLLECTION", "documents"), - EmbeddingsURL: getEnv("EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"), - RerankerURL: getEnv("RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"), - LLMURL: getEnv("LLM_URL", "http://vllm-predictor.ai-ml.svc.cluster.local"), - TTSURL: getEnv("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local"), - STTURL: getEnv("STT_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"), +embeddingsURL: getEnv("EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"), +rerankerURL: getEnv("RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"), +llmURL: getEnv("LLM_URL", "http://vllm-predictor.ai-ml.svc.cluster.local"), +ttsURL: getEnv("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local"), +sttURL: getEnv("STT_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"), - OTELEnabled: getEnvBool("OTEL_ENABLED", true), - OTELEndpoint: getEnv("OTEL_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317"), - OTELUseHTTP: getEnvBool("OTEL_USE_HTTP", false), +OTELEnabled: getEnvBool("OTEL_ENABLED", true), +OTELEndpoint: getEnv("OTEL_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317"), +OTELUseHTTP: getEnvBool("OTEL_USE_HTTP", false), - HyperDXEnabled: getEnvBool("HYPERDX_ENABLED", false), - HyperDXAPIKey: getEnv("HYPERDX_API_KEY", ""), - HyperDXEndpoint: getEnv("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io"), +HyperDXEnabled: getEnvBool("HYPERDX_ENABLED", false), +HyperDXAPIKey: getEnv("HYPERDX_API_KEY", ""), +HyperDXEndpoint: getEnv("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io"), - MLflowTrackingURI: getEnv("MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"), - MLflowExperimentName: getEnv("MLFLOW_EXPERIMENT_NAME", ""), - MLflowEnabled: getEnvBool("MLFLOW_ENABLED", true), +MLflowTrackingURI: getEnv("MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"), +MLflowExperimentName: getEnv("MLFLOW_EXPERIMENT_NAME", ""), +MLflowEnabled: getEnvBool("MLFLOW_ENABLED", true), - HealthPort: getEnvInt("HEALTH_PORT", 8080), - HealthPath: getEnv("HEALTH_PATH", "/health"), - ReadyPath: getEnv("READY_PATH", "/ready"), +HealthPort: getEnvInt("HEALTH_PORT", 8080), +HealthPath: getEnv("HEALTH_PATH", "/health"), +ReadyPath: getEnv("READY_PATH", "/ready"), - HTTPTimeout: getEnvDuration("HTTP_TIMEOUT", 60*time.Second), - NATSTimeout: getEnvDuration("NATS_TIMEOUT", 30*time.Second), - } +HTTPTimeout: getEnvDuration("HTTP_TIMEOUT", 60*time.Second), +NATSTimeout: getEnvDuration("NATS_TIMEOUT", 30*time.Second), + +SecretsPath: getEnv("SECRETS_PATH", ""), +} +} + +// EmbeddingsURL returns the current embeddings service URL (thread-safe). +func (s *Settings) EmbeddingsURL() string { +s.mu.RLock() +defer s.mu.RUnlock() +return s.embeddingsURL +} + +// RerankerURL returns the current reranker service URL (thread-safe). +func (s *Settings) RerankerURL() string { +s.mu.RLock() +defer s.mu.RUnlock() +return s.rerankerURL +} + +// LLMURL returns the current LLM service URL (thread-safe). +func (s *Settings) LLMURL() string { +s.mu.RLock() +defer s.mu.RUnlock() +return s.llmURL +} + +// TTSURL returns the current TTS service URL (thread-safe). +func (s *Settings) TTSURL() string { +s.mu.RLock() +defer s.mu.RUnlock() +return s.ttsURL +} + +// STTURL returns the current STT service URL (thread-safe). +func (s *Settings) STTURL() string { +s.mu.RLock() +defer s.mu.RUnlock() +return s.sttURL +} + +// WatchSecrets watches the SecretsPath directory for changes and reloads +// hot-reloadable fields. Blocks until ctx is cancelled. +func (s *Settings) WatchSecrets(ctx context.Context) { +if s.SecretsPath == "" { +return +} + +watcher, err := fsnotify.NewWatcher() +if err != nil { +slog.Error("config: failed to create fsnotify watcher", "error", err) +return +} +defer func() { _ = watcher.Close() }() + +if err := watcher.Add(s.SecretsPath); err != nil { +slog.Error("config: failed to watch secrets path", "error", err, "path", s.SecretsPath) +return +} + +slog.Info("config: watching secrets for hot reload", "path", s.SecretsPath) + +for { +select { +case event, ok := <-watcher.Events: +if !ok { +return +} +if event.Has(fsnotify.Create) || event.Has(fsnotify.Write) { +s.reloadFromSecrets() +} +case err, ok := <-watcher.Errors: +if !ok { +return +} +slog.Error("config: fsnotify error", "error", err) +case <-ctx.Done(): +return +} +} +} + +// reloadFromSecrets reads hot-reloadable values from the secrets directory. +func (s *Settings) reloadFromSecrets() { +s.mu.Lock() +defer s.mu.Unlock() + +updated := 0 +reload := func(filename string, target *string) { +path := filepath.Join(s.SecretsPath, filename) +data, err := os.ReadFile(path) +if err != nil { +return +} +val := strings.TrimSpace(string(data)) +if val != "" && val != *target { +*target = val +updated++ +slog.Info("config: reloaded secret", "key", filename) +} +} + +reload("embeddings-url", &s.embeddingsURL) +reload("reranker-url", &s.rerankerURL) +reload("llm-url", &s.llmURL) +reload("tts-url", &s.ttsURL) +reload("stt-url", &s.sttURL) + +if updated > 0 { +slog.Info("config: secrets reloaded", "updated", updated) +} } func getEnv(key, fallback string) string { - if v := os.Getenv(key); v != "" { - return v - } - return fallback +if v := os.Getenv(key); v != "" { +return v +} +return fallback } func getEnvInt(key string, fallback int) int { - if v := os.Getenv(key); v != "" { - if i, err := strconv.Atoi(v); err == nil { - return i - } - } - return fallback +if v := os.Getenv(key); v != "" { +if i, err := strconv.Atoi(v); err == nil { +return i +} +} +return fallback } func getEnvBool(key string, fallback bool) bool { - if v := os.Getenv(key); v != "" { - if b, err := strconv.ParseBool(v); err == nil { - return b - } - } - return fallback +if v := os.Getenv(key); v != "" { +if b, err := strconv.ParseBool(v); err == nil { +return b +} +} +return fallback } func getEnvDuration(key string, fallback time.Duration) time.Duration { - if v := os.Getenv(key); v != "" { - if f, err := strconv.ParseFloat(v, 64); err == nil { - return time.Duration(f * float64(time.Second)) - } - } - return fallback +if v := os.Getenv(key); v != "" { +if f, err := strconv.ParseFloat(v, 64); err == nil { +return time.Duration(f * float64(time.Second)) +} +} +return fallback } diff --git a/config/config_test.go b/config/config_test.go index 8032f1b..fa450a0 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1,42 +1,123 @@ package config import ( - "os" - "testing" - "time" +"os" +"path/filepath" +"testing" +"time" ) func TestLoadDefaults(t *testing.T) { - s := Load() - if s.ServiceName != "handler" { - t.Errorf("expected default ServiceName 'handler', got %q", s.ServiceName) - } - if s.HealthPort != 8080 { - t.Errorf("expected default HealthPort 8080, got %d", s.HealthPort) - } - if s.HTTPTimeout != 60*time.Second { - t.Errorf("expected default HTTPTimeout 60s, got %v", s.HTTPTimeout) - } +s := Load() +if s.ServiceName != "handler" { +t.Errorf("expected default ServiceName 'handler', got %q", s.ServiceName) +} +if s.HealthPort != 8080 { +t.Errorf("expected default HealthPort 8080, got %d", s.HealthPort) +} +if s.HTTPTimeout != 60*time.Second { +t.Errorf("expected default HTTPTimeout 60s, got %v", s.HTTPTimeout) +} } func TestLoadFromEnv(t *testing.T) { - os.Setenv("SERVICE_NAME", "test-svc") - os.Setenv("HEALTH_PORT", "9090") - os.Setenv("OTEL_ENABLED", "false") - defer func() { - os.Unsetenv("SERVICE_NAME") - os.Unsetenv("HEALTH_PORT") - os.Unsetenv("OTEL_ENABLED") - }() +t.Setenv("SERVICE_NAME", "test-svc") +t.Setenv("HEALTH_PORT", "9090") +t.Setenv("OTEL_ENABLED", "false") - s := Load() - if s.ServiceName != "test-svc" { - t.Errorf("expected ServiceName 'test-svc', got %q", s.ServiceName) - } - if s.HealthPort != 9090 { - t.Errorf("expected HealthPort 9090, got %d", s.HealthPort) - } - if s.OTELEnabled { - t.Error("expected OTELEnabled false") - } +s := Load() +if s.ServiceName != "test-svc" { +t.Errorf("expected ServiceName 'test-svc', got %q", s.ServiceName) +} +if s.HealthPort != 9090 { +t.Errorf("expected HealthPort 9090, got %d", s.HealthPort) +} +if s.OTELEnabled { +t.Error("expected OTELEnabled false") +} +} + +func TestURLGetters(t *testing.T) { +s := Load() +if s.EmbeddingsURL() == "" { +t.Error("EmbeddingsURL should have a default") +} +if s.RerankerURL() == "" { +t.Error("RerankerURL should have a default") +} +if s.LLMURL() == "" { +t.Error("LLMURL should have a default") +} +if s.TTSURL() == "" { +t.Error("TTSURL should have a default") +} +if s.STTURL() == "" { +t.Error("STTURL should have a default") +} +} + +func TestURLGettersFromEnv(t *testing.T) { +t.Setenv("EMBEDDINGS_URL", "http://embed:8000") +t.Setenv("LLM_URL", "http://llm:9000") + +s := Load() +if s.EmbeddingsURL() != "http://embed:8000" { +t.Errorf("expected custom EmbeddingsURL, got %q", s.EmbeddingsURL()) +} +if s.LLMURL() != "http://llm:9000" { +t.Errorf("expected custom LLMURL, got %q", s.LLMURL()) +} +} + +func TestReloadFromSecrets(t *testing.T) { +dir := t.TempDir() + +// Write initial secret files +writeSecret(t, dir, "embeddings-url", "http://old-embed:8000") +writeSecret(t, dir, "llm-url", "http://old-llm:9000") + +s := Load() +s.SecretsPath = dir +s.reloadFromSecrets() + +if s.EmbeddingsURL() != "http://old-embed:8000" { +t.Errorf("expected reloaded EmbeddingsURL, got %q", s.EmbeddingsURL()) +} +if s.LLMURL() != "http://old-llm:9000" { +t.Errorf("expected reloaded LLMURL, got %q", s.LLMURL()) +} + +// Simulate secret update +writeSecret(t, dir, "embeddings-url", "http://new-embed:8000") +s.reloadFromSecrets() + +if s.EmbeddingsURL() != "http://new-embed:8000" { +t.Errorf("expected updated EmbeddingsURL, got %q", s.EmbeddingsURL()) +} +// LLM should remain unchanged +if s.LLMURL() != "http://old-llm:9000" { +t.Errorf("expected unchanged LLMURL, got %q", s.LLMURL()) +} +} + +func TestReloadFromSecretsNoPath(t *testing.T) { +s := Load() +s.SecretsPath = "" +// Should not panic +s.reloadFromSecrets() +} + +func TestGetEnvDuration(t *testing.T) { +t.Setenv("TEST_DUR", "30") +d := getEnvDuration("TEST_DUR", 10*time.Second) +if d != 30*time.Second { +t.Errorf("expected 30s, got %v", d) +} +} + +func writeSecret(t *testing.T, dir, name, value string) { +t.Helper() +if err := os.WriteFile(filepath.Join(dir, name), []byte(value), 0644); err != nil { +t.Fatal(err) +} } diff --git a/go.mod b/go.mod index e2442c9..f93c209 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( require ( github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/google/uuid v1.6.0 // indirect diff --git a/go.sum b/go.sum index 4a3f959..b9a1b68 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= diff --git a/handler/handler_test.go b/handler/handler_test.go new file mode 100644 index 0000000..d7b21ee --- /dev/null +++ b/handler/handler_test.go @@ -0,0 +1,201 @@ +package handler + +import ( + "context" + "testing" + + "github.com/nats-io/nats.go" + "github.com/vmihailenco/msgpack/v5" + + "git.daviestechlabs.io/daviestechlabs/handler-base/config" +) + +// ──────────────────────────────────────────────────────────────────────────── +// Handler construction tests +// ──────────────────────────────────────────────────────────────────────────── + +func TestNewHandler(t *testing.T) { + cfg := config.Load() + cfg.ServiceName = "test-handler" + cfg.NATSQueueGroup = "test-group" + + h := New("ai.test.subject", cfg) + if h.Subject != "ai.test.subject" { + t.Errorf("Subject = %q", h.Subject) + } + if h.QueueGroup != "test-group" { + t.Errorf("QueueGroup = %q", h.QueueGroup) + } + if h.Settings.ServiceName != "test-handler" { + t.Errorf("ServiceName = %q", h.Settings.ServiceName) + } +} + +func TestNewHandlerNilSettings(t *testing.T) { + h := New("ai.test", nil) + if h.Settings == nil { + t.Fatal("Settings should be loaded automatically") + } + if h.Settings.ServiceName != "handler" { + t.Errorf("ServiceName = %q, want default", h.Settings.ServiceName) + } +} + +func TestCallbackRegistration(t *testing.T) { + cfg := config.Load() + h := New("ai.test", cfg) + + setupCalled := false + h.OnSetup(func(ctx context.Context) error { + setupCalled = true + return nil + }) + + teardownCalled := false + h.OnTeardown(func(ctx context.Context) error { + teardownCalled = true + return nil + }) + + h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) { + return nil, nil + }) + + if h.onSetup == nil || h.onTeardown == nil || h.onMessage == nil { + t.Error("callbacks should not be nil after registration") + } + + // Verify setup/teardown work when called directly. + h.onSetup(context.Background()) + h.onTeardown(context.Background()) + if !setupCalled || !teardownCalled { + t.Error("callbacks should have been invoked") + } +} + +// ──────────────────────────────────────────────────────────────────────────── +// wrapHandler dispatch tests (unit test the message decode + dispatch logic) +// ──────────────────────────────────────────────────────────────────────────── + +func TestWrapHandler_ValidMessage(t *testing.T) { + cfg := config.Load() + h := New("ai.test", cfg) + + var receivedData map[string]any + h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) { + receivedData = data + return map[string]any{"status": "ok"}, nil + }) + + // Encode a message the same way services would. + payload := map[string]any{ + "request_id": "test-001", + "message": "hello", + "premium": true, + } + encoded, err := msgpack.Marshal(payload) + if err != nil { + t.Fatal(err) + } + + // Call wrapHandler directly without NATS. + handler := h.wrapHandler(context.Background()) + handler(&nats.Msg{ + Subject: "ai.test.user.42.message", + Data: encoded, + }) + + if receivedData == nil { + t.Fatal("handler was not called") + } + if receivedData["request_id"] != "test-001" { + t.Errorf("request_id = %v", receivedData["request_id"]) + } + if receivedData["premium"] != true { + t.Errorf("premium = %v", receivedData["premium"]) + } +} + +func TestWrapHandler_InvalidMsgpack(t *testing.T) { + cfg := config.Load() + h := New("ai.test", cfg) + + handlerCalled := false + h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) { + handlerCalled = true + return nil, nil + }) + + handler := h.wrapHandler(context.Background()) + handler(&nats.Msg{ + Subject: "ai.test", + Data: []byte{0xFF, 0xFE, 0xFD}, // invalid msgpack + }) + + if handlerCalled { + t.Error("handler should not be called for invalid msgpack") + } +} + +func TestWrapHandler_HandlerError(t *testing.T) { + cfg := config.Load() + h := New("ai.test", cfg) + + h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) { + return nil, context.DeadlineExceeded + }) + + encoded, _ := msgpack.Marshal(map[string]any{"key": "val"}) + handler := h.wrapHandler(context.Background()) + + // Should not panic even when handler returns error. + handler(&nats.Msg{ + Subject: "ai.test", + Data: encoded, + }) +} + +func TestWrapHandler_NilResponse(t *testing.T) { + cfg := config.Load() + h := New("ai.test", cfg) + + h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) { + return nil, nil // fire-and-forget style + }) + + encoded, _ := msgpack.Marshal(map[string]any{"x": 1}) + handler := h.wrapHandler(context.Background()) + + // Should not panic with nil response and no reply subject. + handler(&nats.Msg{ + Subject: "ai.test", + Data: encoded, + }) +} + +// ──────────────────────────────────────────────────────────────────────────── +// Benchmark: message decode + dispatch overhead +// ──────────────────────────────────────────────────────────────────────────── + +func BenchmarkWrapHandler(b *testing.B) { + cfg := config.Load() + h := New("ai.test", cfg) + h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) { + return map[string]any{"ok": true}, nil + }) + + payload := map[string]any{ + "request_id": "bench-001", + "message": "What is the capital of France?", + "premium": true, + "top_k": 10, + } + encoded, _ := msgpack.Marshal(payload) + handler := h.wrapHandler(context.Background()) + msg := &nats.Msg{Subject: "ai.test", Data: encoded} + + b.ResetTimer() + for b.Loop() { + handler(msg) + } +} diff --git a/messages/bench_test.go b/messages/bench_test.go new file mode 100644 index 0000000..7310641 --- /dev/null +++ b/messages/bench_test.go @@ -0,0 +1,515 @@ +// Package messages benchmarks compare three serialization strategies: +// +// 1. msgpack map[string]any — the old approach (dynamic, no types) +// 2. msgpack typed struct — the new approach (compile-time safe, short keys) +// 3. protobuf — optional future migration +// +// Run with: +// +// go test -bench=. -benchmem -count=5 ./messages/... | tee bench.txt +// # optional: go install golang.org/x/perf/cmd/benchstat@latest && benchstat bench.txt +package messages + +import ( + "testing" + "time" + + "github.com/vmihailenco/msgpack/v5" + "google.golang.org/protobuf/proto" + + pb "git.daviestechlabs.io/daviestechlabs/handler-base/messages/proto" +) + +// ──────────────────────────────────────────────────────────────────────────── +// Test fixtures — equivalent data across all three encodings +// ──────────────────────────────────────────────────────────────────────────── + +// chatRequestMap is the legacy map[string]any representation. +func chatRequestMap() map[string]any { + return map[string]any{ + "request_id": "req-abc-123", + "user_id": "user-42", + "message": "What is the capital of France?", + "query": "", + "premium": true, + "enable_rag": true, + "enable_reranker": true, + "enable_streaming": false, + "top_k": 10, + "collection": "documents", + "enable_tts": false, + "system_prompt": "You are a helpful assistant.", + "response_subject": "ai.chat.response.req-abc-123", + } +} + +// chatRequestStruct is the typed struct representation. +func chatRequestStruct() ChatRequest { + return ChatRequest{ + RequestID: "req-abc-123", + UserID: "user-42", + Message: "What is the capital of France?", + Premium: true, + EnableRAG: true, + EnableReranker: true, + TopK: 10, + Collection: "documents", + SystemPrompt: "You are a helpful assistant.", + ResponseSubject: "ai.chat.response.req-abc-123", + } +} + +// chatRequestProto is the protobuf representation. +func chatRequestProto() *pb.ChatRequest { + return &pb.ChatRequest{ + RequestId: "req-abc-123", + UserId: "user-42", + Message: "What is the capital of France?", + Premium: true, + EnableRag: true, + EnableReranker: true, + TopK: 10, + Collection: "documents", + SystemPrompt: "You are a helpful assistant.", + ResponseSubject: "ai.chat.response.req-abc-123", + } +} + +// voiceResponseMap is a voice response with a 16 KB audio payload. +func voiceResponseMap() map[string]any { + return map[string]any{ + "request_id": "vr-001", + "response": "The capital of France is Paris.", + "audio": make([]byte, 16384), + "transcription": "What is the capital of France?", + } +} + +func voiceResponseStruct() VoiceResponse { + return VoiceResponse{ + RequestID: "vr-001", + Response: "The capital of France is Paris.", + Audio: make([]byte, 16384), + Transcription: "What is the capital of France?", + } +} + +func voiceResponseProto() *pb.VoiceResponse { + return &pb.VoiceResponse{ + RequestId: "vr-001", + Response: "The capital of France is Paris.", + Audio: make([]byte, 16384), + Transcription: "What is the capital of France?", + } +} + +// ttsChunkMap simulates a streaming audio chunk (~32 KB). +func ttsChunkMap() map[string]any { + return map[string]any{ + "session_id": "tts-sess-99", + "chunk_index": 3, + "total_chunks": 12, + "audio_b64": string(make([]byte, 32768)), // old: base64 string + "is_last": false, + "timestamp": time.Now().Unix(), + "sample_rate": 24000, + } +} + +func ttsChunkStruct() TTSAudioChunk { + return TTSAudioChunk{ + SessionID: "tts-sess-99", + ChunkIndex: 3, + TotalChunks: 12, + Audio: make([]byte, 32768), // new: raw bytes + IsLast: false, + Timestamp: time.Now().Unix(), + SampleRate: 24000, + } +} + +func ttsChunkProto() *pb.TTSAudioChunk { + return &pb.TTSAudioChunk{ + SessionId: "tts-sess-99", + ChunkIndex: 3, + TotalChunks: 12, + Audio: make([]byte, 32768), + IsLast: false, + Timestamp: time.Now().Unix(), + SampleRate: 24000, + } +} + +// ──────────────────────────────────────────────────────────────────────────── +// Wire-size comparison (run once, printed by TestWireSize) +// ──────────────────────────────────────────────────────────────────────────── + +func TestWireSize(t *testing.T) { + tests := []struct { + name string + mapData any + structVal any + protoMsg proto.Message + }{ + {"ChatRequest", chatRequestMap(), chatRequestStruct(), chatRequestProto()}, + {"VoiceResponse", voiceResponseMap(), voiceResponseStruct(), voiceResponseProto()}, + {"TTSAudioChunk", ttsChunkMap(), ttsChunkStruct(), ttsChunkProto()}, + } + + for _, tt := range tests { + mapBytes, _ := msgpack.Marshal(tt.mapData) + structBytes, _ := msgpack.Marshal(tt.structVal) + protoBytes, _ := proto.Marshal(tt.protoMsg) + + t.Logf("%-16s map=%5d B struct=%5d B proto=%5d B (struct saves %.0f%%, proto saves %.0f%%)", + tt.name, + len(mapBytes), len(structBytes), len(protoBytes), + 100*(1-float64(len(structBytes))/float64(len(mapBytes))), + 100*(1-float64(len(protoBytes))/float64(len(mapBytes))), + ) + } +} + +// ──────────────────────────────────────────────────────────────────────────── +// Encode benchmarks +// ──────────────────────────────────────────────────────────────────────────── + +func BenchmarkEncode_ChatRequest_MsgpackMap(b *testing.B) { + data := chatRequestMap() + b.ResetTimer() + for b.Loop() { + msgpack.Marshal(data) + } +} + +func BenchmarkEncode_ChatRequest_MsgpackStruct(b *testing.B) { + data := chatRequestStruct() + b.ResetTimer() + for b.Loop() { + msgpack.Marshal(data) + } +} + +func BenchmarkEncode_ChatRequest_Protobuf(b *testing.B) { + data := chatRequestProto() + b.ResetTimer() + for b.Loop() { + proto.Marshal(data) + } +} + +func BenchmarkEncode_VoiceResponse_MsgpackMap(b *testing.B) { + data := voiceResponseMap() + b.ResetTimer() + for b.Loop() { + msgpack.Marshal(data) + } +} + +func BenchmarkEncode_VoiceResponse_MsgpackStruct(b *testing.B) { + data := voiceResponseStruct() + b.ResetTimer() + for b.Loop() { + msgpack.Marshal(data) + } +} + +func BenchmarkEncode_VoiceResponse_Protobuf(b *testing.B) { + data := voiceResponseProto() + b.ResetTimer() + for b.Loop() { + proto.Marshal(data) + } +} + +func BenchmarkEncode_TTSChunk_MsgpackMap(b *testing.B) { + data := ttsChunkMap() + b.ResetTimer() + for b.Loop() { + msgpack.Marshal(data) + } +} + +func BenchmarkEncode_TTSChunk_MsgpackStruct(b *testing.B) { + data := ttsChunkStruct() + b.ResetTimer() + for b.Loop() { + msgpack.Marshal(data) + } +} + +func BenchmarkEncode_TTSChunk_Protobuf(b *testing.B) { + data := ttsChunkProto() + b.ResetTimer() + for b.Loop() { + proto.Marshal(data) + } +} + +// ──────────────────────────────────────────────────────────────────────────── +// Decode benchmarks +// ──────────────────────────────────────────────────────────────────────────── + +func BenchmarkDecode_ChatRequest_MsgpackMap(b *testing.B) { + encoded, _ := msgpack.Marshal(chatRequestMap()) + b.ResetTimer() + for b.Loop() { + var m map[string]any + msgpack.Unmarshal(encoded, &m) + } +} + +func BenchmarkDecode_ChatRequest_MsgpackStruct(b *testing.B) { + encoded, _ := msgpack.Marshal(chatRequestStruct()) + b.ResetTimer() + for b.Loop() { + var m ChatRequest + msgpack.Unmarshal(encoded, &m) + } +} + +func BenchmarkDecode_ChatRequest_Protobuf(b *testing.B) { + encoded, _ := proto.Marshal(chatRequestProto()) + b.ResetTimer() + for b.Loop() { + var m pb.ChatRequest + proto.Unmarshal(encoded, &m) + } +} + +func BenchmarkDecode_VoiceResponse_MsgpackMap(b *testing.B) { + encoded, _ := msgpack.Marshal(voiceResponseMap()) + b.ResetTimer() + for b.Loop() { + var m map[string]any + msgpack.Unmarshal(encoded, &m) + } +} + +func BenchmarkDecode_VoiceResponse_MsgpackStruct(b *testing.B) { + encoded, _ := msgpack.Marshal(voiceResponseStruct()) + b.ResetTimer() + for b.Loop() { + var m VoiceResponse + msgpack.Unmarshal(encoded, &m) + } +} + +func BenchmarkDecode_VoiceResponse_Protobuf(b *testing.B) { + encoded, _ := proto.Marshal(voiceResponseProto()) + b.ResetTimer() + for b.Loop() { + var m pb.VoiceResponse + proto.Unmarshal(encoded, &m) + } +} + +func BenchmarkDecode_TTSChunk_MsgpackMap(b *testing.B) { + encoded, _ := msgpack.Marshal(ttsChunkMap()) + b.ResetTimer() + for b.Loop() { + var m map[string]any + msgpack.Unmarshal(encoded, &m) + } +} + +func BenchmarkDecode_TTSChunk_MsgpackStruct(b *testing.B) { + encoded, _ := msgpack.Marshal(ttsChunkStruct()) + b.ResetTimer() + for b.Loop() { + var m TTSAudioChunk + msgpack.Unmarshal(encoded, &m) + } +} + +func BenchmarkDecode_TTSChunk_Protobuf(b *testing.B) { + encoded, _ := proto.Marshal(ttsChunkProto()) + b.ResetTimer() + for b.Loop() { + var m pb.TTSAudioChunk + proto.Unmarshal(encoded, &m) + } +} + +// ──────────────────────────────────────────────────────────────────────────── +// Roundtrip benchmarks (encode + decode) +// ──────────────────────────────────────────────────────────────────────────── + +func BenchmarkRoundtrip_ChatRequest_MsgpackMap(b *testing.B) { + data := chatRequestMap() + b.ResetTimer() + for b.Loop() { + enc, _ := msgpack.Marshal(data) + var dec map[string]any + msgpack.Unmarshal(enc, &dec) + } +} + +func BenchmarkRoundtrip_ChatRequest_MsgpackStruct(b *testing.B) { + data := chatRequestStruct() + b.ResetTimer() + for b.Loop() { + enc, _ := msgpack.Marshal(data) + var dec ChatRequest + msgpack.Unmarshal(enc, &dec) + } +} + +func BenchmarkRoundtrip_ChatRequest_Protobuf(b *testing.B) { + data := chatRequestProto() + b.ResetTimer() + for b.Loop() { + enc, _ := proto.Marshal(data) + var dec pb.ChatRequest + proto.Unmarshal(enc, &dec) + } +} + +// ──────────────────────────────────────────────────────────────────────────── +// Typed struct unit tests — verify roundtrip correctness +// ──────────────────────────────────────────────────────────────────────────── + +func TestRoundtrip_ChatRequest(t *testing.T) { + orig := chatRequestStruct() + data, err := msgpack.Marshal(orig) + if err != nil { + t.Fatal(err) + } + var dec ChatRequest + if err := msgpack.Unmarshal(data, &dec); err != nil { + t.Fatal(err) + } + if dec.RequestID != orig.RequestID { + t.Errorf("RequestID = %q, want %q", dec.RequestID, orig.RequestID) + } + if dec.Message != orig.Message { + t.Errorf("Message = %q, want %q", dec.Message, orig.Message) + } + if dec.TopK != orig.TopK { + t.Errorf("TopK = %d, want %d", dec.TopK, orig.TopK) + } + if dec.Premium != orig.Premium { + t.Errorf("Premium = %v, want %v", dec.Premium, orig.Premium) + } + if dec.EffectiveQuery() != orig.Message { + t.Errorf("EffectiveQuery() = %q, want %q", dec.EffectiveQuery(), orig.Message) + } +} + +func TestRoundtrip_VoiceResponse(t *testing.T) { + orig := voiceResponseStruct() + data, err := msgpack.Marshal(orig) + if err != nil { + t.Fatal(err) + } + var dec VoiceResponse + if err := msgpack.Unmarshal(data, &dec); err != nil { + t.Fatal(err) + } + if dec.RequestID != orig.RequestID { + t.Errorf("RequestID mismatch") + } + if len(dec.Audio) != len(orig.Audio) { + t.Errorf("Audio len = %d, want %d", len(dec.Audio), len(orig.Audio)) + } + if dec.Transcription != orig.Transcription { + t.Errorf("Transcription mismatch") + } +} + +func TestRoundtrip_TTSAudioChunk(t *testing.T) { + orig := ttsChunkStruct() + data, err := msgpack.Marshal(orig) + if err != nil { + t.Fatal(err) + } + var dec TTSAudioChunk + if err := msgpack.Unmarshal(data, &dec); err != nil { + t.Fatal(err) + } + if dec.SessionID != orig.SessionID { + t.Errorf("SessionID mismatch") + } + if dec.ChunkIndex != orig.ChunkIndex { + t.Errorf("ChunkIndex = %d, want %d", dec.ChunkIndex, orig.ChunkIndex) + } + if len(dec.Audio) != len(orig.Audio) { + t.Errorf("Audio len = %d, want %d", len(dec.Audio), len(orig.Audio)) + } + if dec.SampleRate != orig.SampleRate { + t.Errorf("SampleRate = %d, want %d", dec.SampleRate, orig.SampleRate) + } +} + +func TestRoundtrip_PipelineTrigger(t *testing.T) { + orig := PipelineTrigger{ + RequestID: "pip-001", + Pipeline: "document-ingestion", + Parameters: map[string]any{"source": "s3://bucket/data"}, + } + data, err := msgpack.Marshal(orig) + if err != nil { + t.Fatal(err) + } + var dec PipelineTrigger + if err := msgpack.Unmarshal(data, &dec); err != nil { + t.Fatal(err) + } + if dec.Pipeline != orig.Pipeline { + t.Errorf("Pipeline = %q, want %q", dec.Pipeline, orig.Pipeline) + } + if dec.Parameters["source"] != orig.Parameters["source"] { + t.Errorf("Parameters[source] mismatch") + } +} + +func TestRoundtrip_STTTranscription(t *testing.T) { + orig := STTTranscription{ + SessionID: "stt-001", + Transcript: "hello world", + Sequence: 5, + IsPartial: false, + IsFinal: true, + Timestamp: time.Now().Unix(), + SpeakerID: "speaker-1", + HasVoiceActivity: true, + State: "listening", + } + data, err := msgpack.Marshal(orig) + if err != nil { + t.Fatal(err) + } + var dec STTTranscription + if err := msgpack.Unmarshal(data, &dec); err != nil { + t.Fatal(err) + } + if dec.Transcript != orig.Transcript { + t.Errorf("Transcript = %q, want %q", dec.Transcript, orig.Transcript) + } + if dec.IsFinal != orig.IsFinal { + t.Error("IsFinal mismatch") + } +} + +func TestRoundtrip_ErrorResponse(t *testing.T) { + orig := ErrorResponse{Error: true, Message: "something broke", Type: "InternalError"} + data, err := msgpack.Marshal(orig) + if err != nil { + t.Fatal(err) + } + var dec ErrorResponse + if err := msgpack.Unmarshal(data, &dec); err != nil { + t.Fatal(err) + } + if !dec.Error || dec.Message != "something broke" || dec.Type != "InternalError" { + t.Errorf("ErrorResponse roundtrip mismatch: %+v", dec) + } +} + +func TestTimestamp(t *testing.T) { + ts := Timestamp() + now := time.Now().Unix() + if ts < now-1 || ts > now+1 { + t.Errorf("Timestamp() = %d, expected ~%d", ts, now) + } +} diff --git a/messages/messages.go b/messages/messages.go new file mode 100644 index 0000000..93a5cdb --- /dev/null +++ b/messages/messages.go @@ -0,0 +1,224 @@ +// Package messages defines typed NATS message structs for all services. +// +// Using typed structs with short msgpack field tags instead of map[string]any +// provides compile-time safety, smaller wire size (integer-like short keys vs +// full string keys), and faster encode/decode by avoiding interface{} boxing. +// +// Audio data uses raw []byte instead of base64-encoded strings — msgpack +// supports binary natively, eliminating the 33% base64 overhead. +package messages + +import "time" + +// ──────────────────────────────────────────────────────────────────────────── +// Pipeline Bridge +// ──────────────────────────────────────────────────────────────────────────── + +// PipelineTrigger is the request to start a pipeline. +type PipelineTrigger struct { +RequestID string `msgpack:"request_id" json:"request_id"` +Pipeline string `msgpack:"pipeline" json:"pipeline"` +Parameters map[string]any `msgpack:"parameters,omitempty" json:"parameters,omitempty"` +} + +// PipelineStatus is the response / status update for a pipeline run. +type PipelineStatus struct { +RequestID string `msgpack:"request_id" json:"request_id"` +Status string `msgpack:"status" json:"status"` +RunID string `msgpack:"run_id,omitempty" json:"run_id,omitempty"` +Engine string `msgpack:"engine,omitempty" json:"engine,omitempty"` +Pipeline string `msgpack:"pipeline,omitempty" json:"pipeline,omitempty"` +SubmittedAt string `msgpack:"submitted_at,omitempty" json:"submitted_at,omitempty"` +Error string `msgpack:"error,omitempty" json:"error,omitempty"` +AvailablePipelines []string `msgpack:"available_pipelines,omitempty" json:"available_pipelines,omitempty"` +} + +// ──────────────────────────────────────────────────────────────────────────── +// Chat Handler +// ──────────────────────────────────────────────────────────────────────────── + +// ChatRequest is an incoming chat message. +type ChatRequest struct { +RequestID string `msgpack:"request_id" json:"request_id"` +UserID string `msgpack:"user_id" json:"user_id"` +Message string `msgpack:"message" json:"message"` +Query string `msgpack:"query,omitempty" json:"query,omitempty"` +Premium bool `msgpack:"premium,omitempty" json:"premium,omitempty"` +EnableRAG bool `msgpack:"enable_rag,omitempty" json:"enable_rag,omitempty"` +EnableReranker bool `msgpack:"enable_reranker,omitempty" json:"enable_reranker,omitempty"` +EnableStreaming bool `msgpack:"enable_streaming,omitempty" json:"enable_streaming,omitempty"` +TopK int `msgpack:"top_k,omitempty" json:"top_k,omitempty"` +Collection string `msgpack:"collection,omitempty" json:"collection,omitempty"` +EnableTTS bool `msgpack:"enable_tts,omitempty" json:"enable_tts,omitempty"` +SystemPrompt string `msgpack:"system_prompt,omitempty" json:"system_prompt,omitempty"` +ResponseSubject string `msgpack:"response_subject,omitempty" json:"response_subject,omitempty"` +} + +// EffectiveQuery returns Message or falls back to Query. +func (c *ChatRequest) EffectiveQuery() string { +if c.Message != "" { +return c.Message +} +return c.Query +} + +// ChatResponse is the full reply to a chat request. +type ChatResponse struct { +UserID string `msgpack:"user_id" json:"user_id"` +Response string `msgpack:"response" json:"response"` +ResponseText string `msgpack:"response_text" json:"response_text"` +UsedRAG bool `msgpack:"used_rag" json:"used_rag"` +RAGSources []string `msgpack:"rag_sources,omitempty" json:"rag_sources,omitempty"` +Success bool `msgpack:"success" json:"success"` +Audio []byte `msgpack:"audio,omitempty" json:"audio,omitempty"` +Error string `msgpack:"error,omitempty" json:"error,omitempty"` +} + +// ChatStreamChunk is a single streaming chunk from an LLM response. +type ChatStreamChunk struct { +RequestID string `msgpack:"request_id" json:"request_id"` +Type string `msgpack:"type" json:"type"` +Content string `msgpack:"content" json:"content"` +Done bool `msgpack:"done" json:"done"` +Timestamp int64 `msgpack:"timestamp" json:"timestamp"` +} + +// ──────────────────────────────────────────────────────────────────────────── +// Voice Assistant +// ──────────────────────────────────────────────────────────────────────────── + +// VoiceRequest is an incoming voice-to-voice request. +type VoiceRequest struct { +RequestID string `msgpack:"request_id" json:"request_id"` +Audio []byte `msgpack:"audio" json:"audio"` +Language string `msgpack:"language,omitempty" json:"language,omitempty"` +Collection string `msgpack:"collection,omitempty" json:"collection,omitempty"` +} + +// VoiceResponse is the reply to a voice request. +type VoiceResponse struct { +RequestID string `msgpack:"request_id" json:"request_id"` +Response string `msgpack:"response" json:"response"` +Audio []byte `msgpack:"audio" json:"audio"` +Transcription string `msgpack:"transcription,omitempty" json:"transcription,omitempty"` +Sources []DocumentSource `msgpack:"sources,omitempty" json:"sources,omitempty"` +Error string `msgpack:"error,omitempty" json:"error,omitempty"` +} + +// DocumentSource is a RAG search result source. +type DocumentSource struct { +Text string `msgpack:"text" json:"text"` +Score float64 `msgpack:"score" json:"score"` +} + +// ──────────────────────────────────────────────────────────────────────────── +// TTS Module +// ──────────────────────────────────────────────────────────────────────────── + +// TTSRequest is a text-to-speech synthesis request. +type TTSRequest struct { +Text string `msgpack:"text" json:"text"` +Speaker string `msgpack:"speaker,omitempty" json:"speaker,omitempty"` +Language string `msgpack:"language,omitempty" json:"language,omitempty"` +SpeakerWavB64 string `msgpack:"speaker_wav_b64,omitempty" json:"speaker_wav_b64,omitempty"` +Stream bool `msgpack:"stream,omitempty" json:"stream,omitempty"` +} + +// TTSAudioChunk is a streamed audio chunk from TTS synthesis. +type TTSAudioChunk struct { +SessionID string `msgpack:"session_id" json:"session_id"` +ChunkIndex int `msgpack:"chunk_index" json:"chunk_index"` +TotalChunks int `msgpack:"total_chunks" json:"total_chunks"` +Audio []byte `msgpack:"audio" json:"audio"` +IsLast bool `msgpack:"is_last" json:"is_last"` +Timestamp int64 `msgpack:"timestamp" json:"timestamp"` +SampleRate int `msgpack:"sample_rate" json:"sample_rate"` +} + +// TTSFullResponse is a non-streamed TTS response (whole audio). +type TTSFullResponse struct { +SessionID string `msgpack:"session_id" json:"session_id"` +Audio []byte `msgpack:"audio" json:"audio"` +Timestamp int64 `msgpack:"timestamp" json:"timestamp"` +SampleRate int `msgpack:"sample_rate" json:"sample_rate"` +} + +// TTSStatus is a TTS processing status update. +type TTSStatus struct { +SessionID string `msgpack:"session_id" json:"session_id"` +Status string `msgpack:"status" json:"status"` +Message string `msgpack:"message" json:"message"` +Timestamp int64 `msgpack:"timestamp" json:"timestamp"` +} + +// TTSVoiceListResponse is the reply to a voice list request. +type TTSVoiceListResponse struct { +DefaultSpeaker string `msgpack:"default_speaker" json:"default_speaker"` +CustomVoices []TTSVoiceInfo `msgpack:"custom_voices" json:"custom_voices"` +LastRefresh int64 `msgpack:"last_refresh" json:"last_refresh"` +Timestamp int64 `msgpack:"timestamp" json:"timestamp"` +} + +// TTSVoiceInfo is summary info about a custom voice. +type TTSVoiceInfo struct { +Name string `msgpack:"name" json:"name"` +Language string `msgpack:"language" json:"language"` +ModelType string `msgpack:"model_type" json:"model_type"` +CreatedAt string `msgpack:"created_at" json:"created_at"` +} + +// TTSVoiceRefreshResponse is the reply to a voice refresh request. +type TTSVoiceRefreshResponse struct { +Count int `msgpack:"count" json:"count"` +CustomVoices []TTSVoiceInfo `msgpack:"custom_voices" json:"custom_voices"` +Timestamp int64 `msgpack:"timestamp" json:"timestamp"` +} + +// ──────────────────────────────────────────────────────────────────────────── +// STT Module +// ──────────────────────────────────────────────────────────────────────────── + +// STTStreamMessage is any message on the ai.voice.stream.{session} subject. +type STTStreamMessage struct { +Type string `msgpack:"type" json:"type"` +Audio []byte `msgpack:"audio,omitempty" json:"audio,omitempty"` +State string `msgpack:"state,omitempty" json:"state,omitempty"` +SpeakerID string `msgpack:"speaker_id,omitempty" json:"speaker_id,omitempty"` +} + +// STTTranscription is the transcription result published by the STT module. +type STTTranscription struct { +SessionID string `msgpack:"session_id" json:"session_id"` +Transcript string `msgpack:"transcript" json:"transcript"` +Sequence int `msgpack:"sequence" json:"sequence"` +IsPartial bool `msgpack:"is_partial" json:"is_partial"` +IsFinal bool `msgpack:"is_final" json:"is_final"` +Timestamp int64 `msgpack:"timestamp" json:"timestamp"` +SpeakerID string `msgpack:"speaker_id" json:"speaker_id"` +HasVoiceActivity bool `msgpack:"has_voice_activity" json:"has_voice_activity"` +State string `msgpack:"state" json:"state"` +} + +// STTInterrupt is published when the STT module detects a user interrupt. +type STTInterrupt struct { +SessionID string `msgpack:"session_id" json:"session_id"` +Type string `msgpack:"type" json:"type"` +Timestamp int64 `msgpack:"timestamp" json:"timestamp"` +SpeakerID string `msgpack:"speaker_id" json:"speaker_id"` +} + +// ──────────────────────────────────────────────────────────────────────────── +// Common / Error +// ──────────────────────────────────────────────────────────────────────────── + +// ErrorResponse is the standard error reply from any handler. +type ErrorResponse struct { +Error bool `msgpack:"error" json:"error"` +Message string `msgpack:"message" json:"message"` +Type string `msgpack:"type" json:"type"` +} + +// Timestamp returns the current Unix timestamp (helper for message construction). +func Timestamp() int64 { +return time.Now().Unix() +} diff --git a/messages/proto/messages.pb.go b/messages/proto/messages.pb.go new file mode 100644 index 0000000..c90db85 --- /dev/null +++ b/messages/proto/messages.pb.go @@ -0,0 +1,1738 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.10 +// protoc v6.30.2 +// source: messages/proto/messages.proto + +package proto + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type PipelineTrigger struct { + state protoimpl.MessageState `protogen:"open.v1"` + RequestId string `protobuf:"bytes,1,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + Pipeline string `protobuf:"bytes,2,opt,name=pipeline,proto3" json:"pipeline,omitempty"` + Parameters map[string]string `protobuf:"bytes,3,rep,name=parameters,proto3" json:"parameters,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *PipelineTrigger) Reset() { + *x = PipelineTrigger{} + mi := &file_messages_proto_messages_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *PipelineTrigger) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PipelineTrigger) ProtoMessage() {} + +func (x *PipelineTrigger) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_messages_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PipelineTrigger.ProtoReflect.Descriptor instead. +func (*PipelineTrigger) Descriptor() ([]byte, []int) { + return file_messages_proto_messages_proto_rawDescGZIP(), []int{0} +} + +func (x *PipelineTrigger) GetRequestId() string { + if x != nil { + return x.RequestId + } + return "" +} + +func (x *PipelineTrigger) GetPipeline() string { + if x != nil { + return x.Pipeline + } + return "" +} + +func (x *PipelineTrigger) GetParameters() map[string]string { + if x != nil { + return x.Parameters + } + return nil +} + +type PipelineStatus struct { + state protoimpl.MessageState `protogen:"open.v1"` + RequestId string `protobuf:"bytes,1,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` + RunId string `protobuf:"bytes,3,opt,name=run_id,json=runId,proto3" json:"run_id,omitempty"` + Engine string `protobuf:"bytes,4,opt,name=engine,proto3" json:"engine,omitempty"` + Pipeline string `protobuf:"bytes,5,opt,name=pipeline,proto3" json:"pipeline,omitempty"` + SubmittedAt string `protobuf:"bytes,6,opt,name=submitted_at,json=submittedAt,proto3" json:"submitted_at,omitempty"` + Error string `protobuf:"bytes,7,opt,name=error,proto3" json:"error,omitempty"` + AvailablePipelines []string `protobuf:"bytes,8,rep,name=available_pipelines,json=availablePipelines,proto3" json:"available_pipelines,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *PipelineStatus) Reset() { + *x = PipelineStatus{} + mi := &file_messages_proto_messages_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *PipelineStatus) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PipelineStatus) ProtoMessage() {} + +func (x *PipelineStatus) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_messages_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PipelineStatus.ProtoReflect.Descriptor instead. +func (*PipelineStatus) Descriptor() ([]byte, []int) { + return file_messages_proto_messages_proto_rawDescGZIP(), []int{1} +} + +func (x *PipelineStatus) GetRequestId() string { + if x != nil { + return x.RequestId + } + return "" +} + +func (x *PipelineStatus) GetStatus() string { + if x != nil { + return x.Status + } + return "" +} + +func (x *PipelineStatus) GetRunId() string { + if x != nil { + return x.RunId + } + return "" +} + +func (x *PipelineStatus) GetEngine() string { + if x != nil { + return x.Engine + } + return "" +} + +func (x *PipelineStatus) GetPipeline() string { + if x != nil { + return x.Pipeline + } + return "" +} + +func (x *PipelineStatus) GetSubmittedAt() string { + if x != nil { + return x.SubmittedAt + } + return "" +} + +func (x *PipelineStatus) GetError() string { + if x != nil { + return x.Error + } + return "" +} + +func (x *PipelineStatus) GetAvailablePipelines() []string { + if x != nil { + return x.AvailablePipelines + } + return nil +} + +type ChatRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + RequestId string `protobuf:"bytes,1,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` + Message string `protobuf:"bytes,3,opt,name=message,proto3" json:"message,omitempty"` + Query string `protobuf:"bytes,4,opt,name=query,proto3" json:"query,omitempty"` + Premium bool `protobuf:"varint,5,opt,name=premium,proto3" json:"premium,omitempty"` + EnableRag bool `protobuf:"varint,6,opt,name=enable_rag,json=enableRag,proto3" json:"enable_rag,omitempty"` + EnableReranker bool `protobuf:"varint,7,opt,name=enable_reranker,json=enableReranker,proto3" json:"enable_reranker,omitempty"` + EnableStreaming bool `protobuf:"varint,8,opt,name=enable_streaming,json=enableStreaming,proto3" json:"enable_streaming,omitempty"` + TopK int32 `protobuf:"varint,9,opt,name=top_k,json=topK,proto3" json:"top_k,omitempty"` + Collection string `protobuf:"bytes,10,opt,name=collection,proto3" json:"collection,omitempty"` + EnableTts bool `protobuf:"varint,11,opt,name=enable_tts,json=enableTts,proto3" json:"enable_tts,omitempty"` + SystemPrompt string `protobuf:"bytes,12,opt,name=system_prompt,json=systemPrompt,proto3" json:"system_prompt,omitempty"` + ResponseSubject string `protobuf:"bytes,13,opt,name=response_subject,json=responseSubject,proto3" json:"response_subject,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ChatRequest) Reset() { + *x = ChatRequest{} + mi := &file_messages_proto_messages_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ChatRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ChatRequest) ProtoMessage() {} + +func (x *ChatRequest) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_messages_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ChatRequest.ProtoReflect.Descriptor instead. +func (*ChatRequest) Descriptor() ([]byte, []int) { + return file_messages_proto_messages_proto_rawDescGZIP(), []int{2} +} + +func (x *ChatRequest) GetRequestId() string { + if x != nil { + return x.RequestId + } + return "" +} + +func (x *ChatRequest) GetUserId() string { + if x != nil { + return x.UserId + } + return "" +} + +func (x *ChatRequest) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +func (x *ChatRequest) GetQuery() string { + if x != nil { + return x.Query + } + return "" +} + +func (x *ChatRequest) GetPremium() bool { + if x != nil { + return x.Premium + } + return false +} + +func (x *ChatRequest) GetEnableRag() bool { + if x != nil { + return x.EnableRag + } + return false +} + +func (x *ChatRequest) GetEnableReranker() bool { + if x != nil { + return x.EnableReranker + } + return false +} + +func (x *ChatRequest) GetEnableStreaming() bool { + if x != nil { + return x.EnableStreaming + } + return false +} + +func (x *ChatRequest) GetTopK() int32 { + if x != nil { + return x.TopK + } + return 0 +} + +func (x *ChatRequest) GetCollection() string { + if x != nil { + return x.Collection + } + return "" +} + +func (x *ChatRequest) GetEnableTts() bool { + if x != nil { + return x.EnableTts + } + return false +} + +func (x *ChatRequest) GetSystemPrompt() string { + if x != nil { + return x.SystemPrompt + } + return "" +} + +func (x *ChatRequest) GetResponseSubject() string { + if x != nil { + return x.ResponseSubject + } + return "" +} + +type ChatResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` + Response string `protobuf:"bytes,2,opt,name=response,proto3" json:"response,omitempty"` + ResponseText string `protobuf:"bytes,3,opt,name=response_text,json=responseText,proto3" json:"response_text,omitempty"` + UsedRag bool `protobuf:"varint,4,opt,name=used_rag,json=usedRag,proto3" json:"used_rag,omitempty"` + RagSources []string `protobuf:"bytes,5,rep,name=rag_sources,json=ragSources,proto3" json:"rag_sources,omitempty"` + Success bool `protobuf:"varint,6,opt,name=success,proto3" json:"success,omitempty"` + Audio []byte `protobuf:"bytes,7,opt,name=audio,proto3" json:"audio,omitempty"` + Error string `protobuf:"bytes,8,opt,name=error,proto3" json:"error,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ChatResponse) Reset() { + *x = ChatResponse{} + mi := &file_messages_proto_messages_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ChatResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ChatResponse) ProtoMessage() {} + +func (x *ChatResponse) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_messages_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ChatResponse.ProtoReflect.Descriptor instead. +func (*ChatResponse) Descriptor() ([]byte, []int) { + return file_messages_proto_messages_proto_rawDescGZIP(), []int{3} +} + +func (x *ChatResponse) GetUserId() string { + if x != nil { + return x.UserId + } + return "" +} + +func (x *ChatResponse) GetResponse() string { + if x != nil { + return x.Response + } + return "" +} + +func (x *ChatResponse) GetResponseText() string { + if x != nil { + return x.ResponseText + } + return "" +} + +func (x *ChatResponse) GetUsedRag() bool { + if x != nil { + return x.UsedRag + } + return false +} + +func (x *ChatResponse) GetRagSources() []string { + if x != nil { + return x.RagSources + } + return nil +} + +func (x *ChatResponse) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +func (x *ChatResponse) GetAudio() []byte { + if x != nil { + return x.Audio + } + return nil +} + +func (x *ChatResponse) GetError() string { + if x != nil { + return x.Error + } + return "" +} + +type ChatStreamChunk struct { + state protoimpl.MessageState `protogen:"open.v1"` + RequestId string `protobuf:"bytes,1,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + Type string `protobuf:"bytes,2,opt,name=type,proto3" json:"type,omitempty"` + Content string `protobuf:"bytes,3,opt,name=content,proto3" json:"content,omitempty"` + Done bool `protobuf:"varint,4,opt,name=done,proto3" json:"done,omitempty"` + Timestamp int64 `protobuf:"varint,5,opt,name=timestamp,proto3" json:"timestamp,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ChatStreamChunk) Reset() { + *x = ChatStreamChunk{} + mi := &file_messages_proto_messages_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ChatStreamChunk) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ChatStreamChunk) ProtoMessage() {} + +func (x *ChatStreamChunk) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_messages_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ChatStreamChunk.ProtoReflect.Descriptor instead. +func (*ChatStreamChunk) Descriptor() ([]byte, []int) { + return file_messages_proto_messages_proto_rawDescGZIP(), []int{4} +} + +func (x *ChatStreamChunk) GetRequestId() string { + if x != nil { + return x.RequestId + } + return "" +} + +func (x *ChatStreamChunk) GetType() string { + if x != nil { + return x.Type + } + return "" +} + +func (x *ChatStreamChunk) GetContent() string { + if x != nil { + return x.Content + } + return "" +} + +func (x *ChatStreamChunk) GetDone() bool { + if x != nil { + return x.Done + } + return false +} + +func (x *ChatStreamChunk) GetTimestamp() int64 { + if x != nil { + return x.Timestamp + } + return 0 +} + +type VoiceRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + RequestId string `protobuf:"bytes,1,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + Audio []byte `protobuf:"bytes,2,opt,name=audio,proto3" json:"audio,omitempty"` + Language string `protobuf:"bytes,3,opt,name=language,proto3" json:"language,omitempty"` + Collection string `protobuf:"bytes,4,opt,name=collection,proto3" json:"collection,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *VoiceRequest) Reset() { + *x = VoiceRequest{} + mi := &file_messages_proto_messages_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *VoiceRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*VoiceRequest) ProtoMessage() {} + +func (x *VoiceRequest) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_messages_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use VoiceRequest.ProtoReflect.Descriptor instead. +func (*VoiceRequest) Descriptor() ([]byte, []int) { + return file_messages_proto_messages_proto_rawDescGZIP(), []int{5} +} + +func (x *VoiceRequest) GetRequestId() string { + if x != nil { + return x.RequestId + } + return "" +} + +func (x *VoiceRequest) GetAudio() []byte { + if x != nil { + return x.Audio + } + return nil +} + +func (x *VoiceRequest) GetLanguage() string { + if x != nil { + return x.Language + } + return "" +} + +func (x *VoiceRequest) GetCollection() string { + if x != nil { + return x.Collection + } + return "" +} + +type VoiceResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + RequestId string `protobuf:"bytes,1,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + Response string `protobuf:"bytes,2,opt,name=response,proto3" json:"response,omitempty"` + Audio []byte `protobuf:"bytes,3,opt,name=audio,proto3" json:"audio,omitempty"` + Transcription string `protobuf:"bytes,4,opt,name=transcription,proto3" json:"transcription,omitempty"` + Sources []*DocumentSource `protobuf:"bytes,5,rep,name=sources,proto3" json:"sources,omitempty"` + Error string `protobuf:"bytes,6,opt,name=error,proto3" json:"error,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *VoiceResponse) Reset() { + *x = VoiceResponse{} + mi := &file_messages_proto_messages_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *VoiceResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*VoiceResponse) ProtoMessage() {} + +func (x *VoiceResponse) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_messages_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use VoiceResponse.ProtoReflect.Descriptor instead. +func (*VoiceResponse) Descriptor() ([]byte, []int) { + return file_messages_proto_messages_proto_rawDescGZIP(), []int{6} +} + +func (x *VoiceResponse) GetRequestId() string { + if x != nil { + return x.RequestId + } + return "" +} + +func (x *VoiceResponse) GetResponse() string { + if x != nil { + return x.Response + } + return "" +} + +func (x *VoiceResponse) GetAudio() []byte { + if x != nil { + return x.Audio + } + return nil +} + +func (x *VoiceResponse) GetTranscription() string { + if x != nil { + return x.Transcription + } + return "" +} + +func (x *VoiceResponse) GetSources() []*DocumentSource { + if x != nil { + return x.Sources + } + return nil +} + +func (x *VoiceResponse) GetError() string { + if x != nil { + return x.Error + } + return "" +} + +type DocumentSource struct { + state protoimpl.MessageState `protogen:"open.v1"` + Text string `protobuf:"bytes,1,opt,name=text,proto3" json:"text,omitempty"` + Score float64 `protobuf:"fixed64,2,opt,name=score,proto3" json:"score,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DocumentSource) Reset() { + *x = DocumentSource{} + mi := &file_messages_proto_messages_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DocumentSource) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DocumentSource) ProtoMessage() {} + +func (x *DocumentSource) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_messages_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DocumentSource.ProtoReflect.Descriptor instead. +func (*DocumentSource) Descriptor() ([]byte, []int) { + return file_messages_proto_messages_proto_rawDescGZIP(), []int{7} +} + +func (x *DocumentSource) GetText() string { + if x != nil { + return x.Text + } + return "" +} + +func (x *DocumentSource) GetScore() float64 { + if x != nil { + return x.Score + } + return 0 +} + +type TTSRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Text string `protobuf:"bytes,1,opt,name=text,proto3" json:"text,omitempty"` + Speaker string `protobuf:"bytes,2,opt,name=speaker,proto3" json:"speaker,omitempty"` + Language string `protobuf:"bytes,3,opt,name=language,proto3" json:"language,omitempty"` + SpeakerWavB64 string `protobuf:"bytes,4,opt,name=speaker_wav_b64,json=speakerWavB64,proto3" json:"speaker_wav_b64,omitempty"` + Stream bool `protobuf:"varint,5,opt,name=stream,proto3" json:"stream,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TTSRequest) Reset() { + *x = TTSRequest{} + mi := &file_messages_proto_messages_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TTSRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TTSRequest) ProtoMessage() {} + +func (x *TTSRequest) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_messages_proto_msgTypes[8] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TTSRequest.ProtoReflect.Descriptor instead. +func (*TTSRequest) Descriptor() ([]byte, []int) { + return file_messages_proto_messages_proto_rawDescGZIP(), []int{8} +} + +func (x *TTSRequest) GetText() string { + if x != nil { + return x.Text + } + return "" +} + +func (x *TTSRequest) GetSpeaker() string { + if x != nil { + return x.Speaker + } + return "" +} + +func (x *TTSRequest) GetLanguage() string { + if x != nil { + return x.Language + } + return "" +} + +func (x *TTSRequest) GetSpeakerWavB64() string { + if x != nil { + return x.SpeakerWavB64 + } + return "" +} + +func (x *TTSRequest) GetStream() bool { + if x != nil { + return x.Stream + } + return false +} + +type TTSAudioChunk struct { + state protoimpl.MessageState `protogen:"open.v1"` + SessionId string `protobuf:"bytes,1,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` + ChunkIndex int32 `protobuf:"varint,2,opt,name=chunk_index,json=chunkIndex,proto3" json:"chunk_index,omitempty"` + TotalChunks int32 `protobuf:"varint,3,opt,name=total_chunks,json=totalChunks,proto3" json:"total_chunks,omitempty"` + Audio []byte `protobuf:"bytes,4,opt,name=audio,proto3" json:"audio,omitempty"` + IsLast bool `protobuf:"varint,5,opt,name=is_last,json=isLast,proto3" json:"is_last,omitempty"` + Timestamp int64 `protobuf:"varint,6,opt,name=timestamp,proto3" json:"timestamp,omitempty"` + SampleRate int32 `protobuf:"varint,7,opt,name=sample_rate,json=sampleRate,proto3" json:"sample_rate,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TTSAudioChunk) Reset() { + *x = TTSAudioChunk{} + mi := &file_messages_proto_messages_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TTSAudioChunk) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TTSAudioChunk) ProtoMessage() {} + +func (x *TTSAudioChunk) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_messages_proto_msgTypes[9] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TTSAudioChunk.ProtoReflect.Descriptor instead. +func (*TTSAudioChunk) Descriptor() ([]byte, []int) { + return file_messages_proto_messages_proto_rawDescGZIP(), []int{9} +} + +func (x *TTSAudioChunk) GetSessionId() string { + if x != nil { + return x.SessionId + } + return "" +} + +func (x *TTSAudioChunk) GetChunkIndex() int32 { + if x != nil { + return x.ChunkIndex + } + return 0 +} + +func (x *TTSAudioChunk) GetTotalChunks() int32 { + if x != nil { + return x.TotalChunks + } + return 0 +} + +func (x *TTSAudioChunk) GetAudio() []byte { + if x != nil { + return x.Audio + } + return nil +} + +func (x *TTSAudioChunk) GetIsLast() bool { + if x != nil { + return x.IsLast + } + return false +} + +func (x *TTSAudioChunk) GetTimestamp() int64 { + if x != nil { + return x.Timestamp + } + return 0 +} + +func (x *TTSAudioChunk) GetSampleRate() int32 { + if x != nil { + return x.SampleRate + } + return 0 +} + +type TTSFullResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + SessionId string `protobuf:"bytes,1,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` + Audio []byte `protobuf:"bytes,2,opt,name=audio,proto3" json:"audio,omitempty"` + Timestamp int64 `protobuf:"varint,3,opt,name=timestamp,proto3" json:"timestamp,omitempty"` + SampleRate int32 `protobuf:"varint,4,opt,name=sample_rate,json=sampleRate,proto3" json:"sample_rate,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TTSFullResponse) Reset() { + *x = TTSFullResponse{} + mi := &file_messages_proto_messages_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TTSFullResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TTSFullResponse) ProtoMessage() {} + +func (x *TTSFullResponse) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_messages_proto_msgTypes[10] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TTSFullResponse.ProtoReflect.Descriptor instead. +func (*TTSFullResponse) Descriptor() ([]byte, []int) { + return file_messages_proto_messages_proto_rawDescGZIP(), []int{10} +} + +func (x *TTSFullResponse) GetSessionId() string { + if x != nil { + return x.SessionId + } + return "" +} + +func (x *TTSFullResponse) GetAudio() []byte { + if x != nil { + return x.Audio + } + return nil +} + +func (x *TTSFullResponse) GetTimestamp() int64 { + if x != nil { + return x.Timestamp + } + return 0 +} + +func (x *TTSFullResponse) GetSampleRate() int32 { + if x != nil { + return x.SampleRate + } + return 0 +} + +type TTSStatus struct { + state protoimpl.MessageState `protogen:"open.v1"` + SessionId string `protobuf:"bytes,1,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` + Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` + Message string `protobuf:"bytes,3,opt,name=message,proto3" json:"message,omitempty"` + Timestamp int64 `protobuf:"varint,4,opt,name=timestamp,proto3" json:"timestamp,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TTSStatus) Reset() { + *x = TTSStatus{} + mi := &file_messages_proto_messages_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TTSStatus) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TTSStatus) ProtoMessage() {} + +func (x *TTSStatus) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_messages_proto_msgTypes[11] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TTSStatus.ProtoReflect.Descriptor instead. +func (*TTSStatus) Descriptor() ([]byte, []int) { + return file_messages_proto_messages_proto_rawDescGZIP(), []int{11} +} + +func (x *TTSStatus) GetSessionId() string { + if x != nil { + return x.SessionId + } + return "" +} + +func (x *TTSStatus) GetStatus() string { + if x != nil { + return x.Status + } + return "" +} + +func (x *TTSStatus) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +func (x *TTSStatus) GetTimestamp() int64 { + if x != nil { + return x.Timestamp + } + return 0 +} + +type TTSVoiceInfo struct { + state protoimpl.MessageState `protogen:"open.v1"` + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Language string `protobuf:"bytes,2,opt,name=language,proto3" json:"language,omitempty"` + ModelType string `protobuf:"bytes,3,opt,name=model_type,json=modelType,proto3" json:"model_type,omitempty"` + CreatedAt string `protobuf:"bytes,4,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TTSVoiceInfo) Reset() { + *x = TTSVoiceInfo{} + mi := &file_messages_proto_messages_proto_msgTypes[12] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TTSVoiceInfo) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TTSVoiceInfo) ProtoMessage() {} + +func (x *TTSVoiceInfo) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_messages_proto_msgTypes[12] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TTSVoiceInfo.ProtoReflect.Descriptor instead. +func (*TTSVoiceInfo) Descriptor() ([]byte, []int) { + return file_messages_proto_messages_proto_rawDescGZIP(), []int{12} +} + +func (x *TTSVoiceInfo) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *TTSVoiceInfo) GetLanguage() string { + if x != nil { + return x.Language + } + return "" +} + +func (x *TTSVoiceInfo) GetModelType() string { + if x != nil { + return x.ModelType + } + return "" +} + +func (x *TTSVoiceInfo) GetCreatedAt() string { + if x != nil { + return x.CreatedAt + } + return "" +} + +type TTSVoiceListResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + DefaultSpeaker string `protobuf:"bytes,1,opt,name=default_speaker,json=defaultSpeaker,proto3" json:"default_speaker,omitempty"` + CustomVoices []*TTSVoiceInfo `protobuf:"bytes,2,rep,name=custom_voices,json=customVoices,proto3" json:"custom_voices,omitempty"` + LastRefresh int64 `protobuf:"varint,3,opt,name=last_refresh,json=lastRefresh,proto3" json:"last_refresh,omitempty"` + Timestamp int64 `protobuf:"varint,4,opt,name=timestamp,proto3" json:"timestamp,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TTSVoiceListResponse) Reset() { + *x = TTSVoiceListResponse{} + mi := &file_messages_proto_messages_proto_msgTypes[13] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TTSVoiceListResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TTSVoiceListResponse) ProtoMessage() {} + +func (x *TTSVoiceListResponse) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_messages_proto_msgTypes[13] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TTSVoiceListResponse.ProtoReflect.Descriptor instead. +func (*TTSVoiceListResponse) Descriptor() ([]byte, []int) { + return file_messages_proto_messages_proto_rawDescGZIP(), []int{13} +} + +func (x *TTSVoiceListResponse) GetDefaultSpeaker() string { + if x != nil { + return x.DefaultSpeaker + } + return "" +} + +func (x *TTSVoiceListResponse) GetCustomVoices() []*TTSVoiceInfo { + if x != nil { + return x.CustomVoices + } + return nil +} + +func (x *TTSVoiceListResponse) GetLastRefresh() int64 { + if x != nil { + return x.LastRefresh + } + return 0 +} + +func (x *TTSVoiceListResponse) GetTimestamp() int64 { + if x != nil { + return x.Timestamp + } + return 0 +} + +type TTSVoiceRefreshResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Count int32 `protobuf:"varint,1,opt,name=count,proto3" json:"count,omitempty"` + CustomVoices []*TTSVoiceInfo `protobuf:"bytes,2,rep,name=custom_voices,json=customVoices,proto3" json:"custom_voices,omitempty"` + Timestamp int64 `protobuf:"varint,3,opt,name=timestamp,proto3" json:"timestamp,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TTSVoiceRefreshResponse) Reset() { + *x = TTSVoiceRefreshResponse{} + mi := &file_messages_proto_messages_proto_msgTypes[14] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TTSVoiceRefreshResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TTSVoiceRefreshResponse) ProtoMessage() {} + +func (x *TTSVoiceRefreshResponse) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_messages_proto_msgTypes[14] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TTSVoiceRefreshResponse.ProtoReflect.Descriptor instead. +func (*TTSVoiceRefreshResponse) Descriptor() ([]byte, []int) { + return file_messages_proto_messages_proto_rawDescGZIP(), []int{14} +} + +func (x *TTSVoiceRefreshResponse) GetCount() int32 { + if x != nil { + return x.Count + } + return 0 +} + +func (x *TTSVoiceRefreshResponse) GetCustomVoices() []*TTSVoiceInfo { + if x != nil { + return x.CustomVoices + } + return nil +} + +func (x *TTSVoiceRefreshResponse) GetTimestamp() int64 { + if x != nil { + return x.Timestamp + } + return 0 +} + +type STTStreamMessage struct { + state protoimpl.MessageState `protogen:"open.v1"` + Type string `protobuf:"bytes,1,opt,name=type,proto3" json:"type,omitempty"` + Audio []byte `protobuf:"bytes,2,opt,name=audio,proto3" json:"audio,omitempty"` + State string `protobuf:"bytes,3,opt,name=state,proto3" json:"state,omitempty"` + SpeakerId string `protobuf:"bytes,4,opt,name=speaker_id,json=speakerId,proto3" json:"speaker_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *STTStreamMessage) Reset() { + *x = STTStreamMessage{} + mi := &file_messages_proto_messages_proto_msgTypes[15] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *STTStreamMessage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*STTStreamMessage) ProtoMessage() {} + +func (x *STTStreamMessage) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_messages_proto_msgTypes[15] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use STTStreamMessage.ProtoReflect.Descriptor instead. +func (*STTStreamMessage) Descriptor() ([]byte, []int) { + return file_messages_proto_messages_proto_rawDescGZIP(), []int{15} +} + +func (x *STTStreamMessage) GetType() string { + if x != nil { + return x.Type + } + return "" +} + +func (x *STTStreamMessage) GetAudio() []byte { + if x != nil { + return x.Audio + } + return nil +} + +func (x *STTStreamMessage) GetState() string { + if x != nil { + return x.State + } + return "" +} + +func (x *STTStreamMessage) GetSpeakerId() string { + if x != nil { + return x.SpeakerId + } + return "" +} + +type STTTranscription struct { + state protoimpl.MessageState `protogen:"open.v1"` + SessionId string `protobuf:"bytes,1,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` + Transcript string `protobuf:"bytes,2,opt,name=transcript,proto3" json:"transcript,omitempty"` + Sequence int32 `protobuf:"varint,3,opt,name=sequence,proto3" json:"sequence,omitempty"` + IsPartial bool `protobuf:"varint,4,opt,name=is_partial,json=isPartial,proto3" json:"is_partial,omitempty"` + IsFinal bool `protobuf:"varint,5,opt,name=is_final,json=isFinal,proto3" json:"is_final,omitempty"` + Timestamp int64 `protobuf:"varint,6,opt,name=timestamp,proto3" json:"timestamp,omitempty"` + SpeakerId string `protobuf:"bytes,7,opt,name=speaker_id,json=speakerId,proto3" json:"speaker_id,omitempty"` + HasVoiceActivity bool `protobuf:"varint,8,opt,name=has_voice_activity,json=hasVoiceActivity,proto3" json:"has_voice_activity,omitempty"` + State string `protobuf:"bytes,9,opt,name=state,proto3" json:"state,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *STTTranscription) Reset() { + *x = STTTranscription{} + mi := &file_messages_proto_messages_proto_msgTypes[16] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *STTTranscription) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*STTTranscription) ProtoMessage() {} + +func (x *STTTranscription) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_messages_proto_msgTypes[16] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use STTTranscription.ProtoReflect.Descriptor instead. +func (*STTTranscription) Descriptor() ([]byte, []int) { + return file_messages_proto_messages_proto_rawDescGZIP(), []int{16} +} + +func (x *STTTranscription) GetSessionId() string { + if x != nil { + return x.SessionId + } + return "" +} + +func (x *STTTranscription) GetTranscript() string { + if x != nil { + return x.Transcript + } + return "" +} + +func (x *STTTranscription) GetSequence() int32 { + if x != nil { + return x.Sequence + } + return 0 +} + +func (x *STTTranscription) GetIsPartial() bool { + if x != nil { + return x.IsPartial + } + return false +} + +func (x *STTTranscription) GetIsFinal() bool { + if x != nil { + return x.IsFinal + } + return false +} + +func (x *STTTranscription) GetTimestamp() int64 { + if x != nil { + return x.Timestamp + } + return 0 +} + +func (x *STTTranscription) GetSpeakerId() string { + if x != nil { + return x.SpeakerId + } + return "" +} + +func (x *STTTranscription) GetHasVoiceActivity() bool { + if x != nil { + return x.HasVoiceActivity + } + return false +} + +func (x *STTTranscription) GetState() string { + if x != nil { + return x.State + } + return "" +} + +type STTInterrupt struct { + state protoimpl.MessageState `protogen:"open.v1"` + SessionId string `protobuf:"bytes,1,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` + Type string `protobuf:"bytes,2,opt,name=type,proto3" json:"type,omitempty"` + Timestamp int64 `protobuf:"varint,3,opt,name=timestamp,proto3" json:"timestamp,omitempty"` + SpeakerId string `protobuf:"bytes,4,opt,name=speaker_id,json=speakerId,proto3" json:"speaker_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *STTInterrupt) Reset() { + *x = STTInterrupt{} + mi := &file_messages_proto_messages_proto_msgTypes[17] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *STTInterrupt) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*STTInterrupt) ProtoMessage() {} + +func (x *STTInterrupt) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_messages_proto_msgTypes[17] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use STTInterrupt.ProtoReflect.Descriptor instead. +func (*STTInterrupt) Descriptor() ([]byte, []int) { + return file_messages_proto_messages_proto_rawDescGZIP(), []int{17} +} + +func (x *STTInterrupt) GetSessionId() string { + if x != nil { + return x.SessionId + } + return "" +} + +func (x *STTInterrupt) GetType() string { + if x != nil { + return x.Type + } + return "" +} + +func (x *STTInterrupt) GetTimestamp() int64 { + if x != nil { + return x.Timestamp + } + return 0 +} + +func (x *STTInterrupt) GetSpeakerId() string { + if x != nil { + return x.SpeakerId + } + return "" +} + +type ErrorResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Error bool `protobuf:"varint,1,opt,name=error,proto3" json:"error,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + Type string `protobuf:"bytes,3,opt,name=type,proto3" json:"type,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ErrorResponse) Reset() { + *x = ErrorResponse{} + mi := &file_messages_proto_messages_proto_msgTypes[18] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ErrorResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ErrorResponse) ProtoMessage() {} + +func (x *ErrorResponse) ProtoReflect() protoreflect.Message { + mi := &file_messages_proto_messages_proto_msgTypes[18] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ErrorResponse.ProtoReflect.Descriptor instead. +func (*ErrorResponse) Descriptor() ([]byte, []int) { + return file_messages_proto_messages_proto_rawDescGZIP(), []int{18} +} + +func (x *ErrorResponse) GetError() bool { + if x != nil { + return x.Error + } + return false +} + +func (x *ErrorResponse) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +func (x *ErrorResponse) GetType() string { + if x != nil { + return x.Type + } + return "" +} + +var File_messages_proto_messages_proto protoreflect.FileDescriptor + +const file_messages_proto_messages_proto_rawDesc = "" + + "\n" + + "\x1dmessages/proto/messages.proto\x12\bmessages\"\xd6\x01\n" + + "\x0fPipelineTrigger\x12\x1d\n" + + "\n" + + "request_id\x18\x01 \x01(\tR\trequestId\x12\x1a\n" + + "\bpipeline\x18\x02 \x01(\tR\bpipeline\x12I\n" + + "\n" + + "parameters\x18\x03 \x03(\v2).messages.PipelineTrigger.ParametersEntryR\n" + + "parameters\x1a=\n" + + "\x0fParametersEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\"\xfc\x01\n" + + "\x0ePipelineStatus\x12\x1d\n" + + "\n" + + "request_id\x18\x01 \x01(\tR\trequestId\x12\x16\n" + + "\x06status\x18\x02 \x01(\tR\x06status\x12\x15\n" + + "\x06run_id\x18\x03 \x01(\tR\x05runId\x12\x16\n" + + "\x06engine\x18\x04 \x01(\tR\x06engine\x12\x1a\n" + + "\bpipeline\x18\x05 \x01(\tR\bpipeline\x12!\n" + + "\fsubmitted_at\x18\x06 \x01(\tR\vsubmittedAt\x12\x14\n" + + "\x05error\x18\a \x01(\tR\x05error\x12/\n" + + "\x13available_pipelines\x18\b \x03(\tR\x12availablePipelines\"\xa6\x03\n" + + "\vChatRequest\x12\x1d\n" + + "\n" + + "request_id\x18\x01 \x01(\tR\trequestId\x12\x17\n" + + "\auser_id\x18\x02 \x01(\tR\x06userId\x12\x18\n" + + "\amessage\x18\x03 \x01(\tR\amessage\x12\x14\n" + + "\x05query\x18\x04 \x01(\tR\x05query\x12\x18\n" + + "\apremium\x18\x05 \x01(\bR\apremium\x12\x1d\n" + + "\n" + + "enable_rag\x18\x06 \x01(\bR\tenableRag\x12'\n" + + "\x0fenable_reranker\x18\a \x01(\bR\x0eenableReranker\x12)\n" + + "\x10enable_streaming\x18\b \x01(\bR\x0fenableStreaming\x12\x13\n" + + "\x05top_k\x18\t \x01(\x05R\x04topK\x12\x1e\n" + + "\n" + + "collection\x18\n" + + " \x01(\tR\n" + + "collection\x12\x1d\n" + + "\n" + + "enable_tts\x18\v \x01(\bR\tenableTts\x12#\n" + + "\rsystem_prompt\x18\f \x01(\tR\fsystemPrompt\x12)\n" + + "\x10response_subject\x18\r \x01(\tR\x0fresponseSubject\"\xea\x01\n" + + "\fChatResponse\x12\x17\n" + + "\auser_id\x18\x01 \x01(\tR\x06userId\x12\x1a\n" + + "\bresponse\x18\x02 \x01(\tR\bresponse\x12#\n" + + "\rresponse_text\x18\x03 \x01(\tR\fresponseText\x12\x19\n" + + "\bused_rag\x18\x04 \x01(\bR\ausedRag\x12\x1f\n" + + "\vrag_sources\x18\x05 \x03(\tR\n" + + "ragSources\x12\x18\n" + + "\asuccess\x18\x06 \x01(\bR\asuccess\x12\x14\n" + + "\x05audio\x18\a \x01(\fR\x05audio\x12\x14\n" + + "\x05error\x18\b \x01(\tR\x05error\"\x90\x01\n" + + "\x0fChatStreamChunk\x12\x1d\n" + + "\n" + + "request_id\x18\x01 \x01(\tR\trequestId\x12\x12\n" + + "\x04type\x18\x02 \x01(\tR\x04type\x12\x18\n" + + "\acontent\x18\x03 \x01(\tR\acontent\x12\x12\n" + + "\x04done\x18\x04 \x01(\bR\x04done\x12\x1c\n" + + "\ttimestamp\x18\x05 \x01(\x03R\ttimestamp\"\x7f\n" + + "\fVoiceRequest\x12\x1d\n" + + "\n" + + "request_id\x18\x01 \x01(\tR\trequestId\x12\x14\n" + + "\x05audio\x18\x02 \x01(\fR\x05audio\x12\x1a\n" + + "\blanguage\x18\x03 \x01(\tR\blanguage\x12\x1e\n" + + "\n" + + "collection\x18\x04 \x01(\tR\n" + + "collection\"\xd0\x01\n" + + "\rVoiceResponse\x12\x1d\n" + + "\n" + + "request_id\x18\x01 \x01(\tR\trequestId\x12\x1a\n" + + "\bresponse\x18\x02 \x01(\tR\bresponse\x12\x14\n" + + "\x05audio\x18\x03 \x01(\fR\x05audio\x12$\n" + + "\rtranscription\x18\x04 \x01(\tR\rtranscription\x122\n" + + "\asources\x18\x05 \x03(\v2\x18.messages.DocumentSourceR\asources\x12\x14\n" + + "\x05error\x18\x06 \x01(\tR\x05error\":\n" + + "\x0eDocumentSource\x12\x12\n" + + "\x04text\x18\x01 \x01(\tR\x04text\x12\x14\n" + + "\x05score\x18\x02 \x01(\x01R\x05score\"\x96\x01\n" + + "\n" + + "TTSRequest\x12\x12\n" + + "\x04text\x18\x01 \x01(\tR\x04text\x12\x18\n" + + "\aspeaker\x18\x02 \x01(\tR\aspeaker\x12\x1a\n" + + "\blanguage\x18\x03 \x01(\tR\blanguage\x12&\n" + + "\x0fspeaker_wav_b64\x18\x04 \x01(\tR\rspeakerWavB64\x12\x16\n" + + "\x06stream\x18\x05 \x01(\bR\x06stream\"\xe0\x01\n" + + "\rTTSAudioChunk\x12\x1d\n" + + "\n" + + "session_id\x18\x01 \x01(\tR\tsessionId\x12\x1f\n" + + "\vchunk_index\x18\x02 \x01(\x05R\n" + + "chunkIndex\x12!\n" + + "\ftotal_chunks\x18\x03 \x01(\x05R\vtotalChunks\x12\x14\n" + + "\x05audio\x18\x04 \x01(\fR\x05audio\x12\x17\n" + + "\ais_last\x18\x05 \x01(\bR\x06isLast\x12\x1c\n" + + "\ttimestamp\x18\x06 \x01(\x03R\ttimestamp\x12\x1f\n" + + "\vsample_rate\x18\a \x01(\x05R\n" + + "sampleRate\"\x85\x01\n" + + "\x0fTTSFullResponse\x12\x1d\n" + + "\n" + + "session_id\x18\x01 \x01(\tR\tsessionId\x12\x14\n" + + "\x05audio\x18\x02 \x01(\fR\x05audio\x12\x1c\n" + + "\ttimestamp\x18\x03 \x01(\x03R\ttimestamp\x12\x1f\n" + + "\vsample_rate\x18\x04 \x01(\x05R\n" + + "sampleRate\"z\n" + + "\tTTSStatus\x12\x1d\n" + + "\n" + + "session_id\x18\x01 \x01(\tR\tsessionId\x12\x16\n" + + "\x06status\x18\x02 \x01(\tR\x06status\x12\x18\n" + + "\amessage\x18\x03 \x01(\tR\amessage\x12\x1c\n" + + "\ttimestamp\x18\x04 \x01(\x03R\ttimestamp\"|\n" + + "\fTTSVoiceInfo\x12\x12\n" + + "\x04name\x18\x01 \x01(\tR\x04name\x12\x1a\n" + + "\blanguage\x18\x02 \x01(\tR\blanguage\x12\x1d\n" + + "\n" + + "model_type\x18\x03 \x01(\tR\tmodelType\x12\x1d\n" + + "\n" + + "created_at\x18\x04 \x01(\tR\tcreatedAt\"\xbd\x01\n" + + "\x14TTSVoiceListResponse\x12'\n" + + "\x0fdefault_speaker\x18\x01 \x01(\tR\x0edefaultSpeaker\x12;\n" + + "\rcustom_voices\x18\x02 \x03(\v2\x16.messages.TTSVoiceInfoR\fcustomVoices\x12!\n" + + "\flast_refresh\x18\x03 \x01(\x03R\vlastRefresh\x12\x1c\n" + + "\ttimestamp\x18\x04 \x01(\x03R\ttimestamp\"\x8a\x01\n" + + "\x17TTSVoiceRefreshResponse\x12\x14\n" + + "\x05count\x18\x01 \x01(\x05R\x05count\x12;\n" + + "\rcustom_voices\x18\x02 \x03(\v2\x16.messages.TTSVoiceInfoR\fcustomVoices\x12\x1c\n" + + "\ttimestamp\x18\x03 \x01(\x03R\ttimestamp\"q\n" + + "\x10STTStreamMessage\x12\x12\n" + + "\x04type\x18\x01 \x01(\tR\x04type\x12\x14\n" + + "\x05audio\x18\x02 \x01(\fR\x05audio\x12\x14\n" + + "\x05state\x18\x03 \x01(\tR\x05state\x12\x1d\n" + + "\n" + + "speaker_id\x18\x04 \x01(\tR\tspeakerId\"\xa8\x02\n" + + "\x10STTTranscription\x12\x1d\n" + + "\n" + + "session_id\x18\x01 \x01(\tR\tsessionId\x12\x1e\n" + + "\n" + + "transcript\x18\x02 \x01(\tR\n" + + "transcript\x12\x1a\n" + + "\bsequence\x18\x03 \x01(\x05R\bsequence\x12\x1d\n" + + "\n" + + "is_partial\x18\x04 \x01(\bR\tisPartial\x12\x19\n" + + "\bis_final\x18\x05 \x01(\bR\aisFinal\x12\x1c\n" + + "\ttimestamp\x18\x06 \x01(\x03R\ttimestamp\x12\x1d\n" + + "\n" + + "speaker_id\x18\a \x01(\tR\tspeakerId\x12,\n" + + "\x12has_voice_activity\x18\b \x01(\bR\x10hasVoiceActivity\x12\x14\n" + + "\x05state\x18\t \x01(\tR\x05state\"~\n" + + "\fSTTInterrupt\x12\x1d\n" + + "\n" + + "session_id\x18\x01 \x01(\tR\tsessionId\x12\x12\n" + + "\x04type\x18\x02 \x01(\tR\x04type\x12\x1c\n" + + "\ttimestamp\x18\x03 \x01(\x03R\ttimestamp\x12\x1d\n" + + "\n" + + "speaker_id\x18\x04 \x01(\tR\tspeakerId\"S\n" + + "\rErrorResponse\x12\x14\n" + + "\x05error\x18\x01 \x01(\bR\x05error\x12\x18\n" + + "\amessage\x18\x02 \x01(\tR\amessage\x12\x12\n" + + "\x04type\x18\x03 \x01(\tR\x04typeBBZ@git.daviestechlabs.io/daviestechlabs/handler-base/messages/protob\x06proto3" + +var ( + file_messages_proto_messages_proto_rawDescOnce sync.Once + file_messages_proto_messages_proto_rawDescData []byte +) + +func file_messages_proto_messages_proto_rawDescGZIP() []byte { + file_messages_proto_messages_proto_rawDescOnce.Do(func() { + file_messages_proto_messages_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_messages_proto_messages_proto_rawDesc), len(file_messages_proto_messages_proto_rawDesc))) + }) + return file_messages_proto_messages_proto_rawDescData +} + +var file_messages_proto_messages_proto_msgTypes = make([]protoimpl.MessageInfo, 20) +var file_messages_proto_messages_proto_goTypes = []any{ + (*PipelineTrigger)(nil), // 0: messages.PipelineTrigger + (*PipelineStatus)(nil), // 1: messages.PipelineStatus + (*ChatRequest)(nil), // 2: messages.ChatRequest + (*ChatResponse)(nil), // 3: messages.ChatResponse + (*ChatStreamChunk)(nil), // 4: messages.ChatStreamChunk + (*VoiceRequest)(nil), // 5: messages.VoiceRequest + (*VoiceResponse)(nil), // 6: messages.VoiceResponse + (*DocumentSource)(nil), // 7: messages.DocumentSource + (*TTSRequest)(nil), // 8: messages.TTSRequest + (*TTSAudioChunk)(nil), // 9: messages.TTSAudioChunk + (*TTSFullResponse)(nil), // 10: messages.TTSFullResponse + (*TTSStatus)(nil), // 11: messages.TTSStatus + (*TTSVoiceInfo)(nil), // 12: messages.TTSVoiceInfo + (*TTSVoiceListResponse)(nil), // 13: messages.TTSVoiceListResponse + (*TTSVoiceRefreshResponse)(nil), // 14: messages.TTSVoiceRefreshResponse + (*STTStreamMessage)(nil), // 15: messages.STTStreamMessage + (*STTTranscription)(nil), // 16: messages.STTTranscription + (*STTInterrupt)(nil), // 17: messages.STTInterrupt + (*ErrorResponse)(nil), // 18: messages.ErrorResponse + nil, // 19: messages.PipelineTrigger.ParametersEntry +} +var file_messages_proto_messages_proto_depIdxs = []int32{ + 19, // 0: messages.PipelineTrigger.parameters:type_name -> messages.PipelineTrigger.ParametersEntry + 7, // 1: messages.VoiceResponse.sources:type_name -> messages.DocumentSource + 12, // 2: messages.TTSVoiceListResponse.custom_voices:type_name -> messages.TTSVoiceInfo + 12, // 3: messages.TTSVoiceRefreshResponse.custom_voices:type_name -> messages.TTSVoiceInfo + 4, // [4:4] is the sub-list for method output_type + 4, // [4:4] is the sub-list for method input_type + 4, // [4:4] is the sub-list for extension type_name + 4, // [4:4] is the sub-list for extension extendee + 0, // [0:4] is the sub-list for field type_name +} + +func init() { file_messages_proto_messages_proto_init() } +func file_messages_proto_messages_proto_init() { + if File_messages_proto_messages_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_messages_proto_messages_proto_rawDesc), len(file_messages_proto_messages_proto_rawDesc)), + NumEnums: 0, + NumMessages: 20, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_messages_proto_messages_proto_goTypes, + DependencyIndexes: file_messages_proto_messages_proto_depIdxs, + MessageInfos: file_messages_proto_messages_proto_msgTypes, + }.Build() + File_messages_proto_messages_proto = out.File + file_messages_proto_messages_proto_goTypes = nil + file_messages_proto_messages_proto_depIdxs = nil +} diff --git a/messages/proto/messages.proto b/messages/proto/messages.proto new file mode 100644 index 0000000..04e2c12 --- /dev/null +++ b/messages/proto/messages.proto @@ -0,0 +1,174 @@ +syntax = "proto3"; + +package messages; + +option go_package = "git.daviestechlabs.io/daviestechlabs/handler-base/messages/proto"; + +// ── Pipeline Bridge ───────────────────────────────────────────────────────── + +message PipelineTrigger { + string request_id = 1; + string pipeline = 2; + map parameters = 3; +} + +message PipelineStatus { + string request_id = 1; + string status = 2; + string run_id = 3; + string engine = 4; + string pipeline = 5; + string submitted_at = 6; + string error = 7; + repeated string available_pipelines = 8; +} + +// ── Chat Handler ──────────────────────────────────────────────────────────── + +message ChatRequest { + string request_id = 1; + string user_id = 2; + string message = 3; + string query = 4; + bool premium = 5; + bool enable_rag = 6; + bool enable_reranker = 7; + bool enable_streaming = 8; + int32 top_k = 9; + string collection = 10; + bool enable_tts = 11; + string system_prompt = 12; + string response_subject = 13; +} + +message ChatResponse { + string user_id = 1; + string response = 2; + string response_text = 3; + bool used_rag = 4; + repeated string rag_sources = 5; + bool success = 6; + bytes audio = 7; + string error = 8; +} + +message ChatStreamChunk { + string request_id = 1; + string type = 2; + string content = 3; + bool done = 4; + int64 timestamp = 5; +} + +// ── Voice Assistant ───────────────────────────────────────────────────────── + +message VoiceRequest { + string request_id = 1; + bytes audio = 2; + string language = 3; + string collection = 4; +} + +message VoiceResponse { + string request_id = 1; + string response = 2; + bytes audio = 3; + string transcription = 4; + repeated DocumentSource sources = 5; + string error = 6; +} + +message DocumentSource { + string text = 1; + double score = 2; +} + +// ── TTS Module ────────────────────────────────────────────────────────────── + +message TTSRequest { + string text = 1; + string speaker = 2; + string language = 3; + string speaker_wav_b64 = 4; + bool stream = 5; +} + +message TTSAudioChunk { + string session_id = 1; + int32 chunk_index = 2; + int32 total_chunks = 3; + bytes audio = 4; + bool is_last = 5; + int64 timestamp = 6; + int32 sample_rate = 7; +} + +message TTSFullResponse { + string session_id = 1; + bytes audio = 2; + int64 timestamp = 3; + int32 sample_rate = 4; +} + +message TTSStatus { + string session_id = 1; + string status = 2; + string message = 3; + int64 timestamp = 4; +} + +message TTSVoiceInfo { + string name = 1; + string language = 2; + string model_type = 3; + string created_at = 4; +} + +message TTSVoiceListResponse { + string default_speaker = 1; + repeated TTSVoiceInfo custom_voices = 2; + int64 last_refresh = 3; + int64 timestamp = 4; +} + +message TTSVoiceRefreshResponse { + int32 count = 1; + repeated TTSVoiceInfo custom_voices = 2; + int64 timestamp = 3; +} + +// ── STT Module ────────────────────────────────────────────────────────────── + +message STTStreamMessage { + string type = 1; + bytes audio = 2; + string state = 3; + string speaker_id = 4; +} + +message STTTranscription { + string session_id = 1; + string transcript = 2; + int32 sequence = 3; + bool is_partial = 4; + bool is_final = 5; + int64 timestamp = 6; + string speaker_id = 7; + bool has_voice_activity = 8; + string state = 9; +} + +message STTInterrupt { + string session_id = 1; + string type = 2; + int64 timestamp = 3; + string speaker_id = 4; +} + +// ── Common ────────────────────────────────────────────────────────────────── + +message ErrorResponse { + bool error = 1; + string message = 2; + string type = 3; +} diff --git a/natsutil/natsutil_test.go b/natsutil/natsutil_test.go new file mode 100644 index 0000000..afb663f --- /dev/null +++ b/natsutil/natsutil_test.go @@ -0,0 +1,256 @@ +package natsutil + +import ( + "testing" + + "github.com/vmihailenco/msgpack/v5" +) + +// ──────────────────────────────────────────────────────────────────────────── +// DecodeMsgpackMap tests +// ──────────────────────────────────────────────────────────────────────────── + +func TestDecodeMsgpackMap_Roundtrip(t *testing.T) { + orig := map[string]any{ + "request_id": "req-001", + "user_id": "user-42", + "premium": true, + "top_k": int64(10), // msgpack decodes ints as int64 + } + data, err := msgpack.Marshal(orig) + if err != nil { + t.Fatal(err) + } + + decoded, err := DecodeMsgpackMap(data) + if err != nil { + t.Fatal(err) + } + + if decoded["request_id"] != "req-001" { + t.Errorf("request_id = %v", decoded["request_id"]) + } + if decoded["premium"] != true { + t.Errorf("premium = %v", decoded["premium"]) + } +} + +func TestDecodeMsgpackMap_Empty(t *testing.T) { + data, _ := msgpack.Marshal(map[string]any{}) + m, err := DecodeMsgpackMap(data) + if err != nil { + t.Fatal(err) + } + if len(m) != 0 { + t.Errorf("expected empty map, got %v", m) + } +} + +func TestDecodeMsgpackMap_InvalidData(t *testing.T) { + _, err := DecodeMsgpackMap([]byte{0xFF, 0xFE}) + if err == nil { + t.Error("expected error for invalid msgpack data") + } +} + +// ──────────────────────────────────────────────────────────────────────────── +// DecodeMsgpack (typed struct) tests +// ──────────────────────────────────────────────────────────────────────────── + +type testMessage struct { + RequestID string `msgpack:"request_id"` + UserID string `msgpack:"user_id"` + Count int `msgpack:"count"` + Active bool `msgpack:"active"` +} + +func TestDecodeMsgpackTyped_Roundtrip(t *testing.T) { + orig := testMessage{ + RequestID: "req-typed-001", + UserID: "user-7", + Count: 42, + Active: true, + } + data, err := msgpack.Marshal(orig) + if err != nil { + t.Fatal(err) + } + + // Simulate nats.Msg data decoding. + var decoded testMessage + if err := msgpack.Unmarshal(data, &decoded); err != nil { + t.Fatal(err) + } + + if decoded.RequestID != orig.RequestID { + t.Errorf("RequestID = %q, want %q", decoded.RequestID, orig.RequestID) + } + if decoded.Count != orig.Count { + t.Errorf("Count = %d, want %d", decoded.Count, orig.Count) + } + if decoded.Active != orig.Active { + t.Errorf("Active = %v, want %v", decoded.Active, orig.Active) + } +} + +// TestTypedStructDecodesMapEncoding verifies that a typed struct can be +// decoded from data that was encoded as map[string]any (backwards compat). +func TestTypedStructDecodesMapEncoding(t *testing.T) { + // Encode as map (the old way). + mapData := map[string]any{ + "request_id": "req-compat", + "user_id": "user-compat", + "count": int64(99), + "active": false, + } + data, err := msgpack.Marshal(mapData) + if err != nil { + t.Fatal(err) + } + + // Decode into typed struct (the new way). + var msg testMessage + if err := msgpack.Unmarshal(data, &msg); err != nil { + t.Fatal(err) + } + + if msg.RequestID != "req-compat" { + t.Errorf("RequestID = %q", msg.RequestID) + } + if msg.Count != 99 { + t.Errorf("Count = %d, want 99", msg.Count) + } +} + +// ──────────────────────────────────────────────────────────────────────────── +// Binary data tests (audio []byte in msgpack) +// ──────────────────────────────────────────────────────────────────────────── + +type audioMessage struct { + SessionID string `msgpack:"session_id"` + Audio []byte `msgpack:"audio"` + SampleRate int `msgpack:"sample_rate"` +} + +func TestBinaryDataRoundtrip(t *testing.T) { + audio := make([]byte, 32768) + for i := range audio { + audio[i] = byte(i % 256) + } + + orig := audioMessage{ + SessionID: "sess-audio-001", + Audio: audio, + SampleRate: 24000, + } + data, err := msgpack.Marshal(orig) + if err != nil { + t.Fatal(err) + } + + var decoded audioMessage + if err := msgpack.Unmarshal(data, &decoded); err != nil { + t.Fatal(err) + } + + if len(decoded.Audio) != len(orig.Audio) { + t.Fatalf("audio len = %d, want %d", len(decoded.Audio), len(orig.Audio)) + } + for i := range decoded.Audio { + if decoded.Audio[i] != orig.Audio[i] { + t.Fatalf("audio[%d] = %d, want %d", i, decoded.Audio[i], orig.Audio[i]) + } + } +} + +// TestBinaryVsBase64Size shows the wire-size win of raw bytes vs base64 string. +func TestBinaryVsBase64Size(t *testing.T) { + audio := make([]byte, 16384) + + // Old approach: base64 string in map. + import_b64 := make([]byte, (len(audio)*4+2)/3) // approximate base64 size + mapMsg := map[string]any{ + "session_id": "sess-1", + "audio_b64": string(import_b64), + } + mapData, _ := msgpack.Marshal(mapMsg) + + // New approach: raw bytes in struct. + structMsg := audioMessage{ + SessionID: "sess-1", + Audio: audio, + } + structData, _ := msgpack.Marshal(structMsg) + + t.Logf("base64-in-map: %d bytes, raw-bytes-in-struct: %d bytes (%.0f%% smaller)", + len(mapData), len(structData), + 100*(1-float64(len(structData))/float64(len(mapData)))) +} + +// ──────────────────────────────────────────────────────────────────────────── +// Benchmarks +// ──────────────────────────────────────────────────────────────────────────── + +func BenchmarkEncodeMap(b *testing.B) { + data := map[string]any{ + "request_id": "req-bench", + "user_id": "user-bench", + "message": "What is the weather today?", + "premium": true, + "top_k": 10, + } + for b.Loop() { + msgpack.Marshal(data) + } +} + +func BenchmarkEncodeStruct(b *testing.B) { + data := testMessage{ + RequestID: "req-bench", + UserID: "user-bench", + Count: 10, + Active: true, + } + for b.Loop() { + msgpack.Marshal(data) + } +} + +func BenchmarkDecodeMap(b *testing.B) { + raw, _ := msgpack.Marshal(map[string]any{ + "request_id": "req-bench", + "user_id": "user-bench", + "message": "What is the weather today?", + "premium": true, + "top_k": 10, + }) + for b.Loop() { + var m map[string]any + msgpack.Unmarshal(raw, &m) + } +} + +func BenchmarkDecodeStruct(b *testing.B) { + raw, _ := msgpack.Marshal(testMessage{ + RequestID: "req-bench", + UserID: "user-bench", + Count: 10, + Active: true, + }) + for b.Loop() { + var m testMessage + msgpack.Unmarshal(raw, &m) + } +} + +func BenchmarkDecodeAudio32KB(b *testing.B) { + raw, _ := msgpack.Marshal(audioMessage{ + SessionID: "s1", + Audio: make([]byte, 32768), + SampleRate: 24000, + }) + for b.Loop() { + var m audioMessage + msgpack.Unmarshal(raw, &m) + } +}