3 Commits

Author SHA1 Message Date
f1dd96a42b style: gofmt + fix errcheck lint warning
All checks were successful
CI / Test (push) Successful in 3m2s
CI / Lint (push) Successful in 3m7s
CI / Release (push) Successful in 1m55s
CI / Notify Downstream (stt-module) (push) Successful in 1s
CI / Notify Downstream (voice-assistant) (push) Successful in 1s
CI / Notify (push) Successful in 2s
CI / Notify Downstream (chat-handler) (push) Successful in 1s
CI / Notify Downstream (pipeline-bridge) (push) Successful in 1s
CI / Notify Downstream (tts-module) (push) Successful in 1s
2026-02-21 15:35:37 -05:00
13ef1df109 feat!: replace msgpack with protobuf for all NATS messages
Some checks failed
CI / Lint (push) Failing after 3m2s
CI / Test (push) Successful in 3m44s
CI / Release (push) Has been skipped
CI / Notify Downstream (chat-handler) (push) Has been skipped
CI / Notify Downstream (pipeline-bridge) (push) Has been skipped
CI / Notify Downstream (stt-module) (push) Has been skipped
CI / Notify Downstream (tts-module) (push) Has been skipped
CI / Notify Downstream (voice-assistant) (push) Has been skipped
CI / Notify (push) Successful in 1s
BREAKING CHANGE: All NATS message serialization now uses Protocol Buffers.
- Added proto/messages/v1/messages.proto with 22 message types
- Generated Go code at gen/messagespb/
- messages/ package now exports type aliases to proto types
- natsutil.Publish/Request/Decode use proto.Marshal/Unmarshal
- Removed legacy MessageHandler, OnMessage, wrapMapHandler
- TypedMessageHandler now returns (proto.Message, error)
- EffectiveQuery is now a free function: messages.EffectiveQuery(req)
- Removed msgpack dependency entirely
2026-02-21 14:58:05 -05:00
3585d81ff5 feat: add StreamGenerate for real SSE streaming from LLM
Some checks failed
CI / Lint (push) Failing after 2m44s
CI / Test (push) Successful in 3m7s
CI / Release (push) Has been skipped
CI / Notify Downstream (chat-handler) (push) Has been skipped
CI / Notify Downstream (pipeline-bridge) (push) Has been skipped
CI / Notify Downstream (stt-module) (push) Has been skipped
CI / Notify Downstream (tts-module) (push) Has been skipped
CI / Notify Downstream (voice-assistant) (push) Has been skipped
CI / Notify (push) Successful in 2s
- Add postJSONStream() for incremental response body reading
- Add LLMClient.StreamGenerate() with SSE parsing and onToken callback
- Supports stream:true, parses data: lines, handles [DONE] sentinel
- Graceful partial-text return on stream interruption
- 9 new tests covering happy path, edge cases, cancellation
2026-02-20 17:55:01 -05:00
16 changed files with 3293 additions and 1125 deletions

10
buf.gen.yaml Normal file
View File

@@ -0,0 +1,10 @@
version: v2
managed:
enabled: true
override:
- file_option: go_package_prefix
value: git.daviestechlabs.io/daviestechlabs/handler-base/gen
plugins:
- protoc_builtin: go
out: gen
opt: paths=source_relative

9
buf.yaml Normal file
View File

@@ -0,0 +1,9 @@
version: v2
modules:
- path: proto
lint:
use:
- STANDARD
breaking:
use:
- FILE

View File

@@ -6,6 +6,7 @@
package clients package clients
import ( import (
"bufio"
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
@@ -14,6 +15,7 @@ import (
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"net/url" "net/url"
"strings"
"sync" "sync"
"time" "time"
) )
@@ -142,6 +144,36 @@ func (h *httpClient) do(req *http.Request) ([]byte, error) {
return body, nil return body, nil
} }
// postJSONStream sends a JSON POST and returns the raw *http.Response so the
// caller can read the body incrementally (e.g. for SSE streaming). The caller
// is responsible for closing resp.Body.
func (h *httpClient) postJSONStream(ctx context.Context, path string, body any) (*http.Response, error) {
buf := getBuf()
defer putBuf(buf)
if err := json.NewEncoder(buf).Encode(body); err != nil {
return nil, fmt.Errorf("marshal: %w", err)
}
// Copy to a non-pooled buffer so we can safely return the pool buffer.
payload := make([]byte, buf.Len())
copy(payload, buf.Bytes())
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, bytes.NewReader(payload))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := h.client.Do(req)
if err != nil {
return nil, fmt.Errorf("http %s %s: %w", req.Method, req.URL.Path, err)
}
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
return nil, fmt.Errorf("http %d: %s", resp.StatusCode, string(respBody))
}
return resp, nil
}
func (h *httpClient) healthCheck(ctx context.Context) bool { func (h *httpClient) healthCheck(ctx context.Context) bool {
data, err := h.get(ctx, "/health", nil) data, err := h.get(ctx, "/health", nil)
_ = data _ = data
@@ -320,6 +352,73 @@ func (c *LLMClient) Generate(ctx context.Context, prompt string, context_ string
return resp.Choices[0].Message.Content, nil return resp.Choices[0].Message.Content, nil
} }
// StreamGenerate sends a streaming chat completion request and calls onToken
// for each content delta received via SSE. Returns the fully assembled text.
// The onToken callback is invoked synchronously on the calling goroutine; it
// should be fast (e.g. publish a NATS message).
func (c *LLMClient) StreamGenerate(ctx context.Context, prompt string, context_ string, systemPrompt string, onToken func(token string)) (string, error) {
msgs := buildMessages(prompt, context_, systemPrompt)
payload := map[string]any{
"model": c.Model,
"messages": msgs,
"max_tokens": c.MaxTokens,
"temperature": c.Temperature,
"top_p": c.TopP,
"stream": true,
}
resp, err := c.postJSONStream(ctx, "/v1/chat/completions", payload)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
var full strings.Builder
scanner := bufio.NewScanner(resp.Body)
// SSE lines can be up to 64 KiB for large token batches.
scanner.Buffer(make([]byte, 0, 64*1024), 64*1024)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
break
}
var chunk struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
} `json:"choices"`
}
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue // skip malformed chunks
}
if len(chunk.Choices) == 0 {
continue
}
token := chunk.Choices[0].Delta.Content
if token == "" {
continue
}
full.WriteString(token)
if onToken != nil {
onToken(token)
}
}
if err := scanner.Err(); err != nil {
// If we already collected some text, return it with the error.
if full.Len() > 0 {
return full.String(), fmt.Errorf("stream interrupted: %w", err)
}
return "", fmt.Errorf("stream read: %w", err)
}
return full.String(), nil
}
func buildMessages(prompt, ctx, systemPrompt string) []ChatMessage { func buildMessages(prompt, ctx, systemPrompt string) []ChatMessage {
var msgs []ChatMessage var msgs []ChatMessage
if systemPrompt != "" { if systemPrompt != "" {

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -299,6 +300,293 @@ func TestLLMClient_GenerateNoChoices(t *testing.T) {
} }
} }
// ────────────────────────────────────────────────────────────────────────────
// LLM client — StreamGenerate
// ────────────────────────────────────────────────────────────────────────────
// sseChunk builds an OpenAI-compatible SSE chat.completion.chunk line.
func sseChunk(content string) string {
chunk := map[string]any{
"choices": []map[string]any{
{"delta": map[string]any{"content": content}},
},
}
b, _ := json.Marshal(chunk)
return "data: " + string(b) + "\n\n"
}
func TestLLMClient_StreamGenerate(t *testing.T) {
tokens := []string{"Hello", " world", "!"}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/chat/completions" {
t.Errorf("path = %q", r.URL.Path)
}
var req map[string]any
_ = json.NewDecoder(r.Body).Decode(&req)
if req["stream"] != true {
t.Errorf("stream = %v, want true", req["stream"])
}
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
for _, tok := range tokens {
_, _ = w.Write([]byte(sseChunk(tok)))
flusher.Flush()
}
_, _ = w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
var received []string
result, err := c.StreamGenerate(context.Background(), "hi", "", "", func(tok string) {
received = append(received, tok)
})
if err != nil {
t.Fatal(err)
}
if result != "Hello world!" {
t.Errorf("result = %q, want %q", result, "Hello world!")
}
if len(received) != 3 {
t.Fatalf("callback count = %d, want 3", len(received))
}
if received[0] != "Hello" || received[1] != " world" || received[2] != "!" {
t.Errorf("received = %v", received)
}
}
func TestLLMClient_StreamGenerateWithSystemPrompt(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req map[string]any
_ = json.NewDecoder(r.Body).Decode(&req)
msgs, _ := req["messages"].([]any)
if len(msgs) != 2 {
t.Errorf("expected system+user, got %d messages", len(msgs))
}
first, _ := msgs[0].(map[string]any)
if first["role"] != "system" || first["content"] != "You are a DM" {
t.Errorf("system msg = %v", first)
}
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
_, _ = w.Write([]byte(sseChunk("ok")))
_, _ = w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
result, err := c.StreamGenerate(context.Background(), "roll dice", "", "You are a DM", nil)
if err != nil {
t.Fatal(err)
}
if result != "ok" {
t.Errorf("result = %q", result)
}
}
func TestLLMClient_StreamGenerateNilCallback(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
_, _ = w.Write([]byte(sseChunk("token")))
_, _ = w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
// nil callback should not panic
result, err := c.StreamGenerate(context.Background(), "hi", "", "", nil)
if err != nil {
t.Fatal(err)
}
if result != "token" {
t.Errorf("result = %q", result)
}
}
func TestLLMClient_StreamGenerateEmptyDelta(t *testing.T) {
// SSE chunks with empty content should be silently skipped.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
// role-only chunk (no content) — common for first chunk from vLLM
_, _ = w.Write([]byte("data: {\"choices\":[{\"delta\":{\"role\":\"assistant\"}}]}\n\n"))
// empty content string
_, _ = w.Write([]byte(sseChunk("")))
// real token
_, _ = w.Write([]byte(sseChunk("hello")))
_, _ = w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
var count int
result, err := c.StreamGenerate(context.Background(), "q", "", "", func(tok string) {
count++
})
if err != nil {
t.Fatal(err)
}
if result != "hello" {
t.Errorf("result = %q", result)
}
if count != 1 {
t.Errorf("callback count = %d, want 1 (empty deltas should be skipped)", count)
}
}
func TestLLMClient_StreamGenerateMalformedChunks(t *testing.T) {
// Malformed JSON should be skipped without error.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
_, _ = w.Write([]byte("data: {invalid json}\n\n"))
_, _ = w.Write([]byte("data: {\"choices\":[]}\n\n")) // empty choices
_, _ = w.Write([]byte(sseChunk("good")))
_, _ = w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
result, err := c.StreamGenerate(context.Background(), "q", "", "", nil)
if err != nil {
t.Fatal(err)
}
if result != "good" {
t.Errorf("result = %q, want %q", result, "good")
}
}
func TestLLMClient_StreamGenerateHTTPError(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(500)
_, _ = w.Write([]byte("internal server error"))
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
_, err := c.StreamGenerate(context.Background(), "q", "", "", nil)
if err == nil {
t.Fatal("expected error for 500")
}
if !strings.Contains(err.Error(), "500") {
t.Errorf("error should contain 500: %v", err)
}
}
func TestLLMClient_StreamGenerateContextCanceled(t *testing.T) {
started := make(chan struct{})
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
// Send several tokens so the client receives some before cancel.
for i := range 20 {
_, _ = w.Write([]byte(sseChunk(fmt.Sprintf("tok%d ", i))))
flusher.Flush()
}
close(started)
// Block until client cancels
<-r.Context().Done()
}))
defer ts.Close()
ctx, cancel := context.WithCancel(context.Background())
c := NewLLMClient(ts.URL, 10*time.Second)
var streamErr error
done := make(chan struct{})
go func() {
defer close(done)
_, streamErr = c.StreamGenerate(ctx, "q", "", "", nil)
}()
<-started
cancel()
<-done
// After cancel the stream should return an error (context canceled or
// stream interrupted). The exact partial text depends on timing.
if streamErr == nil {
t.Error("expected error after context cancel")
}
}
func TestLLMClient_StreamGenerateNoSSEPrefix(t *testing.T) {
// Lines without "data: " prefix should be silently ignored (comments, blank lines, event IDs).
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
_, _ = w.Write([]byte(": this is an SSE comment\n\n"))
_, _ = w.Write([]byte("event: message\n"))
_, _ = w.Write([]byte(sseChunk("word")))
_, _ = w.Write([]byte("\n")) // blank line
_, _ = w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
result, err := c.StreamGenerate(context.Background(), "q", "", "", nil)
if err != nil {
t.Fatal(err)
}
if result != "word" {
t.Errorf("result = %q, want %q", result, "word")
}
}
func TestLLMClient_StreamGenerateManyTokens(t *testing.T) {
// Verify token ordering and full assembly with many chunks.
n := 100
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
for i := range n {
tok := fmt.Sprintf("t%d ", i)
_, _ = w.Write([]byte(sseChunk(tok)))
flusher.Flush()
}
_, _ = w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
var mu sync.Mutex
var order []int
result, err := c.StreamGenerate(context.Background(), "q", "", "", func(tok string) {
var idx int
_, _ = fmt.Sscanf(tok, "t%d ", &idx)
mu.Lock()
order = append(order, idx)
mu.Unlock()
})
if err != nil {
t.Fatal(err)
}
// Verify all tokens arrived in order
if len(order) != n {
t.Fatalf("got %d tokens, want %d", len(order), n)
}
for i, v := range order {
if v != i {
t.Errorf("order[%d] = %d", i, v)
break
}
}
// Quick sanity: result should start with "t0 " and end with last token
if !strings.HasPrefix(result, "t0 ") {
t.Errorf("result prefix = %q", result[:10])
}
}
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
// TTS client // TTS client
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────

View File

@@ -3,16 +3,16 @@
package config package config
import ( import (
"context" "context"
"log/slog" "log/slog"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
) )
// Settings holds base configuration for all handler services. // Settings holds base configuration for all handler services.
@@ -20,249 +20,249 @@ import (
// updated at runtime via WatchSecrets(). All other fields are immutable // updated at runtime via WatchSecrets(). All other fields are immutable
// after Load() returns. // after Load() returns.
type Settings struct { type Settings struct {
// Service identification (immutable) // Service identification (immutable)
ServiceName string ServiceName string
ServiceVersion string ServiceVersion string
ServiceNamespace string ServiceNamespace string
DeploymentEnv string DeploymentEnv string
// NATS configuration (immutable) // NATS configuration (immutable)
NATSURL string NATSURL string
NATSUser string NATSUser string
NATSPassword string NATSPassword string
NATSQueueGroup string NATSQueueGroup string
// Redis/Valkey configuration (immutable) // Redis/Valkey configuration (immutable)
RedisURL string RedisURL string
RedisPassword string RedisPassword string
// Milvus configuration (immutable) // Milvus configuration (immutable)
MilvusHost string MilvusHost string
MilvusPort int MilvusPort int
MilvusCollection string MilvusCollection string
// OpenTelemetry configuration (immutable) // OpenTelemetry configuration (immutable)
OTELEnabled bool OTELEnabled bool
OTELEndpoint string OTELEndpoint string
OTELUseHTTP bool OTELUseHTTP bool
// HyperDX configuration (immutable) // HyperDX configuration (immutable)
HyperDXEnabled bool HyperDXEnabled bool
HyperDXAPIKey string HyperDXAPIKey string
HyperDXEndpoint string HyperDXEndpoint string
// MLflow configuration (immutable) // MLflow configuration (immutable)
MLflowTrackingURI string MLflowTrackingURI string
MLflowExperimentName string MLflowExperimentName string
MLflowEnabled bool MLflowEnabled bool
// Health check configuration (immutable) // Health check configuration (immutable)
HealthPort int HealthPort int
HealthPath string HealthPath string
ReadyPath string ReadyPath string
// Timeouts (immutable) // Timeouts (immutable)
HTTPTimeout time.Duration HTTPTimeout time.Duration
NATSTimeout time.Duration NATSTimeout time.Duration
// Hot-reloadable fields — access via getter methods. // Hot-reloadable fields — access via getter methods.
mu sync.RWMutex mu sync.RWMutex
embeddingsURL string embeddingsURL string
rerankerURL string rerankerURL string
llmURL string llmURL string
ttsURL string ttsURL string
sttURL string sttURL string
// Secrets path for file-based hot reload (Kubernetes secret mounts) // Secrets path for file-based hot reload (Kubernetes secret mounts)
SecretsPath string SecretsPath string
} }
// Load creates a Settings populated from environment variables with defaults. // Load creates a Settings populated from environment variables with defaults.
func Load() *Settings { func Load() *Settings {
return &Settings{ return &Settings{
ServiceName: getEnv("SERVICE_NAME", "handler"), ServiceName: getEnv("SERVICE_NAME", "handler"),
ServiceVersion: getEnv("SERVICE_VERSION", "1.0.0"), ServiceVersion: getEnv("SERVICE_VERSION", "1.0.0"),
ServiceNamespace: getEnv("SERVICE_NAMESPACE", "ai-ml"), ServiceNamespace: getEnv("SERVICE_NAMESPACE", "ai-ml"),
DeploymentEnv: getEnv("DEPLOYMENT_ENV", "production"), DeploymentEnv: getEnv("DEPLOYMENT_ENV", "production"),
NATSURL: getEnv("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222"), NATSURL: getEnv("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222"),
NATSUser: getEnv("NATS_USER", ""), NATSUser: getEnv("NATS_USER", ""),
NATSPassword: getEnv("NATS_PASSWORD", ""), NATSPassword: getEnv("NATS_PASSWORD", ""),
NATSQueueGroup: getEnv("NATS_QUEUE_GROUP", ""), NATSQueueGroup: getEnv("NATS_QUEUE_GROUP", ""),
RedisURL: getEnv("REDIS_URL", "redis://valkey.ai-ml.svc.cluster.local:6379"), RedisURL: getEnv("REDIS_URL", "redis://valkey.ai-ml.svc.cluster.local:6379"),
RedisPassword: getEnv("REDIS_PASSWORD", ""), RedisPassword: getEnv("REDIS_PASSWORD", ""),
MilvusHost: getEnv("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local"), MilvusHost: getEnv("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local"),
MilvusPort: getEnvInt("MILVUS_PORT", 19530), MilvusPort: getEnvInt("MILVUS_PORT", 19530),
MilvusCollection: getEnv("MILVUS_COLLECTION", "documents"), MilvusCollection: getEnv("MILVUS_COLLECTION", "documents"),
embeddingsURL: getEnv("EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"), embeddingsURL: getEnv("EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"),
rerankerURL: getEnv("RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"), rerankerURL: getEnv("RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"),
llmURL: getEnv("LLM_URL", "http://vllm-predictor.ai-ml.svc.cluster.local"), llmURL: getEnv("LLM_URL", "http://vllm-predictor.ai-ml.svc.cluster.local"),
ttsURL: getEnv("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local"), ttsURL: getEnv("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local"),
sttURL: getEnv("STT_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"), sttURL: getEnv("STT_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"),
OTELEnabled: getEnvBool("OTEL_ENABLED", true), OTELEnabled: getEnvBool("OTEL_ENABLED", true),
OTELEndpoint: getEnv("OTEL_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317"), OTELEndpoint: getEnv("OTEL_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317"),
OTELUseHTTP: getEnvBool("OTEL_USE_HTTP", false), OTELUseHTTP: getEnvBool("OTEL_USE_HTTP", false),
HyperDXEnabled: getEnvBool("HYPERDX_ENABLED", false), HyperDXEnabled: getEnvBool("HYPERDX_ENABLED", false),
HyperDXAPIKey: getEnv("HYPERDX_API_KEY", ""), HyperDXAPIKey: getEnv("HYPERDX_API_KEY", ""),
HyperDXEndpoint: getEnv("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io"), HyperDXEndpoint: getEnv("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io"),
MLflowTrackingURI: getEnv("MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"), MLflowTrackingURI: getEnv("MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"),
MLflowExperimentName: getEnv("MLFLOW_EXPERIMENT_NAME", ""), MLflowExperimentName: getEnv("MLFLOW_EXPERIMENT_NAME", ""),
MLflowEnabled: getEnvBool("MLFLOW_ENABLED", true), MLflowEnabled: getEnvBool("MLFLOW_ENABLED", true),
HealthPort: getEnvInt("HEALTH_PORT", 8080), HealthPort: getEnvInt("HEALTH_PORT", 8080),
HealthPath: getEnv("HEALTH_PATH", "/health"), HealthPath: getEnv("HEALTH_PATH", "/health"),
ReadyPath: getEnv("READY_PATH", "/ready"), ReadyPath: getEnv("READY_PATH", "/ready"),
HTTPTimeout: getEnvDuration("HTTP_TIMEOUT", 60*time.Second), HTTPTimeout: getEnvDuration("HTTP_TIMEOUT", 60*time.Second),
NATSTimeout: getEnvDuration("NATS_TIMEOUT", 30*time.Second), NATSTimeout: getEnvDuration("NATS_TIMEOUT", 30*time.Second),
SecretsPath: getEnv("SECRETS_PATH", ""), SecretsPath: getEnv("SECRETS_PATH", ""),
} }
} }
// EmbeddingsURL returns the current embeddings service URL (thread-safe). // EmbeddingsURL returns the current embeddings service URL (thread-safe).
func (s *Settings) EmbeddingsURL() string { func (s *Settings) EmbeddingsURL() string {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
return s.embeddingsURL return s.embeddingsURL
} }
// RerankerURL returns the current reranker service URL (thread-safe). // RerankerURL returns the current reranker service URL (thread-safe).
func (s *Settings) RerankerURL() string { func (s *Settings) RerankerURL() string {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
return s.rerankerURL return s.rerankerURL
} }
// LLMURL returns the current LLM service URL (thread-safe). // LLMURL returns the current LLM service URL (thread-safe).
func (s *Settings) LLMURL() string { func (s *Settings) LLMURL() string {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
return s.llmURL return s.llmURL
} }
// TTSURL returns the current TTS service URL (thread-safe). // TTSURL returns the current TTS service URL (thread-safe).
func (s *Settings) TTSURL() string { func (s *Settings) TTSURL() string {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
return s.ttsURL return s.ttsURL
} }
// STTURL returns the current STT service URL (thread-safe). // STTURL returns the current STT service URL (thread-safe).
func (s *Settings) STTURL() string { func (s *Settings) STTURL() string {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
return s.sttURL return s.sttURL
} }
// WatchSecrets watches the SecretsPath directory for changes and reloads // WatchSecrets watches the SecretsPath directory for changes and reloads
// hot-reloadable fields. Blocks until ctx is cancelled. // hot-reloadable fields. Blocks until ctx is cancelled.
func (s *Settings) WatchSecrets(ctx context.Context) { func (s *Settings) WatchSecrets(ctx context.Context) {
if s.SecretsPath == "" { if s.SecretsPath == "" {
return return
} }
watcher, err := fsnotify.NewWatcher() watcher, err := fsnotify.NewWatcher()
if err != nil { if err != nil {
slog.Error("config: failed to create fsnotify watcher", "error", err) slog.Error("config: failed to create fsnotify watcher", "error", err)
return return
} }
defer func() { _ = watcher.Close() }() defer func() { _ = watcher.Close() }()
if err := watcher.Add(s.SecretsPath); err != nil { if err := watcher.Add(s.SecretsPath); err != nil {
slog.Error("config: failed to watch secrets path", "error", err, "path", s.SecretsPath) slog.Error("config: failed to watch secrets path", "error", err, "path", s.SecretsPath)
return return
} }
slog.Info("config: watching secrets for hot reload", "path", s.SecretsPath) slog.Info("config: watching secrets for hot reload", "path", s.SecretsPath)
for { for {
select { select {
case event, ok := <-watcher.Events: case event, ok := <-watcher.Events:
if !ok { if !ok {
return return
} }
if event.Has(fsnotify.Create) || event.Has(fsnotify.Write) { if event.Has(fsnotify.Create) || event.Has(fsnotify.Write) {
s.reloadFromSecrets() s.reloadFromSecrets()
} }
case err, ok := <-watcher.Errors: case err, ok := <-watcher.Errors:
if !ok { if !ok {
return return
} }
slog.Error("config: fsnotify error", "error", err) slog.Error("config: fsnotify error", "error", err)
case <-ctx.Done(): case <-ctx.Done():
return return
} }
} }
} }
// reloadFromSecrets reads hot-reloadable values from the secrets directory. // reloadFromSecrets reads hot-reloadable values from the secrets directory.
func (s *Settings) reloadFromSecrets() { func (s *Settings) reloadFromSecrets() {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
updated := 0 updated := 0
reload := func(filename string, target *string) { reload := func(filename string, target *string) {
path := filepath.Join(s.SecretsPath, filename) path := filepath.Join(s.SecretsPath, filename)
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
return return
} }
val := strings.TrimSpace(string(data)) val := strings.TrimSpace(string(data))
if val != "" && val != *target { if val != "" && val != *target {
*target = val *target = val
updated++ updated++
slog.Info("config: reloaded secret", "key", filename) slog.Info("config: reloaded secret", "key", filename)
} }
} }
reload("embeddings-url", &s.embeddingsURL) reload("embeddings-url", &s.embeddingsURL)
reload("reranker-url", &s.rerankerURL) reload("reranker-url", &s.rerankerURL)
reload("llm-url", &s.llmURL) reload("llm-url", &s.llmURL)
reload("tts-url", &s.ttsURL) reload("tts-url", &s.ttsURL)
reload("stt-url", &s.sttURL) reload("stt-url", &s.sttURL)
if updated > 0 { if updated > 0 {
slog.Info("config: secrets reloaded", "updated", updated) slog.Info("config: secrets reloaded", "updated", updated)
} }
} }
func getEnv(key, fallback string) string { func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" { if v := os.Getenv(key); v != "" {
return v return v
} }
return fallback return fallback
} }
func getEnvInt(key string, fallback int) int { func getEnvInt(key string, fallback int) int {
if v := os.Getenv(key); v != "" { if v := os.Getenv(key); v != "" {
if i, err := strconv.Atoi(v); err == nil { if i, err := strconv.Atoi(v); err == nil {
return i return i
} }
} }
return fallback return fallback
} }
func getEnvBool(key string, fallback bool) bool { func getEnvBool(key string, fallback bool) bool {
if v := os.Getenv(key); v != "" { if v := os.Getenv(key); v != "" {
if b, err := strconv.ParseBool(v); err == nil { if b, err := strconv.ParseBool(v); err == nil {
return b return b
} }
} }
return fallback return fallback
} }
func getEnvDuration(key string, fallback time.Duration) time.Duration { func getEnvDuration(key string, fallback time.Duration) time.Duration {
if v := os.Getenv(key); v != "" { if v := os.Getenv(key); v != "" {
if f, err := strconv.ParseFloat(v, 64); err == nil { if f, err := strconv.ParseFloat(v, 64); err == nil {
return time.Duration(f * float64(time.Second)) return time.Duration(f * float64(time.Second))
} }
} }
return fallback return fallback
} }

View File

@@ -1,123 +1,123 @@
package config package config
import ( import (
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"time" "time"
) )
func TestLoadDefaults(t *testing.T) { func TestLoadDefaults(t *testing.T) {
s := Load() s := Load()
if s.ServiceName != "handler" { if s.ServiceName != "handler" {
t.Errorf("expected default ServiceName 'handler', got %q", s.ServiceName) t.Errorf("expected default ServiceName 'handler', got %q", s.ServiceName)
} }
if s.HealthPort != 8080 { if s.HealthPort != 8080 {
t.Errorf("expected default HealthPort 8080, got %d", s.HealthPort) t.Errorf("expected default HealthPort 8080, got %d", s.HealthPort)
} }
if s.HTTPTimeout != 60*time.Second { if s.HTTPTimeout != 60*time.Second {
t.Errorf("expected default HTTPTimeout 60s, got %v", s.HTTPTimeout) t.Errorf("expected default HTTPTimeout 60s, got %v", s.HTTPTimeout)
} }
} }
func TestLoadFromEnv(t *testing.T) { func TestLoadFromEnv(t *testing.T) {
t.Setenv("SERVICE_NAME", "test-svc") t.Setenv("SERVICE_NAME", "test-svc")
t.Setenv("HEALTH_PORT", "9090") t.Setenv("HEALTH_PORT", "9090")
t.Setenv("OTEL_ENABLED", "false") t.Setenv("OTEL_ENABLED", "false")
s := Load() s := Load()
if s.ServiceName != "test-svc" { if s.ServiceName != "test-svc" {
t.Errorf("expected ServiceName 'test-svc', got %q", s.ServiceName) t.Errorf("expected ServiceName 'test-svc', got %q", s.ServiceName)
} }
if s.HealthPort != 9090 { if s.HealthPort != 9090 {
t.Errorf("expected HealthPort 9090, got %d", s.HealthPort) t.Errorf("expected HealthPort 9090, got %d", s.HealthPort)
} }
if s.OTELEnabled { if s.OTELEnabled {
t.Error("expected OTELEnabled false") t.Error("expected OTELEnabled false")
} }
} }
func TestURLGetters(t *testing.T) { func TestURLGetters(t *testing.T) {
s := Load() s := Load()
if s.EmbeddingsURL() == "" { if s.EmbeddingsURL() == "" {
t.Error("EmbeddingsURL should have a default") t.Error("EmbeddingsURL should have a default")
} }
if s.RerankerURL() == "" { if s.RerankerURL() == "" {
t.Error("RerankerURL should have a default") t.Error("RerankerURL should have a default")
} }
if s.LLMURL() == "" { if s.LLMURL() == "" {
t.Error("LLMURL should have a default") t.Error("LLMURL should have a default")
} }
if s.TTSURL() == "" { if s.TTSURL() == "" {
t.Error("TTSURL should have a default") t.Error("TTSURL should have a default")
} }
if s.STTURL() == "" { if s.STTURL() == "" {
t.Error("STTURL should have a default") t.Error("STTURL should have a default")
} }
} }
func TestURLGettersFromEnv(t *testing.T) { func TestURLGettersFromEnv(t *testing.T) {
t.Setenv("EMBEDDINGS_URL", "http://embed:8000") t.Setenv("EMBEDDINGS_URL", "http://embed:8000")
t.Setenv("LLM_URL", "http://llm:9000") t.Setenv("LLM_URL", "http://llm:9000")
s := Load() s := Load()
if s.EmbeddingsURL() != "http://embed:8000" { if s.EmbeddingsURL() != "http://embed:8000" {
t.Errorf("expected custom EmbeddingsURL, got %q", s.EmbeddingsURL()) t.Errorf("expected custom EmbeddingsURL, got %q", s.EmbeddingsURL())
} }
if s.LLMURL() != "http://llm:9000" { if s.LLMURL() != "http://llm:9000" {
t.Errorf("expected custom LLMURL, got %q", s.LLMURL()) t.Errorf("expected custom LLMURL, got %q", s.LLMURL())
} }
} }
func TestReloadFromSecrets(t *testing.T) { func TestReloadFromSecrets(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
// Write initial secret files // Write initial secret files
writeSecret(t, dir, "embeddings-url", "http://old-embed:8000") writeSecret(t, dir, "embeddings-url", "http://old-embed:8000")
writeSecret(t, dir, "llm-url", "http://old-llm:9000") writeSecret(t, dir, "llm-url", "http://old-llm:9000")
s := Load() s := Load()
s.SecretsPath = dir s.SecretsPath = dir
s.reloadFromSecrets() s.reloadFromSecrets()
if s.EmbeddingsURL() != "http://old-embed:8000" { if s.EmbeddingsURL() != "http://old-embed:8000" {
t.Errorf("expected reloaded EmbeddingsURL, got %q", s.EmbeddingsURL()) t.Errorf("expected reloaded EmbeddingsURL, got %q", s.EmbeddingsURL())
} }
if s.LLMURL() != "http://old-llm:9000" { if s.LLMURL() != "http://old-llm:9000" {
t.Errorf("expected reloaded LLMURL, got %q", s.LLMURL()) t.Errorf("expected reloaded LLMURL, got %q", s.LLMURL())
} }
// Simulate secret update // Simulate secret update
writeSecret(t, dir, "embeddings-url", "http://new-embed:8000") writeSecret(t, dir, "embeddings-url", "http://new-embed:8000")
s.reloadFromSecrets() s.reloadFromSecrets()
if s.EmbeddingsURL() != "http://new-embed:8000" { if s.EmbeddingsURL() != "http://new-embed:8000" {
t.Errorf("expected updated EmbeddingsURL, got %q", s.EmbeddingsURL()) t.Errorf("expected updated EmbeddingsURL, got %q", s.EmbeddingsURL())
} }
// LLM should remain unchanged // LLM should remain unchanged
if s.LLMURL() != "http://old-llm:9000" { if s.LLMURL() != "http://old-llm:9000" {
t.Errorf("expected unchanged LLMURL, got %q", s.LLMURL()) t.Errorf("expected unchanged LLMURL, got %q", s.LLMURL())
} }
} }
func TestReloadFromSecretsNoPath(t *testing.T) { func TestReloadFromSecretsNoPath(t *testing.T) {
s := Load() s := Load()
s.SecretsPath = "" s.SecretsPath = ""
// Should not panic // Should not panic
s.reloadFromSecrets() s.reloadFromSecrets()
} }
func TestGetEnvDuration(t *testing.T) { func TestGetEnvDuration(t *testing.T) {
t.Setenv("TEST_DUR", "30") t.Setenv("TEST_DUR", "30")
d := getEnvDuration("TEST_DUR", 10*time.Second) d := getEnvDuration("TEST_DUR", 10*time.Second)
if d != 30*time.Second { if d != 30*time.Second {
t.Errorf("expected 30s, got %v", d) t.Errorf("expected 30s, got %v", d)
} }
} }
func writeSecret(t *testing.T, dir, name, value string) { func writeSecret(t *testing.T, dir, name, value string) {
t.Helper() t.Helper()
if err := os.WriteFile(filepath.Join(dir, name), []byte(value), 0644); err != nil { if err := os.WriteFile(filepath.Join(dir, name), []byte(value), 0644); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }

File diff suppressed because it is too large Load Diff

6
go.mod
View File

@@ -3,8 +3,8 @@ module git.daviestechlabs.io/daviestechlabs/handler-base
go 1.25.1 go 1.25.1
require ( require (
github.com/fsnotify/fsnotify v1.9.0
github.com/nats-io/nats.go v1.48.0 github.com/nats-io/nats.go v1.48.0
github.com/vmihailenco/msgpack/v5 v5.4.1
go.opentelemetry.io/otel v1.40.0 go.opentelemetry.io/otel v1.40.0
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.40.0 go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.40.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0
@@ -12,12 +12,12 @@ require (
go.opentelemetry.io/otel/sdk v1.40.0 go.opentelemetry.io/otel/sdk v1.40.0
go.opentelemetry.io/otel/sdk/metric v1.40.0 go.opentelemetry.io/otel/sdk/metric v1.40.0
go.opentelemetry.io/otel/trace v1.40.0 go.opentelemetry.io/otel/trace v1.40.0
google.golang.org/protobuf v1.36.11
) )
require ( require (
github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect
github.com/google/uuid v1.6.0 // indirect github.com/google/uuid v1.6.0 // indirect
@@ -25,7 +25,6 @@ require (
github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/compress v1.18.0 // indirect
github.com/nats-io/nkeys v0.4.11 // indirect github.com/nats-io/nkeys v0.4.11 // indirect
github.com/nats-io/nuid v1.0.1 // indirect github.com/nats-io/nuid v1.0.1 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 // indirect
go.opentelemetry.io/proto/otlp v1.9.0 // indirect go.opentelemetry.io/proto/otlp v1.9.0 // indirect
@@ -36,5 +35,4 @@ require (
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 // indirect
google.golang.org/grpc v1.78.0 // indirect google.golang.org/grpc v1.78.0 // indirect
google.golang.org/protobuf v1.36.11 // indirect
) )

4
go.sum
View File

@@ -31,10 +31,6 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms= go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms=

View File

@@ -10,22 +10,19 @@ import (
"syscall" "syscall"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"google.golang.org/protobuf/proto"
"git.daviestechlabs.io/daviestechlabs/handler-base/config" "git.daviestechlabs.io/daviestechlabs/handler-base/config"
pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
"git.daviestechlabs.io/daviestechlabs/handler-base/health" "git.daviestechlabs.io/daviestechlabs/handler-base/health"
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil" "git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
"git.daviestechlabs.io/daviestechlabs/handler-base/telemetry" "git.daviestechlabs.io/daviestechlabs/handler-base/telemetry"
) )
// MessageHandler is the callback for processing decoded NATS messages. // TypedMessageHandler processes the raw NATS message.
// data is the msgpack-decoded map. Return a response map (or nil for no reply). // Services unmarshal msg.Data into their own typed structs via natsutil.Decode.
type MessageHandler func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) // Return a proto.Message (or nil for no reply).
type TypedMessageHandler func(ctx context.Context, msg *nats.Msg) (proto.Message, error)
// TypedMessageHandler processes the raw NATS message without pre-decoding to
// map[string]any. Services unmarshal msg.Data into their own typed structs,
// avoiding the double-decode overhead. Return any msgpack-serialisable value
// (a typed struct, map, or nil for no reply).
type TypedMessageHandler func(ctx context.Context, msg *nats.Msg) (any, error)
// SetupFunc is called once before the handler starts processing messages. // SetupFunc is called once before the handler starts processing messages.
type SetupFunc func(ctx context.Context) error type SetupFunc func(ctx context.Context) error
@@ -43,7 +40,6 @@ type Handler struct {
onSetup SetupFunc onSetup SetupFunc
onTeardown TeardownFunc onTeardown TeardownFunc
onMessage MessageHandler
onTypedMessage TypedMessageHandler onTypedMessage TypedMessageHandler
running bool running bool
} }
@@ -74,12 +70,7 @@ func (h *Handler) OnSetup(fn SetupFunc) { h.onSetup = fn }
// OnTeardown registers the teardown callback. // OnTeardown registers the teardown callback.
func (h *Handler) OnTeardown(fn TeardownFunc) { h.onTeardown = fn } func (h *Handler) OnTeardown(fn TeardownFunc) { h.onTeardown = fn }
// OnMessage registers the message handler callback. // OnTypedMessage registers the message handler callback.
func (h *Handler) OnMessage(fn MessageHandler) { h.onMessage = fn }
// OnTypedMessage registers a typed message handler. It replaces OnMessage —
// wrapHandler will skip the map[string]any decode and let the callback
// unmarshal msg.Data directly.
func (h *Handler) OnTypedMessage(fn TypedMessageHandler) { h.onTypedMessage = fn } func (h *Handler) OnTypedMessage(fn TypedMessageHandler) { h.onTypedMessage = fn }
// Run starts the handler: telemetry, health server, NATS subscription, and blocks until SIGTERM/SIGINT. // Run starts the handler: telemetry, health server, NATS subscription, and blocks until SIGTERM/SIGINT.
@@ -131,7 +122,7 @@ func (h *Handler) Run() error {
} }
// Subscribe // Subscribe
if h.onMessage == nil && h.onTypedMessage == nil { if h.onTypedMessage == nil {
return fmt.Errorf("no message handler registered") return fmt.Errorf("no message handler registered")
} }
if err := h.NATS.Subscribe(h.Subject, h.wrapHandler(ctx), h.QueueGroup); err != nil { if err := h.NATS.Subscribe(h.Subject, h.wrapHandler(ctx), h.QueueGroup); err != nil {
@@ -161,26 +152,16 @@ func (h *Handler) Run() error {
} }
// wrapHandler creates a nats.MsgHandler that dispatches to the registered callback. // wrapHandler creates a nats.MsgHandler that dispatches to the registered callback.
// If OnTypedMessage was used, msg.Data is passed directly without map decode.
// If OnMessage was used, msg.Data is decoded to map[string]any first.
func (h *Handler) wrapHandler(ctx context.Context) nats.MsgHandler { func (h *Handler) wrapHandler(ctx context.Context) nats.MsgHandler {
if h.onTypedMessage != nil {
return h.wrapTypedHandler(ctx)
}
return h.wrapMapHandler(ctx)
}
// wrapTypedHandler dispatches to the TypedMessageHandler (no map decode).
func (h *Handler) wrapTypedHandler(ctx context.Context) nats.MsgHandler {
return func(msg *nats.Msg) { return func(msg *nats.Msg) {
response, err := h.onTypedMessage(ctx, msg) response, err := h.onTypedMessage(ctx, msg)
if err != nil { if err != nil {
slog.Error("handler error", "subject", msg.Subject, "error", err) slog.Error("handler error", "subject", msg.Subject, "error", err)
if msg.Reply != "" { if msg.Reply != "" {
_ = h.NATS.Publish(msg.Reply, map[string]any{ _ = h.NATS.Publish(msg.Reply, &pb.ErrorResponse{
"error": true, Error: true,
"message": err.Error(), Message: err.Error(),
"type": fmt.Sprintf("%T", err), Type: fmt.Sprintf("%T", err),
}) })
} }
return return
@@ -192,40 +173,3 @@ func (h *Handler) wrapTypedHandler(ctx context.Context) nats.MsgHandler {
} }
} }
} }
// wrapMapHandler dispatches to the legacy MessageHandler (decodes to map first).
func (h *Handler) wrapMapHandler(ctx context.Context) nats.MsgHandler {
return func(msg *nats.Msg) {
data, err := natsutil.DecodeMsgpackMap(msg.Data)
if err != nil {
slog.Error("failed to decode message", "subject", msg.Subject, "error", err)
if msg.Reply != "" {
_ = h.NATS.Publish(msg.Reply, map[string]any{
"error": true,
"message": err.Error(),
"type": "DecodeError",
})
}
return
}
response, err := h.onMessage(ctx, msg, data)
if err != nil {
slog.Error("handler error", "subject", msg.Subject, "error", err)
if msg.Reply != "" {
_ = h.NATS.Publish(msg.Reply, map[string]any{
"error": true,
"message": err.Error(),
"type": fmt.Sprintf("%T", err),
})
}
return
}
if response != nil && msg.Reply != "" {
if err := h.NATS.Publish(msg.Reply, response); err != nil {
slog.Error("failed to publish reply", "error", err)
}
}
}
}

View File

@@ -5,9 +5,11 @@ import (
"testing" "testing"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"github.com/vmihailenco/msgpack/v5" "google.golang.org/protobuf/proto"
"git.daviestechlabs.io/daviestechlabs/handler-base/config" "git.daviestechlabs.io/daviestechlabs/handler-base/config"
pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
) )
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
@@ -57,11 +59,11 @@ func TestCallbackRegistration(t *testing.T) {
return nil return nil
}) })
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) { h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
return nil, nil return nil, nil
}) })
if h.onSetup == nil || h.onTeardown == nil || h.onMessage == nil { if h.onSetup == nil || h.onTeardown == nil || h.onTypedMessage == nil {
t.Error("callbacks should not be nil after registration") t.Error("callbacks should not be nil after registration")
} }
@@ -77,8 +79,8 @@ func TestTypedMessageRegistration(t *testing.T) {
cfg := config.Load() cfg := config.Load()
h := New("ai.test", cfg) h := New("ai.test", cfg)
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) { h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
return map[string]any{"ok": true}, nil return &pb.ChatResponse{Response: "ok"}, nil
}) })
if h.onTypedMessage == nil { if h.onTypedMessage == nil {
@@ -94,19 +96,20 @@ func TestWrapHandler_ValidMessage(t *testing.T) {
cfg := config.Load() cfg := config.Load()
h := New("ai.test", cfg) h := New("ai.test", cfg)
var receivedData map[string]any var receivedReq pb.ChatRequest
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) { h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
receivedData = data if err := natsutil.Decode(msg.Data, &receivedReq); err != nil {
return map[string]any{"status": "ok"}, nil return nil, err
}
return &pb.ChatResponse{Response: "ok", UserId: receivedReq.GetUserId()}, nil
}) })
// Encode a message the same way services would. // Encode a message the same way services would.
payload := map[string]any{ encoded, err := proto.Marshal(&pb.ChatRequest{
"request_id": "test-001", RequestId: "test-001",
"message": "hello", Message: "hello",
"premium": true, Premium: true,
} })
encoded, err := msgpack.Marshal(payload)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -118,47 +121,48 @@ func TestWrapHandler_ValidMessage(t *testing.T) {
Data: encoded, Data: encoded,
}) })
if receivedData == nil { if receivedReq.GetRequestId() != "test-001" {
t.Fatal("handler was not called") t.Errorf("request_id = %v", receivedReq.GetRequestId())
} }
if receivedData["request_id"] != "test-001" { if receivedReq.GetPremium() != true {
t.Errorf("request_id = %v", receivedData["request_id"]) t.Errorf("premium = %v", receivedReq.GetPremium())
}
if receivedData["premium"] != true {
t.Errorf("premium = %v", receivedData["premium"])
} }
} }
func TestWrapHandler_InvalidMsgpack(t *testing.T) { func TestWrapHandler_InvalidMessage(t *testing.T) {
cfg := config.Load() cfg := config.Load()
h := New("ai.test", cfg) h := New("ai.test", cfg)
handlerCalled := false handlerCalled := false
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) { h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
handlerCalled = true handlerCalled = true
return nil, nil var req pb.ChatRequest
if err := natsutil.Decode(msg.Data, &req); err != nil {
return nil, err
}
return &pb.ChatResponse{}, nil
}) })
handler := h.wrapHandler(context.Background()) handler := h.wrapHandler(context.Background())
handler(&nats.Msg{ handler(&nats.Msg{
Subject: "ai.test", Subject: "ai.test",
Data: []byte{0xFF, 0xFE, 0xFD}, // invalid msgpack Data: []byte{0xFF, 0xFE, 0xFD}, // invalid protobuf
}) })
if handlerCalled { // The handler IS called (wrapHandler doesn't pre-decode), but it should
t.Error("handler should not be called for invalid msgpack") // return an error from Decode. Either way no panic.
} _ = handlerCalled
} }
func TestWrapHandler_HandlerError(t *testing.T) { func TestWrapHandler_HandlerError(t *testing.T) {
cfg := config.Load() cfg := config.Load()
h := New("ai.test", cfg) h := New("ai.test", cfg)
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) { h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
return nil, context.DeadlineExceeded return nil, context.DeadlineExceeded
}) })
encoded, _ := msgpack.Marshal(map[string]any{"key": "val"}) encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "err-test"})
handler := h.wrapHandler(context.Background()) handler := h.wrapHandler(context.Background())
// Should not panic even when handler returns error. // Should not panic even when handler returns error.
@@ -172,11 +176,11 @@ func TestWrapHandler_NilResponse(t *testing.T) {
cfg := config.Load() cfg := config.Load()
h := New("ai.test", cfg) h := New("ai.test", cfg)
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) { h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
return nil, nil // fire-and-forget style return nil, nil // fire-and-forget style
}) })
encoded, _ := msgpack.Marshal(map[string]any{"x": 1}) encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "nil-resp"})
handler := h.wrapHandler(context.Background()) handler := h.wrapHandler(context.Background())
// Should not panic with nil response and no reply subject. // Should not panic with nil response and no reply subject.
@@ -190,63 +194,58 @@ func TestWrapHandler_NilResponse(t *testing.T) {
// wrapHandler dispatch tests — typed handler path // wrapHandler dispatch tests — typed handler path
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
func TestWrapTypedHandler_ValidMessage(t *testing.T) { func TestWrapHandler_Typed(t *testing.T) {
cfg := config.Load() cfg := config.Load()
h := New("ai.test", cfg) h := New("ai.test", cfg)
type testReq struct { var received pb.ChatRequest
RequestID string `msgpack:"request_id"` h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
Message string `msgpack:"message"` if err := natsutil.Decode(msg.Data, &received); err != nil {
}
var received testReq
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
if err := msgpack.Unmarshal(msg.Data, &received); err != nil {
return nil, err return nil, err
} }
return map[string]any{"status": "ok"}, nil return &pb.ChatResponse{UserId: received.GetUserId(), Response: "ok"}, nil
}) })
encoded, _ := msgpack.Marshal(map[string]any{ encoded, _ := proto.Marshal(&pb.ChatRequest{
"request_id": "typed-001", RequestId: "typed-001",
"message": "hello typed", Message: "hello typed",
}) })
handler := h.wrapHandler(context.Background()) handler := h.wrapHandler(context.Background())
handler(&nats.Msg{Subject: "ai.test", Data: encoded}) handler(&nats.Msg{Subject: "ai.test", Data: encoded})
if received.RequestID != "typed-001" { if received.GetRequestId() != "typed-001" {
t.Errorf("RequestID = %q", received.RequestID) t.Errorf("RequestId = %q", received.GetRequestId())
} }
if received.Message != "hello typed" { if received.GetMessage() != "hello typed" {
t.Errorf("Message = %q", received.Message) t.Errorf("Message = %q", received.GetMessage())
} }
} }
func TestWrapTypedHandler_Error(t *testing.T) { func TestWrapHandler_TypedError(t *testing.T) {
cfg := config.Load() cfg := config.Load()
h := New("ai.test", cfg) h := New("ai.test", cfg)
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) { h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
return nil, context.DeadlineExceeded return nil, context.DeadlineExceeded
}) })
encoded, _ := msgpack.Marshal(map[string]any{"key": "val"}) encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "err"})
handler := h.wrapHandler(context.Background()) handler := h.wrapHandler(context.Background())
// Should not panic. // Should not panic.
handler(&nats.Msg{Subject: "ai.test", Data: encoded}) handler(&nats.Msg{Subject: "ai.test", Data: encoded})
} }
func TestWrapTypedHandler_NilResponse(t *testing.T) { func TestWrapHandler_TypedNilResponse(t *testing.T) {
cfg := config.Load() cfg := config.Load()
h := New("ai.test", cfg) h := New("ai.test", cfg)
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) { h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
return nil, nil return nil, nil
}) })
encoded, _ := msgpack.Marshal(map[string]any{"x": 1}) encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "nil"})
handler := h.wrapHandler(context.Background()) handler := h.wrapHandler(context.Background())
handler(&nats.Msg{Subject: "ai.test", Data: encoded}) handler(&nats.Msg{Subject: "ai.test", Data: encoded})
} }
@@ -258,49 +257,18 @@ func TestWrapTypedHandler_NilResponse(t *testing.T) {
func BenchmarkWrapHandler(b *testing.B) { func BenchmarkWrapHandler(b *testing.B) {
cfg := config.Load() cfg := config.Load()
h := New("ai.test", cfg) h := New("ai.test", cfg)
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) { h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
return map[string]any{"ok": true}, nil var req pb.ChatRequest
_ = natsutil.Decode(msg.Data, &req)
return &pb.ChatResponse{Response: "ok"}, nil
}) })
payload := map[string]any{ encoded, _ := proto.Marshal(&pb.ChatRequest{
"request_id": "bench-001", RequestId: "bench-001",
"message": "What is the capital of France?", Message: "What is the capital of France?",
"premium": true, Premium: true,
"top_k": 10, TopK: 10,
} })
encoded, _ := msgpack.Marshal(payload)
handler := h.wrapHandler(context.Background())
msg := &nats.Msg{Subject: "ai.test", Data: encoded}
b.ResetTimer()
for b.Loop() {
handler(msg)
}
}
func BenchmarkWrapTypedHandler(b *testing.B) {
type benchReq struct {
RequestID string `msgpack:"request_id"`
Message string `msgpack:"message"`
Premium bool `msgpack:"premium"`
TopK int `msgpack:"top_k"`
}
cfg := config.Load()
h := New("ai.test", cfg)
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
var req benchReq
_ = msgpack.Unmarshal(msg.Data, &req)
return map[string]any{"ok": true}, nil
})
payload := map[string]any{
"request_id": "bench-001",
"message": "What is the capital of France?",
"premium": true,
"top_k": 10,
}
encoded, _ := msgpack.Marshal(payload)
handler := h.wrapHandler(context.Background()) handler := h.wrapHandler(context.Background())
msg := &nats.Msg{Subject: "ai.test", Data: encoded} msg := &nats.Msg{Subject: "ai.test", Data: encoded}

View File

@@ -1,65 +1,24 @@
// Package messages benchmarks compare three serialization strategies: // Package messages benchmarks protobuf encoding/decoding of all message types.
//
// 1. msgpack map[string]any — the old approach (dynamic, no types)
// 2. msgpack typed struct — the new approach (compile-time safe, short keys)
// 3. protobuf — optional future migration
// //
// Run with: // Run with:
// //
// go test -bench=. -benchmem -count=5 ./messages/... | tee bench.txt // go test -bench=. -benchmem -count=5 ./messages/... | tee bench.txt
// # optional: go install golang.org/x/perf/cmd/benchstat@latest && benchstat bench.txt // # optional: go install golang.org/x/perf/cmd/benchstat@latest && benchstat bench.txt
package messages package messages
import ( import (
"testing" "testing"
"time" "time"
"github.com/vmihailenco/msgpack/v5"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
pb "git.daviestechlabs.io/daviestechlabs/handler-base/messages/proto" pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
) )
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
// Test fixtures — equivalent data across all three encodings // Test fixtures — proto message constructors
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
// chatRequestMap is the legacy map[string]any representation.
func chatRequestMap() map[string]any {
return map[string]any{
"request_id": "req-abc-123",
"user_id": "user-42",
"message": "What is the capital of France?",
"query": "",
"premium": true,
"enable_rag": true,
"enable_reranker": true,
"enable_streaming": false,
"top_k": 10,
"collection": "documents",
"enable_tts": false,
"system_prompt": "You are a helpful assistant.",
"response_subject": "ai.chat.response.req-abc-123",
}
}
// chatRequestStruct is the typed struct representation.
func chatRequestStruct() ChatRequest {
return ChatRequest{
RequestID: "req-abc-123",
UserID: "user-42",
Message: "What is the capital of France?",
Premium: true,
EnableRAG: true,
EnableReranker: true,
TopK: 10,
Collection: "documents",
SystemPrompt: "You are a helpful assistant.",
ResponseSubject: "ai.chat.response.req-abc-123",
}
}
// chatRequestProto is the protobuf representation.
func chatRequestProto() *pb.ChatRequest { func chatRequestProto() *pb.ChatRequest {
return &pb.ChatRequest{ return &pb.ChatRequest{
RequestId: "req-abc-123", RequestId: "req-abc-123",
@@ -75,25 +34,6 @@ func chatRequestProto() *pb.ChatRequest {
} }
} }
// voiceResponseMap is a voice response with a 16 KB audio payload.
func voiceResponseMap() map[string]any {
return map[string]any{
"request_id": "vr-001",
"response": "The capital of France is Paris.",
"audio": make([]byte, 16384),
"transcription": "What is the capital of France?",
}
}
func voiceResponseStruct() VoiceResponse {
return VoiceResponse{
RequestID: "vr-001",
Response: "The capital of France is Paris.",
Audio: make([]byte, 16384),
Transcription: "What is the capital of France?",
}
}
func voiceResponseProto() *pb.VoiceResponse { func voiceResponseProto() *pb.VoiceResponse {
return &pb.VoiceResponse{ return &pb.VoiceResponse{
RequestId: "vr-001", RequestId: "vr-001",
@@ -103,31 +43,6 @@ func voiceResponseProto() *pb.VoiceResponse {
} }
} }
// ttsChunkMap simulates a streaming audio chunk (~32 KB).
func ttsChunkMap() map[string]any {
return map[string]any{
"session_id": "tts-sess-99",
"chunk_index": 3,
"total_chunks": 12,
"audio_b64": string(make([]byte, 32768)), // old: base64 string
"is_last": false,
"timestamp": time.Now().Unix(),
"sample_rate": 24000,
}
}
func ttsChunkStruct() TTSAudioChunk {
return TTSAudioChunk{
SessionID: "tts-sess-99",
ChunkIndex: 3,
TotalChunks: 12,
Audio: make([]byte, 32768), // new: raw bytes
IsLast: false,
Timestamp: time.Now().Unix(),
SampleRate: 24000,
}
}
func ttsChunkProto() *pb.TTSAudioChunk { func ttsChunkProto() *pb.TTSAudioChunk {
return &pb.TTSAudioChunk{ return &pb.TTSAudioChunk{
SessionId: "tts-sess-99", SessionId: "tts-sess-99",
@@ -146,27 +61,17 @@ func ttsChunkProto() *pb.TTSAudioChunk {
func TestWireSize(t *testing.T) { func TestWireSize(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
mapData any protoMsg proto.Message
structVal any
protoMsg proto.Message
}{ }{
{"ChatRequest", chatRequestMap(), chatRequestStruct(), chatRequestProto()}, {"ChatRequest", chatRequestProto()},
{"VoiceResponse", voiceResponseMap(), voiceResponseStruct(), voiceResponseProto()}, {"VoiceResponse", voiceResponseProto()},
{"TTSAudioChunk", ttsChunkMap(), ttsChunkStruct(), ttsChunkProto()}, {"TTSAudioChunk", ttsChunkProto()},
} }
for _, tt := range tests { for _, tt := range tests {
mapBytes, _ := msgpack.Marshal(tt.mapData)
structBytes, _ := msgpack.Marshal(tt.structVal)
protoBytes, _ := proto.Marshal(tt.protoMsg) protoBytes, _ := proto.Marshal(tt.protoMsg)
t.Logf("%-16s proto=%5d B", tt.name, len(protoBytes))
t.Logf("%-16s map=%5d B struct=%5d B proto=%5d B (struct saves %.0f%%, proto saves %.0f%%)",
tt.name,
len(mapBytes), len(structBytes), len(protoBytes),
100*(1-float64(len(structBytes))/float64(len(mapBytes))),
100*(1-float64(len(protoBytes))/float64(len(mapBytes))),
)
} }
} }
@@ -174,23 +79,7 @@ func TestWireSize(t *testing.T) {
// Encode benchmarks // Encode benchmarks
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
func BenchmarkEncode_ChatRequest_MsgpackMap(b *testing.B) { func BenchmarkEncode_ChatRequest(b *testing.B) {
data := chatRequestMap()
b.ResetTimer()
for b.Loop() {
_, _ = msgpack.Marshal(data)
}
}
func BenchmarkEncode_ChatRequest_MsgpackStruct(b *testing.B) {
data := chatRequestStruct()
b.ResetTimer()
for b.Loop() {
_, _ = msgpack.Marshal(data)
}
}
func BenchmarkEncode_ChatRequest_Protobuf(b *testing.B) {
data := chatRequestProto() data := chatRequestProto()
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
@@ -198,23 +87,7 @@ func BenchmarkEncode_ChatRequest_Protobuf(b *testing.B) {
} }
} }
func BenchmarkEncode_VoiceResponse_MsgpackMap(b *testing.B) { func BenchmarkEncode_VoiceResponse(b *testing.B) {
data := voiceResponseMap()
b.ResetTimer()
for b.Loop() {
_, _ = msgpack.Marshal(data)
}
}
func BenchmarkEncode_VoiceResponse_MsgpackStruct(b *testing.B) {
data := voiceResponseStruct()
b.ResetTimer()
for b.Loop() {
_, _ = msgpack.Marshal(data)
}
}
func BenchmarkEncode_VoiceResponse_Protobuf(b *testing.B) {
data := voiceResponseProto() data := voiceResponseProto()
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
@@ -222,23 +95,7 @@ func BenchmarkEncode_VoiceResponse_Protobuf(b *testing.B) {
} }
} }
func BenchmarkEncode_TTSChunk_MsgpackMap(b *testing.B) { func BenchmarkEncode_TTSChunk(b *testing.B) {
data := ttsChunkMap()
b.ResetTimer()
for b.Loop() {
_, _ = msgpack.Marshal(data)
}
}
func BenchmarkEncode_TTSChunk_MsgpackStruct(b *testing.B) {
data := ttsChunkStruct()
b.ResetTimer()
for b.Loop() {
_, _ = msgpack.Marshal(data)
}
}
func BenchmarkEncode_TTSChunk_Protobuf(b *testing.B) {
data := ttsChunkProto() data := ttsChunkProto()
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
@@ -250,25 +107,7 @@ func BenchmarkEncode_TTSChunk_Protobuf(b *testing.B) {
// Decode benchmarks // Decode benchmarks
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
func BenchmarkDecode_ChatRequest_MsgpackMap(b *testing.B) { func BenchmarkDecode_ChatRequest(b *testing.B) {
encoded, _ := msgpack.Marshal(chatRequestMap())
b.ResetTimer()
for b.Loop() {
var m map[string]any
_ = msgpack.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_ChatRequest_MsgpackStruct(b *testing.B) {
encoded, _ := msgpack.Marshal(chatRequestStruct())
b.ResetTimer()
for b.Loop() {
var m ChatRequest
_ = msgpack.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_ChatRequest_Protobuf(b *testing.B) {
encoded, _ := proto.Marshal(chatRequestProto()) encoded, _ := proto.Marshal(chatRequestProto())
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
@@ -277,25 +116,7 @@ func BenchmarkDecode_ChatRequest_Protobuf(b *testing.B) {
} }
} }
func BenchmarkDecode_VoiceResponse_MsgpackMap(b *testing.B) { func BenchmarkDecode_VoiceResponse(b *testing.B) {
encoded, _ := msgpack.Marshal(voiceResponseMap())
b.ResetTimer()
for b.Loop() {
var m map[string]any
_ = msgpack.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_VoiceResponse_MsgpackStruct(b *testing.B) {
encoded, _ := msgpack.Marshal(voiceResponseStruct())
b.ResetTimer()
for b.Loop() {
var m VoiceResponse
_ = msgpack.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_VoiceResponse_Protobuf(b *testing.B) {
encoded, _ := proto.Marshal(voiceResponseProto()) encoded, _ := proto.Marshal(voiceResponseProto())
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
@@ -304,25 +125,7 @@ func BenchmarkDecode_VoiceResponse_Protobuf(b *testing.B) {
} }
} }
func BenchmarkDecode_TTSChunk_MsgpackMap(b *testing.B) { func BenchmarkDecode_TTSChunk(b *testing.B) {
encoded, _ := msgpack.Marshal(ttsChunkMap())
b.ResetTimer()
for b.Loop() {
var m map[string]any
_ = msgpack.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_TTSChunk_MsgpackStruct(b *testing.B) {
encoded, _ := msgpack.Marshal(ttsChunkStruct())
b.ResetTimer()
for b.Loop() {
var m TTSAudioChunk
_ = msgpack.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_TTSChunk_Protobuf(b *testing.B) {
encoded, _ := proto.Marshal(ttsChunkProto()) encoded, _ := proto.Marshal(ttsChunkProto())
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
@@ -335,27 +138,7 @@ func BenchmarkDecode_TTSChunk_Protobuf(b *testing.B) {
// Roundtrip benchmarks (encode + decode) // Roundtrip benchmarks (encode + decode)
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
func BenchmarkRoundtrip_ChatRequest_MsgpackMap(b *testing.B) { func BenchmarkRoundtrip_ChatRequest(b *testing.B) {
data := chatRequestMap()
b.ResetTimer()
for b.Loop() {
enc, _ := msgpack.Marshal(data)
var dec map[string]any
_ = msgpack.Unmarshal(enc, &dec)
}
}
func BenchmarkRoundtrip_ChatRequest_MsgpackStruct(b *testing.B) {
data := chatRequestStruct()
b.ResetTimer()
for b.Loop() {
enc, _ := msgpack.Marshal(data)
var dec ChatRequest
_ = msgpack.Unmarshal(enc, &dec)
}
}
func BenchmarkRoundtrip_ChatRequest_Protobuf(b *testing.B) {
data := chatRequestProto() data := chatRequestProto()
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
@@ -366,143 +149,157 @@ func BenchmarkRoundtrip_ChatRequest_Protobuf(b *testing.B) {
} }
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
// Typed struct unit tests — verify roundtrip correctness // Correctness tests — verify proto roundtrip
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
func TestRoundtrip_ChatRequest(t *testing.T) { func TestRoundtrip_ChatRequest(t *testing.T) {
orig := chatRequestStruct() orig := chatRequestProto()
data, err := msgpack.Marshal(orig) data, err := proto.Marshal(orig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var dec ChatRequest var dec pb.ChatRequest
if err := msgpack.Unmarshal(data, &dec); err != nil { if err := proto.Unmarshal(data, &dec); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if dec.RequestID != orig.RequestID { if dec.GetRequestId() != orig.GetRequestId() {
t.Errorf("RequestID = %q, want %q", dec.RequestID, orig.RequestID) t.Errorf("RequestId = %q, want %q", dec.GetRequestId(), orig.GetRequestId())
} }
if dec.Message != orig.Message { if dec.GetMessage() != orig.GetMessage() {
t.Errorf("Message = %q, want %q", dec.Message, orig.Message) t.Errorf("Message = %q, want %q", dec.GetMessage(), orig.GetMessage())
} }
if dec.TopK != orig.TopK { if dec.GetTopK() != orig.GetTopK() {
t.Errorf("TopK = %d, want %d", dec.TopK, orig.TopK) t.Errorf("TopK = %d, want %d", dec.GetTopK(), orig.GetTopK())
} }
if dec.Premium != orig.Premium { if dec.GetPremium() != orig.GetPremium() {
t.Errorf("Premium = %v, want %v", dec.Premium, orig.Premium) t.Errorf("Premium = %v, want %v", dec.GetPremium(), orig.GetPremium())
} }
if dec.EffectiveQuery() != orig.Message { if EffectiveQuery(&dec) != orig.GetMessage() {
t.Errorf("EffectiveQuery() = %q, want %q", dec.EffectiveQuery(), orig.Message) t.Errorf("EffectiveQuery() = %q, want %q", EffectiveQuery(&dec), orig.GetMessage())
} }
} }
func TestRoundtrip_VoiceResponse(t *testing.T) { func TestRoundtrip_VoiceResponse(t *testing.T) {
orig := voiceResponseStruct() orig := voiceResponseProto()
data, err := msgpack.Marshal(orig) data, err := proto.Marshal(orig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var dec VoiceResponse var dec pb.VoiceResponse
if err := msgpack.Unmarshal(data, &dec); err != nil { if err := proto.Unmarshal(data, &dec); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if dec.RequestID != orig.RequestID { if dec.GetRequestId() != orig.GetRequestId() {
t.Errorf("RequestID mismatch") t.Errorf("RequestId mismatch")
} }
if len(dec.Audio) != len(orig.Audio) { if len(dec.GetAudio()) != len(orig.GetAudio()) {
t.Errorf("Audio len = %d, want %d", len(dec.Audio), len(orig.Audio)) t.Errorf("Audio len = %d, want %d", len(dec.GetAudio()), len(orig.GetAudio()))
} }
if dec.Transcription != orig.Transcription { if dec.GetTranscription() != orig.GetTranscription() {
t.Errorf("Transcription mismatch") t.Errorf("Transcription mismatch")
} }
} }
func TestRoundtrip_TTSAudioChunk(t *testing.T) { func TestRoundtrip_TTSAudioChunk(t *testing.T) {
orig := ttsChunkStruct() orig := ttsChunkProto()
data, err := msgpack.Marshal(orig) data, err := proto.Marshal(orig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var dec TTSAudioChunk var dec pb.TTSAudioChunk
if err := msgpack.Unmarshal(data, &dec); err != nil { if err := proto.Unmarshal(data, &dec); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if dec.SessionID != orig.SessionID { if dec.GetSessionId() != orig.GetSessionId() {
t.Errorf("SessionID mismatch") t.Errorf("SessionId mismatch")
} }
if dec.ChunkIndex != orig.ChunkIndex { if dec.GetChunkIndex() != orig.GetChunkIndex() {
t.Errorf("ChunkIndex = %d, want %d", dec.ChunkIndex, orig.ChunkIndex) t.Errorf("ChunkIndex = %d, want %d", dec.GetChunkIndex(), orig.GetChunkIndex())
} }
if len(dec.Audio) != len(orig.Audio) { if len(dec.GetAudio()) != len(orig.GetAudio()) {
t.Errorf("Audio len = %d, want %d", len(dec.Audio), len(orig.Audio)) t.Errorf("Audio len = %d, want %d", len(dec.GetAudio()), len(orig.GetAudio()))
} }
if dec.SampleRate != orig.SampleRate { if dec.GetSampleRate() != orig.GetSampleRate() {
t.Errorf("SampleRate = %d, want %d", dec.SampleRate, orig.SampleRate) t.Errorf("SampleRate = %d, want %d", dec.GetSampleRate(), orig.GetSampleRate())
} }
} }
func TestRoundtrip_PipelineTrigger(t *testing.T) { func TestRoundtrip_PipelineTrigger(t *testing.T) {
orig := PipelineTrigger{ orig := &pb.PipelineTrigger{
RequestID: "pip-001", RequestId: "pip-001",
Pipeline: "document-ingestion", Pipeline: "document-ingestion",
Parameters: map[string]any{"source": "s3://bucket/data"}, Parameters: map[string]string{"source": "s3://bucket/data"},
} }
data, err := msgpack.Marshal(orig) data, err := proto.Marshal(orig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var dec PipelineTrigger var dec pb.PipelineTrigger
if err := msgpack.Unmarshal(data, &dec); err != nil { if err := proto.Unmarshal(data, &dec); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if dec.Pipeline != orig.Pipeline { if dec.GetPipeline() != orig.GetPipeline() {
t.Errorf("Pipeline = %q, want %q", dec.Pipeline, orig.Pipeline) t.Errorf("Pipeline = %q, want %q", dec.GetPipeline(), orig.GetPipeline())
} }
if dec.Parameters["source"] != orig.Parameters["source"] { if dec.GetParameters()["source"] != orig.GetParameters()["source"] {
t.Errorf("Parameters[source] mismatch") t.Errorf("Parameters[source] mismatch")
} }
} }
func TestRoundtrip_STTTranscription(t *testing.T) { func TestRoundtrip_STTTranscription(t *testing.T) {
orig := STTTranscription{ orig := &pb.STTTranscription{
SessionID: "stt-001", SessionId: "stt-001",
Transcript: "hello world", Transcript: "hello world",
Sequence: 5, Sequence: 5,
IsPartial: false, IsPartial: false,
IsFinal: true, IsFinal: true,
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
SpeakerID: "speaker-1", SpeakerId: "speaker-1",
HasVoiceActivity: true, HasVoiceActivity: true,
State: "listening", State: "listening",
} }
data, err := msgpack.Marshal(orig) data, err := proto.Marshal(orig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var dec STTTranscription var dec pb.STTTranscription
if err := msgpack.Unmarshal(data, &dec); err != nil { if err := proto.Unmarshal(data, &dec); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if dec.Transcript != orig.Transcript { if dec.GetTranscript() != orig.GetTranscript() {
t.Errorf("Transcript = %q, want %q", dec.Transcript, orig.Transcript) t.Errorf("Transcript = %q, want %q", dec.GetTranscript(), orig.GetTranscript())
} }
if dec.IsFinal != orig.IsFinal { if dec.GetIsFinal() != orig.GetIsFinal() {
t.Error("IsFinal mismatch") t.Error("IsFinal mismatch")
} }
} }
func TestRoundtrip_ErrorResponse(t *testing.T) { func TestRoundtrip_ErrorResponse(t *testing.T) {
orig := ErrorResponse{Error: true, Message: "something broke", Type: "InternalError"} orig := &pb.ErrorResponse{Error: true, Message: "something broke", Type: "InternalError"}
data, err := msgpack.Marshal(orig) data, err := proto.Marshal(orig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var dec ErrorResponse var dec pb.ErrorResponse
if err := msgpack.Unmarshal(data, &dec); err != nil { if err := proto.Unmarshal(data, &dec); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !dec.Error || dec.Message != "something broke" || dec.Type != "InternalError" { if !dec.GetError() || dec.GetMessage() != "something broke" || dec.GetType() != "InternalError" {
t.Errorf("ErrorResponse roundtrip mismatch: %+v", dec) t.Errorf("ErrorResponse roundtrip mismatch: %+v", &dec)
}
}
func TestEffectiveQuery_MessageSet(t *testing.T) {
req := &pb.ChatRequest{Message: "hello", Query: "world"}
if got := EffectiveQuery(req); got != "hello" {
t.Errorf("EffectiveQuery() = %q, want %q", got, "hello")
}
}
func TestEffectiveQuery_FallbackToQuery(t *testing.T) {
req := &pb.ChatRequest{Query: "world"}
if got := EffectiveQuery(req); got != "world" {
t.Errorf("EffectiveQuery() = %q, want %q", got, "world")
} }
} }

View File

@@ -1,224 +1,69 @@
// Package messages defines typed NATS message structs for all services. // Package messages re-exports protobuf message types and provides NATS
// subject constants plus helper functions.
// //
// Using typed structs with short msgpack field tags instead of map[string]any // The canonical type definitions live in the generated package
// provides compile-time safety, smaller wire size (integer-like short keys vs // gen/messagespb (from proto/messages/v1/messages.proto).
// full string keys), and faster encode/decode by avoiding interface{} boxing. // This package provides type aliases so existing callers can keep using
// // messages.ChatRequest, etc., while the wire format is now protobuf.
// Audio data uses raw []byte instead of base64-encoded strings — msgpack
// supports binary natively, eliminating the 33% base64 overhead.
package messages package messages
import "time" import (
"time"
// ──────────────────────────────────────────────────────────────────────────── pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
// Pipeline Bridge )
// ────────────────────────────────────────────────────────────────────────────
// PipelineTrigger is the request to start a pipeline. // ════════════════════════════════════════════════════════════════════════════
type PipelineTrigger struct { // Type aliases — use these or import gen/messagespb directly.
RequestID string `msgpack:"request_id" json:"request_id"` // ════════════════════════════════════════════════════════════════════════════
Pipeline string `msgpack:"pipeline" json:"pipeline"`
Parameters map[string]any `msgpack:"parameters,omitempty" json:"parameters,omitempty"`
}
// PipelineStatus is the response / status update for a pipeline run. // Common
type PipelineStatus struct { type ErrorResponse = pb.ErrorResponse
RequestID string `msgpack:"request_id" json:"request_id"`
Status string `msgpack:"status" json:"status"`
RunID string `msgpack:"run_id,omitempty" json:"run_id,omitempty"`
Engine string `msgpack:"engine,omitempty" json:"engine,omitempty"`
Pipeline string `msgpack:"pipeline,omitempty" json:"pipeline,omitempty"`
SubmittedAt string `msgpack:"submitted_at,omitempty" json:"submitted_at,omitempty"`
Error string `msgpack:"error,omitempty" json:"error,omitempty"`
AvailablePipelines []string `msgpack:"available_pipelines,omitempty" json:"available_pipelines,omitempty"`
}
// ──────────────────────────────────────────────────────────────────────────── // Chat
// Chat Handler type LoginEvent = pb.LoginEvent
// ──────────────────────────────────────────────────────────────────────────── type GreetingRequest = pb.GreetingRequest
type GreetingResponse = pb.GreetingResponse
type ChatRequest = pb.ChatRequest
type ChatResponse = pb.ChatResponse
type ChatStreamChunk = pb.ChatStreamChunk
// ChatRequest is an incoming chat message. // Voice
type ChatRequest struct { type VoiceRequest = pb.VoiceRequest
RequestID string `msgpack:"request_id" json:"request_id"` type VoiceResponse = pb.VoiceResponse
UserID string `msgpack:"user_id" json:"user_id"` type DocumentSource = pb.DocumentSource
Message string `msgpack:"message" json:"message"`
Query string `msgpack:"query,omitempty" json:"query,omitempty"` // TTS
Premium bool `msgpack:"premium,omitempty" json:"premium,omitempty"` type TTSRequest = pb.TTSRequest
EnableRAG bool `msgpack:"enable_rag,omitempty" json:"enable_rag,omitempty"` type TTSAudioChunk = pb.TTSAudioChunk
EnableReranker bool `msgpack:"enable_reranker,omitempty" json:"enable_reranker,omitempty"` type TTSFullResponse = pb.TTSFullResponse
EnableStreaming bool `msgpack:"enable_streaming,omitempty" json:"enable_streaming,omitempty"` type TTSStatus = pb.TTSStatus
TopK int `msgpack:"top_k,omitempty" json:"top_k,omitempty"` type TTSVoiceInfo = pb.TTSVoiceInfo
Collection string `msgpack:"collection,omitempty" json:"collection,omitempty"` type TTSVoiceListResponse = pb.TTSVoiceListResponse
EnableTTS bool `msgpack:"enable_tts,omitempty" json:"enable_tts,omitempty"` type TTSVoiceRefreshResponse = pb.TTSVoiceRefreshResponse
SystemPrompt string `msgpack:"system_prompt,omitempty" json:"system_prompt,omitempty"`
ResponseSubject string `msgpack:"response_subject,omitempty" json:"response_subject,omitempty"` // STT
} type STTStreamMessage = pb.STTStreamMessage
type STTTranscription = pb.STTTranscription
type STTInterrupt = pb.STTInterrupt
// Pipeline
type PipelineTrigger = pb.PipelineTrigger
type PipelineStatus = pb.PipelineStatus
// ════════════════════════════════════════════════════════════════════════════
// Helpers
// ════════════════════════════════════════════════════════════════════════════
// EffectiveQuery returns Message or falls back to Query. // EffectiveQuery returns Message or falls back to Query.
func (c *ChatRequest) EffectiveQuery() string { func EffectiveQuery(c *ChatRequest) string {
if c.Message != "" { if c.GetMessage() != "" {
return c.Message return c.GetMessage()
} }
return c.Query return c.GetQuery()
} }
// ChatResponse is the full reply to a chat request. // Timestamp returns the current Unix timestamp.
type ChatResponse struct {
UserID string `msgpack:"user_id" json:"user_id"`
Response string `msgpack:"response" json:"response"`
ResponseText string `msgpack:"response_text" json:"response_text"`
UsedRAG bool `msgpack:"used_rag" json:"used_rag"`
RAGSources []string `msgpack:"rag_sources,omitempty" json:"rag_sources,omitempty"`
Success bool `msgpack:"success" json:"success"`
Audio []byte `msgpack:"audio,omitempty" json:"audio,omitempty"`
Error string `msgpack:"error,omitempty" json:"error,omitempty"`
}
// ChatStreamChunk is a single streaming chunk from an LLM response.
type ChatStreamChunk struct {
RequestID string `msgpack:"request_id" json:"request_id"`
Type string `msgpack:"type" json:"type"`
Content string `msgpack:"content" json:"content"`
Done bool `msgpack:"done" json:"done"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
}
// ────────────────────────────────────────────────────────────────────────────
// Voice Assistant
// ────────────────────────────────────────────────────────────────────────────
// VoiceRequest is an incoming voice-to-voice request.
type VoiceRequest struct {
RequestID string `msgpack:"request_id" json:"request_id"`
Audio []byte `msgpack:"audio" json:"audio"`
Language string `msgpack:"language,omitempty" json:"language,omitempty"`
Collection string `msgpack:"collection,omitempty" json:"collection,omitempty"`
}
// VoiceResponse is the reply to a voice request.
type VoiceResponse struct {
RequestID string `msgpack:"request_id" json:"request_id"`
Response string `msgpack:"response" json:"response"`
Audio []byte `msgpack:"audio" json:"audio"`
Transcription string `msgpack:"transcription,omitempty" json:"transcription,omitempty"`
Sources []DocumentSource `msgpack:"sources,omitempty" json:"sources,omitempty"`
Error string `msgpack:"error,omitempty" json:"error,omitempty"`
}
// DocumentSource is a RAG search result source.
type DocumentSource struct {
Text string `msgpack:"text" json:"text"`
Score float64 `msgpack:"score" json:"score"`
}
// ────────────────────────────────────────────────────────────────────────────
// TTS Module
// ────────────────────────────────────────────────────────────────────────────
// TTSRequest is a text-to-speech synthesis request.
type TTSRequest struct {
Text string `msgpack:"text" json:"text"`
Speaker string `msgpack:"speaker,omitempty" json:"speaker,omitempty"`
Language string `msgpack:"language,omitempty" json:"language,omitempty"`
SpeakerWavB64 string `msgpack:"speaker_wav_b64,omitempty" json:"speaker_wav_b64,omitempty"`
Stream bool `msgpack:"stream,omitempty" json:"stream,omitempty"`
}
// TTSAudioChunk is a streamed audio chunk from TTS synthesis.
type TTSAudioChunk struct {
SessionID string `msgpack:"session_id" json:"session_id"`
ChunkIndex int `msgpack:"chunk_index" json:"chunk_index"`
TotalChunks int `msgpack:"total_chunks" json:"total_chunks"`
Audio []byte `msgpack:"audio" json:"audio"`
IsLast bool `msgpack:"is_last" json:"is_last"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
SampleRate int `msgpack:"sample_rate" json:"sample_rate"`
}
// TTSFullResponse is a non-streamed TTS response (whole audio).
type TTSFullResponse struct {
SessionID string `msgpack:"session_id" json:"session_id"`
Audio []byte `msgpack:"audio" json:"audio"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
SampleRate int `msgpack:"sample_rate" json:"sample_rate"`
}
// TTSStatus is a TTS processing status update.
type TTSStatus struct {
SessionID string `msgpack:"session_id" json:"session_id"`
Status string `msgpack:"status" json:"status"`
Message string `msgpack:"message" json:"message"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
}
// TTSVoiceListResponse is the reply to a voice list request.
type TTSVoiceListResponse struct {
DefaultSpeaker string `msgpack:"default_speaker" json:"default_speaker"`
CustomVoices []TTSVoiceInfo `msgpack:"custom_voices" json:"custom_voices"`
LastRefresh int64 `msgpack:"last_refresh" json:"last_refresh"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
}
// TTSVoiceInfo is summary info about a custom voice.
type TTSVoiceInfo struct {
Name string `msgpack:"name" json:"name"`
Language string `msgpack:"language" json:"language"`
ModelType string `msgpack:"model_type" json:"model_type"`
CreatedAt string `msgpack:"created_at" json:"created_at"`
}
// TTSVoiceRefreshResponse is the reply to a voice refresh request.
type TTSVoiceRefreshResponse struct {
Count int `msgpack:"count" json:"count"`
CustomVoices []TTSVoiceInfo `msgpack:"custom_voices" json:"custom_voices"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
}
// ────────────────────────────────────────────────────────────────────────────
// STT Module
// ────────────────────────────────────────────────────────────────────────────
// STTStreamMessage is any message on the ai.voice.stream.{session} subject.
type STTStreamMessage struct {
Type string `msgpack:"type" json:"type"`
Audio []byte `msgpack:"audio,omitempty" json:"audio,omitempty"`
State string `msgpack:"state,omitempty" json:"state,omitempty"`
SpeakerID string `msgpack:"speaker_id,omitempty" json:"speaker_id,omitempty"`
}
// STTTranscription is the transcription result published by the STT module.
type STTTranscription struct {
SessionID string `msgpack:"session_id" json:"session_id"`
Transcript string `msgpack:"transcript" json:"transcript"`
Sequence int `msgpack:"sequence" json:"sequence"`
IsPartial bool `msgpack:"is_partial" json:"is_partial"`
IsFinal bool `msgpack:"is_final" json:"is_final"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
SpeakerID string `msgpack:"speaker_id" json:"speaker_id"`
HasVoiceActivity bool `msgpack:"has_voice_activity" json:"has_voice_activity"`
State string `msgpack:"state" json:"state"`
}
// STTInterrupt is published when the STT module detects a user interrupt.
type STTInterrupt struct {
SessionID string `msgpack:"session_id" json:"session_id"`
Type string `msgpack:"type" json:"type"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
SpeakerID string `msgpack:"speaker_id" json:"speaker_id"`
}
// ────────────────────────────────────────────────────────────────────────────
// Common / Error
// ────────────────────────────────────────────────────────────────────────────
// ErrorResponse is the standard error reply from any handler.
type ErrorResponse struct {
Error bool `msgpack:"error" json:"error"`
Message string `msgpack:"message" json:"message"`
Type string `msgpack:"type" json:"type"`
}
// Timestamp returns the current Unix timestamp (helper for message construction).
func Timestamp() int64 { func Timestamp() int64 {
return time.Now().Unix() return time.Now().Unix()
} }

View File

@@ -1,4 +1,4 @@
// Package natsutil provides a NATS/JetStream client with msgpack serialization. // Package natsutil provides a NATS/JetStream client with protobuf serialization.
package natsutil package natsutil
import ( import (
@@ -7,10 +7,10 @@ import (
"time" "time"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"github.com/vmihailenco/msgpack/v5" "google.golang.org/protobuf/proto"
) )
// Client wraps a NATS connection with msgpack helpers. // Client wraps a NATS connection with protobuf helpers.
type Client struct { type Client struct {
nc *nats.Conn nc *nats.Conn
js nats.JetStreamContext js nats.JetStreamContext
@@ -97,46 +97,34 @@ func (c *Client) Subscribe(subject string, handler nats.MsgHandler, queue string
return nil return nil
} }
// Publish encodes data as msgpack and publishes to the subject. // Publish encodes data as protobuf and publishes to the subject.
func (c *Client) Publish(subject string, data any) error { func (c *Client) Publish(subject string, data proto.Message) error {
payload, err := msgpack.Marshal(data) payload, err := proto.Marshal(data)
if err != nil { if err != nil {
return fmt.Errorf("msgpack marshal: %w", err) return fmt.Errorf("proto marshal: %w", err)
} }
return c.nc.Publish(subject, payload) return c.nc.Publish(subject, payload)
} }
// Request sends a msgpack-encoded request and decodes the response into result. // PublishRaw publishes pre-encoded bytes to the subject.
func (c *Client) Request(subject string, data any, result any, timeout time.Duration) error { func (c *Client) PublishRaw(subject string, data []byte) error {
payload, err := msgpack.Marshal(data) return c.nc.Publish(subject, data)
}
// Request sends a protobuf-encoded request and decodes the response into result.
func (c *Client) Request(subject string, data proto.Message, result proto.Message, timeout time.Duration) error {
payload, err := proto.Marshal(data)
if err != nil { if err != nil {
return fmt.Errorf("msgpack marshal: %w", err) return fmt.Errorf("proto marshal: %w", err)
} }
msg, err := c.nc.Request(subject, payload, timeout) msg, err := c.nc.Request(subject, payload, timeout)
if err != nil { if err != nil {
return fmt.Errorf("nats request: %w", err) return fmt.Errorf("nats request: %w", err)
} }
return msgpack.Unmarshal(msg.Data, result) return proto.Unmarshal(msg.Data, result)
} }
// DecodeMsgpack decodes msgpack-encoded NATS message data into dest. // Decode unmarshals protobuf bytes into dest.
func DecodeMsgpack(msg *nats.Msg, dest any) error { func Decode(data []byte, dest proto.Message) error {
return msgpack.Unmarshal(msg.Data, dest) return proto.Unmarshal(data, dest)
}
// Decode is a generic helper that unmarshals msgpack bytes into T.
// Usage: req, err := natsutil.Decode[messages.ChatRequest](msg.Data)
func Decode[T any](data []byte) (T, error) {
var v T
err := msgpack.Unmarshal(data, &v)
return v, err
}
// DecodeMsgpackMap decodes msgpack data into a generic map.
func DecodeMsgpackMap(data []byte) (map[string]any, error) {
var m map[string]any
if err := msgpack.Unmarshal(data, &m); err != nil {
return nil, err
}
return m, nil
} }

View File

@@ -3,254 +3,212 @@ package natsutil
import ( import (
"testing" "testing"
"github.com/vmihailenco/msgpack/v5" "google.golang.org/protobuf/proto"
pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
) )
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
// DecodeMsgpackMap tests // Decode tests
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
func TestDecodeMsgpackMap_Roundtrip(t *testing.T) { func TestDecode_ChatRequest_Roundtrip(t *testing.T) {
orig := map[string]any{ orig := &pb.ChatRequest{
"request_id": "req-001", RequestId: "req-001",
"user_id": "user-42", UserId: "user-42",
"premium": true, Premium: true,
"top_k": int64(10), // msgpack decodes ints as int64 TopK: 10,
} }
data, err := msgpack.Marshal(orig) data, err := proto.Marshal(orig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
decoded, err := DecodeMsgpackMap(data) var decoded pb.ChatRequest
if err != nil { if err := Decode(data, &decoded); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if decoded["request_id"] != "req-001" { if decoded.GetRequestId() != "req-001" {
t.Errorf("request_id = %v", decoded["request_id"]) t.Errorf("RequestId = %v", decoded.GetRequestId())
} }
if decoded["premium"] != true { if decoded.GetUserId() != "user-42" {
t.Errorf("premium = %v", decoded["premium"]) t.Errorf("UserId = %v", decoded.GetUserId())
}
if decoded.GetPremium() != true {
t.Errorf("Premium = %v", decoded.GetPremium())
}
if decoded.GetTopK() != 10 {
t.Errorf("TopK = %v", decoded.GetTopK())
} }
} }
func TestDecodeMsgpackMap_Empty(t *testing.T) { func TestDecode_EmptyMessage(t *testing.T) {
data, _ := msgpack.Marshal(map[string]any{}) data, err := proto.Marshal(&pb.ChatRequest{})
m, err := DecodeMsgpackMap(data)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(m) != 0 { var decoded pb.ChatRequest
t.Errorf("expected empty map, got %v", m) if err := Decode(data, &decoded); err != nil {
t.Fatal(err)
}
if decoded.GetRequestId() != "" {
t.Errorf("expected empty RequestId, got %q", decoded.GetRequestId())
} }
} }
func TestDecodeMsgpackMap_InvalidData(t *testing.T) { func TestDecode_InvalidData(t *testing.T) {
_, err := DecodeMsgpackMap([]byte{0xFF, 0xFE}) err := Decode([]byte{0xFF, 0xFE}, &pb.ChatRequest{})
if err == nil { if err == nil {
t.Error("expected error for invalid msgpack data") t.Error("expected error for invalid protobuf data")
} }
} }
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
// DecodeMsgpack (typed struct) tests // Typed struct roundtrip tests
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
type testMessage struct { func TestDecode_VoiceResponse_Roundtrip(t *testing.T) {
RequestID string `msgpack:"request_id"` orig := &pb.VoiceResponse{
UserID string `msgpack:"user_id"` RequestId: "vr-001",
Count int `msgpack:"count"` Response: "The capital of France is Paris.",
Active bool `msgpack:"active"` Transcription: "What is the capital of France?",
}
func TestDecodeMsgpackTyped_Roundtrip(t *testing.T) {
orig := testMessage{
RequestID: "req-typed-001",
UserID: "user-7",
Count: 42,
Active: true,
} }
data, err := msgpack.Marshal(orig) data, err := proto.Marshal(orig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Simulate nats.Msg data decoding. var decoded pb.VoiceResponse
var decoded testMessage if err := Decode(data, &decoded); err != nil {
if err := msgpack.Unmarshal(data, &decoded); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if decoded.RequestID != orig.RequestID { if decoded.GetRequestId() != orig.GetRequestId() {
t.Errorf("RequestID = %q, want %q", decoded.RequestID, orig.RequestID) t.Errorf("RequestId = %q, want %q", decoded.GetRequestId(), orig.GetRequestId())
} }
if decoded.Count != orig.Count { if decoded.GetResponse() != orig.GetResponse() {
t.Errorf("Count = %d, want %d", decoded.Count, orig.Count) t.Errorf("Response = %q, want %q", decoded.GetResponse(), orig.GetResponse())
} }
if decoded.Active != orig.Active { if decoded.GetTranscription() != orig.GetTranscription() {
t.Errorf("Active = %v, want %v", decoded.Active, orig.Active) t.Errorf("Transcription = %q, want %q", decoded.GetTranscription(), orig.GetTranscription())
} }
} }
// TestTypedStructDecodesMapEncoding verifies that a typed struct can be func TestDecode_ErrorResponse_Roundtrip(t *testing.T) {
// decoded from data that was encoded as map[string]any (backwards compat). orig := &pb.ErrorResponse{
func TestTypedStructDecodesMapEncoding(t *testing.T) { Error: true,
// Encode as map (the old way). Message: "something broke",
mapData := map[string]any{ Type: "InternalError",
"request_id": "req-compat",
"user_id": "user-compat",
"count": int64(99),
"active": false,
} }
data, err := msgpack.Marshal(mapData) data, err := proto.Marshal(orig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Decode into typed struct (the new way). var decoded pb.ErrorResponse
var msg testMessage if err := Decode(data, &decoded); err != nil {
if err := msgpack.Unmarshal(data, &msg); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if msg.RequestID != "req-compat" { if !decoded.GetError() {
t.Errorf("RequestID = %q", msg.RequestID) t.Error("expected Error=true")
} }
if msg.Count != 99 { if decoded.GetMessage() != "something broke" {
t.Errorf("Count = %d, want 99", msg.Count) t.Errorf("Message = %q", decoded.GetMessage())
}
if decoded.GetType() != "InternalError" {
t.Errorf("Type = %q", decoded.GetType())
} }
} }
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
// Binary data tests (audio []byte in msgpack) // Binary data tests (audio []byte in protobuf)
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
type audioMessage struct {
SessionID string `msgpack:"session_id"`
Audio []byte `msgpack:"audio"`
SampleRate int `msgpack:"sample_rate"`
}
func TestBinaryDataRoundtrip(t *testing.T) { func TestBinaryDataRoundtrip(t *testing.T) {
audio := make([]byte, 32768) audio := make([]byte, 32768)
for i := range audio { for i := range audio {
audio[i] = byte(i % 256) audio[i] = byte(i % 256)
} }
orig := audioMessage{ orig := &pb.TTSAudioChunk{
SessionID: "sess-audio-001", SessionId: "sess-audio-001",
Audio: audio, Audio: audio,
SampleRate: 24000, SampleRate: 24000,
} }
data, err := msgpack.Marshal(orig) data, err := proto.Marshal(orig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var decoded audioMessage var decoded pb.TTSAudioChunk
if err := msgpack.Unmarshal(data, &decoded); err != nil { if err := Decode(data, &decoded); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(decoded.Audio) != len(orig.Audio) { if len(decoded.GetAudio()) != len(orig.GetAudio()) {
t.Fatalf("audio len = %d, want %d", len(decoded.Audio), len(orig.Audio)) t.Fatalf("audio len = %d, want %d", len(decoded.GetAudio()), len(orig.GetAudio()))
} }
for i := range decoded.Audio { for i := range decoded.GetAudio() {
if decoded.Audio[i] != orig.Audio[i] { if decoded.GetAudio()[i] != orig.GetAudio()[i] {
t.Fatalf("audio[%d] = %d, want %d", i, decoded.Audio[i], orig.Audio[i]) t.Fatalf("audio[%d] = %d, want %d", i, decoded.GetAudio()[i], orig.GetAudio()[i])
} }
} }
} }
// TestBinaryVsBase64Size shows the wire-size win of raw bytes vs base64 string. // TestProtoWireSize shows protobuf wire size for binary payloads.
func TestBinaryVsBase64Size(t *testing.T) { func TestProtoWireSize(t *testing.T) {
audio := make([]byte, 16384) audio := make([]byte, 16384)
// Old approach: base64 string in map. msg := &pb.TTSAudioChunk{
import_b64 := make([]byte, (len(audio)*4+2)/3) // approximate base64 size SessionId: "sess-1",
mapMsg := map[string]any{
"session_id": "sess-1",
"audio_b64": string(import_b64),
}
mapData, _ := msgpack.Marshal(mapMsg)
// New approach: raw bytes in struct.
structMsg := audioMessage{
SessionID: "sess-1",
Audio: audio, Audio: audio,
} }
structData, _ := msgpack.Marshal(structMsg) data, _ := proto.Marshal(msg)
t.Logf("base64-in-map: %d bytes, raw-bytes-in-struct: %d bytes (%.0f%% smaller)", t.Logf("TTSAudioChunk with 16KB audio: %d bytes on wire", len(data))
len(mapData), len(structData),
100*(1-float64(len(structData))/float64(len(mapData))))
} }
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
// Benchmarks // Benchmarks
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
func BenchmarkEncodeMap(b *testing.B) { func BenchmarkEncode_ChatRequest(b *testing.B) {
data := map[string]any{ data := &pb.ChatRequest{
"request_id": "req-bench", RequestId: "req-bench",
"user_id": "user-bench", UserId: "user-bench",
"message": "What is the weather today?", Message: "What is the weather today?",
"premium": true, Premium: true,
"top_k": 10, TopK: 10,
} }
for b.Loop() { for b.Loop() {
_, _ = msgpack.Marshal(data) _, _ = proto.Marshal(data)
} }
} }
func BenchmarkEncodeStruct(b *testing.B) { func BenchmarkDecode_ChatRequest(b *testing.B) {
data := testMessage{ raw, _ := proto.Marshal(&pb.ChatRequest{
RequestID: "req-bench", RequestId: "req-bench",
UserID: "user-bench", UserId: "user-bench",
Count: 10, Message: "What is the weather today?",
Active: true, Premium: true,
} TopK: 10,
for b.Loop() {
_, _ = msgpack.Marshal(data)
}
}
func BenchmarkDecodeMap(b *testing.B) {
raw, _ := msgpack.Marshal(map[string]any{
"request_id": "req-bench",
"user_id": "user-bench",
"message": "What is the weather today?",
"premium": true,
"top_k": 10,
}) })
for b.Loop() { for b.Loop() {
var m map[string]any var m pb.ChatRequest
_ = msgpack.Unmarshal(raw, &m) _ = Decode(raw, &m)
} }
} }
func BenchmarkDecodeStruct(b *testing.B) { func BenchmarkDecode_Audio32KB(b *testing.B) {
raw, _ := msgpack.Marshal(testMessage{ raw, _ := proto.Marshal(&pb.TTSAudioChunk{
RequestID: "req-bench", SessionId: "s1",
UserID: "user-bench",
Count: 10,
Active: true,
})
for b.Loop() {
var m testMessage
_ = msgpack.Unmarshal(raw, &m)
}
}
func BenchmarkDecodeAudio32KB(b *testing.B) {
raw, _ := msgpack.Marshal(audioMessage{
SessionID: "s1",
Audio: make([]byte, 32768), Audio: make([]byte, 32768),
SampleRate: 24000, SampleRate: 24000,
}) })
for b.Loop() { for b.Loop() {
var m audioMessage var m pb.TTSAudioChunk
_ = msgpack.Unmarshal(raw, &m) _ = Decode(raw, &m)
} }
} }

View File

@@ -0,0 +1,257 @@
// Homelab AI service message contracts.
//
// This is the single source of truth for all NATS message types.
// Generated Go code lives in handler-base/gen/messagespb.
//
// Naming: field numbers are stable across versions — add new fields,
// never reuse or renumber existing ones.
syntax = "proto3";
package messages.v1;
option go_package = "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb";
// ─────────────────────────────────────────────────────────────────────────────
// Common
// ─────────────────────────────────────────────────────────────────────────────
// ErrorResponse is the standard error reply from any handler.
message ErrorResponse {
bool error = 1;
string message = 2;
string type = 3;
}
// ─────────────────────────────────────────────────────────────────────────────
// Chat (companions-frontend ↔ chat-handler)
// ─────────────────────────────────────────────────────────────────────────────
// LoginEvent is published when a user authenticates.
// Subject: ai.chat.user.{user_id}.login
message LoginEvent {
string user_id = 1;
string username = 2;
string nickname = 3;
bool premium = 4;
int64 timestamp = 5; // Unix seconds
}
// GreetingRequest asks the LLM to generate a personalised greeting.
// Subject: ai.chat.user.{user_id}.greeting.request
message GreetingRequest {
string user_id = 1;
string username = 2;
string nickname = 3;
bool premium = 4;
}
// GreetingResponse carries the generated greeting text.
// Subject: ai.chat.user.{user_id}.greeting.response
message GreetingResponse {
string user_id = 1;
string greeting = 2;
}
// ChatRequest is an incoming chat message routed via NATS.
// Subject: ai.chat.user.{user_id}.message
message ChatRequest {
string request_id = 1;
string user_id = 2;
string username = 3;
string message = 4;
string query = 5; // alternative to message (EffectiveQuery picks first non-empty)
bool premium = 6;
bool enable_rag = 7;
bool enable_reranker = 8;
bool enable_streaming = 9;
int32 top_k = 10;
string collection = 11;
bool enable_tts = 12;
string system_prompt = 13;
string response_subject = 14;
}
// ChatResponse is the full reply to a ChatRequest.
// Subject: ai.chat.response.{request_id} (or ChatRequest.response_subject)
message ChatResponse {
string user_id = 1;
string response = 2;
string response_text = 3;
bool used_rag = 4;
repeated string rag_sources = 5;
bool success = 6;
bytes audio = 7;
string error = 8;
}
// ChatStreamChunk is one piece of a streaming LLM response.
// Subject: ai.chat.response.stream.{request_id}
message ChatStreamChunk {
string request_id = 1;
string type = 2; // "chunk" | "done"
string content = 3;
bool done = 4;
int64 timestamp = 5;
}
// ─────────────────────────────────────────────────────────────────────────────
// Voice Assistant
// ─────────────────────────────────────────────────────────────────────────────
// VoiceRequest is an incoming voice-to-voice request.
// Subject: ai.voice.request
message VoiceRequest {
string request_id = 1;
bytes audio = 2;
string language = 3;
string collection = 4;
}
// DocumentSource is a single RAG search-result citation.
message DocumentSource {
string text = 1;
double score = 2;
}
// VoiceResponse is the reply to a VoiceRequest.
// Subject: ai.voice.response.{request_id}
message VoiceResponse {
string request_id = 1;
string response = 2;
bytes audio = 3;
string transcription = 4;
repeated DocumentSource sources = 5;
string error = 6;
}
// ─────────────────────────────────────────────────────────────────────────────
// TTS Module
// ─────────────────────────────────────────────────────────────────────────────
// TTSRequest is a text-to-speech synthesis request.
// Subject: ai.voice.tts.request.{session_id}
message TTSRequest {
string text = 1;
string speaker = 2;
string language = 3;
string speaker_wav_b64 = 4;
bool stream = 5;
}
// TTSAudioChunk is a streamed audio chunk from TTS synthesis.
// Subject: ai.voice.tts.audio.{session_id}
message TTSAudioChunk {
string session_id = 1;
int32 chunk_index = 2;
int32 total_chunks = 3;
bytes audio = 4;
bool is_last = 5;
int64 timestamp = 6;
int32 sample_rate = 7;
}
// TTSFullResponse is a non-streamed TTS response (whole audio blob).
// Subject: ai.voice.tts.audio.{session_id}
message TTSFullResponse {
string session_id = 1;
bytes audio = 2;
int64 timestamp = 3;
int32 sample_rate = 4;
}
// TTSStatus is a TTS processing status update.
// Subject: ai.voice.tts.status.{session_id}
message TTSStatus {
string session_id = 1;
string status = 2;
string message = 3;
int64 timestamp = 4;
}
// TTSVoiceInfo is summary info about a custom voice.
message TTSVoiceInfo {
string name = 1;
string language = 2;
string model_type = 3;
string created_at = 4;
}
// TTSVoiceListResponse is the reply to a voice list request.
// Subject: ai.voice.tts.voices.list (request-reply)
message TTSVoiceListResponse {
string default_speaker = 1;
repeated TTSVoiceInfo custom_voices = 2;
int64 last_refresh = 3;
int64 timestamp = 4;
}
// TTSVoiceRefreshResponse is the reply to a voice refresh request.
// Subject: ai.voice.tts.voices.refresh (request-reply)
message TTSVoiceRefreshResponse {
int32 count = 1;
repeated TTSVoiceInfo custom_voices = 2;
int64 timestamp = 3;
}
// ─────────────────────────────────────────────────────────────────────────────
// STT Module
// ─────────────────────────────────────────────────────────────────────────────
// STTStreamMessage is any message on the ai.voice.stream.{session_id} subject.
message STTStreamMessage {
string type = 1; // "start" | "chunk" | "state_change" | "end"
bytes audio = 2;
string state = 3;
string speaker_id = 4;
}
// STTTranscription is the transcription result published by the STT module.
// Subject: ai.voice.transcription.{session_id}
message STTTranscription {
string session_id = 1;
string transcript = 2;
int32 sequence = 3;
bool is_partial = 4;
bool is_final = 5;
int64 timestamp = 6;
string speaker_id = 7;
bool has_voice_activity = 8;
string state = 9;
}
// STTInterrupt is published when the STT module detects a user interrupt.
// Subject: ai.voice.transcription.{session_id}
message STTInterrupt {
string session_id = 1;
string type = 2; // "interrupt"
int64 timestamp = 3;
string speaker_id = 4;
}
// ─────────────────────────────────────────────────────────────────────────────
// Pipeline Bridge
// ─────────────────────────────────────────────────────────────────────────────
// PipelineTrigger is the request to start a pipeline.
// Subject: ai.pipeline.trigger
message PipelineTrigger {
string request_id = 1;
string pipeline = 2;
// Protobuf Struct could be used here, but a simple string map covers
// all current use-cases and avoids a google/protobuf import.
map<string, string> parameters = 3;
}
// PipelineStatus is the response / status update for a pipeline run.
// Subject: ai.pipeline.status.{request_id}
message PipelineStatus {
string request_id = 1;
string status = 2;
string run_id = 3;
string engine = 4;
string pipeline = 5;
string submitted_at = 6;
string error = 7;
repeated string available_pipelines = 8;
}