fix: auto-fix ruff linting errors and remove unsupported upload-artifact

This commit is contained in:
2026-02-02 08:34:00 -05:00
parent 7b30ff6a05
commit 8e266cd488
19 changed files with 414 additions and 400 deletions

View File

@@ -57,12 +57,6 @@ jobs:
- name: Run tests with coverage - name: Run tests with coverage
run: uv run pytest --cov=handler_base --cov-report=xml --cov-report=term run: uv run pytest --cov=handler_base --cov-report=xml --cov-report=term
- name: Upload coverage artifact
uses: actions/upload-artifact@v4
with:
name: coverage
path: coverage.xml
release: release:
name: Release name: Release
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@@ -8,11 +8,12 @@ Provides consistent patterns for:
- Graceful shutdown - Graceful shutdown
- Service client wrappers - Service client wrappers
""" """
from handler_base.config import Settings from handler_base.config import Settings
from handler_base.handler import Handler from handler_base.handler import Handler
from handler_base.health import HealthServer from handler_base.health import HealthServer
from handler_base.nats_client import NATSClient from handler_base.nats_client import NATSClient
from handler_base.telemetry import setup_telemetry, get_tracer, get_meter from handler_base.telemetry import get_meter, get_tracer, setup_telemetry
__all__ = [ __all__ = [
"Handler", "Handler",

View File

@@ -1,12 +1,13 @@
""" """
Service client wrappers for AI/ML backends. Service client wrappers for AI/ML backends.
""" """
from handler_base.clients.embeddings import EmbeddingsClient from handler_base.clients.embeddings import EmbeddingsClient
from handler_base.clients.reranker import RerankerClient
from handler_base.clients.llm import LLMClient from handler_base.clients.llm import LLMClient
from handler_base.clients.tts import TTSClient
from handler_base.clients.stt import STTClient
from handler_base.clients.milvus import MilvusClient from handler_base.clients.milvus import MilvusClient
from handler_base.clients.reranker import RerankerClient
from handler_base.clients.stt import STTClient
from handler_base.clients.tts import TTSClient
__all__ = [ __all__ = [
"EmbeddingsClient", "EmbeddingsClient",

View File

@@ -1,6 +1,7 @@
""" """
Embeddings service client (Infinity/BGE). Embeddings service client (Infinity/BGE).
""" """
import logging import logging
from typing import Optional from typing import Optional
@@ -15,23 +16,23 @@ logger = logging.getLogger(__name__)
class EmbeddingsClient: class EmbeddingsClient:
""" """
Client for the embeddings service (Infinity with BGE models). Client for the embeddings service (Infinity with BGE models).
Usage: Usage:
client = EmbeddingsClient() client = EmbeddingsClient()
embeddings = await client.embed(["Hello world"]) embeddings = await client.embed(["Hello world"])
""" """
def __init__(self, settings: Optional[EmbeddingsSettings] = None): def __init__(self, settings: Optional[EmbeddingsSettings] = None):
self.settings = settings or EmbeddingsSettings() self.settings = settings or EmbeddingsSettings()
self._client = httpx.AsyncClient( self._client = httpx.AsyncClient(
base_url=self.settings.embeddings_url, base_url=self.settings.embeddings_url,
timeout=self.settings.http_timeout, timeout=self.settings.http_timeout,
) )
async def close(self) -> None: async def close(self) -> None:
"""Close the HTTP client.""" """Close the HTTP client."""
await self._client.aclose() await self._client.aclose()
async def embed( async def embed(
self, self,
texts: list[str], texts: list[str],
@@ -39,49 +40,49 @@ class EmbeddingsClient:
) -> list[list[float]]: ) -> list[list[float]]:
""" """
Generate embeddings for a list of texts. Generate embeddings for a list of texts.
Args: Args:
texts: List of texts to embed texts: List of texts to embed
model: Model name (defaults to settings) model: Model name (defaults to settings)
Returns: Returns:
List of embedding vectors List of embedding vectors
""" """
model = model or self.settings.embeddings_model model = model or self.settings.embeddings_model
with create_span("embeddings.embed") as span: with create_span("embeddings.embed") as span:
if span: if span:
span.set_attribute("embeddings.model", model) span.set_attribute("embeddings.model", model)
span.set_attribute("embeddings.batch_size", len(texts)) span.set_attribute("embeddings.batch_size", len(texts))
response = await self._client.post( response = await self._client.post(
"/embeddings", "/embeddings",
json={"input": texts, "model": model}, json={"input": texts, "model": model},
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
embeddings = [d["embedding"] for d in result.get("data", [])] embeddings = [d["embedding"] for d in result.get("data", [])]
if span: if span:
span.set_attribute("embeddings.dimensions", len(embeddings[0]) if embeddings else 0) span.set_attribute("embeddings.dimensions", len(embeddings[0]) if embeddings else 0)
return embeddings return embeddings
async def embed_single(self, text: str, model: Optional[str] = None) -> list[float]: async def embed_single(self, text: str, model: Optional[str] = None) -> list[float]:
""" """
Generate embedding for a single text. Generate embedding for a single text.
Args: Args:
text: Text to embed text: Text to embed
model: Model name (defaults to settings) model: Model name (defaults to settings)
Returns: Returns:
Embedding vector Embedding vector
""" """
embeddings = await self.embed([text], model) embeddings = await self.embed([text], model)
return embeddings[0] if embeddings else [] return embeddings[0] if embeddings else []
async def health(self) -> bool: async def health(self) -> bool:
"""Check if the embeddings service is healthy.""" """Check if the embeddings service is healthy."""
try: try:

View File

@@ -1,8 +1,9 @@
""" """
LLM service client (vLLM/OpenAI-compatible). LLM service client (vLLM/OpenAI-compatible).
""" """
import logging import logging
from typing import Optional, AsyncIterator from typing import AsyncIterator, Optional
import httpx import httpx
@@ -15,33 +16,33 @@ logger = logging.getLogger(__name__)
class LLMClient: class LLMClient:
""" """
Client for the LLM service (vLLM with OpenAI-compatible API). Client for the LLM service (vLLM with OpenAI-compatible API).
Usage: Usage:
client = LLMClient() client = LLMClient()
response = await client.generate("Hello, how are you?") response = await client.generate("Hello, how are you?")
# With context for RAG # With context for RAG
response = await client.generate( response = await client.generate(
"What is the capital?", "What is the capital?",
context="France is a country in Europe..." context="France is a country in Europe..."
) )
# Streaming # Streaming
async for chunk in client.stream("Tell me a story"): async for chunk in client.stream("Tell me a story"):
print(chunk, end="") print(chunk, end="")
""" """
def __init__(self, settings: Optional[LLMSettings] = None): def __init__(self, settings: Optional[LLMSettings] = None):
self.settings = settings or LLMSettings() self.settings = settings or LLMSettings()
self._client = httpx.AsyncClient( self._client = httpx.AsyncClient(
base_url=self.settings.llm_url, base_url=self.settings.llm_url,
timeout=self.settings.http_timeout, timeout=self.settings.http_timeout,
) )
async def close(self) -> None: async def close(self) -> None:
"""Close the HTTP client.""" """Close the HTTP client."""
await self._client.aclose() await self._client.aclose()
async def generate( async def generate(
self, self,
prompt: str, prompt: str,
@@ -54,7 +55,7 @@ class LLMClient:
) -> str: ) -> str:
""" """
Generate a response from the LLM. Generate a response from the LLM.
Args: Args:
prompt: User prompt/query prompt: User prompt/query
context: Optional context for RAG context: Optional context for RAG
@@ -63,19 +64,19 @@ class LLMClient:
temperature: Sampling temperature temperature: Sampling temperature
top_p: Top-p sampling top_p: Top-p sampling
stop: Stop sequences stop: Stop sequences
Returns: Returns:
Generated text response Generated text response
""" """
with create_span("llm.generate") as span: with create_span("llm.generate") as span:
messages = self._build_messages(prompt, context, system_prompt) messages = self._build_messages(prompt, context, system_prompt)
if span: if span:
span.set_attribute("llm.model", self.settings.llm_model) span.set_attribute("llm.model", self.settings.llm_model)
span.set_attribute("llm.prompt_length", len(prompt)) span.set_attribute("llm.prompt_length", len(prompt))
if context: if context:
span.set_attribute("llm.context_length", len(context)) span.set_attribute("llm.context_length", len(context))
payload = { payload = {
"model": self.settings.llm_model, "model": self.settings.llm_model,
"messages": messages, "messages": messages,
@@ -85,21 +86,21 @@ class LLMClient:
} }
if stop: if stop:
payload["stop"] = stop payload["stop"] = stop
response = await self._client.post("/v1/chat/completions", json=payload) response = await self._client.post("/v1/chat/completions", json=payload)
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
content = result["choices"][0]["message"]["content"] content = result["choices"][0]["message"]["content"]
if span: if span:
span.set_attribute("llm.response_length", len(content)) span.set_attribute("llm.response_length", len(content))
usage = result.get("usage", {}) usage = result.get("usage", {})
span.set_attribute("llm.prompt_tokens", usage.get("prompt_tokens", 0)) span.set_attribute("llm.prompt_tokens", usage.get("prompt_tokens", 0))
span.set_attribute("llm.completion_tokens", usage.get("completion_tokens", 0)) span.set_attribute("llm.completion_tokens", usage.get("completion_tokens", 0))
return content return content
async def stream( async def stream(
self, self,
prompt: str, prompt: str,
@@ -110,19 +111,19 @@ class LLMClient:
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
""" """
Stream a response from the LLM. Stream a response from the LLM.
Args: Args:
prompt: User prompt/query prompt: User prompt/query
context: Optional context for RAG context: Optional context for RAG
system_prompt: Optional system prompt system_prompt: Optional system prompt
max_tokens: Maximum tokens to generate max_tokens: Maximum tokens to generate
temperature: Sampling temperature temperature: Sampling temperature
Yields: Yields:
Text chunks as they're generated Text chunks as they're generated
""" """
messages = self._build_messages(prompt, context, system_prompt) messages = self._build_messages(prompt, context, system_prompt)
payload = { payload = {
"model": self.settings.llm_model, "model": self.settings.llm_model,
"messages": messages, "messages": messages,
@@ -130,25 +131,24 @@ class LLMClient:
"temperature": temperature or self.settings.llm_temperature, "temperature": temperature or self.settings.llm_temperature,
"stream": True, "stream": True,
} }
async with self._client.stream( async with self._client.stream("POST", "/v1/chat/completions", json=payload) as response:
"POST", "/v1/chat/completions", json=payload
) as response:
response.raise_for_status() response.raise_for_status()
async for line in response.aiter_lines(): async for line in response.aiter_lines():
if line.startswith("data: "): if line.startswith("data: "):
data = line[6:] data = line[6:]
if data == "[DONE]": if data == "[DONE]":
break break
import json import json
chunk = json.loads(data) chunk = json.loads(data)
delta = chunk["choices"][0].get("delta", {}) delta = chunk["choices"][0].get("delta", {})
content = delta.get("content", "") content = delta.get("content", "")
if content: if content:
yield content yield content
def _build_messages( def _build_messages(
self, self,
prompt: str, prompt: str,
@@ -157,32 +157,36 @@ class LLMClient:
) -> list[dict]: ) -> list[dict]:
"""Build the messages list for the API call.""" """Build the messages list for the API call."""
messages = [] messages = []
# System prompt # System prompt
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
elif context: elif context:
# Default RAG system prompt # Default RAG system prompt
messages.append({ messages.append(
"role": "system", {
"content": ( "role": "system",
"You are a helpful assistant. Use the provided context to answer " "content": (
"the user's question. If the context doesn't contain relevant " "You are a helpful assistant. Use the provided context to answer "
"information, say so." "the user's question. If the context doesn't contain relevant "
), "information, say so."
}) ),
}
)
# Add context as a separate message if provided # Add context as a separate message if provided
if context: if context:
messages.append({ messages.append(
"role": "user", {
"content": f"Context:\n{context}\n\nQuestion: {prompt}", "role": "user",
}) "content": f"Context:\n{context}\n\nQuestion: {prompt}",
}
)
else: else:
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
return messages return messages
async def health(self) -> bool: async def health(self) -> bool:
"""Check if the LLM service is healthy.""" """Check if the LLM service is healthy."""
try: try:

View File

@@ -1,10 +1,11 @@
""" """
Milvus vector database client. Milvus vector database client.
""" """
import logging
from typing import Optional, Any
from pymilvus import connections, Collection, utility import logging
from typing import Optional
from pymilvus import Collection, connections, utility
from handler_base.config import Settings from handler_base.config import Settings
from handler_base.telemetry import create_span from handler_base.telemetry import create_span
@@ -15,42 +16,42 @@ logger = logging.getLogger(__name__)
class MilvusClient: class MilvusClient:
""" """
Client for Milvus vector database. Client for Milvus vector database.
Usage: Usage:
client = MilvusClient() client = MilvusClient()
await client.connect() await client.connect()
results = await client.search(embedding, limit=10) results = await client.search(embedding, limit=10)
""" """
def __init__(self, settings: Optional[Settings] = None): def __init__(self, settings: Optional[Settings] = None):
self.settings = settings or Settings() self.settings = settings or Settings()
self._connected = False self._connected = False
self._collection: Optional[Collection] = None self._collection: Optional[Collection] = None
async def connect(self, collection_name: Optional[str] = None) -> None: async def connect(self, collection_name: Optional[str] = None) -> None:
""" """
Connect to Milvus and load collection. Connect to Milvus and load collection.
Args: Args:
collection_name: Collection to use (defaults to settings) collection_name: Collection to use (defaults to settings)
""" """
collection_name = collection_name or self.settings.milvus_collection collection_name = collection_name or self.settings.milvus_collection
connections.connect( connections.connect(
alias="default", alias="default",
host=self.settings.milvus_host, host=self.settings.milvus_host,
port=self.settings.milvus_port, port=self.settings.milvus_port,
) )
if utility.has_collection(collection_name): if utility.has_collection(collection_name):
self._collection = Collection(collection_name) self._collection = Collection(collection_name)
self._collection.load() self._collection.load()
logger.info(f"Connected to Milvus collection: {collection_name}") logger.info(f"Connected to Milvus collection: {collection_name}")
else: else:
logger.warning(f"Collection {collection_name} does not exist") logger.warning(f"Collection {collection_name} does not exist")
self._connected = True self._connected = True
async def close(self) -> None: async def close(self) -> None:
"""Close Milvus connection.""" """Close Milvus connection."""
if self._collection: if self._collection:
@@ -58,7 +59,7 @@ class MilvusClient:
connections.disconnect("default") connections.disconnect("default")
self._connected = False self._connected = False
logger.info("Disconnected from Milvus") logger.info("Disconnected from Milvus")
async def search( async def search(
self, self,
embedding: list[float], embedding: list[float],
@@ -68,26 +69,26 @@ class MilvusClient:
) -> list[dict]: ) -> list[dict]:
""" """
Search for similar vectors. Search for similar vectors.
Args: Args:
embedding: Query embedding vector embedding: Query embedding vector
limit: Maximum number of results limit: Maximum number of results
output_fields: Fields to return (default: all) output_fields: Fields to return (default: all)
filter_expr: Optional filter expression filter_expr: Optional filter expression
Returns: Returns:
List of results with 'id', 'distance', and requested fields List of results with 'id', 'distance', and requested fields
""" """
if not self._collection: if not self._collection:
raise RuntimeError("Not connected to collection") raise RuntimeError("Not connected to collection")
with create_span("milvus.search") as span: with create_span("milvus.search") as span:
if span: if span:
span.set_attribute("milvus.collection", self._collection.name) span.set_attribute("milvus.collection", self._collection.name)
span.set_attribute("milvus.limit", limit) span.set_attribute("milvus.limit", limit)
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
results = self._collection.search( results = self._collection.search(
data=[embedding], data=[embedding],
anns_field="embedding", anns_field="embedding",
@@ -96,7 +97,7 @@ class MilvusClient:
output_fields=output_fields, output_fields=output_fields,
expr=filter_expr, expr=filter_expr,
) )
# Convert to list of dicts # Convert to list of dicts
hits = [] hits = []
for hit in results[0]: for hit in results[0]:
@@ -111,12 +112,12 @@ class MilvusClient:
if hasattr(hit.entity, field): if hasattr(hit.entity, field):
item[field] = getattr(hit.entity, field) item[field] = getattr(hit.entity, field)
hits.append(item) hits.append(item)
if span: if span:
span.set_attribute("milvus.results", len(hits)) span.set_attribute("milvus.results", len(hits))
return hits return hits
async def search_with_texts( async def search_with_texts(
self, self,
embedding: list[float], embedding: list[float],
@@ -126,22 +127,22 @@ class MilvusClient:
) -> list[dict]: ) -> list[dict]:
""" """
Search and return text content with metadata. Search and return text content with metadata.
Args: Args:
embedding: Query embedding embedding: Query embedding
limit: Maximum results limit: Maximum results
text_field: Name of text field in collection text_field: Name of text field in collection
metadata_fields: Additional metadata fields to return metadata_fields: Additional metadata fields to return
Returns: Returns:
List of results with text and metadata List of results with text and metadata
""" """
output_fields = [text_field] output_fields = [text_field]
if metadata_fields: if metadata_fields:
output_fields.extend(metadata_fields) output_fields.extend(metadata_fields)
return await self.search(embedding, limit, output_fields) return await self.search(embedding, limit, output_fields)
async def insert( async def insert(
self, self,
embeddings: list[list[float]], embeddings: list[list[float]],
@@ -149,34 +150,34 @@ class MilvusClient:
) -> list[int]: ) -> list[int]:
""" """
Insert vectors with data into the collection. Insert vectors with data into the collection.
Args: Args:
embeddings: List of embedding vectors embeddings: List of embedding vectors
data: List of dicts with field values data: List of dicts with field values
Returns: Returns:
List of inserted IDs List of inserted IDs
""" """
if not self._collection: if not self._collection:
raise RuntimeError("Not connected to collection") raise RuntimeError("Not connected to collection")
with create_span("milvus.insert") as span: with create_span("milvus.insert") as span:
if span: if span:
span.set_attribute("milvus.collection", self._collection.name) span.set_attribute("milvus.collection", self._collection.name)
span.set_attribute("milvus.count", len(embeddings)) span.set_attribute("milvus.count", len(embeddings))
# Build insert data # Build insert data
insert_data = [embeddings] insert_data = [embeddings]
for field in self._collection.schema.fields: for field in self._collection.schema.fields:
if field.name not in ("id", "embedding"): if field.name not in ("id", "embedding"):
field_values = [d.get(field.name) for d in data] field_values = [d.get(field.name) for d in data]
insert_data.append(field_values) insert_data.append(field_values)
result = self._collection.insert(insert_data) result = self._collection.insert(insert_data)
self._collection.flush() self._collection.flush()
return result.primary_keys return result.primary_keys
def health(self) -> bool: def health(self) -> bool:
"""Check if connected to Milvus.""" """Check if connected to Milvus."""
return self._connected and utility.get_connection_addr("default") is not None return self._connected and utility.get_connection_addr("default") is not None

View File

@@ -1,6 +1,7 @@
""" """
Reranker service client (Infinity/BGE Reranker). Reranker service client (Infinity/BGE Reranker).
""" """
import logging import logging
from typing import Optional from typing import Optional
@@ -15,23 +16,23 @@ logger = logging.getLogger(__name__)
class RerankerClient: class RerankerClient:
""" """
Client for the reranker service (Infinity with BGE Reranker). Client for the reranker service (Infinity with BGE Reranker).
Usage: Usage:
client = RerankerClient() client = RerankerClient()
reranked = await client.rerank("query", ["doc1", "doc2"]) reranked = await client.rerank("query", ["doc1", "doc2"])
""" """
def __init__(self, settings: Optional[Settings] = None): def __init__(self, settings: Optional[Settings] = None):
self.settings = settings or Settings() self.settings = settings or Settings()
self._client = httpx.AsyncClient( self._client = httpx.AsyncClient(
base_url=self.settings.reranker_url, base_url=self.settings.reranker_url,
timeout=self.settings.http_timeout, timeout=self.settings.http_timeout,
) )
async def close(self) -> None: async def close(self) -> None:
"""Close the HTTP client.""" """Close the HTTP client."""
await self._client.aclose() await self._client.aclose()
async def rerank( async def rerank(
self, self,
query: str, query: str,
@@ -40,12 +41,12 @@ class RerankerClient:
) -> list[dict]: ) -> list[dict]:
""" """
Rerank documents based on relevance to query. Rerank documents based on relevance to query.
Args: Args:
query: Query text query: Query text
documents: List of documents to rerank documents: List of documents to rerank
top_k: Number of top results to return (default: all) top_k: Number of top results to return (default: all)
Returns: Returns:
List of dicts with 'index', 'score', and 'document' keys, List of dicts with 'index', 'score', and 'document' keys,
sorted by relevance score descending. sorted by relevance score descending.
@@ -55,32 +56,34 @@ class RerankerClient:
span.set_attribute("reranker.num_documents", len(documents)) span.set_attribute("reranker.num_documents", len(documents))
if top_k: if top_k:
span.set_attribute("reranker.top_k", top_k) span.set_attribute("reranker.top_k", top_k)
payload = { payload = {
"query": query, "query": query,
"documents": documents, "documents": documents,
} }
if top_k: if top_k:
payload["top_n"] = top_k payload["top_n"] = top_k
response = await self._client.post("/rerank", json=payload) response = await self._client.post("/rerank", json=payload)
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
results = result.get("results", []) results = result.get("results", [])
# Enrich with original documents # Enrich with original documents
enriched = [] enriched = []
for r in results: for r in results:
idx = r.get("index", 0) idx = r.get("index", 0)
enriched.append({ enriched.append(
"index": idx, {
"score": r.get("relevance_score", r.get("score", 0)), "index": idx,
"document": documents[idx] if idx < len(documents) else "", "score": r.get("relevance_score", r.get("score", 0)),
}) "document": documents[idx] if idx < len(documents) else "",
}
)
return enriched return enriched
async def rerank_with_metadata( async def rerank_with_metadata(
self, self,
query: str, query: str,
@@ -90,27 +93,27 @@ class RerankerClient:
) -> list[dict]: ) -> list[dict]:
""" """
Rerank documents with metadata, preserving metadata in results. Rerank documents with metadata, preserving metadata in results.
Args: Args:
query: Query text query: Query text
documents: List of dicts with text and metadata documents: List of dicts with text and metadata
text_key: Key containing text in each document dict text_key: Key containing text in each document dict
top_k: Number of top results to return top_k: Number of top results to return
Returns: Returns:
Reranked documents with original metadata preserved. Reranked documents with original metadata preserved.
""" """
texts = [d.get(text_key, "") for d in documents] texts = [d.get(text_key, "") for d in documents]
reranked = await self.rerank(query, texts, top_k) reranked = await self.rerank(query, texts, top_k)
# Merge back metadata # Merge back metadata
for r in reranked: for r in reranked:
idx = r["index"] idx = r["index"]
if idx < len(documents): if idx < len(documents):
r["metadata"] = {k: v for k, v in documents[idx].items() if k != text_key} r["metadata"] = {k: v for k, v in documents[idx].items() if k != text_key}
return reranked return reranked
async def health(self) -> bool: async def health(self) -> bool:
"""Check if the reranker service is healthy.""" """Check if the reranker service is healthy."""
try: try:

View File

@@ -1,7 +1,7 @@
""" """
STT service client (Whisper/faster-whisper). STT service client (Whisper/faster-whisper).
""" """
import io
import logging import logging
from typing import Optional from typing import Optional
@@ -16,23 +16,23 @@ logger = logging.getLogger(__name__)
class STTClient: class STTClient:
""" """
Client for the STT service (Whisper/faster-whisper). Client for the STT service (Whisper/faster-whisper).
Usage: Usage:
client = STTClient() client = STTClient()
text = await client.transcribe(audio_bytes) text = await client.transcribe(audio_bytes)
""" """
def __init__(self, settings: Optional[STTSettings] = None): def __init__(self, settings: Optional[STTSettings] = None):
self.settings = settings or STTSettings() self.settings = settings or STTSettings()
self._client = httpx.AsyncClient( self._client = httpx.AsyncClient(
base_url=self.settings.stt_url, base_url=self.settings.stt_url,
timeout=180.0, # Transcription can be slow timeout=180.0, # Transcription can be slow
) )
async def close(self) -> None: async def close(self) -> None:
"""Close the HTTP client.""" """Close the HTTP client."""
await self._client.aclose() await self._client.aclose()
async def transcribe( async def transcribe(
self, self,
audio: bytes, audio: bytes,
@@ -42,54 +42,54 @@ class STTClient:
) -> dict: ) -> dict:
""" """
Transcribe audio to text. Transcribe audio to text.
Args: Args:
audio: Audio bytes (WAV, MP3, etc.) audio: Audio bytes (WAV, MP3, etc.)
language: Language code (None for auto-detect) language: Language code (None for auto-detect)
task: "transcribe" or "translate" task: "transcribe" or "translate"
response_format: "json", "text", "srt", "vtt" response_format: "json", "text", "srt", "vtt"
Returns: Returns:
Dict with 'text', 'language', and optional 'segments' Dict with 'text', 'language', and optional 'segments'
""" """
language = language or self.settings.stt_language language = language or self.settings.stt_language
task = task or self.settings.stt_task task = task or self.settings.stt_task
with create_span("stt.transcribe") as span: with create_span("stt.transcribe") as span:
if span: if span:
span.set_attribute("stt.task", task) span.set_attribute("stt.task", task)
span.set_attribute("stt.audio_size", len(audio)) span.set_attribute("stt.audio_size", len(audio))
if language: if language:
span.set_attribute("stt.language", language) span.set_attribute("stt.language", language)
files = {"file": ("audio.wav", audio, "audio/wav")} files = {"file": ("audio.wav", audio, "audio/wav")}
data = { data = {
"response_format": response_format, "response_format": response_format,
} }
if language: if language:
data["language"] = language data["language"] = language
# Choose endpoint based on task # Choose endpoint based on task
if task == "translate": if task == "translate":
endpoint = "/v1/audio/translations" endpoint = "/v1/audio/translations"
else: else:
endpoint = "/v1/audio/transcriptions" endpoint = "/v1/audio/transcriptions"
response = await self._client.post(endpoint, files=files, data=data) response = await self._client.post(endpoint, files=files, data=data)
response.raise_for_status() response.raise_for_status()
if response_format == "text": if response_format == "text":
return {"text": response.text} return {"text": response.text}
result = response.json() result = response.json()
if span: if span:
span.set_attribute("stt.result_length", len(result.get("text", ""))) span.set_attribute("stt.result_length", len(result.get("text", "")))
if result.get("language"): if result.get("language"):
span.set_attribute("stt.detected_language", result["language"]) span.set_attribute("stt.detected_language", result["language"])
return result return result
async def transcribe_file( async def transcribe_file(
self, self,
file_path: str, file_path: str,
@@ -98,31 +98,31 @@ class STTClient:
) -> dict: ) -> dict:
""" """
Transcribe an audio file. Transcribe an audio file.
Args: Args:
file_path: Path to audio file file_path: Path to audio file
language: Language code language: Language code
task: "transcribe" or "translate" task: "transcribe" or "translate"
Returns: Returns:
Transcription result Transcription result
""" """
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
audio = f.read() audio = f.read()
return await self.transcribe(audio, language, task) return await self.transcribe(audio, language, task)
async def translate(self, audio: bytes) -> dict: async def translate(self, audio: bytes) -> dict:
""" """
Translate audio to English. Translate audio to English.
Args: Args:
audio: Audio bytes audio: Audio bytes
Returns: Returns:
Translation result with 'text' key Translation result with 'text' key
""" """
return await self.transcribe(audio, task="translate") return await self.transcribe(audio, task="translate")
async def health(self) -> bool: async def health(self) -> bool:
"""Check if the STT service is healthy.""" """Check if the STT service is healthy."""
try: try:

View File

@@ -1,7 +1,7 @@
""" """
TTS service client (Coqui XTTS). TTS service client (Coqui XTTS).
""" """
import io
import logging import logging
from typing import Optional from typing import Optional
@@ -16,23 +16,23 @@ logger = logging.getLogger(__name__)
class TTSClient: class TTSClient:
""" """
Client for the TTS service (Coqui XTTS). Client for the TTS service (Coqui XTTS).
Usage: Usage:
client = TTSClient() client = TTSClient()
audio_bytes = await client.synthesize("Hello world") audio_bytes = await client.synthesize("Hello world")
""" """
def __init__(self, settings: Optional[TTSSettings] = None): def __init__(self, settings: Optional[TTSSettings] = None):
self.settings = settings or TTSSettings() self.settings = settings or TTSSettings()
self._client = httpx.AsyncClient( self._client = httpx.AsyncClient(
base_url=self.settings.tts_url, base_url=self.settings.tts_url,
timeout=120.0, # TTS can be slow timeout=120.0, # TTS can be slow
) )
async def close(self) -> None: async def close(self) -> None:
"""Close the HTTP client.""" """Close the HTTP client."""
await self._client.aclose() await self._client.aclose()
async def synthesize( async def synthesize(
self, self,
text: str, text: str,
@@ -41,39 +41,39 @@ class TTSClient:
) -> bytes: ) -> bytes:
""" """
Synthesize speech from text. Synthesize speech from text.
Args: Args:
text: Text to synthesize text: Text to synthesize
language: Language code (e.g., "en", "es", "fr") language: Language code (e.g., "en", "es", "fr")
speaker: Speaker ID or reference speaker: Speaker ID or reference
Returns: Returns:
WAV audio bytes WAV audio bytes
""" """
language = language or self.settings.tts_language language = language or self.settings.tts_language
with create_span("tts.synthesize") as span: with create_span("tts.synthesize") as span:
if span: if span:
span.set_attribute("tts.language", language) span.set_attribute("tts.language", language)
span.set_attribute("tts.text_length", len(text)) span.set_attribute("tts.text_length", len(text))
params = { params = {
"text": text, "text": text,
"language_id": language, "language_id": language,
} }
if speaker: if speaker:
params["speaker_id"] = speaker params["speaker_id"] = speaker
response = await self._client.get("/api/tts", params=params) response = await self._client.get("/api/tts", params=params)
response.raise_for_status() response.raise_for_status()
audio_bytes = response.content audio_bytes = response.content
if span: if span:
span.set_attribute("tts.audio_size", len(audio_bytes)) span.set_attribute("tts.audio_size", len(audio_bytes))
return audio_bytes return audio_bytes
async def synthesize_to_file( async def synthesize_to_file(
self, self,
text: str, text: str,
@@ -83,7 +83,7 @@ class TTSClient:
) -> None: ) -> None:
""" """
Synthesize speech and save to a file. Synthesize speech and save to a file.
Args: Args:
text: Text to synthesize text: Text to synthesize
output_path: Path to save the audio file output_path: Path to save the audio file
@@ -91,10 +91,10 @@ class TTSClient:
speaker: Speaker ID speaker: Speaker ID
""" """
audio_bytes = await self.synthesize(text, language, speaker) audio_bytes = await self.synthesize(text, language, speaker)
with open(output_path, "wb") as f: with open(output_path, "wb") as f:
f.write(audio_bytes) f.write(audio_bytes)
async def get_speakers(self) -> list[dict]: async def get_speakers(self) -> list[dict]:
"""Get available speakers/voices.""" """Get available speakers/voices."""
try: try:
@@ -103,7 +103,7 @@ class TTSClient:
return response.json() return response.json()
except Exception: except Exception:
return [] return []
async def health(self) -> bool: async def health(self) -> bool:
"""Check if the TTS service is healthy.""" """Check if the TTS service is healthy."""
try: try:

View File

@@ -3,67 +3,69 @@ Configuration management using Pydantic Settings.
Environment variables are automatically loaded and validated. Environment variables are automatically loaded and validated.
""" """
from typing import Optional from typing import Optional
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings): class Settings(BaseSettings):
"""Base settings for all handler services.""" """Base settings for all handler services."""
model_config = SettingsConfigDict( model_config = SettingsConfigDict(
env_file=".env", env_file=".env",
env_file_encoding="utf-8", env_file_encoding="utf-8",
extra="ignore", extra="ignore",
) )
# Service identification # Service identification
service_name: str = "handler" service_name: str = "handler"
service_version: str = "1.0.0" service_version: str = "1.0.0"
service_namespace: str = "ai-ml" service_namespace: str = "ai-ml"
deployment_env: str = "production" deployment_env: str = "production"
# NATS configuration # NATS configuration
nats_url: str = "nats://nats.ai-ml.svc.cluster.local:4222" nats_url: str = "nats://nats.ai-ml.svc.cluster.local:4222"
nats_user: Optional[str] = None nats_user: Optional[str] = None
nats_password: Optional[str] = None nats_password: Optional[str] = None
nats_queue_group: Optional[str] = None nats_queue_group: Optional[str] = None
# Redis/Valkey configuration # Redis/Valkey configuration
redis_url: str = "redis://valkey.ai-ml.svc.cluster.local:6379" redis_url: str = "redis://valkey.ai-ml.svc.cluster.local:6379"
redis_password: Optional[str] = None redis_password: Optional[str] = None
# Milvus configuration # Milvus configuration
milvus_host: str = "milvus.ai-ml.svc.cluster.local" milvus_host: str = "milvus.ai-ml.svc.cluster.local"
milvus_port: int = 19530 milvus_port: int = 19530
milvus_collection: str = "documents" milvus_collection: str = "documents"
# Service endpoints # Service endpoints
embeddings_url: str = "http://embeddings-predictor.ai-ml.svc.cluster.local" embeddings_url: str = "http://embeddings-predictor.ai-ml.svc.cluster.local"
reranker_url: str = "http://reranker-predictor.ai-ml.svc.cluster.local" reranker_url: str = "http://reranker-predictor.ai-ml.svc.cluster.local"
llm_url: str = "http://vllm-predictor.ai-ml.svc.cluster.local" llm_url: str = "http://vllm-predictor.ai-ml.svc.cluster.local"
tts_url: str = "http://tts-predictor.ai-ml.svc.cluster.local" tts_url: str = "http://tts-predictor.ai-ml.svc.cluster.local"
stt_url: str = "http://whisper-predictor.ai-ml.svc.cluster.local" stt_url: str = "http://whisper-predictor.ai-ml.svc.cluster.local"
# OpenTelemetry configuration # OpenTelemetry configuration
otel_enabled: bool = True otel_enabled: bool = True
otel_endpoint: str = "http://opentelemetry-collector.observability.svc.cluster.local:4317" otel_endpoint: str = "http://opentelemetry-collector.observability.svc.cluster.local:4317"
otel_use_http: bool = False otel_use_http: bool = False
# HyperDX configuration # HyperDX configuration
hyperdx_enabled: bool = False hyperdx_enabled: bool = False
hyperdx_api_key: Optional[str] = None hyperdx_api_key: Optional[str] = None
hyperdx_endpoint: str = "https://in-otel.hyperdx.io" hyperdx_endpoint: str = "https://in-otel.hyperdx.io"
# MLflow configuration # MLflow configuration
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80" mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80"
mlflow_experiment_name: Optional[str] = None mlflow_experiment_name: Optional[str] = None
mlflow_enabled: bool = True mlflow_enabled: bool = True
# Health check configuration # Health check configuration
health_port: int = 8080 health_port: int = 8080
health_path: str = "/health" health_path: str = "/health"
ready_path: str = "/ready" ready_path: str = "/ready"
# Timeouts (seconds) # Timeouts (seconds)
http_timeout: float = 60.0 http_timeout: float = 60.0
nats_timeout: float = 30.0 nats_timeout: float = 30.0
@@ -71,14 +73,14 @@ class Settings(BaseSettings):
class EmbeddingsSettings(Settings): class EmbeddingsSettings(Settings):
"""Settings for embeddings service client.""" """Settings for embeddings service client."""
embeddings_model: str = "bge" embeddings_model: str = "bge"
embeddings_batch_size: int = 32 embeddings_batch_size: int = 32
class LLMSettings(Settings): class LLMSettings(Settings):
"""Settings for LLM service client.""" """Settings for LLM service client."""
llm_model: str = "default" llm_model: str = "default"
llm_max_tokens: int = 2048 llm_max_tokens: int = 2048
llm_temperature: float = 0.7 llm_temperature: float = 0.7
@@ -87,13 +89,13 @@ class LLMSettings(Settings):
class TTSSettings(Settings): class TTSSettings(Settings):
"""Settings for TTS service client.""" """Settings for TTS service client."""
tts_language: str = "en" tts_language: str = "en"
tts_speaker: Optional[str] = None tts_speaker: Optional[str] = None
class STTSettings(Settings): class STTSettings(Settings):
"""Settings for STT service client.""" """Settings for STT service client."""
stt_language: Optional[str] = None # Auto-detect stt_language: Optional[str] = None # Auto-detect
stt_task: str = "transcribe" # or "translate" stt_task: str = "transcribe" # or "translate"

View File

@@ -1,6 +1,7 @@
""" """
Base handler class for building NATS-based services. Base handler class for building NATS-based services.
""" """
import asyncio import asyncio
import logging import logging
import signal import signal
@@ -12,7 +13,7 @@ from nats.aio.msg import Msg
from handler_base.config import Settings from handler_base.config import Settings
from handler_base.health import HealthServer from handler_base.health import HealthServer
from handler_base.nats_client import NATSClient from handler_base.nats_client import NATSClient
from handler_base.telemetry import setup_telemetry, create_span from handler_base.telemetry import create_span, setup_telemetry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -20,25 +21,25 @@ logger = logging.getLogger(__name__)
class Handler(ABC): class Handler(ABC):
""" """
Base class for NATS message handlers. Base class for NATS message handlers.
Subclass and implement: Subclass and implement:
- setup(): Initialize your service clients - setup(): Initialize your service clients
- handle_message(): Process incoming messages - handle_message(): Process incoming messages
- teardown(): Clean up resources (optional) - teardown(): Clean up resources (optional)
Example: Example:
class MyHandler(Handler): class MyHandler(Handler):
async def setup(self): async def setup(self):
self.embeddings = EmbeddingsClient() self.embeddings = EmbeddingsClient()
async def handle_message(self, msg: Msg, data: dict) -> Optional[dict]: async def handle_message(self, msg: Msg, data: dict) -> Optional[dict]:
result = await self.embeddings.embed(data["text"]) result = await self.embeddings.embed(data["text"])
return {"embedding": result} return {"embedding": result}
if __name__ == "__main__": if __name__ == "__main__":
MyHandler(subject="my.subject").run() MyHandler(subject="my.subject").run()
""" """
def __init__( def __init__(
self, self,
subject: str, subject: str,
@@ -47,7 +48,7 @@ class Handler(ABC):
): ):
""" """
Initialize the handler. Initialize the handler.
Args: Args:
subject: NATS subject to subscribe to subject: NATS subject to subscribe to
settings: Configuration settings settings: Configuration settings
@@ -56,78 +57,78 @@ class Handler(ABC):
self.subject = subject self.subject = subject
self.settings = settings or Settings() self.settings = settings or Settings()
self.queue_group = queue_group or self.settings.nats_queue_group self.queue_group = queue_group or self.settings.nats_queue_group
self.nats = NATSClient(self.settings) self.nats = NATSClient(self.settings)
self.health_server = HealthServer(self.settings, self._check_ready) self.health_server = HealthServer(self.settings, self._check_ready)
self._running = False self._running = False
self._shutdown_event = asyncio.Event() self._shutdown_event = asyncio.Event()
@abstractmethod @abstractmethod
async def setup(self) -> None: async def setup(self) -> None:
""" """
Initialize service clients and resources. Initialize service clients and resources.
Called once before starting to handle messages. Called once before starting to handle messages.
Override this to set up your service-specific clients. Override this to set up your service-specific clients.
""" """
pass pass
@abstractmethod @abstractmethod
async def handle_message(self, msg: Msg, data: Any) -> Optional[Any]: async def handle_message(self, msg: Msg, data: Any) -> Optional[Any]:
""" """
Handle an incoming message. Handle an incoming message.
Args: Args:
msg: Raw NATS message msg: Raw NATS message
data: Decoded message data (msgpack unpacked) data: Decoded message data (msgpack unpacked)
Returns: Returns:
Optional response data. If returned and msg has a reply subject, Optional response data. If returned and msg has a reply subject,
the response will be sent automatically. the response will be sent automatically.
""" """
pass pass
async def teardown(self) -> None: async def teardown(self) -> None:
""" """
Clean up resources. Clean up resources.
Called during graceful shutdown. Called during graceful shutdown.
Override to add custom cleanup logic. Override to add custom cleanup logic.
""" """
pass pass
async def _check_ready(self) -> bool: async def _check_ready(self) -> bool:
"""Check if the service is ready to handle requests.""" """Check if the service is ready to handle requests."""
return self._running and self.nats._nc is not None return self._running and self.nats._nc is not None
async def _message_handler(self, msg: Msg) -> None: async def _message_handler(self, msg: Msg) -> None:
"""Internal message handler with tracing and error handling.""" """Internal message handler with tracing and error handling."""
with create_span(f"handle.{self.subject}") as span: with create_span(f"handle.{self.subject}") as span:
try: try:
# Decode message # Decode message
data = NATSClient.decode_msgpack(msg) data = NATSClient.decode_msgpack(msg)
if span: if span:
span.set_attribute("messaging.destination", msg.subject) span.set_attribute("messaging.destination", msg.subject)
if isinstance(data, dict): if isinstance(data, dict):
request_id = data.get("request_id", data.get("id")) request_id = data.get("request_id", data.get("id"))
if request_id: if request_id:
span.set_attribute("request.id", str(request_id)) span.set_attribute("request.id", str(request_id))
# Handle message # Handle message
response = await self.handle_message(msg, data) response = await self.handle_message(msg, data)
# Send response if applicable # Send response if applicable
if response is not None and msg.reply: if response is not None and msg.reply:
await self.nats.publish(msg.reply, response) await self.nats.publish(msg.reply, response)
except Exception as e: except Exception as e:
logger.exception(f"Error handling message on {msg.subject}") logger.exception(f"Error handling message on {msg.subject}")
if span: if span:
span.set_attribute("error", True) span.set_attribute("error", True)
span.set_attribute("error.message", str(e)) span.set_attribute("error.message", str(e))
# Send error response if reply expected # Send error response if reply expected
if msg.reply: if msg.reply:
error_response = { error_response = {
@@ -136,71 +137,71 @@ class Handler(ABC):
"type": type(e).__name__, "type": type(e).__name__,
} }
await self.nats.publish(msg.reply, error_response) await self.nats.publish(msg.reply, error_response)
def _setup_signals(self) -> None: def _setup_signals(self) -> None:
"""Set up signal handlers for graceful shutdown.""" """Set up signal handlers for graceful shutdown."""
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT): for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, self._handle_signal, sig) loop.add_signal_handler(sig, self._handle_signal, sig)
def _handle_signal(self, sig: signal.Signals) -> None: def _handle_signal(self, sig: signal.Signals) -> None:
"""Handle shutdown signal.""" """Handle shutdown signal."""
logger.info(f"Received {sig.name}, initiating graceful shutdown...") logger.info(f"Received {sig.name}, initiating graceful shutdown...")
self._shutdown_event.set() self._shutdown_event.set()
async def _run(self) -> None: async def _run(self) -> None:
"""Main async run loop.""" """Main async run loop."""
# Setup telemetry # Setup telemetry
setup_telemetry(self.settings) setup_telemetry(self.settings)
# Start health server # Start health server
self.health_server.start() self.health_server.start()
try: try:
# Connect to NATS # Connect to NATS
await self.nats.connect() await self.nats.connect()
# Run user setup # Run user setup
logger.info("Running service setup...") logger.info("Running service setup...")
await self.setup() await self.setup()
# Subscribe to subject # Subscribe to subject
await self.nats.subscribe( await self.nats.subscribe(
self.subject, self.subject,
self._message_handler, self._message_handler,
queue=self.queue_group, queue=self.queue_group,
) )
self._running = True self._running = True
logger.info(f"Handler ready, listening on {self.subject}") logger.info(f"Handler ready, listening on {self.subject}")
# Wait for shutdown signal # Wait for shutdown signal
await self._shutdown_event.wait() await self._shutdown_event.wait()
except Exception as e: except Exception:
logger.exception("Fatal error in handler") logger.exception("Fatal error in handler")
raise raise
finally: finally:
self._running = False self._running = False
# Graceful shutdown # Graceful shutdown
logger.info("Shutting down...") logger.info("Shutting down...")
try: try:
await self.teardown() await self.teardown()
except Exception as e: except Exception as e:
logger.warning(f"Error during teardown: {e}") logger.warning(f"Error during teardown: {e}")
await self.nats.close() await self.nats.close()
self.health_server.stop() self.health_server.stop()
logger.info("Shutdown complete") logger.info("Shutdown complete")
def run(self) -> None: def run(self) -> None:
""" """
Run the handler. Run the handler.
This is the main entry point. It sets up signal handlers This is the main entry point. It sets up signal handlers
and runs the async event loop. and runs the async event loop.
""" """
@@ -209,12 +210,12 @@ class Handler(ABC):
level=logging.INFO, level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
) )
logger.info(f"Starting {self.settings.service_name} v{self.settings.service_version}") logger.info(f"Starting {self.settings.service_name} v{self.settings.service_version}")
# Run the async loop # Run the async loop
asyncio.run(self._run_with_signals()) asyncio.run(self._run_with_signals())
async def _run_with_signals(self) -> None: async def _run_with_signals(self) -> None:
"""Run with signal handling.""" """Run with signal handling."""
self._setup_signals() self._setup_signals()

View File

@@ -3,12 +3,13 @@ HTTP health check server.
Provides /health and /ready endpoints for Kubernetes probes. Provides /health and /ready endpoints for Kubernetes probes.
""" """
import asyncio import asyncio
import logging
from typing import Callable, Optional, Awaitable
from http.server import HTTPServer, BaseHTTPRequestHandler
import threading
import json import json
import logging
import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Awaitable, Callable, Optional
from handler_base.config import Settings from handler_base.config import Settings
@@ -17,16 +18,16 @@ logger = logging.getLogger(__name__)
class HealthHandler(BaseHTTPRequestHandler): class HealthHandler(BaseHTTPRequestHandler):
"""HTTP request handler for health checks.""" """HTTP request handler for health checks."""
# Class-level state # Class-level state
ready_check: Optional[Callable[[], Awaitable[bool]]] = None ready_check: Optional[Callable[[], Awaitable[bool]]] = None
health_path: str = "/health" health_path: str = "/health"
ready_path: str = "/ready" ready_path: str = "/ready"
def log_message(self, format, *args): def log_message(self, format, *args):
"""Suppress default logging.""" """Suppress default logging."""
pass pass
def do_GET(self): def do_GET(self):
"""Handle GET requests for health/ready endpoints.""" """Handle GET requests for health/ready endpoints."""
if self.path == self.health_path: if self.path == self.health_path:
@@ -35,7 +36,7 @@ class HealthHandler(BaseHTTPRequestHandler):
self._handle_ready() self._handle_ready()
else: else:
self._respond_not_found() self._respond_not_found()
def _handle_ready(self): def _handle_ready(self):
"""Check readiness and respond.""" """Check readiness and respond."""
# Access via class to avoid method binding issues # Access via class to avoid method binding issues
@@ -43,7 +44,7 @@ class HealthHandler(BaseHTTPRequestHandler):
if ready_check is None: if ready_check is None:
self._respond_ok({"status": "ready"}) self._respond_ok({"status": "ready"})
return return
try: try:
# Run the async check in a new event loop # Run the async check in a new event loop
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
@@ -51,7 +52,7 @@ class HealthHandler(BaseHTTPRequestHandler):
is_ready = loop.run_until_complete(ready_check()) is_ready = loop.run_until_complete(ready_check())
finally: finally:
loop.close() loop.close()
if is_ready: if is_ready:
self._respond_ok({"status": "ready"}) self._respond_ok({"status": "ready"})
else: else:
@@ -59,19 +60,19 @@ class HealthHandler(BaseHTTPRequestHandler):
except Exception as e: except Exception as e:
logger.exception("Readiness check failed") logger.exception("Readiness check failed")
self._respond_unavailable({"status": "error", "message": str(e)}) self._respond_unavailable({"status": "error", "message": str(e)})
def _respond_ok(self, data: dict): def _respond_ok(self, data: dict):
self.send_response(200) self.send_response(200)
self.send_header("Content-Type", "application/json") self.send_header("Content-Type", "application/json")
self.end_headers() self.end_headers()
self.wfile.write(json.dumps(data).encode()) self.wfile.write(json.dumps(data).encode())
def _respond_unavailable(self, data: dict): def _respond_unavailable(self, data: dict):
self.send_response(503) self.send_response(503)
self.send_header("Content-Type", "application/json") self.send_header("Content-Type", "application/json")
self.end_headers() self.end_headers()
self.wfile.write(json.dumps(data).encode()) self.wfile.write(json.dumps(data).encode())
def _respond_not_found(self): def _respond_not_found(self):
self.send_response(404) self.send_response(404)
self.end_headers() self.end_headers()
@@ -80,14 +81,14 @@ class HealthHandler(BaseHTTPRequestHandler):
class HealthServer: class HealthServer:
""" """
Background HTTP server for health checks. Background HTTP server for health checks.
Usage: Usage:
server = HealthServer(settings) server = HealthServer(settings)
server.start() server.start()
# ... run your service ... # ... run your service ...
server.stop() server.stop()
""" """
def __init__( def __init__(
self, self,
settings: Optional[Settings] = None, settings: Optional[Settings] = None,
@@ -97,24 +98,24 @@ class HealthServer:
self.ready_check = ready_check self.ready_check = ready_check
self._server: Optional[HTTPServer] = None self._server: Optional[HTTPServer] = None
self._thread: Optional[threading.Thread] = None self._thread: Optional[threading.Thread] = None
def start(self) -> None: def start(self) -> None:
"""Start the health check server in a background thread.""" """Start the health check server in a background thread."""
# Configure handler class # Configure handler class
HealthHandler.ready_check = self.ready_check HealthHandler.ready_check = self.ready_check
HealthHandler.health_path = self.settings.health_path HealthHandler.health_path = self.settings.health_path
HealthHandler.ready_path = self.settings.ready_path HealthHandler.ready_path = self.settings.ready_path
# Create and start server # Create and start server
self._server = HTTPServer(("0.0.0.0", self.settings.health_port), HealthHandler) self._server = HTTPServer(("0.0.0.0", self.settings.health_port), HealthHandler)
self._thread = threading.Thread(target=self._server.serve_forever, daemon=True) self._thread = threading.Thread(target=self._server.serve_forever, daemon=True)
self._thread.start() self._thread.start()
logger.info( logger.info(
f"Health server started on port {self.settings.health_port} " f"Health server started on port {self.settings.health_port} "
f"(health: {self.settings.health_path}, ready: {self.settings.ready_path})" f"(health: {self.settings.health_path}, ready: {self.settings.ready_path})"
) )
def stop(self) -> None: def stop(self) -> None:
"""Stop the health check server.""" """Stop the health check server."""
if self._server: if self._server:

View File

@@ -1,9 +1,9 @@
""" """
NATS client wrapper with connection management and utilities. NATS client wrapper with connection management and utilities.
""" """
import asyncio
import logging import logging
from typing import Any, Callable, Optional, Awaitable from typing import Any, Awaitable, Callable, Optional
import msgpack import msgpack
import nats import nats
@@ -20,34 +20,34 @@ logger = logging.getLogger(__name__)
class NATSClient: class NATSClient:
""" """
NATS client with automatic connection management. NATS client with automatic connection management.
Supports: Supports:
- Core NATS pub/sub - Core NATS pub/sub
- JetStream for persistence - JetStream for persistence
- Queue groups for load balancing - Queue groups for load balancing
- Msgpack serialization - Msgpack serialization
""" """
def __init__(self, settings: Optional[Settings] = None): def __init__(self, settings: Optional[Settings] = None):
self.settings = settings or Settings() self.settings = settings or Settings()
self._nc: Optional[Client] = None self._nc: Optional[Client] = None
self._js: Optional[JetStreamContext] = None self._js: Optional[JetStreamContext] = None
self._subscriptions: list = [] self._subscriptions: list = []
@property @property
def nc(self) -> Client: def nc(self) -> Client:
"""Get the NATS client, raising if not connected.""" """Get the NATS client, raising if not connected."""
if self._nc is None: if self._nc is None:
raise RuntimeError("NATS client not connected. Call connect() first.") raise RuntimeError("NATS client not connected. Call connect() first.")
return self._nc return self._nc
@property @property
def js(self) -> JetStreamContext: def js(self) -> JetStreamContext:
"""Get JetStream context, raising if not connected.""" """Get JetStream context, raising if not connected."""
if self._js is None: if self._js is None:
raise RuntimeError("JetStream not initialized. Call connect() first.") raise RuntimeError("JetStream not initialized. Call connect() first.")
return self._js return self._js
async def connect(self) -> None: async def connect(self) -> None:
"""Connect to NATS server.""" """Connect to NATS server."""
connect_opts = { connect_opts = {
@@ -55,16 +55,16 @@ class NATSClient:
"reconnect_time_wait": 2, "reconnect_time_wait": 2,
"max_reconnect_attempts": -1, # Infinite "max_reconnect_attempts": -1, # Infinite
} }
if self.settings.nats_user and self.settings.nats_password: if self.settings.nats_user and self.settings.nats_password:
connect_opts["user"] = self.settings.nats_user connect_opts["user"] = self.settings.nats_user
connect_opts["password"] = self.settings.nats_password connect_opts["password"] = self.settings.nats_password
logger.info(f"Connecting to NATS at {self.settings.nats_url}") logger.info(f"Connecting to NATS at {self.settings.nats_url}")
self._nc = await nats.connect(**connect_opts) self._nc = await nats.connect(**connect_opts)
self._js = self._nc.jetstream() self._js = self._nc.jetstream()
logger.info("Connected to NATS") logger.info("Connected to NATS")
async def close(self) -> None: async def close(self) -> None:
"""Close NATS connection gracefully.""" """Close NATS connection gracefully."""
if self._nc: if self._nc:
@@ -74,13 +74,13 @@ class NATSClient:
await sub.drain() await sub.drain()
except Exception as e: except Exception as e:
logger.warning(f"Error draining subscription: {e}") logger.warning(f"Error draining subscription: {e}")
await self._nc.drain() await self._nc.drain()
await self._nc.close() await self._nc.close()
self._nc = None self._nc = None
self._js = None self._js = None
logger.info("NATS connection closed") logger.info("NATS connection closed")
async def subscribe( async def subscribe(
self, self,
subject: str, subject: str,
@@ -89,24 +89,24 @@ class NATSClient:
): ):
""" """
Subscribe to a subject with a handler function. Subscribe to a subject with a handler function.
Args: Args:
subject: NATS subject to subscribe to subject: NATS subject to subscribe to
handler: Async function to handle messages handler: Async function to handle messages
queue: Optional queue group for load balancing queue: Optional queue group for load balancing
""" """
queue = queue or self.settings.nats_queue_group queue = queue or self.settings.nats_queue_group
if queue: if queue:
sub = await self.nc.subscribe(subject, queue=queue, cb=handler) sub = await self.nc.subscribe(subject, queue=queue, cb=handler)
logger.info(f"Subscribed to {subject} (queue: {queue})") logger.info(f"Subscribed to {subject} (queue: {queue})")
else: else:
sub = await self.nc.subscribe(subject, cb=handler) sub = await self.nc.subscribe(subject, cb=handler)
logger.info(f"Subscribed to {subject}") logger.info(f"Subscribed to {subject}")
self._subscriptions.append(sub) self._subscriptions.append(sub)
return sub return sub
async def publish( async def publish(
self, self,
subject: str, subject: str,
@@ -115,7 +115,7 @@ class NATSClient:
) -> None: ) -> None:
""" """
Publish a message to a subject. Publish a message to a subject.
Args: Args:
subject: NATS subject to publish to subject: NATS subject to publish to
data: Data to publish (will be serialized) data: Data to publish (will be serialized)
@@ -124,15 +124,16 @@ class NATSClient:
with create_span("nats.publish") as span: with create_span("nats.publish") as span:
if span: if span:
span.set_attribute("messaging.destination", subject) span.set_attribute("messaging.destination", subject)
if use_msgpack: if use_msgpack:
payload = msgpack.packb(data, use_bin_type=True) payload = msgpack.packb(data, use_bin_type=True)
else: else:
import json import json
payload = json.dumps(data).encode() payload = json.dumps(data).encode()
await self.nc.publish(subject, payload) await self.nc.publish(subject, payload)
async def request( async def request(
self, self,
subject: str, subject: str,
@@ -142,43 +143,46 @@ class NATSClient:
) -> Any: ) -> Any:
""" """
Send a request and wait for response. Send a request and wait for response.
Args: Args:
subject: NATS subject to send request to subject: NATS subject to send request to
data: Request data data: Request data
timeout: Response timeout in seconds timeout: Response timeout in seconds
use_msgpack: Whether to use msgpack serialization use_msgpack: Whether to use msgpack serialization
Returns: Returns:
Decoded response data Decoded response data
""" """
timeout = timeout or self.settings.nats_timeout timeout = timeout or self.settings.nats_timeout
with create_span("nats.request") as span: with create_span("nats.request") as span:
if span: if span:
span.set_attribute("messaging.destination", subject) span.set_attribute("messaging.destination", subject)
if use_msgpack: if use_msgpack:
payload = msgpack.packb(data, use_bin_type=True) payload = msgpack.packb(data, use_bin_type=True)
else: else:
import json import json
payload = json.dumps(data).encode() payload = json.dumps(data).encode()
response = await self.nc.request(subject, payload, timeout=timeout) response = await self.nc.request(subject, payload, timeout=timeout)
if use_msgpack: if use_msgpack:
return msgpack.unpackb(response.data, raw=False) return msgpack.unpackb(response.data, raw=False)
else: else:
import json import json
return json.loads(response.data.decode()) return json.loads(response.data.decode())
@staticmethod @staticmethod
def decode_msgpack(msg: Msg) -> Any: def decode_msgpack(msg: Msg) -> Any:
"""Decode a msgpack message.""" """Decode a msgpack message."""
return msgpack.unpackb(msg.data, raw=False) return msgpack.unpackb(msg.data, raw=False)
@staticmethod @staticmethod
def decode_json(msg: Msg) -> Any: def decode_json(msg: Msg) -> Any:
"""Decode a JSON message.""" """Decode a JSON message."""
import json import json
return json.loads(msg.data.decode()) return json.loads(msg.data.decode())

View File

@@ -3,26 +3,27 @@ OpenTelemetry setup for tracing and metrics.
Supports both gRPC and HTTP exporters, with optional HyperDX integration. Supports both gRPC and HTTP exporters, with optional HyperDX integration.
""" """
import logging import logging
import os import os
from typing import Optional, Tuple from typing import Optional, Tuple
from opentelemetry import trace, metrics from opentelemetry import metrics, trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
OTLPSpanExporter as OTLPSpanExporterHTTP,
)
from opentelemetry.exporter.otlp.proto.http.metric_exporter import ( from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
OTLPMetricExporter as OTLPMetricExporterHTTP, OTLPMetricExporter as OTLPMetricExporterHTTP,
) )
from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter as OTLPSpanExporterHTTP,
)
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
from opentelemetry.instrumentation.logging import LoggingInstrumentor from opentelemetry.instrumentation.logging import LoggingInstrumentor
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.resources import SERVICE_NAME, SERVICE_NAMESPACE, SERVICE_VERSION, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from handler_base.config import Settings from handler_base.config import Settings
@@ -39,35 +40,37 @@ def setup_telemetry(
) -> Tuple[Optional[trace.Tracer], Optional[metrics.Meter]]: ) -> Tuple[Optional[trace.Tracer], Optional[metrics.Meter]]:
""" """
Initialize OpenTelemetry tracing and metrics. Initialize OpenTelemetry tracing and metrics.
Args: Args:
settings: Configuration settings. If None, loads from environment. settings: Configuration settings. If None, loads from environment.
Returns: Returns:
Tuple of (tracer, meter) or (None, None) if disabled. Tuple of (tracer, meter) or (None, None) if disabled.
""" """
global _tracer, _meter, _initialized global _tracer, _meter, _initialized
if _initialized: if _initialized:
return _tracer, _meter return _tracer, _meter
if settings is None: if settings is None:
settings = Settings() settings = Settings()
if not settings.otel_enabled: if not settings.otel_enabled:
logger.info("OpenTelemetry disabled") logger.info("OpenTelemetry disabled")
_initialized = True _initialized = True
return None, None return None, None
# Create resource with service information # Create resource with service information
resource = Resource.create({ resource = Resource.create(
SERVICE_NAME: settings.service_name, {
SERVICE_VERSION: settings.service_version, SERVICE_NAME: settings.service_name,
SERVICE_NAMESPACE: settings.service_namespace, SERVICE_VERSION: settings.service_version,
"deployment.environment": settings.deployment_env, SERVICE_NAMESPACE: settings.service_namespace,
"host.name": os.environ.get("HOSTNAME", "unknown"), "deployment.environment": settings.deployment_env,
}) "host.name": os.environ.get("HOSTNAME", "unknown"),
}
)
# Determine endpoint and exporter type # Determine endpoint and exporter type
if settings.hyperdx_enabled and settings.hyperdx_api_key: if settings.hyperdx_enabled and settings.hyperdx_api_key:
# HyperDX uses HTTP with API key header # HyperDX uses HTTP with API key header
@@ -80,7 +83,7 @@ def setup_telemetry(
headers = None headers = None
use_http = settings.otel_use_http use_http = settings.otel_use_http
logger.info(f"Using OTEL endpoint: {endpoint} (HTTP: {use_http})") logger.info(f"Using OTEL endpoint: {endpoint} (HTTP: {use_http})")
# Setup tracing # Setup tracing
if use_http: if use_http:
trace_exporter = OTLPSpanExporterHTTP( trace_exporter = OTLPSpanExporterHTTP(
@@ -91,11 +94,11 @@ def setup_telemetry(
trace_exporter = OTLPSpanExporter( trace_exporter = OTLPSpanExporter(
endpoint=endpoint, endpoint=endpoint,
) )
tracer_provider = TracerProvider(resource=resource) tracer_provider = TracerProvider(resource=resource)
tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter)) tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter))
trace.set_tracer_provider(tracer_provider) trace.set_tracer_provider(tracer_provider)
# Setup metrics # Setup metrics
if use_http: if use_http:
metric_exporter = OTLPMetricExporterHTTP( metric_exporter = OTLPMetricExporterHTTP(
@@ -106,25 +109,25 @@ def setup_telemetry(
metric_exporter = OTLPMetricExporter( metric_exporter = OTLPMetricExporter(
endpoint=endpoint, endpoint=endpoint,
) )
metric_reader = PeriodicExportingMetricReader( metric_reader = PeriodicExportingMetricReader(
metric_exporter, metric_exporter,
export_interval_millis=60000, export_interval_millis=60000,
) )
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
metrics.set_meter_provider(meter_provider) metrics.set_meter_provider(meter_provider)
# Instrument libraries # Instrument libraries
HTTPXClientInstrumentor().instrument() HTTPXClientInstrumentor().instrument()
LoggingInstrumentor().instrument(set_logging_format=True) LoggingInstrumentor().instrument(set_logging_format=True)
# Create tracer and meter for this service # Create tracer and meter for this service
_tracer = trace.get_tracer(settings.service_name, settings.service_version) _tracer = trace.get_tracer(settings.service_name, settings.service_version)
_meter = metrics.get_meter(settings.service_name, settings.service_version) _meter = metrics.get_meter(settings.service_name, settings.service_version)
logger.info(f"OpenTelemetry initialized for {settings.service_name}") logger.info(f"OpenTelemetry initialized for {settings.service_name}")
_initialized = True _initialized = True
return _tracer, _meter return _tracer, _meter
@@ -141,7 +144,7 @@ def get_meter() -> Optional[metrics.Meter]:
def create_span(name: str, **kwargs): def create_span(name: str, **kwargs):
""" """
Create a new span. Create a new span.
Usage: Usage:
with create_span("my_operation") as span: with create_span("my_operation") as span:
span.set_attribute("key", "value") span.set_attribute("key", "value")
@@ -150,5 +153,6 @@ def create_span(name: str, **kwargs):
if _tracer is None: if _tracer is None:
# Return a no-op context manager # Return a no-op context manager
from contextlib import nullcontext from contextlib import nullcontext
return nullcontext() return nullcontext()
return _tracer.start_as_current_span(name, **kwargs) return _tracer.start_as_current_span(name, **kwargs)

View File

@@ -1,14 +1,13 @@
""" """
Pytest configuration and fixtures. Pytest configuration and fixtures.
""" """
import asyncio import asyncio
import os import os
from typing import AsyncGenerator
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
# Set test environment variables before importing handler_base # Set test environment variables before importing handler_base
os.environ.setdefault("NATS_URL", "nats://localhost:4222") os.environ.setdefault("NATS_URL", "nats://localhost:4222")
os.environ.setdefault("REDIS_URL", "redis://localhost:6379") os.environ.setdefault("REDIS_URL", "redis://localhost:6379")
@@ -29,6 +28,7 @@ def event_loop():
def settings(): def settings():
"""Create test settings.""" """Create test settings."""
from handler_base.config import Settings from handler_base.config import Settings
return Settings( return Settings(
service_name="test-service", service_name="test-service",
service_version="1.0.0-test", service_version="1.0.0-test",
@@ -56,7 +56,7 @@ def mock_nats_message():
msg = MagicMock() msg = MagicMock()
msg.subject = "test.subject" msg.subject = "test.subject"
msg.reply = "test.reply" msg.reply = "test.reply"
msg.data = b'\x82\xa8query\xa5hello\xaarequest_id\xa4test' # msgpack msg.data = b"\x82\xa8query\xa5hello\xaarequest_id\xa4test" # msgpack
return msg return msg

View File

@@ -1,44 +1,43 @@
""" """
Unit tests for service clients. Unit tests for service clients.
""" """
import json
from unittest.mock import MagicMock
import pytest import pytest
from unittest.mock import AsyncMock, MagicMock, patch
class TestEmbeddingsClient: class TestEmbeddingsClient:
"""Tests for EmbeddingsClient.""" """Tests for EmbeddingsClient."""
@pytest.fixture @pytest.fixture
def embeddings_client(self, mock_httpx_client): def embeddings_client(self, mock_httpx_client):
"""Create an EmbeddingsClient with mocked HTTP.""" """Create an EmbeddingsClient with mocked HTTP."""
from handler_base.clients.embeddings import EmbeddingsClient from handler_base.clients.embeddings import EmbeddingsClient
client = EmbeddingsClient() client = EmbeddingsClient()
client._client = mock_httpx_client client._client = mock_httpx_client
return client return client
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_embed_single(self, embeddings_client, mock_httpx_client, sample_embedding): async def test_embed_single(self, embeddings_client, mock_httpx_client, sample_embedding):
"""Test embedding a single text.""" """Test embedding a single text."""
# Setup mock response # Setup mock response
mock_response = MagicMock() mock_response = MagicMock()
mock_response.json.return_value = { mock_response.json.return_value = {"data": [{"embedding": sample_embedding, "index": 0}]}
"data": [{"embedding": sample_embedding, "index": 0}]
}
mock_response.raise_for_status = MagicMock() mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response mock_httpx_client.post.return_value = mock_response
result = await embeddings_client.embed_single("Hello world") result = await embeddings_client.embed_single("Hello world")
assert result == sample_embedding assert result == sample_embedding
mock_httpx_client.post.assert_called_once() mock_httpx_client.post.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_embed_batch(self, embeddings_client, mock_httpx_client, sample_embedding): async def test_embed_batch(self, embeddings_client, mock_httpx_client, sample_embedding):
"""Test embedding multiple texts.""" """Test embedding multiple texts."""
texts = ["Hello", "World"] texts = ["Hello", "World"]
mock_response = MagicMock() mock_response = MagicMock()
mock_response.json.return_value = { mock_response.json.return_value = {
"data": [ "data": [
@@ -48,41 +47,41 @@ class TestEmbeddingsClient:
} }
mock_response.raise_for_status = MagicMock() mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response mock_httpx_client.post.return_value = mock_response
result = await embeddings_client.embed(texts) result = await embeddings_client.embed(texts)
assert len(result) == 2 assert len(result) == 2
assert all(len(e) == len(sample_embedding) for e in result) assert all(len(e) == len(sample_embedding) for e in result)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_health_check(self, embeddings_client, mock_httpx_client): async def test_health_check(self, embeddings_client, mock_httpx_client):
"""Test health check.""" """Test health check."""
mock_response = MagicMock() mock_response = MagicMock()
mock_response.status_code = 200 mock_response.status_code = 200
mock_httpx_client.get.return_value = mock_response mock_httpx_client.get.return_value = mock_response
result = await embeddings_client.health() result = await embeddings_client.health()
assert result is True assert result is True
class TestRerankerClient: class TestRerankerClient:
"""Tests for RerankerClient.""" """Tests for RerankerClient."""
@pytest.fixture @pytest.fixture
def reranker_client(self, mock_httpx_client): def reranker_client(self, mock_httpx_client):
"""Create a RerankerClient with mocked HTTP.""" """Create a RerankerClient with mocked HTTP."""
from handler_base.clients.reranker import RerankerClient from handler_base.clients.reranker import RerankerClient
client = RerankerClient() client = RerankerClient()
client._client = mock_httpx_client client._client = mock_httpx_client
return client return client
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_rerank(self, reranker_client, mock_httpx_client, sample_documents): async def test_rerank(self, reranker_client, mock_httpx_client, sample_documents):
"""Test reranking documents.""" """Test reranking documents."""
texts = [d["text"] for d in sample_documents] texts = [d["text"] for d in sample_documents]
mock_response = MagicMock() mock_response = MagicMock()
mock_response.json.return_value = { mock_response.json.return_value = {
"results": [ "results": [
@@ -93,9 +92,9 @@ class TestRerankerClient:
} }
mock_response.raise_for_status = MagicMock() mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response mock_httpx_client.post.return_value = mock_response
result = await reranker_client.rerank("What is ML?", texts) result = await reranker_client.rerank("What is ML?", texts)
assert len(result) == 3 assert len(result) == 3
assert result[0]["score"] == 0.95 assert result[0]["score"] == 0.95
assert result[0]["index"] == 1 assert result[0]["index"] == 1
@@ -103,53 +102,48 @@ class TestRerankerClient:
class TestLLMClient: class TestLLMClient:
"""Tests for LLMClient.""" """Tests for LLMClient."""
@pytest.fixture @pytest.fixture
def llm_client(self, mock_httpx_client): def llm_client(self, mock_httpx_client):
"""Create an LLMClient with mocked HTTP.""" """Create an LLMClient with mocked HTTP."""
from handler_base.clients.llm import LLMClient from handler_base.clients.llm import LLMClient
client = LLMClient() client = LLMClient()
client._client = mock_httpx_client client._client = mock_httpx_client
return client return client
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate(self, llm_client, mock_httpx_client): async def test_generate(self, llm_client, mock_httpx_client):
"""Test generating a response.""" """Test generating a response."""
mock_response = MagicMock() mock_response = MagicMock()
mock_response.json.return_value = { mock_response.json.return_value = {
"choices": [ "choices": [{"message": {"content": "Hello! I'm an AI assistant."}}],
{"message": {"content": "Hello! I'm an AI assistant."}} "usage": {"prompt_tokens": 10, "completion_tokens": 20},
],
"usage": {"prompt_tokens": 10, "completion_tokens": 20}
} }
mock_response.raise_for_status = MagicMock() mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response mock_httpx_client.post.return_value = mock_response
result = await llm_client.generate("Hello") result = await llm_client.generate("Hello")
assert result == "Hello! I'm an AI assistant." assert result == "Hello! I'm an AI assistant."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_with_context(self, llm_client, mock_httpx_client): async def test_generate_with_context(self, llm_client, mock_httpx_client):
"""Test generating with RAG context.""" """Test generating with RAG context."""
mock_response = MagicMock() mock_response = MagicMock()
mock_response.json.return_value = { mock_response.json.return_value = {
"choices": [ "choices": [{"message": {"content": "Based on the context..."}}],
{"message": {"content": "Based on the context..."}} "usage": {},
],
"usage": {}
} }
mock_response.raise_for_status = MagicMock() mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response mock_httpx_client.post.return_value = mock_response
result = await llm_client.generate( result = await llm_client.generate(
"What is Python?", "What is Python?", context="Python is a programming language."
context="Python is a programming language."
) )
assert "Based on the context" in result assert "Based on the context" in result
# Verify context was included in the request # Verify context was included in the request
call_args = mock_httpx_client.post.call_args call_args = mock_httpx_client.post.call_args
messages = call_args.kwargs["json"]["messages"] messages = call_args.kwargs["json"]["messages"]

View File

@@ -1,46 +1,45 @@
""" """
Unit tests for handler_base.config module. Unit tests for handler_base.config module.
""" """
import os
import pytest
class TestSettings: class TestSettings:
"""Tests for Settings configuration.""" """Tests for Settings configuration."""
def test_default_settings(self, settings): def test_default_settings(self, settings):
"""Test that default settings are loaded correctly.""" """Test that default settings are loaded correctly."""
assert settings.service_name == "test-service" assert settings.service_name == "test-service"
assert settings.service_version == "1.0.0-test" assert settings.service_version == "1.0.0-test"
assert settings.otel_enabled is False assert settings.otel_enabled is False
def test_settings_from_env(self, monkeypatch): def test_settings_from_env(self, monkeypatch):
"""Test that settings can be loaded from environment variables.""" """Test that settings can be loaded from environment variables."""
monkeypatch.setenv("SERVICE_NAME", "env-service") monkeypatch.setenv("SERVICE_NAME", "env-service")
monkeypatch.setenv("SERVICE_VERSION", "2.0.0") monkeypatch.setenv("SERVICE_VERSION", "2.0.0")
monkeypatch.setenv("NATS_URL", "nats://custom:4222") monkeypatch.setenv("NATS_URL", "nats://custom:4222")
# Need to reimport to pick up env changes # Need to reimport to pick up env changes
from handler_base.config import Settings from handler_base.config import Settings
s = Settings() s = Settings()
assert s.service_name == "env-service" assert s.service_name == "env-service"
assert s.service_version == "2.0.0" assert s.service_version == "2.0.0"
assert s.nats_url == "nats://custom:4222" assert s.nats_url == "nats://custom:4222"
def test_embeddings_settings(self): def test_embeddings_settings(self):
"""Test EmbeddingsSettings extends base correctly.""" """Test EmbeddingsSettings extends base correctly."""
from handler_base.config import EmbeddingsSettings from handler_base.config import EmbeddingsSettings
s = EmbeddingsSettings() s = EmbeddingsSettings()
assert hasattr(s, "embeddings_model") assert hasattr(s, "embeddings_model")
assert hasattr(s, "embeddings_batch_size") assert hasattr(s, "embeddings_batch_size")
assert s.embeddings_model == "bge" assert s.embeddings_model == "bge"
def test_llm_settings(self): def test_llm_settings(self):
"""Test LLMSettings has expected defaults.""" """Test LLMSettings has expected defaults."""
from handler_base.config import LLMSettings from handler_base.config import LLMSettings
s = LLMSettings() s = LLMSettings()
assert s.llm_max_tokens == 2048 assert s.llm_max_tokens == 2048
assert s.llm_temperature == 0.7 assert s.llm_temperature == 0.7

View File

@@ -1,101 +1,101 @@
""" """
Unit tests for handler_base.health module. Unit tests for handler_base.health module.
""" """
import pytest
import json import json
import threading
import time import time
from http.client import HTTPConnection from http.client import HTTPConnection
from unittest.mock import AsyncMock
import pytest
class TestHealthServer: class TestHealthServer:
"""Tests for HealthServer.""" """Tests for HealthServer."""
@pytest.fixture @pytest.fixture
def health_server(self, settings): def health_server(self, settings):
"""Create a HealthServer instance.""" """Create a HealthServer instance."""
from handler_base.health import HealthServer from handler_base.health import HealthServer
# Use a random high port to avoid conflicts # Use a random high port to avoid conflicts
settings.health_port = 18080 settings.health_port = 18080
return HealthServer(settings) return HealthServer(settings)
def test_start_stop(self, health_server): def test_start_stop(self, health_server):
"""Test starting and stopping the health server.""" """Test starting and stopping the health server."""
health_server.start() health_server.start()
time.sleep(0.1) # Give server time to start time.sleep(0.1) # Give server time to start
# Verify server is running # Verify server is running
assert health_server._server is not None assert health_server._server is not None
assert health_server._thread is not None assert health_server._thread is not None
assert health_server._thread.is_alive() assert health_server._thread.is_alive()
health_server.stop() health_server.stop()
time.sleep(0.1) time.sleep(0.1)
assert health_server._server is None assert health_server._server is None
def test_health_endpoint(self, health_server): def test_health_endpoint(self, health_server):
"""Test the /health endpoint.""" """Test the /health endpoint."""
health_server.start() health_server.start()
time.sleep(0.1) time.sleep(0.1)
try: try:
conn = HTTPConnection("localhost", 18080, timeout=5) conn = HTTPConnection("localhost", 18080, timeout=5)
conn.request("GET", "/health") conn.request("GET", "/health")
response = conn.getresponse() response = conn.getresponse()
assert response.status == 200 assert response.status == 200
data = json.loads(response.read().decode()) data = json.loads(response.read().decode())
assert data["status"] == "healthy" assert data["status"] == "healthy"
finally: finally:
conn.close() conn.close()
health_server.stop() health_server.stop()
def test_ready_endpoint_default(self, health_server): def test_ready_endpoint_default(self, health_server):
"""Test the /ready endpoint with no custom check.""" """Test the /ready endpoint with no custom check."""
health_server.start() health_server.start()
time.sleep(0.1) time.sleep(0.1)
try: try:
conn = HTTPConnection("localhost", 18080, timeout=5) conn = HTTPConnection("localhost", 18080, timeout=5)
conn.request("GET", "/ready") conn.request("GET", "/ready")
response = conn.getresponse() response = conn.getresponse()
assert response.status == 200 assert response.status == 200
data = json.loads(response.read().decode()) data = json.loads(response.read().decode())
assert data["status"] == "ready" assert data["status"] == "ready"
finally: finally:
conn.close() conn.close()
health_server.stop() health_server.stop()
def test_ready_endpoint_with_check(self, settings): def test_ready_endpoint_with_check(self, settings):
"""Test /ready endpoint with custom readiness check.""" """Test /ready endpoint with custom readiness check."""
from handler_base.health import HealthServer from handler_base.health import HealthServer
ready_flag = [False] # Use list to allow mutation in closure ready_flag = [False] # Use list to allow mutation in closure
async def check_ready(): async def check_ready():
return ready_flag[0] return ready_flag[0]
settings.health_port = 18081 settings.health_port = 18081
server = HealthServer(settings, ready_check=check_ready) server = HealthServer(settings, ready_check=check_ready)
server.start() server.start()
time.sleep(0.2) time.sleep(0.2)
try: try:
conn = HTTPConnection("localhost", 18081, timeout=5) conn = HTTPConnection("localhost", 18081, timeout=5)
# Should be not ready initially # Should be not ready initially
conn.request("GET", "/ready") conn.request("GET", "/ready")
response = conn.getresponse() response = conn.getresponse()
response.read() # Consume response body response.read() # Consume response body
assert response.status == 503 assert response.status == 503
# Mark as ready # Mark as ready
ready_flag[0] = True ready_flag[0] = True
# Need new connection after consuming response # Need new connection after consuming response
conn.close() conn.close()
conn = HTTPConnection("localhost", 18081, timeout=5) conn = HTTPConnection("localhost", 18081, timeout=5)
@@ -105,17 +105,17 @@ class TestHealthServer:
finally: finally:
conn.close() conn.close()
server.stop() server.stop()
def test_404_for_unknown_path(self, health_server): def test_404_for_unknown_path(self, health_server):
"""Test that unknown paths return 404.""" """Test that unknown paths return 404."""
health_server.start() health_server.start()
time.sleep(0.1) time.sleep(0.1)
try: try:
conn = HTTPConnection("localhost", 18080, timeout=5) conn = HTTPConnection("localhost", 18080, timeout=5)
conn.request("GET", "/unknown") conn.request("GET", "/unknown")
response = conn.getresponse() response = conn.getresponse()
assert response.status == 404 assert response.status == 404
finally: finally:
conn.close() conn.close()

View File

@@ -1,48 +1,52 @@
""" """
Unit tests for handler_base.nats_client module. Unit tests for handler_base.nats_client module.
""" """
import pytest
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import msgpack import msgpack
import pytest
class TestNATSClient: class TestNATSClient:
"""Tests for NATSClient.""" """Tests for NATSClient."""
@pytest.fixture @pytest.fixture
def nats_client(self, settings): def nats_client(self, settings):
"""Create a NATSClient instance.""" """Create a NATSClient instance."""
from handler_base.nats_client import NATSClient from handler_base.nats_client import NATSClient
return NATSClient(settings) return NATSClient(settings)
def test_init(self, nats_client, settings): def test_init(self, nats_client, settings):
"""Test NATSClient initialization.""" """Test NATSClient initialization."""
assert nats_client.settings == settings assert nats_client.settings == settings
assert nats_client._nc is None assert nats_client._nc is None
assert nats_client._js is None assert nats_client._js is None
def test_decode_msgpack(self, nats_client): def test_decode_msgpack(self, nats_client):
"""Test msgpack decoding.""" """Test msgpack decoding."""
data = {"query": "hello", "request_id": "123"} data = {"query": "hello", "request_id": "123"}
encoded = msgpack.packb(data, use_bin_type=True) encoded = msgpack.packb(data, use_bin_type=True)
msg = MagicMock() msg = MagicMock()
msg.data = encoded msg.data = encoded
result = nats_client.decode_msgpack(msg) result = nats_client.decode_msgpack(msg)
assert result == data assert result == data
def test_decode_json(self, nats_client): def test_decode_json(self, nats_client):
"""Test JSON decoding.""" """Test JSON decoding."""
import json import json
data = {"query": "hello"} data = {"query": "hello"}
msg = MagicMock() msg = MagicMock()
msg.data = json.dumps(data).encode() msg.data = json.dumps(data).encode()
result = nats_client.decode_json(msg) result = nats_client.decode_json(msg)
assert result == data assert result == data
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_connect(self, nats_client): async def test_connect(self, nats_client):
"""Test NATS connection.""" """Test NATS connection."""
@@ -51,30 +55,30 @@ class TestNATSClient:
mock_js = MagicMock() mock_js = MagicMock()
mock_nc.jetstream = MagicMock(return_value=mock_js) # Not async mock_nc.jetstream = MagicMock(return_value=mock_js) # Not async
mock_nats.connect = AsyncMock(return_value=mock_nc) mock_nats.connect = AsyncMock(return_value=mock_nc)
await nats_client.connect() await nats_client.connect()
assert nats_client._nc == mock_nc assert nats_client._nc == mock_nc
assert nats_client._js == mock_js assert nats_client._js == mock_js
mock_nats.connect.assert_called_once() mock_nats.connect.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_publish(self, nats_client): async def test_publish(self, nats_client):
"""Test publishing a message.""" """Test publishing a message."""
mock_nc = AsyncMock() mock_nc = AsyncMock()
nats_client._nc = mock_nc nats_client._nc = mock_nc
data = {"key": "value"} data = {"key": "value"}
await nats_client.publish("test.subject", data) await nats_client.publish("test.subject", data)
mock_nc.publish.assert_called_once() mock_nc.publish.assert_called_once()
call_args = mock_nc.publish.call_args call_args = mock_nc.publish.call_args
assert call_args.args[0] == "test.subject" assert call_args.args[0] == "test.subject"
# Verify msgpack encoding # Verify msgpack encoding
decoded = msgpack.unpackb(call_args.args[1], raw=False) decoded = msgpack.unpackb(call_args.args[1], raw=False)
assert decoded == data assert decoded == data
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_subscribe(self, nats_client): async def test_subscribe(self, nats_client):
"""Test subscribing to a subject.""" """Test subscribing to a subject."""
@@ -82,10 +86,10 @@ class TestNATSClient:
mock_sub = MagicMock() mock_sub = MagicMock()
mock_nc.subscribe = AsyncMock(return_value=mock_sub) mock_nc.subscribe = AsyncMock(return_value=mock_sub)
nats_client._nc = mock_nc nats_client._nc = mock_nc
handler = AsyncMock() handler = AsyncMock()
await nats_client.subscribe("test.subject", handler, queue="test-queue") await nats_client.subscribe("test.subject", handler, queue="test-queue")
mock_nc.subscribe.assert_called_once() mock_nc.subscribe.assert_called_once()
call_kwargs = mock_nc.subscribe.call_args.kwargs call_kwargs = mock_nc.subscribe.call_args.kwargs
assert call_kwargs["queue"] == "test-queue" assert call_kwargs["queue"] == "test-queue"