11 Commits

Author SHA1 Message Date
fba7b62573 fix: rename GITEA_TOKEN to DISPATCH_TOKEN to avoid built-in prefix
All checks were successful
CI / Lint (push) Successful in 2m52s
CI / Test (push) Successful in 2m46s
CI / Release (push) Successful in 52s
CI / Notify Downstream (chat-handler) (push) Successful in 2s
CI / Notify Downstream (pipeline-bridge) (push) Successful in 2s
CI / Notify Downstream (stt-module) (push) Successful in 2s
CI / Notify Downstream (tts-module) (push) Successful in 2s
CI / Notify Downstream (voice-assistant) (push) Successful in 2s
CI / Notify (push) Successful in 2s
2026-02-20 09:10:13 -05:00
6fd0b9a265 feat: add downstream dependency cascade on release
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release (push) Has been cancelled
CI / Notify Downstream (chat-handler) (push) Has been cancelled
CI / Notify Downstream (pipeline-bridge) (push) Has been cancelled
CI / Notify Downstream (voice-assistant) (push) Has been cancelled
CI / Notify (push) Has been cancelled
CI / Notify Downstream (stt-module) (push) Has been cancelled
CI / Notify Downstream (tts-module) (push) Has been cancelled
After a successful release tag, notify 5 downstream repos via
Gitea repository_dispatch so they auto-update handler-base.
2026-02-20 09:05:46 -05:00
8b6232141a ci: verify Go CI pipeline
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release (push) Has been cancelled
CI / Notify (push) Has been cancelled
2026-02-20 09:01:44 -05:00
9876cb9388 fix: replace Python CI workflow with Go CI
All checks were successful
CI / Release (push) Successful in 58s
CI / Notify (push) Successful in 1s
CI / Lint (push) Successful in 3m24s
CI / Test (push) Successful in 3m14s
- Replace uv/ruff/pytest with Go setup, golangci-lint, go test
- Library-only: lint + test + release (tag) + notify, no Docker build
2026-02-20 08:49:29 -05:00
39673d31b8 fix: resolve golangci-lint errcheck warnings
Some checks failed
CI / Lint (push) Failing after 59s
CI / Test (push) Failing after 1m39s
CI / Release (push) Has been cancelled
CI / Notify (push) Has been cancelled
- Add error checks for unchecked return values (errcheck)
- Remove unused struct fields (unused)
- Fix gofmt formatting issues
2026-02-20 08:45:19 -05:00
81581337cd Merge pull request 'feature/go-handler-refactor' (#1) from feature/go-handler-refactor into main
Some checks failed
CI / Lint (push) Failing after 1m18s
CI / Test (push) Failing after 1m19s
CI / Release (push) Has been skipped
CI / Notify (push) Successful in 1s
Reviewed-on: #1
2026-02-20 12:33:18 +00:00
ea9b3a8f2b feat: add TypedMessageHandler + generic Decode[T] helper
Some checks failed
CI / Lint (pull_request) Failing after 1m25s
CI / Test (pull_request) Failing after 1m25s
CI / Release (pull_request) Has been skipped
CI / Notify (pull_request) Successful in 1s
- handler: add OnTypedMessage() for typed NATS message callbacks
  Avoids double-decode (msgpack→map→typed) by skipping map step
- handler: refactor wrapHandler into wrapTypedHandler + wrapMapHandler
- natsutil: add generic Decode[T](data) for direct msgpack→struct decode
- tests: add typed handler tests + benchmark (11 tests pass)
2026-02-20 07:10:33 -05:00
35912d5844 feat: add e2e tests, perf benchmarks, and infrastructure improvements
- messages/bench_test.go: serialization benchmarks (msgpack map vs struct vs protobuf)
- clients/clients_test.go: HTTP client tests with pooling verification (20 tests)
- natsutil/natsutil_test.go: encode/decode roundtrip + binary data tests
- handler/handler_test.go: handler dispatch tests + benchmark
- config/config.go: live reload via fsnotify + RWMutex getter methods
- clients/clients.go: SharedTransport + sync.Pool buffer pooling
- messages/messages.go: typed structs with msgpack+json tags
- messages/proto/: protobuf schema + generated code

Benchmark baseline (ChatRequest roundtrip):
  MsgpackMap:    2949 ns/op, 36 allocs
  MsgpackStruct: 2030 ns/op, 13 allocs (31% faster, 64% fewer allocs)
  Protobuf:       793 ns/op,  8 allocs (73% faster, 78% fewer allocs)
2026-02-20 06:44:37 -05:00
d321c9852b refactor: rewrite handler-base as Go module
Replace Python handler-base library with Go module providing:
- config: environment-based configuration
- health: HTTP health/readiness server for k8s probes
- natsutil: NATS/JetStream client with msgpack serialization
- telemetry: OpenTelemetry tracing and metrics setup
- clients: HTTP clients for LLM, embeddings, reranker, STT, TTS
- handler: base Handler runner wiring NATS + health + telemetry

Implements ADR-0061 Phase 1.
2026-02-19 17:16:17 -05:00
5eb2c43a5d fix: replace astral-sh/setup-uv action with shell install
All checks were successful
CI / Lint (push) Successful in 1m34s
CI / Test (push) Successful in 1m59s
CI / Release (push) Successful in 4s
CI / Notify (push) Successful in 1s
The JS-based GitHub Action doesn't work on Gitea's act runner.
Use curl installer + GITHUB_PATH instead.
2026-02-13 19:40:55 -05:00
cb91015964 chore: add Renovate config for automated dependency updates
All checks were successful
CI / Lint (push) Successful in 1m35s
CI / Release (push) Successful in 5s
CI / Notify (push) Successful in 1s
CI / Test (push) Successful in 1m39s
Ref: ADR-0057
2026-02-13 15:33:54 -05:00
47 changed files with 5430 additions and 6997 deletions

View File

@@ -17,23 +17,22 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up uv - name: Set up Go
uses: astral-sh/setup-uv@v7 uses: actions/setup-go@v5
with: with:
version: "latest" go-version-file: go.mod
activate-environment: false cache: true
- name: Set up Python - name: Run go vet
run: uv python install 3.13 run: go vet ./...
- name: Install dependencies - name: Install golangci-lint
run: uv sync --frozen --extra dev run: |
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b "$(go env GOPATH)/bin"
echo "$(go env GOPATH)/bin" >> $GITHUB_PATH
- name: Run ruff check - name: Run golangci-lint
run: uv run ruff check . run: golangci-lint run ./...
- name: Run ruff format check
run: uv run ruff format --check .
test: test:
name: Test name: Test
@@ -42,26 +41,28 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up uv - name: Set up Go
uses: astral-sh/setup-uv@v7 uses: actions/setup-go@v5
with: with:
version: "latest" go-version-file: go.mod
activate-environment: false cache: true
- name: Set up Python - name: Verify dependencies
run: uv python install 3.13 run: go mod verify
- name: Install dependencies - name: Build
run: uv sync --frozen --extra dev run: go build -v ./...
- name: Run tests with coverage - name: Run tests
run: uv run pytest --cov=handler_base --cov-report=xml --cov-report=term run: go test -v -race -coverprofile=coverage.out -covermode=atomic ./...
release: release:
name: Release name: Release
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [lint, test] needs: [lint, test]
if: gitea.ref == 'refs/heads/main' && gitea.event_name == 'push' if: gitea.ref == 'refs/heads/main' && gitea.event_name == 'push'
outputs:
version: ${{ steps.version.outputs.version }}
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -75,7 +76,7 @@ jobs:
LATEST=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0") LATEST=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
VERSION=${LATEST#v} VERSION=${LATEST#v}
IFS='.' read -r MAJOR MINOR PATCH <<< "$VERSION" IFS='.' read -r MAJOR MINOR PATCH <<< "$VERSION"
# Check commit message for keywords # Check commit message for keywords
MSG="${{ gitea.event.head_commit.message }}" MSG="${{ gitea.event.head_commit.message }}"
if echo "$MSG" | grep -qiE "^major:|BREAKING CHANGE"; then if echo "$MSG" | grep -qiE "^major:|BREAKING CHANGE"; then
@@ -88,7 +89,7 @@ jobs:
PATCH=$((PATCH + 1)) PATCH=$((PATCH + 1))
BUMP="patch" BUMP="patch"
fi fi
NEW_VERSION="v${MAJOR}.${MINOR}.${PATCH}" NEW_VERSION="v${MAJOR}.${MINOR}.${PATCH}"
echo "version=$NEW_VERSION" >> $GITHUB_OUTPUT echo "version=$NEW_VERSION" >> $GITHUB_OUTPUT
echo "bump=$BUMP" >> $GITHUB_OUTPUT echo "bump=$BUMP" >> $GITHUB_OUTPUT
@@ -101,10 +102,32 @@ jobs:
git tag -a ${{ steps.version.outputs.version }} -m "Release ${{ steps.version.outputs.version }}" git tag -a ${{ steps.version.outputs.version }} -m "Release ${{ steps.version.outputs.version }}"
git push origin ${{ steps.version.outputs.version }} git push origin ${{ steps.version.outputs.version }}
notify-downstream:
name: Notify Downstream
runs-on: ubuntu-latest
needs: [release]
if: needs.release.result == 'success'
strategy:
matrix:
repo:
- chat-handler
- pipeline-bridge
- tts-module
- voice-assistant
- stt-module
steps:
- name: Trigger dependency update
run: |
curl -s -X POST \
-H "Authorization: token ${{ secrets.DISPATCH_TOKEN }}" \
-H "Content-Type: application/json" \
-d '{"event_type":"handler-base-release","client_payload":{"version":"${{ needs.release.outputs.version }}"}}' \
"${{ gitea.server_url }}/api/v1/repos/daviestechlabs/${{ matrix.repo }}/dispatches"
notify: notify:
name: Notify name: Notify
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [lint, test, release] needs: [lint, test, release, notify-downstream]
if: always() if: always()
steps: steps:
- name: Notify on success - name: Notify on success

35
.gitignore vendored
View File

@@ -1,23 +1,20 @@
# Python # Go
__pycache__/ *.exe
*.py[cod] *.dll
*$py.class
*.so *.so
.Python *.dylib
build/ *.test
develop-eggs/ *.out
dist/ vendor/
downloads/
eggs/ # IDE
.eggs/ .idea/
lib/ .vscode/
lib64/ *.swp
parts/
sdist/ # OS
var/ .DS_Store
wheels/ Thumbs.db
*.egg-info/
.installed.cfg
*.egg *.egg
# Virtual environments # Virtual environments

View File

@@ -1,32 +0,0 @@
# Pre-commit hooks for handler-base
# Install: pip install pre-commit && pre-commit install
# Run: pre-commit run --all-files
repos:
# Ruff - fast Python linter and formatter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format
# Standard pre-commit hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
args: [--maxkb=500]
- id: check-merge-conflict
- id: detect-private-key
# Type checking (optional - uncomment when ready)
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.10.0
# hooks:
# - id: mypy
# additional_dependencies: [types-all]
# args: [--ignore-missing-imports]

View File

@@ -1,58 +0,0 @@
# Handler Base Image
#
# Provides a pre-built base with all common dependencies for handler services.
# Services extend this image and add their specific code.
#
# Build:
# docker build -t ghcr.io/billy-davies-2/handler-base:latest .
#
# Usage in child Dockerfile:
# FROM ghcr.io/billy-davies-2/handler-base:latest
# COPY my_handler.py .
# CMD ["python", "my_handler.py"]
FROM python:3.13-slim AS base
WORKDIR /app
# Install uv for fast, reliable package management
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy and install handler-base package
COPY pyproject.toml README.md ./
COPY handler_base/ ./handler_base/
# Install the package with all dependencies
RUN uv pip install --system --no-cache .
# Set environment variables
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
# Default health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
CMD curl -f http://localhost:8080/health || exit 1
# Default command (override in child images)
CMD ["python", "-c", "print('handler-base ready')"]
# Audio variant with soundfile, librosa, webrtcvad
FROM base AS audio
RUN apt-get update && apt-get install -y --no-install-recommends \
ffmpeg \
libsndfile1 \
gcc \
python3-dev \
&& rm -rf /var/lib/apt/lists/*
RUN uv pip install --system --no-cache \
soundfile>=0.12.0 \
librosa>=0.10.0 \
webrtcvad>=2.0.10

129
README.md
View File

@@ -1,109 +1,44 @@
# Handler Base # handler-base
Shared base library for building NATS-based AI/ML handler services. Go module providing shared infrastructure for NATS-based handler services.
## Installation ## Packages
```bash | Package | Purpose |
pip install handler-base |---------|---------|
``` | `config` | Environment-based configuration via struct fields |
| `health` | HTTP health/readiness server for Kubernetes probes |
| `natsutil` | NATS/JetStream client with msgpack serialization |
| `telemetry` | OpenTelemetry tracing and metrics setup |
| `clients` | HTTP clients for LLM, embeddings, reranker, STT, TTS |
| `handler` | Base Handler runner wiring NATS + health + telemetry |
Or from Gitea: ## Usage
```bash
pip install git+https://git.daviestechlabs.io/daviestechlabs/handler-base.git
```
## Quick Start ```go
package main
```python import (
from handler_base import Handler, Settings "context"
from nats.aio.msg import Msg "git.daviestechlabs.io/daviestechlabs/handler-base/config"
"git.daviestechlabs.io/daviestechlabs/handler-base/handler"
class MyHandler(Handler): "github.com/nats-io/nats.go"
async def setup(self):
# Initialize your clients
pass
async def handle_message(self, msg: Msg, data: dict):
# Process the message
result = {"processed": True}
return result
if __name__ == "__main__":
MyHandler(subject="my.subject").run()
```
## Features
- **Handler base class** - NATS subscription, graceful shutdown, signal handling
- **NATSClient** - Connection management, JetStream, msgpack serialization
- **Settings** - Pydantic-based configuration from environment
- **HealthServer** - Kubernetes liveness/readiness probes
- **Telemetry** - OpenTelemetry tracing and metrics
- **Service clients** - HTTP wrappers for AI services
## Service Clients
```python
from handler_base.clients import (
STTClient, # Whisper speech-to-text
TTSClient, # XTTS text-to-speech
LLMClient, # vLLM chat completions
EmbeddingsClient, # BGE embeddings
RerankerClient, # BGE reranker
MilvusClient, # Vector database
) )
func main() {
cfg := config.Load()
cfg.ServiceName = "my-service"
h := handler.New("my.subject", cfg)
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
return map[string]any{"ok": true}, nil
})
h.Run()
}
``` ```
## Configuration ## Testing
All settings via environment variables:
| Variable | Default | Description |
|----------|---------|-------------|
| `NATS_URL` | `nats://localhost:4222` | NATS server URL |
| `NATS_USER` | - | NATS username |
| `NATS_PASSWORD` | - | NATS password |
| `NATS_QUEUE_GROUP` | - | Queue group for load balancing |
| `HEALTH_PORT` | `8080` | Health check server port |
| `OTEL_ENABLED` | `true` | Enable OpenTelemetry |
| `OTEL_EXPORTER_OTLP_ENDPOINT` | `http://localhost:4317` | OTLP endpoint |
| `OTEL_SERVICE_NAME` | `handler` | Service name for traces |
## Docker
```dockerfile
FROM ghcr.io/daviestechlabs/handler-base:latest
COPY my_handler.py /app/
CMD ["python", "/app/my_handler.py"]
```
Or build with audio support:
```bash ```bash
docker build --build-arg INSTALL_AUDIO=true -t my-handler . go test ./...
``` ```
## Module Structure
```
handler_base/
├── __init__.py # Public API exports
├── handler.py # Base Handler class
├── nats_client.py # NATS connection wrapper
├── config.py # Pydantic Settings
├── health.py # Health check server
├── telemetry.py # OpenTelemetry setup
└── clients/
├── embeddings.py
├── llm.py
├── milvus.py
├── reranker.py
├── stt.py
└── tts.py
```
## Related
- [voice-assistant](https://git.daviestechlabs.io/daviestechlabs/voice-assistant) - Voice pipeline using handler-base
- [homelab-design](https://git.daviestechlabs.io/daviestechlabs/homelab-design) - Architecture docs

436
clients/clients.go Normal file
View File

@@ -0,0 +1,436 @@
// Package clients provides HTTP client wrappers for AI/ML backend services.
//
// All clients share a single [http.Transport] for connection pooling across
// the process. Request and response bodies are serialized through pooled
// [bytes.Buffer]s to reduce GC pressure.
package clients
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/url"
"sync"
"time"
)
// ─── Shared transport & buffer pool ─────────────────────────────────────────
// SharedTransport is the process-wide HTTP transport used by every service
// client. Tweak pool sizes here rather than creating per-client transports.
var SharedTransport = &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
DisableCompression: true, // in-cluster traffic; skip gzip overhead
}
// bufPool recycles *bytes.Buffer to avoid per-request allocations.
var bufPool = sync.Pool{
New: func() any { return new(bytes.Buffer) },
}
func getBuf() *bytes.Buffer {
buf := bufPool.Get().(*bytes.Buffer)
buf.Reset()
return buf
}
func putBuf(buf *bytes.Buffer) {
if buf.Cap() > 1<<20 { // don't cache buffers > 1 MiB
return
}
bufPool.Put(buf)
}
// ─── httpClient base ────────────────────────────────────────────────────────
// httpClient is the shared base for all service clients.
type httpClient struct {
client *http.Client
baseURL string
}
func newHTTPClient(baseURL string, timeout time.Duration) *httpClient {
return &httpClient{
client: &http.Client{
Timeout: timeout,
Transport: SharedTransport,
},
baseURL: baseURL,
}
}
func (h *httpClient) postJSON(ctx context.Context, path string, body any) ([]byte, error) {
buf := getBuf()
defer putBuf(buf)
if err := json.NewEncoder(buf).Encode(body); err != nil {
return nil, fmt.Errorf("marshal: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, buf)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
return h.do(req)
}
func (h *httpClient) get(ctx context.Context, path string, params url.Values) ([]byte, error) {
u := h.baseURL + path
if len(params) > 0 {
u += "?" + params.Encode()
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return nil, err
}
return h.do(req)
}
func (h *httpClient) getRaw(ctx context.Context, path string, params url.Values) ([]byte, error) {
return h.get(ctx, path, params)
}
func (h *httpClient) postMultipart(ctx context.Context, path string, fieldName string, fileName string, fileData []byte, fields map[string]string) ([]byte, error) {
buf := getBuf()
defer putBuf(buf)
w := multipart.NewWriter(buf)
part, err := w.CreateFormFile(fieldName, fileName)
if err != nil {
return nil, err
}
if _, err := part.Write(fileData); err != nil {
return nil, err
}
for k, v := range fields {
_ = w.WriteField(k, v)
}
_ = w.Close()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, buf)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", w.FormDataContentType())
return h.do(req)
}
func (h *httpClient) do(req *http.Request) ([]byte, error) {
resp, err := h.client.Do(req)
if err != nil {
return nil, fmt.Errorf("http %s %s: %w", req.Method, req.URL.Path, err)
}
defer func() { _ = resp.Body.Close() }()
buf := getBuf()
defer putBuf(buf)
if _, err := io.Copy(buf, resp.Body); err != nil {
return nil, fmt.Errorf("read body: %w", err)
}
// Return a copy so the pooled buffer can be safely recycled.
body := make([]byte, buf.Len())
copy(body, buf.Bytes())
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("http %d: %s", resp.StatusCode, string(body))
}
return body, nil
}
func (h *httpClient) healthCheck(ctx context.Context) bool {
data, err := h.get(ctx, "/health", nil)
_ = data
return err == nil
}
// ─── Embeddings Client ──────────────────────────────────────────────────────
// EmbeddingsClient calls the embeddings service (Infinity/BGE).
type EmbeddingsClient struct {
*httpClient
Model string
}
// NewEmbeddingsClient creates an embeddings client.
func NewEmbeddingsClient(baseURL string, timeout time.Duration, model string) *EmbeddingsClient {
if model == "" {
model = "bge"
}
return &EmbeddingsClient{httpClient: newHTTPClient(baseURL, timeout), Model: model}
}
// Embed generates embeddings for a list of texts.
func (c *EmbeddingsClient) Embed(ctx context.Context, texts []string) ([][]float64, error) {
body, err := c.postJSON(ctx, "/embeddings", map[string]any{
"input": texts,
"model": c.Model,
})
if err != nil {
return nil, err
}
var resp struct {
Data []struct {
Embedding []float64 `json:"embedding"`
} `json:"data"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return nil, err
}
result := make([][]float64, len(resp.Data))
for i, d := range resp.Data {
result[i] = d.Embedding
}
return result, nil
}
// EmbedSingle generates an embedding for a single text.
func (c *EmbeddingsClient) EmbedSingle(ctx context.Context, text string) ([]float64, error) {
results, err := c.Embed(ctx, []string{text})
if err != nil {
return nil, err
}
if len(results) == 0 {
return nil, fmt.Errorf("empty embedding result")
}
return results[0], nil
}
// Health checks if the embeddings service is healthy.
func (c *EmbeddingsClient) Health(ctx context.Context) bool {
return c.healthCheck(ctx)
}
// ─── Reranker Client ────────────────────────────────────────────────────────
// RerankerClient calls the reranker service (BGE Reranker).
type RerankerClient struct {
*httpClient
}
// NewRerankerClient creates a reranker client.
func NewRerankerClient(baseURL string, timeout time.Duration) *RerankerClient {
return &RerankerClient{httpClient: newHTTPClient(baseURL, timeout)}
}
// RerankResult represents a reranked document.
type RerankResult struct {
Index int `json:"index"`
Score float64 `json:"score"`
Document string `json:"document"`
}
// Rerank reranks documents by relevance to the query.
func (c *RerankerClient) Rerank(ctx context.Context, query string, documents []string, topK int) ([]RerankResult, error) {
payload := map[string]any{
"query": query,
"documents": documents,
}
if topK > 0 {
payload["top_n"] = topK
}
body, err := c.postJSON(ctx, "/rerank", payload)
if err != nil {
return nil, err
}
var resp struct {
Results []struct {
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
Score float64 `json:"score"`
} `json:"results"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return nil, err
}
results := make([]RerankResult, len(resp.Results))
for i, r := range resp.Results {
score := r.RelevanceScore
if score == 0 {
score = r.Score
}
doc := ""
if r.Index < len(documents) {
doc = documents[r.Index]
}
results[i] = RerankResult{Index: r.Index, Score: score, Document: doc}
}
return results, nil
}
// ─── LLM Client ─────────────────────────────────────────────────────────────
// LLMClient calls the vLLM-compatible LLM service.
type LLMClient struct {
*httpClient
Model string
MaxTokens int
Temperature float64
TopP float64
}
// NewLLMClient creates an LLM client.
func NewLLMClient(baseURL string, timeout time.Duration) *LLMClient {
return &LLMClient{
httpClient: newHTTPClient(baseURL, timeout),
Model: "default",
MaxTokens: 2048,
Temperature: 0.7,
TopP: 0.9,
}
}
// ChatMessage is an OpenAI-compatible message.
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// Generate sends a chat completion request and returns the response text.
func (c *LLMClient) Generate(ctx context.Context, prompt string, context_ string, systemPrompt string) (string, error) {
messages := buildMessages(prompt, context_, systemPrompt)
payload := map[string]any{
"model": c.Model,
"messages": messages,
"max_tokens": c.MaxTokens,
"temperature": c.Temperature,
"top_p": c.TopP,
}
body, err := c.postJSON(ctx, "/v1/chat/completions", payload)
if err != nil {
return "", err
}
var resp struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return "", err
}
if len(resp.Choices) == 0 {
return "", fmt.Errorf("no choices in LLM response")
}
return resp.Choices[0].Message.Content, nil
}
func buildMessages(prompt, ctx, systemPrompt string) []ChatMessage {
var msgs []ChatMessage
if systemPrompt != "" {
msgs = append(msgs, ChatMessage{Role: "system", Content: systemPrompt})
} else if ctx != "" {
msgs = append(msgs, ChatMessage{Role: "system", Content: "You are a helpful assistant. Use the provided context to answer the user's question. If the context doesn't contain relevant information, say so."})
}
if ctx != "" {
msgs = append(msgs, ChatMessage{Role: "user", Content: fmt.Sprintf("Context:\n%s\n\nQuestion: %s", ctx, prompt)})
} else {
msgs = append(msgs, ChatMessage{Role: "user", Content: prompt})
}
return msgs
}
// ─── TTS Client ─────────────────────────────────────────────────────────────
// TTSClient calls the TTS service (Coqui XTTS).
type TTSClient struct {
*httpClient
Language string
}
// NewTTSClient creates a TTS client.
func NewTTSClient(baseURL string, timeout time.Duration, language string) *TTSClient {
if language == "" {
language = "en"
}
return &TTSClient{httpClient: newHTTPClient(baseURL, timeout), Language: language}
}
// Synthesize generates audio bytes from text.
func (c *TTSClient) Synthesize(ctx context.Context, text, language, speaker string) ([]byte, error) {
if language == "" {
language = c.Language
}
params := url.Values{
"text": {text},
"language_id": {language},
}
if speaker != "" {
params.Set("speaker_id", speaker)
}
return c.getRaw(ctx, "/api/tts", params)
}
// ─── STT Client ─────────────────────────────────────────────────────────────
// STTClient calls the Whisper STT service.
type STTClient struct {
*httpClient
Language string
Task string
}
// NewSTTClient creates an STT client.
func NewSTTClient(baseURL string, timeout time.Duration) *STTClient {
return &STTClient{httpClient: newHTTPClient(baseURL, timeout), Task: "transcribe"}
}
// TranscribeResult holds transcription output.
type TranscribeResult struct {
Text string `json:"text"`
Language string `json:"language,omitempty"`
}
// Transcribe sends audio to Whisper and returns the transcription.
func (c *STTClient) Transcribe(ctx context.Context, audio []byte, language string) (*TranscribeResult, error) {
if language == "" {
language = c.Language
}
fields := map[string]string{
"response_format": "json",
}
if language != "" {
fields["language"] = language
}
endpoint := "/v1/audio/transcriptions"
if c.Task == "translate" {
endpoint = "/v1/audio/translations"
}
body, err := c.postMultipart(ctx, endpoint, "file", "audio.wav", audio, fields)
if err != nil {
return nil, err
}
var result TranscribeResult
if err := json.Unmarshal(body, &result); err != nil {
return nil, err
}
return &result, nil
}
// ─── Milvus Client ──────────────────────────────────────────────────────────
// MilvusClient provides vector search via the Milvus HTTP/gRPC API.
// For the Go port we use the Milvus Go SDK.
type MilvusClient struct {
Host string
Port int
Collection string
}
// NewMilvusClient creates a Milvus client.
func NewMilvusClient(host string, port int, collection string) *MilvusClient {
return &MilvusClient{Host: host, Port: port, Collection: collection}
}
// SearchResult holds a single vector search hit.
type SearchResult struct {
ID int64 `json:"id"`
Distance float64 `json:"distance"`
Score float64 `json:"score"`
Fields map[string]any `json:"fields,omitempty"`
}

506
clients/clients_test.go Normal file
View File

@@ -0,0 +1,506 @@
package clients
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
)
// ────────────────────────────────────────────────────────────────────────────
// Shared infrastructure tests
// ────────────────────────────────────────────────────────────────────────────
func TestSharedTransport(t *testing.T) {
// All clients created via newHTTPClient should share the same transport.
c1 := newHTTPClient("http://a:8000", 10*time.Second)
c2 := newHTTPClient("http://b:9000", 30*time.Second)
if c1.client.Transport != c2.client.Transport {
t.Error("clients should share the same http.Transport")
}
if c1.client.Transport != SharedTransport {
t.Error("transport should be the package-level SharedTransport")
}
}
func TestBufferPoolGetPut(t *testing.T) {
buf := getBuf()
if buf == nil {
t.Fatal("getBuf returned nil")
}
if buf.Len() != 0 {
t.Error("getBuf should return a reset buffer")
}
buf.WriteString("hello")
putBuf(buf)
// On re-get, buffer should be reset.
buf2 := getBuf()
if buf2.Len() != 0 {
t.Error("re-acquired buffer should be reset")
}
putBuf(buf2)
}
func TestBufferPoolOversizedDiscarded(t *testing.T) {
buf := getBuf()
// Grow beyond 1 MB threshold.
buf.Write(make([]byte, 2<<20))
putBuf(buf) // should silently discard
// Pool should still work — we get a fresh one.
buf2 := getBuf()
if buf2.Len() != 0 {
t.Error("should get a fresh buffer")
}
putBuf(buf2)
}
func TestBufferPoolConcurrency(t *testing.T) {
var wg sync.WaitGroup
for i := range 100 {
wg.Add(1)
go func(n int) {
defer wg.Done()
buf := getBuf()
buf.WriteString(strings.Repeat("x", n))
putBuf(buf)
}(i)
}
wg.Wait()
}
// ────────────────────────────────────────────────────────────────────────────
// Embeddings client
// ────────────────────────────────────────────────────────────────────────────
func TestEmbeddingsClient_Embed(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/embeddings" {
t.Errorf("path = %q, want /embeddings", r.URL.Path)
}
if r.Method != http.MethodPost {
t.Errorf("method = %s, want POST", r.Method)
}
var req map[string]any
_ = json.NewDecoder(r.Body).Decode(&req)
input, _ := req["input"].([]any)
if len(input) != 2 {
t.Errorf("input len = %d, want 2", len(input))
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"data": []map[string]any{
{"embedding": []float64{0.1, 0.2, 0.3}},
{"embedding": []float64{0.4, 0.5, 0.6}},
},
})
}))
defer ts.Close()
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "bge")
results, err := c.Embed(context.Background(), []string{"hello", "world"})
if err != nil {
t.Fatal(err)
}
if len(results) != 2 {
t.Fatalf("len(results) = %d, want 2", len(results))
}
if results[0][0] != 0.1 {
t.Errorf("results[0][0] = %f, want 0.1", results[0][0])
}
}
func TestEmbeddingsClient_EmbedSingle(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]any{
"data": []map[string]any{
{"embedding": []float64{1.0, 2.0}},
},
})
}))
defer ts.Close()
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
vec, err := c.EmbedSingle(context.Background(), "test")
if err != nil {
t.Fatal(err)
}
if len(vec) != 2 || vec[0] != 1.0 {
t.Errorf("vec = %v", vec)
}
}
func TestEmbeddingsClient_EmbedEmpty(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]any{"data": []any{}})
}))
defer ts.Close()
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
_, err := c.EmbedSingle(context.Background(), "test")
if err == nil {
t.Error("expected error for empty embedding")
}
}
func TestEmbeddingsClient_Health(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/health" {
w.WriteHeader(200)
return
}
w.WriteHeader(404)
}))
defer ts.Close()
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
if !c.Health(context.Background()) {
t.Error("expected healthy")
}
}
// ────────────────────────────────────────────────────────────────────────────
// Reranker client
// ────────────────────────────────────────────────────────────────────────────
func TestRerankerClient_Rerank(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req map[string]any
_ = json.NewDecoder(r.Body).Decode(&req)
if req["query"] != "test query" {
t.Errorf("query = %v", req["query"])
}
_ = json.NewEncoder(w).Encode(map[string]any{
"results": []map[string]any{
{"index": 1, "relevance_score": 0.95},
{"index": 0, "relevance_score": 0.80},
},
})
}))
defer ts.Close()
c := NewRerankerClient(ts.URL, 5*time.Second)
docs := []string{"Paris is great", "France is in Europe"}
results, err := c.Rerank(context.Background(), "test query", docs, 2)
if err != nil {
t.Fatal(err)
}
if len(results) != 2 {
t.Fatalf("len = %d", len(results))
}
if results[0].Score != 0.95 {
t.Errorf("score = %f, want 0.95", results[0].Score)
}
if results[0].Document != "France is in Europe" {
t.Errorf("document = %q", results[0].Document)
}
}
func TestRerankerClient_RerankFallbackScore(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]any{
"results": []map[string]any{
{"index": 0, "score": 0.77, "relevance_score": 0}, // some APIs only set score
},
})
}))
defer ts.Close()
c := NewRerankerClient(ts.URL, 5*time.Second)
results, err := c.Rerank(context.Background(), "q", []string{"doc1"}, 0)
if err != nil {
t.Fatal(err)
}
if results[0].Score != 0.77 {
t.Errorf("fallback score = %f, want 0.77", results[0].Score)
}
}
// ────────────────────────────────────────────────────────────────────────────
// LLM client
// ────────────────────────────────────────────────────────────────────────────
func TestLLMClient_Generate(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/chat/completions" {
t.Errorf("path = %q", r.URL.Path)
}
var req map[string]any
_ = json.NewDecoder(r.Body).Decode(&req)
msgs, _ := req["messages"].([]any)
if len(msgs) == 0 {
t.Error("no messages in request")
}
_ = json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{"message": map[string]any{"content": "Paris is the capital of France."}},
},
})
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
result, err := c.Generate(context.Background(), "capital of France?", "", "")
if err != nil {
t.Fatal(err)
}
if result != "Paris is the capital of France." {
t.Errorf("result = %q", result)
}
}
func TestLLMClient_GenerateWithContext(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req map[string]any
_ = json.NewDecoder(r.Body).Decode(&req)
msgs, _ := req["messages"].([]any)
// Should have system + user message
if len(msgs) != 2 {
t.Errorf("expected 2 messages, got %d", len(msgs))
}
_ = json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{"message": map[string]any{"content": "answer with context"}},
},
})
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
result, err := c.Generate(context.Background(), "question", "some context", "")
if err != nil {
t.Fatal(err)
}
if result != "answer with context" {
t.Errorf("result = %q", result)
}
}
func TestLLMClient_GenerateNoChoices(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]any{"choices": []any{}})
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
_, err := c.Generate(context.Background(), "q", "", "")
if err == nil {
t.Error("expected error for empty choices")
}
}
// ────────────────────────────────────────────────────────────────────────────
// TTS client
// ────────────────────────────────────────────────────────────────────────────
func TestTTSClient_Synthesize(t *testing.T) {
expected := []byte{0xDE, 0xAD, 0xBE, 0xEF}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/tts" {
t.Errorf("path = %q", r.URL.Path)
}
if r.URL.Query().Get("text") != "hello world" {
t.Errorf("text = %q", r.URL.Query().Get("text"))
}
_, _ = w.Write(expected)
}))
defer ts.Close()
c := NewTTSClient(ts.URL, 5*time.Second, "en")
audio, err := c.Synthesize(context.Background(), "hello world", "", "")
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(audio, expected) {
t.Errorf("audio = %x, want %x", audio, expected)
}
}
func TestTTSClient_SynthesizeWithSpeaker(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("speaker_id") != "alice" {
t.Errorf("speaker_id = %q", r.URL.Query().Get("speaker_id"))
}
_, _ = w.Write([]byte{0x01})
}))
defer ts.Close()
c := NewTTSClient(ts.URL, 5*time.Second, "en")
_, err := c.Synthesize(context.Background(), "hi", "en", "alice")
if err != nil {
t.Fatal(err)
}
}
// ────────────────────────────────────────────────────────────────────────────
// STT client
// ────────────────────────────────────────────────────────────────────────────
func TestSTTClient_Transcribe(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/audio/transcriptions" {
t.Errorf("path = %q", r.URL.Path)
}
ct := r.Header.Get("Content-Type")
if !strings.Contains(ct, "multipart/form-data") {
t.Errorf("content-type = %q", ct)
}
// Verify the audio file is present.
file, _, err := r.FormFile("file")
if err != nil {
t.Fatal(err)
}
data, _ := io.ReadAll(file)
if len(data) != 100 {
t.Errorf("file size = %d, want 100", len(data))
}
_ = json.NewEncoder(w).Encode(map[string]string{"text": "hello world"})
}))
defer ts.Close()
c := NewSTTClient(ts.URL, 5*time.Second)
result, err := c.Transcribe(context.Background(), make([]byte, 100), "en")
if err != nil {
t.Fatal(err)
}
if result.Text != "hello world" {
t.Errorf("text = %q", result.Text)
}
}
func TestSTTClient_TranscribeTranslate(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/audio/translations" {
t.Errorf("path = %q, want /v1/audio/translations", r.URL.Path)
}
_ = json.NewEncoder(w).Encode(map[string]string{"text": "translated"})
}))
defer ts.Close()
c := NewSTTClient(ts.URL, 5*time.Second)
c.Task = "translate"
result, err := c.Transcribe(context.Background(), []byte{0x01}, "")
if err != nil {
t.Fatal(err)
}
if result.Text != "translated" {
t.Errorf("text = %q", result.Text)
}
}
// ────────────────────────────────────────────────────────────────────────────
// HTTP error handling
// ────────────────────────────────────────────────────────────────────────────
func TestHTTPError4xx(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(422)
_, _ = w.Write([]byte(`{"error": "bad input"}`))
}))
defer ts.Close()
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
_, err := c.Embed(context.Background(), []string{"test"})
if err == nil {
t.Fatal("expected error for 422")
}
if !strings.Contains(err.Error(), "422") {
t.Errorf("error should contain status code: %v", err)
}
}
func TestHTTPError5xx(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(500)
_, _ = w.Write([]byte("internal server error"))
}))
defer ts.Close()
c := NewLLMClient(ts.URL, 5*time.Second)
_, err := c.Generate(context.Background(), "q", "", "")
if err == nil {
t.Fatal("expected error for 500")
}
}
// ────────────────────────────────────────────────────────────────────────────
// buildMessages helper
// ────────────────────────────────────────────────────────────────────────────
func TestBuildMessages(t *testing.T) {
// No context, no system prompt → just user message
msgs := buildMessages("hello", "", "")
if len(msgs) != 1 || msgs[0].Role != "user" {
t.Errorf("expected 1 user msg, got %+v", msgs)
}
// With system prompt
msgs = buildMessages("hello", "", "You are helpful")
if len(msgs) != 2 || msgs[0].Role != "system" || msgs[0].Content != "You are helpful" {
t.Errorf("expected system+user, got %+v", msgs)
}
// With context, no system prompt → auto system prompt
msgs = buildMessages("question", "some context", "")
if len(msgs) != 2 || msgs[0].Role != "system" {
t.Errorf("expected auto system+user, got %+v", msgs)
}
if !strings.Contains(msgs[1].Content, "Context:") {
t.Error("user message should contain context")
}
}
// ────────────────────────────────────────────────────────────────────────────
// Benchmarks: pooled buffer vs direct allocation
// ────────────────────────────────────────────────────────────────────────────
func BenchmarkPostJSON(b *testing.B) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.Copy(io.Discard, r.Body)
_, _ = w.Write([]byte(`{"ok":true}`))
}))
defer ts.Close()
c := newHTTPClient(ts.URL, 10*time.Second)
ctx := context.Background()
payload := map[string]any{
"text": strings.Repeat("x", 1024),
"count": 42,
"enabled": true,
}
b.ResetTimer()
for b.Loop() {
_, _ = c.postJSON(ctx, "/test", payload)
}
}
func BenchmarkBufferPool(b *testing.B) {
b.ResetTimer()
for b.Loop() {
buf := getBuf()
buf.WriteString(strings.Repeat("x", 4096))
putBuf(buf)
}
}
func BenchmarkBufferPoolParallel(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
buf := getBuf()
buf.WriteString(strings.Repeat("x", 4096))
putBuf(buf)
}
})
}

268
config/config.go Normal file
View File

@@ -0,0 +1,268 @@
// Package config provides environment-based configuration for handler services
// with optional live reload of secrets and service endpoints.
package config
import (
"context"
"log/slog"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/fsnotify/fsnotify"
)
// Settings holds base configuration for all handler services.
// Fields in the "hot-reload" section are protected by a RWMutex and can be
// updated at runtime via WatchSecrets(). All other fields are immutable
// after Load() returns.
type Settings struct {
// Service identification (immutable)
ServiceName string
ServiceVersion string
ServiceNamespace string
DeploymentEnv string
// NATS configuration (immutable)
NATSURL string
NATSUser string
NATSPassword string
NATSQueueGroup string
// Redis/Valkey configuration (immutable)
RedisURL string
RedisPassword string
// Milvus configuration (immutable)
MilvusHost string
MilvusPort int
MilvusCollection string
// OpenTelemetry configuration (immutable)
OTELEnabled bool
OTELEndpoint string
OTELUseHTTP bool
// HyperDX configuration (immutable)
HyperDXEnabled bool
HyperDXAPIKey string
HyperDXEndpoint string
// MLflow configuration (immutable)
MLflowTrackingURI string
MLflowExperimentName string
MLflowEnabled bool
// Health check configuration (immutable)
HealthPort int
HealthPath string
ReadyPath string
// Timeouts (immutable)
HTTPTimeout time.Duration
NATSTimeout time.Duration
// Hot-reloadable fields — access via getter methods.
mu sync.RWMutex
embeddingsURL string
rerankerURL string
llmURL string
ttsURL string
sttURL string
// Secrets path for file-based hot reload (Kubernetes secret mounts)
SecretsPath string
}
// Load creates a Settings populated from environment variables with defaults.
func Load() *Settings {
return &Settings{
ServiceName: getEnv("SERVICE_NAME", "handler"),
ServiceVersion: getEnv("SERVICE_VERSION", "1.0.0"),
ServiceNamespace: getEnv("SERVICE_NAMESPACE", "ai-ml"),
DeploymentEnv: getEnv("DEPLOYMENT_ENV", "production"),
NATSURL: getEnv("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222"),
NATSUser: getEnv("NATS_USER", ""),
NATSPassword: getEnv("NATS_PASSWORD", ""),
NATSQueueGroup: getEnv("NATS_QUEUE_GROUP", ""),
RedisURL: getEnv("REDIS_URL", "redis://valkey.ai-ml.svc.cluster.local:6379"),
RedisPassword: getEnv("REDIS_PASSWORD", ""),
MilvusHost: getEnv("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local"),
MilvusPort: getEnvInt("MILVUS_PORT", 19530),
MilvusCollection: getEnv("MILVUS_COLLECTION", "documents"),
embeddingsURL: getEnv("EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"),
rerankerURL: getEnv("RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"),
llmURL: getEnv("LLM_URL", "http://vllm-predictor.ai-ml.svc.cluster.local"),
ttsURL: getEnv("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local"),
sttURL: getEnv("STT_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"),
OTELEnabled: getEnvBool("OTEL_ENABLED", true),
OTELEndpoint: getEnv("OTEL_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317"),
OTELUseHTTP: getEnvBool("OTEL_USE_HTTP", false),
HyperDXEnabled: getEnvBool("HYPERDX_ENABLED", false),
HyperDXAPIKey: getEnv("HYPERDX_API_KEY", ""),
HyperDXEndpoint: getEnv("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io"),
MLflowTrackingURI: getEnv("MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"),
MLflowExperimentName: getEnv("MLFLOW_EXPERIMENT_NAME", ""),
MLflowEnabled: getEnvBool("MLFLOW_ENABLED", true),
HealthPort: getEnvInt("HEALTH_PORT", 8080),
HealthPath: getEnv("HEALTH_PATH", "/health"),
ReadyPath: getEnv("READY_PATH", "/ready"),
HTTPTimeout: getEnvDuration("HTTP_TIMEOUT", 60*time.Second),
NATSTimeout: getEnvDuration("NATS_TIMEOUT", 30*time.Second),
SecretsPath: getEnv("SECRETS_PATH", ""),
}
}
// EmbeddingsURL returns the current embeddings service URL (thread-safe).
func (s *Settings) EmbeddingsURL() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.embeddingsURL
}
// RerankerURL returns the current reranker service URL (thread-safe).
func (s *Settings) RerankerURL() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.rerankerURL
}
// LLMURL returns the current LLM service URL (thread-safe).
func (s *Settings) LLMURL() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.llmURL
}
// TTSURL returns the current TTS service URL (thread-safe).
func (s *Settings) TTSURL() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.ttsURL
}
// STTURL returns the current STT service URL (thread-safe).
func (s *Settings) STTURL() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.sttURL
}
// WatchSecrets watches the SecretsPath directory for changes and reloads
// hot-reloadable fields. Blocks until ctx is cancelled.
func (s *Settings) WatchSecrets(ctx context.Context) {
if s.SecretsPath == "" {
return
}
watcher, err := fsnotify.NewWatcher()
if err != nil {
slog.Error("config: failed to create fsnotify watcher", "error", err)
return
}
defer func() { _ = watcher.Close() }()
if err := watcher.Add(s.SecretsPath); err != nil {
slog.Error("config: failed to watch secrets path", "error", err, "path", s.SecretsPath)
return
}
slog.Info("config: watching secrets for hot reload", "path", s.SecretsPath)
for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
if event.Has(fsnotify.Create) || event.Has(fsnotify.Write) {
s.reloadFromSecrets()
}
case err, ok := <-watcher.Errors:
if !ok {
return
}
slog.Error("config: fsnotify error", "error", err)
case <-ctx.Done():
return
}
}
}
// reloadFromSecrets reads hot-reloadable values from the secrets directory.
func (s *Settings) reloadFromSecrets() {
s.mu.Lock()
defer s.mu.Unlock()
updated := 0
reload := func(filename string, target *string) {
path := filepath.Join(s.SecretsPath, filename)
data, err := os.ReadFile(path)
if err != nil {
return
}
val := strings.TrimSpace(string(data))
if val != "" && val != *target {
*target = val
updated++
slog.Info("config: reloaded secret", "key", filename)
}
}
reload("embeddings-url", &s.embeddingsURL)
reload("reranker-url", &s.rerankerURL)
reload("llm-url", &s.llmURL)
reload("tts-url", &s.ttsURL)
reload("stt-url", &s.sttURL)
if updated > 0 {
slog.Info("config: secrets reloaded", "updated", updated)
}
}
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
func getEnvInt(key string, fallback int) int {
if v := os.Getenv(key); v != "" {
if i, err := strconv.Atoi(v); err == nil {
return i
}
}
return fallback
}
func getEnvBool(key string, fallback bool) bool {
if v := os.Getenv(key); v != "" {
if b, err := strconv.ParseBool(v); err == nil {
return b
}
}
return fallback
}
func getEnvDuration(key string, fallback time.Duration) time.Duration {
if v := os.Getenv(key); v != "" {
if f, err := strconv.ParseFloat(v, 64); err == nil {
return time.Duration(f * float64(time.Second))
}
}
return fallback
}

123
config/config_test.go Normal file
View File

@@ -0,0 +1,123 @@
package config
import (
"os"
"path/filepath"
"testing"
"time"
)
func TestLoadDefaults(t *testing.T) {
s := Load()
if s.ServiceName != "handler" {
t.Errorf("expected default ServiceName 'handler', got %q", s.ServiceName)
}
if s.HealthPort != 8080 {
t.Errorf("expected default HealthPort 8080, got %d", s.HealthPort)
}
if s.HTTPTimeout != 60*time.Second {
t.Errorf("expected default HTTPTimeout 60s, got %v", s.HTTPTimeout)
}
}
func TestLoadFromEnv(t *testing.T) {
t.Setenv("SERVICE_NAME", "test-svc")
t.Setenv("HEALTH_PORT", "9090")
t.Setenv("OTEL_ENABLED", "false")
s := Load()
if s.ServiceName != "test-svc" {
t.Errorf("expected ServiceName 'test-svc', got %q", s.ServiceName)
}
if s.HealthPort != 9090 {
t.Errorf("expected HealthPort 9090, got %d", s.HealthPort)
}
if s.OTELEnabled {
t.Error("expected OTELEnabled false")
}
}
func TestURLGetters(t *testing.T) {
s := Load()
if s.EmbeddingsURL() == "" {
t.Error("EmbeddingsURL should have a default")
}
if s.RerankerURL() == "" {
t.Error("RerankerURL should have a default")
}
if s.LLMURL() == "" {
t.Error("LLMURL should have a default")
}
if s.TTSURL() == "" {
t.Error("TTSURL should have a default")
}
if s.STTURL() == "" {
t.Error("STTURL should have a default")
}
}
func TestURLGettersFromEnv(t *testing.T) {
t.Setenv("EMBEDDINGS_URL", "http://embed:8000")
t.Setenv("LLM_URL", "http://llm:9000")
s := Load()
if s.EmbeddingsURL() != "http://embed:8000" {
t.Errorf("expected custom EmbeddingsURL, got %q", s.EmbeddingsURL())
}
if s.LLMURL() != "http://llm:9000" {
t.Errorf("expected custom LLMURL, got %q", s.LLMURL())
}
}
func TestReloadFromSecrets(t *testing.T) {
dir := t.TempDir()
// Write initial secret files
writeSecret(t, dir, "embeddings-url", "http://old-embed:8000")
writeSecret(t, dir, "llm-url", "http://old-llm:9000")
s := Load()
s.SecretsPath = dir
s.reloadFromSecrets()
if s.EmbeddingsURL() != "http://old-embed:8000" {
t.Errorf("expected reloaded EmbeddingsURL, got %q", s.EmbeddingsURL())
}
if s.LLMURL() != "http://old-llm:9000" {
t.Errorf("expected reloaded LLMURL, got %q", s.LLMURL())
}
// Simulate secret update
writeSecret(t, dir, "embeddings-url", "http://new-embed:8000")
s.reloadFromSecrets()
if s.EmbeddingsURL() != "http://new-embed:8000" {
t.Errorf("expected updated EmbeddingsURL, got %q", s.EmbeddingsURL())
}
// LLM should remain unchanged
if s.LLMURL() != "http://old-llm:9000" {
t.Errorf("expected unchanged LLMURL, got %q", s.LLMURL())
}
}
func TestReloadFromSecretsNoPath(t *testing.T) {
s := Load()
s.SecretsPath = ""
// Should not panic
s.reloadFromSecrets()
}
func TestGetEnvDuration(t *testing.T) {
t.Setenv("TEST_DUR", "30")
d := getEnvDuration("TEST_DUR", 10*time.Second)
if d != 30*time.Second {
t.Errorf("expected 30s, got %v", d)
}
}
func writeSecret(t *testing.T, dir, name, value string) {
t.Helper()
if err := os.WriteFile(filepath.Join(dir, name), []byte(value), 0644); err != nil {
t.Fatal(err)
}
}

40
go.mod Normal file
View File

@@ -0,0 +1,40 @@
module git.daviestechlabs.io/daviestechlabs/handler-base
go 1.25.1
require (
github.com/nats-io/nats.go v1.48.0
github.com/vmihailenco/msgpack/v5 v5.4.1
go.opentelemetry.io/otel v1.40.0
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.40.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0
go.opentelemetry.io/otel/metric v1.40.0
go.opentelemetry.io/otel/sdk v1.40.0
go.opentelemetry.io/otel/sdk/metric v1.40.0
go.opentelemetry.io/otel/trace v1.40.0
)
require (
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/nats-io/nkeys v0.4.11 // indirect
github.com/nats-io/nuid v1.0.1 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 // indirect
go.opentelemetry.io/proto/otlp v1.9.0 // indirect
golang.org/x/crypto v0.47.0 // indirect
golang.org/x/net v0.49.0 // indirect
golang.org/x/sys v0.40.0 // indirect
golang.org/x/text v0.33.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 // indirect
google.golang.org/grpc v1.78.0 // indirect
google.golang.org/protobuf v1.36.11 // indirect
)

79
go.sum Normal file
View File

@@ -0,0 +1,79 @@
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7 h1:X+2YciYSxvMQK0UZ7sg45ZVabVZBeBuvMkmuI2V3Fak=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7/go.mod h1:lW34nIZuQ8UDPdkon5fmfp2l3+ZkQ2me/+oecHYLOII=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/nats-io/nats.go v1.48.0 h1:pSFyXApG+yWU/TgbKCjmm5K4wrHu86231/w84qRVR+U=
github.com/nats-io/nats.go v1.48.0/go.mod h1:iRWIPokVIFbVijxuMQq4y9ttaBTMe0SFdlZfMDd+33g=
github.com/nats-io/nkeys v0.4.11 h1:q44qGV008kYd9W1b1nEBkNzvnWxtRSQ7A8BoqRrcfa0=
github.com/nats-io/nkeys v0.4.11/go.mod h1:szDimtgmfOi9n25JpfIdGw12tZFYXqhGxjhVxsatHVE=
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms=
go.opentelemetry.io/otel v1.40.0/go.mod h1:IMb+uXZUKkMXdPddhwAHm6UfOwJyh4ct1ybIlV14J0g=
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.40.0 h1:NOyNnS19BF2SUDApbOKbDtWZ0IK7b8FJ2uAGdIWOGb0=
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.40.0/go.mod h1:VL6EgVikRLcJa9ftukrHu/ZkkhFBSo1lzvdBC9CF1ss=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 h1:QKdN8ly8zEMrByybbQgv8cWBcdAarwmIPZ6FThrWXJs=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0/go.mod h1:bTdK1nhqF76qiPoCCdyFIV+N/sRHYXYCTQc+3VCi3MI=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0 h1:DvJDOPmSWQHWywQS6lKL+pb8s3gBLOZUtw4N+mavW1I=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0/go.mod h1:EtekO9DEJb4/jRyN4v4Qjc2yA7AtfCBuz2FynRUWTXs=
go.opentelemetry.io/otel/metric v1.40.0 h1:rcZe317KPftE2rstWIBitCdVp89A2HqjkxR3c11+p9g=
go.opentelemetry.io/otel/metric v1.40.0/go.mod h1:ib/crwQH7N3r5kfiBZQbwrTge743UDc7DTFVZrrXnqc=
go.opentelemetry.io/otel/sdk v1.40.0 h1:KHW/jUzgo6wsPh9At46+h4upjtccTmuZCFAc9OJ71f8=
go.opentelemetry.io/otel/sdk v1.40.0/go.mod h1:Ph7EFdYvxq72Y8Li9q8KebuYUr2KoeyHx0DRMKrYBUE=
go.opentelemetry.io/otel/sdk/metric v1.40.0 h1:mtmdVqgQkeRxHgRv4qhyJduP3fYJRMX4AtAlbuWdCYw=
go.opentelemetry.io/otel/sdk/metric v1.40.0/go.mod h1:4Z2bGMf0KSK3uRjlczMOeMhKU2rhUqdWNoKcYrtcBPg=
go.opentelemetry.io/otel/trace v1.40.0 h1:WA4etStDttCSYuhwvEa8OP8I5EWu24lkOzp+ZYblVjw=
go.opentelemetry.io/otel/trace v1.40.0/go.mod h1:zeAhriXecNGP/s2SEG3+Y8X9ujcJOTqQ5RgdEJcawiA=
go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A=
go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 h1:merA0rdPeUV3YIIfHHcH4qBkiQAc1nfCKSI7lB4cV2M=
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409/go.mod h1:fl8J1IvUjCilwZzQowmw2b7HQB2eAuYBabMXzWurF+I=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 h1:H86B94AW+VfJWDqFeEbBPhEtHzJwJfTbgE2lZa54ZAQ=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ=
google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

231
handler/handler.go Normal file
View File

@@ -0,0 +1,231 @@
// Package handler provides the base Handler pattern for NATS message-driven services.
package handler
import (
"context"
"fmt"
"log/slog"
"os"
"os/signal"
"syscall"
"github.com/nats-io/nats.go"
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
"git.daviestechlabs.io/daviestechlabs/handler-base/health"
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
"git.daviestechlabs.io/daviestechlabs/handler-base/telemetry"
)
// MessageHandler is the callback for processing decoded NATS messages.
// data is the msgpack-decoded map. Return a response map (or nil for no reply).
type MessageHandler func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error)
// TypedMessageHandler processes the raw NATS message without pre-decoding to
// map[string]any. Services unmarshal msg.Data into their own typed structs,
// avoiding the double-decode overhead. Return any msgpack-serialisable value
// (a typed struct, map, or nil for no reply).
type TypedMessageHandler func(ctx context.Context, msg *nats.Msg) (any, error)
// SetupFunc is called once before the handler starts processing messages.
type SetupFunc func(ctx context.Context) error
// TeardownFunc is called during graceful shutdown.
type TeardownFunc func(ctx context.Context) error
// Handler is the base service runner that wires NATS, health, and telemetry.
type Handler struct {
Settings *config.Settings
NATS *natsutil.Client
Telemetry *telemetry.Provider
Subject string
QueueGroup string
onSetup SetupFunc
onTeardown TeardownFunc
onMessage MessageHandler
onTypedMessage TypedMessageHandler
running bool
}
// New creates a Handler for the given NATS subject.
func New(subject string, settings *config.Settings) *Handler {
if settings == nil {
settings = config.Load()
}
queueGroup := settings.NATSQueueGroup
natsOpts := []nats.Option{}
if settings.NATSUser != "" && settings.NATSPassword != "" {
natsOpts = append(natsOpts, nats.UserInfo(settings.NATSUser, settings.NATSPassword))
}
return &Handler{
Settings: settings,
Subject: subject,
QueueGroup: queueGroup,
NATS: natsutil.New(settings.NATSURL, natsOpts...),
}
}
// OnSetup registers the setup callback.
func (h *Handler) OnSetup(fn SetupFunc) { h.onSetup = fn }
// OnTeardown registers the teardown callback.
func (h *Handler) OnTeardown(fn TeardownFunc) { h.onTeardown = fn }
// OnMessage registers the message handler callback.
func (h *Handler) OnMessage(fn MessageHandler) { h.onMessage = fn }
// OnTypedMessage registers a typed message handler. It replaces OnMessage —
// wrapHandler will skip the map[string]any decode and let the callback
// unmarshal msg.Data directly.
func (h *Handler) OnTypedMessage(fn TypedMessageHandler) { h.onTypedMessage = fn }
// Run starts the handler: telemetry, health server, NATS subscription, and blocks until SIGTERM/SIGINT.
func (h *Handler) Run() error {
// Structured logging
slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})))
slog.Info("starting service", "name", h.Settings.ServiceName, "version", h.Settings.ServiceVersion)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Telemetry
tp, shutdown, err := telemetry.Setup(ctx, telemetry.Config{
ServiceName: h.Settings.ServiceName,
ServiceVersion: h.Settings.ServiceVersion,
ServiceNamespace: h.Settings.ServiceNamespace,
DeploymentEnv: h.Settings.DeploymentEnv,
Enabled: h.Settings.OTELEnabled,
Endpoint: h.Settings.OTELEndpoint,
})
if err != nil {
return fmt.Errorf("telemetry setup: %w", err)
}
defer shutdown(ctx)
h.Telemetry = tp
// Health server
healthSrv := health.New(
h.Settings.HealthPort,
h.Settings.HealthPath,
h.Settings.ReadyPath,
func() bool { return h.running && h.NATS.IsConnected() },
)
healthSrv.Start()
defer healthSrv.Stop(ctx)
// Connect to NATS
if err := h.NATS.Connect(); err != nil {
return fmt.Errorf("nats: %w", err)
}
defer h.NATS.Close()
// User setup
if h.onSetup != nil {
slog.Info("running service setup")
if err := h.onSetup(ctx); err != nil {
return fmt.Errorf("setup: %w", err)
}
}
// Subscribe
if h.onMessage == nil && h.onTypedMessage == nil {
return fmt.Errorf("no message handler registered")
}
if err := h.NATS.Subscribe(h.Subject, h.wrapHandler(ctx), h.QueueGroup); err != nil {
return fmt.Errorf("subscribe: %w", err)
}
h.running = true
slog.Info("handler ready", "subject", h.Subject)
// Wait for shutdown signal
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
<-sigCh
slog.Info("shutting down")
h.running = false
// Teardown
if h.onTeardown != nil {
if err := h.onTeardown(ctx); err != nil {
slog.Warn("teardown error", "error", err)
}
}
slog.Info("shutdown complete")
return nil
}
// wrapHandler creates a nats.MsgHandler that dispatches to the registered callback.
// If OnTypedMessage was used, msg.Data is passed directly without map decode.
// If OnMessage was used, msg.Data is decoded to map[string]any first.
func (h *Handler) wrapHandler(ctx context.Context) nats.MsgHandler {
if h.onTypedMessage != nil {
return h.wrapTypedHandler(ctx)
}
return h.wrapMapHandler(ctx)
}
// wrapTypedHandler dispatches to the TypedMessageHandler (no map decode).
func (h *Handler) wrapTypedHandler(ctx context.Context) nats.MsgHandler {
return func(msg *nats.Msg) {
response, err := h.onTypedMessage(ctx, msg)
if err != nil {
slog.Error("handler error", "subject", msg.Subject, "error", err)
if msg.Reply != "" {
_ = h.NATS.Publish(msg.Reply, map[string]any{
"error": true,
"message": err.Error(),
"type": fmt.Sprintf("%T", err),
})
}
return
}
if response != nil && msg.Reply != "" {
if err := h.NATS.Publish(msg.Reply, response); err != nil {
slog.Error("failed to publish reply", "error", err)
}
}
}
}
// wrapMapHandler dispatches to the legacy MessageHandler (decodes to map first).
func (h *Handler) wrapMapHandler(ctx context.Context) nats.MsgHandler {
return func(msg *nats.Msg) {
data, err := natsutil.DecodeMsgpackMap(msg.Data)
if err != nil {
slog.Error("failed to decode message", "subject", msg.Subject, "error", err)
if msg.Reply != "" {
_ = h.NATS.Publish(msg.Reply, map[string]any{
"error": true,
"message": err.Error(),
"type": "DecodeError",
})
}
return
}
response, err := h.onMessage(ctx, msg, data)
if err != nil {
slog.Error("handler error", "subject", msg.Subject, "error", err)
if msg.Reply != "" {
_ = h.NATS.Publish(msg.Reply, map[string]any{
"error": true,
"message": err.Error(),
"type": fmt.Sprintf("%T", err),
})
}
return
}
if response != nil && msg.Reply != "" {
if err := h.NATS.Publish(msg.Reply, response); err != nil {
slog.Error("failed to publish reply", "error", err)
}
}
}
}

311
handler/handler_test.go Normal file
View File

@@ -0,0 +1,311 @@
package handler
import (
"context"
"testing"
"github.com/nats-io/nats.go"
"github.com/vmihailenco/msgpack/v5"
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
)
// ────────────────────────────────────────────────────────────────────────────
// Handler construction tests
// ────────────────────────────────────────────────────────────────────────────
func TestNewHandler(t *testing.T) {
cfg := config.Load()
cfg.ServiceName = "test-handler"
cfg.NATSQueueGroup = "test-group"
h := New("ai.test.subject", cfg)
if h.Subject != "ai.test.subject" {
t.Errorf("Subject = %q", h.Subject)
}
if h.QueueGroup != "test-group" {
t.Errorf("QueueGroup = %q", h.QueueGroup)
}
if h.Settings.ServiceName != "test-handler" {
t.Errorf("ServiceName = %q", h.Settings.ServiceName)
}
}
func TestNewHandlerNilSettings(t *testing.T) {
h := New("ai.test", nil)
if h.Settings == nil {
t.Fatal("Settings should be loaded automatically")
}
if h.Settings.ServiceName != "handler" {
t.Errorf("ServiceName = %q, want default", h.Settings.ServiceName)
}
}
func TestCallbackRegistration(t *testing.T) {
cfg := config.Load()
h := New("ai.test", cfg)
setupCalled := false
h.OnSetup(func(ctx context.Context) error {
setupCalled = true
return nil
})
teardownCalled := false
h.OnTeardown(func(ctx context.Context) error {
teardownCalled = true
return nil
})
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
return nil, nil
})
if h.onSetup == nil || h.onTeardown == nil || h.onMessage == nil {
t.Error("callbacks should not be nil after registration")
}
// Verify setup/teardown work when called directly.
_ = h.onSetup(context.Background())
_ = h.onTeardown(context.Background())
if !setupCalled || !teardownCalled {
t.Error("callbacks should have been invoked")
}
}
func TestTypedMessageRegistration(t *testing.T) {
cfg := config.Load()
h := New("ai.test", cfg)
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
return map[string]any{"ok": true}, nil
})
if h.onTypedMessage == nil {
t.Error("onTypedMessage should not be nil after registration")
}
}
// ────────────────────────────────────────────────────────────────────────────
// wrapHandler dispatch tests (unit test the message decode + dispatch logic)
// ────────────────────────────────────────────────────────────────────────────
func TestWrapHandler_ValidMessage(t *testing.T) {
cfg := config.Load()
h := New("ai.test", cfg)
var receivedData map[string]any
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
receivedData = data
return map[string]any{"status": "ok"}, nil
})
// Encode a message the same way services would.
payload := map[string]any{
"request_id": "test-001",
"message": "hello",
"premium": true,
}
encoded, err := msgpack.Marshal(payload)
if err != nil {
t.Fatal(err)
}
// Call wrapHandler directly without NATS.
handler := h.wrapHandler(context.Background())
handler(&nats.Msg{
Subject: "ai.test.user.42.message",
Data: encoded,
})
if receivedData == nil {
t.Fatal("handler was not called")
}
if receivedData["request_id"] != "test-001" {
t.Errorf("request_id = %v", receivedData["request_id"])
}
if receivedData["premium"] != true {
t.Errorf("premium = %v", receivedData["premium"])
}
}
func TestWrapHandler_InvalidMsgpack(t *testing.T) {
cfg := config.Load()
h := New("ai.test", cfg)
handlerCalled := false
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
handlerCalled = true
return nil, nil
})
handler := h.wrapHandler(context.Background())
handler(&nats.Msg{
Subject: "ai.test",
Data: []byte{0xFF, 0xFE, 0xFD}, // invalid msgpack
})
if handlerCalled {
t.Error("handler should not be called for invalid msgpack")
}
}
func TestWrapHandler_HandlerError(t *testing.T) {
cfg := config.Load()
h := New("ai.test", cfg)
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
return nil, context.DeadlineExceeded
})
encoded, _ := msgpack.Marshal(map[string]any{"key": "val"})
handler := h.wrapHandler(context.Background())
// Should not panic even when handler returns error.
handler(&nats.Msg{
Subject: "ai.test",
Data: encoded,
})
}
func TestWrapHandler_NilResponse(t *testing.T) {
cfg := config.Load()
h := New("ai.test", cfg)
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
return nil, nil // fire-and-forget style
})
encoded, _ := msgpack.Marshal(map[string]any{"x": 1})
handler := h.wrapHandler(context.Background())
// Should not panic with nil response and no reply subject.
handler(&nats.Msg{
Subject: "ai.test",
Data: encoded,
})
}
// ────────────────────────────────────────────────────────────────────────────
// wrapHandler dispatch tests — typed handler path
// ────────────────────────────────────────────────────────────────────────────
func TestWrapTypedHandler_ValidMessage(t *testing.T) {
cfg := config.Load()
h := New("ai.test", cfg)
type testReq struct {
RequestID string `msgpack:"request_id"`
Message string `msgpack:"message"`
}
var received testReq
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
if err := msgpack.Unmarshal(msg.Data, &received); err != nil {
return nil, err
}
return map[string]any{"status": "ok"}, nil
})
encoded, _ := msgpack.Marshal(map[string]any{
"request_id": "typed-001",
"message": "hello typed",
})
handler := h.wrapHandler(context.Background())
handler(&nats.Msg{Subject: "ai.test", Data: encoded})
if received.RequestID != "typed-001" {
t.Errorf("RequestID = %q", received.RequestID)
}
if received.Message != "hello typed" {
t.Errorf("Message = %q", received.Message)
}
}
func TestWrapTypedHandler_Error(t *testing.T) {
cfg := config.Load()
h := New("ai.test", cfg)
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
return nil, context.DeadlineExceeded
})
encoded, _ := msgpack.Marshal(map[string]any{"key": "val"})
handler := h.wrapHandler(context.Background())
// Should not panic.
handler(&nats.Msg{Subject: "ai.test", Data: encoded})
}
func TestWrapTypedHandler_NilResponse(t *testing.T) {
cfg := config.Load()
h := New("ai.test", cfg)
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
return nil, nil
})
encoded, _ := msgpack.Marshal(map[string]any{"x": 1})
handler := h.wrapHandler(context.Background())
handler(&nats.Msg{Subject: "ai.test", Data: encoded})
}
// ────────────────────────────────────────────────────────────────────────────
// Benchmark: message decode + dispatch overhead
// ────────────────────────────────────────────────────────────────────────────
func BenchmarkWrapHandler(b *testing.B) {
cfg := config.Load()
h := New("ai.test", cfg)
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
return map[string]any{"ok": true}, nil
})
payload := map[string]any{
"request_id": "bench-001",
"message": "What is the capital of France?",
"premium": true,
"top_k": 10,
}
encoded, _ := msgpack.Marshal(payload)
handler := h.wrapHandler(context.Background())
msg := &nats.Msg{Subject: "ai.test", Data: encoded}
b.ResetTimer()
for b.Loop() {
handler(msg)
}
}
func BenchmarkWrapTypedHandler(b *testing.B) {
type benchReq struct {
RequestID string `msgpack:"request_id"`
Message string `msgpack:"message"`
Premium bool `msgpack:"premium"`
TopK int `msgpack:"top_k"`
}
cfg := config.Load()
h := New("ai.test", cfg)
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
var req benchReq
_ = msgpack.Unmarshal(msg.Data, &req)
return map[string]any{"ok": true}, nil
})
payload := map[string]any{
"request_id": "bench-001",
"message": "What is the capital of France?",
"premium": true,
"top_k": 10,
}
encoded, _ := msgpack.Marshal(payload)
handler := h.wrapHandler(context.Background())
msg := &nats.Msg{Subject: "ai.test", Data: encoded}
b.ResetTimer()
for b.Loop() {
handler(msg)
}
}

View File

@@ -1,28 +0,0 @@
"""
Handler Base - Shared utilities for AI/ML handler services.
Provides consistent patterns for:
- OpenTelemetry tracing and metrics
- NATS messaging
- Health checks
- Graceful shutdown
- Service client wrappers
"""
from handler_base.config import Settings
from handler_base.handler import Handler
from handler_base.health import HealthServer
from handler_base.nats_client import NATSClient
from handler_base.telemetry import get_meter, get_tracer, setup_telemetry
__all__ = [
"Handler",
"Settings",
"HealthServer",
"NATSClient",
"setup_telemetry",
"get_tracer",
"get_meter",
]
__version__ = "1.0.0"

View File

@@ -1,19 +0,0 @@
"""
Service client wrappers for AI/ML backends.
"""
from handler_base.clients.embeddings import EmbeddingsClient
from handler_base.clients.llm import LLMClient
from handler_base.clients.milvus import MilvusClient
from handler_base.clients.reranker import RerankerClient
from handler_base.clients.stt import STTClient
from handler_base.clients.tts import TTSClient
__all__ = [
"EmbeddingsClient",
"RerankerClient",
"LLMClient",
"TTSClient",
"STTClient",
"MilvusClient",
]

View File

@@ -1,128 +0,0 @@
"""
Embeddings service client (Infinity/BGE).
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
"""
import logging
from typing import Any, Optional
import httpx
from handler_base.config import EmbeddingsSettings
from handler_base.ray_utils import get_ray_handle
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class EmbeddingsClient:
"""
Client for the embeddings service (Infinity with BGE models).
When running inside Ray, automatically uses Ray handles for faster
internal communication. Falls back to HTTP for external calls.
Usage:
client = EmbeddingsClient()
embeddings = await client.embed(["Hello world"])
"""
# Ray Serve deployment configuration
RAY_DEPLOYMENT_NAME = "EmbeddingsDeployment"
RAY_APP_NAME = "embeddings"
def __init__(self, settings: Optional[EmbeddingsSettings] = None):
self.settings = settings or EmbeddingsSettings()
self._client = httpx.AsyncClient(
base_url=self.settings.embeddings_url,
timeout=self.settings.http_timeout,
)
self._ray_handle: Optional[Any] = None
self._ray_checked = False
def _get_ray_handle(self) -> Optional[Any]:
"""Get Ray handle, checking only once."""
if not self._ray_checked:
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
self._ray_checked = True
return self._ray_handle
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def embed(
self,
texts: list[str],
model: Optional[str] = None,
) -> list[list[float]]:
"""
Generate embeddings for a list of texts.
Args:
texts: List of texts to embed
model: Model name (defaults to settings)
Returns:
List of embedding vectors
"""
model = model or self.settings.embeddings_model
with create_span("embeddings.embed") as span:
if span:
span.set_attribute("embeddings.model", model)
span.set_attribute("embeddings.batch_size", len(texts))
# Try Ray handle first (faster internal path)
handle = self._get_ray_handle()
if handle:
try:
if span:
span.set_attribute("embeddings.transport", "ray")
result = await handle.embed.remote(texts, model)
if span and result:
span.set_attribute("embeddings.dimensions", len(result[0]) if result else 0)
return result
except Exception as e:
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
# HTTP fallback
if span:
span.set_attribute("embeddings.transport", "http")
response = await self._client.post(
"/embeddings",
json={"input": texts, "model": model},
)
response.raise_for_status()
result = response.json()
embeddings = [d["embedding"] for d in result.get("data", [])]
if span:
span.set_attribute("embeddings.dimensions", len(embeddings[0]) if embeddings else 0)
return embeddings
async def embed_single(self, text: str, model: Optional[str] = None) -> list[float]:
"""
Generate embedding for a single text.
Args:
text: Text to embed
model: Model name (defaults to settings)
Returns:
Embedding vector
"""
embeddings = await self.embed([text], model)
return embeddings[0] if embeddings else []
async def health(self) -> bool:
"""Check if the embeddings service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False

View File

@@ -1,233 +0,0 @@
"""
LLM service client (vLLM/OpenAI-compatible).
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
"""
import logging
from typing import Any, AsyncIterator, Optional
import httpx
from handler_base.config import LLMSettings
from handler_base.ray_utils import get_ray_handle
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class LLMClient:
"""
Client for the LLM service (vLLM with OpenAI-compatible API).
When running inside Ray, automatically uses Ray handles for faster
internal communication. Falls back to HTTP for external calls.
Usage:
client = LLMClient()
response = await client.generate("Hello, how are you?")
# With context for RAG
response = await client.generate(
"What is the capital?",
context="France is a country in Europe..."
)
# Streaming
async for chunk in client.stream("Tell me a story"):
print(chunk, end="")
"""
# Ray Serve deployment configuration
RAY_DEPLOYMENT_NAME = "VLLMDeployment"
RAY_APP_NAME = "llm"
def __init__(self, settings: Optional[LLMSettings] = None):
self.settings = settings or LLMSettings()
self._client = httpx.AsyncClient(
base_url=self.settings.llm_url,
timeout=self.settings.http_timeout,
)
self._ray_handle: Optional[Any] = None
self._ray_checked = False
def _get_ray_handle(self) -> Optional[Any]:
"""Get Ray handle, checking only once."""
if not self._ray_checked:
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
self._ray_checked = True
return self._ray_handle
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def generate(
self,
prompt: str,
context: Optional[str] = None,
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
stop: Optional[list[str]] = None,
) -> str:
"""
Generate a response from the LLM.
Args:
prompt: User prompt/query
context: Optional context for RAG
system_prompt: Optional system prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Top-p sampling
stop: Stop sequences
Returns:
Generated text response
"""
with create_span("llm.generate") as span:
messages = self._build_messages(prompt, context, system_prompt)
if span:
span.set_attribute("llm.model", self.settings.llm_model)
span.set_attribute("llm.prompt_length", len(prompt))
if context:
span.set_attribute("llm.context_length", len(context))
payload = {
"model": self.settings.llm_model,
"messages": messages,
"max_tokens": max_tokens or self.settings.llm_max_tokens,
"temperature": temperature or self.settings.llm_temperature,
"top_p": top_p or self.settings.llm_top_p,
}
if stop:
payload["stop"] = stop
# Try Ray handle first (faster internal path)
handle = self._get_ray_handle()
if handle:
try:
if span:
span.set_attribute("llm.transport", "ray")
result = await handle.remote(payload)
content = result["choices"][0]["message"]["content"]
if span:
span.set_attribute("llm.response_length", len(content))
return content
except Exception as e:
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
# HTTP fallback
if span:
span.set_attribute("llm.transport", "http")
response = await self._client.post("/v1/chat/completions", json=payload)
response.raise_for_status()
result = response.json()
content = result["choices"][0]["message"]["content"]
if span:
span.set_attribute("llm.response_length", len(content))
usage = result.get("usage", {})
span.set_attribute("llm.prompt_tokens", usage.get("prompt_tokens", 0))
span.set_attribute("llm.completion_tokens", usage.get("completion_tokens", 0))
return content
async def stream(
self,
prompt: str,
context: Optional[str] = None,
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
) -> AsyncIterator[str]:
"""
Stream a response from the LLM.
Args:
prompt: User prompt/query
context: Optional context for RAG
system_prompt: Optional system prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
Yields:
Text chunks as they're generated
"""
messages = self._build_messages(prompt, context, system_prompt)
payload = {
"model": self.settings.llm_model,
"messages": messages,
"max_tokens": max_tokens or self.settings.llm_max_tokens,
"temperature": temperature or self.settings.llm_temperature,
"stream": True,
}
async with self._client.stream("POST", "/v1/chat/completions", json=payload) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data = line[6:]
if data == "[DONE]":
break
import json
chunk = json.loads(data)
delta = chunk["choices"][0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
def _build_messages(
self,
prompt: str,
context: Optional[str] = None,
system_prompt: Optional[str] = None,
) -> list[dict]:
"""Build the messages list for the API call."""
messages = []
# System prompt
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
elif context:
# Default RAG system prompt
messages.append(
{
"role": "system",
"content": (
"You are a helpful assistant. Use the provided context to answer "
"the user's question. If the context doesn't contain relevant "
"information, say so."
),
}
)
# Add context as a separate message if provided
if context:
messages.append(
{
"role": "user",
"content": f"Context:\n{context}\n\nQuestion: {prompt}",
}
)
else:
messages.append({"role": "user", "content": prompt})
return messages
async def health(self) -> bool:
"""Check if the LLM service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False

View File

@@ -1,183 +0,0 @@
"""
Milvus vector database client.
"""
import logging
from typing import Optional
from pymilvus import Collection, connections, utility
from handler_base.config import Settings
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class MilvusClient:
"""
Client for Milvus vector database.
Usage:
client = MilvusClient()
await client.connect()
results = await client.search(embedding, limit=10)
"""
def __init__(self, settings: Optional[Settings] = None):
self.settings = settings or Settings()
self._connected = False
self._collection: Optional[Collection] = None
async def connect(self, collection_name: Optional[str] = None) -> None:
"""
Connect to Milvus and load collection.
Args:
collection_name: Collection to use (defaults to settings)
"""
collection_name = collection_name or self.settings.milvus_collection
connections.connect(
alias="default",
host=self.settings.milvus_host,
port=self.settings.milvus_port,
)
if utility.has_collection(collection_name):
self._collection = Collection(collection_name)
self._collection.load()
logger.info(f"Connected to Milvus collection: {collection_name}")
else:
logger.warning(f"Collection {collection_name} does not exist")
self._connected = True
async def close(self) -> None:
"""Close Milvus connection."""
if self._collection:
self._collection.release()
connections.disconnect("default")
self._connected = False
logger.info("Disconnected from Milvus")
async def search(
self,
embedding: list[float],
limit: int = 10,
output_fields: Optional[list[str]] = None,
filter_expr: Optional[str] = None,
) -> list[dict]:
"""
Search for similar vectors.
Args:
embedding: Query embedding vector
limit: Maximum number of results
output_fields: Fields to return (default: all)
filter_expr: Optional filter expression
Returns:
List of results with 'id', 'distance', and requested fields
"""
if not self._collection:
raise RuntimeError("Not connected to collection")
with create_span("milvus.search") as span:
if span:
span.set_attribute("milvus.collection", self._collection.name)
span.set_attribute("milvus.limit", limit)
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
results = self._collection.search(
data=[embedding],
anns_field="embedding",
param=search_params,
limit=limit,
output_fields=output_fields,
expr=filter_expr,
)
# Convert to list of dicts
hits = []
for hit in results[0]:
item = {
"id": hit.id,
"distance": hit.distance,
"score": 1 - hit.distance, # Convert distance to similarity
}
# Add output fields
if output_fields:
for field in output_fields:
if hasattr(hit.entity, field):
item[field] = getattr(hit.entity, field)
hits.append(item)
if span:
span.set_attribute("milvus.results", len(hits))
return hits
async def search_with_texts(
self,
embedding: list[float],
limit: int = 10,
text_field: str = "text",
metadata_fields: Optional[list[str]] = None,
) -> list[dict]:
"""
Search and return text content with metadata.
Args:
embedding: Query embedding
limit: Maximum results
text_field: Name of text field in collection
metadata_fields: Additional metadata fields to return
Returns:
List of results with text and metadata
"""
output_fields = [text_field]
if metadata_fields:
output_fields.extend(metadata_fields)
return await self.search(embedding, limit, output_fields)
async def insert(
self,
embeddings: list[list[float]],
data: list[dict],
) -> list[int]:
"""
Insert vectors with data into the collection.
Args:
embeddings: List of embedding vectors
data: List of dicts with field values
Returns:
List of inserted IDs
"""
if not self._collection:
raise RuntimeError("Not connected to collection")
with create_span("milvus.insert") as span:
if span:
span.set_attribute("milvus.collection", self._collection.name)
span.set_attribute("milvus.count", len(embeddings))
# Build insert data
insert_data = [embeddings]
for field in self._collection.schema.fields:
if field.name not in ("id", "embedding"):
field_values = [d.get(field.name) for d in data]
insert_data.append(field_values)
result = self._collection.insert(insert_data)
self._collection.flush()
return result.primary_keys
def health(self) -> bool:
"""Check if connected to Milvus."""
return self._connected and utility.get_connection_addr("default") is not None

View File

@@ -1,168 +0,0 @@
"""
Reranker service client (Infinity/BGE Reranker).
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
"""
import logging
from typing import Any, Optional
import httpx
from handler_base.config import Settings
from handler_base.ray_utils import get_ray_handle
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class RerankerClient:
"""
Client for the reranker service (Infinity with BGE Reranker).
When running inside Ray, automatically uses Ray handles for faster
internal communication. Falls back to HTTP for external calls.
Usage:
client = RerankerClient()
reranked = await client.rerank("query", ["doc1", "doc2"])
"""
# Ray Serve deployment configuration
RAY_DEPLOYMENT_NAME = "RerankerDeployment"
RAY_APP_NAME = "reranker"
def __init__(self, settings: Optional[Settings] = None):
self.settings = settings or Settings()
self._client = httpx.AsyncClient(
base_url=self.settings.reranker_url,
timeout=self.settings.http_timeout,
)
self._ray_handle: Optional[Any] = None
self._ray_checked = False
def _get_ray_handle(self) -> Optional[Any]:
"""Get Ray handle, checking only once."""
if not self._ray_checked:
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
self._ray_checked = True
return self._ray_handle
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def rerank(
self,
query: str,
documents: list[str],
top_k: Optional[int] = None,
) -> list[dict]:
"""
Rerank documents based on relevance to query.
Args:
query: Query text
documents: List of documents to rerank
top_k: Number of top results to return (default: all)
Returns:
List of dicts with 'index', 'score', and 'document' keys,
sorted by relevance score descending.
"""
with create_span("reranker.rerank") as span:
if span:
span.set_attribute("reranker.num_documents", len(documents))
if top_k:
span.set_attribute("reranker.top_k", top_k)
payload = {
"query": query,
"documents": documents,
}
if top_k:
payload["top_n"] = top_k
# Try Ray handle first (faster internal path)
handle = self._get_ray_handle()
if handle:
try:
if span:
span.set_attribute("reranker.transport", "ray")
results = await handle.rerank.remote(query, documents, top_k)
# Enrich with original documents
enriched = []
for r in results:
idx = r.get("index", 0)
enriched.append(
{
"index": idx,
"score": r.get("relevance_score", r.get("score", 0)),
"document": documents[idx] if idx < len(documents) else "",
}
)
return enriched
except Exception as e:
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
# HTTP fallback
if span:
span.set_attribute("reranker.transport", "http")
response = await self._client.post("/rerank", json=payload)
response.raise_for_status()
result = response.json()
results = result.get("results", [])
# Enrich with original documents
enriched = []
for r in results:
idx = r.get("index", 0)
enriched.append(
{
"index": idx,
"score": r.get("relevance_score", r.get("score", 0)),
"document": documents[idx] if idx < len(documents) else "",
}
)
return enriched
async def rerank_with_metadata(
self,
query: str,
documents: list[dict],
text_key: str = "text",
top_k: Optional[int] = None,
) -> list[dict]:
"""
Rerank documents with metadata, preserving metadata in results.
Args:
query: Query text
documents: List of dicts with text and metadata
text_key: Key containing text in each document dict
top_k: Number of top results to return
Returns:
Reranked documents with original metadata preserved.
"""
texts = [d.get(text_key, "") for d in documents]
reranked = await self.rerank(query, texts, top_k)
# Merge back metadata
for r in reranked:
idx = r["index"]
if idx < len(documents):
r["metadata"] = {k: v for k, v in documents[idx].items() if k != text_key}
return reranked
async def health(self) -> bool:
"""Check if the reranker service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False

View File

@@ -1,170 +0,0 @@
"""
STT service client (Whisper/faster-whisper).
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
"""
import logging
from typing import Any, Optional
import httpx
from handler_base.config import STTSettings
from handler_base.ray_utils import get_ray_handle
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class STTClient:
"""
Client for the STT service (Whisper/faster-whisper).
When running inside Ray, automatically uses Ray handles for faster
internal communication. Falls back to HTTP for external calls.
Usage:
client = STTClient()
text = await client.transcribe(audio_bytes)
"""
# Ray Serve deployment configuration
RAY_DEPLOYMENT_NAME = "WhisperDeployment"
RAY_APP_NAME = "whisper"
def __init__(self, settings: Optional[STTSettings] = None):
self.settings = settings or STTSettings()
self._client = httpx.AsyncClient(
base_url=self.settings.stt_url,
timeout=180.0, # Transcription can be slow
)
self._ray_handle: Optional[Any] = None
self._ray_checked = False
def _get_ray_handle(self) -> Optional[Any]:
"""Get Ray handle, checking only once."""
if not self._ray_checked:
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
self._ray_checked = True
return self._ray_handle
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def transcribe(
self,
audio: bytes,
language: Optional[str] = None,
task: Optional[str] = None,
response_format: str = "json",
) -> dict:
"""
Transcribe audio to text.
Args:
audio: Audio bytes (WAV, MP3, etc.)
language: Language code (None for auto-detect)
task: "transcribe" or "translate"
response_format: "json", "text", "srt", "vtt"
Returns:
Dict with 'text', 'language', and optional 'segments'
"""
language = language or self.settings.stt_language
task = task or self.settings.stt_task
with create_span("stt.transcribe") as span:
if span:
span.set_attribute("stt.task", task)
span.set_attribute("stt.audio_size", len(audio))
if language:
span.set_attribute("stt.language", language)
# Try Ray handle first (faster internal path)
handle = self._get_ray_handle()
if handle:
try:
if span:
span.set_attribute("stt.transport", "ray")
result = await handle.transcribe.remote(audio, language, task)
if span:
span.set_attribute("stt.result_length", len(result.get("text", "")))
if result.get("language"):
span.set_attribute("stt.detected_language", result["language"])
return result
except Exception as e:
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
# HTTP fallback
if span:
span.set_attribute("stt.transport", "http")
files = {"file": ("audio.wav", audio, "audio/wav")}
data = {
"response_format": response_format,
}
if language:
data["language"] = language
# Choose endpoint based on task
if task == "translate":
endpoint = "/v1/audio/translations"
else:
endpoint = "/v1/audio/transcriptions"
response = await self._client.post(endpoint, files=files, data=data)
response.raise_for_status()
if response_format == "text":
return {"text": response.text}
result = response.json()
if span:
span.set_attribute("stt.result_length", len(result.get("text", "")))
if result.get("language"):
span.set_attribute("stt.detected_language", result["language"])
return result
async def transcribe_file(
self,
file_path: str,
language: Optional[str] = None,
task: Optional[str] = None,
) -> dict:
"""
Transcribe an audio file.
Args:
file_path: Path to audio file
language: Language code
task: "transcribe" or "translate"
Returns:
Transcription result
"""
with open(file_path, "rb") as f:
audio = f.read()
return await self.transcribe(audio, language, task)
async def translate(self, audio: bytes) -> dict:
"""
Translate audio to English.
Args:
audio: Audio bytes
Returns:
Translation result with 'text' key
"""
return await self.transcribe(audio, task="translate")
async def health(self) -> bool:
"""Check if the STT service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False

View File

@@ -1,149 +0,0 @@
"""
TTS service client (Coqui XTTS).
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
"""
import logging
from typing import Any, Optional
import httpx
from handler_base.config import TTSSettings
from handler_base.ray_utils import get_ray_handle
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class TTSClient:
"""
Client for the TTS service (Coqui XTTS).
When running inside Ray, automatically uses Ray handles for faster
internal communication. Falls back to HTTP for external calls.
Usage:
client = TTSClient()
audio_bytes = await client.synthesize("Hello world")
"""
# Ray Serve deployment configuration
RAY_DEPLOYMENT_NAME = "TTSDeployment"
RAY_APP_NAME = "tts"
def __init__(self, settings: Optional[TTSSettings] = None):
self.settings = settings or TTSSettings()
self._client = httpx.AsyncClient(
base_url=self.settings.tts_url,
timeout=120.0, # TTS can be slow
)
self._ray_handle: Optional[Any] = None
self._ray_checked = False
def _get_ray_handle(self) -> Optional[Any]:
"""Get Ray handle, checking only once."""
if not self._ray_checked:
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
self._ray_checked = True
return self._ray_handle
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def synthesize(
self,
text: str,
language: Optional[str] = None,
speaker: Optional[str] = None,
) -> bytes:
"""
Synthesize speech from text.
Args:
text: Text to synthesize
language: Language code (e.g., "en", "es", "fr")
speaker: Speaker ID or reference
Returns:
WAV audio bytes
"""
language = language or self.settings.tts_language
with create_span("tts.synthesize") as span:
if span:
span.set_attribute("tts.language", language)
span.set_attribute("tts.text_length", len(text))
params = {
"text": text,
"language_id": language,
}
if speaker:
params["speaker_id"] = speaker
# Try Ray handle first (faster internal path)
handle = self._get_ray_handle()
if handle:
try:
if span:
span.set_attribute("tts.transport", "ray")
audio_bytes = await handle.synthesize.remote(text, language, speaker)
if span:
span.set_attribute("tts.audio_size", len(audio_bytes))
return audio_bytes
except Exception as e:
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
# HTTP fallback
if span:
span.set_attribute("tts.transport", "http")
response = await self._client.get("/api/tts", params=params)
response.raise_for_status()
audio_bytes = response.content
if span:
span.set_attribute("tts.audio_size", len(audio_bytes))
return audio_bytes
async def synthesize_to_file(
self,
text: str,
output_path: str,
language: Optional[str] = None,
speaker: Optional[str] = None,
) -> None:
"""
Synthesize speech and save to a file.
Args:
text: Text to synthesize
output_path: Path to save the audio file
language: Language code
speaker: Speaker ID
"""
audio_bytes = await self.synthesize(text, language, speaker)
with open(output_path, "wb") as f:
f.write(audio_bytes)
async def get_speakers(self) -> list[dict]:
"""Get available speakers/voices."""
try:
response = await self._client.get("/api/speakers")
response.raise_for_status()
return response.json()
except Exception:
return []
async def health(self) -> bool:
"""Check if the TTS service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False

View File

@@ -1,101 +0,0 @@
"""
Configuration management using Pydantic Settings.
Environment variables are automatically loaded and validated.
"""
from typing import Optional
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Base settings for all handler services."""
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
# Service identification
service_name: str = "handler"
service_version: str = "1.0.0"
service_namespace: str = "ai-ml"
deployment_env: str = "production"
# NATS configuration
nats_url: str = "nats://nats.ai-ml.svc.cluster.local:4222"
nats_user: Optional[str] = None
nats_password: Optional[str] = None
nats_queue_group: Optional[str] = None
# Redis/Valkey configuration
redis_url: str = "redis://valkey.ai-ml.svc.cluster.local:6379"
redis_password: Optional[str] = None
# Milvus configuration
milvus_host: str = "milvus.ai-ml.svc.cluster.local"
milvus_port: int = 19530
milvus_collection: str = "documents"
# Service endpoints
embeddings_url: str = "http://embeddings-predictor.ai-ml.svc.cluster.local"
reranker_url: str = "http://reranker-predictor.ai-ml.svc.cluster.local"
llm_url: str = "http://vllm-predictor.ai-ml.svc.cluster.local"
tts_url: str = "http://tts-predictor.ai-ml.svc.cluster.local"
stt_url: str = "http://whisper-predictor.ai-ml.svc.cluster.local"
# OpenTelemetry configuration
otel_enabled: bool = True
otel_endpoint: str = "http://opentelemetry-collector.observability.svc.cluster.local:4317"
otel_use_http: bool = False
# HyperDX configuration
hyperdx_enabled: bool = False
hyperdx_api_key: Optional[str] = None
hyperdx_endpoint: str = "https://in-otel.hyperdx.io"
# MLflow configuration
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80"
mlflow_experiment_name: Optional[str] = None
mlflow_enabled: bool = True
# Health check configuration
health_port: int = 8080
health_path: str = "/health"
ready_path: str = "/ready"
# Timeouts (seconds)
http_timeout: float = 60.0
nats_timeout: float = 30.0
class EmbeddingsSettings(Settings):
"""Settings for embeddings service client."""
embeddings_model: str = "bge"
embeddings_batch_size: int = 32
class LLMSettings(Settings):
"""Settings for LLM service client."""
llm_model: str = "default"
llm_max_tokens: int = 2048
llm_temperature: float = 0.7
llm_top_p: float = 0.9
class TTSSettings(Settings):
"""Settings for TTS service client."""
tts_language: str = "en"
tts_speaker: Optional[str] = None
class STTSettings(Settings):
"""Settings for STT service client."""
stt_language: Optional[str] = None # Auto-detect
stt_task: str = "transcribe" # or "translate"

View File

@@ -1,222 +0,0 @@
"""
Base handler class for building NATS-based services.
"""
import asyncio
import logging
import signal
from abc import ABC, abstractmethod
from typing import Any, Optional
from nats.aio.msg import Msg
from handler_base.config import Settings
from handler_base.health import HealthServer
from handler_base.nats_client import NATSClient
from handler_base.telemetry import create_span, setup_telemetry
logger = logging.getLogger(__name__)
class Handler(ABC):
"""
Base class for NATS message handlers.
Subclass and implement:
- setup(): Initialize your service clients
- handle_message(): Process incoming messages
- teardown(): Clean up resources (optional)
Example:
class MyHandler(Handler):
async def setup(self):
self.embeddings = EmbeddingsClient()
async def handle_message(self, msg: Msg, data: dict) -> Optional[dict]:
result = await self.embeddings.embed(data["text"])
return {"embedding": result}
if __name__ == "__main__":
MyHandler(subject="my.subject").run()
"""
def __init__(
self,
subject: str,
settings: Optional[Settings] = None,
queue_group: Optional[str] = None,
):
"""
Initialize the handler.
Args:
subject: NATS subject to subscribe to
settings: Configuration settings
queue_group: Optional queue group for load balancing
"""
self.subject = subject
self.settings = settings or Settings()
self.queue_group = queue_group or self.settings.nats_queue_group
self.nats = NATSClient(self.settings)
self.health_server = HealthServer(self.settings, self._check_ready)
self._running = False
self._shutdown_event = asyncio.Event()
@abstractmethod
async def setup(self) -> None:
"""
Initialize service clients and resources.
Called once before starting to handle messages.
Override this to set up your service-specific clients.
"""
pass
@abstractmethod
async def handle_message(self, msg: Msg, data: Any) -> Optional[Any]:
"""
Handle an incoming message.
Args:
msg: Raw NATS message
data: Decoded message data (msgpack unpacked)
Returns:
Optional response data. If returned and msg has a reply subject,
the response will be sent automatically.
"""
pass
async def teardown(self) -> None:
"""
Clean up resources.
Called during graceful shutdown.
Override to add custom cleanup logic.
"""
pass
async def _check_ready(self) -> bool:
"""Check if the service is ready to handle requests."""
return self._running and self.nats._nc is not None
async def _message_handler(self, msg: Msg) -> None:
"""Internal message handler with tracing and error handling."""
with create_span(f"handle.{self.subject}") as span:
try:
# Decode message
data = NATSClient.decode_msgpack(msg)
if span:
span.set_attribute("messaging.destination", msg.subject)
if isinstance(data, dict):
request_id = data.get("request_id", data.get("id"))
if request_id:
span.set_attribute("request.id", str(request_id))
# Handle message
response = await self.handle_message(msg, data)
# Send response if applicable
if response is not None and msg.reply:
await self.nats.publish(msg.reply, response)
except Exception as e:
logger.exception(f"Error handling message on {msg.subject}")
if span:
span.set_attribute("error", True)
span.set_attribute("error.message", str(e))
# Send error response if reply expected
if msg.reply:
error_response = {
"error": True,
"message": str(e),
"type": type(e).__name__,
}
await self.nats.publish(msg.reply, error_response)
def _setup_signals(self) -> None:
"""Set up signal handlers for graceful shutdown."""
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, self._handle_signal, sig)
def _handle_signal(self, sig: signal.Signals) -> None:
"""Handle shutdown signal."""
logger.info(f"Received {sig.name}, initiating graceful shutdown...")
self._shutdown_event.set()
async def _run(self) -> None:
"""Main async run loop."""
# Setup telemetry
setup_telemetry(self.settings)
# Start health server
self.health_server.start()
try:
# Connect to NATS
await self.nats.connect()
# Run user setup
logger.info("Running service setup...")
await self.setup()
# Subscribe to subject
await self.nats.subscribe(
self.subject,
self._message_handler,
queue=self.queue_group,
)
self._running = True
logger.info(f"Handler ready, listening on {self.subject}")
# Wait for shutdown signal
await self._shutdown_event.wait()
except Exception:
logger.exception("Fatal error in handler")
raise
finally:
self._running = False
# Graceful shutdown
logger.info("Shutting down...")
try:
await self.teardown()
except Exception as e:
logger.warning(f"Error during teardown: {e}")
await self.nats.close()
self.health_server.stop()
logger.info("Shutdown complete")
def run(self) -> None:
"""
Run the handler.
This is the main entry point. It sets up signal handlers
and runs the async event loop.
"""
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger.info(f"Starting {self.settings.service_name} v{self.settings.service_version}")
# Run the async loop
asyncio.run(self._run_with_signals())
async def _run_with_signals(self) -> None:
"""Run with signal handling."""
self._setup_signals()
await self._run()

View File

@@ -1,125 +0,0 @@
"""
HTTP health check server.
Provides /health and /ready endpoints for Kubernetes probes.
"""
import asyncio
import json
import logging
import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Awaitable, Callable, Optional
from handler_base.config import Settings
logger = logging.getLogger(__name__)
class HealthHandler(BaseHTTPRequestHandler):
"""HTTP request handler for health checks."""
# Class-level state
ready_check: Optional[Callable[[], Awaitable[bool]]] = None
health_path: str = "/health"
ready_path: str = "/ready"
def log_message(self, format, *args):
"""Suppress default logging."""
pass
def do_GET(self):
"""Handle GET requests for health/ready endpoints."""
if self.path == self.health_path:
self._respond_ok({"status": "healthy"})
elif self.path == self.ready_path:
self._handle_ready()
else:
self._respond_not_found()
def _handle_ready(self):
"""Check readiness and respond."""
# Access via class to avoid method binding issues
ready_check = HealthHandler.ready_check
if ready_check is None:
self._respond_ok({"status": "ready"})
return
try:
# Run the async check in a new event loop
loop = asyncio.new_event_loop()
try:
is_ready = loop.run_until_complete(ready_check())
finally:
loop.close()
if is_ready:
self._respond_ok({"status": "ready"})
else:
self._respond_unavailable({"status": "not ready"})
except Exception as e:
logger.exception("Readiness check failed")
self._respond_unavailable({"status": "error", "message": str(e)})
def _respond_ok(self, data: dict):
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(json.dumps(data).encode())
def _respond_unavailable(self, data: dict):
self.send_response(503)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(json.dumps(data).encode())
def _respond_not_found(self):
self.send_response(404)
self.end_headers()
class HealthServer:
"""
Background HTTP server for health checks.
Usage:
server = HealthServer(settings)
server.start()
# ... run your service ...
server.stop()
"""
def __init__(
self,
settings: Optional[Settings] = None,
ready_check: Optional[Callable[[], Awaitable[bool]]] = None,
):
self.settings = settings or Settings()
self.ready_check = ready_check
self._server: Optional[HTTPServer] = None
self._thread: Optional[threading.Thread] = None
def start(self) -> None:
"""Start the health check server in a background thread."""
# Configure handler class
HealthHandler.ready_check = self.ready_check
HealthHandler.health_path = self.settings.health_path
HealthHandler.ready_path = self.settings.ready_path
# Create and start server
self._server = HTTPServer(("0.0.0.0", self.settings.health_port), HealthHandler)
self._thread = threading.Thread(target=self._server.serve_forever, daemon=True)
self._thread.start()
logger.info(
f"Health server started on port {self.settings.health_port} "
f"(health: {self.settings.health_path}, ready: {self.settings.ready_path})"
)
def stop(self) -> None:
"""Stop the health check server."""
if self._server:
self._server.shutdown()
self._server = None
self._thread = None
logger.info("Health server stopped")

View File

@@ -1,188 +0,0 @@
"""
NATS client wrapper with connection management and utilities.
"""
import logging
from typing import Any, Awaitable, Callable, Optional
import msgpack
import nats
from nats.aio.client import Client
from nats.aio.msg import Msg
from nats.js import JetStreamContext
from handler_base.config import Settings
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class NATSClient:
"""
NATS client with automatic connection management.
Supports:
- Core NATS pub/sub
- JetStream for persistence
- Queue groups for load balancing
- Msgpack serialization
"""
def __init__(self, settings: Optional[Settings] = None):
self.settings = settings or Settings()
self._nc: Optional[Client] = None
self._js: Optional[JetStreamContext] = None
self._subscriptions: list = []
@property
def nc(self) -> Client:
"""Get the NATS client, raising if not connected."""
if self._nc is None:
raise RuntimeError("NATS client not connected. Call connect() first.")
return self._nc
@property
def js(self) -> JetStreamContext:
"""Get JetStream context, raising if not connected."""
if self._js is None:
raise RuntimeError("JetStream not initialized. Call connect() first.")
return self._js
async def connect(self) -> None:
"""Connect to NATS server."""
connect_opts = {
"servers": self.settings.nats_url,
"reconnect_time_wait": 2,
"max_reconnect_attempts": -1, # Infinite
}
if self.settings.nats_user and self.settings.nats_password:
connect_opts["user"] = self.settings.nats_user
connect_opts["password"] = self.settings.nats_password
logger.info(f"Connecting to NATS at {self.settings.nats_url}")
self._nc = await nats.connect(**connect_opts)
self._js = self._nc.jetstream()
logger.info("Connected to NATS")
async def close(self) -> None:
"""Close NATS connection gracefully."""
if self._nc:
# Drain subscriptions first
for sub in self._subscriptions:
try:
await sub.drain()
except Exception as e:
logger.warning(f"Error draining subscription: {e}")
await self._nc.drain()
await self._nc.close()
self._nc = None
self._js = None
logger.info("NATS connection closed")
async def subscribe(
self,
subject: str,
handler: Callable[[Msg], Awaitable[None]],
queue: Optional[str] = None,
):
"""
Subscribe to a subject with a handler function.
Args:
subject: NATS subject to subscribe to
handler: Async function to handle messages
queue: Optional queue group for load balancing
"""
queue = queue or self.settings.nats_queue_group
if queue:
sub = await self.nc.subscribe(subject, queue=queue, cb=handler)
logger.info(f"Subscribed to {subject} (queue: {queue})")
else:
sub = await self.nc.subscribe(subject, cb=handler)
logger.info(f"Subscribed to {subject}")
self._subscriptions.append(sub)
return sub
async def publish(
self,
subject: str,
data: Any,
use_msgpack: bool = True,
) -> None:
"""
Publish a message to a subject.
Args:
subject: NATS subject to publish to
data: Data to publish (will be serialized)
use_msgpack: Whether to use msgpack (True) or JSON (False)
"""
with create_span("nats.publish") as span:
if span:
span.set_attribute("messaging.destination", subject)
if use_msgpack:
payload = msgpack.packb(data, use_bin_type=True)
else:
import json
payload = json.dumps(data).encode()
await self.nc.publish(subject, payload)
async def request(
self,
subject: str,
data: Any,
timeout: Optional[float] = None,
use_msgpack: bool = True,
) -> Any:
"""
Send a request and wait for response.
Args:
subject: NATS subject to send request to
data: Request data
timeout: Response timeout in seconds
use_msgpack: Whether to use msgpack serialization
Returns:
Decoded response data
"""
timeout = timeout or self.settings.nats_timeout
with create_span("nats.request") as span:
if span:
span.set_attribute("messaging.destination", subject)
if use_msgpack:
payload = msgpack.packb(data, use_bin_type=True)
else:
import json
payload = json.dumps(data).encode()
response = await self.nc.request(subject, payload, timeout=timeout)
if use_msgpack:
return msgpack.unpackb(response.data, raw=False)
else:
import json
return json.loads(response.data.decode())
@staticmethod
def decode_msgpack(msg: Msg) -> Any:
"""Decode a msgpack message."""
return msgpack.unpackb(msg.data, raw=False)
@staticmethod
def decode_json(msg: Msg) -> Any:
"""Decode a JSON message."""
import json
return json.loads(msg.data.decode())

View File

View File

@@ -1,70 +0,0 @@
"""
Ray integration utilities for handler-base clients.
When running inside a Ray cluster, clients can use Ray Serve handles
for faster internal communication (gRPC instead of HTTP).
"""
import logging
from typing import Any, Optional
logger = logging.getLogger(__name__)
# Ray handle cache to avoid repeated lookups
_ray_handles: dict[str, Any] = {}
_ray_available: Optional[bool] = None
def is_ray_available() -> bool:
"""Check if we're running inside a Ray cluster."""
global _ray_available
if _ray_available is not None:
return _ray_available
try:
import ray
_ray_available = ray.is_initialized()
if _ray_available:
logger.info("Ray detected - will use Ray handles for internal calls")
return _ray_available
except ImportError:
_ray_available = False
return False
def get_ray_handle(deployment_name: str, app_name: str) -> Optional[Any]:
"""
Get a Ray Serve deployment handle for internal calls.
Args:
deployment_name: Name of the Ray Serve deployment
app_name: Name of the Ray Serve application
Returns:
DeploymentHandle if available, None otherwise
"""
if not is_ray_available():
return None
cache_key = f"{app_name}/{deployment_name}"
if cache_key in _ray_handles:
return _ray_handles[cache_key]
try:
from ray import serve
handle = serve.get_deployment_handle(deployment_name, app_name=app_name)
_ray_handles[cache_key] = handle
logger.debug(f"Got Ray handle for {cache_key}")
return handle
except Exception as e:
logger.debug(f"Could not get Ray handle for {cache_key}: {e}")
return None
def clear_ray_handles() -> None:
"""Clear cached Ray handles (useful for testing)."""
global _ray_handles, _ray_available
_ray_handles.clear()
_ray_available = None

View File

@@ -1,158 +0,0 @@
"""
OpenTelemetry setup for tracing and metrics.
Supports both gRPC and HTTP exporters, with optional HyperDX integration.
"""
import logging
import os
from typing import Optional, Tuple
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
OTLPMetricExporter as OTLPMetricExporterHTTP,
)
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter as OTLPSpanExporterHTTP,
)
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
from opentelemetry.instrumentation.logging import LoggingInstrumentor
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.resources import SERVICE_NAME, SERVICE_NAMESPACE, SERVICE_VERSION, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from handler_base.config import Settings
logger = logging.getLogger(__name__)
# Global references
_tracer: Optional[trace.Tracer] = None
_meter: Optional[metrics.Meter] = None
_initialized = False
def setup_telemetry(
settings: Optional[Settings] = None,
) -> Tuple[Optional[trace.Tracer], Optional[metrics.Meter]]:
"""
Initialize OpenTelemetry tracing and metrics.
Args:
settings: Configuration settings. If None, loads from environment.
Returns:
Tuple of (tracer, meter) or (None, None) if disabled.
"""
global _tracer, _meter, _initialized
if _initialized:
return _tracer, _meter
if settings is None:
settings = Settings()
if not settings.otel_enabled:
logger.info("OpenTelemetry disabled")
_initialized = True
return None, None
# Create resource with service information
resource = Resource.create(
{
SERVICE_NAME: settings.service_name,
SERVICE_VERSION: settings.service_version,
SERVICE_NAMESPACE: settings.service_namespace,
"deployment.environment": settings.deployment_env,
"host.name": os.environ.get("HOSTNAME", "unknown"),
}
)
# Determine endpoint and exporter type
if settings.hyperdx_enabled and settings.hyperdx_api_key:
# HyperDX uses HTTP with API key header
endpoint = settings.hyperdx_endpoint
headers = {"authorization": settings.hyperdx_api_key}
use_http = True
logger.info(f"Using HyperDX endpoint: {endpoint}")
else:
endpoint = settings.otel_endpoint
headers = None
use_http = settings.otel_use_http
logger.info(f"Using OTEL endpoint: {endpoint} (HTTP: {use_http})")
# Setup tracing
if use_http:
trace_exporter = OTLPSpanExporterHTTP(
endpoint=f"{endpoint}/v1/traces",
headers=headers,
)
else:
trace_exporter = OTLPSpanExporter(
endpoint=endpoint,
)
tracer_provider = TracerProvider(resource=resource)
tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter))
trace.set_tracer_provider(tracer_provider)
# Setup metrics
if use_http:
metric_exporter = OTLPMetricExporterHTTP(
endpoint=f"{endpoint}/v1/metrics",
headers=headers,
)
else:
metric_exporter = OTLPMetricExporter(
endpoint=endpoint,
)
metric_reader = PeriodicExportingMetricReader(
metric_exporter,
export_interval_millis=60000,
)
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
metrics.set_meter_provider(meter_provider)
# Instrument libraries
HTTPXClientInstrumentor().instrument()
LoggingInstrumentor().instrument(set_logging_format=True)
# Create tracer and meter for this service
_tracer = trace.get_tracer(settings.service_name, settings.service_version)
_meter = metrics.get_meter(settings.service_name, settings.service_version)
logger.info(f"OpenTelemetry initialized for {settings.service_name}")
_initialized = True
return _tracer, _meter
def get_tracer() -> Optional[trace.Tracer]:
"""Get the global tracer instance."""
return _tracer
def get_meter() -> Optional[metrics.Meter]:
"""Get the global meter instance."""
return _meter
def create_span(name: str, **kwargs):
"""
Create a new span.
Usage:
with create_span("my_operation") as span:
span.set_attribute("key", "value")
# do work
"""
if _tracer is None:
# Return a no-op context manager
from contextlib import nullcontext
return nullcontext()
return _tracer.start_as_current_span(name, **kwargs)

84
health/health.go Normal file
View File

@@ -0,0 +1,84 @@
// Package health provides an HTTP server for Kubernetes liveness and readiness probes.
package health
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net"
"net/http"
"time"
)
// ReadyFunc is called to determine if the service is ready. Return true if ready.
type ReadyFunc func() bool
// Server serves /health and /ready endpoints.
type Server struct {
port int
healthPath string
readyPath string
readyCheck ReadyFunc
srv *http.Server
}
// New creates a health server on the given port.
func New(port int, healthPath, readyPath string, readyCheck ReadyFunc) *Server {
s := &Server{
port: port,
healthPath: healthPath,
readyPath: readyPath,
readyCheck: readyCheck,
}
mux := http.NewServeMux()
mux.HandleFunc(healthPath, s.handleHealth)
mux.HandleFunc(readyPath, s.handleReady)
s.srv = &http.Server{
Addr: fmt.Sprintf(":%d", port),
Handler: mux,
ReadHeaderTimeout: 5 * time.Second,
}
return s
}
// Start begins serving in the background. Call Stop to shut down.
func (s *Server) Start() {
ln, err := net.Listen("tcp", s.srv.Addr)
if err != nil {
slog.Error("health server listen failed", "error", err)
return
}
slog.Info("health server started", "port", s.port, "health", s.healthPath, "ready", s.readyPath)
go func() {
if err := s.srv.Serve(ln); err != nil && err != http.ErrServerClosed {
slog.Error("health server error", "error", err)
}
}()
}
// Stop gracefully shuts down the server.
func (s *Server) Stop(ctx context.Context) {
if s.srv != nil {
_ = s.srv.Shutdown(ctx)
slog.Info("health server stopped")
}
}
func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, map[string]string{"status": "healthy"})
}
func (s *Server) handleReady(w http.ResponseWriter, _ *http.Request) {
if s.readyCheck != nil && !s.readyCheck() {
writeJSON(w, http.StatusServiceUnavailable, map[string]string{"status": "not ready"})
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "ready"})
}
func writeJSON(w http.ResponseWriter, status int, data any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(data)
}

77
health/health_test.go Normal file
View File

@@ -0,0 +1,77 @@
package health
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"time"
)
func TestHealthEndpoint(t *testing.T) {
srv := New(18080, "/health", "/ready", nil)
srv.Start()
defer srv.Stop(context.Background())
time.Sleep(50 * time.Millisecond)
resp, err := http.Get("http://localhost:18080/health")
if err != nil {
t.Fatalf("health request failed: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != 200 {
t.Errorf("expected 200, got %d", resp.StatusCode)
}
body, _ := io.ReadAll(resp.Body)
var data map[string]string
_ = json.Unmarshal(body, &data)
if data["status"] != "healthy" {
t.Errorf("expected status 'healthy', got %q", data["status"])
}
}
func TestReadyEndpointDefault(t *testing.T) {
srv := New(18081, "/health", "/ready", nil)
srv.Start()
defer srv.Stop(context.Background())
time.Sleep(50 * time.Millisecond)
resp, err := http.Get("http://localhost:18081/ready")
if err != nil {
t.Fatalf("ready request failed: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != 200 {
t.Errorf("expected 200, got %d", resp.StatusCode)
}
}
func TestReadyEndpointNotReady(t *testing.T) {
ready := false
srv := New(18082, "/health", "/ready", func() bool { return ready })
srv.Start()
defer srv.Stop(context.Background())
time.Sleep(50 * time.Millisecond)
resp, err := http.Get("http://localhost:18082/ready")
if err != nil {
t.Fatalf("ready request failed: %v", err)
}
_ = resp.Body.Close()
if resp.StatusCode != 503 {
t.Errorf("expected 503 when not ready, got %d", resp.StatusCode)
}
ready = true
resp2, err := http.Get("http://localhost:18082/ready")
if err != nil {
t.Fatalf("ready request failed: %v", err)
}
_ = resp2.Body.Close()
if resp2.StatusCode != 200 {
t.Errorf("expected 200 when ready, got %d", resp2.StatusCode)
}
}

515
messages/bench_test.go Normal file
View File

@@ -0,0 +1,515 @@
// Package messages benchmarks compare three serialization strategies:
//
// 1. msgpack map[string]any — the old approach (dynamic, no types)
// 2. msgpack typed struct — the new approach (compile-time safe, short keys)
// 3. protobuf — optional future migration
//
// Run with:
//
// go test -bench=. -benchmem -count=5 ./messages/... | tee bench.txt
// # optional: go install golang.org/x/perf/cmd/benchstat@latest && benchstat bench.txt
package messages
import (
"testing"
"time"
"github.com/vmihailenco/msgpack/v5"
"google.golang.org/protobuf/proto"
pb "git.daviestechlabs.io/daviestechlabs/handler-base/messages/proto"
)
// ────────────────────────────────────────────────────────────────────────────
// Test fixtures — equivalent data across all three encodings
// ────────────────────────────────────────────────────────────────────────────
// chatRequestMap is the legacy map[string]any representation.
func chatRequestMap() map[string]any {
return map[string]any{
"request_id": "req-abc-123",
"user_id": "user-42",
"message": "What is the capital of France?",
"query": "",
"premium": true,
"enable_rag": true,
"enable_reranker": true,
"enable_streaming": false,
"top_k": 10,
"collection": "documents",
"enable_tts": false,
"system_prompt": "You are a helpful assistant.",
"response_subject": "ai.chat.response.req-abc-123",
}
}
// chatRequestStruct is the typed struct representation.
func chatRequestStruct() ChatRequest {
return ChatRequest{
RequestID: "req-abc-123",
UserID: "user-42",
Message: "What is the capital of France?",
Premium: true,
EnableRAG: true,
EnableReranker: true,
TopK: 10,
Collection: "documents",
SystemPrompt: "You are a helpful assistant.",
ResponseSubject: "ai.chat.response.req-abc-123",
}
}
// chatRequestProto is the protobuf representation.
func chatRequestProto() *pb.ChatRequest {
return &pb.ChatRequest{
RequestId: "req-abc-123",
UserId: "user-42",
Message: "What is the capital of France?",
Premium: true,
EnableRag: true,
EnableReranker: true,
TopK: 10,
Collection: "documents",
SystemPrompt: "You are a helpful assistant.",
ResponseSubject: "ai.chat.response.req-abc-123",
}
}
// voiceResponseMap is a voice response with a 16 KB audio payload.
func voiceResponseMap() map[string]any {
return map[string]any{
"request_id": "vr-001",
"response": "The capital of France is Paris.",
"audio": make([]byte, 16384),
"transcription": "What is the capital of France?",
}
}
func voiceResponseStruct() VoiceResponse {
return VoiceResponse{
RequestID: "vr-001",
Response: "The capital of France is Paris.",
Audio: make([]byte, 16384),
Transcription: "What is the capital of France?",
}
}
func voiceResponseProto() *pb.VoiceResponse {
return &pb.VoiceResponse{
RequestId: "vr-001",
Response: "The capital of France is Paris.",
Audio: make([]byte, 16384),
Transcription: "What is the capital of France?",
}
}
// ttsChunkMap simulates a streaming audio chunk (~32 KB).
func ttsChunkMap() map[string]any {
return map[string]any{
"session_id": "tts-sess-99",
"chunk_index": 3,
"total_chunks": 12,
"audio_b64": string(make([]byte, 32768)), // old: base64 string
"is_last": false,
"timestamp": time.Now().Unix(),
"sample_rate": 24000,
}
}
func ttsChunkStruct() TTSAudioChunk {
return TTSAudioChunk{
SessionID: "tts-sess-99",
ChunkIndex: 3,
TotalChunks: 12,
Audio: make([]byte, 32768), // new: raw bytes
IsLast: false,
Timestamp: time.Now().Unix(),
SampleRate: 24000,
}
}
func ttsChunkProto() *pb.TTSAudioChunk {
return &pb.TTSAudioChunk{
SessionId: "tts-sess-99",
ChunkIndex: 3,
TotalChunks: 12,
Audio: make([]byte, 32768),
IsLast: false,
Timestamp: time.Now().Unix(),
SampleRate: 24000,
}
}
// ────────────────────────────────────────────────────────────────────────────
// Wire-size comparison (run once, printed by TestWireSize)
// ────────────────────────────────────────────────────────────────────────────
func TestWireSize(t *testing.T) {
tests := []struct {
name string
mapData any
structVal any
protoMsg proto.Message
}{
{"ChatRequest", chatRequestMap(), chatRequestStruct(), chatRequestProto()},
{"VoiceResponse", voiceResponseMap(), voiceResponseStruct(), voiceResponseProto()},
{"TTSAudioChunk", ttsChunkMap(), ttsChunkStruct(), ttsChunkProto()},
}
for _, tt := range tests {
mapBytes, _ := msgpack.Marshal(tt.mapData)
structBytes, _ := msgpack.Marshal(tt.structVal)
protoBytes, _ := proto.Marshal(tt.protoMsg)
t.Logf("%-16s map=%5d B struct=%5d B proto=%5d B (struct saves %.0f%%, proto saves %.0f%%)",
tt.name,
len(mapBytes), len(structBytes), len(protoBytes),
100*(1-float64(len(structBytes))/float64(len(mapBytes))),
100*(1-float64(len(protoBytes))/float64(len(mapBytes))),
)
}
}
// ────────────────────────────────────────────────────────────────────────────
// Encode benchmarks
// ────────────────────────────────────────────────────────────────────────────
func BenchmarkEncode_ChatRequest_MsgpackMap(b *testing.B) {
data := chatRequestMap()
b.ResetTimer()
for b.Loop() {
_, _ = msgpack.Marshal(data)
}
}
func BenchmarkEncode_ChatRequest_MsgpackStruct(b *testing.B) {
data := chatRequestStruct()
b.ResetTimer()
for b.Loop() {
_, _ = msgpack.Marshal(data)
}
}
func BenchmarkEncode_ChatRequest_Protobuf(b *testing.B) {
data := chatRequestProto()
b.ResetTimer()
for b.Loop() {
_, _ = proto.Marshal(data)
}
}
func BenchmarkEncode_VoiceResponse_MsgpackMap(b *testing.B) {
data := voiceResponseMap()
b.ResetTimer()
for b.Loop() {
_, _ = msgpack.Marshal(data)
}
}
func BenchmarkEncode_VoiceResponse_MsgpackStruct(b *testing.B) {
data := voiceResponseStruct()
b.ResetTimer()
for b.Loop() {
_, _ = msgpack.Marshal(data)
}
}
func BenchmarkEncode_VoiceResponse_Protobuf(b *testing.B) {
data := voiceResponseProto()
b.ResetTimer()
for b.Loop() {
_, _ = proto.Marshal(data)
}
}
func BenchmarkEncode_TTSChunk_MsgpackMap(b *testing.B) {
data := ttsChunkMap()
b.ResetTimer()
for b.Loop() {
_, _ = msgpack.Marshal(data)
}
}
func BenchmarkEncode_TTSChunk_MsgpackStruct(b *testing.B) {
data := ttsChunkStruct()
b.ResetTimer()
for b.Loop() {
_, _ = msgpack.Marshal(data)
}
}
func BenchmarkEncode_TTSChunk_Protobuf(b *testing.B) {
data := ttsChunkProto()
b.ResetTimer()
for b.Loop() {
_, _ = proto.Marshal(data)
}
}
// ────────────────────────────────────────────────────────────────────────────
// Decode benchmarks
// ────────────────────────────────────────────────────────────────────────────
func BenchmarkDecode_ChatRequest_MsgpackMap(b *testing.B) {
encoded, _ := msgpack.Marshal(chatRequestMap())
b.ResetTimer()
for b.Loop() {
var m map[string]any
_ = msgpack.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_ChatRequest_MsgpackStruct(b *testing.B) {
encoded, _ := msgpack.Marshal(chatRequestStruct())
b.ResetTimer()
for b.Loop() {
var m ChatRequest
_ = msgpack.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_ChatRequest_Protobuf(b *testing.B) {
encoded, _ := proto.Marshal(chatRequestProto())
b.ResetTimer()
for b.Loop() {
var m pb.ChatRequest
_ = proto.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_VoiceResponse_MsgpackMap(b *testing.B) {
encoded, _ := msgpack.Marshal(voiceResponseMap())
b.ResetTimer()
for b.Loop() {
var m map[string]any
_ = msgpack.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_VoiceResponse_MsgpackStruct(b *testing.B) {
encoded, _ := msgpack.Marshal(voiceResponseStruct())
b.ResetTimer()
for b.Loop() {
var m VoiceResponse
_ = msgpack.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_VoiceResponse_Protobuf(b *testing.B) {
encoded, _ := proto.Marshal(voiceResponseProto())
b.ResetTimer()
for b.Loop() {
var m pb.VoiceResponse
_ = proto.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_TTSChunk_MsgpackMap(b *testing.B) {
encoded, _ := msgpack.Marshal(ttsChunkMap())
b.ResetTimer()
for b.Loop() {
var m map[string]any
_ = msgpack.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_TTSChunk_MsgpackStruct(b *testing.B) {
encoded, _ := msgpack.Marshal(ttsChunkStruct())
b.ResetTimer()
for b.Loop() {
var m TTSAudioChunk
_ = msgpack.Unmarshal(encoded, &m)
}
}
func BenchmarkDecode_TTSChunk_Protobuf(b *testing.B) {
encoded, _ := proto.Marshal(ttsChunkProto())
b.ResetTimer()
for b.Loop() {
var m pb.TTSAudioChunk
_ = proto.Unmarshal(encoded, &m)
}
}
// ────────────────────────────────────────────────────────────────────────────
// Roundtrip benchmarks (encode + decode)
// ────────────────────────────────────────────────────────────────────────────
func BenchmarkRoundtrip_ChatRequest_MsgpackMap(b *testing.B) {
data := chatRequestMap()
b.ResetTimer()
for b.Loop() {
enc, _ := msgpack.Marshal(data)
var dec map[string]any
_ = msgpack.Unmarshal(enc, &dec)
}
}
func BenchmarkRoundtrip_ChatRequest_MsgpackStruct(b *testing.B) {
data := chatRequestStruct()
b.ResetTimer()
for b.Loop() {
enc, _ := msgpack.Marshal(data)
var dec ChatRequest
_ = msgpack.Unmarshal(enc, &dec)
}
}
func BenchmarkRoundtrip_ChatRequest_Protobuf(b *testing.B) {
data := chatRequestProto()
b.ResetTimer()
for b.Loop() {
enc, _ := proto.Marshal(data)
var dec pb.ChatRequest
_ = proto.Unmarshal(enc, &dec)
}
}
// ────────────────────────────────────────────────────────────────────────────
// Typed struct unit tests — verify roundtrip correctness
// ────────────────────────────────────────────────────────────────────────────
func TestRoundtrip_ChatRequest(t *testing.T) {
orig := chatRequestStruct()
data, err := msgpack.Marshal(orig)
if err != nil {
t.Fatal(err)
}
var dec ChatRequest
if err := msgpack.Unmarshal(data, &dec); err != nil {
t.Fatal(err)
}
if dec.RequestID != orig.RequestID {
t.Errorf("RequestID = %q, want %q", dec.RequestID, orig.RequestID)
}
if dec.Message != orig.Message {
t.Errorf("Message = %q, want %q", dec.Message, orig.Message)
}
if dec.TopK != orig.TopK {
t.Errorf("TopK = %d, want %d", dec.TopK, orig.TopK)
}
if dec.Premium != orig.Premium {
t.Errorf("Premium = %v, want %v", dec.Premium, orig.Premium)
}
if dec.EffectiveQuery() != orig.Message {
t.Errorf("EffectiveQuery() = %q, want %q", dec.EffectiveQuery(), orig.Message)
}
}
func TestRoundtrip_VoiceResponse(t *testing.T) {
orig := voiceResponseStruct()
data, err := msgpack.Marshal(orig)
if err != nil {
t.Fatal(err)
}
var dec VoiceResponse
if err := msgpack.Unmarshal(data, &dec); err != nil {
t.Fatal(err)
}
if dec.RequestID != orig.RequestID {
t.Errorf("RequestID mismatch")
}
if len(dec.Audio) != len(orig.Audio) {
t.Errorf("Audio len = %d, want %d", len(dec.Audio), len(orig.Audio))
}
if dec.Transcription != orig.Transcription {
t.Errorf("Transcription mismatch")
}
}
func TestRoundtrip_TTSAudioChunk(t *testing.T) {
orig := ttsChunkStruct()
data, err := msgpack.Marshal(orig)
if err != nil {
t.Fatal(err)
}
var dec TTSAudioChunk
if err := msgpack.Unmarshal(data, &dec); err != nil {
t.Fatal(err)
}
if dec.SessionID != orig.SessionID {
t.Errorf("SessionID mismatch")
}
if dec.ChunkIndex != orig.ChunkIndex {
t.Errorf("ChunkIndex = %d, want %d", dec.ChunkIndex, orig.ChunkIndex)
}
if len(dec.Audio) != len(orig.Audio) {
t.Errorf("Audio len = %d, want %d", len(dec.Audio), len(orig.Audio))
}
if dec.SampleRate != orig.SampleRate {
t.Errorf("SampleRate = %d, want %d", dec.SampleRate, orig.SampleRate)
}
}
func TestRoundtrip_PipelineTrigger(t *testing.T) {
orig := PipelineTrigger{
RequestID: "pip-001",
Pipeline: "document-ingestion",
Parameters: map[string]any{"source": "s3://bucket/data"},
}
data, err := msgpack.Marshal(orig)
if err != nil {
t.Fatal(err)
}
var dec PipelineTrigger
if err := msgpack.Unmarshal(data, &dec); err != nil {
t.Fatal(err)
}
if dec.Pipeline != orig.Pipeline {
t.Errorf("Pipeline = %q, want %q", dec.Pipeline, orig.Pipeline)
}
if dec.Parameters["source"] != orig.Parameters["source"] {
t.Errorf("Parameters[source] mismatch")
}
}
func TestRoundtrip_STTTranscription(t *testing.T) {
orig := STTTranscription{
SessionID: "stt-001",
Transcript: "hello world",
Sequence: 5,
IsPartial: false,
IsFinal: true,
Timestamp: time.Now().Unix(),
SpeakerID: "speaker-1",
HasVoiceActivity: true,
State: "listening",
}
data, err := msgpack.Marshal(orig)
if err != nil {
t.Fatal(err)
}
var dec STTTranscription
if err := msgpack.Unmarshal(data, &dec); err != nil {
t.Fatal(err)
}
if dec.Transcript != orig.Transcript {
t.Errorf("Transcript = %q, want %q", dec.Transcript, orig.Transcript)
}
if dec.IsFinal != orig.IsFinal {
t.Error("IsFinal mismatch")
}
}
func TestRoundtrip_ErrorResponse(t *testing.T) {
orig := ErrorResponse{Error: true, Message: "something broke", Type: "InternalError"}
data, err := msgpack.Marshal(orig)
if err != nil {
t.Fatal(err)
}
var dec ErrorResponse
if err := msgpack.Unmarshal(data, &dec); err != nil {
t.Fatal(err)
}
if !dec.Error || dec.Message != "something broke" || dec.Type != "InternalError" {
t.Errorf("ErrorResponse roundtrip mismatch: %+v", dec)
}
}
func TestTimestamp(t *testing.T) {
ts := Timestamp()
now := time.Now().Unix()
if ts < now-1 || ts > now+1 {
t.Errorf("Timestamp() = %d, expected ~%d", ts, now)
}
}

224
messages/messages.go Normal file
View File

@@ -0,0 +1,224 @@
// Package messages defines typed NATS message structs for all services.
//
// Using typed structs with short msgpack field tags instead of map[string]any
// provides compile-time safety, smaller wire size (integer-like short keys vs
// full string keys), and faster encode/decode by avoiding interface{} boxing.
//
// Audio data uses raw []byte instead of base64-encoded strings — msgpack
// supports binary natively, eliminating the 33% base64 overhead.
package messages
import "time"
// ────────────────────────────────────────────────────────────────────────────
// Pipeline Bridge
// ────────────────────────────────────────────────────────────────────────────
// PipelineTrigger is the request to start a pipeline.
type PipelineTrigger struct {
RequestID string `msgpack:"request_id" json:"request_id"`
Pipeline string `msgpack:"pipeline" json:"pipeline"`
Parameters map[string]any `msgpack:"parameters,omitempty" json:"parameters,omitempty"`
}
// PipelineStatus is the response / status update for a pipeline run.
type PipelineStatus struct {
RequestID string `msgpack:"request_id" json:"request_id"`
Status string `msgpack:"status" json:"status"`
RunID string `msgpack:"run_id,omitempty" json:"run_id,omitempty"`
Engine string `msgpack:"engine,omitempty" json:"engine,omitempty"`
Pipeline string `msgpack:"pipeline,omitempty" json:"pipeline,omitempty"`
SubmittedAt string `msgpack:"submitted_at,omitempty" json:"submitted_at,omitempty"`
Error string `msgpack:"error,omitempty" json:"error,omitempty"`
AvailablePipelines []string `msgpack:"available_pipelines,omitempty" json:"available_pipelines,omitempty"`
}
// ────────────────────────────────────────────────────────────────────────────
// Chat Handler
// ────────────────────────────────────────────────────────────────────────────
// ChatRequest is an incoming chat message.
type ChatRequest struct {
RequestID string `msgpack:"request_id" json:"request_id"`
UserID string `msgpack:"user_id" json:"user_id"`
Message string `msgpack:"message" json:"message"`
Query string `msgpack:"query,omitempty" json:"query,omitempty"`
Premium bool `msgpack:"premium,omitempty" json:"premium,omitempty"`
EnableRAG bool `msgpack:"enable_rag,omitempty" json:"enable_rag,omitempty"`
EnableReranker bool `msgpack:"enable_reranker,omitempty" json:"enable_reranker,omitempty"`
EnableStreaming bool `msgpack:"enable_streaming,omitempty" json:"enable_streaming,omitempty"`
TopK int `msgpack:"top_k,omitempty" json:"top_k,omitempty"`
Collection string `msgpack:"collection,omitempty" json:"collection,omitempty"`
EnableTTS bool `msgpack:"enable_tts,omitempty" json:"enable_tts,omitempty"`
SystemPrompt string `msgpack:"system_prompt,omitempty" json:"system_prompt,omitempty"`
ResponseSubject string `msgpack:"response_subject,omitempty" json:"response_subject,omitempty"`
}
// EffectiveQuery returns Message or falls back to Query.
func (c *ChatRequest) EffectiveQuery() string {
if c.Message != "" {
return c.Message
}
return c.Query
}
// ChatResponse is the full reply to a chat request.
type ChatResponse struct {
UserID string `msgpack:"user_id" json:"user_id"`
Response string `msgpack:"response" json:"response"`
ResponseText string `msgpack:"response_text" json:"response_text"`
UsedRAG bool `msgpack:"used_rag" json:"used_rag"`
RAGSources []string `msgpack:"rag_sources,omitempty" json:"rag_sources,omitempty"`
Success bool `msgpack:"success" json:"success"`
Audio []byte `msgpack:"audio,omitempty" json:"audio,omitempty"`
Error string `msgpack:"error,omitempty" json:"error,omitempty"`
}
// ChatStreamChunk is a single streaming chunk from an LLM response.
type ChatStreamChunk struct {
RequestID string `msgpack:"request_id" json:"request_id"`
Type string `msgpack:"type" json:"type"`
Content string `msgpack:"content" json:"content"`
Done bool `msgpack:"done" json:"done"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
}
// ────────────────────────────────────────────────────────────────────────────
// Voice Assistant
// ────────────────────────────────────────────────────────────────────────────
// VoiceRequest is an incoming voice-to-voice request.
type VoiceRequest struct {
RequestID string `msgpack:"request_id" json:"request_id"`
Audio []byte `msgpack:"audio" json:"audio"`
Language string `msgpack:"language,omitempty" json:"language,omitempty"`
Collection string `msgpack:"collection,omitempty" json:"collection,omitempty"`
}
// VoiceResponse is the reply to a voice request.
type VoiceResponse struct {
RequestID string `msgpack:"request_id" json:"request_id"`
Response string `msgpack:"response" json:"response"`
Audio []byte `msgpack:"audio" json:"audio"`
Transcription string `msgpack:"transcription,omitempty" json:"transcription,omitempty"`
Sources []DocumentSource `msgpack:"sources,omitempty" json:"sources,omitempty"`
Error string `msgpack:"error,omitempty" json:"error,omitempty"`
}
// DocumentSource is a RAG search result source.
type DocumentSource struct {
Text string `msgpack:"text" json:"text"`
Score float64 `msgpack:"score" json:"score"`
}
// ────────────────────────────────────────────────────────────────────────────
// TTS Module
// ────────────────────────────────────────────────────────────────────────────
// TTSRequest is a text-to-speech synthesis request.
type TTSRequest struct {
Text string `msgpack:"text" json:"text"`
Speaker string `msgpack:"speaker,omitempty" json:"speaker,omitempty"`
Language string `msgpack:"language,omitempty" json:"language,omitempty"`
SpeakerWavB64 string `msgpack:"speaker_wav_b64,omitempty" json:"speaker_wav_b64,omitempty"`
Stream bool `msgpack:"stream,omitempty" json:"stream,omitempty"`
}
// TTSAudioChunk is a streamed audio chunk from TTS synthesis.
type TTSAudioChunk struct {
SessionID string `msgpack:"session_id" json:"session_id"`
ChunkIndex int `msgpack:"chunk_index" json:"chunk_index"`
TotalChunks int `msgpack:"total_chunks" json:"total_chunks"`
Audio []byte `msgpack:"audio" json:"audio"`
IsLast bool `msgpack:"is_last" json:"is_last"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
SampleRate int `msgpack:"sample_rate" json:"sample_rate"`
}
// TTSFullResponse is a non-streamed TTS response (whole audio).
type TTSFullResponse struct {
SessionID string `msgpack:"session_id" json:"session_id"`
Audio []byte `msgpack:"audio" json:"audio"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
SampleRate int `msgpack:"sample_rate" json:"sample_rate"`
}
// TTSStatus is a TTS processing status update.
type TTSStatus struct {
SessionID string `msgpack:"session_id" json:"session_id"`
Status string `msgpack:"status" json:"status"`
Message string `msgpack:"message" json:"message"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
}
// TTSVoiceListResponse is the reply to a voice list request.
type TTSVoiceListResponse struct {
DefaultSpeaker string `msgpack:"default_speaker" json:"default_speaker"`
CustomVoices []TTSVoiceInfo `msgpack:"custom_voices" json:"custom_voices"`
LastRefresh int64 `msgpack:"last_refresh" json:"last_refresh"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
}
// TTSVoiceInfo is summary info about a custom voice.
type TTSVoiceInfo struct {
Name string `msgpack:"name" json:"name"`
Language string `msgpack:"language" json:"language"`
ModelType string `msgpack:"model_type" json:"model_type"`
CreatedAt string `msgpack:"created_at" json:"created_at"`
}
// TTSVoiceRefreshResponse is the reply to a voice refresh request.
type TTSVoiceRefreshResponse struct {
Count int `msgpack:"count" json:"count"`
CustomVoices []TTSVoiceInfo `msgpack:"custom_voices" json:"custom_voices"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
}
// ────────────────────────────────────────────────────────────────────────────
// STT Module
// ────────────────────────────────────────────────────────────────────────────
// STTStreamMessage is any message on the ai.voice.stream.{session} subject.
type STTStreamMessage struct {
Type string `msgpack:"type" json:"type"`
Audio []byte `msgpack:"audio,omitempty" json:"audio,omitempty"`
State string `msgpack:"state,omitempty" json:"state,omitempty"`
SpeakerID string `msgpack:"speaker_id,omitempty" json:"speaker_id,omitempty"`
}
// STTTranscription is the transcription result published by the STT module.
type STTTranscription struct {
SessionID string `msgpack:"session_id" json:"session_id"`
Transcript string `msgpack:"transcript" json:"transcript"`
Sequence int `msgpack:"sequence" json:"sequence"`
IsPartial bool `msgpack:"is_partial" json:"is_partial"`
IsFinal bool `msgpack:"is_final" json:"is_final"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
SpeakerID string `msgpack:"speaker_id" json:"speaker_id"`
HasVoiceActivity bool `msgpack:"has_voice_activity" json:"has_voice_activity"`
State string `msgpack:"state" json:"state"`
}
// STTInterrupt is published when the STT module detects a user interrupt.
type STTInterrupt struct {
SessionID string `msgpack:"session_id" json:"session_id"`
Type string `msgpack:"type" json:"type"`
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
SpeakerID string `msgpack:"speaker_id" json:"speaker_id"`
}
// ────────────────────────────────────────────────────────────────────────────
// Common / Error
// ────────────────────────────────────────────────────────────────────────────
// ErrorResponse is the standard error reply from any handler.
type ErrorResponse struct {
Error bool `msgpack:"error" json:"error"`
Message string `msgpack:"message" json:"message"`
Type string `msgpack:"type" json:"type"`
}
// Timestamp returns the current Unix timestamp (helper for message construction).
func Timestamp() int64 {
return time.Now().Unix()
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,174 @@
syntax = "proto3";
package messages;
option go_package = "git.daviestechlabs.io/daviestechlabs/handler-base/messages/proto";
// ── Pipeline Bridge ─────────────────────────────────────────────────────────
message PipelineTrigger {
string request_id = 1;
string pipeline = 2;
map<string, string> parameters = 3;
}
message PipelineStatus {
string request_id = 1;
string status = 2;
string run_id = 3;
string engine = 4;
string pipeline = 5;
string submitted_at = 6;
string error = 7;
repeated string available_pipelines = 8;
}
// ── Chat Handler ────────────────────────────────────────────────────────────
message ChatRequest {
string request_id = 1;
string user_id = 2;
string message = 3;
string query = 4;
bool premium = 5;
bool enable_rag = 6;
bool enable_reranker = 7;
bool enable_streaming = 8;
int32 top_k = 9;
string collection = 10;
bool enable_tts = 11;
string system_prompt = 12;
string response_subject = 13;
}
message ChatResponse {
string user_id = 1;
string response = 2;
string response_text = 3;
bool used_rag = 4;
repeated string rag_sources = 5;
bool success = 6;
bytes audio = 7;
string error = 8;
}
message ChatStreamChunk {
string request_id = 1;
string type = 2;
string content = 3;
bool done = 4;
int64 timestamp = 5;
}
// ── Voice Assistant ─────────────────────────────────────────────────────────
message VoiceRequest {
string request_id = 1;
bytes audio = 2;
string language = 3;
string collection = 4;
}
message VoiceResponse {
string request_id = 1;
string response = 2;
bytes audio = 3;
string transcription = 4;
repeated DocumentSource sources = 5;
string error = 6;
}
message DocumentSource {
string text = 1;
double score = 2;
}
// ── TTS Module ──────────────────────────────────────────────────────────────
message TTSRequest {
string text = 1;
string speaker = 2;
string language = 3;
string speaker_wav_b64 = 4;
bool stream = 5;
}
message TTSAudioChunk {
string session_id = 1;
int32 chunk_index = 2;
int32 total_chunks = 3;
bytes audio = 4;
bool is_last = 5;
int64 timestamp = 6;
int32 sample_rate = 7;
}
message TTSFullResponse {
string session_id = 1;
bytes audio = 2;
int64 timestamp = 3;
int32 sample_rate = 4;
}
message TTSStatus {
string session_id = 1;
string status = 2;
string message = 3;
int64 timestamp = 4;
}
message TTSVoiceInfo {
string name = 1;
string language = 2;
string model_type = 3;
string created_at = 4;
}
message TTSVoiceListResponse {
string default_speaker = 1;
repeated TTSVoiceInfo custom_voices = 2;
int64 last_refresh = 3;
int64 timestamp = 4;
}
message TTSVoiceRefreshResponse {
int32 count = 1;
repeated TTSVoiceInfo custom_voices = 2;
int64 timestamp = 3;
}
// ── STT Module ──────────────────────────────────────────────────────────────
message STTStreamMessage {
string type = 1;
bytes audio = 2;
string state = 3;
string speaker_id = 4;
}
message STTTranscription {
string session_id = 1;
string transcript = 2;
int32 sequence = 3;
bool is_partial = 4;
bool is_final = 5;
int64 timestamp = 6;
string speaker_id = 7;
bool has_voice_activity = 8;
string state = 9;
}
message STTInterrupt {
string session_id = 1;
string type = 2;
int64 timestamp = 3;
string speaker_id = 4;
}
// ── Common ──────────────────────────────────────────────────────────────────
message ErrorResponse {
bool error = 1;
string message = 2;
string type = 3;
}

142
natsutil/natsutil.go Normal file
View File

@@ -0,0 +1,142 @@
// Package natsutil provides a NATS/JetStream client with msgpack serialization.
package natsutil
import (
"fmt"
"log/slog"
"time"
"github.com/nats-io/nats.go"
"github.com/vmihailenco/msgpack/v5"
)
// Client wraps a NATS connection with msgpack helpers.
type Client struct {
nc *nats.Conn
js nats.JetStreamContext
subs []*nats.Subscription
url string
opts []nats.Option
}
// New creates a NATS client configured to connect to the given URL.
// Optional NATS options (e.g. credentials) can be appended.
func New(url string, opts ...nats.Option) *Client {
defaults := []nats.Option{
nats.ReconnectWait(2 * time.Second),
nats.MaxReconnects(-1),
nats.DisconnectErrHandler(func(_ *nats.Conn, err error) {
slog.Warn("NATS disconnected", "error", err)
}),
nats.ReconnectHandler(func(_ *nats.Conn) {
slog.Info("NATS reconnected")
}),
}
return &Client{
url: url,
opts: append(defaults, opts...),
}
}
// Connect establishes the NATS connection and JetStream context.
func (c *Client) Connect() error {
nc, err := nats.Connect(c.url, c.opts...)
if err != nil {
return fmt.Errorf("nats connect: %w", err)
}
js, err := nc.JetStream()
if err != nil {
nc.Close()
return fmt.Errorf("jetstream: %w", err)
}
c.nc = nc
c.js = js
slog.Info("connected to NATS", "url", c.url)
return nil
}
// Close drains subscriptions and closes the connection.
func (c *Client) Close() {
if c.nc == nil {
return
}
for _, sub := range c.subs {
_ = sub.Drain()
}
c.nc.Close()
slog.Info("NATS connection closed")
}
// Conn returns the underlying *nats.Conn.
func (c *Client) Conn() *nats.Conn { return c.nc }
// JS returns the JetStream context.
func (c *Client) JS() nats.JetStreamContext { return c.js }
// IsConnected returns true if the NATS connection is active.
func (c *Client) IsConnected() bool {
return c.nc != nil && c.nc.IsConnected()
}
// Subscribe subscribes to a subject with an optional queue group.
// The handler receives the raw *nats.Msg.
func (c *Client) Subscribe(subject string, handler nats.MsgHandler, queue string) error {
var sub *nats.Subscription
var err error
if queue != "" {
sub, err = c.nc.QueueSubscribe(subject, queue, handler)
slog.Info("subscribed", "subject", subject, "queue", queue)
} else {
sub, err = c.nc.Subscribe(subject, handler)
slog.Info("subscribed", "subject", subject)
}
if err != nil {
return fmt.Errorf("subscribe %s: %w", subject, err)
}
c.subs = append(c.subs, sub)
return nil
}
// Publish encodes data as msgpack and publishes to the subject.
func (c *Client) Publish(subject string, data any) error {
payload, err := msgpack.Marshal(data)
if err != nil {
return fmt.Errorf("msgpack marshal: %w", err)
}
return c.nc.Publish(subject, payload)
}
// Request sends a msgpack-encoded request and decodes the response into result.
func (c *Client) Request(subject string, data any, result any, timeout time.Duration) error {
payload, err := msgpack.Marshal(data)
if err != nil {
return fmt.Errorf("msgpack marshal: %w", err)
}
msg, err := c.nc.Request(subject, payload, timeout)
if err != nil {
return fmt.Errorf("nats request: %w", err)
}
return msgpack.Unmarshal(msg.Data, result)
}
// DecodeMsgpack decodes msgpack-encoded NATS message data into dest.
func DecodeMsgpack(msg *nats.Msg, dest any) error {
return msgpack.Unmarshal(msg.Data, dest)
}
// Decode is a generic helper that unmarshals msgpack bytes into T.
// Usage: req, err := natsutil.Decode[messages.ChatRequest](msg.Data)
func Decode[T any](data []byte) (T, error) {
var v T
err := msgpack.Unmarshal(data, &v)
return v, err
}
// DecodeMsgpackMap decodes msgpack data into a generic map.
func DecodeMsgpackMap(data []byte) (map[string]any, error) {
var m map[string]any
if err := msgpack.Unmarshal(data, &m); err != nil {
return nil, err
}
return m, nil
}

256
natsutil/natsutil_test.go Normal file
View File

@@ -0,0 +1,256 @@
package natsutil
import (
"testing"
"github.com/vmihailenco/msgpack/v5"
)
// ────────────────────────────────────────────────────────────────────────────
// DecodeMsgpackMap tests
// ────────────────────────────────────────────────────────────────────────────
func TestDecodeMsgpackMap_Roundtrip(t *testing.T) {
orig := map[string]any{
"request_id": "req-001",
"user_id": "user-42",
"premium": true,
"top_k": int64(10), // msgpack decodes ints as int64
}
data, err := msgpack.Marshal(orig)
if err != nil {
t.Fatal(err)
}
decoded, err := DecodeMsgpackMap(data)
if err != nil {
t.Fatal(err)
}
if decoded["request_id"] != "req-001" {
t.Errorf("request_id = %v", decoded["request_id"])
}
if decoded["premium"] != true {
t.Errorf("premium = %v", decoded["premium"])
}
}
func TestDecodeMsgpackMap_Empty(t *testing.T) {
data, _ := msgpack.Marshal(map[string]any{})
m, err := DecodeMsgpackMap(data)
if err != nil {
t.Fatal(err)
}
if len(m) != 0 {
t.Errorf("expected empty map, got %v", m)
}
}
func TestDecodeMsgpackMap_InvalidData(t *testing.T) {
_, err := DecodeMsgpackMap([]byte{0xFF, 0xFE})
if err == nil {
t.Error("expected error for invalid msgpack data")
}
}
// ────────────────────────────────────────────────────────────────────────────
// DecodeMsgpack (typed struct) tests
// ────────────────────────────────────────────────────────────────────────────
type testMessage struct {
RequestID string `msgpack:"request_id"`
UserID string `msgpack:"user_id"`
Count int `msgpack:"count"`
Active bool `msgpack:"active"`
}
func TestDecodeMsgpackTyped_Roundtrip(t *testing.T) {
orig := testMessage{
RequestID: "req-typed-001",
UserID: "user-7",
Count: 42,
Active: true,
}
data, err := msgpack.Marshal(orig)
if err != nil {
t.Fatal(err)
}
// Simulate nats.Msg data decoding.
var decoded testMessage
if err := msgpack.Unmarshal(data, &decoded); err != nil {
t.Fatal(err)
}
if decoded.RequestID != orig.RequestID {
t.Errorf("RequestID = %q, want %q", decoded.RequestID, orig.RequestID)
}
if decoded.Count != orig.Count {
t.Errorf("Count = %d, want %d", decoded.Count, orig.Count)
}
if decoded.Active != orig.Active {
t.Errorf("Active = %v, want %v", decoded.Active, orig.Active)
}
}
// TestTypedStructDecodesMapEncoding verifies that a typed struct can be
// decoded from data that was encoded as map[string]any (backwards compat).
func TestTypedStructDecodesMapEncoding(t *testing.T) {
// Encode as map (the old way).
mapData := map[string]any{
"request_id": "req-compat",
"user_id": "user-compat",
"count": int64(99),
"active": false,
}
data, err := msgpack.Marshal(mapData)
if err != nil {
t.Fatal(err)
}
// Decode into typed struct (the new way).
var msg testMessage
if err := msgpack.Unmarshal(data, &msg); err != nil {
t.Fatal(err)
}
if msg.RequestID != "req-compat" {
t.Errorf("RequestID = %q", msg.RequestID)
}
if msg.Count != 99 {
t.Errorf("Count = %d, want 99", msg.Count)
}
}
// ────────────────────────────────────────────────────────────────────────────
// Binary data tests (audio []byte in msgpack)
// ────────────────────────────────────────────────────────────────────────────
type audioMessage struct {
SessionID string `msgpack:"session_id"`
Audio []byte `msgpack:"audio"`
SampleRate int `msgpack:"sample_rate"`
}
func TestBinaryDataRoundtrip(t *testing.T) {
audio := make([]byte, 32768)
for i := range audio {
audio[i] = byte(i % 256)
}
orig := audioMessage{
SessionID: "sess-audio-001",
Audio: audio,
SampleRate: 24000,
}
data, err := msgpack.Marshal(orig)
if err != nil {
t.Fatal(err)
}
var decoded audioMessage
if err := msgpack.Unmarshal(data, &decoded); err != nil {
t.Fatal(err)
}
if len(decoded.Audio) != len(orig.Audio) {
t.Fatalf("audio len = %d, want %d", len(decoded.Audio), len(orig.Audio))
}
for i := range decoded.Audio {
if decoded.Audio[i] != orig.Audio[i] {
t.Fatalf("audio[%d] = %d, want %d", i, decoded.Audio[i], orig.Audio[i])
}
}
}
// TestBinaryVsBase64Size shows the wire-size win of raw bytes vs base64 string.
func TestBinaryVsBase64Size(t *testing.T) {
audio := make([]byte, 16384)
// Old approach: base64 string in map.
import_b64 := make([]byte, (len(audio)*4+2)/3) // approximate base64 size
mapMsg := map[string]any{
"session_id": "sess-1",
"audio_b64": string(import_b64),
}
mapData, _ := msgpack.Marshal(mapMsg)
// New approach: raw bytes in struct.
structMsg := audioMessage{
SessionID: "sess-1",
Audio: audio,
}
structData, _ := msgpack.Marshal(structMsg)
t.Logf("base64-in-map: %d bytes, raw-bytes-in-struct: %d bytes (%.0f%% smaller)",
len(mapData), len(structData),
100*(1-float64(len(structData))/float64(len(mapData))))
}
// ────────────────────────────────────────────────────────────────────────────
// Benchmarks
// ────────────────────────────────────────────────────────────────────────────
func BenchmarkEncodeMap(b *testing.B) {
data := map[string]any{
"request_id": "req-bench",
"user_id": "user-bench",
"message": "What is the weather today?",
"premium": true,
"top_k": 10,
}
for b.Loop() {
_, _ = msgpack.Marshal(data)
}
}
func BenchmarkEncodeStruct(b *testing.B) {
data := testMessage{
RequestID: "req-bench",
UserID: "user-bench",
Count: 10,
Active: true,
}
for b.Loop() {
_, _ = msgpack.Marshal(data)
}
}
func BenchmarkDecodeMap(b *testing.B) {
raw, _ := msgpack.Marshal(map[string]any{
"request_id": "req-bench",
"user_id": "user-bench",
"message": "What is the weather today?",
"premium": true,
"top_k": 10,
})
for b.Loop() {
var m map[string]any
_ = msgpack.Unmarshal(raw, &m)
}
}
func BenchmarkDecodeStruct(b *testing.B) {
raw, _ := msgpack.Marshal(testMessage{
RequestID: "req-bench",
UserID: "user-bench",
Count: 10,
Active: true,
})
for b.Loop() {
var m testMessage
_ = msgpack.Unmarshal(raw, &m)
}
}
func BenchmarkDecodeAudio32KB(b *testing.B) {
raw, _ := msgpack.Marshal(audioMessage{
SessionID: "s1",
Audio: make([]byte, 32768),
SampleRate: 24000,
})
for b.Loop() {
var m audioMessage
_ = msgpack.Unmarshal(raw, &m)
}
}

View File

@@ -1,70 +0,0 @@
[project]
name = "handler-base"
version = "1.0.0"
description = "Shared base library for AI/ML handler services"
readme = "README.md"
requires-python = ">=3.11"
license = { text = "MIT" }
authors = [{ name = "Davies Tech Labs" }]
dependencies = [
# Async & messaging
"nats-py>=2.7.0",
"httpx>=0.27.0",
"msgpack>=1.0.0",
# Data stores
"pymilvus>=2.4.0",
"redis>=5.0.0",
# Observability
"opentelemetry-api>=1.20.0",
"opentelemetry-sdk>=1.20.0",
"opentelemetry-exporter-otlp-proto-grpc>=1.20.0",
"opentelemetry-exporter-otlp-proto-http>=1.20.0",
"opentelemetry-instrumentation-httpx>=0.44b0",
"opentelemetry-instrumentation-logging>=0.44b0",
# MLflow
"mlflow>=2.10.0",
"psycopg2-binary>=2.9.0",
# Utilities
"numpy>=1.26.0",
"pydantic>=2.5.0",
"pydantic-settings>=2.1.0",
]
[project.optional-dependencies]
audio = [
"soundfile>=0.12.0",
"librosa>=0.10.0",
"webrtcvad>=2.0.10",
]
dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.23.0",
"pytest-cov>=4.0.0",
"ruff>=0.4.0",
"pre-commit>=3.7.0",
]
ray = [
"ray[serve]>=2.9.0",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["handler_base"]
[tool.ruff]
line-length = 100
target-version = "py311"
[tool.ruff.lint]
select = ["E", "F", "I", "W"]
[tool.pytest.ini_options]
asyncio_mode = "auto"

View File

@@ -1,7 +0,0 @@
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
asyncio_mode = auto
addopts = -v --tb=short

7
renovate.json Normal file
View File

@@ -0,0 +1,7 @@
{
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
"extends": [
"local>daviestechlabs/renovate-config",
"local>daviestechlabs/renovate-config:python"
]
}

122
telemetry/telemetry.go Normal file
View File

@@ -0,0 +1,122 @@
// Package telemetry provides OpenTelemetry tracing and metrics setup.
package telemetry
import (
"context"
"log/slog"
"os"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/propagation"
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
"go.opentelemetry.io/otel/trace"
)
// Config holds the telemetry configuration parameters.
type Config struct {
ServiceName string
ServiceVersion string
ServiceNamespace string
DeploymentEnv string
Enabled bool
Endpoint string
}
// Provider holds the initialized tracer and meter providers.
type Provider struct {
TracerProvider *sdktrace.TracerProvider
MeterProvider *sdkmetric.MeterProvider
Tracer trace.Tracer
Meter metric.Meter
}
// Setup initialises OpenTelemetry tracing and metrics.
// Returns a Provider and a shutdown function.
func Setup(ctx context.Context, cfg Config) (*Provider, func(context.Context), error) {
if !cfg.Enabled {
slog.Info("OpenTelemetry disabled")
return &Provider{
Tracer: otel.Tracer(cfg.ServiceName),
Meter: otel.Meter(cfg.ServiceName),
}, func(context.Context) {}, nil
}
hostname, _ := os.Hostname()
res, err := resource.New(ctx,
resource.WithAttributes(
semconv.ServiceNameKey.String(cfg.ServiceName),
semconv.ServiceVersionKey.String(cfg.ServiceVersion),
semconv.ServiceNamespaceKey.String(cfg.ServiceNamespace),
attribute.String("deployment.environment", cfg.DeploymentEnv),
attribute.String("host.name", hostname),
),
)
if err != nil {
return nil, nil, err
}
// Trace exporter (gRPC)
traceExp, err := otlptracegrpc.New(ctx,
otlptracegrpc.WithEndpoint(stripScheme(cfg.Endpoint)),
otlptracegrpc.WithInsecure(),
)
if err != nil {
return nil, nil, err
}
tp := sdktrace.NewTracerProvider(
sdktrace.WithBatcher(traceExp),
sdktrace.WithResource(res),
)
otel.SetTracerProvider(tp)
otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(
propagation.TraceContext{},
propagation.Baggage{},
))
// Metric exporter (gRPC)
metricExp, err := otlpmetricgrpc.New(ctx,
otlpmetricgrpc.WithEndpoint(stripScheme(cfg.Endpoint)),
otlpmetricgrpc.WithInsecure(),
)
if err != nil {
return nil, nil, err
}
mp := sdkmetric.NewMeterProvider(
sdkmetric.WithReader(sdkmetric.NewPeriodicReader(metricExp)),
sdkmetric.WithResource(res),
)
otel.SetMeterProvider(mp)
slog.Info("OpenTelemetry initialized", "service", cfg.ServiceName, "endpoint", cfg.Endpoint)
shutdown := func(ctx context.Context) {
_ = tp.Shutdown(ctx)
_ = mp.Shutdown(ctx)
}
return &Provider{
TracerProvider: tp,
MeterProvider: mp,
Tracer: tp.Tracer(cfg.ServiceName, trace.WithInstrumentationVersion(cfg.ServiceVersion)),
Meter: mp.Meter(cfg.ServiceName),
}, shutdown, nil
}
// stripScheme removes http:// or https:// from an endpoint for gRPC dialers.
func stripScheme(endpoint string) string {
for _, prefix := range []string{"https://", "http://"} {
if len(endpoint) > len(prefix) && endpoint[:len(prefix)] == prefix {
return endpoint[len(prefix):]
}
}
return endpoint
}

View File

@@ -1,76 +0,0 @@
"""
Pytest configuration and fixtures.
"""
import asyncio
import os
from unittest.mock import AsyncMock, MagicMock
import pytest
# Set test environment variables before importing handler_base
os.environ.setdefault("NATS_URL", "nats://localhost:4222")
os.environ.setdefault("REDIS_URL", "redis://localhost:6379")
os.environ.setdefault("MILVUS_HOST", "localhost")
os.environ.setdefault("OTEL_ENABLED", "false")
os.environ.setdefault("MLFLOW_ENABLED", "false")
@pytest.fixture(scope="session")
def event_loop():
"""Create event loop for async tests."""
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture
def settings():
"""Create test settings."""
from handler_base.config import Settings
return Settings(
service_name="test-service",
service_version="1.0.0-test",
otel_enabled=False,
mlflow_enabled=False,
nats_url="nats://localhost:4222",
redis_url="redis://localhost:6379",
milvus_host="localhost",
)
@pytest.fixture
def mock_httpx_client():
"""Create a mock httpx AsyncClient."""
client = AsyncMock()
client.post = AsyncMock()
client.get = AsyncMock()
client.aclose = AsyncMock()
return client
@pytest.fixture
def mock_nats_message():
"""Create a mock NATS message."""
msg = MagicMock()
msg.subject = "test.subject"
msg.reply = "test.reply"
msg.data = b"\x82\xa8query\xa5hello\xaarequest_id\xa4test" # msgpack
return msg
@pytest.fixture
def sample_embedding():
"""Sample embedding vector."""
return [0.1] * 1024
@pytest.fixture
def sample_documents():
"""Sample documents for testing."""
return [
{"text": "Python is a programming language.", "source": "doc1"},
{"text": "Machine learning is a subset of AI.", "source": "doc2"},
{"text": "Deep learning uses neural networks.", "source": "doc3"},
]

View File

@@ -1 +0,0 @@
# Unit tests package

View File

@@ -1,150 +0,0 @@
"""
Unit tests for service clients.
"""
from unittest.mock import MagicMock
import pytest
class TestEmbeddingsClient:
"""Tests for EmbeddingsClient."""
@pytest.fixture
def embeddings_client(self, mock_httpx_client):
"""Create an EmbeddingsClient with mocked HTTP."""
from handler_base.clients.embeddings import EmbeddingsClient
client = EmbeddingsClient()
client._client = mock_httpx_client
return client
@pytest.mark.asyncio
async def test_embed_single(self, embeddings_client, mock_httpx_client, sample_embedding):
"""Test embedding a single text."""
# Setup mock response
mock_response = MagicMock()
mock_response.json.return_value = {"data": [{"embedding": sample_embedding, "index": 0}]}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response
result = await embeddings_client.embed_single("Hello world")
assert result == sample_embedding
mock_httpx_client.post.assert_called_once()
@pytest.mark.asyncio
async def test_embed_batch(self, embeddings_client, mock_httpx_client, sample_embedding):
"""Test embedding multiple texts."""
texts = ["Hello", "World"]
mock_response = MagicMock()
mock_response.json.return_value = {
"data": [
{"embedding": sample_embedding, "index": 0},
{"embedding": sample_embedding, "index": 1},
]
}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response
result = await embeddings_client.embed(texts)
assert len(result) == 2
assert all(len(e) == len(sample_embedding) for e in result)
@pytest.mark.asyncio
async def test_health_check(self, embeddings_client, mock_httpx_client):
"""Test health check."""
mock_response = MagicMock()
mock_response.status_code = 200
mock_httpx_client.get.return_value = mock_response
result = await embeddings_client.health()
assert result is True
class TestRerankerClient:
"""Tests for RerankerClient."""
@pytest.fixture
def reranker_client(self, mock_httpx_client):
"""Create a RerankerClient with mocked HTTP."""
from handler_base.clients.reranker import RerankerClient
client = RerankerClient()
client._client = mock_httpx_client
return client
@pytest.mark.asyncio
async def test_rerank(self, reranker_client, mock_httpx_client, sample_documents):
"""Test reranking documents."""
texts = [d["text"] for d in sample_documents]
mock_response = MagicMock()
mock_response.json.return_value = {
"results": [
{"index": 1, "relevance_score": 0.95},
{"index": 0, "relevance_score": 0.80},
{"index": 2, "relevance_score": 0.65},
]
}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response
result = await reranker_client.rerank("What is ML?", texts)
assert len(result) == 3
assert result[0]["score"] == 0.95
assert result[0]["index"] == 1
class TestLLMClient:
"""Tests for LLMClient."""
@pytest.fixture
def llm_client(self, mock_httpx_client):
"""Create an LLMClient with mocked HTTP."""
from handler_base.clients.llm import LLMClient
client = LLMClient()
client._client = mock_httpx_client
return client
@pytest.mark.asyncio
async def test_generate(self, llm_client, mock_httpx_client):
"""Test generating a response."""
mock_response = MagicMock()
mock_response.json.return_value = {
"choices": [{"message": {"content": "Hello! I'm an AI assistant."}}],
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response
result = await llm_client.generate("Hello")
assert result == "Hello! I'm an AI assistant."
@pytest.mark.asyncio
async def test_generate_with_context(self, llm_client, mock_httpx_client):
"""Test generating with RAG context."""
mock_response = MagicMock()
mock_response.json.return_value = {
"choices": [{"message": {"content": "Based on the context..."}}],
"usage": {},
}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response
result = await llm_client.generate(
"What is Python?", context="Python is a programming language."
)
assert "Based on the context" in result
# Verify context was included in the request
call_args = mock_httpx_client.post.call_args
messages = call_args.kwargs["json"]["messages"]
assert any("Context:" in m["content"] for m in messages if m["role"] == "user")

View File

@@ -1,46 +0,0 @@
"""
Unit tests for handler_base.config module.
"""
class TestSettings:
"""Tests for Settings configuration."""
def test_default_settings(self, settings):
"""Test that default settings are loaded correctly."""
assert settings.service_name == "test-service"
assert settings.service_version == "1.0.0-test"
assert settings.otel_enabled is False
def test_settings_from_env(self, monkeypatch):
"""Test that settings can be loaded from environment variables."""
monkeypatch.setenv("SERVICE_NAME", "env-service")
monkeypatch.setenv("SERVICE_VERSION", "2.0.0")
monkeypatch.setenv("NATS_URL", "nats://custom:4222")
# Need to reimport to pick up env changes
from handler_base.config import Settings
s = Settings()
assert s.service_name == "env-service"
assert s.service_version == "2.0.0"
assert s.nats_url == "nats://custom:4222"
def test_embeddings_settings(self):
"""Test EmbeddingsSettings extends base correctly."""
from handler_base.config import EmbeddingsSettings
s = EmbeddingsSettings()
assert hasattr(s, "embeddings_model")
assert hasattr(s, "embeddings_batch_size")
assert s.embeddings_model == "bge"
def test_llm_settings(self):
"""Test LLMSettings has expected defaults."""
from handler_base.config import LLMSettings
s = LLMSettings()
assert s.llm_max_tokens == 2048
assert s.llm_temperature == 0.7
assert 0 <= s.llm_top_p <= 1

View File

@@ -1,122 +0,0 @@
"""
Unit tests for handler_base.health module.
"""
import json
import time
from http.client import HTTPConnection
import pytest
class TestHealthServer:
"""Tests for HealthServer."""
@pytest.fixture
def health_server(self, settings):
"""Create a HealthServer instance."""
from handler_base.health import HealthServer
# Use a random high port to avoid conflicts
settings.health_port = 18080
return HealthServer(settings)
def test_start_stop(self, health_server):
"""Test starting and stopping the health server."""
health_server.start()
time.sleep(0.1) # Give server time to start
# Verify server is running
assert health_server._server is not None
assert health_server._thread is not None
assert health_server._thread.is_alive()
health_server.stop()
time.sleep(0.1)
assert health_server._server is None
def test_health_endpoint(self, health_server):
"""Test the /health endpoint."""
health_server.start()
time.sleep(0.1)
try:
conn = HTTPConnection("localhost", 18080, timeout=5)
conn.request("GET", "/health")
response = conn.getresponse()
assert response.status == 200
data = json.loads(response.read().decode())
assert data["status"] == "healthy"
finally:
conn.close()
health_server.stop()
def test_ready_endpoint_default(self, health_server):
"""Test the /ready endpoint with no custom check."""
health_server.start()
time.sleep(0.1)
try:
conn = HTTPConnection("localhost", 18080, timeout=5)
conn.request("GET", "/ready")
response = conn.getresponse()
assert response.status == 200
data = json.loads(response.read().decode())
assert data["status"] == "ready"
finally:
conn.close()
health_server.stop()
def test_ready_endpoint_with_check(self, settings):
"""Test /ready endpoint with custom readiness check."""
from handler_base.health import HealthServer
ready_flag = [False] # Use list to allow mutation in closure
async def check_ready():
return ready_flag[0]
settings.health_port = 18081
server = HealthServer(settings, ready_check=check_ready)
server.start()
time.sleep(0.2)
try:
conn = HTTPConnection("localhost", 18081, timeout=5)
# Should be not ready initially
conn.request("GET", "/ready")
response = conn.getresponse()
response.read() # Consume response body
assert response.status == 503
# Mark as ready
ready_flag[0] = True
# Need new connection after consuming response
conn.close()
conn = HTTPConnection("localhost", 18081, timeout=5)
conn.request("GET", "/ready")
response = conn.getresponse()
assert response.status == 200
finally:
conn.close()
server.stop()
def test_404_for_unknown_path(self, health_server):
"""Test that unknown paths return 404."""
health_server.start()
time.sleep(0.1)
try:
conn = HTTPConnection("localhost", 18080, timeout=5)
conn.request("GET", "/unknown")
response = conn.getresponse()
assert response.status == 404
finally:
conn.close()
health_server.stop()

View File

@@ -1,95 +0,0 @@
"""
Unit tests for handler_base.nats_client module.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import msgpack
import pytest
class TestNATSClient:
"""Tests for NATSClient."""
@pytest.fixture
def nats_client(self, settings):
"""Create a NATSClient instance."""
from handler_base.nats_client import NATSClient
return NATSClient(settings)
def test_init(self, nats_client, settings):
"""Test NATSClient initialization."""
assert nats_client.settings == settings
assert nats_client._nc is None
assert nats_client._js is None
def test_decode_msgpack(self, nats_client):
"""Test msgpack decoding."""
data = {"query": "hello", "request_id": "123"}
encoded = msgpack.packb(data, use_bin_type=True)
msg = MagicMock()
msg.data = encoded
result = nats_client.decode_msgpack(msg)
assert result == data
def test_decode_json(self, nats_client):
"""Test JSON decoding."""
import json
data = {"query": "hello"}
msg = MagicMock()
msg.data = json.dumps(data).encode()
result = nats_client.decode_json(msg)
assert result == data
@pytest.mark.asyncio
async def test_connect(self, nats_client):
"""Test NATS connection."""
with patch("handler_base.nats_client.nats") as mock_nats:
mock_nc = AsyncMock()
mock_js = MagicMock()
mock_nc.jetstream = MagicMock(return_value=mock_js) # Not async
mock_nats.connect = AsyncMock(return_value=mock_nc)
await nats_client.connect()
assert nats_client._nc == mock_nc
assert nats_client._js == mock_js
mock_nats.connect.assert_called_once()
@pytest.mark.asyncio
async def test_publish(self, nats_client):
"""Test publishing a message."""
mock_nc = AsyncMock()
nats_client._nc = mock_nc
data = {"key": "value"}
await nats_client.publish("test.subject", data)
mock_nc.publish.assert_called_once()
call_args = mock_nc.publish.call_args
assert call_args.args[0] == "test.subject"
# Verify msgpack encoding
decoded = msgpack.unpackb(call_args.args[1], raw=False)
assert decoded == data
@pytest.mark.asyncio
async def test_subscribe(self, nats_client):
"""Test subscribing to a subject."""
mock_nc = AsyncMock()
mock_sub = MagicMock()
mock_nc.subscribe = AsyncMock(return_value=mock_sub)
nats_client._nc = mock_nc
handler = AsyncMock()
await nats_client.subscribe("test.subject", handler, queue="test-queue")
mock_nc.subscribe.assert_called_once()
call_kwargs = mock_nc.subscribe.call_args.kwargs
assert call_kwargs["queue"] == "test-queue"

4256
uv.lock generated

File diff suppressed because it is too large Load Diff