feat: Add chat handler with RAG pipeline

- chat_handler.py: Standalone NATS handler with RAG
- chat_handler_v2.py: Handler-base implementation
- Dockerfiles for both versions

Pipeline: Embeddings → Milvus → Rerank → LLM → (optional TTS)
This commit is contained in:
2026-02-01 20:37:34 -05:00
parent cf859ead4e
commit 6ef42b3d2c
7 changed files with 1290 additions and 1 deletions

867
chat_handler.py Normal file
View File

@@ -0,0 +1,867 @@
#!/usr/bin/env python3
"""
Chat Handler Service
Text-based chat pipeline:
1. Listen for text on NATS subject "ai.chat.request"
2. Generate embeddings for RAG (optional)
3. Retrieve context from Milvus
4. Rerank with BGE reranker
5. Generate response with vLLM
6. Optionally synthesize speech with XTTS
7. Publish result to NATS "ai.chat.response.{request_id}"
"""
import asyncio
import base64
import json
import logging
import os
import signal
import subprocess
import sys
import time
from typing import List, Dict, Optional
# Install dependencies on startup
subprocess.check_call([
sys.executable, "-m", "pip", "install", "-q",
"--root-user-action=ignore",
"-r", "/app/requirements.txt"
])
import httpx
import msgpack
import nats
import redis.asyncio as redis
from pymilvus import connections, Collection, utility
# OpenTelemetry imports
from opentelemetry import trace, metrics
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP
from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
from opentelemetry.instrumentation.logging import LoggingInstrumentor
# MLflow inference tracking
try:
from mlflow_utils import InferenceMetricsTracker
from mlflow_utils.inference_tracker import InferenceMetrics
MLFLOW_AVAILABLE = True
except ImportError:
MLFLOW_AVAILABLE = False
InferenceMetricsTracker = None
InferenceMetrics = None
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("chat-handler")
def setup_telemetry():
"""Initialize OpenTelemetry tracing and metrics."""
otel_enabled = os.environ.get("OTEL_ENABLED", "true").lower() == "true"
if not otel_enabled:
logger.info("OpenTelemetry disabled")
return None, None
otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317")
service_name = os.environ.get("OTEL_SERVICE_NAME", "chat-handler")
service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml")
# HyperDX configuration
hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "")
hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io")
use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key
resource = Resource.create({
SERVICE_NAME: service_name,
SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"),
SERVICE_NAMESPACE: service_namespace,
"deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"),
"host.name": os.environ.get("HOSTNAME", "unknown"),
})
trace_provider = TracerProvider(resource=resource)
if use_hyperdx:
logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}")
headers = {"authorization": hyperdx_api_key}
otlp_span_exporter = OTLPSpanExporterHTTP(
endpoint=f"{hyperdx_endpoint}/v1/traces",
headers=headers
)
otlp_metric_exporter = OTLPMetricExporterHTTP(
endpoint=f"{hyperdx_endpoint}/v1/metrics",
headers=headers
)
else:
otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True)
otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True)
trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter))
trace.set_tracer_provider(trace_provider)
metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000)
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
metrics.set_meter_provider(meter_provider)
HTTPXClientInstrumentor().instrument()
LoggingInstrumentor().instrument(set_logging_format=True)
destination = "HyperDX" if use_hyperdx else "OTEL Collector"
logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}")
return trace.get_tracer(__name__), metrics.get_meter(__name__)
# Configuration from environment
TTS_URL = os.environ.get("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local")
EMBEDDINGS_URL = os.environ.get(
"EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local"
)
RERANKER_URL = os.environ.get(
"RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local"
)
VLLM_URL = os.environ.get("VLLM_URL", "http://llm-draft.ai-ml.svc.cluster.local:8000")
LLM_MODEL = os.environ.get("LLM_MODEL", "mistralai/Mistral-7B-Instruct-v0.3")
MILVUS_HOST = os.environ.get("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local")
MILVUS_PORT = int(os.environ.get("MILVUS_PORT", "19530"))
COLLECTION_NAME = os.environ.get("COLLECTION_NAME", "knowledge_base")
NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222")
VALKEY_URL = os.environ.get("VALKEY_URL", "redis://valkey.ai-ml.svc.cluster.local:6379")
# MLflow configuration
MLFLOW_ENABLED = os.environ.get("MLFLOW_ENABLED", "true").lower() == "true"
MLFLOW_TRACKING_URI = os.environ.get(
"MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"
)
# Context window limits (characters)
MAX_CONTEXT_LENGTH = int(os.environ.get("MAX_CONTEXT_LENGTH", "8000")) # Prevent unbounded growth
# NATS subjects (ai.* schema)
# Per-user channels matching companions-frontend pattern
REQUEST_SUBJECT = "ai.chat.user.*.message" # Wildcard subscription for all users
PREMIUM_REQUEST_SUBJECT = "ai.chat.premium.user.*.message" # Premium users
RESPONSE_SUBJECT = "ai.chat.response" # Response published to specific request_id
STREAM_RESPONSE_SUBJECT = "ai.chat.response.stream" # Streaming responses (token chunks)
# System prompt for the assistant
SYSTEM_PROMPT = """You are a helpful AI assistant.
Answer questions based on the provided context when available.
Be concise and informative. If you don't know the answer, say so clearly."""
class ChatHandler:
def __init__(self):
self.nc = None
self.http_client = None
self.collection = None
self.valkey_client = None
self.running = True
self.tracer = None
self.meter = None
self.request_counter = None
self.request_duration = None
self.rag_search_duration = None
# MLflow inference tracker
self.mlflow_tracker = None
async def setup(self):
"""Initialize all connections."""
# Initialize OpenTelemetry
self.tracer, self.meter = setup_telemetry()
# Setup metrics
if self.meter:
self.request_counter = self.meter.create_counter(
"chat.requests",
description="Number of chat requests processed",
unit="1"
)
self.request_duration = self.meter.create_histogram(
"chat.request_duration",
description="Duration of chat request processing",
unit="s"
)
self.rag_search_duration = self.meter.create_histogram(
"chat.rag_search_duration",
description="Duration of RAG search operations",
unit="s"
)
# Initialize MLflow inference tracker
if MLFLOW_ENABLED and MLFLOW_AVAILABLE:
try:
self.mlflow_tracker = InferenceMetricsTracker(
service_name="chat-handler",
experiment_name="chat-inference",
tracking_uri=MLFLOW_TRACKING_URI,
batch_size=50,
flush_interval_seconds=60.0,
)
await self.mlflow_tracker.start()
logger.info(f"MLflow inference tracking enabled at {MLFLOW_TRACKING_URI}")
except Exception as e:
logger.warning(f"MLflow initialization failed: {e}, tracking disabled")
self.mlflow_tracker = None
elif not MLFLOW_AVAILABLE:
logger.info("MLflow utils not available, inference tracking disabled")
else:
logger.info("MLflow tracking disabled via MLFLOW_ENABLED=false")
# NATS connection with reconnection support
async def disconnected_cb():
logger.warning("NATS disconnected, attempting reconnection...")
async def reconnected_cb():
logger.info(f"NATS reconnected to {self.nc.connected_url.netloc}")
async def error_cb(e):
logger.error(f"NATS error: {e}")
async def closed_cb():
logger.warning("NATS connection closed")
self.nc = await nats.connect(
NATS_URL,
reconnect_time_wait=2,
max_reconnect_attempts=-1, # Infinite reconnection attempts
disconnected_cb=disconnected_cb,
reconnected_cb=reconnected_cb,
error_cb=error_cb,
closed_cb=closed_cb,
)
logger.info(f"Connected to NATS at {NATS_URL}")
# HTTP client for services
self.http_client = httpx.AsyncClient(timeout=180.0)
# Connect to Valkey for conversation history and context caching
try:
self.valkey_client = redis.from_url(
VALKEY_URL,
encoding="utf-8",
decode_responses=True,
socket_connect_timeout=5
)
await self.valkey_client.ping()
logger.info(f"Connected to Valkey at {VALKEY_URL}")
except Exception as e:
logger.warning(f"Valkey connection failed: {e}, conversation history disabled")
self.valkey_client = None
# Connect to Milvus if collection exists
try:
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
if utility.has_collection(COLLECTION_NAME):
self.collection = Collection(COLLECTION_NAME)
self.collection.load()
logger.info(f"Connected to Milvus collection: {COLLECTION_NAME}")
else:
logger.warning(f"Collection {COLLECTION_NAME} not found, RAG disabled")
except Exception as e:
logger.warning(f"Milvus connection failed: {e}, RAG disabled")
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get embeddings from the embedding service."""
try:
response = await self.http_client.post(
f"{EMBEDDINGS_URL}/embeddings", json={"input": texts, "model": "bge"}
)
result = response.json()
return [d["embedding"] for d in result.get("data", [])]
except Exception as e:
logger.error(f"Embedding failed: {e}")
return []
async def search_milvus(
self, query_embedding: List[float], top_k: int = 5
) -> List[Dict]:
"""Search Milvus for relevant documents."""
if not self.collection:
return []
try:
results = self.collection.search(
data=[query_embedding],
anns_field="embedding",
param={"metric_type": "COSINE", "params": {"ef": 64}},
limit=top_k,
output_fields=["text", "book_name", "page_num"],
)
docs = []
for hits in results:
for hit in hits:
docs.append(
{
"text": hit.entity.get("text", ""),
"source": f'{hit.entity.get("book_name", "")} p.{hit.entity.get("page_num", "")}',
"score": hit.score,
}
)
return docs
except Exception as e:
logger.error(f"Milvus search failed: {e}")
return []
async def rerank(self, query: str, documents: List[str]) -> List[Dict]:
"""Rerank documents using the reranker service."""
if not documents:
return []
try:
response = await self.http_client.post(
f"{RERANKER_URL}/v1/rerank",
json={"query": query, "documents": documents},
)
return response.json().get("results", [])
except Exception as e:
logger.error(f"Reranking failed: {e}")
return [{"index": i, "relevance_score": 0.5} for i in range(len(documents))]
async def get_conversation_history(self, session_id: str, max_messages: int = 10) -> List[Dict]:
"""Retrieve conversation history from Valkey."""
if not self.valkey_client or not session_id:
return []
try:
key = f"chat:history:{session_id}"
# Get the most recent messages (stored as a list)
history_json = await self.valkey_client.lrange(key, -max_messages, -1)
history = [json.loads(msg) for msg in history_json]
logger.info(f"Retrieved {len(history)} messages from history for session {session_id}")
return history
except Exception as e:
logger.error(f"Failed to get conversation history: {e}")
return []
async def save_message_to_history(self, session_id: str, role: str, content: str, ttl: int = 3600):
"""Save a message to conversation history in Valkey."""
if not self.valkey_client or not session_id:
return
try:
key = f"chat:history:{session_id}"
message = json.dumps({"role": role, "content": content, "timestamp": time.time()})
# Use RPUSH to append to the list
await self.valkey_client.rpush(key, message)
# Set TTL on the key (1 hour by default)
await self.valkey_client.expire(key, ttl)
logger.debug(f"Saved {role} message to history for session {session_id}")
except Exception as e:
logger.error(f"Failed to save message to history: {e}")
async def get_context_window(self, session_id: str) -> Optional[str]:
"""Retrieve cached context window from Valkey for attention offloading."""
if not self.valkey_client or not session_id:
return None
try:
key = f"chat:context:{session_id}"
context = await self.valkey_client.get(key)
if context:
logger.info(f"Retrieved cached context window for session {session_id}")
return context
except Exception as e:
logger.error(f"Failed to get context window: {e}")
return None
async def save_context_window(self, session_id: str, context: str, ttl: int = 3600):
"""Save context window to Valkey for attention offloading."""
if not self.valkey_client or not session_id:
return
try:
key = f"chat:context:{session_id}"
await self.valkey_client.set(key, context, ex=ttl)
logger.debug(f"Saved context window for session {session_id}")
except Exception as e:
logger.error(f"Failed to save context window: {e}")
async def generate_response(self, query: str, context: str = "", session_id: str = None) -> str:
"""Generate response using vLLM with conversation history from Valkey."""
try:
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
# Add conversation history from Valkey if session exists
if session_id:
history = await self.get_conversation_history(session_id)
messages.extend(history)
if context:
messages.append(
{
"role": "user",
"content": f"Context:\n{context}\n\nQuestion: {query}",
}
)
else:
messages.append({"role": "user", "content": query})
response = await self.http_client.post(
f"{VLLM_URL}/v1/chat/completions",
json={
"model": LLM_MODEL,
"messages": messages,
"max_tokens": 1000,
"temperature": 0.7,
},
)
result = response.json()
answer = result["choices"][0]["message"]["content"]
logger.info(f"Generated response: {answer[:100]}...")
# Save messages to conversation history
if session_id:
await self.save_message_to_history(session_id, "user", query)
await self.save_message_to_history(session_id, "assistant", answer)
return answer
except Exception as e:
logger.error(f"LLM generation failed: {e}")
return "I'm sorry, I couldn't generate a response."
async def generate_response_streaming(self, query: str, context: str = "", request_id: str = "", session_id: str = None):
"""Generate streaming response using vLLM and publish chunks to NATS.
Yields tokens as they are generated and publishes them to NATS streaming subject.
Returns the complete response text.
"""
try:
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
# Add conversation history from Valkey if session exists
if session_id:
history = await self.get_conversation_history(session_id)
messages.extend(history)
if context:
messages.append(
{
"role": "user",
"content": f"Context:\n{context}\n\nQuestion: {query}",
}
)
else:
messages.append({"role": "user", "content": query})
full_response = ""
# Stream response from vLLM
async with self.http_client.stream(
"POST",
f"{VLLM_URL}/v1/chat/completions",
json={
"model": LLM_MODEL,
"messages": messages,
"max_tokens": 1000,
"temperature": 0.7,
"stream": True,
},
timeout=60.0,
) as response:
# Parse SSE (Server-Sent Events) stream
async for line in response.aiter_lines():
if not line or not line.startswith("data: "):
continue
data_str = line[6:] # Remove "data: " prefix
if data_str.strip() == "[DONE]":
break
try:
chunk_data = json.loads(data_str)
# Extract token from delta
if chunk_data.get("choices") and len(chunk_data["choices"]) > 0:
delta = chunk_data["choices"][0].get("delta", {})
content = delta.get("content", "")
if content:
full_response += content
# Publish token chunk to NATS streaming subject
chunk_msg = {
"request_id": request_id,
"type": "chunk",
"content": content,
"done": False,
}
await self.nc.publish(
f"{STREAM_RESPONSE_SUBJECT}.{request_id}",
msgpack.packb(chunk_msg)
)
except json.JSONDecodeError:
continue
# Send completion message
completion_msg = {
"request_id": request_id,
"type": "done",
"content": "",
"done": True,
}
await self.nc.publish(
f"{STREAM_RESPONSE_SUBJECT}.{request_id}",
msgpack.packb(completion_msg)
)
logger.info(f"Streamed complete response ({len(full_response)} chars) for request {request_id}")
# Save messages to conversation history
if session_id:
await self.save_message_to_history(session_id, "user", query)
await self.save_message_to_history(session_id, "assistant", full_response)
return full_response
except Exception as e:
logger.error(f"Streaming LLM generation failed: {e}")
# Send error message
error_msg = {
"request_id": request_id,
"type": "error",
"content": "I'm sorry, I couldn't generate a response.",
"done": True,
"error": str(e),
}
await self.nc.publish(
f"{STREAM_RESPONSE_SUBJECT}.{request_id}",
msgpack.packb(error_msg)
)
return "I'm sorry, I couldn't generate a response."
async def synthesize_speech(self, text: str, language: str = "en") -> str:
"""Convert text to speech using XTTS (Coqui TTS)."""
try:
response = await self.http_client.get(
f"{TTS_URL}/api/tts", params={"text": text, "language_id": language}
)
if response.status_code == 200:
audio_b64 = base64.b64encode(response.content).decode("utf-8")
logger.info(f"Synthesized {len(response.content)} bytes of audio")
return audio_b64
else:
logger.error(
f"TTS returned status {response.status_code}: {response.text}"
)
return ""
except Exception as e:
logger.error(f"TTS failed: {e}")
return ""
async def process_request(self, msg, is_premium=False):
"""Process a chat request."""
start_time = time.time()
span = None
# MLflow metrics tracking
mlflow_metrics = None
embedding_start = None
rag_start = None
rerank_start = None
llm_start = None
try:
data = msgpack.unpackb(msg.data, raw=False)
# Support companions-frontend format (user_id, username, message, premium)
# as well as the original format (request_id, text, enable_rag, etc.)
user_id = data.get("user_id")
username = data.get("username", "")
# Get text from either 'message' (companions-frontend) or 'text' (original)
text = data.get("message") or data.get("text", "")
# Generate request_id from user_id if not provided
import uuid
request_id = data.get("request_id") or f"{user_id or 'anon'}-{uuid.uuid4().hex[:8]}"
# Initialize MLflow metrics if available
if self.mlflow_tracker and MLFLOW_AVAILABLE:
mlflow_metrics = InferenceMetrics(
request_id=request_id,
user_id=user_id,
session_id=data.get("session_id"),
model_name=LLM_MODEL,
model_endpoint=VLLM_URL,
)
# Start tracing span
if self.tracer:
span = self.tracer.start_span("chat.process_request")
span.set_attribute("request_id", request_id)
span.set_attribute("user_id", user_id or "anonymous")
span.set_attribute("premium", is_premium)
# Premium status from message or channel
is_premium = is_premium or data.get("premium", False)
# Support both new parameters and backward compatibility with use_rag
use_rag = data.get("use_rag") # Legacy parameter
enable_rag = data.get(
"enable_rag", use_rag if use_rag is not None else True
)
enable_reranker = data.get(
"enable_reranker", use_rag if use_rag is not None else True
)
# Premium users get more documents for deeper RAG
default_top_k = 15 if is_premium else 5
top_k = data.get("top_k", default_top_k)
# Get request parameters
enable_tts = data.get("enable_tts", False)
enable_streaming = data.get("enable_streaming", False) # New parameter for streaming
language = data.get("language", "en")
session_id = data.get("session_id")
# Update MLflow metrics with request params
if mlflow_metrics:
mlflow_metrics.rag_enabled = enable_rag
mlflow_metrics.reranker_enabled = enable_reranker
mlflow_metrics.is_streaming = enable_streaming
mlflow_metrics.is_premium = is_premium
mlflow_metrics.prompt_length = len(text)
# Add attributes to span
if span:
span.set_attribute("enable_rag", enable_rag)
span.set_attribute("enable_reranker", enable_reranker)
span.set_attribute("top_k", top_k)
span.set_attribute("enable_tts", enable_tts)
span.set_attribute("enable_streaming", enable_streaming)
logger.info(
f"Processing {'premium ' if is_premium else ''}chat request {request_id} from {username or user_id or 'anonymous'}: {text[:50]}... (RAG: {enable_rag}, Reranker: {enable_reranker}, top_k: {top_k})"
)
# Warn if reranker is enabled without RAG
if enable_reranker and not enable_rag:
logger.warning(
f"Request {request_id}: Reranker enabled without RAG - no documents to rerank"
)
if not text:
await self.publish_error(request_id, "No text provided")
return
context = ""
rag_sources = []
docs = []
# Step 1: RAG retrieval (if enabled)
if enable_rag and self.collection:
# Get embeddings for RAG
embedding_start = time.time()
embeddings = await self.get_embeddings([text])
if mlflow_metrics:
mlflow_metrics.embedding_latency = time.time() - embedding_start
if embeddings:
# Search Milvus with configurable top_k
rag_start = time.time()
docs = await self.search_milvus(embeddings[0], top_k=top_k)
if mlflow_metrics:
mlflow_metrics.rag_search_latency = time.time() - rag_start
mlflow_metrics.rag_documents_retrieved = len(docs)
if docs:
rag_sources = [d.get("source", "") for d in docs]
# Step 2: Reranking (if enabled and we have documents)
if enable_reranker and docs:
# Rerank documents
rerank_start = time.time()
doc_texts = [d["text"] for d in docs]
reranked = await self.rerank(text, doc_texts)
if mlflow_metrics:
mlflow_metrics.rerank_latency = time.time() - rerank_start
# Take top 3 reranked documents with bounds checking
sorted_docs = sorted(
reranked, key=lambda x: x.get("relevance_score", 0), reverse=True
)[:3]
# Build context with bounds checking
# Note: doc_texts and docs have the same length (doc_texts derived from docs)
context_parts = []
sources = []
for item in sorted_docs:
idx = item.get("index", -1)
if 0 <= idx < len(docs):
context_parts.append(doc_texts[idx])
sources.append(docs[idx].get("source", ""))
else:
logger.warning(
f"Reranker returned invalid index {idx} for {len(docs)} docs"
)
context = "\n\n".join(context_parts)
rag_sources = sources
elif docs:
# Use documents without reranking (take top 3)
doc_texts = [d["text"] for d in docs[:3]]
context = "\n\n".join(doc_texts)
rag_sources = [d.get("source", "") for d in docs[:3]]
# Step 3: Generate response (streaming or non-streaming)
# Check for cached context window from Valkey (for attention offloading)
cached_context = None
if session_id:
cached_context = await self.get_context_window(session_id)
# Combine RAG context with cached context if available
if cached_context and context:
# Prepend cached context to current RAG context
combined_context = f"{cached_context}\n\n{context}"
# Truncate to prevent unbounded growth
if len(combined_context) > MAX_CONTEXT_LENGTH:
logger.warning(f"Context length {len(combined_context)} exceeds max {MAX_CONTEXT_LENGTH}, truncating")
# Keep the most recent context (from the end)
combined_context = combined_context[-MAX_CONTEXT_LENGTH:]
context = combined_context
elif cached_context:
# Only cached context, still need to check length
if len(cached_context) > MAX_CONTEXT_LENGTH:
logger.warning(f"Cached context length {len(cached_context)} exceeds max {MAX_CONTEXT_LENGTH}, truncating")
cached_context = cached_context[-MAX_CONTEXT_LENGTH:]
context = cached_context
# Save the combined context for future use (already truncated if needed)
if session_id and context:
await self.save_context_window(session_id, context)
# Track number of RAG docs used after reranking
if mlflow_metrics and enable_rag:
mlflow_metrics.rag_documents_used = min(3, len(docs)) if docs else 0
llm_start = time.time()
if enable_streaming:
# Use streaming response
answer = await self.generate_response_streaming(text, context, request_id, session_id)
else:
# Use non-streaming response
answer = await self.generate_response(text, context, session_id)
if mlflow_metrics:
mlflow_metrics.llm_latency = time.time() - llm_start
mlflow_metrics.response_length = len(answer)
# Estimate token counts (rough approximation: 4 chars per token)
mlflow_metrics.input_tokens = len(text) // 4
mlflow_metrics.output_tokens = len(answer) // 4
mlflow_metrics.total_tokens = mlflow_metrics.input_tokens + mlflow_metrics.output_tokens
# Step 4: Optionally synthesize speech
audio_b64 = ""
if enable_tts:
audio_b64 = await self.synthesize_speech(answer, language)
# Publish result
# Include both 'response' (companions-frontend) and 'response_text' (original) for compatibility
result = {
"request_id": request_id,
"user_id": user_id,
"text": text,
"response": answer, # companions-frontend expects 'response'
"response_text": answer, # original format
"audio_b64": audio_b64 if enable_tts else None,
"used_rag": bool(context),
"rag_enabled": enable_rag,
"reranker_enabled": enable_reranker,
"rag_sources": rag_sources,
"session_id": session_id,
"success": True,
}
await self.nc.publish(
f"{RESPONSE_SUBJECT}.{request_id}", msgpack.packb(result)
)
logger.info(f"Published response for request {request_id}")
# Record metrics
duration = time.time() - start_time
if self.request_counter:
self.request_counter.add(1, {"premium": str(is_premium), "rag_enabled": str(enable_rag), "success": "true"})
if self.request_duration:
self.request_duration.record(duration, {"premium": str(is_premium), "rag_enabled": str(enable_rag)})
if span:
span.set_attribute("success", True)
span.set_attribute("response_length", len(answer))
# Log to MLflow
if self.mlflow_tracker and mlflow_metrics:
mlflow_metrics.total_latency = duration
await self.mlflow_tracker.log_inference(mlflow_metrics)
except Exception as e:
logger.error(f"Request processing failed: {e}")
if self.request_counter:
self.request_counter.add(1, {"premium": str(is_premium), "success": "false"})
if span:
span.set_attribute("success", False)
span.set_attribute("error", str(e))
# Log error to MLflow
if self.mlflow_tracker and mlflow_metrics:
mlflow_metrics.has_error = True
mlflow_metrics.error_message = str(e)
mlflow_metrics.total_latency = time.time() - start_time
await self.mlflow_tracker.log_inference(mlflow_metrics)
await self.publish_error(data.get("request_id", "unknown"), str(e))
finally:
if span:
span.end()
async def publish_error(self, request_id: str, error: str):
"""Publish an error response."""
result = {"request_id": request_id, "error": error, "success": False}
await self.nc.publish(
f"{RESPONSE_SUBJECT}.{request_id}", msgpack.packb(result)
)
async def process_premium_request(self, msg):
"""Process a premium chat request (wrapper for deeper RAG)."""
await self.process_request(msg, is_premium=True)
async def run(self):
"""Main run loop."""
await self.setup()
# Subscribe to standard chat requests
sub = await self.nc.subscribe(REQUEST_SUBJECT, cb=self.process_request)
logger.info(f"Subscribed to {REQUEST_SUBJECT}")
# Subscribe to premium chat requests (deeper RAG retrieval)
premium_sub = await self.nc.subscribe(
PREMIUM_REQUEST_SUBJECT, cb=self.process_premium_request
)
logger.info(f"Subscribed to {PREMIUM_REQUEST_SUBJECT}")
# Handle shutdown
def signal_handler():
self.running = False
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
# Keep running
while self.running:
await asyncio.sleep(1)
# Cleanup
await sub.unsubscribe()
await premium_sub.unsubscribe()
await self.nc.close()
if self.valkey_client:
await self.valkey_client.close()
if self.collection:
connections.disconnect("default")
if self.mlflow_tracker:
await self.mlflow_tracker.stop()
logger.info("Shutdown complete")
if __name__ == "__main__":
handler = ChatHandler()
asyncio.run(handler.run())