style: gofmt + fix errcheck lint warning
This commit is contained in:
@@ -563,7 +563,7 @@ func TestLLMClient_StreamGenerateManyTokens(t *testing.T) {
|
|||||||
var order []int
|
var order []int
|
||||||
result, err := c.StreamGenerate(context.Background(), "q", "", "", func(tok string) {
|
result, err := c.StreamGenerate(context.Background(), "q", "", "", func(tok string) {
|
||||||
var idx int
|
var idx int
|
||||||
fmt.Sscanf(tok, "t%d ", &idx)
|
_, _ = fmt.Sscanf(tok, "t%d ", &idx)
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
order = append(order, idx)
|
order = append(order, idx)
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
|
|||||||
364
config/config.go
364
config/config.go
@@ -3,16 +3,16 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
"github.com/fsnotify/fsnotify"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Settings holds base configuration for all handler services.
|
// Settings holds base configuration for all handler services.
|
||||||
@@ -20,249 +20,249 @@ import (
|
|||||||
// updated at runtime via WatchSecrets(). All other fields are immutable
|
// updated at runtime via WatchSecrets(). All other fields are immutable
|
||||||
// after Load() returns.
|
// after Load() returns.
|
||||||
type Settings struct {
|
type Settings struct {
|
||||||
// Service identification (immutable)
|
// Service identification (immutable)
|
||||||
ServiceName string
|
ServiceName string
|
||||||
ServiceVersion string
|
ServiceVersion string
|
||||||
ServiceNamespace string
|
ServiceNamespace string
|
||||||
DeploymentEnv string
|
DeploymentEnv string
|
||||||
|
|
||||||
// NATS configuration (immutable)
|
// NATS configuration (immutable)
|
||||||
NATSURL string
|
NATSURL string
|
||||||
NATSUser string
|
NATSUser string
|
||||||
NATSPassword string
|
NATSPassword string
|
||||||
NATSQueueGroup string
|
NATSQueueGroup string
|
||||||
|
|
||||||
// Redis/Valkey configuration (immutable)
|
// Redis/Valkey configuration (immutable)
|
||||||
RedisURL string
|
RedisURL string
|
||||||
RedisPassword string
|
RedisPassword string
|
||||||
|
|
||||||
// Milvus configuration (immutable)
|
// Milvus configuration (immutable)
|
||||||
MilvusHost string
|
MilvusHost string
|
||||||
MilvusPort int
|
MilvusPort int
|
||||||
MilvusCollection string
|
MilvusCollection string
|
||||||
|
|
||||||
// OpenTelemetry configuration (immutable)
|
// OpenTelemetry configuration (immutable)
|
||||||
OTELEnabled bool
|
OTELEnabled bool
|
||||||
OTELEndpoint string
|
OTELEndpoint string
|
||||||
OTELUseHTTP bool
|
OTELUseHTTP bool
|
||||||
|
|
||||||
// HyperDX configuration (immutable)
|
// HyperDX configuration (immutable)
|
||||||
HyperDXEnabled bool
|
HyperDXEnabled bool
|
||||||
HyperDXAPIKey string
|
HyperDXAPIKey string
|
||||||
HyperDXEndpoint string
|
HyperDXEndpoint string
|
||||||
|
|
||||||
// MLflow configuration (immutable)
|
// MLflow configuration (immutable)
|
||||||
MLflowTrackingURI string
|
MLflowTrackingURI string
|
||||||
MLflowExperimentName string
|
MLflowExperimentName string
|
||||||
MLflowEnabled bool
|
MLflowEnabled bool
|
||||||
|
|
||||||
// Health check configuration (immutable)
|
// Health check configuration (immutable)
|
||||||
HealthPort int
|
HealthPort int
|
||||||
HealthPath string
|
HealthPath string
|
||||||
ReadyPath string
|
ReadyPath string
|
||||||
|
|
||||||
// Timeouts (immutable)
|
// Timeouts (immutable)
|
||||||
HTTPTimeout time.Duration
|
HTTPTimeout time.Duration
|
||||||
NATSTimeout time.Duration
|
NATSTimeout time.Duration
|
||||||
|
|
||||||
// Hot-reloadable fields — access via getter methods.
|
// Hot-reloadable fields — access via getter methods.
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
embeddingsURL string
|
embeddingsURL string
|
||||||
rerankerURL string
|
rerankerURL string
|
||||||
llmURL string
|
llmURL string
|
||||||
ttsURL string
|
ttsURL string
|
||||||
sttURL string
|
sttURL string
|
||||||
|
|
||||||
// Secrets path for file-based hot reload (Kubernetes secret mounts)
|
// Secrets path for file-based hot reload (Kubernetes secret mounts)
|
||||||
SecretsPath string
|
SecretsPath string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load creates a Settings populated from environment variables with defaults.
|
// Load creates a Settings populated from environment variables with defaults.
|
||||||
func Load() *Settings {
|
func Load() *Settings {
|
||||||
return &Settings{
|
return &Settings{
|
||||||
ServiceName: getEnv("SERVICE_NAME", "handler"),
|
ServiceName: getEnv("SERVICE_NAME", "handler"),
|
||||||
ServiceVersion: getEnv("SERVICE_VERSION", "1.0.0"),
|
ServiceVersion: getEnv("SERVICE_VERSION", "1.0.0"),
|
||||||
ServiceNamespace: getEnv("SERVICE_NAMESPACE", "ai-ml"),
|
ServiceNamespace: getEnv("SERVICE_NAMESPACE", "ai-ml"),
|
||||||
DeploymentEnv: getEnv("DEPLOYMENT_ENV", "production"),
|
DeploymentEnv: getEnv("DEPLOYMENT_ENV", "production"),
|
||||||
|
|
||||||
NATSURL: getEnv("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222"),
|
NATSURL: getEnv("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222"),
|
||||||
NATSUser: getEnv("NATS_USER", ""),
|
NATSUser: getEnv("NATS_USER", ""),
|
||||||
NATSPassword: getEnv("NATS_PASSWORD", ""),
|
NATSPassword: getEnv("NATS_PASSWORD", ""),
|
||||||
NATSQueueGroup: getEnv("NATS_QUEUE_GROUP", ""),
|
NATSQueueGroup: getEnv("NATS_QUEUE_GROUP", ""),
|
||||||
|
|
||||||
RedisURL: getEnv("REDIS_URL", "redis://valkey.ai-ml.svc.cluster.local:6379"),
|
RedisURL: getEnv("REDIS_URL", "redis://valkey.ai-ml.svc.cluster.local:6379"),
|
||||||
RedisPassword: getEnv("REDIS_PASSWORD", ""),
|
RedisPassword: getEnv("REDIS_PASSWORD", ""),
|
||||||
|
|
||||||
MilvusHost: getEnv("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local"),
|
MilvusHost: getEnv("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local"),
|
||||||
MilvusPort: getEnvInt("MILVUS_PORT", 19530),
|
MilvusPort: getEnvInt("MILVUS_PORT", 19530),
|
||||||
MilvusCollection: getEnv("MILVUS_COLLECTION", "documents"),
|
MilvusCollection: getEnv("MILVUS_COLLECTION", "documents"),
|
||||||
|
|
||||||
embeddingsURL: getEnv("EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"),
|
embeddingsURL: getEnv("EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"),
|
||||||
rerankerURL: getEnv("RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"),
|
rerankerURL: getEnv("RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"),
|
||||||
llmURL: getEnv("LLM_URL", "http://vllm-predictor.ai-ml.svc.cluster.local"),
|
llmURL: getEnv("LLM_URL", "http://vllm-predictor.ai-ml.svc.cluster.local"),
|
||||||
ttsURL: getEnv("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local"),
|
ttsURL: getEnv("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local"),
|
||||||
sttURL: getEnv("STT_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"),
|
sttURL: getEnv("STT_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"),
|
||||||
|
|
||||||
OTELEnabled: getEnvBool("OTEL_ENABLED", true),
|
OTELEnabled: getEnvBool("OTEL_ENABLED", true),
|
||||||
OTELEndpoint: getEnv("OTEL_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317"),
|
OTELEndpoint: getEnv("OTEL_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317"),
|
||||||
OTELUseHTTP: getEnvBool("OTEL_USE_HTTP", false),
|
OTELUseHTTP: getEnvBool("OTEL_USE_HTTP", false),
|
||||||
|
|
||||||
HyperDXEnabled: getEnvBool("HYPERDX_ENABLED", false),
|
HyperDXEnabled: getEnvBool("HYPERDX_ENABLED", false),
|
||||||
HyperDXAPIKey: getEnv("HYPERDX_API_KEY", ""),
|
HyperDXAPIKey: getEnv("HYPERDX_API_KEY", ""),
|
||||||
HyperDXEndpoint: getEnv("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io"),
|
HyperDXEndpoint: getEnv("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io"),
|
||||||
|
|
||||||
MLflowTrackingURI: getEnv("MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"),
|
MLflowTrackingURI: getEnv("MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"),
|
||||||
MLflowExperimentName: getEnv("MLFLOW_EXPERIMENT_NAME", ""),
|
MLflowExperimentName: getEnv("MLFLOW_EXPERIMENT_NAME", ""),
|
||||||
MLflowEnabled: getEnvBool("MLFLOW_ENABLED", true),
|
MLflowEnabled: getEnvBool("MLFLOW_ENABLED", true),
|
||||||
|
|
||||||
HealthPort: getEnvInt("HEALTH_PORT", 8080),
|
HealthPort: getEnvInt("HEALTH_PORT", 8080),
|
||||||
HealthPath: getEnv("HEALTH_PATH", "/health"),
|
HealthPath: getEnv("HEALTH_PATH", "/health"),
|
||||||
ReadyPath: getEnv("READY_PATH", "/ready"),
|
ReadyPath: getEnv("READY_PATH", "/ready"),
|
||||||
|
|
||||||
HTTPTimeout: getEnvDuration("HTTP_TIMEOUT", 60*time.Second),
|
HTTPTimeout: getEnvDuration("HTTP_TIMEOUT", 60*time.Second),
|
||||||
NATSTimeout: getEnvDuration("NATS_TIMEOUT", 30*time.Second),
|
NATSTimeout: getEnvDuration("NATS_TIMEOUT", 30*time.Second),
|
||||||
|
|
||||||
SecretsPath: getEnv("SECRETS_PATH", ""),
|
SecretsPath: getEnv("SECRETS_PATH", ""),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// EmbeddingsURL returns the current embeddings service URL (thread-safe).
|
// EmbeddingsURL returns the current embeddings service URL (thread-safe).
|
||||||
func (s *Settings) EmbeddingsURL() string {
|
func (s *Settings) EmbeddingsURL() string {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
return s.embeddingsURL
|
return s.embeddingsURL
|
||||||
}
|
}
|
||||||
|
|
||||||
// RerankerURL returns the current reranker service URL (thread-safe).
|
// RerankerURL returns the current reranker service URL (thread-safe).
|
||||||
func (s *Settings) RerankerURL() string {
|
func (s *Settings) RerankerURL() string {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
return s.rerankerURL
|
return s.rerankerURL
|
||||||
}
|
}
|
||||||
|
|
||||||
// LLMURL returns the current LLM service URL (thread-safe).
|
// LLMURL returns the current LLM service URL (thread-safe).
|
||||||
func (s *Settings) LLMURL() string {
|
func (s *Settings) LLMURL() string {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
return s.llmURL
|
return s.llmURL
|
||||||
}
|
}
|
||||||
|
|
||||||
// TTSURL returns the current TTS service URL (thread-safe).
|
// TTSURL returns the current TTS service URL (thread-safe).
|
||||||
func (s *Settings) TTSURL() string {
|
func (s *Settings) TTSURL() string {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
return s.ttsURL
|
return s.ttsURL
|
||||||
}
|
}
|
||||||
|
|
||||||
// STTURL returns the current STT service URL (thread-safe).
|
// STTURL returns the current STT service URL (thread-safe).
|
||||||
func (s *Settings) STTURL() string {
|
func (s *Settings) STTURL() string {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
return s.sttURL
|
return s.sttURL
|
||||||
}
|
}
|
||||||
|
|
||||||
// WatchSecrets watches the SecretsPath directory for changes and reloads
|
// WatchSecrets watches the SecretsPath directory for changes and reloads
|
||||||
// hot-reloadable fields. Blocks until ctx is cancelled.
|
// hot-reloadable fields. Blocks until ctx is cancelled.
|
||||||
func (s *Settings) WatchSecrets(ctx context.Context) {
|
func (s *Settings) WatchSecrets(ctx context.Context) {
|
||||||
if s.SecretsPath == "" {
|
if s.SecretsPath == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
watcher, err := fsnotify.NewWatcher()
|
watcher, err := fsnotify.NewWatcher()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("config: failed to create fsnotify watcher", "error", err)
|
slog.Error("config: failed to create fsnotify watcher", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() { _ = watcher.Close() }()
|
defer func() { _ = watcher.Close() }()
|
||||||
|
|
||||||
if err := watcher.Add(s.SecretsPath); err != nil {
|
if err := watcher.Add(s.SecretsPath); err != nil {
|
||||||
slog.Error("config: failed to watch secrets path", "error", err, "path", s.SecretsPath)
|
slog.Error("config: failed to watch secrets path", "error", err, "path", s.SecretsPath)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("config: watching secrets for hot reload", "path", s.SecretsPath)
|
slog.Info("config: watching secrets for hot reload", "path", s.SecretsPath)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case event, ok := <-watcher.Events:
|
case event, ok := <-watcher.Events:
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if event.Has(fsnotify.Create) || event.Has(fsnotify.Write) {
|
if event.Has(fsnotify.Create) || event.Has(fsnotify.Write) {
|
||||||
s.reloadFromSecrets()
|
s.reloadFromSecrets()
|
||||||
}
|
}
|
||||||
case err, ok := <-watcher.Errors:
|
case err, ok := <-watcher.Errors:
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
slog.Error("config: fsnotify error", "error", err)
|
slog.Error("config: fsnotify error", "error", err)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// reloadFromSecrets reads hot-reloadable values from the secrets directory.
|
// reloadFromSecrets reads hot-reloadable values from the secrets directory.
|
||||||
func (s *Settings) reloadFromSecrets() {
|
func (s *Settings) reloadFromSecrets() {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
updated := 0
|
updated := 0
|
||||||
reload := func(filename string, target *string) {
|
reload := func(filename string, target *string) {
|
||||||
path := filepath.Join(s.SecretsPath, filename)
|
path := filepath.Join(s.SecretsPath, filename)
|
||||||
data, err := os.ReadFile(path)
|
data, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
val := strings.TrimSpace(string(data))
|
val := strings.TrimSpace(string(data))
|
||||||
if val != "" && val != *target {
|
if val != "" && val != *target {
|
||||||
*target = val
|
*target = val
|
||||||
updated++
|
updated++
|
||||||
slog.Info("config: reloaded secret", "key", filename)
|
slog.Info("config: reloaded secret", "key", filename)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
reload("embeddings-url", &s.embeddingsURL)
|
reload("embeddings-url", &s.embeddingsURL)
|
||||||
reload("reranker-url", &s.rerankerURL)
|
reload("reranker-url", &s.rerankerURL)
|
||||||
reload("llm-url", &s.llmURL)
|
reload("llm-url", &s.llmURL)
|
||||||
reload("tts-url", &s.ttsURL)
|
reload("tts-url", &s.ttsURL)
|
||||||
reload("stt-url", &s.sttURL)
|
reload("stt-url", &s.sttURL)
|
||||||
|
|
||||||
if updated > 0 {
|
if updated > 0 {
|
||||||
slog.Info("config: secrets reloaded", "updated", updated)
|
slog.Info("config: secrets reloaded", "updated", updated)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getEnv(key, fallback string) string {
|
func getEnv(key, fallback string) string {
|
||||||
if v := os.Getenv(key); v != "" {
|
if v := os.Getenv(key); v != "" {
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
return fallback
|
return fallback
|
||||||
}
|
}
|
||||||
|
|
||||||
func getEnvInt(key string, fallback int) int {
|
func getEnvInt(key string, fallback int) int {
|
||||||
if v := os.Getenv(key); v != "" {
|
if v := os.Getenv(key); v != "" {
|
||||||
if i, err := strconv.Atoi(v); err == nil {
|
if i, err := strconv.Atoi(v); err == nil {
|
||||||
return i
|
return i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fallback
|
return fallback
|
||||||
}
|
}
|
||||||
|
|
||||||
func getEnvBool(key string, fallback bool) bool {
|
func getEnvBool(key string, fallback bool) bool {
|
||||||
if v := os.Getenv(key); v != "" {
|
if v := os.Getenv(key); v != "" {
|
||||||
if b, err := strconv.ParseBool(v); err == nil {
|
if b, err := strconv.ParseBool(v); err == nil {
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fallback
|
return fallback
|
||||||
}
|
}
|
||||||
|
|
||||||
func getEnvDuration(key string, fallback time.Duration) time.Duration {
|
func getEnvDuration(key string, fallback time.Duration) time.Duration {
|
||||||
if v := os.Getenv(key); v != "" {
|
if v := os.Getenv(key); v != "" {
|
||||||
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
||||||
return time.Duration(f * float64(time.Second))
|
return time.Duration(f * float64(time.Second))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fallback
|
return fallback
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,123 +1,123 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLoadDefaults(t *testing.T) {
|
func TestLoadDefaults(t *testing.T) {
|
||||||
s := Load()
|
s := Load()
|
||||||
if s.ServiceName != "handler" {
|
if s.ServiceName != "handler" {
|
||||||
t.Errorf("expected default ServiceName 'handler', got %q", s.ServiceName)
|
t.Errorf("expected default ServiceName 'handler', got %q", s.ServiceName)
|
||||||
}
|
}
|
||||||
if s.HealthPort != 8080 {
|
if s.HealthPort != 8080 {
|
||||||
t.Errorf("expected default HealthPort 8080, got %d", s.HealthPort)
|
t.Errorf("expected default HealthPort 8080, got %d", s.HealthPort)
|
||||||
}
|
}
|
||||||
if s.HTTPTimeout != 60*time.Second {
|
if s.HTTPTimeout != 60*time.Second {
|
||||||
t.Errorf("expected default HTTPTimeout 60s, got %v", s.HTTPTimeout)
|
t.Errorf("expected default HTTPTimeout 60s, got %v", s.HTTPTimeout)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadFromEnv(t *testing.T) {
|
func TestLoadFromEnv(t *testing.T) {
|
||||||
t.Setenv("SERVICE_NAME", "test-svc")
|
t.Setenv("SERVICE_NAME", "test-svc")
|
||||||
t.Setenv("HEALTH_PORT", "9090")
|
t.Setenv("HEALTH_PORT", "9090")
|
||||||
t.Setenv("OTEL_ENABLED", "false")
|
t.Setenv("OTEL_ENABLED", "false")
|
||||||
|
|
||||||
s := Load()
|
s := Load()
|
||||||
if s.ServiceName != "test-svc" {
|
if s.ServiceName != "test-svc" {
|
||||||
t.Errorf("expected ServiceName 'test-svc', got %q", s.ServiceName)
|
t.Errorf("expected ServiceName 'test-svc', got %q", s.ServiceName)
|
||||||
}
|
}
|
||||||
if s.HealthPort != 9090 {
|
if s.HealthPort != 9090 {
|
||||||
t.Errorf("expected HealthPort 9090, got %d", s.HealthPort)
|
t.Errorf("expected HealthPort 9090, got %d", s.HealthPort)
|
||||||
}
|
}
|
||||||
if s.OTELEnabled {
|
if s.OTELEnabled {
|
||||||
t.Error("expected OTELEnabled false")
|
t.Error("expected OTELEnabled false")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestURLGetters(t *testing.T) {
|
func TestURLGetters(t *testing.T) {
|
||||||
s := Load()
|
s := Load()
|
||||||
if s.EmbeddingsURL() == "" {
|
if s.EmbeddingsURL() == "" {
|
||||||
t.Error("EmbeddingsURL should have a default")
|
t.Error("EmbeddingsURL should have a default")
|
||||||
}
|
}
|
||||||
if s.RerankerURL() == "" {
|
if s.RerankerURL() == "" {
|
||||||
t.Error("RerankerURL should have a default")
|
t.Error("RerankerURL should have a default")
|
||||||
}
|
}
|
||||||
if s.LLMURL() == "" {
|
if s.LLMURL() == "" {
|
||||||
t.Error("LLMURL should have a default")
|
t.Error("LLMURL should have a default")
|
||||||
}
|
}
|
||||||
if s.TTSURL() == "" {
|
if s.TTSURL() == "" {
|
||||||
t.Error("TTSURL should have a default")
|
t.Error("TTSURL should have a default")
|
||||||
}
|
}
|
||||||
if s.STTURL() == "" {
|
if s.STTURL() == "" {
|
||||||
t.Error("STTURL should have a default")
|
t.Error("STTURL should have a default")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestURLGettersFromEnv(t *testing.T) {
|
func TestURLGettersFromEnv(t *testing.T) {
|
||||||
t.Setenv("EMBEDDINGS_URL", "http://embed:8000")
|
t.Setenv("EMBEDDINGS_URL", "http://embed:8000")
|
||||||
t.Setenv("LLM_URL", "http://llm:9000")
|
t.Setenv("LLM_URL", "http://llm:9000")
|
||||||
|
|
||||||
s := Load()
|
s := Load()
|
||||||
if s.EmbeddingsURL() != "http://embed:8000" {
|
if s.EmbeddingsURL() != "http://embed:8000" {
|
||||||
t.Errorf("expected custom EmbeddingsURL, got %q", s.EmbeddingsURL())
|
t.Errorf("expected custom EmbeddingsURL, got %q", s.EmbeddingsURL())
|
||||||
}
|
}
|
||||||
if s.LLMURL() != "http://llm:9000" {
|
if s.LLMURL() != "http://llm:9000" {
|
||||||
t.Errorf("expected custom LLMURL, got %q", s.LLMURL())
|
t.Errorf("expected custom LLMURL, got %q", s.LLMURL())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReloadFromSecrets(t *testing.T) {
|
func TestReloadFromSecrets(t *testing.T) {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
|
|
||||||
// Write initial secret files
|
// Write initial secret files
|
||||||
writeSecret(t, dir, "embeddings-url", "http://old-embed:8000")
|
writeSecret(t, dir, "embeddings-url", "http://old-embed:8000")
|
||||||
writeSecret(t, dir, "llm-url", "http://old-llm:9000")
|
writeSecret(t, dir, "llm-url", "http://old-llm:9000")
|
||||||
|
|
||||||
s := Load()
|
s := Load()
|
||||||
s.SecretsPath = dir
|
s.SecretsPath = dir
|
||||||
s.reloadFromSecrets()
|
s.reloadFromSecrets()
|
||||||
|
|
||||||
if s.EmbeddingsURL() != "http://old-embed:8000" {
|
if s.EmbeddingsURL() != "http://old-embed:8000" {
|
||||||
t.Errorf("expected reloaded EmbeddingsURL, got %q", s.EmbeddingsURL())
|
t.Errorf("expected reloaded EmbeddingsURL, got %q", s.EmbeddingsURL())
|
||||||
}
|
}
|
||||||
if s.LLMURL() != "http://old-llm:9000" {
|
if s.LLMURL() != "http://old-llm:9000" {
|
||||||
t.Errorf("expected reloaded LLMURL, got %q", s.LLMURL())
|
t.Errorf("expected reloaded LLMURL, got %q", s.LLMURL())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Simulate secret update
|
// Simulate secret update
|
||||||
writeSecret(t, dir, "embeddings-url", "http://new-embed:8000")
|
writeSecret(t, dir, "embeddings-url", "http://new-embed:8000")
|
||||||
s.reloadFromSecrets()
|
s.reloadFromSecrets()
|
||||||
|
|
||||||
if s.EmbeddingsURL() != "http://new-embed:8000" {
|
if s.EmbeddingsURL() != "http://new-embed:8000" {
|
||||||
t.Errorf("expected updated EmbeddingsURL, got %q", s.EmbeddingsURL())
|
t.Errorf("expected updated EmbeddingsURL, got %q", s.EmbeddingsURL())
|
||||||
}
|
}
|
||||||
// LLM should remain unchanged
|
// LLM should remain unchanged
|
||||||
if s.LLMURL() != "http://old-llm:9000" {
|
if s.LLMURL() != "http://old-llm:9000" {
|
||||||
t.Errorf("expected unchanged LLMURL, got %q", s.LLMURL())
|
t.Errorf("expected unchanged LLMURL, got %q", s.LLMURL())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReloadFromSecretsNoPath(t *testing.T) {
|
func TestReloadFromSecretsNoPath(t *testing.T) {
|
||||||
s := Load()
|
s := Load()
|
||||||
s.SecretsPath = ""
|
s.SecretsPath = ""
|
||||||
// Should not panic
|
// Should not panic
|
||||||
s.reloadFromSecrets()
|
s.reloadFromSecrets()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetEnvDuration(t *testing.T) {
|
func TestGetEnvDuration(t *testing.T) {
|
||||||
t.Setenv("TEST_DUR", "30")
|
t.Setenv("TEST_DUR", "30")
|
||||||
d := getEnvDuration("TEST_DUR", 10*time.Second)
|
d := getEnvDuration("TEST_DUR", 10*time.Second)
|
||||||
if d != 30*time.Second {
|
if d != 30*time.Second {
|
||||||
t.Errorf("expected 30s, got %v", d)
|
t.Errorf("expected 30s, got %v", d)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeSecret(t *testing.T, dir, name, value string) {
|
func writeSecret(t *testing.T, dir, name, value string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
if err := os.WriteFile(filepath.Join(dir, name), []byte(value), 0644); err != nil {
|
if err := os.WriteFile(filepath.Join(dir, name), []byte(value), 0644); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,21 +2,21 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
||||||
pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
|
pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
|
||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/health"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/health"
|
||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/telemetry"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/telemetry"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TypedMessageHandler processes the raw NATS message.
|
// TypedMessageHandler processes the raw NATS message.
|
||||||
@@ -32,36 +32,36 @@ type TeardownFunc func(ctx context.Context) error
|
|||||||
|
|
||||||
// Handler is the base service runner that wires NATS, health, and telemetry.
|
// Handler is the base service runner that wires NATS, health, and telemetry.
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
Settings *config.Settings
|
Settings *config.Settings
|
||||||
NATS *natsutil.Client
|
NATS *natsutil.Client
|
||||||
Telemetry *telemetry.Provider
|
Telemetry *telemetry.Provider
|
||||||
Subject string
|
Subject string
|
||||||
QueueGroup string
|
QueueGroup string
|
||||||
|
|
||||||
onSetup SetupFunc
|
onSetup SetupFunc
|
||||||
onTeardown TeardownFunc
|
onTeardown TeardownFunc
|
||||||
onTypedMessage TypedMessageHandler
|
onTypedMessage TypedMessageHandler
|
||||||
running bool
|
running bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a Handler for the given NATS subject.
|
// New creates a Handler for the given NATS subject.
|
||||||
func New(subject string, settings *config.Settings) *Handler {
|
func New(subject string, settings *config.Settings) *Handler {
|
||||||
if settings == nil {
|
if settings == nil {
|
||||||
settings = config.Load()
|
settings = config.Load()
|
||||||
}
|
}
|
||||||
queueGroup := settings.NATSQueueGroup
|
queueGroup := settings.NATSQueueGroup
|
||||||
|
|
||||||
natsOpts := []nats.Option{}
|
natsOpts := []nats.Option{}
|
||||||
if settings.NATSUser != "" && settings.NATSPassword != "" {
|
if settings.NATSUser != "" && settings.NATSPassword != "" {
|
||||||
natsOpts = append(natsOpts, nats.UserInfo(settings.NATSUser, settings.NATSPassword))
|
natsOpts = append(natsOpts, nats.UserInfo(settings.NATSUser, settings.NATSPassword))
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Handler{
|
return &Handler{
|
||||||
Settings: settings,
|
Settings: settings,
|
||||||
Subject: subject,
|
Subject: subject,
|
||||||
QueueGroup: queueGroup,
|
QueueGroup: queueGroup,
|
||||||
NATS: natsutil.New(settings.NATSURL, natsOpts...),
|
NATS: natsutil.New(settings.NATSURL, natsOpts...),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnSetup registers the setup callback.
|
// OnSetup registers the setup callback.
|
||||||
@@ -75,101 +75,101 @@ func (h *Handler) OnTypedMessage(fn TypedMessageHandler) { h.onTypedMessage = fn
|
|||||||
|
|
||||||
// Run starts the handler: telemetry, health server, NATS subscription, and blocks until SIGTERM/SIGINT.
|
// Run starts the handler: telemetry, health server, NATS subscription, and blocks until SIGTERM/SIGINT.
|
||||||
func (h *Handler) Run() error {
|
func (h *Handler) Run() error {
|
||||||
// Structured logging
|
// Structured logging
|
||||||
slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})))
|
slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})))
|
||||||
slog.Info("starting service", "name", h.Settings.ServiceName, "version", h.Settings.ServiceVersion)
|
slog.Info("starting service", "name", h.Settings.ServiceName, "version", h.Settings.ServiceVersion)
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Telemetry
|
// Telemetry
|
||||||
tp, shutdown, err := telemetry.Setup(ctx, telemetry.Config{
|
tp, shutdown, err := telemetry.Setup(ctx, telemetry.Config{
|
||||||
ServiceName: h.Settings.ServiceName,
|
ServiceName: h.Settings.ServiceName,
|
||||||
ServiceVersion: h.Settings.ServiceVersion,
|
ServiceVersion: h.Settings.ServiceVersion,
|
||||||
ServiceNamespace: h.Settings.ServiceNamespace,
|
ServiceNamespace: h.Settings.ServiceNamespace,
|
||||||
DeploymentEnv: h.Settings.DeploymentEnv,
|
DeploymentEnv: h.Settings.DeploymentEnv,
|
||||||
Enabled: h.Settings.OTELEnabled,
|
Enabled: h.Settings.OTELEnabled,
|
||||||
Endpoint: h.Settings.OTELEndpoint,
|
Endpoint: h.Settings.OTELEndpoint,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("telemetry setup: %w", err)
|
return fmt.Errorf("telemetry setup: %w", err)
|
||||||
}
|
}
|
||||||
defer shutdown(ctx)
|
defer shutdown(ctx)
|
||||||
h.Telemetry = tp
|
h.Telemetry = tp
|
||||||
|
|
||||||
// Health server
|
// Health server
|
||||||
healthSrv := health.New(
|
healthSrv := health.New(
|
||||||
h.Settings.HealthPort,
|
h.Settings.HealthPort,
|
||||||
h.Settings.HealthPath,
|
h.Settings.HealthPath,
|
||||||
h.Settings.ReadyPath,
|
h.Settings.ReadyPath,
|
||||||
func() bool { return h.running && h.NATS.IsConnected() },
|
func() bool { return h.running && h.NATS.IsConnected() },
|
||||||
)
|
)
|
||||||
healthSrv.Start()
|
healthSrv.Start()
|
||||||
defer healthSrv.Stop(ctx)
|
defer healthSrv.Stop(ctx)
|
||||||
|
|
||||||
// Connect to NATS
|
// Connect to NATS
|
||||||
if err := h.NATS.Connect(); err != nil {
|
if err := h.NATS.Connect(); err != nil {
|
||||||
return fmt.Errorf("nats: %w", err)
|
return fmt.Errorf("nats: %w", err)
|
||||||
}
|
}
|
||||||
defer h.NATS.Close()
|
defer h.NATS.Close()
|
||||||
|
|
||||||
// User setup
|
// User setup
|
||||||
if h.onSetup != nil {
|
if h.onSetup != nil {
|
||||||
slog.Info("running service setup")
|
slog.Info("running service setup")
|
||||||
if err := h.onSetup(ctx); err != nil {
|
if err := h.onSetup(ctx); err != nil {
|
||||||
return fmt.Errorf("setup: %w", err)
|
return fmt.Errorf("setup: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subscribe
|
// Subscribe
|
||||||
if h.onTypedMessage == nil {
|
if h.onTypedMessage == nil {
|
||||||
return fmt.Errorf("no message handler registered")
|
return fmt.Errorf("no message handler registered")
|
||||||
}
|
}
|
||||||
if err := h.NATS.Subscribe(h.Subject, h.wrapHandler(ctx), h.QueueGroup); err != nil {
|
if err := h.NATS.Subscribe(h.Subject, h.wrapHandler(ctx), h.QueueGroup); err != nil {
|
||||||
return fmt.Errorf("subscribe: %w", err)
|
return fmt.Errorf("subscribe: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
h.running = true
|
h.running = true
|
||||||
slog.Info("handler ready", "subject", h.Subject)
|
slog.Info("handler ready", "subject", h.Subject)
|
||||||
|
|
||||||
// Wait for shutdown signal
|
// Wait for shutdown signal
|
||||||
sigCh := make(chan os.Signal, 1)
|
sigCh := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
|
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
|
||||||
<-sigCh
|
<-sigCh
|
||||||
|
|
||||||
slog.Info("shutting down")
|
slog.Info("shutting down")
|
||||||
h.running = false
|
h.running = false
|
||||||
|
|
||||||
// Teardown
|
// Teardown
|
||||||
if h.onTeardown != nil {
|
if h.onTeardown != nil {
|
||||||
if err := h.onTeardown(ctx); err != nil {
|
if err := h.onTeardown(ctx); err != nil {
|
||||||
slog.Warn("teardown error", "error", err)
|
slog.Warn("teardown error", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("shutdown complete")
|
slog.Info("shutdown complete")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// wrapHandler creates a nats.MsgHandler that dispatches to the registered callback.
|
// wrapHandler creates a nats.MsgHandler that dispatches to the registered callback.
|
||||||
func (h *Handler) wrapHandler(ctx context.Context) nats.MsgHandler {
|
func (h *Handler) wrapHandler(ctx context.Context) nats.MsgHandler {
|
||||||
return func(msg *nats.Msg) {
|
return func(msg *nats.Msg) {
|
||||||
response, err := h.onTypedMessage(ctx, msg)
|
response, err := h.onTypedMessage(ctx, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("handler error", "subject", msg.Subject, "error", err)
|
slog.Error("handler error", "subject", msg.Subject, "error", err)
|
||||||
if msg.Reply != "" {
|
if msg.Reply != "" {
|
||||||
_ = h.NATS.Publish(msg.Reply, &pb.ErrorResponse{
|
_ = h.NATS.Publish(msg.Reply, &pb.ErrorResponse{
|
||||||
Error: true,
|
Error: true,
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
Type: fmt.Sprintf("%T", err),
|
Type: fmt.Sprintf("%T", err),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if response != nil && msg.Reply != "" {
|
if response != nil && msg.Reply != "" {
|
||||||
if err := h.NATS.Publish(msg.Reply, response); err != nil {
|
if err := h.NATS.Publish(msg.Reply, response); err != nil {
|
||||||
slog.Error("failed to publish reply", "error", err)
|
slog.Error("failed to publish reply", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
||||||
pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
|
pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
|
||||||
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
@@ -17,75 +17,75 @@ pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
|
|||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
func TestNewHandler(t *testing.T) {
|
func TestNewHandler(t *testing.T) {
|
||||||
cfg := config.Load()
|
cfg := config.Load()
|
||||||
cfg.ServiceName = "test-handler"
|
cfg.ServiceName = "test-handler"
|
||||||
cfg.NATSQueueGroup = "test-group"
|
cfg.NATSQueueGroup = "test-group"
|
||||||
|
|
||||||
h := New("ai.test.subject", cfg)
|
h := New("ai.test.subject", cfg)
|
||||||
if h.Subject != "ai.test.subject" {
|
if h.Subject != "ai.test.subject" {
|
||||||
t.Errorf("Subject = %q", h.Subject)
|
t.Errorf("Subject = %q", h.Subject)
|
||||||
}
|
}
|
||||||
if h.QueueGroup != "test-group" {
|
if h.QueueGroup != "test-group" {
|
||||||
t.Errorf("QueueGroup = %q", h.QueueGroup)
|
t.Errorf("QueueGroup = %q", h.QueueGroup)
|
||||||
}
|
}
|
||||||
if h.Settings.ServiceName != "test-handler" {
|
if h.Settings.ServiceName != "test-handler" {
|
||||||
t.Errorf("ServiceName = %q", h.Settings.ServiceName)
|
t.Errorf("ServiceName = %q", h.Settings.ServiceName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewHandlerNilSettings(t *testing.T) {
|
func TestNewHandlerNilSettings(t *testing.T) {
|
||||||
h := New("ai.test", nil)
|
h := New("ai.test", nil)
|
||||||
if h.Settings == nil {
|
if h.Settings == nil {
|
||||||
t.Fatal("Settings should be loaded automatically")
|
t.Fatal("Settings should be loaded automatically")
|
||||||
}
|
}
|
||||||
if h.Settings.ServiceName != "handler" {
|
if h.Settings.ServiceName != "handler" {
|
||||||
t.Errorf("ServiceName = %q, want default", h.Settings.ServiceName)
|
t.Errorf("ServiceName = %q, want default", h.Settings.ServiceName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCallbackRegistration(t *testing.T) {
|
func TestCallbackRegistration(t *testing.T) {
|
||||||
cfg := config.Load()
|
cfg := config.Load()
|
||||||
h := New("ai.test", cfg)
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
setupCalled := false
|
setupCalled := false
|
||||||
h.OnSetup(func(ctx context.Context) error {
|
h.OnSetup(func(ctx context.Context) error {
|
||||||
setupCalled = true
|
setupCalled = true
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
teardownCalled := false
|
teardownCalled := false
|
||||||
h.OnTeardown(func(ctx context.Context) error {
|
h.OnTeardown(func(ctx context.Context) error {
|
||||||
teardownCalled = true
|
teardownCalled = true
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
if h.onSetup == nil || h.onTeardown == nil || h.onTypedMessage == nil {
|
if h.onSetup == nil || h.onTeardown == nil || h.onTypedMessage == nil {
|
||||||
t.Error("callbacks should not be nil after registration")
|
t.Error("callbacks should not be nil after registration")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify setup/teardown work when called directly.
|
// Verify setup/teardown work when called directly.
|
||||||
_ = h.onSetup(context.Background())
|
_ = h.onSetup(context.Background())
|
||||||
_ = h.onTeardown(context.Background())
|
_ = h.onTeardown(context.Background())
|
||||||
if !setupCalled || !teardownCalled {
|
if !setupCalled || !teardownCalled {
|
||||||
t.Error("callbacks should have been invoked")
|
t.Error("callbacks should have been invoked")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTypedMessageRegistration(t *testing.T) {
|
func TestTypedMessageRegistration(t *testing.T) {
|
||||||
cfg := config.Load()
|
cfg := config.Load()
|
||||||
h := New("ai.test", cfg)
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
||||||
return &pb.ChatResponse{Response: "ok"}, nil
|
return &pb.ChatResponse{Response: "ok"}, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
if h.onTypedMessage == nil {
|
if h.onTypedMessage == nil {
|
||||||
t.Error("onTypedMessage should not be nil after registration")
|
t.Error("onTypedMessage should not be nil after registration")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
@@ -93,101 +93,101 @@ t.Error("onTypedMessage should not be nil after registration")
|
|||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
func TestWrapHandler_ValidMessage(t *testing.T) {
|
func TestWrapHandler_ValidMessage(t *testing.T) {
|
||||||
cfg := config.Load()
|
cfg := config.Load()
|
||||||
h := New("ai.test", cfg)
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
var receivedReq pb.ChatRequest
|
var receivedReq pb.ChatRequest
|
||||||
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
||||||
if err := natsutil.Decode(msg.Data, &receivedReq); err != nil {
|
if err := natsutil.Decode(msg.Data, &receivedReq); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &pb.ChatResponse{Response: "ok", UserId: receivedReq.GetUserId()}, nil
|
return &pb.ChatResponse{Response: "ok", UserId: receivedReq.GetUserId()}, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
// Encode a message the same way services would.
|
// Encode a message the same way services would.
|
||||||
encoded, err := proto.Marshal(&pb.ChatRequest{
|
encoded, err := proto.Marshal(&pb.ChatRequest{
|
||||||
RequestId: "test-001",
|
RequestId: "test-001",
|
||||||
Message: "hello",
|
Message: "hello",
|
||||||
Premium: true,
|
Premium: true,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call wrapHandler directly without NATS.
|
// Call wrapHandler directly without NATS.
|
||||||
handler := h.wrapHandler(context.Background())
|
handler := h.wrapHandler(context.Background())
|
||||||
handler(&nats.Msg{
|
handler(&nats.Msg{
|
||||||
Subject: "ai.test.user.42.message",
|
Subject: "ai.test.user.42.message",
|
||||||
Data: encoded,
|
Data: encoded,
|
||||||
})
|
})
|
||||||
|
|
||||||
if receivedReq.GetRequestId() != "test-001" {
|
if receivedReq.GetRequestId() != "test-001" {
|
||||||
t.Errorf("request_id = %v", receivedReq.GetRequestId())
|
t.Errorf("request_id = %v", receivedReq.GetRequestId())
|
||||||
}
|
}
|
||||||
if receivedReq.GetPremium() != true {
|
if receivedReq.GetPremium() != true {
|
||||||
t.Errorf("premium = %v", receivedReq.GetPremium())
|
t.Errorf("premium = %v", receivedReq.GetPremium())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWrapHandler_InvalidMessage(t *testing.T) {
|
func TestWrapHandler_InvalidMessage(t *testing.T) {
|
||||||
cfg := config.Load()
|
cfg := config.Load()
|
||||||
h := New("ai.test", cfg)
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
handlerCalled := false
|
handlerCalled := false
|
||||||
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
||||||
handlerCalled = true
|
handlerCalled = true
|
||||||
var req pb.ChatRequest
|
var req pb.ChatRequest
|
||||||
if err := natsutil.Decode(msg.Data, &req); err != nil {
|
if err := natsutil.Decode(msg.Data, &req); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &pb.ChatResponse{}, nil
|
return &pb.ChatResponse{}, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
handler := h.wrapHandler(context.Background())
|
handler := h.wrapHandler(context.Background())
|
||||||
handler(&nats.Msg{
|
handler(&nats.Msg{
|
||||||
Subject: "ai.test",
|
Subject: "ai.test",
|
||||||
Data: []byte{0xFF, 0xFE, 0xFD}, // invalid protobuf
|
Data: []byte{0xFF, 0xFE, 0xFD}, // invalid protobuf
|
||||||
})
|
})
|
||||||
|
|
||||||
// The handler IS called (wrapHandler doesn't pre-decode), but it should
|
// The handler IS called (wrapHandler doesn't pre-decode), but it should
|
||||||
// return an error from Decode. Either way no panic.
|
// return an error from Decode. Either way no panic.
|
||||||
_ = handlerCalled
|
_ = handlerCalled
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWrapHandler_HandlerError(t *testing.T) {
|
func TestWrapHandler_HandlerError(t *testing.T) {
|
||||||
cfg := config.Load()
|
cfg := config.Load()
|
||||||
h := New("ai.test", cfg)
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
||||||
return nil, context.DeadlineExceeded
|
return nil, context.DeadlineExceeded
|
||||||
})
|
})
|
||||||
|
|
||||||
encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "err-test"})
|
encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "err-test"})
|
||||||
handler := h.wrapHandler(context.Background())
|
handler := h.wrapHandler(context.Background())
|
||||||
|
|
||||||
// Should not panic even when handler returns error.
|
// Should not panic even when handler returns error.
|
||||||
handler(&nats.Msg{
|
handler(&nats.Msg{
|
||||||
Subject: "ai.test",
|
Subject: "ai.test",
|
||||||
Data: encoded,
|
Data: encoded,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWrapHandler_NilResponse(t *testing.T) {
|
func TestWrapHandler_NilResponse(t *testing.T) {
|
||||||
cfg := config.Load()
|
cfg := config.Load()
|
||||||
h := New("ai.test", cfg)
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
||||||
return nil, nil // fire-and-forget style
|
return nil, nil // fire-and-forget style
|
||||||
})
|
})
|
||||||
|
|
||||||
encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "nil-resp"})
|
encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "nil-resp"})
|
||||||
handler := h.wrapHandler(context.Background())
|
handler := h.wrapHandler(context.Background())
|
||||||
|
|
||||||
// Should not panic with nil response and no reply subject.
|
// Should not panic with nil response and no reply subject.
|
||||||
handler(&nats.Msg{
|
handler(&nats.Msg{
|
||||||
Subject: "ai.test",
|
Subject: "ai.test",
|
||||||
Data: encoded,
|
Data: encoded,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
@@ -195,59 +195,59 @@ Data: encoded,
|
|||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
func TestWrapHandler_Typed(t *testing.T) {
|
func TestWrapHandler_Typed(t *testing.T) {
|
||||||
cfg := config.Load()
|
cfg := config.Load()
|
||||||
h := New("ai.test", cfg)
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
var received pb.ChatRequest
|
var received pb.ChatRequest
|
||||||
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
||||||
if err := natsutil.Decode(msg.Data, &received); err != nil {
|
if err := natsutil.Decode(msg.Data, &received); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &pb.ChatResponse{UserId: received.GetUserId(), Response: "ok"}, nil
|
return &pb.ChatResponse{UserId: received.GetUserId(), Response: "ok"}, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
encoded, _ := proto.Marshal(&pb.ChatRequest{
|
encoded, _ := proto.Marshal(&pb.ChatRequest{
|
||||||
RequestId: "typed-001",
|
RequestId: "typed-001",
|
||||||
Message: "hello typed",
|
Message: "hello typed",
|
||||||
})
|
})
|
||||||
|
|
||||||
handler := h.wrapHandler(context.Background())
|
handler := h.wrapHandler(context.Background())
|
||||||
handler(&nats.Msg{Subject: "ai.test", Data: encoded})
|
handler(&nats.Msg{Subject: "ai.test", Data: encoded})
|
||||||
|
|
||||||
if received.GetRequestId() != "typed-001" {
|
if received.GetRequestId() != "typed-001" {
|
||||||
t.Errorf("RequestId = %q", received.GetRequestId())
|
t.Errorf("RequestId = %q", received.GetRequestId())
|
||||||
}
|
}
|
||||||
if received.GetMessage() != "hello typed" {
|
if received.GetMessage() != "hello typed" {
|
||||||
t.Errorf("Message = %q", received.GetMessage())
|
t.Errorf("Message = %q", received.GetMessage())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWrapHandler_TypedError(t *testing.T) {
|
func TestWrapHandler_TypedError(t *testing.T) {
|
||||||
cfg := config.Load()
|
cfg := config.Load()
|
||||||
h := New("ai.test", cfg)
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
||||||
return nil, context.DeadlineExceeded
|
return nil, context.DeadlineExceeded
|
||||||
})
|
})
|
||||||
|
|
||||||
encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "err"})
|
encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "err"})
|
||||||
handler := h.wrapHandler(context.Background())
|
handler := h.wrapHandler(context.Background())
|
||||||
|
|
||||||
// Should not panic.
|
// Should not panic.
|
||||||
handler(&nats.Msg{Subject: "ai.test", Data: encoded})
|
handler(&nats.Msg{Subject: "ai.test", Data: encoded})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWrapHandler_TypedNilResponse(t *testing.T) {
|
func TestWrapHandler_TypedNilResponse(t *testing.T) {
|
||||||
cfg := config.Load()
|
cfg := config.Load()
|
||||||
h := New("ai.test", cfg)
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "nil"})
|
encoded, _ := proto.Marshal(&pb.ChatRequest{RequestId: "nil"})
|
||||||
handler := h.wrapHandler(context.Background())
|
handler := h.wrapHandler(context.Background())
|
||||||
handler(&nats.Msg{Subject: "ai.test", Data: encoded})
|
handler(&nats.Msg{Subject: "ai.test", Data: encoded})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
@@ -255,25 +255,25 @@ handler(&nats.Msg{Subject: "ai.test", Data: encoded})
|
|||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
func BenchmarkWrapHandler(b *testing.B) {
|
func BenchmarkWrapHandler(b *testing.B) {
|
||||||
cfg := config.Load()
|
cfg := config.Load()
|
||||||
h := New("ai.test", cfg)
|
h := New("ai.test", cfg)
|
||||||
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (proto.Message, error) {
|
||||||
var req pb.ChatRequest
|
var req pb.ChatRequest
|
||||||
_ = natsutil.Decode(msg.Data, &req)
|
_ = natsutil.Decode(msg.Data, &req)
|
||||||
return &pb.ChatResponse{Response: "ok"}, nil
|
return &pb.ChatResponse{Response: "ok"}, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
encoded, _ := proto.Marshal(&pb.ChatRequest{
|
encoded, _ := proto.Marshal(&pb.ChatRequest{
|
||||||
RequestId: "bench-001",
|
RequestId: "bench-001",
|
||||||
Message: "What is the capital of France?",
|
Message: "What is the capital of France?",
|
||||||
Premium: true,
|
Premium: true,
|
||||||
TopK: 10,
|
TopK: 10,
|
||||||
})
|
})
|
||||||
handler := h.wrapHandler(context.Background())
|
handler := h.wrapHandler(context.Background())
|
||||||
msg := &nats.Msg{Subject: "ai.test", Data: encoded}
|
msg := &nats.Msg{Subject: "ai.test", Data: encoded}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
handler(msg)
|
handler(msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,17 +2,17 @@
|
|||||||
//
|
//
|
||||||
// Run with:
|
// Run with:
|
||||||
//
|
//
|
||||||
//go test -bench=. -benchmem -count=5 ./messages/... | tee bench.txt
|
// go test -bench=. -benchmem -count=5 ./messages/... | tee bench.txt
|
||||||
//# optional: go install golang.org/x/perf/cmd/benchstat@latest && benchstat bench.txt
|
// # optional: go install golang.org/x/perf/cmd/benchstat@latest && benchstat bench.txt
|
||||||
package messages
|
package messages
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
|
pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
@@ -20,39 +20,39 @@ pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
|
|||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
func chatRequestProto() *pb.ChatRequest {
|
func chatRequestProto() *pb.ChatRequest {
|
||||||
return &pb.ChatRequest{
|
return &pb.ChatRequest{
|
||||||
RequestId: "req-abc-123",
|
RequestId: "req-abc-123",
|
||||||
UserId: "user-42",
|
UserId: "user-42",
|
||||||
Message: "What is the capital of France?",
|
Message: "What is the capital of France?",
|
||||||
Premium: true,
|
Premium: true,
|
||||||
EnableRag: true,
|
EnableRag: true,
|
||||||
EnableReranker: true,
|
EnableReranker: true,
|
||||||
TopK: 10,
|
TopK: 10,
|
||||||
Collection: "documents",
|
Collection: "documents",
|
||||||
SystemPrompt: "You are a helpful assistant.",
|
SystemPrompt: "You are a helpful assistant.",
|
||||||
ResponseSubject: "ai.chat.response.req-abc-123",
|
ResponseSubject: "ai.chat.response.req-abc-123",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func voiceResponseProto() *pb.VoiceResponse {
|
func voiceResponseProto() *pb.VoiceResponse {
|
||||||
return &pb.VoiceResponse{
|
return &pb.VoiceResponse{
|
||||||
RequestId: "vr-001",
|
RequestId: "vr-001",
|
||||||
Response: "The capital of France is Paris.",
|
Response: "The capital of France is Paris.",
|
||||||
Audio: make([]byte, 16384),
|
Audio: make([]byte, 16384),
|
||||||
Transcription: "What is the capital of France?",
|
Transcription: "What is the capital of France?",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ttsChunkProto() *pb.TTSAudioChunk {
|
func ttsChunkProto() *pb.TTSAudioChunk {
|
||||||
return &pb.TTSAudioChunk{
|
return &pb.TTSAudioChunk{
|
||||||
SessionId: "tts-sess-99",
|
SessionId: "tts-sess-99",
|
||||||
ChunkIndex: 3,
|
ChunkIndex: 3,
|
||||||
TotalChunks: 12,
|
TotalChunks: 12,
|
||||||
Audio: make([]byte, 32768),
|
Audio: make([]byte, 32768),
|
||||||
IsLast: false,
|
IsLast: false,
|
||||||
Timestamp: time.Now().Unix(),
|
Timestamp: time.Now().Unix(),
|
||||||
SampleRate: 24000,
|
SampleRate: 24000,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
@@ -60,19 +60,19 @@ SampleRate: 24000,
|
|||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
func TestWireSize(t *testing.T) {
|
func TestWireSize(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
protoMsg proto.Message
|
protoMsg proto.Message
|
||||||
}{
|
}{
|
||||||
{"ChatRequest", chatRequestProto()},
|
{"ChatRequest", chatRequestProto()},
|
||||||
{"VoiceResponse", voiceResponseProto()},
|
{"VoiceResponse", voiceResponseProto()},
|
||||||
{"TTSAudioChunk", ttsChunkProto()},
|
{"TTSAudioChunk", ttsChunkProto()},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
protoBytes, _ := proto.Marshal(tt.protoMsg)
|
protoBytes, _ := proto.Marshal(tt.protoMsg)
|
||||||
t.Logf("%-16s proto=%5d B", tt.name, len(protoBytes))
|
t.Logf("%-16s proto=%5d B", tt.name, len(protoBytes))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
@@ -80,27 +80,27 @@ t.Logf("%-16s proto=%5d B", tt.name, len(protoBytes))
|
|||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
func BenchmarkEncode_ChatRequest(b *testing.B) {
|
func BenchmarkEncode_ChatRequest(b *testing.B) {
|
||||||
data := chatRequestProto()
|
data := chatRequestProto()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
_, _ = proto.Marshal(data)
|
_, _ = proto.Marshal(data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkEncode_VoiceResponse(b *testing.B) {
|
func BenchmarkEncode_VoiceResponse(b *testing.B) {
|
||||||
data := voiceResponseProto()
|
data := voiceResponseProto()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
_, _ = proto.Marshal(data)
|
_, _ = proto.Marshal(data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkEncode_TTSChunk(b *testing.B) {
|
func BenchmarkEncode_TTSChunk(b *testing.B) {
|
||||||
data := ttsChunkProto()
|
data := ttsChunkProto()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
_, _ = proto.Marshal(data)
|
_, _ = proto.Marshal(data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
@@ -108,30 +108,30 @@ _, _ = proto.Marshal(data)
|
|||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
func BenchmarkDecode_ChatRequest(b *testing.B) {
|
func BenchmarkDecode_ChatRequest(b *testing.B) {
|
||||||
encoded, _ := proto.Marshal(chatRequestProto())
|
encoded, _ := proto.Marshal(chatRequestProto())
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
var m pb.ChatRequest
|
var m pb.ChatRequest
|
||||||
_ = proto.Unmarshal(encoded, &m)
|
_ = proto.Unmarshal(encoded, &m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkDecode_VoiceResponse(b *testing.B) {
|
func BenchmarkDecode_VoiceResponse(b *testing.B) {
|
||||||
encoded, _ := proto.Marshal(voiceResponseProto())
|
encoded, _ := proto.Marshal(voiceResponseProto())
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
var m pb.VoiceResponse
|
var m pb.VoiceResponse
|
||||||
_ = proto.Unmarshal(encoded, &m)
|
_ = proto.Unmarshal(encoded, &m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkDecode_TTSChunk(b *testing.B) {
|
func BenchmarkDecode_TTSChunk(b *testing.B) {
|
||||||
encoded, _ := proto.Marshal(ttsChunkProto())
|
encoded, _ := proto.Marshal(ttsChunkProto())
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
var m pb.TTSAudioChunk
|
var m pb.TTSAudioChunk
|
||||||
_ = proto.Unmarshal(encoded, &m)
|
_ = proto.Unmarshal(encoded, &m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
@@ -139,13 +139,13 @@ _ = proto.Unmarshal(encoded, &m)
|
|||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
func BenchmarkRoundtrip_ChatRequest(b *testing.B) {
|
func BenchmarkRoundtrip_ChatRequest(b *testing.B) {
|
||||||
data := chatRequestProto()
|
data := chatRequestProto()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
enc, _ := proto.Marshal(data)
|
enc, _ := proto.Marshal(data)
|
||||||
var dec pb.ChatRequest
|
var dec pb.ChatRequest
|
||||||
_ = proto.Unmarshal(enc, &dec)
|
_ = proto.Unmarshal(enc, &dec)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
@@ -153,160 +153,160 @@ _ = proto.Unmarshal(enc, &dec)
|
|||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
func TestRoundtrip_ChatRequest(t *testing.T) {
|
func TestRoundtrip_ChatRequest(t *testing.T) {
|
||||||
orig := chatRequestProto()
|
orig := chatRequestProto()
|
||||||
data, err := proto.Marshal(orig)
|
data, err := proto.Marshal(orig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
var dec pb.ChatRequest
|
var dec pb.ChatRequest
|
||||||
if err := proto.Unmarshal(data, &dec); err != nil {
|
if err := proto.Unmarshal(data, &dec); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if dec.GetRequestId() != orig.GetRequestId() {
|
if dec.GetRequestId() != orig.GetRequestId() {
|
||||||
t.Errorf("RequestId = %q, want %q", dec.GetRequestId(), orig.GetRequestId())
|
t.Errorf("RequestId = %q, want %q", dec.GetRequestId(), orig.GetRequestId())
|
||||||
}
|
}
|
||||||
if dec.GetMessage() != orig.GetMessage() {
|
if dec.GetMessage() != orig.GetMessage() {
|
||||||
t.Errorf("Message = %q, want %q", dec.GetMessage(), orig.GetMessage())
|
t.Errorf("Message = %q, want %q", dec.GetMessage(), orig.GetMessage())
|
||||||
}
|
}
|
||||||
if dec.GetTopK() != orig.GetTopK() {
|
if dec.GetTopK() != orig.GetTopK() {
|
||||||
t.Errorf("TopK = %d, want %d", dec.GetTopK(), orig.GetTopK())
|
t.Errorf("TopK = %d, want %d", dec.GetTopK(), orig.GetTopK())
|
||||||
}
|
}
|
||||||
if dec.GetPremium() != orig.GetPremium() {
|
if dec.GetPremium() != orig.GetPremium() {
|
||||||
t.Errorf("Premium = %v, want %v", dec.GetPremium(), orig.GetPremium())
|
t.Errorf("Premium = %v, want %v", dec.GetPremium(), orig.GetPremium())
|
||||||
}
|
}
|
||||||
if EffectiveQuery(&dec) != orig.GetMessage() {
|
if EffectiveQuery(&dec) != orig.GetMessage() {
|
||||||
t.Errorf("EffectiveQuery() = %q, want %q", EffectiveQuery(&dec), orig.GetMessage())
|
t.Errorf("EffectiveQuery() = %q, want %q", EffectiveQuery(&dec), orig.GetMessage())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRoundtrip_VoiceResponse(t *testing.T) {
|
func TestRoundtrip_VoiceResponse(t *testing.T) {
|
||||||
orig := voiceResponseProto()
|
orig := voiceResponseProto()
|
||||||
data, err := proto.Marshal(orig)
|
data, err := proto.Marshal(orig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
var dec pb.VoiceResponse
|
var dec pb.VoiceResponse
|
||||||
if err := proto.Unmarshal(data, &dec); err != nil {
|
if err := proto.Unmarshal(data, &dec); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if dec.GetRequestId() != orig.GetRequestId() {
|
if dec.GetRequestId() != orig.GetRequestId() {
|
||||||
t.Errorf("RequestId mismatch")
|
t.Errorf("RequestId mismatch")
|
||||||
}
|
}
|
||||||
if len(dec.GetAudio()) != len(orig.GetAudio()) {
|
if len(dec.GetAudio()) != len(orig.GetAudio()) {
|
||||||
t.Errorf("Audio len = %d, want %d", len(dec.GetAudio()), len(orig.GetAudio()))
|
t.Errorf("Audio len = %d, want %d", len(dec.GetAudio()), len(orig.GetAudio()))
|
||||||
}
|
}
|
||||||
if dec.GetTranscription() != orig.GetTranscription() {
|
if dec.GetTranscription() != orig.GetTranscription() {
|
||||||
t.Errorf("Transcription mismatch")
|
t.Errorf("Transcription mismatch")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRoundtrip_TTSAudioChunk(t *testing.T) {
|
func TestRoundtrip_TTSAudioChunk(t *testing.T) {
|
||||||
orig := ttsChunkProto()
|
orig := ttsChunkProto()
|
||||||
data, err := proto.Marshal(orig)
|
data, err := proto.Marshal(orig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
var dec pb.TTSAudioChunk
|
var dec pb.TTSAudioChunk
|
||||||
if err := proto.Unmarshal(data, &dec); err != nil {
|
if err := proto.Unmarshal(data, &dec); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if dec.GetSessionId() != orig.GetSessionId() {
|
if dec.GetSessionId() != orig.GetSessionId() {
|
||||||
t.Errorf("SessionId mismatch")
|
t.Errorf("SessionId mismatch")
|
||||||
}
|
}
|
||||||
if dec.GetChunkIndex() != orig.GetChunkIndex() {
|
if dec.GetChunkIndex() != orig.GetChunkIndex() {
|
||||||
t.Errorf("ChunkIndex = %d, want %d", dec.GetChunkIndex(), orig.GetChunkIndex())
|
t.Errorf("ChunkIndex = %d, want %d", dec.GetChunkIndex(), orig.GetChunkIndex())
|
||||||
}
|
}
|
||||||
if len(dec.GetAudio()) != len(orig.GetAudio()) {
|
if len(dec.GetAudio()) != len(orig.GetAudio()) {
|
||||||
t.Errorf("Audio len = %d, want %d", len(dec.GetAudio()), len(orig.GetAudio()))
|
t.Errorf("Audio len = %d, want %d", len(dec.GetAudio()), len(orig.GetAudio()))
|
||||||
}
|
}
|
||||||
if dec.GetSampleRate() != orig.GetSampleRate() {
|
if dec.GetSampleRate() != orig.GetSampleRate() {
|
||||||
t.Errorf("SampleRate = %d, want %d", dec.GetSampleRate(), orig.GetSampleRate())
|
t.Errorf("SampleRate = %d, want %d", dec.GetSampleRate(), orig.GetSampleRate())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRoundtrip_PipelineTrigger(t *testing.T) {
|
func TestRoundtrip_PipelineTrigger(t *testing.T) {
|
||||||
orig := &pb.PipelineTrigger{
|
orig := &pb.PipelineTrigger{
|
||||||
RequestId: "pip-001",
|
RequestId: "pip-001",
|
||||||
Pipeline: "document-ingestion",
|
Pipeline: "document-ingestion",
|
||||||
Parameters: map[string]string{"source": "s3://bucket/data"},
|
Parameters: map[string]string{"source": "s3://bucket/data"},
|
||||||
}
|
}
|
||||||
data, err := proto.Marshal(orig)
|
data, err := proto.Marshal(orig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
var dec pb.PipelineTrigger
|
var dec pb.PipelineTrigger
|
||||||
if err := proto.Unmarshal(data, &dec); err != nil {
|
if err := proto.Unmarshal(data, &dec); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if dec.GetPipeline() != orig.GetPipeline() {
|
if dec.GetPipeline() != orig.GetPipeline() {
|
||||||
t.Errorf("Pipeline = %q, want %q", dec.GetPipeline(), orig.GetPipeline())
|
t.Errorf("Pipeline = %q, want %q", dec.GetPipeline(), orig.GetPipeline())
|
||||||
}
|
}
|
||||||
if dec.GetParameters()["source"] != orig.GetParameters()["source"] {
|
if dec.GetParameters()["source"] != orig.GetParameters()["source"] {
|
||||||
t.Errorf("Parameters[source] mismatch")
|
t.Errorf("Parameters[source] mismatch")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRoundtrip_STTTranscription(t *testing.T) {
|
func TestRoundtrip_STTTranscription(t *testing.T) {
|
||||||
orig := &pb.STTTranscription{
|
orig := &pb.STTTranscription{
|
||||||
SessionId: "stt-001",
|
SessionId: "stt-001",
|
||||||
Transcript: "hello world",
|
Transcript: "hello world",
|
||||||
Sequence: 5,
|
Sequence: 5,
|
||||||
IsPartial: false,
|
IsPartial: false,
|
||||||
IsFinal: true,
|
IsFinal: true,
|
||||||
Timestamp: time.Now().Unix(),
|
Timestamp: time.Now().Unix(),
|
||||||
SpeakerId: "speaker-1",
|
SpeakerId: "speaker-1",
|
||||||
HasVoiceActivity: true,
|
HasVoiceActivity: true,
|
||||||
State: "listening",
|
State: "listening",
|
||||||
}
|
}
|
||||||
data, err := proto.Marshal(orig)
|
data, err := proto.Marshal(orig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
var dec pb.STTTranscription
|
var dec pb.STTTranscription
|
||||||
if err := proto.Unmarshal(data, &dec); err != nil {
|
if err := proto.Unmarshal(data, &dec); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if dec.GetTranscript() != orig.GetTranscript() {
|
if dec.GetTranscript() != orig.GetTranscript() {
|
||||||
t.Errorf("Transcript = %q, want %q", dec.GetTranscript(), orig.GetTranscript())
|
t.Errorf("Transcript = %q, want %q", dec.GetTranscript(), orig.GetTranscript())
|
||||||
}
|
}
|
||||||
if dec.GetIsFinal() != orig.GetIsFinal() {
|
if dec.GetIsFinal() != orig.GetIsFinal() {
|
||||||
t.Error("IsFinal mismatch")
|
t.Error("IsFinal mismatch")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRoundtrip_ErrorResponse(t *testing.T) {
|
func TestRoundtrip_ErrorResponse(t *testing.T) {
|
||||||
orig := &pb.ErrorResponse{Error: true, Message: "something broke", Type: "InternalError"}
|
orig := &pb.ErrorResponse{Error: true, Message: "something broke", Type: "InternalError"}
|
||||||
data, err := proto.Marshal(orig)
|
data, err := proto.Marshal(orig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
var dec pb.ErrorResponse
|
var dec pb.ErrorResponse
|
||||||
if err := proto.Unmarshal(data, &dec); err != nil {
|
if err := proto.Unmarshal(data, &dec); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if !dec.GetError() || dec.GetMessage() != "something broke" || dec.GetType() != "InternalError" {
|
if !dec.GetError() || dec.GetMessage() != "something broke" || dec.GetType() != "InternalError" {
|
||||||
t.Errorf("ErrorResponse roundtrip mismatch: %+v", &dec)
|
t.Errorf("ErrorResponse roundtrip mismatch: %+v", &dec)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEffectiveQuery_MessageSet(t *testing.T) {
|
func TestEffectiveQuery_MessageSet(t *testing.T) {
|
||||||
req := &pb.ChatRequest{Message: "hello", Query: "world"}
|
req := &pb.ChatRequest{Message: "hello", Query: "world"}
|
||||||
if got := EffectiveQuery(req); got != "hello" {
|
if got := EffectiveQuery(req); got != "hello" {
|
||||||
t.Errorf("EffectiveQuery() = %q, want %q", got, "hello")
|
t.Errorf("EffectiveQuery() = %q, want %q", got, "hello")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEffectiveQuery_FallbackToQuery(t *testing.T) {
|
func TestEffectiveQuery_FallbackToQuery(t *testing.T) {
|
||||||
req := &pb.ChatRequest{Query: "world"}
|
req := &pb.ChatRequest{Query: "world"}
|
||||||
if got := EffectiveQuery(req); got != "world" {
|
if got := EffectiveQuery(req); got != "world" {
|
||||||
t.Errorf("EffectiveQuery() = %q, want %q", got, "world")
|
t.Errorf("EffectiveQuery() = %q, want %q", got, "world")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTimestamp(t *testing.T) {
|
func TestTimestamp(t *testing.T) {
|
||||||
ts := Timestamp()
|
ts := Timestamp()
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
if ts < now-1 || ts > now+1 {
|
if ts < now-1 || ts > now+1 {
|
||||||
t.Errorf("Timestamp() = %d, expected ~%d", ts, now)
|
t.Errorf("Timestamp() = %d, expected ~%d", ts, now)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,9 +8,9 @@
|
|||||||
package messages
|
package messages
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
|
pb "git.daviestechlabs.io/daviestechlabs/handler-base/gen/messagespb"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ════════════════════════════════════════════════════════════════════════════
|
// ════════════════════════════════════════════════════════════════════════════
|
||||||
@@ -57,13 +57,13 @@ type PipelineStatus = pb.PipelineStatus
|
|||||||
|
|
||||||
// EffectiveQuery returns Message or falls back to Query.
|
// EffectiveQuery returns Message or falls back to Query.
|
||||||
func EffectiveQuery(c *ChatRequest) string {
|
func EffectiveQuery(c *ChatRequest) string {
|
||||||
if c.GetMessage() != "" {
|
if c.GetMessage() != "" {
|
||||||
return c.GetMessage()
|
return c.GetMessage()
|
||||||
}
|
}
|
||||||
return c.GetQuery()
|
return c.GetQuery()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Timestamp returns the current Unix timestamp.
|
// Timestamp returns the current Unix timestamp.
|
||||||
func Timestamp() int64 {
|
func Timestamp() int64 {
|
||||||
return time.Now().Unix()
|
return time.Now().Unix()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,69 +2,69 @@
|
|||||||
package natsutil
|
package natsutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Client wraps a NATS connection with protobuf helpers.
|
// Client wraps a NATS connection with protobuf helpers.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
nc *nats.Conn
|
nc *nats.Conn
|
||||||
js nats.JetStreamContext
|
js nats.JetStreamContext
|
||||||
subs []*nats.Subscription
|
subs []*nats.Subscription
|
||||||
url string
|
url string
|
||||||
opts []nats.Option
|
opts []nats.Option
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a NATS client configured to connect to the given URL.
|
// New creates a NATS client configured to connect to the given URL.
|
||||||
// Optional NATS options (e.g. credentials) can be appended.
|
// Optional NATS options (e.g. credentials) can be appended.
|
||||||
func New(url string, opts ...nats.Option) *Client {
|
func New(url string, opts ...nats.Option) *Client {
|
||||||
defaults := []nats.Option{
|
defaults := []nats.Option{
|
||||||
nats.ReconnectWait(2 * time.Second),
|
nats.ReconnectWait(2 * time.Second),
|
||||||
nats.MaxReconnects(-1),
|
nats.MaxReconnects(-1),
|
||||||
nats.DisconnectErrHandler(func(_ *nats.Conn, err error) {
|
nats.DisconnectErrHandler(func(_ *nats.Conn, err error) {
|
||||||
slog.Warn("NATS disconnected", "error", err)
|
slog.Warn("NATS disconnected", "error", err)
|
||||||
}),
|
}),
|
||||||
nats.ReconnectHandler(func(_ *nats.Conn) {
|
nats.ReconnectHandler(func(_ *nats.Conn) {
|
||||||
slog.Info("NATS reconnected")
|
slog.Info("NATS reconnected")
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
return &Client{
|
return &Client{
|
||||||
url: url,
|
url: url,
|
||||||
opts: append(defaults, opts...),
|
opts: append(defaults, opts...),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect establishes the NATS connection and JetStream context.
|
// Connect establishes the NATS connection and JetStream context.
|
||||||
func (c *Client) Connect() error {
|
func (c *Client) Connect() error {
|
||||||
nc, err := nats.Connect(c.url, c.opts...)
|
nc, err := nats.Connect(c.url, c.opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("nats connect: %w", err)
|
return fmt.Errorf("nats connect: %w", err)
|
||||||
}
|
}
|
||||||
js, err := nc.JetStream()
|
js, err := nc.JetStream()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
nc.Close()
|
nc.Close()
|
||||||
return fmt.Errorf("jetstream: %w", err)
|
return fmt.Errorf("jetstream: %w", err)
|
||||||
}
|
}
|
||||||
c.nc = nc
|
c.nc = nc
|
||||||
c.js = js
|
c.js = js
|
||||||
slog.Info("connected to NATS", "url", c.url)
|
slog.Info("connected to NATS", "url", c.url)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close drains subscriptions and closes the connection.
|
// Close drains subscriptions and closes the connection.
|
||||||
func (c *Client) Close() {
|
func (c *Client) Close() {
|
||||||
if c.nc == nil {
|
if c.nc == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, sub := range c.subs {
|
for _, sub := range c.subs {
|
||||||
_ = sub.Drain()
|
_ = sub.Drain()
|
||||||
}
|
}
|
||||||
c.nc.Close()
|
c.nc.Close()
|
||||||
slog.Info("NATS connection closed")
|
slog.Info("NATS connection closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conn returns the underlying *nats.Conn.
|
// Conn returns the underlying *nats.Conn.
|
||||||
@@ -75,56 +75,56 @@ func (c *Client) JS() nats.JetStreamContext { return c.js }
|
|||||||
|
|
||||||
// IsConnected returns true if the NATS connection is active.
|
// IsConnected returns true if the NATS connection is active.
|
||||||
func (c *Client) IsConnected() bool {
|
func (c *Client) IsConnected() bool {
|
||||||
return c.nc != nil && c.nc.IsConnected()
|
return c.nc != nil && c.nc.IsConnected()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subscribe subscribes to a subject with an optional queue group.
|
// Subscribe subscribes to a subject with an optional queue group.
|
||||||
// The handler receives the raw *nats.Msg.
|
// The handler receives the raw *nats.Msg.
|
||||||
func (c *Client) Subscribe(subject string, handler nats.MsgHandler, queue string) error {
|
func (c *Client) Subscribe(subject string, handler nats.MsgHandler, queue string) error {
|
||||||
var sub *nats.Subscription
|
var sub *nats.Subscription
|
||||||
var err error
|
var err error
|
||||||
if queue != "" {
|
if queue != "" {
|
||||||
sub, err = c.nc.QueueSubscribe(subject, queue, handler)
|
sub, err = c.nc.QueueSubscribe(subject, queue, handler)
|
||||||
slog.Info("subscribed", "subject", subject, "queue", queue)
|
slog.Info("subscribed", "subject", subject, "queue", queue)
|
||||||
} else {
|
} else {
|
||||||
sub, err = c.nc.Subscribe(subject, handler)
|
sub, err = c.nc.Subscribe(subject, handler)
|
||||||
slog.Info("subscribed", "subject", subject)
|
slog.Info("subscribed", "subject", subject)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("subscribe %s: %w", subject, err)
|
return fmt.Errorf("subscribe %s: %w", subject, err)
|
||||||
}
|
}
|
||||||
c.subs = append(c.subs, sub)
|
c.subs = append(c.subs, sub)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Publish encodes data as protobuf and publishes to the subject.
|
// Publish encodes data as protobuf and publishes to the subject.
|
||||||
func (c *Client) Publish(subject string, data proto.Message) error {
|
func (c *Client) Publish(subject string, data proto.Message) error {
|
||||||
payload, err := proto.Marshal(data)
|
payload, err := proto.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("proto marshal: %w", err)
|
return fmt.Errorf("proto marshal: %w", err)
|
||||||
}
|
}
|
||||||
return c.nc.Publish(subject, payload)
|
return c.nc.Publish(subject, payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PublishRaw publishes pre-encoded bytes to the subject.
|
// PublishRaw publishes pre-encoded bytes to the subject.
|
||||||
func (c *Client) PublishRaw(subject string, data []byte) error {
|
func (c *Client) PublishRaw(subject string, data []byte) error {
|
||||||
return c.nc.Publish(subject, data)
|
return c.nc.Publish(subject, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Request sends a protobuf-encoded request and decodes the response into result.
|
// Request sends a protobuf-encoded request and decodes the response into result.
|
||||||
func (c *Client) Request(subject string, data proto.Message, result proto.Message, timeout time.Duration) error {
|
func (c *Client) Request(subject string, data proto.Message, result proto.Message, timeout time.Duration) error {
|
||||||
payload, err := proto.Marshal(data)
|
payload, err := proto.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("proto marshal: %w", err)
|
return fmt.Errorf("proto marshal: %w", err)
|
||||||
}
|
}
|
||||||
msg, err := c.nc.Request(subject, payload, timeout)
|
msg, err := c.nc.Request(subject, payload, timeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("nats request: %w", err)
|
return fmt.Errorf("nats request: %w", err)
|
||||||
}
|
}
|
||||||
return proto.Unmarshal(msg.Data, result)
|
return proto.Unmarshal(msg.Data, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode unmarshals protobuf bytes into dest.
|
// Decode unmarshals protobuf bytes into dest.
|
||||||
func Decode(data []byte, dest proto.Message) error {
|
func Decode(data []byte, dest proto.Message) error {
|
||||||
return proto.Unmarshal(data, dest)
|
return proto.Unmarshal(data, dest)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user