Files
voice-assistant/voice_assistant.py
Billy D. f0b626a5e7 feat: Add voice assistant handler and Kubeflow pipeline
- voice_assistant.py: Standalone NATS handler with full RAG pipeline
- voice_assistant_v2.py: Handler-base implementation
- pipelines/voice_pipeline.py: KFP SDK pipeline definitions
- Dockerfiles for both standalone and handler-base versions

Pipeline: STT → Embeddings → Milvus → Rerank → LLM → TTS
2026-02-01 20:32:37 -05:00

876 lines
36 KiB
Python

#!/usr/bin/env python3
"""
Voice Assistant Service
End-to-end voice assistant pipeline:
1. Listen for audio on NATS subject "voice.request"
2. Transcribe with Whisper (STT)
3. Generate embeddings for RAG
4. Retrieve context from Milvus
5. Rerank with BGE reranker
6. Generate response with vLLM
7. Synthesize speech with XTTS
8. Publish result to NATS "voice.response"
"""
import asyncio
import base64
import json
import logging
import os
import signal
import subprocess
import sys
import time
from typing import List, Dict, Optional
# Install dependencies on startup
subprocess.check_call([
sys.executable, "-m", "pip", "install", "-q",
"-r", "/app/requirements.txt"
])
import httpx
import msgpack
import nats
import redis.asyncio as redis
from pymilvus import connections, Collection, utility
# OpenTelemetry imports
from opentelemetry import trace, metrics
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP
from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
from opentelemetry.instrumentation.logging import LoggingInstrumentor
# MLflow inference tracking
try:
from mlflow_utils import InferenceMetricsTracker
from mlflow_utils.inference_tracker import InferenceMetrics
MLFLOW_AVAILABLE = True
except ImportError:
MLFLOW_AVAILABLE = False
InferenceMetricsTracker = None
InferenceMetrics = None
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("voice-assistant")
def setup_telemetry():
"""Initialize OpenTelemetry tracing and metrics."""
otel_enabled = os.environ.get("OTEL_ENABLED", "true").lower() == "true"
if not otel_enabled:
logger.info("OpenTelemetry disabled")
return None, None
otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317")
service_name = os.environ.get("OTEL_SERVICE_NAME", "voice-assistant")
service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml")
# HyperDX configuration
hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "")
hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io")
use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key
resource = Resource.create({
SERVICE_NAME: service_name,
SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"),
SERVICE_NAMESPACE: service_namespace,
"deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"),
"host.name": os.environ.get("HOSTNAME", "unknown"),
})
trace_provider = TracerProvider(resource=resource)
if use_hyperdx:
logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}")
headers = {"authorization": hyperdx_api_key}
otlp_span_exporter = OTLPSpanExporterHTTP(
endpoint=f"{hyperdx_endpoint}/v1/traces",
headers=headers
)
otlp_metric_exporter = OTLPMetricExporterHTTP(
endpoint=f"{hyperdx_endpoint}/v1/metrics",
headers=headers
)
else:
otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True)
otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True)
trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter))
trace.set_tracer_provider(trace_provider)
metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000)
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
metrics.set_meter_provider(meter_provider)
HTTPXClientInstrumentor().instrument()
LoggingInstrumentor().instrument(set_logging_format=True)
destination = "HyperDX" if use_hyperdx else "OTEL Collector"
logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}")
return trace.get_tracer(__name__), metrics.get_meter(__name__)
# Configuration from environment
WHISPER_URL = os.environ.get(
"WHISPER_URL", "http://whisper-predictor.ai-ml.svc.cluster.local"
)
TTS_URL = os.environ.get("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local:5002")
EMBEDDINGS_URL = os.environ.get(
"EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"
)
RERANKER_URL = os.environ.get(
"RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"
)
VLLM_URL = os.environ.get("VLLM_URL", "http://llm-draft.ai-ml.svc.cluster.local:8000")
LLM_MODEL = os.environ.get("LLM_MODEL", "mistralai/Mistral-7B-Instruct-v0.3")
MILVUS_HOST = os.environ.get("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local")
MILVUS_PORT = int(os.environ.get("MILVUS_PORT", "19530"))
COLLECTION_NAME = os.environ.get("COLLECTION_NAME", "knowledge_base")
NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222")
VALKEY_URL = os.environ.get("VALKEY_URL", "redis://valkey.ai-ml.svc.cluster.local:6379")
# MLflow configuration
MLFLOW_ENABLED = os.environ.get("MLFLOW_ENABLED", "true").lower() == "true"
MLFLOW_TRACKING_URI = os.environ.get(
"MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"
)
# Context window limits (characters)
MAX_CONTEXT_LENGTH = int(os.environ.get("MAX_CONTEXT_LENGTH", "8000")) # Prevent unbounded growth
# NATS subjects (ai.* schema)
# Per-user channels matching companions-frontend pattern
REQUEST_SUBJECT = "ai.voice.user.*.request" # Wildcard subscription for all users
PREMIUM_REQUEST_SUBJECT = "ai.voice.premium.user.*.request" # Premium users
RESPONSE_SUBJECT = "ai.voice.response" # Response published to specific request_id
STREAM_RESPONSE_SUBJECT = "ai.voice.response.stream" # Streaming responses (token chunks)
# System prompt for the assistant
SYSTEM_PROMPT = """You are a helpful voice assistant.
Answer questions based on the provided context when available.
Keep responses concise and natural for speech synthesis.
If you don't know the answer, say so clearly."""
class VoiceAssistant:
def __init__(self):
self.nc = None
self.http_client = None
self.collection = None
self.valkey_client = None
self.running = True
self.tracer = None
self.meter = None
self.request_counter = None
self.request_duration = None
self.stt_duration = None
self.tts_duration = None
# MLflow inference tracker
self.mlflow_tracker = None
async def setup(self):
"""Initialize all connections."""
# Initialize OpenTelemetry
self.tracer, self.meter = setup_telemetry()
# Setup metrics
if self.meter:
self.request_counter = self.meter.create_counter(
"voice.requests",
description="Number of voice requests processed",
unit="1"
)
self.request_duration = self.meter.create_histogram(
"voice.request_duration",
description="Duration of voice request processing",
unit="s"
)
self.stt_duration = self.meter.create_histogram(
"voice.stt_duration",
description="Duration of speech-to-text processing",
unit="s"
)
self.tts_duration = self.meter.create_histogram(
"voice.tts_duration",
description="Duration of text-to-speech processing",
unit="s"
)
# Initialize MLflow inference tracker
if MLFLOW_ENABLED and MLFLOW_AVAILABLE:
try:
self.mlflow_tracker = InferenceMetricsTracker(
service_name="voice-assistant",
experiment_name="voice-inference",
tracking_uri=MLFLOW_TRACKING_URI,
batch_size=50,
flush_interval_seconds=60.0,
)
await self.mlflow_tracker.start()
logger.info(f"MLflow inference tracking enabled at {MLFLOW_TRACKING_URI}")
except Exception as e:
logger.warning(f"MLflow initialization failed: {e}, tracking disabled")
self.mlflow_tracker = None
elif not MLFLOW_AVAILABLE:
logger.info("MLflow utils not available, inference tracking disabled")
else:
logger.info("MLflow tracking disabled via MLFLOW_ENABLED=false")
# NATS connection
self.nc = await nats.connect(NATS_URL)
logger.info(f"Connected to NATS at {NATS_URL}")
# HTTP client for services
self.http_client = httpx.AsyncClient(timeout=180.0)
# Connect to Valkey for conversation history and context caching
try:
self.valkey_client = redis.from_url(
VALKEY_URL,
encoding="utf-8",
decode_responses=True,
socket_connect_timeout=5
)
await self.valkey_client.ping()
logger.info(f"Connected to Valkey at {VALKEY_URL}")
except Exception as e:
logger.warning(f"Valkey connection failed: {e}, conversation history disabled")
self.valkey_client = None
# Connect to Milvus if collection exists
try:
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
if utility.has_collection(COLLECTION_NAME):
self.collection = Collection(COLLECTION_NAME)
self.collection.load()
logger.info(f"Connected to Milvus collection: {COLLECTION_NAME}")
else:
logger.warning(f"Collection {COLLECTION_NAME} not found, RAG disabled")
except Exception as e:
logger.warning(f"Milvus connection failed: {e}, RAG disabled")
async def transcribe(self, audio_b64: str) -> str:
"""Transcribe audio using Whisper."""
try:
audio_bytes = base64.b64decode(audio_b64)
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
response = await self.http_client.post(
f"{WHISPER_URL}/v1/audio/transcriptions", files=files
)
result = response.json()
transcript = result.get("text", "")
logger.info(f"Transcribed: {transcript[:100]}...")
return transcript
except Exception as e:
logger.error(f"Transcription failed: {e}")
return ""
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get embeddings from the embedding service."""
try:
response = await self.http_client.post(
f"{EMBEDDINGS_URL}/embeddings", json={"input": texts, "model": "bge"}
)
result = response.json()
return [d["embedding"] for d in result.get("data", [])]
except Exception as e:
logger.error(f"Embedding failed: {e}")
return []
async def search_milvus(
self, query_embedding: List[float], top_k: int = 5
) -> List[Dict]:
"""Search Milvus for relevant documents."""
if not self.collection:
return []
try:
results = self.collection.search(
data=[query_embedding],
anns_field="embedding",
param={"metric_type": "COSINE", "params": {"ef": 64}},
limit=top_k,
output_fields=["text", "book_name", "page_num"],
)
docs = []
for hits in results:
for hit in hits:
docs.append(
{
"text": hit.entity.get("text", ""),
"source": f'{hit.entity.get("book_name", "")} p.{hit.entity.get("page_num", "")}',
"score": hit.score,
}
)
return docs
except Exception as e:
logger.error(f"Milvus search failed: {e}")
return []
async def rerank(self, query: str, documents: List[str]) -> List[Dict]:
"""Rerank documents using the reranker service."""
if not documents:
return []
try:
response = await self.http_client.post(
f"{RERANKER_URL}/v1/rerank",
json={"query": query, "documents": documents},
)
return response.json().get("results", [])
except Exception as e:
logger.error(f"Reranking failed: {e}")
return [{"index": i, "relevance_score": 0.5} for i in range(len(documents))]
async def get_conversation_history(self, session_id: str, max_messages: int = 10) -> List[Dict]:
"""Retrieve conversation history from Valkey."""
if not self.valkey_client or not session_id:
return []
try:
key = f"voice:history:{session_id}"
# Get the most recent messages (stored as a list)
history_json = await self.valkey_client.lrange(key, -max_messages, -1)
history = [json.loads(msg) for msg in history_json]
logger.info(f"Retrieved {len(history)} messages from history for session {session_id}")
return history
except Exception as e:
logger.error(f"Failed to get conversation history: {e}")
return []
async def save_message_to_history(self, session_id: str, role: str, content: str, ttl: int = 3600):
"""Save a message to conversation history in Valkey."""
if not self.valkey_client or not session_id:
return
try:
key = f"voice:history:{session_id}"
message = json.dumps({"role": role, "content": content, "timestamp": time.time()})
# Use RPUSH to append to the list
await self.valkey_client.rpush(key, message)
# Set TTL on the key (1 hour by default)
await self.valkey_client.expire(key, ttl)
logger.debug(f"Saved {role} message to history for session {session_id}")
except Exception as e:
logger.error(f"Failed to save message to history: {e}")
async def get_context_window(self, session_id: str) -> Optional[str]:
"""Retrieve cached context window from Valkey for attention offloading."""
if not self.valkey_client or not session_id:
return None
try:
key = f"voice:context:{session_id}"
context = await self.valkey_client.get(key)
if context:
logger.info(f"Retrieved cached context window for session {session_id}")
return context
except Exception as e:
logger.error(f"Failed to get context window: {e}")
return None
async def save_context_window(self, session_id: str, context: str, ttl: int = 3600):
"""Save context window to Valkey for attention offloading."""
if not self.valkey_client or not session_id:
return
try:
key = f"voice:context:{session_id}"
await self.valkey_client.set(key, context, ex=ttl)
logger.debug(f"Saved context window for session {session_id}")
except Exception as e:
logger.error(f"Failed to save context window: {e}")
async def generate_response(self, query: str, context: str = "", session_id: str = None) -> str:
"""Generate response using vLLM with conversation history from Valkey."""
try:
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
# Add conversation history from Valkey if session exists
if session_id:
history = await self.get_conversation_history(session_id)
messages.extend(history)
if context:
messages.append(
{
"role": "user",
"content": f"Context:\n{context}\n\nQuestion: {query}",
}
)
else:
messages.append({"role": "user", "content": query})
response = await self.http_client.post(
f"{VLLM_URL}/v1/chat/completions",
json={
"model": LLM_MODEL,
"messages": messages,
"max_tokens": 500,
"temperature": 0.7,
},
)
result = response.json()
answer = result["choices"][0]["message"]["content"]
logger.info(f"Generated response: {answer[:100]}...")
# Save messages to conversation history
if session_id:
await self.save_message_to_history(session_id, "user", query)
await self.save_message_to_history(session_id, "assistant", answer)
return answer
except Exception as e:
logger.error(f"LLM generation failed: {e}")
return "I'm sorry, I couldn't generate a response."
async def generate_response_streaming(self, query: str, context: str = "", request_id: str = "", session_id: str = None):
"""Generate streaming response using vLLM and publish chunks to NATS.
Yields tokens as they are generated and publishes them to NATS streaming subject.
Returns the complete response text.
"""
try:
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
# Add conversation history from Valkey if session exists
if session_id:
history = await self.get_conversation_history(session_id)
messages.extend(history)
if context:
messages.append(
{
"role": "user",
"content": f"Context:\n{context}\n\nQuestion: {query}",
}
)
else:
messages.append({"role": "user", "content": query})
full_response = ""
# Stream response from vLLM
async with self.http_client.stream(
"POST",
f"{VLLM_URL}/v1/chat/completions",
json={
"model": LLM_MODEL,
"messages": messages,
"max_tokens": 500,
"temperature": 0.7,
"stream": True,
},
timeout=60.0,
) as response:
# Parse SSE (Server-Sent Events) stream
async for line in response.aiter_lines():
if not line or not line.startswith("data: "):
continue
data_str = line[6:] # Remove "data: " prefix
if data_str.strip() == "[DONE]":
break
try:
chunk_data = json.loads(data_str)
# Extract token from delta
if chunk_data.get("choices") and len(chunk_data["choices"]) > 0:
delta = chunk_data["choices"][0].get("delta", {})
content = delta.get("content", "")
if content:
full_response += content
# Publish token chunk to NATS streaming subject
chunk_msg = {
"request_id": request_id,
"type": "chunk",
"content": content,
"done": False,
}
await self.nc.publish(
f"{STREAM_RESPONSE_SUBJECT}.{request_id}",
msgpack.packb(chunk_msg)
)
except json.JSONDecodeError:
continue
# Send completion message
completion_msg = {
"request_id": request_id,
"type": "done",
"content": "",
"done": True,
}
await self.nc.publish(
f"{STREAM_RESPONSE_SUBJECT}.{request_id}",
msgpack.packb(completion_msg)
)
logger.info(f"Streamed complete response ({len(full_response)} chars) for request {request_id}")
# Save messages to conversation history
if session_id:
await self.save_message_to_history(session_id, "user", query)
await self.save_message_to_history(session_id, "assistant", full_response)
return full_response
except Exception as e:
logger.error(f"Streaming LLM generation failed: {e}")
# Send error message
error_msg = {
"request_id": request_id,
"type": "error",
"content": "I'm sorry, I couldn't generate a response.",
"done": True,
"error": str(e),
}
await self.nc.publish(
f"{STREAM_RESPONSE_SUBJECT}.{request_id}",
msgpack.packb(error_msg)
)
return "I'm sorry, I couldn't generate a response."
async def synthesize_speech(self, text: str, language: str = "en") -> str:
"""Convert text to speech using XTTS (Coqui TTS)."""
try:
# XTTS API endpoint - uses /api/tts for synthesis
# The Coqui TTS server API accepts text and returns wav audio
response = await self.http_client.get(
f"{TTS_URL}/api/tts",
params={
"text": text,
"language_id": language,
# Optional: specify speaker_id for multi-speaker models
},
)
if response.status_code == 200:
audio_b64 = base64.b64encode(response.content).decode("utf-8")
logger.info(f"Synthesized {len(response.content)} bytes of audio")
return audio_b64
else:
logger.error(
f"TTS returned status {response.status_code}: {response.text}"
)
return ""
except Exception as e:
logger.error(f"TTS failed: {e}")
return ""
async def process_request(self, msg, is_premium=False):
"""Process a voice assistant request."""
start_time = time.time()
span = None
# MLflow metrics tracking
mlflow_metrics = None
stt_start = None
embedding_start = None
rag_start = None
rerank_start = None
llm_start = None
tts_start = None
try:
data = msgpack.unpackb(msg.data, raw=False)
request_id = data.get("request_id", "unknown")
audio_b64 = data.get("audio_b64", "")
user_id = data.get("user_id")
# Initialize MLflow metrics if available
if self.mlflow_tracker and MLFLOW_AVAILABLE:
mlflow_metrics = InferenceMetrics(
request_id=request_id,
user_id=user_id,
session_id=data.get("session_id"),
model_name=LLM_MODEL,
model_endpoint=VLLM_URL,
)
# Start tracing span
if self.tracer:
span = self.tracer.start_span("voice.process_request")
span.set_attribute("request_id", request_id)
span.set_attribute("user_id", user_id or "anonymous")
span.set_attribute("premium", is_premium)
# Support both new parameters and backward compatibility with use_rag
use_rag = data.get("use_rag") # Legacy parameter
enable_rag = data.get(
"enable_rag", use_rag if use_rag is not None else True
)
enable_reranker = data.get(
"enable_reranker", use_rag if use_rag is not None else True
)
enable_streaming = data.get("enable_streaming", False) # New parameter for streaming
# Premium channel retrieves more documents for deeper RAG
default_top_k = 15 if is_premium else 5
top_k = data.get("top_k", default_top_k)
language = data.get("language", "en")
session_id = data.get("session_id")
# Update MLflow metrics with request params
if mlflow_metrics:
mlflow_metrics.rag_enabled = enable_rag
mlflow_metrics.reranker_enabled = enable_reranker
mlflow_metrics.is_streaming = enable_streaming
mlflow_metrics.is_premium = is_premium
# Add attributes to span
if span:
span.set_attribute("enable_rag", enable_rag)
span.set_attribute("enable_reranker", enable_reranker)
span.set_attribute("enable_streaming", enable_streaming)
span.set_attribute("top_k", top_k)
logger.info(
f"Processing {'premium ' if is_premium else ''}voice request {request_id} (RAG: {enable_rag}, Reranker: {enable_reranker}, top_k: {top_k})"
)
# Warn if reranker is enabled without RAG
if enable_reranker and not enable_rag:
logger.warning(
f"Request {request_id}: Reranker enabled without RAG - no documents to rerank"
)
# Step 1: Transcribe audio
stt_start = time.time()
transcript = await self.transcribe(audio_b64)
if mlflow_metrics:
mlflow_metrics.stt_latency = time.time() - stt_start
mlflow_metrics.prompt_length = len(transcript) if transcript else 0
if not transcript:
if mlflow_metrics:
mlflow_metrics.has_error = True
mlflow_metrics.error_message = "Transcription failed"
await self.publish_error(request_id, "Transcription failed")
return
context = ""
rag_sources = []
docs = []
# Step 2: RAG retrieval (if enabled)
if enable_rag and self.collection:
# Get embeddings
embedding_start = time.time()
embeddings = await self.get_embeddings([transcript])
if mlflow_metrics:
mlflow_metrics.embedding_latency = time.time() - embedding_start
if embeddings:
# Search Milvus with configurable top_k
rag_start = time.time()
docs = await self.search_milvus(embeddings[0], top_k=top_k)
if mlflow_metrics:
mlflow_metrics.rag_search_latency = time.time() - rag_start
mlflow_metrics.rag_documents_retrieved = len(docs)
if docs:
rag_sources = [d.get("source", "") for d in docs]
# Step 3: Reranking (if enabled and we have documents)
if enable_reranker and docs:
# Rerank documents
rerank_start = time.time()
doc_texts = [d["text"] for d in docs]
reranked = await self.rerank(transcript, doc_texts)
if mlflow_metrics:
mlflow_metrics.rerank_latency = time.time() - rerank_start
# Take top 3 reranked documents with bounds checking
sorted_docs = sorted(
reranked, key=lambda x: x.get("relevance_score", 0), reverse=True
)[:3]
# Build context with bounds checking
# Note: doc_texts and docs have the same length (doc_texts derived from docs)
context_parts = []
sources = []
for item in sorted_docs:
idx = item.get("index", -1)
if 0 <= idx < len(docs):
context_parts.append(doc_texts[idx])
sources.append(docs[idx].get("source", ""))
else:
logger.warning(
f"Reranker returned invalid index {idx} for {len(docs)} docs"
)
context = "\n\n".join(context_parts)
rag_sources = sources
elif docs:
# Use documents without reranking (take top 3)
doc_texts = [d["text"] for d in docs[:3]]
context = "\n\n".join(doc_texts)
rag_sources = [d.get("source", "") for d in docs[:3]]
# Step 4: Generate response (streaming or non-streaming)
# Check for cached context window from Valkey (for attention offloading)
cached_context = None
if session_id:
cached_context = await self.get_context_window(session_id)
# Combine RAG context with cached context if available
if cached_context and context:
# Prepend cached context to current RAG context
combined_context = f"{cached_context}\n\n{context}"
# Truncate to prevent unbounded growth
if len(combined_context) > MAX_CONTEXT_LENGTH:
logger.warning(f"Context length {len(combined_context)} exceeds max {MAX_CONTEXT_LENGTH}, truncating")
# Keep the most recent context (from the end)
combined_context = combined_context[-MAX_CONTEXT_LENGTH:]
context = combined_context
elif cached_context:
# Only cached context, still need to check length
if len(cached_context) > MAX_CONTEXT_LENGTH:
logger.warning(f"Cached context length {len(cached_context)} exceeds max {MAX_CONTEXT_LENGTH}, truncating")
cached_context = cached_context[-MAX_CONTEXT_LENGTH:]
context = cached_context
# Save the combined context for future use (already truncated if needed)
if session_id and context:
await self.save_context_window(session_id, context)
# Track number of RAG docs used after reranking
if mlflow_metrics and enable_rag:
mlflow_metrics.rag_documents_used = min(3, len(docs)) if docs else 0
llm_start = time.time()
if enable_streaming:
# Use streaming response
answer = await self.generate_response_streaming(transcript, context, request_id, session_id)
else:
# Use non-streaming response
answer = await self.generate_response(transcript, context, session_id)
if mlflow_metrics:
mlflow_metrics.llm_latency = time.time() - llm_start
mlflow_metrics.response_length = len(answer)
# Estimate token counts (rough approximation: 4 chars per token)
mlflow_metrics.input_tokens = len(transcript) // 4
mlflow_metrics.output_tokens = len(answer) // 4
mlflow_metrics.total_tokens = mlflow_metrics.input_tokens + mlflow_metrics.output_tokens
# Step 5: Synthesize speech
tts_start = time.time()
audio_response = await self.synthesize_speech(answer, language)
if mlflow_metrics:
mlflow_metrics.tts_latency = time.time() - tts_start
# Publish result
result = {
"request_id": request_id,
"user_id": user_id,
"transcript": transcript,
"response_text": answer,
"audio_b64": audio_response,
"used_rag": bool(context),
"rag_enabled": enable_rag,
"reranker_enabled": enable_reranker,
"rag_sources": rag_sources,
"success": True,
}
await self.nc.publish(
f"{RESPONSE_SUBJECT}.{request_id}", msgpack.packb(result)
)
logger.info(f"Published response for request {request_id}")
# Record metrics
duration = time.time() - start_time
if self.request_counter:
self.request_counter.add(1, {"premium": str(is_premium), "rag_enabled": str(enable_rag), "success": "true"})
if self.request_duration:
self.request_duration.record(duration, {"premium": str(is_premium), "rag_enabled": str(enable_rag)})
if span:
span.set_attribute("success", True)
span.set_attribute("response_length", len(answer))
span.set_attribute("transcript_length", len(transcript))
# Log to MLflow
if self.mlflow_tracker and mlflow_metrics:
mlflow_metrics.total_latency = duration
await self.mlflow_tracker.log_inference(mlflow_metrics)
except Exception as e:
logger.error(f"Request processing failed: {e}")
if self.request_counter:
self.request_counter.add(1, {"premium": str(is_premium), "success": "false"})
if span:
span.set_attribute("success", False)
span.set_attribute("error", str(e))
# Log error to MLflow
if self.mlflow_tracker and mlflow_metrics:
mlflow_metrics.has_error = True
mlflow_metrics.error_message = str(e)
mlflow_metrics.total_latency = time.time() - start_time
await self.mlflow_tracker.log_inference(mlflow_metrics)
await self.publish_error(data.get("request_id", "unknown"), str(e))
finally:
if span:
span.end()
async def publish_error(self, request_id: str, error: str):
"""Publish an error response."""
result = {"request_id": request_id, "error": error, "success": False}
await self.nc.publish(
f"{RESPONSE_SUBJECT}.{request_id}", msgpack.packb(result)
)
async def process_premium_request(self, msg):
"""Process a premium voice request (wrapper for deeper RAG)."""
await self.process_request(msg, is_premium=True)
async def run(self):
"""Main run loop."""
await self.setup()
# Subscribe to standard voice requests
sub = await self.nc.subscribe(REQUEST_SUBJECT, cb=self.process_request)
logger.info(f"Subscribed to {REQUEST_SUBJECT}")
# Subscribe to premium voice requests (deeper RAG retrieval)
premium_sub = await self.nc.subscribe(
PREMIUM_REQUEST_SUBJECT, cb=self.process_premium_request
)
logger.info(f"Subscribed to {PREMIUM_REQUEST_SUBJECT}")
# Handle shutdown
def signal_handler():
self.running = False
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
# Keep running
while self.running:
await asyncio.sleep(1)
# Cleanup
await sub.unsubscribe()
await premium_sub.unsubscribe()
await self.nc.close()
if self.valkey_client:
await self.valkey_client.close()
if self.collection:
connections.disconnect("default")
if self.mlflow_tracker:
await self.mlflow_tracker.stop()
logger.info("Shutdown complete")
if __name__ == "__main__":
assistant = VoiceAssistant()
asyncio.run(assistant.run())