feat: Add handler-base library for NATS AI/ML services

- Handler base class with graceful shutdown and signal handling
- NATSClient with JetStream and msgpack serialization
- Pydantic Settings for environment configuration
- HealthServer for Kubernetes probes
- OpenTelemetry telemetry setup
- Service clients: STT, TTS, LLM, Embeddings, Reranker, Milvus
This commit is contained in:
2026-02-01 20:36:00 -05:00
parent 00df482412
commit 99c97b7973
17 changed files with 1932 additions and 1 deletions

27
handler_base/__init__.py Normal file
View File

@@ -0,0 +1,27 @@
"""
Handler Base - Shared utilities for AI/ML handler services.
Provides consistent patterns for:
- OpenTelemetry tracing and metrics
- NATS messaging
- Health checks
- Graceful shutdown
- Service client wrappers
"""
from handler_base.config import Settings
from handler_base.handler import Handler
from handler_base.health import HealthServer
from handler_base.nats_client import NATSClient
from handler_base.telemetry import setup_telemetry, get_tracer, get_meter
__all__ = [
"Handler",
"Settings",
"HealthServer",
"NATSClient",
"setup_telemetry",
"get_tracer",
"get_meter",
]
__version__ = "1.0.0"

View File

@@ -0,0 +1,18 @@
"""
Service client wrappers for AI/ML backends.
"""
from handler_base.clients.embeddings import EmbeddingsClient
from handler_base.clients.reranker import RerankerClient
from handler_base.clients.llm import LLMClient
from handler_base.clients.tts import TTSClient
from handler_base.clients.stt import STTClient
from handler_base.clients.milvus import MilvusClient
__all__ = [
"EmbeddingsClient",
"RerankerClient",
"LLMClient",
"TTSClient",
"STTClient",
"MilvusClient",
]

View File

@@ -0,0 +1,91 @@
"""
Embeddings service client (Infinity/BGE).
"""
import logging
from typing import Optional
import httpx
from handler_base.config import EmbeddingsSettings
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class EmbeddingsClient:
"""
Client for the embeddings service (Infinity with BGE models).
Usage:
client = EmbeddingsClient()
embeddings = await client.embed(["Hello world"])
"""
def __init__(self, settings: Optional[EmbeddingsSettings] = None):
self.settings = settings or EmbeddingsSettings()
self._client = httpx.AsyncClient(
base_url=self.settings.embeddings_url,
timeout=self.settings.http_timeout,
)
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def embed(
self,
texts: list[str],
model: Optional[str] = None,
) -> list[list[float]]:
"""
Generate embeddings for a list of texts.
Args:
texts: List of texts to embed
model: Model name (defaults to settings)
Returns:
List of embedding vectors
"""
model = model or self.settings.embeddings_model
with create_span("embeddings.embed") as span:
if span:
span.set_attribute("embeddings.model", model)
span.set_attribute("embeddings.batch_size", len(texts))
response = await self._client.post(
"/embeddings",
json={"input": texts, "model": model},
)
response.raise_for_status()
result = response.json()
embeddings = [d["embedding"] for d in result.get("data", [])]
if span:
span.set_attribute("embeddings.dimensions", len(embeddings[0]) if embeddings else 0)
return embeddings
async def embed_single(self, text: str, model: Optional[str] = None) -> list[float]:
"""
Generate embedding for a single text.
Args:
text: Text to embed
model: Model name (defaults to settings)
Returns:
Embedding vector
"""
embeddings = await self.embed([text], model)
return embeddings[0] if embeddings else []
async def health(self) -> bool:
"""Check if the embeddings service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False

192
handler_base/clients/llm.py Normal file
View File

@@ -0,0 +1,192 @@
"""
LLM service client (vLLM/OpenAI-compatible).
"""
import logging
from typing import Optional, AsyncIterator
import httpx
from handler_base.config import LLMSettings
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class LLMClient:
"""
Client for the LLM service (vLLM with OpenAI-compatible API).
Usage:
client = LLMClient()
response = await client.generate("Hello, how are you?")
# With context for RAG
response = await client.generate(
"What is the capital?",
context="France is a country in Europe..."
)
# Streaming
async for chunk in client.stream("Tell me a story"):
print(chunk, end="")
"""
def __init__(self, settings: Optional[LLMSettings] = None):
self.settings = settings or LLMSettings()
self._client = httpx.AsyncClient(
base_url=self.settings.llm_url,
timeout=self.settings.http_timeout,
)
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def generate(
self,
prompt: str,
context: Optional[str] = None,
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
stop: Optional[list[str]] = None,
) -> str:
"""
Generate a response from the LLM.
Args:
prompt: User prompt/query
context: Optional context for RAG
system_prompt: Optional system prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Top-p sampling
stop: Stop sequences
Returns:
Generated text response
"""
with create_span("llm.generate") as span:
messages = self._build_messages(prompt, context, system_prompt)
if span:
span.set_attribute("llm.model", self.settings.llm_model)
span.set_attribute("llm.prompt_length", len(prompt))
if context:
span.set_attribute("llm.context_length", len(context))
payload = {
"model": self.settings.llm_model,
"messages": messages,
"max_tokens": max_tokens or self.settings.llm_max_tokens,
"temperature": temperature or self.settings.llm_temperature,
"top_p": top_p or self.settings.llm_top_p,
}
if stop:
payload["stop"] = stop
response = await self._client.post("/v1/chat/completions", json=payload)
response.raise_for_status()
result = response.json()
content = result["choices"][0]["message"]["content"]
if span:
span.set_attribute("llm.response_length", len(content))
usage = result.get("usage", {})
span.set_attribute("llm.prompt_tokens", usage.get("prompt_tokens", 0))
span.set_attribute("llm.completion_tokens", usage.get("completion_tokens", 0))
return content
async def stream(
self,
prompt: str,
context: Optional[str] = None,
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
) -> AsyncIterator[str]:
"""
Stream a response from the LLM.
Args:
prompt: User prompt/query
context: Optional context for RAG
system_prompt: Optional system prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
Yields:
Text chunks as they're generated
"""
messages = self._build_messages(prompt, context, system_prompt)
payload = {
"model": self.settings.llm_model,
"messages": messages,
"max_tokens": max_tokens or self.settings.llm_max_tokens,
"temperature": temperature or self.settings.llm_temperature,
"stream": True,
}
async with self._client.stream(
"POST", "/v1/chat/completions", json=payload
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data = line[6:]
if data == "[DONE]":
break
import json
chunk = json.loads(data)
delta = chunk["choices"][0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
def _build_messages(
self,
prompt: str,
context: Optional[str] = None,
system_prompt: Optional[str] = None,
) -> list[dict]:
"""Build the messages list for the API call."""
messages = []
# System prompt
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
elif context:
# Default RAG system prompt
messages.append({
"role": "system",
"content": (
"You are a helpful assistant. Use the provided context to answer "
"the user's question. If the context doesn't contain relevant "
"information, say so."
),
})
# Add context as a separate message if provided
if context:
messages.append({
"role": "user",
"content": f"Context:\n{context}\n\nQuestion: {prompt}",
})
else:
messages.append({"role": "user", "content": prompt})
return messages
async def health(self) -> bool:
"""Check if the LLM service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False

View File

@@ -0,0 +1,182 @@
"""
Milvus vector database client.
"""
import logging
from typing import Optional, Any
from pymilvus import connections, Collection, utility
from handler_base.config import Settings
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class MilvusClient:
"""
Client for Milvus vector database.
Usage:
client = MilvusClient()
await client.connect()
results = await client.search(embedding, limit=10)
"""
def __init__(self, settings: Optional[Settings] = None):
self.settings = settings or Settings()
self._connected = False
self._collection: Optional[Collection] = None
async def connect(self, collection_name: Optional[str] = None) -> None:
"""
Connect to Milvus and load collection.
Args:
collection_name: Collection to use (defaults to settings)
"""
collection_name = collection_name or self.settings.milvus_collection
connections.connect(
alias="default",
host=self.settings.milvus_host,
port=self.settings.milvus_port,
)
if utility.has_collection(collection_name):
self._collection = Collection(collection_name)
self._collection.load()
logger.info(f"Connected to Milvus collection: {collection_name}")
else:
logger.warning(f"Collection {collection_name} does not exist")
self._connected = True
async def close(self) -> None:
"""Close Milvus connection."""
if self._collection:
self._collection.release()
connections.disconnect("default")
self._connected = False
logger.info("Disconnected from Milvus")
async def search(
self,
embedding: list[float],
limit: int = 10,
output_fields: Optional[list[str]] = None,
filter_expr: Optional[str] = None,
) -> list[dict]:
"""
Search for similar vectors.
Args:
embedding: Query embedding vector
limit: Maximum number of results
output_fields: Fields to return (default: all)
filter_expr: Optional filter expression
Returns:
List of results with 'id', 'distance', and requested fields
"""
if not self._collection:
raise RuntimeError("Not connected to collection")
with create_span("milvus.search") as span:
if span:
span.set_attribute("milvus.collection", self._collection.name)
span.set_attribute("milvus.limit", limit)
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
results = self._collection.search(
data=[embedding],
anns_field="embedding",
param=search_params,
limit=limit,
output_fields=output_fields,
expr=filter_expr,
)
# Convert to list of dicts
hits = []
for hit in results[0]:
item = {
"id": hit.id,
"distance": hit.distance,
"score": 1 - hit.distance, # Convert distance to similarity
}
# Add output fields
if output_fields:
for field in output_fields:
if hasattr(hit.entity, field):
item[field] = getattr(hit.entity, field)
hits.append(item)
if span:
span.set_attribute("milvus.results", len(hits))
return hits
async def search_with_texts(
self,
embedding: list[float],
limit: int = 10,
text_field: str = "text",
metadata_fields: Optional[list[str]] = None,
) -> list[dict]:
"""
Search and return text content with metadata.
Args:
embedding: Query embedding
limit: Maximum results
text_field: Name of text field in collection
metadata_fields: Additional metadata fields to return
Returns:
List of results with text and metadata
"""
output_fields = [text_field]
if metadata_fields:
output_fields.extend(metadata_fields)
return await self.search(embedding, limit, output_fields)
async def insert(
self,
embeddings: list[list[float]],
data: list[dict],
) -> list[int]:
"""
Insert vectors with data into the collection.
Args:
embeddings: List of embedding vectors
data: List of dicts with field values
Returns:
List of inserted IDs
"""
if not self._collection:
raise RuntimeError("Not connected to collection")
with create_span("milvus.insert") as span:
if span:
span.set_attribute("milvus.collection", self._collection.name)
span.set_attribute("milvus.count", len(embeddings))
# Build insert data
insert_data = [embeddings]
for field in self._collection.schema.fields:
if field.name not in ("id", "embedding"):
field_values = [d.get(field.name) for d in data]
insert_data.append(field_values)
result = self._collection.insert(insert_data)
self._collection.flush()
return result.primary_keys
def health(self) -> bool:
"""Check if connected to Milvus."""
return self._connected and utility.get_connection_addr("default") is not None

View File

@@ -0,0 +1,120 @@
"""
Reranker service client (Infinity/BGE Reranker).
"""
import logging
from typing import Optional
import httpx
from handler_base.config import Settings
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class RerankerClient:
"""
Client for the reranker service (Infinity with BGE Reranker).
Usage:
client = RerankerClient()
reranked = await client.rerank("query", ["doc1", "doc2"])
"""
def __init__(self, settings: Optional[Settings] = None):
self.settings = settings or Settings()
self._client = httpx.AsyncClient(
base_url=self.settings.reranker_url,
timeout=self.settings.http_timeout,
)
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def rerank(
self,
query: str,
documents: list[str],
top_k: Optional[int] = None,
) -> list[dict]:
"""
Rerank documents based on relevance to query.
Args:
query: Query text
documents: List of documents to rerank
top_k: Number of top results to return (default: all)
Returns:
List of dicts with 'index', 'score', and 'document' keys,
sorted by relevance score descending.
"""
with create_span("reranker.rerank") as span:
if span:
span.set_attribute("reranker.num_documents", len(documents))
if top_k:
span.set_attribute("reranker.top_k", top_k)
payload = {
"query": query,
"documents": documents,
}
if top_k:
payload["top_n"] = top_k
response = await self._client.post("/rerank", json=payload)
response.raise_for_status()
result = response.json()
results = result.get("results", [])
# Enrich with original documents
enriched = []
for r in results:
idx = r.get("index", 0)
enriched.append({
"index": idx,
"score": r.get("relevance_score", r.get("score", 0)),
"document": documents[idx] if idx < len(documents) else "",
})
return enriched
async def rerank_with_metadata(
self,
query: str,
documents: list[dict],
text_key: str = "text",
top_k: Optional[int] = None,
) -> list[dict]:
"""
Rerank documents with metadata, preserving metadata in results.
Args:
query: Query text
documents: List of dicts with text and metadata
text_key: Key containing text in each document dict
top_k: Number of top results to return
Returns:
Reranked documents with original metadata preserved.
"""
texts = [d.get(text_key, "") for d in documents]
reranked = await self.rerank(query, texts, top_k)
# Merge back metadata
for r in reranked:
idx = r["index"]
if idx < len(documents):
r["metadata"] = {k: v for k, v in documents[idx].items() if k != text_key}
return reranked
async def health(self) -> bool:
"""Check if the reranker service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False

132
handler_base/clients/stt.py Normal file
View File

@@ -0,0 +1,132 @@
"""
STT service client (Whisper/faster-whisper).
"""
import io
import logging
from typing import Optional
import httpx
from handler_base.config import STTSettings
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class STTClient:
"""
Client for the STT service (Whisper/faster-whisper).
Usage:
client = STTClient()
text = await client.transcribe(audio_bytes)
"""
def __init__(self, settings: Optional[STTSettings] = None):
self.settings = settings or STTSettings()
self._client = httpx.AsyncClient(
base_url=self.settings.stt_url,
timeout=180.0, # Transcription can be slow
)
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def transcribe(
self,
audio: bytes,
language: Optional[str] = None,
task: Optional[str] = None,
response_format: str = "json",
) -> dict:
"""
Transcribe audio to text.
Args:
audio: Audio bytes (WAV, MP3, etc.)
language: Language code (None for auto-detect)
task: "transcribe" or "translate"
response_format: "json", "text", "srt", "vtt"
Returns:
Dict with 'text', 'language', and optional 'segments'
"""
language = language or self.settings.stt_language
task = task or self.settings.stt_task
with create_span("stt.transcribe") as span:
if span:
span.set_attribute("stt.task", task)
span.set_attribute("stt.audio_size", len(audio))
if language:
span.set_attribute("stt.language", language)
files = {"file": ("audio.wav", audio, "audio/wav")}
data = {
"response_format": response_format,
}
if language:
data["language"] = language
# Choose endpoint based on task
if task == "translate":
endpoint = "/v1/audio/translations"
else:
endpoint = "/v1/audio/transcriptions"
response = await self._client.post(endpoint, files=files, data=data)
response.raise_for_status()
if response_format == "text":
return {"text": response.text}
result = response.json()
if span:
span.set_attribute("stt.result_length", len(result.get("text", "")))
if result.get("language"):
span.set_attribute("stt.detected_language", result["language"])
return result
async def transcribe_file(
self,
file_path: str,
language: Optional[str] = None,
task: Optional[str] = None,
) -> dict:
"""
Transcribe an audio file.
Args:
file_path: Path to audio file
language: Language code
task: "transcribe" or "translate"
Returns:
Transcription result
"""
with open(file_path, "rb") as f:
audio = f.read()
return await self.transcribe(audio, language, task)
async def translate(self, audio: bytes) -> dict:
"""
Translate audio to English.
Args:
audio: Audio bytes
Returns:
Translation result with 'text' key
"""
return await self.transcribe(audio, task="translate")
async def health(self) -> bool:
"""Check if the STT service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False

113
handler_base/clients/tts.py Normal file
View File

@@ -0,0 +1,113 @@
"""
TTS service client (Coqui XTTS).
"""
import io
import logging
from typing import Optional
import httpx
from handler_base.config import TTSSettings
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class TTSClient:
"""
Client for the TTS service (Coqui XTTS).
Usage:
client = TTSClient()
audio_bytes = await client.synthesize("Hello world")
"""
def __init__(self, settings: Optional[TTSSettings] = None):
self.settings = settings or TTSSettings()
self._client = httpx.AsyncClient(
base_url=self.settings.tts_url,
timeout=120.0, # TTS can be slow
)
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def synthesize(
self,
text: str,
language: Optional[str] = None,
speaker: Optional[str] = None,
) -> bytes:
"""
Synthesize speech from text.
Args:
text: Text to synthesize
language: Language code (e.g., "en", "es", "fr")
speaker: Speaker ID or reference
Returns:
WAV audio bytes
"""
language = language or self.settings.tts_language
with create_span("tts.synthesize") as span:
if span:
span.set_attribute("tts.language", language)
span.set_attribute("tts.text_length", len(text))
params = {
"text": text,
"language_id": language,
}
if speaker:
params["speaker_id"] = speaker
response = await self._client.get("/api/tts", params=params)
response.raise_for_status()
audio_bytes = response.content
if span:
span.set_attribute("tts.audio_size", len(audio_bytes))
return audio_bytes
async def synthesize_to_file(
self,
text: str,
output_path: str,
language: Optional[str] = None,
speaker: Optional[str] = None,
) -> None:
"""
Synthesize speech and save to a file.
Args:
text: Text to synthesize
output_path: Path to save the audio file
language: Language code
speaker: Speaker ID
"""
audio_bytes = await self.synthesize(text, language, speaker)
with open(output_path, "wb") as f:
f.write(audio_bytes)
async def get_speakers(self) -> list[dict]:
"""Get available speakers/voices."""
try:
response = await self._client.get("/api/speakers")
response.raise_for_status()
return response.json()
except Exception:
return []
async def health(self) -> bool:
"""Check if the TTS service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False

99
handler_base/config.py Normal file
View File

@@ -0,0 +1,99 @@
"""
Configuration management using Pydantic Settings.
Environment variables are automatically loaded and validated.
"""
from typing import Optional
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Base settings for all handler services."""
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
# Service identification
service_name: str = "handler"
service_version: str = "1.0.0"
service_namespace: str = "ai-ml"
deployment_env: str = "production"
# NATS configuration
nats_url: str = "nats://nats.ai-ml.svc.cluster.local:4222"
nats_user: Optional[str] = None
nats_password: Optional[str] = None
nats_queue_group: Optional[str] = None
# Redis/Valkey configuration
redis_url: str = "redis://valkey.ai-ml.svc.cluster.local:6379"
redis_password: Optional[str] = None
# Milvus configuration
milvus_host: str = "milvus.ai-ml.svc.cluster.local"
milvus_port: int = 19530
milvus_collection: str = "documents"
# Service endpoints
embeddings_url: str = "http://embeddings-predictor.ai-ml.svc.cluster.local"
reranker_url: str = "http://reranker-predictor.ai-ml.svc.cluster.local"
llm_url: str = "http://vllm-predictor.ai-ml.svc.cluster.local"
tts_url: str = "http://tts-predictor.ai-ml.svc.cluster.local"
stt_url: str = "http://whisper-predictor.ai-ml.svc.cluster.local"
# OpenTelemetry configuration
otel_enabled: bool = True
otel_endpoint: str = "http://opentelemetry-collector.observability.svc.cluster.local:4317"
otel_use_http: bool = False
# HyperDX configuration
hyperdx_enabled: bool = False
hyperdx_api_key: Optional[str] = None
hyperdx_endpoint: str = "https://in-otel.hyperdx.io"
# MLflow configuration
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80"
mlflow_experiment_name: Optional[str] = None
mlflow_enabled: bool = True
# Health check configuration
health_port: int = 8080
health_path: str = "/health"
ready_path: str = "/ready"
# Timeouts (seconds)
http_timeout: float = 60.0
nats_timeout: float = 30.0
class EmbeddingsSettings(Settings):
"""Settings for embeddings service client."""
embeddings_model: str = "bge"
embeddings_batch_size: int = 32
class LLMSettings(Settings):
"""Settings for LLM service client."""
llm_model: str = "default"
llm_max_tokens: int = 2048
llm_temperature: float = 0.7
llm_top_p: float = 0.9
class TTSSettings(Settings):
"""Settings for TTS service client."""
tts_language: str = "en"
tts_speaker: Optional[str] = None
class STTSettings(Settings):
"""Settings for STT service client."""
stt_language: Optional[str] = None # Auto-detect
stt_task: str = "transcribe" # or "translate"

221
handler_base/handler.py Normal file
View File

@@ -0,0 +1,221 @@
"""
Base handler class for building NATS-based services.
"""
import asyncio
import logging
import signal
from abc import ABC, abstractmethod
from typing import Any, Optional
from nats.aio.msg import Msg
from handler_base.config import Settings
from handler_base.health import HealthServer
from handler_base.nats_client import NATSClient
from handler_base.telemetry import setup_telemetry, create_span
logger = logging.getLogger(__name__)
class Handler(ABC):
"""
Base class for NATS message handlers.
Subclass and implement:
- setup(): Initialize your service clients
- handle_message(): Process incoming messages
- teardown(): Clean up resources (optional)
Example:
class MyHandler(Handler):
async def setup(self):
self.embeddings = EmbeddingsClient()
async def handle_message(self, msg: Msg, data: dict) -> Optional[dict]:
result = await self.embeddings.embed(data["text"])
return {"embedding": result}
if __name__ == "__main__":
MyHandler(subject="my.subject").run()
"""
def __init__(
self,
subject: str,
settings: Optional[Settings] = None,
queue_group: Optional[str] = None,
):
"""
Initialize the handler.
Args:
subject: NATS subject to subscribe to
settings: Configuration settings
queue_group: Optional queue group for load balancing
"""
self.subject = subject
self.settings = settings or Settings()
self.queue_group = queue_group or self.settings.nats_queue_group
self.nats = NATSClient(self.settings)
self.health_server = HealthServer(self.settings, self._check_ready)
self._running = False
self._shutdown_event = asyncio.Event()
@abstractmethod
async def setup(self) -> None:
"""
Initialize service clients and resources.
Called once before starting to handle messages.
Override this to set up your service-specific clients.
"""
pass
@abstractmethod
async def handle_message(self, msg: Msg, data: Any) -> Optional[Any]:
"""
Handle an incoming message.
Args:
msg: Raw NATS message
data: Decoded message data (msgpack unpacked)
Returns:
Optional response data. If returned and msg has a reply subject,
the response will be sent automatically.
"""
pass
async def teardown(self) -> None:
"""
Clean up resources.
Called during graceful shutdown.
Override to add custom cleanup logic.
"""
pass
async def _check_ready(self) -> bool:
"""Check if the service is ready to handle requests."""
return self._running and self.nats._nc is not None
async def _message_handler(self, msg: Msg) -> None:
"""Internal message handler with tracing and error handling."""
with create_span(f"handle.{self.subject}") as span:
try:
# Decode message
data = NATSClient.decode_msgpack(msg)
if span:
span.set_attribute("messaging.destination", msg.subject)
if isinstance(data, dict):
request_id = data.get("request_id", data.get("id"))
if request_id:
span.set_attribute("request.id", str(request_id))
# Handle message
response = await self.handle_message(msg, data)
# Send response if applicable
if response is not None and msg.reply:
await self.nats.publish(msg.reply, response)
except Exception as e:
logger.exception(f"Error handling message on {msg.subject}")
if span:
span.set_attribute("error", True)
span.set_attribute("error.message", str(e))
# Send error response if reply expected
if msg.reply:
error_response = {
"error": True,
"message": str(e),
"type": type(e).__name__,
}
await self.nats.publish(msg.reply, error_response)
def _setup_signals(self) -> None:
"""Set up signal handlers for graceful shutdown."""
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, self._handle_signal, sig)
def _handle_signal(self, sig: signal.Signals) -> None:
"""Handle shutdown signal."""
logger.info(f"Received {sig.name}, initiating graceful shutdown...")
self._shutdown_event.set()
async def _run(self) -> None:
"""Main async run loop."""
# Setup telemetry
setup_telemetry(self.settings)
# Start health server
self.health_server.start()
try:
# Connect to NATS
await self.nats.connect()
# Run user setup
logger.info("Running service setup...")
await self.setup()
# Subscribe to subject
await self.nats.subscribe(
self.subject,
self._message_handler,
queue=self.queue_group,
)
self._running = True
logger.info(f"Handler ready, listening on {self.subject}")
# Wait for shutdown signal
await self._shutdown_event.wait()
except Exception as e:
logger.exception("Fatal error in handler")
raise
finally:
self._running = False
# Graceful shutdown
logger.info("Shutting down...")
try:
await self.teardown()
except Exception as e:
logger.warning(f"Error during teardown: {e}")
await self.nats.close()
self.health_server.stop()
logger.info("Shutdown complete")
def run(self) -> None:
"""
Run the handler.
This is the main entry point. It sets up signal handlers
and runs the async event loop.
"""
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger.info(f"Starting {self.settings.service_name} v{self.settings.service_version}")
# Run the async loop
asyncio.run(self._run_with_signals())
async def _run_with_signals(self) -> None:
"""Run with signal handling."""
self._setup_signals()
await self._run()

124
handler_base/health.py Normal file
View File

@@ -0,0 +1,124 @@
"""
HTTP health check server.
Provides /health and /ready endpoints for Kubernetes probes.
"""
import asyncio
import logging
from typing import Callable, Optional, Awaitable
from http.server import HTTPServer, BaseHTTPRequestHandler
import threading
import json
from handler_base.config import Settings
logger = logging.getLogger(__name__)
class HealthHandler(BaseHTTPRequestHandler):
"""HTTP request handler for health checks."""
# Class-level state
ready_check: Optional[Callable[[], Awaitable[bool]]] = None
health_path: str = "/health"
ready_path: str = "/ready"
def log_message(self, format, *args):
"""Suppress default logging."""
pass
def do_GET(self):
"""Handle GET requests for health/ready endpoints."""
if self.path == self.health_path:
self._respond_ok({"status": "healthy"})
elif self.path == self.ready_path:
self._handle_ready()
else:
self._respond_not_found()
def _handle_ready(self):
"""Check readiness and respond."""
# Access via class to avoid method binding issues
ready_check = HealthHandler.ready_check
if ready_check is None:
self._respond_ok({"status": "ready"})
return
try:
# Run the async check in a new event loop
loop = asyncio.new_event_loop()
try:
is_ready = loop.run_until_complete(ready_check())
finally:
loop.close()
if is_ready:
self._respond_ok({"status": "ready"})
else:
self._respond_unavailable({"status": "not ready"})
except Exception as e:
logger.exception("Readiness check failed")
self._respond_unavailable({"status": "error", "message": str(e)})
def _respond_ok(self, data: dict):
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(json.dumps(data).encode())
def _respond_unavailable(self, data: dict):
self.send_response(503)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(json.dumps(data).encode())
def _respond_not_found(self):
self.send_response(404)
self.end_headers()
class HealthServer:
"""
Background HTTP server for health checks.
Usage:
server = HealthServer(settings)
server.start()
# ... run your service ...
server.stop()
"""
def __init__(
self,
settings: Optional[Settings] = None,
ready_check: Optional[Callable[[], Awaitable[bool]]] = None,
):
self.settings = settings or Settings()
self.ready_check = ready_check
self._server: Optional[HTTPServer] = None
self._thread: Optional[threading.Thread] = None
def start(self) -> None:
"""Start the health check server in a background thread."""
# Configure handler class
HealthHandler.ready_check = self.ready_check
HealthHandler.health_path = self.settings.health_path
HealthHandler.ready_path = self.settings.ready_path
# Create and start server
self._server = HTTPServer(("0.0.0.0", self.settings.health_port), HealthHandler)
self._thread = threading.Thread(target=self._server.serve_forever, daemon=True)
self._thread.start()
logger.info(
f"Health server started on port {self.settings.health_port} "
f"(health: {self.settings.health_path}, ready: {self.settings.ready_path})"
)
def stop(self) -> None:
"""Stop the health check server."""
if self._server:
self._server.shutdown()
self._server = None
self._thread = None
logger.info("Health server stopped")

184
handler_base/nats_client.py Normal file
View File

@@ -0,0 +1,184 @@
"""
NATS client wrapper with connection management and utilities.
"""
import asyncio
import logging
from typing import Any, Callable, Optional, Awaitable
import msgpack
import nats
from nats.aio.client import Client
from nats.aio.msg import Msg
from nats.js import JetStreamContext
from handler_base.config import Settings
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class NATSClient:
"""
NATS client with automatic connection management.
Supports:
- Core NATS pub/sub
- JetStream for persistence
- Queue groups for load balancing
- Msgpack serialization
"""
def __init__(self, settings: Optional[Settings] = None):
self.settings = settings or Settings()
self._nc: Optional[Client] = None
self._js: Optional[JetStreamContext] = None
self._subscriptions: list = []
@property
def nc(self) -> Client:
"""Get the NATS client, raising if not connected."""
if self._nc is None:
raise RuntimeError("NATS client not connected. Call connect() first.")
return self._nc
@property
def js(self) -> JetStreamContext:
"""Get JetStream context, raising if not connected."""
if self._js is None:
raise RuntimeError("JetStream not initialized. Call connect() first.")
return self._js
async def connect(self) -> None:
"""Connect to NATS server."""
connect_opts = {
"servers": self.settings.nats_url,
"reconnect_time_wait": 2,
"max_reconnect_attempts": -1, # Infinite
}
if self.settings.nats_user and self.settings.nats_password:
connect_opts["user"] = self.settings.nats_user
connect_opts["password"] = self.settings.nats_password
logger.info(f"Connecting to NATS at {self.settings.nats_url}")
self._nc = await nats.connect(**connect_opts)
self._js = self._nc.jetstream()
logger.info("Connected to NATS")
async def close(self) -> None:
"""Close NATS connection gracefully."""
if self._nc:
# Drain subscriptions first
for sub in self._subscriptions:
try:
await sub.drain()
except Exception as e:
logger.warning(f"Error draining subscription: {e}")
await self._nc.drain()
await self._nc.close()
self._nc = None
self._js = None
logger.info("NATS connection closed")
async def subscribe(
self,
subject: str,
handler: Callable[[Msg], Awaitable[None]],
queue: Optional[str] = None,
):
"""
Subscribe to a subject with a handler function.
Args:
subject: NATS subject to subscribe to
handler: Async function to handle messages
queue: Optional queue group for load balancing
"""
queue = queue or self.settings.nats_queue_group
if queue:
sub = await self.nc.subscribe(subject, queue=queue, cb=handler)
logger.info(f"Subscribed to {subject} (queue: {queue})")
else:
sub = await self.nc.subscribe(subject, cb=handler)
logger.info(f"Subscribed to {subject}")
self._subscriptions.append(sub)
return sub
async def publish(
self,
subject: str,
data: Any,
use_msgpack: bool = True,
) -> None:
"""
Publish a message to a subject.
Args:
subject: NATS subject to publish to
data: Data to publish (will be serialized)
use_msgpack: Whether to use msgpack (True) or JSON (False)
"""
with create_span("nats.publish") as span:
if span:
span.set_attribute("messaging.destination", subject)
if use_msgpack:
payload = msgpack.packb(data, use_bin_type=True)
else:
import json
payload = json.dumps(data).encode()
await self.nc.publish(subject, payload)
async def request(
self,
subject: str,
data: Any,
timeout: Optional[float] = None,
use_msgpack: bool = True,
) -> Any:
"""
Send a request and wait for response.
Args:
subject: NATS subject to send request to
data: Request data
timeout: Response timeout in seconds
use_msgpack: Whether to use msgpack serialization
Returns:
Decoded response data
"""
timeout = timeout or self.settings.nats_timeout
with create_span("nats.request") as span:
if span:
span.set_attribute("messaging.destination", subject)
if use_msgpack:
payload = msgpack.packb(data, use_bin_type=True)
else:
import json
payload = json.dumps(data).encode()
response = await self.nc.request(subject, payload, timeout=timeout)
if use_msgpack:
return msgpack.unpackb(response.data, raw=False)
else:
import json
return json.loads(response.data.decode())
@staticmethod
def decode_msgpack(msg: Msg) -> Any:
"""Decode a msgpack message."""
return msgpack.unpackb(msg.data, raw=False)
@staticmethod
def decode_json(msg: Msg) -> Any:
"""Decode a JSON message."""
import json
return json.loads(msg.data.decode())

154
handler_base/telemetry.py Normal file
View File

@@ -0,0 +1,154 @@
"""
OpenTelemetry setup for tracing and metrics.
Supports both gRPC and HTTP exporters, with optional HyperDX integration.
"""
import logging
import os
from typing import Optional, Tuple
from opentelemetry import trace, metrics
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter as OTLPSpanExporterHTTP,
)
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
OTLPMetricExporter as OTLPMetricExporterHTTP,
)
from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
from opentelemetry.instrumentation.logging import LoggingInstrumentor
from handler_base.config import Settings
logger = logging.getLogger(__name__)
# Global references
_tracer: Optional[trace.Tracer] = None
_meter: Optional[metrics.Meter] = None
_initialized = False
def setup_telemetry(
settings: Optional[Settings] = None,
) -> Tuple[Optional[trace.Tracer], Optional[metrics.Meter]]:
"""
Initialize OpenTelemetry tracing and metrics.
Args:
settings: Configuration settings. If None, loads from environment.
Returns:
Tuple of (tracer, meter) or (None, None) if disabled.
"""
global _tracer, _meter, _initialized
if _initialized:
return _tracer, _meter
if settings is None:
settings = Settings()
if not settings.otel_enabled:
logger.info("OpenTelemetry disabled")
_initialized = True
return None, None
# Create resource with service information
resource = Resource.create({
SERVICE_NAME: settings.service_name,
SERVICE_VERSION: settings.service_version,
SERVICE_NAMESPACE: settings.service_namespace,
"deployment.environment": settings.deployment_env,
"host.name": os.environ.get("HOSTNAME", "unknown"),
})
# Determine endpoint and exporter type
if settings.hyperdx_enabled and settings.hyperdx_api_key:
# HyperDX uses HTTP with API key header
endpoint = settings.hyperdx_endpoint
headers = {"authorization": settings.hyperdx_api_key}
use_http = True
logger.info(f"Using HyperDX endpoint: {endpoint}")
else:
endpoint = settings.otel_endpoint
headers = None
use_http = settings.otel_use_http
logger.info(f"Using OTEL endpoint: {endpoint} (HTTP: {use_http})")
# Setup tracing
if use_http:
trace_exporter = OTLPSpanExporterHTTP(
endpoint=f"{endpoint}/v1/traces",
headers=headers,
)
else:
trace_exporter = OTLPSpanExporter(
endpoint=endpoint,
)
tracer_provider = TracerProvider(resource=resource)
tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter))
trace.set_tracer_provider(tracer_provider)
# Setup metrics
if use_http:
metric_exporter = OTLPMetricExporterHTTP(
endpoint=f"{endpoint}/v1/metrics",
headers=headers,
)
else:
metric_exporter = OTLPMetricExporter(
endpoint=endpoint,
)
metric_reader = PeriodicExportingMetricReader(
metric_exporter,
export_interval_millis=60000,
)
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
metrics.set_meter_provider(meter_provider)
# Instrument libraries
HTTPXClientInstrumentor().instrument()
LoggingInstrumentor().instrument(set_logging_format=True)
# Create tracer and meter for this service
_tracer = trace.get_tracer(settings.service_name, settings.service_version)
_meter = metrics.get_meter(settings.service_name, settings.service_version)
logger.info(f"OpenTelemetry initialized for {settings.service_name}")
_initialized = True
return _tracer, _meter
def get_tracer() -> Optional[trace.Tracer]:
"""Get the global tracer instance."""
return _tracer
def get_meter() -> Optional[metrics.Meter]:
"""Get the global meter instance."""
return _meter
def create_span(name: str, **kwargs):
"""
Create a new span.
Usage:
with create_span("my_operation") as span:
span.set_attribute("key", "value")
# do work
"""
if _tracer is None:
# Return a no-op context manager
from contextlib import nullcontext
return nullcontext()
return _tracer.start_as_current_span(name, **kwargs)