refactor: consolidate to handler-base, migrate to pyproject.toml, add tests

This commit is contained in:
2026-02-02 07:10:54 -05:00
parent f0b626a5e7
commit 77d6822a63
10 changed files with 548 additions and 1122 deletions

View File

@@ -1,29 +1,9 @@
FROM python:3.13-slim # Voice Assistant - Using handler-base with audio support
ARG BASE_TAG=latest
FROM ghcr.io/billy-davies-2/handler-base:${BASE_TAG}
WORKDIR /app WORKDIR /app
# Install uv for fast, reliable package management
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements first for better caching
COPY requirements.txt .
RUN uv pip install --system --no-cache -r requirements.txt
# Copy application code
COPY voice_assistant.py . COPY voice_assistant.py .
# Set environment variables
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD python -c "print('healthy')" || exit 1
# Run the application
CMD ["python", "voice_assistant.py"] CMD ["python", "voice_assistant.py"]

View File

@@ -1,9 +0,0 @@
# Voice Assistant v2 - Using handler-base with audio support
ARG BASE_TAG=local-audio
FROM ghcr.io/billy-davies-2/handler-base:${BASE_TAG}
WORKDIR /app
COPY voice_assistant_v2.py ./voice_assistant.py
CMD ["python", "voice_assistant.py"]

View File

@@ -6,17 +6,10 @@ End-to-end voice assistant pipeline for the DaviesTechLabs AI/ML platform.
### Real-time Handler (NATS-based) ### Real-time Handler (NATS-based)
The voice assistant service listens on NATS for audio requests and returns synthesized speech responses. The voice assistant service listens on NATS for audio requests and returns synthesized speech responses. It uses the [handler-base](https://git.daviestechlabs.io/daviestechlabs/handler-base) library for standardized NATS handling, telemetry, and health checks.
**Pipeline:** STT → Embeddings → Milvus RAG → Rerank → LLM → TTS **Pipeline:** STT → Embeddings → Milvus RAG → Rerank → LLM → TTS
| File | Description |
|------|-------------|
| `voice_assistant.py` | Standalone handler (v1) |
| `voice_assistant_v2.py` | Handler using handler-base library |
| `Dockerfile` | Standalone image |
| `Dockerfile.v2` | Handler-base image |
### Kubeflow Pipeline (Batch) ### Kubeflow Pipeline (Batch)
For batch processing or async workflows via Kubeflow Pipelines. For batch processing or async workflows via Kubeflow Pipelines.
@@ -106,11 +99,10 @@ NATS (voice.request)
## Building ## Building
```bash ```bash
# Standalone image (v1) docker build -t voice-assistant:latest .
docker build -f Dockerfile -t voice-assistant:latest .
# Handler-base image (v2 - recommended) # With specific handler-base tag
docker build -f Dockerfile.v2 -t voice-assistant:v2 . docker build --build-arg BASE_TAG=latest -t voice-assistant:latest .
``` ```
## Related ## Related

40
pyproject.toml Normal file
View File

@@ -0,0 +1,40 @@
[project]
name = "voice-assistant"
version = "1.0.0"
description = "Voice assistant pipeline - STT → RAG → LLM → TTS"
readme = "README.md"
requires-python = ">=3.11"
license = { text = "MIT" }
authors = [{ name = "Davies Tech Labs" }]
dependencies = [
"handler-base @ git+https://git.daviestechlabs.io/daviestechlabs/handler-base.git",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.23.0",
"ruff>=0.1.0",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["."]
only-include = ["voice_assistant.py"]
[tool.ruff]
line-length = 100
target-version = "py311"
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = "-v --tb=short"
filterwarnings = ["ignore::DeprecationWarning"]

View File

@@ -1,15 +0,0 @@
nats-py
httpx
pymilvus
numpy
msgpack
redis>=5.0.0
opentelemetry-api
opentelemetry-sdk
opentelemetry-exporter-otlp-proto-grpc
opentelemetry-exporter-otlp-proto-http
opentelemetry-instrumentation-httpx
opentelemetry-instrumentation-logging
# MLflow for inference metrics tracking
mlflow>=2.10.0
psycopg2-binary>=2.9.0

1
tests/__init__.py Normal file
View File

@@ -0,0 +1 @@
# Voice Assistant Tests

113
tests/conftest.py Normal file
View File

@@ -0,0 +1,113 @@
"""
Pytest configuration and fixtures for voice-assistant tests.
"""
import asyncio
import base64
import os
from typing import AsyncGenerator
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# Set test environment variables before importing
os.environ.setdefault("NATS_URL", "nats://localhost:4222")
os.environ.setdefault("REDIS_URL", "redis://localhost:6379")
os.environ.setdefault("MILVUS_HOST", "localhost")
os.environ.setdefault("OTEL_ENABLED", "false")
os.environ.setdefault("MLFLOW_ENABLED", "false")
@pytest.fixture(scope="session")
def event_loop():
"""Create event loop for async tests."""
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture
def sample_audio_b64():
"""Sample base64 encoded audio for testing."""
# 16-bit PCM silence (44 bytes header + 1000 samples)
wav_header = bytes([
0x52, 0x49, 0x46, 0x46, # "RIFF"
0x24, 0x08, 0x00, 0x00, # File size
0x57, 0x41, 0x56, 0x45, # "WAVE"
0x66, 0x6D, 0x74, 0x20, # "fmt "
0x10, 0x00, 0x00, 0x00, # Chunk size
0x01, 0x00, # PCM format
0x01, 0x00, # Mono
0x80, 0x3E, 0x00, 0x00, # Sample rate (16000)
0x00, 0x7D, 0x00, 0x00, # Byte rate
0x02, 0x00, # Block align
0x10, 0x00, # Bits per sample
0x64, 0x61, 0x74, 0x61, # "data"
0x00, 0x08, 0x00, 0x00, # Data size
])
silence = bytes([0x00] * 2048)
return base64.b64encode(wav_header + silence).decode()
@pytest.fixture
def sample_embedding():
"""Sample embedding vector."""
return [0.1] * 1024
@pytest.fixture
def sample_documents():
"""Sample search results."""
return [
{"text": "Machine learning is a subset of AI.", "score": 0.95},
{"text": "Deep learning uses neural networks.", "score": 0.90},
{"text": "AI enables intelligent automation.", "score": 0.85},
]
@pytest.fixture
def sample_reranked():
"""Sample reranked results."""
return [
{"document": "Machine learning is a subset of AI.", "score": 0.98},
{"document": "Deep learning uses neural networks.", "score": 0.85},
]
@pytest.fixture
def mock_nats_message():
"""Create a mock NATS message."""
msg = MagicMock()
msg.subject = "voice.request"
msg.reply = "voice.response.test-123"
return msg
@pytest.fixture
def mock_voice_request(sample_audio_b64):
"""Sample voice request payload."""
return {
"request_id": "test-request-123",
"audio": sample_audio_b64,
"language": "en",
"collection": "test_collection",
}
@pytest.fixture
def mock_clients():
"""Mock all service clients."""
with patch("voice_assistant.STTClient") as stt, \
patch("voice_assistant.EmbeddingsClient") as embeddings, \
patch("voice_assistant.RerankerClient") as reranker, \
patch("voice_assistant.LLMClient") as llm, \
patch("voice_assistant.TTSClient") as tts, \
patch("voice_assistant.MilvusClient") as milvus:
yield {
"stt": stt,
"embeddings": embeddings,
"reranker": reranker,
"llm": llm,
"tts": tts,
"milvus": milvus,
}

View File

@@ -0,0 +1,199 @@
"""
Unit tests for VoiceAssistant handler.
"""
import base64
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
# Import after environment is set up in conftest
from voice_assistant import VoiceAssistant, VoiceSettings
class TestVoiceSettings:
"""Tests for VoiceSettings configuration."""
def test_default_settings(self):
"""Test default settings values."""
settings = VoiceSettings()
assert settings.service_name == "voice-assistant"
assert settings.rag_top_k == 10
assert settings.rag_rerank_top_k == 5
assert settings.rag_collection == "documents"
assert settings.stt_language is None # Auto-detect
assert settings.tts_language == "en"
assert settings.include_transcription is True
assert settings.include_sources is False
def test_custom_settings(self, monkeypatch):
"""Test settings from environment."""
monkeypatch.setenv("RAG_TOP_K", "20")
monkeypatch.setenv("RAG_COLLECTION", "custom_collection")
# Note: Would need to re-instantiate settings to pick up env vars
settings = VoiceSettings(
rag_top_k=20,
rag_collection="custom_collection"
)
assert settings.rag_top_k == 20
assert settings.rag_collection == "custom_collection"
class TestVoiceAssistant:
"""Tests for VoiceAssistant handler."""
@pytest.fixture
def handler(self):
"""Create handler with mocked clients."""
with patch("voice_assistant.STTClient"), \
patch("voice_assistant.EmbeddingsClient"), \
patch("voice_assistant.RerankerClient"), \
patch("voice_assistant.LLMClient"), \
patch("voice_assistant.TTSClient"), \
patch("voice_assistant.MilvusClient"):
handler = VoiceAssistant()
# Setup mock clients
handler.stt = AsyncMock()
handler.embeddings = AsyncMock()
handler.reranker = AsyncMock()
handler.llm = AsyncMock()
handler.tts = AsyncMock()
handler.milvus = AsyncMock()
handler.nats = AsyncMock()
yield handler
def test_init(self, handler):
"""Test handler initialization."""
assert handler.subject == "voice.request"
assert handler.queue_group == "voice-assistants"
assert handler.voice_settings.service_name == "voice-assistant"
@pytest.mark.asyncio
async def test_handle_message_success(
self,
handler,
mock_nats_message,
mock_voice_request,
sample_embedding,
sample_documents,
sample_reranked,
):
"""Test successful voice request handling."""
# Setup mocks
handler.stt.transcribe.return_value = {"text": "What is machine learning?"}
handler.embeddings.embed_single.return_value = sample_embedding
handler.milvus.search_with_texts.return_value = sample_documents
handler.reranker.rerank.return_value = sample_reranked
handler.llm.generate.return_value = "Machine learning is a type of AI."
handler.tts.synthesize.return_value = b"audio_bytes"
# Execute
result = await handler.handle_message(mock_nats_message, mock_voice_request)
# Verify
assert result["request_id"] == "test-request-123"
assert result["response"] == "Machine learning is a type of AI."
assert "audio" in result
assert result["transcription"] == "What is machine learning?"
# Verify pipeline was called
handler.stt.transcribe.assert_called_once()
handler.embeddings.embed_single.assert_called_once()
handler.milvus.search_with_texts.assert_called_once()
handler.reranker.rerank.assert_called_once()
handler.llm.generate.assert_called_once()
handler.tts.synthesize.assert_called_once()
@pytest.mark.asyncio
async def test_handle_message_empty_transcription(
self,
handler,
mock_nats_message,
mock_voice_request,
):
"""Test handling when transcription is empty."""
handler.stt.transcribe.return_value = {"text": ""}
result = await handler.handle_message(mock_nats_message, mock_voice_request)
assert "error" in result
assert result["error"] == "Could not transcribe audio"
# Verify pipeline stopped after transcription
handler.embeddings.embed_single.assert_not_called()
@pytest.mark.asyncio
async def test_handle_message_with_sources(
self,
handler,
mock_nats_message,
mock_voice_request,
sample_embedding,
sample_documents,
sample_reranked,
):
"""Test response includes sources when enabled."""
handler.voice_settings.include_sources = True
# Setup mocks
handler.stt.transcribe.return_value = {"text": "Hello"}
handler.embeddings.embed_single.return_value = sample_embedding
handler.milvus.search_with_texts.return_value = sample_documents
handler.reranker.rerank.return_value = sample_reranked
handler.llm.generate.return_value = "Hi there!"
handler.tts.synthesize.return_value = b"audio"
result = await handler.handle_message(mock_nats_message, mock_voice_request)
assert "sources" in result
assert len(result["sources"]) <= 3
def test_build_context(self, handler):
"""Test context building from documents."""
documents = [
{"document": "First doc content"},
{"document": "Second doc content"},
]
context = handler._build_context(documents)
assert "First doc content" in context
assert "Second doc content" in context
@pytest.mark.asyncio
async def test_setup_initializes_clients(self):
"""Test that setup initializes all clients."""
with patch("voice_assistant.STTClient") as stt_cls, \
patch("voice_assistant.EmbeddingsClient") as emb_cls, \
patch("voice_assistant.RerankerClient") as rer_cls, \
patch("voice_assistant.LLMClient") as llm_cls, \
patch("voice_assistant.TTSClient") as tts_cls, \
patch("voice_assistant.MilvusClient") as mil_cls:
mil_cls.return_value.connect = AsyncMock()
handler = VoiceAssistant()
await handler.setup()
stt_cls.assert_called_once()
emb_cls.assert_called_once()
rer_cls.assert_called_once()
llm_cls.assert_called_once()
tts_cls.assert_called_once()
mil_cls.assert_called_once()
@pytest.mark.asyncio
async def test_teardown_closes_clients(self, handler):
"""Test that teardown closes all clients."""
await handler.teardown()
handler.stt.close.assert_called_once()
handler.embeddings.close.assert_called_once()
handler.reranker.close.assert_called_once()
handler.llm.close.assert_called_once()
handler.tts.close.assert_called_once()
handler.milvus.close.assert_called_once()

File diff suppressed because it is too large Load Diff

View File

@@ -1,234 +0,0 @@
#!/usr/bin/env python3
"""
Voice Assistant Service (Refactored)
End-to-end voice assistant pipeline using handler-base:
1. Listen for audio on NATS subject "voice.request"
2. Transcribe with Whisper (STT)
3. Generate embeddings for RAG
4. Retrieve context from Milvus
5. Rerank with BGE reranker
6. Generate response with vLLM
7. Synthesize speech with XTTS
8. Publish result to NATS "voice.response.{request_id}"
"""
import base64
import logging
from typing import Any, Optional
from nats.aio.msg import Msg
from handler_base import Handler, Settings
from handler_base.clients import (
EmbeddingsClient,
RerankerClient,
LLMClient,
TTSClient,
STTClient,
MilvusClient,
)
from handler_base.telemetry import create_span
logger = logging.getLogger("voice-assistant")
class VoiceSettings(Settings):
"""Voice assistant specific settings."""
service_name: str = "voice-assistant"
# RAG settings
rag_top_k: int = 10
rag_rerank_top_k: int = 5
rag_collection: str = "documents"
# Audio settings
stt_language: Optional[str] = None # Auto-detect
tts_language: str = "en"
# Response settings
include_transcription: bool = True
include_sources: bool = False
class VoiceAssistant(Handler):
"""
Voice request handler with full STT -> RAG -> LLM -> TTS pipeline.
Request format (msgpack):
{
"request_id": "uuid",
"audio": "base64 encoded audio",
"language": "optional language code",
"collection": "optional collection name"
}
Response format:
{
"request_id": "uuid",
"transcription": "what the user said",
"response": "generated text response",
"audio": "base64 encoded response audio"
}
"""
def __init__(self):
self.voice_settings = VoiceSettings()
super().__init__(
subject="voice.request",
settings=self.voice_settings,
queue_group="voice-assistants",
)
async def setup(self) -> None:
"""Initialize service clients."""
logger.info("Initializing voice assistant clients...")
self.stt = STTClient(self.voice_settings)
self.embeddings = EmbeddingsClient(self.voice_settings)
self.reranker = RerankerClient(self.voice_settings)
self.llm = LLMClient(self.voice_settings)
self.tts = TTSClient(self.voice_settings)
self.milvus = MilvusClient(self.voice_settings)
await self.milvus.connect(self.voice_settings.rag_collection)
logger.info("Voice assistant clients initialized")
async def teardown(self) -> None:
"""Clean up service clients."""
logger.info("Closing voice assistant clients...")
await self.stt.close()
await self.embeddings.close()
await self.reranker.close()
await self.llm.close()
await self.tts.close()
await self.milvus.close()
logger.info("Voice assistant clients closed")
async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]:
"""Handle incoming voice request."""
request_id = data.get("request_id", "unknown")
audio_b64 = data.get("audio", "")
language = data.get("language", self.voice_settings.stt_language)
collection = data.get("collection", self.voice_settings.rag_collection)
logger.info(f"Processing voice request {request_id}")
with create_span("voice.process") as span:
if span:
span.set_attribute("request.id", request_id)
# 1. Decode audio
audio_bytes = base64.b64decode(audio_b64)
# 2. Transcribe audio to text
transcription = await self._transcribe(audio_bytes, language)
query = transcription.get("text", "")
if not query.strip():
logger.warning(f"Empty transcription for request {request_id}")
return {
"request_id": request_id,
"error": "Could not transcribe audio",
}
logger.info(f"Transcribed: {query[:50]}...")
# 3. Generate query embedding
embedding = await self._get_embedding(query)
# 4. Search Milvus for context
documents = await self._search_context(embedding, collection)
# 5. Rerank documents
reranked = await self._rerank_documents(query, documents)
# 6. Build context
context = self._build_context(reranked)
# 7. Generate LLM response
response_text = await self._generate_response(query, context)
# 8. Synthesize speech
response_audio = await self._synthesize_speech(response_text)
# Build response
result = {
"request_id": request_id,
"response": response_text,
"audio": response_audio,
}
if self.voice_settings.include_transcription:
result["transcription"] = query
if self.voice_settings.include_sources:
result["sources"] = [
{"text": d["document"][:200], "score": d["score"]}
for d in reranked[:3]
]
logger.info(f"Completed voice request {request_id}")
# Publish to response subject
response_subject = f"voice.response.{request_id}"
await self.nats.publish(response_subject, result)
return result
async def _transcribe(
self, audio: bytes, language: Optional[str]
) -> dict:
"""Transcribe audio to text."""
with create_span("voice.stt"):
return await self.stt.transcribe(audio, language=language)
async def _get_embedding(self, text: str) -> list[float]:
"""Generate embedding for query text."""
with create_span("voice.embedding"):
return await self.embeddings.embed_single(text)
async def _search_context(
self, embedding: list[float], collection: str
) -> list[dict]:
"""Search Milvus for relevant documents."""
with create_span("voice.search"):
return await self.milvus.search_with_texts(
embedding,
limit=self.voice_settings.rag_top_k,
text_field="text",
)
async def _rerank_documents(
self, query: str, documents: list[dict]
) -> list[dict]:
"""Rerank documents by relevance."""
with create_span("voice.rerank"):
texts = [d.get("text", "") for d in documents]
return await self.reranker.rerank(
query, texts, top_k=self.voice_settings.rag_rerank_top_k
)
def _build_context(self, documents: list[dict]) -> str:
"""Build context string from ranked documents."""
return "\n\n".join(d.get("document", "") for d in documents)
async def _generate_response(self, query: str, context: str) -> str:
"""Generate LLM response."""
with create_span("voice.generate"):
return await self.llm.generate(query, context=context)
async def _synthesize_speech(self, text: str) -> str:
"""Synthesize speech and return base64."""
with create_span("voice.tts"):
audio_bytes = await self.tts.synthesize(
text, language=self.voice_settings.tts_language
)
return base64.b64encode(audio_bytes).decode()
if __name__ == "__main__":
VoiceAssistant().run()