diff --git a/clients/clients_test.go b/clients/clients_test.go index 9c24887..78a4dfc 100644 --- a/clients/clients_test.go +++ b/clients/clients_test.go @@ -563,7 +563,7 @@ func TestLLMClient_StreamGenerateManyTokens(t *testing.T) { var order []int result, err := c.StreamGenerate(context.Background(), "q", "", "", func(tok string) { var idx int - fmt.Sscanf(tok, "t%d ", &idx) + _, _ = fmt.Sscanf(tok, "t%d ", &idx) mu.Lock() order = append(order, idx) mu.Unlock() diff --git a/config/config.go b/config/config.go index bfb6ffa..cb64a8f 100644 --- a/config/config.go +++ b/config/config.go @@ -3,16 +3,16 @@ package config import ( -"context" -"log/slog" -"os" -"path/filepath" -"strconv" -"strings" -"sync" -"time" + "context" + "log/slog" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "time" -"github.com/fsnotify/fsnotify" + "github.com/fsnotify/fsnotify" ) // Settings holds base configuration for all handler services. @@ -20,249 +20,249 @@ import ( // updated at runtime via WatchSecrets(). All other fields are immutable // after Load() returns. type Settings struct { -// Service identification (immutable) -ServiceName string -ServiceVersion string -ServiceNamespace string -DeploymentEnv string + // Service identification (immutable) + ServiceName string + ServiceVersion string + ServiceNamespace string + DeploymentEnv string -// NATS configuration (immutable) -NATSURL string -NATSUser string -NATSPassword string -NATSQueueGroup string + // NATS configuration (immutable) + NATSURL string + NATSUser string + NATSPassword string + NATSQueueGroup string -// Redis/Valkey configuration (immutable) -RedisURL string -RedisPassword string + // Redis/Valkey configuration (immutable) + RedisURL string + RedisPassword string -// Milvus configuration (immutable) -MilvusHost string -MilvusPort int -MilvusCollection string + // Milvus configuration (immutable) + MilvusHost string + MilvusPort int + MilvusCollection string -// OpenTelemetry configuration (immutable) -OTELEnabled bool -OTELEndpoint string -OTELUseHTTP bool + // OpenTelemetry configuration (immutable) + OTELEnabled bool + OTELEndpoint string + OTELUseHTTP bool -// HyperDX configuration (immutable) -HyperDXEnabled bool -HyperDXAPIKey string -HyperDXEndpoint string + // HyperDX configuration (immutable) + HyperDXEnabled bool + HyperDXAPIKey string + HyperDXEndpoint string -// MLflow configuration (immutable) -MLflowTrackingURI string -MLflowExperimentName string -MLflowEnabled bool + // MLflow configuration (immutable) + MLflowTrackingURI string + MLflowExperimentName string + MLflowEnabled bool -// Health check configuration (immutable) -HealthPort int -HealthPath string -ReadyPath string + // Health check configuration (immutable) + HealthPort int + HealthPath string + ReadyPath string -// Timeouts (immutable) -HTTPTimeout time.Duration -NATSTimeout time.Duration + // Timeouts (immutable) + HTTPTimeout time.Duration + NATSTimeout time.Duration -// Hot-reloadable fields — access via getter methods. -mu sync.RWMutex -embeddingsURL string -rerankerURL string -llmURL string -ttsURL string -sttURL string + // Hot-reloadable fields — access via getter methods. + mu sync.RWMutex + embeddingsURL string + rerankerURL string + llmURL string + ttsURL string + sttURL string -// Secrets path for file-based hot reload (Kubernetes secret mounts) -SecretsPath string + // Secrets path for file-based hot reload (Kubernetes secret mounts) + SecretsPath string } // Load creates a Settings populated from environment variables with defaults. func Load() *Settings { -return &Settings{ -ServiceName: getEnv("SERVICE_NAME", "handler"), -ServiceVersion: getEnv("SERVICE_VERSION", "1.0.0"), -ServiceNamespace: getEnv("SERVICE_NAMESPACE", "ai-ml"), -DeploymentEnv: getEnv("DEPLOYMENT_ENV", "production"), + return &Settings{ + ServiceName: getEnv("SERVICE_NAME", "handler"), + ServiceVersion: getEnv("SERVICE_VERSION", "1.0.0"), + ServiceNamespace: getEnv("SERVICE_NAMESPACE", "ai-ml"), + DeploymentEnv: getEnv("DEPLOYMENT_ENV", "production"), -NATSURL: getEnv("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222"), -NATSUser: getEnv("NATS_USER", ""), -NATSPassword: getEnv("NATS_PASSWORD", ""), -NATSQueueGroup: getEnv("NATS_QUEUE_GROUP", ""), + NATSURL: getEnv("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222"), + NATSUser: getEnv("NATS_USER", ""), + NATSPassword: getEnv("NATS_PASSWORD", ""), + NATSQueueGroup: getEnv("NATS_QUEUE_GROUP", ""), -RedisURL: getEnv("REDIS_URL", "redis://valkey.ai-ml.svc.cluster.local:6379"), -RedisPassword: getEnv("REDIS_PASSWORD", ""), + RedisURL: getEnv("REDIS_URL", "redis://valkey.ai-ml.svc.cluster.local:6379"), + RedisPassword: getEnv("REDIS_PASSWORD", ""), -MilvusHost: getEnv("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local"), -MilvusPort: getEnvInt("MILVUS_PORT", 19530), -MilvusCollection: getEnv("MILVUS_COLLECTION", "documents"), + MilvusHost: getEnv("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local"), + MilvusPort: getEnvInt("MILVUS_PORT", 19530), + MilvusCollection: getEnv("MILVUS_COLLECTION", "documents"), -embeddingsURL: getEnv("EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"), -rerankerURL: getEnv("RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"), -llmURL: getEnv("LLM_URL", "http://vllm-predictor.ai-ml.svc.cluster.local"), -ttsURL: getEnv("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local"), -sttURL: getEnv("STT_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"), + embeddingsURL: getEnv("EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"), + rerankerURL: getEnv("RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"), + llmURL: getEnv("LLM_URL", "http://vllm-predictor.ai-ml.svc.cluster.local"), + ttsURL: getEnv("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local"), + sttURL: getEnv("STT_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"), -OTELEnabled: getEnvBool("OTEL_ENABLED", true), -OTELEndpoint: getEnv("OTEL_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317"), -OTELUseHTTP: getEnvBool("OTEL_USE_HTTP", false), + OTELEnabled: getEnvBool("OTEL_ENABLED", true), + OTELEndpoint: getEnv("OTEL_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317"), + OTELUseHTTP: getEnvBool("OTEL_USE_HTTP", false), -HyperDXEnabled: getEnvBool("HYPERDX_ENABLED", false), -HyperDXAPIKey: getEnv("HYPERDX_API_KEY", ""), -HyperDXEndpoint: getEnv("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io"), + HyperDXEnabled: getEnvBool("HYPERDX_ENABLED", false), + HyperDXAPIKey: getEnv("HYPERDX_API_KEY", ""), + HyperDXEndpoint: getEnv("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io"), -MLflowTrackingURI: getEnv("MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"), -MLflowExperimentName: getEnv("MLFLOW_EXPERIMENT_NAME", ""), -MLflowEnabled: getEnvBool("MLFLOW_ENABLED", true), + MLflowTrackingURI: getEnv("MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"), + MLflowExperimentName: getEnv("MLFLOW_EXPERIMENT_NAME", ""), + MLflowEnabled: getEnvBool("MLFLOW_ENABLED", true), -HealthPort: getEnvInt("HEALTH_PORT", 8080), -HealthPath: getEnv("HEALTH_PATH", "/health"), -ReadyPath: getEnv("READY_PATH", "/ready"), + HealthPort: getEnvInt("HEALTH_PORT", 8080), + HealthPath: getEnv("HEALTH_PATH", "/health"), + ReadyPath: getEnv("READY_PATH", "/ready"), -HTTPTimeout: getEnvDuration("HTTP_TIMEOUT", 60*time.Second), -NATSTimeout: getEnvDuration("NATS_TIMEOUT", 30*time.Second), + HTTPTimeout: getEnvDuration("HTTP_TIMEOUT", 60*time.Second), + NATSTimeout: getEnvDuration("NATS_TIMEOUT", 30*time.Second), -SecretsPath: getEnv("SECRETS_PATH", ""), -} + SecretsPath: getEnv("SECRETS_PATH", ""), + } } // EmbeddingsURL returns the current embeddings service URL (thread-safe). func (s *Settings) EmbeddingsURL() string { -s.mu.RLock() -defer s.mu.RUnlock() -return s.embeddingsURL + s.mu.RLock() + defer s.mu.RUnlock() + return s.embeddingsURL } // RerankerURL returns the current reranker service URL (thread-safe). func (s *Settings) RerankerURL() string { -s.mu.RLock() -defer s.mu.RUnlock() -return s.rerankerURL + s.mu.RLock() + defer s.mu.RUnlock() + return s.rerankerURL } // LLMURL returns the current LLM service URL (thread-safe). func (s *Settings) LLMURL() string { -s.mu.RLock() -defer s.mu.RUnlock() -return s.llmURL + s.mu.RLock() + defer s.mu.RUnlock() + return s.llmURL } // TTSURL returns the current TTS service URL (thread-safe). func (s *Settings) TTSURL() string { -s.mu.RLock() -defer s.mu.RUnlock() -return s.ttsURL + s.mu.RLock() + defer s.mu.RUnlock() + return s.ttsURL } // STTURL returns the current STT service URL (thread-safe). func (s *Settings) STTURL() string { -s.mu.RLock() -defer s.mu.RUnlock() -return s.sttURL + s.mu.RLock() + defer s.mu.RUnlock() + return s.sttURL } // WatchSecrets watches the SecretsPath directory for changes and reloads // hot-reloadable fields. Blocks until ctx is cancelled. func (s *Settings) WatchSecrets(ctx context.Context) { -if s.SecretsPath == "" { -return -} + if s.SecretsPath == "" { + return + } -watcher, err := fsnotify.NewWatcher() -if err != nil { -slog.Error("config: failed to create fsnotify watcher", "error", err) -return -} -defer func() { _ = watcher.Close() }() + watcher, err := fsnotify.NewWatcher() + if err != nil { + slog.Error("config: failed to create fsnotify watcher", "error", err) + return + } + defer func() { _ = watcher.Close() }() -if err := watcher.Add(s.SecretsPath); err != nil { -slog.Error("config: failed to watch secrets path", "error", err, "path", s.SecretsPath) -return -} + if err := watcher.Add(s.SecretsPath); err != nil { + slog.Error("config: failed to watch secrets path", "error", err, "path", s.SecretsPath) + return + } -slog.Info("config: watching secrets for hot reload", "path", s.SecretsPath) + slog.Info("config: watching secrets for hot reload", "path", s.SecretsPath) -for { -select { -case event, ok := <-watcher.Events: -if !ok { -return -} -if event.Has(fsnotify.Create) || event.Has(fsnotify.Write) { -s.reloadFromSecrets() -} -case err, ok := <-watcher.Errors: -if !ok { -return -} -slog.Error("config: fsnotify error", "error", err) -case <-ctx.Done(): -return -} -} + for { + select { + case event, ok := <-watcher.Events: + if !ok { + return + } + if event.Has(fsnotify.Create) || event.Has(fsnotify.Write) { + s.reloadFromSecrets() + } + case err, ok := <-watcher.Errors: + if !ok { + return + } + slog.Error("config: fsnotify error", "error", err) + case <-ctx.Done(): + return + } + } } // reloadFromSecrets reads hot-reloadable values from the secrets directory. func (s *Settings) reloadFromSecrets() { -s.mu.Lock() -defer s.mu.Unlock() + s.mu.Lock() + defer s.mu.Unlock() -updated := 0 -reload := func(filename string, target *string) { -path := filepath.Join(s.SecretsPath, filename) -data, err := os.ReadFile(path) -if err != nil { -return -} -val := strings.TrimSpace(string(data)) -if val != "" && val != *target { -*target = val -updated++ -slog.Info("config: reloaded secret", "key", filename) -} -} + updated := 0 + reload := func(filename string, target *string) { + path := filepath.Join(s.SecretsPath, filename) + data, err := os.ReadFile(path) + if err != nil { + return + } + val := strings.TrimSpace(string(data)) + if val != "" && val != *target { + *target = val + updated++ + slog.Info("config: reloaded secret", "key", filename) + } + } -reload("embeddings-url", &s.embeddingsURL) -reload("reranker-url", &s.rerankerURL) -reload("llm-url", &s.llmURL) -reload("tts-url", &s.ttsURL) -reload("stt-url", &s.sttURL) + reload("embeddings-url", &s.embeddingsURL) + reload("reranker-url", &s.rerankerURL) + reload("llm-url", &s.llmURL) + reload("tts-url", &s.ttsURL) + reload("stt-url", &s.sttURL) -if updated > 0 { -slog.Info("config: secrets reloaded", "updated", updated) -} + if updated > 0 { + slog.Info("config: secrets reloaded", "updated", updated) + } } func getEnv(key, fallback string) string { -if v := os.Getenv(key); v != "" { -return v -} -return fallback + if v := os.Getenv(key); v != "" { + return v + } + return fallback } func getEnvInt(key string, fallback int) int { -if v := os.Getenv(key); v != "" { -if i, err := strconv.Atoi(v); err == nil { -return i -} -} -return fallback + if v := os.Getenv(key); v != "" { + if i, err := strconv.Atoi(v); err == nil { + return i + } + } + return fallback } func getEnvBool(key string, fallback bool) bool { -if v := os.Getenv(key); v != "" { -if b, err := strconv.ParseBool(v); err == nil { -return b -} -} -return fallback + if v := os.Getenv(key); v != "" { + if b, err := strconv.ParseBool(v); err == nil { + return b + } + } + return fallback } func getEnvDuration(key string, fallback time.Duration) time.Duration { -if v := os.Getenv(key); v != "" { -if f, err := strconv.ParseFloat(v, 64); err == nil { -return time.Duration(f * float64(time.Second)) -} -} -return fallback + if v := os.Getenv(key); v != "" { + if f, err := strconv.ParseFloat(v, 64); err == nil { + return time.Duration(f * float64(time.Second)) + } + } + return fallback } diff --git a/config/config_test.go b/config/config_test.go index fa450a0..9a472d3 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1,123 +1,123 @@ package config import ( -"os" -"path/filepath" -"testing" -"time" + "os" + "path/filepath" + "testing" + "time" ) func TestLoadDefaults(t *testing.T) { -s := Load() -if s.ServiceName != "handler" { -t.Errorf("expected default ServiceName 'handler', got %q", s.ServiceName) -} -if s.HealthPort != 8080 { -t.Errorf("expected default HealthPort 8080, got %d", s.HealthPort) -} -if s.HTTPTimeout != 60*time.Second { -t.Errorf("expected default HTTPTimeout 60s, got %v", s.HTTPTimeout) -} + s := Load() + if s.ServiceName != "handler" { + t.Errorf("expected default ServiceName 'handler', got %q", s.ServiceName) + } + if s.HealthPort != 8080 { + t.Errorf("expected default HealthPort 8080, got %d", s.HealthPort) + } + if s.HTTPTimeout != 60*time.Second { + t.Errorf("expected default HTTPTimeout 60s, got %v", s.HTTPTimeout) + } } func TestLoadFromEnv(t *testing.T) { -t.Setenv("SERVICE_NAME", "test-svc") -t.Setenv("HEALTH_PORT", "9090") -t.Setenv("OTEL_ENABLED", "false") + t.Setenv("SERVICE_NAME", "test-svc") + t.Setenv("HEALTH_PORT", "9090") + t.Setenv("OTEL_ENABLED", "false") -s := Load() -if s.ServiceName != "test-svc" { -t.Errorf("expected ServiceName 'test-svc', got %q", s.ServiceName) -} -if s.HealthPort != 9090 { -t.Errorf("expected HealthPort 9090, got %d", s.HealthPort) -} -if s.OTELEnabled { -t.Error("expected OTELEnabled false") -} + s := Load() + if s.ServiceName != "test-svc" { + t.Errorf("expected ServiceName 'test-svc', got %q", s.ServiceName) + } + if s.HealthPort != 9090 { + t.Errorf("expected HealthPort 9090, got %d", s.HealthPort) + } + if s.OTELEnabled { + t.Error("expected OTELEnabled false") + } } func TestURLGetters(t *testing.T) { -s := Load() -if s.EmbeddingsURL() == "" { -t.Error("EmbeddingsURL should have a default") -} -if s.RerankerURL() == "" { -t.Error("RerankerURL should have a default") -} -if s.LLMURL() == "" { -t.Error("LLMURL should have a default") -} -if s.TTSURL() == "" { -t.Error("TTSURL should have a default") -} -if s.STTURL() == "" { -t.Error("STTURL should have a default") -} + s := Load() + if s.EmbeddingsURL() == "" { + t.Error("EmbeddingsURL should have a default") + } + if s.RerankerURL() == "" { + t.Error("RerankerURL should have a default") + } + if s.LLMURL() == "" { + t.Error("LLMURL should have a default") + } + if s.TTSURL() == "" { + t.Error("TTSURL should have a default") + } + if s.STTURL() == "" { + t.Error("STTURL should have a default") + } } func TestURLGettersFromEnv(t *testing.T) { -t.Setenv("EMBEDDINGS_URL", "http://embed:8000") -t.Setenv("LLM_URL", "http://llm:9000") + t.Setenv("EMBEDDINGS_URL", "http://embed:8000") + t.Setenv("LLM_URL", "http://llm:9000") -s := Load() -if s.EmbeddingsURL() != "http://embed:8000" { -t.Errorf("expected custom EmbeddingsURL, got %q", s.EmbeddingsURL()) -} -if s.LLMURL() != "http://llm:9000" { -t.Errorf("expected custom LLMURL, got %q", s.LLMURL()) -} + s := Load() + if s.EmbeddingsURL() != "http://embed:8000" { + t.Errorf("expected custom EmbeddingsURL, got %q", s.EmbeddingsURL()) + } + if s.LLMURL() != "http://llm:9000" { + t.Errorf("expected custom LLMURL, got %q", s.LLMURL()) + } } func TestReloadFromSecrets(t *testing.T) { -dir := t.TempDir() + dir := t.TempDir() -// Write initial secret files -writeSecret(t, dir, "embeddings-url", "http://old-embed:8000") -writeSecret(t, dir, "llm-url", "http://old-llm:9000") + // Write initial secret files + writeSecret(t, dir, "embeddings-url", "http://old-embed:8000") + writeSecret(t, dir, "llm-url", "http://old-llm:9000") -s := Load() -s.SecretsPath = dir -s.reloadFromSecrets() + s := Load() + s.SecretsPath = dir + s.reloadFromSecrets() -if s.EmbeddingsURL() != "http://old-embed:8000" { -t.Errorf("expected reloaded EmbeddingsURL, got %q", s.EmbeddingsURL()) -} -if s.LLMURL() != "http://old-llm:9000" { -t.Errorf("expected reloaded LLMURL, got %q", s.LLMURL()) -} + if s.EmbeddingsURL() != "http://old-embed:8000" { + t.Errorf("expected reloaded EmbeddingsURL, got %q", s.EmbeddingsURL()) + } + if s.LLMURL() != "http://old-llm:9000" { + t.Errorf("expected reloaded LLMURL, got %q", s.LLMURL()) + } -// Simulate secret update -writeSecret(t, dir, "embeddings-url", "http://new-embed:8000") -s.reloadFromSecrets() + // Simulate secret update + writeSecret(t, dir, "embeddings-url", "http://new-embed:8000") + s.reloadFromSecrets() -if s.EmbeddingsURL() != "http://new-embed:8000" { -t.Errorf("expected updated EmbeddingsURL, got %q", s.EmbeddingsURL()) -} -// LLM should remain unchanged -if s.LLMURL() != "http://old-llm:9000" { -t.Errorf("expected unchanged LLMURL, got %q", s.LLMURL()) -} + if s.EmbeddingsURL() != "http://new-embed:8000" { + t.Errorf("expected updated EmbeddingsURL, got %q", s.EmbeddingsURL()) + } + // LLM should remain unchanged + if s.LLMURL() != "http://old-llm:9000" { + t.Errorf("expected unchanged LLMURL, got %q", s.LLMURL()) + } } func TestReloadFromSecretsNoPath(t *testing.T) { -s := Load() -s.SecretsPath = "" -// Should not panic -s.reloadFromSecrets() + s := Load() + s.SecretsPath = "" + // Should not panic + s.reloadFromSecrets() } func TestGetEnvDuration(t *testing.T) { -t.Setenv("TEST_DUR", "30") -d := getEnvDuration("TEST_DUR", 10*time.Second) -if d != 30*time.Second { -t.Errorf("expected 30s, got %v", d) -} + t.Setenv("TEST_DUR", "30") + d := getEnvDuration("TEST_DUR", 10*time.Second) + if d != 30*time.Second { + t.Errorf("expected 30s, got %v", d) + } } func writeSecret(t *testing.T, dir, name, value string) { -t.Helper() -if err := os.WriteFile(filepath.Join(dir, name), []byte(value), 0644); err != nil { -t.Fatal(err) -} + t.Helper() + if err := os.WriteFile(filepath.Join(dir, name), []byte(value), 0644); err != nil { + t.Fatal(err) + } } diff --git a/handler/handler.go b/handler/handler.go index 2ad8d05..a1117a4 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -2,21 +2,21 @@ package handler import ( -"context" -"fmt" -"log/slog" -"os" -"os/signal" -"syscall" + "context" + "fmt" + "log/slog" + "os" + "os/signal" + "syscall" -"github.com/nats-io/nats.go" -"google.golang.org/protobuf/proto" + "github.com/nats-io/nats.go" + "google.golang.org/protobuf/proto" -"git.daviestechlabs.io/daviestechlabs/handler-base/config" -pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb" -"git.daviestechlabs.io/daviestechlabs/handler-base/health" -"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil" -"git.daviestechlabs.io/daviestechlabs/handler-base/telemetry" + "git.daviestechlabs.io/daviestechlabs/handler-base/config" + pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb" + "git.daviestechlabs.io/daviestechlabs/handler-base/health" + "git.daviestechlabs.io/daviestechlabs/handler-base/natsutil" + "git.daviestechlabs.io/daviestechlabs/handler-base/telemetry" ) // TypedMessageHandler processes the raw NATS message. @@ -32,36 +32,36 @@ type TeardownFunc func(ctx context.Context) error // Handler is the base service runner that wires NATS, health, and telemetry. type Handler struct { -Settings *config.Settings -NATS *natsutil.Client -Telemetry *telemetry.Provider -Subject string -QueueGroup string + Settings *config.Settings + NATS *natsutil.Client + Telemetry *telemetry.Provider + Subject string + QueueGroup string -onSetup SetupFunc -onTeardown TeardownFunc -onTypedMessage TypedMessageHandler -running bool + onSetup SetupFunc + onTeardown TeardownFunc + onTypedMessage TypedMessageHandler + running bool } // New creates a Handler for the given NATS subject. func New(subject string, settings *config.Settings) *Handler { -if settings == nil { -settings = config.Load() -} -queueGroup := settings.NATSQueueGroup + if settings == nil { + settings = config.Load() + } + queueGroup := settings.NATSQueueGroup -natsOpts := []nats.Option{} -if settings.NATSUser != "" && settings.NATSPassword != "" { -natsOpts = append(natsOpts, nats.UserInfo(settings.NATSUser, settings.NATSPassword)) -} + natsOpts := []nats.Option{} + if settings.NATSUser != "" && settings.NATSPassword != "" { + natsOpts = append(natsOpts, nats.UserInfo(settings.NATSUser, settings.NATSPassword)) + } -return &Handler{ -Settings: settings, -Subject: subject, -QueueGroup: queueGroup, -NATS: natsutil.New(settings.NATSURL, natsOpts...), -} + return &Handler{ + Settings: settings, + Subject: subject, + QueueGroup: queueGroup, + NATS: natsutil.New(settings.NATSURL, natsOpts...), + } } // OnSetup registers the setup callback. @@ -75,101 +75,101 @@ func (h *Handler) OnTypedMessage(fn TypedMessageHandler) { h.onTypedMessage = fn // Run starts the handler: telemetry, health server, NATS subscription, and blocks until SIGTERM/SIGINT. func (h *Handler) Run() error { -// Structured logging -slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo}))) -slog.Info("starting service", "name", h.Settings.ServiceName, "version", h.Settings.ServiceVersion) + // Structured logging + slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo}))) + slog.Info("starting service", "name", h.Settings.ServiceName, "version", h.Settings.ServiceVersion) -ctx, cancel := context.WithCancel(context.Background()) -defer cancel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() -// Telemetry -tp, shutdown, err := telemetry.Setup(ctx, telemetry.Config{ -ServiceName: h.Settings.ServiceName, -ServiceVersion: h.Settings.ServiceVersion, -ServiceNamespace: h.Settings.ServiceNamespace, -DeploymentEnv: h.Settings.DeploymentEnv, -Enabled: h.Settings.OTELEnabled, -Endpoint: h.Settings.OTELEndpoint, -}) -if err != nil { -return fmt.Errorf("telemetry setup: %w", err) -} -defer shutdown(ctx) -h.Telemetry = tp + // Telemetry + tp, shutdown, err := telemetry.Setup(ctx, telemetry.Config{ + ServiceName: h.Settings.ServiceName, + ServiceVersion: h.Settings.ServiceVersion, + ServiceNamespace: h.Settings.ServiceNamespace, + DeploymentEnv: h.Settings.DeploymentEnv, + Enabled: h.Settings.OTELEnabled, + Endpoint: h.Settings.OTELEndpoint, + }) + if err != nil { + return fmt.Errorf("telemetry setup: %w", err) + } + defer shutdown(ctx) + h.Telemetry = tp -// Health server -healthSrv := health.New( -h.Settings.HealthPort, -h.Settings.HealthPath, -h.Settings.ReadyPath, -func() bool { return h.running && h.NATS.IsConnected() }, -) -healthSrv.Start() -defer healthSrv.Stop(ctx) + // Health server + healthSrv := health.New( + h.Settings.HealthPort, + h.Settings.HealthPath, + h.Settings.ReadyPath, + func() bool { return h.running && h.NATS.IsConnected() }, + ) + healthSrv.Start() + defer healthSrv.Stop(ctx) -// Connect to NATS -if err := h.NATS.Connect(); err != nil { -return fmt.Errorf("nats: %w", err) -} -defer h.NATS.Close() + // Connect to NATS + if err := h.NATS.Connect(); err != nil { + return fmt.Errorf("nats: %w", err) + } + defer h.NATS.Close() -// User setup -if h.onSetup != nil { -slog.Info("running service setup") -if err := h.onSetup(ctx); err != nil { -return fmt.Errorf("setup: %w", err) -} -} + // User setup + if h.onSetup != nil { + slog.Info("running service setup") + if err := h.onSetup(ctx); err != nil { + return fmt.Errorf("setup: %w", err) + } + } -// Subscribe -if h.onTypedMessage == nil { -return fmt.Errorf("no message handler registered") -} -if err := h.NATS.Subscribe(h.Subject, h.wrapHandler(ctx), h.QueueGroup); err != nil { -return fmt.Errorf("subscribe: %w", err) -} + // Subscribe + if h.onTypedMessage == nil { + return fmt.Errorf("no message handler registered") + } + if err := h.NATS.Subscribe(h.Subject, h.wrapHandler(ctx), h.QueueGroup); err != nil { + return fmt.Errorf("subscribe: %w", err) + } -h.running = true -slog.Info("handler ready", "subject", h.Subject) + h.running = true + slog.Info("handler ready", "subject", h.Subject) -// Wait for shutdown signal -sigCh := make(chan os.Signal, 1) -signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT) -<-sigCh + // Wait for shutdown signal + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT) + <-sigCh -slog.Info("shutting down") -h.running = false + slog.Info("shutting down") + h.running = false -// Teardown -if h.onTeardown != nil { -if err := h.onTeardown(ctx); err != nil { -slog.Warn("teardown error", "error", err) -} -} + // Teardown + if h.onTeardown != nil { + if err := h.onTeardown(ctx); err != nil { + slog.Warn("teardown error", "error", err) + } + } -slog.Info("shutdown complete") -return nil + slog.Info("shutdown complete") + return nil } // wrapHandler creates a nats.MsgHandler that dispatches to the registered callback. func (h *Handler) wrapHandler(ctx context.Context) nats.MsgHandler { -return func(msg *nats.Msg) { -response, err := h.onTypedMessage(ctx, msg) -if err != nil { -slog.Error("handler error", "subject", msg.Subject, "error", err) -if msg.Reply != "" { -_ = h.NATS.Publish(msg.Reply, &pb.ErrorResponse{ -Error: true, -Message: err.Error(), -Type: fmt.Sprintf("%T", err), -}) -} -return -} -if response != nil && msg.Reply != "" { -if err := h.NATS.Publish(msg.Reply, response); err != nil { -slog.Error("failed to publish reply", "error", err) -} -} -} + return func(msg *nats.Msg) { + response, err := h.onTypedMessage(ctx, msg) + if err != nil { + slog.Error("handler error", "subject", msg.Subject, "error", err) + if msg.Reply != "" { + _ = h.NATS.Publish(msg.Reply, &pb.ErrorResponse{ + Error: true, + Message: err.Error(), + Type: fmt.Sprintf("%T", err), + }) + } + return + } + if response != nil && msg.Reply != "" { + if err := h.NATS.Publish(msg.Reply, response); err != nil { + slog.Error("failed to publish reply", "error", err) + } + } + } } diff --git a/handler/handler_test.go b/handler/handler_test.go index 33fb694..90ccc37 100644 --- a/handler/handler_test.go +++ b/handler/handler_test.go @@ -1,15 +1,15 @@ package handler import ( -"context" -"testing" + "context" + "testing" -"github.com/nats-io/nats.go" -"google.golang.org/protobuf/proto" + "github.com/nats-io/nats.go" + "google.golang.org/protobuf/proto" -"git.daviestechlabs.io/daviestechlabs/handler-base/config" -pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb" -"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil" + "git.daviestechlabs.io/daviestechlabs/handler-base/config" + pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb" + "git.daviestechlabs.io/daviestechlabs/handler-base/natsutil" ) // ──────────────────────────────────────────────────────────────────────────── @@ -17,75 +17,75 @@ pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb" // ──────────────────────────────────────────────────────────────────────────── func TestNewHandler(t *testing.T) { -cfg := config.Load() -cfg.ServiceName = "test-handler" -cfg.NATSQueueGroup = "test-group" + cfg := config.Load() + cfg.ServiceName = "test-handler" + cfg.NATSQueueGroup = "test-group" -h := New("ai.test.subject", cfg) -if h.Subject != "ai.test.subject" { -t.Errorf("Subject = %q", h.Subject) -} -if h.QueueGroup != "test-group" { -t.Errorf("QueueGroup = %q", h.QueueGroup) -} -if h.Settings.ServiceName != "test-handler" { -t.Errorf("ServiceName = %q", h.Settings.ServiceName) -} + h := New("ai.test.subject", cfg) + if h.Subject != "ai.test.subject" { + t.Errorf("Subject = %q", h.Subject) + } + if h.QueueGroup != "test-group" { + t.Errorf("QueueGroup = %q", h.QueueGroup) + } + if h.Settings.ServiceName != "test-handler" { + t.Errorf("ServiceName = %q", h.Settings.ServiceName) + } } func TestNewHandlerNilSettings(t *testing.T) { -h := New("ai.test", nil) -if h.Settings == nil { -t.Fatal("Settings should be loaded automatically") -} -if h.Settings.ServiceName != "handler" { -t.Errorf("ServiceName = %q, want default", h.Settings.ServiceName) -} + h := New("ai.test", nil) + if h.Settings == nil { + t.Fatal("Settings should be loaded automatically") + } + if h.Settings.ServiceName != "handler" { + t.Errorf("ServiceName = %q, want default", h.Settings.ServiceName) + } } func TestCallbackRegistration(t *testing.T) { -cfg := config.Load() -h := New("ai.test", cfg) + cfg := config.Load() + h := New("ai.test", cfg) -setupCalled := false -h.OnSetup(func(ctx context.Context) error { -setupCalled = true -return nil -}) + setupCalled := false + h.OnSetup(func(ctx context.Context) error { + setupCalled = true + return nil + }) -teardownCalled := false -h.OnTeardown(func(ctx context.Context) error { -teardownCalled = true -return nil -}) + teardownCalled := false + h.OnTeardown(func(ctx context.Context) error { + teardownCalled = true + return nil + }) -h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { -return nil, nil -}) + h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { + return nil, nil + }) -if h.onSetup == nil || h.onTeardown == nil || h.onTypedMessage == nil { -t.Error("callbacks should not be nil after registration") -} + if h.onSetup == nil || h.onTeardown == nil || h.onTypedMessage == nil { + t.Error("callbacks should not be nil after registration") + } -// Verify setup/teardown work when called directly. -_ = h.onSetup(context.Background()) -_ = h.onTeardown(context.Background()) -if !setupCalled || !teardownCalled { -t.Error("callbacks should have been invoked") -} + // Verify setup/teardown work when called directly. + _ = h.onSetup(context.Background()) + _ = h.onTeardown(context.Background()) + if !setupCalled || !teardownCalled { + t.Error("callbacks should have been invoked") + } } func TestTypedMessageRegistration(t *testing.T) { -cfg := config.Load() -h := New("ai.test", cfg) + cfg := config.Load() + h := New("ai.test", cfg) -h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { + h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { return &pb.ChatResponse{Response: "ok"}, nil -}) + }) -if h.onTypedMessage == nil { -t.Error("onTypedMessage should not be nil after registration") -} + if h.onTypedMessage == nil { + t.Error("onTypedMessage should not be nil after registration") + } } // ──────────────────────────────────────────────────────────────────────────── @@ -93,101 +93,101 @@ t.Error("onTypedMessage should not be nil after registration") // ──────────────────────────────────────────────────────────────────────────── func TestWrapHandler_ValidMessage(t *testing.T) { -cfg := config.Load() -h := New("ai.test", cfg) + cfg := config.Load() + h := New("ai.test", cfg) -var receivedReq pb.ChatRequest -h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { -if err := natsutil.Decode(msg.Data, &receivedReq); err != nil { -return nil, err -} + var receivedReq pb.ChatRequest + h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { + if err := natsutil.Decode(msg.Data, &receivedReq); err != nil { + return nil, err + } return &pb.ChatResponse{Response: "ok", UserId: receivedReq.GetUserId()}, nil -}) + }) -// Encode a message the same way services would. -encoded, err := proto.Marshal(&pb.ChatRequest{ -RequestId: "test-001", -Message: "hello", -Premium: true, -}) -if err != nil { -t.Fatal(err) -} + // Encode a message the same way services would. + encoded, err := proto.Marshal(&pb.ChatRequest{ + RequestId: "test-001", + Message: "hello", + Premium: true, + }) + if err != nil { + t.Fatal(err) + } -// Call wrapHandler directly without NATS. -handler := h.wrapHandler(context.Background()) -handler(&nats.Msg{ -Subject: "ai.test.user.42.message", -Data: encoded, -}) + // Call wrapHandler directly without NATS. + handler := h.wrapHandler(context.Background()) + handler(&nats.Msg{ + Subject: "ai.test.user.42.message", + Data: encoded, + }) -if receivedReq.GetRequestId() != "test-001" { -t.Errorf("request_id = %v", receivedReq.GetRequestId()) -} -if receivedReq.GetPremium() != true { -t.Errorf("premium = %v", receivedReq.GetPremium()) -} + if receivedReq.GetRequestId() != "test-001" { + t.Errorf("request_id = %v", receivedReq.GetRequestId()) + } + if receivedReq.GetPremium() != true { + t.Errorf("premium = %v", receivedReq.GetPremium()) + } } func TestWrapHandler_InvalidMessage(t *testing.T) { -cfg := config.Load() -h := New("ai.test", cfg) + cfg := config.Load() + h := New("ai.test", cfg) -handlerCalled := false -h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { -handlerCalled = true -var req pb.ChatRequest -if err := natsutil.Decode(msg.Data, &req); err != nil { -return nil, err -} -return &pb.ChatResponse{}, nil -}) + handlerCalled := false + h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { + handlerCalled = true + var req pb.ChatRequest + if err := natsutil.Decode(msg.Data, &req); err != nil { + return nil, err + } + return &pb.ChatResponse{}, nil + }) -handler := h.wrapHandler(context.Background()) -handler(&nats.Msg{ -Subject: "ai.test", -Data: []byte{0xFF, 0xFE, 0xFD}, // invalid protobuf -}) + handler := h.wrapHandler(context.Background()) + handler(&nats.Msg{ + Subject: "ai.test", + Data: []byte{0xFF, 0xFE, 0xFD}, // invalid protobuf + }) -// The handler IS called (wrapHandler doesn't pre-decode), but it should -// return an error from Decode. Either way no panic. -_ = handlerCalled + // The handler IS called (wrapHandler doesn't pre-decode), but it should + // return an error from Decode. Either way no panic. + _ = handlerCalled } func TestWrapHandler_HandlerError(t *testing.T) { -cfg := config.Load() -h := New("ai.test", cfg) + cfg := config.Load() + h := New("ai.test", cfg) -h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { -return nil, context.DeadlineExceeded -}) + h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { + return nil, context.DeadlineExceeded + }) -encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "err-test"}) -handler := h.wrapHandler(context.Background()) + encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "err-test"}) + handler := h.wrapHandler(context.Background()) -// Should not panic even when handler returns error. -handler(&nats.Msg{ -Subject: "ai.test", -Data: encoded, -}) + // Should not panic even when handler returns error. + handler(&nats.Msg{ + Subject: "ai.test", + Data: encoded, + }) } func TestWrapHandler_NilResponse(t *testing.T) { -cfg := config.Load() -h := New("ai.test", cfg) + cfg := config.Load() + h := New("ai.test", cfg) -h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { -return nil, nil // fire-and-forget style -}) + h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { + return nil, nil // fire-and-forget style + }) -encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "nil-resp"}) -handler := h.wrapHandler(context.Background()) + encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "nil-resp"}) + handler := h.wrapHandler(context.Background()) -// Should not panic with nil response and no reply subject. -handler(&nats.Msg{ -Subject: "ai.test", -Data: encoded, -}) + // Should not panic with nil response and no reply subject. + handler(&nats.Msg{ + Subject: "ai.test", + Data: encoded, + }) } // ──────────────────────────────────────────────────────────────────────────── @@ -195,59 +195,59 @@ Data: encoded, // ──────────────────────────────────────────────────────────────────────────── func TestWrapHandler_Typed(t *testing.T) { -cfg := config.Load() -h := New("ai.test", cfg) + cfg := config.Load() + h := New("ai.test", cfg) -var received pb.ChatRequest -h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { -if err := natsutil.Decode(msg.Data, &received); err != nil { -return nil, err -} + var received pb.ChatRequest + h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { + if err := natsutil.Decode(msg.Data, &received); err != nil { + return nil, err + } return &pb.ChatResponse{UserId: received.GetUserId(), Response: "ok"}, nil -}) + }) -encoded, _ := proto.Marshal(&pb.ChatRequest{ -RequestId: "typed-001", -Message: "hello typed", -}) + encoded, _ := proto.Marshal(&pb.ChatRequest{ + RequestId: "typed-001", + Message: "hello typed", + }) -handler := h.wrapHandler(context.Background()) -handler(&nats.Msg{Subject: "ai.test", Data: encoded}) + handler := h.wrapHandler(context.Background()) + handler(&nats.Msg{Subject: "ai.test", Data: encoded}) -if received.GetRequestId() != "typed-001" { -t.Errorf("RequestId = %q", received.GetRequestId()) -} -if received.GetMessage() != "hello typed" { -t.Errorf("Message = %q", received.GetMessage()) -} + if received.GetRequestId() != "typed-001" { + t.Errorf("RequestId = %q", received.GetRequestId()) + } + if received.GetMessage() != "hello typed" { + t.Errorf("Message = %q", received.GetMessage()) + } } func TestWrapHandler_TypedError(t *testing.T) { -cfg := config.Load() -h := New("ai.test", cfg) + cfg := config.Load() + h := New("ai.test", cfg) -h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { -return nil, context.DeadlineExceeded -}) + h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { + return nil, context.DeadlineExceeded + }) -encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "err"}) -handler := h.wrapHandler(context.Background()) + encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "err"}) + handler := h.wrapHandler(context.Background()) -// Should not panic. -handler(&nats.Msg{Subject: "ai.test", Data: encoded}) + // Should not panic. + handler(&nats.Msg{Subject: "ai.test", Data: encoded}) } func TestWrapHandler_TypedNilResponse(t *testing.T) { -cfg := config.Load() -h := New("ai.test", cfg) + cfg := config.Load() + h := New("ai.test", cfg) -h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { -return nil, nil -}) + h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { + return nil, nil + }) -encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "nil"}) -handler := h.wrapHandler(context.Background()) -handler(&nats.Msg{Subject: "ai.test", Data: encoded}) + encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "nil"}) + handler := h.wrapHandler(context.Background()) + handler(&nats.Msg{Subject: "ai.test", Data: encoded}) } // ──────────────────────────────────────────────────────────────────────────── @@ -255,25 +255,25 @@ handler(&nats.Msg{Subject: "ai.test", Data: encoded}) // ──────────────────────────────────────────────────────────────────────────── func BenchmarkWrapHandler(b *testing.B) { -cfg := config.Load() -h := New("ai.test", cfg) -h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { -var req pb.ChatRequest -_ = natsutil.Decode(msg.Data, &req) -return &pb.ChatResponse{Response: "ok"}, nil -}) + cfg := config.Load() + h := New("ai.test", cfg) + h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { + var req pb.ChatRequest + _ = natsutil.Decode(msg.Data, &req) + return &pb.ChatResponse{Response: "ok"}, nil + }) -encoded, _ := proto.Marshal(&pb.ChatRequest{ -RequestId: "bench-001", -Message: "What is the capital of France?", -Premium: true, -TopK: 10, -}) -handler := h.wrapHandler(context.Background()) -msg := &nats.Msg{Subject: "ai.test", Data: encoded} + encoded, _ := proto.Marshal(&pb.ChatRequest{ + RequestId: "bench-001", + Message: "What is the capital of France?", + Premium: true, + TopK: 10, + }) + handler := h.wrapHandler(context.Background()) + msg := &nats.Msg{Subject: "ai.test", Data: encoded} -b.ResetTimer() -for b.Loop() { -handler(msg) -} + b.ResetTimer() + for b.Loop() { + handler(msg) + } } diff --git a/messages/bench_test.go b/messages/bench_test.go index b4a3cd0..1907ff0 100644 --- a/messages/bench_test.go +++ b/messages/bench_test.go @@ -2,17 +2,17 @@ // // Run with: // -//go test -bench=. -benchmem -count=5 ./messages/... | tee bench.txt -//# optional: go install golang.org/x/perf/cmd/benchstat@latest && benchstat bench.txt +// go test -bench=. -benchmem -count=5 ./messages/... | tee bench.txt +// # optional: go install golang.org/x/perf/cmd/benchstat@latest && benchstat bench.txt package messages import ( -"testing" -"time" + "testing" + "time" -"google.golang.org/protobuf/proto" + "google.golang.org/protobuf/proto" -pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb" + pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb" ) // ──────────────────────────────────────────────────────────────────────────── @@ -20,39 +20,39 @@ pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb" // ──────────────────────────────────────────────────────────────────────────── func chatRequestProto() *pb.ChatRequest { -return &pb.ChatRequest{ -RequestId: "req-abc-123", -UserId: "user-42", -Message: "What is the capital of France?", -Premium: true, -EnableRag: true, -EnableReranker: true, -TopK: 10, -Collection: "documents", -SystemPrompt: "You are a helpful assistant.", -ResponseSubject: "ai.chat.response.req-abc-123", -} + return &pb.ChatRequest{ + RequestId: "req-abc-123", + UserId: "user-42", + Message: "What is the capital of France?", + Premium: true, + EnableRag: true, + EnableReranker: true, + TopK: 10, + Collection: "documents", + SystemPrompt: "You are a helpful assistant.", + ResponseSubject: "ai.chat.response.req-abc-123", + } } func voiceResponseProto() *pb.VoiceResponse { -return &pb.VoiceResponse{ -RequestId: "vr-001", -Response: "The capital of France is Paris.", -Audio: make([]byte, 16384), -Transcription: "What is the capital of France?", -} + return &pb.VoiceResponse{ + RequestId: "vr-001", + Response: "The capital of France is Paris.", + Audio: make([]byte, 16384), + Transcription: "What is the capital of France?", + } } func ttsChunkProto() *pb.TTSAudioChunk { -return &pb.TTSAudioChunk{ -SessionId: "tts-sess-99", -ChunkIndex: 3, -TotalChunks: 12, -Audio: make([]byte, 32768), -IsLast: false, -Timestamp: time.Now().Unix(), -SampleRate: 24000, -} + return &pb.TTSAudioChunk{ + SessionId: "tts-sess-99", + ChunkIndex: 3, + TotalChunks: 12, + Audio: make([]byte, 32768), + IsLast: false, + Timestamp: time.Now().Unix(), + SampleRate: 24000, + } } // ──────────────────────────────────────────────────────────────────────────── @@ -60,19 +60,19 @@ SampleRate: 24000, // ──────────────────────────────────────────────────────────────────────────── func TestWireSize(t *testing.T) { -tests := []struct { -name string -protoMsg proto.Message -}{ -{"ChatRequest", chatRequestProto()}, -{"VoiceResponse", voiceResponseProto()}, -{"TTSAudioChunk", ttsChunkProto()}, -} + tests := []struct { + name string + protoMsg proto.Message + }{ + {"ChatRequest", chatRequestProto()}, + {"VoiceResponse", voiceResponseProto()}, + {"TTSAudioChunk", ttsChunkProto()}, + } -for _, tt := range tests { -protoBytes, _ := proto.Marshal(tt.protoMsg) -t.Logf("%-16s proto=%5d B", tt.name, len(protoBytes)) -} + for _, tt := range tests { + protoBytes, _ := proto.Marshal(tt.protoMsg) + t.Logf("%-16s proto=%5d B", tt.name, len(protoBytes)) + } } // ──────────────────────────────────────────────────────────────────────────── @@ -80,27 +80,27 @@ t.Logf("%-16s proto=%5d B", tt.name, len(protoBytes)) // ──────────────────────────────────────────────────────────────────────────── func BenchmarkEncode_ChatRequest(b *testing.B) { -data := chatRequestProto() -b.ResetTimer() -for b.Loop() { -_, _ = proto.Marshal(data) -} + data := chatRequestProto() + b.ResetTimer() + for b.Loop() { + _, _ = proto.Marshal(data) + } } func BenchmarkEncode_VoiceResponse(b *testing.B) { -data := voiceResponseProto() -b.ResetTimer() -for b.Loop() { -_, _ = proto.Marshal(data) -} + data := voiceResponseProto() + b.ResetTimer() + for b.Loop() { + _, _ = proto.Marshal(data) + } } func BenchmarkEncode_TTSChunk(b *testing.B) { -data := ttsChunkProto() -b.ResetTimer() -for b.Loop() { -_, _ = proto.Marshal(data) -} + data := ttsChunkProto() + b.ResetTimer() + for b.Loop() { + _, _ = proto.Marshal(data) + } } // ──────────────────────────────────────────────────────────────────────────── @@ -108,30 +108,30 @@ _, _ = proto.Marshal(data) // ──────────────────────────────────────────────────────────────────────────── func BenchmarkDecode_ChatRequest(b *testing.B) { -encoded, _ := proto.Marshal(chatRequestProto()) -b.ResetTimer() -for b.Loop() { -var m pb.ChatRequest -_ = proto.Unmarshal(encoded, &m) -} + encoded, _ := proto.Marshal(chatRequestProto()) + b.ResetTimer() + for b.Loop() { + var m pb.ChatRequest + _ = proto.Unmarshal(encoded, &m) + } } func BenchmarkDecode_VoiceResponse(b *testing.B) { -encoded, _ := proto.Marshal(voiceResponseProto()) -b.ResetTimer() -for b.Loop() { -var m pb.VoiceResponse -_ = proto.Unmarshal(encoded, &m) -} + encoded, _ := proto.Marshal(voiceResponseProto()) + b.ResetTimer() + for b.Loop() { + var m pb.VoiceResponse + _ = proto.Unmarshal(encoded, &m) + } } func BenchmarkDecode_TTSChunk(b *testing.B) { -encoded, _ := proto.Marshal(ttsChunkProto()) -b.ResetTimer() -for b.Loop() { -var m pb.TTSAudioChunk -_ = proto.Unmarshal(encoded, &m) -} + encoded, _ := proto.Marshal(ttsChunkProto()) + b.ResetTimer() + for b.Loop() { + var m pb.TTSAudioChunk + _ = proto.Unmarshal(encoded, &m) + } } // ──────────────────────────────────────────────────────────────────────────── @@ -139,13 +139,13 @@ _ = proto.Unmarshal(encoded, &m) // ──────────────────────────────────────────────────────────────────────────── func BenchmarkRoundtrip_ChatRequest(b *testing.B) { -data := chatRequestProto() -b.ResetTimer() -for b.Loop() { -enc, _ := proto.Marshal(data) -var dec pb.ChatRequest -_ = proto.Unmarshal(enc, &dec) -} + data := chatRequestProto() + b.ResetTimer() + for b.Loop() { + enc, _ := proto.Marshal(data) + var dec pb.ChatRequest + _ = proto.Unmarshal(enc, &dec) + } } // ──────────────────────────────────────────────────────────────────────────── @@ -153,160 +153,160 @@ _ = proto.Unmarshal(enc, &dec) // ──────────────────────────────────────────────────────────────────────────── func TestRoundtrip_ChatRequest(t *testing.T) { -orig := chatRequestProto() -data, err := proto.Marshal(orig) -if err != nil { -t.Fatal(err) -} -var dec pb.ChatRequest -if err := proto.Unmarshal(data, &dec); err != nil { -t.Fatal(err) -} -if dec.GetRequestId() != orig.GetRequestId() { -t.Errorf("RequestId = %q, want %q", dec.GetRequestId(), orig.GetRequestId()) -} -if dec.GetMessage() != orig.GetMessage() { -t.Errorf("Message = %q, want %q", dec.GetMessage(), orig.GetMessage()) -} -if dec.GetTopK() != orig.GetTopK() { -t.Errorf("TopK = %d, want %d", dec.GetTopK(), orig.GetTopK()) -} -if dec.GetPremium() != orig.GetPremium() { -t.Errorf("Premium = %v, want %v", dec.GetPremium(), orig.GetPremium()) -} -if EffectiveQuery(&dec) != orig.GetMessage() { -t.Errorf("EffectiveQuery() = %q, want %q", EffectiveQuery(&dec), orig.GetMessage()) -} + orig := chatRequestProto() + data, err := proto.Marshal(orig) + if err != nil { + t.Fatal(err) + } + var dec pb.ChatRequest + if err := proto.Unmarshal(data, &dec); err != nil { + t.Fatal(err) + } + if dec.GetRequestId() != orig.GetRequestId() { + t.Errorf("RequestId = %q, want %q", dec.GetRequestId(), orig.GetRequestId()) + } + if dec.GetMessage() != orig.GetMessage() { + t.Errorf("Message = %q, want %q", dec.GetMessage(), orig.GetMessage()) + } + if dec.GetTopK() != orig.GetTopK() { + t.Errorf("TopK = %d, want %d", dec.GetTopK(), orig.GetTopK()) + } + if dec.GetPremium() != orig.GetPremium() { + t.Errorf("Premium = %v, want %v", dec.GetPremium(), orig.GetPremium()) + } + if EffectiveQuery(&dec) != orig.GetMessage() { + t.Errorf("EffectiveQuery() = %q, want %q", EffectiveQuery(&dec), orig.GetMessage()) + } } func TestRoundtrip_VoiceResponse(t *testing.T) { -orig := voiceResponseProto() -data, err := proto.Marshal(orig) -if err != nil { -t.Fatal(err) -} -var dec pb.VoiceResponse -if err := proto.Unmarshal(data, &dec); err != nil { -t.Fatal(err) -} -if dec.GetRequestId() != orig.GetRequestId() { -t.Errorf("RequestId mismatch") -} -if len(dec.GetAudio()) != len(orig.GetAudio()) { -t.Errorf("Audio len = %d, want %d", len(dec.GetAudio()), len(orig.GetAudio())) -} -if dec.GetTranscription() != orig.GetTranscription() { -t.Errorf("Transcription mismatch") -} + orig := voiceResponseProto() + data, err := proto.Marshal(orig) + if err != nil { + t.Fatal(err) + } + var dec pb.VoiceResponse + if err := proto.Unmarshal(data, &dec); err != nil { + t.Fatal(err) + } + if dec.GetRequestId() != orig.GetRequestId() { + t.Errorf("RequestId mismatch") + } + if len(dec.GetAudio()) != len(orig.GetAudio()) { + t.Errorf("Audio len = %d, want %d", len(dec.GetAudio()), len(orig.GetAudio())) + } + if dec.GetTranscription() != orig.GetTranscription() { + t.Errorf("Transcription mismatch") + } } func TestRoundtrip_TTSAudioChunk(t *testing.T) { -orig := ttsChunkProto() -data, err := proto.Marshal(orig) -if err != nil { -t.Fatal(err) -} -var dec pb.TTSAudioChunk -if err := proto.Unmarshal(data, &dec); err != nil { -t.Fatal(err) -} -if dec.GetSessionId() != orig.GetSessionId() { -t.Errorf("SessionId mismatch") -} -if dec.GetChunkIndex() != orig.GetChunkIndex() { -t.Errorf("ChunkIndex = %d, want %d", dec.GetChunkIndex(), orig.GetChunkIndex()) -} -if len(dec.GetAudio()) != len(orig.GetAudio()) { -t.Errorf("Audio len = %d, want %d", len(dec.GetAudio()), len(orig.GetAudio())) -} -if dec.GetSampleRate() != orig.GetSampleRate() { -t.Errorf("SampleRate = %d, want %d", dec.GetSampleRate(), orig.GetSampleRate()) -} + orig := ttsChunkProto() + data, err := proto.Marshal(orig) + if err != nil { + t.Fatal(err) + } + var dec pb.TTSAudioChunk + if err := proto.Unmarshal(data, &dec); err != nil { + t.Fatal(err) + } + if dec.GetSessionId() != orig.GetSessionId() { + t.Errorf("SessionId mismatch") + } + if dec.GetChunkIndex() != orig.GetChunkIndex() { + t.Errorf("ChunkIndex = %d, want %d", dec.GetChunkIndex(), orig.GetChunkIndex()) + } + if len(dec.GetAudio()) != len(orig.GetAudio()) { + t.Errorf("Audio len = %d, want %d", len(dec.GetAudio()), len(orig.GetAudio())) + } + if dec.GetSampleRate() != orig.GetSampleRate() { + t.Errorf("SampleRate = %d, want %d", dec.GetSampleRate(), orig.GetSampleRate()) + } } func TestRoundtrip_PipelineTrigger(t *testing.T) { -orig := &pb.PipelineTrigger{ -RequestId: "pip-001", -Pipeline: "document-ingestion", -Parameters: map[string]string{"source": "s3://bucket/data"}, -} -data, err := proto.Marshal(orig) -if err != nil { -t.Fatal(err) -} -var dec pb.PipelineTrigger -if err := proto.Unmarshal(data, &dec); err != nil { -t.Fatal(err) -} -if dec.GetPipeline() != orig.GetPipeline() { -t.Errorf("Pipeline = %q, want %q", dec.GetPipeline(), orig.GetPipeline()) -} -if dec.GetParameters()["source"] != orig.GetParameters()["source"] { -t.Errorf("Parameters[source] mismatch") -} + orig := &pb.PipelineTrigger{ + RequestId: "pip-001", + Pipeline: "document-ingestion", + Parameters: map[string]string{"source": "s3://bucket/data"}, + } + data, err := proto.Marshal(orig) + if err != nil { + t.Fatal(err) + } + var dec pb.PipelineTrigger + if err := proto.Unmarshal(data, &dec); err != nil { + t.Fatal(err) + } + if dec.GetPipeline() != orig.GetPipeline() { + t.Errorf("Pipeline = %q, want %q", dec.GetPipeline(), orig.GetPipeline()) + } + if dec.GetParameters()["source"] != orig.GetParameters()["source"] { + t.Errorf("Parameters[source] mismatch") + } } func TestRoundtrip_STTTranscription(t *testing.T) { -orig := &pb.STTTranscription{ -SessionId: "stt-001", -Transcript: "hello world", -Sequence: 5, -IsPartial: false, -IsFinal: true, -Timestamp: time.Now().Unix(), -SpeakerId: "speaker-1", -HasVoiceActivity: true, -State: "listening", -} -data, err := proto.Marshal(orig) -if err != nil { -t.Fatal(err) -} -var dec pb.STTTranscription -if err := proto.Unmarshal(data, &dec); err != nil { -t.Fatal(err) -} -if dec.GetTranscript() != orig.GetTranscript() { -t.Errorf("Transcript = %q, want %q", dec.GetTranscript(), orig.GetTranscript()) -} -if dec.GetIsFinal() != orig.GetIsFinal() { -t.Error("IsFinal mismatch") -} + orig := &pb.STTTranscription{ + SessionId: "stt-001", + Transcript: "hello world", + Sequence: 5, + IsPartial: false, + IsFinal: true, + Timestamp: time.Now().Unix(), + SpeakerId: "speaker-1", + HasVoiceActivity: true, + State: "listening", + } + data, err := proto.Marshal(orig) + if err != nil { + t.Fatal(err) + } + var dec pb.STTTranscription + if err := proto.Unmarshal(data, &dec); err != nil { + t.Fatal(err) + } + if dec.GetTranscript() != orig.GetTranscript() { + t.Errorf("Transcript = %q, want %q", dec.GetTranscript(), orig.GetTranscript()) + } + if dec.GetIsFinal() != orig.GetIsFinal() { + t.Error("IsFinal mismatch") + } } func TestRoundtrip_ErrorResponse(t *testing.T) { -orig := &pb.ErrorResponse{Error: true, Message: "something broke", Type: "InternalError"} -data, err := proto.Marshal(orig) -if err != nil { -t.Fatal(err) -} -var dec pb.ErrorResponse -if err := proto.Unmarshal(data, &dec); err != nil { -t.Fatal(err) -} -if !dec.GetError() || dec.GetMessage() != "something broke" || dec.GetType() != "InternalError" { -t.Errorf("ErrorResponse roundtrip mismatch: %+v", &dec) -} + orig := &pb.ErrorResponse{Error: true, Message: "something broke", Type: "InternalError"} + data, err := proto.Marshal(orig) + if err != nil { + t.Fatal(err) + } + var dec pb.ErrorResponse + if err := proto.Unmarshal(data, &dec); err != nil { + t.Fatal(err) + } + if !dec.GetError() || dec.GetMessage() != "something broke" || dec.GetType() != "InternalError" { + t.Errorf("ErrorResponse roundtrip mismatch: %+v", &dec) + } } func TestEffectiveQuery_MessageSet(t *testing.T) { -req := &pb.ChatRequest{Message: "hello", Query: "world"} -if got := EffectiveQuery(req); got != "hello" { -t.Errorf("EffectiveQuery() = %q, want %q", got, "hello") -} + req := &pb.ChatRequest{Message: "hello", Query: "world"} + if got := EffectiveQuery(req); got != "hello" { + t.Errorf("EffectiveQuery() = %q, want %q", got, "hello") + } } func TestEffectiveQuery_FallbackToQuery(t *testing.T) { -req := &pb.ChatRequest{Query: "world"} -if got := EffectiveQuery(req); got != "world" { -t.Errorf("EffectiveQuery() = %q, want %q", got, "world") -} + req := &pb.ChatRequest{Query: "world"} + if got := EffectiveQuery(req); got != "world" { + t.Errorf("EffectiveQuery() = %q, want %q", got, "world") + } } func TestTimestamp(t *testing.T) { -ts := Timestamp() -now := time.Now().Unix() -if ts < now-1 || ts > now+1 { -t.Errorf("Timestamp() = %d, expected ~%d", ts, now) -} + ts := Timestamp() + now := time.Now().Unix() + if ts < now-1 || ts > now+1 { + t.Errorf("Timestamp() = %d, expected ~%d", ts, now) + } } diff --git a/messages/messages.go b/messages/messages.go index 9b3ad37..02b7371 100644 --- a/messages/messages.go +++ b/messages/messages.go @@ -8,9 +8,9 @@ package messages import ( -"time" + "time" -pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb" + pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb" ) // ════════════════════════════════════════════════════════════════════════════ @@ -57,13 +57,13 @@ type PipelineStatus = pb.PipelineStatus // EffectiveQuery returns Message or falls back to Query. func EffectiveQuery(c *ChatRequest) string { -if c.GetMessage() != "" { -return c.GetMessage() -} -return c.GetQuery() + if c.GetMessage() != "" { + return c.GetMessage() + } + return c.GetQuery() } // Timestamp returns the current Unix timestamp. func Timestamp() int64 { -return time.Now().Unix() + return time.Now().Unix() } diff --git a/natsutil/natsutil.go b/natsutil/natsutil.go index 2f75714..43526ca 100644 --- a/natsutil/natsutil.go +++ b/natsutil/natsutil.go @@ -2,69 +2,69 @@ package natsutil import ( -"fmt" -"log/slog" -"time" + "fmt" + "log/slog" + "time" -"github.com/nats-io/nats.go" -"google.golang.org/protobuf/proto" + "github.com/nats-io/nats.go" + "google.golang.org/protobuf/proto" ) // Client wraps a NATS connection with protobuf helpers. type Client struct { -nc *nats.Conn -js nats.JetStreamContext -subs []*nats.Subscription -url string -opts []nats.Option + nc *nats.Conn + js nats.JetStreamContext + subs []*nats.Subscription + url string + opts []nats.Option } // New creates a NATS client configured to connect to the given URL. // Optional NATS options (e.g. credentials) can be appended. func New(url string, opts ...nats.Option) *Client { -defaults := []nats.Option{ -nats.ReconnectWait(2 * time.Second), -nats.MaxReconnects(-1), -nats.DisconnectErrHandler(func(_ *nats.Conn, err error) { -slog.Warn("NATS disconnected", "error", err) -}), -nats.ReconnectHandler(func(_ *nats.Conn) { -slog.Info("NATS reconnected") -}), -} -return &Client{ -url: url, -opts: append(defaults, opts...), -} + defaults := []nats.Option{ + nats.ReconnectWait(2 * time.Second), + nats.MaxReconnects(-1), + nats.DisconnectErrHandler(func(_ *nats.Conn, err error) { + slog.Warn("NATS disconnected", "error", err) + }), + nats.ReconnectHandler(func(_ *nats.Conn) { + slog.Info("NATS reconnected") + }), + } + return &Client{ + url: url, + opts: append(defaults, opts...), + } } // Connect establishes the NATS connection and JetStream context. func (c *Client) Connect() error { -nc, err := nats.Connect(c.url, c.opts...) -if err != nil { -return fmt.Errorf("nats connect: %w", err) -} -js, err := nc.JetStream() -if err != nil { -nc.Close() -return fmt.Errorf("jetstream: %w", err) -} -c.nc = nc -c.js = js -slog.Info("connected to NATS", "url", c.url) -return nil + nc, err := nats.Connect(c.url, c.opts...) + if err != nil { + return fmt.Errorf("nats connect: %w", err) + } + js, err := nc.JetStream() + if err != nil { + nc.Close() + return fmt.Errorf("jetstream: %w", err) + } + c.nc = nc + c.js = js + slog.Info("connected to NATS", "url", c.url) + return nil } // Close drains subscriptions and closes the connection. func (c *Client) Close() { -if c.nc == nil { -return -} -for _, sub := range c.subs { -_ = sub.Drain() -} -c.nc.Close() -slog.Info("NATS connection closed") + if c.nc == nil { + return + } + for _, sub := range c.subs { + _ = sub.Drain() + } + c.nc.Close() + slog.Info("NATS connection closed") } // Conn returns the underlying *nats.Conn. @@ -75,56 +75,56 @@ func (c *Client) JS() nats.JetStreamContext { return c.js } // IsConnected returns true if the NATS connection is active. func (c *Client) IsConnected() bool { -return c.nc != nil && c.nc.IsConnected() + return c.nc != nil && c.nc.IsConnected() } // Subscribe subscribes to a subject with an optional queue group. // The handler receives the raw *nats.Msg. func (c *Client) Subscribe(subject string, handler nats.MsgHandler, queue string) error { -var sub *nats.Subscription -var err error -if queue != "" { -sub, err = c.nc.QueueSubscribe(subject, queue, handler) -slog.Info("subscribed", "subject", subject, "queue", queue) -} else { -sub, err = c.nc.Subscribe(subject, handler) -slog.Info("subscribed", "subject", subject) -} -if err != nil { -return fmt.Errorf("subscribe %s: %w", subject, err) -} -c.subs = append(c.subs, sub) -return nil + var sub *nats.Subscription + var err error + if queue != "" { + sub, err = c.nc.QueueSubscribe(subject, queue, handler) + slog.Info("subscribed", "subject", subject, "queue", queue) + } else { + sub, err = c.nc.Subscribe(subject, handler) + slog.Info("subscribed", "subject", subject) + } + if err != nil { + return fmt.Errorf("subscribe %s: %w", subject, err) + } + c.subs = append(c.subs, sub) + return nil } // Publish encodes data as protobuf and publishes to the subject. func (c *Client) Publish(subject string, data proto.Message) error { -payload, err := proto.Marshal(data) -if err != nil { -return fmt.Errorf("proto marshal: %w", err) -} -return c.nc.Publish(subject, payload) + payload, err := proto.Marshal(data) + if err != nil { + return fmt.Errorf("proto marshal: %w", err) + } + return c.nc.Publish(subject, payload) } // PublishRaw publishes pre-encoded bytes to the subject. func (c *Client) PublishRaw(subject string, data []byte) error { -return c.nc.Publish(subject, data) + return c.nc.Publish(subject, data) } // Request sends a protobuf-encoded request and decodes the response into result. func (c *Client) Request(subject string, data proto.Message, result proto.Message, timeout time.Duration) error { -payload, err := proto.Marshal(data) -if err != nil { -return fmt.Errorf("proto marshal: %w", err) -} -msg, err := c.nc.Request(subject, payload, timeout) -if err != nil { -return fmt.Errorf("nats request: %w", err) -} -return proto.Unmarshal(msg.Data, result) + payload, err := proto.Marshal(data) + if err != nil { + return fmt.Errorf("proto marshal: %w", err) + } + msg, err := c.nc.Request(subject, payload, timeout) + if err != nil { + return fmt.Errorf("nats request: %w", err) + } + return proto.Unmarshal(msg.Data, result) } // Decode unmarshals protobuf bytes into dest. func Decode(data []byte, dest proto.Message) error { -return proto.Unmarshal(data, dest) + return proto.Unmarshal(data, dest) }