diff --git a/.gitea/workflows/ci.yml b/.gitea/workflows/ci.yml index 05b5cef..db0009a 100644 --- a/.gitea/workflows/ci.yml +++ b/.gitea/workflows/ci.yml @@ -57,12 +57,6 @@ jobs: - name: Run tests with coverage run: uv run pytest --cov=handler_base --cov-report=xml --cov-report=term - - name: Upload coverage artifact - uses: actions/upload-artifact@v4 - with: - name: coverage - path: coverage.xml - release: name: Release runs-on: ubuntu-latest diff --git a/handler_base/__init__.py b/handler_base/__init__.py index e812c3f..1a011d8 100644 --- a/handler_base/__init__.py +++ b/handler_base/__init__.py @@ -8,11 +8,12 @@ Provides consistent patterns for: - Graceful shutdown - Service client wrappers """ + from handler_base.config import Settings from handler_base.handler import Handler from handler_base.health import HealthServer from handler_base.nats_client import NATSClient -from handler_base.telemetry import setup_telemetry, get_tracer, get_meter +from handler_base.telemetry import get_meter, get_tracer, setup_telemetry __all__ = [ "Handler", diff --git a/handler_base/clients/__init__.py b/handler_base/clients/__init__.py index c9df8ec..40e28e7 100644 --- a/handler_base/clients/__init__.py +++ b/handler_base/clients/__init__.py @@ -1,12 +1,13 @@ """ Service client wrappers for AI/ML backends. """ + from handler_base.clients.embeddings import EmbeddingsClient -from handler_base.clients.reranker import RerankerClient from handler_base.clients.llm import LLMClient -from handler_base.clients.tts import TTSClient -from handler_base.clients.stt import STTClient from handler_base.clients.milvus import MilvusClient +from handler_base.clients.reranker import RerankerClient +from handler_base.clients.stt import STTClient +from handler_base.clients.tts import TTSClient __all__ = [ "EmbeddingsClient", diff --git a/handler_base/clients/embeddings.py b/handler_base/clients/embeddings.py index 5a99b23..6fc5868 100644 --- a/handler_base/clients/embeddings.py +++ b/handler_base/clients/embeddings.py @@ -1,6 +1,7 @@ """ Embeddings service client (Infinity/BGE). """ + import logging from typing import Optional @@ -15,23 +16,23 @@ logger = logging.getLogger(__name__) class EmbeddingsClient: """ Client for the embeddings service (Infinity with BGE models). - + Usage: client = EmbeddingsClient() embeddings = await client.embed(["Hello world"]) """ - + def __init__(self, settings: Optional[EmbeddingsSettings] = None): self.settings = settings or EmbeddingsSettings() self._client = httpx.AsyncClient( base_url=self.settings.embeddings_url, timeout=self.settings.http_timeout, ) - + async def close(self) -> None: """Close the HTTP client.""" await self._client.aclose() - + async def embed( self, texts: list[str], @@ -39,49 +40,49 @@ class EmbeddingsClient: ) -> list[list[float]]: """ Generate embeddings for a list of texts. - + Args: texts: List of texts to embed model: Model name (defaults to settings) - + Returns: List of embedding vectors """ model = model or self.settings.embeddings_model - + with create_span("embeddings.embed") as span: if span: span.set_attribute("embeddings.model", model) span.set_attribute("embeddings.batch_size", len(texts)) - + response = await self._client.post( "/embeddings", json={"input": texts, "model": model}, ) response.raise_for_status() - + result = response.json() embeddings = [d["embedding"] for d in result.get("data", [])] - + if span: span.set_attribute("embeddings.dimensions", len(embeddings[0]) if embeddings else 0) - + return embeddings - + async def embed_single(self, text: str, model: Optional[str] = None) -> list[float]: """ Generate embedding for a single text. - + Args: text: Text to embed model: Model name (defaults to settings) - + Returns: Embedding vector """ embeddings = await self.embed([text], model) return embeddings[0] if embeddings else [] - + async def health(self) -> bool: """Check if the embeddings service is healthy.""" try: diff --git a/handler_base/clients/llm.py b/handler_base/clients/llm.py index 8e432c9..13baee9 100644 --- a/handler_base/clients/llm.py +++ b/handler_base/clients/llm.py @@ -1,8 +1,9 @@ """ LLM service client (vLLM/OpenAI-compatible). """ + import logging -from typing import Optional, AsyncIterator +from typing import AsyncIterator, Optional import httpx @@ -15,33 +16,33 @@ logger = logging.getLogger(__name__) class LLMClient: """ Client for the LLM service (vLLM with OpenAI-compatible API). - + Usage: client = LLMClient() response = await client.generate("Hello, how are you?") - + # With context for RAG response = await client.generate( "What is the capital?", context="France is a country in Europe..." ) - + # Streaming async for chunk in client.stream("Tell me a story"): print(chunk, end="") """ - + def __init__(self, settings: Optional[LLMSettings] = None): self.settings = settings or LLMSettings() self._client = httpx.AsyncClient( base_url=self.settings.llm_url, timeout=self.settings.http_timeout, ) - + async def close(self) -> None: """Close the HTTP client.""" await self._client.aclose() - + async def generate( self, prompt: str, @@ -54,7 +55,7 @@ class LLMClient: ) -> str: """ Generate a response from the LLM. - + Args: prompt: User prompt/query context: Optional context for RAG @@ -63,19 +64,19 @@ class LLMClient: temperature: Sampling temperature top_p: Top-p sampling stop: Stop sequences - + Returns: Generated text response """ with create_span("llm.generate") as span: messages = self._build_messages(prompt, context, system_prompt) - + if span: span.set_attribute("llm.model", self.settings.llm_model) span.set_attribute("llm.prompt_length", len(prompt)) if context: span.set_attribute("llm.context_length", len(context)) - + payload = { "model": self.settings.llm_model, "messages": messages, @@ -85,21 +86,21 @@ class LLMClient: } if stop: payload["stop"] = stop - + response = await self._client.post("/v1/chat/completions", json=payload) response.raise_for_status() - + result = response.json() content = result["choices"][0]["message"]["content"] - + if span: span.set_attribute("llm.response_length", len(content)) usage = result.get("usage", {}) span.set_attribute("llm.prompt_tokens", usage.get("prompt_tokens", 0)) span.set_attribute("llm.completion_tokens", usage.get("completion_tokens", 0)) - + return content - + async def stream( self, prompt: str, @@ -110,19 +111,19 @@ class LLMClient: ) -> AsyncIterator[str]: """ Stream a response from the LLM. - + Args: prompt: User prompt/query context: Optional context for RAG system_prompt: Optional system prompt max_tokens: Maximum tokens to generate temperature: Sampling temperature - + Yields: Text chunks as they're generated """ messages = self._build_messages(prompt, context, system_prompt) - + payload = { "model": self.settings.llm_model, "messages": messages, @@ -130,25 +131,24 @@ class LLMClient: "temperature": temperature or self.settings.llm_temperature, "stream": True, } - - async with self._client.stream( - "POST", "/v1/chat/completions", json=payload - ) as response: + + async with self._client.stream("POST", "/v1/chat/completions", json=payload) as response: response.raise_for_status() - + async for line in response.aiter_lines(): if line.startswith("data: "): data = line[6:] if data == "[DONE]": break - + import json + chunk = json.loads(data) delta = chunk["choices"][0].get("delta", {}) content = delta.get("content", "") if content: yield content - + def _build_messages( self, prompt: str, @@ -157,32 +157,36 @@ class LLMClient: ) -> list[dict]: """Build the messages list for the API call.""" messages = [] - + # System prompt if system_prompt: messages.append({"role": "system", "content": system_prompt}) elif context: # Default RAG system prompt - messages.append({ - "role": "system", - "content": ( - "You are a helpful assistant. Use the provided context to answer " - "the user's question. If the context doesn't contain relevant " - "information, say so." - ), - }) - + messages.append( + { + "role": "system", + "content": ( + "You are a helpful assistant. Use the provided context to answer " + "the user's question. If the context doesn't contain relevant " + "information, say so." + ), + } + ) + # Add context as a separate message if provided if context: - messages.append({ - "role": "user", - "content": f"Context:\n{context}\n\nQuestion: {prompt}", - }) + messages.append( + { + "role": "user", + "content": f"Context:\n{context}\n\nQuestion: {prompt}", + } + ) else: messages.append({"role": "user", "content": prompt}) - + return messages - + async def health(self) -> bool: """Check if the LLM service is healthy.""" try: diff --git a/handler_base/clients/milvus.py b/handler_base/clients/milvus.py index 8035a6c..1b4e309 100644 --- a/handler_base/clients/milvus.py +++ b/handler_base/clients/milvus.py @@ -1,10 +1,11 @@ """ Milvus vector database client. """ -import logging -from typing import Optional, Any -from pymilvus import connections, Collection, utility +import logging +from typing import Optional + +from pymilvus import Collection, connections, utility from handler_base.config import Settings from handler_base.telemetry import create_span @@ -15,42 +16,42 @@ logger = logging.getLogger(__name__) class MilvusClient: """ Client for Milvus vector database. - + Usage: client = MilvusClient() await client.connect() results = await client.search(embedding, limit=10) """ - + def __init__(self, settings: Optional[Settings] = None): self.settings = settings or Settings() self._connected = False self._collection: Optional[Collection] = None - + async def connect(self, collection_name: Optional[str] = None) -> None: """ Connect to Milvus and load collection. - + Args: collection_name: Collection to use (defaults to settings) """ collection_name = collection_name or self.settings.milvus_collection - + connections.connect( alias="default", host=self.settings.milvus_host, port=self.settings.milvus_port, ) - + if utility.has_collection(collection_name): self._collection = Collection(collection_name) self._collection.load() logger.info(f"Connected to Milvus collection: {collection_name}") else: logger.warning(f"Collection {collection_name} does not exist") - + self._connected = True - + async def close(self) -> None: """Close Milvus connection.""" if self._collection: @@ -58,7 +59,7 @@ class MilvusClient: connections.disconnect("default") self._connected = False logger.info("Disconnected from Milvus") - + async def search( self, embedding: list[float], @@ -68,26 +69,26 @@ class MilvusClient: ) -> list[dict]: """ Search for similar vectors. - + Args: embedding: Query embedding vector limit: Maximum number of results output_fields: Fields to return (default: all) filter_expr: Optional filter expression - + Returns: List of results with 'id', 'distance', and requested fields """ if not self._collection: raise RuntimeError("Not connected to collection") - + with create_span("milvus.search") as span: if span: span.set_attribute("milvus.collection", self._collection.name) span.set_attribute("milvus.limit", limit) - + search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} - + results = self._collection.search( data=[embedding], anns_field="embedding", @@ -96,7 +97,7 @@ class MilvusClient: output_fields=output_fields, expr=filter_expr, ) - + # Convert to list of dicts hits = [] for hit in results[0]: @@ -111,12 +112,12 @@ class MilvusClient: if hasattr(hit.entity, field): item[field] = getattr(hit.entity, field) hits.append(item) - + if span: span.set_attribute("milvus.results", len(hits)) - + return hits - + async def search_with_texts( self, embedding: list[float], @@ -126,22 +127,22 @@ class MilvusClient: ) -> list[dict]: """ Search and return text content with metadata. - + Args: embedding: Query embedding limit: Maximum results text_field: Name of text field in collection metadata_fields: Additional metadata fields to return - + Returns: List of results with text and metadata """ output_fields = [text_field] if metadata_fields: output_fields.extend(metadata_fields) - + return await self.search(embedding, limit, output_fields) - + async def insert( self, embeddings: list[list[float]], @@ -149,34 +150,34 @@ class MilvusClient: ) -> list[int]: """ Insert vectors with data into the collection. - + Args: embeddings: List of embedding vectors data: List of dicts with field values - + Returns: List of inserted IDs """ if not self._collection: raise RuntimeError("Not connected to collection") - + with create_span("milvus.insert") as span: if span: span.set_attribute("milvus.collection", self._collection.name) span.set_attribute("milvus.count", len(embeddings)) - + # Build insert data insert_data = [embeddings] for field in self._collection.schema.fields: if field.name not in ("id", "embedding"): field_values = [d.get(field.name) for d in data] insert_data.append(field_values) - + result = self._collection.insert(insert_data) self._collection.flush() - + return result.primary_keys - + def health(self) -> bool: """Check if connected to Milvus.""" return self._connected and utility.get_connection_addr("default") is not None diff --git a/handler_base/clients/reranker.py b/handler_base/clients/reranker.py index 4297d31..cfd9f37 100644 --- a/handler_base/clients/reranker.py +++ b/handler_base/clients/reranker.py @@ -1,6 +1,7 @@ """ Reranker service client (Infinity/BGE Reranker). """ + import logging from typing import Optional @@ -15,23 +16,23 @@ logger = logging.getLogger(__name__) class RerankerClient: """ Client for the reranker service (Infinity with BGE Reranker). - + Usage: client = RerankerClient() reranked = await client.rerank("query", ["doc1", "doc2"]) """ - + def __init__(self, settings: Optional[Settings] = None): self.settings = settings or Settings() self._client = httpx.AsyncClient( base_url=self.settings.reranker_url, timeout=self.settings.http_timeout, ) - + async def close(self) -> None: """Close the HTTP client.""" await self._client.aclose() - + async def rerank( self, query: str, @@ -40,12 +41,12 @@ class RerankerClient: ) -> list[dict]: """ Rerank documents based on relevance to query. - + Args: query: Query text documents: List of documents to rerank top_k: Number of top results to return (default: all) - + Returns: List of dicts with 'index', 'score', and 'document' keys, sorted by relevance score descending. @@ -55,32 +56,34 @@ class RerankerClient: span.set_attribute("reranker.num_documents", len(documents)) if top_k: span.set_attribute("reranker.top_k", top_k) - + payload = { "query": query, "documents": documents, } if top_k: payload["top_n"] = top_k - + response = await self._client.post("/rerank", json=payload) response.raise_for_status() - + result = response.json() results = result.get("results", []) - + # Enrich with original documents enriched = [] for r in results: idx = r.get("index", 0) - enriched.append({ - "index": idx, - "score": r.get("relevance_score", r.get("score", 0)), - "document": documents[idx] if idx < len(documents) else "", - }) - + enriched.append( + { + "index": idx, + "score": r.get("relevance_score", r.get("score", 0)), + "document": documents[idx] if idx < len(documents) else "", + } + ) + return enriched - + async def rerank_with_metadata( self, query: str, @@ -90,27 +93,27 @@ class RerankerClient: ) -> list[dict]: """ Rerank documents with metadata, preserving metadata in results. - + Args: query: Query text documents: List of dicts with text and metadata text_key: Key containing text in each document dict top_k: Number of top results to return - + Returns: Reranked documents with original metadata preserved. """ texts = [d.get(text_key, "") for d in documents] reranked = await self.rerank(query, texts, top_k) - + # Merge back metadata for r in reranked: idx = r["index"] if idx < len(documents): r["metadata"] = {k: v for k, v in documents[idx].items() if k != text_key} - + return reranked - + async def health(self) -> bool: """Check if the reranker service is healthy.""" try: diff --git a/handler_base/clients/stt.py b/handler_base/clients/stt.py index 0436e4e..cfb901f 100644 --- a/handler_base/clients/stt.py +++ b/handler_base/clients/stt.py @@ -1,7 +1,7 @@ """ STT service client (Whisper/faster-whisper). """ -import io + import logging from typing import Optional @@ -16,23 +16,23 @@ logger = logging.getLogger(__name__) class STTClient: """ Client for the STT service (Whisper/faster-whisper). - + Usage: client = STTClient() text = await client.transcribe(audio_bytes) """ - + def __init__(self, settings: Optional[STTSettings] = None): self.settings = settings or STTSettings() self._client = httpx.AsyncClient( base_url=self.settings.stt_url, timeout=180.0, # Transcription can be slow ) - + async def close(self) -> None: """Close the HTTP client.""" await self._client.aclose() - + async def transcribe( self, audio: bytes, @@ -42,54 +42,54 @@ class STTClient: ) -> dict: """ Transcribe audio to text. - + Args: audio: Audio bytes (WAV, MP3, etc.) language: Language code (None for auto-detect) task: "transcribe" or "translate" response_format: "json", "text", "srt", "vtt" - + Returns: Dict with 'text', 'language', and optional 'segments' """ language = language or self.settings.stt_language task = task or self.settings.stt_task - + with create_span("stt.transcribe") as span: if span: span.set_attribute("stt.task", task) span.set_attribute("stt.audio_size", len(audio)) if language: span.set_attribute("stt.language", language) - + files = {"file": ("audio.wav", audio, "audio/wav")} data = { "response_format": response_format, } if language: data["language"] = language - + # Choose endpoint based on task if task == "translate": endpoint = "/v1/audio/translations" else: endpoint = "/v1/audio/transcriptions" - + response = await self._client.post(endpoint, files=files, data=data) response.raise_for_status() - + if response_format == "text": return {"text": response.text} - + result = response.json() - + if span: span.set_attribute("stt.result_length", len(result.get("text", ""))) if result.get("language"): span.set_attribute("stt.detected_language", result["language"]) - + return result - + async def transcribe_file( self, file_path: str, @@ -98,31 +98,31 @@ class STTClient: ) -> dict: """ Transcribe an audio file. - + Args: file_path: Path to audio file language: Language code task: "transcribe" or "translate" - + Returns: Transcription result """ with open(file_path, "rb") as f: audio = f.read() return await self.transcribe(audio, language, task) - + async def translate(self, audio: bytes) -> dict: """ Translate audio to English. - + Args: audio: Audio bytes - + Returns: Translation result with 'text' key """ return await self.transcribe(audio, task="translate") - + async def health(self) -> bool: """Check if the STT service is healthy.""" try: diff --git a/handler_base/clients/tts.py b/handler_base/clients/tts.py index 691261a..ef7d7c3 100644 --- a/handler_base/clients/tts.py +++ b/handler_base/clients/tts.py @@ -1,7 +1,7 @@ """ TTS service client (Coqui XTTS). """ -import io + import logging from typing import Optional @@ -16,23 +16,23 @@ logger = logging.getLogger(__name__) class TTSClient: """ Client for the TTS service (Coqui XTTS). - + Usage: client = TTSClient() audio_bytes = await client.synthesize("Hello world") """ - + def __init__(self, settings: Optional[TTSSettings] = None): self.settings = settings or TTSSettings() self._client = httpx.AsyncClient( base_url=self.settings.tts_url, timeout=120.0, # TTS can be slow ) - + async def close(self) -> None: """Close the HTTP client.""" await self._client.aclose() - + async def synthesize( self, text: str, @@ -41,39 +41,39 @@ class TTSClient: ) -> bytes: """ Synthesize speech from text. - + Args: text: Text to synthesize language: Language code (e.g., "en", "es", "fr") speaker: Speaker ID or reference - + Returns: WAV audio bytes """ language = language or self.settings.tts_language - + with create_span("tts.synthesize") as span: if span: span.set_attribute("tts.language", language) span.set_attribute("tts.text_length", len(text)) - + params = { "text": text, "language_id": language, } if speaker: params["speaker_id"] = speaker - + response = await self._client.get("/api/tts", params=params) response.raise_for_status() - + audio_bytes = response.content - + if span: span.set_attribute("tts.audio_size", len(audio_bytes)) - + return audio_bytes - + async def synthesize_to_file( self, text: str, @@ -83,7 +83,7 @@ class TTSClient: ) -> None: """ Synthesize speech and save to a file. - + Args: text: Text to synthesize output_path: Path to save the audio file @@ -91,10 +91,10 @@ class TTSClient: speaker: Speaker ID """ audio_bytes = await self.synthesize(text, language, speaker) - + with open(output_path, "wb") as f: f.write(audio_bytes) - + async def get_speakers(self) -> list[dict]: """Get available speakers/voices.""" try: @@ -103,7 +103,7 @@ class TTSClient: return response.json() except Exception: return [] - + async def health(self) -> bool: """Check if the TTS service is healthy.""" try: diff --git a/handler_base/config.py b/handler_base/config.py index 4fc0d6e..fb1facd 100644 --- a/handler_base/config.py +++ b/handler_base/config.py @@ -3,67 +3,69 @@ Configuration management using Pydantic Settings. Environment variables are automatically loaded and validated. """ + from typing import Optional + from pydantic_settings import BaseSettings, SettingsConfigDict class Settings(BaseSettings): """Base settings for all handler services.""" - + model_config = SettingsConfigDict( env_file=".env", env_file_encoding="utf-8", extra="ignore", ) - + # Service identification service_name: str = "handler" service_version: str = "1.0.0" service_namespace: str = "ai-ml" deployment_env: str = "production" - + # NATS configuration nats_url: str = "nats://nats.ai-ml.svc.cluster.local:4222" nats_user: Optional[str] = None nats_password: Optional[str] = None nats_queue_group: Optional[str] = None - + # Redis/Valkey configuration redis_url: str = "redis://valkey.ai-ml.svc.cluster.local:6379" redis_password: Optional[str] = None - + # Milvus configuration milvus_host: str = "milvus.ai-ml.svc.cluster.local" milvus_port: int = 19530 milvus_collection: str = "documents" - + # Service endpoints embeddings_url: str = "http://embeddings-predictor.ai-ml.svc.cluster.local" reranker_url: str = "http://reranker-predictor.ai-ml.svc.cluster.local" llm_url: str = "http://vllm-predictor.ai-ml.svc.cluster.local" tts_url: str = "http://tts-predictor.ai-ml.svc.cluster.local" stt_url: str = "http://whisper-predictor.ai-ml.svc.cluster.local" - + # OpenTelemetry configuration otel_enabled: bool = True otel_endpoint: str = "http://opentelemetry-collector.observability.svc.cluster.local:4317" otel_use_http: bool = False - + # HyperDX configuration hyperdx_enabled: bool = False hyperdx_api_key: Optional[str] = None hyperdx_endpoint: str = "https://in-otel.hyperdx.io" - + # MLflow configuration mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80" mlflow_experiment_name: Optional[str] = None mlflow_enabled: bool = True - + # Health check configuration health_port: int = 8080 health_path: str = "/health" ready_path: str = "/ready" - + # Timeouts (seconds) http_timeout: float = 60.0 nats_timeout: float = 30.0 @@ -71,14 +73,14 @@ class Settings(BaseSettings): class EmbeddingsSettings(Settings): """Settings for embeddings service client.""" - + embeddings_model: str = "bge" embeddings_batch_size: int = 32 class LLMSettings(Settings): """Settings for LLM service client.""" - + llm_model: str = "default" llm_max_tokens: int = 2048 llm_temperature: float = 0.7 @@ -87,13 +89,13 @@ class LLMSettings(Settings): class TTSSettings(Settings): """Settings for TTS service client.""" - + tts_language: str = "en" tts_speaker: Optional[str] = None class STTSettings(Settings): """Settings for STT service client.""" - + stt_language: Optional[str] = None # Auto-detect stt_task: str = "transcribe" # or "translate" diff --git a/handler_base/handler.py b/handler_base/handler.py index 4e70d2b..b2a1528 100644 --- a/handler_base/handler.py +++ b/handler_base/handler.py @@ -1,6 +1,7 @@ """ Base handler class for building NATS-based services. """ + import asyncio import logging import signal @@ -12,7 +13,7 @@ from nats.aio.msg import Msg from handler_base.config import Settings from handler_base.health import HealthServer from handler_base.nats_client import NATSClient -from handler_base.telemetry import setup_telemetry, create_span +from handler_base.telemetry import create_span, setup_telemetry logger = logging.getLogger(__name__) @@ -20,25 +21,25 @@ logger = logging.getLogger(__name__) class Handler(ABC): """ Base class for NATS message handlers. - + Subclass and implement: - setup(): Initialize your service clients - handle_message(): Process incoming messages - teardown(): Clean up resources (optional) - + Example: class MyHandler(Handler): async def setup(self): self.embeddings = EmbeddingsClient() - + async def handle_message(self, msg: Msg, data: dict) -> Optional[dict]: result = await self.embeddings.embed(data["text"]) return {"embedding": result} - + if __name__ == "__main__": MyHandler(subject="my.subject").run() """ - + def __init__( self, subject: str, @@ -47,7 +48,7 @@ class Handler(ABC): ): """ Initialize the handler. - + Args: subject: NATS subject to subscribe to settings: Configuration settings @@ -56,78 +57,78 @@ class Handler(ABC): self.subject = subject self.settings = settings or Settings() self.queue_group = queue_group or self.settings.nats_queue_group - + self.nats = NATSClient(self.settings) self.health_server = HealthServer(self.settings, self._check_ready) - + self._running = False self._shutdown_event = asyncio.Event() - + @abstractmethod async def setup(self) -> None: """ Initialize service clients and resources. - + Called once before starting to handle messages. Override this to set up your service-specific clients. """ pass - + @abstractmethod async def handle_message(self, msg: Msg, data: Any) -> Optional[Any]: """ Handle an incoming message. - + Args: msg: Raw NATS message data: Decoded message data (msgpack unpacked) - + Returns: Optional response data. If returned and msg has a reply subject, the response will be sent automatically. """ pass - + async def teardown(self) -> None: """ Clean up resources. - + Called during graceful shutdown. Override to add custom cleanup logic. """ pass - + async def _check_ready(self) -> bool: """Check if the service is ready to handle requests.""" return self._running and self.nats._nc is not None - + async def _message_handler(self, msg: Msg) -> None: """Internal message handler with tracing and error handling.""" with create_span(f"handle.{self.subject}") as span: try: # Decode message data = NATSClient.decode_msgpack(msg) - + if span: span.set_attribute("messaging.destination", msg.subject) if isinstance(data, dict): request_id = data.get("request_id", data.get("id")) if request_id: span.set_attribute("request.id", str(request_id)) - + # Handle message response = await self.handle_message(msg, data) - + # Send response if applicable if response is not None and msg.reply: await self.nats.publish(msg.reply, response) - + except Exception as e: logger.exception(f"Error handling message on {msg.subject}") if span: span.set_attribute("error", True) span.set_attribute("error.message", str(e)) - + # Send error response if reply expected if msg.reply: error_response = { @@ -136,71 +137,71 @@ class Handler(ABC): "type": type(e).__name__, } await self.nats.publish(msg.reply, error_response) - + def _setup_signals(self) -> None: """Set up signal handlers for graceful shutdown.""" loop = asyncio.get_event_loop() - + for sig in (signal.SIGTERM, signal.SIGINT): loop.add_signal_handler(sig, self._handle_signal, sig) - + def _handle_signal(self, sig: signal.Signals) -> None: """Handle shutdown signal.""" logger.info(f"Received {sig.name}, initiating graceful shutdown...") self._shutdown_event.set() - + async def _run(self) -> None: """Main async run loop.""" # Setup telemetry setup_telemetry(self.settings) - + # Start health server self.health_server.start() - + try: # Connect to NATS await self.nats.connect() - + # Run user setup logger.info("Running service setup...") await self.setup() - + # Subscribe to subject await self.nats.subscribe( self.subject, self._message_handler, queue=self.queue_group, ) - + self._running = True logger.info(f"Handler ready, listening on {self.subject}") - + # Wait for shutdown signal await self._shutdown_event.wait() - - except Exception as e: + + except Exception: logger.exception("Fatal error in handler") raise finally: self._running = False - + # Graceful shutdown logger.info("Shutting down...") - + try: await self.teardown() except Exception as e: logger.warning(f"Error during teardown: {e}") - + await self.nats.close() self.health_server.stop() - + logger.info("Shutdown complete") - + def run(self) -> None: """ Run the handler. - + This is the main entry point. It sets up signal handlers and runs the async event loop. """ @@ -209,12 +210,12 @@ class Handler(ABC): level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) - + logger.info(f"Starting {self.settings.service_name} v{self.settings.service_version}") - + # Run the async loop asyncio.run(self._run_with_signals()) - + async def _run_with_signals(self) -> None: """Run with signal handling.""" self._setup_signals() diff --git a/handler_base/health.py b/handler_base/health.py index 8f841b3..5e0b514 100644 --- a/handler_base/health.py +++ b/handler_base/health.py @@ -3,12 +3,13 @@ HTTP health check server. Provides /health and /ready endpoints for Kubernetes probes. """ + import asyncio -import logging -from typing import Callable, Optional, Awaitable -from http.server import HTTPServer, BaseHTTPRequestHandler -import threading import json +import logging +import threading +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Awaitable, Callable, Optional from handler_base.config import Settings @@ -17,16 +18,16 @@ logger = logging.getLogger(__name__) class HealthHandler(BaseHTTPRequestHandler): """HTTP request handler for health checks.""" - + # Class-level state ready_check: Optional[Callable[[], Awaitable[bool]]] = None health_path: str = "/health" ready_path: str = "/ready" - + def log_message(self, format, *args): """Suppress default logging.""" pass - + def do_GET(self): """Handle GET requests for health/ready endpoints.""" if self.path == self.health_path: @@ -35,7 +36,7 @@ class HealthHandler(BaseHTTPRequestHandler): self._handle_ready() else: self._respond_not_found() - + def _handle_ready(self): """Check readiness and respond.""" # Access via class to avoid method binding issues @@ -43,7 +44,7 @@ class HealthHandler(BaseHTTPRequestHandler): if ready_check is None: self._respond_ok({"status": "ready"}) return - + try: # Run the async check in a new event loop loop = asyncio.new_event_loop() @@ -51,7 +52,7 @@ class HealthHandler(BaseHTTPRequestHandler): is_ready = loop.run_until_complete(ready_check()) finally: loop.close() - + if is_ready: self._respond_ok({"status": "ready"}) else: @@ -59,19 +60,19 @@ class HealthHandler(BaseHTTPRequestHandler): except Exception as e: logger.exception("Readiness check failed") self._respond_unavailable({"status": "error", "message": str(e)}) - + def _respond_ok(self, data: dict): self.send_response(200) self.send_header("Content-Type", "application/json") self.end_headers() self.wfile.write(json.dumps(data).encode()) - + def _respond_unavailable(self, data: dict): self.send_response(503) self.send_header("Content-Type", "application/json") self.end_headers() self.wfile.write(json.dumps(data).encode()) - + def _respond_not_found(self): self.send_response(404) self.end_headers() @@ -80,14 +81,14 @@ class HealthHandler(BaseHTTPRequestHandler): class HealthServer: """ Background HTTP server for health checks. - + Usage: server = HealthServer(settings) server.start() # ... run your service ... server.stop() """ - + def __init__( self, settings: Optional[Settings] = None, @@ -97,24 +98,24 @@ class HealthServer: self.ready_check = ready_check self._server: Optional[HTTPServer] = None self._thread: Optional[threading.Thread] = None - + def start(self) -> None: """Start the health check server in a background thread.""" # Configure handler class HealthHandler.ready_check = self.ready_check HealthHandler.health_path = self.settings.health_path HealthHandler.ready_path = self.settings.ready_path - + # Create and start server self._server = HTTPServer(("0.0.0.0", self.settings.health_port), HealthHandler) self._thread = threading.Thread(target=self._server.serve_forever, daemon=True) self._thread.start() - + logger.info( f"Health server started on port {self.settings.health_port} " f"(health: {self.settings.health_path}, ready: {self.settings.ready_path})" ) - + def stop(self) -> None: """Stop the health check server.""" if self._server: diff --git a/handler_base/nats_client.py b/handler_base/nats_client.py index 9c0a996..c42fb41 100644 --- a/handler_base/nats_client.py +++ b/handler_base/nats_client.py @@ -1,9 +1,9 @@ """ NATS client wrapper with connection management and utilities. """ -import asyncio + import logging -from typing import Any, Callable, Optional, Awaitable +from typing import Any, Awaitable, Callable, Optional import msgpack import nats @@ -20,34 +20,34 @@ logger = logging.getLogger(__name__) class NATSClient: """ NATS client with automatic connection management. - + Supports: - Core NATS pub/sub - JetStream for persistence - Queue groups for load balancing - Msgpack serialization """ - + def __init__(self, settings: Optional[Settings] = None): self.settings = settings or Settings() self._nc: Optional[Client] = None self._js: Optional[JetStreamContext] = None self._subscriptions: list = [] - + @property def nc(self) -> Client: """Get the NATS client, raising if not connected.""" if self._nc is None: raise RuntimeError("NATS client not connected. Call connect() first.") return self._nc - + @property def js(self) -> JetStreamContext: """Get JetStream context, raising if not connected.""" if self._js is None: raise RuntimeError("JetStream not initialized. Call connect() first.") return self._js - + async def connect(self) -> None: """Connect to NATS server.""" connect_opts = { @@ -55,16 +55,16 @@ class NATSClient: "reconnect_time_wait": 2, "max_reconnect_attempts": -1, # Infinite } - + if self.settings.nats_user and self.settings.nats_password: connect_opts["user"] = self.settings.nats_user connect_opts["password"] = self.settings.nats_password - + logger.info(f"Connecting to NATS at {self.settings.nats_url}") self._nc = await nats.connect(**connect_opts) self._js = self._nc.jetstream() logger.info("Connected to NATS") - + async def close(self) -> None: """Close NATS connection gracefully.""" if self._nc: @@ -74,13 +74,13 @@ class NATSClient: await sub.drain() except Exception as e: logger.warning(f"Error draining subscription: {e}") - + await self._nc.drain() await self._nc.close() self._nc = None self._js = None logger.info("NATS connection closed") - + async def subscribe( self, subject: str, @@ -89,24 +89,24 @@ class NATSClient: ): """ Subscribe to a subject with a handler function. - + Args: subject: NATS subject to subscribe to handler: Async function to handle messages queue: Optional queue group for load balancing """ queue = queue or self.settings.nats_queue_group - + if queue: sub = await self.nc.subscribe(subject, queue=queue, cb=handler) logger.info(f"Subscribed to {subject} (queue: {queue})") else: sub = await self.nc.subscribe(subject, cb=handler) logger.info(f"Subscribed to {subject}") - + self._subscriptions.append(sub) return sub - + async def publish( self, subject: str, @@ -115,7 +115,7 @@ class NATSClient: ) -> None: """ Publish a message to a subject. - + Args: subject: NATS subject to publish to data: Data to publish (will be serialized) @@ -124,15 +124,16 @@ class NATSClient: with create_span("nats.publish") as span: if span: span.set_attribute("messaging.destination", subject) - + if use_msgpack: payload = msgpack.packb(data, use_bin_type=True) else: import json + payload = json.dumps(data).encode() - + await self.nc.publish(subject, payload) - + async def request( self, subject: str, @@ -142,43 +143,46 @@ class NATSClient: ) -> Any: """ Send a request and wait for response. - + Args: subject: NATS subject to send request to data: Request data timeout: Response timeout in seconds use_msgpack: Whether to use msgpack serialization - + Returns: Decoded response data """ timeout = timeout or self.settings.nats_timeout - + with create_span("nats.request") as span: if span: span.set_attribute("messaging.destination", subject) - + if use_msgpack: payload = msgpack.packb(data, use_bin_type=True) else: import json + payload = json.dumps(data).encode() - + response = await self.nc.request(subject, payload, timeout=timeout) - + if use_msgpack: return msgpack.unpackb(response.data, raw=False) else: import json + return json.loads(response.data.decode()) - + @staticmethod def decode_msgpack(msg: Msg) -> Any: """Decode a msgpack message.""" return msgpack.unpackb(msg.data, raw=False) - + @staticmethod def decode_json(msg: Msg) -> Any: """Decode a JSON message.""" import json + return json.loads(msg.data.decode()) diff --git a/handler_base/telemetry.py b/handler_base/telemetry.py index 67369ad..b8c8484 100644 --- a/handler_base/telemetry.py +++ b/handler_base/telemetry.py @@ -3,26 +3,27 @@ OpenTelemetry setup for tracing and metrics. Supports both gRPC and HTTP exporters, with optional HyperDX integration. """ + import logging import os from typing import Optional, Tuple -from opentelemetry import trace, metrics -from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.sdk.metrics import MeterProvider -from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader -from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter +from opentelemetry import metrics, trace from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter -from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( - OTLPSpanExporter as OTLPSpanExporterHTTP, -) +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter from opentelemetry.exporter.otlp.proto.http.metric_exporter import ( OTLPMetricExporter as OTLPMetricExporterHTTP, ) -from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE +from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( + OTLPSpanExporter as OTLPSpanExporterHTTP, +) from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor from opentelemetry.instrumentation.logging import LoggingInstrumentor +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +from opentelemetry.sdk.resources import SERVICE_NAME, SERVICE_NAMESPACE, SERVICE_VERSION, Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor from handler_base.config import Settings @@ -39,35 +40,37 @@ def setup_telemetry( ) -> Tuple[Optional[trace.Tracer], Optional[metrics.Meter]]: """ Initialize OpenTelemetry tracing and metrics. - + Args: settings: Configuration settings. If None, loads from environment. - + Returns: Tuple of (tracer, meter) or (None, None) if disabled. """ global _tracer, _meter, _initialized - + if _initialized: return _tracer, _meter - + if settings is None: settings = Settings() - + if not settings.otel_enabled: logger.info("OpenTelemetry disabled") _initialized = True return None, None - + # Create resource with service information - resource = Resource.create({ - SERVICE_NAME: settings.service_name, - SERVICE_VERSION: settings.service_version, - SERVICE_NAMESPACE: settings.service_namespace, - "deployment.environment": settings.deployment_env, - "host.name": os.environ.get("HOSTNAME", "unknown"), - }) - + resource = Resource.create( + { + SERVICE_NAME: settings.service_name, + SERVICE_VERSION: settings.service_version, + SERVICE_NAMESPACE: settings.service_namespace, + "deployment.environment": settings.deployment_env, + "host.name": os.environ.get("HOSTNAME", "unknown"), + } + ) + # Determine endpoint and exporter type if settings.hyperdx_enabled and settings.hyperdx_api_key: # HyperDX uses HTTP with API key header @@ -80,7 +83,7 @@ def setup_telemetry( headers = None use_http = settings.otel_use_http logger.info(f"Using OTEL endpoint: {endpoint} (HTTP: {use_http})") - + # Setup tracing if use_http: trace_exporter = OTLPSpanExporterHTTP( @@ -91,11 +94,11 @@ def setup_telemetry( trace_exporter = OTLPSpanExporter( endpoint=endpoint, ) - + tracer_provider = TracerProvider(resource=resource) tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter)) trace.set_tracer_provider(tracer_provider) - + # Setup metrics if use_http: metric_exporter = OTLPMetricExporterHTTP( @@ -106,25 +109,25 @@ def setup_telemetry( metric_exporter = OTLPMetricExporter( endpoint=endpoint, ) - + metric_reader = PeriodicExportingMetricReader( metric_exporter, export_interval_millis=60000, ) meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) metrics.set_meter_provider(meter_provider) - + # Instrument libraries HTTPXClientInstrumentor().instrument() LoggingInstrumentor().instrument(set_logging_format=True) - + # Create tracer and meter for this service _tracer = trace.get_tracer(settings.service_name, settings.service_version) _meter = metrics.get_meter(settings.service_name, settings.service_version) - + logger.info(f"OpenTelemetry initialized for {settings.service_name}") _initialized = True - + return _tracer, _meter @@ -141,7 +144,7 @@ def get_meter() -> Optional[metrics.Meter]: def create_span(name: str, **kwargs): """ Create a new span. - + Usage: with create_span("my_operation") as span: span.set_attribute("key", "value") @@ -150,5 +153,6 @@ def create_span(name: str, **kwargs): if _tracer is None: # Return a no-op context manager from contextlib import nullcontext + return nullcontext() return _tracer.start_as_current_span(name, **kwargs) diff --git a/tests/conftest.py b/tests/conftest.py index 5581e1f..91e8332 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,13 @@ """ Pytest configuration and fixtures. """ + import asyncio import os -from typing import AsyncGenerator from unittest.mock import AsyncMock, MagicMock import pytest - # Set test environment variables before importing handler_base os.environ.setdefault("NATS_URL", "nats://localhost:4222") os.environ.setdefault("REDIS_URL", "redis://localhost:6379") @@ -29,6 +28,7 @@ def event_loop(): def settings(): """Create test settings.""" from handler_base.config import Settings + return Settings( service_name="test-service", service_version="1.0.0-test", @@ -56,7 +56,7 @@ def mock_nats_message(): msg = MagicMock() msg.subject = "test.subject" msg.reply = "test.reply" - msg.data = b'\x82\xa8query\xa5hello\xaarequest_id\xa4test' # msgpack + msg.data = b"\x82\xa8query\xa5hello\xaarequest_id\xa4test" # msgpack return msg diff --git a/tests/unit/test_clients.py b/tests/unit/test_clients.py index b187758..3dfc7fa 100644 --- a/tests/unit/test_clients.py +++ b/tests/unit/test_clients.py @@ -1,44 +1,43 @@ """ Unit tests for service clients. """ -import json + +from unittest.mock import MagicMock + import pytest -from unittest.mock import AsyncMock, MagicMock, patch class TestEmbeddingsClient: """Tests for EmbeddingsClient.""" - + @pytest.fixture def embeddings_client(self, mock_httpx_client): """Create an EmbeddingsClient with mocked HTTP.""" from handler_base.clients.embeddings import EmbeddingsClient - + client = EmbeddingsClient() client._client = mock_httpx_client return client - + @pytest.mark.asyncio async def test_embed_single(self, embeddings_client, mock_httpx_client, sample_embedding): """Test embedding a single text.""" # Setup mock response mock_response = MagicMock() - mock_response.json.return_value = { - "data": [{"embedding": sample_embedding, "index": 0}] - } + mock_response.json.return_value = {"data": [{"embedding": sample_embedding, "index": 0}]} mock_response.raise_for_status = MagicMock() mock_httpx_client.post.return_value = mock_response - + result = await embeddings_client.embed_single("Hello world") - + assert result == sample_embedding mock_httpx_client.post.assert_called_once() - + @pytest.mark.asyncio async def test_embed_batch(self, embeddings_client, mock_httpx_client, sample_embedding): """Test embedding multiple texts.""" texts = ["Hello", "World"] - + mock_response = MagicMock() mock_response.json.return_value = { "data": [ @@ -48,41 +47,41 @@ class TestEmbeddingsClient: } mock_response.raise_for_status = MagicMock() mock_httpx_client.post.return_value = mock_response - + result = await embeddings_client.embed(texts) - + assert len(result) == 2 assert all(len(e) == len(sample_embedding) for e in result) - + @pytest.mark.asyncio async def test_health_check(self, embeddings_client, mock_httpx_client): """Test health check.""" mock_response = MagicMock() mock_response.status_code = 200 mock_httpx_client.get.return_value = mock_response - + result = await embeddings_client.health() - + assert result is True class TestRerankerClient: """Tests for RerankerClient.""" - + @pytest.fixture def reranker_client(self, mock_httpx_client): """Create a RerankerClient with mocked HTTP.""" from handler_base.clients.reranker import RerankerClient - + client = RerankerClient() client._client = mock_httpx_client return client - + @pytest.mark.asyncio async def test_rerank(self, reranker_client, mock_httpx_client, sample_documents): """Test reranking documents.""" texts = [d["text"] for d in sample_documents] - + mock_response = MagicMock() mock_response.json.return_value = { "results": [ @@ -93,9 +92,9 @@ class TestRerankerClient: } mock_response.raise_for_status = MagicMock() mock_httpx_client.post.return_value = mock_response - + result = await reranker_client.rerank("What is ML?", texts) - + assert len(result) == 3 assert result[0]["score"] == 0.95 assert result[0]["index"] == 1 @@ -103,53 +102,48 @@ class TestRerankerClient: class TestLLMClient: """Tests for LLMClient.""" - + @pytest.fixture def llm_client(self, mock_httpx_client): """Create an LLMClient with mocked HTTP.""" from handler_base.clients.llm import LLMClient - + client = LLMClient() client._client = mock_httpx_client return client - + @pytest.mark.asyncio async def test_generate(self, llm_client, mock_httpx_client): """Test generating a response.""" mock_response = MagicMock() mock_response.json.return_value = { - "choices": [ - {"message": {"content": "Hello! I'm an AI assistant."}} - ], - "usage": {"prompt_tokens": 10, "completion_tokens": 20} + "choices": [{"message": {"content": "Hello! I'm an AI assistant."}}], + "usage": {"prompt_tokens": 10, "completion_tokens": 20}, } mock_response.raise_for_status = MagicMock() mock_httpx_client.post.return_value = mock_response - + result = await llm_client.generate("Hello") - + assert result == "Hello! I'm an AI assistant." - + @pytest.mark.asyncio async def test_generate_with_context(self, llm_client, mock_httpx_client): """Test generating with RAG context.""" mock_response = MagicMock() mock_response.json.return_value = { - "choices": [ - {"message": {"content": "Based on the context..."}} - ], - "usage": {} + "choices": [{"message": {"content": "Based on the context..."}}], + "usage": {}, } mock_response.raise_for_status = MagicMock() mock_httpx_client.post.return_value = mock_response - + result = await llm_client.generate( - "What is Python?", - context="Python is a programming language." + "What is Python?", context="Python is a programming language." ) - + assert "Based on the context" in result - + # Verify context was included in the request call_args = mock_httpx_client.post.call_args messages = call_args.kwargs["json"]["messages"] diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index c9134ce..3795152 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1,46 +1,45 @@ """ Unit tests for handler_base.config module. """ -import os -import pytest class TestSettings: """Tests for Settings configuration.""" - + def test_default_settings(self, settings): """Test that default settings are loaded correctly.""" assert settings.service_name == "test-service" assert settings.service_version == "1.0.0-test" assert settings.otel_enabled is False - + def test_settings_from_env(self, monkeypatch): """Test that settings can be loaded from environment variables.""" monkeypatch.setenv("SERVICE_NAME", "env-service") monkeypatch.setenv("SERVICE_VERSION", "2.0.0") monkeypatch.setenv("NATS_URL", "nats://custom:4222") - + # Need to reimport to pick up env changes from handler_base.config import Settings + s = Settings() - + assert s.service_name == "env-service" assert s.service_version == "2.0.0" assert s.nats_url == "nats://custom:4222" - + def test_embeddings_settings(self): """Test EmbeddingsSettings extends base correctly.""" from handler_base.config import EmbeddingsSettings - + s = EmbeddingsSettings() assert hasattr(s, "embeddings_model") assert hasattr(s, "embeddings_batch_size") assert s.embeddings_model == "bge" - + def test_llm_settings(self): """Test LLMSettings has expected defaults.""" from handler_base.config import LLMSettings - + s = LLMSettings() assert s.llm_max_tokens == 2048 assert s.llm_temperature == 0.7 diff --git a/tests/unit/test_health.py b/tests/unit/test_health.py index a5cde25..71d0539 100644 --- a/tests/unit/test_health.py +++ b/tests/unit/test_health.py @@ -1,101 +1,101 @@ """ Unit tests for handler_base.health module. """ -import pytest + import json -import threading import time from http.client import HTTPConnection -from unittest.mock import AsyncMock + +import pytest class TestHealthServer: """Tests for HealthServer.""" - + @pytest.fixture def health_server(self, settings): """Create a HealthServer instance.""" from handler_base.health import HealthServer - + # Use a random high port to avoid conflicts settings.health_port = 18080 return HealthServer(settings) - + def test_start_stop(self, health_server): """Test starting and stopping the health server.""" health_server.start() time.sleep(0.1) # Give server time to start - + # Verify server is running assert health_server._server is not None assert health_server._thread is not None assert health_server._thread.is_alive() - + health_server.stop() time.sleep(0.1) - + assert health_server._server is None - + def test_health_endpoint(self, health_server): """Test the /health endpoint.""" health_server.start() time.sleep(0.1) - + try: conn = HTTPConnection("localhost", 18080, timeout=5) conn.request("GET", "/health") response = conn.getresponse() - + assert response.status == 200 data = json.loads(response.read().decode()) assert data["status"] == "healthy" finally: conn.close() health_server.stop() - + def test_ready_endpoint_default(self, health_server): """Test the /ready endpoint with no custom check.""" health_server.start() time.sleep(0.1) - + try: conn = HTTPConnection("localhost", 18080, timeout=5) conn.request("GET", "/ready") response = conn.getresponse() - + assert response.status == 200 data = json.loads(response.read().decode()) assert data["status"] == "ready" finally: conn.close() health_server.stop() - + def test_ready_endpoint_with_check(self, settings): """Test /ready endpoint with custom readiness check.""" from handler_base.health import HealthServer - + ready_flag = [False] # Use list to allow mutation in closure - + async def check_ready(): return ready_flag[0] - + settings.health_port = 18081 server = HealthServer(settings, ready_check=check_ready) server.start() time.sleep(0.2) - + try: conn = HTTPConnection("localhost", 18081, timeout=5) - + # Should be not ready initially conn.request("GET", "/ready") response = conn.getresponse() response.read() # Consume response body assert response.status == 503 - + # Mark as ready ready_flag[0] = True - + # Need new connection after consuming response conn.close() conn = HTTPConnection("localhost", 18081, timeout=5) @@ -105,17 +105,17 @@ class TestHealthServer: finally: conn.close() server.stop() - + def test_404_for_unknown_path(self, health_server): """Test that unknown paths return 404.""" health_server.start() time.sleep(0.1) - + try: conn = HTTPConnection("localhost", 18080, timeout=5) conn.request("GET", "/unknown") response = conn.getresponse() - + assert response.status == 404 finally: conn.close() diff --git a/tests/unit/test_nats_client.py b/tests/unit/test_nats_client.py index 5fb5507..4f8b667 100644 --- a/tests/unit/test_nats_client.py +++ b/tests/unit/test_nats_client.py @@ -1,48 +1,52 @@ """ Unit tests for handler_base.nats_client module. """ -import pytest + from unittest.mock import AsyncMock, MagicMock, patch + import msgpack +import pytest class TestNATSClient: """Tests for NATSClient.""" - + @pytest.fixture def nats_client(self, settings): """Create a NATSClient instance.""" from handler_base.nats_client import NATSClient + return NATSClient(settings) - + def test_init(self, nats_client, settings): """Test NATSClient initialization.""" assert nats_client.settings == settings assert nats_client._nc is None assert nats_client._js is None - + def test_decode_msgpack(self, nats_client): """Test msgpack decoding.""" data = {"query": "hello", "request_id": "123"} encoded = msgpack.packb(data, use_bin_type=True) - + msg = MagicMock() msg.data = encoded - + result = nats_client.decode_msgpack(msg) assert result == data - + def test_decode_json(self, nats_client): """Test JSON decoding.""" import json + data = {"query": "hello"} - + msg = MagicMock() msg.data = json.dumps(data).encode() - + result = nats_client.decode_json(msg) assert result == data - + @pytest.mark.asyncio async def test_connect(self, nats_client): """Test NATS connection.""" @@ -51,30 +55,30 @@ class TestNATSClient: mock_js = MagicMock() mock_nc.jetstream = MagicMock(return_value=mock_js) # Not async mock_nats.connect = AsyncMock(return_value=mock_nc) - + await nats_client.connect() - + assert nats_client._nc == mock_nc assert nats_client._js == mock_js mock_nats.connect.assert_called_once() - + @pytest.mark.asyncio async def test_publish(self, nats_client): """Test publishing a message.""" mock_nc = AsyncMock() nats_client._nc = mock_nc - + data = {"key": "value"} await nats_client.publish("test.subject", data) - + mock_nc.publish.assert_called_once() call_args = mock_nc.publish.call_args assert call_args.args[0] == "test.subject" - + # Verify msgpack encoding decoded = msgpack.unpackb(call_args.args[1], raw=False) assert decoded == data - + @pytest.mark.asyncio async def test_subscribe(self, nats_client): """Test subscribing to a subject.""" @@ -82,10 +86,10 @@ class TestNATSClient: mock_sub = MagicMock() mock_nc.subscribe = AsyncMock(return_value=mock_sub) nats_client._nc = mock_nc - + handler = AsyncMock() await nats_client.subscribe("test.subject", handler, queue="test-queue") - + mock_nc.subscribe.assert_called_once() call_kwargs = mock_nc.subscribe.call_args.kwargs assert call_kwargs["queue"] == "test-queue"