diff --git a/chat_handler.py b/chat_handler.py index 8cc3197..20a75e6 100644 --- a/chat_handler.py +++ b/chat_handler.py @@ -3,13 +3,12 @@ Chat Handler Service (Refactored) Text-based chat pipeline using handler-base: -1. Listen for text on NATS subject "ai.chat.request" -2. Generate embeddings for RAG -3. Retrieve context from Milvus -4. Rerank with BGE reranker -5. Generate response with vLLM -6. Optionally synthesize speech with XTTS -7. Publish result to NATS "ai.chat.response.{request_id}" +1. Listen for text on NATS subject "ai.chat.user.*.message" +2. If RAG enabled (premium/explicit): embed → Milvus search → rerank +3. Generate response with vLLM (with or without RAG context) +4. Optionally stream chunks to "ai.chat.response.stream.{request_id}" +5. Optionally synthesize speech with XTTS +6. Publish result to "ai.chat.response.{request_id}" (or custom response_subject) """ import base64 @@ -51,28 +50,38 @@ class ChatHandler(Handler): """ Chat request handler with RAG pipeline. - Request format: + Subscribes to: ai.chat.user.*.message (JetStream durable "chat-handler") + + Request format (msgpack): { "request_id": "uuid", - "query": "user question", - "collection": "optional collection name", - "enable_tts": false, + "user_id": "user-123", + "username": "john_doe", + "message": "user question", + "premium": false, + "enable_rag": true, + "enable_reranker": true, + "enable_streaming": true, + "top_k": 5, + "session_id": "session-abc", "system_prompt": "optional custom system prompt" } - Response format: + Response format (msgpack): { - "request_id": "uuid", + "user_id": "user-123", "response": "generated response", - "sources": [{"text": "...", "score": 0.95}], - "audio": "base64 encoded audio (if tts enabled)" + "response_text": "generated response", + "used_rag": true, + "rag_sources": ["source1", "source2"], + "success": true } """ def __init__(self): self.chat_settings = ChatSettings() super().__init__( - subject="ai.chat.request", + subject="ai.chat.user.*.message", settings=self.chat_settings, queue_group="chat-handlers", ) @@ -114,56 +123,93 @@ class ChatHandler(Handler): async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]: """Handle incoming chat request.""" request_id = data.get("request_id", "unknown") - query = data.get("query", "") + user_id = data.get("user_id", "unknown") + query = data.get("message", "") or data.get("query", "") + premium = data.get("premium", False) + enable_rag = data.get("enable_rag", premium) + enable_reranker = data.get("enable_reranker", enable_rag) + enable_streaming = data.get("enable_streaming", False) + top_k = data.get("top_k", self.chat_settings.rag_top_k) collection = data.get("collection", self.chat_settings.rag_collection) enable_tts = data.get("enable_tts", self.chat_settings.enable_tts) system_prompt = data.get("system_prompt") + # companions-frontend may set a custom response subject + response_subject = data.get( + "response_subject", f"ai.chat.response.{request_id}" + ) logger.info(f"Processing request {request_id}: {query[:50]}...") with create_span("chat.process") as span: if span: span.set_attribute("request.id", request_id) + span.set_attribute("user.id", user_id) span.set_attribute("query.length", len(query)) + span.set_attribute("premium", premium) + span.set_attribute("rag.enabled", enable_rag) - # 1. Generate query embedding - embedding = await self._get_embedding(query) + context = "" + rag_sources: list[str] = [] + used_rag = False - # 2. Search Milvus for context - documents = await self._search_context(embedding, collection) + # Only run RAG pipeline when enabled (premium users or explicit flag) + if enable_rag: + # 1. Generate query embedding + embedding = await self._get_embedding(query) - # 3. Rerank documents - reranked = await self._rerank_documents(query, documents) + # 2. Search Milvus for context + documents = await self._search_context( + embedding, collection, top_k=top_k, + ) - # 4. Build context from top documents - context = self._build_context(reranked) + # 3. Optionally rerank documents + if enable_reranker and documents: + reranked = await self._rerank_documents(query, documents) + else: + reranked = documents - # 5. Generate LLM response - response_text = await self._generate_response(query, context, system_prompt) + # 4. Build context from top documents + if reranked: + context = self._build_context(reranked) + rag_sources = [ + d.get("source", d.get("document", "")[:80]) + for d in reranked[:3] + ] + used_rag = True - # 6. Optionally synthesize speech + # 5. Generate LLM response (with or without RAG context) + response_text = await self._generate_response( + query, context or None, system_prompt, + ) + + # 6. Stream response chunks if requested + if enable_streaming: + stream_subject = f"ai.chat.response.stream.{request_id}" + await self._publish_streaming_chunks( + stream_subject, request_id, response_text, + ) + + # 7. Optionally synthesize speech audio_b64 = None if enable_tts and self.tts: audio_b64 = await self._synthesize_speech(response_text) - # Build response - result = { - "request_id": request_id, + # Build response (compatible with companions-frontend NATSChatResponse) + result: dict[str, Any] = { + "user_id": user_id, "response": response_text, + "response_text": response_text, + "used_rag": used_rag, + "rag_sources": rag_sources, + "success": True, } - if self.chat_settings.include_sources: - result["sources"] = [ - {"text": d["document"][:200], "score": d["score"]} for d in reranked[:3] - ] - if audio_b64: result["audio"] = audio_b64 - logger.info(f"Completed request {request_id}") + logger.info(f"Completed request {request_id} (rag={used_rag})") - # Publish to response subject - response_subject = f"ai.chat.response.{request_id}" + # Publish to the response subject the frontend is waiting on await self.nats.publish(response_subject, result) return result @@ -173,12 +219,17 @@ class ChatHandler(Handler): with create_span("chat.embedding"): return await self.embeddings.embed_single(text) - async def _search_context(self, embedding: list[float], collection: str) -> list[dict]: + async def _search_context( + self, + embedding: list[float], + collection: str, + top_k: int | None = None, + ) -> list[dict]: """Search Milvus for relevant documents.""" with create_span("chat.search"): return await self.milvus.search_with_texts( embedding, - limit=self.chat_settings.rag_top_k, + limit=top_k or self.chat_settings.rag_top_k, text_field="text", metadata_fields=["source", "title"], ) @@ -202,10 +253,10 @@ class ChatHandler(Handler): async def _generate_response( self, query: str, - context: str, + context: Optional[str] = None, system_prompt: Optional[str] = None, ) -> str: - """Generate LLM response with context.""" + """Generate LLM response, optionally augmented with RAG context.""" with create_span("chat.generate"): return await self.llm.generate( query, @@ -213,6 +264,41 @@ class ChatHandler(Handler): system_prompt=system_prompt, ) + async def _publish_streaming_chunks( + self, + subject: str, + request_id: str, + full_text: str, + ) -> None: + """Publish response as streaming chunks for real-time display.""" + import time + + words = full_text.split(" ") + chunk_size = 4 + for i in range(0, len(words), chunk_size): + token_chunk = " ".join(words[i : i + chunk_size]) + await self.nats.publish( + subject, + { + "request_id": request_id, + "type": "chunk", + "content": token_chunk, + "done": False, + "timestamp": time.time(), + }, + ) + # Send done marker + await self.nats.publish( + subject, + { + "request_id": request_id, + "type": "done", + "content": "", + "done": True, + "timestamp": time.time(), + }, + ) + async def _synthesize_speech(self, text: str) -> str: """Synthesize speech and return base64 encoded audio.""" with create_span("chat.tts"):