refactor: consolidate to handler-base, migrate to pyproject.toml, add tests

This commit is contained in:
2026-02-02 07:11:02 -05:00
parent 6ef42b3d2c
commit bed9fa4297
10 changed files with 585 additions and 1132 deletions

View File

@@ -1,29 +1,9 @@
FROM python:3.13-slim # Chat Handler - Using handler-base
ARG BASE_TAG=latest
FROM ghcr.io/billy-davies-2/handler-base:${BASE_TAG}
WORKDIR /app WORKDIR /app
# Install uv for fast, reliable package management
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements first for better caching
COPY requirements.txt .
RUN uv pip install --system --no-cache -r requirements.txt
# Copy application code
COPY chat_handler.py . COPY chat_handler.py .
# Set environment variables
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD python -c "print('healthy')" || exit 1
# Run the application
CMD ["python", "chat_handler.py"] CMD ["python", "chat_handler.py"]

View File

@@ -1,11 +0,0 @@
# Chat Handler v2 - Using handler-base
ARG BASE_TAG=local
FROM ghcr.io/billy-davies-2/handler-base:${BASE_TAG}
WORKDIR /app
# Copy only the handler code (dependencies are in base image)
COPY chat_handler_v2.py ./chat_handler.py
# Run the handler
CMD ["python", "chat_handler.py"]

View File

@@ -4,19 +4,10 @@ Text-based chat pipeline for the DaviesTechLabs AI/ML platform.
## Overview ## Overview
A NATS-based service that handles chat completion requests with RAG (Retrieval Augmented Generation). A NATS-based service that handles chat completion requests with RAG (Retrieval Augmented Generation). It uses the [handler-base](https://git.daviestechlabs.io/daviestechlabs/handler-base) library for standardized NATS handling, telemetry, and health checks.
**Pipeline:** Query → Embeddings → Milvus → Rerank → LLM → (optional TTS) **Pipeline:** Query → Embeddings → Milvus → Rerank → LLM → (optional TTS)
## Versions
| File | Description |
|------|-------------|
| `chat_handler.py` | Standalone implementation (v1) |
| `chat_handler_v2.py` | Uses handler-base library (recommended) |
| `Dockerfile` | Standalone image |
| `Dockerfile.v2` | Handler-base image |
## Architecture ## Architecture
``` ```
@@ -88,19 +79,10 @@ NATS (ai.chat.request)
## Building ## Building
```bash ```bash
# Standalone image (v1) docker build -t chat-handler:latest .
docker build -f Dockerfile -t chat-handler:latest .
# Handler-base image (v2 - recommended) # With specific handler-base tag
docker build -f Dockerfile.v2 -t chat-handler:v2 . docker build --build-arg BASE_TAG=latest -t chat-handler:latest .
```
## Dependencies
The v2 handler depends on [handler-base](https://git.daviestechlabs.io/daviestechlabs/handler-base):
```bash
pip install git+https://git.daviestechlabs.io/daviestechlabs/handler-base.git
``` ```
## Related ## Related

File diff suppressed because it is too large Load Diff

View File

@@ -1,233 +0,0 @@
#!/usr/bin/env python3
"""
Chat Handler Service (Refactored)
Text-based chat pipeline using handler-base:
1. Listen for text on NATS subject "ai.chat.request"
2. Generate embeddings for RAG
3. Retrieve context from Milvus
4. Rerank with BGE reranker
5. Generate response with vLLM
6. Optionally synthesize speech with XTTS
7. Publish result to NATS "ai.chat.response.{request_id}"
"""
import base64
import logging
from typing import Any, Optional
from nats.aio.msg import Msg
from handler_base import Handler, Settings
from handler_base.clients import (
EmbeddingsClient,
RerankerClient,
LLMClient,
TTSClient,
MilvusClient,
)
from handler_base.telemetry import create_span
logger = logging.getLogger("chat-handler")
class ChatSettings(Settings):
"""Chat handler specific settings."""
service_name: str = "chat-handler"
# RAG settings
rag_top_k: int = 10
rag_rerank_top_k: int = 5
rag_collection: str = "documents"
# Response settings
include_sources: bool = True
enable_tts: bool = False
tts_language: str = "en"
class ChatHandler(Handler):
"""
Chat request handler with RAG pipeline.
Request format:
{
"request_id": "uuid",
"query": "user question",
"collection": "optional collection name",
"enable_tts": false,
"system_prompt": "optional custom system prompt"
}
Response format:
{
"request_id": "uuid",
"response": "generated response",
"sources": [{"text": "...", "score": 0.95}],
"audio": "base64 encoded audio (if tts enabled)"
}
"""
def __init__(self):
self.chat_settings = ChatSettings()
super().__init__(
subject="ai.chat.request",
settings=self.chat_settings,
queue_group="chat-handlers",
)
async def setup(self) -> None:
"""Initialize service clients."""
logger.info("Initializing service clients...")
self.embeddings = EmbeddingsClient(self.chat_settings)
self.reranker = RerankerClient(self.chat_settings)
self.llm = LLMClient(self.chat_settings)
self.milvus = MilvusClient(self.chat_settings)
# TTS is optional
if self.chat_settings.enable_tts:
self.tts = TTSClient(self.chat_settings)
else:
self.tts = None
# Connect to Milvus
await self.milvus.connect(self.chat_settings.rag_collection)
logger.info("Service clients initialized")
async def teardown(self) -> None:
"""Clean up service clients."""
logger.info("Closing service clients...")
await self.embeddings.close()
await self.reranker.close()
await self.llm.close()
await self.milvus.close()
if self.tts:
await self.tts.close()
logger.info("Service clients closed")
async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]:
"""Handle incoming chat request."""
request_id = data.get("request_id", "unknown")
query = data.get("query", "")
collection = data.get("collection", self.chat_settings.rag_collection)
enable_tts = data.get("enable_tts", self.chat_settings.enable_tts)
system_prompt = data.get("system_prompt")
logger.info(f"Processing request {request_id}: {query[:50]}...")
with create_span("chat.process") as span:
if span:
span.set_attribute("request.id", request_id)
span.set_attribute("query.length", len(query))
# 1. Generate query embedding
embedding = await self._get_embedding(query)
# 2. Search Milvus for context
documents = await self._search_context(embedding, collection)
# 3. Rerank documents
reranked = await self._rerank_documents(query, documents)
# 4. Build context from top documents
context = self._build_context(reranked)
# 5. Generate LLM response
response_text = await self._generate_response(
query, context, system_prompt
)
# 6. Optionally synthesize speech
audio_b64 = None
if enable_tts and self.tts:
audio_b64 = await self._synthesize_speech(response_text)
# Build response
result = {
"request_id": request_id,
"response": response_text,
}
if self.chat_settings.include_sources:
result["sources"] = [
{"text": d["document"][:200], "score": d["score"]}
for d in reranked[:3]
]
if audio_b64:
result["audio"] = audio_b64
logger.info(f"Completed request {request_id}")
# Publish to response subject
response_subject = f"ai.chat.response.{request_id}"
await self.nats.publish(response_subject, result)
return result
async def _get_embedding(self, text: str) -> list[float]:
"""Generate embedding for query text."""
with create_span("chat.embedding"):
return await self.embeddings.embed_single(text)
async def _search_context(
self, embedding: list[float], collection: str
) -> list[dict]:
"""Search Milvus for relevant documents."""
with create_span("chat.search"):
return await self.milvus.search_with_texts(
embedding,
limit=self.chat_settings.rag_top_k,
text_field="text",
metadata_fields=["source", "title"],
)
async def _rerank_documents(
self, query: str, documents: list[dict]
) -> list[dict]:
"""Rerank documents by relevance to query."""
with create_span("chat.rerank"):
texts = [d.get("text", "") for d in documents]
return await self.reranker.rerank(
query, texts, top_k=self.chat_settings.rag_rerank_top_k
)
def _build_context(self, documents: list[dict]) -> str:
"""Build context string from ranked documents."""
context_parts = []
for i, doc in enumerate(documents, 1):
text = doc.get("document", "")
context_parts.append(f"[{i}] {text}")
return "\n\n".join(context_parts)
async def _generate_response(
self,
query: str,
context: str,
system_prompt: Optional[str] = None,
) -> str:
"""Generate LLM response with context."""
with create_span("chat.generate"):
return await self.llm.generate(
query,
context=context,
system_prompt=system_prompt,
)
async def _synthesize_speech(self, text: str) -> str:
"""Synthesize speech and return base64 encoded audio."""
with create_span("chat.tts"):
audio_bytes = await self.tts.synthesize(
text,
language=self.chat_settings.tts_language,
)
return base64.b64encode(audio_bytes).decode()
if __name__ == "__main__":
ChatHandler().run()

40
pyproject.toml Normal file
View File

@@ -0,0 +1,40 @@
[project]
name = "chat-handler"
version = "1.0.0"
description = "Text chat pipeline with RAG - Query → Embeddings → Milvus → Rerank → LLM"
readme = "README.md"
requires-python = ">=3.11"
license = { text = "MIT" }
authors = [{ name = "Davies Tech Labs" }]
dependencies = [
"handler-base @ git+https://git.daviestechlabs.io/daviestechlabs/handler-base.git",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.23.0",
"ruff>=0.1.0",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["."]
only-include = ["chat_handler.py"]
[tool.ruff]
line-length = 100
target-version = "py311"
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = "-v --tb=short"
filterwarnings = ["ignore::DeprecationWarning"]

View File

@@ -1,15 +0,0 @@
nats-py
httpx
pymilvus
numpy
msgpack
redis>=5.0.0
opentelemetry-api
opentelemetry-sdk
opentelemetry-exporter-otlp-proto-grpc
opentelemetry-exporter-otlp-proto-http
opentelemetry-instrumentation-httpx
opentelemetry-instrumentation-logging
# MLflow for inference metrics tracking
mlflow>=2.10.0
psycopg2-binary>=2.9.0

1
tests/__init__.py Normal file
View File

@@ -0,0 +1 @@
# Chat Handler Tests

81
tests/conftest.py Normal file
View File

@@ -0,0 +1,81 @@
"""
Pytest configuration and fixtures for chat-handler tests.
"""
import asyncio
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# Set test environment variables before importing
os.environ.setdefault("NATS_URL", "nats://localhost:4222")
os.environ.setdefault("REDIS_URL", "redis://localhost:6379")
os.environ.setdefault("MILVUS_HOST", "localhost")
os.environ.setdefault("OTEL_ENABLED", "false")
os.environ.setdefault("MLFLOW_ENABLED", "false")
@pytest.fixture(scope="session")
def event_loop():
"""Create event loop for async tests."""
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture
def sample_embedding():
"""Sample embedding vector."""
return [0.1] * 1024
@pytest.fixture
def sample_documents():
"""Sample search results."""
return [
{"text": "Machine learning is a subset of AI.", "score": 0.95},
{"text": "Deep learning uses neural networks.", "score": 0.90},
{"text": "AI enables intelligent automation.", "score": 0.85},
]
@pytest.fixture
def sample_reranked():
"""Sample reranked results."""
return [
{"document": "Machine learning is a subset of AI.", "score": 0.98},
{"document": "Deep learning uses neural networks.", "score": 0.85},
]
@pytest.fixture
def mock_nats_message():
"""Create a mock NATS message."""
msg = MagicMock()
msg.subject = "ai.chat.request"
msg.reply = "ai.chat.response.test-123"
return msg
@pytest.fixture
def mock_chat_request():
"""Sample chat request payload."""
return {
"request_id": "test-request-123",
"query": "What is machine learning?",
"collection": "test_collection",
"enable_tts": False,
"system_prompt": None,
}
@pytest.fixture
def mock_chat_request_with_tts():
"""Sample chat request with TTS enabled."""
return {
"request_id": "test-request-456",
"query": "Tell me about AI",
"collection": "documents",
"enable_tts": True,
"system_prompt": "You are a helpful assistant.",
}

262
tests/test_chat_handler.py Normal file
View File

@@ -0,0 +1,262 @@
"""
Unit tests for ChatHandler.
"""
import base64
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from chat_handler import ChatHandler, ChatSettings
class TestChatSettings:
"""Tests for ChatSettings configuration."""
def test_default_settings(self):
"""Test default settings values."""
settings = ChatSettings()
assert settings.service_name == "chat-handler"
assert settings.rag_top_k == 10
assert settings.rag_rerank_top_k == 5
assert settings.rag_collection == "documents"
assert settings.include_sources is True
assert settings.enable_tts is False
assert settings.tts_language == "en"
def test_custom_settings(self):
"""Test custom settings."""
settings = ChatSettings(
rag_top_k=20,
rag_collection="custom_docs",
enable_tts=True,
)
assert settings.rag_top_k == 20
assert settings.rag_collection == "custom_docs"
assert settings.enable_tts is True
class TestChatHandler:
"""Tests for ChatHandler."""
@pytest.fixture
def handler(self):
"""Create handler with mocked clients."""
with patch("chat_handler.EmbeddingsClient"), \
patch("chat_handler.RerankerClient"), \
patch("chat_handler.LLMClient"), \
patch("chat_handler.TTSClient"), \
patch("chat_handler.MilvusClient"):
handler = ChatHandler()
# Setup mock clients
handler.embeddings = AsyncMock()
handler.reranker = AsyncMock()
handler.llm = AsyncMock()
handler.milvus = AsyncMock()
handler.tts = None # TTS disabled by default
handler.nats = AsyncMock()
yield handler
@pytest.fixture
def handler_with_tts(self):
"""Create handler with TTS enabled."""
with patch("chat_handler.EmbeddingsClient"), \
patch("chat_handler.RerankerClient"), \
patch("chat_handler.LLMClient"), \
patch("chat_handler.TTSClient"), \
patch("chat_handler.MilvusClient"):
handler = ChatHandler()
handler.chat_settings.enable_tts = True
# Setup mock clients
handler.embeddings = AsyncMock()
handler.reranker = AsyncMock()
handler.llm = AsyncMock()
handler.milvus = AsyncMock()
handler.tts = AsyncMock()
handler.nats = AsyncMock()
yield handler
def test_init(self, handler):
"""Test handler initialization."""
assert handler.subject == "ai.chat.request"
assert handler.queue_group == "chat-handlers"
assert handler.chat_settings.service_name == "chat-handler"
@pytest.mark.asyncio
async def test_handle_message_success(
self,
handler,
mock_nats_message,
mock_chat_request,
sample_embedding,
sample_documents,
sample_reranked,
):
"""Test successful chat request handling."""
# Setup mocks
handler.embeddings.embed_single.return_value = sample_embedding
handler.milvus.search_with_texts.return_value = sample_documents
handler.reranker.rerank.return_value = sample_reranked
handler.llm.generate.return_value = "Machine learning is a subset of AI that..."
# Execute
result = await handler.handle_message(mock_nats_message, mock_chat_request)
# Verify
assert result["request_id"] == "test-request-123"
assert "response" in result
assert result["response"] == "Machine learning is a subset of AI that..."
assert "sources" in result # include_sources is True by default
# Verify pipeline was called
handler.embeddings.embed_single.assert_called_once()
handler.milvus.search_with_texts.assert_called_once()
handler.reranker.rerank.assert_called_once()
handler.llm.generate.assert_called_once()
@pytest.mark.asyncio
async def test_handle_message_without_sources(
self,
handler,
mock_nats_message,
mock_chat_request,
sample_embedding,
sample_documents,
sample_reranked,
):
"""Test response without sources when disabled."""
handler.chat_settings.include_sources = False
handler.embeddings.embed_single.return_value = sample_embedding
handler.milvus.search_with_texts.return_value = sample_documents
handler.reranker.rerank.return_value = sample_reranked
handler.llm.generate.return_value = "Response text"
result = await handler.handle_message(mock_nats_message, mock_chat_request)
assert "sources" not in result
@pytest.mark.asyncio
async def test_handle_message_with_tts(
self,
handler_with_tts,
mock_nats_message,
mock_chat_request_with_tts,
sample_embedding,
sample_documents,
sample_reranked,
):
"""Test response with TTS audio."""
handler = handler_with_tts
handler.embeddings.embed_single.return_value = sample_embedding
handler.milvus.search_with_texts.return_value = sample_documents
handler.reranker.rerank.return_value = sample_reranked
handler.llm.generate.return_value = "AI response"
handler.tts.synthesize.return_value = b"audio_bytes"
result = await handler.handle_message(mock_nats_message, mock_chat_request_with_tts)
assert "audio" in result
handler.tts.synthesize.assert_called_once()
@pytest.mark.asyncio
async def test_handle_message_with_custom_system_prompt(
self,
handler,
mock_nats_message,
sample_embedding,
sample_documents,
sample_reranked,
):
"""Test LLM is called with custom system prompt."""
request = {
"request_id": "test-123",
"query": "Hello",
"system_prompt": "You are a pirate. Respond like one.",
}
handler.embeddings.embed_single.return_value = sample_embedding
handler.milvus.search_with_texts.return_value = sample_documents
handler.reranker.rerank.return_value = sample_reranked
handler.llm.generate.return_value = "Ahoy!"
await handler.handle_message(mock_nats_message, request)
# Verify system_prompt was passed to LLM
handler.llm.generate.assert_called_once()
call_kwargs = handler.llm.generate.call_args.kwargs
assert call_kwargs.get("system_prompt") == "You are a pirate. Respond like one."
def test_build_context(self, handler):
"""Test context building with numbered sources."""
documents = [
{"document": "First doc content"},
{"document": "Second doc content"},
]
context = handler._build_context(documents)
assert "[1]" in context
assert "[2]" in context
assert "First doc content" in context
assert "Second doc content" in context
@pytest.mark.asyncio
async def test_setup_initializes_clients(self):
"""Test that setup initializes all required clients."""
with patch("chat_handler.EmbeddingsClient") as emb_cls, \
patch("chat_handler.RerankerClient") as rer_cls, \
patch("chat_handler.LLMClient") as llm_cls, \
patch("chat_handler.TTSClient") as tts_cls, \
patch("chat_handler.MilvusClient") as mil_cls:
mil_cls.return_value.connect = AsyncMock()
handler = ChatHandler()
await handler.setup()
emb_cls.assert_called_once()
rer_cls.assert_called_once()
llm_cls.assert_called_once()
mil_cls.assert_called_once()
# TTS should not be initialized when disabled
tts_cls.assert_not_called()
@pytest.mark.asyncio
async def test_teardown_closes_clients(self, handler):
"""Test that teardown closes all clients."""
await handler.teardown()
handler.embeddings.close.assert_called_once()
handler.reranker.close.assert_called_once()
handler.llm.close.assert_called_once()
handler.milvus.close.assert_called_once()
@pytest.mark.asyncio
async def test_publishes_to_response_subject(
self,
handler,
mock_nats_message,
mock_chat_request,
sample_embedding,
sample_documents,
sample_reranked,
):
"""Test that result is published to response subject."""
handler.embeddings.embed_single.return_value = sample_embedding
handler.milvus.search_with_texts.return_value = sample_documents
handler.reranker.rerank.return_value = sample_reranked
handler.llm.generate.return_value = "Response"
await handler.handle_message(mock_nats_message, mock_chat_request)
handler.nats.publish.assert_called_once()
call_args = handler.nats.publish.call_args
assert "ai.chat.response.test-request-123" in str(call_args)