fix: ruff formatting and allow-direct-references for handler-base dep

This commit is contained in:
2026-02-02 08:44:34 -05:00
parent 09e8135f93
commit 9d26facfaa
4 changed files with 107 additions and 106 deletions

View File

@@ -11,6 +11,7 @@ Text-based chat pipeline using handler-base:
6. Optionally synthesize speech with XTTS 6. Optionally synthesize speech with XTTS
7. Publish result to NATS "ai.chat.response.{request_id}" 7. Publish result to NATS "ai.chat.response.{request_id}"
""" """
import base64 import base64
import logging import logging
from typing import Any, Optional from typing import Any, Optional
@@ -32,14 +33,14 @@ logger = logging.getLogger("chat-handler")
class ChatSettings(Settings): class ChatSettings(Settings):
"""Chat handler specific settings.""" """Chat handler specific settings."""
service_name: str = "chat-handler" service_name: str = "chat-handler"
# RAG settings # RAG settings
rag_top_k: int = 10 rag_top_k: int = 10
rag_rerank_top_k: int = 5 rag_rerank_top_k: int = 5
rag_collection: str = "documents" rag_collection: str = "documents"
# Response settings # Response settings
include_sources: bool = True include_sources: bool = True
enable_tts: bool = False enable_tts: bool = False
@@ -49,7 +50,7 @@ class ChatSettings(Settings):
class ChatHandler(Handler): class ChatHandler(Handler):
""" """
Chat request handler with RAG pipeline. Chat request handler with RAG pipeline.
Request format: Request format:
{ {
"request_id": "uuid", "request_id": "uuid",
@@ -58,7 +59,7 @@ class ChatHandler(Handler):
"enable_tts": false, "enable_tts": false,
"system_prompt": "optional custom system prompt" "system_prompt": "optional custom system prompt"
} }
Response format: Response format:
{ {
"request_id": "uuid", "request_id": "uuid",
@@ -67,7 +68,7 @@ class ChatHandler(Handler):
"audio": "base64 encoded audio (if tts enabled)" "audio": "base64 encoded audio (if tts enabled)"
} }
""" """
def __init__(self): def __init__(self):
self.chat_settings = ChatSettings() self.chat_settings = ChatSettings()
super().__init__( super().__init__(
@@ -75,41 +76,41 @@ class ChatHandler(Handler):
settings=self.chat_settings, settings=self.chat_settings,
queue_group="chat-handlers", queue_group="chat-handlers",
) )
async def setup(self) -> None: async def setup(self) -> None:
"""Initialize service clients.""" """Initialize service clients."""
logger.info("Initializing service clients...") logger.info("Initializing service clients...")
self.embeddings = EmbeddingsClient(self.chat_settings) self.embeddings = EmbeddingsClient(self.chat_settings)
self.reranker = RerankerClient(self.chat_settings) self.reranker = RerankerClient(self.chat_settings)
self.llm = LLMClient(self.chat_settings) self.llm = LLMClient(self.chat_settings)
self.milvus = MilvusClient(self.chat_settings) self.milvus = MilvusClient(self.chat_settings)
# TTS is optional # TTS is optional
if self.chat_settings.enable_tts: if self.chat_settings.enable_tts:
self.tts = TTSClient(self.chat_settings) self.tts = TTSClient(self.chat_settings)
else: else:
self.tts = None self.tts = None
# Connect to Milvus # Connect to Milvus
await self.milvus.connect(self.chat_settings.rag_collection) await self.milvus.connect(self.chat_settings.rag_collection)
logger.info("Service clients initialized") logger.info("Service clients initialized")
async def teardown(self) -> None: async def teardown(self) -> None:
"""Clean up service clients.""" """Clean up service clients."""
logger.info("Closing service clients...") logger.info("Closing service clients...")
await self.embeddings.close() await self.embeddings.close()
await self.reranker.close() await self.reranker.close()
await self.llm.close() await self.llm.close()
await self.milvus.close() await self.milvus.close()
if self.tts: if self.tts:
await self.tts.close() await self.tts.close()
logger.info("Service clients closed") logger.info("Service clients closed")
async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]: async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]:
"""Handle incoming chat request.""" """Handle incoming chat request."""
request_id = data.get("request_id", "unknown") request_id = data.get("request_id", "unknown")
@@ -117,67 +118,62 @@ class ChatHandler(Handler):
collection = data.get("collection", self.chat_settings.rag_collection) collection = data.get("collection", self.chat_settings.rag_collection)
enable_tts = data.get("enable_tts", self.chat_settings.enable_tts) enable_tts = data.get("enable_tts", self.chat_settings.enable_tts)
system_prompt = data.get("system_prompt") system_prompt = data.get("system_prompt")
logger.info(f"Processing request {request_id}: {query[:50]}...") logger.info(f"Processing request {request_id}: {query[:50]}...")
with create_span("chat.process") as span: with create_span("chat.process") as span:
if span: if span:
span.set_attribute("request.id", request_id) span.set_attribute("request.id", request_id)
span.set_attribute("query.length", len(query)) span.set_attribute("query.length", len(query))
# 1. Generate query embedding # 1. Generate query embedding
embedding = await self._get_embedding(query) embedding = await self._get_embedding(query)
# 2. Search Milvus for context # 2. Search Milvus for context
documents = await self._search_context(embedding, collection) documents = await self._search_context(embedding, collection)
# 3. Rerank documents # 3. Rerank documents
reranked = await self._rerank_documents(query, documents) reranked = await self._rerank_documents(query, documents)
# 4. Build context from top documents # 4. Build context from top documents
context = self._build_context(reranked) context = self._build_context(reranked)
# 5. Generate LLM response # 5. Generate LLM response
response_text = await self._generate_response( response_text = await self._generate_response(query, context, system_prompt)
query, context, system_prompt
)
# 6. Optionally synthesize speech # 6. Optionally synthesize speech
audio_b64 = None audio_b64 = None
if enable_tts and self.tts: if enable_tts and self.tts:
audio_b64 = await self._synthesize_speech(response_text) audio_b64 = await self._synthesize_speech(response_text)
# Build response # Build response
result = { result = {
"request_id": request_id, "request_id": request_id,
"response": response_text, "response": response_text,
} }
if self.chat_settings.include_sources: if self.chat_settings.include_sources:
result["sources"] = [ result["sources"] = [
{"text": d["document"][:200], "score": d["score"]} {"text": d["document"][:200], "score": d["score"]} for d in reranked[:3]
for d in reranked[:3]
] ]
if audio_b64: if audio_b64:
result["audio"] = audio_b64 result["audio"] = audio_b64
logger.info(f"Completed request {request_id}") logger.info(f"Completed request {request_id}")
# Publish to response subject # Publish to response subject
response_subject = f"ai.chat.response.{request_id}" response_subject = f"ai.chat.response.{request_id}"
await self.nats.publish(response_subject, result) await self.nats.publish(response_subject, result)
return result return result
async def _get_embedding(self, text: str) -> list[float]: async def _get_embedding(self, text: str) -> list[float]:
"""Generate embedding for query text.""" """Generate embedding for query text."""
with create_span("chat.embedding"): with create_span("chat.embedding"):
return await self.embeddings.embed_single(text) return await self.embeddings.embed_single(text)
async def _search_context( async def _search_context(self, embedding: list[float], collection: str) -> list[dict]:
self, embedding: list[float], collection: str
) -> list[dict]:
"""Search Milvus for relevant documents.""" """Search Milvus for relevant documents."""
with create_span("chat.search"): with create_span("chat.search"):
return await self.milvus.search_with_texts( return await self.milvus.search_with_texts(
@@ -186,17 +182,15 @@ class ChatHandler(Handler):
text_field="text", text_field="text",
metadata_fields=["source", "title"], metadata_fields=["source", "title"],
) )
async def _rerank_documents( async def _rerank_documents(self, query: str, documents: list[dict]) -> list[dict]:
self, query: str, documents: list[dict]
) -> list[dict]:
"""Rerank documents by relevance to query.""" """Rerank documents by relevance to query."""
with create_span("chat.rerank"): with create_span("chat.rerank"):
texts = [d.get("text", "") for d in documents] texts = [d.get("text", "") for d in documents]
return await self.reranker.rerank( return await self.reranker.rerank(
query, texts, top_k=self.chat_settings.rag_rerank_top_k query, texts, top_k=self.chat_settings.rag_rerank_top_k
) )
def _build_context(self, documents: list[dict]) -> str: def _build_context(self, documents: list[dict]) -> str:
"""Build context string from ranked documents.""" """Build context string from ranked documents."""
context_parts = [] context_parts = []
@@ -204,7 +198,7 @@ class ChatHandler(Handler):
text = doc.get("document", "") text = doc.get("document", "")
context_parts.append(f"[{i}] {text}") context_parts.append(f"[{i}] {text}")
return "\n\n".join(context_parts) return "\n\n".join(context_parts)
async def _generate_response( async def _generate_response(
self, self,
query: str, query: str,
@@ -218,7 +212,7 @@ class ChatHandler(Handler):
context=context, context=context,
system_prompt=system_prompt, system_prompt=system_prompt,
) )
async def _synthesize_speech(self, text: str) -> str: async def _synthesize_speech(self, text: str) -> str:
"""Synthesize speech and return base64 encoded audio.""" """Synthesize speech and return base64 encoded audio."""
with create_span("chat.tts"): with create_span("chat.tts"):

View File

@@ -22,6 +22,9 @@ dev = [
requires = ["hatchling"] requires = ["hatchling"]
build-backend = "hatchling.build" build-backend = "hatchling.build"
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build.targets.wheel] [tool.hatch.build.targets.wheel]
packages = ["."] packages = ["."]
only-include = ["chat_handler.py"] only-include = ["chat_handler.py"]

View File

@@ -1,9 +1,10 @@
""" """
Pytest configuration and fixtures for chat-handler tests. Pytest configuration and fixtures for chat-handler tests.
""" """
import asyncio import asyncio
import os import os
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import MagicMock
import pytest import pytest

View File

@@ -1,20 +1,20 @@
""" """
Unit tests for ChatHandler. Unit tests for ChatHandler.
""" """
import base64
import pytest import pytest
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, patch
from chat_handler import ChatHandler, ChatSettings from chat_handler import ChatHandler, ChatSettings
class TestChatSettings: class TestChatSettings:
"""Tests for ChatSettings configuration.""" """Tests for ChatSettings configuration."""
def test_default_settings(self): def test_default_settings(self):
"""Test default settings values.""" """Test default settings values."""
settings = ChatSettings() settings = ChatSettings()
assert settings.service_name == "chat-handler" assert settings.service_name == "chat-handler"
assert settings.rag_top_k == 10 assert settings.rag_top_k == 10
assert settings.rag_rerank_top_k == 5 assert settings.rag_rerank_top_k == 5
@@ -22,7 +22,7 @@ class TestChatSettings:
assert settings.include_sources is True assert settings.include_sources is True
assert settings.enable_tts is False assert settings.enable_tts is False
assert settings.tts_language == "en" assert settings.tts_language == "en"
def test_custom_settings(self): def test_custom_settings(self):
"""Test custom settings.""" """Test custom settings."""
settings = ChatSettings( settings = ChatSettings(
@@ -30,7 +30,7 @@ class TestChatSettings:
rag_collection="custom_docs", rag_collection="custom_docs",
enable_tts=True, enable_tts=True,
) )
assert settings.rag_top_k == 20 assert settings.rag_top_k == 20
assert settings.rag_collection == "custom_docs" assert settings.rag_collection == "custom_docs"
assert settings.enable_tts is True assert settings.enable_tts is True
@@ -38,18 +38,19 @@ class TestChatSettings:
class TestChatHandler: class TestChatHandler:
"""Tests for ChatHandler.""" """Tests for ChatHandler."""
@pytest.fixture @pytest.fixture
def handler(self): def handler(self):
"""Create handler with mocked clients.""" """Create handler with mocked clients."""
with patch("chat_handler.EmbeddingsClient"), \ with (
patch("chat_handler.RerankerClient"), \ patch("chat_handler.EmbeddingsClient"),
patch("chat_handler.LLMClient"), \ patch("chat_handler.RerankerClient"),
patch("chat_handler.TTSClient"), \ patch("chat_handler.LLMClient"),
patch("chat_handler.MilvusClient"): patch("chat_handler.TTSClient"),
patch("chat_handler.MilvusClient"),
):
handler = ChatHandler() handler = ChatHandler()
# Setup mock clients # Setup mock clients
handler.embeddings = AsyncMock() handler.embeddings = AsyncMock()
handler.reranker = AsyncMock() handler.reranker = AsyncMock()
@@ -57,21 +58,22 @@ class TestChatHandler:
handler.milvus = AsyncMock() handler.milvus = AsyncMock()
handler.tts = None # TTS disabled by default handler.tts = None # TTS disabled by default
handler.nats = AsyncMock() handler.nats = AsyncMock()
yield handler yield handler
@pytest.fixture @pytest.fixture
def handler_with_tts(self): def handler_with_tts(self):
"""Create handler with TTS enabled.""" """Create handler with TTS enabled."""
with patch("chat_handler.EmbeddingsClient"), \ with (
patch("chat_handler.RerankerClient"), \ patch("chat_handler.EmbeddingsClient"),
patch("chat_handler.LLMClient"), \ patch("chat_handler.RerankerClient"),
patch("chat_handler.TTSClient"), \ patch("chat_handler.LLMClient"),
patch("chat_handler.MilvusClient"): patch("chat_handler.TTSClient"),
patch("chat_handler.MilvusClient"),
):
handler = ChatHandler() handler = ChatHandler()
handler.chat_settings.enable_tts = True handler.chat_settings.enable_tts = True
# Setup mock clients # Setup mock clients
handler.embeddings = AsyncMock() handler.embeddings = AsyncMock()
handler.reranker = AsyncMock() handler.reranker = AsyncMock()
@@ -79,15 +81,15 @@ class TestChatHandler:
handler.milvus = AsyncMock() handler.milvus = AsyncMock()
handler.tts = AsyncMock() handler.tts = AsyncMock()
handler.nats = AsyncMock() handler.nats = AsyncMock()
yield handler yield handler
def test_init(self, handler): def test_init(self, handler):
"""Test handler initialization.""" """Test handler initialization."""
assert handler.subject == "ai.chat.request" assert handler.subject == "ai.chat.request"
assert handler.queue_group == "chat-handlers" assert handler.queue_group == "chat-handlers"
assert handler.chat_settings.service_name == "chat-handler" assert handler.chat_settings.service_name == "chat-handler"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_message_success( async def test_handle_message_success(
self, self,
@@ -104,22 +106,22 @@ class TestChatHandler:
handler.milvus.search_with_texts.return_value = sample_documents handler.milvus.search_with_texts.return_value = sample_documents
handler.reranker.rerank.return_value = sample_reranked handler.reranker.rerank.return_value = sample_reranked
handler.llm.generate.return_value = "Machine learning is a subset of AI that..." handler.llm.generate.return_value = "Machine learning is a subset of AI that..."
# Execute # Execute
result = await handler.handle_message(mock_nats_message, mock_chat_request) result = await handler.handle_message(mock_nats_message, mock_chat_request)
# Verify # Verify
assert result["request_id"] == "test-request-123" assert result["request_id"] == "test-request-123"
assert "response" in result assert "response" in result
assert result["response"] == "Machine learning is a subset of AI that..." assert result["response"] == "Machine learning is a subset of AI that..."
assert "sources" in result # include_sources is True by default assert "sources" in result # include_sources is True by default
# Verify pipeline was called # Verify pipeline was called
handler.embeddings.embed_single.assert_called_once() handler.embeddings.embed_single.assert_called_once()
handler.milvus.search_with_texts.assert_called_once() handler.milvus.search_with_texts.assert_called_once()
handler.reranker.rerank.assert_called_once() handler.reranker.rerank.assert_called_once()
handler.llm.generate.assert_called_once() handler.llm.generate.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_message_without_sources( async def test_handle_message_without_sources(
self, self,
@@ -132,16 +134,16 @@ class TestChatHandler:
): ):
"""Test response without sources when disabled.""" """Test response without sources when disabled."""
handler.chat_settings.include_sources = False handler.chat_settings.include_sources = False
handler.embeddings.embed_single.return_value = sample_embedding handler.embeddings.embed_single.return_value = sample_embedding
handler.milvus.search_with_texts.return_value = sample_documents handler.milvus.search_with_texts.return_value = sample_documents
handler.reranker.rerank.return_value = sample_reranked handler.reranker.rerank.return_value = sample_reranked
handler.llm.generate.return_value = "Response text" handler.llm.generate.return_value = "Response text"
result = await handler.handle_message(mock_nats_message, mock_chat_request) result = await handler.handle_message(mock_nats_message, mock_chat_request)
assert "sources" not in result assert "sources" not in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_message_with_tts( async def test_handle_message_with_tts(
self, self,
@@ -154,18 +156,18 @@ class TestChatHandler:
): ):
"""Test response with TTS audio.""" """Test response with TTS audio."""
handler = handler_with_tts handler = handler_with_tts
handler.embeddings.embed_single.return_value = sample_embedding handler.embeddings.embed_single.return_value = sample_embedding
handler.milvus.search_with_texts.return_value = sample_documents handler.milvus.search_with_texts.return_value = sample_documents
handler.reranker.rerank.return_value = sample_reranked handler.reranker.rerank.return_value = sample_reranked
handler.llm.generate.return_value = "AI response" handler.llm.generate.return_value = "AI response"
handler.tts.synthesize.return_value = b"audio_bytes" handler.tts.synthesize.return_value = b"audio_bytes"
result = await handler.handle_message(mock_nats_message, mock_chat_request_with_tts) result = await handler.handle_message(mock_nats_message, mock_chat_request_with_tts)
assert "audio" in result assert "audio" in result
handler.tts.synthesize.assert_called_once() handler.tts.synthesize.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_message_with_custom_system_prompt( async def test_handle_message_with_custom_system_prompt(
self, self,
@@ -181,64 +183,65 @@ class TestChatHandler:
"query": "Hello", "query": "Hello",
"system_prompt": "You are a pirate. Respond like one.", "system_prompt": "You are a pirate. Respond like one.",
} }
handler.embeddings.embed_single.return_value = sample_embedding handler.embeddings.embed_single.return_value = sample_embedding
handler.milvus.search_with_texts.return_value = sample_documents handler.milvus.search_with_texts.return_value = sample_documents
handler.reranker.rerank.return_value = sample_reranked handler.reranker.rerank.return_value = sample_reranked
handler.llm.generate.return_value = "Ahoy!" handler.llm.generate.return_value = "Ahoy!"
await handler.handle_message(mock_nats_message, request) await handler.handle_message(mock_nats_message, request)
# Verify system_prompt was passed to LLM # Verify system_prompt was passed to LLM
handler.llm.generate.assert_called_once() handler.llm.generate.assert_called_once()
call_kwargs = handler.llm.generate.call_args.kwargs call_kwargs = handler.llm.generate.call_args.kwargs
assert call_kwargs.get("system_prompt") == "You are a pirate. Respond like one." assert call_kwargs.get("system_prompt") == "You are a pirate. Respond like one."
def test_build_context(self, handler): def test_build_context(self, handler):
"""Test context building with numbered sources.""" """Test context building with numbered sources."""
documents = [ documents = [
{"document": "First doc content"}, {"document": "First doc content"},
{"document": "Second doc content"}, {"document": "Second doc content"},
] ]
context = handler._build_context(documents) context = handler._build_context(documents)
assert "[1]" in context assert "[1]" in context
assert "[2]" in context assert "[2]" in context
assert "First doc content" in context assert "First doc content" in context
assert "Second doc content" in context assert "Second doc content" in context
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_initializes_clients(self): async def test_setup_initializes_clients(self):
"""Test that setup initializes all required clients.""" """Test that setup initializes all required clients."""
with patch("chat_handler.EmbeddingsClient") as emb_cls, \ with (
patch("chat_handler.RerankerClient") as rer_cls, \ patch("chat_handler.EmbeddingsClient") as emb_cls,
patch("chat_handler.LLMClient") as llm_cls, \ patch("chat_handler.RerankerClient") as rer_cls,
patch("chat_handler.TTSClient") as tts_cls, \ patch("chat_handler.LLMClient") as llm_cls,
patch("chat_handler.MilvusClient") as mil_cls: patch("chat_handler.TTSClient") as tts_cls,
patch("chat_handler.MilvusClient") as mil_cls,
):
mil_cls.return_value.connect = AsyncMock() mil_cls.return_value.connect = AsyncMock()
handler = ChatHandler() handler = ChatHandler()
await handler.setup() await handler.setup()
emb_cls.assert_called_once() emb_cls.assert_called_once()
rer_cls.assert_called_once() rer_cls.assert_called_once()
llm_cls.assert_called_once() llm_cls.assert_called_once()
mil_cls.assert_called_once() mil_cls.assert_called_once()
# TTS should not be initialized when disabled # TTS should not be initialized when disabled
tts_cls.assert_not_called() tts_cls.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_teardown_closes_clients(self, handler): async def test_teardown_closes_clients(self, handler):
"""Test that teardown closes all clients.""" """Test that teardown closes all clients."""
await handler.teardown() await handler.teardown()
handler.embeddings.close.assert_called_once() handler.embeddings.close.assert_called_once()
handler.reranker.close.assert_called_once() handler.reranker.close.assert_called_once()
handler.llm.close.assert_called_once() handler.llm.close.assert_called_once()
handler.milvus.close.assert_called_once() handler.milvus.close.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_publishes_to_response_subject( async def test_publishes_to_response_subject(
self, self,
@@ -254,9 +257,9 @@ class TestChatHandler:
handler.milvus.search_with_texts.return_value = sample_documents handler.milvus.search_with_texts.return_value = sample_documents
handler.reranker.rerank.return_value = sample_reranked handler.reranker.rerank.return_value = sample_reranked
handler.llm.generate.return_value = "Response" handler.llm.generate.return_value = "Response"
await handler.handle_message(mock_nats_message, mock_chat_request) await handler.handle_message(mock_nats_message, mock_chat_request)
handler.nats.publish.assert_called_once() handler.nats.publish.assert_called_once()
call_args = handler.nats.publish.call_args call_args = handler.nats.publish.call_args
assert "ai.chat.response.test-request-123" in str(call_args) assert "ai.chat.response.test-request-123" in str(call_args)