style: gofmt + fix errcheck lint warning
All checks were successful
CI / Test (push) Successful in 3m2s
CI / Lint (push) Successful in 3m7s
CI / Release (push) Successful in 1m55s
CI / Notify Downstream (stt-module) (push) Successful in 1s
CI / Notify Downstream (voice-assistant) (push) Successful in 1s
CI / Notify (push) Successful in 2s
CI / Notify Downstream (chat-handler) (push) Successful in 1s
CI / Notify Downstream (pipeline-bridge) (push) Successful in 1s
CI / Notify Downstream (tts-module) (push) Successful in 1s

This commit is contained in:
2026-02-21 15:35:37 -05:00
parent 13ef1df109
commit f1dd96a42b
8 changed files with 875 additions and 875 deletions

View File

@@ -563,7 +563,7 @@ func TestLLMClient_StreamGenerateManyTokens(t *testing.T) {
var order []int var order []int
result, err := c.StreamGenerate(context.Background(), "q", "", "", func(tok string) { result, err := c.StreamGenerate(context.Background(), "q", "", "", func(tok string) {
var idx int var idx int
fmt.Sscanf(tok, "t%d ", &idx) _, _ = fmt.Sscanf(tok, "t%d ", &idx)
mu.Lock() mu.Lock()
order = append(order, idx) order = append(order, idx)
mu.Unlock() mu.Unlock()

View File

@@ -3,16 +3,16 @@
package config package config
import ( import (
"context" "context"
"log/slog" "log/slog"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
) )
// Settings holds base configuration for all handler services. // Settings holds base configuration for all handler services.
@@ -20,249 +20,249 @@ import (
// updated at runtime via WatchSecrets(). All other fields are immutable // updated at runtime via WatchSecrets(). All other fields are immutable
// after Load() returns. // after Load() returns.
type Settings struct { type Settings struct {
// Service identification (immutable) // Service identification (immutable)
ServiceName string ServiceName string
ServiceVersion string ServiceVersion string
ServiceNamespace string ServiceNamespace string
DeploymentEnv string DeploymentEnv string
// NATS configuration (immutable) // NATS configuration (immutable)
NATSURL string NATSURL string
NATSUser string NATSUser string
NATSPassword string NATSPassword string
NATSQueueGroup string NATSQueueGroup string
// Redis/Valkey configuration (immutable) // Redis/Valkey configuration (immutable)
RedisURL string RedisURL string
RedisPassword string RedisPassword string
// Milvus configuration (immutable) // Milvus configuration (immutable)
MilvusHost string MilvusHost string
MilvusPort int MilvusPort int
MilvusCollection string MilvusCollection string
// OpenTelemetry configuration (immutable) // OpenTelemetry configuration (immutable)
OTELEnabled bool OTELEnabled bool
OTELEndpoint string OTELEndpoint string
OTELUseHTTP bool OTELUseHTTP bool
// HyperDX configuration (immutable) // HyperDX configuration (immutable)
HyperDXEnabled bool HyperDXEnabled bool
HyperDXAPIKey string HyperDXAPIKey string
HyperDXEndpoint string HyperDXEndpoint string
// MLflow configuration (immutable) // MLflow configuration (immutable)
MLflowTrackingURI string MLflowTrackingURI string
MLflowExperimentName string MLflowExperimentName string
MLflowEnabled bool MLflowEnabled bool
// Health check configuration (immutable) // Health check configuration (immutable)
HealthPort int HealthPort int
HealthPath string HealthPath string
ReadyPath string ReadyPath string
// Timeouts (immutable) // Timeouts (immutable)
HTTPTimeout time.Duration HTTPTimeout time.Duration
NATSTimeout time.Duration NATSTimeout time.Duration
// Hot-reloadable fields — access via getter methods. // Hot-reloadable fields — access via getter methods.
mu sync.RWMutex mu sync.RWMutex
embeddingsURL string embeddingsURL string
rerankerURL string rerankerURL string
llmURL string llmURL string
ttsURL string ttsURL string
sttURL string sttURL string
// Secrets path for file-based hot reload (Kubernetes secret mounts) // Secrets path for file-based hot reload (Kubernetes secret mounts)
SecretsPath string SecretsPath string
} }
// Load creates a Settings populated from environment variables with defaults. // Load creates a Settings populated from environment variables with defaults.
func Load() *Settings { func Load() *Settings {
return &Settings{ return &Settings{
ServiceName: getEnv("SERVICE_NAME", "handler"), ServiceName: getEnv("SERVICE_NAME", "handler"),
ServiceVersion: getEnv("SERVICE_VERSION", "1.0.0"), ServiceVersion: getEnv("SERVICE_VERSION", "1.0.0"),
ServiceNamespace: getEnv("SERVICE_NAMESPACE", "ai-ml"), ServiceNamespace: getEnv("SERVICE_NAMESPACE", "ai-ml"),
DeploymentEnv: getEnv("DEPLOYMENT_ENV", "production"), DeploymentEnv: getEnv("DEPLOYMENT_ENV", "production"),
NATSURL: getEnv("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222"), NATSURL: getEnv("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222"),
NATSUser: getEnv("NATS_USER", ""), NATSUser: getEnv("NATS_USER", ""),
NATSPassword: getEnv("NATS_PASSWORD", ""), NATSPassword: getEnv("NATS_PASSWORD", ""),
NATSQueueGroup: getEnv("NATS_QUEUE_GROUP", ""), NATSQueueGroup: getEnv("NATS_QUEUE_GROUP", ""),
RedisURL: getEnv("REDIS_URL", "redis://valkey.ai-ml.svc.cluster.local:6379"), RedisURL: getEnv("REDIS_URL", "redis://valkey.ai-ml.svc.cluster.local:6379"),
RedisPassword: getEnv("REDIS_PASSWORD", ""), RedisPassword: getEnv("REDIS_PASSWORD", ""),
MilvusHost: getEnv("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local"), MilvusHost: getEnv("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local"),
MilvusPort: getEnvInt("MILVUS_PORT", 19530), MilvusPort: getEnvInt("MILVUS_PORT", 19530),
MilvusCollection: getEnv("MILVUS_COLLECTION", "documents"), MilvusCollection: getEnv("MILVUS_COLLECTION", "documents"),
embeddingsURL: getEnv("EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"), embeddingsURL: getEnv("EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"),
rerankerURL: getEnv("RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"), rerankerURL: getEnv("RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"),
llmURL: getEnv("LLM_URL", "http://vllm-predictor.ai-ml.svc.cluster.local"), llmURL: getEnv("LLM_URL", "http://vllm-predictor.ai-ml.svc.cluster.local"),
ttsURL: getEnv("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local"), ttsURL: getEnv("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local"),
sttURL: getEnv("STT_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"), sttURL: getEnv("STT_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"),
OTELEnabled: getEnvBool("OTEL_ENABLED", true), OTELEnabled: getEnvBool("OTEL_ENABLED", true),
OTELEndpoint: getEnv("OTEL_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317"), OTELEndpoint: getEnv("OTEL_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317"),
OTELUseHTTP: getEnvBool("OTEL_USE_HTTP", false), OTELUseHTTP: getEnvBool("OTEL_USE_HTTP", false),
HyperDXEnabled: getEnvBool("HYPERDX_ENABLED", false), HyperDXEnabled: getEnvBool("HYPERDX_ENABLED", false),
HyperDXAPIKey: getEnv("HYPERDX_API_KEY", ""), HyperDXAPIKey: getEnv("HYPERDX_API_KEY", ""),
HyperDXEndpoint: getEnv("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io"), HyperDXEndpoint: getEnv("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io"),
MLflowTrackingURI: getEnv("MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"), MLflowTrackingURI: getEnv("MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"),
MLflowExperimentName: getEnv("MLFLOW_EXPERIMENT_NAME", ""), MLflowExperimentName: getEnv("MLFLOW_EXPERIMENT_NAME", ""),
MLflowEnabled: getEnvBool("MLFLOW_ENABLED", true), MLflowEnabled: getEnvBool("MLFLOW_ENABLED", true),
HealthPort: getEnvInt("HEALTH_PORT", 8080), HealthPort: getEnvInt("HEALTH_PORT", 8080),
HealthPath: getEnv("HEALTH_PATH", "/health"), HealthPath: getEnv("HEALTH_PATH", "/health"),
ReadyPath: getEnv("READY_PATH", "/ready"), ReadyPath: getEnv("READY_PATH", "/ready"),
HTTPTimeout: getEnvDuration("HTTP_TIMEOUT", 60*time.Second), HTTPTimeout: getEnvDuration("HTTP_TIMEOUT", 60*time.Second),
NATSTimeout: getEnvDuration("NATS_TIMEOUT", 30*time.Second), NATSTimeout: getEnvDuration("NATS_TIMEOUT", 30*time.Second),
SecretsPath: getEnv("SECRETS_PATH", ""), SecretsPath: getEnv("SECRETS_PATH", ""),
} }
} }
// EmbeddingsURL returns the current embeddings service URL (thread-safe). // EmbeddingsURL returns the current embeddings service URL (thread-safe).
func (s *Settings) EmbeddingsURL() string { func (s *Settings) EmbeddingsURL() string {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
return s.embeddingsURL return s.embeddingsURL
} }
// RerankerURL returns the current reranker service URL (thread-safe). // RerankerURL returns the current reranker service URL (thread-safe).
func (s *Settings) RerankerURL() string { func (s *Settings) RerankerURL() string {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
return s.rerankerURL return s.rerankerURL
} }
// LLMURL returns the current LLM service URL (thread-safe). // LLMURL returns the current LLM service URL (thread-safe).
func (s *Settings) LLMURL() string { func (s *Settings) LLMURL() string {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
return s.llmURL return s.llmURL
} }
// TTSURL returns the current TTS service URL (thread-safe). // TTSURL returns the current TTS service URL (thread-safe).
func (s *Settings) TTSURL() string { func (s *Settings) TTSURL() string {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
return s.ttsURL return s.ttsURL
} }
// STTURL returns the current STT service URL (thread-safe). // STTURL returns the current STT service URL (thread-safe).
func (s *Settings) STTURL() string { func (s *Settings) STTURL() string {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
return s.sttURL return s.sttURL
} }
// WatchSecrets watches the SecretsPath directory for changes and reloads // WatchSecrets watches the SecretsPath directory for changes and reloads
// hot-reloadable fields. Blocks until ctx is cancelled. // hot-reloadable fields. Blocks until ctx is cancelled.
func (s *Settings) WatchSecrets(ctx context.Context) { func (s *Settings) WatchSecrets(ctx context.Context) {
if s.SecretsPath == "" { if s.SecretsPath == "" {
return return
} }
watcher, err := fsnotify.NewWatcher() watcher, err := fsnotify.NewWatcher()
if err != nil { if err != nil {
slog.Error("config: failed to create fsnotify watcher", "error", err) slog.Error("config: failed to create fsnotify watcher", "error", err)
return return
} }
defer func() { _ = watcher.Close() }() defer func() { _ = watcher.Close() }()
if err := watcher.Add(s.SecretsPath); err != nil { if err := watcher.Add(s.SecretsPath); err != nil {
slog.Error("config: failed to watch secrets path", "error", err, "path", s.SecretsPath) slog.Error("config: failed to watch secrets path", "error", err, "path", s.SecretsPath)
return return
} }
slog.Info("config: watching secrets for hot reload", "path", s.SecretsPath) slog.Info("config: watching secrets for hot reload", "path", s.SecretsPath)
for { for {
select { select {
case event, ok := <-watcher.Events: case event, ok := <-watcher.Events:
if !ok { if !ok {
return return
} }
if event.Has(fsnotify.Create) || event.Has(fsnotify.Write) { if event.Has(fsnotify.Create) || event.Has(fsnotify.Write) {
s.reloadFromSecrets() s.reloadFromSecrets()
} }
case err, ok := <-watcher.Errors: case err, ok := <-watcher.Errors:
if !ok { if !ok {
return return
} }
slog.Error("config: fsnotify error", "error", err) slog.Error("config: fsnotify error", "error", err)
case <-ctx.Done(): case <-ctx.Done():
return return
} }
} }
} }
// reloadFromSecrets reads hot-reloadable values from the secrets directory. // reloadFromSecrets reads hot-reloadable values from the secrets directory.
func (s *Settings) reloadFromSecrets() { func (s *Settings) reloadFromSecrets() {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
updated := 0 updated := 0
reload := func(filename string, target *string) { reload := func(filename string, target *string) {
path := filepath.Join(s.SecretsPath, filename) path := filepath.Join(s.SecretsPath, filename)
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
return return
} }
val := strings.TrimSpace(string(data)) val := strings.TrimSpace(string(data))
if val != "" && val != *target { if val != "" && val != *target {
*target = val *target = val
updated++ updated++
slog.Info("config: reloaded secret", "key", filename) slog.Info("config: reloaded secret", "key", filename)
} }
} }
reload("embeddings-url", &s.embeddingsURL) reload("embeddings-url", &s.embeddingsURL)
reload("reranker-url", &s.rerankerURL) reload("reranker-url", &s.rerankerURL)
reload("llm-url", &s.llmURL) reload("llm-url", &s.llmURL)
reload("tts-url", &s.ttsURL) reload("tts-url", &s.ttsURL)
reload("stt-url", &s.sttURL) reload("stt-url", &s.sttURL)
if updated > 0 { if updated > 0 {
slog.Info("config: secrets reloaded", "updated", updated) slog.Info("config: secrets reloaded", "updated", updated)
} }
} }
func getEnv(key, fallback string) string { func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" { if v := os.Getenv(key); v != "" {
return v return v
} }
return fallback return fallback
} }
func getEnvInt(key string, fallback int) int { func getEnvInt(key string, fallback int) int {
if v := os.Getenv(key); v != "" { if v := os.Getenv(key); v != "" {
if i, err := strconv.Atoi(v); err == nil { if i, err := strconv.Atoi(v); err == nil {
return i return i
} }
} }
return fallback return fallback
} }
func getEnvBool(key string, fallback bool) bool { func getEnvBool(key string, fallback bool) bool {
if v := os.Getenv(key); v != "" { if v := os.Getenv(key); v != "" {
if b, err := strconv.ParseBool(v); err == nil { if b, err := strconv.ParseBool(v); err == nil {
return b return b
} }
} }
return fallback return fallback
} }
func getEnvDuration(key string, fallback time.Duration) time.Duration { func getEnvDuration(key string, fallback time.Duration) time.Duration {
if v := os.Getenv(key); v != "" { if v := os.Getenv(key); v != "" {
if f, err := strconv.ParseFloat(v, 64); err == nil { if f, err := strconv.ParseFloat(v, 64); err == nil {
return time.Duration(f * float64(time.Second)) return time.Duration(f * float64(time.Second))
} }
} }
return fallback return fallback
} }

View File

@@ -1,123 +1,123 @@
package config package config
import ( import (
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"time" "time"
) )
func TestLoadDefaults(t *testing.T) { func TestLoadDefaults(t *testing.T) {
s := Load() s := Load()
if s.ServiceName != "handler" { if s.ServiceName != "handler" {
t.Errorf("expected default ServiceName 'handler', got %q", s.ServiceName) t.Errorf("expected default ServiceName 'handler', got %q", s.ServiceName)
} }
if s.HealthPort != 8080 { if s.HealthPort != 8080 {
t.Errorf("expected default HealthPort 8080, got %d", s.HealthPort) t.Errorf("expected default HealthPort 8080, got %d", s.HealthPort)
} }
if s.HTTPTimeout != 60*time.Second { if s.HTTPTimeout != 60*time.Second {
t.Errorf("expected default HTTPTimeout 60s, got %v", s.HTTPTimeout) t.Errorf("expected default HTTPTimeout 60s, got %v", s.HTTPTimeout)
} }
} }
func TestLoadFromEnv(t *testing.T) { func TestLoadFromEnv(t *testing.T) {
t.Setenv("SERVICE_NAME", "test-svc") t.Setenv("SERVICE_NAME", "test-svc")
t.Setenv("HEALTH_PORT", "9090") t.Setenv("HEALTH_PORT", "9090")
t.Setenv("OTEL_ENABLED", "false") t.Setenv("OTEL_ENABLED", "false")
s := Load() s := Load()
if s.ServiceName != "test-svc" { if s.ServiceName != "test-svc" {
t.Errorf("expected ServiceName 'test-svc', got %q", s.ServiceName) t.Errorf("expected ServiceName 'test-svc', got %q", s.ServiceName)
} }
if s.HealthPort != 9090 { if s.HealthPort != 9090 {
t.Errorf("expected HealthPort 9090, got %d", s.HealthPort) t.Errorf("expected HealthPort 9090, got %d", s.HealthPort)
} }
if s.OTELEnabled { if s.OTELEnabled {
t.Error("expected OTELEnabled false") t.Error("expected OTELEnabled false")
} }
} }
func TestURLGetters(t *testing.T) { func TestURLGetters(t *testing.T) {
s := Load() s := Load()
if s.EmbeddingsURL() == "" { if s.EmbeddingsURL() == "" {
t.Error("EmbeddingsURL should have a default") t.Error("EmbeddingsURL should have a default")
} }
if s.RerankerURL() == "" { if s.RerankerURL() == "" {
t.Error("RerankerURL should have a default") t.Error("RerankerURL should have a default")
} }
if s.LLMURL() == "" { if s.LLMURL() == "" {
t.Error("LLMURL should have a default") t.Error("LLMURL should have a default")
} }
if s.TTSURL() == "" { if s.TTSURL() == "" {
t.Error("TTSURL should have a default") t.Error("TTSURL should have a default")
} }
if s.STTURL() == "" { if s.STTURL() == "" {
t.Error("STTURL should have a default") t.Error("STTURL should have a default")
} }
} }
func TestURLGettersFromEnv(t *testing.T) { func TestURLGettersFromEnv(t *testing.T) {
t.Setenv("EMBEDDINGS_URL", "http://embed:8000") t.Setenv("EMBEDDINGS_URL", "http://embed:8000")
t.Setenv("LLM_URL", "http://llm:9000") t.Setenv("LLM_URL", "http://llm:9000")
s := Load() s := Load()
if s.EmbeddingsURL() != "http://embed:8000" { if s.EmbeddingsURL() != "http://embed:8000" {
t.Errorf("expected custom EmbeddingsURL, got %q", s.EmbeddingsURL()) t.Errorf("expected custom EmbeddingsURL, got %q", s.EmbeddingsURL())
} }
if s.LLMURL() != "http://llm:9000" { if s.LLMURL() != "http://llm:9000" {
t.Errorf("expected custom LLMURL, got %q", s.LLMURL()) t.Errorf("expected custom LLMURL, got %q", s.LLMURL())
} }
} }
func TestReloadFromSecrets(t *testing.T) { func TestReloadFromSecrets(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
// Write initial secret files // Write initial secret files
writeSecret(t, dir, "embeddings-url", "http://old-embed:8000") writeSecret(t, dir, "embeddings-url", "http://old-embed:8000")
writeSecret(t, dir, "llm-url", "http://old-llm:9000") writeSecret(t, dir, "llm-url", "http://old-llm:9000")
s := Load() s := Load()
s.SecretsPath = dir s.SecretsPath = dir
s.reloadFromSecrets() s.reloadFromSecrets()
if s.EmbeddingsURL() != "http://old-embed:8000" { if s.EmbeddingsURL() != "http://old-embed:8000" {
t.Errorf("expected reloaded EmbeddingsURL, got %q", s.EmbeddingsURL()) t.Errorf("expected reloaded EmbeddingsURL, got %q", s.EmbeddingsURL())
} }
if s.LLMURL() != "http://old-llm:9000" { if s.LLMURL() != "http://old-llm:9000" {
t.Errorf("expected reloaded LLMURL, got %q", s.LLMURL()) t.Errorf("expected reloaded LLMURL, got %q", s.LLMURL())
} }
// Simulate secret update // Simulate secret update
writeSecret(t, dir, "embeddings-url", "http://new-embed:8000") writeSecret(t, dir, "embeddings-url", "http://new-embed:8000")
s.reloadFromSecrets() s.reloadFromSecrets()
if s.EmbeddingsURL() != "http://new-embed:8000" { if s.EmbeddingsURL() != "http://new-embed:8000" {
t.Errorf("expected updated EmbeddingsURL, got %q", s.EmbeddingsURL()) t.Errorf("expected updated EmbeddingsURL, got %q", s.EmbeddingsURL())
} }
// LLM should remain unchanged // LLM should remain unchanged
if s.LLMURL() != "http://old-llm:9000" { if s.LLMURL() != "http://old-llm:9000" {
t.Errorf("expected unchanged LLMURL, got %q", s.LLMURL()) t.Errorf("expected unchanged LLMURL, got %q", s.LLMURL())
} }
} }
func TestReloadFromSecretsNoPath(t *testing.T) { func TestReloadFromSecretsNoPath(t *testing.T) {
s := Load() s := Load()
s.SecretsPath = "" s.SecretsPath = ""
// Should not panic // Should not panic
s.reloadFromSecrets() s.reloadFromSecrets()
} }
func TestGetEnvDuration(t *testing.T) { func TestGetEnvDuration(t *testing.T) {
t.Setenv("TEST_DUR", "30") t.Setenv("TEST_DUR", "30")
d := getEnvDuration("TEST_DUR", 10*time.Second) d := getEnvDuration("TEST_DUR", 10*time.Second)
if d != 30*time.Second { if d != 30*time.Second {
t.Errorf("expected 30s, got %v", d) t.Errorf("expected 30s, got %v", d)
} }
} }
func writeSecret(t *testing.T, dir, name, value string) { func writeSecret(t *testing.T, dir, name, value string) {
t.Helper() t.Helper()
if err := os.WriteFile(filepath.Join(dir, name), []byte(value), 0644); err != nil { if err := os.WriteFile(filepath.Join(dir, name), []byte(value), 0644); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }

View File

@@ -2,21 +2,21 @@
package handler package handler
import ( import (
"context" "context"
"fmt" "fmt"
"log/slog" "log/slog"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"git.daviestechlabs.io/daviestechlabs/handler-base/config" "git.daviestechlabs.io/daviestechlabs/handler-base/config"
pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb" pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
"git.daviestechlabs.io/daviestechlabs/handler-base/health" "git.daviestechlabs.io/daviestechlabs/handler-base/health"
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil" "git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
"git.daviestechlabs.io/daviestechlabs/handler-base/telemetry" "git.daviestechlabs.io/daviestechlabs/handler-base/telemetry"
) )
// TypedMessageHandler processes the raw NATS message. // TypedMessageHandler processes the raw NATS message.
@@ -32,36 +32,36 @@ type TeardownFunc func(ctx context.Context) error
// Handler is the base service runner that wires NATS, health, and telemetry. // Handler is the base service runner that wires NATS, health, and telemetry.
type Handler struct { type Handler struct {
Settings *config.Settings Settings *config.Settings
NATS *natsutil.Client NATS *natsutil.Client
Telemetry *telemetry.Provider Telemetry *telemetry.Provider
Subject string Subject string
QueueGroup string QueueGroup string
onSetup SetupFunc onSetup SetupFunc
onTeardown TeardownFunc onTeardown TeardownFunc
onTypedMessage TypedMessageHandler onTypedMessage TypedMessageHandler
running bool running bool
} }
// New creates a Handler for the given NATS subject. // New creates a Handler for the given NATS subject.
func New(subject string, settings *config.Settings) *Handler { func New(subject string, settings *config.Settings) *Handler {
if settings == nil { if settings == nil {
settings = config.Load() settings = config.Load()
} }
queueGroup := settings.NATSQueueGroup queueGroup := settings.NATSQueueGroup
natsOpts := []nats.Option{} natsOpts := []nats.Option{}
if settings.NATSUser != "" && settings.NATSPassword != "" { if settings.NATSUser != "" && settings.NATSPassword != "" {
natsOpts = append(natsOpts, nats.UserInfo(settings.NATSUser, settings.NATSPassword)) natsOpts = append(natsOpts, nats.UserInfo(settings.NATSUser, settings.NATSPassword))
} }
return &Handler{ return &Handler{
Settings: settings, Settings: settings,
Subject: subject, Subject: subject,
QueueGroup: queueGroup, QueueGroup: queueGroup,
NATS: natsutil.New(settings.NATSURL, natsOpts...), NATS: natsutil.New(settings.NATSURL, natsOpts...),
} }
} }
// OnSetup registers the setup callback. // OnSetup registers the setup callback.
@@ -75,101 +75,101 @@ func (h *Handler) OnTypedMessage(fn TypedMessageHandler) { h.onTypedMessage = fn
// Run starts the handler: telemetry, health server, NATS subscription, and blocks until SIGTERM/SIGINT. // Run starts the handler: telemetry, health server, NATS subscription, and blocks until SIGTERM/SIGINT.
func (h *Handler) Run() error { func (h *Handler) Run() error {
// Structured logging // Structured logging
slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo}))) slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})))
slog.Info("starting service", "name", h.Settings.ServiceName, "version", h.Settings.ServiceVersion) slog.Info("starting service", "name", h.Settings.ServiceName, "version", h.Settings.ServiceVersion)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
// Telemetry // Telemetry
tp, shutdown, err := telemetry.Setup(ctx, telemetry.Config{ tp, shutdown, err := telemetry.Setup(ctx, telemetry.Config{
ServiceName: h.Settings.ServiceName, ServiceName: h.Settings.ServiceName,
ServiceVersion: h.Settings.ServiceVersion, ServiceVersion: h.Settings.ServiceVersion,
ServiceNamespace: h.Settings.ServiceNamespace, ServiceNamespace: h.Settings.ServiceNamespace,
DeploymentEnv: h.Settings.DeploymentEnv, DeploymentEnv: h.Settings.DeploymentEnv,
Enabled: h.Settings.OTELEnabled, Enabled: h.Settings.OTELEnabled,
Endpoint: h.Settings.OTELEndpoint, Endpoint: h.Settings.OTELEndpoint,
}) })
if err != nil { if err != nil {
return fmt.Errorf("telemetry setup: %w", err) return fmt.Errorf("telemetry setup: %w", err)
} }
defer shutdown(ctx) defer shutdown(ctx)
h.Telemetry = tp h.Telemetry = tp
// Health server // Health server
healthSrv := health.New( healthSrv := health.New(
h.Settings.HealthPort, h.Settings.HealthPort,
h.Settings.HealthPath, h.Settings.HealthPath,
h.Settings.ReadyPath, h.Settings.ReadyPath,
func() bool { return h.running && h.NATS.IsConnected() }, func() bool { return h.running && h.NATS.IsConnected() },
) )
healthSrv.Start() healthSrv.Start()
defer healthSrv.Stop(ctx) defer healthSrv.Stop(ctx)
// Connect to NATS // Connect to NATS
if err := h.NATS.Connect(); err != nil { if err := h.NATS.Connect(); err != nil {
return fmt.Errorf("nats: %w", err) return fmt.Errorf("nats: %w", err)
} }
defer h.NATS.Close() defer h.NATS.Close()
// User setup // User setup
if h.onSetup != nil { if h.onSetup != nil {
slog.Info("running service setup") slog.Info("running service setup")
if err := h.onSetup(ctx); err != nil { if err := h.onSetup(ctx); err != nil {
return fmt.Errorf("setup: %w", err) return fmt.Errorf("setup: %w", err)
} }
} }
// Subscribe // Subscribe
if h.onTypedMessage == nil { if h.onTypedMessage == nil {
return fmt.Errorf("no message handler registered") return fmt.Errorf("no message handler registered")
} }
if err := h.NATS.Subscribe(h.Subject, h.wrapHandler(ctx), h.QueueGroup); err != nil { if err := h.NATS.Subscribe(h.Subject, h.wrapHandler(ctx), h.QueueGroup); err != nil {
return fmt.Errorf("subscribe: %w", err) return fmt.Errorf("subscribe: %w", err)
} }
h.running = true h.running = true
slog.Info("handler ready", "subject", h.Subject) slog.Info("handler ready", "subject", h.Subject)
// Wait for shutdown signal // Wait for shutdown signal
sigCh := make(chan os.Signal, 1) sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT) signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
<-sigCh <-sigCh
slog.Info("shutting down") slog.Info("shutting down")
h.running = false h.running = false
// Teardown // Teardown
if h.onTeardown != nil { if h.onTeardown != nil {
if err := h.onTeardown(ctx); err != nil { if err := h.onTeardown(ctx); err != nil {
slog.Warn("teardown error", "error", err) slog.Warn("teardown error", "error", err)
} }
} }
slog.Info("shutdown complete") slog.Info("shutdown complete")
return nil return nil
} }
// wrapHandler creates a nats.MsgHandler that dispatches to the registered callback. // wrapHandler creates a nats.MsgHandler that dispatches to the registered callback.
func (h *Handler) wrapHandler(ctx context.Context) nats.MsgHandler { func (h *Handler) wrapHandler(ctx context.Context) nats.MsgHandler {
return func(msg *nats.Msg) { return func(msg *nats.Msg) {
response, err := h.onTypedMessage(ctx, msg) response, err := h.onTypedMessage(ctx, msg)
if err != nil { if err != nil {
slog.Error("handler error", "subject", msg.Subject, "error", err) slog.Error("handler error", "subject", msg.Subject, "error", err)
if msg.Reply != "" { if msg.Reply != "" {
_ = h.NATS.Publish(msg.Reply, &pb.ErrorResponse{ _ = h.NATS.Publish(msg.Reply, &pb.ErrorResponse{
Error: true, Error: true,
Message: err.Error(), Message: err.Error(),
Type: fmt.Sprintf("%T", err), Type: fmt.Sprintf("%T", err),
}) })
} }
return return
} }
if response != nil && msg.Reply != "" { if response != nil && msg.Reply != "" {
if err := h.NATS.Publish(msg.Reply, response); err != nil { if err := h.NATS.Publish(msg.Reply, response); err != nil {
slog.Error("failed to publish reply", "error", err) slog.Error("failed to publish reply", "error", err)
} }
} }
} }
} }

View File

@@ -1,15 +1,15 @@
package handler package handler
import ( import (
"context" "context"
"testing" "testing"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"git.daviestechlabs.io/daviestechlabs/handler-base/config" "git.daviestechlabs.io/daviestechlabs/handler-base/config"
pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb" pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil" "git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
) )
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
@@ -17,75 +17,75 @@ pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
func TestNewHandler(t *testing.T) { func TestNewHandler(t *testing.T) {
cfg := config.Load() cfg := config.Load()
cfg.ServiceName = "test-handler" cfg.ServiceName = "test-handler"
cfg.NATSQueueGroup = "test-group" cfg.NATSQueueGroup = "test-group"
h := New("ai.test.subject", cfg) h := New("ai.test.subject", cfg)
if h.Subject != "ai.test.subject" { if h.Subject != "ai.test.subject" {
t.Errorf("Subject = %q", h.Subject) t.Errorf("Subject = %q", h.Subject)
} }
if h.QueueGroup != "test-group" { if h.QueueGroup != "test-group" {
t.Errorf("QueueGroup = %q", h.QueueGroup) t.Errorf("QueueGroup = %q", h.QueueGroup)
} }
if h.Settings.ServiceName != "test-handler" { if h.Settings.ServiceName != "test-handler" {
t.Errorf("ServiceName = %q", h.Settings.ServiceName) t.Errorf("ServiceName = %q", h.Settings.ServiceName)
} }
} }
func TestNewHandlerNilSettings(t *testing.T) { func TestNewHandlerNilSettings(t *testing.T) {
h := New("ai.test", nil) h := New("ai.test", nil)
if h.Settings == nil { if h.Settings == nil {
t.Fatal("Settings should be loaded automatically") t.Fatal("Settings should be loaded automatically")
} }
if h.Settings.ServiceName != "handler" { if h.Settings.ServiceName != "handler" {
t.Errorf("ServiceName = %q, want default", h.Settings.ServiceName) t.Errorf("ServiceName = %q, want default", h.Settings.ServiceName)
} }
} }
func TestCallbackRegistration(t *testing.T) { func TestCallbackRegistration(t *testing.T) {
cfg := config.Load() cfg := config.Load()
h := New("ai.test", cfg) h := New("ai.test", cfg)
setupCalled := false setupCalled := false
h.OnSetup(func(ctx context.Context) error { h.OnSetup(func(ctx context.Context) error {
setupCalled = true setupCalled = true
return nil return nil
}) })
teardownCalled := false teardownCalled := false
h.OnTeardown(func(ctx context.Context) error { h.OnTeardown(func(ctx context.Context) error {
teardownCalled = true teardownCalled = true
return nil return nil
}) })
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
return nil, nil return nil, nil
}) })
if h.onSetup == nil || h.onTeardown == nil || h.onTypedMessage == nil { if h.onSetup == nil || h.onTeardown == nil || h.onTypedMessage == nil {
t.Error("callbacks should not be nil after registration") t.Error("callbacks should not be nil after registration")
} }
// Verify setup/teardown work when called directly. // Verify setup/teardown work when called directly.
_ = h.onSetup(context.Background()) _ = h.onSetup(context.Background())
_ = h.onTeardown(context.Background()) _ = h.onTeardown(context.Background())
if !setupCalled || !teardownCalled { if !setupCalled || !teardownCalled {
t.Error("callbacks should have been invoked") t.Error("callbacks should have been invoked")
} }
} }
func TestTypedMessageRegistration(t *testing.T) { func TestTypedMessageRegistration(t *testing.T) {
cfg := config.Load() cfg := config.Load()
h := New("ai.test", cfg) h := New("ai.test", cfg)
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
return &pb.ChatResponse{Response: "ok"}, nil return &pb.ChatResponse{Response: "ok"}, nil
}) })
if h.onTypedMessage == nil { if h.onTypedMessage == nil {
t.Error("onTypedMessage should not be nil after registration") t.Error("onTypedMessage should not be nil after registration")
} }
} }
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
@@ -93,101 +93,101 @@ t.Error("onTypedMessage should not be nil after registration")
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
func TestWrapHandler_ValidMessage(t *testing.T) { func TestWrapHandler_ValidMessage(t *testing.T) {
cfg := config.Load() cfg := config.Load()
h := New("ai.test", cfg) h := New("ai.test", cfg)
var receivedReq pb.ChatRequest var receivedReq pb.ChatRequest
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
if err := natsutil.Decode(msg.Data, &receivedReq); err != nil { if err := natsutil.Decode(msg.Data, &receivedReq); err != nil {
return nil, err return nil, err
} }
return &pb.ChatResponse{Response: "ok", UserId: receivedReq.GetUserId()}, nil return &pb.ChatResponse{Response: "ok", UserId: receivedReq.GetUserId()}, nil
}) })
// Encode a message the same way services would. // Encode a message the same way services would.
encoded, err := proto.Marshal(&pb.ChatRequest{ encoded, err := proto.Marshal(&pb.ChatRequest{
RequestId: "test-001", RequestId: "test-001",
Message: "hello", Message: "hello",
Premium: true, Premium: true,
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Call wrapHandler directly without NATS. // Call wrapHandler directly without NATS.
handler := h.wrapHandler(context.Background()) handler := h.wrapHandler(context.Background())
handler(&nats.Msg{ handler(&nats.Msg{
Subject: "ai.test.user.42.message", Subject: "ai.test.user.42.message",
Data: encoded, Data: encoded,
}) })
if receivedReq.GetRequestId() != "test-001" { if receivedReq.GetRequestId() != "test-001" {
t.Errorf("request_id = %v", receivedReq.GetRequestId()) t.Errorf("request_id = %v", receivedReq.GetRequestId())
} }
if receivedReq.GetPremium() != true { if receivedReq.GetPremium() != true {
t.Errorf("premium = %v", receivedReq.GetPremium()) t.Errorf("premium = %v", receivedReq.GetPremium())
} }
} }
func TestWrapHandler_InvalidMessage(t *testing.T) { func TestWrapHandler_InvalidMessage(t *testing.T) {
cfg := config.Load() cfg := config.Load()
h := New("ai.test", cfg) h := New("ai.test", cfg)
handlerCalled := false handlerCalled := false
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
handlerCalled = true handlerCalled = true
var req pb.ChatRequest var req pb.ChatRequest
if err := natsutil.Decode(msg.Data, &req); err != nil { if err := natsutil.Decode(msg.Data, &req); err != nil {
return nil, err return nil, err
} }
return &pb.ChatResponse{}, nil return &pb.ChatResponse{}, nil
}) })
handler := h.wrapHandler(context.Background()) handler := h.wrapHandler(context.Background())
handler(&nats.Msg{ handler(&nats.Msg{
Subject: "ai.test", Subject: "ai.test",
Data: []byte{0xFF, 0xFE, 0xFD}, // invalid protobuf Data: []byte{0xFF, 0xFE, 0xFD}, // invalid protobuf
}) })
// The handler IS called (wrapHandler doesn't pre-decode), but it should // The handler IS called (wrapHandler doesn't pre-decode), but it should
// return an error from Decode. Either way no panic. // return an error from Decode. Either way no panic.
_ = handlerCalled _ = handlerCalled
} }
func TestWrapHandler_HandlerError(t *testing.T) { func TestWrapHandler_HandlerError(t *testing.T) {
cfg := config.Load() cfg := config.Load()
h := New("ai.test", cfg) h := New("ai.test", cfg)
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
return nil, context.DeadlineExceeded return nil, context.DeadlineExceeded
}) })
encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "err-test"}) encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "err-test"})
handler := h.wrapHandler(context.Background()) handler := h.wrapHandler(context.Background())
// Should not panic even when handler returns error. // Should not panic even when handler returns error.
handler(&nats.Msg{ handler(&nats.Msg{
Subject: "ai.test", Subject: "ai.test",
Data: encoded, Data: encoded,
}) })
} }
func TestWrapHandler_NilResponse(t *testing.T) { func TestWrapHandler_NilResponse(t *testing.T) {
cfg := config.Load() cfg := config.Load()
h := New("ai.test", cfg) h := New("ai.test", cfg)
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
return nil, nil // fire-and-forget style return nil, nil // fire-and-forget style
}) })
encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "nil-resp"}) encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "nil-resp"})
handler := h.wrapHandler(context.Background()) handler := h.wrapHandler(context.Background())
// Should not panic with nil response and no reply subject. // Should not panic with nil response and no reply subject.
handler(&nats.Msg{ handler(&nats.Msg{
Subject: "ai.test", Subject: "ai.test",
Data: encoded, Data: encoded,
}) })
} }
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
@@ -195,59 +195,59 @@ Data: encoded,
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
func TestWrapHandler_Typed(t *testing.T) { func TestWrapHandler_Typed(t *testing.T) {
cfg := config.Load() cfg := config.Load()
h := New("ai.test", cfg) h := New("ai.test", cfg)
var received pb.ChatRequest var received pb.ChatRequest
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
if err := natsutil.Decode(msg.Data, &received); err != nil { if err := natsutil.Decode(msg.Data, &received); err != nil {
return nil, err return nil, err
} }
return &pb.ChatResponse{UserId: received.GetUserId(), Response: "ok"}, nil return &pb.ChatResponse{UserId: received.GetUserId(), Response: "ok"}, nil
}) })
encoded, _ := proto.Marshal(&pb.ChatRequest{ encoded, _ := proto.Marshal(&pb.ChatRequest{
RequestId: "typed-001", RequestId: "typed-001",
Message: "hello typed", Message: "hello typed",
}) })
handler := h.wrapHandler(context.Background()) handler := h.wrapHandler(context.Background())
handler(&nats.Msg{Subject: "ai.test", Data: encoded}) handler(&nats.Msg{Subject: "ai.test", Data: encoded})
if received.GetRequestId() != "typed-001" { if received.GetRequestId() != "typed-001" {
t.Errorf("RequestId = %q", received.GetRequestId()) t.Errorf("RequestId = %q", received.GetRequestId())
} }
if received.GetMessage() != "hello typed" { if received.GetMessage() != "hello typed" {
t.Errorf("Message = %q", received.GetMessage()) t.Errorf("Message = %q", received.GetMessage())
} }
} }
func TestWrapHandler_TypedError(t *testing.T) { func TestWrapHandler_TypedError(t *testing.T) {
cfg := config.Load() cfg := config.Load()
h := New("ai.test", cfg) h := New("ai.test", cfg)
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
return nil, context.DeadlineExceeded return nil, context.DeadlineExceeded
}) })
encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "err"}) encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "err"})
handler := h.wrapHandler(context.Background()) handler := h.wrapHandler(context.Background())
// Should not panic. // Should not panic.
handler(&nats.Msg{Subject: "ai.test", Data: encoded}) handler(&nats.Msg{Subject: "ai.test", Data: encoded})
} }
func TestWrapHandler_TypedNilResponse(t *testing.T) { func TestWrapHandler_TypedNilResponse(t *testing.T) {
cfg := config.Load() cfg := config.Load()
h := New("ai.test", cfg) h := New("ai.test", cfg)
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
return nil, nil return nil, nil
}) })
encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "nil"}) encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "nil"})
handler := h.wrapHandler(context.Background()) handler := h.wrapHandler(context.Background())
handler(&nats.Msg{Subject: "ai.test", Data: encoded}) handler(&nats.Msg{Subject: "ai.test", Data: encoded})
} }
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
@@ -255,25 +255,25 @@ handler(&nats.Msg{Subject: "ai.test", Data: encoded})
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
func BenchmarkWrapHandler(b *testing.B) { func BenchmarkWrapHandler(b *testing.B) {
cfg := config.Load() cfg := config.Load()
h := New("ai.test", cfg) h := New("ai.test", cfg)
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) { h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
var req pb.ChatRequest var req pb.ChatRequest
_ = natsutil.Decode(msg.Data, &req) _ = natsutil.Decode(msg.Data, &req)
return &pb.ChatResponse{Response: "ok"}, nil return &pb.ChatResponse{Response: "ok"}, nil
}) })
encoded, _ := proto.Marshal(&pb.ChatRequest{ encoded, _ := proto.Marshal(&pb.ChatRequest{
RequestId: "bench-001", RequestId: "bench-001",
Message: "What is the capital of France?", Message: "What is the capital of France?",
Premium: true, Premium: true,
TopK: 10, TopK: 10,
}) })
handler := h.wrapHandler(context.Background()) handler := h.wrapHandler(context.Background())
msg := &nats.Msg{Subject: "ai.test", Data: encoded} msg := &nats.Msg{Subject: "ai.test", Data: encoded}
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
handler(msg) handler(msg)
} }
} }

View File

@@ -2,17 +2,17 @@
// //
// Run with: // Run with:
// //
//go test -bench=. -benchmem -count=5 ./messages/... | tee bench.txt // go test -bench=. -benchmem -count=5 ./messages/... | tee bench.txt
//# optional: go install golang.org/x/perf/cmd/benchstat@latest && benchstat bench.txt // # optional: go install golang.org/x/perf/cmd/benchstat@latest && benchstat bench.txt
package messages package messages
import ( import (
"testing" "testing"
"time" "time"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb" pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
) )
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
@@ -20,39 +20,39 @@ pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
func chatRequestProto() *pb.ChatRequest { func chatRequestProto() *pb.ChatRequest {
return &pb.ChatRequest{ return &pb.ChatRequest{
RequestId: "req-abc-123", RequestId: "req-abc-123",
UserId: "user-42", UserId: "user-42",
Message: "What is the capital of France?", Message: "What is the capital of France?",
Premium: true, Premium: true,
EnableRag: true, EnableRag: true,
EnableReranker: true, EnableReranker: true,
TopK: 10, TopK: 10,
Collection: "documents", Collection: "documents",
SystemPrompt: "You are a helpful assistant.", SystemPrompt: "You are a helpful assistant.",
ResponseSubject: "ai.chat.response.req-abc-123", ResponseSubject: "ai.chat.response.req-abc-123",
} }
} }
func voiceResponseProto() *pb.VoiceResponse { func voiceResponseProto() *pb.VoiceResponse {
return &pb.VoiceResponse{ return &pb.VoiceResponse{
RequestId: "vr-001", RequestId: "vr-001",
Response: "The capital of France is Paris.", Response: "The capital of France is Paris.",
Audio: make([]byte, 16384), Audio: make([]byte, 16384),
Transcription: "What is the capital of France?", Transcription: "What is the capital of France?",
} }
} }
func ttsChunkProto() *pb.TTSAudioChunk { func ttsChunkProto() *pb.TTSAudioChunk {
return &pb.TTSAudioChunk{ return &pb.TTSAudioChunk{
SessionId: "tts-sess-99", SessionId: "tts-sess-99",
ChunkIndex: 3, ChunkIndex: 3,
TotalChunks: 12, TotalChunks: 12,
Audio: make([]byte, 32768), Audio: make([]byte, 32768),
IsLast: false, IsLast: false,
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
SampleRate: 24000, SampleRate: 24000,
} }
} }
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
@@ -60,19 +60,19 @@ SampleRate: 24000,
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
func TestWireSize(t *testing.T) { func TestWireSize(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
protoMsg proto.Message protoMsg proto.Message
}{ }{
{"ChatRequest", chatRequestProto()}, {"ChatRequest", chatRequestProto()},
{"VoiceResponse", voiceResponseProto()}, {"VoiceResponse", voiceResponseProto()},
{"TTSAudioChunk", ttsChunkProto()}, {"TTSAudioChunk", ttsChunkProto()},
} }
for _, tt := range tests { for _, tt := range tests {
protoBytes, _ := proto.Marshal(tt.protoMsg) protoBytes, _ := proto.Marshal(tt.protoMsg)
t.Logf("%-16s proto=%5d B", tt.name, len(protoBytes)) t.Logf("%-16s proto=%5d B", tt.name, len(protoBytes))
} }
} }
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
@@ -80,27 +80,27 @@ t.Logf("%-16s proto=%5d B", tt.name, len(protoBytes))
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
func BenchmarkEncode_ChatRequest(b *testing.B) { func BenchmarkEncode_ChatRequest(b *testing.B) {
data := chatRequestProto() data := chatRequestProto()
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
_, _ = proto.Marshal(data) _, _ = proto.Marshal(data)
} }
} }
func BenchmarkEncode_VoiceResponse(b *testing.B) { func BenchmarkEncode_VoiceResponse(b *testing.B) {
data := voiceResponseProto() data := voiceResponseProto()
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
_, _ = proto.Marshal(data) _, _ = proto.Marshal(data)
} }
} }
func BenchmarkEncode_TTSChunk(b *testing.B) { func BenchmarkEncode_TTSChunk(b *testing.B) {
data := ttsChunkProto() data := ttsChunkProto()
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
_, _ = proto.Marshal(data) _, _ = proto.Marshal(data)
} }
} }
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
@@ -108,30 +108,30 @@ _, _ = proto.Marshal(data)
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
func BenchmarkDecode_ChatRequest(b *testing.B) { func BenchmarkDecode_ChatRequest(b *testing.B) {
encoded, _ := proto.Marshal(chatRequestProto()) encoded, _ := proto.Marshal(chatRequestProto())
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
var m pb.ChatRequest var m pb.ChatRequest
_ = proto.Unmarshal(encoded, &m) _ = proto.Unmarshal(encoded, &m)
} }
} }
func BenchmarkDecode_VoiceResponse(b *testing.B) { func BenchmarkDecode_VoiceResponse(b *testing.B) {
encoded, _ := proto.Marshal(voiceResponseProto()) encoded, _ := proto.Marshal(voiceResponseProto())
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
var m pb.VoiceResponse var m pb.VoiceResponse
_ = proto.Unmarshal(encoded, &m) _ = proto.Unmarshal(encoded, &m)
} }
} }
func BenchmarkDecode_TTSChunk(b *testing.B) { func BenchmarkDecode_TTSChunk(b *testing.B) {
encoded, _ := proto.Marshal(ttsChunkProto()) encoded, _ := proto.Marshal(ttsChunkProto())
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
var m pb.TTSAudioChunk var m pb.TTSAudioChunk
_ = proto.Unmarshal(encoded, &m) _ = proto.Unmarshal(encoded, &m)
} }
} }
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
@@ -139,13 +139,13 @@ _ = proto.Unmarshal(encoded, &m)
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
func BenchmarkRoundtrip_ChatRequest(b *testing.B) { func BenchmarkRoundtrip_ChatRequest(b *testing.B) {
data := chatRequestProto() data := chatRequestProto()
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
enc, _ := proto.Marshal(data) enc, _ := proto.Marshal(data)
var dec pb.ChatRequest var dec pb.ChatRequest
_ = proto.Unmarshal(enc, &dec) _ = proto.Unmarshal(enc, &dec)
} }
} }
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
@@ -153,160 +153,160 @@ _ = proto.Unmarshal(enc, &dec)
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
func TestRoundtrip_ChatRequest(t *testing.T) { func TestRoundtrip_ChatRequest(t *testing.T) {
orig := chatRequestProto() orig := chatRequestProto()
data, err := proto.Marshal(orig) data, err := proto.Marshal(orig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var dec pb.ChatRequest var dec pb.ChatRequest
if err := proto.Unmarshal(data, &dec); err != nil { if err := proto.Unmarshal(data, &dec); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if dec.GetRequestId() != orig.GetRequestId() { if dec.GetRequestId() != orig.GetRequestId() {
t.Errorf("RequestId = %q, want %q", dec.GetRequestId(), orig.GetRequestId()) t.Errorf("RequestId = %q, want %q", dec.GetRequestId(), orig.GetRequestId())
} }
if dec.GetMessage() != orig.GetMessage() { if dec.GetMessage() != orig.GetMessage() {
t.Errorf("Message = %q, want %q", dec.GetMessage(), orig.GetMessage()) t.Errorf("Message = %q, want %q", dec.GetMessage(), orig.GetMessage())
} }
if dec.GetTopK() != orig.GetTopK() { if dec.GetTopK() != orig.GetTopK() {
t.Errorf("TopK = %d, want %d", dec.GetTopK(), orig.GetTopK()) t.Errorf("TopK = %d, want %d", dec.GetTopK(), orig.GetTopK())
} }
if dec.GetPremium() != orig.GetPremium() { if dec.GetPremium() != orig.GetPremium() {
t.Errorf("Premium = %v, want %v", dec.GetPremium(), orig.GetPremium()) t.Errorf("Premium = %v, want %v", dec.GetPremium(), orig.GetPremium())
} }
if EffectiveQuery(&dec) != orig.GetMessage() { if EffectiveQuery(&dec) != orig.GetMessage() {
t.Errorf("EffectiveQuery() = %q, want %q", EffectiveQuery(&dec), orig.GetMessage()) t.Errorf("EffectiveQuery() = %q, want %q", EffectiveQuery(&dec), orig.GetMessage())
} }
} }
func TestRoundtrip_VoiceResponse(t *testing.T) { func TestRoundtrip_VoiceResponse(t *testing.T) {
orig := voiceResponseProto() orig := voiceResponseProto()
data, err := proto.Marshal(orig) data, err := proto.Marshal(orig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var dec pb.VoiceResponse var dec pb.VoiceResponse
if err := proto.Unmarshal(data, &dec); err != nil { if err := proto.Unmarshal(data, &dec); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if dec.GetRequestId() != orig.GetRequestId() { if dec.GetRequestId() != orig.GetRequestId() {
t.Errorf("RequestId mismatch") t.Errorf("RequestId mismatch")
} }
if len(dec.GetAudio()) != len(orig.GetAudio()) { if len(dec.GetAudio()) != len(orig.GetAudio()) {
t.Errorf("Audio len = %d, want %d", len(dec.GetAudio()), len(orig.GetAudio())) t.Errorf("Audio len = %d, want %d", len(dec.GetAudio()), len(orig.GetAudio()))
} }
if dec.GetTranscription() != orig.GetTranscription() { if dec.GetTranscription() != orig.GetTranscription() {
t.Errorf("Transcription mismatch") t.Errorf("Transcription mismatch")
} }
} }
func TestRoundtrip_TTSAudioChunk(t *testing.T) { func TestRoundtrip_TTSAudioChunk(t *testing.T) {
orig := ttsChunkProto() orig := ttsChunkProto()
data, err := proto.Marshal(orig) data, err := proto.Marshal(orig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var dec pb.TTSAudioChunk var dec pb.TTSAudioChunk
if err := proto.Unmarshal(data, &dec); err != nil { if err := proto.Unmarshal(data, &dec); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if dec.GetSessionId() != orig.GetSessionId() { if dec.GetSessionId() != orig.GetSessionId() {
t.Errorf("SessionId mismatch") t.Errorf("SessionId mismatch")
} }
if dec.GetChunkIndex() != orig.GetChunkIndex() { if dec.GetChunkIndex() != orig.GetChunkIndex() {
t.Errorf("ChunkIndex = %d, want %d", dec.GetChunkIndex(), orig.GetChunkIndex()) t.Errorf("ChunkIndex = %d, want %d", dec.GetChunkIndex(), orig.GetChunkIndex())
} }
if len(dec.GetAudio()) != len(orig.GetAudio()) { if len(dec.GetAudio()) != len(orig.GetAudio()) {
t.Errorf("Audio len = %d, want %d", len(dec.GetAudio()), len(orig.GetAudio())) t.Errorf("Audio len = %d, want %d", len(dec.GetAudio()), len(orig.GetAudio()))
} }
if dec.GetSampleRate() != orig.GetSampleRate() { if dec.GetSampleRate() != orig.GetSampleRate() {
t.Errorf("SampleRate = %d, want %d", dec.GetSampleRate(), orig.GetSampleRate()) t.Errorf("SampleRate = %d, want %d", dec.GetSampleRate(), orig.GetSampleRate())
} }
} }
func TestRoundtrip_PipelineTrigger(t *testing.T) { func TestRoundtrip_PipelineTrigger(t *testing.T) {
orig := &pb.PipelineTrigger{ orig := &pb.PipelineTrigger{
RequestId: "pip-001", RequestId: "pip-001",
Pipeline: "document-ingestion", Pipeline: "document-ingestion",
Parameters: map[string]string{"source": "s3://bucket/data"}, Parameters: map[string]string{"source": "s3://bucket/data"},
} }
data, err := proto.Marshal(orig) data, err := proto.Marshal(orig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var dec pb.PipelineTrigger var dec pb.PipelineTrigger
if err := proto.Unmarshal(data, &dec); err != nil { if err := proto.Unmarshal(data, &dec); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if dec.GetPipeline() != orig.GetPipeline() { if dec.GetPipeline() != orig.GetPipeline() {
t.Errorf("Pipeline = %q, want %q", dec.GetPipeline(), orig.GetPipeline()) t.Errorf("Pipeline = %q, want %q", dec.GetPipeline(), orig.GetPipeline())
} }
if dec.GetParameters()["source"] != orig.GetParameters()["source"] { if dec.GetParameters()["source"] != orig.GetParameters()["source"] {
t.Errorf("Parameters[source] mismatch") t.Errorf("Parameters[source] mismatch")
} }
} }
func TestRoundtrip_STTTranscription(t *testing.T) { func TestRoundtrip_STTTranscription(t *testing.T) {
orig := &pb.STTTranscription{ orig := &pb.STTTranscription{
SessionId: "stt-001", SessionId: "stt-001",
Transcript: "hello world", Transcript: "hello world",
Sequence: 5, Sequence: 5,
IsPartial: false, IsPartial: false,
IsFinal: true, IsFinal: true,
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
SpeakerId: "speaker-1", SpeakerId: "speaker-1",
HasVoiceActivity: true, HasVoiceActivity: true,
State: "listening", State: "listening",
} }
data, err := proto.Marshal(orig) data, err := proto.Marshal(orig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var dec pb.STTTranscription var dec pb.STTTranscription
if err := proto.Unmarshal(data, &dec); err != nil { if err := proto.Unmarshal(data, &dec); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if dec.GetTranscript() != orig.GetTranscript() { if dec.GetTranscript() != orig.GetTranscript() {
t.Errorf("Transcript = %q, want %q", dec.GetTranscript(), orig.GetTranscript()) t.Errorf("Transcript = %q, want %q", dec.GetTranscript(), orig.GetTranscript())
} }
if dec.GetIsFinal() != orig.GetIsFinal() { if dec.GetIsFinal() != orig.GetIsFinal() {
t.Error("IsFinal mismatch") t.Error("IsFinal mismatch")
} }
} }
func TestRoundtrip_ErrorResponse(t *testing.T) { func TestRoundtrip_ErrorResponse(t *testing.T) {
orig := &pb.ErrorResponse{Error: true, Message: "something broke", Type: "InternalError"} orig := &pb.ErrorResponse{Error: true, Message: "something broke", Type: "InternalError"}
data, err := proto.Marshal(orig) data, err := proto.Marshal(orig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var dec pb.ErrorResponse var dec pb.ErrorResponse
if err := proto.Unmarshal(data, &dec); err != nil { if err := proto.Unmarshal(data, &dec); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !dec.GetError() || dec.GetMessage() != "something broke" || dec.GetType() != "InternalError" { if !dec.GetError() || dec.GetMessage() != "something broke" || dec.GetType() != "InternalError" {
t.Errorf("ErrorResponse roundtrip mismatch: %+v", &dec) t.Errorf("ErrorResponse roundtrip mismatch: %+v", &dec)
} }
} }
func TestEffectiveQuery_MessageSet(t *testing.T) { func TestEffectiveQuery_MessageSet(t *testing.T) {
req := &pb.ChatRequest{Message: "hello", Query: "world"} req := &pb.ChatRequest{Message: "hello", Query: "world"}
if got := EffectiveQuery(req); got != "hello" { if got := EffectiveQuery(req); got != "hello" {
t.Errorf("EffectiveQuery() = %q, want %q", got, "hello") t.Errorf("EffectiveQuery() = %q, want %q", got, "hello")
} }
} }
func TestEffectiveQuery_FallbackToQuery(t *testing.T) { func TestEffectiveQuery_FallbackToQuery(t *testing.T) {
req := &pb.ChatRequest{Query: "world"} req := &pb.ChatRequest{Query: "world"}
if got := EffectiveQuery(req); got != "world" { if got := EffectiveQuery(req); got != "world" {
t.Errorf("EffectiveQuery() = %q, want %q", got, "world") t.Errorf("EffectiveQuery() = %q, want %q", got, "world")
} }
} }
func TestTimestamp(t *testing.T) { func TestTimestamp(t *testing.T) {
ts := Timestamp() ts := Timestamp()
now := time.Now().Unix() now := time.Now().Unix()
if ts < now-1 || ts > now+1 { if ts < now-1 || ts > now+1 {
t.Errorf("Timestamp() = %d, expected ~%d", ts, now) t.Errorf("Timestamp() = %d, expected ~%d", ts, now)
} }
} }

View File

@@ -8,9 +8,9 @@
package messages package messages
import ( import (
"time" "time"
pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb" pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
) )
// ════════════════════════════════════════════════════════════════════════════ // ════════════════════════════════════════════════════════════════════════════
@@ -57,13 +57,13 @@ type PipelineStatus = pb.PipelineStatus
// EffectiveQuery returns Message or falls back to Query. // EffectiveQuery returns Message or falls back to Query.
func EffectiveQuery(c *ChatRequest) string { func EffectiveQuery(c *ChatRequest) string {
if c.GetMessage() != "" { if c.GetMessage() != "" {
return c.GetMessage() return c.GetMessage()
} }
return c.GetQuery() return c.GetQuery()
} }
// Timestamp returns the current Unix timestamp. // Timestamp returns the current Unix timestamp.
func Timestamp() int64 { func Timestamp() int64 {
return time.Now().Unix() return time.Now().Unix()
} }

View File

@@ -2,69 +2,69 @@
package natsutil package natsutil
import ( import (
"fmt" "fmt"
"log/slog" "log/slog"
"time" "time"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
) )
// Client wraps a NATS connection with protobuf helpers. // Client wraps a NATS connection with protobuf helpers.
type Client struct { type Client struct {
nc *nats.Conn nc *nats.Conn
js nats.JetStreamContext js nats.JetStreamContext
subs []*nats.Subscription subs []*nats.Subscription
url string url string
opts []nats.Option opts []nats.Option
} }
// New creates a NATS client configured to connect to the given URL. // New creates a NATS client configured to connect to the given URL.
// Optional NATS options (e.g. credentials) can be appended. // Optional NATS options (e.g. credentials) can be appended.
func New(url string, opts ...nats.Option) *Client { func New(url string, opts ...nats.Option) *Client {
defaults := []nats.Option{ defaults := []nats.Option{
nats.ReconnectWait(2 * time.Second), nats.ReconnectWait(2 * time.Second),
nats.MaxReconnects(-1), nats.MaxReconnects(-1),
nats.DisconnectErrHandler(func(_ *nats.Conn, err error) { nats.DisconnectErrHandler(func(_ *nats.Conn, err error) {
slog.Warn("NATS disconnected", "error", err) slog.Warn("NATS disconnected", "error", err)
}), }),
nats.ReconnectHandler(func(_ *nats.Conn) { nats.ReconnectHandler(func(_ *nats.Conn) {
slog.Info("NATS reconnected") slog.Info("NATS reconnected")
}), }),
} }
return &Client{ return &Client{
url: url, url: url,
opts: append(defaults, opts...), opts: append(defaults, opts...),
} }
} }
// Connect establishes the NATS connection and JetStream context. // Connect establishes the NATS connection and JetStream context.
func (c *Client) Connect() error { func (c *Client) Connect() error {
nc, err := nats.Connect(c.url, c.opts...) nc, err := nats.Connect(c.url, c.opts...)
if err != nil { if err != nil {
return fmt.Errorf("nats connect: %w", err) return fmt.Errorf("nats connect: %w", err)
} }
js, err := nc.JetStream() js, err := nc.JetStream()
if err != nil { if err != nil {
nc.Close() nc.Close()
return fmt.Errorf("jetstream: %w", err) return fmt.Errorf("jetstream: %w", err)
} }
c.nc = nc c.nc = nc
c.js = js c.js = js
slog.Info("connected to NATS", "url", c.url) slog.Info("connected to NATS", "url", c.url)
return nil return nil
} }
// Close drains subscriptions and closes the connection. // Close drains subscriptions and closes the connection.
func (c *Client) Close() { func (c *Client) Close() {
if c.nc == nil { if c.nc == nil {
return return
} }
for _, sub := range c.subs { for _, sub := range c.subs {
_ = sub.Drain() _ = sub.Drain()
} }
c.nc.Close() c.nc.Close()
slog.Info("NATS connection closed") slog.Info("NATS connection closed")
} }
// Conn returns the underlying *nats.Conn. // Conn returns the underlying *nats.Conn.
@@ -75,56 +75,56 @@ func (c *Client) JS() nats.JetStreamContext { return c.js }
// IsConnected returns true if the NATS connection is active. // IsConnected returns true if the NATS connection is active.
func (c *Client) IsConnected() bool { func (c *Client) IsConnected() bool {
return c.nc != nil && c.nc.IsConnected() return c.nc != nil && c.nc.IsConnected()
} }
// Subscribe subscribes to a subject with an optional queue group. // Subscribe subscribes to a subject with an optional queue group.
// The handler receives the raw *nats.Msg. // The handler receives the raw *nats.Msg.
func (c *Client) Subscribe(subject string, handler nats.MsgHandler, queue string) error { func (c *Client) Subscribe(subject string, handler nats.MsgHandler, queue string) error {
var sub *nats.Subscription var sub *nats.Subscription
var err error var err error
if queue != "" { if queue != "" {
sub, err = c.nc.QueueSubscribe(subject, queue, handler) sub, err = c.nc.QueueSubscribe(subject, queue, handler)
slog.Info("subscribed", "subject", subject, "queue", queue) slog.Info("subscribed", "subject", subject, "queue", queue)
} else { } else {
sub, err = c.nc.Subscribe(subject, handler) sub, err = c.nc.Subscribe(subject, handler)
slog.Info("subscribed", "subject", subject) slog.Info("subscribed", "subject", subject)
} }
if err != nil { if err != nil {
return fmt.Errorf("subscribe %s: %w", subject, err) return fmt.Errorf("subscribe %s: %w", subject, err)
} }
c.subs = append(c.subs, sub) c.subs = append(c.subs, sub)
return nil return nil
} }
// Publish encodes data as protobuf and publishes to the subject. // Publish encodes data as protobuf and publishes to the subject.
func (c *Client) Publish(subject string, data proto.Message) error { func (c *Client) Publish(subject string, data proto.Message) error {
payload, err := proto.Marshal(data) payload, err := proto.Marshal(data)
if err != nil { if err != nil {
return fmt.Errorf("proto marshal: %w", err) return fmt.Errorf("proto marshal: %w", err)
} }
return c.nc.Publish(subject, payload) return c.nc.Publish(subject, payload)
} }
// PublishRaw publishes pre-encoded bytes to the subject. // PublishRaw publishes pre-encoded bytes to the subject.
func (c *Client) PublishRaw(subject string, data []byte) error { func (c *Client) PublishRaw(subject string, data []byte) error {
return c.nc.Publish(subject, data) return c.nc.Publish(subject, data)
} }
// Request sends a protobuf-encoded request and decodes the response into result. // Request sends a protobuf-encoded request and decodes the response into result.
func (c *Client) Request(subject string, data proto.Message, result proto.Message, timeout time.Duration) error { func (c *Client) Request(subject string, data proto.Message, result proto.Message, timeout time.Duration) error {
payload, err := proto.Marshal(data) payload, err := proto.Marshal(data)
if err != nil { if err != nil {
return fmt.Errorf("proto marshal: %w", err) return fmt.Errorf("proto marshal: %w", err)
} }
msg, err := c.nc.Request(subject, payload, timeout) msg, err := c.nc.Request(subject, payload, timeout)
if err != nil { if err != nil {
return fmt.Errorf("nats request: %w", err) return fmt.Errorf("nats request: %w", err)
} }
return proto.Unmarshal(msg.Data, result) return proto.Unmarshal(msg.Data, result)
} }
// Decode unmarshals protobuf bytes into dest. // Decode unmarshals protobuf bytes into dest.
func Decode(data []byte, dest proto.Message) error { func Decode(data []byte, dest proto.Message) error {
return proto.Unmarshal(data, dest) return proto.Unmarshal(data, dest)
} }