diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..838a01b --- /dev/null +++ b/.gitignore @@ -0,0 +1,42 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual environments +venv/ +.venv/ +ENV/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# Compiled KFP pipelines +*.yaml +!pipelines/*.py + +# Local +.env +.env.local +*.log diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..80bab16 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,29 @@ +FROM python:3.13-slim + +WORKDIR /app + +# Install uv for fast, reliable package management +COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements first for better caching +COPY requirements.txt . +RUN uv pip install --system --no-cache -r requirements.txt + +# Copy application code +COPY voice_assistant.py . + +# Set environment variables +ENV PYTHONUNBUFFERED=1 +ENV PYTHONDONTWRITEBYTECODE=1 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD python -c "print('healthy')" || exit 1 + +# Run the application +CMD ["python", "voice_assistant.py"] diff --git a/Dockerfile.v2 b/Dockerfile.v2 new file mode 100644 index 0000000..6a0d2d3 --- /dev/null +++ b/Dockerfile.v2 @@ -0,0 +1,9 @@ +# Voice Assistant v2 - Using handler-base with audio support +ARG BASE_TAG=local-audio +FROM ghcr.io/billy-davies-2/handler-base:${BASE_TAG} + +WORKDIR /app + +COPY voice_assistant_v2.py ./voice_assistant.py + +CMD ["python", "voice_assistant.py"] diff --git a/README.md b/README.md index c3ba671..6d1b6b3 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,120 @@ -# voice-assistant +# Voice Assistant -voice assistance please \ No newline at end of file +End-to-end voice assistant pipeline for the DaviesTechLabs AI/ML platform. + +## Components + +### Real-time Handler (NATS-based) + +The voice assistant service listens on NATS for audio requests and returns synthesized speech responses. + +**Pipeline:** STT → Embeddings → Milvus RAG → Rerank → LLM → TTS + +| File | Description | +|------|-------------| +| `voice_assistant.py` | Standalone handler (v1) | +| `voice_assistant_v2.py` | Handler using handler-base library | +| `Dockerfile` | Standalone image | +| `Dockerfile.v2` | Handler-base image | + +### Kubeflow Pipeline (Batch) + +For batch processing or async workflows via Kubeflow Pipelines. + +| Pipeline | Description | +|----------|-------------| +| `voice_pipeline.yaml` | Full STT → RAG → TTS pipeline | +| `rag_pipeline.yaml` | Text-only RAG pipeline | +| `tts_pipeline.yaml` | Simple text-to-speech | + +```bash +# Compile pipelines +cd pipelines +pip install kfp==2.12.1 +python voice_pipeline.py +``` + +## Architecture + +``` +NATS (voice.request) + │ + ▼ +┌───────────────────┐ +│ Voice Assistant │ +│ Handler │ +└───────────────────┘ + │ + ├──▶ Whisper STT (elminster) + │ │ + │ ▼ + ├──▶ BGE Embeddings (drizzt) + │ │ + │ ▼ + ├──▶ Milvus Vector Search + │ │ + │ ▼ + ├──▶ BGE Reranker (danilo) + │ │ + │ ▼ + ├──▶ vLLM (khelben) + │ │ + │ ▼ + └──▶ XTTS TTS (elminster) + │ + ▼ + NATS (voice.response.{id}) +``` + +## Configuration + +| Environment Variable | Default | Description | +|---------------------|---------|-------------| +| `NATS_URL` | `nats://nats.ai-ml.svc.cluster.local:4222` | NATS server | +| `WHISPER_URL` | `http://whisper-predictor.ai-ml.svc.cluster.local` | STT service | +| `EMBEDDINGS_URL` | `http://embeddings-predictor.ai-ml.svc.cluster.local` | Embeddings | +| `RERANKER_URL` | `http://reranker-predictor.ai-ml.svc.cluster.local` | Reranker | +| `VLLM_URL` | `http://llm-draft.ai-ml.svc.cluster.local:8000` | LLM service | +| `TTS_URL` | `http://tts-predictor.ai-ml.svc.cluster.local` | TTS service | +| `MILVUS_HOST` | `milvus.ai-ml.svc.cluster.local` | Vector DB | +| `COLLECTION_NAME` | `knowledge_base` | Milvus collection | + +## NATS Message Format + +### Request (voice.request) + +```json +{ + "request_id": "uuid", + "audio": "base64-encoded-audio", + "language": "en", + "collection": "knowledge_base" +} +``` + +### Response (voice.response.{request_id}) + +```json +{ + "request_id": "uuid", + "transcription": "user question", + "response": "assistant answer", + "audio": "base64-encoded-audio" +} +``` + +## Building + +```bash +# Standalone image (v1) +docker build -f Dockerfile -t voice-assistant:latest . + +# Handler-base image (v2 - recommended) +docker build -f Dockerfile.v2 -t voice-assistant:v2 . +``` + +## Related + +- [homelab-design](https://git.daviestechlabs.io/daviestechlabs/homelab-design) - Architecture docs +- [kuberay-images](https://git.daviestechlabs.io/daviestechlabs/kuberay-images) - Ray worker images +- [handler-base](https://github.com/Billy-Davies-2/llm-workflows/tree/main/handler-base) - Base handler library diff --git a/pipelines/voice_pipeline.py b/pipelines/voice_pipeline.py new file mode 100644 index 0000000..886f6c0 --- /dev/null +++ b/pipelines/voice_pipeline.py @@ -0,0 +1,316 @@ +#!/usr/bin/env python3 +""" +Voice Pipeline - Kubeflow Pipelines SDK + +Compile this to create a Kubeflow Pipeline for voice assistant workflows. + +Usage: + pip install kfp==2.12.1 + python voice_pipeline.py + # Upload voice_pipeline.yaml to Kubeflow Pipelines UI +""" + +from kfp import dsl +from kfp import compiler + + +@dsl.component( + base_image="python:3.13-slim", + packages_to_install=["httpx"] +) +def transcribe_audio( + audio_b64: str, + whisper_url: str = "http://whisper-predictor.ai-ml.svc.cluster.local" +) -> str: + """Transcribe audio using Whisper STT service.""" + import base64 + import httpx + + audio_bytes = base64.b64decode(audio_b64) + + with httpx.Client(timeout=120.0) as client: + response = client.post( + f"{whisper_url}/v1/audio/transcriptions", + files={"file": ("audio.wav", audio_bytes, "audio/wav")}, + data={"model": "whisper-large-v3", "language": "en"} + ) + result = response.json() + + return result.get("text", "") + + +@dsl.component( + base_image="python:3.13-slim", + packages_to_install=["httpx"] +) +def generate_embeddings( + text: str, + embeddings_url: str = "http://embeddings-predictor.ai-ml.svc.cluster.local" +) -> list: + """Generate embeddings for RAG retrieval.""" + import httpx + + with httpx.Client(timeout=60.0) as client: + response = client.post( + f"{embeddings_url}/embeddings", + json={"input": text, "model": "bge-small-en-v1.5"} + ) + result = response.json() + + return result["data"][0]["embedding"] + + +@dsl.component( + base_image="python:3.13-slim", + packages_to_install=["pymilvus"] +) +def retrieve_context( + embedding: list, + milvus_host: str = "milvus.ai-ml.svc.cluster.local", + collection_name: str = "knowledge_base", + top_k: int = 5 +) -> list: + """Retrieve relevant documents from Milvus vector database.""" + from pymilvus import connections, Collection, utility + + connections.connect(host=milvus_host, port=19530) + + if not utility.has_collection(collection_name): + return [] + + collection = Collection(collection_name) + collection.load() + + results = collection.search( + data=[embedding], + anns_field="embedding", + param={"metric_type": "COSINE", "params": {"nprobe": 10}}, + limit=top_k, + output_fields=["text", "source"] + ) + + documents = [] + for hits in results: + for hit in hits: + documents.append({ + "text": hit.entity.get("text"), + "source": hit.entity.get("source"), + "score": hit.distance + }) + + return documents + + +@dsl.component( + base_image="python:3.13-slim", + packages_to_install=["httpx"] +) +def rerank_documents( + query: str, + documents: list, + reranker_url: str = "http://reranker-predictor.ai-ml.svc.cluster.local", + top_k: int = 3 +) -> list: + """Rerank documents using BGE reranker.""" + import httpx + + if not documents: + return [] + + with httpx.Client(timeout=60.0) as client: + response = client.post( + f"{reranker_url}/v1/rerank", + json={ + "query": query, + "documents": [doc["text"] for doc in documents], + "model": "bge-reranker-v2-m3" + } + ) + result = response.json() + + # Sort by rerank score + reranked = sorted( + zip(documents, result.get("scores", [0] * len(documents))), + key=lambda x: x[1], + reverse=True + )[:top_k] + + return [doc for doc, score in reranked] + + +@dsl.component( + base_image="python:3.13-slim", + packages_to_install=["httpx"] +) +def generate_response( + query: str, + context: list, + vllm_url: str = "http://llm-draft.ai-ml.svc.cluster.local:8000", + model: str = "mistralai/Mistral-7B-Instruct-v0.3" +) -> str: + """Generate response using vLLM.""" + import httpx + + # Build context + if context: + context_text = "\n\n".join([doc["text"] for doc in context]) + user_content = f"Context:\n{context_text}\n\nQuestion: {query}" + else: + user_content = query + + system_prompt = """You are a helpful voice assistant. +Answer questions based on the provided context when available. +Keep responses concise and natural for speech synthesis.""" + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_content} + ] + + with httpx.Client(timeout=180.0) as client: + response = client.post( + f"{vllm_url}/v1/chat/completions", + json={ + "model": model, + "messages": messages, + "max_tokens": 512, + "temperature": 0.7 + } + ) + result = response.json() + + return result["choices"][0]["message"]["content"] + + +@dsl.component( + base_image="python:3.13-slim", + packages_to_install=["httpx"] +) +def synthesize_speech( + text: str, + tts_url: str = "http://tts-predictor.ai-ml.svc.cluster.local" +) -> str: + """Convert text to speech using TTS service.""" + import base64 + import httpx + + with httpx.Client(timeout=120.0) as client: + response = client.post( + f"{tts_url}/v1/audio/speech", + json={ + "input": text, + "voice": "en_US-lessac-high", + "response_format": "wav" + } + ) + audio_b64 = base64.b64encode(response.content).decode("utf-8") + + return audio_b64 + + +@dsl.pipeline( + name="voice-assistant-rag-pipeline", + description="End-to-end voice assistant with RAG: STT -> Embeddings -> Milvus -> Rerank -> LLM -> TTS" +) +def voice_assistant_pipeline( + audio_b64: str, + collection_name: str = "knowledge_base" +): + """ + Voice Assistant Pipeline with RAG + + Args: + audio_b64: Base64-encoded audio file + collection_name: Milvus collection for RAG + """ + + # Step 1: Transcribe audio with Whisper + transcribe_task = transcribe_audio(audio_b64=audio_b64) + transcribe_task.set_caching_options(enable_caching=False) + + # Step 2: Generate embeddings + embed_task = generate_embeddings(text=transcribe_task.output) + embed_task.set_caching_options(enable_caching=True) + + # Step 3: Retrieve context from Milvus + retrieve_task = retrieve_context( + embedding=embed_task.output, + collection_name=collection_name + ) + + # Step 4: Rerank documents + rerank_task = rerank_documents( + query=transcribe_task.output, + documents=retrieve_task.output + ) + + # Step 5: Generate response with context + llm_task = generate_response( + query=transcribe_task.output, + context=rerank_task.output + ) + + # Step 6: Synthesize speech + tts_task = synthesize_speech(text=llm_task.output) + + +@dsl.pipeline( + name="text-to-speech-pipeline", + description="Simple text to speech pipeline" +) +def text_to_speech_pipeline(text: str): + """Simple TTS pipeline for testing.""" + tts_task = synthesize_speech(text=text) + + +@dsl.pipeline( + name="rag-query-pipeline", + description="RAG query pipeline: Embed -> Retrieve -> Rerank -> LLM" +) +def rag_query_pipeline( + query: str, + collection_name: str = "knowledge_base" +): + """ + RAG Query Pipeline (text input, no voice) + + Args: + query: Text query + collection_name: Milvus collection name + """ + # Embed the query + embed_task = generate_embeddings(text=query) + + # Retrieve from Milvus + retrieve_task = retrieve_context( + embedding=embed_task.output, + collection_name=collection_name + ) + + # Rerank + rerank_task = rerank_documents( + query=query, + documents=retrieve_task.output + ) + + # Generate response + llm_task = generate_response( + query=query, + context=rerank_task.output + ) + + +if __name__ == "__main__": + # Compile all pipelines + pipelines = [ + ("voice_pipeline.yaml", voice_assistant_pipeline), + ("tts_pipeline.yaml", text_to_speech_pipeline), + ("rag_pipeline.yaml", rag_query_pipeline), + ] + + for filename, pipeline_func in pipelines: + compiler.Compiler().compile(pipeline_func, filename) + print(f"Compiled: {filename}") + + print("\nUpload these YAML files to Kubeflow Pipelines UI at:") + print(" http://kubeflow.example.com/pipelines") diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ef68ffa --- /dev/null +++ b/requirements.txt @@ -0,0 +1,15 @@ +nats-py +httpx +pymilvus +numpy +msgpack +redis>=5.0.0 +opentelemetry-api +opentelemetry-sdk +opentelemetry-exporter-otlp-proto-grpc +opentelemetry-exporter-otlp-proto-http +opentelemetry-instrumentation-httpx +opentelemetry-instrumentation-logging +# MLflow for inference metrics tracking +mlflow>=2.10.0 +psycopg2-binary>=2.9.0 diff --git a/voice_assistant.py b/voice_assistant.py new file mode 100644 index 0000000..87c1587 --- /dev/null +++ b/voice_assistant.py @@ -0,0 +1,875 @@ +#!/usr/bin/env python3 +""" +Voice Assistant Service + +End-to-end voice assistant pipeline: +1. Listen for audio on NATS subject "voice.request" +2. Transcribe with Whisper (STT) +3. Generate embeddings for RAG +4. Retrieve context from Milvus +5. Rerank with BGE reranker +6. Generate response with vLLM +7. Synthesize speech with XTTS +8. Publish result to NATS "voice.response" +""" +import asyncio +import base64 +import json +import logging +import os +import signal +import subprocess +import sys +import time +from typing import List, Dict, Optional + +# Install dependencies on startup +subprocess.check_call([ + sys.executable, "-m", "pip", "install", "-q", + "-r", "/app/requirements.txt" +]) + +import httpx +import msgpack +import nats +import redis.asyncio as redis +from pymilvus import connections, Collection, utility + +# OpenTelemetry imports +from opentelemetry import trace, metrics +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter +from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP +from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE +from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor +from opentelemetry.instrumentation.logging import LoggingInstrumentor + +# MLflow inference tracking +try: + from mlflow_utils import InferenceMetricsTracker + from mlflow_utils.inference_tracker import InferenceMetrics + MLFLOW_AVAILABLE = True +except ImportError: + MLFLOW_AVAILABLE = False + InferenceMetricsTracker = None + InferenceMetrics = None + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger("voice-assistant") + + +def setup_telemetry(): + """Initialize OpenTelemetry tracing and metrics.""" + otel_enabled = os.environ.get("OTEL_ENABLED", "true").lower() == "true" + if not otel_enabled: + logger.info("OpenTelemetry disabled") + return None, None + + otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317") + service_name = os.environ.get("OTEL_SERVICE_NAME", "voice-assistant") + service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml") + + # HyperDX configuration + hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "") + hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io") + use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key + + resource = Resource.create({ + SERVICE_NAME: service_name, + SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"), + SERVICE_NAMESPACE: service_namespace, + "deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"), + "host.name": os.environ.get("HOSTNAME", "unknown"), + }) + + trace_provider = TracerProvider(resource=resource) + + if use_hyperdx: + logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}") + headers = {"authorization": hyperdx_api_key} + otlp_span_exporter = OTLPSpanExporterHTTP( + endpoint=f"{hyperdx_endpoint}/v1/traces", + headers=headers + ) + otlp_metric_exporter = OTLPMetricExporterHTTP( + endpoint=f"{hyperdx_endpoint}/v1/metrics", + headers=headers + ) + else: + otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True) + otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True) + + trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter)) + trace.set_tracer_provider(trace_provider) + + metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000) + meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) + metrics.set_meter_provider(meter_provider) + + HTTPXClientInstrumentor().instrument() + LoggingInstrumentor().instrument(set_logging_format=True) + + destination = "HyperDX" if use_hyperdx else "OTEL Collector" + logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}") + + return trace.get_tracer(__name__), metrics.get_meter(__name__) + +# Configuration from environment +WHISPER_URL = os.environ.get( + "WHISPER_URL", "http://whisper-predictor.ai-ml.svc.cluster.local" +) +TTS_URL = os.environ.get("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local:5002") +EMBEDDINGS_URL = os.environ.get( + "EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local" +) +RERANKER_URL = os.environ.get( + "RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local" +) +VLLM_URL = os.environ.get("VLLM_URL", "http://llm-draft.ai-ml.svc.cluster.local:8000") +LLM_MODEL = os.environ.get("LLM_MODEL", "mistralai/Mistral-7B-Instruct-v0.3") +MILVUS_HOST = os.environ.get("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local") +MILVUS_PORT = int(os.environ.get("MILVUS_PORT", "19530")) +COLLECTION_NAME = os.environ.get("COLLECTION_NAME", "knowledge_base") +NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222") +VALKEY_URL = os.environ.get("VALKEY_URL", "redis://valkey.ai-ml.svc.cluster.local:6379") + +# MLflow configuration +MLFLOW_ENABLED = os.environ.get("MLFLOW_ENABLED", "true").lower() == "true" +MLFLOW_TRACKING_URI = os.environ.get( + "MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80" +) + +# Context window limits (characters) +MAX_CONTEXT_LENGTH = int(os.environ.get("MAX_CONTEXT_LENGTH", "8000")) # Prevent unbounded growth + +# NATS subjects (ai.* schema) +# Per-user channels matching companions-frontend pattern +REQUEST_SUBJECT = "ai.voice.user.*.request" # Wildcard subscription for all users +PREMIUM_REQUEST_SUBJECT = "ai.voice.premium.user.*.request" # Premium users +RESPONSE_SUBJECT = "ai.voice.response" # Response published to specific request_id +STREAM_RESPONSE_SUBJECT = "ai.voice.response.stream" # Streaming responses (token chunks) + +# System prompt for the assistant +SYSTEM_PROMPT = """You are a helpful voice assistant. +Answer questions based on the provided context when available. +Keep responses concise and natural for speech synthesis. +If you don't know the answer, say so clearly.""" + + +class VoiceAssistant: + def __init__(self): + self.nc = None + self.http_client = None + self.collection = None + self.valkey_client = None + self.running = True + self.tracer = None + self.meter = None + self.request_counter = None + self.request_duration = None + self.stt_duration = None + self.tts_duration = None + # MLflow inference tracker + self.mlflow_tracker = None + + async def setup(self): + """Initialize all connections.""" + # Initialize OpenTelemetry + self.tracer, self.meter = setup_telemetry() + + # Setup metrics + if self.meter: + self.request_counter = self.meter.create_counter( + "voice.requests", + description="Number of voice requests processed", + unit="1" + ) + self.request_duration = self.meter.create_histogram( + "voice.request_duration", + description="Duration of voice request processing", + unit="s" + ) + self.stt_duration = self.meter.create_histogram( + "voice.stt_duration", + description="Duration of speech-to-text processing", + unit="s" + ) + self.tts_duration = self.meter.create_histogram( + "voice.tts_duration", + description="Duration of text-to-speech processing", + unit="s" + ) + + # Initialize MLflow inference tracker + if MLFLOW_ENABLED and MLFLOW_AVAILABLE: + try: + self.mlflow_tracker = InferenceMetricsTracker( + service_name="voice-assistant", + experiment_name="voice-inference", + tracking_uri=MLFLOW_TRACKING_URI, + batch_size=50, + flush_interval_seconds=60.0, + ) + await self.mlflow_tracker.start() + logger.info(f"MLflow inference tracking enabled at {MLFLOW_TRACKING_URI}") + except Exception as e: + logger.warning(f"MLflow initialization failed: {e}, tracking disabled") + self.mlflow_tracker = None + elif not MLFLOW_AVAILABLE: + logger.info("MLflow utils not available, inference tracking disabled") + else: + logger.info("MLflow tracking disabled via MLFLOW_ENABLED=false") + + # NATS connection + self.nc = await nats.connect(NATS_URL) + logger.info(f"Connected to NATS at {NATS_URL}") + + # HTTP client for services + self.http_client = httpx.AsyncClient(timeout=180.0) + + # Connect to Valkey for conversation history and context caching + try: + self.valkey_client = redis.from_url( + VALKEY_URL, + encoding="utf-8", + decode_responses=True, + socket_connect_timeout=5 + ) + await self.valkey_client.ping() + logger.info(f"Connected to Valkey at {VALKEY_URL}") + except Exception as e: + logger.warning(f"Valkey connection failed: {e}, conversation history disabled") + self.valkey_client = None + + # Connect to Milvus if collection exists + try: + connections.connect(host=MILVUS_HOST, port=MILVUS_PORT) + if utility.has_collection(COLLECTION_NAME): + self.collection = Collection(COLLECTION_NAME) + self.collection.load() + logger.info(f"Connected to Milvus collection: {COLLECTION_NAME}") + else: + logger.warning(f"Collection {COLLECTION_NAME} not found, RAG disabled") + except Exception as e: + logger.warning(f"Milvus connection failed: {e}, RAG disabled") + + async def transcribe(self, audio_b64: str) -> str: + """Transcribe audio using Whisper.""" + try: + audio_bytes = base64.b64decode(audio_b64) + files = {"file": ("audio.wav", audio_bytes, "audio/wav")} + response = await self.http_client.post( + f"{WHISPER_URL}/v1/audio/transcriptions", files=files + ) + result = response.json() + transcript = result.get("text", "") + logger.info(f"Transcribed: {transcript[:100]}...") + return transcript + except Exception as e: + logger.error(f"Transcription failed: {e}") + return "" + + async def get_embeddings(self, texts: List[str]) -> List[List[float]]: + """Get embeddings from the embedding service.""" + try: + response = await self.http_client.post( + f"{EMBEDDINGS_URL}/embeddings", json={"input": texts, "model": "bge"} + ) + result = response.json() + return [d["embedding"] for d in result.get("data", [])] + except Exception as e: + logger.error(f"Embedding failed: {e}") + return [] + + async def search_milvus( + self, query_embedding: List[float], top_k: int = 5 + ) -> List[Dict]: + """Search Milvus for relevant documents.""" + if not self.collection: + return [] + try: + results = self.collection.search( + data=[query_embedding], + anns_field="embedding", + param={"metric_type": "COSINE", "params": {"ef": 64}}, + limit=top_k, + output_fields=["text", "book_name", "page_num"], + ) + docs = [] + for hits in results: + for hit in hits: + docs.append( + { + "text": hit.entity.get("text", ""), + "source": f'{hit.entity.get("book_name", "")} p.{hit.entity.get("page_num", "")}', + "score": hit.score, + } + ) + return docs + except Exception as e: + logger.error(f"Milvus search failed: {e}") + return [] + + async def rerank(self, query: str, documents: List[str]) -> List[Dict]: + """Rerank documents using the reranker service.""" + if not documents: + return [] + try: + response = await self.http_client.post( + f"{RERANKER_URL}/v1/rerank", + json={"query": query, "documents": documents}, + ) + return response.json().get("results", []) + except Exception as e: + logger.error(f"Reranking failed: {e}") + return [{"index": i, "relevance_score": 0.5} for i in range(len(documents))] + + async def get_conversation_history(self, session_id: str, max_messages: int = 10) -> List[Dict]: + """Retrieve conversation history from Valkey.""" + if not self.valkey_client or not session_id: + return [] + try: + key = f"voice:history:{session_id}" + # Get the most recent messages (stored as a list) + history_json = await self.valkey_client.lrange(key, -max_messages, -1) + history = [json.loads(msg) for msg in history_json] + logger.info(f"Retrieved {len(history)} messages from history for session {session_id}") + return history + except Exception as e: + logger.error(f"Failed to get conversation history: {e}") + return [] + + async def save_message_to_history(self, session_id: str, role: str, content: str, ttl: int = 3600): + """Save a message to conversation history in Valkey.""" + if not self.valkey_client or not session_id: + return + try: + key = f"voice:history:{session_id}" + message = json.dumps({"role": role, "content": content, "timestamp": time.time()}) + # Use RPUSH to append to the list + await self.valkey_client.rpush(key, message) + # Set TTL on the key (1 hour by default) + await self.valkey_client.expire(key, ttl) + logger.debug(f"Saved {role} message to history for session {session_id}") + except Exception as e: + logger.error(f"Failed to save message to history: {e}") + + async def get_context_window(self, session_id: str) -> Optional[str]: + """Retrieve cached context window from Valkey for attention offloading.""" + if not self.valkey_client or not session_id: + return None + try: + key = f"voice:context:{session_id}" + context = await self.valkey_client.get(key) + if context: + logger.info(f"Retrieved cached context window for session {session_id}") + return context + except Exception as e: + logger.error(f"Failed to get context window: {e}") + return None + + async def save_context_window(self, session_id: str, context: str, ttl: int = 3600): + """Save context window to Valkey for attention offloading.""" + if not self.valkey_client or not session_id: + return + try: + key = f"voice:context:{session_id}" + await self.valkey_client.set(key, context, ex=ttl) + logger.debug(f"Saved context window for session {session_id}") + except Exception as e: + logger.error(f"Failed to save context window: {e}") + + async def generate_response(self, query: str, context: str = "", session_id: str = None) -> str: + """Generate response using vLLM with conversation history from Valkey.""" + try: + messages = [{"role": "system", "content": SYSTEM_PROMPT}] + + # Add conversation history from Valkey if session exists + if session_id: + history = await self.get_conversation_history(session_id) + messages.extend(history) + + if context: + messages.append( + { + "role": "user", + "content": f"Context:\n{context}\n\nQuestion: {query}", + } + ) + else: + messages.append({"role": "user", "content": query}) + + response = await self.http_client.post( + f"{VLLM_URL}/v1/chat/completions", + json={ + "model": LLM_MODEL, + "messages": messages, + "max_tokens": 500, + "temperature": 0.7, + }, + ) + result = response.json() + answer = result["choices"][0]["message"]["content"] + logger.info(f"Generated response: {answer[:100]}...") + + # Save messages to conversation history + if session_id: + await self.save_message_to_history(session_id, "user", query) + await self.save_message_to_history(session_id, "assistant", answer) + + return answer + except Exception as e: + logger.error(f"LLM generation failed: {e}") + return "I'm sorry, I couldn't generate a response." + + async def generate_response_streaming(self, query: str, context: str = "", request_id: str = "", session_id: str = None): + """Generate streaming response using vLLM and publish chunks to NATS. + + Yields tokens as they are generated and publishes them to NATS streaming subject. + Returns the complete response text. + """ + try: + messages = [{"role": "system", "content": SYSTEM_PROMPT}] + + # Add conversation history from Valkey if session exists + if session_id: + history = await self.get_conversation_history(session_id) + messages.extend(history) + + if context: + messages.append( + { + "role": "user", + "content": f"Context:\n{context}\n\nQuestion: {query}", + } + ) + else: + messages.append({"role": "user", "content": query}) + + full_response = "" + + # Stream response from vLLM + async with self.http_client.stream( + "POST", + f"{VLLM_URL}/v1/chat/completions", + json={ + "model": LLM_MODEL, + "messages": messages, + "max_tokens": 500, + "temperature": 0.7, + "stream": True, + }, + timeout=60.0, + ) as response: + # Parse SSE (Server-Sent Events) stream + async for line in response.aiter_lines(): + if not line or not line.startswith("data: "): + continue + + data_str = line[6:] # Remove "data: " prefix + if data_str.strip() == "[DONE]": + break + + try: + chunk_data = json.loads(data_str) + + # Extract token from delta + if chunk_data.get("choices") and len(chunk_data["choices"]) > 0: + delta = chunk_data["choices"][0].get("delta", {}) + content = delta.get("content", "") + + if content: + full_response += content + + # Publish token chunk to NATS streaming subject + chunk_msg = { + "request_id": request_id, + "type": "chunk", + "content": content, + "done": False, + } + await self.nc.publish( + f"{STREAM_RESPONSE_SUBJECT}.{request_id}", + msgpack.packb(chunk_msg) + ) + except json.JSONDecodeError: + continue + + # Send completion message + completion_msg = { + "request_id": request_id, + "type": "done", + "content": "", + "done": True, + } + await self.nc.publish( + f"{STREAM_RESPONSE_SUBJECT}.{request_id}", + msgpack.packb(completion_msg) + ) + + logger.info(f"Streamed complete response ({len(full_response)} chars) for request {request_id}") + + # Save messages to conversation history + if session_id: + await self.save_message_to_history(session_id, "user", query) + await self.save_message_to_history(session_id, "assistant", full_response) + + return full_response + + except Exception as e: + logger.error(f"Streaming LLM generation failed: {e}") + # Send error message + error_msg = { + "request_id": request_id, + "type": "error", + "content": "I'm sorry, I couldn't generate a response.", + "done": True, + "error": str(e), + } + await self.nc.publish( + f"{STREAM_RESPONSE_SUBJECT}.{request_id}", + msgpack.packb(error_msg) + ) + return "I'm sorry, I couldn't generate a response." + + async def synthesize_speech(self, text: str, language: str = "en") -> str: + """Convert text to speech using XTTS (Coqui TTS).""" + try: + # XTTS API endpoint - uses /api/tts for synthesis + # The Coqui TTS server API accepts text and returns wav audio + response = await self.http_client.get( + f"{TTS_URL}/api/tts", + params={ + "text": text, + "language_id": language, + # Optional: specify speaker_id for multi-speaker models + }, + ) + if response.status_code == 200: + audio_b64 = base64.b64encode(response.content).decode("utf-8") + logger.info(f"Synthesized {len(response.content)} bytes of audio") + return audio_b64 + else: + logger.error( + f"TTS returned status {response.status_code}: {response.text}" + ) + return "" + except Exception as e: + logger.error(f"TTS failed: {e}") + return "" + + async def process_request(self, msg, is_premium=False): + """Process a voice assistant request.""" + start_time = time.time() + span = None + + # MLflow metrics tracking + mlflow_metrics = None + stt_start = None + embedding_start = None + rag_start = None + rerank_start = None + llm_start = None + tts_start = None + + try: + data = msgpack.unpackb(msg.data, raw=False) + request_id = data.get("request_id", "unknown") + audio_b64 = data.get("audio_b64", "") + user_id = data.get("user_id") + + # Initialize MLflow metrics if available + if self.mlflow_tracker and MLFLOW_AVAILABLE: + mlflow_metrics = InferenceMetrics( + request_id=request_id, + user_id=user_id, + session_id=data.get("session_id"), + model_name=LLM_MODEL, + model_endpoint=VLLM_URL, + ) + + # Start tracing span + if self.tracer: + span = self.tracer.start_span("voice.process_request") + span.set_attribute("request_id", request_id) + span.set_attribute("user_id", user_id or "anonymous") + span.set_attribute("premium", is_premium) + + # Support both new parameters and backward compatibility with use_rag + use_rag = data.get("use_rag") # Legacy parameter + enable_rag = data.get( + "enable_rag", use_rag if use_rag is not None else True + ) + enable_reranker = data.get( + "enable_reranker", use_rag if use_rag is not None else True + ) + enable_streaming = data.get("enable_streaming", False) # New parameter for streaming + + # Premium channel retrieves more documents for deeper RAG + default_top_k = 15 if is_premium else 5 + top_k = data.get("top_k", default_top_k) + + language = data.get("language", "en") + session_id = data.get("session_id") + + # Update MLflow metrics with request params + if mlflow_metrics: + mlflow_metrics.rag_enabled = enable_rag + mlflow_metrics.reranker_enabled = enable_reranker + mlflow_metrics.is_streaming = enable_streaming + mlflow_metrics.is_premium = is_premium + + # Add attributes to span + if span: + span.set_attribute("enable_rag", enable_rag) + span.set_attribute("enable_reranker", enable_reranker) + span.set_attribute("enable_streaming", enable_streaming) + span.set_attribute("top_k", top_k) + + logger.info( + f"Processing {'premium ' if is_premium else ''}voice request {request_id} (RAG: {enable_rag}, Reranker: {enable_reranker}, top_k: {top_k})" + ) + + # Warn if reranker is enabled without RAG + if enable_reranker and not enable_rag: + logger.warning( + f"Request {request_id}: Reranker enabled without RAG - no documents to rerank" + ) + + # Step 1: Transcribe audio + stt_start = time.time() + transcript = await self.transcribe(audio_b64) + if mlflow_metrics: + mlflow_metrics.stt_latency = time.time() - stt_start + mlflow_metrics.prompt_length = len(transcript) if transcript else 0 + + if not transcript: + if mlflow_metrics: + mlflow_metrics.has_error = True + mlflow_metrics.error_message = "Transcription failed" + await self.publish_error(request_id, "Transcription failed") + return + + context = "" + rag_sources = [] + docs = [] + + # Step 2: RAG retrieval (if enabled) + if enable_rag and self.collection: + # Get embeddings + embedding_start = time.time() + embeddings = await self.get_embeddings([transcript]) + if mlflow_metrics: + mlflow_metrics.embedding_latency = time.time() - embedding_start + + if embeddings: + # Search Milvus with configurable top_k + rag_start = time.time() + docs = await self.search_milvus(embeddings[0], top_k=top_k) + if mlflow_metrics: + mlflow_metrics.rag_search_latency = time.time() - rag_start + mlflow_metrics.rag_documents_retrieved = len(docs) + if docs: + rag_sources = [d.get("source", "") for d in docs] + + # Step 3: Reranking (if enabled and we have documents) + if enable_reranker and docs: + # Rerank documents + rerank_start = time.time() + doc_texts = [d["text"] for d in docs] + reranked = await self.rerank(transcript, doc_texts) + if mlflow_metrics: + mlflow_metrics.rerank_latency = time.time() - rerank_start + # Take top 3 reranked documents with bounds checking + sorted_docs = sorted( + reranked, key=lambda x: x.get("relevance_score", 0), reverse=True + )[:3] + # Build context with bounds checking + # Note: doc_texts and docs have the same length (doc_texts derived from docs) + context_parts = [] + sources = [] + for item in sorted_docs: + idx = item.get("index", -1) + if 0 <= idx < len(docs): + context_parts.append(doc_texts[idx]) + sources.append(docs[idx].get("source", "")) + else: + logger.warning( + f"Reranker returned invalid index {idx} for {len(docs)} docs" + ) + context = "\n\n".join(context_parts) + rag_sources = sources + elif docs: + # Use documents without reranking (take top 3) + doc_texts = [d["text"] for d in docs[:3]] + context = "\n\n".join(doc_texts) + rag_sources = [d.get("source", "") for d in docs[:3]] + + # Step 4: Generate response (streaming or non-streaming) + # Check for cached context window from Valkey (for attention offloading) + cached_context = None + if session_id: + cached_context = await self.get_context_window(session_id) + + # Combine RAG context with cached context if available + if cached_context and context: + # Prepend cached context to current RAG context + combined_context = f"{cached_context}\n\n{context}" + # Truncate to prevent unbounded growth + if len(combined_context) > MAX_CONTEXT_LENGTH: + logger.warning(f"Context length {len(combined_context)} exceeds max {MAX_CONTEXT_LENGTH}, truncating") + # Keep the most recent context (from the end) + combined_context = combined_context[-MAX_CONTEXT_LENGTH:] + context = combined_context + elif cached_context: + # Only cached context, still need to check length + if len(cached_context) > MAX_CONTEXT_LENGTH: + logger.warning(f"Cached context length {len(cached_context)} exceeds max {MAX_CONTEXT_LENGTH}, truncating") + cached_context = cached_context[-MAX_CONTEXT_LENGTH:] + context = cached_context + + # Save the combined context for future use (already truncated if needed) + if session_id and context: + await self.save_context_window(session_id, context) + + # Track number of RAG docs used after reranking + if mlflow_metrics and enable_rag: + mlflow_metrics.rag_documents_used = min(3, len(docs)) if docs else 0 + + llm_start = time.time() + if enable_streaming: + # Use streaming response + answer = await self.generate_response_streaming(transcript, context, request_id, session_id) + else: + # Use non-streaming response + answer = await self.generate_response(transcript, context, session_id) + + if mlflow_metrics: + mlflow_metrics.llm_latency = time.time() - llm_start + mlflow_metrics.response_length = len(answer) + # Estimate token counts (rough approximation: 4 chars per token) + mlflow_metrics.input_tokens = len(transcript) // 4 + mlflow_metrics.output_tokens = len(answer) // 4 + mlflow_metrics.total_tokens = mlflow_metrics.input_tokens + mlflow_metrics.output_tokens + + # Step 5: Synthesize speech + tts_start = time.time() + audio_response = await self.synthesize_speech(answer, language) + if mlflow_metrics: + mlflow_metrics.tts_latency = time.time() - tts_start + + # Publish result + result = { + "request_id": request_id, + "user_id": user_id, + "transcript": transcript, + "response_text": answer, + "audio_b64": audio_response, + "used_rag": bool(context), + "rag_enabled": enable_rag, + "reranker_enabled": enable_reranker, + "rag_sources": rag_sources, + "success": True, + } + await self.nc.publish( + f"{RESPONSE_SUBJECT}.{request_id}", msgpack.packb(result) + ) + logger.info(f"Published response for request {request_id}") + + # Record metrics + duration = time.time() - start_time + if self.request_counter: + self.request_counter.add(1, {"premium": str(is_premium), "rag_enabled": str(enable_rag), "success": "true"}) + if self.request_duration: + self.request_duration.record(duration, {"premium": str(is_premium), "rag_enabled": str(enable_rag)}) + if span: + span.set_attribute("success", True) + span.set_attribute("response_length", len(answer)) + span.set_attribute("transcript_length", len(transcript)) + + # Log to MLflow + if self.mlflow_tracker and mlflow_metrics: + mlflow_metrics.total_latency = duration + await self.mlflow_tracker.log_inference(mlflow_metrics) + + except Exception as e: + logger.error(f"Request processing failed: {e}") + if self.request_counter: + self.request_counter.add(1, {"premium": str(is_premium), "success": "false"}) + if span: + span.set_attribute("success", False) + span.set_attribute("error", str(e)) + + # Log error to MLflow + if self.mlflow_tracker and mlflow_metrics: + mlflow_metrics.has_error = True + mlflow_metrics.error_message = str(e) + mlflow_metrics.total_latency = time.time() - start_time + await self.mlflow_tracker.log_inference(mlflow_metrics) + + await self.publish_error(data.get("request_id", "unknown"), str(e)) + finally: + if span: + span.end() + + async def publish_error(self, request_id: str, error: str): + """Publish an error response.""" + result = {"request_id": request_id, "error": error, "success": False} + await self.nc.publish( + f"{RESPONSE_SUBJECT}.{request_id}", msgpack.packb(result) + ) + + async def process_premium_request(self, msg): + """Process a premium voice request (wrapper for deeper RAG).""" + await self.process_request(msg, is_premium=True) + + async def run(self): + """Main run loop.""" + await self.setup() + + # Subscribe to standard voice requests + sub = await self.nc.subscribe(REQUEST_SUBJECT, cb=self.process_request) + logger.info(f"Subscribed to {REQUEST_SUBJECT}") + + # Subscribe to premium voice requests (deeper RAG retrieval) + premium_sub = await self.nc.subscribe( + PREMIUM_REQUEST_SUBJECT, cb=self.process_premium_request + ) + logger.info(f"Subscribed to {PREMIUM_REQUEST_SUBJECT}") + + # Handle shutdown + def signal_handler(): + self.running = False + + loop = asyncio.get_event_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, signal_handler) + + # Keep running + while self.running: + await asyncio.sleep(1) + + # Cleanup + await sub.unsubscribe() + await premium_sub.unsubscribe() + await self.nc.close() + if self.valkey_client: + await self.valkey_client.close() + if self.collection: + connections.disconnect("default") + if self.mlflow_tracker: + await self.mlflow_tracker.stop() + logger.info("Shutdown complete") + + +if __name__ == "__main__": + assistant = VoiceAssistant() + asyncio.run(assistant.run()) diff --git a/voice_assistant_v2.py b/voice_assistant_v2.py new file mode 100644 index 0000000..de04ead --- /dev/null +++ b/voice_assistant_v2.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python3 +""" +Voice Assistant Service (Refactored) + +End-to-end voice assistant pipeline using handler-base: +1. Listen for audio on NATS subject "voice.request" +2. Transcribe with Whisper (STT) +3. Generate embeddings for RAG +4. Retrieve context from Milvus +5. Rerank with BGE reranker +6. Generate response with vLLM +7. Synthesize speech with XTTS +8. Publish result to NATS "voice.response.{request_id}" +""" +import base64 +import logging +from typing import Any, Optional + +from nats.aio.msg import Msg + +from handler_base import Handler, Settings +from handler_base.clients import ( + EmbeddingsClient, + RerankerClient, + LLMClient, + TTSClient, + STTClient, + MilvusClient, +) +from handler_base.telemetry import create_span + +logger = logging.getLogger("voice-assistant") + + +class VoiceSettings(Settings): + """Voice assistant specific settings.""" + + service_name: str = "voice-assistant" + + # RAG settings + rag_top_k: int = 10 + rag_rerank_top_k: int = 5 + rag_collection: str = "documents" + + # Audio settings + stt_language: Optional[str] = None # Auto-detect + tts_language: str = "en" + + # Response settings + include_transcription: bool = True + include_sources: bool = False + + +class VoiceAssistant(Handler): + """ + Voice request handler with full STT -> RAG -> LLM -> TTS pipeline. + + Request format (msgpack): + { + "request_id": "uuid", + "audio": "base64 encoded audio", + "language": "optional language code", + "collection": "optional collection name" + } + + Response format: + { + "request_id": "uuid", + "transcription": "what the user said", + "response": "generated text response", + "audio": "base64 encoded response audio" + } + """ + + def __init__(self): + self.voice_settings = VoiceSettings() + super().__init__( + subject="voice.request", + settings=self.voice_settings, + queue_group="voice-assistants", + ) + + async def setup(self) -> None: + """Initialize service clients.""" + logger.info("Initializing voice assistant clients...") + + self.stt = STTClient(self.voice_settings) + self.embeddings = EmbeddingsClient(self.voice_settings) + self.reranker = RerankerClient(self.voice_settings) + self.llm = LLMClient(self.voice_settings) + self.tts = TTSClient(self.voice_settings) + self.milvus = MilvusClient(self.voice_settings) + + await self.milvus.connect(self.voice_settings.rag_collection) + + logger.info("Voice assistant clients initialized") + + async def teardown(self) -> None: + """Clean up service clients.""" + logger.info("Closing voice assistant clients...") + + await self.stt.close() + await self.embeddings.close() + await self.reranker.close() + await self.llm.close() + await self.tts.close() + await self.milvus.close() + + logger.info("Voice assistant clients closed") + + async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]: + """Handle incoming voice request.""" + request_id = data.get("request_id", "unknown") + audio_b64 = data.get("audio", "") + language = data.get("language", self.voice_settings.stt_language) + collection = data.get("collection", self.voice_settings.rag_collection) + + logger.info(f"Processing voice request {request_id}") + + with create_span("voice.process") as span: + if span: + span.set_attribute("request.id", request_id) + + # 1. Decode audio + audio_bytes = base64.b64decode(audio_b64) + + # 2. Transcribe audio to text + transcription = await self._transcribe(audio_bytes, language) + query = transcription.get("text", "") + + if not query.strip(): + logger.warning(f"Empty transcription for request {request_id}") + return { + "request_id": request_id, + "error": "Could not transcribe audio", + } + + logger.info(f"Transcribed: {query[:50]}...") + + # 3. Generate query embedding + embedding = await self._get_embedding(query) + + # 4. Search Milvus for context + documents = await self._search_context(embedding, collection) + + # 5. Rerank documents + reranked = await self._rerank_documents(query, documents) + + # 6. Build context + context = self._build_context(reranked) + + # 7. Generate LLM response + response_text = await self._generate_response(query, context) + + # 8. Synthesize speech + response_audio = await self._synthesize_speech(response_text) + + # Build response + result = { + "request_id": request_id, + "response": response_text, + "audio": response_audio, + } + + if self.voice_settings.include_transcription: + result["transcription"] = query + + if self.voice_settings.include_sources: + result["sources"] = [ + {"text": d["document"][:200], "score": d["score"]} + for d in reranked[:3] + ] + + logger.info(f"Completed voice request {request_id}") + + # Publish to response subject + response_subject = f"voice.response.{request_id}" + await self.nats.publish(response_subject, result) + + return result + + async def _transcribe( + self, audio: bytes, language: Optional[str] + ) -> dict: + """Transcribe audio to text.""" + with create_span("voice.stt"): + return await self.stt.transcribe(audio, language=language) + + async def _get_embedding(self, text: str) -> list[float]: + """Generate embedding for query text.""" + with create_span("voice.embedding"): + return await self.embeddings.embed_single(text) + + async def _search_context( + self, embedding: list[float], collection: str + ) -> list[dict]: + """Search Milvus for relevant documents.""" + with create_span("voice.search"): + return await self.milvus.search_with_texts( + embedding, + limit=self.voice_settings.rag_top_k, + text_field="text", + ) + + async def _rerank_documents( + self, query: str, documents: list[dict] + ) -> list[dict]: + """Rerank documents by relevance.""" + with create_span("voice.rerank"): + texts = [d.get("text", "") for d in documents] + return await self.reranker.rerank( + query, texts, top_k=self.voice_settings.rag_rerank_top_k + ) + + def _build_context(self, documents: list[dict]) -> str: + """Build context string from ranked documents.""" + return "\n\n".join(d.get("document", "") for d in documents) + + async def _generate_response(self, query: str, context: str) -> str: + """Generate LLM response.""" + with create_span("voice.generate"): + return await self.llm.generate(query, context=context) + + async def _synthesize_speech(self, text: str) -> str: + """Synthesize speech and return base64.""" + with create_span("voice.tts"): + audio_bytes = await self.tts.synthesize( + text, language=self.voice_settings.tts_language + ) + return base64.b64encode(audio_bytes).decode() + + +if __name__ == "__main__": + VoiceAssistant().run()