#!/usr/bin/env python3 """ Chat Handler Service (Refactored) Text-based chat pipeline using handler-base: 1. Listen for text on NATS subject "ai.chat.user.*.message" 2. If RAG enabled (premium/explicit): embed → Milvus search → rerank 3. Generate response with vLLM (with or without RAG context) 4. Optionally stream chunks to "ai.chat.response.stream.{request_id}" 5. Optionally synthesize speech with XTTS 6. Publish result to "ai.chat.response.{request_id}" (or custom response_subject) """ import base64 import logging from typing import Any, Optional from nats.aio.msg import Msg from handler_base import Handler, Settings from handler_base.clients import ( EmbeddingsClient, RerankerClient, LLMClient, TTSClient, MilvusClient, ) from handler_base.telemetry import create_span logger = logging.getLogger("chat-handler") class ChatSettings(Settings): """Chat handler specific settings.""" service_name: str = "chat-handler" # RAG settings rag_top_k: int = 10 rag_rerank_top_k: int = 5 rag_collection: str = "documents" # Response settings include_sources: bool = True enable_tts: bool = False tts_language: str = "en" class ChatHandler(Handler): """ Chat request handler with RAG pipeline. Subscribes to: ai.chat.user.*.message (JetStream durable "chat-handler") Request format (msgpack): { "request_id": "uuid", "user_id": "user-123", "username": "john_doe", "message": "user question", "premium": false, "enable_rag": true, "enable_reranker": true, "enable_streaming": true, "top_k": 5, "session_id": "session-abc", "system_prompt": "optional custom system prompt" } Response format (msgpack): { "user_id": "user-123", "response": "generated response", "response_text": "generated response", "used_rag": true, "rag_sources": ["source1", "source2"], "success": true } """ def __init__(self): self.chat_settings = ChatSettings() super().__init__( subject="ai.chat.user.*.message", settings=self.chat_settings, queue_group="chat-handlers", ) async def setup(self) -> None: """Initialize service clients.""" logger.info("Initializing service clients...") self.embeddings = EmbeddingsClient(self.chat_settings) self.reranker = RerankerClient(self.chat_settings) self.llm = LLMClient(self.chat_settings) self.milvus = MilvusClient(self.chat_settings) # TTS is optional if self.chat_settings.enable_tts: self.tts = TTSClient(self.chat_settings) else: self.tts = None # Connect to Milvus await self.milvus.connect(self.chat_settings.rag_collection) logger.info("Service clients initialized") async def teardown(self) -> None: """Clean up service clients.""" logger.info("Closing service clients...") await self.embeddings.close() await self.reranker.close() await self.llm.close() await self.milvus.close() if self.tts: await self.tts.close() logger.info("Service clients closed") async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]: """Handle incoming chat request.""" request_id = data.get("request_id", "unknown") user_id = data.get("user_id", "unknown") query = data.get("message", "") or data.get("query", "") premium = data.get("premium", False) enable_rag = data.get("enable_rag", premium) enable_reranker = data.get("enable_reranker", enable_rag) enable_streaming = data.get("enable_streaming", False) top_k = data.get("top_k", self.chat_settings.rag_top_k) collection = data.get("collection", self.chat_settings.rag_collection) enable_tts = data.get("enable_tts", self.chat_settings.enable_tts) system_prompt = data.get("system_prompt") # companions-frontend may set a custom response subject response_subject = data.get( "response_subject", f"ai.chat.response.{request_id}" ) logger.info(f"Processing request {request_id}: {query[:50]}...") with create_span("chat.process") as span: if span: span.set_attribute("request.id", request_id) span.set_attribute("user.id", user_id) span.set_attribute("query.length", len(query)) span.set_attribute("premium", premium) span.set_attribute("rag.enabled", enable_rag) context = "" rag_sources: list[str] = [] used_rag = False # Only run RAG pipeline when enabled (premium users or explicit flag) if enable_rag: # 1. Generate query embedding embedding = await self._get_embedding(query) # 2. Search Milvus for context documents = await self._search_context( embedding, collection, top_k=top_k, ) # 3. Optionally rerank documents if enable_reranker and documents: reranked = await self._rerank_documents(query, documents) else: reranked = documents # 4. Build context from top documents if reranked: context = self._build_context(reranked) rag_sources = [ d.get("source", d.get("document", "")[:80]) for d in reranked[:3] ] used_rag = True # 5. Generate LLM response (with or without RAG context) response_text = await self._generate_response( query, context or None, system_prompt, ) # 6. Stream response chunks if requested if enable_streaming: stream_subject = f"ai.chat.response.stream.{request_id}" await self._publish_streaming_chunks( stream_subject, request_id, response_text, ) # 7. Optionally synthesize speech audio_b64 = None if enable_tts and self.tts: audio_b64 = await self._synthesize_speech(response_text) # Build response (compatible with companions-frontend NATSChatResponse) result: dict[str, Any] = { "user_id": user_id, "response": response_text, "response_text": response_text, "used_rag": used_rag, "rag_sources": rag_sources, "success": True, } if audio_b64: result["audio"] = audio_b64 logger.info(f"Completed request {request_id} (rag={used_rag})") # Publish to the response subject the frontend is waiting on await self.nats.publish(response_subject, result) return result async def _get_embedding(self, text: str) -> list[float]: """Generate embedding for query text.""" with create_span("chat.embedding"): return await self.embeddings.embed_single(text) async def _search_context( self, embedding: list[float], collection: str, top_k: int | None = None, ) -> list[dict]: """Search Milvus for relevant documents.""" with create_span("chat.search"): return await self.milvus.search_with_texts( embedding, limit=top_k or self.chat_settings.rag_top_k, text_field="text", metadata_fields=["source", "title"], ) async def _rerank_documents(self, query: str, documents: list[dict]) -> list[dict]: """Rerank documents by relevance to query.""" with create_span("chat.rerank"): texts = [d.get("text", "") for d in documents] return await self.reranker.rerank( query, texts, top_k=self.chat_settings.rag_rerank_top_k ) def _build_context(self, documents: list[dict]) -> str: """Build context string from ranked documents.""" context_parts = [] for i, doc in enumerate(documents, 1): text = doc.get("document", "") context_parts.append(f"[{i}] {text}") return "\n\n".join(context_parts) async def _generate_response( self, query: str, context: Optional[str] = None, system_prompt: Optional[str] = None, ) -> str: """Generate LLM response, optionally augmented with RAG context.""" with create_span("chat.generate"): return await self.llm.generate( query, context=context, system_prompt=system_prompt, ) async def _publish_streaming_chunks( self, subject: str, request_id: str, full_text: str, ) -> None: """Publish response as streaming chunks for real-time display.""" import time words = full_text.split(" ") chunk_size = 4 for i in range(0, len(words), chunk_size): token_chunk = " ".join(words[i : i + chunk_size]) await self.nats.publish( subject, { "request_id": request_id, "type": "chunk", "content": token_chunk, "done": False, "timestamp": time.time(), }, ) # Send done marker await self.nats.publish( subject, { "request_id": request_id, "type": "done", "content": "", "done": True, "timestamp": time.time(), }, ) async def _synthesize_speech(self, text: str) -> str: """Synthesize speech and return base64 encoded audio.""" with create_span("chat.tts"): audio_bytes = await self.tts.synthesize( text, language=self.chat_settings.tts_language, ) return base64.b64encode(audio_bytes).decode() if __name__ == "__main__": ChatHandler().run()