#!/usr/bin/env python3 """ Voice Assistant Service End-to-end voice assistant pipeline: 1. Listen for audio on NATS subject "voice.request" 2. Transcribe with Whisper (STT) 3. Generate embeddings for RAG 4. Retrieve context from Milvus 5. Rerank with BGE reranker 6. Generate response with vLLM 7. Synthesize speech with XTTS 8. Publish result to NATS "voice.response" """ import asyncio import base64 import json import logging import os import signal import subprocess import sys import time from typing import List, Dict, Optional # Install dependencies on startup subprocess.check_call([ sys.executable, "-m", "pip", "install", "-q", "-r", "/app/requirements.txt" ]) import httpx import msgpack import nats import redis.asyncio as redis from pymilvus import connections, Collection, utility # OpenTelemetry imports from opentelemetry import trace, metrics from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor from opentelemetry.instrumentation.logging import LoggingInstrumentor # MLflow inference tracking try: from mlflow_utils import InferenceMetricsTracker from mlflow_utils.inference_tracker import InferenceMetrics MLFLOW_AVAILABLE = True except ImportError: MLFLOW_AVAILABLE = False InferenceMetricsTracker = None InferenceMetrics = None # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger("voice-assistant") def setup_telemetry(): """Initialize OpenTelemetry tracing and metrics.""" otel_enabled = os.environ.get("OTEL_ENABLED", "true").lower() == "true" if not otel_enabled: logger.info("OpenTelemetry disabled") return None, None otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317") service_name = os.environ.get("OTEL_SERVICE_NAME", "voice-assistant") service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml") # HyperDX configuration hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "") hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io") use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key resource = Resource.create({ SERVICE_NAME: service_name, SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"), SERVICE_NAMESPACE: service_namespace, "deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"), "host.name": os.environ.get("HOSTNAME", "unknown"), }) trace_provider = TracerProvider(resource=resource) if use_hyperdx: logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}") headers = {"authorization": hyperdx_api_key} otlp_span_exporter = OTLPSpanExporterHTTP( endpoint=f"{hyperdx_endpoint}/v1/traces", headers=headers ) otlp_metric_exporter = OTLPMetricExporterHTTP( endpoint=f"{hyperdx_endpoint}/v1/metrics", headers=headers ) else: otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True) otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True) trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter)) trace.set_tracer_provider(trace_provider) metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000) meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) metrics.set_meter_provider(meter_provider) HTTPXClientInstrumentor().instrument() LoggingInstrumentor().instrument(set_logging_format=True) destination = "HyperDX" if use_hyperdx else "OTEL Collector" logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}") return trace.get_tracer(__name__), metrics.get_meter(__name__) # Configuration from environment WHISPER_URL = os.environ.get( "WHISPER_URL", "http://whisper-predictor.ai-ml.svc.cluster.local" ) TTS_URL = os.environ.get("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local:5002") EMBEDDINGS_URL = os.environ.get( "EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local" ) RERANKER_URL = os.environ.get( "RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local" ) VLLM_URL = os.environ.get("VLLM_URL", "http://llm-draft.ai-ml.svc.cluster.local:8000") LLM_MODEL = os.environ.get("LLM_MODEL", "mistralai/Mistral-7B-Instruct-v0.3") MILVUS_HOST = os.environ.get("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local") MILVUS_PORT = int(os.environ.get("MILVUS_PORT", "19530")) COLLECTION_NAME = os.environ.get("COLLECTION_NAME", "knowledge_base") NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222") VALKEY_URL = os.environ.get("VALKEY_URL", "redis://valkey.ai-ml.svc.cluster.local:6379") # MLflow configuration MLFLOW_ENABLED = os.environ.get("MLFLOW_ENABLED", "true").lower() == "true" MLFLOW_TRACKING_URI = os.environ.get( "MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80" ) # Context window limits (characters) MAX_CONTEXT_LENGTH = int(os.environ.get("MAX_CONTEXT_LENGTH", "8000")) # Prevent unbounded growth # NATS subjects (ai.* schema) # Per-user channels matching companions-frontend pattern REQUEST_SUBJECT = "ai.voice.user.*.request" # Wildcard subscription for all users PREMIUM_REQUEST_SUBJECT = "ai.voice.premium.user.*.request" # Premium users RESPONSE_SUBJECT = "ai.voice.response" # Response published to specific request_id STREAM_RESPONSE_SUBJECT = "ai.voice.response.stream" # Streaming responses (token chunks) # System prompt for the assistant SYSTEM_PROMPT = """You are a helpful voice assistant. Answer questions based on the provided context when available. Keep responses concise and natural for speech synthesis. If you don't know the answer, say so clearly.""" class VoiceAssistant: def __init__(self): self.nc = None self.http_client = None self.collection = None self.valkey_client = None self.running = True self.tracer = None self.meter = None self.request_counter = None self.request_duration = None self.stt_duration = None self.tts_duration = None # MLflow inference tracker self.mlflow_tracker = None async def setup(self): """Initialize all connections.""" # Initialize OpenTelemetry self.tracer, self.meter = setup_telemetry() # Setup metrics if self.meter: self.request_counter = self.meter.create_counter( "voice.requests", description="Number of voice requests processed", unit="1" ) self.request_duration = self.meter.create_histogram( "voice.request_duration", description="Duration of voice request processing", unit="s" ) self.stt_duration = self.meter.create_histogram( "voice.stt_duration", description="Duration of speech-to-text processing", unit="s" ) self.tts_duration = self.meter.create_histogram( "voice.tts_duration", description="Duration of text-to-speech processing", unit="s" ) # Initialize MLflow inference tracker if MLFLOW_ENABLED and MLFLOW_AVAILABLE: try: self.mlflow_tracker = InferenceMetricsTracker( service_name="voice-assistant", experiment_name="voice-inference", tracking_uri=MLFLOW_TRACKING_URI, batch_size=50, flush_interval_seconds=60.0, ) await self.mlflow_tracker.start() logger.info(f"MLflow inference tracking enabled at {MLFLOW_TRACKING_URI}") except Exception as e: logger.warning(f"MLflow initialization failed: {e}, tracking disabled") self.mlflow_tracker = None elif not MLFLOW_AVAILABLE: logger.info("MLflow utils not available, inference tracking disabled") else: logger.info("MLflow tracking disabled via MLFLOW_ENABLED=false") # NATS connection self.nc = await nats.connect(NATS_URL) logger.info(f"Connected to NATS at {NATS_URL}") # HTTP client for services self.http_client = httpx.AsyncClient(timeout=180.0) # Connect to Valkey for conversation history and context caching try: self.valkey_client = redis.from_url( VALKEY_URL, encoding="utf-8", decode_responses=True, socket_connect_timeout=5 ) await self.valkey_client.ping() logger.info(f"Connected to Valkey at {VALKEY_URL}") except Exception as e: logger.warning(f"Valkey connection failed: {e}, conversation history disabled") self.valkey_client = None # Connect to Milvus if collection exists try: connections.connect(host=MILVUS_HOST, port=MILVUS_PORT) if utility.has_collection(COLLECTION_NAME): self.collection = Collection(COLLECTION_NAME) self.collection.load() logger.info(f"Connected to Milvus collection: {COLLECTION_NAME}") else: logger.warning(f"Collection {COLLECTION_NAME} not found, RAG disabled") except Exception as e: logger.warning(f"Milvus connection failed: {e}, RAG disabled") async def transcribe(self, audio_b64: str) -> str: """Transcribe audio using Whisper.""" try: audio_bytes = base64.b64decode(audio_b64) files = {"file": ("audio.wav", audio_bytes, "audio/wav")} response = await self.http_client.post( f"{WHISPER_URL}/v1/audio/transcriptions", files=files ) result = response.json() transcript = result.get("text", "") logger.info(f"Transcribed: {transcript[:100]}...") return transcript except Exception as e: logger.error(f"Transcription failed: {e}") return "" async def get_embeddings(self, texts: List[str]) -> List[List[float]]: """Get embeddings from the embedding service.""" try: response = await self.http_client.post( f"{EMBEDDINGS_URL}/embeddings", json={"input": texts, "model": "bge"} ) result = response.json() return [d["embedding"] for d in result.get("data", [])] except Exception as e: logger.error(f"Embedding failed: {e}") return [] async def search_milvus( self, query_embedding: List[float], top_k: int = 5 ) -> List[Dict]: """Search Milvus for relevant documents.""" if not self.collection: return [] try: results = self.collection.search( data=[query_embedding], anns_field="embedding", param={"metric_type": "COSINE", "params": {"ef": 64}}, limit=top_k, output_fields=["text", "book_name", "page_num"], ) docs = [] for hits in results: for hit in hits: docs.append( { "text": hit.entity.get("text", ""), "source": f'{hit.entity.get("book_name", "")} p.{hit.entity.get("page_num", "")}', "score": hit.score, } ) return docs except Exception as e: logger.error(f"Milvus search failed: {e}") return [] async def rerank(self, query: str, documents: List[str]) -> List[Dict]: """Rerank documents using the reranker service.""" if not documents: return [] try: response = await self.http_client.post( f"{RERANKER_URL}/v1/rerank", json={"query": query, "documents": documents}, ) return response.json().get("results", []) except Exception as e: logger.error(f"Reranking failed: {e}") return [{"index": i, "relevance_score": 0.5} for i in range(len(documents))] async def get_conversation_history(self, session_id: str, max_messages: int = 10) -> List[Dict]: """Retrieve conversation history from Valkey.""" if not self.valkey_client or not session_id: return [] try: key = f"voice:history:{session_id}" # Get the most recent messages (stored as a list) history_json = await self.valkey_client.lrange(key, -max_messages, -1) history = [json.loads(msg) for msg in history_json] logger.info(f"Retrieved {len(history)} messages from history for session {session_id}") return history except Exception as e: logger.error(f"Failed to get conversation history: {e}") return [] async def save_message_to_history(self, session_id: str, role: str, content: str, ttl: int = 3600): """Save a message to conversation history in Valkey.""" if not self.valkey_client or not session_id: return try: key = f"voice:history:{session_id}" message = json.dumps({"role": role, "content": content, "timestamp": time.time()}) # Use RPUSH to append to the list await self.valkey_client.rpush(key, message) # Set TTL on the key (1 hour by default) await self.valkey_client.expire(key, ttl) logger.debug(f"Saved {role} message to history for session {session_id}") except Exception as e: logger.error(f"Failed to save message to history: {e}") async def get_context_window(self, session_id: str) -> Optional[str]: """Retrieve cached context window from Valkey for attention offloading.""" if not self.valkey_client or not session_id: return None try: key = f"voice:context:{session_id}" context = await self.valkey_client.get(key) if context: logger.info(f"Retrieved cached context window for session {session_id}") return context except Exception as e: logger.error(f"Failed to get context window: {e}") return None async def save_context_window(self, session_id: str, context: str, ttl: int = 3600): """Save context window to Valkey for attention offloading.""" if not self.valkey_client or not session_id: return try: key = f"voice:context:{session_id}" await self.valkey_client.set(key, context, ex=ttl) logger.debug(f"Saved context window for session {session_id}") except Exception as e: logger.error(f"Failed to save context window: {e}") async def generate_response(self, query: str, context: str = "", session_id: str = None) -> str: """Generate response using vLLM with conversation history from Valkey.""" try: messages = [{"role": "system", "content": SYSTEM_PROMPT}] # Add conversation history from Valkey if session exists if session_id: history = await self.get_conversation_history(session_id) messages.extend(history) if context: messages.append( { "role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}", } ) else: messages.append({"role": "user", "content": query}) response = await self.http_client.post( f"{VLLM_URL}/v1/chat/completions", json={ "model": LLM_MODEL, "messages": messages, "max_tokens": 500, "temperature": 0.7, }, ) result = response.json() answer = result["choices"][0]["message"]["content"] logger.info(f"Generated response: {answer[:100]}...") # Save messages to conversation history if session_id: await self.save_message_to_history(session_id, "user", query) await self.save_message_to_history(session_id, "assistant", answer) return answer except Exception as e: logger.error(f"LLM generation failed: {e}") return "I'm sorry, I couldn't generate a response." async def generate_response_streaming(self, query: str, context: str = "", request_id: str = "", session_id: str = None): """Generate streaming response using vLLM and publish chunks to NATS. Yields tokens as they are generated and publishes them to NATS streaming subject. Returns the complete response text. """ try: messages = [{"role": "system", "content": SYSTEM_PROMPT}] # Add conversation history from Valkey if session exists if session_id: history = await self.get_conversation_history(session_id) messages.extend(history) if context: messages.append( { "role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}", } ) else: messages.append({"role": "user", "content": query}) full_response = "" # Stream response from vLLM async with self.http_client.stream( "POST", f"{VLLM_URL}/v1/chat/completions", json={ "model": LLM_MODEL, "messages": messages, "max_tokens": 500, "temperature": 0.7, "stream": True, }, timeout=60.0, ) as response: # Parse SSE (Server-Sent Events) stream async for line in response.aiter_lines(): if not line or not line.startswith("data: "): continue data_str = line[6:] # Remove "data: " prefix if data_str.strip() == "[DONE]": break try: chunk_data = json.loads(data_str) # Extract token from delta if chunk_data.get("choices") and len(chunk_data["choices"]) > 0: delta = chunk_data["choices"][0].get("delta", {}) content = delta.get("content", "") if content: full_response += content # Publish token chunk to NATS streaming subject chunk_msg = { "request_id": request_id, "type": "chunk", "content": content, "done": False, } await self.nc.publish( f"{STREAM_RESPONSE_SUBJECT}.{request_id}", msgpack.packb(chunk_msg) ) except json.JSONDecodeError: continue # Send completion message completion_msg = { "request_id": request_id, "type": "done", "content": "", "done": True, } await self.nc.publish( f"{STREAM_RESPONSE_SUBJECT}.{request_id}", msgpack.packb(completion_msg) ) logger.info(f"Streamed complete response ({len(full_response)} chars) for request {request_id}") # Save messages to conversation history if session_id: await self.save_message_to_history(session_id, "user", query) await self.save_message_to_history(session_id, "assistant", full_response) return full_response except Exception as e: logger.error(f"Streaming LLM generation failed: {e}") # Send error message error_msg = { "request_id": request_id, "type": "error", "content": "I'm sorry, I couldn't generate a response.", "done": True, "error": str(e), } await self.nc.publish( f"{STREAM_RESPONSE_SUBJECT}.{request_id}", msgpack.packb(error_msg) ) return "I'm sorry, I couldn't generate a response." async def synthesize_speech(self, text: str, language: str = "en") -> str: """Convert text to speech using XTTS (Coqui TTS).""" try: # XTTS API endpoint - uses /api/tts for synthesis # The Coqui TTS server API accepts text and returns wav audio response = await self.http_client.get( f"{TTS_URL}/api/tts", params={ "text": text, "language_id": language, # Optional: specify speaker_id for multi-speaker models }, ) if response.status_code == 200: audio_b64 = base64.b64encode(response.content).decode("utf-8") logger.info(f"Synthesized {len(response.content)} bytes of audio") return audio_b64 else: logger.error( f"TTS returned status {response.status_code}: {response.text}" ) return "" except Exception as e: logger.error(f"TTS failed: {e}") return "" async def process_request(self, msg, is_premium=False): """Process a voice assistant request.""" start_time = time.time() span = None # MLflow metrics tracking mlflow_metrics = None stt_start = None embedding_start = None rag_start = None rerank_start = None llm_start = None tts_start = None try: data = msgpack.unpackb(msg.data, raw=False) request_id = data.get("request_id", "unknown") audio_b64 = data.get("audio_b64", "") user_id = data.get("user_id") # Initialize MLflow metrics if available if self.mlflow_tracker and MLFLOW_AVAILABLE: mlflow_metrics = InferenceMetrics( request_id=request_id, user_id=user_id, session_id=data.get("session_id"), model_name=LLM_MODEL, model_endpoint=VLLM_URL, ) # Start tracing span if self.tracer: span = self.tracer.start_span("voice.process_request") span.set_attribute("request_id", request_id) span.set_attribute("user_id", user_id or "anonymous") span.set_attribute("premium", is_premium) # Support both new parameters and backward compatibility with use_rag use_rag = data.get("use_rag") # Legacy parameter enable_rag = data.get( "enable_rag", use_rag if use_rag is not None else True ) enable_reranker = data.get( "enable_reranker", use_rag if use_rag is not None else True ) enable_streaming = data.get("enable_streaming", False) # New parameter for streaming # Premium channel retrieves more documents for deeper RAG default_top_k = 15 if is_premium else 5 top_k = data.get("top_k", default_top_k) language = data.get("language", "en") session_id = data.get("session_id") # Update MLflow metrics with request params if mlflow_metrics: mlflow_metrics.rag_enabled = enable_rag mlflow_metrics.reranker_enabled = enable_reranker mlflow_metrics.is_streaming = enable_streaming mlflow_metrics.is_premium = is_premium # Add attributes to span if span: span.set_attribute("enable_rag", enable_rag) span.set_attribute("enable_reranker", enable_reranker) span.set_attribute("enable_streaming", enable_streaming) span.set_attribute("top_k", top_k) logger.info( f"Processing {'premium ' if is_premium else ''}voice request {request_id} (RAG: {enable_rag}, Reranker: {enable_reranker}, top_k: {top_k})" ) # Warn if reranker is enabled without RAG if enable_reranker and not enable_rag: logger.warning( f"Request {request_id}: Reranker enabled without RAG - no documents to rerank" ) # Step 1: Transcribe audio stt_start = time.time() transcript = await self.transcribe(audio_b64) if mlflow_metrics: mlflow_metrics.stt_latency = time.time() - stt_start mlflow_metrics.prompt_length = len(transcript) if transcript else 0 if not transcript: if mlflow_metrics: mlflow_metrics.has_error = True mlflow_metrics.error_message = "Transcription failed" await self.publish_error(request_id, "Transcription failed") return context = "" rag_sources = [] docs = [] # Step 2: RAG retrieval (if enabled) if enable_rag and self.collection: # Get embeddings embedding_start = time.time() embeddings = await self.get_embeddings([transcript]) if mlflow_metrics: mlflow_metrics.embedding_latency = time.time() - embedding_start if embeddings: # Search Milvus with configurable top_k rag_start = time.time() docs = await self.search_milvus(embeddings[0], top_k=top_k) if mlflow_metrics: mlflow_metrics.rag_search_latency = time.time() - rag_start mlflow_metrics.rag_documents_retrieved = len(docs) if docs: rag_sources = [d.get("source", "") for d in docs] # Step 3: Reranking (if enabled and we have documents) if enable_reranker and docs: # Rerank documents rerank_start = time.time() doc_texts = [d["text"] for d in docs] reranked = await self.rerank(transcript, doc_texts) if mlflow_metrics: mlflow_metrics.rerank_latency = time.time() - rerank_start # Take top 3 reranked documents with bounds checking sorted_docs = sorted( reranked, key=lambda x: x.get("relevance_score", 0), reverse=True )[:3] # Build context with bounds checking # Note: doc_texts and docs have the same length (doc_texts derived from docs) context_parts = [] sources = [] for item in sorted_docs: idx = item.get("index", -1) if 0 <= idx < len(docs): context_parts.append(doc_texts[idx]) sources.append(docs[idx].get("source", "")) else: logger.warning( f"Reranker returned invalid index {idx} for {len(docs)} docs" ) context = "\n\n".join(context_parts) rag_sources = sources elif docs: # Use documents without reranking (take top 3) doc_texts = [d["text"] for d in docs[:3]] context = "\n\n".join(doc_texts) rag_sources = [d.get("source", "") for d in docs[:3]] # Step 4: Generate response (streaming or non-streaming) # Check for cached context window from Valkey (for attention offloading) cached_context = None if session_id: cached_context = await self.get_context_window(session_id) # Combine RAG context with cached context if available if cached_context and context: # Prepend cached context to current RAG context combined_context = f"{cached_context}\n\n{context}" # Truncate to prevent unbounded growth if len(combined_context) > MAX_CONTEXT_LENGTH: logger.warning(f"Context length {len(combined_context)} exceeds max {MAX_CONTEXT_LENGTH}, truncating") # Keep the most recent context (from the end) combined_context = combined_context[-MAX_CONTEXT_LENGTH:] context = combined_context elif cached_context: # Only cached context, still need to check length if len(cached_context) > MAX_CONTEXT_LENGTH: logger.warning(f"Cached context length {len(cached_context)} exceeds max {MAX_CONTEXT_LENGTH}, truncating") cached_context = cached_context[-MAX_CONTEXT_LENGTH:] context = cached_context # Save the combined context for future use (already truncated if needed) if session_id and context: await self.save_context_window(session_id, context) # Track number of RAG docs used after reranking if mlflow_metrics and enable_rag: mlflow_metrics.rag_documents_used = min(3, len(docs)) if docs else 0 llm_start = time.time() if enable_streaming: # Use streaming response answer = await self.generate_response_streaming(transcript, context, request_id, session_id) else: # Use non-streaming response answer = await self.generate_response(transcript, context, session_id) if mlflow_metrics: mlflow_metrics.llm_latency = time.time() - llm_start mlflow_metrics.response_length = len(answer) # Estimate token counts (rough approximation: 4 chars per token) mlflow_metrics.input_tokens = len(transcript) // 4 mlflow_metrics.output_tokens = len(answer) // 4 mlflow_metrics.total_tokens = mlflow_metrics.input_tokens + mlflow_metrics.output_tokens # Step 5: Synthesize speech tts_start = time.time() audio_response = await self.synthesize_speech(answer, language) if mlflow_metrics: mlflow_metrics.tts_latency = time.time() - tts_start # Publish result result = { "request_id": request_id, "user_id": user_id, "transcript": transcript, "response_text": answer, "audio_b64": audio_response, "used_rag": bool(context), "rag_enabled": enable_rag, "reranker_enabled": enable_reranker, "rag_sources": rag_sources, "success": True, } await self.nc.publish( f"{RESPONSE_SUBJECT}.{request_id}", msgpack.packb(result) ) logger.info(f"Published response for request {request_id}") # Record metrics duration = time.time() - start_time if self.request_counter: self.request_counter.add(1, {"premium": str(is_premium), "rag_enabled": str(enable_rag), "success": "true"}) if self.request_duration: self.request_duration.record(duration, {"premium": str(is_premium), "rag_enabled": str(enable_rag)}) if span: span.set_attribute("success", True) span.set_attribute("response_length", len(answer)) span.set_attribute("transcript_length", len(transcript)) # Log to MLflow if self.mlflow_tracker and mlflow_metrics: mlflow_metrics.total_latency = duration await self.mlflow_tracker.log_inference(mlflow_metrics) except Exception as e: logger.error(f"Request processing failed: {e}") if self.request_counter: self.request_counter.add(1, {"premium": str(is_premium), "success": "false"}) if span: span.set_attribute("success", False) span.set_attribute("error", str(e)) # Log error to MLflow if self.mlflow_tracker and mlflow_metrics: mlflow_metrics.has_error = True mlflow_metrics.error_message = str(e) mlflow_metrics.total_latency = time.time() - start_time await self.mlflow_tracker.log_inference(mlflow_metrics) await self.publish_error(data.get("request_id", "unknown"), str(e)) finally: if span: span.end() async def publish_error(self, request_id: str, error: str): """Publish an error response.""" result = {"request_id": request_id, "error": error, "success": False} await self.nc.publish( f"{RESPONSE_SUBJECT}.{request_id}", msgpack.packb(result) ) async def process_premium_request(self, msg): """Process a premium voice request (wrapper for deeper RAG).""" await self.process_request(msg, is_premium=True) async def run(self): """Main run loop.""" await self.setup() # Subscribe to standard voice requests sub = await self.nc.subscribe(REQUEST_SUBJECT, cb=self.process_request) logger.info(f"Subscribed to {REQUEST_SUBJECT}") # Subscribe to premium voice requests (deeper RAG retrieval) premium_sub = await self.nc.subscribe( PREMIUM_REQUEST_SUBJECT, cb=self.process_premium_request ) logger.info(f"Subscribed to {PREMIUM_REQUEST_SUBJECT}") # Handle shutdown def signal_handler(): self.running = False loop = asyncio.get_event_loop() for sig in (signal.SIGTERM, signal.SIGINT): loop.add_signal_handler(sig, signal_handler) # Keep running while self.running: await asyncio.sleep(1) # Cleanup await sub.unsubscribe() await premium_sub.unsubscribe() await self.nc.close() if self.valkey_client: await self.valkey_client.close() if self.collection: connections.disconnect("default") if self.mlflow_tracker: await self.mlflow_tracker.stop() logger.info("Shutdown complete") if __name__ == "__main__": assistant = VoiceAssistant() asyncio.run(assistant.run())