fix: auto-fix ruff linting errors and remove unsupported upload-artifact
All checks were successful
CI / Lint (push) Successful in 52s
CI / Test (push) Successful in 1m1s
CI / Release (push) Successful in 5s
CI / Notify (push) Successful in 1s

This commit is contained in:
2026-02-02 08:34:00 -05:00
parent 7b30ff6a05
commit dbf1a93141
19 changed files with 414 additions and 400 deletions

View File

@@ -57,12 +57,6 @@ jobs:
- name: Run tests with coverage
run: uv run pytest --cov=handler_base --cov-report=xml --cov-report=term
- name: Upload coverage artifact
uses: actions/upload-artifact@v4
with:
name: coverage
path: coverage.xml
release:
name: Release
runs-on: ubuntu-latest

View File

@@ -8,11 +8,12 @@ Provides consistent patterns for:
- Graceful shutdown
- Service client wrappers
"""
from handler_base.config import Settings
from handler_base.handler import Handler
from handler_base.health import HealthServer
from handler_base.nats_client import NATSClient
from handler_base.telemetry import setup_telemetry, get_tracer, get_meter
from handler_base.telemetry import get_meter, get_tracer, setup_telemetry
__all__ = [
"Handler",

View File

@@ -1,12 +1,13 @@
"""
Service client wrappers for AI/ML backends.
"""
from handler_base.clients.embeddings import EmbeddingsClient
from handler_base.clients.reranker import RerankerClient
from handler_base.clients.llm import LLMClient
from handler_base.clients.tts import TTSClient
from handler_base.clients.stt import STTClient
from handler_base.clients.milvus import MilvusClient
from handler_base.clients.reranker import RerankerClient
from handler_base.clients.stt import STTClient
from handler_base.clients.tts import TTSClient
__all__ = [
"EmbeddingsClient",

View File

@@ -1,6 +1,7 @@
"""
Embeddings service client (Infinity/BGE).
"""
import logging
from typing import Optional
@@ -15,23 +16,23 @@ logger = logging.getLogger(__name__)
class EmbeddingsClient:
"""
Client for the embeddings service (Infinity with BGE models).
Usage:
client = EmbeddingsClient()
embeddings = await client.embed(["Hello world"])
"""
def __init__(self, settings: Optional[EmbeddingsSettings] = None):
self.settings = settings or EmbeddingsSettings()
self._client = httpx.AsyncClient(
base_url=self.settings.embeddings_url,
timeout=self.settings.http_timeout,
)
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def embed(
self,
texts: list[str],
@@ -39,49 +40,49 @@ class EmbeddingsClient:
) -> list[list[float]]:
"""
Generate embeddings for a list of texts.
Args:
texts: List of texts to embed
model: Model name (defaults to settings)
Returns:
List of embedding vectors
"""
model = model or self.settings.embeddings_model
with create_span("embeddings.embed") as span:
if span:
span.set_attribute("embeddings.model", model)
span.set_attribute("embeddings.batch_size", len(texts))
response = await self._client.post(
"/embeddings",
json={"input": texts, "model": model},
)
response.raise_for_status()
result = response.json()
embeddings = [d["embedding"] for d in result.get("data", [])]
if span:
span.set_attribute("embeddings.dimensions", len(embeddings[0]) if embeddings else 0)
return embeddings
async def embed_single(self, text: str, model: Optional[str] = None) -> list[float]:
"""
Generate embedding for a single text.
Args:
text: Text to embed
model: Model name (defaults to settings)
Returns:
Embedding vector
"""
embeddings = await self.embed([text], model)
return embeddings[0] if embeddings else []
async def health(self) -> bool:
"""Check if the embeddings service is healthy."""
try:

View File

@@ -1,8 +1,9 @@
"""
LLM service client (vLLM/OpenAI-compatible).
"""
import logging
from typing import Optional, AsyncIterator
from typing import AsyncIterator, Optional
import httpx
@@ -15,33 +16,33 @@ logger = logging.getLogger(__name__)
class LLMClient:
"""
Client for the LLM service (vLLM with OpenAI-compatible API).
Usage:
client = LLMClient()
response = await client.generate("Hello, how are you?")
# With context for RAG
response = await client.generate(
"What is the capital?",
context="France is a country in Europe..."
)
# Streaming
async for chunk in client.stream("Tell me a story"):
print(chunk, end="")
"""
def __init__(self, settings: Optional[LLMSettings] = None):
self.settings = settings or LLMSettings()
self._client = httpx.AsyncClient(
base_url=self.settings.llm_url,
timeout=self.settings.http_timeout,
)
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def generate(
self,
prompt: str,
@@ -54,7 +55,7 @@ class LLMClient:
) -> str:
"""
Generate a response from the LLM.
Args:
prompt: User prompt/query
context: Optional context for RAG
@@ -63,19 +64,19 @@ class LLMClient:
temperature: Sampling temperature
top_p: Top-p sampling
stop: Stop sequences
Returns:
Generated text response
"""
with create_span("llm.generate") as span:
messages = self._build_messages(prompt, context, system_prompt)
if span:
span.set_attribute("llm.model", self.settings.llm_model)
span.set_attribute("llm.prompt_length", len(prompt))
if context:
span.set_attribute("llm.context_length", len(context))
payload = {
"model": self.settings.llm_model,
"messages": messages,
@@ -85,21 +86,21 @@ class LLMClient:
}
if stop:
payload["stop"] = stop
response = await self._client.post("/v1/chat/completions", json=payload)
response.raise_for_status()
result = response.json()
content = result["choices"][0]["message"]["content"]
if span:
span.set_attribute("llm.response_length", len(content))
usage = result.get("usage", {})
span.set_attribute("llm.prompt_tokens", usage.get("prompt_tokens", 0))
span.set_attribute("llm.completion_tokens", usage.get("completion_tokens", 0))
return content
async def stream(
self,
prompt: str,
@@ -110,19 +111,19 @@ class LLMClient:
) -> AsyncIterator[str]:
"""
Stream a response from the LLM.
Args:
prompt: User prompt/query
context: Optional context for RAG
system_prompt: Optional system prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
Yields:
Text chunks as they're generated
"""
messages = self._build_messages(prompt, context, system_prompt)
payload = {
"model": self.settings.llm_model,
"messages": messages,
@@ -130,25 +131,24 @@ class LLMClient:
"temperature": temperature or self.settings.llm_temperature,
"stream": True,
}
async with self._client.stream(
"POST", "/v1/chat/completions", json=payload
) as response:
async with self._client.stream("POST", "/v1/chat/completions", json=payload) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data = line[6:]
if data == "[DONE]":
break
import json
chunk = json.loads(data)
delta = chunk["choices"][0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
def _build_messages(
self,
prompt: str,
@@ -157,32 +157,36 @@ class LLMClient:
) -> list[dict]:
"""Build the messages list for the API call."""
messages = []
# System prompt
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
elif context:
# Default RAG system prompt
messages.append({
"role": "system",
"content": (
"You are a helpful assistant. Use the provided context to answer "
"the user's question. If the context doesn't contain relevant "
"information, say so."
),
})
messages.append(
{
"role": "system",
"content": (
"You are a helpful assistant. Use the provided context to answer "
"the user's question. If the context doesn't contain relevant "
"information, say so."
),
}
)
# Add context as a separate message if provided
if context:
messages.append({
"role": "user",
"content": f"Context:\n{context}\n\nQuestion: {prompt}",
})
messages.append(
{
"role": "user",
"content": f"Context:\n{context}\n\nQuestion: {prompt}",
}
)
else:
messages.append({"role": "user", "content": prompt})
return messages
async def health(self) -> bool:
"""Check if the LLM service is healthy."""
try:

View File

@@ -1,10 +1,11 @@
"""
Milvus vector database client.
"""
import logging
from typing import Optional, Any
from pymilvus import connections, Collection, utility
import logging
from typing import Optional
from pymilvus import Collection, connections, utility
from handler_base.config import Settings
from handler_base.telemetry import create_span
@@ -15,42 +16,42 @@ logger = logging.getLogger(__name__)
class MilvusClient:
"""
Client for Milvus vector database.
Usage:
client = MilvusClient()
await client.connect()
results = await client.search(embedding, limit=10)
"""
def __init__(self, settings: Optional[Settings] = None):
self.settings = settings or Settings()
self._connected = False
self._collection: Optional[Collection] = None
async def connect(self, collection_name: Optional[str] = None) -> None:
"""
Connect to Milvus and load collection.
Args:
collection_name: Collection to use (defaults to settings)
"""
collection_name = collection_name or self.settings.milvus_collection
connections.connect(
alias="default",
host=self.settings.milvus_host,
port=self.settings.milvus_port,
)
if utility.has_collection(collection_name):
self._collection = Collection(collection_name)
self._collection.load()
logger.info(f"Connected to Milvus collection: {collection_name}")
else:
logger.warning(f"Collection {collection_name} does not exist")
self._connected = True
async def close(self) -> None:
"""Close Milvus connection."""
if self._collection:
@@ -58,7 +59,7 @@ class MilvusClient:
connections.disconnect("default")
self._connected = False
logger.info("Disconnected from Milvus")
async def search(
self,
embedding: list[float],
@@ -68,26 +69,26 @@ class MilvusClient:
) -> list[dict]:
"""
Search for similar vectors.
Args:
embedding: Query embedding vector
limit: Maximum number of results
output_fields: Fields to return (default: all)
filter_expr: Optional filter expression
Returns:
List of results with 'id', 'distance', and requested fields
"""
if not self._collection:
raise RuntimeError("Not connected to collection")
with create_span("milvus.search") as span:
if span:
span.set_attribute("milvus.collection", self._collection.name)
span.set_attribute("milvus.limit", limit)
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
results = self._collection.search(
data=[embedding],
anns_field="embedding",
@@ -96,7 +97,7 @@ class MilvusClient:
output_fields=output_fields,
expr=filter_expr,
)
# Convert to list of dicts
hits = []
for hit in results[0]:
@@ -111,12 +112,12 @@ class MilvusClient:
if hasattr(hit.entity, field):
item[field] = getattr(hit.entity, field)
hits.append(item)
if span:
span.set_attribute("milvus.results", len(hits))
return hits
async def search_with_texts(
self,
embedding: list[float],
@@ -126,22 +127,22 @@ class MilvusClient:
) -> list[dict]:
"""
Search and return text content with metadata.
Args:
embedding: Query embedding
limit: Maximum results
text_field: Name of text field in collection
metadata_fields: Additional metadata fields to return
Returns:
List of results with text and metadata
"""
output_fields = [text_field]
if metadata_fields:
output_fields.extend(metadata_fields)
return await self.search(embedding, limit, output_fields)
async def insert(
self,
embeddings: list[list[float]],
@@ -149,34 +150,34 @@ class MilvusClient:
) -> list[int]:
"""
Insert vectors with data into the collection.
Args:
embeddings: List of embedding vectors
data: List of dicts with field values
Returns:
List of inserted IDs
"""
if not self._collection:
raise RuntimeError("Not connected to collection")
with create_span("milvus.insert") as span:
if span:
span.set_attribute("milvus.collection", self._collection.name)
span.set_attribute("milvus.count", len(embeddings))
# Build insert data
insert_data = [embeddings]
for field in self._collection.schema.fields:
if field.name not in ("id", "embedding"):
field_values = [d.get(field.name) for d in data]
insert_data.append(field_values)
result = self._collection.insert(insert_data)
self._collection.flush()
return result.primary_keys
def health(self) -> bool:
"""Check if connected to Milvus."""
return self._connected and utility.get_connection_addr("default") is not None

View File

@@ -1,6 +1,7 @@
"""
Reranker service client (Infinity/BGE Reranker).
"""
import logging
from typing import Optional
@@ -15,23 +16,23 @@ logger = logging.getLogger(__name__)
class RerankerClient:
"""
Client for the reranker service (Infinity with BGE Reranker).
Usage:
client = RerankerClient()
reranked = await client.rerank("query", ["doc1", "doc2"])
"""
def __init__(self, settings: Optional[Settings] = None):
self.settings = settings or Settings()
self._client = httpx.AsyncClient(
base_url=self.settings.reranker_url,
timeout=self.settings.http_timeout,
)
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def rerank(
self,
query: str,
@@ -40,12 +41,12 @@ class RerankerClient:
) -> list[dict]:
"""
Rerank documents based on relevance to query.
Args:
query: Query text
documents: List of documents to rerank
top_k: Number of top results to return (default: all)
Returns:
List of dicts with 'index', 'score', and 'document' keys,
sorted by relevance score descending.
@@ -55,32 +56,34 @@ class RerankerClient:
span.set_attribute("reranker.num_documents", len(documents))
if top_k:
span.set_attribute("reranker.top_k", top_k)
payload = {
"query": query,
"documents": documents,
}
if top_k:
payload["top_n"] = top_k
response = await self._client.post("/rerank", json=payload)
response.raise_for_status()
result = response.json()
results = result.get("results", [])
# Enrich with original documents
enriched = []
for r in results:
idx = r.get("index", 0)
enriched.append({
"index": idx,
"score": r.get("relevance_score", r.get("score", 0)),
"document": documents[idx] if idx < len(documents) else "",
})
enriched.append(
{
"index": idx,
"score": r.get("relevance_score", r.get("score", 0)),
"document": documents[idx] if idx < len(documents) else "",
}
)
return enriched
async def rerank_with_metadata(
self,
query: str,
@@ -90,27 +93,27 @@ class RerankerClient:
) -> list[dict]:
"""
Rerank documents with metadata, preserving metadata in results.
Args:
query: Query text
documents: List of dicts with text and metadata
text_key: Key containing text in each document dict
top_k: Number of top results to return
Returns:
Reranked documents with original metadata preserved.
"""
texts = [d.get(text_key, "") for d in documents]
reranked = await self.rerank(query, texts, top_k)
# Merge back metadata
for r in reranked:
idx = r["index"]
if idx < len(documents):
r["metadata"] = {k: v for k, v in documents[idx].items() if k != text_key}
return reranked
async def health(self) -> bool:
"""Check if the reranker service is healthy."""
try:

View File

@@ -1,7 +1,7 @@
"""
STT service client (Whisper/faster-whisper).
"""
import io
import logging
from typing import Optional
@@ -16,23 +16,23 @@ logger = logging.getLogger(__name__)
class STTClient:
"""
Client for the STT service (Whisper/faster-whisper).
Usage:
client = STTClient()
text = await client.transcribe(audio_bytes)
"""
def __init__(self, settings: Optional[STTSettings] = None):
self.settings = settings or STTSettings()
self._client = httpx.AsyncClient(
base_url=self.settings.stt_url,
timeout=180.0, # Transcription can be slow
)
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def transcribe(
self,
audio: bytes,
@@ -42,54 +42,54 @@ class STTClient:
) -> dict:
"""
Transcribe audio to text.
Args:
audio: Audio bytes (WAV, MP3, etc.)
language: Language code (None for auto-detect)
task: "transcribe" or "translate"
response_format: "json", "text", "srt", "vtt"
Returns:
Dict with 'text', 'language', and optional 'segments'
"""
language = language or self.settings.stt_language
task = task or self.settings.stt_task
with create_span("stt.transcribe") as span:
if span:
span.set_attribute("stt.task", task)
span.set_attribute("stt.audio_size", len(audio))
if language:
span.set_attribute("stt.language", language)
files = {"file": ("audio.wav", audio, "audio/wav")}
data = {
"response_format": response_format,
}
if language:
data["language"] = language
# Choose endpoint based on task
if task == "translate":
endpoint = "/v1/audio/translations"
else:
endpoint = "/v1/audio/transcriptions"
response = await self._client.post(endpoint, files=files, data=data)
response.raise_for_status()
if response_format == "text":
return {"text": response.text}
result = response.json()
if span:
span.set_attribute("stt.result_length", len(result.get("text", "")))
if result.get("language"):
span.set_attribute("stt.detected_language", result["language"])
return result
async def transcribe_file(
self,
file_path: str,
@@ -98,31 +98,31 @@ class STTClient:
) -> dict:
"""
Transcribe an audio file.
Args:
file_path: Path to audio file
language: Language code
task: "transcribe" or "translate"
Returns:
Transcription result
"""
with open(file_path, "rb") as f:
audio = f.read()
return await self.transcribe(audio, language, task)
async def translate(self, audio: bytes) -> dict:
"""
Translate audio to English.
Args:
audio: Audio bytes
Returns:
Translation result with 'text' key
"""
return await self.transcribe(audio, task="translate")
async def health(self) -> bool:
"""Check if the STT service is healthy."""
try:

View File

@@ -1,7 +1,7 @@
"""
TTS service client (Coqui XTTS).
"""
import io
import logging
from typing import Optional
@@ -16,23 +16,23 @@ logger = logging.getLogger(__name__)
class TTSClient:
"""
Client for the TTS service (Coqui XTTS).
Usage:
client = TTSClient()
audio_bytes = await client.synthesize("Hello world")
"""
def __init__(self, settings: Optional[TTSSettings] = None):
self.settings = settings or TTSSettings()
self._client = httpx.AsyncClient(
base_url=self.settings.tts_url,
timeout=120.0, # TTS can be slow
)
async def close(self) -> None:
"""Close the HTTP client."""
await self._client.aclose()
async def synthesize(
self,
text: str,
@@ -41,39 +41,39 @@ class TTSClient:
) -> bytes:
"""
Synthesize speech from text.
Args:
text: Text to synthesize
language: Language code (e.g., "en", "es", "fr")
speaker: Speaker ID or reference
Returns:
WAV audio bytes
"""
language = language or self.settings.tts_language
with create_span("tts.synthesize") as span:
if span:
span.set_attribute("tts.language", language)
span.set_attribute("tts.text_length", len(text))
params = {
"text": text,
"language_id": language,
}
if speaker:
params["speaker_id"] = speaker
response = await self._client.get("/api/tts", params=params)
response.raise_for_status()
audio_bytes = response.content
if span:
span.set_attribute("tts.audio_size", len(audio_bytes))
return audio_bytes
async def synthesize_to_file(
self,
text: str,
@@ -83,7 +83,7 @@ class TTSClient:
) -> None:
"""
Synthesize speech and save to a file.
Args:
text: Text to synthesize
output_path: Path to save the audio file
@@ -91,10 +91,10 @@ class TTSClient:
speaker: Speaker ID
"""
audio_bytes = await self.synthesize(text, language, speaker)
with open(output_path, "wb") as f:
f.write(audio_bytes)
async def get_speakers(self) -> list[dict]:
"""Get available speakers/voices."""
try:
@@ -103,7 +103,7 @@ class TTSClient:
return response.json()
except Exception:
return []
async def health(self) -> bool:
"""Check if the TTS service is healthy."""
try:

View File

@@ -3,67 +3,69 @@ Configuration management using Pydantic Settings.
Environment variables are automatically loaded and validated.
"""
from typing import Optional
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Base settings for all handler services."""
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
# Service identification
service_name: str = "handler"
service_version: str = "1.0.0"
service_namespace: str = "ai-ml"
deployment_env: str = "production"
# NATS configuration
nats_url: str = "nats://nats.ai-ml.svc.cluster.local:4222"
nats_user: Optional[str] = None
nats_password: Optional[str] = None
nats_queue_group: Optional[str] = None
# Redis/Valkey configuration
redis_url: str = "redis://valkey.ai-ml.svc.cluster.local:6379"
redis_password: Optional[str] = None
# Milvus configuration
milvus_host: str = "milvus.ai-ml.svc.cluster.local"
milvus_port: int = 19530
milvus_collection: str = "documents"
# Service endpoints
embeddings_url: str = "http://embeddings-predictor.ai-ml.svc.cluster.local"
reranker_url: str = "http://reranker-predictor.ai-ml.svc.cluster.local"
llm_url: str = "http://vllm-predictor.ai-ml.svc.cluster.local"
tts_url: str = "http://tts-predictor.ai-ml.svc.cluster.local"
stt_url: str = "http://whisper-predictor.ai-ml.svc.cluster.local"
# OpenTelemetry configuration
otel_enabled: bool = True
otel_endpoint: str = "http://opentelemetry-collector.observability.svc.cluster.local:4317"
otel_use_http: bool = False
# HyperDX configuration
hyperdx_enabled: bool = False
hyperdx_api_key: Optional[str] = None
hyperdx_endpoint: str = "https://in-otel.hyperdx.io"
# MLflow configuration
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80"
mlflow_experiment_name: Optional[str] = None
mlflow_enabled: bool = True
# Health check configuration
health_port: int = 8080
health_path: str = "/health"
ready_path: str = "/ready"
# Timeouts (seconds)
http_timeout: float = 60.0
nats_timeout: float = 30.0
@@ -71,14 +73,14 @@ class Settings(BaseSettings):
class EmbeddingsSettings(Settings):
"""Settings for embeddings service client."""
embeddings_model: str = "bge"
embeddings_batch_size: int = 32
class LLMSettings(Settings):
"""Settings for LLM service client."""
llm_model: str = "default"
llm_max_tokens: int = 2048
llm_temperature: float = 0.7
@@ -87,13 +89,13 @@ class LLMSettings(Settings):
class TTSSettings(Settings):
"""Settings for TTS service client."""
tts_language: str = "en"
tts_speaker: Optional[str] = None
class STTSettings(Settings):
"""Settings for STT service client."""
stt_language: Optional[str] = None # Auto-detect
stt_task: str = "transcribe" # or "translate"

View File

@@ -1,6 +1,7 @@
"""
Base handler class for building NATS-based services.
"""
import asyncio
import logging
import signal
@@ -12,7 +13,7 @@ from nats.aio.msg import Msg
from handler_base.config import Settings
from handler_base.health import HealthServer
from handler_base.nats_client import NATSClient
from handler_base.telemetry import setup_telemetry, create_span
from handler_base.telemetry import create_span, setup_telemetry
logger = logging.getLogger(__name__)
@@ -20,25 +21,25 @@ logger = logging.getLogger(__name__)
class Handler(ABC):
"""
Base class for NATS message handlers.
Subclass and implement:
- setup(): Initialize your service clients
- handle_message(): Process incoming messages
- teardown(): Clean up resources (optional)
Example:
class MyHandler(Handler):
async def setup(self):
self.embeddings = EmbeddingsClient()
async def handle_message(self, msg: Msg, data: dict) -> Optional[dict]:
result = await self.embeddings.embed(data["text"])
return {"embedding": result}
if __name__ == "__main__":
MyHandler(subject="my.subject").run()
"""
def __init__(
self,
subject: str,
@@ -47,7 +48,7 @@ class Handler(ABC):
):
"""
Initialize the handler.
Args:
subject: NATS subject to subscribe to
settings: Configuration settings
@@ -56,78 +57,78 @@ class Handler(ABC):
self.subject = subject
self.settings = settings or Settings()
self.queue_group = queue_group or self.settings.nats_queue_group
self.nats = NATSClient(self.settings)
self.health_server = HealthServer(self.settings, self._check_ready)
self._running = False
self._shutdown_event = asyncio.Event()
@abstractmethod
async def setup(self) -> None:
"""
Initialize service clients and resources.
Called once before starting to handle messages.
Override this to set up your service-specific clients.
"""
pass
@abstractmethod
async def handle_message(self, msg: Msg, data: Any) -> Optional[Any]:
"""
Handle an incoming message.
Args:
msg: Raw NATS message
data: Decoded message data (msgpack unpacked)
Returns:
Optional response data. If returned and msg has a reply subject,
the response will be sent automatically.
"""
pass
async def teardown(self) -> None:
"""
Clean up resources.
Called during graceful shutdown.
Override to add custom cleanup logic.
"""
pass
async def _check_ready(self) -> bool:
"""Check if the service is ready to handle requests."""
return self._running and self.nats._nc is not None
async def _message_handler(self, msg: Msg) -> None:
"""Internal message handler with tracing and error handling."""
with create_span(f"handle.{self.subject}") as span:
try:
# Decode message
data = NATSClient.decode_msgpack(msg)
if span:
span.set_attribute("messaging.destination", msg.subject)
if isinstance(data, dict):
request_id = data.get("request_id", data.get("id"))
if request_id:
span.set_attribute("request.id", str(request_id))
# Handle message
response = await self.handle_message(msg, data)
# Send response if applicable
if response is not None and msg.reply:
await self.nats.publish(msg.reply, response)
except Exception as e:
logger.exception(f"Error handling message on {msg.subject}")
if span:
span.set_attribute("error", True)
span.set_attribute("error.message", str(e))
# Send error response if reply expected
if msg.reply:
error_response = {
@@ -136,71 +137,71 @@ class Handler(ABC):
"type": type(e).__name__,
}
await self.nats.publish(msg.reply, error_response)
def _setup_signals(self) -> None:
"""Set up signal handlers for graceful shutdown."""
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, self._handle_signal, sig)
def _handle_signal(self, sig: signal.Signals) -> None:
"""Handle shutdown signal."""
logger.info(f"Received {sig.name}, initiating graceful shutdown...")
self._shutdown_event.set()
async def _run(self) -> None:
"""Main async run loop."""
# Setup telemetry
setup_telemetry(self.settings)
# Start health server
self.health_server.start()
try:
# Connect to NATS
await self.nats.connect()
# Run user setup
logger.info("Running service setup...")
await self.setup()
# Subscribe to subject
await self.nats.subscribe(
self.subject,
self._message_handler,
queue=self.queue_group,
)
self._running = True
logger.info(f"Handler ready, listening on {self.subject}")
# Wait for shutdown signal
await self._shutdown_event.wait()
except Exception as e:
except Exception:
logger.exception("Fatal error in handler")
raise
finally:
self._running = False
# Graceful shutdown
logger.info("Shutting down...")
try:
await self.teardown()
except Exception as e:
logger.warning(f"Error during teardown: {e}")
await self.nats.close()
self.health_server.stop()
logger.info("Shutdown complete")
def run(self) -> None:
"""
Run the handler.
This is the main entry point. It sets up signal handlers
and runs the async event loop.
"""
@@ -209,12 +210,12 @@ class Handler(ABC):
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger.info(f"Starting {self.settings.service_name} v{self.settings.service_version}")
# Run the async loop
asyncio.run(self._run_with_signals())
async def _run_with_signals(self) -> None:
"""Run with signal handling."""
self._setup_signals()

View File

@@ -3,12 +3,13 @@ HTTP health check server.
Provides /health and /ready endpoints for Kubernetes probes.
"""
import asyncio
import logging
from typing import Callable, Optional, Awaitable
from http.server import HTTPServer, BaseHTTPRequestHandler
import threading
import json
import logging
import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Awaitable, Callable, Optional
from handler_base.config import Settings
@@ -17,16 +18,16 @@ logger = logging.getLogger(__name__)
class HealthHandler(BaseHTTPRequestHandler):
"""HTTP request handler for health checks."""
# Class-level state
ready_check: Optional[Callable[[], Awaitable[bool]]] = None
health_path: str = "/health"
ready_path: str = "/ready"
def log_message(self, format, *args):
"""Suppress default logging."""
pass
def do_GET(self):
"""Handle GET requests for health/ready endpoints."""
if self.path == self.health_path:
@@ -35,7 +36,7 @@ class HealthHandler(BaseHTTPRequestHandler):
self._handle_ready()
else:
self._respond_not_found()
def _handle_ready(self):
"""Check readiness and respond."""
# Access via class to avoid method binding issues
@@ -43,7 +44,7 @@ class HealthHandler(BaseHTTPRequestHandler):
if ready_check is None:
self._respond_ok({"status": "ready"})
return
try:
# Run the async check in a new event loop
loop = asyncio.new_event_loop()
@@ -51,7 +52,7 @@ class HealthHandler(BaseHTTPRequestHandler):
is_ready = loop.run_until_complete(ready_check())
finally:
loop.close()
if is_ready:
self._respond_ok({"status": "ready"})
else:
@@ -59,19 +60,19 @@ class HealthHandler(BaseHTTPRequestHandler):
except Exception as e:
logger.exception("Readiness check failed")
self._respond_unavailable({"status": "error", "message": str(e)})
def _respond_ok(self, data: dict):
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(json.dumps(data).encode())
def _respond_unavailable(self, data: dict):
self.send_response(503)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(json.dumps(data).encode())
def _respond_not_found(self):
self.send_response(404)
self.end_headers()
@@ -80,14 +81,14 @@ class HealthHandler(BaseHTTPRequestHandler):
class HealthServer:
"""
Background HTTP server for health checks.
Usage:
server = HealthServer(settings)
server.start()
# ... run your service ...
server.stop()
"""
def __init__(
self,
settings: Optional[Settings] = None,
@@ -97,24 +98,24 @@ class HealthServer:
self.ready_check = ready_check
self._server: Optional[HTTPServer] = None
self._thread: Optional[threading.Thread] = None
def start(self) -> None:
"""Start the health check server in a background thread."""
# Configure handler class
HealthHandler.ready_check = self.ready_check
HealthHandler.health_path = self.settings.health_path
HealthHandler.ready_path = self.settings.ready_path
# Create and start server
self._server = HTTPServer(("0.0.0.0", self.settings.health_port), HealthHandler)
self._thread = threading.Thread(target=self._server.serve_forever, daemon=True)
self._thread.start()
logger.info(
f"Health server started on port {self.settings.health_port} "
f"(health: {self.settings.health_path}, ready: {self.settings.ready_path})"
)
def stop(self) -> None:
"""Stop the health check server."""
if self._server:

View File

@@ -1,9 +1,9 @@
"""
NATS client wrapper with connection management and utilities.
"""
import asyncio
import logging
from typing import Any, Callable, Optional, Awaitable
from typing import Any, Awaitable, Callable, Optional
import msgpack
import nats
@@ -20,34 +20,34 @@ logger = logging.getLogger(__name__)
class NATSClient:
"""
NATS client with automatic connection management.
Supports:
- Core NATS pub/sub
- JetStream for persistence
- Queue groups for load balancing
- Msgpack serialization
"""
def __init__(self, settings: Optional[Settings] = None):
self.settings = settings or Settings()
self._nc: Optional[Client] = None
self._js: Optional[JetStreamContext] = None
self._subscriptions: list = []
@property
def nc(self) -> Client:
"""Get the NATS client, raising if not connected."""
if self._nc is None:
raise RuntimeError("NATS client not connected. Call connect() first.")
return self._nc
@property
def js(self) -> JetStreamContext:
"""Get JetStream context, raising if not connected."""
if self._js is None:
raise RuntimeError("JetStream not initialized. Call connect() first.")
return self._js
async def connect(self) -> None:
"""Connect to NATS server."""
connect_opts = {
@@ -55,16 +55,16 @@ class NATSClient:
"reconnect_time_wait": 2,
"max_reconnect_attempts": -1, # Infinite
}
if self.settings.nats_user and self.settings.nats_password:
connect_opts["user"] = self.settings.nats_user
connect_opts["password"] = self.settings.nats_password
logger.info(f"Connecting to NATS at {self.settings.nats_url}")
self._nc = await nats.connect(**connect_opts)
self._js = self._nc.jetstream()
logger.info("Connected to NATS")
async def close(self) -> None:
"""Close NATS connection gracefully."""
if self._nc:
@@ -74,13 +74,13 @@ class NATSClient:
await sub.drain()
except Exception as e:
logger.warning(f"Error draining subscription: {e}")
await self._nc.drain()
await self._nc.close()
self._nc = None
self._js = None
logger.info("NATS connection closed")
async def subscribe(
self,
subject: str,
@@ -89,24 +89,24 @@ class NATSClient:
):
"""
Subscribe to a subject with a handler function.
Args:
subject: NATS subject to subscribe to
handler: Async function to handle messages
queue: Optional queue group for load balancing
"""
queue = queue or self.settings.nats_queue_group
if queue:
sub = await self.nc.subscribe(subject, queue=queue, cb=handler)
logger.info(f"Subscribed to {subject} (queue: {queue})")
else:
sub = await self.nc.subscribe(subject, cb=handler)
logger.info(f"Subscribed to {subject}")
self._subscriptions.append(sub)
return sub
async def publish(
self,
subject: str,
@@ -115,7 +115,7 @@ class NATSClient:
) -> None:
"""
Publish a message to a subject.
Args:
subject: NATS subject to publish to
data: Data to publish (will be serialized)
@@ -124,15 +124,16 @@ class NATSClient:
with create_span("nats.publish") as span:
if span:
span.set_attribute("messaging.destination", subject)
if use_msgpack:
payload = msgpack.packb(data, use_bin_type=True)
else:
import json
payload = json.dumps(data).encode()
await self.nc.publish(subject, payload)
async def request(
self,
subject: str,
@@ -142,43 +143,46 @@ class NATSClient:
) -> Any:
"""
Send a request and wait for response.
Args:
subject: NATS subject to send request to
data: Request data
timeout: Response timeout in seconds
use_msgpack: Whether to use msgpack serialization
Returns:
Decoded response data
"""
timeout = timeout or self.settings.nats_timeout
with create_span("nats.request") as span:
if span:
span.set_attribute("messaging.destination", subject)
if use_msgpack:
payload = msgpack.packb(data, use_bin_type=True)
else:
import json
payload = json.dumps(data).encode()
response = await self.nc.request(subject, payload, timeout=timeout)
if use_msgpack:
return msgpack.unpackb(response.data, raw=False)
else:
import json
return json.loads(response.data.decode())
@staticmethod
def decode_msgpack(msg: Msg) -> Any:
"""Decode a msgpack message."""
return msgpack.unpackb(msg.data, raw=False)
@staticmethod
def decode_json(msg: Msg) -> Any:
"""Decode a JSON message."""
import json
return json.loads(msg.data.decode())

View File

@@ -3,26 +3,27 @@ OpenTelemetry setup for tracing and metrics.
Supports both gRPC and HTTP exporters, with optional HyperDX integration.
"""
import logging
import os
from typing import Optional, Tuple
from opentelemetry import trace, metrics
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter as OTLPSpanExporterHTTP,
)
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
OTLPMetricExporter as OTLPMetricExporterHTTP,
)
from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter as OTLPSpanExporterHTTP,
)
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
from opentelemetry.instrumentation.logging import LoggingInstrumentor
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.resources import SERVICE_NAME, SERVICE_NAMESPACE, SERVICE_VERSION, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from handler_base.config import Settings
@@ -39,35 +40,37 @@ def setup_telemetry(
) -> Tuple[Optional[trace.Tracer], Optional[metrics.Meter]]:
"""
Initialize OpenTelemetry tracing and metrics.
Args:
settings: Configuration settings. If None, loads from environment.
Returns:
Tuple of (tracer, meter) or (None, None) if disabled.
"""
global _tracer, _meter, _initialized
if _initialized:
return _tracer, _meter
if settings is None:
settings = Settings()
if not settings.otel_enabled:
logger.info("OpenTelemetry disabled")
_initialized = True
return None, None
# Create resource with service information
resource = Resource.create({
SERVICE_NAME: settings.service_name,
SERVICE_VERSION: settings.service_version,
SERVICE_NAMESPACE: settings.service_namespace,
"deployment.environment": settings.deployment_env,
"host.name": os.environ.get("HOSTNAME", "unknown"),
})
resource = Resource.create(
{
SERVICE_NAME: settings.service_name,
SERVICE_VERSION: settings.service_version,
SERVICE_NAMESPACE: settings.service_namespace,
"deployment.environment": settings.deployment_env,
"host.name": os.environ.get("HOSTNAME", "unknown"),
}
)
# Determine endpoint and exporter type
if settings.hyperdx_enabled and settings.hyperdx_api_key:
# HyperDX uses HTTP with API key header
@@ -80,7 +83,7 @@ def setup_telemetry(
headers = None
use_http = settings.otel_use_http
logger.info(f"Using OTEL endpoint: {endpoint} (HTTP: {use_http})")
# Setup tracing
if use_http:
trace_exporter = OTLPSpanExporterHTTP(
@@ -91,11 +94,11 @@ def setup_telemetry(
trace_exporter = OTLPSpanExporter(
endpoint=endpoint,
)
tracer_provider = TracerProvider(resource=resource)
tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter))
trace.set_tracer_provider(tracer_provider)
# Setup metrics
if use_http:
metric_exporter = OTLPMetricExporterHTTP(
@@ -106,25 +109,25 @@ def setup_telemetry(
metric_exporter = OTLPMetricExporter(
endpoint=endpoint,
)
metric_reader = PeriodicExportingMetricReader(
metric_exporter,
export_interval_millis=60000,
)
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
metrics.set_meter_provider(meter_provider)
# Instrument libraries
HTTPXClientInstrumentor().instrument()
LoggingInstrumentor().instrument(set_logging_format=True)
# Create tracer and meter for this service
_tracer = trace.get_tracer(settings.service_name, settings.service_version)
_meter = metrics.get_meter(settings.service_name, settings.service_version)
logger.info(f"OpenTelemetry initialized for {settings.service_name}")
_initialized = True
return _tracer, _meter
@@ -141,7 +144,7 @@ def get_meter() -> Optional[metrics.Meter]:
def create_span(name: str, **kwargs):
"""
Create a new span.
Usage:
with create_span("my_operation") as span:
span.set_attribute("key", "value")
@@ -150,5 +153,6 @@ def create_span(name: str, **kwargs):
if _tracer is None:
# Return a no-op context manager
from contextlib import nullcontext
return nullcontext()
return _tracer.start_as_current_span(name, **kwargs)

View File

@@ -1,14 +1,13 @@
"""
Pytest configuration and fixtures.
"""
import asyncio
import os
from typing import AsyncGenerator
from unittest.mock import AsyncMock, MagicMock
import pytest
# Set test environment variables before importing handler_base
os.environ.setdefault("NATS_URL", "nats://localhost:4222")
os.environ.setdefault("REDIS_URL", "redis://localhost:6379")
@@ -29,6 +28,7 @@ def event_loop():
def settings():
"""Create test settings."""
from handler_base.config import Settings
return Settings(
service_name="test-service",
service_version="1.0.0-test",
@@ -56,7 +56,7 @@ def mock_nats_message():
msg = MagicMock()
msg.subject = "test.subject"
msg.reply = "test.reply"
msg.data = b'\x82\xa8query\xa5hello\xaarequest_id\xa4test' # msgpack
msg.data = b"\x82\xa8query\xa5hello\xaarequest_id\xa4test" # msgpack
return msg

View File

@@ -1,44 +1,43 @@
"""
Unit tests for service clients.
"""
import json
from unittest.mock import MagicMock
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
class TestEmbeddingsClient:
"""Tests for EmbeddingsClient."""
@pytest.fixture
def embeddings_client(self, mock_httpx_client):
"""Create an EmbeddingsClient with mocked HTTP."""
from handler_base.clients.embeddings import EmbeddingsClient
client = EmbeddingsClient()
client._client = mock_httpx_client
return client
@pytest.mark.asyncio
async def test_embed_single(self, embeddings_client, mock_httpx_client, sample_embedding):
"""Test embedding a single text."""
# Setup mock response
mock_response = MagicMock()
mock_response.json.return_value = {
"data": [{"embedding": sample_embedding, "index": 0}]
}
mock_response.json.return_value = {"data": [{"embedding": sample_embedding, "index": 0}]}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response
result = await embeddings_client.embed_single("Hello world")
assert result == sample_embedding
mock_httpx_client.post.assert_called_once()
@pytest.mark.asyncio
async def test_embed_batch(self, embeddings_client, mock_httpx_client, sample_embedding):
"""Test embedding multiple texts."""
texts = ["Hello", "World"]
mock_response = MagicMock()
mock_response.json.return_value = {
"data": [
@@ -48,41 +47,41 @@ class TestEmbeddingsClient:
}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response
result = await embeddings_client.embed(texts)
assert len(result) == 2
assert all(len(e) == len(sample_embedding) for e in result)
@pytest.mark.asyncio
async def test_health_check(self, embeddings_client, mock_httpx_client):
"""Test health check."""
mock_response = MagicMock()
mock_response.status_code = 200
mock_httpx_client.get.return_value = mock_response
result = await embeddings_client.health()
assert result is True
class TestRerankerClient:
"""Tests for RerankerClient."""
@pytest.fixture
def reranker_client(self, mock_httpx_client):
"""Create a RerankerClient with mocked HTTP."""
from handler_base.clients.reranker import RerankerClient
client = RerankerClient()
client._client = mock_httpx_client
return client
@pytest.mark.asyncio
async def test_rerank(self, reranker_client, mock_httpx_client, sample_documents):
"""Test reranking documents."""
texts = [d["text"] for d in sample_documents]
mock_response = MagicMock()
mock_response.json.return_value = {
"results": [
@@ -93,9 +92,9 @@ class TestRerankerClient:
}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response
result = await reranker_client.rerank("What is ML?", texts)
assert len(result) == 3
assert result[0]["score"] == 0.95
assert result[0]["index"] == 1
@@ -103,53 +102,48 @@ class TestRerankerClient:
class TestLLMClient:
"""Tests for LLMClient."""
@pytest.fixture
def llm_client(self, mock_httpx_client):
"""Create an LLMClient with mocked HTTP."""
from handler_base.clients.llm import LLMClient
client = LLMClient()
client._client = mock_httpx_client
return client
@pytest.mark.asyncio
async def test_generate(self, llm_client, mock_httpx_client):
"""Test generating a response."""
mock_response = MagicMock()
mock_response.json.return_value = {
"choices": [
{"message": {"content": "Hello! I'm an AI assistant."}}
],
"usage": {"prompt_tokens": 10, "completion_tokens": 20}
"choices": [{"message": {"content": "Hello! I'm an AI assistant."}}],
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response
result = await llm_client.generate("Hello")
assert result == "Hello! I'm an AI assistant."
@pytest.mark.asyncio
async def test_generate_with_context(self, llm_client, mock_httpx_client):
"""Test generating with RAG context."""
mock_response = MagicMock()
mock_response.json.return_value = {
"choices": [
{"message": {"content": "Based on the context..."}}
],
"usage": {}
"choices": [{"message": {"content": "Based on the context..."}}],
"usage": {},
}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response
result = await llm_client.generate(
"What is Python?",
context="Python is a programming language."
"What is Python?", context="Python is a programming language."
)
assert "Based on the context" in result
# Verify context was included in the request
call_args = mock_httpx_client.post.call_args
messages = call_args.kwargs["json"]["messages"]

View File

@@ -1,46 +1,45 @@
"""
Unit tests for handler_base.config module.
"""
import os
import pytest
class TestSettings:
"""Tests for Settings configuration."""
def test_default_settings(self, settings):
"""Test that default settings are loaded correctly."""
assert settings.service_name == "test-service"
assert settings.service_version == "1.0.0-test"
assert settings.otel_enabled is False
def test_settings_from_env(self, monkeypatch):
"""Test that settings can be loaded from environment variables."""
monkeypatch.setenv("SERVICE_NAME", "env-service")
monkeypatch.setenv("SERVICE_VERSION", "2.0.0")
monkeypatch.setenv("NATS_URL", "nats://custom:4222")
# Need to reimport to pick up env changes
from handler_base.config import Settings
s = Settings()
assert s.service_name == "env-service"
assert s.service_version == "2.0.0"
assert s.nats_url == "nats://custom:4222"
def test_embeddings_settings(self):
"""Test EmbeddingsSettings extends base correctly."""
from handler_base.config import EmbeddingsSettings
s = EmbeddingsSettings()
assert hasattr(s, "embeddings_model")
assert hasattr(s, "embeddings_batch_size")
assert s.embeddings_model == "bge"
def test_llm_settings(self):
"""Test LLMSettings has expected defaults."""
from handler_base.config import LLMSettings
s = LLMSettings()
assert s.llm_max_tokens == 2048
assert s.llm_temperature == 0.7

View File

@@ -1,101 +1,101 @@
"""
Unit tests for handler_base.health module.
"""
import pytest
import json
import threading
import time
from http.client import HTTPConnection
from unittest.mock import AsyncMock
import pytest
class TestHealthServer:
"""Tests for HealthServer."""
@pytest.fixture
def health_server(self, settings):
"""Create a HealthServer instance."""
from handler_base.health import HealthServer
# Use a random high port to avoid conflicts
settings.health_port = 18080
return HealthServer(settings)
def test_start_stop(self, health_server):
"""Test starting and stopping the health server."""
health_server.start()
time.sleep(0.1) # Give server time to start
# Verify server is running
assert health_server._server is not None
assert health_server._thread is not None
assert health_server._thread.is_alive()
health_server.stop()
time.sleep(0.1)
assert health_server._server is None
def test_health_endpoint(self, health_server):
"""Test the /health endpoint."""
health_server.start()
time.sleep(0.1)
try:
conn = HTTPConnection("localhost", 18080, timeout=5)
conn.request("GET", "/health")
response = conn.getresponse()
assert response.status == 200
data = json.loads(response.read().decode())
assert data["status"] == "healthy"
finally:
conn.close()
health_server.stop()
def test_ready_endpoint_default(self, health_server):
"""Test the /ready endpoint with no custom check."""
health_server.start()
time.sleep(0.1)
try:
conn = HTTPConnection("localhost", 18080, timeout=5)
conn.request("GET", "/ready")
response = conn.getresponse()
assert response.status == 200
data = json.loads(response.read().decode())
assert data["status"] == "ready"
finally:
conn.close()
health_server.stop()
def test_ready_endpoint_with_check(self, settings):
"""Test /ready endpoint with custom readiness check."""
from handler_base.health import HealthServer
ready_flag = [False] # Use list to allow mutation in closure
async def check_ready():
return ready_flag[0]
settings.health_port = 18081
server = HealthServer(settings, ready_check=check_ready)
server.start()
time.sleep(0.2)
try:
conn = HTTPConnection("localhost", 18081, timeout=5)
# Should be not ready initially
conn.request("GET", "/ready")
response = conn.getresponse()
response.read() # Consume response body
assert response.status == 503
# Mark as ready
ready_flag[0] = True
# Need new connection after consuming response
conn.close()
conn = HTTPConnection("localhost", 18081, timeout=5)
@@ -105,17 +105,17 @@ class TestHealthServer:
finally:
conn.close()
server.stop()
def test_404_for_unknown_path(self, health_server):
"""Test that unknown paths return 404."""
health_server.start()
time.sleep(0.1)
try:
conn = HTTPConnection("localhost", 18080, timeout=5)
conn.request("GET", "/unknown")
response = conn.getresponse()
assert response.status == 404
finally:
conn.close()

View File

@@ -1,48 +1,52 @@
"""
Unit tests for handler_base.nats_client module.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import msgpack
import pytest
class TestNATSClient:
"""Tests for NATSClient."""
@pytest.fixture
def nats_client(self, settings):
"""Create a NATSClient instance."""
from handler_base.nats_client import NATSClient
return NATSClient(settings)
def test_init(self, nats_client, settings):
"""Test NATSClient initialization."""
assert nats_client.settings == settings
assert nats_client._nc is None
assert nats_client._js is None
def test_decode_msgpack(self, nats_client):
"""Test msgpack decoding."""
data = {"query": "hello", "request_id": "123"}
encoded = msgpack.packb(data, use_bin_type=True)
msg = MagicMock()
msg.data = encoded
result = nats_client.decode_msgpack(msg)
assert result == data
def test_decode_json(self, nats_client):
"""Test JSON decoding."""
import json
data = {"query": "hello"}
msg = MagicMock()
msg.data = json.dumps(data).encode()
result = nats_client.decode_json(msg)
assert result == data
@pytest.mark.asyncio
async def test_connect(self, nats_client):
"""Test NATS connection."""
@@ -51,30 +55,30 @@ class TestNATSClient:
mock_js = MagicMock()
mock_nc.jetstream = MagicMock(return_value=mock_js) # Not async
mock_nats.connect = AsyncMock(return_value=mock_nc)
await nats_client.connect()
assert nats_client._nc == mock_nc
assert nats_client._js == mock_js
mock_nats.connect.assert_called_once()
@pytest.mark.asyncio
async def test_publish(self, nats_client):
"""Test publishing a message."""
mock_nc = AsyncMock()
nats_client._nc = mock_nc
data = {"key": "value"}
await nats_client.publish("test.subject", data)
mock_nc.publish.assert_called_once()
call_args = mock_nc.publish.call_args
assert call_args.args[0] == "test.subject"
# Verify msgpack encoding
decoded = msgpack.unpackb(call_args.args[1], raw=False)
assert decoded == data
@pytest.mark.asyncio
async def test_subscribe(self, nats_client):
"""Test subscribing to a subject."""
@@ -82,10 +86,10 @@ class TestNATSClient:
mock_sub = MagicMock()
mock_nc.subscribe = AsyncMock(return_value=mock_sub)
nats_client._nc = mock_nc
handler = AsyncMock()
await nats_client.subscribe("test.subject", handler, queue="test-queue")
mock_nc.subscribe.assert_called_once()
call_kwargs = mock_nc.subscribe.call_args.kwargs
assert call_kwargs["queue"] == "test-queue"