refactor: rewrite handler-base as Go module

Replace Python handler-base library with Go module providing:
- config: environment-based configuration
- health: HTTP health/readiness server for k8s probes
- natsutil: NATS/JetStream client with msgpack serialization
- telemetry: OpenTelemetry tracing and metrics setup
- clients: HTTP clients for LLM, embeddings, reranker, STT, TTS
- handler: base Handler runner wiring NATS + health + telemetry

Implements ADR-0061 Phase 1.
This commit is contained in:
2026-02-19 17:16:17 -05:00
parent 5eb2c43a5d
commit d321c9852b
38 changed files with 1345 additions and 6971 deletions

35
.gitignore vendored
View File

@@ -1,23 +1,20 @@
# Python
__pycache__/
*.py[cod]
*$py.class
# Go
*.exe
*.dll
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.dylib
*.test
*.out
vendor/
# IDE
.idea/
.vscode/
*.swp
# OS
.DS_Store
Thumbs.db
*.egg
# Virtual environments

View File

@@ -1,32 +0,0 @@
# Pre-commit hooks for handler-base
# Install: pip install pre-commit && pre-commit install
# Run: pre-commit run --all-files
repos:
# Ruff - fast Python linter and formatter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format
# Standard pre-commit hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
args: [--maxkb=500]
- id: check-merge-conflict
- id: detect-private-key
# Type checking (optional - uncomment when ready)
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.10.0
# hooks:
# - id: mypy
# additional_dependencies: [types-all]
# args: [--ignore-missing-imports]

View File

@@ -1,58 +0,0 @@
# Handler Base Image
#
# Provides a pre-built base with all common dependencies for handler services.
# Services extend this image and add their specific code.
#
# Build:
# docker build -t ghcr.io/billy-davies-2/handler-base:latest .
#
# Usage in child Dockerfile:
# FROM ghcr.io/billy-davies-2/handler-base:latest
# COPY my_handler.py .
# CMD ["python", "my_handler.py"]
FROM python:3.13-slim AS base
WORKDIR /app
# Install uv for fast, reliable package management
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy and install handler-base package
COPY pyproject.toml README.md ./
COPY handler_base/ ./handler_base/
# Install the package with all dependencies
RUN uv pip install --system --no-cache .
# Set environment variables
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
# Default health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
CMD curl -f http://localhost:8080/health || exit 1
# Default command (override in child images)
CMD ["python", "-c", "print('handler-base ready')"]
# Audio variant with soundfile, librosa, webrtcvad
FROM base AS audio
RUN apt-get update && apt-get install -y --no-install-recommends \
ffmpeg \
libsndfile1 \
gcc \
python3-dev \
&& rm -rf /var/lib/apt/lists/*
RUN uv pip install --system --no-cache \
soundfile>=0.12.0 \
librosa>=0.10.0 \
webrtcvad>=2.0.10

129
README.md
View File

@@ -1,109 +1,44 @@
# Handler Base
# handler-base
Shared base library for building NATS-based AI/ML handler services.
Go module providing shared infrastructure for NATS-based handler services.
## Installation
## Packages
```bash
pip install handler-base
```
| Package | Purpose |
|---------|---------|
| `config` | Environment-based configuration via struct fields |
| `health` | HTTP health/readiness server for Kubernetes probes |
| `natsutil` | NATS/JetStream client with msgpack serialization |
| `telemetry` | OpenTelemetry tracing and metrics setup |
| `clients` | HTTP clients for LLM, embeddings, reranker, STT, TTS |
| `handler` | Base Handler runner wiring NATS + health + telemetry |
Or from Gitea:
```bash
pip install git+https://git.daviestechlabs.io/daviestechlabs/handler-base.git
```
## Usage
## Quick Start
```go
package main
```python
from handler_base import Handler, Settings
from nats.aio.msg import Msg
class MyHandler(Handler):
async def setup(self):
# Initialize your clients
pass
async def handle_message(self, msg: Msg, data: dict):
# Process the message
result = {"processed": True}
return result
if __name__ == "__main__":
MyHandler(subject="my.subject").run()
```
## Features
- **Handler base class** - NATS subscription, graceful shutdown, signal handling
- **NATSClient** - Connection management, JetStream, msgpack serialization
- **Settings** - Pydantic-based configuration from environment
- **HealthServer** - Kubernetes liveness/readiness probes
- **Telemetry** - OpenTelemetry tracing and metrics
- **Service clients** - HTTP wrappers for AI services
## Service Clients
```python
from handler_base.clients import (
STTClient, # Whisper speech-to-text
TTSClient, # XTTS text-to-speech
LLMClient, # vLLM chat completions
EmbeddingsClient, # BGE embeddings
RerankerClient, # BGE reranker
MilvusClient, # Vector database
import (
"context"
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
"git.daviestechlabs.io/daviestechlabs/handler-base/handler"
"github.com/nats-io/nats.go"
)
func main() {
cfg := config.Load()
cfg.ServiceName = "my-service"
h := handler.New("my.subject", cfg)
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
return map[string]any{"ok": true}, nil
})
h.Run()
}
```
## Configuration
## Testing
All settings via environment variables:
| Variable | Default | Description |
|----------|---------|-------------|
| `NATS_URL` | `nats://localhost:4222` | NATS server URL |
| `NATS_USER` | - | NATS username |
| `NATS_PASSWORD` | - | NATS password |
| `NATS_QUEUE_GROUP` | - | Queue group for load balancing |
| `HEALTH_PORT` | `8080` | Health check server port |
| `OTEL_ENABLED` | `true` | Enable OpenTelemetry |
| `OTEL_EXPORTER_OTLP_ENDPOINT` | `http://localhost:4317` | OTLP endpoint |
| `OTEL_SERVICE_NAME` | `handler` | Service name for traces |
## Docker
```dockerfile
FROM ghcr.io/daviestechlabs/handler-base:latest
COPY my_handler.py /app/
CMD ["python", "/app/my_handler.py"]
```
Or build with audio support:
```bash
docker build --build-arg INSTALL_AUDIO=true -t my-handler .
go test ./...
```
## Module Structure
```
handler_base/
├── __init__.py # Public API exports
├── handler.py # Base Handler class
├── nats_client.py # NATS connection wrapper
├── config.py # Pydantic Settings
├── health.py # Health check server
├── telemetry.py # OpenTelemetry setup
└── clients/
├── embeddings.py
├── llm.py
├── milvus.py
├── reranker.py
├── stt.py
└── tts.py
```
## Related
- [voice-assistant](https://git.daviestechlabs.io/daviestechlabs/voice-assistant) - Voice pipeline using handler-base
- [homelab-design](https://git.daviestechlabs.io/daviestechlabs/homelab-design) - Architecture docs

389
clients/clients.go Normal file
View File

@@ -0,0 +1,389 @@
// Package clients provides HTTP client wrappers for AI/ML backend services.
package clients
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/url"
"time"
)
// httpClient is a shared interface for all service clients.
type httpClient struct {
client *http.Client
baseURL string
}
func newHTTPClient(baseURL string, timeout time.Duration) *httpClient {
return &httpClient{
client: &http.Client{Timeout: timeout},
baseURL: baseURL,
}
}
func (h *httpClient) postJSON(ctx context.Context, path string, body any) ([]byte, error) {
data, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("marshal: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, bytes.NewReader(data))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
return h.do(req)
}
func (h *httpClient) get(ctx context.Context, path string, params url.Values) ([]byte, error) {
u := h.baseURL + path
if len(params) > 0 {
u += "?" + params.Encode()
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return nil, err
}
return h.do(req)
}
func (h *httpClient) getRaw(ctx context.Context, path string, params url.Values) ([]byte, error) {
return h.get(ctx, path, params)
}
func (h *httpClient) postMultipart(ctx context.Context, path string, fieldName string, fileName string, fileData []byte, fields map[string]string) ([]byte, error) {
var buf bytes.Buffer
w := multipart.NewWriter(&buf)
part, err := w.CreateFormFile(fieldName, fileName)
if err != nil {
return nil, err
}
if _, err := part.Write(fileData); err != nil {
return nil, err
}
for k, v := range fields {
_ = w.WriteField(k, v)
}
_ = w.Close()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, &buf)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", w.FormDataContentType())
return h.do(req)
}
func (h *httpClient) do(req *http.Request) ([]byte, error) {
resp, err := h.client.Do(req)
if err != nil {
return nil, fmt.Errorf("http %s %s: %w", req.Method, req.URL.Path, err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read body: %w", err)
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("http %d: %s", resp.StatusCode, string(body))
}
return body, nil
}
func (h *httpClient) healthCheck(ctx context.Context) bool {
data, err := h.get(ctx, "/health", nil)
_ = data
return err == nil
}
// --- Embeddings Client ---
// EmbeddingsClient calls the embeddings service (Infinity/BGE).
type EmbeddingsClient struct {
*httpClient
Model string
}
// NewEmbeddingsClient creates an embeddings client.
func NewEmbeddingsClient(baseURL string, timeout time.Duration, model string) *EmbeddingsClient {
if model == "" {
model = "bge"
}
return &EmbeddingsClient{httpClient: newHTTPClient(baseURL, timeout), Model: model}
}
// Embed generates embeddings for a list of texts.
func (c *EmbeddingsClient) Embed(ctx context.Context, texts []string) ([][]float64, error) {
body, err := c.postJSON(ctx, "/embeddings", map[string]any{
"input": texts,
"model": c.Model,
})
if err != nil {
return nil, err
}
var resp struct {
Data []struct {
Embedding []float64 `json:"embedding"`
} `json:"data"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return nil, err
}
result := make([][]float64, len(resp.Data))
for i, d := range resp.Data {
result[i] = d.Embedding
}
return result, nil
}
// EmbedSingle generates an embedding for a single text.
func (c *EmbeddingsClient) EmbedSingle(ctx context.Context, text string) ([]float64, error) {
results, err := c.Embed(ctx, []string{text})
if err != nil {
return nil, err
}
if len(results) == 0 {
return nil, fmt.Errorf("empty embedding result")
}
return results[0], nil
}
// Health checks if the embeddings service is healthy.
func (c *EmbeddingsClient) Health(ctx context.Context) bool {
return c.healthCheck(ctx)
}
// --- Reranker Client ---
// RerankerClient calls the reranker service (BGE Reranker).
type RerankerClient struct {
*httpClient
}
// NewRerankerClient creates a reranker client.
func NewRerankerClient(baseURL string, timeout time.Duration) *RerankerClient {
return &RerankerClient{httpClient: newHTTPClient(baseURL, timeout)}
}
// RerankResult represents a reranked document.
type RerankResult struct {
Index int `json:"index"`
Score float64 `json:"score"`
Document string `json:"document"`
}
// Rerank reranks documents by relevance to the query.
func (c *RerankerClient) Rerank(ctx context.Context, query string, documents []string, topK int) ([]RerankResult, error) {
payload := map[string]any{
"query": query,
"documents": documents,
}
if topK > 0 {
payload["top_n"] = topK
}
body, err := c.postJSON(ctx, "/rerank", payload)
if err != nil {
return nil, err
}
var resp struct {
Results []struct {
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
Score float64 `json:"score"`
} `json:"results"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return nil, err
}
results := make([]RerankResult, len(resp.Results))
for i, r := range resp.Results {
score := r.RelevanceScore
if score == 0 {
score = r.Score
}
doc := ""
if r.Index < len(documents) {
doc = documents[r.Index]
}
results[i] = RerankResult{Index: r.Index, Score: score, Document: doc}
}
return results, nil
}
// --- LLM Client ---
// LLMClient calls the vLLM-compatible LLM service.
type LLMClient struct {
*httpClient
Model string
MaxTokens int
Temperature float64
TopP float64
}
// NewLLMClient creates an LLM client.
func NewLLMClient(baseURL string, timeout time.Duration) *LLMClient {
return &LLMClient{
httpClient: newHTTPClient(baseURL, timeout),
Model: "default",
MaxTokens: 2048,
Temperature: 0.7,
TopP: 0.9,
}
}
// ChatMessage is an OpenAI-compatible message.
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// Generate sends a chat completion request and returns the response text.
func (c *LLMClient) Generate(ctx context.Context, prompt string, context_ string, systemPrompt string) (string, error) {
messages := buildMessages(prompt, context_, systemPrompt)
payload := map[string]any{
"model": c.Model,
"messages": messages,
"max_tokens": c.MaxTokens,
"temperature": c.Temperature,
"top_p": c.TopP,
}
body, err := c.postJSON(ctx, "/v1/chat/completions", payload)
if err != nil {
return "", err
}
var resp struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return "", err
}
if len(resp.Choices) == 0 {
return "", fmt.Errorf("no choices in LLM response")
}
return resp.Choices[0].Message.Content, nil
}
func buildMessages(prompt, ctx, systemPrompt string) []ChatMessage {
var msgs []ChatMessage
if systemPrompt != "" {
msgs = append(msgs, ChatMessage{Role: "system", Content: systemPrompt})
} else if ctx != "" {
msgs = append(msgs, ChatMessage{Role: "system", Content: "You are a helpful assistant. Use the provided context to answer the user's question. If the context doesn't contain relevant information, say so."})
}
if ctx != "" {
msgs = append(msgs, ChatMessage{Role: "user", Content: fmt.Sprintf("Context:\n%s\n\nQuestion: %s", ctx, prompt)})
} else {
msgs = append(msgs, ChatMessage{Role: "user", Content: prompt})
}
return msgs
}
// --- TTS Client ---
// TTSClient calls the TTS service (Coqui XTTS).
type TTSClient struct {
*httpClient
Language string
}
// NewTTSClient creates a TTS client.
func NewTTSClient(baseURL string, timeout time.Duration, language string) *TTSClient {
if language == "" {
language = "en"
}
return &TTSClient{httpClient: newHTTPClient(baseURL, timeout), Language: language}
}
// Synthesize generates audio bytes from text.
func (c *TTSClient) Synthesize(ctx context.Context, text, language, speaker string) ([]byte, error) {
if language == "" {
language = c.Language
}
params := url.Values{
"text": {text},
"language_id": {language},
}
if speaker != "" {
params.Set("speaker_id", speaker)
}
return c.getRaw(ctx, "/api/tts", params)
}
// --- STT Client ---
// STTClient calls the Whisper STT service.
type STTClient struct {
*httpClient
Language string
Task string
}
// NewSTTClient creates an STT client.
func NewSTTClient(baseURL string, timeout time.Duration) *STTClient {
return &STTClient{httpClient: newHTTPClient(baseURL, timeout), Task: "transcribe"}
}
// TranscribeResult holds transcription output.
type TranscribeResult struct {
Text string `json:"text"`
Language string `json:"language,omitempty"`
}
// Transcribe sends audio to Whisper and returns the transcription.
func (c *STTClient) Transcribe(ctx context.Context, audio []byte, language string) (*TranscribeResult, error) {
if language == "" {
language = c.Language
}
fields := map[string]string{
"response_format": "json",
}
if language != "" {
fields["language"] = language
}
endpoint := "/v1/audio/transcriptions"
if c.Task == "translate" {
endpoint = "/v1/audio/translations"
}
body, err := c.postMultipart(ctx, endpoint, "file", "audio.wav", audio, fields)
if err != nil {
return nil, err
}
var result TranscribeResult
if err := json.Unmarshal(body, &result); err != nil {
return nil, err
}
return &result, nil
}
// --- Milvus Client ---
// MilvusClient provides vector search via the Milvus HTTP/gRPC API.
// For the Go port we use the Milvus Go SDK.
type MilvusClient struct {
Host string
Port int
Collection string
connected bool
}
// NewMilvusClient creates a Milvus client.
func NewMilvusClient(host string, port int, collection string) *MilvusClient {
return &MilvusClient{Host: host, Port: port, Collection: collection}
}
// SearchResult holds a single vector search hit.
type SearchResult struct {
ID int64 `json:"id"`
Distance float64 `json:"distance"`
Score float64 `json:"score"`
Fields map[string]any `json:"fields,omitempty"`
}

145
config/config.go Normal file
View File

@@ -0,0 +1,145 @@
// Package config provides environment-based configuration for handler services.
package config
import (
"os"
"strconv"
"time"
)
// Settings holds base configuration for all handler services.
// Values are loaded from environment variables with sensible defaults.
type Settings struct {
// Service identification
ServiceName string
ServiceVersion string
ServiceNamespace string
DeploymentEnv string
// NATS configuration
NATSURL string
NATSUser string
NATSPassword string
NATSQueueGroup string
// Redis/Valkey configuration
RedisURL string
RedisPassword string
// Milvus configuration
MilvusHost string
MilvusPort int
MilvusCollection string
// Service endpoints
EmbeddingsURL string
RerankerURL string
LLMURL string
TTSURL string
STTURL string
// OpenTelemetry configuration
OTELEnabled bool
OTELEndpoint string
OTELUseHTTP bool
// HyperDX configuration
HyperDXEnabled bool
HyperDXAPIKey string
HyperDXEndpoint string
// MLflow configuration
MLflowTrackingURI string
MLflowExperimentName string
MLflowEnabled bool
// Health check configuration
HealthPort int
HealthPath string
ReadyPath string
// Timeouts
HTTPTimeout time.Duration
NATSTimeout time.Duration
}
// Load creates a Settings populated from environment variables with defaults.
func Load() *Settings {
return &Settings{
ServiceName: getEnv("SERVICE_NAME", "handler"),
ServiceVersion: getEnv("SERVICE_VERSION", "1.0.0"),
ServiceNamespace: getEnv("SERVICE_NAMESPACE", "ai-ml"),
DeploymentEnv: getEnv("DEPLOYMENT_ENV", "production"),
NATSURL: getEnv("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222"),
NATSUser: getEnv("NATS_USER", ""),
NATSPassword: getEnv("NATS_PASSWORD", ""),
NATSQueueGroup: getEnv("NATS_QUEUE_GROUP", ""),
RedisURL: getEnv("REDIS_URL", "redis://valkey.ai-ml.svc.cluster.local:6379"),
RedisPassword: getEnv("REDIS_PASSWORD", ""),
MilvusHost: getEnv("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local"),
MilvusPort: getEnvInt("MILVUS_PORT", 19530),
MilvusCollection: getEnv("MILVUS_COLLECTION", "documents"),
EmbeddingsURL: getEnv("EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"),
RerankerURL: getEnv("RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"),
LLMURL: getEnv("LLM_URL", "http://vllm-predictor.ai-ml.svc.cluster.local"),
TTSURL: getEnv("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local"),
STTURL: getEnv("STT_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"),
OTELEnabled: getEnvBool("OTEL_ENABLED", true),
OTELEndpoint: getEnv("OTEL_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317"),
OTELUseHTTP: getEnvBool("OTEL_USE_HTTP", false),
HyperDXEnabled: getEnvBool("HYPERDX_ENABLED", false),
HyperDXAPIKey: getEnv("HYPERDX_API_KEY", ""),
HyperDXEndpoint: getEnv("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io"),
MLflowTrackingURI: getEnv("MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"),
MLflowExperimentName: getEnv("MLFLOW_EXPERIMENT_NAME", ""),
MLflowEnabled: getEnvBool("MLFLOW_ENABLED", true),
HealthPort: getEnvInt("HEALTH_PORT", 8080),
HealthPath: getEnv("HEALTH_PATH", "/health"),
ReadyPath: getEnv("READY_PATH", "/ready"),
HTTPTimeout: getEnvDuration("HTTP_TIMEOUT", 60*time.Second),
NATSTimeout: getEnvDuration("NATS_TIMEOUT", 30*time.Second),
}
}
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
func getEnvInt(key string, fallback int) int {
if v := os.Getenv(key); v != "" {
if i, err := strconv.Atoi(v); err == nil {
return i
}
}
return fallback
}
func getEnvBool(key string, fallback bool) bool {
if v := os.Getenv(key); v != "" {
if b, err := strconv.ParseBool(v); err == nil {
return b
}
}
return fallback
}
func getEnvDuration(key string, fallback time.Duration) time.Duration {
if v := os.Getenv(key); v != "" {
if f, err := strconv.ParseFloat(v, 64); err == nil {
return time.Duration(f * float64(time.Second))
}
}
return fallback
}

42
config/config_test.go Normal file
View File

@@ -0,0 +1,42 @@
package config
import (
"os"
"testing"
"time"
)
func TestLoadDefaults(t *testing.T) {
s := Load()
if s.ServiceName != "handler" {
t.Errorf("expected default ServiceName 'handler', got %q", s.ServiceName)
}
if s.HealthPort != 8080 {
t.Errorf("expected default HealthPort 8080, got %d", s.HealthPort)
}
if s.HTTPTimeout != 60*time.Second {
t.Errorf("expected default HTTPTimeout 60s, got %v", s.HTTPTimeout)
}
}
func TestLoadFromEnv(t *testing.T) {
os.Setenv("SERVICE_NAME", "test-svc")
os.Setenv("HEALTH_PORT", "9090")
os.Setenv("OTEL_ENABLED", "false")
defer func() {
os.Unsetenv("SERVICE_NAME")
os.Unsetenv("HEALTH_PORT")
os.Unsetenv("OTEL_ENABLED")
}()
s := Load()
if s.ServiceName != "test-svc" {
t.Errorf("expected ServiceName 'test-svc', got %q", s.ServiceName)
}
if s.HealthPort != 9090 {
t.Errorf("expected HealthPort 9090, got %d", s.HealthPort)
}
if s.OTELEnabled {
t.Error("expected OTELEnabled false")
}
}

39
go.mod Normal file
View File

@@ -0,0 +1,39 @@
module git.daviestechlabs.io/daviestechlabs/handler-base
go 1.25.1
require (
github.com/nats-io/nats.go v1.48.0
github.com/vmihailenco/msgpack/v5 v5.4.1
go.opentelemetry.io/otel v1.40.0
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.40.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0
go.opentelemetry.io/otel/metric v1.40.0
go.opentelemetry.io/otel/sdk v1.40.0
go.opentelemetry.io/otel/sdk/metric v1.40.0
go.opentelemetry.io/otel/trace v1.40.0
)
require (
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/nats-io/nkeys v0.4.11 // indirect
github.com/nats-io/nuid v1.0.1 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 // indirect
go.opentelemetry.io/proto/otlp v1.9.0 // indirect
golang.org/x/crypto v0.47.0 // indirect
golang.org/x/net v0.49.0 // indirect
golang.org/x/sys v0.40.0 // indirect
golang.org/x/text v0.33.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 // indirect
google.golang.org/grpc v1.78.0 // indirect
google.golang.org/protobuf v1.36.11 // indirect
)

77
go.sum Normal file
View File

@@ -0,0 +1,77 @@
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7 h1:X+2YciYSxvMQK0UZ7sg45ZVabVZBeBuvMkmuI2V3Fak=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7/go.mod h1:lW34nIZuQ8UDPdkon5fmfp2l3+ZkQ2me/+oecHYLOII=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/nats-io/nats.go v1.48.0 h1:pSFyXApG+yWU/TgbKCjmm5K4wrHu86231/w84qRVR+U=
github.com/nats-io/nats.go v1.48.0/go.mod h1:iRWIPokVIFbVijxuMQq4y9ttaBTMe0SFdlZfMDd+33g=
github.com/nats-io/nkeys v0.4.11 h1:q44qGV008kYd9W1b1nEBkNzvnWxtRSQ7A8BoqRrcfa0=
github.com/nats-io/nkeys v0.4.11/go.mod h1:szDimtgmfOi9n25JpfIdGw12tZFYXqhGxjhVxsatHVE=
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms=
go.opentelemetry.io/otel v1.40.0/go.mod h1:IMb+uXZUKkMXdPddhwAHm6UfOwJyh4ct1ybIlV14J0g=
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.40.0 h1:NOyNnS19BF2SUDApbOKbDtWZ0IK7b8FJ2uAGdIWOGb0=
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.40.0/go.mod h1:VL6EgVikRLcJa9ftukrHu/ZkkhFBSo1lzvdBC9CF1ss=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 h1:QKdN8ly8zEMrByybbQgv8cWBcdAarwmIPZ6FThrWXJs=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0/go.mod h1:bTdK1nhqF76qiPoCCdyFIV+N/sRHYXYCTQc+3VCi3MI=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0 h1:DvJDOPmSWQHWywQS6lKL+pb8s3gBLOZUtw4N+mavW1I=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0/go.mod h1:EtekO9DEJb4/jRyN4v4Qjc2yA7AtfCBuz2FynRUWTXs=
go.opentelemetry.io/otel/metric v1.40.0 h1:rcZe317KPftE2rstWIBitCdVp89A2HqjkxR3c11+p9g=
go.opentelemetry.io/otel/metric v1.40.0/go.mod h1:ib/crwQH7N3r5kfiBZQbwrTge743UDc7DTFVZrrXnqc=
go.opentelemetry.io/otel/sdk v1.40.0 h1:KHW/jUzgo6wsPh9At46+h4upjtccTmuZCFAc9OJ71f8=
go.opentelemetry.io/otel/sdk v1.40.0/go.mod h1:Ph7EFdYvxq72Y8Li9q8KebuYUr2KoeyHx0DRMKrYBUE=
go.opentelemetry.io/otel/sdk/metric v1.40.0 h1:mtmdVqgQkeRxHgRv4qhyJduP3fYJRMX4AtAlbuWdCYw=
go.opentelemetry.io/otel/sdk/metric v1.40.0/go.mod h1:4Z2bGMf0KSK3uRjlczMOeMhKU2rhUqdWNoKcYrtcBPg=
go.opentelemetry.io/otel/trace v1.40.0 h1:WA4etStDttCSYuhwvEa8OP8I5EWu24lkOzp+ZYblVjw=
go.opentelemetry.io/otel/trace v1.40.0/go.mod h1:zeAhriXecNGP/s2SEG3+Y8X9ujcJOTqQ5RgdEJcawiA=
go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A=
go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 h1:merA0rdPeUV3YIIfHHcH4qBkiQAc1nfCKSI7lB4cV2M=
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409/go.mod h1:fl8J1IvUjCilwZzQowmw2b7HQB2eAuYBabMXzWurF+I=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 h1:H86B94AW+VfJWDqFeEbBPhEtHzJwJfTbgE2lZa54ZAQ=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ=
google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

186
handler/handler.go Normal file
View File

@@ -0,0 +1,186 @@
// Package handler provides the base Handler pattern for NATS message-driven services.
package handler
import (
"context"
"fmt"
"log/slog"
"os"
"os/signal"
"syscall"
"github.com/nats-io/nats.go"
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
"git.daviestechlabs.io/daviestechlabs/handler-base/health"
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
"git.daviestechlabs.io/daviestechlabs/handler-base/telemetry"
)
// MessageHandler is the callback for processing decoded NATS messages.
// data is the msgpack-decoded map. Return a response map (or nil for no reply).
type MessageHandler func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error)
// SetupFunc is called once before the handler starts processing messages.
type SetupFunc func(ctx context.Context) error
// TeardownFunc is called during graceful shutdown.
type TeardownFunc func(ctx context.Context) error
// Handler is the base service runner that wires NATS, health, and telemetry.
type Handler struct {
Settings *config.Settings
NATS *natsutil.Client
Telemetry *telemetry.Provider
Subject string
QueueGroup string
onSetup SetupFunc
onTeardown TeardownFunc
onMessage MessageHandler
running bool
}
// New creates a Handler for the given NATS subject.
func New(subject string, settings *config.Settings) *Handler {
if settings == nil {
settings = config.Load()
}
queueGroup := settings.NATSQueueGroup
natsOpts := []nats.Option{}
if settings.NATSUser != "" && settings.NATSPassword != "" {
natsOpts = append(natsOpts, nats.UserInfo(settings.NATSUser, settings.NATSPassword))
}
return &Handler{
Settings: settings,
Subject: subject,
QueueGroup: queueGroup,
NATS: natsutil.New(settings.NATSURL, natsOpts...),
}
}
// OnSetup registers the setup callback.
func (h *Handler) OnSetup(fn SetupFunc) { h.onSetup = fn }
// OnTeardown registers the teardown callback.
func (h *Handler) OnTeardown(fn TeardownFunc) { h.onTeardown = fn }
// OnMessage registers the message handler callback.
func (h *Handler) OnMessage(fn MessageHandler) { h.onMessage = fn }
// Run starts the handler: telemetry, health server, NATS subscription, and blocks until SIGTERM/SIGINT.
func (h *Handler) Run() error {
// Structured logging
slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})))
slog.Info("starting service", "name", h.Settings.ServiceName, "version", h.Settings.ServiceVersion)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Telemetry
tp, shutdown, err := telemetry.Setup(ctx, telemetry.Config{
ServiceName: h.Settings.ServiceName,
ServiceVersion: h.Settings.ServiceVersion,
ServiceNamespace: h.Settings.ServiceNamespace,
DeploymentEnv: h.Settings.DeploymentEnv,
Enabled: h.Settings.OTELEnabled,
Endpoint: h.Settings.OTELEndpoint,
})
if err != nil {
return fmt.Errorf("telemetry setup: %w", err)
}
defer shutdown(ctx)
h.Telemetry = tp
// Health server
healthSrv := health.New(
h.Settings.HealthPort,
h.Settings.HealthPath,
h.Settings.ReadyPath,
func() bool { return h.running && h.NATS.IsConnected() },
)
healthSrv.Start()
defer healthSrv.Stop(ctx)
// Connect to NATS
if err := h.NATS.Connect(); err != nil {
return fmt.Errorf("nats: %w", err)
}
defer h.NATS.Close()
// User setup
if h.onSetup != nil {
slog.Info("running service setup")
if err := h.onSetup(ctx); err != nil {
return fmt.Errorf("setup: %w", err)
}
}
// Subscribe
if h.onMessage == nil {
return fmt.Errorf("no message handler registered")
}
if err := h.NATS.Subscribe(h.Subject, h.wrapHandler(ctx), h.QueueGroup); err != nil {
return fmt.Errorf("subscribe: %w", err)
}
h.running = true
slog.Info("handler ready", "subject", h.Subject)
// Wait for shutdown signal
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
<-sigCh
slog.Info("shutting down")
h.running = false
// Teardown
if h.onTeardown != nil {
if err := h.onTeardown(ctx); err != nil {
slog.Warn("teardown error", "error", err)
}
}
slog.Info("shutdown complete")
return nil
}
// wrapHandler creates a nats.MsgHandler that decodes msgpack and dispatches to the user handler.
func (h *Handler) wrapHandler(ctx context.Context) nats.MsgHandler {
return func(msg *nats.Msg) {
data, err := natsutil.DecodeMsgpackMap(msg.Data)
if err != nil {
slog.Error("failed to decode message", "subject", msg.Subject, "error", err)
if msg.Reply != "" {
_ = h.NATS.Publish(msg.Reply, map[string]any{
"error": true,
"message": err.Error(),
"type": "DecodeError",
})
}
return
}
response, err := h.onMessage(ctx, msg, data)
if err != nil {
slog.Error("handler error", "subject", msg.Subject, "error", err)
if msg.Reply != "" {
_ = h.NATS.Publish(msg.Reply, map[string]any{
"error": true,
"message": err.Error(),
"type": fmt.Sprintf("%T", err),
})
}
return
}
if response != nil && msg.Reply != "" {
if err := h.NATS.Publish(msg.Reply, response); err != nil {
slog.Error("failed to publish reply", "error", err)
}
}
}
}

View File

@@ -1,28 +0,0 @@
"""
Handler Base - Shared utilities for AI/ML handler services.
Provides consistent patterns for:
- OpenTelemetry tracing and metrics
- NATS messaging
- Health checks
- Graceful shutdown
- Service client wrappers
"""
from handler_base.config import Settings
from handler_base.handler import Handler
from handler_base.health import HealthServer
from handler_base.nats_client import NATSClient
from handler_base.telemetry import get_meter, get_tracer, setup_telemetry
__all__ = [
"Handler",
"Settings",
"HealthServer",
"NATSClient",
"setup_telemetry",
"get_tracer",
"get_meter",
]
__version__ = "1.0.0"

View File

@@ -1,19 +0,0 @@
"""
Service client wrappers for AI/ML backends.
"""
from handler_base.clients.embeddings import EmbeddingsClient
from handler_base.clients.llm import LLMClient
from handler_base.clients.milvus import MilvusClient
from handler_base.clients.reranker import RerankerClient
from handler_base.clients.stt import STTClient
from handler_base.clients.tts import TTSClient
__all__ = [
"EmbeddingsClient",
"RerankerClient",
"LLMClient",
"TTSClient",
"STTClient",
"MilvusClient",
]

View File

@@ -1,128 +0,0 @@
"""
Embeddings service client (Infinity/BGE).
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
"""
import logging
from typing import Any, Optional
import httpx
from handler_base.config import EmbeddingsSettings
from handler_base.ray_utils import get_ray_handle
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class EmbeddingsClient:
"""
Client for the embeddings service (Infinity with BGE models).
When running inside Ray, automatically uses Ray handles for faster
internal communication. Falls back to HTTP for external calls.
Usage:
client = EmbeddingsClient()
embeddings = await client.embed(["Hello world"])
"""
# Ray Serve deployment configuration
RAY_DEPLOYMENT_NAME = "EmbeddingsDeployment"
RAY_APP_NAME = "embeddings"
def __init__(self, settings: Optional[EmbeddingsSettings] = None):
self.settings = settings or EmbeddingsSettings()
self._client = httpx.AsyncClient(
base_url=self.settings.embeddings_url,
timeout=self.settings.http_timeout,
)
self._ray_handle: Optional[Any] = None
self._ray_checked = False
def _get_ray_handle(self) -> Optional[Any]:
"""Get Ray handle, checking only once."""
if not self._ray_checked:
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
self._ray_checked = True
return self._ray_handle
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def embed(
self,
texts: list[str],
model: Optional[str] = None,
) -> list[list[float]]:
"""
Generate embeddings for a list of texts.
Args:
texts: List of texts to embed
model: Model name (defaults to settings)
Returns:
List of embedding vectors
"""
model = model or self.settings.embeddings_model
with create_span("embeddings.embed") as span:
if span:
span.set_attribute("embeddings.model", model)
span.set_attribute("embeddings.batch_size", len(texts))
# Try Ray handle first (faster internal path)
handle = self._get_ray_handle()
if handle:
try:
if span:
span.set_attribute("embeddings.transport", "ray")
result = await handle.embed.remote(texts, model)
if span and result:
span.set_attribute("embeddings.dimensions", len(result[0]) if result else 0)
return result
except Exception as e:
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
# HTTP fallback
if span:
span.set_attribute("embeddings.transport", "http")
response = await self._client.post(
"/embeddings",
json={"input": texts, "model": model},
)
response.raise_for_status()
result = response.json()
embeddings = [d["embedding"] for d in result.get("data", [])]
if span:
span.set_attribute("embeddings.dimensions", len(embeddings[0]) if embeddings else 0)
return embeddings
async def embed_single(self, text: str, model: Optional[str] = None) -> list[float]:
"""
Generate embedding for a single text.
Args:
text: Text to embed
model: Model name (defaults to settings)
Returns:
Embedding vector
"""
embeddings = await self.embed([text], model)
return embeddings[0] if embeddings else []
async def health(self) -> bool:
"""Check if the embeddings service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False

View File

@@ -1,233 +0,0 @@
"""
LLM service client (vLLM/OpenAI-compatible).
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
"""
import logging
from typing import Any, AsyncIterator, Optional
import httpx
from handler_base.config import LLMSettings
from handler_base.ray_utils import get_ray_handle
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class LLMClient:
"""
Client for the LLM service (vLLM with OpenAI-compatible API).
When running inside Ray, automatically uses Ray handles for faster
internal communication. Falls back to HTTP for external calls.
Usage:
client = LLMClient()
response = await client.generate("Hello, how are you?")
# With context for RAG
response = await client.generate(
"What is the capital?",
context="France is a country in Europe..."
)
# Streaming
async for chunk in client.stream("Tell me a story"):
print(chunk, end="")
"""
# Ray Serve deployment configuration
RAY_DEPLOYMENT_NAME = "VLLMDeployment"
RAY_APP_NAME = "llm"
def __init__(self, settings: Optional[LLMSettings] = None):
self.settings = settings or LLMSettings()
self._client = httpx.AsyncClient(
base_url=self.settings.llm_url,
timeout=self.settings.http_timeout,
)
self._ray_handle: Optional[Any] = None
self._ray_checked = False
def _get_ray_handle(self) -> Optional[Any]:
"""Get Ray handle, checking only once."""
if not self._ray_checked:
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
self._ray_checked = True
return self._ray_handle
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def generate(
self,
prompt: str,
context: Optional[str] = None,
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
stop: Optional[list[str]] = None,
) -> str:
"""
Generate a response from the LLM.
Args:
prompt: User prompt/query
context: Optional context for RAG
system_prompt: Optional system prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Top-p sampling
stop: Stop sequences
Returns:
Generated text response
"""
with create_span("llm.generate") as span:
messages = self._build_messages(prompt, context, system_prompt)
if span:
span.set_attribute("llm.model", self.settings.llm_model)
span.set_attribute("llm.prompt_length", len(prompt))
if context:
span.set_attribute("llm.context_length", len(context))
payload = {
"model": self.settings.llm_model,
"messages": messages,
"max_tokens": max_tokens or self.settings.llm_max_tokens,
"temperature": temperature or self.settings.llm_temperature,
"top_p": top_p or self.settings.llm_top_p,
}
if stop:
payload["stop"] = stop
# Try Ray handle first (faster internal path)
handle = self._get_ray_handle()
if handle:
try:
if span:
span.set_attribute("llm.transport", "ray")
result = await handle.remote(payload)
content = result["choices"][0]["message"]["content"]
if span:
span.set_attribute("llm.response_length", len(content))
return content
except Exception as e:
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
# HTTP fallback
if span:
span.set_attribute("llm.transport", "http")
response = await self._client.post("/v1/chat/completions", json=payload)
response.raise_for_status()
result = response.json()
content = result["choices"][0]["message"]["content"]
if span:
span.set_attribute("llm.response_length", len(content))
usage = result.get("usage", {})
span.set_attribute("llm.prompt_tokens", usage.get("prompt_tokens", 0))
span.set_attribute("llm.completion_tokens", usage.get("completion_tokens", 0))
return content
async def stream(
self,
prompt: str,
context: Optional[str] = None,
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
) -> AsyncIterator[str]:
"""
Stream a response from the LLM.
Args:
prompt: User prompt/query
context: Optional context for RAG
system_prompt: Optional system prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
Yields:
Text chunks as they're generated
"""
messages = self._build_messages(prompt, context, system_prompt)
payload = {
"model": self.settings.llm_model,
"messages": messages,
"max_tokens": max_tokens or self.settings.llm_max_tokens,
"temperature": temperature or self.settings.llm_temperature,
"stream": True,
}
async with self._client.stream("POST", "/v1/chat/completions", json=payload) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data = line[6:]
if data == "[DONE]":
break
import json
chunk = json.loads(data)
delta = chunk["choices"][0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
def _build_messages(
self,
prompt: str,
context: Optional[str] = None,
system_prompt: Optional[str] = None,
) -> list[dict]:
"""Build the messages list for the API call."""
messages = []
# System prompt
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
elif context:
# Default RAG system prompt
messages.append(
{
"role": "system",
"content": (
"You are a helpful assistant. Use the provided context to answer "
"the user's question. If the context doesn't contain relevant "
"information, say so."
),
}
)
# Add context as a separate message if provided
if context:
messages.append(
{
"role": "user",
"content": f"Context:\n{context}\n\nQuestion: {prompt}",
}
)
else:
messages.append({"role": "user", "content": prompt})
return messages
async def health(self) -> bool:
"""Check if the LLM service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False

View File

@@ -1,183 +0,0 @@
"""
Milvus vector database client.
"""
import logging
from typing import Optional
from pymilvus import Collection, connections, utility
from handler_base.config import Settings
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class MilvusClient:
"""
Client for Milvus vector database.
Usage:
client = MilvusClient()
await client.connect()
results = await client.search(embedding, limit=10)
"""
def __init__(self, settings: Optional[Settings] = None):
self.settings = settings or Settings()
self._connected = False
self._collection: Optional[Collection] = None
async def connect(self, collection_name: Optional[str] = None) -> None:
"""
Connect to Milvus and load collection.
Args:
collection_name: Collection to use (defaults to settings)
"""
collection_name = collection_name or self.settings.milvus_collection
connections.connect(
alias="default",
host=self.settings.milvus_host,
port=self.settings.milvus_port,
)
if utility.has_collection(collection_name):
self._collection = Collection(collection_name)
self._collection.load()
logger.info(f"Connected to Milvus collection: {collection_name}")
else:
logger.warning(f"Collection {collection_name} does not exist")
self._connected = True
async def close(self) -> None:
"""Close Milvus connection."""
if self._collection:
self._collection.release()
connections.disconnect("default")
self._connected = False
logger.info("Disconnected from Milvus")
async def search(
self,
embedding: list[float],
limit: int = 10,
output_fields: Optional[list[str]] = None,
filter_expr: Optional[str] = None,
) -> list[dict]:
"""
Search for similar vectors.
Args:
embedding: Query embedding vector
limit: Maximum number of results
output_fields: Fields to return (default: all)
filter_expr: Optional filter expression
Returns:
List of results with 'id', 'distance', and requested fields
"""
if not self._collection:
raise RuntimeError("Not connected to collection")
with create_span("milvus.search") as span:
if span:
span.set_attribute("milvus.collection", self._collection.name)
span.set_attribute("milvus.limit", limit)
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
results = self._collection.search(
data=[embedding],
anns_field="embedding",
param=search_params,
limit=limit,
output_fields=output_fields,
expr=filter_expr,
)
# Convert to list of dicts
hits = []
for hit in results[0]:
item = {
"id": hit.id,
"distance": hit.distance,
"score": 1 - hit.distance, # Convert distance to similarity
}
# Add output fields
if output_fields:
for field in output_fields:
if hasattr(hit.entity, field):
item[field] = getattr(hit.entity, field)
hits.append(item)
if span:
span.set_attribute("milvus.results", len(hits))
return hits
async def search_with_texts(
self,
embedding: list[float],
limit: int = 10,
text_field: str = "text",
metadata_fields: Optional[list[str]] = None,
) -> list[dict]:
"""
Search and return text content with metadata.
Args:
embedding: Query embedding
limit: Maximum results
text_field: Name of text field in collection
metadata_fields: Additional metadata fields to return
Returns:
List of results with text and metadata
"""
output_fields = [text_field]
if metadata_fields:
output_fields.extend(metadata_fields)
return await self.search(embedding, limit, output_fields)
async def insert(
self,
embeddings: list[list[float]],
data: list[dict],
) -> list[int]:
"""
Insert vectors with data into the collection.
Args:
embeddings: List of embedding vectors
data: List of dicts with field values
Returns:
List of inserted IDs
"""
if not self._collection:
raise RuntimeError("Not connected to collection")
with create_span("milvus.insert") as span:
if span:
span.set_attribute("milvus.collection", self._collection.name)
span.set_attribute("milvus.count", len(embeddings))
# Build insert data
insert_data = [embeddings]
for field in self._collection.schema.fields:
if field.name not in ("id", "embedding"):
field_values = [d.get(field.name) for d in data]
insert_data.append(field_values)
result = self._collection.insert(insert_data)
self._collection.flush()
return result.primary_keys
def health(self) -> bool:
"""Check if connected to Milvus."""
return self._connected and utility.get_connection_addr("default") is not None

View File

@@ -1,168 +0,0 @@
"""
Reranker service client (Infinity/BGE Reranker).
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
"""
import logging
from typing import Any, Optional
import httpx
from handler_base.config import Settings
from handler_base.ray_utils import get_ray_handle
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class RerankerClient:
"""
Client for the reranker service (Infinity with BGE Reranker).
When running inside Ray, automatically uses Ray handles for faster
internal communication. Falls back to HTTP for external calls.
Usage:
client = RerankerClient()
reranked = await client.rerank("query", ["doc1", "doc2"])
"""
# Ray Serve deployment configuration
RAY_DEPLOYMENT_NAME = "RerankerDeployment"
RAY_APP_NAME = "reranker"
def __init__(self, settings: Optional[Settings] = None):
self.settings = settings or Settings()
self._client = httpx.AsyncClient(
base_url=self.settings.reranker_url,
timeout=self.settings.http_timeout,
)
self._ray_handle: Optional[Any] = None
self._ray_checked = False
def _get_ray_handle(self) -> Optional[Any]:
"""Get Ray handle, checking only once."""
if not self._ray_checked:
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
self._ray_checked = True
return self._ray_handle
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def rerank(
self,
query: str,
documents: list[str],
top_k: Optional[int] = None,
) -> list[dict]:
"""
Rerank documents based on relevance to query.
Args:
query: Query text
documents: List of documents to rerank
top_k: Number of top results to return (default: all)
Returns:
List of dicts with 'index', 'score', and 'document' keys,
sorted by relevance score descending.
"""
with create_span("reranker.rerank") as span:
if span:
span.set_attribute("reranker.num_documents", len(documents))
if top_k:
span.set_attribute("reranker.top_k", top_k)
payload = {
"query": query,
"documents": documents,
}
if top_k:
payload["top_n"] = top_k
# Try Ray handle first (faster internal path)
handle = self._get_ray_handle()
if handle:
try:
if span:
span.set_attribute("reranker.transport", "ray")
results = await handle.rerank.remote(query, documents, top_k)
# Enrich with original documents
enriched = []
for r in results:
idx = r.get("index", 0)
enriched.append(
{
"index": idx,
"score": r.get("relevance_score", r.get("score", 0)),
"document": documents[idx] if idx < len(documents) else "",
}
)
return enriched
except Exception as e:
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
# HTTP fallback
if span:
span.set_attribute("reranker.transport", "http")
response = await self._client.post("/rerank", json=payload)
response.raise_for_status()
result = response.json()
results = result.get("results", [])
# Enrich with original documents
enriched = []
for r in results:
idx = r.get("index", 0)
enriched.append(
{
"index": idx,
"score": r.get("relevance_score", r.get("score", 0)),
"document": documents[idx] if idx < len(documents) else "",
}
)
return enriched
async def rerank_with_metadata(
self,
query: str,
documents: list[dict],
text_key: str = "text",
top_k: Optional[int] = None,
) -> list[dict]:
"""
Rerank documents with metadata, preserving metadata in results.
Args:
query: Query text
documents: List of dicts with text and metadata
text_key: Key containing text in each document dict
top_k: Number of top results to return
Returns:
Reranked documents with original metadata preserved.
"""
texts = [d.get(text_key, "") for d in documents]
reranked = await self.rerank(query, texts, top_k)
# Merge back metadata
for r in reranked:
idx = r["index"]
if idx < len(documents):
r["metadata"] = {k: v for k, v in documents[idx].items() if k != text_key}
return reranked
async def health(self) -> bool:
"""Check if the reranker service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False

View File

@@ -1,170 +0,0 @@
"""
STT service client (Whisper/faster-whisper).
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
"""
import logging
from typing import Any, Optional
import httpx
from handler_base.config import STTSettings
from handler_base.ray_utils import get_ray_handle
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class STTClient:
"""
Client for the STT service (Whisper/faster-whisper).
When running inside Ray, automatically uses Ray handles for faster
internal communication. Falls back to HTTP for external calls.
Usage:
client = STTClient()
text = await client.transcribe(audio_bytes)
"""
# Ray Serve deployment configuration
RAY_DEPLOYMENT_NAME = "WhisperDeployment"
RAY_APP_NAME = "whisper"
def __init__(self, settings: Optional[STTSettings] = None):
self.settings = settings or STTSettings()
self._client = httpx.AsyncClient(
base_url=self.settings.stt_url,
timeout=180.0, # Transcription can be slow
)
self._ray_handle: Optional[Any] = None
self._ray_checked = False
def _get_ray_handle(self) -> Optional[Any]:
"""Get Ray handle, checking only once."""
if not self._ray_checked:
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
self._ray_checked = True
return self._ray_handle
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def transcribe(
self,
audio: bytes,
language: Optional[str] = None,
task: Optional[str] = None,
response_format: str = "json",
) -> dict:
"""
Transcribe audio to text.
Args:
audio: Audio bytes (WAV, MP3, etc.)
language: Language code (None for auto-detect)
task: "transcribe" or "translate"
response_format: "json", "text", "srt", "vtt"
Returns:
Dict with 'text', 'language', and optional 'segments'
"""
language = language or self.settings.stt_language
task = task or self.settings.stt_task
with create_span("stt.transcribe") as span:
if span:
span.set_attribute("stt.task", task)
span.set_attribute("stt.audio_size", len(audio))
if language:
span.set_attribute("stt.language", language)
# Try Ray handle first (faster internal path)
handle = self._get_ray_handle()
if handle:
try:
if span:
span.set_attribute("stt.transport", "ray")
result = await handle.transcribe.remote(audio, language, task)
if span:
span.set_attribute("stt.result_length", len(result.get("text", "")))
if result.get("language"):
span.set_attribute("stt.detected_language", result["language"])
return result
except Exception as e:
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
# HTTP fallback
if span:
span.set_attribute("stt.transport", "http")
files = {"file": ("audio.wav", audio, "audio/wav")}
data = {
"response_format": response_format,
}
if language:
data["language"] = language
# Choose endpoint based on task
if task == "translate":
endpoint = "/v1/audio/translations"
else:
endpoint = "/v1/audio/transcriptions"
response = await self._client.post(endpoint, files=files, data=data)
response.raise_for_status()
if response_format == "text":
return {"text": response.text}
result = response.json()
if span:
span.set_attribute("stt.result_length", len(result.get("text", "")))
if result.get("language"):
span.set_attribute("stt.detected_language", result["language"])
return result
async def transcribe_file(
self,
file_path: str,
language: Optional[str] = None,
task: Optional[str] = None,
) -> dict:
"""
Transcribe an audio file.
Args:
file_path: Path to audio file
language: Language code
task: "transcribe" or "translate"
Returns:
Transcription result
"""
with open(file_path, "rb") as f:
audio = f.read()
return await self.transcribe(audio, language, task)
async def translate(self, audio: bytes) -> dict:
"""
Translate audio to English.
Args:
audio: Audio bytes
Returns:
Translation result with 'text' key
"""
return await self.transcribe(audio, task="translate")
async def health(self) -> bool:
"""Check if the STT service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False

View File

@@ -1,149 +0,0 @@
"""
TTS service client (Coqui XTTS).
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
"""
import logging
from typing import Any, Optional
import httpx
from handler_base.config import TTSSettings
from handler_base.ray_utils import get_ray_handle
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class TTSClient:
"""
Client for the TTS service (Coqui XTTS).
When running inside Ray, automatically uses Ray handles for faster
internal communication. Falls back to HTTP for external calls.
Usage:
client = TTSClient()
audio_bytes = await client.synthesize("Hello world")
"""
# Ray Serve deployment configuration
RAY_DEPLOYMENT_NAME = "TTSDeployment"
RAY_APP_NAME = "tts"
def __init__(self, settings: Optional[TTSSettings] = None):
self.settings = settings or TTSSettings()
self._client = httpx.AsyncClient(
base_url=self.settings.tts_url,
timeout=120.0, # TTS can be slow
)
self._ray_handle: Optional[Any] = None
self._ray_checked = False
def _get_ray_handle(self) -> Optional[Any]:
"""Get Ray handle, checking only once."""
if not self._ray_checked:
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
self._ray_checked = True
return self._ray_handle
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def synthesize(
self,
text: str,
language: Optional[str] = None,
speaker: Optional[str] = None,
) -> bytes:
"""
Synthesize speech from text.
Args:
text: Text to synthesize
language: Language code (e.g., "en", "es", "fr")
speaker: Speaker ID or reference
Returns:
WAV audio bytes
"""
language = language or self.settings.tts_language
with create_span("tts.synthesize") as span:
if span:
span.set_attribute("tts.language", language)
span.set_attribute("tts.text_length", len(text))
params = {
"text": text,
"language_id": language,
}
if speaker:
params["speaker_id"] = speaker
# Try Ray handle first (faster internal path)
handle = self._get_ray_handle()
if handle:
try:
if span:
span.set_attribute("tts.transport", "ray")
audio_bytes = await handle.synthesize.remote(text, language, speaker)
if span:
span.set_attribute("tts.audio_size", len(audio_bytes))
return audio_bytes
except Exception as e:
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
# HTTP fallback
if span:
span.set_attribute("tts.transport", "http")
response = await self._client.get("/api/tts", params=params)
response.raise_for_status()
audio_bytes = response.content
if span:
span.set_attribute("tts.audio_size", len(audio_bytes))
return audio_bytes
async def synthesize_to_file(
self,
text: str,
output_path: str,
language: Optional[str] = None,
speaker: Optional[str] = None,
) -> None:
"""
Synthesize speech and save to a file.
Args:
text: Text to synthesize
output_path: Path to save the audio file
language: Language code
speaker: Speaker ID
"""
audio_bytes = await self.synthesize(text, language, speaker)
with open(output_path, "wb") as f:
f.write(audio_bytes)
async def get_speakers(self) -> list[dict]:
"""Get available speakers/voices."""
try:
response = await self._client.get("/api/speakers")
response.raise_for_status()
return response.json()
except Exception:
return []
async def health(self) -> bool:
"""Check if the TTS service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False

View File

@@ -1,101 +0,0 @@
"""
Configuration management using Pydantic Settings.
Environment variables are automatically loaded and validated.
"""
from typing import Optional
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Base settings for all handler services."""
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
# Service identification
service_name: str = "handler"
service_version: str = "1.0.0"
service_namespace: str = "ai-ml"
deployment_env: str = "production"
# NATS configuration
nats_url: str = "nats://nats.ai-ml.svc.cluster.local:4222"
nats_user: Optional[str] = None
nats_password: Optional[str] = None
nats_queue_group: Optional[str] = None
# Redis/Valkey configuration
redis_url: str = "redis://valkey.ai-ml.svc.cluster.local:6379"
redis_password: Optional[str] = None
# Milvus configuration
milvus_host: str = "milvus.ai-ml.svc.cluster.local"
milvus_port: int = 19530
milvus_collection: str = "documents"
# Service endpoints
embeddings_url: str = "http://embeddings-predictor.ai-ml.svc.cluster.local"
reranker_url: str = "http://reranker-predictor.ai-ml.svc.cluster.local"
llm_url: str = "http://vllm-predictor.ai-ml.svc.cluster.local"
tts_url: str = "http://tts-predictor.ai-ml.svc.cluster.local"
stt_url: str = "http://whisper-predictor.ai-ml.svc.cluster.local"
# OpenTelemetry configuration
otel_enabled: bool = True
otel_endpoint: str = "http://opentelemetry-collector.observability.svc.cluster.local:4317"
otel_use_http: bool = False
# HyperDX configuration
hyperdx_enabled: bool = False
hyperdx_api_key: Optional[str] = None
hyperdx_endpoint: str = "https://in-otel.hyperdx.io"
# MLflow configuration
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80"
mlflow_experiment_name: Optional[str] = None
mlflow_enabled: bool = True
# Health check configuration
health_port: int = 8080
health_path: str = "/health"
ready_path: str = "/ready"
# Timeouts (seconds)
http_timeout: float = 60.0
nats_timeout: float = 30.0
class EmbeddingsSettings(Settings):
"""Settings for embeddings service client."""
embeddings_model: str = "bge"
embeddings_batch_size: int = 32
class LLMSettings(Settings):
"""Settings for LLM service client."""
llm_model: str = "default"
llm_max_tokens: int = 2048
llm_temperature: float = 0.7
llm_top_p: float = 0.9
class TTSSettings(Settings):
"""Settings for TTS service client."""
tts_language: str = "en"
tts_speaker: Optional[str] = None
class STTSettings(Settings):
"""Settings for STT service client."""
stt_language: Optional[str] = None # Auto-detect
stt_task: str = "transcribe" # or "translate"

View File

@@ -1,222 +0,0 @@
"""
Base handler class for building NATS-based services.
"""
import asyncio
import logging
import signal
from abc import ABC, abstractmethod
from typing import Any, Optional
from nats.aio.msg import Msg
from handler_base.config import Settings
from handler_base.health import HealthServer
from handler_base.nats_client import NATSClient
from handler_base.telemetry import create_span, setup_telemetry
logger = logging.getLogger(__name__)
class Handler(ABC):
"""
Base class for NATS message handlers.
Subclass and implement:
- setup(): Initialize your service clients
- handle_message(): Process incoming messages
- teardown(): Clean up resources (optional)
Example:
class MyHandler(Handler):
async def setup(self):
self.embeddings = EmbeddingsClient()
async def handle_message(self, msg: Msg, data: dict) -> Optional[dict]:
result = await self.embeddings.embed(data["text"])
return {"embedding": result}
if __name__ == "__main__":
MyHandler(subject="my.subject").run()
"""
def __init__(
self,
subject: str,
settings: Optional[Settings] = None,
queue_group: Optional[str] = None,
):
"""
Initialize the handler.
Args:
subject: NATS subject to subscribe to
settings: Configuration settings
queue_group: Optional queue group for load balancing
"""
self.subject = subject
self.settings = settings or Settings()
self.queue_group = queue_group or self.settings.nats_queue_group
self.nats = NATSClient(self.settings)
self.health_server = HealthServer(self.settings, self._check_ready)
self._running = False
self._shutdown_event = asyncio.Event()
@abstractmethod
async def setup(self) -> None:
"""
Initialize service clients and resources.
Called once before starting to handle messages.
Override this to set up your service-specific clients.
"""
pass
@abstractmethod
async def handle_message(self, msg: Msg, data: Any) -> Optional[Any]:
"""
Handle an incoming message.
Args:
msg: Raw NATS message
data: Decoded message data (msgpack unpacked)
Returns:
Optional response data. If returned and msg has a reply subject,
the response will be sent automatically.
"""
pass
async def teardown(self) -> None:
"""
Clean up resources.
Called during graceful shutdown.
Override to add custom cleanup logic.
"""
pass
async def _check_ready(self) -> bool:
"""Check if the service is ready to handle requests."""
return self._running and self.nats._nc is not None
async def _message_handler(self, msg: Msg) -> None:
"""Internal message handler with tracing and error handling."""
with create_span(f"handle.{self.subject}") as span:
try:
# Decode message
data = NATSClient.decode_msgpack(msg)
if span:
span.set_attribute("messaging.destination", msg.subject)
if isinstance(data, dict):
request_id = data.get("request_id", data.get("id"))
if request_id:
span.set_attribute("request.id", str(request_id))
# Handle message
response = await self.handle_message(msg, data)
# Send response if applicable
if response is not None and msg.reply:
await self.nats.publish(msg.reply, response)
except Exception as e:
logger.exception(f"Error handling message on {msg.subject}")
if span:
span.set_attribute("error", True)
span.set_attribute("error.message", str(e))
# Send error response if reply expected
if msg.reply:
error_response = {
"error": True,
"message": str(e),
"type": type(e).__name__,
}
await self.nats.publish(msg.reply, error_response)
def _setup_signals(self) -> None:
"""Set up signal handlers for graceful shutdown."""
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, self._handle_signal, sig)
def _handle_signal(self, sig: signal.Signals) -> None:
"""Handle shutdown signal."""
logger.info(f"Received {sig.name}, initiating graceful shutdown...")
self._shutdown_event.set()
async def _run(self) -> None:
"""Main async run loop."""
# Setup telemetry
setup_telemetry(self.settings)
# Start health server
self.health_server.start()
try:
# Connect to NATS
await self.nats.connect()
# Run user setup
logger.info("Running service setup...")
await self.setup()
# Subscribe to subject
await self.nats.subscribe(
self.subject,
self._message_handler,
queue=self.queue_group,
)
self._running = True
logger.info(f"Handler ready, listening on {self.subject}")
# Wait for shutdown signal
await self._shutdown_event.wait()
except Exception:
logger.exception("Fatal error in handler")
raise
finally:
self._running = False
# Graceful shutdown
logger.info("Shutting down...")
try:
await self.teardown()
except Exception as e:
logger.warning(f"Error during teardown: {e}")
await self.nats.close()
self.health_server.stop()
logger.info("Shutdown complete")
def run(self) -> None:
"""
Run the handler.
This is the main entry point. It sets up signal handlers
and runs the async event loop.
"""
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger.info(f"Starting {self.settings.service_name} v{self.settings.service_version}")
# Run the async loop
asyncio.run(self._run_with_signals())
async def _run_with_signals(self) -> None:
"""Run with signal handling."""
self._setup_signals()
await self._run()

View File

@@ -1,125 +0,0 @@
"""
HTTP health check server.
Provides /health and /ready endpoints for Kubernetes probes.
"""
import asyncio
import json
import logging
import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Awaitable, Callable, Optional
from handler_base.config import Settings
logger = logging.getLogger(__name__)
class HealthHandler(BaseHTTPRequestHandler):
"""HTTP request handler for health checks."""
# Class-level state
ready_check: Optional[Callable[[], Awaitable[bool]]] = None
health_path: str = "/health"
ready_path: str = "/ready"
def log_message(self, format, *args):
"""Suppress default logging."""
pass
def do_GET(self):
"""Handle GET requests for health/ready endpoints."""
if self.path == self.health_path:
self._respond_ok({"status": "healthy"})
elif self.path == self.ready_path:
self._handle_ready()
else:
self._respond_not_found()
def _handle_ready(self):
"""Check readiness and respond."""
# Access via class to avoid method binding issues
ready_check = HealthHandler.ready_check
if ready_check is None:
self._respond_ok({"status": "ready"})
return
try:
# Run the async check in a new event loop
loop = asyncio.new_event_loop()
try:
is_ready = loop.run_until_complete(ready_check())
finally:
loop.close()
if is_ready:
self._respond_ok({"status": "ready"})
else:
self._respond_unavailable({"status": "not ready"})
except Exception as e:
logger.exception("Readiness check failed")
self._respond_unavailable({"status": "error", "message": str(e)})
def _respond_ok(self, data: dict):
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(json.dumps(data).encode())
def _respond_unavailable(self, data: dict):
self.send_response(503)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(json.dumps(data).encode())
def _respond_not_found(self):
self.send_response(404)
self.end_headers()
class HealthServer:
"""
Background HTTP server for health checks.
Usage:
server = HealthServer(settings)
server.start()
# ... run your service ...
server.stop()
"""
def __init__(
self,
settings: Optional[Settings] = None,
ready_check: Optional[Callable[[], Awaitable[bool]]] = None,
):
self.settings = settings or Settings()
self.ready_check = ready_check
self._server: Optional[HTTPServer] = None
self._thread: Optional[threading.Thread] = None
def start(self) -> None:
"""Start the health check server in a background thread."""
# Configure handler class
HealthHandler.ready_check = self.ready_check
HealthHandler.health_path = self.settings.health_path
HealthHandler.ready_path = self.settings.ready_path
# Create and start server
self._server = HTTPServer(("0.0.0.0", self.settings.health_port), HealthHandler)
self._thread = threading.Thread(target=self._server.serve_forever, daemon=True)
self._thread.start()
logger.info(
f"Health server started on port {self.settings.health_port} "
f"(health: {self.settings.health_path}, ready: {self.settings.ready_path})"
)
def stop(self) -> None:
"""Stop the health check server."""
if self._server:
self._server.shutdown()
self._server = None
self._thread = None
logger.info("Health server stopped")

View File

@@ -1,188 +0,0 @@
"""
NATS client wrapper with connection management and utilities.
"""
import logging
from typing import Any, Awaitable, Callable, Optional
import msgpack
import nats
from nats.aio.client import Client
from nats.aio.msg import Msg
from nats.js import JetStreamContext
from handler_base.config import Settings
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class NATSClient:
"""
NATS client with automatic connection management.
Supports:
- Core NATS pub/sub
- JetStream for persistence
- Queue groups for load balancing
- Msgpack serialization
"""
def __init__(self, settings: Optional[Settings] = None):
self.settings = settings or Settings()
self._nc: Optional[Client] = None
self._js: Optional[JetStreamContext] = None
self._subscriptions: list = []
@property
def nc(self) -> Client:
"""Get the NATS client, raising if not connected."""
if self._nc is None:
raise RuntimeError("NATS client not connected. Call connect() first.")
return self._nc
@property
def js(self) -> JetStreamContext:
"""Get JetStream context, raising if not connected."""
if self._js is None:
raise RuntimeError("JetStream not initialized. Call connect() first.")
return self._js
async def connect(self) -> None:
"""Connect to NATS server."""
connect_opts = {
"servers": self.settings.nats_url,
"reconnect_time_wait": 2,
"max_reconnect_attempts": -1, # Infinite
}
if self.settings.nats_user and self.settings.nats_password:
connect_opts["user"] = self.settings.nats_user
connect_opts["password"] = self.settings.nats_password
logger.info(f"Connecting to NATS at {self.settings.nats_url}")
self._nc = await nats.connect(**connect_opts)
self._js = self._nc.jetstream()
logger.info("Connected to NATS")
async def close(self) -> None:
"""Close NATS connection gracefully."""
if self._nc:
# Drain subscriptions first
for sub in self._subscriptions:
try:
await sub.drain()
except Exception as e:
logger.warning(f"Error draining subscription: {e}")
await self._nc.drain()
await self._nc.close()
self._nc = None
self._js = None
logger.info("NATS connection closed")
async def subscribe(
self,
subject: str,
handler: Callable[[Msg], Awaitable[None]],
queue: Optional[str] = None,
):
"""
Subscribe to a subject with a handler function.
Args:
subject: NATS subject to subscribe to
handler: Async function to handle messages
queue: Optional queue group for load balancing
"""
queue = queue or self.settings.nats_queue_group
if queue:
sub = await self.nc.subscribe(subject, queue=queue, cb=handler)
logger.info(f"Subscribed to {subject} (queue: {queue})")
else:
sub = await self.nc.subscribe(subject, cb=handler)
logger.info(f"Subscribed to {subject}")
self._subscriptions.append(sub)
return sub
async def publish(
self,
subject: str,
data: Any,
use_msgpack: bool = True,
) -> None:
"""
Publish a message to a subject.
Args:
subject: NATS subject to publish to
data: Data to publish (will be serialized)
use_msgpack: Whether to use msgpack (True) or JSON (False)
"""
with create_span("nats.publish") as span:
if span:
span.set_attribute("messaging.destination", subject)
if use_msgpack:
payload = msgpack.packb(data, use_bin_type=True)
else:
import json
payload = json.dumps(data).encode()
await self.nc.publish(subject, payload)
async def request(
self,
subject: str,
data: Any,
timeout: Optional[float] = None,
use_msgpack: bool = True,
) -> Any:
"""
Send a request and wait for response.
Args:
subject: NATS subject to send request to
data: Request data
timeout: Response timeout in seconds
use_msgpack: Whether to use msgpack serialization
Returns:
Decoded response data
"""
timeout = timeout or self.settings.nats_timeout
with create_span("nats.request") as span:
if span:
span.set_attribute("messaging.destination", subject)
if use_msgpack:
payload = msgpack.packb(data, use_bin_type=True)
else:
import json
payload = json.dumps(data).encode()
response = await self.nc.request(subject, payload, timeout=timeout)
if use_msgpack:
return msgpack.unpackb(response.data, raw=False)
else:
import json
return json.loads(response.data.decode())
@staticmethod
def decode_msgpack(msg: Msg) -> Any:
"""Decode a msgpack message."""
return msgpack.unpackb(msg.data, raw=False)
@staticmethod
def decode_json(msg: Msg) -> Any:
"""Decode a JSON message."""
import json
return json.loads(msg.data.decode())

View File

View File

@@ -1,70 +0,0 @@
"""
Ray integration utilities for handler-base clients.
When running inside a Ray cluster, clients can use Ray Serve handles
for faster internal communication (gRPC instead of HTTP).
"""
import logging
from typing import Any, Optional
logger = logging.getLogger(__name__)
# Ray handle cache to avoid repeated lookups
_ray_handles: dict[str, Any] = {}
_ray_available: Optional[bool] = None
def is_ray_available() -> bool:
"""Check if we're running inside a Ray cluster."""
global _ray_available
if _ray_available is not None:
return _ray_available
try:
import ray
_ray_available = ray.is_initialized()
if _ray_available:
logger.info("Ray detected - will use Ray handles for internal calls")
return _ray_available
except ImportError:
_ray_available = False
return False
def get_ray_handle(deployment_name: str, app_name: str) -> Optional[Any]:
"""
Get a Ray Serve deployment handle for internal calls.
Args:
deployment_name: Name of the Ray Serve deployment
app_name: Name of the Ray Serve application
Returns:
DeploymentHandle if available, None otherwise
"""
if not is_ray_available():
return None
cache_key = f"{app_name}/{deployment_name}"
if cache_key in _ray_handles:
return _ray_handles[cache_key]
try:
from ray import serve
handle = serve.get_deployment_handle(deployment_name, app_name=app_name)
_ray_handles[cache_key] = handle
logger.debug(f"Got Ray handle for {cache_key}")
return handle
except Exception as e:
logger.debug(f"Could not get Ray handle for {cache_key}: {e}")
return None
def clear_ray_handles() -> None:
"""Clear cached Ray handles (useful for testing)."""
global _ray_handles, _ray_available
_ray_handles.clear()
_ray_available = None

View File

@@ -1,158 +0,0 @@
"""
OpenTelemetry setup for tracing and metrics.
Supports both gRPC and HTTP exporters, with optional HyperDX integration.
"""
import logging
import os
from typing import Optional, Tuple
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
OTLPMetricExporter as OTLPMetricExporterHTTP,
)
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter as OTLPSpanExporterHTTP,
)
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
from opentelemetry.instrumentation.logging import LoggingInstrumentor
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.resources import SERVICE_NAME, SERVICE_NAMESPACE, SERVICE_VERSION, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from handler_base.config import Settings
logger = logging.getLogger(__name__)
# Global references
_tracer: Optional[trace.Tracer] = None
_meter: Optional[metrics.Meter] = None
_initialized = False
def setup_telemetry(
settings: Optional[Settings] = None,
) -> Tuple[Optional[trace.Tracer], Optional[metrics.Meter]]:
"""
Initialize OpenTelemetry tracing and metrics.
Args:
settings: Configuration settings. If None, loads from environment.
Returns:
Tuple of (tracer, meter) or (None, None) if disabled.
"""
global _tracer, _meter, _initialized
if _initialized:
return _tracer, _meter
if settings is None:
settings = Settings()
if not settings.otel_enabled:
logger.info("OpenTelemetry disabled")
_initialized = True
return None, None
# Create resource with service information
resource = Resource.create(
{
SERVICE_NAME: settings.service_name,
SERVICE_VERSION: settings.service_version,
SERVICE_NAMESPACE: settings.service_namespace,
"deployment.environment": settings.deployment_env,
"host.name": os.environ.get("HOSTNAME", "unknown"),
}
)
# Determine endpoint and exporter type
if settings.hyperdx_enabled and settings.hyperdx_api_key:
# HyperDX uses HTTP with API key header
endpoint = settings.hyperdx_endpoint
headers = {"authorization": settings.hyperdx_api_key}
use_http = True
logger.info(f"Using HyperDX endpoint: {endpoint}")
else:
endpoint = settings.otel_endpoint
headers = None
use_http = settings.otel_use_http
logger.info(f"Using OTEL endpoint: {endpoint} (HTTP: {use_http})")
# Setup tracing
if use_http:
trace_exporter = OTLPSpanExporterHTTP(
endpoint=f"{endpoint}/v1/traces",
headers=headers,
)
else:
trace_exporter = OTLPSpanExporter(
endpoint=endpoint,
)
tracer_provider = TracerProvider(resource=resource)
tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter))
trace.set_tracer_provider(tracer_provider)
# Setup metrics
if use_http:
metric_exporter = OTLPMetricExporterHTTP(
endpoint=f"{endpoint}/v1/metrics",
headers=headers,
)
else:
metric_exporter = OTLPMetricExporter(
endpoint=endpoint,
)
metric_reader = PeriodicExportingMetricReader(
metric_exporter,
export_interval_millis=60000,
)
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
metrics.set_meter_provider(meter_provider)
# Instrument libraries
HTTPXClientInstrumentor().instrument()
LoggingInstrumentor().instrument(set_logging_format=True)
# Create tracer and meter for this service
_tracer = trace.get_tracer(settings.service_name, settings.service_version)
_meter = metrics.get_meter(settings.service_name, settings.service_version)
logger.info(f"OpenTelemetry initialized for {settings.service_name}")
_initialized = True
return _tracer, _meter
def get_tracer() -> Optional[trace.Tracer]:
"""Get the global tracer instance."""
return _tracer
def get_meter() -> Optional[metrics.Meter]:
"""Get the global meter instance."""
return _meter
def create_span(name: str, **kwargs):
"""
Create a new span.
Usage:
with create_span("my_operation") as span:
span.set_attribute("key", "value")
# do work
"""
if _tracer is None:
# Return a no-op context manager
from contextlib import nullcontext
return nullcontext()
return _tracer.start_as_current_span(name, **kwargs)

86
health/health.go Normal file
View File

@@ -0,0 +1,86 @@
// Package health provides an HTTP server for Kubernetes liveness and readiness probes.
package health
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net"
"net/http"
"sync/atomic"
"time"
)
// ReadyFunc is called to determine if the service is ready. Return true if ready.
type ReadyFunc func() bool
// Server serves /health and /ready endpoints.
type Server struct {
port int
healthPath string
readyPath string
readyCheck ReadyFunc
srv *http.Server
ready atomic.Bool
}
// New creates a health server on the given port.
func New(port int, healthPath, readyPath string, readyCheck ReadyFunc) *Server {
s := &Server{
port: port,
healthPath: healthPath,
readyPath: readyPath,
readyCheck: readyCheck,
}
mux := http.NewServeMux()
mux.HandleFunc(healthPath, s.handleHealth)
mux.HandleFunc(readyPath, s.handleReady)
s.srv = &http.Server{
Addr: fmt.Sprintf(":%d", port),
Handler: mux,
ReadHeaderTimeout: 5 * time.Second,
}
return s
}
// Start begins serving in the background. Call Stop to shut down.
func (s *Server) Start() {
ln, err := net.Listen("tcp", s.srv.Addr)
if err != nil {
slog.Error("health server listen failed", "error", err)
return
}
slog.Info("health server started", "port", s.port, "health", s.healthPath, "ready", s.readyPath)
go func() {
if err := s.srv.Serve(ln); err != nil && err != http.ErrServerClosed {
slog.Error("health server error", "error", err)
}
}()
}
// Stop gracefully shuts down the server.
func (s *Server) Stop(ctx context.Context) {
if s.srv != nil {
_ = s.srv.Shutdown(ctx)
slog.Info("health server stopped")
}
}
func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, map[string]string{"status": "healthy"})
}
func (s *Server) handleReady(w http.ResponseWriter, _ *http.Request) {
if s.readyCheck != nil && !s.readyCheck() {
writeJSON(w, http.StatusServiceUnavailable, map[string]string{"status": "not ready"})
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "ready"})
}
func writeJSON(w http.ResponseWriter, status int, data any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(data)
}

77
health/health_test.go Normal file
View File

@@ -0,0 +1,77 @@
package health
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"time"
)
func TestHealthEndpoint(t *testing.T) {
srv := New(18080, "/health", "/ready", nil)
srv.Start()
defer srv.Stop(context.Background())
time.Sleep(50 * time.Millisecond)
resp, err := http.Get("http://localhost:18080/health")
if err != nil {
t.Fatalf("health request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("expected 200, got %d", resp.StatusCode)
}
body, _ := io.ReadAll(resp.Body)
var data map[string]string
_ = json.Unmarshal(body, &data)
if data["status"] != "healthy" {
t.Errorf("expected status 'healthy', got %q", data["status"])
}
}
func TestReadyEndpointDefault(t *testing.T) {
srv := New(18081, "/health", "/ready", nil)
srv.Start()
defer srv.Stop(context.Background())
time.Sleep(50 * time.Millisecond)
resp, err := http.Get("http://localhost:18081/ready")
if err != nil {
t.Fatalf("ready request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("expected 200, got %d", resp.StatusCode)
}
}
func TestReadyEndpointNotReady(t *testing.T) {
ready := false
srv := New(18082, "/health", "/ready", func() bool { return ready })
srv.Start()
defer srv.Stop(context.Background())
time.Sleep(50 * time.Millisecond)
resp, err := http.Get("http://localhost:18082/ready")
if err != nil {
t.Fatalf("ready request failed: %v", err)
}
resp.Body.Close()
if resp.StatusCode != 503 {
t.Errorf("expected 503 when not ready, got %d", resp.StatusCode)
}
ready = true
resp2, err := http.Get("http://localhost:18082/ready")
if err != nil {
t.Fatalf("ready request failed: %v", err)
}
resp2.Body.Close()
if resp2.StatusCode != 200 {
t.Errorf("expected 200 when ready, got %d", resp2.StatusCode)
}
}

134
natsutil/natsutil.go Normal file
View File

@@ -0,0 +1,134 @@
// Package natsutil provides a NATS/JetStream client with msgpack serialization.
package natsutil
import (
"fmt"
"log/slog"
"time"
"github.com/nats-io/nats.go"
"github.com/vmihailenco/msgpack/v5"
)
// Client wraps a NATS connection with msgpack helpers.
type Client struct {
nc *nats.Conn
js nats.JetStreamContext
subs []*nats.Subscription
url string
opts []nats.Option
}
// New creates a NATS client configured to connect to the given URL.
// Optional NATS options (e.g. credentials) can be appended.
func New(url string, opts ...nats.Option) *Client {
defaults := []nats.Option{
nats.ReconnectWait(2 * time.Second),
nats.MaxReconnects(-1),
nats.DisconnectErrHandler(func(_ *nats.Conn, err error) {
slog.Warn("NATS disconnected", "error", err)
}),
nats.ReconnectHandler(func(_ *nats.Conn) {
slog.Info("NATS reconnected")
}),
}
return &Client{
url: url,
opts: append(defaults, opts...),
}
}
// Connect establishes the NATS connection and JetStream context.
func (c *Client) Connect() error {
nc, err := nats.Connect(c.url, c.opts...)
if err != nil {
return fmt.Errorf("nats connect: %w", err)
}
js, err := nc.JetStream()
if err != nil {
nc.Close()
return fmt.Errorf("jetstream: %w", err)
}
c.nc = nc
c.js = js
slog.Info("connected to NATS", "url", c.url)
return nil
}
// Close drains subscriptions and closes the connection.
func (c *Client) Close() {
if c.nc == nil {
return
}
for _, sub := range c.subs {
_ = sub.Drain()
}
c.nc.Close()
slog.Info("NATS connection closed")
}
// Conn returns the underlying *nats.Conn.
func (c *Client) Conn() *nats.Conn { return c.nc }
// JS returns the JetStream context.
func (c *Client) JS() nats.JetStreamContext { return c.js }
// IsConnected returns true if the NATS connection is active.
func (c *Client) IsConnected() bool {
return c.nc != nil && c.nc.IsConnected()
}
// Subscribe subscribes to a subject with an optional queue group.
// The handler receives the raw *nats.Msg.
func (c *Client) Subscribe(subject string, handler nats.MsgHandler, queue string) error {
var sub *nats.Subscription
var err error
if queue != "" {
sub, err = c.nc.QueueSubscribe(subject, queue, handler)
slog.Info("subscribed", "subject", subject, "queue", queue)
} else {
sub, err = c.nc.Subscribe(subject, handler)
slog.Info("subscribed", "subject", subject)
}
if err != nil {
return fmt.Errorf("subscribe %s: %w", subject, err)
}
c.subs = append(c.subs, sub)
return nil
}
// Publish encodes data as msgpack and publishes to the subject.
func (c *Client) Publish(subject string, data any) error {
payload, err := msgpack.Marshal(data)
if err != nil {
return fmt.Errorf("msgpack marshal: %w", err)
}
return c.nc.Publish(subject, payload)
}
// Request sends a msgpack-encoded request and decodes the response into result.
func (c *Client) Request(subject string, data any, result any, timeout time.Duration) error {
payload, err := msgpack.Marshal(data)
if err != nil {
return fmt.Errorf("msgpack marshal: %w", err)
}
msg, err := c.nc.Request(subject, payload, timeout)
if err != nil {
return fmt.Errorf("nats request: %w", err)
}
return msgpack.Unmarshal(msg.Data, result)
}
// DecodeMsgpack decodes msgpack-encoded NATS message data into dest.
func DecodeMsgpack(msg *nats.Msg, dest any) error {
return msgpack.Unmarshal(msg.Data, dest)
}
// DecodeMsgpackMap decodes msgpack data into a generic map.
func DecodeMsgpackMap(data []byte) (map[string]any, error) {
var m map[string]any
if err := msgpack.Unmarshal(data, &m); err != nil {
return nil, err
}
return m, nil
}

View File

@@ -1,70 +0,0 @@
[project]
name = "handler-base"
version = "1.0.0"
description = "Shared base library for AI/ML handler services"
readme = "README.md"
requires-python = ">=3.11"
license = { text = "MIT" }
authors = [{ name = "Davies Tech Labs" }]
dependencies = [
# Async & messaging
"nats-py>=2.7.0",
"httpx>=0.27.0",
"msgpack>=1.0.0",
# Data stores
"pymilvus>=2.4.0",
"redis>=5.0.0",
# Observability
"opentelemetry-api>=1.20.0",
"opentelemetry-sdk>=1.20.0",
"opentelemetry-exporter-otlp-proto-grpc>=1.20.0",
"opentelemetry-exporter-otlp-proto-http>=1.20.0",
"opentelemetry-instrumentation-httpx>=0.44b0",
"opentelemetry-instrumentation-logging>=0.44b0",
# MLflow
"mlflow>=2.10.0",
"psycopg2-binary>=2.9.0",
# Utilities
"numpy>=1.26.0",
"pydantic>=2.5.0",
"pydantic-settings>=2.1.0",
]
[project.optional-dependencies]
audio = [
"soundfile>=0.12.0",
"librosa>=0.10.0",
"webrtcvad>=2.0.10",
]
dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.23.0",
"pytest-cov>=4.0.0",
"ruff>=0.4.0",
"pre-commit>=3.7.0",
]
ray = [
"ray[serve]>=2.9.0",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["handler_base"]
[tool.ruff]
line-length = 100
target-version = "py311"
[tool.ruff.lint]
select = ["E", "F", "I", "W"]
[tool.pytest.ini_options]
asyncio_mode = "auto"

View File

@@ -1,7 +0,0 @@
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
asyncio_mode = auto
addopts = -v --tb=short

122
telemetry/telemetry.go Normal file
View File

@@ -0,0 +1,122 @@
// Package telemetry provides OpenTelemetry tracing and metrics setup.
package telemetry
import (
"context"
"log/slog"
"os"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/propagation"
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
"go.opentelemetry.io/otel/trace"
)
// Config holds the telemetry configuration parameters.
type Config struct {
ServiceName string
ServiceVersion string
ServiceNamespace string
DeploymentEnv string
Enabled bool
Endpoint string
}
// Provider holds the initialized tracer and meter providers.
type Provider struct {
TracerProvider *sdktrace.TracerProvider
MeterProvider *sdkmetric.MeterProvider
Tracer trace.Tracer
Meter metric.Meter
}
// Setup initialises OpenTelemetry tracing and metrics.
// Returns a Provider and a shutdown function.
func Setup(ctx context.Context, cfg Config) (*Provider, func(context.Context), error) {
if !cfg.Enabled {
slog.Info("OpenTelemetry disabled")
return &Provider{
Tracer: otel.Tracer(cfg.ServiceName),
Meter: otel.Meter(cfg.ServiceName),
}, func(context.Context) {}, nil
}
hostname, _ := os.Hostname()
res, err := resource.New(ctx,
resource.WithAttributes(
semconv.ServiceNameKey.String(cfg.ServiceName),
semconv.ServiceVersionKey.String(cfg.ServiceVersion),
semconv.ServiceNamespaceKey.String(cfg.ServiceNamespace),
attribute.String("deployment.environment", cfg.DeploymentEnv),
attribute.String("host.name", hostname),
),
)
if err != nil {
return nil, nil, err
}
// Trace exporter (gRPC)
traceExp, err := otlptracegrpc.New(ctx,
otlptracegrpc.WithEndpoint(stripScheme(cfg.Endpoint)),
otlptracegrpc.WithInsecure(),
)
if err != nil {
return nil, nil, err
}
tp := sdktrace.NewTracerProvider(
sdktrace.WithBatcher(traceExp),
sdktrace.WithResource(res),
)
otel.SetTracerProvider(tp)
otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(
propagation.TraceContext{},
propagation.Baggage{},
))
// Metric exporter (gRPC)
metricExp, err := otlpmetricgrpc.New(ctx,
otlpmetricgrpc.WithEndpoint(stripScheme(cfg.Endpoint)),
otlpmetricgrpc.WithInsecure(),
)
if err != nil {
return nil, nil, err
}
mp := sdkmetric.NewMeterProvider(
sdkmetric.WithReader(sdkmetric.NewPeriodicReader(metricExp)),
sdkmetric.WithResource(res),
)
otel.SetMeterProvider(mp)
slog.Info("OpenTelemetry initialized", "service", cfg.ServiceName, "endpoint", cfg.Endpoint)
shutdown := func(ctx context.Context) {
_ = tp.Shutdown(ctx)
_ = mp.Shutdown(ctx)
}
return &Provider{
TracerProvider: tp,
MeterProvider: mp,
Tracer: tp.Tracer(cfg.ServiceName, trace.WithInstrumentationVersion(cfg.ServiceVersion)),
Meter: mp.Meter(cfg.ServiceName),
}, shutdown, nil
}
// stripScheme removes http:// or https:// from an endpoint for gRPC dialers.
func stripScheme(endpoint string) string {
for _, prefix := range []string{"https://", "http://"} {
if len(endpoint) > len(prefix) && endpoint[:len(prefix)] == prefix {
return endpoint[len(prefix):]
}
}
return endpoint
}

View File

@@ -1,76 +0,0 @@
"""
Pytest configuration and fixtures.
"""
import asyncio
import os
from unittest.mock import AsyncMock, MagicMock
import pytest
# Set test environment variables before importing handler_base
os.environ.setdefault("NATS_URL", "nats://localhost:4222")
os.environ.setdefault("REDIS_URL", "redis://localhost:6379")
os.environ.setdefault("MILVUS_HOST", "localhost")
os.environ.setdefault("OTEL_ENABLED", "false")
os.environ.setdefault("MLFLOW_ENABLED", "false")
@pytest.fixture(scope="session")
def event_loop():
"""Create event loop for async tests."""
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture
def settings():
"""Create test settings."""
from handler_base.config import Settings
return Settings(
service_name="test-service",
service_version="1.0.0-test",
otel_enabled=False,
mlflow_enabled=False,
nats_url="nats://localhost:4222",
redis_url="redis://localhost:6379",
milvus_host="localhost",
)
@pytest.fixture
def mock_httpx_client():
"""Create a mock httpx AsyncClient."""
client = AsyncMock()
client.post = AsyncMock()
client.get = AsyncMock()
client.aclose = AsyncMock()
return client
@pytest.fixture
def mock_nats_message():
"""Create a mock NATS message."""
msg = MagicMock()
msg.subject = "test.subject"
msg.reply = "test.reply"
msg.data = b"\x82\xa8query\xa5hello\xaarequest_id\xa4test" # msgpack
return msg
@pytest.fixture
def sample_embedding():
"""Sample embedding vector."""
return [0.1] * 1024
@pytest.fixture
def sample_documents():
"""Sample documents for testing."""
return [
{"text": "Python is a programming language.", "source": "doc1"},
{"text": "Machine learning is a subset of AI.", "source": "doc2"},
{"text": "Deep learning uses neural networks.", "source": "doc3"},
]

View File

@@ -1 +0,0 @@
# Unit tests package

View File

@@ -1,150 +0,0 @@
"""
Unit tests for service clients.
"""
from unittest.mock import MagicMock
import pytest
class TestEmbeddingsClient:
"""Tests for EmbeddingsClient."""
@pytest.fixture
def embeddings_client(self, mock_httpx_client):
"""Create an EmbeddingsClient with mocked HTTP."""
from handler_base.clients.embeddings import EmbeddingsClient
client = EmbeddingsClient()
client._client = mock_httpx_client
return client
@pytest.mark.asyncio
async def test_embed_single(self, embeddings_client, mock_httpx_client, sample_embedding):
"""Test embedding a single text."""
# Setup mock response
mock_response = MagicMock()
mock_response.json.return_value = {"data": [{"embedding": sample_embedding, "index": 0}]}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response
result = await embeddings_client.embed_single("Hello world")
assert result == sample_embedding
mock_httpx_client.post.assert_called_once()
@pytest.mark.asyncio
async def test_embed_batch(self, embeddings_client, mock_httpx_client, sample_embedding):
"""Test embedding multiple texts."""
texts = ["Hello", "World"]
mock_response = MagicMock()
mock_response.json.return_value = {
"data": [
{"embedding": sample_embedding, "index": 0},
{"embedding": sample_embedding, "index": 1},
]
}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response
result = await embeddings_client.embed(texts)
assert len(result) == 2
assert all(len(e) == len(sample_embedding) for e in result)
@pytest.mark.asyncio
async def test_health_check(self, embeddings_client, mock_httpx_client):
"""Test health check."""
mock_response = MagicMock()
mock_response.status_code = 200
mock_httpx_client.get.return_value = mock_response
result = await embeddings_client.health()
assert result is True
class TestRerankerClient:
"""Tests for RerankerClient."""
@pytest.fixture
def reranker_client(self, mock_httpx_client):
"""Create a RerankerClient with mocked HTTP."""
from handler_base.clients.reranker import RerankerClient
client = RerankerClient()
client._client = mock_httpx_client
return client
@pytest.mark.asyncio
async def test_rerank(self, reranker_client, mock_httpx_client, sample_documents):
"""Test reranking documents."""
texts = [d["text"] for d in sample_documents]
mock_response = MagicMock()
mock_response.json.return_value = {
"results": [
{"index": 1, "relevance_score": 0.95},
{"index": 0, "relevance_score": 0.80},
{"index": 2, "relevance_score": 0.65},
]
}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response
result = await reranker_client.rerank("What is ML?", texts)
assert len(result) == 3
assert result[0]["score"] == 0.95
assert result[0]["index"] == 1
class TestLLMClient:
"""Tests for LLMClient."""
@pytest.fixture
def llm_client(self, mock_httpx_client):
"""Create an LLMClient with mocked HTTP."""
from handler_base.clients.llm import LLMClient
client = LLMClient()
client._client = mock_httpx_client
return client
@pytest.mark.asyncio
async def test_generate(self, llm_client, mock_httpx_client):
"""Test generating a response."""
mock_response = MagicMock()
mock_response.json.return_value = {
"choices": [{"message": {"content": "Hello! I'm an AI assistant."}}],
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response
result = await llm_client.generate("Hello")
assert result == "Hello! I'm an AI assistant."
@pytest.mark.asyncio
async def test_generate_with_context(self, llm_client, mock_httpx_client):
"""Test generating with RAG context."""
mock_response = MagicMock()
mock_response.json.return_value = {
"choices": [{"message": {"content": "Based on the context..."}}],
"usage": {},
}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response
result = await llm_client.generate(
"What is Python?", context="Python is a programming language."
)
assert "Based on the context" in result
# Verify context was included in the request
call_args = mock_httpx_client.post.call_args
messages = call_args.kwargs["json"]["messages"]
assert any("Context:" in m["content"] for m in messages if m["role"] == "user")

View File

@@ -1,46 +0,0 @@
"""
Unit tests for handler_base.config module.
"""
class TestSettings:
"""Tests for Settings configuration."""
def test_default_settings(self, settings):
"""Test that default settings are loaded correctly."""
assert settings.service_name == "test-service"
assert settings.service_version == "1.0.0-test"
assert settings.otel_enabled is False
def test_settings_from_env(self, monkeypatch):
"""Test that settings can be loaded from environment variables."""
monkeypatch.setenv("SERVICE_NAME", "env-service")
monkeypatch.setenv("SERVICE_VERSION", "2.0.0")
monkeypatch.setenv("NATS_URL", "nats://custom:4222")
# Need to reimport to pick up env changes
from handler_base.config import Settings
s = Settings()
assert s.service_name == "env-service"
assert s.service_version == "2.0.0"
assert s.nats_url == "nats://custom:4222"
def test_embeddings_settings(self):
"""Test EmbeddingsSettings extends base correctly."""
from handler_base.config import EmbeddingsSettings
s = EmbeddingsSettings()
assert hasattr(s, "embeddings_model")
assert hasattr(s, "embeddings_batch_size")
assert s.embeddings_model == "bge"
def test_llm_settings(self):
"""Test LLMSettings has expected defaults."""
from handler_base.config import LLMSettings
s = LLMSettings()
assert s.llm_max_tokens == 2048
assert s.llm_temperature == 0.7
assert 0 <= s.llm_top_p <= 1

View File

@@ -1,122 +0,0 @@
"""
Unit tests for handler_base.health module.
"""
import json
import time
from http.client import HTTPConnection
import pytest
class TestHealthServer:
"""Tests for HealthServer."""
@pytest.fixture
def health_server(self, settings):
"""Create a HealthServer instance."""
from handler_base.health import HealthServer
# Use a random high port to avoid conflicts
settings.health_port = 18080
return HealthServer(settings)
def test_start_stop(self, health_server):
"""Test starting and stopping the health server."""
health_server.start()
time.sleep(0.1) # Give server time to start
# Verify server is running
assert health_server._server is not None
assert health_server._thread is not None
assert health_server._thread.is_alive()
health_server.stop()
time.sleep(0.1)
assert health_server._server is None
def test_health_endpoint(self, health_server):
"""Test the /health endpoint."""
health_server.start()
time.sleep(0.1)
try:
conn = HTTPConnection("localhost", 18080, timeout=5)
conn.request("GET", "/health")
response = conn.getresponse()
assert response.status == 200
data = json.loads(response.read().decode())
assert data["status"] == "healthy"
finally:
conn.close()
health_server.stop()
def test_ready_endpoint_default(self, health_server):
"""Test the /ready endpoint with no custom check."""
health_server.start()
time.sleep(0.1)
try:
conn = HTTPConnection("localhost", 18080, timeout=5)
conn.request("GET", "/ready")
response = conn.getresponse()
assert response.status == 200
data = json.loads(response.read().decode())
assert data["status"] == "ready"
finally:
conn.close()
health_server.stop()
def test_ready_endpoint_with_check(self, settings):
"""Test /ready endpoint with custom readiness check."""
from handler_base.health import HealthServer
ready_flag = [False] # Use list to allow mutation in closure
async def check_ready():
return ready_flag[0]
settings.health_port = 18081
server = HealthServer(settings, ready_check=check_ready)
server.start()
time.sleep(0.2)
try:
conn = HTTPConnection("localhost", 18081, timeout=5)
# Should be not ready initially
conn.request("GET", "/ready")
response = conn.getresponse()
response.read() # Consume response body
assert response.status == 503
# Mark as ready
ready_flag[0] = True
# Need new connection after consuming response
conn.close()
conn = HTTPConnection("localhost", 18081, timeout=5)
conn.request("GET", "/ready")
response = conn.getresponse()
assert response.status == 200
finally:
conn.close()
server.stop()
def test_404_for_unknown_path(self, health_server):
"""Test that unknown paths return 404."""
health_server.start()
time.sleep(0.1)
try:
conn = HTTPConnection("localhost", 18080, timeout=5)
conn.request("GET", "/unknown")
response = conn.getresponse()
assert response.status == 404
finally:
conn.close()
health_server.stop()

View File

@@ -1,95 +0,0 @@
"""
Unit tests for handler_base.nats_client module.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import msgpack
import pytest
class TestNATSClient:
"""Tests for NATSClient."""
@pytest.fixture
def nats_client(self, settings):
"""Create a NATSClient instance."""
from handler_base.nats_client import NATSClient
return NATSClient(settings)
def test_init(self, nats_client, settings):
"""Test NATSClient initialization."""
assert nats_client.settings == settings
assert nats_client._nc is None
assert nats_client._js is None
def test_decode_msgpack(self, nats_client):
"""Test msgpack decoding."""
data = {"query": "hello", "request_id": "123"}
encoded = msgpack.packb(data, use_bin_type=True)
msg = MagicMock()
msg.data = encoded
result = nats_client.decode_msgpack(msg)
assert result == data
def test_decode_json(self, nats_client):
"""Test JSON decoding."""
import json
data = {"query": "hello"}
msg = MagicMock()
msg.data = json.dumps(data).encode()
result = nats_client.decode_json(msg)
assert result == data
@pytest.mark.asyncio
async def test_connect(self, nats_client):
"""Test NATS connection."""
with patch("handler_base.nats_client.nats") as mock_nats:
mock_nc = AsyncMock()
mock_js = MagicMock()
mock_nc.jetstream = MagicMock(return_value=mock_js) # Not async
mock_nats.connect = AsyncMock(return_value=mock_nc)
await nats_client.connect()
assert nats_client._nc == mock_nc
assert nats_client._js == mock_js
mock_nats.connect.assert_called_once()
@pytest.mark.asyncio
async def test_publish(self, nats_client):
"""Test publishing a message."""
mock_nc = AsyncMock()
nats_client._nc = mock_nc
data = {"key": "value"}
await nats_client.publish("test.subject", data)
mock_nc.publish.assert_called_once()
call_args = mock_nc.publish.call_args
assert call_args.args[0] == "test.subject"
# Verify msgpack encoding
decoded = msgpack.unpackb(call_args.args[1], raw=False)
assert decoded == data
@pytest.mark.asyncio
async def test_subscribe(self, nats_client):
"""Test subscribing to a subject."""
mock_nc = AsyncMock()
mock_sub = MagicMock()
mock_nc.subscribe = AsyncMock(return_value=mock_sub)
nats_client._nc = mock_nc
handler = AsyncMock()
await nats_client.subscribe("test.subject", handler, queue="test-queue")
mock_nc.subscribe.assert_called_once()
call_kwargs = mock_nc.subscribe.call_args.kwargs
assert call_kwargs["queue"] == "test-queue"

4256
uv.lock generated

File diff suppressed because it is too large Load Diff