fixing up chat-handler.
Some checks failed
CI / Lint (push) Failing after 1m39s
CI / Test (push) Failing after 1m37s
CI / Release (push) Has been skipped
CI / Notify (push) Successful in 1s

This commit is contained in:
2026-02-18 07:29:41 -05:00
parent b34e8d2e1c
commit 24a4098c9a

View File

@@ -3,13 +3,12 @@
Chat Handler Service (Refactored) Chat Handler Service (Refactored)
Text-based chat pipeline using handler-base: Text-based chat pipeline using handler-base:
1. Listen for text on NATS subject "ai.chat.request" 1. Listen for text on NATS subject "ai.chat.user.*.message"
2. Generate embeddings for RAG 2. If RAG enabled (premium/explicit): embed → Milvus search → rerank
3. Retrieve context from Milvus 3. Generate response with vLLM (with or without RAG context)
4. Rerank with BGE reranker 4. Optionally stream chunks to "ai.chat.response.stream.{request_id}"
5. Generate response with vLLM 5. Optionally synthesize speech with XTTS
6. Optionally synthesize speech with XTTS 6. Publish result to "ai.chat.response.{request_id}" (or custom response_subject)
7. Publish result to NATS "ai.chat.response.{request_id}"
""" """
import base64 import base64
@@ -51,28 +50,38 @@ class ChatHandler(Handler):
""" """
Chat request handler with RAG pipeline. Chat request handler with RAG pipeline.
Request format: Subscribes to: ai.chat.user.*.message (JetStream durable "chat-handler")
Request format (msgpack):
{ {
"request_id": "uuid", "request_id": "uuid",
"query": "user question", "user_id": "user-123",
"collection": "optional collection name", "username": "john_doe",
"enable_tts": false, "message": "user question",
"premium": false,
"enable_rag": true,
"enable_reranker": true,
"enable_streaming": true,
"top_k": 5,
"session_id": "session-abc",
"system_prompt": "optional custom system prompt" "system_prompt": "optional custom system prompt"
} }
Response format: Response format (msgpack):
{ {
"request_id": "uuid", "user_id": "user-123",
"response": "generated response", "response": "generated response",
"sources": [{"text": "...", "score": 0.95}], "response_text": "generated response",
"audio": "base64 encoded audio (if tts enabled)" "used_rag": true,
"rag_sources": ["source1", "source2"],
"success": true
} }
""" """
def __init__(self): def __init__(self):
self.chat_settings = ChatSettings() self.chat_settings = ChatSettings()
super().__init__( super().__init__(
subject="ai.chat.request", subject="ai.chat.user.*.message",
settings=self.chat_settings, settings=self.chat_settings,
queue_group="chat-handlers", queue_group="chat-handlers",
) )
@@ -114,56 +123,93 @@ class ChatHandler(Handler):
async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]: async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]:
"""Handle incoming chat request.""" """Handle incoming chat request."""
request_id = data.get("request_id", "unknown") request_id = data.get("request_id", "unknown")
query = data.get("query", "") user_id = data.get("user_id", "unknown")
query = data.get("message", "") or data.get("query", "")
premium = data.get("premium", False)
enable_rag = data.get("enable_rag", premium)
enable_reranker = data.get("enable_reranker", enable_rag)
enable_streaming = data.get("enable_streaming", False)
top_k = data.get("top_k", self.chat_settings.rag_top_k)
collection = data.get("collection", self.chat_settings.rag_collection) collection = data.get("collection", self.chat_settings.rag_collection)
enable_tts = data.get("enable_tts", self.chat_settings.enable_tts) enable_tts = data.get("enable_tts", self.chat_settings.enable_tts)
system_prompt = data.get("system_prompt") system_prompt = data.get("system_prompt")
# companions-frontend may set a custom response subject
response_subject = data.get(
"response_subject", f"ai.chat.response.{request_id}"
)
logger.info(f"Processing request {request_id}: {query[:50]}...") logger.info(f"Processing request {request_id}: {query[:50]}...")
with create_span("chat.process") as span: with create_span("chat.process") as span:
if span: if span:
span.set_attribute("request.id", request_id) span.set_attribute("request.id", request_id)
span.set_attribute("user.id", user_id)
span.set_attribute("query.length", len(query)) span.set_attribute("query.length", len(query))
span.set_attribute("premium", premium)
span.set_attribute("rag.enabled", enable_rag)
context = ""
rag_sources: list[str] = []
used_rag = False
# Only run RAG pipeline when enabled (premium users or explicit flag)
if enable_rag:
# 1. Generate query embedding # 1. Generate query embedding
embedding = await self._get_embedding(query) embedding = await self._get_embedding(query)
# 2. Search Milvus for context # 2. Search Milvus for context
documents = await self._search_context(embedding, collection) documents = await self._search_context(
embedding, collection, top_k=top_k,
)
# 3. Rerank documents # 3. Optionally rerank documents
if enable_reranker and documents:
reranked = await self._rerank_documents(query, documents) reranked = await self._rerank_documents(query, documents)
else:
reranked = documents
# 4. Build context from top documents # 4. Build context from top documents
if reranked:
context = self._build_context(reranked) context = self._build_context(reranked)
rag_sources = [
d.get("source", d.get("document", "")[:80])
for d in reranked[:3]
]
used_rag = True
# 5. Generate LLM response # 5. Generate LLM response (with or without RAG context)
response_text = await self._generate_response(query, context, system_prompt) response_text = await self._generate_response(
query, context or None, system_prompt,
)
# 6. Optionally synthesize speech # 6. Stream response chunks if requested
if enable_streaming:
stream_subject = f"ai.chat.response.stream.{request_id}"
await self._publish_streaming_chunks(
stream_subject, request_id, response_text,
)
# 7. Optionally synthesize speech
audio_b64 = None audio_b64 = None
if enable_tts and self.tts: if enable_tts and self.tts:
audio_b64 = await self._synthesize_speech(response_text) audio_b64 = await self._synthesize_speech(response_text)
# Build response # Build response (compatible with companions-frontend NATSChatResponse)
result = { result: dict[str, Any] = {
"request_id": request_id, "user_id": user_id,
"response": response_text, "response": response_text,
"response_text": response_text,
"used_rag": used_rag,
"rag_sources": rag_sources,
"success": True,
} }
if self.chat_settings.include_sources:
result["sources"] = [
{"text": d["document"][:200], "score": d["score"]} for d in reranked[:3]
]
if audio_b64: if audio_b64:
result["audio"] = audio_b64 result["audio"] = audio_b64
logger.info(f"Completed request {request_id}") logger.info(f"Completed request {request_id} (rag={used_rag})")
# Publish to response subject # Publish to the response subject the frontend is waiting on
response_subject = f"ai.chat.response.{request_id}"
await self.nats.publish(response_subject, result) await self.nats.publish(response_subject, result)
return result return result
@@ -173,12 +219,17 @@ class ChatHandler(Handler):
with create_span("chat.embedding"): with create_span("chat.embedding"):
return await self.embeddings.embed_single(text) return await self.embeddings.embed_single(text)
async def _search_context(self, embedding: list[float], collection: str) -> list[dict]: async def _search_context(
self,
embedding: list[float],
collection: str,
top_k: int | None = None,
) -> list[dict]:
"""Search Milvus for relevant documents.""" """Search Milvus for relevant documents."""
with create_span("chat.search"): with create_span("chat.search"):
return await self.milvus.search_with_texts( return await self.milvus.search_with_texts(
embedding, embedding,
limit=self.chat_settings.rag_top_k, limit=top_k or self.chat_settings.rag_top_k,
text_field="text", text_field="text",
metadata_fields=["source", "title"], metadata_fields=["source", "title"],
) )
@@ -202,10 +253,10 @@ class ChatHandler(Handler):
async def _generate_response( async def _generate_response(
self, self,
query: str, query: str,
context: str, context: Optional[str] = None,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
) -> str: ) -> str:
"""Generate LLM response with context.""" """Generate LLM response, optionally augmented with RAG context."""
with create_span("chat.generate"): with create_span("chat.generate"):
return await self.llm.generate( return await self.llm.generate(
query, query,
@@ -213,6 +264,41 @@ class ChatHandler(Handler):
system_prompt=system_prompt, system_prompt=system_prompt,
) )
async def _publish_streaming_chunks(
self,
subject: str,
request_id: str,
full_text: str,
) -> None:
"""Publish response as streaming chunks for real-time display."""
import time
words = full_text.split(" ")
chunk_size = 4
for i in range(0, len(words), chunk_size):
token_chunk = " ".join(words[i : i + chunk_size])
await self.nats.publish(
subject,
{
"request_id": request_id,
"type": "chunk",
"content": token_chunk,
"done": False,
"timestamp": time.time(),
},
)
# Send done marker
await self.nats.publish(
subject,
{
"request_id": request_id,
"type": "done",
"content": "",
"done": True,
"timestamp": time.time(),
},
)
async def _synthesize_speech(self, text: str) -> str: async def _synthesize_speech(self, text: str) -> str:
"""Synthesize speech and return base64 encoded audio.""" """Synthesize speech and return base64 encoded audio."""
with create_span("chat.tts"): with create_span("chat.tts"):