Files
chat-handler/chat_handler.py
Billy D. 24a4098c9a
Some checks failed
CI / Lint (push) Failing after 1m39s
CI / Test (push) Failing after 1m37s
CI / Release (push) Has been skipped
CI / Notify (push) Successful in 1s
fixing up chat-handler.
2026-02-18 07:29:41 -05:00

314 lines
10 KiB
Python

#!/usr/bin/env python3
"""
Chat Handler Service (Refactored)
Text-based chat pipeline using handler-base:
1. Listen for text on NATS subject "ai.chat.user.*.message"
2. If RAG enabled (premium/explicit): embed → Milvus search → rerank
3. Generate response with vLLM (with or without RAG context)
4. Optionally stream chunks to "ai.chat.response.stream.{request_id}"
5. Optionally synthesize speech with XTTS
6. Publish result to "ai.chat.response.{request_id}" (or custom response_subject)
"""
import base64
import logging
from typing import Any, Optional
from nats.aio.msg import Msg
from handler_base import Handler, Settings
from handler_base.clients import (
EmbeddingsClient,
RerankerClient,
LLMClient,
TTSClient,
MilvusClient,
)
from handler_base.telemetry import create_span
logger = logging.getLogger("chat-handler")
class ChatSettings(Settings):
"""Chat handler specific settings."""
service_name: str = "chat-handler"
# RAG settings
rag_top_k: int = 10
rag_rerank_top_k: int = 5
rag_collection: str = "documents"
# Response settings
include_sources: bool = True
enable_tts: bool = False
tts_language: str = "en"
class ChatHandler(Handler):
"""
Chat request handler with RAG pipeline.
Subscribes to: ai.chat.user.*.message (JetStream durable "chat-handler")
Request format (msgpack):
{
"request_id": "uuid",
"user_id": "user-123",
"username": "john_doe",
"message": "user question",
"premium": false,
"enable_rag": true,
"enable_reranker": true,
"enable_streaming": true,
"top_k": 5,
"session_id": "session-abc",
"system_prompt": "optional custom system prompt"
}
Response format (msgpack):
{
"user_id": "user-123",
"response": "generated response",
"response_text": "generated response",
"used_rag": true,
"rag_sources": ["source1", "source2"],
"success": true
}
"""
def __init__(self):
self.chat_settings = ChatSettings()
super().__init__(
subject="ai.chat.user.*.message",
settings=self.chat_settings,
queue_group="chat-handlers",
)
async def setup(self) -> None:
"""Initialize service clients."""
logger.info("Initializing service clients...")
self.embeddings = EmbeddingsClient(self.chat_settings)
self.reranker = RerankerClient(self.chat_settings)
self.llm = LLMClient(self.chat_settings)
self.milvus = MilvusClient(self.chat_settings)
# TTS is optional
if self.chat_settings.enable_tts:
self.tts = TTSClient(self.chat_settings)
else:
self.tts = None
# Connect to Milvus
await self.milvus.connect(self.chat_settings.rag_collection)
logger.info("Service clients initialized")
async def teardown(self) -> None:
"""Clean up service clients."""
logger.info("Closing service clients...")
await self.embeddings.close()
await self.reranker.close()
await self.llm.close()
await self.milvus.close()
if self.tts:
await self.tts.close()
logger.info("Service clients closed")
async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]:
"""Handle incoming chat request."""
request_id = data.get("request_id", "unknown")
user_id = data.get("user_id", "unknown")
query = data.get("message", "") or data.get("query", "")
premium = data.get("premium", False)
enable_rag = data.get("enable_rag", premium)
enable_reranker = data.get("enable_reranker", enable_rag)
enable_streaming = data.get("enable_streaming", False)
top_k = data.get("top_k", self.chat_settings.rag_top_k)
collection = data.get("collection", self.chat_settings.rag_collection)
enable_tts = data.get("enable_tts", self.chat_settings.enable_tts)
system_prompt = data.get("system_prompt")
# companions-frontend may set a custom response subject
response_subject = data.get(
"response_subject", f"ai.chat.response.{request_id}"
)
logger.info(f"Processing request {request_id}: {query[:50]}...")
with create_span("chat.process") as span:
if span:
span.set_attribute("request.id", request_id)
span.set_attribute("user.id", user_id)
span.set_attribute("query.length", len(query))
span.set_attribute("premium", premium)
span.set_attribute("rag.enabled", enable_rag)
context = ""
rag_sources: list[str] = []
used_rag = False
# Only run RAG pipeline when enabled (premium users or explicit flag)
if enable_rag:
# 1. Generate query embedding
embedding = await self._get_embedding(query)
# 2. Search Milvus for context
documents = await self._search_context(
embedding, collection, top_k=top_k,
)
# 3. Optionally rerank documents
if enable_reranker and documents:
reranked = await self._rerank_documents(query, documents)
else:
reranked = documents
# 4. Build context from top documents
if reranked:
context = self._build_context(reranked)
rag_sources = [
d.get("source", d.get("document", "")[:80])
for d in reranked[:3]
]
used_rag = True
# 5. Generate LLM response (with or without RAG context)
response_text = await self._generate_response(
query, context or None, system_prompt,
)
# 6. Stream response chunks if requested
if enable_streaming:
stream_subject = f"ai.chat.response.stream.{request_id}"
await self._publish_streaming_chunks(
stream_subject, request_id, response_text,
)
# 7. Optionally synthesize speech
audio_b64 = None
if enable_tts and self.tts:
audio_b64 = await self._synthesize_speech(response_text)
# Build response (compatible with companions-frontend NATSChatResponse)
result: dict[str, Any] = {
"user_id": user_id,
"response": response_text,
"response_text": response_text,
"used_rag": used_rag,
"rag_sources": rag_sources,
"success": True,
}
if audio_b64:
result["audio"] = audio_b64
logger.info(f"Completed request {request_id} (rag={used_rag})")
# Publish to the response subject the frontend is waiting on
await self.nats.publish(response_subject, result)
return result
async def _get_embedding(self, text: str) -> list[float]:
"""Generate embedding for query text."""
with create_span("chat.embedding"):
return await self.embeddings.embed_single(text)
async def _search_context(
self,
embedding: list[float],
collection: str,
top_k: int | None = None,
) -> list[dict]:
"""Search Milvus for relevant documents."""
with create_span("chat.search"):
return await self.milvus.search_with_texts(
embedding,
limit=top_k or self.chat_settings.rag_top_k,
text_field="text",
metadata_fields=["source", "title"],
)
async def _rerank_documents(self, query: str, documents: list[dict]) -> list[dict]:
"""Rerank documents by relevance to query."""
with create_span("chat.rerank"):
texts = [d.get("text", "") for d in documents]
return await self.reranker.rerank(
query, texts, top_k=self.chat_settings.rag_rerank_top_k
)
def _build_context(self, documents: list[dict]) -> str:
"""Build context string from ranked documents."""
context_parts = []
for i, doc in enumerate(documents, 1):
text = doc.get("document", "")
context_parts.append(f"[{i}] {text}")
return "\n\n".join(context_parts)
async def _generate_response(
self,
query: str,
context: Optional[str] = None,
system_prompt: Optional[str] = None,
) -> str:
"""Generate LLM response, optionally augmented with RAG context."""
with create_span("chat.generate"):
return await self.llm.generate(
query,
context=context,
system_prompt=system_prompt,
)
async def _publish_streaming_chunks(
self,
subject: str,
request_id: str,
full_text: str,
) -> None:
"""Publish response as streaming chunks for real-time display."""
import time
words = full_text.split(" ")
chunk_size = 4
for i in range(0, len(words), chunk_size):
token_chunk = " ".join(words[i : i + chunk_size])
await self.nats.publish(
subject,
{
"request_id": request_id,
"type": "chunk",
"content": token_chunk,
"done": False,
"timestamp": time.time(),
},
)
# Send done marker
await self.nats.publish(
subject,
{
"request_id": request_id,
"type": "done",
"content": "",
"done": True,
"timestamp": time.time(),
},
)
async def _synthesize_speech(self, text: str) -> str:
"""Synthesize speech and return base64 encoded audio."""
with create_span("chat.tts"):
audio_bytes = await self.tts.synthesize(
text,
language=self.chat_settings.tts_language,
)
return base64.b64encode(audio_bytes).decode()
if __name__ == "__main__":
ChatHandler().run()