From 77d6822a63ebf91d5646fbe151266ad225c0ca75 Mon Sep 17 00:00:00 2001 From: "Billy D." Date: Mon, 2 Feb 2026 07:10:54 -0500 Subject: [PATCH] refactor: consolidate to handler-base, migrate to pyproject.toml, add tests --- Dockerfile | 26 +- Dockerfile.v2 | 9 - README.md | 16 +- pyproject.toml | 40 ++ requirements.txt | 15 - tests/__init__.py | 1 + tests/conftest.py | 113 ++++ tests/test_voice_assistant.py | 199 +++++++ voice_assistant.py | 1017 ++++++--------------------------- voice_assistant_v2.py | 234 -------- 10 files changed, 548 insertions(+), 1122 deletions(-) delete mode 100644 Dockerfile.v2 create mode 100644 pyproject.toml delete mode 100644 requirements.txt create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_voice_assistant.py delete mode 100644 voice_assistant_v2.py diff --git a/Dockerfile b/Dockerfile index 80bab16..b1aa656 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,29 +1,9 @@ -FROM python:3.13-slim +# Voice Assistant - Using handler-base with audio support +ARG BASE_TAG=latest +FROM ghcr.io/billy-davies-2/handler-base:${BASE_TAG} WORKDIR /app -# Install uv for fast, reliable package management -COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv - -# Install system dependencies -RUN apt-get update && apt-get install -y --no-install-recommends \ - curl \ - && rm -rf /var/lib/apt/lists/* - -# Copy requirements first for better caching -COPY requirements.txt . -RUN uv pip install --system --no-cache -r requirements.txt - -# Copy application code COPY voice_assistant.py . -# Set environment variables -ENV PYTHONUNBUFFERED=1 -ENV PYTHONDONTWRITEBYTECODE=1 - -# Health check -HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ - CMD python -c "print('healthy')" || exit 1 - -# Run the application CMD ["python", "voice_assistant.py"] diff --git a/Dockerfile.v2 b/Dockerfile.v2 deleted file mode 100644 index 6a0d2d3..0000000 --- a/Dockerfile.v2 +++ /dev/null @@ -1,9 +0,0 @@ -# Voice Assistant v2 - Using handler-base with audio support -ARG BASE_TAG=local-audio -FROM ghcr.io/billy-davies-2/handler-base:${BASE_TAG} - -WORKDIR /app - -COPY voice_assistant_v2.py ./voice_assistant.py - -CMD ["python", "voice_assistant.py"] diff --git a/README.md b/README.md index 6d1b6b3..ef1888f 100644 --- a/README.md +++ b/README.md @@ -6,17 +6,10 @@ End-to-end voice assistant pipeline for the DaviesTechLabs AI/ML platform. ### Real-time Handler (NATS-based) -The voice assistant service listens on NATS for audio requests and returns synthesized speech responses. +The voice assistant service listens on NATS for audio requests and returns synthesized speech responses. It uses the [handler-base](https://git.daviestechlabs.io/daviestechlabs/handler-base) library for standardized NATS handling, telemetry, and health checks. **Pipeline:** STT → Embeddings → Milvus RAG → Rerank → LLM → TTS -| File | Description | -|------|-------------| -| `voice_assistant.py` | Standalone handler (v1) | -| `voice_assistant_v2.py` | Handler using handler-base library | -| `Dockerfile` | Standalone image | -| `Dockerfile.v2` | Handler-base image | - ### Kubeflow Pipeline (Batch) For batch processing or async workflows via Kubeflow Pipelines. @@ -106,11 +99,10 @@ NATS (voice.request) ## Building ```bash -# Standalone image (v1) -docker build -f Dockerfile -t voice-assistant:latest . +docker build -t voice-assistant:latest . -# Handler-base image (v2 - recommended) -docker build -f Dockerfile.v2 -t voice-assistant:v2 . +# With specific handler-base tag +docker build --build-arg BASE_TAG=latest -t voice-assistant:latest . ``` ## Related diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8b22e63 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,40 @@ +[project] +name = "voice-assistant" +version = "1.0.0" +description = "Voice assistant pipeline - STT → RAG → LLM → TTS" +readme = "README.md" +requires-python = ">=3.11" +license = { text = "MIT" } +authors = [{ name = "Davies Tech Labs" }] + +dependencies = [ + "handler-base @ git+https://git.daviestechlabs.io/daviestechlabs/handler-base.git", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", + "ruff>=0.1.0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["."] +only-include = ["voice_assistant.py"] + +[tool.ruff] +line-length = 100 +target-version = "py311" + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = "-v --tb=short" +filterwarnings = ["ignore::DeprecationWarning"] diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index ef68ffa..0000000 --- a/requirements.txt +++ /dev/null @@ -1,15 +0,0 @@ -nats-py -httpx -pymilvus -numpy -msgpack -redis>=5.0.0 -opentelemetry-api -opentelemetry-sdk -opentelemetry-exporter-otlp-proto-grpc -opentelemetry-exporter-otlp-proto-http -opentelemetry-instrumentation-httpx -opentelemetry-instrumentation-logging -# MLflow for inference metrics tracking -mlflow>=2.10.0 -psycopg2-binary>=2.9.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..ee96d26 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Voice Assistant Tests diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..7a2939b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,113 @@ +""" +Pytest configuration and fixtures for voice-assistant tests. +""" +import asyncio +import base64 +import os +from typing import AsyncGenerator +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Set test environment variables before importing +os.environ.setdefault("NATS_URL", "nats://localhost:4222") +os.environ.setdefault("REDIS_URL", "redis://localhost:6379") +os.environ.setdefault("MILVUS_HOST", "localhost") +os.environ.setdefault("OTEL_ENABLED", "false") +os.environ.setdefault("MLFLOW_ENABLED", "false") + + +@pytest.fixture(scope="session") +def event_loop(): + """Create event loop for async tests.""" + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture +def sample_audio_b64(): + """Sample base64 encoded audio for testing.""" + # 16-bit PCM silence (44 bytes header + 1000 samples) + wav_header = bytes([ + 0x52, 0x49, 0x46, 0x46, # "RIFF" + 0x24, 0x08, 0x00, 0x00, # File size + 0x57, 0x41, 0x56, 0x45, # "WAVE" + 0x66, 0x6D, 0x74, 0x20, # "fmt " + 0x10, 0x00, 0x00, 0x00, # Chunk size + 0x01, 0x00, # PCM format + 0x01, 0x00, # Mono + 0x80, 0x3E, 0x00, 0x00, # Sample rate (16000) + 0x00, 0x7D, 0x00, 0x00, # Byte rate + 0x02, 0x00, # Block align + 0x10, 0x00, # Bits per sample + 0x64, 0x61, 0x74, 0x61, # "data" + 0x00, 0x08, 0x00, 0x00, # Data size + ]) + silence = bytes([0x00] * 2048) + return base64.b64encode(wav_header + silence).decode() + + +@pytest.fixture +def sample_embedding(): + """Sample embedding vector.""" + return [0.1] * 1024 + + +@pytest.fixture +def sample_documents(): + """Sample search results.""" + return [ + {"text": "Machine learning is a subset of AI.", "score": 0.95}, + {"text": "Deep learning uses neural networks.", "score": 0.90}, + {"text": "AI enables intelligent automation.", "score": 0.85}, + ] + + +@pytest.fixture +def sample_reranked(): + """Sample reranked results.""" + return [ + {"document": "Machine learning is a subset of AI.", "score": 0.98}, + {"document": "Deep learning uses neural networks.", "score": 0.85}, + ] + + +@pytest.fixture +def mock_nats_message(): + """Create a mock NATS message.""" + msg = MagicMock() + msg.subject = "voice.request" + msg.reply = "voice.response.test-123" + return msg + + +@pytest.fixture +def mock_voice_request(sample_audio_b64): + """Sample voice request payload.""" + return { + "request_id": "test-request-123", + "audio": sample_audio_b64, + "language": "en", + "collection": "test_collection", + } + + +@pytest.fixture +def mock_clients(): + """Mock all service clients.""" + with patch("voice_assistant.STTClient") as stt, \ + patch("voice_assistant.EmbeddingsClient") as embeddings, \ + patch("voice_assistant.RerankerClient") as reranker, \ + patch("voice_assistant.LLMClient") as llm, \ + patch("voice_assistant.TTSClient") as tts, \ + patch("voice_assistant.MilvusClient") as milvus: + + yield { + "stt": stt, + "embeddings": embeddings, + "reranker": reranker, + "llm": llm, + "tts": tts, + "milvus": milvus, + } diff --git a/tests/test_voice_assistant.py b/tests/test_voice_assistant.py new file mode 100644 index 0000000..69de754 --- /dev/null +++ b/tests/test_voice_assistant.py @@ -0,0 +1,199 @@ +""" +Unit tests for VoiceAssistant handler. +""" +import base64 +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +# Import after environment is set up in conftest +from voice_assistant import VoiceAssistant, VoiceSettings + + +class TestVoiceSettings: + """Tests for VoiceSettings configuration.""" + + def test_default_settings(self): + """Test default settings values.""" + settings = VoiceSettings() + + assert settings.service_name == "voice-assistant" + assert settings.rag_top_k == 10 + assert settings.rag_rerank_top_k == 5 + assert settings.rag_collection == "documents" + assert settings.stt_language is None # Auto-detect + assert settings.tts_language == "en" + assert settings.include_transcription is True + assert settings.include_sources is False + + def test_custom_settings(self, monkeypatch): + """Test settings from environment.""" + monkeypatch.setenv("RAG_TOP_K", "20") + monkeypatch.setenv("RAG_COLLECTION", "custom_collection") + + # Note: Would need to re-instantiate settings to pick up env vars + settings = VoiceSettings( + rag_top_k=20, + rag_collection="custom_collection" + ) + + assert settings.rag_top_k == 20 + assert settings.rag_collection == "custom_collection" + + +class TestVoiceAssistant: + """Tests for VoiceAssistant handler.""" + + @pytest.fixture + def handler(self): + """Create handler with mocked clients.""" + with patch("voice_assistant.STTClient"), \ + patch("voice_assistant.EmbeddingsClient"), \ + patch("voice_assistant.RerankerClient"), \ + patch("voice_assistant.LLMClient"), \ + patch("voice_assistant.TTSClient"), \ + patch("voice_assistant.MilvusClient"): + + handler = VoiceAssistant() + + # Setup mock clients + handler.stt = AsyncMock() + handler.embeddings = AsyncMock() + handler.reranker = AsyncMock() + handler.llm = AsyncMock() + handler.tts = AsyncMock() + handler.milvus = AsyncMock() + handler.nats = AsyncMock() + + yield handler + + def test_init(self, handler): + """Test handler initialization.""" + assert handler.subject == "voice.request" + assert handler.queue_group == "voice-assistants" + assert handler.voice_settings.service_name == "voice-assistant" + + @pytest.mark.asyncio + async def test_handle_message_success( + self, + handler, + mock_nats_message, + mock_voice_request, + sample_embedding, + sample_documents, + sample_reranked, + ): + """Test successful voice request handling.""" + # Setup mocks + handler.stt.transcribe.return_value = {"text": "What is machine learning?"} + handler.embeddings.embed_single.return_value = sample_embedding + handler.milvus.search_with_texts.return_value = sample_documents + handler.reranker.rerank.return_value = sample_reranked + handler.llm.generate.return_value = "Machine learning is a type of AI." + handler.tts.synthesize.return_value = b"audio_bytes" + + # Execute + result = await handler.handle_message(mock_nats_message, mock_voice_request) + + # Verify + assert result["request_id"] == "test-request-123" + assert result["response"] == "Machine learning is a type of AI." + assert "audio" in result + assert result["transcription"] == "What is machine learning?" + + # Verify pipeline was called + handler.stt.transcribe.assert_called_once() + handler.embeddings.embed_single.assert_called_once() + handler.milvus.search_with_texts.assert_called_once() + handler.reranker.rerank.assert_called_once() + handler.llm.generate.assert_called_once() + handler.tts.synthesize.assert_called_once() + + @pytest.mark.asyncio + async def test_handle_message_empty_transcription( + self, + handler, + mock_nats_message, + mock_voice_request, + ): + """Test handling when transcription is empty.""" + handler.stt.transcribe.return_value = {"text": ""} + + result = await handler.handle_message(mock_nats_message, mock_voice_request) + + assert "error" in result + assert result["error"] == "Could not transcribe audio" + + # Verify pipeline stopped after transcription + handler.embeddings.embed_single.assert_not_called() + + @pytest.mark.asyncio + async def test_handle_message_with_sources( + self, + handler, + mock_nats_message, + mock_voice_request, + sample_embedding, + sample_documents, + sample_reranked, + ): + """Test response includes sources when enabled.""" + handler.voice_settings.include_sources = True + + # Setup mocks + handler.stt.transcribe.return_value = {"text": "Hello"} + handler.embeddings.embed_single.return_value = sample_embedding + handler.milvus.search_with_texts.return_value = sample_documents + handler.reranker.rerank.return_value = sample_reranked + handler.llm.generate.return_value = "Hi there!" + handler.tts.synthesize.return_value = b"audio" + + result = await handler.handle_message(mock_nats_message, mock_voice_request) + + assert "sources" in result + assert len(result["sources"]) <= 3 + + def test_build_context(self, handler): + """Test context building from documents.""" + documents = [ + {"document": "First doc content"}, + {"document": "Second doc content"}, + ] + + context = handler._build_context(documents) + + assert "First doc content" in context + assert "Second doc content" in context + + @pytest.mark.asyncio + async def test_setup_initializes_clients(self): + """Test that setup initializes all clients.""" + with patch("voice_assistant.STTClient") as stt_cls, \ + patch("voice_assistant.EmbeddingsClient") as emb_cls, \ + patch("voice_assistant.RerankerClient") as rer_cls, \ + patch("voice_assistant.LLMClient") as llm_cls, \ + patch("voice_assistant.TTSClient") as tts_cls, \ + patch("voice_assistant.MilvusClient") as mil_cls: + + mil_cls.return_value.connect = AsyncMock() + + handler = VoiceAssistant() + await handler.setup() + + stt_cls.assert_called_once() + emb_cls.assert_called_once() + rer_cls.assert_called_once() + llm_cls.assert_called_once() + tts_cls.assert_called_once() + mil_cls.assert_called_once() + + @pytest.mark.asyncio + async def test_teardown_closes_clients(self, handler): + """Test that teardown closes all clients.""" + await handler.teardown() + + handler.stt.close.assert_called_once() + handler.embeddings.close.assert_called_once() + handler.reranker.close.assert_called_once() + handler.llm.close.assert_called_once() + handler.tts.close.assert_called_once() + handler.milvus.close.assert_called_once() diff --git a/voice_assistant.py b/voice_assistant.py index 87c1587..de04ead 100644 --- a/voice_assistant.py +++ b/voice_assistant.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 """ -Voice Assistant Service +Voice Assistant Service (Refactored) -End-to-end voice assistant pipeline: +End-to-end voice assistant pipeline using handler-base: 1. Listen for audio on NATS subject "voice.request" 2. Transcribe with Whisper (STT) 3. Generate embeddings for RAG @@ -10,866 +10,225 @@ End-to-end voice assistant pipeline: 5. Rerank with BGE reranker 6. Generate response with vLLM 7. Synthesize speech with XTTS -8. Publish result to NATS "voice.response" +8. Publish result to NATS "voice.response.{request_id}" """ -import asyncio import base64 -import json import logging -import os -import signal -import subprocess -import sys -import time -from typing import List, Dict, Optional +from typing import Any, Optional -# Install dependencies on startup -subprocess.check_call([ - sys.executable, "-m", "pip", "install", "-q", - "-r", "/app/requirements.txt" -]) +from nats.aio.msg import Msg -import httpx -import msgpack -import nats -import redis.asyncio as redis -from pymilvus import connections, Collection, utility - -# OpenTelemetry imports -from opentelemetry import trace, metrics -from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.sdk.metrics import MeterProvider -from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader -from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter -from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter -from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP -from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP -from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE -from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor -from opentelemetry.instrumentation.logging import LoggingInstrumentor - -# MLflow inference tracking -try: - from mlflow_utils import InferenceMetricsTracker - from mlflow_utils.inference_tracker import InferenceMetrics - MLFLOW_AVAILABLE = True -except ImportError: - MLFLOW_AVAILABLE = False - InferenceMetricsTracker = None - InferenceMetrics = None - -# Configure logging -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +from handler_base import Handler, Settings +from handler_base.clients import ( + EmbeddingsClient, + RerankerClient, + LLMClient, + TTSClient, + STTClient, + MilvusClient, ) +from handler_base.telemetry import create_span + logger = logging.getLogger("voice-assistant") -def setup_telemetry(): - """Initialize OpenTelemetry tracing and metrics.""" - otel_enabled = os.environ.get("OTEL_ENABLED", "true").lower() == "true" - if not otel_enabled: - logger.info("OpenTelemetry disabled") - return None, None +class VoiceSettings(Settings): + """Voice assistant specific settings.""" - otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317") - service_name = os.environ.get("OTEL_SERVICE_NAME", "voice-assistant") - service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml") + service_name: str = "voice-assistant" - # HyperDX configuration - hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "") - hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io") - use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key + # RAG settings + rag_top_k: int = 10 + rag_rerank_top_k: int = 5 + rag_collection: str = "documents" - resource = Resource.create({ - SERVICE_NAME: service_name, - SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"), - SERVICE_NAMESPACE: service_namespace, - "deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"), - "host.name": os.environ.get("HOSTNAME", "unknown"), - }) + # Audio settings + stt_language: Optional[str] = None # Auto-detect + tts_language: str = "en" - trace_provider = TracerProvider(resource=resource) - - if use_hyperdx: - logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}") - headers = {"authorization": hyperdx_api_key} - otlp_span_exporter = OTLPSpanExporterHTTP( - endpoint=f"{hyperdx_endpoint}/v1/traces", - headers=headers - ) - otlp_metric_exporter = OTLPMetricExporterHTTP( - endpoint=f"{hyperdx_endpoint}/v1/metrics", - headers=headers - ) - else: - otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True) - otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True) - - trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter)) - trace.set_tracer_provider(trace_provider) - - metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000) - meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) - metrics.set_meter_provider(meter_provider) - - HTTPXClientInstrumentor().instrument() - LoggingInstrumentor().instrument(set_logging_format=True) - - destination = "HyperDX" if use_hyperdx else "OTEL Collector" - logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}") - - return trace.get_tracer(__name__), metrics.get_meter(__name__) - -# Configuration from environment -WHISPER_URL = os.environ.get( - "WHISPER_URL", "http://whisper-predictor.ai-ml.svc.cluster.local" -) -TTS_URL = os.environ.get("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local:5002") -EMBEDDINGS_URL = os.environ.get( - "EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local" -) -RERANKER_URL = os.environ.get( - "RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local" -) -VLLM_URL = os.environ.get("VLLM_URL", "http://llm-draft.ai-ml.svc.cluster.local:8000") -LLM_MODEL = os.environ.get("LLM_MODEL", "mistralai/Mistral-7B-Instruct-v0.3") -MILVUS_HOST = os.environ.get("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local") -MILVUS_PORT = int(os.environ.get("MILVUS_PORT", "19530")) -COLLECTION_NAME = os.environ.get("COLLECTION_NAME", "knowledge_base") -NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222") -VALKEY_URL = os.environ.get("VALKEY_URL", "redis://valkey.ai-ml.svc.cluster.local:6379") - -# MLflow configuration -MLFLOW_ENABLED = os.environ.get("MLFLOW_ENABLED", "true").lower() == "true" -MLFLOW_TRACKING_URI = os.environ.get( - "MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80" -) - -# Context window limits (characters) -MAX_CONTEXT_LENGTH = int(os.environ.get("MAX_CONTEXT_LENGTH", "8000")) # Prevent unbounded growth - -# NATS subjects (ai.* schema) -# Per-user channels matching companions-frontend pattern -REQUEST_SUBJECT = "ai.voice.user.*.request" # Wildcard subscription for all users -PREMIUM_REQUEST_SUBJECT = "ai.voice.premium.user.*.request" # Premium users -RESPONSE_SUBJECT = "ai.voice.response" # Response published to specific request_id -STREAM_RESPONSE_SUBJECT = "ai.voice.response.stream" # Streaming responses (token chunks) - -# System prompt for the assistant -SYSTEM_PROMPT = """You are a helpful voice assistant. -Answer questions based on the provided context when available. -Keep responses concise and natural for speech synthesis. -If you don't know the answer, say so clearly.""" + # Response settings + include_transcription: bool = True + include_sources: bool = False -class VoiceAssistant: +class VoiceAssistant(Handler): + """ + Voice request handler with full STT -> RAG -> LLM -> TTS pipeline. + + Request format (msgpack): + { + "request_id": "uuid", + "audio": "base64 encoded audio", + "language": "optional language code", + "collection": "optional collection name" + } + + Response format: + { + "request_id": "uuid", + "transcription": "what the user said", + "response": "generated text response", + "audio": "base64 encoded response audio" + } + """ + def __init__(self): - self.nc = None - self.http_client = None - self.collection = None - self.valkey_client = None - self.running = True - self.tracer = None - self.meter = None - self.request_counter = None - self.request_duration = None - self.stt_duration = None - self.tts_duration = None - # MLflow inference tracker - self.mlflow_tracker = None - - async def setup(self): - """Initialize all connections.""" - # Initialize OpenTelemetry - self.tracer, self.meter = setup_telemetry() + self.voice_settings = VoiceSettings() + super().__init__( + subject="voice.request", + settings=self.voice_settings, + queue_group="voice-assistants", + ) + + async def setup(self) -> None: + """Initialize service clients.""" + logger.info("Initializing voice assistant clients...") - # Setup metrics - if self.meter: - self.request_counter = self.meter.create_counter( - "voice.requests", - description="Number of voice requests processed", - unit="1" - ) - self.request_duration = self.meter.create_histogram( - "voice.request_duration", - description="Duration of voice request processing", - unit="s" - ) - self.stt_duration = self.meter.create_histogram( - "voice.stt_duration", - description="Duration of speech-to-text processing", - unit="s" - ) - self.tts_duration = self.meter.create_histogram( - "voice.tts_duration", - description="Duration of text-to-speech processing", - unit="s" - ) + self.stt = STTClient(self.voice_settings) + self.embeddings = EmbeddingsClient(self.voice_settings) + self.reranker = RerankerClient(self.voice_settings) + self.llm = LLMClient(self.voice_settings) + self.tts = TTSClient(self.voice_settings) + self.milvus = MilvusClient(self.voice_settings) - # Initialize MLflow inference tracker - if MLFLOW_ENABLED and MLFLOW_AVAILABLE: - try: - self.mlflow_tracker = InferenceMetricsTracker( - service_name="voice-assistant", - experiment_name="voice-inference", - tracking_uri=MLFLOW_TRACKING_URI, - batch_size=50, - flush_interval_seconds=60.0, - ) - await self.mlflow_tracker.start() - logger.info(f"MLflow inference tracking enabled at {MLFLOW_TRACKING_URI}") - except Exception as e: - logger.warning(f"MLflow initialization failed: {e}, tracking disabled") - self.mlflow_tracker = None - elif not MLFLOW_AVAILABLE: - logger.info("MLflow utils not available, inference tracking disabled") - else: - logger.info("MLflow tracking disabled via MLFLOW_ENABLED=false") + await self.milvus.connect(self.voice_settings.rag_collection) - # NATS connection - self.nc = await nats.connect(NATS_URL) - logger.info(f"Connected to NATS at {NATS_URL}") - - # HTTP client for services - self.http_client = httpx.AsyncClient(timeout=180.0) - - # Connect to Valkey for conversation history and context caching - try: - self.valkey_client = redis.from_url( - VALKEY_URL, - encoding="utf-8", - decode_responses=True, - socket_connect_timeout=5 - ) - await self.valkey_client.ping() - logger.info(f"Connected to Valkey at {VALKEY_URL}") - except Exception as e: - logger.warning(f"Valkey connection failed: {e}, conversation history disabled") - self.valkey_client = None - - # Connect to Milvus if collection exists - try: - connections.connect(host=MILVUS_HOST, port=MILVUS_PORT) - if utility.has_collection(COLLECTION_NAME): - self.collection = Collection(COLLECTION_NAME) - self.collection.load() - logger.info(f"Connected to Milvus collection: {COLLECTION_NAME}") - else: - logger.warning(f"Collection {COLLECTION_NAME} not found, RAG disabled") - except Exception as e: - logger.warning(f"Milvus connection failed: {e}, RAG disabled") - - async def transcribe(self, audio_b64: str) -> str: - """Transcribe audio using Whisper.""" - try: - audio_bytes = base64.b64decode(audio_b64) - files = {"file": ("audio.wav", audio_bytes, "audio/wav")} - response = await self.http_client.post( - f"{WHISPER_URL}/v1/audio/transcriptions", files=files - ) - result = response.json() - transcript = result.get("text", "") - logger.info(f"Transcribed: {transcript[:100]}...") - return transcript - except Exception as e: - logger.error(f"Transcription failed: {e}") - return "" - - async def get_embeddings(self, texts: List[str]) -> List[List[float]]: - """Get embeddings from the embedding service.""" - try: - response = await self.http_client.post( - f"{EMBEDDINGS_URL}/embeddings", json={"input": texts, "model": "bge"} - ) - result = response.json() - return [d["embedding"] for d in result.get("data", [])] - except Exception as e: - logger.error(f"Embedding failed: {e}") - return [] - - async def search_milvus( - self, query_embedding: List[float], top_k: int = 5 - ) -> List[Dict]: - """Search Milvus for relevant documents.""" - if not self.collection: - return [] - try: - results = self.collection.search( - data=[query_embedding], - anns_field="embedding", - param={"metric_type": "COSINE", "params": {"ef": 64}}, - limit=top_k, - output_fields=["text", "book_name", "page_num"], - ) - docs = [] - for hits in results: - for hit in hits: - docs.append( - { - "text": hit.entity.get("text", ""), - "source": f'{hit.entity.get("book_name", "")} p.{hit.entity.get("page_num", "")}', - "score": hit.score, - } - ) - return docs - except Exception as e: - logger.error(f"Milvus search failed: {e}") - return [] - - async def rerank(self, query: str, documents: List[str]) -> List[Dict]: - """Rerank documents using the reranker service.""" - if not documents: - return [] - try: - response = await self.http_client.post( - f"{RERANKER_URL}/v1/rerank", - json={"query": query, "documents": documents}, - ) - return response.json().get("results", []) - except Exception as e: - logger.error(f"Reranking failed: {e}") - return [{"index": i, "relevance_score": 0.5} for i in range(len(documents))] - - async def get_conversation_history(self, session_id: str, max_messages: int = 10) -> List[Dict]: - """Retrieve conversation history from Valkey.""" - if not self.valkey_client or not session_id: - return [] - try: - key = f"voice:history:{session_id}" - # Get the most recent messages (stored as a list) - history_json = await self.valkey_client.lrange(key, -max_messages, -1) - history = [json.loads(msg) for msg in history_json] - logger.info(f"Retrieved {len(history)} messages from history for session {session_id}") - return history - except Exception as e: - logger.error(f"Failed to get conversation history: {e}") - return [] - - async def save_message_to_history(self, session_id: str, role: str, content: str, ttl: int = 3600): - """Save a message to conversation history in Valkey.""" - if not self.valkey_client or not session_id: - return - try: - key = f"voice:history:{session_id}" - message = json.dumps({"role": role, "content": content, "timestamp": time.time()}) - # Use RPUSH to append to the list - await self.valkey_client.rpush(key, message) - # Set TTL on the key (1 hour by default) - await self.valkey_client.expire(key, ttl) - logger.debug(f"Saved {role} message to history for session {session_id}") - except Exception as e: - logger.error(f"Failed to save message to history: {e}") - - async def get_context_window(self, session_id: str) -> Optional[str]: - """Retrieve cached context window from Valkey for attention offloading.""" - if not self.valkey_client or not session_id: - return None - try: - key = f"voice:context:{session_id}" - context = await self.valkey_client.get(key) - if context: - logger.info(f"Retrieved cached context window for session {session_id}") - return context - except Exception as e: - logger.error(f"Failed to get context window: {e}") - return None - - async def save_context_window(self, session_id: str, context: str, ttl: int = 3600): - """Save context window to Valkey for attention offloading.""" - if not self.valkey_client or not session_id: - return - try: - key = f"voice:context:{session_id}" - await self.valkey_client.set(key, context, ex=ttl) - logger.debug(f"Saved context window for session {session_id}") - except Exception as e: - logger.error(f"Failed to save context window: {e}") - - async def generate_response(self, query: str, context: str = "", session_id: str = None) -> str: - """Generate response using vLLM with conversation history from Valkey.""" - try: - messages = [{"role": "system", "content": SYSTEM_PROMPT}] - - # Add conversation history from Valkey if session exists - if session_id: - history = await self.get_conversation_history(session_id) - messages.extend(history) - - if context: - messages.append( - { - "role": "user", - "content": f"Context:\n{context}\n\nQuestion: {query}", - } - ) - else: - messages.append({"role": "user", "content": query}) - - response = await self.http_client.post( - f"{VLLM_URL}/v1/chat/completions", - json={ - "model": LLM_MODEL, - "messages": messages, - "max_tokens": 500, - "temperature": 0.7, - }, - ) - result = response.json() - answer = result["choices"][0]["message"]["content"] - logger.info(f"Generated response: {answer[:100]}...") - - # Save messages to conversation history - if session_id: - await self.save_message_to_history(session_id, "user", query) - await self.save_message_to_history(session_id, "assistant", answer) - - return answer - except Exception as e: - logger.error(f"LLM generation failed: {e}") - return "I'm sorry, I couldn't generate a response." - - async def generate_response_streaming(self, query: str, context: str = "", request_id: str = "", session_id: str = None): - """Generate streaming response using vLLM and publish chunks to NATS. + logger.info("Voice assistant clients initialized") + + async def teardown(self) -> None: + """Clean up service clients.""" + logger.info("Closing voice assistant clients...") - Yields tokens as they are generated and publishes them to NATS streaming subject. - Returns the complete response text. - """ - try: - messages = [{"role": "system", "content": SYSTEM_PROMPT}] - - # Add conversation history from Valkey if session exists - if session_id: - history = await self.get_conversation_history(session_id) - messages.extend(history) - - if context: - messages.append( - { - "role": "user", - "content": f"Context:\n{context}\n\nQuestion: {query}", - } - ) - else: - messages.append({"role": "user", "content": query}) - - full_response = "" - - # Stream response from vLLM - async with self.http_client.stream( - "POST", - f"{VLLM_URL}/v1/chat/completions", - json={ - "model": LLM_MODEL, - "messages": messages, - "max_tokens": 500, - "temperature": 0.7, - "stream": True, - }, - timeout=60.0, - ) as response: - # Parse SSE (Server-Sent Events) stream - async for line in response.aiter_lines(): - if not line or not line.startswith("data: "): - continue - - data_str = line[6:] # Remove "data: " prefix - if data_str.strip() == "[DONE]": - break - - try: - chunk_data = json.loads(data_str) - - # Extract token from delta - if chunk_data.get("choices") and len(chunk_data["choices"]) > 0: - delta = chunk_data["choices"][0].get("delta", {}) - content = delta.get("content", "") - - if content: - full_response += content - - # Publish token chunk to NATS streaming subject - chunk_msg = { - "request_id": request_id, - "type": "chunk", - "content": content, - "done": False, - } - await self.nc.publish( - f"{STREAM_RESPONSE_SUBJECT}.{request_id}", - msgpack.packb(chunk_msg) - ) - except json.JSONDecodeError: - continue - - # Send completion message - completion_msg = { - "request_id": request_id, - "type": "done", - "content": "", - "done": True, - } - await self.nc.publish( - f"{STREAM_RESPONSE_SUBJECT}.{request_id}", - msgpack.packb(completion_msg) - ) - - logger.info(f"Streamed complete response ({len(full_response)} chars) for request {request_id}") - - # Save messages to conversation history - if session_id: - await self.save_message_to_history(session_id, "user", query) - await self.save_message_to_history(session_id, "assistant", full_response) - - return full_response - - except Exception as e: - logger.error(f"Streaming LLM generation failed: {e}") - # Send error message - error_msg = { - "request_id": request_id, - "type": "error", - "content": "I'm sorry, I couldn't generate a response.", - "done": True, - "error": str(e), - } - await self.nc.publish( - f"{STREAM_RESPONSE_SUBJECT}.{request_id}", - msgpack.packb(error_msg) - ) - return "I'm sorry, I couldn't generate a response." - - async def synthesize_speech(self, text: str, language: str = "en") -> str: - """Convert text to speech using XTTS (Coqui TTS).""" - try: - # XTTS API endpoint - uses /api/tts for synthesis - # The Coqui TTS server API accepts text and returns wav audio - response = await self.http_client.get( - f"{TTS_URL}/api/tts", - params={ - "text": text, - "language_id": language, - # Optional: specify speaker_id for multi-speaker models - }, - ) - if response.status_code == 200: - audio_b64 = base64.b64encode(response.content).decode("utf-8") - logger.info(f"Synthesized {len(response.content)} bytes of audio") - return audio_b64 - else: - logger.error( - f"TTS returned status {response.status_code}: {response.text}" - ) - return "" - except Exception as e: - logger.error(f"TTS failed: {e}") - return "" - - async def process_request(self, msg, is_premium=False): - """Process a voice assistant request.""" - start_time = time.time() - span = None + await self.stt.close() + await self.embeddings.close() + await self.reranker.close() + await self.llm.close() + await self.tts.close() + await self.milvus.close() - # MLflow metrics tracking - mlflow_metrics = None - stt_start = None - embedding_start = None - rag_start = None - rerank_start = None - llm_start = None - tts_start = None + logger.info("Voice assistant clients closed") + + async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]: + """Handle incoming voice request.""" + request_id = data.get("request_id", "unknown") + audio_b64 = data.get("audio", "") + language = data.get("language", self.voice_settings.stt_language) + collection = data.get("collection", self.voice_settings.rag_collection) - try: - data = msgpack.unpackb(msg.data, raw=False) - request_id = data.get("request_id", "unknown") - audio_b64 = data.get("audio_b64", "") - user_id = data.get("user_id") - - # Initialize MLflow metrics if available - if self.mlflow_tracker and MLFLOW_AVAILABLE: - mlflow_metrics = InferenceMetrics( - request_id=request_id, - user_id=user_id, - session_id=data.get("session_id"), - model_name=LLM_MODEL, - model_endpoint=VLLM_URL, - ) - - # Start tracing span - if self.tracer: - span = self.tracer.start_span("voice.process_request") - span.set_attribute("request_id", request_id) - span.set_attribute("user_id", user_id or "anonymous") - span.set_attribute("premium", is_premium) - - # Support both new parameters and backward compatibility with use_rag - use_rag = data.get("use_rag") # Legacy parameter - enable_rag = data.get( - "enable_rag", use_rag if use_rag is not None else True - ) - enable_reranker = data.get( - "enable_reranker", use_rag if use_rag is not None else True - ) - enable_streaming = data.get("enable_streaming", False) # New parameter for streaming - - # Premium channel retrieves more documents for deeper RAG - default_top_k = 15 if is_premium else 5 - top_k = data.get("top_k", default_top_k) - - language = data.get("language", "en") - session_id = data.get("session_id") - - # Update MLflow metrics with request params - if mlflow_metrics: - mlflow_metrics.rag_enabled = enable_rag - mlflow_metrics.reranker_enabled = enable_reranker - mlflow_metrics.is_streaming = enable_streaming - mlflow_metrics.is_premium = is_premium - - # Add attributes to span + logger.info(f"Processing voice request {request_id}") + + with create_span("voice.process") as span: if span: - span.set_attribute("enable_rag", enable_rag) - span.set_attribute("enable_reranker", enable_reranker) - span.set_attribute("enable_streaming", enable_streaming) - span.set_attribute("top_k", top_k) - - logger.info( - f"Processing {'premium ' if is_premium else ''}voice request {request_id} (RAG: {enable_rag}, Reranker: {enable_reranker}, top_k: {top_k})" - ) - - # Warn if reranker is enabled without RAG - if enable_reranker and not enable_rag: - logger.warning( - f"Request {request_id}: Reranker enabled without RAG - no documents to rerank" - ) - - # Step 1: Transcribe audio - stt_start = time.time() - transcript = await self.transcribe(audio_b64) - if mlflow_metrics: - mlflow_metrics.stt_latency = time.time() - stt_start - mlflow_metrics.prompt_length = len(transcript) if transcript else 0 + span.set_attribute("request.id", request_id) - if not transcript: - if mlflow_metrics: - mlflow_metrics.has_error = True - mlflow_metrics.error_message = "Transcription failed" - await self.publish_error(request_id, "Transcription failed") - return - - context = "" - rag_sources = [] - docs = [] - - # Step 2: RAG retrieval (if enabled) - if enable_rag and self.collection: - # Get embeddings - embedding_start = time.time() - embeddings = await self.get_embeddings([transcript]) - if mlflow_metrics: - mlflow_metrics.embedding_latency = time.time() - embedding_start - - if embeddings: - # Search Milvus with configurable top_k - rag_start = time.time() - docs = await self.search_milvus(embeddings[0], top_k=top_k) - if mlflow_metrics: - mlflow_metrics.rag_search_latency = time.time() - rag_start - mlflow_metrics.rag_documents_retrieved = len(docs) - if docs: - rag_sources = [d.get("source", "") for d in docs] - - # Step 3: Reranking (if enabled and we have documents) - if enable_reranker and docs: - # Rerank documents - rerank_start = time.time() - doc_texts = [d["text"] for d in docs] - reranked = await self.rerank(transcript, doc_texts) - if mlflow_metrics: - mlflow_metrics.rerank_latency = time.time() - rerank_start - # Take top 3 reranked documents with bounds checking - sorted_docs = sorted( - reranked, key=lambda x: x.get("relevance_score", 0), reverse=True - )[:3] - # Build context with bounds checking - # Note: doc_texts and docs have the same length (doc_texts derived from docs) - context_parts = [] - sources = [] - for item in sorted_docs: - idx = item.get("index", -1) - if 0 <= idx < len(docs): - context_parts.append(doc_texts[idx]) - sources.append(docs[idx].get("source", "")) - else: - logger.warning( - f"Reranker returned invalid index {idx} for {len(docs)} docs" - ) - context = "\n\n".join(context_parts) - rag_sources = sources - elif docs: - # Use documents without reranking (take top 3) - doc_texts = [d["text"] for d in docs[:3]] - context = "\n\n".join(doc_texts) - rag_sources = [d.get("source", "") for d in docs[:3]] - - # Step 4: Generate response (streaming or non-streaming) - # Check for cached context window from Valkey (for attention offloading) - cached_context = None - if session_id: - cached_context = await self.get_context_window(session_id) + # 1. Decode audio + audio_bytes = base64.b64decode(audio_b64) - # Combine RAG context with cached context if available - if cached_context and context: - # Prepend cached context to current RAG context - combined_context = f"{cached_context}\n\n{context}" - # Truncate to prevent unbounded growth - if len(combined_context) > MAX_CONTEXT_LENGTH: - logger.warning(f"Context length {len(combined_context)} exceeds max {MAX_CONTEXT_LENGTH}, truncating") - # Keep the most recent context (from the end) - combined_context = combined_context[-MAX_CONTEXT_LENGTH:] - context = combined_context - elif cached_context: - # Only cached context, still need to check length - if len(cached_context) > MAX_CONTEXT_LENGTH: - logger.warning(f"Cached context length {len(cached_context)} exceeds max {MAX_CONTEXT_LENGTH}, truncating") - cached_context = cached_context[-MAX_CONTEXT_LENGTH:] - context = cached_context + # 2. Transcribe audio to text + transcription = await self._transcribe(audio_bytes, language) + query = transcription.get("text", "") - # Save the combined context for future use (already truncated if needed) - if session_id and context: - await self.save_context_window(session_id, context) + if not query.strip(): + logger.warning(f"Empty transcription for request {request_id}") + return { + "request_id": request_id, + "error": "Could not transcribe audio", + } - # Track number of RAG docs used after reranking - if mlflow_metrics and enable_rag: - mlflow_metrics.rag_documents_used = min(3, len(docs)) if docs else 0 + logger.info(f"Transcribed: {query[:50]}...") - llm_start = time.time() - if enable_streaming: - # Use streaming response - answer = await self.generate_response_streaming(transcript, context, request_id, session_id) - else: - # Use non-streaming response - answer = await self.generate_response(transcript, context, session_id) + # 3. Generate query embedding + embedding = await self._get_embedding(query) - if mlflow_metrics: - mlflow_metrics.llm_latency = time.time() - llm_start - mlflow_metrics.response_length = len(answer) - # Estimate token counts (rough approximation: 4 chars per token) - mlflow_metrics.input_tokens = len(transcript) // 4 - mlflow_metrics.output_tokens = len(answer) // 4 - mlflow_metrics.total_tokens = mlflow_metrics.input_tokens + mlflow_metrics.output_tokens - - # Step 5: Synthesize speech - tts_start = time.time() - audio_response = await self.synthesize_speech(answer, language) - if mlflow_metrics: - mlflow_metrics.tts_latency = time.time() - tts_start - - # Publish result + # 4. Search Milvus for context + documents = await self._search_context(embedding, collection) + + # 5. Rerank documents + reranked = await self._rerank_documents(query, documents) + + # 6. Build context + context = self._build_context(reranked) + + # 7. Generate LLM response + response_text = await self._generate_response(query, context) + + # 8. Synthesize speech + response_audio = await self._synthesize_speech(response_text) + + # Build response result = { "request_id": request_id, - "user_id": user_id, - "transcript": transcript, - "response_text": answer, - "audio_b64": audio_response, - "used_rag": bool(context), - "rag_enabled": enable_rag, - "reranker_enabled": enable_reranker, - "rag_sources": rag_sources, - "success": True, + "response": response_text, + "audio": response_audio, } - await self.nc.publish( - f"{RESPONSE_SUBJECT}.{request_id}", msgpack.packb(result) + + if self.voice_settings.include_transcription: + result["transcription"] = query + + if self.voice_settings.include_sources: + result["sources"] = [ + {"text": d["document"][:200], "score": d["score"]} + for d in reranked[:3] + ] + + logger.info(f"Completed voice request {request_id}") + + # Publish to response subject + response_subject = f"voice.response.{request_id}" + await self.nats.publish(response_subject, result) + + return result + + async def _transcribe( + self, audio: bytes, language: Optional[str] + ) -> dict: + """Transcribe audio to text.""" + with create_span("voice.stt"): + return await self.stt.transcribe(audio, language=language) + + async def _get_embedding(self, text: str) -> list[float]: + """Generate embedding for query text.""" + with create_span("voice.embedding"): + return await self.embeddings.embed_single(text) + + async def _search_context( + self, embedding: list[float], collection: str + ) -> list[dict]: + """Search Milvus for relevant documents.""" + with create_span("voice.search"): + return await self.milvus.search_with_texts( + embedding, + limit=self.voice_settings.rag_top_k, + text_field="text", ) - logger.info(f"Published response for request {request_id}") - - # Record metrics - duration = time.time() - start_time - if self.request_counter: - self.request_counter.add(1, {"premium": str(is_premium), "rag_enabled": str(enable_rag), "success": "true"}) - if self.request_duration: - self.request_duration.record(duration, {"premium": str(is_premium), "rag_enabled": str(enable_rag)}) - if span: - span.set_attribute("success", True) - span.set_attribute("response_length", len(answer)) - span.set_attribute("transcript_length", len(transcript)) - - # Log to MLflow - if self.mlflow_tracker and mlflow_metrics: - mlflow_metrics.total_latency = duration - await self.mlflow_tracker.log_inference(mlflow_metrics) - - except Exception as e: - logger.error(f"Request processing failed: {e}") - if self.request_counter: - self.request_counter.add(1, {"premium": str(is_premium), "success": "false"}) - if span: - span.set_attribute("success", False) - span.set_attribute("error", str(e)) - - # Log error to MLflow - if self.mlflow_tracker and mlflow_metrics: - mlflow_metrics.has_error = True - mlflow_metrics.error_message = str(e) - mlflow_metrics.total_latency = time.time() - start_time - await self.mlflow_tracker.log_inference(mlflow_metrics) - - await self.publish_error(data.get("request_id", "unknown"), str(e)) - finally: - if span: - span.end() - - async def publish_error(self, request_id: str, error: str): - """Publish an error response.""" - result = {"request_id": request_id, "error": error, "success": False} - await self.nc.publish( - f"{RESPONSE_SUBJECT}.{request_id}", msgpack.packb(result) - ) - - async def process_premium_request(self, msg): - """Process a premium voice request (wrapper for deeper RAG).""" - await self.process_request(msg, is_premium=True) - - async def run(self): - """Main run loop.""" - await self.setup() - - # Subscribe to standard voice requests - sub = await self.nc.subscribe(REQUEST_SUBJECT, cb=self.process_request) - logger.info(f"Subscribed to {REQUEST_SUBJECT}") - - # Subscribe to premium voice requests (deeper RAG retrieval) - premium_sub = await self.nc.subscribe( - PREMIUM_REQUEST_SUBJECT, cb=self.process_premium_request - ) - logger.info(f"Subscribed to {PREMIUM_REQUEST_SUBJECT}") - - # Handle shutdown - def signal_handler(): - self.running = False - - loop = asyncio.get_event_loop() - for sig in (signal.SIGTERM, signal.SIGINT): - loop.add_signal_handler(sig, signal_handler) - - # Keep running - while self.running: - await asyncio.sleep(1) - - # Cleanup - await sub.unsubscribe() - await premium_sub.unsubscribe() - await self.nc.close() - if self.valkey_client: - await self.valkey_client.close() - if self.collection: - connections.disconnect("default") - if self.mlflow_tracker: - await self.mlflow_tracker.stop() - logger.info("Shutdown complete") + + async def _rerank_documents( + self, query: str, documents: list[dict] + ) -> list[dict]: + """Rerank documents by relevance.""" + with create_span("voice.rerank"): + texts = [d.get("text", "") for d in documents] + return await self.reranker.rerank( + query, texts, top_k=self.voice_settings.rag_rerank_top_k + ) + + def _build_context(self, documents: list[dict]) -> str: + """Build context string from ranked documents.""" + return "\n\n".join(d.get("document", "") for d in documents) + + async def _generate_response(self, query: str, context: str) -> str: + """Generate LLM response.""" + with create_span("voice.generate"): + return await self.llm.generate(query, context=context) + + async def _synthesize_speech(self, text: str) -> str: + """Synthesize speech and return base64.""" + with create_span("voice.tts"): + audio_bytes = await self.tts.synthesize( + text, language=self.voice_settings.tts_language + ) + return base64.b64encode(audio_bytes).decode() if __name__ == "__main__": - assistant = VoiceAssistant() - asyncio.run(assistant.run()) + VoiceAssistant().run() diff --git a/voice_assistant_v2.py b/voice_assistant_v2.py deleted file mode 100644 index de04ead..0000000 --- a/voice_assistant_v2.py +++ /dev/null @@ -1,234 +0,0 @@ -#!/usr/bin/env python3 -""" -Voice Assistant Service (Refactored) - -End-to-end voice assistant pipeline using handler-base: -1. Listen for audio on NATS subject "voice.request" -2. Transcribe with Whisper (STT) -3. Generate embeddings for RAG -4. Retrieve context from Milvus -5. Rerank with BGE reranker -6. Generate response with vLLM -7. Synthesize speech with XTTS -8. Publish result to NATS "voice.response.{request_id}" -""" -import base64 -import logging -from typing import Any, Optional - -from nats.aio.msg import Msg - -from handler_base import Handler, Settings -from handler_base.clients import ( - EmbeddingsClient, - RerankerClient, - LLMClient, - TTSClient, - STTClient, - MilvusClient, -) -from handler_base.telemetry import create_span - -logger = logging.getLogger("voice-assistant") - - -class VoiceSettings(Settings): - """Voice assistant specific settings.""" - - service_name: str = "voice-assistant" - - # RAG settings - rag_top_k: int = 10 - rag_rerank_top_k: int = 5 - rag_collection: str = "documents" - - # Audio settings - stt_language: Optional[str] = None # Auto-detect - tts_language: str = "en" - - # Response settings - include_transcription: bool = True - include_sources: bool = False - - -class VoiceAssistant(Handler): - """ - Voice request handler with full STT -> RAG -> LLM -> TTS pipeline. - - Request format (msgpack): - { - "request_id": "uuid", - "audio": "base64 encoded audio", - "language": "optional language code", - "collection": "optional collection name" - } - - Response format: - { - "request_id": "uuid", - "transcription": "what the user said", - "response": "generated text response", - "audio": "base64 encoded response audio" - } - """ - - def __init__(self): - self.voice_settings = VoiceSettings() - super().__init__( - subject="voice.request", - settings=self.voice_settings, - queue_group="voice-assistants", - ) - - async def setup(self) -> None: - """Initialize service clients.""" - logger.info("Initializing voice assistant clients...") - - self.stt = STTClient(self.voice_settings) - self.embeddings = EmbeddingsClient(self.voice_settings) - self.reranker = RerankerClient(self.voice_settings) - self.llm = LLMClient(self.voice_settings) - self.tts = TTSClient(self.voice_settings) - self.milvus = MilvusClient(self.voice_settings) - - await self.milvus.connect(self.voice_settings.rag_collection) - - logger.info("Voice assistant clients initialized") - - async def teardown(self) -> None: - """Clean up service clients.""" - logger.info("Closing voice assistant clients...") - - await self.stt.close() - await self.embeddings.close() - await self.reranker.close() - await self.llm.close() - await self.tts.close() - await self.milvus.close() - - logger.info("Voice assistant clients closed") - - async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]: - """Handle incoming voice request.""" - request_id = data.get("request_id", "unknown") - audio_b64 = data.get("audio", "") - language = data.get("language", self.voice_settings.stt_language) - collection = data.get("collection", self.voice_settings.rag_collection) - - logger.info(f"Processing voice request {request_id}") - - with create_span("voice.process") as span: - if span: - span.set_attribute("request.id", request_id) - - # 1. Decode audio - audio_bytes = base64.b64decode(audio_b64) - - # 2. Transcribe audio to text - transcription = await self._transcribe(audio_bytes, language) - query = transcription.get("text", "") - - if not query.strip(): - logger.warning(f"Empty transcription for request {request_id}") - return { - "request_id": request_id, - "error": "Could not transcribe audio", - } - - logger.info(f"Transcribed: {query[:50]}...") - - # 3. Generate query embedding - embedding = await self._get_embedding(query) - - # 4. Search Milvus for context - documents = await self._search_context(embedding, collection) - - # 5. Rerank documents - reranked = await self._rerank_documents(query, documents) - - # 6. Build context - context = self._build_context(reranked) - - # 7. Generate LLM response - response_text = await self._generate_response(query, context) - - # 8. Synthesize speech - response_audio = await self._synthesize_speech(response_text) - - # Build response - result = { - "request_id": request_id, - "response": response_text, - "audio": response_audio, - } - - if self.voice_settings.include_transcription: - result["transcription"] = query - - if self.voice_settings.include_sources: - result["sources"] = [ - {"text": d["document"][:200], "score": d["score"]} - for d in reranked[:3] - ] - - logger.info(f"Completed voice request {request_id}") - - # Publish to response subject - response_subject = f"voice.response.{request_id}" - await self.nats.publish(response_subject, result) - - return result - - async def _transcribe( - self, audio: bytes, language: Optional[str] - ) -> dict: - """Transcribe audio to text.""" - with create_span("voice.stt"): - return await self.stt.transcribe(audio, language=language) - - async def _get_embedding(self, text: str) -> list[float]: - """Generate embedding for query text.""" - with create_span("voice.embedding"): - return await self.embeddings.embed_single(text) - - async def _search_context( - self, embedding: list[float], collection: str - ) -> list[dict]: - """Search Milvus for relevant documents.""" - with create_span("voice.search"): - return await self.milvus.search_with_texts( - embedding, - limit=self.voice_settings.rag_top_k, - text_field="text", - ) - - async def _rerank_documents( - self, query: str, documents: list[dict] - ) -> list[dict]: - """Rerank documents by relevance.""" - with create_span("voice.rerank"): - texts = [d.get("text", "") for d in documents] - return await self.reranker.rerank( - query, texts, top_k=self.voice_settings.rag_rerank_top_k - ) - - def _build_context(self, documents: list[dict]) -> str: - """Build context string from ranked documents.""" - return "\n\n".join(d.get("document", "") for d in documents) - - async def _generate_response(self, query: str, context: str) -> str: - """Generate LLM response.""" - with create_span("voice.generate"): - return await self.llm.generate(query, context=context) - - async def _synthesize_speech(self, text: str) -> str: - """Synthesize speech and return base64.""" - with create_span("voice.tts"): - audio_bytes = await self.tts.synthesize( - text, language=self.voice_settings.tts_language - ) - return base64.b64encode(audio_bytes).decode() - - -if __name__ == "__main__": - VoiceAssistant().run()