#!/usr/bin/env python3 """ Chat Handler Service Text-based chat pipeline: 1. Listen for text on NATS subject "ai.chat.request" 2. Generate embeddings for RAG (optional) 3. Retrieve context from Milvus 4. Rerank with BGE reranker 5. Generate response with vLLM 6. Optionally synthesize speech with XTTS 7. Publish result to NATS "ai.chat.response.{request_id}" """ import asyncio import base64 import json import logging import os import signal import subprocess import sys import time from typing import List, Dict, Optional # Install dependencies on startup subprocess.check_call([ sys.executable, "-m", "pip", "install", "-q", "--root-user-action=ignore", "-r", "/app/requirements.txt" ]) import httpx import msgpack import nats import redis.asyncio as redis from pymilvus import connections, Collection, utility # OpenTelemetry imports from opentelemetry import trace, metrics from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor from opentelemetry.instrumentation.logging import LoggingInstrumentor # MLflow inference tracking try: from mlflow_utils import InferenceMetricsTracker from mlflow_utils.inference_tracker import InferenceMetrics MLFLOW_AVAILABLE = True except ImportError: MLFLOW_AVAILABLE = False InferenceMetricsTracker = None InferenceMetrics = None # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger("chat-handler") def setup_telemetry(): """Initialize OpenTelemetry tracing and metrics.""" otel_enabled = os.environ.get("OTEL_ENABLED", "true").lower() == "true" if not otel_enabled: logger.info("OpenTelemetry disabled") return None, None otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317") service_name = os.environ.get("OTEL_SERVICE_NAME", "chat-handler") service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml") # HyperDX configuration hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "") hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io") use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key resource = Resource.create({ SERVICE_NAME: service_name, SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"), SERVICE_NAMESPACE: service_namespace, "deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"), "host.name": os.environ.get("HOSTNAME", "unknown"), }) trace_provider = TracerProvider(resource=resource) if use_hyperdx: logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}") headers = {"authorization": hyperdx_api_key} otlp_span_exporter = OTLPSpanExporterHTTP( endpoint=f"{hyperdx_endpoint}/v1/traces", headers=headers ) otlp_metric_exporter = OTLPMetricExporterHTTP( endpoint=f"{hyperdx_endpoint}/v1/metrics", headers=headers ) else: otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True) otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True) trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter)) trace.set_tracer_provider(trace_provider) metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000) meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) metrics.set_meter_provider(meter_provider) HTTPXClientInstrumentor().instrument() LoggingInstrumentor().instrument(set_logging_format=True) destination = "HyperDX" if use_hyperdx else "OTEL Collector" logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}") return trace.get_tracer(__name__), metrics.get_meter(__name__) # Configuration from environment TTS_URL = os.environ.get("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local") EMBEDDINGS_URL = os.environ.get( "EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local" ) RERANKER_URL = os.environ.get( "RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local" ) VLLM_URL = os.environ.get("VLLM_URL", "http://llm-draft.ai-ml.svc.cluster.local:8000") LLM_MODEL = os.environ.get("LLM_MODEL", "mistralai/Mistral-7B-Instruct-v0.3") MILVUS_HOST = os.environ.get("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local") MILVUS_PORT = int(os.environ.get("MILVUS_PORT", "19530")) COLLECTION_NAME = os.environ.get("COLLECTION_NAME", "knowledge_base") NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222") VALKEY_URL = os.environ.get("VALKEY_URL", "redis://valkey.ai-ml.svc.cluster.local:6379") # MLflow configuration MLFLOW_ENABLED = os.environ.get("MLFLOW_ENABLED", "true").lower() == "true" MLFLOW_TRACKING_URI = os.environ.get( "MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80" ) # Context window limits (characters) MAX_CONTEXT_LENGTH = int(os.environ.get("MAX_CONTEXT_LENGTH", "8000")) # Prevent unbounded growth # NATS subjects (ai.* schema) # Per-user channels matching companions-frontend pattern REQUEST_SUBJECT = "ai.chat.user.*.message" # Wildcard subscription for all users PREMIUM_REQUEST_SUBJECT = "ai.chat.premium.user.*.message" # Premium users RESPONSE_SUBJECT = "ai.chat.response" # Response published to specific request_id STREAM_RESPONSE_SUBJECT = "ai.chat.response.stream" # Streaming responses (token chunks) # System prompt for the assistant SYSTEM_PROMPT = """You are a helpful AI assistant. Answer questions based on the provided context when available. Be concise and informative. If you don't know the answer, say so clearly.""" class ChatHandler: def __init__(self): self.nc = None self.http_client = None self.collection = None self.valkey_client = None self.running = True self.tracer = None self.meter = None self.request_counter = None self.request_duration = None self.rag_search_duration = None # MLflow inference tracker self.mlflow_tracker = None async def setup(self): """Initialize all connections.""" # Initialize OpenTelemetry self.tracer, self.meter = setup_telemetry() # Setup metrics if self.meter: self.request_counter = self.meter.create_counter( "chat.requests", description="Number of chat requests processed", unit="1" ) self.request_duration = self.meter.create_histogram( "chat.request_duration", description="Duration of chat request processing", unit="s" ) self.rag_search_duration = self.meter.create_histogram( "chat.rag_search_duration", description="Duration of RAG search operations", unit="s" ) # Initialize MLflow inference tracker if MLFLOW_ENABLED and MLFLOW_AVAILABLE: try: self.mlflow_tracker = InferenceMetricsTracker( service_name="chat-handler", experiment_name="chat-inference", tracking_uri=MLFLOW_TRACKING_URI, batch_size=50, flush_interval_seconds=60.0, ) await self.mlflow_tracker.start() logger.info(f"MLflow inference tracking enabled at {MLFLOW_TRACKING_URI}") except Exception as e: logger.warning(f"MLflow initialization failed: {e}, tracking disabled") self.mlflow_tracker = None elif not MLFLOW_AVAILABLE: logger.info("MLflow utils not available, inference tracking disabled") else: logger.info("MLflow tracking disabled via MLFLOW_ENABLED=false") # NATS connection with reconnection support async def disconnected_cb(): logger.warning("NATS disconnected, attempting reconnection...") async def reconnected_cb(): logger.info(f"NATS reconnected to {self.nc.connected_url.netloc}") async def error_cb(e): logger.error(f"NATS error: {e}") async def closed_cb(): logger.warning("NATS connection closed") self.nc = await nats.connect( NATS_URL, reconnect_time_wait=2, max_reconnect_attempts=-1, # Infinite reconnection attempts disconnected_cb=disconnected_cb, reconnected_cb=reconnected_cb, error_cb=error_cb, closed_cb=closed_cb, ) logger.info(f"Connected to NATS at {NATS_URL}") # HTTP client for services self.http_client = httpx.AsyncClient(timeout=180.0) # Connect to Valkey for conversation history and context caching try: self.valkey_client = redis.from_url( VALKEY_URL, encoding="utf-8", decode_responses=True, socket_connect_timeout=5 ) await self.valkey_client.ping() logger.info(f"Connected to Valkey at {VALKEY_URL}") except Exception as e: logger.warning(f"Valkey connection failed: {e}, conversation history disabled") self.valkey_client = None # Connect to Milvus if collection exists try: connections.connect(host=MILVUS_HOST, port=MILVUS_PORT) if utility.has_collection(COLLECTION_NAME): self.collection = Collection(COLLECTION_NAME) self.collection.load() logger.info(f"Connected to Milvus collection: {COLLECTION_NAME}") else: logger.warning(f"Collection {COLLECTION_NAME} not found, RAG disabled") except Exception as e: logger.warning(f"Milvus connection failed: {e}, RAG disabled") async def get_embeddings(self, texts: List[str]) -> List[List[float]]: """Get embeddings from the embedding service.""" try: response = await self.http_client.post( f"{EMBEDDINGS_URL}/embeddings", json={"input": texts, "model": "bge"} ) result = response.json() return [d["embedding"] for d in result.get("data", [])] except Exception as e: logger.error(f"Embedding failed: {e}") return [] async def search_milvus( self, query_embedding: List[float], top_k: int = 5 ) -> List[Dict]: """Search Milvus for relevant documents.""" if not self.collection: return [] try: results = self.collection.search( data=[query_embedding], anns_field="embedding", param={"metric_type": "COSINE", "params": {"ef": 64}}, limit=top_k, output_fields=["text", "book_name", "page_num"], ) docs = [] for hits in results: for hit in hits: docs.append( { "text": hit.entity.get("text", ""), "source": f'{hit.entity.get("book_name", "")} p.{hit.entity.get("page_num", "")}', "score": hit.score, } ) return docs except Exception as e: logger.error(f"Milvus search failed: {e}") return [] async def rerank(self, query: str, documents: List[str]) -> List[Dict]: """Rerank documents using the reranker service.""" if not documents: return [] try: response = await self.http_client.post( f"{RERANKER_URL}/v1/rerank", json={"query": query, "documents": documents}, ) return response.json().get("results", []) except Exception as e: logger.error(f"Reranking failed: {e}") return [{"index": i, "relevance_score": 0.5} for i in range(len(documents))] async def get_conversation_history(self, session_id: str, max_messages: int = 10) -> List[Dict]: """Retrieve conversation history from Valkey.""" if not self.valkey_client or not session_id: return [] try: key = f"chat:history:{session_id}" # Get the most recent messages (stored as a list) history_json = await self.valkey_client.lrange(key, -max_messages, -1) history = [json.loads(msg) for msg in history_json] logger.info(f"Retrieved {len(history)} messages from history for session {session_id}") return history except Exception as e: logger.error(f"Failed to get conversation history: {e}") return [] async def save_message_to_history(self, session_id: str, role: str, content: str, ttl: int = 3600): """Save a message to conversation history in Valkey.""" if not self.valkey_client or not session_id: return try: key = f"chat:history:{session_id}" message = json.dumps({"role": role, "content": content, "timestamp": time.time()}) # Use RPUSH to append to the list await self.valkey_client.rpush(key, message) # Set TTL on the key (1 hour by default) await self.valkey_client.expire(key, ttl) logger.debug(f"Saved {role} message to history for session {session_id}") except Exception as e: logger.error(f"Failed to save message to history: {e}") async def get_context_window(self, session_id: str) -> Optional[str]: """Retrieve cached context window from Valkey for attention offloading.""" if not self.valkey_client or not session_id: return None try: key = f"chat:context:{session_id}" context = await self.valkey_client.get(key) if context: logger.info(f"Retrieved cached context window for session {session_id}") return context except Exception as e: logger.error(f"Failed to get context window: {e}") return None async def save_context_window(self, session_id: str, context: str, ttl: int = 3600): """Save context window to Valkey for attention offloading.""" if not self.valkey_client or not session_id: return try: key = f"chat:context:{session_id}" await self.valkey_client.set(key, context, ex=ttl) logger.debug(f"Saved context window for session {session_id}") except Exception as e: logger.error(f"Failed to save context window: {e}") async def generate_response(self, query: str, context: str = "", session_id: str = None) -> str: """Generate response using vLLM with conversation history from Valkey.""" try: messages = [{"role": "system", "content": SYSTEM_PROMPT}] # Add conversation history from Valkey if session exists if session_id: history = await self.get_conversation_history(session_id) messages.extend(history) if context: messages.append( { "role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}", } ) else: messages.append({"role": "user", "content": query}) response = await self.http_client.post( f"{VLLM_URL}/v1/chat/completions", json={ "model": LLM_MODEL, "messages": messages, "max_tokens": 1000, "temperature": 0.7, }, ) result = response.json() answer = result["choices"][0]["message"]["content"] logger.info(f"Generated response: {answer[:100]}...") # Save messages to conversation history if session_id: await self.save_message_to_history(session_id, "user", query) await self.save_message_to_history(session_id, "assistant", answer) return answer except Exception as e: logger.error(f"LLM generation failed: {e}") return "I'm sorry, I couldn't generate a response." async def generate_response_streaming(self, query: str, context: str = "", request_id: str = "", session_id: str = None): """Generate streaming response using vLLM and publish chunks to NATS. Yields tokens as they are generated and publishes them to NATS streaming subject. Returns the complete response text. """ try: messages = [{"role": "system", "content": SYSTEM_PROMPT}] # Add conversation history from Valkey if session exists if session_id: history = await self.get_conversation_history(session_id) messages.extend(history) if context: messages.append( { "role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}", } ) else: messages.append({"role": "user", "content": query}) full_response = "" # Stream response from vLLM async with self.http_client.stream( "POST", f"{VLLM_URL}/v1/chat/completions", json={ "model": LLM_MODEL, "messages": messages, "max_tokens": 1000, "temperature": 0.7, "stream": True, }, timeout=60.0, ) as response: # Parse SSE (Server-Sent Events) stream async for line in response.aiter_lines(): if not line or not line.startswith("data: "): continue data_str = line[6:] # Remove "data: " prefix if data_str.strip() == "[DONE]": break try: chunk_data = json.loads(data_str) # Extract token from delta if chunk_data.get("choices") and len(chunk_data["choices"]) > 0: delta = chunk_data["choices"][0].get("delta", {}) content = delta.get("content", "") if content: full_response += content # Publish token chunk to NATS streaming subject chunk_msg = { "request_id": request_id, "type": "chunk", "content": content, "done": False, } await self.nc.publish( f"{STREAM_RESPONSE_SUBJECT}.{request_id}", msgpack.packb(chunk_msg) ) except json.JSONDecodeError: continue # Send completion message completion_msg = { "request_id": request_id, "type": "done", "content": "", "done": True, } await self.nc.publish( f"{STREAM_RESPONSE_SUBJECT}.{request_id}", msgpack.packb(completion_msg) ) logger.info(f"Streamed complete response ({len(full_response)} chars) for request {request_id}") # Save messages to conversation history if session_id: await self.save_message_to_history(session_id, "user", query) await self.save_message_to_history(session_id, "assistant", full_response) return full_response except Exception as e: logger.error(f"Streaming LLM generation failed: {e}") # Send error message error_msg = { "request_id": request_id, "type": "error", "content": "I'm sorry, I couldn't generate a response.", "done": True, "error": str(e), } await self.nc.publish( f"{STREAM_RESPONSE_SUBJECT}.{request_id}", msgpack.packb(error_msg) ) return "I'm sorry, I couldn't generate a response." async def synthesize_speech(self, text: str, language: str = "en") -> str: """Convert text to speech using XTTS (Coqui TTS).""" try: response = await self.http_client.get( f"{TTS_URL}/api/tts", params={"text": text, "language_id": language} ) if response.status_code == 200: audio_b64 = base64.b64encode(response.content).decode("utf-8") logger.info(f"Synthesized {len(response.content)} bytes of audio") return audio_b64 else: logger.error( f"TTS returned status {response.status_code}: {response.text}" ) return "" except Exception as e: logger.error(f"TTS failed: {e}") return "" async def process_request(self, msg, is_premium=False): """Process a chat request.""" start_time = time.time() span = None # MLflow metrics tracking mlflow_metrics = None embedding_start = None rag_start = None rerank_start = None llm_start = None try: data = msgpack.unpackb(msg.data, raw=False) # Support companions-frontend format (user_id, username, message, premium) # as well as the original format (request_id, text, enable_rag, etc.) user_id = data.get("user_id") username = data.get("username", "") # Get text from either 'message' (companions-frontend) or 'text' (original) text = data.get("message") or data.get("text", "") # Generate request_id from user_id if not provided import uuid request_id = data.get("request_id") or f"{user_id or 'anon'}-{uuid.uuid4().hex[:8]}" # Initialize MLflow metrics if available if self.mlflow_tracker and MLFLOW_AVAILABLE: mlflow_metrics = InferenceMetrics( request_id=request_id, user_id=user_id, session_id=data.get("session_id"), model_name=LLM_MODEL, model_endpoint=VLLM_URL, ) # Start tracing span if self.tracer: span = self.tracer.start_span("chat.process_request") span.set_attribute("request_id", request_id) span.set_attribute("user_id", user_id or "anonymous") span.set_attribute("premium", is_premium) # Premium status from message or channel is_premium = is_premium or data.get("premium", False) # Support both new parameters and backward compatibility with use_rag use_rag = data.get("use_rag") # Legacy parameter enable_rag = data.get( "enable_rag", use_rag if use_rag is not None else True ) enable_reranker = data.get( "enable_reranker", use_rag if use_rag is not None else True ) # Premium users get more documents for deeper RAG default_top_k = 15 if is_premium else 5 top_k = data.get("top_k", default_top_k) # Get request parameters enable_tts = data.get("enable_tts", False) enable_streaming = data.get("enable_streaming", False) # New parameter for streaming language = data.get("language", "en") session_id = data.get("session_id") # Update MLflow metrics with request params if mlflow_metrics: mlflow_metrics.rag_enabled = enable_rag mlflow_metrics.reranker_enabled = enable_reranker mlflow_metrics.is_streaming = enable_streaming mlflow_metrics.is_premium = is_premium mlflow_metrics.prompt_length = len(text) # Add attributes to span if span: span.set_attribute("enable_rag", enable_rag) span.set_attribute("enable_reranker", enable_reranker) span.set_attribute("top_k", top_k) span.set_attribute("enable_tts", enable_tts) span.set_attribute("enable_streaming", enable_streaming) logger.info( f"Processing {'premium ' if is_premium else ''}chat request {request_id} from {username or user_id or 'anonymous'}: {text[:50]}... (RAG: {enable_rag}, Reranker: {enable_reranker}, top_k: {top_k})" ) # Warn if reranker is enabled without RAG if enable_reranker and not enable_rag: logger.warning( f"Request {request_id}: Reranker enabled without RAG - no documents to rerank" ) if not text: await self.publish_error(request_id, "No text provided") return context = "" rag_sources = [] docs = [] # Step 1: RAG retrieval (if enabled) if enable_rag and self.collection: # Get embeddings for RAG embedding_start = time.time() embeddings = await self.get_embeddings([text]) if mlflow_metrics: mlflow_metrics.embedding_latency = time.time() - embedding_start if embeddings: # Search Milvus with configurable top_k rag_start = time.time() docs = await self.search_milvus(embeddings[0], top_k=top_k) if mlflow_metrics: mlflow_metrics.rag_search_latency = time.time() - rag_start mlflow_metrics.rag_documents_retrieved = len(docs) if docs: rag_sources = [d.get("source", "") for d in docs] # Step 2: Reranking (if enabled and we have documents) if enable_reranker and docs: # Rerank documents rerank_start = time.time() doc_texts = [d["text"] for d in docs] reranked = await self.rerank(text, doc_texts) if mlflow_metrics: mlflow_metrics.rerank_latency = time.time() - rerank_start # Take top 3 reranked documents with bounds checking sorted_docs = sorted( reranked, key=lambda x: x.get("relevance_score", 0), reverse=True )[:3] # Build context with bounds checking # Note: doc_texts and docs have the same length (doc_texts derived from docs) context_parts = [] sources = [] for item in sorted_docs: idx = item.get("index", -1) if 0 <= idx < len(docs): context_parts.append(doc_texts[idx]) sources.append(docs[idx].get("source", "")) else: logger.warning( f"Reranker returned invalid index {idx} for {len(docs)} docs" ) context = "\n\n".join(context_parts) rag_sources = sources elif docs: # Use documents without reranking (take top 3) doc_texts = [d["text"] for d in docs[:3]] context = "\n\n".join(doc_texts) rag_sources = [d.get("source", "") for d in docs[:3]] # Step 3: Generate response (streaming or non-streaming) # Check for cached context window from Valkey (for attention offloading) cached_context = None if session_id: cached_context = await self.get_context_window(session_id) # Combine RAG context with cached context if available if cached_context and context: # Prepend cached context to current RAG context combined_context = f"{cached_context}\n\n{context}" # Truncate to prevent unbounded growth if len(combined_context) > MAX_CONTEXT_LENGTH: logger.warning(f"Context length {len(combined_context)} exceeds max {MAX_CONTEXT_LENGTH}, truncating") # Keep the most recent context (from the end) combined_context = combined_context[-MAX_CONTEXT_LENGTH:] context = combined_context elif cached_context: # Only cached context, still need to check length if len(cached_context) > MAX_CONTEXT_LENGTH: logger.warning(f"Cached context length {len(cached_context)} exceeds max {MAX_CONTEXT_LENGTH}, truncating") cached_context = cached_context[-MAX_CONTEXT_LENGTH:] context = cached_context # Save the combined context for future use (already truncated if needed) if session_id and context: await self.save_context_window(session_id, context) # Track number of RAG docs used after reranking if mlflow_metrics and enable_rag: mlflow_metrics.rag_documents_used = min(3, len(docs)) if docs else 0 llm_start = time.time() if enable_streaming: # Use streaming response answer = await self.generate_response_streaming(text, context, request_id, session_id) else: # Use non-streaming response answer = await self.generate_response(text, context, session_id) if mlflow_metrics: mlflow_metrics.llm_latency = time.time() - llm_start mlflow_metrics.response_length = len(answer) # Estimate token counts (rough approximation: 4 chars per token) mlflow_metrics.input_tokens = len(text) // 4 mlflow_metrics.output_tokens = len(answer) // 4 mlflow_metrics.total_tokens = mlflow_metrics.input_tokens + mlflow_metrics.output_tokens # Step 4: Optionally synthesize speech audio_b64 = "" if enable_tts: audio_b64 = await self.synthesize_speech(answer, language) # Publish result # Include both 'response' (companions-frontend) and 'response_text' (original) for compatibility result = { "request_id": request_id, "user_id": user_id, "text": text, "response": answer, # companions-frontend expects 'response' "response_text": answer, # original format "audio_b64": audio_b64 if enable_tts else None, "used_rag": bool(context), "rag_enabled": enable_rag, "reranker_enabled": enable_reranker, "rag_sources": rag_sources, "session_id": session_id, "success": True, } await self.nc.publish( f"{RESPONSE_SUBJECT}.{request_id}", msgpack.packb(result) ) logger.info(f"Published response for request {request_id}") # Record metrics duration = time.time() - start_time if self.request_counter: self.request_counter.add(1, {"premium": str(is_premium), "rag_enabled": str(enable_rag), "success": "true"}) if self.request_duration: self.request_duration.record(duration, {"premium": str(is_premium), "rag_enabled": str(enable_rag)}) if span: span.set_attribute("success", True) span.set_attribute("response_length", len(answer)) # Log to MLflow if self.mlflow_tracker and mlflow_metrics: mlflow_metrics.total_latency = duration await self.mlflow_tracker.log_inference(mlflow_metrics) except Exception as e: logger.error(f"Request processing failed: {e}") if self.request_counter: self.request_counter.add(1, {"premium": str(is_premium), "success": "false"}) if span: span.set_attribute("success", False) span.set_attribute("error", str(e)) # Log error to MLflow if self.mlflow_tracker and mlflow_metrics: mlflow_metrics.has_error = True mlflow_metrics.error_message = str(e) mlflow_metrics.total_latency = time.time() - start_time await self.mlflow_tracker.log_inference(mlflow_metrics) await self.publish_error(data.get("request_id", "unknown"), str(e)) finally: if span: span.end() async def publish_error(self, request_id: str, error: str): """Publish an error response.""" result = {"request_id": request_id, "error": error, "success": False} await self.nc.publish( f"{RESPONSE_SUBJECT}.{request_id}", msgpack.packb(result) ) async def process_premium_request(self, msg): """Process a premium chat request (wrapper for deeper RAG).""" await self.process_request(msg, is_premium=True) async def run(self): """Main run loop.""" await self.setup() # Subscribe to standard chat requests sub = await self.nc.subscribe(REQUEST_SUBJECT, cb=self.process_request) logger.info(f"Subscribed to {REQUEST_SUBJECT}") # Subscribe to premium chat requests (deeper RAG retrieval) premium_sub = await self.nc.subscribe( PREMIUM_REQUEST_SUBJECT, cb=self.process_premium_request ) logger.info(f"Subscribed to {PREMIUM_REQUEST_SUBJECT}") # Handle shutdown def signal_handler(): self.running = False loop = asyncio.get_event_loop() for sig in (signal.SIGTERM, signal.SIGINT): loop.add_signal_handler(sig, signal_handler) # Keep running while self.running: await asyncio.sleep(1) # Cleanup await sub.unsubscribe() await premium_sub.unsubscribe() await self.nc.close() if self.valkey_client: await self.valkey_client.close() if self.collection: connections.disconnect("default") if self.mlflow_tracker: await self.mlflow_tracker.stop() logger.info("Shutdown complete") if __name__ == "__main__": handler = ChatHandler() asyncio.run(handler.run())