#!/usr/bin/env python3 """ Chat Handler Service (Refactored) Text-based chat pipeline using handler-base: 1. Listen for text on NATS subject "ai.chat.request" 2. Generate embeddings for RAG 3. Retrieve context from Milvus 4. Rerank with BGE reranker 5. Generate response with vLLM 6. Optionally synthesize speech with XTTS 7. Publish result to NATS "ai.chat.response.{request_id}" """ import base64 import logging from typing import Any, Optional from nats.aio.msg import Msg from handler_base import Handler, Settings from handler_base.clients import ( EmbeddingsClient, RerankerClient, LLMClient, TTSClient, MilvusClient, ) from handler_base.telemetry import create_span logger = logging.getLogger("chat-handler") class ChatSettings(Settings): """Chat handler specific settings.""" service_name: str = "chat-handler" # RAG settings rag_top_k: int = 10 rag_rerank_top_k: int = 5 rag_collection: str = "documents" # Response settings include_sources: bool = True enable_tts: bool = False tts_language: str = "en" class ChatHandler(Handler): """ Chat request handler with RAG pipeline. Request format: { "request_id": "uuid", "query": "user question", "collection": "optional collection name", "enable_tts": false, "system_prompt": "optional custom system prompt" } Response format: { "request_id": "uuid", "response": "generated response", "sources": [{"text": "...", "score": 0.95}], "audio": "base64 encoded audio (if tts enabled)" } """ def __init__(self): self.chat_settings = ChatSettings() super().__init__( subject="ai.chat.request", settings=self.chat_settings, queue_group="chat-handlers", ) async def setup(self) -> None: """Initialize service clients.""" logger.info("Initializing service clients...") self.embeddings = EmbeddingsClient(self.chat_settings) self.reranker = RerankerClient(self.chat_settings) self.llm = LLMClient(self.chat_settings) self.milvus = MilvusClient(self.chat_settings) # TTS is optional if self.chat_settings.enable_tts: self.tts = TTSClient(self.chat_settings) else: self.tts = None # Connect to Milvus await self.milvus.connect(self.chat_settings.rag_collection) logger.info("Service clients initialized") async def teardown(self) -> None: """Clean up service clients.""" logger.info("Closing service clients...") await self.embeddings.close() await self.reranker.close() await self.llm.close() await self.milvus.close() if self.tts: await self.tts.close() logger.info("Service clients closed") async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]: """Handle incoming chat request.""" request_id = data.get("request_id", "unknown") query = data.get("query", "") collection = data.get("collection", self.chat_settings.rag_collection) enable_tts = data.get("enable_tts", self.chat_settings.enable_tts) system_prompt = data.get("system_prompt") logger.info(f"Processing request {request_id}: {query[:50]}...") with create_span("chat.process") as span: if span: span.set_attribute("request.id", request_id) span.set_attribute("query.length", len(query)) # 1. Generate query embedding embedding = await self._get_embedding(query) # 2. Search Milvus for context documents = await self._search_context(embedding, collection) # 3. Rerank documents reranked = await self._rerank_documents(query, documents) # 4. Build context from top documents context = self._build_context(reranked) # 5. Generate LLM response response_text = await self._generate_response( query, context, system_prompt ) # 6. Optionally synthesize speech audio_b64 = None if enable_tts and self.tts: audio_b64 = await self._synthesize_speech(response_text) # Build response result = { "request_id": request_id, "response": response_text, } if self.chat_settings.include_sources: result["sources"] = [ {"text": d["document"][:200], "score": d["score"]} for d in reranked[:3] ] if audio_b64: result["audio"] = audio_b64 logger.info(f"Completed request {request_id}") # Publish to response subject response_subject = f"ai.chat.response.{request_id}" await self.nats.publish(response_subject, result) return result async def _get_embedding(self, text: str) -> list[float]: """Generate embedding for query text.""" with create_span("chat.embedding"): return await self.embeddings.embed_single(text) async def _search_context( self, embedding: list[float], collection: str ) -> list[dict]: """Search Milvus for relevant documents.""" with create_span("chat.search"): return await self.milvus.search_with_texts( embedding, limit=self.chat_settings.rag_top_k, text_field="text", metadata_fields=["source", "title"], ) async def _rerank_documents( self, query: str, documents: list[dict] ) -> list[dict]: """Rerank documents by relevance to query.""" with create_span("chat.rerank"): texts = [d.get("text", "") for d in documents] return await self.reranker.rerank( query, texts, top_k=self.chat_settings.rag_rerank_top_k ) def _build_context(self, documents: list[dict]) -> str: """Build context string from ranked documents.""" context_parts = [] for i, doc in enumerate(documents, 1): text = doc.get("document", "") context_parts.append(f"[{i}] {text}") return "\n\n".join(context_parts) async def _generate_response( self, query: str, context: str, system_prompt: Optional[str] = None, ) -> str: """Generate LLM response with context.""" with create_span("chat.generate"): return await self.llm.generate( query, context=context, system_prompt=system_prompt, ) async def _synthesize_speech(self, text: str) -> str: """Synthesize speech and return base64 encoded audio.""" with create_span("chat.tts"): audio_bytes = await self.tts.synthesize( text, language=self.chat_settings.tts_language, ) return base64.b64encode(audio_bytes).decode() if __name__ == "__main__": ChatHandler().run()