feat: Add voice assistant handler and Kubeflow pipeline

- voice_assistant.py: Standalone NATS handler with full RAG pipeline
- voice_assistant_v2.py: Handler-base implementation
- pipelines/voice_pipeline.py: KFP SDK pipeline definitions
- Dockerfiles for both standalone and handler-base versions

Pipeline: STT → Embeddings → Milvus → Rerank → LLM → TTS
This commit is contained in:
2026-02-01 20:32:37 -05:00
parent 08a17ebd8e
commit f0b626a5e7
8 changed files with 1639 additions and 2 deletions

42
.gitignore vendored Normal file
View File

@@ -0,0 +1,42 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# Virtual environments
venv/
.venv/
ENV/
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# Compiled KFP pipelines
*.yaml
!pipelines/*.py
# Local
.env
.env.local
*.log

29
Dockerfile Normal file
View File

@@ -0,0 +1,29 @@
FROM python:3.13-slim
WORKDIR /app
# Install uv for fast, reliable package management
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements first for better caching
COPY requirements.txt .
RUN uv pip install --system --no-cache -r requirements.txt
# Copy application code
COPY voice_assistant.py .
# Set environment variables
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD python -c "print('healthy')" || exit 1
# Run the application
CMD ["python", "voice_assistant.py"]

9
Dockerfile.v2 Normal file
View File

@@ -0,0 +1,9 @@
# Voice Assistant v2 - Using handler-base with audio support
ARG BASE_TAG=local-audio
FROM ghcr.io/billy-davies-2/handler-base:${BASE_TAG}
WORKDIR /app
COPY voice_assistant_v2.py ./voice_assistant.py
CMD ["python", "voice_assistant.py"]

121
README.md
View File

@@ -1,3 +1,120 @@
# voice-assistant
# Voice Assistant
voice assistance please
End-to-end voice assistant pipeline for the DaviesTechLabs AI/ML platform.
## Components
### Real-time Handler (NATS-based)
The voice assistant service listens on NATS for audio requests and returns synthesized speech responses.
**Pipeline:** STT → Embeddings → Milvus RAG → Rerank → LLM → TTS
| File | Description |
|------|-------------|
| `voice_assistant.py` | Standalone handler (v1) |
| `voice_assistant_v2.py` | Handler using handler-base library |
| `Dockerfile` | Standalone image |
| `Dockerfile.v2` | Handler-base image |
### Kubeflow Pipeline (Batch)
For batch processing or async workflows via Kubeflow Pipelines.
| Pipeline | Description |
|----------|-------------|
| `voice_pipeline.yaml` | Full STT → RAG → TTS pipeline |
| `rag_pipeline.yaml` | Text-only RAG pipeline |
| `tts_pipeline.yaml` | Simple text-to-speech |
```bash
# Compile pipelines
cd pipelines
pip install kfp==2.12.1
python voice_pipeline.py
```
## Architecture
```
NATS (voice.request)
┌───────────────────┐
│ Voice Assistant │
│ Handler │
└───────────────────┘
├──▶ Whisper STT (elminster)
│ │
│ ▼
├──▶ BGE Embeddings (drizzt)
│ │
│ ▼
├──▶ Milvus Vector Search
│ │
│ ▼
├──▶ BGE Reranker (danilo)
│ │
│ ▼
├──▶ vLLM (khelben)
│ │
│ ▼
└──▶ XTTS TTS (elminster)
NATS (voice.response.{id})
```
## Configuration
| Environment Variable | Default | Description |
|---------------------|---------|-------------|
| `NATS_URL` | `nats://nats.ai-ml.svc.cluster.local:4222` | NATS server |
| `WHISPER_URL` | `http://whisper-predictor.ai-ml.svc.cluster.local` | STT service |
| `EMBEDDINGS_URL` | `http://embeddings-predictor.ai-ml.svc.cluster.local` | Embeddings |
| `RERANKER_URL` | `http://reranker-predictor.ai-ml.svc.cluster.local` | Reranker |
| `VLLM_URL` | `http://llm-draft.ai-ml.svc.cluster.local:8000` | LLM service |
| `TTS_URL` | `http://tts-predictor.ai-ml.svc.cluster.local` | TTS service |
| `MILVUS_HOST` | `milvus.ai-ml.svc.cluster.local` | Vector DB |
| `COLLECTION_NAME` | `knowledge_base` | Milvus collection |
## NATS Message Format
### Request (voice.request)
```json
{
"request_id": "uuid",
"audio": "base64-encoded-audio",
"language": "en",
"collection": "knowledge_base"
}
```
### Response (voice.response.{request_id})
```json
{
"request_id": "uuid",
"transcription": "user question",
"response": "assistant answer",
"audio": "base64-encoded-audio"
}
```
## Building
```bash
# Standalone image (v1)
docker build -f Dockerfile -t voice-assistant:latest .
# Handler-base image (v2 - recommended)
docker build -f Dockerfile.v2 -t voice-assistant:v2 .
```
## Related
- [homelab-design](https://git.daviestechlabs.io/daviestechlabs/homelab-design) - Architecture docs
- [kuberay-images](https://git.daviestechlabs.io/daviestechlabs/kuberay-images) - Ray worker images
- [handler-base](https://github.com/Billy-Davies-2/llm-workflows/tree/main/handler-base) - Base handler library

316
pipelines/voice_pipeline.py Normal file
View File

@@ -0,0 +1,316 @@
#!/usr/bin/env python3
"""
Voice Pipeline - Kubeflow Pipelines SDK
Compile this to create a Kubeflow Pipeline for voice assistant workflows.
Usage:
pip install kfp==2.12.1
python voice_pipeline.py
# Upload voice_pipeline.yaml to Kubeflow Pipelines UI
"""
from kfp import dsl
from kfp import compiler
@dsl.component(
base_image="python:3.13-slim",
packages_to_install=["httpx"]
)
def transcribe_audio(
audio_b64: str,
whisper_url: str = "http://whisper-predictor.ai-ml.svc.cluster.local"
) -> str:
"""Transcribe audio using Whisper STT service."""
import base64
import httpx
audio_bytes = base64.b64decode(audio_b64)
with httpx.Client(timeout=120.0) as client:
response = client.post(
f"{whisper_url}/v1/audio/transcriptions",
files={"file": ("audio.wav", audio_bytes, "audio/wav")},
data={"model": "whisper-large-v3", "language": "en"}
)
result = response.json()
return result.get("text", "")
@dsl.component(
base_image="python:3.13-slim",
packages_to_install=["httpx"]
)
def generate_embeddings(
text: str,
embeddings_url: str = "http://embeddings-predictor.ai-ml.svc.cluster.local"
) -> list:
"""Generate embeddings for RAG retrieval."""
import httpx
with httpx.Client(timeout=60.0) as client:
response = client.post(
f"{embeddings_url}/embeddings",
json={"input": text, "model": "bge-small-en-v1.5"}
)
result = response.json()
return result["data"][0]["embedding"]
@dsl.component(
base_image="python:3.13-slim",
packages_to_install=["pymilvus"]
)
def retrieve_context(
embedding: list,
milvus_host: str = "milvus.ai-ml.svc.cluster.local",
collection_name: str = "knowledge_base",
top_k: int = 5
) -> list:
"""Retrieve relevant documents from Milvus vector database."""
from pymilvus import connections, Collection, utility
connections.connect(host=milvus_host, port=19530)
if not utility.has_collection(collection_name):
return []
collection = Collection(collection_name)
collection.load()
results = collection.search(
data=[embedding],
anns_field="embedding",
param={"metric_type": "COSINE", "params": {"nprobe": 10}},
limit=top_k,
output_fields=["text", "source"]
)
documents = []
for hits in results:
for hit in hits:
documents.append({
"text": hit.entity.get("text"),
"source": hit.entity.get("source"),
"score": hit.distance
})
return documents
@dsl.component(
base_image="python:3.13-slim",
packages_to_install=["httpx"]
)
def rerank_documents(
query: str,
documents: list,
reranker_url: str = "http://reranker-predictor.ai-ml.svc.cluster.local",
top_k: int = 3
) -> list:
"""Rerank documents using BGE reranker."""
import httpx
if not documents:
return []
with httpx.Client(timeout=60.0) as client:
response = client.post(
f"{reranker_url}/v1/rerank",
json={
"query": query,
"documents": [doc["text"] for doc in documents],
"model": "bge-reranker-v2-m3"
}
)
result = response.json()
# Sort by rerank score
reranked = sorted(
zip(documents, result.get("scores", [0] * len(documents))),
key=lambda x: x[1],
reverse=True
)[:top_k]
return [doc for doc, score in reranked]
@dsl.component(
base_image="python:3.13-slim",
packages_to_install=["httpx"]
)
def generate_response(
query: str,
context: list,
vllm_url: str = "http://llm-draft.ai-ml.svc.cluster.local:8000",
model: str = "mistralai/Mistral-7B-Instruct-v0.3"
) -> str:
"""Generate response using vLLM."""
import httpx
# Build context
if context:
context_text = "\n\n".join([doc["text"] for doc in context])
user_content = f"Context:\n{context_text}\n\nQuestion: {query}"
else:
user_content = query
system_prompt = """You are a helpful voice assistant.
Answer questions based on the provided context when available.
Keep responses concise and natural for speech synthesis."""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content}
]
with httpx.Client(timeout=180.0) as client:
response = client.post(
f"{vllm_url}/v1/chat/completions",
json={
"model": model,
"messages": messages,
"max_tokens": 512,
"temperature": 0.7
}
)
result = response.json()
return result["choices"][0]["message"]["content"]
@dsl.component(
base_image="python:3.13-slim",
packages_to_install=["httpx"]
)
def synthesize_speech(
text: str,
tts_url: str = "http://tts-predictor.ai-ml.svc.cluster.local"
) -> str:
"""Convert text to speech using TTS service."""
import base64
import httpx
with httpx.Client(timeout=120.0) as client:
response = client.post(
f"{tts_url}/v1/audio/speech",
json={
"input": text,
"voice": "en_US-lessac-high",
"response_format": "wav"
}
)
audio_b64 = base64.b64encode(response.content).decode("utf-8")
return audio_b64
@dsl.pipeline(
name="voice-assistant-rag-pipeline",
description="End-to-end voice assistant with RAG: STT -> Embeddings -> Milvus -> Rerank -> LLM -> TTS"
)
def voice_assistant_pipeline(
audio_b64: str,
collection_name: str = "knowledge_base"
):
"""
Voice Assistant Pipeline with RAG
Args:
audio_b64: Base64-encoded audio file
collection_name: Milvus collection for RAG
"""
# Step 1: Transcribe audio with Whisper
transcribe_task = transcribe_audio(audio_b64=audio_b64)
transcribe_task.set_caching_options(enable_caching=False)
# Step 2: Generate embeddings
embed_task = generate_embeddings(text=transcribe_task.output)
embed_task.set_caching_options(enable_caching=True)
# Step 3: Retrieve context from Milvus
retrieve_task = retrieve_context(
embedding=embed_task.output,
collection_name=collection_name
)
# Step 4: Rerank documents
rerank_task = rerank_documents(
query=transcribe_task.output,
documents=retrieve_task.output
)
# Step 5: Generate response with context
llm_task = generate_response(
query=transcribe_task.output,
context=rerank_task.output
)
# Step 6: Synthesize speech
tts_task = synthesize_speech(text=llm_task.output)
@dsl.pipeline(
name="text-to-speech-pipeline",
description="Simple text to speech pipeline"
)
def text_to_speech_pipeline(text: str):
"""Simple TTS pipeline for testing."""
tts_task = synthesize_speech(text=text)
@dsl.pipeline(
name="rag-query-pipeline",
description="RAG query pipeline: Embed -> Retrieve -> Rerank -> LLM"
)
def rag_query_pipeline(
query: str,
collection_name: str = "knowledge_base"
):
"""
RAG Query Pipeline (text input, no voice)
Args:
query: Text query
collection_name: Milvus collection name
"""
# Embed the query
embed_task = generate_embeddings(text=query)
# Retrieve from Milvus
retrieve_task = retrieve_context(
embedding=embed_task.output,
collection_name=collection_name
)
# Rerank
rerank_task = rerank_documents(
query=query,
documents=retrieve_task.output
)
# Generate response
llm_task = generate_response(
query=query,
context=rerank_task.output
)
if __name__ == "__main__":
# Compile all pipelines
pipelines = [
("voice_pipeline.yaml", voice_assistant_pipeline),
("tts_pipeline.yaml", text_to_speech_pipeline),
("rag_pipeline.yaml", rag_query_pipeline),
]
for filename, pipeline_func in pipelines:
compiler.Compiler().compile(pipeline_func, filename)
print(f"Compiled: {filename}")
print("\nUpload these YAML files to Kubeflow Pipelines UI at:")
print(" http://kubeflow.example.com/pipelines")

15
requirements.txt Normal file
View File

@@ -0,0 +1,15 @@
nats-py
httpx
pymilvus
numpy
msgpack
redis>=5.0.0
opentelemetry-api
opentelemetry-sdk
opentelemetry-exporter-otlp-proto-grpc
opentelemetry-exporter-otlp-proto-http
opentelemetry-instrumentation-httpx
opentelemetry-instrumentation-logging
# MLflow for inference metrics tracking
mlflow>=2.10.0
psycopg2-binary>=2.9.0

875
voice_assistant.py Normal file
View File

@@ -0,0 +1,875 @@
#!/usr/bin/env python3
"""
Voice Assistant Service
End-to-end voice assistant pipeline:
1. Listen for audio on NATS subject "voice.request"
2. Transcribe with Whisper (STT)
3. Generate embeddings for RAG
4. Retrieve context from Milvus
5. Rerank with BGE reranker
6. Generate response with vLLM
7. Synthesize speech with XTTS
8. Publish result to NATS "voice.response"
"""
import asyncio
import base64
import json
import logging
import os
import signal
import subprocess
import sys
import time
from typing import List, Dict, Optional
# Install dependencies on startup
subprocess.check_call([
sys.executable, "-m", "pip", "install", "-q",
"-r", "/app/requirements.txt"
])
import httpx
import msgpack
import nats
import redis.asyncio as redis
from pymilvus import connections, Collection, utility
# OpenTelemetry imports
from opentelemetry import trace, metrics
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP
from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
from opentelemetry.instrumentation.logging import LoggingInstrumentor
# MLflow inference tracking
try:
from mlflow_utils import InferenceMetricsTracker
from mlflow_utils.inference_tracker import InferenceMetrics
MLFLOW_AVAILABLE = True
except ImportError:
MLFLOW_AVAILABLE = False
InferenceMetricsTracker = None
InferenceMetrics = None
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("voice-assistant")
def setup_telemetry():
"""Initialize OpenTelemetry tracing and metrics."""
otel_enabled = os.environ.get("OTEL_ENABLED", "true").lower() == "true"
if not otel_enabled:
logger.info("OpenTelemetry disabled")
return None, None
otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317")
service_name = os.environ.get("OTEL_SERVICE_NAME", "voice-assistant")
service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml")
# HyperDX configuration
hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "")
hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io")
use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key
resource = Resource.create({
SERVICE_NAME: service_name,
SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"),
SERVICE_NAMESPACE: service_namespace,
"deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"),
"host.name": os.environ.get("HOSTNAME", "unknown"),
})
trace_provider = TracerProvider(resource=resource)
if use_hyperdx:
logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}")
headers = {"authorization": hyperdx_api_key}
otlp_span_exporter = OTLPSpanExporterHTTP(
endpoint=f"{hyperdx_endpoint}/v1/traces",
headers=headers
)
otlp_metric_exporter = OTLPMetricExporterHTTP(
endpoint=f"{hyperdx_endpoint}/v1/metrics",
headers=headers
)
else:
otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True)
otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True)
trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter))
trace.set_tracer_provider(trace_provider)
metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000)
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
metrics.set_meter_provider(meter_provider)
HTTPXClientInstrumentor().instrument()
LoggingInstrumentor().instrument(set_logging_format=True)
destination = "HyperDX" if use_hyperdx else "OTEL Collector"
logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}")
return trace.get_tracer(__name__), metrics.get_meter(__name__)
# Configuration from environment
WHISPER_URL = os.environ.get(
"WHISPER_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"
)
TTS_URL = os.environ.get("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local:5002")
EMBEDDINGS_URL = os.environ.get(
"EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"
)
RERANKER_URL = os.environ.get(
"RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"
)
VLLM_URL = os.environ.get("VLLM_URL", "http://llm-draft.ai-ml.svc.cluster.local:8000")
LLM_MODEL = os.environ.get("LLM_MODEL", "mistralai/Mistral-7B-Instruct-v0.3")
MILVUS_HOST = os.environ.get("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local")
MILVUS_PORT = int(os.environ.get("MILVUS_PORT", "19530"))
COLLECTION_NAME = os.environ.get("COLLECTION_NAME", "knowledge_base")
NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222")
VALKEY_URL = os.environ.get("VALKEY_URL", "redis://valkey.ai-ml.svc.cluster.local:6379")
# MLflow configuration
MLFLOW_ENABLED = os.environ.get("MLFLOW_ENABLED", "true").lower() == "true"
MLFLOW_TRACKING_URI = os.environ.get(
"MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"
)
# Context window limits (characters)
MAX_CONTEXT_LENGTH = int(os.environ.get("MAX_CONTEXT_LENGTH", "8000")) # Prevent unbounded growth
# NATS subjects (ai.* schema)
# Per-user channels matching companions-frontend pattern
REQUEST_SUBJECT = "ai.voice.user.*.request" # Wildcard subscription for all users
PREMIUM_REQUEST_SUBJECT = "ai.voice.premium.user.*.request" # Premium users
RESPONSE_SUBJECT = "ai.voice.response" # Response published to specific request_id
STREAM_RESPONSE_SUBJECT = "ai.voice.response.stream" # Streaming responses (token chunks)
# System prompt for the assistant
SYSTEM_PROMPT = """You are a helpful voice assistant.
Answer questions based on the provided context when available.
Keep responses concise and natural for speech synthesis.
If you don't know the answer, say so clearly."""
class VoiceAssistant:
def __init__(self):
self.nc = None
self.http_client = None
self.collection = None
self.valkey_client = None
self.running = True
self.tracer = None
self.meter = None
self.request_counter = None
self.request_duration = None
self.stt_duration = None
self.tts_duration = None
# MLflow inference tracker
self.mlflow_tracker = None
async def setup(self):
"""Initialize all connections."""
# Initialize OpenTelemetry
self.tracer, self.meter = setup_telemetry()
# Setup metrics
if self.meter:
self.request_counter = self.meter.create_counter(
"voice.requests",
description="Number of voice requests processed",
unit="1"
)
self.request_duration = self.meter.create_histogram(
"voice.request_duration",
description="Duration of voice request processing",
unit="s"
)
self.stt_duration = self.meter.create_histogram(
"voice.stt_duration",
description="Duration of speech-to-text processing",
unit="s"
)
self.tts_duration = self.meter.create_histogram(
"voice.tts_duration",
description="Duration of text-to-speech processing",
unit="s"
)
# Initialize MLflow inference tracker
if MLFLOW_ENABLED and MLFLOW_AVAILABLE:
try:
self.mlflow_tracker = InferenceMetricsTracker(
service_name="voice-assistant",
experiment_name="voice-inference",
tracking_uri=MLFLOW_TRACKING_URI,
batch_size=50,
flush_interval_seconds=60.0,
)
await self.mlflow_tracker.start()
logger.info(f"MLflow inference tracking enabled at {MLFLOW_TRACKING_URI}")
except Exception as e:
logger.warning(f"MLflow initialization failed: {e}, tracking disabled")
self.mlflow_tracker = None
elif not MLFLOW_AVAILABLE:
logger.info("MLflow utils not available, inference tracking disabled")
else:
logger.info("MLflow tracking disabled via MLFLOW_ENABLED=false")
# NATS connection
self.nc = await nats.connect(NATS_URL)
logger.info(f"Connected to NATS at {NATS_URL}")
# HTTP client for services
self.http_client = httpx.AsyncClient(timeout=180.0)
# Connect to Valkey for conversation history and context caching
try:
self.valkey_client = redis.from_url(
VALKEY_URL,
encoding="utf-8",
decode_responses=True,
socket_connect_timeout=5
)
await self.valkey_client.ping()
logger.info(f"Connected to Valkey at {VALKEY_URL}")
except Exception as e:
logger.warning(f"Valkey connection failed: {e}, conversation history disabled")
self.valkey_client = None
# Connect to Milvus if collection exists
try:
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
if utility.has_collection(COLLECTION_NAME):
self.collection = Collection(COLLECTION_NAME)
self.collection.load()
logger.info(f"Connected to Milvus collection: {COLLECTION_NAME}")
else:
logger.warning(f"Collection {COLLECTION_NAME} not found, RAG disabled")
except Exception as e:
logger.warning(f"Milvus connection failed: {e}, RAG disabled")
async def transcribe(self, audio_b64: str) -> str:
"""Transcribe audio using Whisper."""
try:
audio_bytes = base64.b64decode(audio_b64)
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
response = await self.http_client.post(
f"{WHISPER_URL}/v1/audio/transcriptions", files=files
)
result = response.json()
transcript = result.get("text", "")
logger.info(f"Transcribed: {transcript[:100]}...")
return transcript
except Exception as e:
logger.error(f"Transcription failed: {e}")
return ""
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get embeddings from the embedding service."""
try:
response = await self.http_client.post(
f"{EMBEDDINGS_URL}/embeddings", json={"input": texts, "model": "bge"}
)
result = response.json()
return [d["embedding"] for d in result.get("data", [])]
except Exception as e:
logger.error(f"Embedding failed: {e}")
return []
async def search_milvus(
self, query_embedding: List[float], top_k: int = 5
) -> List[Dict]:
"""Search Milvus for relevant documents."""
if not self.collection:
return []
try:
results = self.collection.search(
data=[query_embedding],
anns_field="embedding",
param={"metric_type": "COSINE", "params": {"ef": 64}},
limit=top_k,
output_fields=["text", "book_name", "page_num"],
)
docs = []
for hits in results:
for hit in hits:
docs.append(
{
"text": hit.entity.get("text", ""),
"source": f'{hit.entity.get("book_name", "")} p.{hit.entity.get("page_num", "")}',
"score": hit.score,
}
)
return docs
except Exception as e:
logger.error(f"Milvus search failed: {e}")
return []
async def rerank(self, query: str, documents: List[str]) -> List[Dict]:
"""Rerank documents using the reranker service."""
if not documents:
return []
try:
response = await self.http_client.post(
f"{RERANKER_URL}/v1/rerank",
json={"query": query, "documents": documents},
)
return response.json().get("results", [])
except Exception as e:
logger.error(f"Reranking failed: {e}")
return [{"index": i, "relevance_score": 0.5} for i in range(len(documents))]
async def get_conversation_history(self, session_id: str, max_messages: int = 10) -> List[Dict]:
"""Retrieve conversation history from Valkey."""
if not self.valkey_client or not session_id:
return []
try:
key = f"voice:history:{session_id}"
# Get the most recent messages (stored as a list)
history_json = await self.valkey_client.lrange(key, -max_messages, -1)
history = [json.loads(msg) for msg in history_json]
logger.info(f"Retrieved {len(history)} messages from history for session {session_id}")
return history
except Exception as e:
logger.error(f"Failed to get conversation history: {e}")
return []
async def save_message_to_history(self, session_id: str, role: str, content: str, ttl: int = 3600):
"""Save a message to conversation history in Valkey."""
if not self.valkey_client or not session_id:
return
try:
key = f"voice:history:{session_id}"
message = json.dumps({"role": role, "content": content, "timestamp": time.time()})
# Use RPUSH to append to the list
await self.valkey_client.rpush(key, message)
# Set TTL on the key (1 hour by default)
await self.valkey_client.expire(key, ttl)
logger.debug(f"Saved {role} message to history for session {session_id}")
except Exception as e:
logger.error(f"Failed to save message to history: {e}")
async def get_context_window(self, session_id: str) -> Optional[str]:
"""Retrieve cached context window from Valkey for attention offloading."""
if not self.valkey_client or not session_id:
return None
try:
key = f"voice:context:{session_id}"
context = await self.valkey_client.get(key)
if context:
logger.info(f"Retrieved cached context window for session {session_id}")
return context
except Exception as e:
logger.error(f"Failed to get context window: {e}")
return None
async def save_context_window(self, session_id: str, context: str, ttl: int = 3600):
"""Save context window to Valkey for attention offloading."""
if not self.valkey_client or not session_id:
return
try:
key = f"voice:context:{session_id}"
await self.valkey_client.set(key, context, ex=ttl)
logger.debug(f"Saved context window for session {session_id}")
except Exception as e:
logger.error(f"Failed to save context window: {e}")
async def generate_response(self, query: str, context: str = "", session_id: str = None) -> str:
"""Generate response using vLLM with conversation history from Valkey."""
try:
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
# Add conversation history from Valkey if session exists
if session_id:
history = await self.get_conversation_history(session_id)
messages.extend(history)
if context:
messages.append(
{
"role": "user",
"content": f"Context:\n{context}\n\nQuestion: {query}",
}
)
else:
messages.append({"role": "user", "content": query})
response = await self.http_client.post(
f"{VLLM_URL}/v1/chat/completions",
json={
"model": LLM_MODEL,
"messages": messages,
"max_tokens": 500,
"temperature": 0.7,
},
)
result = response.json()
answer = result["choices"][0]["message"]["content"]
logger.info(f"Generated response: {answer[:100]}...")
# Save messages to conversation history
if session_id:
await self.save_message_to_history(session_id, "user", query)
await self.save_message_to_history(session_id, "assistant", answer)
return answer
except Exception as e:
logger.error(f"LLM generation failed: {e}")
return "I'm sorry, I couldn't generate a response."
async def generate_response_streaming(self, query: str, context: str = "", request_id: str = "", session_id: str = None):
"""Generate streaming response using vLLM and publish chunks to NATS.
Yields tokens as they are generated and publishes them to NATS streaming subject.
Returns the complete response text.
"""
try:
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
# Add conversation history from Valkey if session exists
if session_id:
history = await self.get_conversation_history(session_id)
messages.extend(history)
if context:
messages.append(
{
"role": "user",
"content": f"Context:\n{context}\n\nQuestion: {query}",
}
)
else:
messages.append({"role": "user", "content": query})
full_response = ""
# Stream response from vLLM
async with self.http_client.stream(
"POST",
f"{VLLM_URL}/v1/chat/completions",
json={
"model": LLM_MODEL,
"messages": messages,
"max_tokens": 500,
"temperature": 0.7,
"stream": True,
},
timeout=60.0,
) as response:
# Parse SSE (Server-Sent Events) stream
async for line in response.aiter_lines():
if not line or not line.startswith("data: "):
continue
data_str = line[6:] # Remove "data: " prefix
if data_str.strip() == "[DONE]":
break
try:
chunk_data = json.loads(data_str)
# Extract token from delta
if chunk_data.get("choices") and len(chunk_data["choices"]) > 0:
delta = chunk_data["choices"][0].get("delta", {})
content = delta.get("content", "")
if content:
full_response += content
# Publish token chunk to NATS streaming subject
chunk_msg = {
"request_id": request_id,
"type": "chunk",
"content": content,
"done": False,
}
await self.nc.publish(
f"{STREAM_RESPONSE_SUBJECT}.{request_id}",
msgpack.packb(chunk_msg)
)
except json.JSONDecodeError:
continue
# Send completion message
completion_msg = {
"request_id": request_id,
"type": "done",
"content": "",
"done": True,
}
await self.nc.publish(
f"{STREAM_RESPONSE_SUBJECT}.{request_id}",
msgpack.packb(completion_msg)
)
logger.info(f"Streamed complete response ({len(full_response)} chars) for request {request_id}")
# Save messages to conversation history
if session_id:
await self.save_message_to_history(session_id, "user", query)
await self.save_message_to_history(session_id, "assistant", full_response)
return full_response
except Exception as e:
logger.error(f"Streaming LLM generation failed: {e}")
# Send error message
error_msg = {
"request_id": request_id,
"type": "error",
"content": "I'm sorry, I couldn't generate a response.",
"done": True,
"error": str(e),
}
await self.nc.publish(
f"{STREAM_RESPONSE_SUBJECT}.{request_id}",
msgpack.packb(error_msg)
)
return "I'm sorry, I couldn't generate a response."
async def synthesize_speech(self, text: str, language: str = "en") -> str:
"""Convert text to speech using XTTS (Coqui TTS)."""
try:
# XTTS API endpoint - uses /api/tts for synthesis
# The Coqui TTS server API accepts text and returns wav audio
response = await self.http_client.get(
f"{TTS_URL}/api/tts",
params={
"text": text,
"language_id": language,
# Optional: specify speaker_id for multi-speaker models
},
)
if response.status_code == 200:
audio_b64 = base64.b64encode(response.content).decode("utf-8")
logger.info(f"Synthesized {len(response.content)} bytes of audio")
return audio_b64
else:
logger.error(
f"TTS returned status {response.status_code}: {response.text}"
)
return ""
except Exception as e:
logger.error(f"TTS failed: {e}")
return ""
async def process_request(self, msg, is_premium=False):
"""Process a voice assistant request."""
start_time = time.time()
span = None
# MLflow metrics tracking
mlflow_metrics = None
stt_start = None
embedding_start = None
rag_start = None
rerank_start = None
llm_start = None
tts_start = None
try:
data = msgpack.unpackb(msg.data, raw=False)
request_id = data.get("request_id", "unknown")
audio_b64 = data.get("audio_b64", "")
user_id = data.get("user_id")
# Initialize MLflow metrics if available
if self.mlflow_tracker and MLFLOW_AVAILABLE:
mlflow_metrics = InferenceMetrics(
request_id=request_id,
user_id=user_id,
session_id=data.get("session_id"),
model_name=LLM_MODEL,
model_endpoint=VLLM_URL,
)
# Start tracing span
if self.tracer:
span = self.tracer.start_span("voice.process_request")
span.set_attribute("request_id", request_id)
span.set_attribute("user_id", user_id or "anonymous")
span.set_attribute("premium", is_premium)
# Support both new parameters and backward compatibility with use_rag
use_rag = data.get("use_rag") # Legacy parameter
enable_rag = data.get(
"enable_rag", use_rag if use_rag is not None else True
)
enable_reranker = data.get(
"enable_reranker", use_rag if use_rag is not None else True
)
enable_streaming = data.get("enable_streaming", False) # New parameter for streaming
# Premium channel retrieves more documents for deeper RAG
default_top_k = 15 if is_premium else 5
top_k = data.get("top_k", default_top_k)
language = data.get("language", "en")
session_id = data.get("session_id")
# Update MLflow metrics with request params
if mlflow_metrics:
mlflow_metrics.rag_enabled = enable_rag
mlflow_metrics.reranker_enabled = enable_reranker
mlflow_metrics.is_streaming = enable_streaming
mlflow_metrics.is_premium = is_premium
# Add attributes to span
if span:
span.set_attribute("enable_rag", enable_rag)
span.set_attribute("enable_reranker", enable_reranker)
span.set_attribute("enable_streaming", enable_streaming)
span.set_attribute("top_k", top_k)
logger.info(
f"Processing {'premium ' if is_premium else ''}voice request {request_id} (RAG: {enable_rag}, Reranker: {enable_reranker}, top_k: {top_k})"
)
# Warn if reranker is enabled without RAG
if enable_reranker and not enable_rag:
logger.warning(
f"Request {request_id}: Reranker enabled without RAG - no documents to rerank"
)
# Step 1: Transcribe audio
stt_start = time.time()
transcript = await self.transcribe(audio_b64)
if mlflow_metrics:
mlflow_metrics.stt_latency = time.time() - stt_start
mlflow_metrics.prompt_length = len(transcript) if transcript else 0
if not transcript:
if mlflow_metrics:
mlflow_metrics.has_error = True
mlflow_metrics.error_message = "Transcription failed"
await self.publish_error(request_id, "Transcription failed")
return
context = ""
rag_sources = []
docs = []
# Step 2: RAG retrieval (if enabled)
if enable_rag and self.collection:
# Get embeddings
embedding_start = time.time()
embeddings = await self.get_embeddings([transcript])
if mlflow_metrics:
mlflow_metrics.embedding_latency = time.time() - embedding_start
if embeddings:
# Search Milvus with configurable top_k
rag_start = time.time()
docs = await self.search_milvus(embeddings[0], top_k=top_k)
if mlflow_metrics:
mlflow_metrics.rag_search_latency = time.time() - rag_start
mlflow_metrics.rag_documents_retrieved = len(docs)
if docs:
rag_sources = [d.get("source", "") for d in docs]
# Step 3: Reranking (if enabled and we have documents)
if enable_reranker and docs:
# Rerank documents
rerank_start = time.time()
doc_texts = [d["text"] for d in docs]
reranked = await self.rerank(transcript, doc_texts)
if mlflow_metrics:
mlflow_metrics.rerank_latency = time.time() - rerank_start
# Take top 3 reranked documents with bounds checking
sorted_docs = sorted(
reranked, key=lambda x: x.get("relevance_score", 0), reverse=True
)[:3]
# Build context with bounds checking
# Note: doc_texts and docs have the same length (doc_texts derived from docs)
context_parts = []
sources = []
for item in sorted_docs:
idx = item.get("index", -1)
if 0 <= idx < len(docs):
context_parts.append(doc_texts[idx])
sources.append(docs[idx].get("source", ""))
else:
logger.warning(
f"Reranker returned invalid index {idx} for {len(docs)} docs"
)
context = "\n\n".join(context_parts)
rag_sources = sources
elif docs:
# Use documents without reranking (take top 3)
doc_texts = [d["text"] for d in docs[:3]]
context = "\n\n".join(doc_texts)
rag_sources = [d.get("source", "") for d in docs[:3]]
# Step 4: Generate response (streaming or non-streaming)
# Check for cached context window from Valkey (for attention offloading)
cached_context = None
if session_id:
cached_context = await self.get_context_window(session_id)
# Combine RAG context with cached context if available
if cached_context and context:
# Prepend cached context to current RAG context
combined_context = f"{cached_context}\n\n{context}"
# Truncate to prevent unbounded growth
if len(combined_context) > MAX_CONTEXT_LENGTH:
logger.warning(f"Context length {len(combined_context)} exceeds max {MAX_CONTEXT_LENGTH}, truncating")
# Keep the most recent context (from the end)
combined_context = combined_context[-MAX_CONTEXT_LENGTH:]
context = combined_context
elif cached_context:
# Only cached context, still need to check length
if len(cached_context) > MAX_CONTEXT_LENGTH:
logger.warning(f"Cached context length {len(cached_context)} exceeds max {MAX_CONTEXT_LENGTH}, truncating")
cached_context = cached_context[-MAX_CONTEXT_LENGTH:]
context = cached_context
# Save the combined context for future use (already truncated if needed)
if session_id and context:
await self.save_context_window(session_id, context)
# Track number of RAG docs used after reranking
if mlflow_metrics and enable_rag:
mlflow_metrics.rag_documents_used = min(3, len(docs)) if docs else 0
llm_start = time.time()
if enable_streaming:
# Use streaming response
answer = await self.generate_response_streaming(transcript, context, request_id, session_id)
else:
# Use non-streaming response
answer = await self.generate_response(transcript, context, session_id)
if mlflow_metrics:
mlflow_metrics.llm_latency = time.time() - llm_start
mlflow_metrics.response_length = len(answer)
# Estimate token counts (rough approximation: 4 chars per token)
mlflow_metrics.input_tokens = len(transcript) // 4
mlflow_metrics.output_tokens = len(answer) // 4
mlflow_metrics.total_tokens = mlflow_metrics.input_tokens + mlflow_metrics.output_tokens
# Step 5: Synthesize speech
tts_start = time.time()
audio_response = await self.synthesize_speech(answer, language)
if mlflow_metrics:
mlflow_metrics.tts_latency = time.time() - tts_start
# Publish result
result = {
"request_id": request_id,
"user_id": user_id,
"transcript": transcript,
"response_text": answer,
"audio_b64": audio_response,
"used_rag": bool(context),
"rag_enabled": enable_rag,
"reranker_enabled": enable_reranker,
"rag_sources": rag_sources,
"success": True,
}
await self.nc.publish(
f"{RESPONSE_SUBJECT}.{request_id}", msgpack.packb(result)
)
logger.info(f"Published response for request {request_id}")
# Record metrics
duration = time.time() - start_time
if self.request_counter:
self.request_counter.add(1, {"premium": str(is_premium), "rag_enabled": str(enable_rag), "success": "true"})
if self.request_duration:
self.request_duration.record(duration, {"premium": str(is_premium), "rag_enabled": str(enable_rag)})
if span:
span.set_attribute("success", True)
span.set_attribute("response_length", len(answer))
span.set_attribute("transcript_length", len(transcript))
# Log to MLflow
if self.mlflow_tracker and mlflow_metrics:
mlflow_metrics.total_latency = duration
await self.mlflow_tracker.log_inference(mlflow_metrics)
except Exception as e:
logger.error(f"Request processing failed: {e}")
if self.request_counter:
self.request_counter.add(1, {"premium": str(is_premium), "success": "false"})
if span:
span.set_attribute("success", False)
span.set_attribute("error", str(e))
# Log error to MLflow
if self.mlflow_tracker and mlflow_metrics:
mlflow_metrics.has_error = True
mlflow_metrics.error_message = str(e)
mlflow_metrics.total_latency = time.time() - start_time
await self.mlflow_tracker.log_inference(mlflow_metrics)
await self.publish_error(data.get("request_id", "unknown"), str(e))
finally:
if span:
span.end()
async def publish_error(self, request_id: str, error: str):
"""Publish an error response."""
result = {"request_id": request_id, "error": error, "success": False}
await self.nc.publish(
f"{RESPONSE_SUBJECT}.{request_id}", msgpack.packb(result)
)
async def process_premium_request(self, msg):
"""Process a premium voice request (wrapper for deeper RAG)."""
await self.process_request(msg, is_premium=True)
async def run(self):
"""Main run loop."""
await self.setup()
# Subscribe to standard voice requests
sub = await self.nc.subscribe(REQUEST_SUBJECT, cb=self.process_request)
logger.info(f"Subscribed to {REQUEST_SUBJECT}")
# Subscribe to premium voice requests (deeper RAG retrieval)
premium_sub = await self.nc.subscribe(
PREMIUM_REQUEST_SUBJECT, cb=self.process_premium_request
)
logger.info(f"Subscribed to {PREMIUM_REQUEST_SUBJECT}")
# Handle shutdown
def signal_handler():
self.running = False
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
# Keep running
while self.running:
await asyncio.sleep(1)
# Cleanup
await sub.unsubscribe()
await premium_sub.unsubscribe()
await self.nc.close()
if self.valkey_client:
await self.valkey_client.close()
if self.collection:
connections.disconnect("default")
if self.mlflow_tracker:
await self.mlflow_tracker.stop()
logger.info("Shutdown complete")
if __name__ == "__main__":
assistant = VoiceAssistant()
asyncio.run(assistant.run())

234
voice_assistant_v2.py Normal file
View File

@@ -0,0 +1,234 @@
#!/usr/bin/env python3
"""
Voice Assistant Service (Refactored)
End-to-end voice assistant pipeline using handler-base:
1. Listen for audio on NATS subject "voice.request"
2. Transcribe with Whisper (STT)
3. Generate embeddings for RAG
4. Retrieve context from Milvus
5. Rerank with BGE reranker
6. Generate response with vLLM
7. Synthesize speech with XTTS
8. Publish result to NATS "voice.response.{request_id}"
"""
import base64
import logging
from typing import Any, Optional
from nats.aio.msg import Msg
from handler_base import Handler, Settings
from handler_base.clients import (
EmbeddingsClient,
RerankerClient,
LLMClient,
TTSClient,
STTClient,
MilvusClient,
)
from handler_base.telemetry import create_span
logger = logging.getLogger("voice-assistant")
class VoiceSettings(Settings):
"""Voice assistant specific settings."""
service_name: str = "voice-assistant"
# RAG settings
rag_top_k: int = 10
rag_rerank_top_k: int = 5
rag_collection: str = "documents"
# Audio settings
stt_language: Optional[str] = None # Auto-detect
tts_language: str = "en"
# Response settings
include_transcription: bool = True
include_sources: bool = False
class VoiceAssistant(Handler):
"""
Voice request handler with full STT -> RAG -> LLM -> TTS pipeline.
Request format (msgpack):
{
"request_id": "uuid",
"audio": "base64 encoded audio",
"language": "optional language code",
"collection": "optional collection name"
}
Response format:
{
"request_id": "uuid",
"transcription": "what the user said",
"response": "generated text response",
"audio": "base64 encoded response audio"
}
"""
def __init__(self):
self.voice_settings = VoiceSettings()
super().__init__(
subject="voice.request",
settings=self.voice_settings,
queue_group="voice-assistants",
)
async def setup(self) -> None:
"""Initialize service clients."""
logger.info("Initializing voice assistant clients...")
self.stt = STTClient(self.voice_settings)
self.embeddings = EmbeddingsClient(self.voice_settings)
self.reranker = RerankerClient(self.voice_settings)
self.llm = LLMClient(self.voice_settings)
self.tts = TTSClient(self.voice_settings)
self.milvus = MilvusClient(self.voice_settings)
await self.milvus.connect(self.voice_settings.rag_collection)
logger.info("Voice assistant clients initialized")
async def teardown(self) -> None:
"""Clean up service clients."""
logger.info("Closing voice assistant clients...")
await self.stt.close()
await self.embeddings.close()
await self.reranker.close()
await self.llm.close()
await self.tts.close()
await self.milvus.close()
logger.info("Voice assistant clients closed")
async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]:
"""Handle incoming voice request."""
request_id = data.get("request_id", "unknown")
audio_b64 = data.get("audio", "")
language = data.get("language", self.voice_settings.stt_language)
collection = data.get("collection", self.voice_settings.rag_collection)
logger.info(f"Processing voice request {request_id}")
with create_span("voice.process") as span:
if span:
span.set_attribute("request.id", request_id)
# 1. Decode audio
audio_bytes = base64.b64decode(audio_b64)
# 2. Transcribe audio to text
transcription = await self._transcribe(audio_bytes, language)
query = transcription.get("text", "")
if not query.strip():
logger.warning(f"Empty transcription for request {request_id}")
return {
"request_id": request_id,
"error": "Could not transcribe audio",
}
logger.info(f"Transcribed: {query[:50]}...")
# 3. Generate query embedding
embedding = await self._get_embedding(query)
# 4. Search Milvus for context
documents = await self._search_context(embedding, collection)
# 5. Rerank documents
reranked = await self._rerank_documents(query, documents)
# 6. Build context
context = self._build_context(reranked)
# 7. Generate LLM response
response_text = await self._generate_response(query, context)
# 8. Synthesize speech
response_audio = await self._synthesize_speech(response_text)
# Build response
result = {
"request_id": request_id,
"response": response_text,
"audio": response_audio,
}
if self.voice_settings.include_transcription:
result["transcription"] = query
if self.voice_settings.include_sources:
result["sources"] = [
{"text": d["document"][:200], "score": d["score"]}
for d in reranked[:3]
]
logger.info(f"Completed voice request {request_id}")
# Publish to response subject
response_subject = f"voice.response.{request_id}"
await self.nats.publish(response_subject, result)
return result
async def _transcribe(
self, audio: bytes, language: Optional[str]
) -> dict:
"""Transcribe audio to text."""
with create_span("voice.stt"):
return await self.stt.transcribe(audio, language=language)
async def _get_embedding(self, text: str) -> list[float]:
"""Generate embedding for query text."""
with create_span("voice.embedding"):
return await self.embeddings.embed_single(text)
async def _search_context(
self, embedding: list[float], collection: str
) -> list[dict]:
"""Search Milvus for relevant documents."""
with create_span("voice.search"):
return await self.milvus.search_with_texts(
embedding,
limit=self.voice_settings.rag_top_k,
text_field="text",
)
async def _rerank_documents(
self, query: str, documents: list[dict]
) -> list[dict]:
"""Rerank documents by relevance."""
with create_span("voice.rerank"):
texts = [d.get("text", "") for d in documents]
return await self.reranker.rerank(
query, texts, top_k=self.voice_settings.rag_rerank_top_k
)
def _build_context(self, documents: list[dict]) -> str:
"""Build context string from ranked documents."""
return "\n\n".join(d.get("document", "") for d in documents)
async def _generate_response(self, query: str, context: str) -> str:
"""Generate LLM response."""
with create_span("voice.generate"):
return await self.llm.generate(query, context=context)
async def _synthesize_speech(self, text: str) -> str:
"""Synthesize speech and return base64."""
with create_span("voice.tts"):
audio_bytes = await self.tts.synthesize(
text, language=self.voice_settings.tts_language
)
return base64.b64encode(audio_bytes).decode()
if __name__ == "__main__":
VoiceAssistant().run()