feat: Add handler-base library for NATS AI/ML services

- Handler base class with graceful shutdown and signal handling
- NATSClient with JetStream and msgpack serialization
- Pydantic Settings for environment configuration
- HealthServer for Kubernetes probes
- OpenTelemetry telemetry setup
- Service clients: STT, TTS, LLM, Embeddings, Reranker, Milvus
This commit is contained in:
2026-02-01 20:36:00 -05:00
parent 00df482412
commit 99c97b7973
17 changed files with 1932 additions and 1 deletions

View File

@@ -0,0 +1,18 @@
"""
Service client wrappers for AI/ML backends.
"""
from handler_base.clients.embeddings import EmbeddingsClient
from handler_base.clients.reranker import RerankerClient
from handler_base.clients.llm import LLMClient
from handler_base.clients.tts import TTSClient
from handler_base.clients.stt import STTClient
from handler_base.clients.milvus import MilvusClient
__all__ = [
"EmbeddingsClient",
"RerankerClient",
"LLMClient",
"TTSClient",
"STTClient",
"MilvusClient",
]

View File

@@ -0,0 +1,91 @@
"""
Embeddings service client (Infinity/BGE).
"""
import logging
from typing import Optional
import httpx
from handler_base.config import EmbeddingsSettings
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class EmbeddingsClient:
"""
Client for the embeddings service (Infinity with BGE models).
Usage:
client = EmbeddingsClient()
embeddings = await client.embed(["Hello world"])
"""
def __init__(self, settings: Optional[EmbeddingsSettings] = None):
self.settings = settings or EmbeddingsSettings()
self._client = httpx.AsyncClient(
base_url=self.settings.embeddings_url,
timeout=self.settings.http_timeout,
)
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def embed(
self,
texts: list[str],
model: Optional[str] = None,
) -> list[list[float]]:
"""
Generate embeddings for a list of texts.
Args:
texts: List of texts to embed
model: Model name (defaults to settings)
Returns:
List of embedding vectors
"""
model = model or self.settings.embeddings_model
with create_span("embeddings.embed") as span:
if span:
span.set_attribute("embeddings.model", model)
span.set_attribute("embeddings.batch_size", len(texts))
response = await self._client.post(
"/embeddings",
json={"input": texts, "model": model},
)
response.raise_for_status()
result = response.json()
embeddings = [d["embedding"] for d in result.get("data", [])]
if span:
span.set_attribute("embeddings.dimensions", len(embeddings[0]) if embeddings else 0)
return embeddings
async def embed_single(self, text: str, model: Optional[str] = None) -> list[float]:
"""
Generate embedding for a single text.
Args:
text: Text to embed
model: Model name (defaults to settings)
Returns:
Embedding vector
"""
embeddings = await self.embed([text], model)
return embeddings[0] if embeddings else []
async def health(self) -> bool:
"""Check if the embeddings service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False

192
handler_base/clients/llm.py Normal file
View File

@@ -0,0 +1,192 @@
"""
LLM service client (vLLM/OpenAI-compatible).
"""
import logging
from typing import Optional, AsyncIterator
import httpx
from handler_base.config import LLMSettings
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class LLMClient:
"""
Client for the LLM service (vLLM with OpenAI-compatible API).
Usage:
client = LLMClient()
response = await client.generate("Hello, how are you?")
# With context for RAG
response = await client.generate(
"What is the capital?",
context="France is a country in Europe..."
)
# Streaming
async for chunk in client.stream("Tell me a story"):
print(chunk, end="")
"""
def __init__(self, settings: Optional[LLMSettings] = None):
self.settings = settings or LLMSettings()
self._client = httpx.AsyncClient(
base_url=self.settings.llm_url,
timeout=self.settings.http_timeout,
)
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def generate(
self,
prompt: str,
context: Optional[str] = None,
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
stop: Optional[list[str]] = None,
) -> str:
"""
Generate a response from the LLM.
Args:
prompt: User prompt/query
context: Optional context for RAG
system_prompt: Optional system prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Top-p sampling
stop: Stop sequences
Returns:
Generated text response
"""
with create_span("llm.generate") as span:
messages = self._build_messages(prompt, context, system_prompt)
if span:
span.set_attribute("llm.model", self.settings.llm_model)
span.set_attribute("llm.prompt_length", len(prompt))
if context:
span.set_attribute("llm.context_length", len(context))
payload = {
"model": self.settings.llm_model,
"messages": messages,
"max_tokens": max_tokens or self.settings.llm_max_tokens,
"temperature": temperature or self.settings.llm_temperature,
"top_p": top_p or self.settings.llm_top_p,
}
if stop:
payload["stop"] = stop
response = await self._client.post("/v1/chat/completions", json=payload)
response.raise_for_status()
result = response.json()
content = result["choices"][0]["message"]["content"]
if span:
span.set_attribute("llm.response_length", len(content))
usage = result.get("usage", {})
span.set_attribute("llm.prompt_tokens", usage.get("prompt_tokens", 0))
span.set_attribute("llm.completion_tokens", usage.get("completion_tokens", 0))
return content
async def stream(
self,
prompt: str,
context: Optional[str] = None,
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
) -> AsyncIterator[str]:
"""
Stream a response from the LLM.
Args:
prompt: User prompt/query
context: Optional context for RAG
system_prompt: Optional system prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
Yields:
Text chunks as they're generated
"""
messages = self._build_messages(prompt, context, system_prompt)
payload = {
"model": self.settings.llm_model,
"messages": messages,
"max_tokens": max_tokens or self.settings.llm_max_tokens,
"temperature": temperature or self.settings.llm_temperature,
"stream": True,
}
async with self._client.stream(
"POST", "/v1/chat/completions", json=payload
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data = line[6:]
if data == "[DONE]":
break
import json
chunk = json.loads(data)
delta = chunk["choices"][0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
def _build_messages(
self,
prompt: str,
context: Optional[str] = None,
system_prompt: Optional[str] = None,
) -> list[dict]:
"""Build the messages list for the API call."""
messages = []
# System prompt
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
elif context:
# Default RAG system prompt
messages.append({
"role": "system",
"content": (
"You are a helpful assistant. Use the provided context to answer "
"the user's question. If the context doesn't contain relevant "
"information, say so."
),
})
# Add context as a separate message if provided
if context:
messages.append({
"role": "user",
"content": f"Context:\n{context}\n\nQuestion: {prompt}",
})
else:
messages.append({"role": "user", "content": prompt})
return messages
async def health(self) -> bool:
"""Check if the LLM service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False

View File

@@ -0,0 +1,182 @@
"""
Milvus vector database client.
"""
import logging
from typing import Optional, Any
from pymilvus import connections, Collection, utility
from handler_base.config import Settings
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class MilvusClient:
"""
Client for Milvus vector database.
Usage:
client = MilvusClient()
await client.connect()
results = await client.search(embedding, limit=10)
"""
def __init__(self, settings: Optional[Settings] = None):
self.settings = settings or Settings()
self._connected = False
self._collection: Optional[Collection] = None
async def connect(self, collection_name: Optional[str] = None) -> None:
"""
Connect to Milvus and load collection.
Args:
collection_name: Collection to use (defaults to settings)
"""
collection_name = collection_name or self.settings.milvus_collection
connections.connect(
alias="default",
host=self.settings.milvus_host,
port=self.settings.milvus_port,
)
if utility.has_collection(collection_name):
self._collection = Collection(collection_name)
self._collection.load()
logger.info(f"Connected to Milvus collection: {collection_name}")
else:
logger.warning(f"Collection {collection_name} does not exist")
self._connected = True
async def close(self) -> None:
"""Close Milvus connection."""
if self._collection:
self._collection.release()
connections.disconnect("default")
self._connected = False
logger.info("Disconnected from Milvus")
async def search(
self,
embedding: list[float],
limit: int = 10,
output_fields: Optional[list[str]] = None,
filter_expr: Optional[str] = None,
) -> list[dict]:
"""
Search for similar vectors.
Args:
embedding: Query embedding vector
limit: Maximum number of results
output_fields: Fields to return (default: all)
filter_expr: Optional filter expression
Returns:
List of results with 'id', 'distance', and requested fields
"""
if not self._collection:
raise RuntimeError("Not connected to collection")
with create_span("milvus.search") as span:
if span:
span.set_attribute("milvus.collection", self._collection.name)
span.set_attribute("milvus.limit", limit)
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
results = self._collection.search(
data=[embedding],
anns_field="embedding",
param=search_params,
limit=limit,
output_fields=output_fields,
expr=filter_expr,
)
# Convert to list of dicts
hits = []
for hit in results[0]:
item = {
"id": hit.id,
"distance": hit.distance,
"score": 1 - hit.distance, # Convert distance to similarity
}
# Add output fields
if output_fields:
for field in output_fields:
if hasattr(hit.entity, field):
item[field] = getattr(hit.entity, field)
hits.append(item)
if span:
span.set_attribute("milvus.results", len(hits))
return hits
async def search_with_texts(
self,
embedding: list[float],
limit: int = 10,
text_field: str = "text",
metadata_fields: Optional[list[str]] = None,
) -> list[dict]:
"""
Search and return text content with metadata.
Args:
embedding: Query embedding
limit: Maximum results
text_field: Name of text field in collection
metadata_fields: Additional metadata fields to return
Returns:
List of results with text and metadata
"""
output_fields = [text_field]
if metadata_fields:
output_fields.extend(metadata_fields)
return await self.search(embedding, limit, output_fields)
async def insert(
self,
embeddings: list[list[float]],
data: list[dict],
) -> list[int]:
"""
Insert vectors with data into the collection.
Args:
embeddings: List of embedding vectors
data: List of dicts with field values
Returns:
List of inserted IDs
"""
if not self._collection:
raise RuntimeError("Not connected to collection")
with create_span("milvus.insert") as span:
if span:
span.set_attribute("milvus.collection", self._collection.name)
span.set_attribute("milvus.count", len(embeddings))
# Build insert data
insert_data = [embeddings]
for field in self._collection.schema.fields:
if field.name not in ("id", "embedding"):
field_values = [d.get(field.name) for d in data]
insert_data.append(field_values)
result = self._collection.insert(insert_data)
self._collection.flush()
return result.primary_keys
def health(self) -> bool:
"""Check if connected to Milvus."""
return self._connected and utility.get_connection_addr("default") is not None

View File

@@ -0,0 +1,120 @@
"""
Reranker service client (Infinity/BGE Reranker).
"""
import logging
from typing import Optional
import httpx
from handler_base.config import Settings
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class RerankerClient:
"""
Client for the reranker service (Infinity with BGE Reranker).
Usage:
client = RerankerClient()
reranked = await client.rerank("query", ["doc1", "doc2"])
"""
def __init__(self, settings: Optional[Settings] = None):
self.settings = settings or Settings()
self._client = httpx.AsyncClient(
base_url=self.settings.reranker_url,
timeout=self.settings.http_timeout,
)
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def rerank(
self,
query: str,
documents: list[str],
top_k: Optional[int] = None,
) -> list[dict]:
"""
Rerank documents based on relevance to query.
Args:
query: Query text
documents: List of documents to rerank
top_k: Number of top results to return (default: all)
Returns:
List of dicts with 'index', 'score', and 'document' keys,
sorted by relevance score descending.
"""
with create_span("reranker.rerank") as span:
if span:
span.set_attribute("reranker.num_documents", len(documents))
if top_k:
span.set_attribute("reranker.top_k", top_k)
payload = {
"query": query,
"documents": documents,
}
if top_k:
payload["top_n"] = top_k
response = await self._client.post("/rerank", json=payload)
response.raise_for_status()
result = response.json()
results = result.get("results", [])
# Enrich with original documents
enriched = []
for r in results:
idx = r.get("index", 0)
enriched.append({
"index": idx,
"score": r.get("relevance_score", r.get("score", 0)),
"document": documents[idx] if idx < len(documents) else "",
})
return enriched
async def rerank_with_metadata(
self,
query: str,
documents: list[dict],
text_key: str = "text",
top_k: Optional[int] = None,
) -> list[dict]:
"""
Rerank documents with metadata, preserving metadata in results.
Args:
query: Query text
documents: List of dicts with text and metadata
text_key: Key containing text in each document dict
top_k: Number of top results to return
Returns:
Reranked documents with original metadata preserved.
"""
texts = [d.get(text_key, "") for d in documents]
reranked = await self.rerank(query, texts, top_k)
# Merge back metadata
for r in reranked:
idx = r["index"]
if idx < len(documents):
r["metadata"] = {k: v for k, v in documents[idx].items() if k != text_key}
return reranked
async def health(self) -> bool:
"""Check if the reranker service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False

132
handler_base/clients/stt.py Normal file
View File

@@ -0,0 +1,132 @@
"""
STT service client (Whisper/faster-whisper).
"""
import io
import logging
from typing import Optional
import httpx
from handler_base.config import STTSettings
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class STTClient:
"""
Client for the STT service (Whisper/faster-whisper).
Usage:
client = STTClient()
text = await client.transcribe(audio_bytes)
"""
def __init__(self, settings: Optional[STTSettings] = None):
self.settings = settings or STTSettings()
self._client = httpx.AsyncClient(
base_url=self.settings.stt_url,
timeout=180.0, # Transcription can be slow
)
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def transcribe(
self,
audio: bytes,
language: Optional[str] = None,
task: Optional[str] = None,
response_format: str = "json",
) -> dict:
"""
Transcribe audio to text.
Args:
audio: Audio bytes (WAV, MP3, etc.)
language: Language code (None for auto-detect)
task: "transcribe" or "translate"
response_format: "json", "text", "srt", "vtt"
Returns:
Dict with 'text', 'language', and optional 'segments'
"""
language = language or self.settings.stt_language
task = task or self.settings.stt_task
with create_span("stt.transcribe") as span:
if span:
span.set_attribute("stt.task", task)
span.set_attribute("stt.audio_size", len(audio))
if language:
span.set_attribute("stt.language", language)
files = {"file": ("audio.wav", audio, "audio/wav")}
data = {
"response_format": response_format,
}
if language:
data["language"] = language
# Choose endpoint based on task
if task == "translate":
endpoint = "/v1/audio/translations"
else:
endpoint = "/v1/audio/transcriptions"
response = await self._client.post(endpoint, files=files, data=data)
response.raise_for_status()
if response_format == "text":
return {"text": response.text}
result = response.json()
if span:
span.set_attribute("stt.result_length", len(result.get("text", "")))
if result.get("language"):
span.set_attribute("stt.detected_language", result["language"])
return result
async def transcribe_file(
self,
file_path: str,
language: Optional[str] = None,
task: Optional[str] = None,
) -> dict:
"""
Transcribe an audio file.
Args:
file_path: Path to audio file
language: Language code
task: "transcribe" or "translate"
Returns:
Transcription result
"""
with open(file_path, "rb") as f:
audio = f.read()
return await self.transcribe(audio, language, task)
async def translate(self, audio: bytes) -> dict:
"""
Translate audio to English.
Args:
audio: Audio bytes
Returns:
Translation result with 'text' key
"""
return await self.transcribe(audio, task="translate")
async def health(self) -> bool:
"""Check if the STT service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False

113
handler_base/clients/tts.py Normal file
View File

@@ -0,0 +1,113 @@
"""
TTS service client (Coqui XTTS).
"""
import io
import logging
from typing import Optional
import httpx
from handler_base.config import TTSSettings
from handler_base.telemetry import create_span
logger = logging.getLogger(__name__)
class TTSClient:
"""
Client for the TTS service (Coqui XTTS).
Usage:
client = TTSClient()
audio_bytes = await client.synthesize("Hello world")
"""
def __init__(self, settings: Optional[TTSSettings] = None):
self.settings = settings or TTSSettings()
self._client = httpx.AsyncClient(
base_url=self.settings.tts_url,
timeout=120.0, # TTS can be slow
)
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def synthesize(
self,
text: str,
language: Optional[str] = None,
speaker: Optional[str] = None,
) -> bytes:
"""
Synthesize speech from text.
Args:
text: Text to synthesize
language: Language code (e.g., "en", "es", "fr")
speaker: Speaker ID or reference
Returns:
WAV audio bytes
"""
language = language or self.settings.tts_language
with create_span("tts.synthesize") as span:
if span:
span.set_attribute("tts.language", language)
span.set_attribute("tts.text_length", len(text))
params = {
"text": text,
"language_id": language,
}
if speaker:
params["speaker_id"] = speaker
response = await self._client.get("/api/tts", params=params)
response.raise_for_status()
audio_bytes = response.content
if span:
span.set_attribute("tts.audio_size", len(audio_bytes))
return audio_bytes
async def synthesize_to_file(
self,
text: str,
output_path: str,
language: Optional[str] = None,
speaker: Optional[str] = None,
) -> None:
"""
Synthesize speech and save to a file.
Args:
text: Text to synthesize
output_path: Path to save the audio file
language: Language code
speaker: Speaker ID
"""
audio_bytes = await self.synthesize(text, language, speaker)
with open(output_path, "wb") as f:
f.write(audio_bytes)
async def get_speakers(self) -> list[dict]:
"""Get available speakers/voices."""
try:
response = await self._client.get("/api/speakers")
response.raise_for_status()
return response.json()
except Exception:
return []
async def health(self) -> bool:
"""Check if the TTS service is healthy."""
try:
response = await self._client.get("/health")
return response.status_code == 200
except Exception:
return False