Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| fba7b62573 | |||
| 6fd0b9a265 | |||
| 8b6232141a | |||
| 9876cb9388 | |||
| 39673d31b8 | |||
| 81581337cd | |||
| ea9b3a8f2b | |||
| 35912d5844 | |||
| d321c9852b | |||
| 5eb2c43a5d |
@@ -17,23 +17,22 @@ jobs:
|
|||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up uv
|
- name: Set up Go
|
||||||
uses: astral-sh/setup-uv@v7
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
version: "latest"
|
go-version-file: go.mod
|
||||||
activate-environment: false
|
cache: true
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Run go vet
|
||||||
run: uv python install 3.13
|
run: go vet ./...
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install golangci-lint
|
||||||
run: uv sync --frozen --extra dev
|
run: |
|
||||||
|
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b "$(go env GOPATH)/bin"
|
||||||
|
echo "$(go env GOPATH)/bin" >> $GITHUB_PATH
|
||||||
|
|
||||||
- name: Run ruff check
|
- name: Run golangci-lint
|
||||||
run: uv run ruff check .
|
run: golangci-lint run ./...
|
||||||
|
|
||||||
- name: Run ruff format check
|
|
||||||
run: uv run ruff format --check .
|
|
||||||
|
|
||||||
test:
|
test:
|
||||||
name: Test
|
name: Test
|
||||||
@@ -42,26 +41,28 @@ jobs:
|
|||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up uv
|
- name: Set up Go
|
||||||
uses: astral-sh/setup-uv@v7
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
version: "latest"
|
go-version-file: go.mod
|
||||||
activate-environment: false
|
cache: true
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Verify dependencies
|
||||||
run: uv python install 3.13
|
run: go mod verify
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Build
|
||||||
run: uv sync --frozen --extra dev
|
run: go build -v ./...
|
||||||
|
|
||||||
- name: Run tests with coverage
|
- name: Run tests
|
||||||
run: uv run pytest --cov=handler_base --cov-report=xml --cov-report=term
|
run: go test -v -race -coverprofile=coverage.out -covermode=atomic ./...
|
||||||
|
|
||||||
release:
|
release:
|
||||||
name: Release
|
name: Release
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: [lint, test]
|
needs: [lint, test]
|
||||||
if: gitea.ref == 'refs/heads/main' && gitea.event_name == 'push'
|
if: gitea.ref == 'refs/heads/main' && gitea.event_name == 'push'
|
||||||
|
outputs:
|
||||||
|
version: ${{ steps.version.outputs.version }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -101,10 +102,32 @@ jobs:
|
|||||||
git tag -a ${{ steps.version.outputs.version }} -m "Release ${{ steps.version.outputs.version }}"
|
git tag -a ${{ steps.version.outputs.version }} -m "Release ${{ steps.version.outputs.version }}"
|
||||||
git push origin ${{ steps.version.outputs.version }}
|
git push origin ${{ steps.version.outputs.version }}
|
||||||
|
|
||||||
|
notify-downstream:
|
||||||
|
name: Notify Downstream
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [release]
|
||||||
|
if: needs.release.result == 'success'
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
repo:
|
||||||
|
- chat-handler
|
||||||
|
- pipeline-bridge
|
||||||
|
- tts-module
|
||||||
|
- voice-assistant
|
||||||
|
- stt-module
|
||||||
|
steps:
|
||||||
|
- name: Trigger dependency update
|
||||||
|
run: |
|
||||||
|
curl -s -X POST \
|
||||||
|
-H "Authorization: token ${{ secrets.DISPATCH_TOKEN }}" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"event_type":"handler-base-release","client_payload":{"version":"${{ needs.release.outputs.version }}"}}' \
|
||||||
|
"${{ gitea.server_url }}/api/v1/repos/daviestechlabs/${{ matrix.repo }}/dispatches"
|
||||||
|
|
||||||
notify:
|
notify:
|
||||||
name: Notify
|
name: Notify
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: [lint, test, release]
|
needs: [lint, test, release, notify-downstream]
|
||||||
if: always()
|
if: always()
|
||||||
steps:
|
steps:
|
||||||
- name: Notify on success
|
- name: Notify on success
|
||||||
|
|||||||
35
.gitignore
vendored
35
.gitignore
vendored
@@ -1,23 +1,20 @@
|
|||||||
# Python
|
# Go
|
||||||
__pycache__/
|
*.exe
|
||||||
*.py[cod]
|
*.dll
|
||||||
*$py.class
|
|
||||||
*.so
|
*.so
|
||||||
.Python
|
*.dylib
|
||||||
build/
|
*.test
|
||||||
develop-eggs/
|
*.out
|
||||||
dist/
|
vendor/
|
||||||
downloads/
|
|
||||||
eggs/
|
# IDE
|
||||||
.eggs/
|
.idea/
|
||||||
lib/
|
.vscode/
|
||||||
lib64/
|
*.swp
|
||||||
parts/
|
|
||||||
sdist/
|
# OS
|
||||||
var/
|
.DS_Store
|
||||||
wheels/
|
Thumbs.db
|
||||||
*.egg-info/
|
|
||||||
.installed.cfg
|
|
||||||
*.egg
|
*.egg
|
||||||
|
|
||||||
# Virtual environments
|
# Virtual environments
|
||||||
|
|||||||
@@ -1,32 +0,0 @@
|
|||||||
# Pre-commit hooks for handler-base
|
|
||||||
# Install: pip install pre-commit && pre-commit install
|
|
||||||
# Run: pre-commit run --all-files
|
|
||||||
|
|
||||||
repos:
|
|
||||||
# Ruff - fast Python linter and formatter
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
||||||
rev: v0.4.4
|
|
||||||
hooks:
|
|
||||||
- id: ruff
|
|
||||||
args: [--fix, --exit-non-zero-on-fix]
|
|
||||||
- id: ruff-format
|
|
||||||
|
|
||||||
# Standard pre-commit hooks
|
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
||||||
rev: v4.6.0
|
|
||||||
hooks:
|
|
||||||
- id: trailing-whitespace
|
|
||||||
- id: end-of-file-fixer
|
|
||||||
- id: check-yaml
|
|
||||||
- id: check-added-large-files
|
|
||||||
args: [--maxkb=500]
|
|
||||||
- id: check-merge-conflict
|
|
||||||
- id: detect-private-key
|
|
||||||
|
|
||||||
# Type checking (optional - uncomment when ready)
|
|
||||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
|
||||||
# rev: v1.10.0
|
|
||||||
# hooks:
|
|
||||||
# - id: mypy
|
|
||||||
# additional_dependencies: [types-all]
|
|
||||||
# args: [--ignore-missing-imports]
|
|
||||||
58
Dockerfile
58
Dockerfile
@@ -1,58 +0,0 @@
|
|||||||
# Handler Base Image
|
|
||||||
#
|
|
||||||
# Provides a pre-built base with all common dependencies for handler services.
|
|
||||||
# Services extend this image and add their specific code.
|
|
||||||
#
|
|
||||||
# Build:
|
|
||||||
# docker build -t ghcr.io/billy-davies-2/handler-base:latest .
|
|
||||||
#
|
|
||||||
# Usage in child Dockerfile:
|
|
||||||
# FROM ghcr.io/billy-davies-2/handler-base:latest
|
|
||||||
# COPY my_handler.py .
|
|
||||||
# CMD ["python", "my_handler.py"]
|
|
||||||
|
|
||||||
FROM python:3.13-slim AS base
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
# Install uv for fast, reliable package management
|
|
||||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
|
|
||||||
|
|
||||||
# Install system dependencies
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
||||||
curl \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Copy and install handler-base package
|
|
||||||
COPY pyproject.toml README.md ./
|
|
||||||
COPY handler_base/ ./handler_base/
|
|
||||||
|
|
||||||
# Install the package with all dependencies
|
|
||||||
RUN uv pip install --system --no-cache .
|
|
||||||
|
|
||||||
# Set environment variables
|
|
||||||
ENV PYTHONUNBUFFERED=1
|
|
||||||
ENV PYTHONDONTWRITEBYTECODE=1
|
|
||||||
|
|
||||||
# Default health check
|
|
||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
|
||||||
CMD curl -f http://localhost:8080/health || exit 1
|
|
||||||
|
|
||||||
# Default command (override in child images)
|
|
||||||
CMD ["python", "-c", "print('handler-base ready')"]
|
|
||||||
|
|
||||||
|
|
||||||
# Audio variant with soundfile, librosa, webrtcvad
|
|
||||||
FROM base AS audio
|
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
||||||
ffmpeg \
|
|
||||||
libsndfile1 \
|
|
||||||
gcc \
|
|
||||||
python3-dev \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
RUN uv pip install --system --no-cache \
|
|
||||||
soundfile>=0.12.0 \
|
|
||||||
librosa>=0.10.0 \
|
|
||||||
webrtcvad>=2.0.10
|
|
||||||
129
README.md
129
README.md
@@ -1,109 +1,44 @@
|
|||||||
# Handler Base
|
# handler-base
|
||||||
|
|
||||||
Shared base library for building NATS-based AI/ML handler services.
|
Go module providing shared infrastructure for NATS-based handler services.
|
||||||
|
|
||||||
## Installation
|
## Packages
|
||||||
|
|
||||||
```bash
|
| Package | Purpose |
|
||||||
pip install handler-base
|
|---------|---------|
|
||||||
```
|
| `config` | Environment-based configuration via struct fields |
|
||||||
|
| `health` | HTTP health/readiness server for Kubernetes probes |
|
||||||
|
| `natsutil` | NATS/JetStream client with msgpack serialization |
|
||||||
|
| `telemetry` | OpenTelemetry tracing and metrics setup |
|
||||||
|
| `clients` | HTTP clients for LLM, embeddings, reranker, STT, TTS |
|
||||||
|
| `handler` | Base Handler runner wiring NATS + health + telemetry |
|
||||||
|
|
||||||
Or from Gitea:
|
## Usage
|
||||||
```bash
|
|
||||||
pip install git+https://git.daviestechlabs.io/daviestechlabs/handler-base.git
|
|
||||||
```
|
|
||||||
|
|
||||||
## Quick Start
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
```python
|
import (
|
||||||
from handler_base import Handler, Settings
|
"context"
|
||||||
from nats.aio.msg import Msg
|
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
||||||
|
"git.daviestechlabs.io/daviestechlabs/handler-base/handler"
|
||||||
class MyHandler(Handler):
|
"github.com/nats-io/nats.go"
|
||||||
async def setup(self):
|
|
||||||
# Initialize your clients
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def handle_message(self, msg: Msg, data: dict):
|
|
||||||
# Process the message
|
|
||||||
result = {"processed": True}
|
|
||||||
return result
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
MyHandler(subject="my.subject").run()
|
|
||||||
```
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
- **Handler base class** - NATS subscription, graceful shutdown, signal handling
|
|
||||||
- **NATSClient** - Connection management, JetStream, msgpack serialization
|
|
||||||
- **Settings** - Pydantic-based configuration from environment
|
|
||||||
- **HealthServer** - Kubernetes liveness/readiness probes
|
|
||||||
- **Telemetry** - OpenTelemetry tracing and metrics
|
|
||||||
- **Service clients** - HTTP wrappers for AI services
|
|
||||||
|
|
||||||
## Service Clients
|
|
||||||
|
|
||||||
```python
|
|
||||||
from handler_base.clients import (
|
|
||||||
STTClient, # Whisper speech-to-text
|
|
||||||
TTSClient, # XTTS text-to-speech
|
|
||||||
LLMClient, # vLLM chat completions
|
|
||||||
EmbeddingsClient, # BGE embeddings
|
|
||||||
RerankerClient, # BGE reranker
|
|
||||||
MilvusClient, # Vector database
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
cfg := config.Load()
|
||||||
|
cfg.ServiceName = "my-service"
|
||||||
|
|
||||||
|
h := handler.New("my.subject", cfg)
|
||||||
|
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
||||||
|
return map[string]any{"ok": true}, nil
|
||||||
|
})
|
||||||
|
h.Run()
|
||||||
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Configuration
|
## Testing
|
||||||
|
|
||||||
All settings via environment variables:
|
|
||||||
|
|
||||||
| Variable | Default | Description |
|
|
||||||
|----------|---------|-------------|
|
|
||||||
| `NATS_URL` | `nats://localhost:4222` | NATS server URL |
|
|
||||||
| `NATS_USER` | - | NATS username |
|
|
||||||
| `NATS_PASSWORD` | - | NATS password |
|
|
||||||
| `NATS_QUEUE_GROUP` | - | Queue group for load balancing |
|
|
||||||
| `HEALTH_PORT` | `8080` | Health check server port |
|
|
||||||
| `OTEL_ENABLED` | `true` | Enable OpenTelemetry |
|
|
||||||
| `OTEL_EXPORTER_OTLP_ENDPOINT` | `http://localhost:4317` | OTLP endpoint |
|
|
||||||
| `OTEL_SERVICE_NAME` | `handler` | Service name for traces |
|
|
||||||
|
|
||||||
## Docker
|
|
||||||
|
|
||||||
```dockerfile
|
|
||||||
FROM ghcr.io/daviestechlabs/handler-base:latest
|
|
||||||
|
|
||||||
COPY my_handler.py /app/
|
|
||||||
CMD ["python", "/app/my_handler.py"]
|
|
||||||
```
|
|
||||||
|
|
||||||
Or build with audio support:
|
|
||||||
```bash
|
```bash
|
||||||
docker build --build-arg INSTALL_AUDIO=true -t my-handler .
|
go test ./...
|
||||||
```
|
```
|
||||||
|
|
||||||
## Module Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
handler_base/
|
|
||||||
├── __init__.py # Public API exports
|
|
||||||
├── handler.py # Base Handler class
|
|
||||||
├── nats_client.py # NATS connection wrapper
|
|
||||||
├── config.py # Pydantic Settings
|
|
||||||
├── health.py # Health check server
|
|
||||||
├── telemetry.py # OpenTelemetry setup
|
|
||||||
└── clients/
|
|
||||||
├── embeddings.py
|
|
||||||
├── llm.py
|
|
||||||
├── milvus.py
|
|
||||||
├── reranker.py
|
|
||||||
├── stt.py
|
|
||||||
└── tts.py
|
|
||||||
```
|
|
||||||
|
|
||||||
## Related
|
|
||||||
|
|
||||||
- [voice-assistant](https://git.daviestechlabs.io/daviestechlabs/voice-assistant) - Voice pipeline using handler-base
|
|
||||||
- [homelab-design](https://git.daviestechlabs.io/daviestechlabs/homelab-design) - Architecture docs
|
|
||||||
|
|||||||
436
clients/clients.go
Normal file
436
clients/clients.go
Normal file
@@ -0,0 +1,436 @@
|
|||||||
|
// Package clients provides HTTP client wrappers for AI/ML backend services.
|
||||||
|
//
|
||||||
|
// All clients share a single [http.Transport] for connection pooling across
|
||||||
|
// the process. Request and response bodies are serialized through pooled
|
||||||
|
// [bytes.Buffer]s to reduce GC pressure.
|
||||||
|
package clients
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ─── Shared transport & buffer pool ─────────────────────────────────────────
|
||||||
|
|
||||||
|
// SharedTransport is the process-wide HTTP transport used by every service
|
||||||
|
// client. Tweak pool sizes here rather than creating per-client transports.
|
||||||
|
var SharedTransport = &http.Transport{
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
MaxIdleConnsPerHost: 10,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
DisableCompression: true, // in-cluster traffic; skip gzip overhead
|
||||||
|
}
|
||||||
|
|
||||||
|
// bufPool recycles *bytes.Buffer to avoid per-request allocations.
|
||||||
|
var bufPool = sync.Pool{
|
||||||
|
New: func() any { return new(bytes.Buffer) },
|
||||||
|
}
|
||||||
|
|
||||||
|
func getBuf() *bytes.Buffer {
|
||||||
|
buf := bufPool.Get().(*bytes.Buffer)
|
||||||
|
buf.Reset()
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func putBuf(buf *bytes.Buffer) {
|
||||||
|
if buf.Cap() > 1<<20 { // don't cache buffers > 1 MiB
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bufPool.Put(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── httpClient base ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// httpClient is the shared base for all service clients.
|
||||||
|
type httpClient struct {
|
||||||
|
client *http.Client
|
||||||
|
baseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTPClient(baseURL string, timeout time.Duration) *httpClient {
|
||||||
|
return &httpClient{
|
||||||
|
client: &http.Client{
|
||||||
|
Timeout: timeout,
|
||||||
|
Transport: SharedTransport,
|
||||||
|
},
|
||||||
|
baseURL: baseURL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpClient) postJSON(ctx context.Context, path string, body any) ([]byte, error) {
|
||||||
|
buf := getBuf()
|
||||||
|
defer putBuf(buf)
|
||||||
|
if err := json.NewEncoder(buf).Encode(body); err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal: %w", err)
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
return h.do(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpClient) get(ctx context.Context, path string, params url.Values) ([]byte, error) {
|
||||||
|
u := h.baseURL + path
|
||||||
|
if len(params) > 0 {
|
||||||
|
u += "?" + params.Encode()
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return h.do(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpClient) getRaw(ctx context.Context, path string, params url.Values) ([]byte, error) {
|
||||||
|
return h.get(ctx, path, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpClient) postMultipart(ctx context.Context, path string, fieldName string, fileName string, fileData []byte, fields map[string]string) ([]byte, error) {
|
||||||
|
buf := getBuf()
|
||||||
|
defer putBuf(buf)
|
||||||
|
w := multipart.NewWriter(buf)
|
||||||
|
part, err := w.CreateFormFile(fieldName, fileName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if _, err := part.Write(fileData); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for k, v := range fields {
|
||||||
|
_ = w.WriteField(k, v)
|
||||||
|
}
|
||||||
|
_ = w.Close()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, h.baseURL+path, buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||||
|
return h.do(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpClient) do(req *http.Request) ([]byte, error) {
|
||||||
|
resp, err := h.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("http %s %s: %w", req.Method, req.URL.Path, err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
buf := getBuf()
|
||||||
|
defer putBuf(buf)
|
||||||
|
if _, err := io.Copy(buf, resp.Body); err != nil {
|
||||||
|
return nil, fmt.Errorf("read body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a copy so the pooled buffer can be safely recycled.
|
||||||
|
body := make([]byte, buf.Len())
|
||||||
|
copy(body, buf.Bytes())
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return nil, fmt.Errorf("http %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpClient) healthCheck(ctx context.Context) bool {
|
||||||
|
data, err := h.get(ctx, "/health", nil)
|
||||||
|
_ = data
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Embeddings Client ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// EmbeddingsClient calls the embeddings service (Infinity/BGE).
|
||||||
|
type EmbeddingsClient struct {
|
||||||
|
*httpClient
|
||||||
|
Model string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEmbeddingsClient creates an embeddings client.
|
||||||
|
func NewEmbeddingsClient(baseURL string, timeout time.Duration, model string) *EmbeddingsClient {
|
||||||
|
if model == "" {
|
||||||
|
model = "bge"
|
||||||
|
}
|
||||||
|
return &EmbeddingsClient{httpClient: newHTTPClient(baseURL, timeout), Model: model}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Embed generates embeddings for a list of texts.
|
||||||
|
func (c *EmbeddingsClient) Embed(ctx context.Context, texts []string) ([][]float64, error) {
|
||||||
|
body, err := c.postJSON(ctx, "/embeddings", map[string]any{
|
||||||
|
"input": texts,
|
||||||
|
"model": c.Model,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var resp struct {
|
||||||
|
Data []struct {
|
||||||
|
Embedding []float64 `json:"embedding"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &resp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result := make([][]float64, len(resp.Data))
|
||||||
|
for i, d := range resp.Data {
|
||||||
|
result[i] = d.Embedding
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbedSingle generates an embedding for a single text.
|
||||||
|
func (c *EmbeddingsClient) EmbedSingle(ctx context.Context, text string) ([]float64, error) {
|
||||||
|
results, err := c.Embed(ctx, []string{text})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(results) == 0 {
|
||||||
|
return nil, fmt.Errorf("empty embedding result")
|
||||||
|
}
|
||||||
|
return results[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Health checks if the embeddings service is healthy.
|
||||||
|
func (c *EmbeddingsClient) Health(ctx context.Context) bool {
|
||||||
|
return c.healthCheck(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Reranker Client ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// RerankerClient calls the reranker service (BGE Reranker).
|
||||||
|
type RerankerClient struct {
|
||||||
|
*httpClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRerankerClient creates a reranker client.
|
||||||
|
func NewRerankerClient(baseURL string, timeout time.Duration) *RerankerClient {
|
||||||
|
return &RerankerClient{httpClient: newHTTPClient(baseURL, timeout)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RerankResult represents a reranked document.
|
||||||
|
type RerankResult struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Score float64 `json:"score"`
|
||||||
|
Document string `json:"document"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rerank reranks documents by relevance to the query.
|
||||||
|
func (c *RerankerClient) Rerank(ctx context.Context, query string, documents []string, topK int) ([]RerankResult, error) {
|
||||||
|
payload := map[string]any{
|
||||||
|
"query": query,
|
||||||
|
"documents": documents,
|
||||||
|
}
|
||||||
|
if topK > 0 {
|
||||||
|
payload["top_n"] = topK
|
||||||
|
}
|
||||||
|
body, err := c.postJSON(ctx, "/rerank", payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var resp struct {
|
||||||
|
Results []struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
RelevanceScore float64 `json:"relevance_score"`
|
||||||
|
Score float64 `json:"score"`
|
||||||
|
} `json:"results"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &resp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
results := make([]RerankResult, len(resp.Results))
|
||||||
|
for i, r := range resp.Results {
|
||||||
|
score := r.RelevanceScore
|
||||||
|
if score == 0 {
|
||||||
|
score = r.Score
|
||||||
|
}
|
||||||
|
doc := ""
|
||||||
|
if r.Index < len(documents) {
|
||||||
|
doc = documents[r.Index]
|
||||||
|
}
|
||||||
|
results[i] = RerankResult{Index: r.Index, Score: score, Document: doc}
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── LLM Client ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// LLMClient calls the vLLM-compatible LLM service.
|
||||||
|
type LLMClient struct {
|
||||||
|
*httpClient
|
||||||
|
Model string
|
||||||
|
MaxTokens int
|
||||||
|
Temperature float64
|
||||||
|
TopP float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLLMClient creates an LLM client.
|
||||||
|
func NewLLMClient(baseURL string, timeout time.Duration) *LLMClient {
|
||||||
|
return &LLMClient{
|
||||||
|
httpClient: newHTTPClient(baseURL, timeout),
|
||||||
|
Model: "default",
|
||||||
|
MaxTokens: 2048,
|
||||||
|
Temperature: 0.7,
|
||||||
|
TopP: 0.9,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatMessage is an OpenAI-compatible message.
|
||||||
|
type ChatMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate sends a chat completion request and returns the response text.
|
||||||
|
func (c *LLMClient) Generate(ctx context.Context, prompt string, context_ string, systemPrompt string) (string, error) {
|
||||||
|
messages := buildMessages(prompt, context_, systemPrompt)
|
||||||
|
payload := map[string]any{
|
||||||
|
"model": c.Model,
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": c.MaxTokens,
|
||||||
|
"temperature": c.Temperature,
|
||||||
|
"top_p": c.TopP,
|
||||||
|
}
|
||||||
|
body, err := c.postJSON(ctx, "/v1/chat/completions", payload)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
var resp struct {
|
||||||
|
Choices []struct {
|
||||||
|
Message struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
} `json:"message"`
|
||||||
|
} `json:"choices"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &resp); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if len(resp.Choices) == 0 {
|
||||||
|
return "", fmt.Errorf("no choices in LLM response")
|
||||||
|
}
|
||||||
|
return resp.Choices[0].Message.Content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildMessages(prompt, ctx, systemPrompt string) []ChatMessage {
|
||||||
|
var msgs []ChatMessage
|
||||||
|
if systemPrompt != "" {
|
||||||
|
msgs = append(msgs, ChatMessage{Role: "system", Content: systemPrompt})
|
||||||
|
} else if ctx != "" {
|
||||||
|
msgs = append(msgs, ChatMessage{Role: "system", Content: "You are a helpful assistant. Use the provided context to answer the user's question. If the context doesn't contain relevant information, say so."})
|
||||||
|
}
|
||||||
|
if ctx != "" {
|
||||||
|
msgs = append(msgs, ChatMessage{Role: "user", Content: fmt.Sprintf("Context:\n%s\n\nQuestion: %s", ctx, prompt)})
|
||||||
|
} else {
|
||||||
|
msgs = append(msgs, ChatMessage{Role: "user", Content: prompt})
|
||||||
|
}
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── TTS Client ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// TTSClient calls the TTS service (Coqui XTTS).
|
||||||
|
type TTSClient struct {
|
||||||
|
*httpClient
|
||||||
|
Language string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTTSClient creates a TTS client.
|
||||||
|
func NewTTSClient(baseURL string, timeout time.Duration, language string) *TTSClient {
|
||||||
|
if language == "" {
|
||||||
|
language = "en"
|
||||||
|
}
|
||||||
|
return &TTSClient{httpClient: newHTTPClient(baseURL, timeout), Language: language}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Synthesize generates audio bytes from text.
|
||||||
|
func (c *TTSClient) Synthesize(ctx context.Context, text, language, speaker string) ([]byte, error) {
|
||||||
|
if language == "" {
|
||||||
|
language = c.Language
|
||||||
|
}
|
||||||
|
params := url.Values{
|
||||||
|
"text": {text},
|
||||||
|
"language_id": {language},
|
||||||
|
}
|
||||||
|
if speaker != "" {
|
||||||
|
params.Set("speaker_id", speaker)
|
||||||
|
}
|
||||||
|
return c.getRaw(ctx, "/api/tts", params)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── STT Client ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// STTClient calls the Whisper STT service.
|
||||||
|
type STTClient struct {
|
||||||
|
*httpClient
|
||||||
|
Language string
|
||||||
|
Task string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSTTClient creates an STT client.
|
||||||
|
func NewSTTClient(baseURL string, timeout time.Duration) *STTClient {
|
||||||
|
return &STTClient{httpClient: newHTTPClient(baseURL, timeout), Task: "transcribe"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TranscribeResult holds transcription output.
|
||||||
|
type TranscribeResult struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
Language string `json:"language,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transcribe sends audio to Whisper and returns the transcription.
|
||||||
|
func (c *STTClient) Transcribe(ctx context.Context, audio []byte, language string) (*TranscribeResult, error) {
|
||||||
|
if language == "" {
|
||||||
|
language = c.Language
|
||||||
|
}
|
||||||
|
fields := map[string]string{
|
||||||
|
"response_format": "json",
|
||||||
|
}
|
||||||
|
if language != "" {
|
||||||
|
fields["language"] = language
|
||||||
|
}
|
||||||
|
endpoint := "/v1/audio/transcriptions"
|
||||||
|
if c.Task == "translate" {
|
||||||
|
endpoint = "/v1/audio/translations"
|
||||||
|
}
|
||||||
|
body, err := c.postMultipart(ctx, endpoint, "file", "audio.wav", audio, fields)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var result TranscribeResult
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Milvus Client ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// MilvusClient provides vector search via the Milvus HTTP/gRPC API.
|
||||||
|
// For the Go port we use the Milvus Go SDK.
|
||||||
|
type MilvusClient struct {
|
||||||
|
Host string
|
||||||
|
Port int
|
||||||
|
Collection string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMilvusClient creates a Milvus client.
|
||||||
|
func NewMilvusClient(host string, port int, collection string) *MilvusClient {
|
||||||
|
return &MilvusClient{Host: host, Port: port, Collection: collection}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchResult holds a single vector search hit.
|
||||||
|
type SearchResult struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Distance float64 `json:"distance"`
|
||||||
|
Score float64 `json:"score"`
|
||||||
|
Fields map[string]any `json:"fields,omitempty"`
|
||||||
|
}
|
||||||
506
clients/clients_test.go
Normal file
506
clients/clients_test.go
Normal file
@@ -0,0 +1,506 @@
|
|||||||
|
package clients
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Shared infrastructure tests
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestSharedTransport(t *testing.T) {
|
||||||
|
// All clients created via newHTTPClient should share the same transport.
|
||||||
|
c1 := newHTTPClient("http://a:8000", 10*time.Second)
|
||||||
|
c2 := newHTTPClient("http://b:9000", 30*time.Second)
|
||||||
|
|
||||||
|
if c1.client.Transport != c2.client.Transport {
|
||||||
|
t.Error("clients should share the same http.Transport")
|
||||||
|
}
|
||||||
|
if c1.client.Transport != SharedTransport {
|
||||||
|
t.Error("transport should be the package-level SharedTransport")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBufferPoolGetPut(t *testing.T) {
|
||||||
|
buf := getBuf()
|
||||||
|
if buf == nil {
|
||||||
|
t.Fatal("getBuf returned nil")
|
||||||
|
}
|
||||||
|
if buf.Len() != 0 {
|
||||||
|
t.Error("getBuf should return a reset buffer")
|
||||||
|
}
|
||||||
|
buf.WriteString("hello")
|
||||||
|
putBuf(buf)
|
||||||
|
|
||||||
|
// On re-get, buffer should be reset.
|
||||||
|
buf2 := getBuf()
|
||||||
|
if buf2.Len() != 0 {
|
||||||
|
t.Error("re-acquired buffer should be reset")
|
||||||
|
}
|
||||||
|
putBuf(buf2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBufferPoolOversizedDiscarded(t *testing.T) {
|
||||||
|
buf := getBuf()
|
||||||
|
// Grow beyond 1 MB threshold.
|
||||||
|
buf.Write(make([]byte, 2<<20))
|
||||||
|
putBuf(buf) // should silently discard
|
||||||
|
|
||||||
|
// Pool should still work — we get a fresh one.
|
||||||
|
buf2 := getBuf()
|
||||||
|
if buf2.Len() != 0 {
|
||||||
|
t.Error("should get a fresh buffer")
|
||||||
|
}
|
||||||
|
putBuf(buf2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBufferPoolConcurrency(t *testing.T) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := range 100 {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(n int) {
|
||||||
|
defer wg.Done()
|
||||||
|
buf := getBuf()
|
||||||
|
buf.WriteString(strings.Repeat("x", n))
|
||||||
|
putBuf(buf)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Embeddings client
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestEmbeddingsClient_Embed(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/embeddings" {
|
||||||
|
t.Errorf("path = %q, want /embeddings", r.URL.Path)
|
||||||
|
}
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Errorf("method = %s, want POST", r.Method)
|
||||||
|
}
|
||||||
|
var req map[string]any
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
input, _ := req["input"].([]any)
|
||||||
|
if len(input) != 2 {
|
||||||
|
t.Errorf("input len = %d, want 2", len(input))
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"data": []map[string]any{
|
||||||
|
{"embedding": []float64{0.1, 0.2, 0.3}},
|
||||||
|
{"embedding": []float64{0.4, 0.5, 0.6}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "bge")
|
||||||
|
results, err := c.Embed(context.Background(), []string{"hello", "world"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Fatalf("len(results) = %d, want 2", len(results))
|
||||||
|
}
|
||||||
|
if results[0][0] != 0.1 {
|
||||||
|
t.Errorf("results[0][0] = %f, want 0.1", results[0][0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddingsClient_EmbedSingle(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"data": []map[string]any{
|
||||||
|
{"embedding": []float64{1.0, 2.0}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
|
||||||
|
vec, err := c.EmbedSingle(context.Background(), "test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(vec) != 2 || vec[0] != 1.0 {
|
||||||
|
t.Errorf("vec = %v", vec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddingsClient_EmbedEmpty(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{"data": []any{}})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
|
||||||
|
_, err := c.EmbedSingle(context.Background(), "test")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for empty embedding")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddingsClient_Health(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/health" {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(404)
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
|
||||||
|
if !c.Health(context.Background()) {
|
||||||
|
t.Error("expected healthy")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Reranker client
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestRerankerClient_Rerank(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req map[string]any
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
if req["query"] != "test query" {
|
||||||
|
t.Errorf("query = %v", req["query"])
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"results": []map[string]any{
|
||||||
|
{"index": 1, "relevance_score": 0.95},
|
||||||
|
{"index": 0, "relevance_score": 0.80},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewRerankerClient(ts.URL, 5*time.Second)
|
||||||
|
docs := []string{"Paris is great", "France is in Europe"}
|
||||||
|
results, err := c.Rerank(context.Background(), "test query", docs, 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Fatalf("len = %d", len(results))
|
||||||
|
}
|
||||||
|
if results[0].Score != 0.95 {
|
||||||
|
t.Errorf("score = %f, want 0.95", results[0].Score)
|
||||||
|
}
|
||||||
|
if results[0].Document != "France is in Europe" {
|
||||||
|
t.Errorf("document = %q", results[0].Document)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRerankerClient_RerankFallbackScore(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"results": []map[string]any{
|
||||||
|
{"index": 0, "score": 0.77, "relevance_score": 0}, // some APIs only set score
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewRerankerClient(ts.URL, 5*time.Second)
|
||||||
|
results, err := c.Rerank(context.Background(), "q", []string{"doc1"}, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if results[0].Score != 0.77 {
|
||||||
|
t.Errorf("fallback score = %f, want 0.77", results[0].Score)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// LLM client
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestLLMClient_Generate(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/v1/chat/completions" {
|
||||||
|
t.Errorf("path = %q", r.URL.Path)
|
||||||
|
}
|
||||||
|
var req map[string]any
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
msgs, _ := req["messages"].([]any)
|
||||||
|
if len(msgs) == 0 {
|
||||||
|
t.Error("no messages in request")
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"choices": []map[string]any{
|
||||||
|
{"message": map[string]any{"content": "Paris is the capital of France."}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
result, err := c.Generate(context.Background(), "capital of France?", "", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result != "Paris is the capital of France." {
|
||||||
|
t.Errorf("result = %q", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_GenerateWithContext(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req map[string]any
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
msgs, _ := req["messages"].([]any)
|
||||||
|
// Should have system + user message
|
||||||
|
if len(msgs) != 2 {
|
||||||
|
t.Errorf("expected 2 messages, got %d", len(msgs))
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"choices": []map[string]any{
|
||||||
|
{"message": map[string]any{"content": "answer with context"}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
result, err := c.Generate(context.Background(), "question", "some context", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result != "answer with context" {
|
||||||
|
t.Errorf("result = %q", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLLMClient_GenerateNoChoices(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{"choices": []any{}})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
_, err := c.Generate(context.Background(), "q", "", "")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for empty choices")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// TTS client
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestTTSClient_Synthesize(t *testing.T) {
|
||||||
|
expected := []byte{0xDE, 0xAD, 0xBE, 0xEF}
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/api/tts" {
|
||||||
|
t.Errorf("path = %q", r.URL.Path)
|
||||||
|
}
|
||||||
|
if r.URL.Query().Get("text") != "hello world" {
|
||||||
|
t.Errorf("text = %q", r.URL.Query().Get("text"))
|
||||||
|
}
|
||||||
|
_, _ = w.Write(expected)
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewTTSClient(ts.URL, 5*time.Second, "en")
|
||||||
|
audio, err := c.Synthesize(context.Background(), "hello world", "", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(audio, expected) {
|
||||||
|
t.Errorf("audio = %x, want %x", audio, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTTSClient_SynthesizeWithSpeaker(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Query().Get("speaker_id") != "alice" {
|
||||||
|
t.Errorf("speaker_id = %q", r.URL.Query().Get("speaker_id"))
|
||||||
|
}
|
||||||
|
_, _ = w.Write([]byte{0x01})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewTTSClient(ts.URL, 5*time.Second, "en")
|
||||||
|
_, err := c.Synthesize(context.Background(), "hi", "en", "alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// STT client
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestSTTClient_Transcribe(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/v1/audio/transcriptions" {
|
||||||
|
t.Errorf("path = %q", r.URL.Path)
|
||||||
|
}
|
||||||
|
ct := r.Header.Get("Content-Type")
|
||||||
|
if !strings.Contains(ct, "multipart/form-data") {
|
||||||
|
t.Errorf("content-type = %q", ct)
|
||||||
|
}
|
||||||
|
// Verify the audio file is present.
|
||||||
|
file, _, err := r.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
data, _ := io.ReadAll(file)
|
||||||
|
if len(data) != 100 {
|
||||||
|
t.Errorf("file size = %d, want 100", len(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]string{"text": "hello world"})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewSTTClient(ts.URL, 5*time.Second)
|
||||||
|
result, err := c.Transcribe(context.Background(), make([]byte, 100), "en")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result.Text != "hello world" {
|
||||||
|
t.Errorf("text = %q", result.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSTTClient_TranscribeTranslate(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/v1/audio/translations" {
|
||||||
|
t.Errorf("path = %q, want /v1/audio/translations", r.URL.Path)
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]string{"text": "translated"})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewSTTClient(ts.URL, 5*time.Second)
|
||||||
|
c.Task = "translate"
|
||||||
|
result, err := c.Transcribe(context.Background(), []byte{0x01}, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result.Text != "translated" {
|
||||||
|
t.Errorf("text = %q", result.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// HTTP error handling
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestHTTPError4xx(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(422)
|
||||||
|
_, _ = w.Write([]byte(`{"error": "bad input"}`))
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewEmbeddingsClient(ts.URL, 5*time.Second, "")
|
||||||
|
_, err := c.Embed(context.Background(), []string{"test"})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 422")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "422") {
|
||||||
|
t.Errorf("error should contain status code: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPError5xx(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(500)
|
||||||
|
_, _ = w.Write([]byte("internal server error"))
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewLLMClient(ts.URL, 5*time.Second)
|
||||||
|
_, err := c.Generate(context.Background(), "q", "", "")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 500")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// buildMessages helper
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestBuildMessages(t *testing.T) {
|
||||||
|
// No context, no system prompt → just user message
|
||||||
|
msgs := buildMessages("hello", "", "")
|
||||||
|
if len(msgs) != 1 || msgs[0].Role != "user" {
|
||||||
|
t.Errorf("expected 1 user msg, got %+v", msgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// With system prompt
|
||||||
|
msgs = buildMessages("hello", "", "You are helpful")
|
||||||
|
if len(msgs) != 2 || msgs[0].Role != "system" || msgs[0].Content != "You are helpful" {
|
||||||
|
t.Errorf("expected system+user, got %+v", msgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// With context, no system prompt → auto system prompt
|
||||||
|
msgs = buildMessages("question", "some context", "")
|
||||||
|
if len(msgs) != 2 || msgs[0].Role != "system" {
|
||||||
|
t.Errorf("expected auto system+user, got %+v", msgs)
|
||||||
|
}
|
||||||
|
if !strings.Contains(msgs[1].Content, "Context:") {
|
||||||
|
t.Error("user message should contain context")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Benchmarks: pooled buffer vs direct allocation
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func BenchmarkPostJSON(b *testing.B) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = io.Copy(io.Discard, r.Body)
|
||||||
|
_, _ = w.Write([]byte(`{"ok":true}`))
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := newHTTPClient(ts.URL, 10*time.Second)
|
||||||
|
ctx := context.Background()
|
||||||
|
payload := map[string]any{
|
||||||
|
"text": strings.Repeat("x", 1024),
|
||||||
|
"count": 42,
|
||||||
|
"enabled": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
_, _ = c.postJSON(ctx, "/test", payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkBufferPool(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
buf := getBuf()
|
||||||
|
buf.WriteString(strings.Repeat("x", 4096))
|
||||||
|
putBuf(buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkBufferPoolParallel(b *testing.B) {
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
buf := getBuf()
|
||||||
|
buf.WriteString(strings.Repeat("x", 4096))
|
||||||
|
putBuf(buf)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
268
config/config.go
Normal file
268
config/config.go
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
// Package config provides environment-based configuration for handler services
|
||||||
|
// with optional live reload of secrets and service endpoints.
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fsnotify/fsnotify"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Settings holds base configuration for all handler services.
|
||||||
|
// Fields in the "hot-reload" section are protected by a RWMutex and can be
|
||||||
|
// updated at runtime via WatchSecrets(). All other fields are immutable
|
||||||
|
// after Load() returns.
|
||||||
|
type Settings struct {
|
||||||
|
// Service identification (immutable)
|
||||||
|
ServiceName string
|
||||||
|
ServiceVersion string
|
||||||
|
ServiceNamespace string
|
||||||
|
DeploymentEnv string
|
||||||
|
|
||||||
|
// NATS configuration (immutable)
|
||||||
|
NATSURL string
|
||||||
|
NATSUser string
|
||||||
|
NATSPassword string
|
||||||
|
NATSQueueGroup string
|
||||||
|
|
||||||
|
// Redis/Valkey configuration (immutable)
|
||||||
|
RedisURL string
|
||||||
|
RedisPassword string
|
||||||
|
|
||||||
|
// Milvus configuration (immutable)
|
||||||
|
MilvusHost string
|
||||||
|
MilvusPort int
|
||||||
|
MilvusCollection string
|
||||||
|
|
||||||
|
// OpenTelemetry configuration (immutable)
|
||||||
|
OTELEnabled bool
|
||||||
|
OTELEndpoint string
|
||||||
|
OTELUseHTTP bool
|
||||||
|
|
||||||
|
// HyperDX configuration (immutable)
|
||||||
|
HyperDXEnabled bool
|
||||||
|
HyperDXAPIKey string
|
||||||
|
HyperDXEndpoint string
|
||||||
|
|
||||||
|
// MLflow configuration (immutable)
|
||||||
|
MLflowTrackingURI string
|
||||||
|
MLflowExperimentName string
|
||||||
|
MLflowEnabled bool
|
||||||
|
|
||||||
|
// Health check configuration (immutable)
|
||||||
|
HealthPort int
|
||||||
|
HealthPath string
|
||||||
|
ReadyPath string
|
||||||
|
|
||||||
|
// Timeouts (immutable)
|
||||||
|
HTTPTimeout time.Duration
|
||||||
|
NATSTimeout time.Duration
|
||||||
|
|
||||||
|
// Hot-reloadable fields — access via getter methods.
|
||||||
|
mu sync.RWMutex
|
||||||
|
embeddingsURL string
|
||||||
|
rerankerURL string
|
||||||
|
llmURL string
|
||||||
|
ttsURL string
|
||||||
|
sttURL string
|
||||||
|
|
||||||
|
// Secrets path for file-based hot reload (Kubernetes secret mounts)
|
||||||
|
SecretsPath string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load creates a Settings populated from environment variables with defaults.
|
||||||
|
func Load() *Settings {
|
||||||
|
return &Settings{
|
||||||
|
ServiceName: getEnv("SERVICE_NAME", "handler"),
|
||||||
|
ServiceVersion: getEnv("SERVICE_VERSION", "1.0.0"),
|
||||||
|
ServiceNamespace: getEnv("SERVICE_NAMESPACE", "ai-ml"),
|
||||||
|
DeploymentEnv: getEnv("DEPLOYMENT_ENV", "production"),
|
||||||
|
|
||||||
|
NATSURL: getEnv("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222"),
|
||||||
|
NATSUser: getEnv("NATS_USER", ""),
|
||||||
|
NATSPassword: getEnv("NATS_PASSWORD", ""),
|
||||||
|
NATSQueueGroup: getEnv("NATS_QUEUE_GROUP", ""),
|
||||||
|
|
||||||
|
RedisURL: getEnv("REDIS_URL", "redis://valkey.ai-ml.svc.cluster.local:6379"),
|
||||||
|
RedisPassword: getEnv("REDIS_PASSWORD", ""),
|
||||||
|
|
||||||
|
MilvusHost: getEnv("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local"),
|
||||||
|
MilvusPort: getEnvInt("MILVUS_PORT", 19530),
|
||||||
|
MilvusCollection: getEnv("MILVUS_COLLECTION", "documents"),
|
||||||
|
|
||||||
|
embeddingsURL: getEnv("EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"),
|
||||||
|
rerankerURL: getEnv("RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"),
|
||||||
|
llmURL: getEnv("LLM_URL", "http://vllm-predictor.ai-ml.svc.cluster.local"),
|
||||||
|
ttsURL: getEnv("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local"),
|
||||||
|
sttURL: getEnv("STT_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"),
|
||||||
|
|
||||||
|
OTELEnabled: getEnvBool("OTEL_ENABLED", true),
|
||||||
|
OTELEndpoint: getEnv("OTEL_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317"),
|
||||||
|
OTELUseHTTP: getEnvBool("OTEL_USE_HTTP", false),
|
||||||
|
|
||||||
|
HyperDXEnabled: getEnvBool("HYPERDX_ENABLED", false),
|
||||||
|
HyperDXAPIKey: getEnv("HYPERDX_API_KEY", ""),
|
||||||
|
HyperDXEndpoint: getEnv("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io"),
|
||||||
|
|
||||||
|
MLflowTrackingURI: getEnv("MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"),
|
||||||
|
MLflowExperimentName: getEnv("MLFLOW_EXPERIMENT_NAME", ""),
|
||||||
|
MLflowEnabled: getEnvBool("MLFLOW_ENABLED", true),
|
||||||
|
|
||||||
|
HealthPort: getEnvInt("HEALTH_PORT", 8080),
|
||||||
|
HealthPath: getEnv("HEALTH_PATH", "/health"),
|
||||||
|
ReadyPath: getEnv("READY_PATH", "/ready"),
|
||||||
|
|
||||||
|
HTTPTimeout: getEnvDuration("HTTP_TIMEOUT", 60*time.Second),
|
||||||
|
NATSTimeout: getEnvDuration("NATS_TIMEOUT", 30*time.Second),
|
||||||
|
|
||||||
|
SecretsPath: getEnv("SECRETS_PATH", ""),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbeddingsURL returns the current embeddings service URL (thread-safe).
|
||||||
|
func (s *Settings) EmbeddingsURL() string {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.embeddingsURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// RerankerURL returns the current reranker service URL (thread-safe).
|
||||||
|
func (s *Settings) RerankerURL() string {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.rerankerURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// LLMURL returns the current LLM service URL (thread-safe).
|
||||||
|
func (s *Settings) LLMURL() string {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.llmURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTSURL returns the current TTS service URL (thread-safe).
|
||||||
|
func (s *Settings) TTSURL() string {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.ttsURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// STTURL returns the current STT service URL (thread-safe).
|
||||||
|
func (s *Settings) STTURL() string {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.sttURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// WatchSecrets watches the SecretsPath directory for changes and reloads
|
||||||
|
// hot-reloadable fields. Blocks until ctx is cancelled.
|
||||||
|
func (s *Settings) WatchSecrets(ctx context.Context) {
|
||||||
|
if s.SecretsPath == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
watcher, err := fsnotify.NewWatcher()
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("config: failed to create fsnotify watcher", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = watcher.Close() }()
|
||||||
|
|
||||||
|
if err := watcher.Add(s.SecretsPath); err != nil {
|
||||||
|
slog.Error("config: failed to watch secrets path", "error", err, "path", s.SecretsPath)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("config: watching secrets for hot reload", "path", s.SecretsPath)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case event, ok := <-watcher.Events:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if event.Has(fsnotify.Create) || event.Has(fsnotify.Write) {
|
||||||
|
s.reloadFromSecrets()
|
||||||
|
}
|
||||||
|
case err, ok := <-watcher.Errors:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
slog.Error("config: fsnotify error", "error", err)
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reloadFromSecrets reads hot-reloadable values from the secrets directory.
|
||||||
|
func (s *Settings) reloadFromSecrets() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
updated := 0
|
||||||
|
reload := func(filename string, target *string) {
|
||||||
|
path := filepath.Join(s.SecretsPath, filename)
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
val := strings.TrimSpace(string(data))
|
||||||
|
if val != "" && val != *target {
|
||||||
|
*target = val
|
||||||
|
updated++
|
||||||
|
slog.Info("config: reloaded secret", "key", filename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reload("embeddings-url", &s.embeddingsURL)
|
||||||
|
reload("reranker-url", &s.rerankerURL)
|
||||||
|
reload("llm-url", &s.llmURL)
|
||||||
|
reload("tts-url", &s.ttsURL)
|
||||||
|
reload("stt-url", &s.sttURL)
|
||||||
|
|
||||||
|
if updated > 0 {
|
||||||
|
slog.Info("config: secrets reloaded", "updated", updated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getEnv(key, fallback string) string {
|
||||||
|
if v := os.Getenv(key); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
func getEnvInt(key string, fallback int) int {
|
||||||
|
if v := os.Getenv(key); v != "" {
|
||||||
|
if i, err := strconv.Atoi(v); err == nil {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
func getEnvBool(key string, fallback bool) bool {
|
||||||
|
if v := os.Getenv(key); v != "" {
|
||||||
|
if b, err := strconv.ParseBool(v); err == nil {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
func getEnvDuration(key string, fallback time.Duration) time.Duration {
|
||||||
|
if v := os.Getenv(key); v != "" {
|
||||||
|
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
||||||
|
return time.Duration(f * float64(time.Second))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
123
config/config_test.go
Normal file
123
config/config_test.go
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadDefaults(t *testing.T) {
|
||||||
|
s := Load()
|
||||||
|
if s.ServiceName != "handler" {
|
||||||
|
t.Errorf("expected default ServiceName 'handler', got %q", s.ServiceName)
|
||||||
|
}
|
||||||
|
if s.HealthPort != 8080 {
|
||||||
|
t.Errorf("expected default HealthPort 8080, got %d", s.HealthPort)
|
||||||
|
}
|
||||||
|
if s.HTTPTimeout != 60*time.Second {
|
||||||
|
t.Errorf("expected default HTTPTimeout 60s, got %v", s.HTTPTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadFromEnv(t *testing.T) {
|
||||||
|
t.Setenv("SERVICE_NAME", "test-svc")
|
||||||
|
t.Setenv("HEALTH_PORT", "9090")
|
||||||
|
t.Setenv("OTEL_ENABLED", "false")
|
||||||
|
|
||||||
|
s := Load()
|
||||||
|
if s.ServiceName != "test-svc" {
|
||||||
|
t.Errorf("expected ServiceName 'test-svc', got %q", s.ServiceName)
|
||||||
|
}
|
||||||
|
if s.HealthPort != 9090 {
|
||||||
|
t.Errorf("expected HealthPort 9090, got %d", s.HealthPort)
|
||||||
|
}
|
||||||
|
if s.OTELEnabled {
|
||||||
|
t.Error("expected OTELEnabled false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestURLGetters(t *testing.T) {
|
||||||
|
s := Load()
|
||||||
|
if s.EmbeddingsURL() == "" {
|
||||||
|
t.Error("EmbeddingsURL should have a default")
|
||||||
|
}
|
||||||
|
if s.RerankerURL() == "" {
|
||||||
|
t.Error("RerankerURL should have a default")
|
||||||
|
}
|
||||||
|
if s.LLMURL() == "" {
|
||||||
|
t.Error("LLMURL should have a default")
|
||||||
|
}
|
||||||
|
if s.TTSURL() == "" {
|
||||||
|
t.Error("TTSURL should have a default")
|
||||||
|
}
|
||||||
|
if s.STTURL() == "" {
|
||||||
|
t.Error("STTURL should have a default")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestURLGettersFromEnv(t *testing.T) {
|
||||||
|
t.Setenv("EMBEDDINGS_URL", "http://embed:8000")
|
||||||
|
t.Setenv("LLM_URL", "http://llm:9000")
|
||||||
|
|
||||||
|
s := Load()
|
||||||
|
if s.EmbeddingsURL() != "http://embed:8000" {
|
||||||
|
t.Errorf("expected custom EmbeddingsURL, got %q", s.EmbeddingsURL())
|
||||||
|
}
|
||||||
|
if s.LLMURL() != "http://llm:9000" {
|
||||||
|
t.Errorf("expected custom LLMURL, got %q", s.LLMURL())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReloadFromSecrets(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
|
||||||
|
// Write initial secret files
|
||||||
|
writeSecret(t, dir, "embeddings-url", "http://old-embed:8000")
|
||||||
|
writeSecret(t, dir, "llm-url", "http://old-llm:9000")
|
||||||
|
|
||||||
|
s := Load()
|
||||||
|
s.SecretsPath = dir
|
||||||
|
s.reloadFromSecrets()
|
||||||
|
|
||||||
|
if s.EmbeddingsURL() != "http://old-embed:8000" {
|
||||||
|
t.Errorf("expected reloaded EmbeddingsURL, got %q", s.EmbeddingsURL())
|
||||||
|
}
|
||||||
|
if s.LLMURL() != "http://old-llm:9000" {
|
||||||
|
t.Errorf("expected reloaded LLMURL, got %q", s.LLMURL())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate secret update
|
||||||
|
writeSecret(t, dir, "embeddings-url", "http://new-embed:8000")
|
||||||
|
s.reloadFromSecrets()
|
||||||
|
|
||||||
|
if s.EmbeddingsURL() != "http://new-embed:8000" {
|
||||||
|
t.Errorf("expected updated EmbeddingsURL, got %q", s.EmbeddingsURL())
|
||||||
|
}
|
||||||
|
// LLM should remain unchanged
|
||||||
|
if s.LLMURL() != "http://old-llm:9000" {
|
||||||
|
t.Errorf("expected unchanged LLMURL, got %q", s.LLMURL())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReloadFromSecretsNoPath(t *testing.T) {
|
||||||
|
s := Load()
|
||||||
|
s.SecretsPath = ""
|
||||||
|
// Should not panic
|
||||||
|
s.reloadFromSecrets()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetEnvDuration(t *testing.T) {
|
||||||
|
t.Setenv("TEST_DUR", "30")
|
||||||
|
d := getEnvDuration("TEST_DUR", 10*time.Second)
|
||||||
|
if d != 30*time.Second {
|
||||||
|
t.Errorf("expected 30s, got %v", d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeSecret(t *testing.T, dir, name, value string) {
|
||||||
|
t.Helper()
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, name), []byte(value), 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
40
go.mod
Normal file
40
go.mod
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
module git.daviestechlabs.io/daviestechlabs/handler-base
|
||||||
|
|
||||||
|
go 1.25.1
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/nats-io/nats.go v1.48.0
|
||||||
|
github.com/vmihailenco/msgpack/v5 v5.4.1
|
||||||
|
go.opentelemetry.io/otel v1.40.0
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.40.0
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0
|
||||||
|
go.opentelemetry.io/otel/metric v1.40.0
|
||||||
|
go.opentelemetry.io/otel/sdk v1.40.0
|
||||||
|
go.opentelemetry.io/otel/sdk/metric v1.40.0
|
||||||
|
go.opentelemetry.io/otel/trace v1.40.0
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||||
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7 // indirect
|
||||||
|
github.com/klauspost/compress v1.18.0 // indirect
|
||||||
|
github.com/nats-io/nkeys v0.4.11 // indirect
|
||||||
|
github.com/nats-io/nuid v1.0.1 // indirect
|
||||||
|
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||||
|
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 // indirect
|
||||||
|
go.opentelemetry.io/proto/otlp v1.9.0 // indirect
|
||||||
|
golang.org/x/crypto v0.47.0 // indirect
|
||||||
|
golang.org/x/net v0.49.0 // indirect
|
||||||
|
golang.org/x/sys v0.40.0 // indirect
|
||||||
|
golang.org/x/text v0.33.0 // indirect
|
||||||
|
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 // indirect
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 // indirect
|
||||||
|
google.golang.org/grpc v1.78.0 // indirect
|
||||||
|
google.golang.org/protobuf v1.36.11 // indirect
|
||||||
|
)
|
||||||
79
go.sum
Normal file
79
go.sum
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
|
||||||
|
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||||
|
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||||
|
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||||
|
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||||
|
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||||
|
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||||
|
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||||
|
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||||
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7 h1:X+2YciYSxvMQK0UZ7sg45ZVabVZBeBuvMkmuI2V3Fak=
|
||||||
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7/go.mod h1:lW34nIZuQ8UDPdkon5fmfp2l3+ZkQ2me/+oecHYLOII=
|
||||||
|
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||||
|
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||||
|
github.com/nats-io/nats.go v1.48.0 h1:pSFyXApG+yWU/TgbKCjmm5K4wrHu86231/w84qRVR+U=
|
||||||
|
github.com/nats-io/nats.go v1.48.0/go.mod h1:iRWIPokVIFbVijxuMQq4y9ttaBTMe0SFdlZfMDd+33g=
|
||||||
|
github.com/nats-io/nkeys v0.4.11 h1:q44qGV008kYd9W1b1nEBkNzvnWxtRSQ7A8BoqRrcfa0=
|
||||||
|
github.com/nats-io/nkeys v0.4.11/go.mod h1:szDimtgmfOi9n25JpfIdGw12tZFYXqhGxjhVxsatHVE=
|
||||||
|
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
|
||||||
|
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
|
||||||
|
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
||||||
|
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||||
|
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||||
|
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||||
|
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||||
|
go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms=
|
||||||
|
go.opentelemetry.io/otel v1.40.0/go.mod h1:IMb+uXZUKkMXdPddhwAHm6UfOwJyh4ct1ybIlV14J0g=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.40.0 h1:NOyNnS19BF2SUDApbOKbDtWZ0IK7b8FJ2uAGdIWOGb0=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.40.0/go.mod h1:VL6EgVikRLcJa9ftukrHu/ZkkhFBSo1lzvdBC9CF1ss=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 h1:QKdN8ly8zEMrByybbQgv8cWBcdAarwmIPZ6FThrWXJs=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0/go.mod h1:bTdK1nhqF76qiPoCCdyFIV+N/sRHYXYCTQc+3VCi3MI=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0 h1:DvJDOPmSWQHWywQS6lKL+pb8s3gBLOZUtw4N+mavW1I=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0/go.mod h1:EtekO9DEJb4/jRyN4v4Qjc2yA7AtfCBuz2FynRUWTXs=
|
||||||
|
go.opentelemetry.io/otel/metric v1.40.0 h1:rcZe317KPftE2rstWIBitCdVp89A2HqjkxR3c11+p9g=
|
||||||
|
go.opentelemetry.io/otel/metric v1.40.0/go.mod h1:ib/crwQH7N3r5kfiBZQbwrTge743UDc7DTFVZrrXnqc=
|
||||||
|
go.opentelemetry.io/otel/sdk v1.40.0 h1:KHW/jUzgo6wsPh9At46+h4upjtccTmuZCFAc9OJ71f8=
|
||||||
|
go.opentelemetry.io/otel/sdk v1.40.0/go.mod h1:Ph7EFdYvxq72Y8Li9q8KebuYUr2KoeyHx0DRMKrYBUE=
|
||||||
|
go.opentelemetry.io/otel/sdk/metric v1.40.0 h1:mtmdVqgQkeRxHgRv4qhyJduP3fYJRMX4AtAlbuWdCYw=
|
||||||
|
go.opentelemetry.io/otel/sdk/metric v1.40.0/go.mod h1:4Z2bGMf0KSK3uRjlczMOeMhKU2rhUqdWNoKcYrtcBPg=
|
||||||
|
go.opentelemetry.io/otel/trace v1.40.0 h1:WA4etStDttCSYuhwvEa8OP8I5EWu24lkOzp+ZYblVjw=
|
||||||
|
go.opentelemetry.io/otel/trace v1.40.0/go.mod h1:zeAhriXecNGP/s2SEG3+Y8X9ujcJOTqQ5RgdEJcawiA=
|
||||||
|
go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A=
|
||||||
|
go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4=
|
||||||
|
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||||
|
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||||
|
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
|
||||||
|
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
|
||||||
|
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
|
||||||
|
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
|
||||||
|
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||||
|
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
|
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
||||||
|
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
|
||||||
|
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||||
|
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||||
|
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 h1:merA0rdPeUV3YIIfHHcH4qBkiQAc1nfCKSI7lB4cV2M=
|
||||||
|
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409/go.mod h1:fl8J1IvUjCilwZzQowmw2b7HQB2eAuYBabMXzWurF+I=
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 h1:H86B94AW+VfJWDqFeEbBPhEtHzJwJfTbgE2lZa54ZAQ=
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ=
|
||||||
|
google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
|
||||||
|
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
|
||||||
|
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||||
|
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
231
handler/handler.go
Normal file
231
handler/handler.go
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
// Package handler provides the base Handler pattern for NATS message-driven services.
|
||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/nats-io/nats.go"
|
||||||
|
|
||||||
|
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
||||||
|
"git.daviestechlabs.io/daviestechlabs/handler-base/health"
|
||||||
|
"git.daviestechlabs.io/daviestechlabs/handler-base/natsutil"
|
||||||
|
"git.daviestechlabs.io/daviestechlabs/handler-base/telemetry"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MessageHandler is the callback for processing decoded NATS messages.
|
||||||
|
// data is the msgpack-decoded map. Return a response map (or nil for no reply).
|
||||||
|
type MessageHandler func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error)
|
||||||
|
|
||||||
|
// TypedMessageHandler processes the raw NATS message without pre-decoding to
|
||||||
|
// map[string]any. Services unmarshal msg.Data into their own typed structs,
|
||||||
|
// avoiding the double-decode overhead. Return any msgpack-serialisable value
|
||||||
|
// (a typed struct, map, or nil for no reply).
|
||||||
|
type TypedMessageHandler func(ctx context.Context, msg *nats.Msg) (any, error)
|
||||||
|
|
||||||
|
// SetupFunc is called once before the handler starts processing messages.
|
||||||
|
type SetupFunc func(ctx context.Context) error
|
||||||
|
|
||||||
|
// TeardownFunc is called during graceful shutdown.
|
||||||
|
type TeardownFunc func(ctx context.Context) error
|
||||||
|
|
||||||
|
// Handler is the base service runner that wires NATS, health, and telemetry.
|
||||||
|
type Handler struct {
|
||||||
|
Settings *config.Settings
|
||||||
|
NATS *natsutil.Client
|
||||||
|
Telemetry *telemetry.Provider
|
||||||
|
Subject string
|
||||||
|
QueueGroup string
|
||||||
|
|
||||||
|
onSetup SetupFunc
|
||||||
|
onTeardown TeardownFunc
|
||||||
|
onMessage MessageHandler
|
||||||
|
onTypedMessage TypedMessageHandler
|
||||||
|
running bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a Handler for the given NATS subject.
|
||||||
|
func New(subject string, settings *config.Settings) *Handler {
|
||||||
|
if settings == nil {
|
||||||
|
settings = config.Load()
|
||||||
|
}
|
||||||
|
queueGroup := settings.NATSQueueGroup
|
||||||
|
|
||||||
|
natsOpts := []nats.Option{}
|
||||||
|
if settings.NATSUser != "" && settings.NATSPassword != "" {
|
||||||
|
natsOpts = append(natsOpts, nats.UserInfo(settings.NATSUser, settings.NATSPassword))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Handler{
|
||||||
|
Settings: settings,
|
||||||
|
Subject: subject,
|
||||||
|
QueueGroup: queueGroup,
|
||||||
|
NATS: natsutil.New(settings.NATSURL, natsOpts...),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnSetup registers the setup callback.
|
||||||
|
func (h *Handler) OnSetup(fn SetupFunc) { h.onSetup = fn }
|
||||||
|
|
||||||
|
// OnTeardown registers the teardown callback.
|
||||||
|
func (h *Handler) OnTeardown(fn TeardownFunc) { h.onTeardown = fn }
|
||||||
|
|
||||||
|
// OnMessage registers the message handler callback.
|
||||||
|
func (h *Handler) OnMessage(fn MessageHandler) { h.onMessage = fn }
|
||||||
|
|
||||||
|
// OnTypedMessage registers a typed message handler. It replaces OnMessage —
|
||||||
|
// wrapHandler will skip the map[string]any decode and let the callback
|
||||||
|
// unmarshal msg.Data directly.
|
||||||
|
func (h *Handler) OnTypedMessage(fn TypedMessageHandler) { h.onTypedMessage = fn }
|
||||||
|
|
||||||
|
// Run starts the handler: telemetry, health server, NATS subscription, and blocks until SIGTERM/SIGINT.
|
||||||
|
func (h *Handler) Run() error {
|
||||||
|
// Structured logging
|
||||||
|
slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})))
|
||||||
|
slog.Info("starting service", "name", h.Settings.ServiceName, "version", h.Settings.ServiceVersion)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Telemetry
|
||||||
|
tp, shutdown, err := telemetry.Setup(ctx, telemetry.Config{
|
||||||
|
ServiceName: h.Settings.ServiceName,
|
||||||
|
ServiceVersion: h.Settings.ServiceVersion,
|
||||||
|
ServiceNamespace: h.Settings.ServiceNamespace,
|
||||||
|
DeploymentEnv: h.Settings.DeploymentEnv,
|
||||||
|
Enabled: h.Settings.OTELEnabled,
|
||||||
|
Endpoint: h.Settings.OTELEndpoint,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("telemetry setup: %w", err)
|
||||||
|
}
|
||||||
|
defer shutdown(ctx)
|
||||||
|
h.Telemetry = tp
|
||||||
|
|
||||||
|
// Health server
|
||||||
|
healthSrv := health.New(
|
||||||
|
h.Settings.HealthPort,
|
||||||
|
h.Settings.HealthPath,
|
||||||
|
h.Settings.ReadyPath,
|
||||||
|
func() bool { return h.running && h.NATS.IsConnected() },
|
||||||
|
)
|
||||||
|
healthSrv.Start()
|
||||||
|
defer healthSrv.Stop(ctx)
|
||||||
|
|
||||||
|
// Connect to NATS
|
||||||
|
if err := h.NATS.Connect(); err != nil {
|
||||||
|
return fmt.Errorf("nats: %w", err)
|
||||||
|
}
|
||||||
|
defer h.NATS.Close()
|
||||||
|
|
||||||
|
// User setup
|
||||||
|
if h.onSetup != nil {
|
||||||
|
slog.Info("running service setup")
|
||||||
|
if err := h.onSetup(ctx); err != nil {
|
||||||
|
return fmt.Errorf("setup: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe
|
||||||
|
if h.onMessage == nil && h.onTypedMessage == nil {
|
||||||
|
return fmt.Errorf("no message handler registered")
|
||||||
|
}
|
||||||
|
if err := h.NATS.Subscribe(h.Subject, h.wrapHandler(ctx), h.QueueGroup); err != nil {
|
||||||
|
return fmt.Errorf("subscribe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.running = true
|
||||||
|
slog.Info("handler ready", "subject", h.Subject)
|
||||||
|
|
||||||
|
// Wait for shutdown signal
|
||||||
|
sigCh := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
|
||||||
|
<-sigCh
|
||||||
|
|
||||||
|
slog.Info("shutting down")
|
||||||
|
h.running = false
|
||||||
|
|
||||||
|
// Teardown
|
||||||
|
if h.onTeardown != nil {
|
||||||
|
if err := h.onTeardown(ctx); err != nil {
|
||||||
|
slog.Warn("teardown error", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("shutdown complete")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapHandler creates a nats.MsgHandler that dispatches to the registered callback.
|
||||||
|
// If OnTypedMessage was used, msg.Data is passed directly without map decode.
|
||||||
|
// If OnMessage was used, msg.Data is decoded to map[string]any first.
|
||||||
|
func (h *Handler) wrapHandler(ctx context.Context) nats.MsgHandler {
|
||||||
|
if h.onTypedMessage != nil {
|
||||||
|
return h.wrapTypedHandler(ctx)
|
||||||
|
}
|
||||||
|
return h.wrapMapHandler(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapTypedHandler dispatches to the TypedMessageHandler (no map decode).
|
||||||
|
func (h *Handler) wrapTypedHandler(ctx context.Context) nats.MsgHandler {
|
||||||
|
return func(msg *nats.Msg) {
|
||||||
|
response, err := h.onTypedMessage(ctx, msg)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("handler error", "subject", msg.Subject, "error", err)
|
||||||
|
if msg.Reply != "" {
|
||||||
|
_ = h.NATS.Publish(msg.Reply, map[string]any{
|
||||||
|
"error": true,
|
||||||
|
"message": err.Error(),
|
||||||
|
"type": fmt.Sprintf("%T", err),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if response != nil && msg.Reply != "" {
|
||||||
|
if err := h.NATS.Publish(msg.Reply, response); err != nil {
|
||||||
|
slog.Error("failed to publish reply", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapMapHandler dispatches to the legacy MessageHandler (decodes to map first).
|
||||||
|
func (h *Handler) wrapMapHandler(ctx context.Context) nats.MsgHandler {
|
||||||
|
return func(msg *nats.Msg) {
|
||||||
|
data, err := natsutil.DecodeMsgpackMap(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to decode message", "subject", msg.Subject, "error", err)
|
||||||
|
if msg.Reply != "" {
|
||||||
|
_ = h.NATS.Publish(msg.Reply, map[string]any{
|
||||||
|
"error": true,
|
||||||
|
"message": err.Error(),
|
||||||
|
"type": "DecodeError",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := h.onMessage(ctx, msg, data)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("handler error", "subject", msg.Subject, "error", err)
|
||||||
|
if msg.Reply != "" {
|
||||||
|
_ = h.NATS.Publish(msg.Reply, map[string]any{
|
||||||
|
"error": true,
|
||||||
|
"message": err.Error(),
|
||||||
|
"type": fmt.Sprintf("%T", err),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if response != nil && msg.Reply != "" {
|
||||||
|
if err := h.NATS.Publish(msg.Reply, response); err != nil {
|
||||||
|
slog.Error("failed to publish reply", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
311
handler/handler_test.go
Normal file
311
handler/handler_test.go
Normal file
@@ -0,0 +1,311 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/nats-io/nats.go"
|
||||||
|
"github.com/vmihailenco/msgpack/v5"
|
||||||
|
|
||||||
|
"git.daviestechlabs.io/daviestechlabs/handler-base/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Handler construction tests
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestNewHandler(t *testing.T) {
|
||||||
|
cfg := config.Load()
|
||||||
|
cfg.ServiceName = "test-handler"
|
||||||
|
cfg.NATSQueueGroup = "test-group"
|
||||||
|
|
||||||
|
h := New("ai.test.subject", cfg)
|
||||||
|
if h.Subject != "ai.test.subject" {
|
||||||
|
t.Errorf("Subject = %q", h.Subject)
|
||||||
|
}
|
||||||
|
if h.QueueGroup != "test-group" {
|
||||||
|
t.Errorf("QueueGroup = %q", h.QueueGroup)
|
||||||
|
}
|
||||||
|
if h.Settings.ServiceName != "test-handler" {
|
||||||
|
t.Errorf("ServiceName = %q", h.Settings.ServiceName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHandlerNilSettings(t *testing.T) {
|
||||||
|
h := New("ai.test", nil)
|
||||||
|
if h.Settings == nil {
|
||||||
|
t.Fatal("Settings should be loaded automatically")
|
||||||
|
}
|
||||||
|
if h.Settings.ServiceName != "handler" {
|
||||||
|
t.Errorf("ServiceName = %q, want default", h.Settings.ServiceName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallbackRegistration(t *testing.T) {
|
||||||
|
cfg := config.Load()
|
||||||
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
|
setupCalled := false
|
||||||
|
h.OnSetup(func(ctx context.Context) error {
|
||||||
|
setupCalled = true
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
teardownCalled := false
|
||||||
|
h.OnTeardown(func(ctx context.Context) error {
|
||||||
|
teardownCalled = true
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
||||||
|
return nil, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if h.onSetup == nil || h.onTeardown == nil || h.onMessage == nil {
|
||||||
|
t.Error("callbacks should not be nil after registration")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify setup/teardown work when called directly.
|
||||||
|
_ = h.onSetup(context.Background())
|
||||||
|
_ = h.onTeardown(context.Background())
|
||||||
|
if !setupCalled || !teardownCalled {
|
||||||
|
t.Error("callbacks should have been invoked")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypedMessageRegistration(t *testing.T) {
|
||||||
|
cfg := config.Load()
|
||||||
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
|
||||||
|
return map[string]any{"ok": true}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if h.onTypedMessage == nil {
|
||||||
|
t.Error("onTypedMessage should not be nil after registration")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// wrapHandler dispatch tests (unit test the message decode + dispatch logic)
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestWrapHandler_ValidMessage(t *testing.T) {
|
||||||
|
cfg := config.Load()
|
||||||
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
|
var receivedData map[string]any
|
||||||
|
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
||||||
|
receivedData = data
|
||||||
|
return map[string]any{"status": "ok"}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Encode a message the same way services would.
|
||||||
|
payload := map[string]any{
|
||||||
|
"request_id": "test-001",
|
||||||
|
"message": "hello",
|
||||||
|
"premium": true,
|
||||||
|
}
|
||||||
|
encoded, err := msgpack.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call wrapHandler directly without NATS.
|
||||||
|
handler := h.wrapHandler(context.Background())
|
||||||
|
handler(&nats.Msg{
|
||||||
|
Subject: "ai.test.user.42.message",
|
||||||
|
Data: encoded,
|
||||||
|
})
|
||||||
|
|
||||||
|
if receivedData == nil {
|
||||||
|
t.Fatal("handler was not called")
|
||||||
|
}
|
||||||
|
if receivedData["request_id"] != "test-001" {
|
||||||
|
t.Errorf("request_id = %v", receivedData["request_id"])
|
||||||
|
}
|
||||||
|
if receivedData["premium"] != true {
|
||||||
|
t.Errorf("premium = %v", receivedData["premium"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapHandler_InvalidMsgpack(t *testing.T) {
|
||||||
|
cfg := config.Load()
|
||||||
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
|
handlerCalled := false
|
||||||
|
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
||||||
|
handlerCalled = true
|
||||||
|
return nil, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := h.wrapHandler(context.Background())
|
||||||
|
handler(&nats.Msg{
|
||||||
|
Subject: "ai.test",
|
||||||
|
Data: []byte{0xFF, 0xFE, 0xFD}, // invalid msgpack
|
||||||
|
})
|
||||||
|
|
||||||
|
if handlerCalled {
|
||||||
|
t.Error("handler should not be called for invalid msgpack")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapHandler_HandlerError(t *testing.T) {
|
||||||
|
cfg := config.Load()
|
||||||
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
|
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
||||||
|
return nil, context.DeadlineExceeded
|
||||||
|
})
|
||||||
|
|
||||||
|
encoded, _ := msgpack.Marshal(map[string]any{"key": "val"})
|
||||||
|
handler := h.wrapHandler(context.Background())
|
||||||
|
|
||||||
|
// Should not panic even when handler returns error.
|
||||||
|
handler(&nats.Msg{
|
||||||
|
Subject: "ai.test",
|
||||||
|
Data: encoded,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapHandler_NilResponse(t *testing.T) {
|
||||||
|
cfg := config.Load()
|
||||||
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
|
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
||||||
|
return nil, nil // fire-and-forget style
|
||||||
|
})
|
||||||
|
|
||||||
|
encoded, _ := msgpack.Marshal(map[string]any{"x": 1})
|
||||||
|
handler := h.wrapHandler(context.Background())
|
||||||
|
|
||||||
|
// Should not panic with nil response and no reply subject.
|
||||||
|
handler(&nats.Msg{
|
||||||
|
Subject: "ai.test",
|
||||||
|
Data: encoded,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// wrapHandler dispatch tests — typed handler path
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestWrapTypedHandler_ValidMessage(t *testing.T) {
|
||||||
|
cfg := config.Load()
|
||||||
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
|
type testReq struct {
|
||||||
|
RequestID string `msgpack:"request_id"`
|
||||||
|
Message string `msgpack:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var received testReq
|
||||||
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
|
||||||
|
if err := msgpack.Unmarshal(msg.Data, &received); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return map[string]any{"status": "ok"}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
encoded, _ := msgpack.Marshal(map[string]any{
|
||||||
|
"request_id": "typed-001",
|
||||||
|
"message": "hello typed",
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := h.wrapHandler(context.Background())
|
||||||
|
handler(&nats.Msg{Subject: "ai.test", Data: encoded})
|
||||||
|
|
||||||
|
if received.RequestID != "typed-001" {
|
||||||
|
t.Errorf("RequestID = %q", received.RequestID)
|
||||||
|
}
|
||||||
|
if received.Message != "hello typed" {
|
||||||
|
t.Errorf("Message = %q", received.Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapTypedHandler_Error(t *testing.T) {
|
||||||
|
cfg := config.Load()
|
||||||
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
|
||||||
|
return nil, context.DeadlineExceeded
|
||||||
|
})
|
||||||
|
|
||||||
|
encoded, _ := msgpack.Marshal(map[string]any{"key": "val"})
|
||||||
|
handler := h.wrapHandler(context.Background())
|
||||||
|
|
||||||
|
// Should not panic.
|
||||||
|
handler(&nats.Msg{Subject: "ai.test", Data: encoded})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapTypedHandler_NilResponse(t *testing.T) {
|
||||||
|
cfg := config.Load()
|
||||||
|
h := New("ai.test", cfg)
|
||||||
|
|
||||||
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
|
||||||
|
return nil, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
encoded, _ := msgpack.Marshal(map[string]any{"x": 1})
|
||||||
|
handler := h.wrapHandler(context.Background())
|
||||||
|
handler(&nats.Msg{Subject: "ai.test", Data: encoded})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Benchmark: message decode + dispatch overhead
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func BenchmarkWrapHandler(b *testing.B) {
|
||||||
|
cfg := config.Load()
|
||||||
|
h := New("ai.test", cfg)
|
||||||
|
h.OnMessage(func(ctx context.Context, msg *nats.Msg, data map[string]any) (map[string]any, error) {
|
||||||
|
return map[string]any{"ok": true}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
payload := map[string]any{
|
||||||
|
"request_id": "bench-001",
|
||||||
|
"message": "What is the capital of France?",
|
||||||
|
"premium": true,
|
||||||
|
"top_k": 10,
|
||||||
|
}
|
||||||
|
encoded, _ := msgpack.Marshal(payload)
|
||||||
|
handler := h.wrapHandler(context.Background())
|
||||||
|
msg := &nats.Msg{Subject: "ai.test", Data: encoded}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
handler(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWrapTypedHandler(b *testing.B) {
|
||||||
|
type benchReq struct {
|
||||||
|
RequestID string `msgpack:"request_id"`
|
||||||
|
Message string `msgpack:"message"`
|
||||||
|
Premium bool `msgpack:"premium"`
|
||||||
|
TopK int `msgpack:"top_k"`
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := config.Load()
|
||||||
|
h := New("ai.test", cfg)
|
||||||
|
h.OnTypedMessage(func(ctx context.Context, msg *nats.Msg) (any, error) {
|
||||||
|
var req benchReq
|
||||||
|
_ = msgpack.Unmarshal(msg.Data, &req)
|
||||||
|
return map[string]any{"ok": true}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
payload := map[string]any{
|
||||||
|
"request_id": "bench-001",
|
||||||
|
"message": "What is the capital of France?",
|
||||||
|
"premium": true,
|
||||||
|
"top_k": 10,
|
||||||
|
}
|
||||||
|
encoded, _ := msgpack.Marshal(payload)
|
||||||
|
handler := h.wrapHandler(context.Background())
|
||||||
|
msg := &nats.Msg{Subject: "ai.test", Data: encoded}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
handler(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
"""
|
|
||||||
Handler Base - Shared utilities for AI/ML handler services.
|
|
||||||
|
|
||||||
Provides consistent patterns for:
|
|
||||||
- OpenTelemetry tracing and metrics
|
|
||||||
- NATS messaging
|
|
||||||
- Health checks
|
|
||||||
- Graceful shutdown
|
|
||||||
- Service client wrappers
|
|
||||||
"""
|
|
||||||
|
|
||||||
from handler_base.config import Settings
|
|
||||||
from handler_base.handler import Handler
|
|
||||||
from handler_base.health import HealthServer
|
|
||||||
from handler_base.nats_client import NATSClient
|
|
||||||
from handler_base.telemetry import get_meter, get_tracer, setup_telemetry
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Handler",
|
|
||||||
"Settings",
|
|
||||||
"HealthServer",
|
|
||||||
"NATSClient",
|
|
||||||
"setup_telemetry",
|
|
||||||
"get_tracer",
|
|
||||||
"get_meter",
|
|
||||||
]
|
|
||||||
|
|
||||||
__version__ = "1.0.0"
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
"""
|
|
||||||
Service client wrappers for AI/ML backends.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from handler_base.clients.embeddings import EmbeddingsClient
|
|
||||||
from handler_base.clients.llm import LLMClient
|
|
||||||
from handler_base.clients.milvus import MilvusClient
|
|
||||||
from handler_base.clients.reranker import RerankerClient
|
|
||||||
from handler_base.clients.stt import STTClient
|
|
||||||
from handler_base.clients.tts import TTSClient
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"EmbeddingsClient",
|
|
||||||
"RerankerClient",
|
|
||||||
"LLMClient",
|
|
||||||
"TTSClient",
|
|
||||||
"STTClient",
|
|
||||||
"MilvusClient",
|
|
||||||
]
|
|
||||||
@@ -1,128 +0,0 @@
|
|||||||
"""
|
|
||||||
Embeddings service client (Infinity/BGE).
|
|
||||||
|
|
||||||
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from handler_base.config import EmbeddingsSettings
|
|
||||||
from handler_base.ray_utils import get_ray_handle
|
|
||||||
from handler_base.telemetry import create_span
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsClient:
|
|
||||||
"""
|
|
||||||
Client for the embeddings service (Infinity with BGE models).
|
|
||||||
|
|
||||||
When running inside Ray, automatically uses Ray handles for faster
|
|
||||||
internal communication. Falls back to HTTP for external calls.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
client = EmbeddingsClient()
|
|
||||||
embeddings = await client.embed(["Hello world"])
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Ray Serve deployment configuration
|
|
||||||
RAY_DEPLOYMENT_NAME = "EmbeddingsDeployment"
|
|
||||||
RAY_APP_NAME = "embeddings"
|
|
||||||
|
|
||||||
def __init__(self, settings: Optional[EmbeddingsSettings] = None):
|
|
||||||
self.settings = settings or EmbeddingsSettings()
|
|
||||||
self._client = httpx.AsyncClient(
|
|
||||||
base_url=self.settings.embeddings_url,
|
|
||||||
timeout=self.settings.http_timeout,
|
|
||||||
)
|
|
||||||
self._ray_handle: Optional[Any] = None
|
|
||||||
self._ray_checked = False
|
|
||||||
|
|
||||||
def _get_ray_handle(self) -> Optional[Any]:
|
|
||||||
"""Get Ray handle, checking only once."""
|
|
||||||
if not self._ray_checked:
|
|
||||||
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
|
|
||||||
self._ray_checked = True
|
|
||||||
return self._ray_handle
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""Close the HTTP client."""
|
|
||||||
await self._client.aclose()
|
|
||||||
|
|
||||||
async def embed(
|
|
||||||
self,
|
|
||||||
texts: list[str],
|
|
||||||
model: Optional[str] = None,
|
|
||||||
) -> list[list[float]]:
|
|
||||||
"""
|
|
||||||
Generate embeddings for a list of texts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: List of texts to embed
|
|
||||||
model: Model name (defaults to settings)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of embedding vectors
|
|
||||||
"""
|
|
||||||
model = model or self.settings.embeddings_model
|
|
||||||
|
|
||||||
with create_span("embeddings.embed") as span:
|
|
||||||
if span:
|
|
||||||
span.set_attribute("embeddings.model", model)
|
|
||||||
span.set_attribute("embeddings.batch_size", len(texts))
|
|
||||||
|
|
||||||
# Try Ray handle first (faster internal path)
|
|
||||||
handle = self._get_ray_handle()
|
|
||||||
if handle:
|
|
||||||
try:
|
|
||||||
if span:
|
|
||||||
span.set_attribute("embeddings.transport", "ray")
|
|
||||||
result = await handle.embed.remote(texts, model)
|
|
||||||
if span and result:
|
|
||||||
span.set_attribute("embeddings.dimensions", len(result[0]) if result else 0)
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
|
|
||||||
|
|
||||||
# HTTP fallback
|
|
||||||
if span:
|
|
||||||
span.set_attribute("embeddings.transport", "http")
|
|
||||||
|
|
||||||
response = await self._client.post(
|
|
||||||
"/embeddings",
|
|
||||||
json={"input": texts, "model": model},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
embeddings = [d["embedding"] for d in result.get("data", [])]
|
|
||||||
|
|
||||||
if span:
|
|
||||||
span.set_attribute("embeddings.dimensions", len(embeddings[0]) if embeddings else 0)
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
async def embed_single(self, text: str, model: Optional[str] = None) -> list[float]:
|
|
||||||
"""
|
|
||||||
Generate embedding for a single text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Text to embed
|
|
||||||
model: Model name (defaults to settings)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Embedding vector
|
|
||||||
"""
|
|
||||||
embeddings = await self.embed([text], model)
|
|
||||||
return embeddings[0] if embeddings else []
|
|
||||||
|
|
||||||
async def health(self) -> bool:
|
|
||||||
"""Check if the embeddings service is healthy."""
|
|
||||||
try:
|
|
||||||
response = await self._client.get("/health")
|
|
||||||
return response.status_code == 200
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
@@ -1,233 +0,0 @@
|
|||||||
"""
|
|
||||||
LLM service client (vLLM/OpenAI-compatible).
|
|
||||||
|
|
||||||
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, AsyncIterator, Optional
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from handler_base.config import LLMSettings
|
|
||||||
from handler_base.ray_utils import get_ray_handle
|
|
||||||
from handler_base.telemetry import create_span
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LLMClient:
|
|
||||||
"""
|
|
||||||
Client for the LLM service (vLLM with OpenAI-compatible API).
|
|
||||||
|
|
||||||
When running inside Ray, automatically uses Ray handles for faster
|
|
||||||
internal communication. Falls back to HTTP for external calls.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
client = LLMClient()
|
|
||||||
response = await client.generate("Hello, how are you?")
|
|
||||||
|
|
||||||
# With context for RAG
|
|
||||||
response = await client.generate(
|
|
||||||
"What is the capital?",
|
|
||||||
context="France is a country in Europe..."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Streaming
|
|
||||||
async for chunk in client.stream("Tell me a story"):
|
|
||||||
print(chunk, end="")
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Ray Serve deployment configuration
|
|
||||||
RAY_DEPLOYMENT_NAME = "VLLMDeployment"
|
|
||||||
RAY_APP_NAME = "llm"
|
|
||||||
|
|
||||||
def __init__(self, settings: Optional[LLMSettings] = None):
|
|
||||||
self.settings = settings or LLMSettings()
|
|
||||||
self._client = httpx.AsyncClient(
|
|
||||||
base_url=self.settings.llm_url,
|
|
||||||
timeout=self.settings.http_timeout,
|
|
||||||
)
|
|
||||||
self._ray_handle: Optional[Any] = None
|
|
||||||
self._ray_checked = False
|
|
||||||
|
|
||||||
def _get_ray_handle(self) -> Optional[Any]:
|
|
||||||
"""Get Ray handle, checking only once."""
|
|
||||||
if not self._ray_checked:
|
|
||||||
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
|
|
||||||
self._ray_checked = True
|
|
||||||
return self._ray_handle
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""Close the HTTP client."""
|
|
||||||
await self._client.aclose()
|
|
||||||
|
|
||||||
async def generate(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
context: Optional[str] = None,
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Generate a response from the LLM.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: User prompt/query
|
|
||||||
context: Optional context for RAG
|
|
||||||
system_prompt: Optional system prompt
|
|
||||||
max_tokens: Maximum tokens to generate
|
|
||||||
temperature: Sampling temperature
|
|
||||||
top_p: Top-p sampling
|
|
||||||
stop: Stop sequences
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Generated text response
|
|
||||||
"""
|
|
||||||
with create_span("llm.generate") as span:
|
|
||||||
messages = self._build_messages(prompt, context, system_prompt)
|
|
||||||
|
|
||||||
if span:
|
|
||||||
span.set_attribute("llm.model", self.settings.llm_model)
|
|
||||||
span.set_attribute("llm.prompt_length", len(prompt))
|
|
||||||
if context:
|
|
||||||
span.set_attribute("llm.context_length", len(context))
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"model": self.settings.llm_model,
|
|
||||||
"messages": messages,
|
|
||||||
"max_tokens": max_tokens or self.settings.llm_max_tokens,
|
|
||||||
"temperature": temperature or self.settings.llm_temperature,
|
|
||||||
"top_p": top_p or self.settings.llm_top_p,
|
|
||||||
}
|
|
||||||
if stop:
|
|
||||||
payload["stop"] = stop
|
|
||||||
|
|
||||||
# Try Ray handle first (faster internal path)
|
|
||||||
handle = self._get_ray_handle()
|
|
||||||
if handle:
|
|
||||||
try:
|
|
||||||
if span:
|
|
||||||
span.set_attribute("llm.transport", "ray")
|
|
||||||
result = await handle.remote(payload)
|
|
||||||
content = result["choices"][0]["message"]["content"]
|
|
||||||
if span:
|
|
||||||
span.set_attribute("llm.response_length", len(content))
|
|
||||||
return content
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
|
|
||||||
|
|
||||||
# HTTP fallback
|
|
||||||
if span:
|
|
||||||
span.set_attribute("llm.transport", "http")
|
|
||||||
|
|
||||||
response = await self._client.post("/v1/chat/completions", json=payload)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
content = result["choices"][0]["message"]["content"]
|
|
||||||
|
|
||||||
if span:
|
|
||||||
span.set_attribute("llm.response_length", len(content))
|
|
||||||
usage = result.get("usage", {})
|
|
||||||
span.set_attribute("llm.prompt_tokens", usage.get("prompt_tokens", 0))
|
|
||||||
span.set_attribute("llm.completion_tokens", usage.get("completion_tokens", 0))
|
|
||||||
|
|
||||||
return content
|
|
||||||
|
|
||||||
async def stream(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
context: Optional[str] = None,
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
) -> AsyncIterator[str]:
|
|
||||||
"""
|
|
||||||
Stream a response from the LLM.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: User prompt/query
|
|
||||||
context: Optional context for RAG
|
|
||||||
system_prompt: Optional system prompt
|
|
||||||
max_tokens: Maximum tokens to generate
|
|
||||||
temperature: Sampling temperature
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
Text chunks as they're generated
|
|
||||||
"""
|
|
||||||
messages = self._build_messages(prompt, context, system_prompt)
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"model": self.settings.llm_model,
|
|
||||||
"messages": messages,
|
|
||||||
"max_tokens": max_tokens or self.settings.llm_max_tokens,
|
|
||||||
"temperature": temperature or self.settings.llm_temperature,
|
|
||||||
"stream": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
async with self._client.stream("POST", "/v1/chat/completions", json=payload) as response:
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
async for line in response.aiter_lines():
|
|
||||||
if line.startswith("data: "):
|
|
||||||
data = line[6:]
|
|
||||||
if data == "[DONE]":
|
|
||||||
break
|
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
chunk = json.loads(data)
|
|
||||||
delta = chunk["choices"][0].get("delta", {})
|
|
||||||
content = delta.get("content", "")
|
|
||||||
if content:
|
|
||||||
yield content
|
|
||||||
|
|
||||||
def _build_messages(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
context: Optional[str] = None,
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
) -> list[dict]:
|
|
||||||
"""Build the messages list for the API call."""
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
# System prompt
|
|
||||||
if system_prompt:
|
|
||||||
messages.append({"role": "system", "content": system_prompt})
|
|
||||||
elif context:
|
|
||||||
# Default RAG system prompt
|
|
||||||
messages.append(
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": (
|
|
||||||
"You are a helpful assistant. Use the provided context to answer "
|
|
||||||
"the user's question. If the context doesn't contain relevant "
|
|
||||||
"information, say so."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add context as a separate message if provided
|
|
||||||
if context:
|
|
||||||
messages.append(
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": f"Context:\n{context}\n\nQuestion: {prompt}",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
messages.append({"role": "user", "content": prompt})
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
async def health(self) -> bool:
|
|
||||||
"""Check if the LLM service is healthy."""
|
|
||||||
try:
|
|
||||||
response = await self._client.get("/health")
|
|
||||||
return response.status_code == 200
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
@@ -1,183 +0,0 @@
|
|||||||
"""
|
|
||||||
Milvus vector database client.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pymilvus import Collection, connections, utility
|
|
||||||
|
|
||||||
from handler_base.config import Settings
|
|
||||||
from handler_base.telemetry import create_span
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class MilvusClient:
|
|
||||||
"""
|
|
||||||
Client for Milvus vector database.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
client = MilvusClient()
|
|
||||||
await client.connect()
|
|
||||||
results = await client.search(embedding, limit=10)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, settings: Optional[Settings] = None):
|
|
||||||
self.settings = settings or Settings()
|
|
||||||
self._connected = False
|
|
||||||
self._collection: Optional[Collection] = None
|
|
||||||
|
|
||||||
async def connect(self, collection_name: Optional[str] = None) -> None:
|
|
||||||
"""
|
|
||||||
Connect to Milvus and load collection.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
collection_name: Collection to use (defaults to settings)
|
|
||||||
"""
|
|
||||||
collection_name = collection_name or self.settings.milvus_collection
|
|
||||||
|
|
||||||
connections.connect(
|
|
||||||
alias="default",
|
|
||||||
host=self.settings.milvus_host,
|
|
||||||
port=self.settings.milvus_port,
|
|
||||||
)
|
|
||||||
|
|
||||||
if utility.has_collection(collection_name):
|
|
||||||
self._collection = Collection(collection_name)
|
|
||||||
self._collection.load()
|
|
||||||
logger.info(f"Connected to Milvus collection: {collection_name}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"Collection {collection_name} does not exist")
|
|
||||||
|
|
||||||
self._connected = True
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""Close Milvus connection."""
|
|
||||||
if self._collection:
|
|
||||||
self._collection.release()
|
|
||||||
connections.disconnect("default")
|
|
||||||
self._connected = False
|
|
||||||
logger.info("Disconnected from Milvus")
|
|
||||||
|
|
||||||
async def search(
|
|
||||||
self,
|
|
||||||
embedding: list[float],
|
|
||||||
limit: int = 10,
|
|
||||||
output_fields: Optional[list[str]] = None,
|
|
||||||
filter_expr: Optional[str] = None,
|
|
||||||
) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Search for similar vectors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embedding: Query embedding vector
|
|
||||||
limit: Maximum number of results
|
|
||||||
output_fields: Fields to return (default: all)
|
|
||||||
filter_expr: Optional filter expression
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of results with 'id', 'distance', and requested fields
|
|
||||||
"""
|
|
||||||
if not self._collection:
|
|
||||||
raise RuntimeError("Not connected to collection")
|
|
||||||
|
|
||||||
with create_span("milvus.search") as span:
|
|
||||||
if span:
|
|
||||||
span.set_attribute("milvus.collection", self._collection.name)
|
|
||||||
span.set_attribute("milvus.limit", limit)
|
|
||||||
|
|
||||||
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
|
|
||||||
|
|
||||||
results = self._collection.search(
|
|
||||||
data=[embedding],
|
|
||||||
anns_field="embedding",
|
|
||||||
param=search_params,
|
|
||||||
limit=limit,
|
|
||||||
output_fields=output_fields,
|
|
||||||
expr=filter_expr,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert to list of dicts
|
|
||||||
hits = []
|
|
||||||
for hit in results[0]:
|
|
||||||
item = {
|
|
||||||
"id": hit.id,
|
|
||||||
"distance": hit.distance,
|
|
||||||
"score": 1 - hit.distance, # Convert distance to similarity
|
|
||||||
}
|
|
||||||
# Add output fields
|
|
||||||
if output_fields:
|
|
||||||
for field in output_fields:
|
|
||||||
if hasattr(hit.entity, field):
|
|
||||||
item[field] = getattr(hit.entity, field)
|
|
||||||
hits.append(item)
|
|
||||||
|
|
||||||
if span:
|
|
||||||
span.set_attribute("milvus.results", len(hits))
|
|
||||||
|
|
||||||
return hits
|
|
||||||
|
|
||||||
async def search_with_texts(
|
|
||||||
self,
|
|
||||||
embedding: list[float],
|
|
||||||
limit: int = 10,
|
|
||||||
text_field: str = "text",
|
|
||||||
metadata_fields: Optional[list[str]] = None,
|
|
||||||
) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Search and return text content with metadata.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embedding: Query embedding
|
|
||||||
limit: Maximum results
|
|
||||||
text_field: Name of text field in collection
|
|
||||||
metadata_fields: Additional metadata fields to return
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of results with text and metadata
|
|
||||||
"""
|
|
||||||
output_fields = [text_field]
|
|
||||||
if metadata_fields:
|
|
||||||
output_fields.extend(metadata_fields)
|
|
||||||
|
|
||||||
return await self.search(embedding, limit, output_fields)
|
|
||||||
|
|
||||||
async def insert(
|
|
||||||
self,
|
|
||||||
embeddings: list[list[float]],
|
|
||||||
data: list[dict],
|
|
||||||
) -> list[int]:
|
|
||||||
"""
|
|
||||||
Insert vectors with data into the collection.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embeddings: List of embedding vectors
|
|
||||||
data: List of dicts with field values
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of inserted IDs
|
|
||||||
"""
|
|
||||||
if not self._collection:
|
|
||||||
raise RuntimeError("Not connected to collection")
|
|
||||||
|
|
||||||
with create_span("milvus.insert") as span:
|
|
||||||
if span:
|
|
||||||
span.set_attribute("milvus.collection", self._collection.name)
|
|
||||||
span.set_attribute("milvus.count", len(embeddings))
|
|
||||||
|
|
||||||
# Build insert data
|
|
||||||
insert_data = [embeddings]
|
|
||||||
for field in self._collection.schema.fields:
|
|
||||||
if field.name not in ("id", "embedding"):
|
|
||||||
field_values = [d.get(field.name) for d in data]
|
|
||||||
insert_data.append(field_values)
|
|
||||||
|
|
||||||
result = self._collection.insert(insert_data)
|
|
||||||
self._collection.flush()
|
|
||||||
|
|
||||||
return result.primary_keys
|
|
||||||
|
|
||||||
def health(self) -> bool:
|
|
||||||
"""Check if connected to Milvus."""
|
|
||||||
return self._connected and utility.get_connection_addr("default") is not None
|
|
||||||
@@ -1,168 +0,0 @@
|
|||||||
"""
|
|
||||||
Reranker service client (Infinity/BGE Reranker).
|
|
||||||
|
|
||||||
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from handler_base.config import Settings
|
|
||||||
from handler_base.ray_utils import get_ray_handle
|
|
||||||
from handler_base.telemetry import create_span
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class RerankerClient:
|
|
||||||
"""
|
|
||||||
Client for the reranker service (Infinity with BGE Reranker).
|
|
||||||
|
|
||||||
When running inside Ray, automatically uses Ray handles for faster
|
|
||||||
internal communication. Falls back to HTTP for external calls.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
client = RerankerClient()
|
|
||||||
reranked = await client.rerank("query", ["doc1", "doc2"])
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Ray Serve deployment configuration
|
|
||||||
RAY_DEPLOYMENT_NAME = "RerankerDeployment"
|
|
||||||
RAY_APP_NAME = "reranker"
|
|
||||||
|
|
||||||
def __init__(self, settings: Optional[Settings] = None):
|
|
||||||
self.settings = settings or Settings()
|
|
||||||
self._client = httpx.AsyncClient(
|
|
||||||
base_url=self.settings.reranker_url,
|
|
||||||
timeout=self.settings.http_timeout,
|
|
||||||
)
|
|
||||||
self._ray_handle: Optional[Any] = None
|
|
||||||
self._ray_checked = False
|
|
||||||
|
|
||||||
def _get_ray_handle(self) -> Optional[Any]:
|
|
||||||
"""Get Ray handle, checking only once."""
|
|
||||||
if not self._ray_checked:
|
|
||||||
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
|
|
||||||
self._ray_checked = True
|
|
||||||
return self._ray_handle
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""Close the HTTP client."""
|
|
||||||
await self._client.aclose()
|
|
||||||
|
|
||||||
async def rerank(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
documents: list[str],
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Rerank documents based on relevance to query.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Query text
|
|
||||||
documents: List of documents to rerank
|
|
||||||
top_k: Number of top results to return (default: all)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of dicts with 'index', 'score', and 'document' keys,
|
|
||||||
sorted by relevance score descending.
|
|
||||||
"""
|
|
||||||
with create_span("reranker.rerank") as span:
|
|
||||||
if span:
|
|
||||||
span.set_attribute("reranker.num_documents", len(documents))
|
|
||||||
if top_k:
|
|
||||||
span.set_attribute("reranker.top_k", top_k)
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"query": query,
|
|
||||||
"documents": documents,
|
|
||||||
}
|
|
||||||
if top_k:
|
|
||||||
payload["top_n"] = top_k
|
|
||||||
|
|
||||||
# Try Ray handle first (faster internal path)
|
|
||||||
handle = self._get_ray_handle()
|
|
||||||
if handle:
|
|
||||||
try:
|
|
||||||
if span:
|
|
||||||
span.set_attribute("reranker.transport", "ray")
|
|
||||||
results = await handle.rerank.remote(query, documents, top_k)
|
|
||||||
# Enrich with original documents
|
|
||||||
enriched = []
|
|
||||||
for r in results:
|
|
||||||
idx = r.get("index", 0)
|
|
||||||
enriched.append(
|
|
||||||
{
|
|
||||||
"index": idx,
|
|
||||||
"score": r.get("relevance_score", r.get("score", 0)),
|
|
||||||
"document": documents[idx] if idx < len(documents) else "",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return enriched
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
|
|
||||||
|
|
||||||
# HTTP fallback
|
|
||||||
if span:
|
|
||||||
span.set_attribute("reranker.transport", "http")
|
|
||||||
|
|
||||||
response = await self._client.post("/rerank", json=payload)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
results = result.get("results", [])
|
|
||||||
|
|
||||||
# Enrich with original documents
|
|
||||||
enriched = []
|
|
||||||
for r in results:
|
|
||||||
idx = r.get("index", 0)
|
|
||||||
enriched.append(
|
|
||||||
{
|
|
||||||
"index": idx,
|
|
||||||
"score": r.get("relevance_score", r.get("score", 0)),
|
|
||||||
"document": documents[idx] if idx < len(documents) else "",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return enriched
|
|
||||||
|
|
||||||
async def rerank_with_metadata(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
documents: list[dict],
|
|
||||||
text_key: str = "text",
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Rerank documents with metadata, preserving metadata in results.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Query text
|
|
||||||
documents: List of dicts with text and metadata
|
|
||||||
text_key: Key containing text in each document dict
|
|
||||||
top_k: Number of top results to return
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Reranked documents with original metadata preserved.
|
|
||||||
"""
|
|
||||||
texts = [d.get(text_key, "") for d in documents]
|
|
||||||
reranked = await self.rerank(query, texts, top_k)
|
|
||||||
|
|
||||||
# Merge back metadata
|
|
||||||
for r in reranked:
|
|
||||||
idx = r["index"]
|
|
||||||
if idx < len(documents):
|
|
||||||
r["metadata"] = {k: v for k, v in documents[idx].items() if k != text_key}
|
|
||||||
|
|
||||||
return reranked
|
|
||||||
|
|
||||||
async def health(self) -> bool:
|
|
||||||
"""Check if the reranker service is healthy."""
|
|
||||||
try:
|
|
||||||
response = await self._client.get("/health")
|
|
||||||
return response.status_code == 200
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
@@ -1,170 +0,0 @@
|
|||||||
"""
|
|
||||||
STT service client (Whisper/faster-whisper).
|
|
||||||
|
|
||||||
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from handler_base.config import STTSettings
|
|
||||||
from handler_base.ray_utils import get_ray_handle
|
|
||||||
from handler_base.telemetry import create_span
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class STTClient:
|
|
||||||
"""
|
|
||||||
Client for the STT service (Whisper/faster-whisper).
|
|
||||||
|
|
||||||
When running inside Ray, automatically uses Ray handles for faster
|
|
||||||
internal communication. Falls back to HTTP for external calls.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
client = STTClient()
|
|
||||||
text = await client.transcribe(audio_bytes)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Ray Serve deployment configuration
|
|
||||||
RAY_DEPLOYMENT_NAME = "WhisperDeployment"
|
|
||||||
RAY_APP_NAME = "whisper"
|
|
||||||
|
|
||||||
def __init__(self, settings: Optional[STTSettings] = None):
|
|
||||||
self.settings = settings or STTSettings()
|
|
||||||
self._client = httpx.AsyncClient(
|
|
||||||
base_url=self.settings.stt_url,
|
|
||||||
timeout=180.0, # Transcription can be slow
|
|
||||||
)
|
|
||||||
self._ray_handle: Optional[Any] = None
|
|
||||||
self._ray_checked = False
|
|
||||||
|
|
||||||
def _get_ray_handle(self) -> Optional[Any]:
|
|
||||||
"""Get Ray handle, checking only once."""
|
|
||||||
if not self._ray_checked:
|
|
||||||
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
|
|
||||||
self._ray_checked = True
|
|
||||||
return self._ray_handle
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""Close the HTTP client."""
|
|
||||||
await self._client.aclose()
|
|
||||||
|
|
||||||
async def transcribe(
|
|
||||||
self,
|
|
||||||
audio: bytes,
|
|
||||||
language: Optional[str] = None,
|
|
||||||
task: Optional[str] = None,
|
|
||||||
response_format: str = "json",
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Transcribe audio to text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audio: Audio bytes (WAV, MP3, etc.)
|
|
||||||
language: Language code (None for auto-detect)
|
|
||||||
task: "transcribe" or "translate"
|
|
||||||
response_format: "json", "text", "srt", "vtt"
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with 'text', 'language', and optional 'segments'
|
|
||||||
"""
|
|
||||||
language = language or self.settings.stt_language
|
|
||||||
task = task or self.settings.stt_task
|
|
||||||
|
|
||||||
with create_span("stt.transcribe") as span:
|
|
||||||
if span:
|
|
||||||
span.set_attribute("stt.task", task)
|
|
||||||
span.set_attribute("stt.audio_size", len(audio))
|
|
||||||
if language:
|
|
||||||
span.set_attribute("stt.language", language)
|
|
||||||
|
|
||||||
# Try Ray handle first (faster internal path)
|
|
||||||
handle = self._get_ray_handle()
|
|
||||||
if handle:
|
|
||||||
try:
|
|
||||||
if span:
|
|
||||||
span.set_attribute("stt.transport", "ray")
|
|
||||||
result = await handle.transcribe.remote(audio, language, task)
|
|
||||||
if span:
|
|
||||||
span.set_attribute("stt.result_length", len(result.get("text", "")))
|
|
||||||
if result.get("language"):
|
|
||||||
span.set_attribute("stt.detected_language", result["language"])
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
|
|
||||||
|
|
||||||
# HTTP fallback
|
|
||||||
if span:
|
|
||||||
span.set_attribute("stt.transport", "http")
|
|
||||||
|
|
||||||
files = {"file": ("audio.wav", audio, "audio/wav")}
|
|
||||||
data = {
|
|
||||||
"response_format": response_format,
|
|
||||||
}
|
|
||||||
if language:
|
|
||||||
data["language"] = language
|
|
||||||
|
|
||||||
# Choose endpoint based on task
|
|
||||||
if task == "translate":
|
|
||||||
endpoint = "/v1/audio/translations"
|
|
||||||
else:
|
|
||||||
endpoint = "/v1/audio/transcriptions"
|
|
||||||
|
|
||||||
response = await self._client.post(endpoint, files=files, data=data)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
if response_format == "text":
|
|
||||||
return {"text": response.text}
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
|
|
||||||
if span:
|
|
||||||
span.set_attribute("stt.result_length", len(result.get("text", "")))
|
|
||||||
if result.get("language"):
|
|
||||||
span.set_attribute("stt.detected_language", result["language"])
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def transcribe_file(
|
|
||||||
self,
|
|
||||||
file_path: str,
|
|
||||||
language: Optional[str] = None,
|
|
||||||
task: Optional[str] = None,
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Transcribe an audio file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: Path to audio file
|
|
||||||
language: Language code
|
|
||||||
task: "transcribe" or "translate"
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Transcription result
|
|
||||||
"""
|
|
||||||
with open(file_path, "rb") as f:
|
|
||||||
audio = f.read()
|
|
||||||
return await self.transcribe(audio, language, task)
|
|
||||||
|
|
||||||
async def translate(self, audio: bytes) -> dict:
|
|
||||||
"""
|
|
||||||
Translate audio to English.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audio: Audio bytes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Translation result with 'text' key
|
|
||||||
"""
|
|
||||||
return await self.transcribe(audio, task="translate")
|
|
||||||
|
|
||||||
async def health(self) -> bool:
|
|
||||||
"""Check if the STT service is healthy."""
|
|
||||||
try:
|
|
||||||
response = await self._client.get("/health")
|
|
||||||
return response.status_code == 200
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
@@ -1,149 +0,0 @@
|
|||||||
"""
|
|
||||||
TTS service client (Coqui XTTS).
|
|
||||||
|
|
||||||
Supports both HTTP (external) and Ray handles (internal) for optimal performance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from handler_base.config import TTSSettings
|
|
||||||
from handler_base.ray_utils import get_ray_handle
|
|
||||||
from handler_base.telemetry import create_span
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class TTSClient:
|
|
||||||
"""
|
|
||||||
Client for the TTS service (Coqui XTTS).
|
|
||||||
|
|
||||||
When running inside Ray, automatically uses Ray handles for faster
|
|
||||||
internal communication. Falls back to HTTP for external calls.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
client = TTSClient()
|
|
||||||
audio_bytes = await client.synthesize("Hello world")
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Ray Serve deployment configuration
|
|
||||||
RAY_DEPLOYMENT_NAME = "TTSDeployment"
|
|
||||||
RAY_APP_NAME = "tts"
|
|
||||||
|
|
||||||
def __init__(self, settings: Optional[TTSSettings] = None):
|
|
||||||
self.settings = settings or TTSSettings()
|
|
||||||
self._client = httpx.AsyncClient(
|
|
||||||
base_url=self.settings.tts_url,
|
|
||||||
timeout=120.0, # TTS can be slow
|
|
||||||
)
|
|
||||||
self._ray_handle: Optional[Any] = None
|
|
||||||
self._ray_checked = False
|
|
||||||
|
|
||||||
def _get_ray_handle(self) -> Optional[Any]:
|
|
||||||
"""Get Ray handle, checking only once."""
|
|
||||||
if not self._ray_checked:
|
|
||||||
self._ray_handle = get_ray_handle(self.RAY_DEPLOYMENT_NAME, self.RAY_APP_NAME)
|
|
||||||
self._ray_checked = True
|
|
||||||
return self._ray_handle
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""Close the HTTP client."""
|
|
||||||
await self._client.aclose()
|
|
||||||
|
|
||||||
async def synthesize(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
language: Optional[str] = None,
|
|
||||||
speaker: Optional[str] = None,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Synthesize speech from text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Text to synthesize
|
|
||||||
language: Language code (e.g., "en", "es", "fr")
|
|
||||||
speaker: Speaker ID or reference
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
WAV audio bytes
|
|
||||||
"""
|
|
||||||
language = language or self.settings.tts_language
|
|
||||||
|
|
||||||
with create_span("tts.synthesize") as span:
|
|
||||||
if span:
|
|
||||||
span.set_attribute("tts.language", language)
|
|
||||||
span.set_attribute("tts.text_length", len(text))
|
|
||||||
|
|
||||||
params = {
|
|
||||||
"text": text,
|
|
||||||
"language_id": language,
|
|
||||||
}
|
|
||||||
if speaker:
|
|
||||||
params["speaker_id"] = speaker
|
|
||||||
|
|
||||||
# Try Ray handle first (faster internal path)
|
|
||||||
handle = self._get_ray_handle()
|
|
||||||
if handle:
|
|
||||||
try:
|
|
||||||
if span:
|
|
||||||
span.set_attribute("tts.transport", "ray")
|
|
||||||
audio_bytes = await handle.synthesize.remote(text, language, speaker)
|
|
||||||
if span:
|
|
||||||
span.set_attribute("tts.audio_size", len(audio_bytes))
|
|
||||||
return audio_bytes
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Ray handle failed, falling back to HTTP: {e}")
|
|
||||||
|
|
||||||
# HTTP fallback
|
|
||||||
if span:
|
|
||||||
span.set_attribute("tts.transport", "http")
|
|
||||||
|
|
||||||
response = await self._client.get("/api/tts", params=params)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
audio_bytes = response.content
|
|
||||||
|
|
||||||
if span:
|
|
||||||
span.set_attribute("tts.audio_size", len(audio_bytes))
|
|
||||||
|
|
||||||
return audio_bytes
|
|
||||||
|
|
||||||
async def synthesize_to_file(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
output_path: str,
|
|
||||||
language: Optional[str] = None,
|
|
||||||
speaker: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Synthesize speech and save to a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Text to synthesize
|
|
||||||
output_path: Path to save the audio file
|
|
||||||
language: Language code
|
|
||||||
speaker: Speaker ID
|
|
||||||
"""
|
|
||||||
audio_bytes = await self.synthesize(text, language, speaker)
|
|
||||||
|
|
||||||
with open(output_path, "wb") as f:
|
|
||||||
f.write(audio_bytes)
|
|
||||||
|
|
||||||
async def get_speakers(self) -> list[dict]:
|
|
||||||
"""Get available speakers/voices."""
|
|
||||||
try:
|
|
||||||
response = await self._client.get("/api/speakers")
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.json()
|
|
||||||
except Exception:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def health(self) -> bool:
|
|
||||||
"""Check if the TTS service is healthy."""
|
|
||||||
try:
|
|
||||||
response = await self._client.get("/health")
|
|
||||||
return response.status_code == 200
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
@@ -1,101 +0,0 @@
|
|||||||
"""
|
|
||||||
Configuration management using Pydantic Settings.
|
|
||||||
|
|
||||||
Environment variables are automatically loaded and validated.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
|
||||||
"""Base settings for all handler services."""
|
|
||||||
|
|
||||||
model_config = SettingsConfigDict(
|
|
||||||
env_file=".env",
|
|
||||||
env_file_encoding="utf-8",
|
|
||||||
extra="ignore",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Service identification
|
|
||||||
service_name: str = "handler"
|
|
||||||
service_version: str = "1.0.0"
|
|
||||||
service_namespace: str = "ai-ml"
|
|
||||||
deployment_env: str = "production"
|
|
||||||
|
|
||||||
# NATS configuration
|
|
||||||
nats_url: str = "nats://nats.ai-ml.svc.cluster.local:4222"
|
|
||||||
nats_user: Optional[str] = None
|
|
||||||
nats_password: Optional[str] = None
|
|
||||||
nats_queue_group: Optional[str] = None
|
|
||||||
|
|
||||||
# Redis/Valkey configuration
|
|
||||||
redis_url: str = "redis://valkey.ai-ml.svc.cluster.local:6379"
|
|
||||||
redis_password: Optional[str] = None
|
|
||||||
|
|
||||||
# Milvus configuration
|
|
||||||
milvus_host: str = "milvus.ai-ml.svc.cluster.local"
|
|
||||||
milvus_port: int = 19530
|
|
||||||
milvus_collection: str = "documents"
|
|
||||||
|
|
||||||
# Service endpoints
|
|
||||||
embeddings_url: str = "http://embeddings-predictor.ai-ml.svc.cluster.local"
|
|
||||||
reranker_url: str = "http://reranker-predictor.ai-ml.svc.cluster.local"
|
|
||||||
llm_url: str = "http://vllm-predictor.ai-ml.svc.cluster.local"
|
|
||||||
tts_url: str = "http://tts-predictor.ai-ml.svc.cluster.local"
|
|
||||||
stt_url: str = "http://whisper-predictor.ai-ml.svc.cluster.local"
|
|
||||||
|
|
||||||
# OpenTelemetry configuration
|
|
||||||
otel_enabled: bool = True
|
|
||||||
otel_endpoint: str = "http://opentelemetry-collector.observability.svc.cluster.local:4317"
|
|
||||||
otel_use_http: bool = False
|
|
||||||
|
|
||||||
# HyperDX configuration
|
|
||||||
hyperdx_enabled: bool = False
|
|
||||||
hyperdx_api_key: Optional[str] = None
|
|
||||||
hyperdx_endpoint: str = "https://in-otel.hyperdx.io"
|
|
||||||
|
|
||||||
# MLflow configuration
|
|
||||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80"
|
|
||||||
mlflow_experiment_name: Optional[str] = None
|
|
||||||
mlflow_enabled: bool = True
|
|
||||||
|
|
||||||
# Health check configuration
|
|
||||||
health_port: int = 8080
|
|
||||||
health_path: str = "/health"
|
|
||||||
ready_path: str = "/ready"
|
|
||||||
|
|
||||||
# Timeouts (seconds)
|
|
||||||
http_timeout: float = 60.0
|
|
||||||
nats_timeout: float = 30.0
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsSettings(Settings):
|
|
||||||
"""Settings for embeddings service client."""
|
|
||||||
|
|
||||||
embeddings_model: str = "bge"
|
|
||||||
embeddings_batch_size: int = 32
|
|
||||||
|
|
||||||
|
|
||||||
class LLMSettings(Settings):
|
|
||||||
"""Settings for LLM service client."""
|
|
||||||
|
|
||||||
llm_model: str = "default"
|
|
||||||
llm_max_tokens: int = 2048
|
|
||||||
llm_temperature: float = 0.7
|
|
||||||
llm_top_p: float = 0.9
|
|
||||||
|
|
||||||
|
|
||||||
class TTSSettings(Settings):
|
|
||||||
"""Settings for TTS service client."""
|
|
||||||
|
|
||||||
tts_language: str = "en"
|
|
||||||
tts_speaker: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class STTSettings(Settings):
|
|
||||||
"""Settings for STT service client."""
|
|
||||||
|
|
||||||
stt_language: Optional[str] = None # Auto-detect
|
|
||||||
stt_task: str = "transcribe" # or "translate"
|
|
||||||
@@ -1,222 +0,0 @@
|
|||||||
"""
|
|
||||||
Base handler class for building NATS-based services.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import signal
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from nats.aio.msg import Msg
|
|
||||||
|
|
||||||
from handler_base.config import Settings
|
|
||||||
from handler_base.health import HealthServer
|
|
||||||
from handler_base.nats_client import NATSClient
|
|
||||||
from handler_base.telemetry import create_span, setup_telemetry
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Handler(ABC):
|
|
||||||
"""
|
|
||||||
Base class for NATS message handlers.
|
|
||||||
|
|
||||||
Subclass and implement:
|
|
||||||
- setup(): Initialize your service clients
|
|
||||||
- handle_message(): Process incoming messages
|
|
||||||
- teardown(): Clean up resources (optional)
|
|
||||||
|
|
||||||
Example:
|
|
||||||
class MyHandler(Handler):
|
|
||||||
async def setup(self):
|
|
||||||
self.embeddings = EmbeddingsClient()
|
|
||||||
|
|
||||||
async def handle_message(self, msg: Msg, data: dict) -> Optional[dict]:
|
|
||||||
result = await self.embeddings.embed(data["text"])
|
|
||||||
return {"embedding": result}
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
MyHandler(subject="my.subject").run()
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
subject: str,
|
|
||||||
settings: Optional[Settings] = None,
|
|
||||||
queue_group: Optional[str] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the handler.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
subject: NATS subject to subscribe to
|
|
||||||
settings: Configuration settings
|
|
||||||
queue_group: Optional queue group for load balancing
|
|
||||||
"""
|
|
||||||
self.subject = subject
|
|
||||||
self.settings = settings or Settings()
|
|
||||||
self.queue_group = queue_group or self.settings.nats_queue_group
|
|
||||||
|
|
||||||
self.nats = NATSClient(self.settings)
|
|
||||||
self.health_server = HealthServer(self.settings, self._check_ready)
|
|
||||||
|
|
||||||
self._running = False
|
|
||||||
self._shutdown_event = asyncio.Event()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def setup(self) -> None:
|
|
||||||
"""
|
|
||||||
Initialize service clients and resources.
|
|
||||||
|
|
||||||
Called once before starting to handle messages.
|
|
||||||
Override this to set up your service-specific clients.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def handle_message(self, msg: Msg, data: Any) -> Optional[Any]:
|
|
||||||
"""
|
|
||||||
Handle an incoming message.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
msg: Raw NATS message
|
|
||||||
data: Decoded message data (msgpack unpacked)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional response data. If returned and msg has a reply subject,
|
|
||||||
the response will be sent automatically.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def teardown(self) -> None:
|
|
||||||
"""
|
|
||||||
Clean up resources.
|
|
||||||
|
|
||||||
Called during graceful shutdown.
|
|
||||||
Override to add custom cleanup logic.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def _check_ready(self) -> bool:
|
|
||||||
"""Check if the service is ready to handle requests."""
|
|
||||||
return self._running and self.nats._nc is not None
|
|
||||||
|
|
||||||
async def _message_handler(self, msg: Msg) -> None:
|
|
||||||
"""Internal message handler with tracing and error handling."""
|
|
||||||
with create_span(f"handle.{self.subject}") as span:
|
|
||||||
try:
|
|
||||||
# Decode message
|
|
||||||
data = NATSClient.decode_msgpack(msg)
|
|
||||||
|
|
||||||
if span:
|
|
||||||
span.set_attribute("messaging.destination", msg.subject)
|
|
||||||
if isinstance(data, dict):
|
|
||||||
request_id = data.get("request_id", data.get("id"))
|
|
||||||
if request_id:
|
|
||||||
span.set_attribute("request.id", str(request_id))
|
|
||||||
|
|
||||||
# Handle message
|
|
||||||
response = await self.handle_message(msg, data)
|
|
||||||
|
|
||||||
# Send response if applicable
|
|
||||||
if response is not None and msg.reply:
|
|
||||||
await self.nats.publish(msg.reply, response)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Error handling message on {msg.subject}")
|
|
||||||
if span:
|
|
||||||
span.set_attribute("error", True)
|
|
||||||
span.set_attribute("error.message", str(e))
|
|
||||||
|
|
||||||
# Send error response if reply expected
|
|
||||||
if msg.reply:
|
|
||||||
error_response = {
|
|
||||||
"error": True,
|
|
||||||
"message": str(e),
|
|
||||||
"type": type(e).__name__,
|
|
||||||
}
|
|
||||||
await self.nats.publish(msg.reply, error_response)
|
|
||||||
|
|
||||||
def _setup_signals(self) -> None:
|
|
||||||
"""Set up signal handlers for graceful shutdown."""
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
|
||||||
loop.add_signal_handler(sig, self._handle_signal, sig)
|
|
||||||
|
|
||||||
def _handle_signal(self, sig: signal.Signals) -> None:
|
|
||||||
"""Handle shutdown signal."""
|
|
||||||
logger.info(f"Received {sig.name}, initiating graceful shutdown...")
|
|
||||||
self._shutdown_event.set()
|
|
||||||
|
|
||||||
async def _run(self) -> None:
|
|
||||||
"""Main async run loop."""
|
|
||||||
# Setup telemetry
|
|
||||||
setup_telemetry(self.settings)
|
|
||||||
|
|
||||||
# Start health server
|
|
||||||
self.health_server.start()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Connect to NATS
|
|
||||||
await self.nats.connect()
|
|
||||||
|
|
||||||
# Run user setup
|
|
||||||
logger.info("Running service setup...")
|
|
||||||
await self.setup()
|
|
||||||
|
|
||||||
# Subscribe to subject
|
|
||||||
await self.nats.subscribe(
|
|
||||||
self.subject,
|
|
||||||
self._message_handler,
|
|
||||||
queue=self.queue_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._running = True
|
|
||||||
logger.info(f"Handler ready, listening on {self.subject}")
|
|
||||||
|
|
||||||
# Wait for shutdown signal
|
|
||||||
await self._shutdown_event.wait()
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Fatal error in handler")
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
self._running = False
|
|
||||||
|
|
||||||
# Graceful shutdown
|
|
||||||
logger.info("Shutting down...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.teardown()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error during teardown: {e}")
|
|
||||||
|
|
||||||
await self.nats.close()
|
|
||||||
self.health_server.stop()
|
|
||||||
|
|
||||||
logger.info("Shutdown complete")
|
|
||||||
|
|
||||||
def run(self) -> None:
|
|
||||||
"""
|
|
||||||
Run the handler.
|
|
||||||
|
|
||||||
This is the main entry point. It sets up signal handlers
|
|
||||||
and runs the async event loop.
|
|
||||||
"""
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Starting {self.settings.service_name} v{self.settings.service_version}")
|
|
||||||
|
|
||||||
# Run the async loop
|
|
||||||
asyncio.run(self._run_with_signals())
|
|
||||||
|
|
||||||
async def _run_with_signals(self) -> None:
|
|
||||||
"""Run with signal handling."""
|
|
||||||
self._setup_signals()
|
|
||||||
await self._run()
|
|
||||||
@@ -1,125 +0,0 @@
|
|||||||
"""
|
|
||||||
HTTP health check server.
|
|
||||||
|
|
||||||
Provides /health and /ready endpoints for Kubernetes probes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
||||||
from typing import Awaitable, Callable, Optional
|
|
||||||
|
|
||||||
from handler_base.config import Settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class HealthHandler(BaseHTTPRequestHandler):
|
|
||||||
"""HTTP request handler for health checks."""
|
|
||||||
|
|
||||||
# Class-level state
|
|
||||||
ready_check: Optional[Callable[[], Awaitable[bool]]] = None
|
|
||||||
health_path: str = "/health"
|
|
||||||
ready_path: str = "/ready"
|
|
||||||
|
|
||||||
def log_message(self, format, *args):
|
|
||||||
"""Suppress default logging."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def do_GET(self):
|
|
||||||
"""Handle GET requests for health/ready endpoints."""
|
|
||||||
if self.path == self.health_path:
|
|
||||||
self._respond_ok({"status": "healthy"})
|
|
||||||
elif self.path == self.ready_path:
|
|
||||||
self._handle_ready()
|
|
||||||
else:
|
|
||||||
self._respond_not_found()
|
|
||||||
|
|
||||||
def _handle_ready(self):
|
|
||||||
"""Check readiness and respond."""
|
|
||||||
# Access via class to avoid method binding issues
|
|
||||||
ready_check = HealthHandler.ready_check
|
|
||||||
if ready_check is None:
|
|
||||||
self._respond_ok({"status": "ready"})
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Run the async check in a new event loop
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
try:
|
|
||||||
is_ready = loop.run_until_complete(ready_check())
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
if is_ready:
|
|
||||||
self._respond_ok({"status": "ready"})
|
|
||||||
else:
|
|
||||||
self._respond_unavailable({"status": "not ready"})
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Readiness check failed")
|
|
||||||
self._respond_unavailable({"status": "error", "message": str(e)})
|
|
||||||
|
|
||||||
def _respond_ok(self, data: dict):
|
|
||||||
self.send_response(200)
|
|
||||||
self.send_header("Content-Type", "application/json")
|
|
||||||
self.end_headers()
|
|
||||||
self.wfile.write(json.dumps(data).encode())
|
|
||||||
|
|
||||||
def _respond_unavailable(self, data: dict):
|
|
||||||
self.send_response(503)
|
|
||||||
self.send_header("Content-Type", "application/json")
|
|
||||||
self.end_headers()
|
|
||||||
self.wfile.write(json.dumps(data).encode())
|
|
||||||
|
|
||||||
def _respond_not_found(self):
|
|
||||||
self.send_response(404)
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
|
|
||||||
class HealthServer:
|
|
||||||
"""
|
|
||||||
Background HTTP server for health checks.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
server = HealthServer(settings)
|
|
||||||
server.start()
|
|
||||||
# ... run your service ...
|
|
||||||
server.stop()
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
settings: Optional[Settings] = None,
|
|
||||||
ready_check: Optional[Callable[[], Awaitable[bool]]] = None,
|
|
||||||
):
|
|
||||||
self.settings = settings or Settings()
|
|
||||||
self.ready_check = ready_check
|
|
||||||
self._server: Optional[HTTPServer] = None
|
|
||||||
self._thread: Optional[threading.Thread] = None
|
|
||||||
|
|
||||||
def start(self) -> None:
|
|
||||||
"""Start the health check server in a background thread."""
|
|
||||||
# Configure handler class
|
|
||||||
HealthHandler.ready_check = self.ready_check
|
|
||||||
HealthHandler.health_path = self.settings.health_path
|
|
||||||
HealthHandler.ready_path = self.settings.ready_path
|
|
||||||
|
|
||||||
# Create and start server
|
|
||||||
self._server = HTTPServer(("0.0.0.0", self.settings.health_port), HealthHandler)
|
|
||||||
self._thread = threading.Thread(target=self._server.serve_forever, daemon=True)
|
|
||||||
self._thread.start()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Health server started on port {self.settings.health_port} "
|
|
||||||
f"(health: {self.settings.health_path}, ready: {self.settings.ready_path})"
|
|
||||||
)
|
|
||||||
|
|
||||||
def stop(self) -> None:
|
|
||||||
"""Stop the health check server."""
|
|
||||||
if self._server:
|
|
||||||
self._server.shutdown()
|
|
||||||
self._server = None
|
|
||||||
self._thread = None
|
|
||||||
logger.info("Health server stopped")
|
|
||||||
@@ -1,188 +0,0 @@
|
|||||||
"""
|
|
||||||
NATS client wrapper with connection management and utilities.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Awaitable, Callable, Optional
|
|
||||||
|
|
||||||
import msgpack
|
|
||||||
import nats
|
|
||||||
from nats.aio.client import Client
|
|
||||||
from nats.aio.msg import Msg
|
|
||||||
from nats.js import JetStreamContext
|
|
||||||
|
|
||||||
from handler_base.config import Settings
|
|
||||||
from handler_base.telemetry import create_span
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class NATSClient:
|
|
||||||
"""
|
|
||||||
NATS client with automatic connection management.
|
|
||||||
|
|
||||||
Supports:
|
|
||||||
- Core NATS pub/sub
|
|
||||||
- JetStream for persistence
|
|
||||||
- Queue groups for load balancing
|
|
||||||
- Msgpack serialization
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, settings: Optional[Settings] = None):
|
|
||||||
self.settings = settings or Settings()
|
|
||||||
self._nc: Optional[Client] = None
|
|
||||||
self._js: Optional[JetStreamContext] = None
|
|
||||||
self._subscriptions: list = []
|
|
||||||
|
|
||||||
@property
|
|
||||||
def nc(self) -> Client:
|
|
||||||
"""Get the NATS client, raising if not connected."""
|
|
||||||
if self._nc is None:
|
|
||||||
raise RuntimeError("NATS client not connected. Call connect() first.")
|
|
||||||
return self._nc
|
|
||||||
|
|
||||||
@property
|
|
||||||
def js(self) -> JetStreamContext:
|
|
||||||
"""Get JetStream context, raising if not connected."""
|
|
||||||
if self._js is None:
|
|
||||||
raise RuntimeError("JetStream not initialized. Call connect() first.")
|
|
||||||
return self._js
|
|
||||||
|
|
||||||
async def connect(self) -> None:
|
|
||||||
"""Connect to NATS server."""
|
|
||||||
connect_opts = {
|
|
||||||
"servers": self.settings.nats_url,
|
|
||||||
"reconnect_time_wait": 2,
|
|
||||||
"max_reconnect_attempts": -1, # Infinite
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.settings.nats_user and self.settings.nats_password:
|
|
||||||
connect_opts["user"] = self.settings.nats_user
|
|
||||||
connect_opts["password"] = self.settings.nats_password
|
|
||||||
|
|
||||||
logger.info(f"Connecting to NATS at {self.settings.nats_url}")
|
|
||||||
self._nc = await nats.connect(**connect_opts)
|
|
||||||
self._js = self._nc.jetstream()
|
|
||||||
logger.info("Connected to NATS")
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""Close NATS connection gracefully."""
|
|
||||||
if self._nc:
|
|
||||||
# Drain subscriptions first
|
|
||||||
for sub in self._subscriptions:
|
|
||||||
try:
|
|
||||||
await sub.drain()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error draining subscription: {e}")
|
|
||||||
|
|
||||||
await self._nc.drain()
|
|
||||||
await self._nc.close()
|
|
||||||
self._nc = None
|
|
||||||
self._js = None
|
|
||||||
logger.info("NATS connection closed")
|
|
||||||
|
|
||||||
async def subscribe(
|
|
||||||
self,
|
|
||||||
subject: str,
|
|
||||||
handler: Callable[[Msg], Awaitable[None]],
|
|
||||||
queue: Optional[str] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Subscribe to a subject with a handler function.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
subject: NATS subject to subscribe to
|
|
||||||
handler: Async function to handle messages
|
|
||||||
queue: Optional queue group for load balancing
|
|
||||||
"""
|
|
||||||
queue = queue or self.settings.nats_queue_group
|
|
||||||
|
|
||||||
if queue:
|
|
||||||
sub = await self.nc.subscribe(subject, queue=queue, cb=handler)
|
|
||||||
logger.info(f"Subscribed to {subject} (queue: {queue})")
|
|
||||||
else:
|
|
||||||
sub = await self.nc.subscribe(subject, cb=handler)
|
|
||||||
logger.info(f"Subscribed to {subject}")
|
|
||||||
|
|
||||||
self._subscriptions.append(sub)
|
|
||||||
return sub
|
|
||||||
|
|
||||||
async def publish(
|
|
||||||
self,
|
|
||||||
subject: str,
|
|
||||||
data: Any,
|
|
||||||
use_msgpack: bool = True,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Publish a message to a subject.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
subject: NATS subject to publish to
|
|
||||||
data: Data to publish (will be serialized)
|
|
||||||
use_msgpack: Whether to use msgpack (True) or JSON (False)
|
|
||||||
"""
|
|
||||||
with create_span("nats.publish") as span:
|
|
||||||
if span:
|
|
||||||
span.set_attribute("messaging.destination", subject)
|
|
||||||
|
|
||||||
if use_msgpack:
|
|
||||||
payload = msgpack.packb(data, use_bin_type=True)
|
|
||||||
else:
|
|
||||||
import json
|
|
||||||
|
|
||||||
payload = json.dumps(data).encode()
|
|
||||||
|
|
||||||
await self.nc.publish(subject, payload)
|
|
||||||
|
|
||||||
async def request(
|
|
||||||
self,
|
|
||||||
subject: str,
|
|
||||||
data: Any,
|
|
||||||
timeout: Optional[float] = None,
|
|
||||||
use_msgpack: bool = True,
|
|
||||||
) -> Any:
|
|
||||||
"""
|
|
||||||
Send a request and wait for response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
subject: NATS subject to send request to
|
|
||||||
data: Request data
|
|
||||||
timeout: Response timeout in seconds
|
|
||||||
use_msgpack: Whether to use msgpack serialization
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Decoded response data
|
|
||||||
"""
|
|
||||||
timeout = timeout or self.settings.nats_timeout
|
|
||||||
|
|
||||||
with create_span("nats.request") as span:
|
|
||||||
if span:
|
|
||||||
span.set_attribute("messaging.destination", subject)
|
|
||||||
|
|
||||||
if use_msgpack:
|
|
||||||
payload = msgpack.packb(data, use_bin_type=True)
|
|
||||||
else:
|
|
||||||
import json
|
|
||||||
|
|
||||||
payload = json.dumps(data).encode()
|
|
||||||
|
|
||||||
response = await self.nc.request(subject, payload, timeout=timeout)
|
|
||||||
|
|
||||||
if use_msgpack:
|
|
||||||
return msgpack.unpackb(response.data, raw=False)
|
|
||||||
else:
|
|
||||||
import json
|
|
||||||
|
|
||||||
return json.loads(response.data.decode())
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def decode_msgpack(msg: Msg) -> Any:
|
|
||||||
"""Decode a msgpack message."""
|
|
||||||
return msgpack.unpackb(msg.data, raw=False)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def decode_json(msg: Msg) -> Any:
|
|
||||||
"""Decode a JSON message."""
|
|
||||||
import json
|
|
||||||
|
|
||||||
return json.loads(msg.data.decode())
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
"""
|
|
||||||
Ray integration utilities for handler-base clients.
|
|
||||||
|
|
||||||
When running inside a Ray cluster, clients can use Ray Serve handles
|
|
||||||
for faster internal communication (gRPC instead of HTTP).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Ray handle cache to avoid repeated lookups
|
|
||||||
_ray_handles: dict[str, Any] = {}
|
|
||||||
_ray_available: Optional[bool] = None
|
|
||||||
|
|
||||||
|
|
||||||
def is_ray_available() -> bool:
|
|
||||||
"""Check if we're running inside a Ray cluster."""
|
|
||||||
global _ray_available
|
|
||||||
if _ray_available is not None:
|
|
||||||
return _ray_available
|
|
||||||
|
|
||||||
try:
|
|
||||||
import ray
|
|
||||||
|
|
||||||
_ray_available = ray.is_initialized()
|
|
||||||
if _ray_available:
|
|
||||||
logger.info("Ray detected - will use Ray handles for internal calls")
|
|
||||||
return _ray_available
|
|
||||||
except ImportError:
|
|
||||||
_ray_available = False
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def get_ray_handle(deployment_name: str, app_name: str) -> Optional[Any]:
|
|
||||||
"""
|
|
||||||
Get a Ray Serve deployment handle for internal calls.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
deployment_name: Name of the Ray Serve deployment
|
|
||||||
app_name: Name of the Ray Serve application
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
DeploymentHandle if available, None otherwise
|
|
||||||
"""
|
|
||||||
if not is_ray_available():
|
|
||||||
return None
|
|
||||||
|
|
||||||
cache_key = f"{app_name}/{deployment_name}"
|
|
||||||
if cache_key in _ray_handles:
|
|
||||||
return _ray_handles[cache_key]
|
|
||||||
|
|
||||||
try:
|
|
||||||
from ray import serve
|
|
||||||
|
|
||||||
handle = serve.get_deployment_handle(deployment_name, app_name=app_name)
|
|
||||||
_ray_handles[cache_key] = handle
|
|
||||||
logger.debug(f"Got Ray handle for {cache_key}")
|
|
||||||
return handle
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"Could not get Ray handle for {cache_key}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def clear_ray_handles() -> None:
|
|
||||||
"""Clear cached Ray handles (useful for testing)."""
|
|
||||||
global _ray_handles, _ray_available
|
|
||||||
_ray_handles.clear()
|
|
||||||
_ray_available = None
|
|
||||||
@@ -1,158 +0,0 @@
|
|||||||
"""
|
|
||||||
OpenTelemetry setup for tracing and metrics.
|
|
||||||
|
|
||||||
Supports both gRPC and HTTP exporters, with optional HyperDX integration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
from opentelemetry import metrics, trace
|
|
||||||
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
|
|
||||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
|
||||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
|
|
||||||
OTLPMetricExporter as OTLPMetricExporterHTTP,
|
|
||||||
)
|
|
||||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
|
|
||||||
OTLPSpanExporter as OTLPSpanExporterHTTP,
|
|
||||||
)
|
|
||||||
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
|
|
||||||
from opentelemetry.instrumentation.logging import LoggingInstrumentor
|
|
||||||
from opentelemetry.sdk.metrics import MeterProvider
|
|
||||||
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
|
||||||
from opentelemetry.sdk.resources import SERVICE_NAME, SERVICE_NAMESPACE, SERVICE_VERSION, Resource
|
|
||||||
from opentelemetry.sdk.trace import TracerProvider
|
|
||||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
|
||||||
|
|
||||||
from handler_base.config import Settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Global references
|
|
||||||
_tracer: Optional[trace.Tracer] = None
|
|
||||||
_meter: Optional[metrics.Meter] = None
|
|
||||||
_initialized = False
|
|
||||||
|
|
||||||
|
|
||||||
def setup_telemetry(
|
|
||||||
settings: Optional[Settings] = None,
|
|
||||||
) -> Tuple[Optional[trace.Tracer], Optional[metrics.Meter]]:
|
|
||||||
"""
|
|
||||||
Initialize OpenTelemetry tracing and metrics.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
settings: Configuration settings. If None, loads from environment.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (tracer, meter) or (None, None) if disabled.
|
|
||||||
"""
|
|
||||||
global _tracer, _meter, _initialized
|
|
||||||
|
|
||||||
if _initialized:
|
|
||||||
return _tracer, _meter
|
|
||||||
|
|
||||||
if settings is None:
|
|
||||||
settings = Settings()
|
|
||||||
|
|
||||||
if not settings.otel_enabled:
|
|
||||||
logger.info("OpenTelemetry disabled")
|
|
||||||
_initialized = True
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
# Create resource with service information
|
|
||||||
resource = Resource.create(
|
|
||||||
{
|
|
||||||
SERVICE_NAME: settings.service_name,
|
|
||||||
SERVICE_VERSION: settings.service_version,
|
|
||||||
SERVICE_NAMESPACE: settings.service_namespace,
|
|
||||||
"deployment.environment": settings.deployment_env,
|
|
||||||
"host.name": os.environ.get("HOSTNAME", "unknown"),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine endpoint and exporter type
|
|
||||||
if settings.hyperdx_enabled and settings.hyperdx_api_key:
|
|
||||||
# HyperDX uses HTTP with API key header
|
|
||||||
endpoint = settings.hyperdx_endpoint
|
|
||||||
headers = {"authorization": settings.hyperdx_api_key}
|
|
||||||
use_http = True
|
|
||||||
logger.info(f"Using HyperDX endpoint: {endpoint}")
|
|
||||||
else:
|
|
||||||
endpoint = settings.otel_endpoint
|
|
||||||
headers = None
|
|
||||||
use_http = settings.otel_use_http
|
|
||||||
logger.info(f"Using OTEL endpoint: {endpoint} (HTTP: {use_http})")
|
|
||||||
|
|
||||||
# Setup tracing
|
|
||||||
if use_http:
|
|
||||||
trace_exporter = OTLPSpanExporterHTTP(
|
|
||||||
endpoint=f"{endpoint}/v1/traces",
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
trace_exporter = OTLPSpanExporter(
|
|
||||||
endpoint=endpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
tracer_provider = TracerProvider(resource=resource)
|
|
||||||
tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter))
|
|
||||||
trace.set_tracer_provider(tracer_provider)
|
|
||||||
|
|
||||||
# Setup metrics
|
|
||||||
if use_http:
|
|
||||||
metric_exporter = OTLPMetricExporterHTTP(
|
|
||||||
endpoint=f"{endpoint}/v1/metrics",
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
metric_exporter = OTLPMetricExporter(
|
|
||||||
endpoint=endpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
metric_reader = PeriodicExportingMetricReader(
|
|
||||||
metric_exporter,
|
|
||||||
export_interval_millis=60000,
|
|
||||||
)
|
|
||||||
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
|
|
||||||
metrics.set_meter_provider(meter_provider)
|
|
||||||
|
|
||||||
# Instrument libraries
|
|
||||||
HTTPXClientInstrumentor().instrument()
|
|
||||||
LoggingInstrumentor().instrument(set_logging_format=True)
|
|
||||||
|
|
||||||
# Create tracer and meter for this service
|
|
||||||
_tracer = trace.get_tracer(settings.service_name, settings.service_version)
|
|
||||||
_meter = metrics.get_meter(settings.service_name, settings.service_version)
|
|
||||||
|
|
||||||
logger.info(f"OpenTelemetry initialized for {settings.service_name}")
|
|
||||||
_initialized = True
|
|
||||||
|
|
||||||
return _tracer, _meter
|
|
||||||
|
|
||||||
|
|
||||||
def get_tracer() -> Optional[trace.Tracer]:
|
|
||||||
"""Get the global tracer instance."""
|
|
||||||
return _tracer
|
|
||||||
|
|
||||||
|
|
||||||
def get_meter() -> Optional[metrics.Meter]:
|
|
||||||
"""Get the global meter instance."""
|
|
||||||
return _meter
|
|
||||||
|
|
||||||
|
|
||||||
def create_span(name: str, **kwargs):
|
|
||||||
"""
|
|
||||||
Create a new span.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
with create_span("my_operation") as span:
|
|
||||||
span.set_attribute("key", "value")
|
|
||||||
# do work
|
|
||||||
"""
|
|
||||||
if _tracer is None:
|
|
||||||
# Return a no-op context manager
|
|
||||||
from contextlib import nullcontext
|
|
||||||
|
|
||||||
return nullcontext()
|
|
||||||
return _tracer.start_as_current_span(name, **kwargs)
|
|
||||||
84
health/health.go
Normal file
84
health/health.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
// Package health provides an HTTP server for Kubernetes liveness and readiness probes.
|
||||||
|
package health
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReadyFunc is called to determine if the service is ready. Return true if ready.
|
||||||
|
type ReadyFunc func() bool
|
||||||
|
|
||||||
|
// Server serves /health and /ready endpoints.
|
||||||
|
type Server struct {
|
||||||
|
port int
|
||||||
|
healthPath string
|
||||||
|
readyPath string
|
||||||
|
readyCheck ReadyFunc
|
||||||
|
srv *http.Server
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a health server on the given port.
|
||||||
|
func New(port int, healthPath, readyPath string, readyCheck ReadyFunc) *Server {
|
||||||
|
s := &Server{
|
||||||
|
port: port,
|
||||||
|
healthPath: healthPath,
|
||||||
|
readyPath: readyPath,
|
||||||
|
readyCheck: readyCheck,
|
||||||
|
}
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc(healthPath, s.handleHealth)
|
||||||
|
mux.HandleFunc(readyPath, s.handleReady)
|
||||||
|
s.srv = &http.Server{
|
||||||
|
Addr: fmt.Sprintf(":%d", port),
|
||||||
|
Handler: mux,
|
||||||
|
ReadHeaderTimeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins serving in the background. Call Stop to shut down.
|
||||||
|
func (s *Server) Start() {
|
||||||
|
ln, err := net.Listen("tcp", s.srv.Addr)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("health server listen failed", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
slog.Info("health server started", "port", s.port, "health", s.healthPath, "ready", s.readyPath)
|
||||||
|
go func() {
|
||||||
|
if err := s.srv.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||||
|
slog.Error("health server error", "error", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop gracefully shuts down the server.
|
||||||
|
func (s *Server) Stop(ctx context.Context) {
|
||||||
|
if s.srv != nil {
|
||||||
|
_ = s.srv.Shutdown(ctx)
|
||||||
|
slog.Info("health server stopped")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
writeJSON(w, http.StatusOK, map[string]string{"status": "healthy"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleReady(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
if s.readyCheck != nil && !s.readyCheck() {
|
||||||
|
writeJSON(w, http.StatusServiceUnavailable, map[string]string{"status": "not ready"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, map[string]string{"status": "ready"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeJSON(w http.ResponseWriter, status int, data any) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
_ = json.NewEncoder(w).Encode(data)
|
||||||
|
}
|
||||||
77
health/health_test.go
Normal file
77
health/health_test.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
package health
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHealthEndpoint(t *testing.T) {
|
||||||
|
srv := New(18080, "/health", "/ready", nil)
|
||||||
|
srv.Start()
|
||||||
|
defer srv.Stop(context.Background())
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
resp, err := http.Get("http://localhost:18080/health")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("health request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
t.Errorf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
var data map[string]string
|
||||||
|
_ = json.Unmarshal(body, &data)
|
||||||
|
if data["status"] != "healthy" {
|
||||||
|
t.Errorf("expected status 'healthy', got %q", data["status"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadyEndpointDefault(t *testing.T) {
|
||||||
|
srv := New(18081, "/health", "/ready", nil)
|
||||||
|
srv.Start()
|
||||||
|
defer srv.Stop(context.Background())
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
resp, err := http.Get("http://localhost:18081/ready")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ready request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
t.Errorf("expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadyEndpointNotReady(t *testing.T) {
|
||||||
|
ready := false
|
||||||
|
srv := New(18082, "/health", "/ready", func() bool { return ready })
|
||||||
|
srv.Start()
|
||||||
|
defer srv.Stop(context.Background())
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
resp, err := http.Get("http://localhost:18082/ready")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ready request failed: %v", err)
|
||||||
|
}
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if resp.StatusCode != 503 {
|
||||||
|
t.Errorf("expected 503 when not ready, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
ready = true
|
||||||
|
resp2, err := http.Get("http://localhost:18082/ready")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ready request failed: %v", err)
|
||||||
|
}
|
||||||
|
_ = resp2.Body.Close()
|
||||||
|
if resp2.StatusCode != 200 {
|
||||||
|
t.Errorf("expected 200 when ready, got %d", resp2.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
515
messages/bench_test.go
Normal file
515
messages/bench_test.go
Normal file
@@ -0,0 +1,515 @@
|
|||||||
|
// Package messages benchmarks compare three serialization strategies:
|
||||||
|
//
|
||||||
|
// 1. msgpack map[string]any — the old approach (dynamic, no types)
|
||||||
|
// 2. msgpack typed struct — the new approach (compile-time safe, short keys)
|
||||||
|
// 3. protobuf — optional future migration
|
||||||
|
//
|
||||||
|
// Run with:
|
||||||
|
//
|
||||||
|
// go test -bench=. -benchmem -count=5 ./messages/... | tee bench.txt
|
||||||
|
// # optional: go install golang.org/x/perf/cmd/benchstat@latest && benchstat bench.txt
|
||||||
|
package messages
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/vmihailenco/msgpack/v5"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
|
pb "git.daviestechlabs.io/daviestechlabs/handler-base/messages/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Test fixtures — equivalent data across all three encodings
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// chatRequestMap is the legacy map[string]any representation.
|
||||||
|
func chatRequestMap() map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"request_id": "req-abc-123",
|
||||||
|
"user_id": "user-42",
|
||||||
|
"message": "What is the capital of France?",
|
||||||
|
"query": "",
|
||||||
|
"premium": true,
|
||||||
|
"enable_rag": true,
|
||||||
|
"enable_reranker": true,
|
||||||
|
"enable_streaming": false,
|
||||||
|
"top_k": 10,
|
||||||
|
"collection": "documents",
|
||||||
|
"enable_tts": false,
|
||||||
|
"system_prompt": "You are a helpful assistant.",
|
||||||
|
"response_subject": "ai.chat.response.req-abc-123",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatRequestStruct is the typed struct representation.
|
||||||
|
func chatRequestStruct() ChatRequest {
|
||||||
|
return ChatRequest{
|
||||||
|
RequestID: "req-abc-123",
|
||||||
|
UserID: "user-42",
|
||||||
|
Message: "What is the capital of France?",
|
||||||
|
Premium: true,
|
||||||
|
EnableRAG: true,
|
||||||
|
EnableReranker: true,
|
||||||
|
TopK: 10,
|
||||||
|
Collection: "documents",
|
||||||
|
SystemPrompt: "You are a helpful assistant.",
|
||||||
|
ResponseSubject: "ai.chat.response.req-abc-123",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatRequestProto is the protobuf representation.
|
||||||
|
func chatRequestProto() *pb.ChatRequest {
|
||||||
|
return &pb.ChatRequest{
|
||||||
|
RequestId: "req-abc-123",
|
||||||
|
UserId: "user-42",
|
||||||
|
Message: "What is the capital of France?",
|
||||||
|
Premium: true,
|
||||||
|
EnableRag: true,
|
||||||
|
EnableReranker: true,
|
||||||
|
TopK: 10,
|
||||||
|
Collection: "documents",
|
||||||
|
SystemPrompt: "You are a helpful assistant.",
|
||||||
|
ResponseSubject: "ai.chat.response.req-abc-123",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// voiceResponseMap is a voice response with a 16 KB audio payload.
|
||||||
|
func voiceResponseMap() map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"request_id": "vr-001",
|
||||||
|
"response": "The capital of France is Paris.",
|
||||||
|
"audio": make([]byte, 16384),
|
||||||
|
"transcription": "What is the capital of France?",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func voiceResponseStruct() VoiceResponse {
|
||||||
|
return VoiceResponse{
|
||||||
|
RequestID: "vr-001",
|
||||||
|
Response: "The capital of France is Paris.",
|
||||||
|
Audio: make([]byte, 16384),
|
||||||
|
Transcription: "What is the capital of France?",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func voiceResponseProto() *pb.VoiceResponse {
|
||||||
|
return &pb.VoiceResponse{
|
||||||
|
RequestId: "vr-001",
|
||||||
|
Response: "The capital of France is Paris.",
|
||||||
|
Audio: make([]byte, 16384),
|
||||||
|
Transcription: "What is the capital of France?",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ttsChunkMap simulates a streaming audio chunk (~32 KB).
|
||||||
|
func ttsChunkMap() map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"session_id": "tts-sess-99",
|
||||||
|
"chunk_index": 3,
|
||||||
|
"total_chunks": 12,
|
||||||
|
"audio_b64": string(make([]byte, 32768)), // old: base64 string
|
||||||
|
"is_last": false,
|
||||||
|
"timestamp": time.Now().Unix(),
|
||||||
|
"sample_rate": 24000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ttsChunkStruct() TTSAudioChunk {
|
||||||
|
return TTSAudioChunk{
|
||||||
|
SessionID: "tts-sess-99",
|
||||||
|
ChunkIndex: 3,
|
||||||
|
TotalChunks: 12,
|
||||||
|
Audio: make([]byte, 32768), // new: raw bytes
|
||||||
|
IsLast: false,
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
SampleRate: 24000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ttsChunkProto() *pb.TTSAudioChunk {
|
||||||
|
return &pb.TTSAudioChunk{
|
||||||
|
SessionId: "tts-sess-99",
|
||||||
|
ChunkIndex: 3,
|
||||||
|
TotalChunks: 12,
|
||||||
|
Audio: make([]byte, 32768),
|
||||||
|
IsLast: false,
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
SampleRate: 24000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Wire-size comparison (run once, printed by TestWireSize)
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestWireSize(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
mapData any
|
||||||
|
structVal any
|
||||||
|
protoMsg proto.Message
|
||||||
|
}{
|
||||||
|
{"ChatRequest", chatRequestMap(), chatRequestStruct(), chatRequestProto()},
|
||||||
|
{"VoiceResponse", voiceResponseMap(), voiceResponseStruct(), voiceResponseProto()},
|
||||||
|
{"TTSAudioChunk", ttsChunkMap(), ttsChunkStruct(), ttsChunkProto()},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
mapBytes, _ := msgpack.Marshal(tt.mapData)
|
||||||
|
structBytes, _ := msgpack.Marshal(tt.structVal)
|
||||||
|
protoBytes, _ := proto.Marshal(tt.protoMsg)
|
||||||
|
|
||||||
|
t.Logf("%-16s map=%5d B struct=%5d B proto=%5d B (struct saves %.0f%%, proto saves %.0f%%)",
|
||||||
|
tt.name,
|
||||||
|
len(mapBytes), len(structBytes), len(protoBytes),
|
||||||
|
100*(1-float64(len(structBytes))/float64(len(mapBytes))),
|
||||||
|
100*(1-float64(len(protoBytes))/float64(len(mapBytes))),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Encode benchmarks
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func BenchmarkEncode_ChatRequest_MsgpackMap(b *testing.B) {
|
||||||
|
data := chatRequestMap()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
_, _ = msgpack.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncode_ChatRequest_MsgpackStruct(b *testing.B) {
|
||||||
|
data := chatRequestStruct()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
_, _ = msgpack.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncode_ChatRequest_Protobuf(b *testing.B) {
|
||||||
|
data := chatRequestProto()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
_, _ = proto.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncode_VoiceResponse_MsgpackMap(b *testing.B) {
|
||||||
|
data := voiceResponseMap()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
_, _ = msgpack.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncode_VoiceResponse_MsgpackStruct(b *testing.B) {
|
||||||
|
data := voiceResponseStruct()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
_, _ = msgpack.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncode_VoiceResponse_Protobuf(b *testing.B) {
|
||||||
|
data := voiceResponseProto()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
_, _ = proto.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncode_TTSChunk_MsgpackMap(b *testing.B) {
|
||||||
|
data := ttsChunkMap()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
_, _ = msgpack.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncode_TTSChunk_MsgpackStruct(b *testing.B) {
|
||||||
|
data := ttsChunkStruct()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
_, _ = msgpack.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncode_TTSChunk_Protobuf(b *testing.B) {
|
||||||
|
data := ttsChunkProto()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
_, _ = proto.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Decode benchmarks
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func BenchmarkDecode_ChatRequest_MsgpackMap(b *testing.B) {
|
||||||
|
encoded, _ := msgpack.Marshal(chatRequestMap())
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
var m map[string]any
|
||||||
|
_ = msgpack.Unmarshal(encoded, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecode_ChatRequest_MsgpackStruct(b *testing.B) {
|
||||||
|
encoded, _ := msgpack.Marshal(chatRequestStruct())
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
var m ChatRequest
|
||||||
|
_ = msgpack.Unmarshal(encoded, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecode_ChatRequest_Protobuf(b *testing.B) {
|
||||||
|
encoded, _ := proto.Marshal(chatRequestProto())
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
var m pb.ChatRequest
|
||||||
|
_ = proto.Unmarshal(encoded, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecode_VoiceResponse_MsgpackMap(b *testing.B) {
|
||||||
|
encoded, _ := msgpack.Marshal(voiceResponseMap())
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
var m map[string]any
|
||||||
|
_ = msgpack.Unmarshal(encoded, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecode_VoiceResponse_MsgpackStruct(b *testing.B) {
|
||||||
|
encoded, _ := msgpack.Marshal(voiceResponseStruct())
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
var m VoiceResponse
|
||||||
|
_ = msgpack.Unmarshal(encoded, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecode_VoiceResponse_Protobuf(b *testing.B) {
|
||||||
|
encoded, _ := proto.Marshal(voiceResponseProto())
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
var m pb.VoiceResponse
|
||||||
|
_ = proto.Unmarshal(encoded, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecode_TTSChunk_MsgpackMap(b *testing.B) {
|
||||||
|
encoded, _ := msgpack.Marshal(ttsChunkMap())
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
var m map[string]any
|
||||||
|
_ = msgpack.Unmarshal(encoded, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecode_TTSChunk_MsgpackStruct(b *testing.B) {
|
||||||
|
encoded, _ := msgpack.Marshal(ttsChunkStruct())
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
var m TTSAudioChunk
|
||||||
|
_ = msgpack.Unmarshal(encoded, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecode_TTSChunk_Protobuf(b *testing.B) {
|
||||||
|
encoded, _ := proto.Marshal(ttsChunkProto())
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
var m pb.TTSAudioChunk
|
||||||
|
_ = proto.Unmarshal(encoded, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Roundtrip benchmarks (encode + decode)
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func BenchmarkRoundtrip_ChatRequest_MsgpackMap(b *testing.B) {
|
||||||
|
data := chatRequestMap()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
enc, _ := msgpack.Marshal(data)
|
||||||
|
var dec map[string]any
|
||||||
|
_ = msgpack.Unmarshal(enc, &dec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRoundtrip_ChatRequest_MsgpackStruct(b *testing.B) {
|
||||||
|
data := chatRequestStruct()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
enc, _ := msgpack.Marshal(data)
|
||||||
|
var dec ChatRequest
|
||||||
|
_ = msgpack.Unmarshal(enc, &dec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRoundtrip_ChatRequest_Protobuf(b *testing.B) {
|
||||||
|
data := chatRequestProto()
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
enc, _ := proto.Marshal(data)
|
||||||
|
var dec pb.ChatRequest
|
||||||
|
_ = proto.Unmarshal(enc, &dec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Typed struct unit tests — verify roundtrip correctness
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestRoundtrip_ChatRequest(t *testing.T) {
|
||||||
|
orig := chatRequestStruct()
|
||||||
|
data, err := msgpack.Marshal(orig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var dec ChatRequest
|
||||||
|
if err := msgpack.Unmarshal(data, &dec); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if dec.RequestID != orig.RequestID {
|
||||||
|
t.Errorf("RequestID = %q, want %q", dec.RequestID, orig.RequestID)
|
||||||
|
}
|
||||||
|
if dec.Message != orig.Message {
|
||||||
|
t.Errorf("Message = %q, want %q", dec.Message, orig.Message)
|
||||||
|
}
|
||||||
|
if dec.TopK != orig.TopK {
|
||||||
|
t.Errorf("TopK = %d, want %d", dec.TopK, orig.TopK)
|
||||||
|
}
|
||||||
|
if dec.Premium != orig.Premium {
|
||||||
|
t.Errorf("Premium = %v, want %v", dec.Premium, orig.Premium)
|
||||||
|
}
|
||||||
|
if dec.EffectiveQuery() != orig.Message {
|
||||||
|
t.Errorf("EffectiveQuery() = %q, want %q", dec.EffectiveQuery(), orig.Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundtrip_VoiceResponse(t *testing.T) {
|
||||||
|
orig := voiceResponseStruct()
|
||||||
|
data, err := msgpack.Marshal(orig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var dec VoiceResponse
|
||||||
|
if err := msgpack.Unmarshal(data, &dec); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if dec.RequestID != orig.RequestID {
|
||||||
|
t.Errorf("RequestID mismatch")
|
||||||
|
}
|
||||||
|
if len(dec.Audio) != len(orig.Audio) {
|
||||||
|
t.Errorf("Audio len = %d, want %d", len(dec.Audio), len(orig.Audio))
|
||||||
|
}
|
||||||
|
if dec.Transcription != orig.Transcription {
|
||||||
|
t.Errorf("Transcription mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundtrip_TTSAudioChunk(t *testing.T) {
|
||||||
|
orig := ttsChunkStruct()
|
||||||
|
data, err := msgpack.Marshal(orig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var dec TTSAudioChunk
|
||||||
|
if err := msgpack.Unmarshal(data, &dec); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if dec.SessionID != orig.SessionID {
|
||||||
|
t.Errorf("SessionID mismatch")
|
||||||
|
}
|
||||||
|
if dec.ChunkIndex != orig.ChunkIndex {
|
||||||
|
t.Errorf("ChunkIndex = %d, want %d", dec.ChunkIndex, orig.ChunkIndex)
|
||||||
|
}
|
||||||
|
if len(dec.Audio) != len(orig.Audio) {
|
||||||
|
t.Errorf("Audio len = %d, want %d", len(dec.Audio), len(orig.Audio))
|
||||||
|
}
|
||||||
|
if dec.SampleRate != orig.SampleRate {
|
||||||
|
t.Errorf("SampleRate = %d, want %d", dec.SampleRate, orig.SampleRate)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundtrip_PipelineTrigger(t *testing.T) {
|
||||||
|
orig := PipelineTrigger{
|
||||||
|
RequestID: "pip-001",
|
||||||
|
Pipeline: "document-ingestion",
|
||||||
|
Parameters: map[string]any{"source": "s3://bucket/data"},
|
||||||
|
}
|
||||||
|
data, err := msgpack.Marshal(orig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var dec PipelineTrigger
|
||||||
|
if err := msgpack.Unmarshal(data, &dec); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if dec.Pipeline != orig.Pipeline {
|
||||||
|
t.Errorf("Pipeline = %q, want %q", dec.Pipeline, orig.Pipeline)
|
||||||
|
}
|
||||||
|
if dec.Parameters["source"] != orig.Parameters["source"] {
|
||||||
|
t.Errorf("Parameters[source] mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundtrip_STTTranscription(t *testing.T) {
|
||||||
|
orig := STTTranscription{
|
||||||
|
SessionID: "stt-001",
|
||||||
|
Transcript: "hello world",
|
||||||
|
Sequence: 5,
|
||||||
|
IsPartial: false,
|
||||||
|
IsFinal: true,
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
SpeakerID: "speaker-1",
|
||||||
|
HasVoiceActivity: true,
|
||||||
|
State: "listening",
|
||||||
|
}
|
||||||
|
data, err := msgpack.Marshal(orig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var dec STTTranscription
|
||||||
|
if err := msgpack.Unmarshal(data, &dec); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if dec.Transcript != orig.Transcript {
|
||||||
|
t.Errorf("Transcript = %q, want %q", dec.Transcript, orig.Transcript)
|
||||||
|
}
|
||||||
|
if dec.IsFinal != orig.IsFinal {
|
||||||
|
t.Error("IsFinal mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundtrip_ErrorResponse(t *testing.T) {
|
||||||
|
orig := ErrorResponse{Error: true, Message: "something broke", Type: "InternalError"}
|
||||||
|
data, err := msgpack.Marshal(orig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var dec ErrorResponse
|
||||||
|
if err := msgpack.Unmarshal(data, &dec); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !dec.Error || dec.Message != "something broke" || dec.Type != "InternalError" {
|
||||||
|
t.Errorf("ErrorResponse roundtrip mismatch: %+v", dec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTimestamp(t *testing.T) {
|
||||||
|
ts := Timestamp()
|
||||||
|
now := time.Now().Unix()
|
||||||
|
if ts < now-1 || ts > now+1 {
|
||||||
|
t.Errorf("Timestamp() = %d, expected ~%d", ts, now)
|
||||||
|
}
|
||||||
|
}
|
||||||
224
messages/messages.go
Normal file
224
messages/messages.go
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
// Package messages defines typed NATS message structs for all services.
|
||||||
|
//
|
||||||
|
// Using typed structs with short msgpack field tags instead of map[string]any
|
||||||
|
// provides compile-time safety, smaller wire size (integer-like short keys vs
|
||||||
|
// full string keys), and faster encode/decode by avoiding interface{} boxing.
|
||||||
|
//
|
||||||
|
// Audio data uses raw []byte instead of base64-encoded strings — msgpack
|
||||||
|
// supports binary natively, eliminating the 33% base64 overhead.
|
||||||
|
package messages
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Pipeline Bridge
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// PipelineTrigger is the request to start a pipeline.
|
||||||
|
type PipelineTrigger struct {
|
||||||
|
RequestID string `msgpack:"request_id" json:"request_id"`
|
||||||
|
Pipeline string `msgpack:"pipeline" json:"pipeline"`
|
||||||
|
Parameters map[string]any `msgpack:"parameters,omitempty" json:"parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PipelineStatus is the response / status update for a pipeline run.
|
||||||
|
type PipelineStatus struct {
|
||||||
|
RequestID string `msgpack:"request_id" json:"request_id"`
|
||||||
|
Status string `msgpack:"status" json:"status"`
|
||||||
|
RunID string `msgpack:"run_id,omitempty" json:"run_id,omitempty"`
|
||||||
|
Engine string `msgpack:"engine,omitempty" json:"engine,omitempty"`
|
||||||
|
Pipeline string `msgpack:"pipeline,omitempty" json:"pipeline,omitempty"`
|
||||||
|
SubmittedAt string `msgpack:"submitted_at,omitempty" json:"submitted_at,omitempty"`
|
||||||
|
Error string `msgpack:"error,omitempty" json:"error,omitempty"`
|
||||||
|
AvailablePipelines []string `msgpack:"available_pipelines,omitempty" json:"available_pipelines,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Chat Handler
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// ChatRequest is an incoming chat message.
|
||||||
|
type ChatRequest struct {
|
||||||
|
RequestID string `msgpack:"request_id" json:"request_id"`
|
||||||
|
UserID string `msgpack:"user_id" json:"user_id"`
|
||||||
|
Message string `msgpack:"message" json:"message"`
|
||||||
|
Query string `msgpack:"query,omitempty" json:"query,omitempty"`
|
||||||
|
Premium bool `msgpack:"premium,omitempty" json:"premium,omitempty"`
|
||||||
|
EnableRAG bool `msgpack:"enable_rag,omitempty" json:"enable_rag,omitempty"`
|
||||||
|
EnableReranker bool `msgpack:"enable_reranker,omitempty" json:"enable_reranker,omitempty"`
|
||||||
|
EnableStreaming bool `msgpack:"enable_streaming,omitempty" json:"enable_streaming,omitempty"`
|
||||||
|
TopK int `msgpack:"top_k,omitempty" json:"top_k,omitempty"`
|
||||||
|
Collection string `msgpack:"collection,omitempty" json:"collection,omitempty"`
|
||||||
|
EnableTTS bool `msgpack:"enable_tts,omitempty" json:"enable_tts,omitempty"`
|
||||||
|
SystemPrompt string `msgpack:"system_prompt,omitempty" json:"system_prompt,omitempty"`
|
||||||
|
ResponseSubject string `msgpack:"response_subject,omitempty" json:"response_subject,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EffectiveQuery returns Message or falls back to Query.
|
||||||
|
func (c *ChatRequest) EffectiveQuery() string {
|
||||||
|
if c.Message != "" {
|
||||||
|
return c.Message
|
||||||
|
}
|
||||||
|
return c.Query
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatResponse is the full reply to a chat request.
|
||||||
|
type ChatResponse struct {
|
||||||
|
UserID string `msgpack:"user_id" json:"user_id"`
|
||||||
|
Response string `msgpack:"response" json:"response"`
|
||||||
|
ResponseText string `msgpack:"response_text" json:"response_text"`
|
||||||
|
UsedRAG bool `msgpack:"used_rag" json:"used_rag"`
|
||||||
|
RAGSources []string `msgpack:"rag_sources,omitempty" json:"rag_sources,omitempty"`
|
||||||
|
Success bool `msgpack:"success" json:"success"`
|
||||||
|
Audio []byte `msgpack:"audio,omitempty" json:"audio,omitempty"`
|
||||||
|
Error string `msgpack:"error,omitempty" json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatStreamChunk is a single streaming chunk from an LLM response.
|
||||||
|
type ChatStreamChunk struct {
|
||||||
|
RequestID string `msgpack:"request_id" json:"request_id"`
|
||||||
|
Type string `msgpack:"type" json:"type"`
|
||||||
|
Content string `msgpack:"content" json:"content"`
|
||||||
|
Done bool `msgpack:"done" json:"done"`
|
||||||
|
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Voice Assistant
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// VoiceRequest is an incoming voice-to-voice request.
|
||||||
|
type VoiceRequest struct {
|
||||||
|
RequestID string `msgpack:"request_id" json:"request_id"`
|
||||||
|
Audio []byte `msgpack:"audio" json:"audio"`
|
||||||
|
Language string `msgpack:"language,omitempty" json:"language,omitempty"`
|
||||||
|
Collection string `msgpack:"collection,omitempty" json:"collection,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// VoiceResponse is the reply to a voice request.
|
||||||
|
type VoiceResponse struct {
|
||||||
|
RequestID string `msgpack:"request_id" json:"request_id"`
|
||||||
|
Response string `msgpack:"response" json:"response"`
|
||||||
|
Audio []byte `msgpack:"audio" json:"audio"`
|
||||||
|
Transcription string `msgpack:"transcription,omitempty" json:"transcription,omitempty"`
|
||||||
|
Sources []DocumentSource `msgpack:"sources,omitempty" json:"sources,omitempty"`
|
||||||
|
Error string `msgpack:"error,omitempty" json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DocumentSource is a RAG search result source.
|
||||||
|
type DocumentSource struct {
|
||||||
|
Text string `msgpack:"text" json:"text"`
|
||||||
|
Score float64 `msgpack:"score" json:"score"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// TTS Module
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// TTSRequest is a text-to-speech synthesis request.
|
||||||
|
type TTSRequest struct {
|
||||||
|
Text string `msgpack:"text" json:"text"`
|
||||||
|
Speaker string `msgpack:"speaker,omitempty" json:"speaker,omitempty"`
|
||||||
|
Language string `msgpack:"language,omitempty" json:"language,omitempty"`
|
||||||
|
SpeakerWavB64 string `msgpack:"speaker_wav_b64,omitempty" json:"speaker_wav_b64,omitempty"`
|
||||||
|
Stream bool `msgpack:"stream,omitempty" json:"stream,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTSAudioChunk is a streamed audio chunk from TTS synthesis.
|
||||||
|
type TTSAudioChunk struct {
|
||||||
|
SessionID string `msgpack:"session_id" json:"session_id"`
|
||||||
|
ChunkIndex int `msgpack:"chunk_index" json:"chunk_index"`
|
||||||
|
TotalChunks int `msgpack:"total_chunks" json:"total_chunks"`
|
||||||
|
Audio []byte `msgpack:"audio" json:"audio"`
|
||||||
|
IsLast bool `msgpack:"is_last" json:"is_last"`
|
||||||
|
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
||||||
|
SampleRate int `msgpack:"sample_rate" json:"sample_rate"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTSFullResponse is a non-streamed TTS response (whole audio).
|
||||||
|
type TTSFullResponse struct {
|
||||||
|
SessionID string `msgpack:"session_id" json:"session_id"`
|
||||||
|
Audio []byte `msgpack:"audio" json:"audio"`
|
||||||
|
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
||||||
|
SampleRate int `msgpack:"sample_rate" json:"sample_rate"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTSStatus is a TTS processing status update.
|
||||||
|
type TTSStatus struct {
|
||||||
|
SessionID string `msgpack:"session_id" json:"session_id"`
|
||||||
|
Status string `msgpack:"status" json:"status"`
|
||||||
|
Message string `msgpack:"message" json:"message"`
|
||||||
|
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTSVoiceListResponse is the reply to a voice list request.
|
||||||
|
type TTSVoiceListResponse struct {
|
||||||
|
DefaultSpeaker string `msgpack:"default_speaker" json:"default_speaker"`
|
||||||
|
CustomVoices []TTSVoiceInfo `msgpack:"custom_voices" json:"custom_voices"`
|
||||||
|
LastRefresh int64 `msgpack:"last_refresh" json:"last_refresh"`
|
||||||
|
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTSVoiceInfo is summary info about a custom voice.
|
||||||
|
type TTSVoiceInfo struct {
|
||||||
|
Name string `msgpack:"name" json:"name"`
|
||||||
|
Language string `msgpack:"language" json:"language"`
|
||||||
|
ModelType string `msgpack:"model_type" json:"model_type"`
|
||||||
|
CreatedAt string `msgpack:"created_at" json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTSVoiceRefreshResponse is the reply to a voice refresh request.
|
||||||
|
type TTSVoiceRefreshResponse struct {
|
||||||
|
Count int `msgpack:"count" json:"count"`
|
||||||
|
CustomVoices []TTSVoiceInfo `msgpack:"custom_voices" json:"custom_voices"`
|
||||||
|
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// STT Module
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// STTStreamMessage is any message on the ai.voice.stream.{session} subject.
|
||||||
|
type STTStreamMessage struct {
|
||||||
|
Type string `msgpack:"type" json:"type"`
|
||||||
|
Audio []byte `msgpack:"audio,omitempty" json:"audio,omitempty"`
|
||||||
|
State string `msgpack:"state,omitempty" json:"state,omitempty"`
|
||||||
|
SpeakerID string `msgpack:"speaker_id,omitempty" json:"speaker_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// STTTranscription is the transcription result published by the STT module.
|
||||||
|
type STTTranscription struct {
|
||||||
|
SessionID string `msgpack:"session_id" json:"session_id"`
|
||||||
|
Transcript string `msgpack:"transcript" json:"transcript"`
|
||||||
|
Sequence int `msgpack:"sequence" json:"sequence"`
|
||||||
|
IsPartial bool `msgpack:"is_partial" json:"is_partial"`
|
||||||
|
IsFinal bool `msgpack:"is_final" json:"is_final"`
|
||||||
|
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
||||||
|
SpeakerID string `msgpack:"speaker_id" json:"speaker_id"`
|
||||||
|
HasVoiceActivity bool `msgpack:"has_voice_activity" json:"has_voice_activity"`
|
||||||
|
State string `msgpack:"state" json:"state"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// STTInterrupt is published when the STT module detects a user interrupt.
|
||||||
|
type STTInterrupt struct {
|
||||||
|
SessionID string `msgpack:"session_id" json:"session_id"`
|
||||||
|
Type string `msgpack:"type" json:"type"`
|
||||||
|
Timestamp int64 `msgpack:"timestamp" json:"timestamp"`
|
||||||
|
SpeakerID string `msgpack:"speaker_id" json:"speaker_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Common / Error
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// ErrorResponse is the standard error reply from any handler.
|
||||||
|
type ErrorResponse struct {
|
||||||
|
Error bool `msgpack:"error" json:"error"`
|
||||||
|
Message string `msgpack:"message" json:"message"`
|
||||||
|
Type string `msgpack:"type" json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timestamp returns the current Unix timestamp (helper for message construction).
|
||||||
|
func Timestamp() int64 {
|
||||||
|
return time.Now().Unix()
|
||||||
|
}
|
||||||
1738
messages/proto/messages.pb.go
Normal file
1738
messages/proto/messages.pb.go
Normal file
File diff suppressed because it is too large
Load Diff
174
messages/proto/messages.proto
Normal file
174
messages/proto/messages.proto
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package messages;
|
||||||
|
|
||||||
|
option go_package = "git.daviestechlabs.io/daviestechlabs/handler-base/messages/proto";
|
||||||
|
|
||||||
|
// ── Pipeline Bridge ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
message PipelineTrigger {
|
||||||
|
string request_id = 1;
|
||||||
|
string pipeline = 2;
|
||||||
|
map<string, string> parameters = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message PipelineStatus {
|
||||||
|
string request_id = 1;
|
||||||
|
string status = 2;
|
||||||
|
string run_id = 3;
|
||||||
|
string engine = 4;
|
||||||
|
string pipeline = 5;
|
||||||
|
string submitted_at = 6;
|
||||||
|
string error = 7;
|
||||||
|
repeated string available_pipelines = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Chat Handler ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
message ChatRequest {
|
||||||
|
string request_id = 1;
|
||||||
|
string user_id = 2;
|
||||||
|
string message = 3;
|
||||||
|
string query = 4;
|
||||||
|
bool premium = 5;
|
||||||
|
bool enable_rag = 6;
|
||||||
|
bool enable_reranker = 7;
|
||||||
|
bool enable_streaming = 8;
|
||||||
|
int32 top_k = 9;
|
||||||
|
string collection = 10;
|
||||||
|
bool enable_tts = 11;
|
||||||
|
string system_prompt = 12;
|
||||||
|
string response_subject = 13;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ChatResponse {
|
||||||
|
string user_id = 1;
|
||||||
|
string response = 2;
|
||||||
|
string response_text = 3;
|
||||||
|
bool used_rag = 4;
|
||||||
|
repeated string rag_sources = 5;
|
||||||
|
bool success = 6;
|
||||||
|
bytes audio = 7;
|
||||||
|
string error = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ChatStreamChunk {
|
||||||
|
string request_id = 1;
|
||||||
|
string type = 2;
|
||||||
|
string content = 3;
|
||||||
|
bool done = 4;
|
||||||
|
int64 timestamp = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Voice Assistant ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
message VoiceRequest {
|
||||||
|
string request_id = 1;
|
||||||
|
bytes audio = 2;
|
||||||
|
string language = 3;
|
||||||
|
string collection = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message VoiceResponse {
|
||||||
|
string request_id = 1;
|
||||||
|
string response = 2;
|
||||||
|
bytes audio = 3;
|
||||||
|
string transcription = 4;
|
||||||
|
repeated DocumentSource sources = 5;
|
||||||
|
string error = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
message DocumentSource {
|
||||||
|
string text = 1;
|
||||||
|
double score = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── TTS Module ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
message TTSRequest {
|
||||||
|
string text = 1;
|
||||||
|
string speaker = 2;
|
||||||
|
string language = 3;
|
||||||
|
string speaker_wav_b64 = 4;
|
||||||
|
bool stream = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TTSAudioChunk {
|
||||||
|
string session_id = 1;
|
||||||
|
int32 chunk_index = 2;
|
||||||
|
int32 total_chunks = 3;
|
||||||
|
bytes audio = 4;
|
||||||
|
bool is_last = 5;
|
||||||
|
int64 timestamp = 6;
|
||||||
|
int32 sample_rate = 7;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TTSFullResponse {
|
||||||
|
string session_id = 1;
|
||||||
|
bytes audio = 2;
|
||||||
|
int64 timestamp = 3;
|
||||||
|
int32 sample_rate = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TTSStatus {
|
||||||
|
string session_id = 1;
|
||||||
|
string status = 2;
|
||||||
|
string message = 3;
|
||||||
|
int64 timestamp = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TTSVoiceInfo {
|
||||||
|
string name = 1;
|
||||||
|
string language = 2;
|
||||||
|
string model_type = 3;
|
||||||
|
string created_at = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TTSVoiceListResponse {
|
||||||
|
string default_speaker = 1;
|
||||||
|
repeated TTSVoiceInfo custom_voices = 2;
|
||||||
|
int64 last_refresh = 3;
|
||||||
|
int64 timestamp = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TTSVoiceRefreshResponse {
|
||||||
|
int32 count = 1;
|
||||||
|
repeated TTSVoiceInfo custom_voices = 2;
|
||||||
|
int64 timestamp = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── STT Module ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
message STTStreamMessage {
|
||||||
|
string type = 1;
|
||||||
|
bytes audio = 2;
|
||||||
|
string state = 3;
|
||||||
|
string speaker_id = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message STTTranscription {
|
||||||
|
string session_id = 1;
|
||||||
|
string transcript = 2;
|
||||||
|
int32 sequence = 3;
|
||||||
|
bool is_partial = 4;
|
||||||
|
bool is_final = 5;
|
||||||
|
int64 timestamp = 6;
|
||||||
|
string speaker_id = 7;
|
||||||
|
bool has_voice_activity = 8;
|
||||||
|
string state = 9;
|
||||||
|
}
|
||||||
|
|
||||||
|
message STTInterrupt {
|
||||||
|
string session_id = 1;
|
||||||
|
string type = 2;
|
||||||
|
int64 timestamp = 3;
|
||||||
|
string speaker_id = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Common ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
message ErrorResponse {
|
||||||
|
bool error = 1;
|
||||||
|
string message = 2;
|
||||||
|
string type = 3;
|
||||||
|
}
|
||||||
142
natsutil/natsutil.go
Normal file
142
natsutil/natsutil.go
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
// Package natsutil provides a NATS/JetStream client with msgpack serialization.
|
||||||
|
package natsutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/nats-io/nats.go"
|
||||||
|
"github.com/vmihailenco/msgpack/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client wraps a NATS connection with msgpack helpers.
|
||||||
|
type Client struct {
|
||||||
|
nc *nats.Conn
|
||||||
|
js nats.JetStreamContext
|
||||||
|
subs []*nats.Subscription
|
||||||
|
url string
|
||||||
|
opts []nats.Option
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a NATS client configured to connect to the given URL.
|
||||||
|
// Optional NATS options (e.g. credentials) can be appended.
|
||||||
|
func New(url string, opts ...nats.Option) *Client {
|
||||||
|
defaults := []nats.Option{
|
||||||
|
nats.ReconnectWait(2 * time.Second),
|
||||||
|
nats.MaxReconnects(-1),
|
||||||
|
nats.DisconnectErrHandler(func(_ *nats.Conn, err error) {
|
||||||
|
slog.Warn("NATS disconnected", "error", err)
|
||||||
|
}),
|
||||||
|
nats.ReconnectHandler(func(_ *nats.Conn) {
|
||||||
|
slog.Info("NATS reconnected")
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
return &Client{
|
||||||
|
url: url,
|
||||||
|
opts: append(defaults, opts...),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect establishes the NATS connection and JetStream context.
|
||||||
|
func (c *Client) Connect() error {
|
||||||
|
nc, err := nats.Connect(c.url, c.opts...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("nats connect: %w", err)
|
||||||
|
}
|
||||||
|
js, err := nc.JetStream()
|
||||||
|
if err != nil {
|
||||||
|
nc.Close()
|
||||||
|
return fmt.Errorf("jetstream: %w", err)
|
||||||
|
}
|
||||||
|
c.nc = nc
|
||||||
|
c.js = js
|
||||||
|
slog.Info("connected to NATS", "url", c.url)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close drains subscriptions and closes the connection.
|
||||||
|
func (c *Client) Close() {
|
||||||
|
if c.nc == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, sub := range c.subs {
|
||||||
|
_ = sub.Drain()
|
||||||
|
}
|
||||||
|
c.nc.Close()
|
||||||
|
slog.Info("NATS connection closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Conn returns the underlying *nats.Conn.
|
||||||
|
func (c *Client) Conn() *nats.Conn { return c.nc }
|
||||||
|
|
||||||
|
// JS returns the JetStream context.
|
||||||
|
func (c *Client) JS() nats.JetStreamContext { return c.js }
|
||||||
|
|
||||||
|
// IsConnected returns true if the NATS connection is active.
|
||||||
|
func (c *Client) IsConnected() bool {
|
||||||
|
return c.nc != nil && c.nc.IsConnected()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe subscribes to a subject with an optional queue group.
|
||||||
|
// The handler receives the raw *nats.Msg.
|
||||||
|
func (c *Client) Subscribe(subject string, handler nats.MsgHandler, queue string) error {
|
||||||
|
var sub *nats.Subscription
|
||||||
|
var err error
|
||||||
|
if queue != "" {
|
||||||
|
sub, err = c.nc.QueueSubscribe(subject, queue, handler)
|
||||||
|
slog.Info("subscribed", "subject", subject, "queue", queue)
|
||||||
|
} else {
|
||||||
|
sub, err = c.nc.Subscribe(subject, handler)
|
||||||
|
slog.Info("subscribed", "subject", subject)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("subscribe %s: %w", subject, err)
|
||||||
|
}
|
||||||
|
c.subs = append(c.subs, sub)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish encodes data as msgpack and publishes to the subject.
|
||||||
|
func (c *Client) Publish(subject string, data any) error {
|
||||||
|
payload, err := msgpack.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("msgpack marshal: %w", err)
|
||||||
|
}
|
||||||
|
return c.nc.Publish(subject, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request sends a msgpack-encoded request and decodes the response into result.
|
||||||
|
func (c *Client) Request(subject string, data any, result any, timeout time.Duration) error {
|
||||||
|
payload, err := msgpack.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("msgpack marshal: %w", err)
|
||||||
|
}
|
||||||
|
msg, err := c.nc.Request(subject, payload, timeout)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("nats request: %w", err)
|
||||||
|
}
|
||||||
|
return msgpack.Unmarshal(msg.Data, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeMsgpack decodes msgpack-encoded NATS message data into dest.
|
||||||
|
func DecodeMsgpack(msg *nats.Msg, dest any) error {
|
||||||
|
return msgpack.Unmarshal(msg.Data, dest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode is a generic helper that unmarshals msgpack bytes into T.
|
||||||
|
// Usage: req, err := natsutil.Decode[messages.ChatRequest](msg.Data)
|
||||||
|
func Decode[T any](data []byte) (T, error) {
|
||||||
|
var v T
|
||||||
|
err := msgpack.Unmarshal(data, &v)
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeMsgpackMap decodes msgpack data into a generic map.
|
||||||
|
func DecodeMsgpackMap(data []byte) (map[string]any, error) {
|
||||||
|
var m map[string]any
|
||||||
|
if err := msgpack.Unmarshal(data, &m); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
256
natsutil/natsutil_test.go
Normal file
256
natsutil/natsutil_test.go
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
package natsutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/vmihailenco/msgpack/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// DecodeMsgpackMap tests
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestDecodeMsgpackMap_Roundtrip(t *testing.T) {
|
||||||
|
orig := map[string]any{
|
||||||
|
"request_id": "req-001",
|
||||||
|
"user_id": "user-42",
|
||||||
|
"premium": true,
|
||||||
|
"top_k": int64(10), // msgpack decodes ints as int64
|
||||||
|
}
|
||||||
|
data, err := msgpack.Marshal(orig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded, err := DecodeMsgpackMap(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if decoded["request_id"] != "req-001" {
|
||||||
|
t.Errorf("request_id = %v", decoded["request_id"])
|
||||||
|
}
|
||||||
|
if decoded["premium"] != true {
|
||||||
|
t.Errorf("premium = %v", decoded["premium"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeMsgpackMap_Empty(t *testing.T) {
|
||||||
|
data, _ := msgpack.Marshal(map[string]any{})
|
||||||
|
m, err := DecodeMsgpackMap(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(m) != 0 {
|
||||||
|
t.Errorf("expected empty map, got %v", m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeMsgpackMap_InvalidData(t *testing.T) {
|
||||||
|
_, err := DecodeMsgpackMap([]byte{0xFF, 0xFE})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for invalid msgpack data")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// DecodeMsgpack (typed struct) tests
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
type testMessage struct {
|
||||||
|
RequestID string `msgpack:"request_id"`
|
||||||
|
UserID string `msgpack:"user_id"`
|
||||||
|
Count int `msgpack:"count"`
|
||||||
|
Active bool `msgpack:"active"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeMsgpackTyped_Roundtrip(t *testing.T) {
|
||||||
|
orig := testMessage{
|
||||||
|
RequestID: "req-typed-001",
|
||||||
|
UserID: "user-7",
|
||||||
|
Count: 42,
|
||||||
|
Active: true,
|
||||||
|
}
|
||||||
|
data, err := msgpack.Marshal(orig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate nats.Msg data decoding.
|
||||||
|
var decoded testMessage
|
||||||
|
if err := msgpack.Unmarshal(data, &decoded); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if decoded.RequestID != orig.RequestID {
|
||||||
|
t.Errorf("RequestID = %q, want %q", decoded.RequestID, orig.RequestID)
|
||||||
|
}
|
||||||
|
if decoded.Count != orig.Count {
|
||||||
|
t.Errorf("Count = %d, want %d", decoded.Count, orig.Count)
|
||||||
|
}
|
||||||
|
if decoded.Active != orig.Active {
|
||||||
|
t.Errorf("Active = %v, want %v", decoded.Active, orig.Active)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTypedStructDecodesMapEncoding verifies that a typed struct can be
|
||||||
|
// decoded from data that was encoded as map[string]any (backwards compat).
|
||||||
|
func TestTypedStructDecodesMapEncoding(t *testing.T) {
|
||||||
|
// Encode as map (the old way).
|
||||||
|
mapData := map[string]any{
|
||||||
|
"request_id": "req-compat",
|
||||||
|
"user_id": "user-compat",
|
||||||
|
"count": int64(99),
|
||||||
|
"active": false,
|
||||||
|
}
|
||||||
|
data, err := msgpack.Marshal(mapData)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode into typed struct (the new way).
|
||||||
|
var msg testMessage
|
||||||
|
if err := msgpack.Unmarshal(data, &msg); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg.RequestID != "req-compat" {
|
||||||
|
t.Errorf("RequestID = %q", msg.RequestID)
|
||||||
|
}
|
||||||
|
if msg.Count != 99 {
|
||||||
|
t.Errorf("Count = %d, want 99", msg.Count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Binary data tests (audio []byte in msgpack)
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
type audioMessage struct {
|
||||||
|
SessionID string `msgpack:"session_id"`
|
||||||
|
Audio []byte `msgpack:"audio"`
|
||||||
|
SampleRate int `msgpack:"sample_rate"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBinaryDataRoundtrip(t *testing.T) {
|
||||||
|
audio := make([]byte, 32768)
|
||||||
|
for i := range audio {
|
||||||
|
audio[i] = byte(i % 256)
|
||||||
|
}
|
||||||
|
|
||||||
|
orig := audioMessage{
|
||||||
|
SessionID: "sess-audio-001",
|
||||||
|
Audio: audio,
|
||||||
|
SampleRate: 24000,
|
||||||
|
}
|
||||||
|
data, err := msgpack.Marshal(orig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var decoded audioMessage
|
||||||
|
if err := msgpack.Unmarshal(data, &decoded); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(decoded.Audio) != len(orig.Audio) {
|
||||||
|
t.Fatalf("audio len = %d, want %d", len(decoded.Audio), len(orig.Audio))
|
||||||
|
}
|
||||||
|
for i := range decoded.Audio {
|
||||||
|
if decoded.Audio[i] != orig.Audio[i] {
|
||||||
|
t.Fatalf("audio[%d] = %d, want %d", i, decoded.Audio[i], orig.Audio[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBinaryVsBase64Size shows the wire-size win of raw bytes vs base64 string.
|
||||||
|
func TestBinaryVsBase64Size(t *testing.T) {
|
||||||
|
audio := make([]byte, 16384)
|
||||||
|
|
||||||
|
// Old approach: base64 string in map.
|
||||||
|
import_b64 := make([]byte, (len(audio)*4+2)/3) // approximate base64 size
|
||||||
|
mapMsg := map[string]any{
|
||||||
|
"session_id": "sess-1",
|
||||||
|
"audio_b64": string(import_b64),
|
||||||
|
}
|
||||||
|
mapData, _ := msgpack.Marshal(mapMsg)
|
||||||
|
|
||||||
|
// New approach: raw bytes in struct.
|
||||||
|
structMsg := audioMessage{
|
||||||
|
SessionID: "sess-1",
|
||||||
|
Audio: audio,
|
||||||
|
}
|
||||||
|
structData, _ := msgpack.Marshal(structMsg)
|
||||||
|
|
||||||
|
t.Logf("base64-in-map: %d bytes, raw-bytes-in-struct: %d bytes (%.0f%% smaller)",
|
||||||
|
len(mapData), len(structData),
|
||||||
|
100*(1-float64(len(structData))/float64(len(mapData))))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Benchmarks
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func BenchmarkEncodeMap(b *testing.B) {
|
||||||
|
data := map[string]any{
|
||||||
|
"request_id": "req-bench",
|
||||||
|
"user_id": "user-bench",
|
||||||
|
"message": "What is the weather today?",
|
||||||
|
"premium": true,
|
||||||
|
"top_k": 10,
|
||||||
|
}
|
||||||
|
for b.Loop() {
|
||||||
|
_, _ = msgpack.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncodeStruct(b *testing.B) {
|
||||||
|
data := testMessage{
|
||||||
|
RequestID: "req-bench",
|
||||||
|
UserID: "user-bench",
|
||||||
|
Count: 10,
|
||||||
|
Active: true,
|
||||||
|
}
|
||||||
|
for b.Loop() {
|
||||||
|
_, _ = msgpack.Marshal(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecodeMap(b *testing.B) {
|
||||||
|
raw, _ := msgpack.Marshal(map[string]any{
|
||||||
|
"request_id": "req-bench",
|
||||||
|
"user_id": "user-bench",
|
||||||
|
"message": "What is the weather today?",
|
||||||
|
"premium": true,
|
||||||
|
"top_k": 10,
|
||||||
|
})
|
||||||
|
for b.Loop() {
|
||||||
|
var m map[string]any
|
||||||
|
_ = msgpack.Unmarshal(raw, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecodeStruct(b *testing.B) {
|
||||||
|
raw, _ := msgpack.Marshal(testMessage{
|
||||||
|
RequestID: "req-bench",
|
||||||
|
UserID: "user-bench",
|
||||||
|
Count: 10,
|
||||||
|
Active: true,
|
||||||
|
})
|
||||||
|
for b.Loop() {
|
||||||
|
var m testMessage
|
||||||
|
_ = msgpack.Unmarshal(raw, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDecodeAudio32KB(b *testing.B) {
|
||||||
|
raw, _ := msgpack.Marshal(audioMessage{
|
||||||
|
SessionID: "s1",
|
||||||
|
Audio: make([]byte, 32768),
|
||||||
|
SampleRate: 24000,
|
||||||
|
})
|
||||||
|
for b.Loop() {
|
||||||
|
var m audioMessage
|
||||||
|
_ = msgpack.Unmarshal(raw, &m)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
[project]
|
|
||||||
name = "handler-base"
|
|
||||||
version = "1.0.0"
|
|
||||||
description = "Shared base library for AI/ML handler services"
|
|
||||||
readme = "README.md"
|
|
||||||
requires-python = ">=3.11"
|
|
||||||
license = { text = "MIT" }
|
|
||||||
authors = [{ name = "Davies Tech Labs" }]
|
|
||||||
|
|
||||||
dependencies = [
|
|
||||||
# Async & messaging
|
|
||||||
"nats-py>=2.7.0",
|
|
||||||
"httpx>=0.27.0",
|
|
||||||
"msgpack>=1.0.0",
|
|
||||||
|
|
||||||
# Data stores
|
|
||||||
"pymilvus>=2.4.0",
|
|
||||||
"redis>=5.0.0",
|
|
||||||
|
|
||||||
# Observability
|
|
||||||
"opentelemetry-api>=1.20.0",
|
|
||||||
"opentelemetry-sdk>=1.20.0",
|
|
||||||
"opentelemetry-exporter-otlp-proto-grpc>=1.20.0",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http>=1.20.0",
|
|
||||||
"opentelemetry-instrumentation-httpx>=0.44b0",
|
|
||||||
"opentelemetry-instrumentation-logging>=0.44b0",
|
|
||||||
|
|
||||||
# MLflow
|
|
||||||
"mlflow>=2.10.0",
|
|
||||||
"psycopg2-binary>=2.9.0",
|
|
||||||
|
|
||||||
# Utilities
|
|
||||||
"numpy>=1.26.0",
|
|
||||||
"pydantic>=2.5.0",
|
|
||||||
"pydantic-settings>=2.1.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.optional-dependencies]
|
|
||||||
audio = [
|
|
||||||
"soundfile>=0.12.0",
|
|
||||||
"librosa>=0.10.0",
|
|
||||||
"webrtcvad>=2.0.10",
|
|
||||||
]
|
|
||||||
dev = [
|
|
||||||
"pytest>=8.0.0",
|
|
||||||
"pytest-asyncio>=0.23.0",
|
|
||||||
"pytest-cov>=4.0.0",
|
|
||||||
"ruff>=0.4.0",
|
|
||||||
"pre-commit>=3.7.0",
|
|
||||||
]
|
|
||||||
ray = [
|
|
||||||
"ray[serve]>=2.9.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = ["hatchling"]
|
|
||||||
build-backend = "hatchling.build"
|
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
|
||||||
packages = ["handler_base"]
|
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
line-length = 100
|
|
||||||
target-version = "py311"
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
|
||||||
select = ["E", "F", "I", "W"]
|
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
|
||||||
asyncio_mode = "auto"
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
[pytest]
|
|
||||||
testpaths = tests
|
|
||||||
python_files = test_*.py
|
|
||||||
python_classes = Test*
|
|
||||||
python_functions = test_*
|
|
||||||
asyncio_mode = auto
|
|
||||||
addopts = -v --tb=short
|
|
||||||
122
telemetry/telemetry.go
Normal file
122
telemetry/telemetry.go
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
// Package telemetry provides OpenTelemetry tracing and metrics setup.
|
||||||
|
package telemetry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"go.opentelemetry.io/otel"
|
||||||
|
"go.opentelemetry.io/otel/attribute"
|
||||||
|
"go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc"
|
||||||
|
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
|
||||||
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
"go.opentelemetry.io/otel/propagation"
|
||||||
|
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
|
||||||
|
"go.opentelemetry.io/otel/sdk/resource"
|
||||||
|
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||||
|
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
|
||||||
|
"go.opentelemetry.io/otel/trace"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config holds the telemetry configuration parameters.
|
||||||
|
type Config struct {
|
||||||
|
ServiceName string
|
||||||
|
ServiceVersion string
|
||||||
|
ServiceNamespace string
|
||||||
|
DeploymentEnv string
|
||||||
|
Enabled bool
|
||||||
|
Endpoint string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provider holds the initialized tracer and meter providers.
|
||||||
|
type Provider struct {
|
||||||
|
TracerProvider *sdktrace.TracerProvider
|
||||||
|
MeterProvider *sdkmetric.MeterProvider
|
||||||
|
Tracer trace.Tracer
|
||||||
|
Meter metric.Meter
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup initialises OpenTelemetry tracing and metrics.
|
||||||
|
// Returns a Provider and a shutdown function.
|
||||||
|
func Setup(ctx context.Context, cfg Config) (*Provider, func(context.Context), error) {
|
||||||
|
if !cfg.Enabled {
|
||||||
|
slog.Info("OpenTelemetry disabled")
|
||||||
|
return &Provider{
|
||||||
|
Tracer: otel.Tracer(cfg.ServiceName),
|
||||||
|
Meter: otel.Meter(cfg.ServiceName),
|
||||||
|
}, func(context.Context) {}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hostname, _ := os.Hostname()
|
||||||
|
res, err := resource.New(ctx,
|
||||||
|
resource.WithAttributes(
|
||||||
|
semconv.ServiceNameKey.String(cfg.ServiceName),
|
||||||
|
semconv.ServiceVersionKey.String(cfg.ServiceVersion),
|
||||||
|
semconv.ServiceNamespaceKey.String(cfg.ServiceNamespace),
|
||||||
|
attribute.String("deployment.environment", cfg.DeploymentEnv),
|
||||||
|
attribute.String("host.name", hostname),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trace exporter (gRPC)
|
||||||
|
traceExp, err := otlptracegrpc.New(ctx,
|
||||||
|
otlptracegrpc.WithEndpoint(stripScheme(cfg.Endpoint)),
|
||||||
|
otlptracegrpc.WithInsecure(),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tp := sdktrace.NewTracerProvider(
|
||||||
|
sdktrace.WithBatcher(traceExp),
|
||||||
|
sdktrace.WithResource(res),
|
||||||
|
)
|
||||||
|
otel.SetTracerProvider(tp)
|
||||||
|
otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(
|
||||||
|
propagation.TraceContext{},
|
||||||
|
propagation.Baggage{},
|
||||||
|
))
|
||||||
|
|
||||||
|
// Metric exporter (gRPC)
|
||||||
|
metricExp, err := otlpmetricgrpc.New(ctx,
|
||||||
|
otlpmetricgrpc.WithEndpoint(stripScheme(cfg.Endpoint)),
|
||||||
|
otlpmetricgrpc.WithInsecure(),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
mp := sdkmetric.NewMeterProvider(
|
||||||
|
sdkmetric.WithReader(sdkmetric.NewPeriodicReader(metricExp)),
|
||||||
|
sdkmetric.WithResource(res),
|
||||||
|
)
|
||||||
|
otel.SetMeterProvider(mp)
|
||||||
|
|
||||||
|
slog.Info("OpenTelemetry initialized", "service", cfg.ServiceName, "endpoint", cfg.Endpoint)
|
||||||
|
|
||||||
|
shutdown := func(ctx context.Context) {
|
||||||
|
_ = tp.Shutdown(ctx)
|
||||||
|
_ = mp.Shutdown(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Provider{
|
||||||
|
TracerProvider: tp,
|
||||||
|
MeterProvider: mp,
|
||||||
|
Tracer: tp.Tracer(cfg.ServiceName, trace.WithInstrumentationVersion(cfg.ServiceVersion)),
|
||||||
|
Meter: mp.Meter(cfg.ServiceName),
|
||||||
|
}, shutdown, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// stripScheme removes http:// or https:// from an endpoint for gRPC dialers.
|
||||||
|
func stripScheme(endpoint string) string {
|
||||||
|
for _, prefix := range []string{"https://", "http://"} {
|
||||||
|
if len(endpoint) > len(prefix) && endpoint[:len(prefix)] == prefix {
|
||||||
|
return endpoint[len(prefix):]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return endpoint
|
||||||
|
}
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
"""
|
|
||||||
Pytest configuration and fixtures.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
# Set test environment variables before importing handler_base
|
|
||||||
os.environ.setdefault("NATS_URL", "nats://localhost:4222")
|
|
||||||
os.environ.setdefault("REDIS_URL", "redis://localhost:6379")
|
|
||||||
os.environ.setdefault("MILVUS_HOST", "localhost")
|
|
||||||
os.environ.setdefault("OTEL_ENABLED", "false")
|
|
||||||
os.environ.setdefault("MLFLOW_ENABLED", "false")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def event_loop():
|
|
||||||
"""Create event loop for async tests."""
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
yield loop
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def settings():
|
|
||||||
"""Create test settings."""
|
|
||||||
from handler_base.config import Settings
|
|
||||||
|
|
||||||
return Settings(
|
|
||||||
service_name="test-service",
|
|
||||||
service_version="1.0.0-test",
|
|
||||||
otel_enabled=False,
|
|
||||||
mlflow_enabled=False,
|
|
||||||
nats_url="nats://localhost:4222",
|
|
||||||
redis_url="redis://localhost:6379",
|
|
||||||
milvus_host="localhost",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_httpx_client():
|
|
||||||
"""Create a mock httpx AsyncClient."""
|
|
||||||
client = AsyncMock()
|
|
||||||
client.post = AsyncMock()
|
|
||||||
client.get = AsyncMock()
|
|
||||||
client.aclose = AsyncMock()
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_nats_message():
|
|
||||||
"""Create a mock NATS message."""
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.subject = "test.subject"
|
|
||||||
msg.reply = "test.reply"
|
|
||||||
msg.data = b"\x82\xa8query\xa5hello\xaarequest_id\xa4test" # msgpack
|
|
||||||
return msg
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_embedding():
|
|
||||||
"""Sample embedding vector."""
|
|
||||||
return [0.1] * 1024
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_documents():
|
|
||||||
"""Sample documents for testing."""
|
|
||||||
return [
|
|
||||||
{"text": "Python is a programming language.", "source": "doc1"},
|
|
||||||
{"text": "Machine learning is a subset of AI.", "source": "doc2"},
|
|
||||||
{"text": "Deep learning uses neural networks.", "source": "doc3"},
|
|
||||||
]
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
# Unit tests package
|
|
||||||
@@ -1,150 +0,0 @@
|
|||||||
"""
|
|
||||||
Unit tests for service clients.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
class TestEmbeddingsClient:
|
|
||||||
"""Tests for EmbeddingsClient."""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def embeddings_client(self, mock_httpx_client):
|
|
||||||
"""Create an EmbeddingsClient with mocked HTTP."""
|
|
||||||
from handler_base.clients.embeddings import EmbeddingsClient
|
|
||||||
|
|
||||||
client = EmbeddingsClient()
|
|
||||||
client._client = mock_httpx_client
|
|
||||||
return client
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_embed_single(self, embeddings_client, mock_httpx_client, sample_embedding):
|
|
||||||
"""Test embedding a single text."""
|
|
||||||
# Setup mock response
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.json.return_value = {"data": [{"embedding": sample_embedding, "index": 0}]}
|
|
||||||
mock_response.raise_for_status = MagicMock()
|
|
||||||
mock_httpx_client.post.return_value = mock_response
|
|
||||||
|
|
||||||
result = await embeddings_client.embed_single("Hello world")
|
|
||||||
|
|
||||||
assert result == sample_embedding
|
|
||||||
mock_httpx_client.post.assert_called_once()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_embed_batch(self, embeddings_client, mock_httpx_client, sample_embedding):
|
|
||||||
"""Test embedding multiple texts."""
|
|
||||||
texts = ["Hello", "World"]
|
|
||||||
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.json.return_value = {
|
|
||||||
"data": [
|
|
||||||
{"embedding": sample_embedding, "index": 0},
|
|
||||||
{"embedding": sample_embedding, "index": 1},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
mock_response.raise_for_status = MagicMock()
|
|
||||||
mock_httpx_client.post.return_value = mock_response
|
|
||||||
|
|
||||||
result = await embeddings_client.embed(texts)
|
|
||||||
|
|
||||||
assert len(result) == 2
|
|
||||||
assert all(len(e) == len(sample_embedding) for e in result)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_health_check(self, embeddings_client, mock_httpx_client):
|
|
||||||
"""Test health check."""
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.status_code = 200
|
|
||||||
mock_httpx_client.get.return_value = mock_response
|
|
||||||
|
|
||||||
result = await embeddings_client.health()
|
|
||||||
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
|
|
||||||
class TestRerankerClient:
|
|
||||||
"""Tests for RerankerClient."""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def reranker_client(self, mock_httpx_client):
|
|
||||||
"""Create a RerankerClient with mocked HTTP."""
|
|
||||||
from handler_base.clients.reranker import RerankerClient
|
|
||||||
|
|
||||||
client = RerankerClient()
|
|
||||||
client._client = mock_httpx_client
|
|
||||||
return client
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_rerank(self, reranker_client, mock_httpx_client, sample_documents):
|
|
||||||
"""Test reranking documents."""
|
|
||||||
texts = [d["text"] for d in sample_documents]
|
|
||||||
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.json.return_value = {
|
|
||||||
"results": [
|
|
||||||
{"index": 1, "relevance_score": 0.95},
|
|
||||||
{"index": 0, "relevance_score": 0.80},
|
|
||||||
{"index": 2, "relevance_score": 0.65},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
mock_response.raise_for_status = MagicMock()
|
|
||||||
mock_httpx_client.post.return_value = mock_response
|
|
||||||
|
|
||||||
result = await reranker_client.rerank("What is ML?", texts)
|
|
||||||
|
|
||||||
assert len(result) == 3
|
|
||||||
assert result[0]["score"] == 0.95
|
|
||||||
assert result[0]["index"] == 1
|
|
||||||
|
|
||||||
|
|
||||||
class TestLLMClient:
|
|
||||||
"""Tests for LLMClient."""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def llm_client(self, mock_httpx_client):
|
|
||||||
"""Create an LLMClient with mocked HTTP."""
|
|
||||||
from handler_base.clients.llm import LLMClient
|
|
||||||
|
|
||||||
client = LLMClient()
|
|
||||||
client._client = mock_httpx_client
|
|
||||||
return client
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_generate(self, llm_client, mock_httpx_client):
|
|
||||||
"""Test generating a response."""
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.json.return_value = {
|
|
||||||
"choices": [{"message": {"content": "Hello! I'm an AI assistant."}}],
|
|
||||||
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
|
|
||||||
}
|
|
||||||
mock_response.raise_for_status = MagicMock()
|
|
||||||
mock_httpx_client.post.return_value = mock_response
|
|
||||||
|
|
||||||
result = await llm_client.generate("Hello")
|
|
||||||
|
|
||||||
assert result == "Hello! I'm an AI assistant."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_generate_with_context(self, llm_client, mock_httpx_client):
|
|
||||||
"""Test generating with RAG context."""
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.json.return_value = {
|
|
||||||
"choices": [{"message": {"content": "Based on the context..."}}],
|
|
||||||
"usage": {},
|
|
||||||
}
|
|
||||||
mock_response.raise_for_status = MagicMock()
|
|
||||||
mock_httpx_client.post.return_value = mock_response
|
|
||||||
|
|
||||||
result = await llm_client.generate(
|
|
||||||
"What is Python?", context="Python is a programming language."
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "Based on the context" in result
|
|
||||||
|
|
||||||
# Verify context was included in the request
|
|
||||||
call_args = mock_httpx_client.post.call_args
|
|
||||||
messages = call_args.kwargs["json"]["messages"]
|
|
||||||
assert any("Context:" in m["content"] for m in messages if m["role"] == "user")
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
"""
|
|
||||||
Unit tests for handler_base.config module.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class TestSettings:
|
|
||||||
"""Tests for Settings configuration."""
|
|
||||||
|
|
||||||
def test_default_settings(self, settings):
|
|
||||||
"""Test that default settings are loaded correctly."""
|
|
||||||
assert settings.service_name == "test-service"
|
|
||||||
assert settings.service_version == "1.0.0-test"
|
|
||||||
assert settings.otel_enabled is False
|
|
||||||
|
|
||||||
def test_settings_from_env(self, monkeypatch):
|
|
||||||
"""Test that settings can be loaded from environment variables."""
|
|
||||||
monkeypatch.setenv("SERVICE_NAME", "env-service")
|
|
||||||
monkeypatch.setenv("SERVICE_VERSION", "2.0.0")
|
|
||||||
monkeypatch.setenv("NATS_URL", "nats://custom:4222")
|
|
||||||
|
|
||||||
# Need to reimport to pick up env changes
|
|
||||||
from handler_base.config import Settings
|
|
||||||
|
|
||||||
s = Settings()
|
|
||||||
|
|
||||||
assert s.service_name == "env-service"
|
|
||||||
assert s.service_version == "2.0.0"
|
|
||||||
assert s.nats_url == "nats://custom:4222"
|
|
||||||
|
|
||||||
def test_embeddings_settings(self):
|
|
||||||
"""Test EmbeddingsSettings extends base correctly."""
|
|
||||||
from handler_base.config import EmbeddingsSettings
|
|
||||||
|
|
||||||
s = EmbeddingsSettings()
|
|
||||||
assert hasattr(s, "embeddings_model")
|
|
||||||
assert hasattr(s, "embeddings_batch_size")
|
|
||||||
assert s.embeddings_model == "bge"
|
|
||||||
|
|
||||||
def test_llm_settings(self):
|
|
||||||
"""Test LLMSettings has expected defaults."""
|
|
||||||
from handler_base.config import LLMSettings
|
|
||||||
|
|
||||||
s = LLMSettings()
|
|
||||||
assert s.llm_max_tokens == 2048
|
|
||||||
assert s.llm_temperature == 0.7
|
|
||||||
assert 0 <= s.llm_top_p <= 1
|
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
"""
|
|
||||||
Unit tests for handler_base.health module.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from http.client import HTTPConnection
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
class TestHealthServer:
|
|
||||||
"""Tests for HealthServer."""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def health_server(self, settings):
|
|
||||||
"""Create a HealthServer instance."""
|
|
||||||
from handler_base.health import HealthServer
|
|
||||||
|
|
||||||
# Use a random high port to avoid conflicts
|
|
||||||
settings.health_port = 18080
|
|
||||||
return HealthServer(settings)
|
|
||||||
|
|
||||||
def test_start_stop(self, health_server):
|
|
||||||
"""Test starting and stopping the health server."""
|
|
||||||
health_server.start()
|
|
||||||
time.sleep(0.1) # Give server time to start
|
|
||||||
|
|
||||||
# Verify server is running
|
|
||||||
assert health_server._server is not None
|
|
||||||
assert health_server._thread is not None
|
|
||||||
assert health_server._thread.is_alive()
|
|
||||||
|
|
||||||
health_server.stop()
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
assert health_server._server is None
|
|
||||||
|
|
||||||
def test_health_endpoint(self, health_server):
|
|
||||||
"""Test the /health endpoint."""
|
|
||||||
health_server.start()
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
try:
|
|
||||||
conn = HTTPConnection("localhost", 18080, timeout=5)
|
|
||||||
conn.request("GET", "/health")
|
|
||||||
response = conn.getresponse()
|
|
||||||
|
|
||||||
assert response.status == 200
|
|
||||||
data = json.loads(response.read().decode())
|
|
||||||
assert data["status"] == "healthy"
|
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
health_server.stop()
|
|
||||||
|
|
||||||
def test_ready_endpoint_default(self, health_server):
|
|
||||||
"""Test the /ready endpoint with no custom check."""
|
|
||||||
health_server.start()
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
try:
|
|
||||||
conn = HTTPConnection("localhost", 18080, timeout=5)
|
|
||||||
conn.request("GET", "/ready")
|
|
||||||
response = conn.getresponse()
|
|
||||||
|
|
||||||
assert response.status == 200
|
|
||||||
data = json.loads(response.read().decode())
|
|
||||||
assert data["status"] == "ready"
|
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
health_server.stop()
|
|
||||||
|
|
||||||
def test_ready_endpoint_with_check(self, settings):
|
|
||||||
"""Test /ready endpoint with custom readiness check."""
|
|
||||||
from handler_base.health import HealthServer
|
|
||||||
|
|
||||||
ready_flag = [False] # Use list to allow mutation in closure
|
|
||||||
|
|
||||||
async def check_ready():
|
|
||||||
return ready_flag[0]
|
|
||||||
|
|
||||||
settings.health_port = 18081
|
|
||||||
server = HealthServer(settings, ready_check=check_ready)
|
|
||||||
server.start()
|
|
||||||
time.sleep(0.2)
|
|
||||||
|
|
||||||
try:
|
|
||||||
conn = HTTPConnection("localhost", 18081, timeout=5)
|
|
||||||
|
|
||||||
# Should be not ready initially
|
|
||||||
conn.request("GET", "/ready")
|
|
||||||
response = conn.getresponse()
|
|
||||||
response.read() # Consume response body
|
|
||||||
assert response.status == 503
|
|
||||||
|
|
||||||
# Mark as ready
|
|
||||||
ready_flag[0] = True
|
|
||||||
|
|
||||||
# Need new connection after consuming response
|
|
||||||
conn.close()
|
|
||||||
conn = HTTPConnection("localhost", 18081, timeout=5)
|
|
||||||
conn.request("GET", "/ready")
|
|
||||||
response = conn.getresponse()
|
|
||||||
assert response.status == 200
|
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
server.stop()
|
|
||||||
|
|
||||||
def test_404_for_unknown_path(self, health_server):
|
|
||||||
"""Test that unknown paths return 404."""
|
|
||||||
health_server.start()
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
try:
|
|
||||||
conn = HTTPConnection("localhost", 18080, timeout=5)
|
|
||||||
conn.request("GET", "/unknown")
|
|
||||||
response = conn.getresponse()
|
|
||||||
|
|
||||||
assert response.status == 404
|
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
health_server.stop()
|
|
||||||
@@ -1,95 +0,0 @@
|
|||||||
"""
|
|
||||||
Unit tests for handler_base.nats_client module.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import msgpack
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
class TestNATSClient:
|
|
||||||
"""Tests for NATSClient."""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def nats_client(self, settings):
|
|
||||||
"""Create a NATSClient instance."""
|
|
||||||
from handler_base.nats_client import NATSClient
|
|
||||||
|
|
||||||
return NATSClient(settings)
|
|
||||||
|
|
||||||
def test_init(self, nats_client, settings):
|
|
||||||
"""Test NATSClient initialization."""
|
|
||||||
assert nats_client.settings == settings
|
|
||||||
assert nats_client._nc is None
|
|
||||||
assert nats_client._js is None
|
|
||||||
|
|
||||||
def test_decode_msgpack(self, nats_client):
|
|
||||||
"""Test msgpack decoding."""
|
|
||||||
data = {"query": "hello", "request_id": "123"}
|
|
||||||
encoded = msgpack.packb(data, use_bin_type=True)
|
|
||||||
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.data = encoded
|
|
||||||
|
|
||||||
result = nats_client.decode_msgpack(msg)
|
|
||||||
assert result == data
|
|
||||||
|
|
||||||
def test_decode_json(self, nats_client):
|
|
||||||
"""Test JSON decoding."""
|
|
||||||
import json
|
|
||||||
|
|
||||||
data = {"query": "hello"}
|
|
||||||
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.data = json.dumps(data).encode()
|
|
||||||
|
|
||||||
result = nats_client.decode_json(msg)
|
|
||||||
assert result == data
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_connect(self, nats_client):
|
|
||||||
"""Test NATS connection."""
|
|
||||||
with patch("handler_base.nats_client.nats") as mock_nats:
|
|
||||||
mock_nc = AsyncMock()
|
|
||||||
mock_js = MagicMock()
|
|
||||||
mock_nc.jetstream = MagicMock(return_value=mock_js) # Not async
|
|
||||||
mock_nats.connect = AsyncMock(return_value=mock_nc)
|
|
||||||
|
|
||||||
await nats_client.connect()
|
|
||||||
|
|
||||||
assert nats_client._nc == mock_nc
|
|
||||||
assert nats_client._js == mock_js
|
|
||||||
mock_nats.connect.assert_called_once()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_publish(self, nats_client):
|
|
||||||
"""Test publishing a message."""
|
|
||||||
mock_nc = AsyncMock()
|
|
||||||
nats_client._nc = mock_nc
|
|
||||||
|
|
||||||
data = {"key": "value"}
|
|
||||||
await nats_client.publish("test.subject", data)
|
|
||||||
|
|
||||||
mock_nc.publish.assert_called_once()
|
|
||||||
call_args = mock_nc.publish.call_args
|
|
||||||
assert call_args.args[0] == "test.subject"
|
|
||||||
|
|
||||||
# Verify msgpack encoding
|
|
||||||
decoded = msgpack.unpackb(call_args.args[1], raw=False)
|
|
||||||
assert decoded == data
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_subscribe(self, nats_client):
|
|
||||||
"""Test subscribing to a subject."""
|
|
||||||
mock_nc = AsyncMock()
|
|
||||||
mock_sub = MagicMock()
|
|
||||||
mock_nc.subscribe = AsyncMock(return_value=mock_sub)
|
|
||||||
nats_client._nc = mock_nc
|
|
||||||
|
|
||||||
handler = AsyncMock()
|
|
||||||
await nats_client.subscribe("test.subject", handler, queue="test-queue")
|
|
||||||
|
|
||||||
mock_nc.subscribe.assert_called_once()
|
|
||||||
call_kwargs = mock_nc.subscribe.call_args.kwargs
|
|
||||||
assert call_kwargs["queue"] == "test-queue"
|
|
||||||
Reference in New Issue
Block a user