From b8ba2379460104572d8203eb112d8ab07a7fd815 Mon Sep 17 00:00:00 2001 From: "Billy D." Date: Mon, 2 Feb 2026 08:44:34 -0500 Subject: [PATCH] fix: ruff formatting and allow-direct-references for handler-base dep --- chat_handler.py | 88 +++++++++++++-------------- pyproject.toml | 3 + tests/conftest.py | 3 +- tests/test_chat_handler.py | 119 +++++++++++++++++++------------------ 4 files changed, 107 insertions(+), 106 deletions(-) diff --git a/chat_handler.py b/chat_handler.py index f720e4e..8cc3197 100644 --- a/chat_handler.py +++ b/chat_handler.py @@ -11,6 +11,7 @@ Text-based chat pipeline using handler-base: 6. Optionally synthesize speech with XTTS 7. Publish result to NATS "ai.chat.response.{request_id}" """ + import base64 import logging from typing import Any, Optional @@ -32,14 +33,14 @@ logger = logging.getLogger("chat-handler") class ChatSettings(Settings): """Chat handler specific settings.""" - + service_name: str = "chat-handler" - + # RAG settings rag_top_k: int = 10 rag_rerank_top_k: int = 5 rag_collection: str = "documents" - + # Response settings include_sources: bool = True enable_tts: bool = False @@ -49,7 +50,7 @@ class ChatSettings(Settings): class ChatHandler(Handler): """ Chat request handler with RAG pipeline. - + Request format: { "request_id": "uuid", @@ -58,7 +59,7 @@ class ChatHandler(Handler): "enable_tts": false, "system_prompt": "optional custom system prompt" } - + Response format: { "request_id": "uuid", @@ -67,7 +68,7 @@ class ChatHandler(Handler): "audio": "base64 encoded audio (if tts enabled)" } """ - + def __init__(self): self.chat_settings = ChatSettings() super().__init__( @@ -75,41 +76,41 @@ class ChatHandler(Handler): settings=self.chat_settings, queue_group="chat-handlers", ) - + async def setup(self) -> None: """Initialize service clients.""" logger.info("Initializing service clients...") - + self.embeddings = EmbeddingsClient(self.chat_settings) self.reranker = RerankerClient(self.chat_settings) self.llm = LLMClient(self.chat_settings) self.milvus = MilvusClient(self.chat_settings) - + # TTS is optional if self.chat_settings.enable_tts: self.tts = TTSClient(self.chat_settings) else: self.tts = None - + # Connect to Milvus await self.milvus.connect(self.chat_settings.rag_collection) - + logger.info("Service clients initialized") - + async def teardown(self) -> None: """Clean up service clients.""" logger.info("Closing service clients...") - + await self.embeddings.close() await self.reranker.close() await self.llm.close() await self.milvus.close() - + if self.tts: await self.tts.close() - + logger.info("Service clients closed") - + async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]: """Handle incoming chat request.""" request_id = data.get("request_id", "unknown") @@ -117,67 +118,62 @@ class ChatHandler(Handler): collection = data.get("collection", self.chat_settings.rag_collection) enable_tts = data.get("enable_tts", self.chat_settings.enable_tts) system_prompt = data.get("system_prompt") - + logger.info(f"Processing request {request_id}: {query[:50]}...") - + with create_span("chat.process") as span: if span: span.set_attribute("request.id", request_id) span.set_attribute("query.length", len(query)) - + # 1. Generate query embedding embedding = await self._get_embedding(query) - + # 2. Search Milvus for context documents = await self._search_context(embedding, collection) - + # 3. Rerank documents reranked = await self._rerank_documents(query, documents) - + # 4. Build context from top documents context = self._build_context(reranked) - + # 5. Generate LLM response - response_text = await self._generate_response( - query, context, system_prompt - ) - + response_text = await self._generate_response(query, context, system_prompt) + # 6. Optionally synthesize speech audio_b64 = None if enable_tts and self.tts: audio_b64 = await self._synthesize_speech(response_text) - + # Build response result = { "request_id": request_id, "response": response_text, } - + if self.chat_settings.include_sources: result["sources"] = [ - {"text": d["document"][:200], "score": d["score"]} - for d in reranked[:3] + {"text": d["document"][:200], "score": d["score"]} for d in reranked[:3] ] - + if audio_b64: result["audio"] = audio_b64 - + logger.info(f"Completed request {request_id}") - + # Publish to response subject response_subject = f"ai.chat.response.{request_id}" await self.nats.publish(response_subject, result) - + return result - + async def _get_embedding(self, text: str) -> list[float]: """Generate embedding for query text.""" with create_span("chat.embedding"): return await self.embeddings.embed_single(text) - - async def _search_context( - self, embedding: list[float], collection: str - ) -> list[dict]: + + async def _search_context(self, embedding: list[float], collection: str) -> list[dict]: """Search Milvus for relevant documents.""" with create_span("chat.search"): return await self.milvus.search_with_texts( @@ -186,17 +182,15 @@ class ChatHandler(Handler): text_field="text", metadata_fields=["source", "title"], ) - - async def _rerank_documents( - self, query: str, documents: list[dict] - ) -> list[dict]: + + async def _rerank_documents(self, query: str, documents: list[dict]) -> list[dict]: """Rerank documents by relevance to query.""" with create_span("chat.rerank"): texts = [d.get("text", "") for d in documents] return await self.reranker.rerank( query, texts, top_k=self.chat_settings.rag_rerank_top_k ) - + def _build_context(self, documents: list[dict]) -> str: """Build context string from ranked documents.""" context_parts = [] @@ -204,7 +198,7 @@ class ChatHandler(Handler): text = doc.get("document", "") context_parts.append(f"[{i}] {text}") return "\n\n".join(context_parts) - + async def _generate_response( self, query: str, @@ -218,7 +212,7 @@ class ChatHandler(Handler): context=context, system_prompt=system_prompt, ) - + async def _synthesize_speech(self, text: str) -> str: """Synthesize speech and return base64 encoded audio.""" with create_span("chat.tts"): diff --git a/pyproject.toml b/pyproject.toml index 610cf74..d1225d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,9 @@ dev = [ requires = ["hatchling"] build-backend = "hatchling.build" +[tool.hatch.metadata] +allow-direct-references = true + [tool.hatch.build.targets.wheel] packages = ["."] only-include = ["chat_handler.py"] diff --git a/tests/conftest.py b/tests/conftest.py index 3c8bef8..7a29ae7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,10 @@ """ Pytest configuration and fixtures for chat-handler tests. """ + import asyncio import os -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock import pytest diff --git a/tests/test_chat_handler.py b/tests/test_chat_handler.py index e435f7f..d857283 100644 --- a/tests/test_chat_handler.py +++ b/tests/test_chat_handler.py @@ -1,20 +1,20 @@ """ Unit tests for ChatHandler. """ -import base64 + import pytest -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch from chat_handler import ChatHandler, ChatSettings class TestChatSettings: """Tests for ChatSettings configuration.""" - + def test_default_settings(self): """Test default settings values.""" settings = ChatSettings() - + assert settings.service_name == "chat-handler" assert settings.rag_top_k == 10 assert settings.rag_rerank_top_k == 5 @@ -22,7 +22,7 @@ class TestChatSettings: assert settings.include_sources is True assert settings.enable_tts is False assert settings.tts_language == "en" - + def test_custom_settings(self): """Test custom settings.""" settings = ChatSettings( @@ -30,7 +30,7 @@ class TestChatSettings: rag_collection="custom_docs", enable_tts=True, ) - + assert settings.rag_top_k == 20 assert settings.rag_collection == "custom_docs" assert settings.enable_tts is True @@ -38,18 +38,19 @@ class TestChatSettings: class TestChatHandler: """Tests for ChatHandler.""" - + @pytest.fixture def handler(self): """Create handler with mocked clients.""" - with patch("chat_handler.EmbeddingsClient"), \ - patch("chat_handler.RerankerClient"), \ - patch("chat_handler.LLMClient"), \ - patch("chat_handler.TTSClient"), \ - patch("chat_handler.MilvusClient"): - + with ( + patch("chat_handler.EmbeddingsClient"), + patch("chat_handler.RerankerClient"), + patch("chat_handler.LLMClient"), + patch("chat_handler.TTSClient"), + patch("chat_handler.MilvusClient"), + ): handler = ChatHandler() - + # Setup mock clients handler.embeddings = AsyncMock() handler.reranker = AsyncMock() @@ -57,21 +58,22 @@ class TestChatHandler: handler.milvus = AsyncMock() handler.tts = None # TTS disabled by default handler.nats = AsyncMock() - + yield handler - + @pytest.fixture def handler_with_tts(self): """Create handler with TTS enabled.""" - with patch("chat_handler.EmbeddingsClient"), \ - patch("chat_handler.RerankerClient"), \ - patch("chat_handler.LLMClient"), \ - patch("chat_handler.TTSClient"), \ - patch("chat_handler.MilvusClient"): - + with ( + patch("chat_handler.EmbeddingsClient"), + patch("chat_handler.RerankerClient"), + patch("chat_handler.LLMClient"), + patch("chat_handler.TTSClient"), + patch("chat_handler.MilvusClient"), + ): handler = ChatHandler() handler.chat_settings.enable_tts = True - + # Setup mock clients handler.embeddings = AsyncMock() handler.reranker = AsyncMock() @@ -79,15 +81,15 @@ class TestChatHandler: handler.milvus = AsyncMock() handler.tts = AsyncMock() handler.nats = AsyncMock() - + yield handler - + def test_init(self, handler): """Test handler initialization.""" assert handler.subject == "ai.chat.request" assert handler.queue_group == "chat-handlers" assert handler.chat_settings.service_name == "chat-handler" - + @pytest.mark.asyncio async def test_handle_message_success( self, @@ -104,22 +106,22 @@ class TestChatHandler: handler.milvus.search_with_texts.return_value = sample_documents handler.reranker.rerank.return_value = sample_reranked handler.llm.generate.return_value = "Machine learning is a subset of AI that..." - + # Execute result = await handler.handle_message(mock_nats_message, mock_chat_request) - + # Verify assert result["request_id"] == "test-request-123" assert "response" in result assert result["response"] == "Machine learning is a subset of AI that..." assert "sources" in result # include_sources is True by default - + # Verify pipeline was called handler.embeddings.embed_single.assert_called_once() handler.milvus.search_with_texts.assert_called_once() handler.reranker.rerank.assert_called_once() handler.llm.generate.assert_called_once() - + @pytest.mark.asyncio async def test_handle_message_without_sources( self, @@ -132,16 +134,16 @@ class TestChatHandler: ): """Test response without sources when disabled.""" handler.chat_settings.include_sources = False - + handler.embeddings.embed_single.return_value = sample_embedding handler.milvus.search_with_texts.return_value = sample_documents handler.reranker.rerank.return_value = sample_reranked handler.llm.generate.return_value = "Response text" - + result = await handler.handle_message(mock_nats_message, mock_chat_request) - + assert "sources" not in result - + @pytest.mark.asyncio async def test_handle_message_with_tts( self, @@ -154,18 +156,18 @@ class TestChatHandler: ): """Test response with TTS audio.""" handler = handler_with_tts - + handler.embeddings.embed_single.return_value = sample_embedding handler.milvus.search_with_texts.return_value = sample_documents handler.reranker.rerank.return_value = sample_reranked handler.llm.generate.return_value = "AI response" handler.tts.synthesize.return_value = b"audio_bytes" - + result = await handler.handle_message(mock_nats_message, mock_chat_request_with_tts) - + assert "audio" in result handler.tts.synthesize.assert_called_once() - + @pytest.mark.asyncio async def test_handle_message_with_custom_system_prompt( self, @@ -181,64 +183,65 @@ class TestChatHandler: "query": "Hello", "system_prompt": "You are a pirate. Respond like one.", } - + handler.embeddings.embed_single.return_value = sample_embedding handler.milvus.search_with_texts.return_value = sample_documents handler.reranker.rerank.return_value = sample_reranked handler.llm.generate.return_value = "Ahoy!" - + await handler.handle_message(mock_nats_message, request) - + # Verify system_prompt was passed to LLM handler.llm.generate.assert_called_once() call_kwargs = handler.llm.generate.call_args.kwargs assert call_kwargs.get("system_prompt") == "You are a pirate. Respond like one." - + def test_build_context(self, handler): """Test context building with numbered sources.""" documents = [ {"document": "First doc content"}, {"document": "Second doc content"}, ] - + context = handler._build_context(documents) - + assert "[1]" in context assert "[2]" in context assert "First doc content" in context assert "Second doc content" in context - + @pytest.mark.asyncio async def test_setup_initializes_clients(self): """Test that setup initializes all required clients.""" - with patch("chat_handler.EmbeddingsClient") as emb_cls, \ - patch("chat_handler.RerankerClient") as rer_cls, \ - patch("chat_handler.LLMClient") as llm_cls, \ - patch("chat_handler.TTSClient") as tts_cls, \ - patch("chat_handler.MilvusClient") as mil_cls: - + with ( + patch("chat_handler.EmbeddingsClient") as emb_cls, + patch("chat_handler.RerankerClient") as rer_cls, + patch("chat_handler.LLMClient") as llm_cls, + patch("chat_handler.TTSClient") as tts_cls, + patch("chat_handler.MilvusClient") as mil_cls, + ): mil_cls.return_value.connect = AsyncMock() - + handler = ChatHandler() await handler.setup() - + emb_cls.assert_called_once() rer_cls.assert_called_once() llm_cls.assert_called_once() mil_cls.assert_called_once() # TTS should not be initialized when disabled tts_cls.assert_not_called() - + @pytest.mark.asyncio async def test_teardown_closes_clients(self, handler): """Test that teardown closes all clients.""" await handler.teardown() - + handler.embeddings.close.assert_called_once() handler.reranker.close.assert_called_once() handler.llm.close.assert_called_once() handler.milvus.close.assert_called_once() - + @pytest.mark.asyncio async def test_publishes_to_response_subject( self, @@ -254,9 +257,9 @@ class TestChatHandler: handler.milvus.search_with_texts.return_value = sample_documents handler.reranker.rerank.return_value = sample_reranked handler.llm.generate.return_value = "Response" - + await handler.handle_message(mock_nats_message, mock_chat_request) - + handler.nats.publish.assert_called_once() call_args = handler.nats.publish.call_args assert "ai.chat.response.test-request-123" in str(call_args)