diff --git a/Dockerfile b/Dockerfile index cdde590..1478ed2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,29 +1,9 @@ -FROM python:3.13-slim +# Chat Handler - Using handler-base +ARG BASE_TAG=latest +FROM ghcr.io/billy-davies-2/handler-base:${BASE_TAG} WORKDIR /app -# Install uv for fast, reliable package management -COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv - -# Install system dependencies -RUN apt-get update && apt-get install -y --no-install-recommends \ - curl \ - && rm -rf /var/lib/apt/lists/* - -# Copy requirements first for better caching -COPY requirements.txt . -RUN uv pip install --system --no-cache -r requirements.txt - -# Copy application code COPY chat_handler.py . -# Set environment variables -ENV PYTHONUNBUFFERED=1 -ENV PYTHONDONTWRITEBYTECODE=1 - -# Health check -HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ - CMD python -c "print('healthy')" || exit 1 - -# Run the application CMD ["python", "chat_handler.py"] diff --git a/Dockerfile.v2 b/Dockerfile.v2 deleted file mode 100644 index 9756a0b..0000000 --- a/Dockerfile.v2 +++ /dev/null @@ -1,11 +0,0 @@ -# Chat Handler v2 - Using handler-base -ARG BASE_TAG=local -FROM ghcr.io/billy-davies-2/handler-base:${BASE_TAG} - -WORKDIR /app - -# Copy only the handler code (dependencies are in base image) -COPY chat_handler_v2.py ./chat_handler.py - -# Run the handler -CMD ["python", "chat_handler.py"] diff --git a/README.md b/README.md index 55acd68..e85810b 100644 --- a/README.md +++ b/README.md @@ -4,19 +4,10 @@ Text-based chat pipeline for the DaviesTechLabs AI/ML platform. ## Overview -A NATS-based service that handles chat completion requests with RAG (Retrieval Augmented Generation). +A NATS-based service that handles chat completion requests with RAG (Retrieval Augmented Generation). It uses the [handler-base](https://git.daviestechlabs.io/daviestechlabs/handler-base) library for standardized NATS handling, telemetry, and health checks. **Pipeline:** Query → Embeddings → Milvus → Rerank → LLM → (optional TTS) -## Versions - -| File | Description | -|------|-------------| -| `chat_handler.py` | Standalone implementation (v1) | -| `chat_handler_v2.py` | Uses handler-base library (recommended) | -| `Dockerfile` | Standalone image | -| `Dockerfile.v2` | Handler-base image | - ## Architecture ``` @@ -88,19 +79,10 @@ NATS (ai.chat.request) ## Building ```bash -# Standalone image (v1) -docker build -f Dockerfile -t chat-handler:latest . +docker build -t chat-handler:latest . -# Handler-base image (v2 - recommended) -docker build -f Dockerfile.v2 -t chat-handler:v2 . -``` - -## Dependencies - -The v2 handler depends on [handler-base](https://git.daviestechlabs.io/daviestechlabs/handler-base): - -```bash -pip install git+https://git.daviestechlabs.io/daviestechlabs/handler-base.git +# With specific handler-base tag +docker build --build-arg BASE_TAG=latest -t chat-handler:latest . ``` ## Related diff --git a/chat_handler.py b/chat_handler.py index 2ce7ef1..f720e4e 100644 --- a/chat_handler.py +++ b/chat_handler.py @@ -1,867 +1,233 @@ #!/usr/bin/env python3 """ -Chat Handler Service +Chat Handler Service (Refactored) -Text-based chat pipeline: +Text-based chat pipeline using handler-base: 1. Listen for text on NATS subject "ai.chat.request" -2. Generate embeddings for RAG (optional) +2. Generate embeddings for RAG 3. Retrieve context from Milvus 4. Rerank with BGE reranker 5. Generate response with vLLM 6. Optionally synthesize speech with XTTS 7. Publish result to NATS "ai.chat.response.{request_id}" """ -import asyncio import base64 -import json import logging -import os -import signal -import subprocess -import sys -import time -from typing import List, Dict, Optional +from typing import Any, Optional -# Install dependencies on startup -subprocess.check_call([ - sys.executable, "-m", "pip", "install", "-q", - "--root-user-action=ignore", - "-r", "/app/requirements.txt" -]) +from nats.aio.msg import Msg -import httpx -import msgpack -import nats -import redis.asyncio as redis -from pymilvus import connections, Collection, utility - -# OpenTelemetry imports -from opentelemetry import trace, metrics -from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.sdk.metrics import MeterProvider -from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader -from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter -from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter -from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as OTLPSpanExporterHTTP -from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as OTLPMetricExporterHTTP -from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION, SERVICE_NAMESPACE -from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor -from opentelemetry.instrumentation.logging import LoggingInstrumentor - -# MLflow inference tracking -try: - from mlflow_utils import InferenceMetricsTracker - from mlflow_utils.inference_tracker import InferenceMetrics - MLFLOW_AVAILABLE = True -except ImportError: - MLFLOW_AVAILABLE = False - InferenceMetricsTracker = None - InferenceMetrics = None - -# Configure logging -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +from handler_base import Handler, Settings +from handler_base.clients import ( + EmbeddingsClient, + RerankerClient, + LLMClient, + TTSClient, + MilvusClient, ) +from handler_base.telemetry import create_span + logger = logging.getLogger("chat-handler") -def setup_telemetry(): - """Initialize OpenTelemetry tracing and metrics.""" - otel_enabled = os.environ.get("OTEL_ENABLED", "true").lower() == "true" - if not otel_enabled: - logger.info("OpenTelemetry disabled") - return None, None +class ChatSettings(Settings): + """Chat handler specific settings.""" - otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://opentelemetry-collector.observability.svc.cluster.local:4317") - service_name = os.environ.get("OTEL_SERVICE_NAME", "chat-handler") - service_namespace = os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml") + service_name: str = "chat-handler" - # HyperDX configuration - hyperdx_api_key = os.environ.get("HYPERDX_API_KEY", "") - hyperdx_endpoint = os.environ.get("HYPERDX_ENDPOINT", "https://in-otel.hyperdx.io") - use_hyperdx = os.environ.get("HYPERDX_ENABLED", "false").lower() == "true" and hyperdx_api_key + # RAG settings + rag_top_k: int = 10 + rag_rerank_top_k: int = 5 + rag_collection: str = "documents" - resource = Resource.create({ - SERVICE_NAME: service_name, - SERVICE_VERSION: os.environ.get("SERVICE_VERSION", "1.0.0"), - SERVICE_NAMESPACE: service_namespace, - "deployment.environment": os.environ.get("DEPLOYMENT_ENV", "production"), - "host.name": os.environ.get("HOSTNAME", "unknown"), - }) - - trace_provider = TracerProvider(resource=resource) - - if use_hyperdx: - logger.info(f"Configuring HyperDX exporter at {hyperdx_endpoint}") - headers = {"authorization": hyperdx_api_key} - otlp_span_exporter = OTLPSpanExporterHTTP( - endpoint=f"{hyperdx_endpoint}/v1/traces", - headers=headers - ) - otlp_metric_exporter = OTLPMetricExporterHTTP( - endpoint=f"{hyperdx_endpoint}/v1/metrics", - headers=headers - ) - else: - otlp_span_exporter = OTLPSpanExporter(endpoint=otel_endpoint, insecure=True) - otlp_metric_exporter = OTLPMetricExporter(endpoint=otel_endpoint, insecure=True) - - trace_provider.add_span_processor(BatchSpanProcessor(otlp_span_exporter)) - trace.set_tracer_provider(trace_provider) - - metric_reader = PeriodicExportingMetricReader(otlp_metric_exporter, export_interval_millis=60000) - meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) - metrics.set_meter_provider(meter_provider) - - HTTPXClientInstrumentor().instrument() - LoggingInstrumentor().instrument(set_logging_format=True) - - destination = "HyperDX" if use_hyperdx else "OTEL Collector" - logger.info(f"OpenTelemetry initialized - destination: {destination}, service: {service_name}") - - return trace.get_tracer(__name__), metrics.get_meter(__name__) - -# Configuration from environment -TTS_URL = os.environ.get("TTS_URL", "http://tts-predictor.ai-ml.svc.cluster.local") -EMBEDDINGS_URL = os.environ.get( - "EMBEDDINGS_URL", "http://embeddings-predictor.ai-ml.svc.cluster.local" -) -RERANKER_URL = os.environ.get( - "RERANKER_URL", "http://reranker-predictor.ai-ml.svc.cluster.local" -) -VLLM_URL = os.environ.get("VLLM_URL", "http://llm-draft.ai-ml.svc.cluster.local:8000") -LLM_MODEL = os.environ.get("LLM_MODEL", "mistralai/Mistral-7B-Instruct-v0.3") -MILVUS_HOST = os.environ.get("MILVUS_HOST", "milvus.ai-ml.svc.cluster.local") -MILVUS_PORT = int(os.environ.get("MILVUS_PORT", "19530")) -COLLECTION_NAME = os.environ.get("COLLECTION_NAME", "knowledge_base") -NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222") -VALKEY_URL = os.environ.get("VALKEY_URL", "redis://valkey.ai-ml.svc.cluster.local:6379") - -# MLflow configuration -MLFLOW_ENABLED = os.environ.get("MLFLOW_ENABLED", "true").lower() == "true" -MLFLOW_TRACKING_URI = os.environ.get( - "MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80" -) - -# Context window limits (characters) -MAX_CONTEXT_LENGTH = int(os.environ.get("MAX_CONTEXT_LENGTH", "8000")) # Prevent unbounded growth - -# NATS subjects (ai.* schema) -# Per-user channels matching companions-frontend pattern -REQUEST_SUBJECT = "ai.chat.user.*.message" # Wildcard subscription for all users -PREMIUM_REQUEST_SUBJECT = "ai.chat.premium.user.*.message" # Premium users -RESPONSE_SUBJECT = "ai.chat.response" # Response published to specific request_id -STREAM_RESPONSE_SUBJECT = "ai.chat.response.stream" # Streaming responses (token chunks) - -# System prompt for the assistant -SYSTEM_PROMPT = """You are a helpful AI assistant. -Answer questions based on the provided context when available. -Be concise and informative. If you don't know the answer, say so clearly.""" + # Response settings + include_sources: bool = True + enable_tts: bool = False + tts_language: str = "en" -class ChatHandler: +class ChatHandler(Handler): + """ + Chat request handler with RAG pipeline. + + Request format: + { + "request_id": "uuid", + "query": "user question", + "collection": "optional collection name", + "enable_tts": false, + "system_prompt": "optional custom system prompt" + } + + Response format: + { + "request_id": "uuid", + "response": "generated response", + "sources": [{"text": "...", "score": 0.95}], + "audio": "base64 encoded audio (if tts enabled)" + } + """ + def __init__(self): - self.nc = None - self.http_client = None - self.collection = None - self.valkey_client = None - self.running = True - self.tracer = None - self.meter = None - self.request_counter = None - self.request_duration = None - self.rag_search_duration = None - # MLflow inference tracker - self.mlflow_tracker = None - - async def setup(self): - """Initialize all connections.""" - # Initialize OpenTelemetry - self.tracer, self.meter = setup_telemetry() - - # Setup metrics - if self.meter: - self.request_counter = self.meter.create_counter( - "chat.requests", - description="Number of chat requests processed", - unit="1" - ) - self.request_duration = self.meter.create_histogram( - "chat.request_duration", - description="Duration of chat request processing", - unit="s" - ) - self.rag_search_duration = self.meter.create_histogram( - "chat.rag_search_duration", - description="Duration of RAG search operations", - unit="s" - ) - - # Initialize MLflow inference tracker - if MLFLOW_ENABLED and MLFLOW_AVAILABLE: - try: - self.mlflow_tracker = InferenceMetricsTracker( - service_name="chat-handler", - experiment_name="chat-inference", - tracking_uri=MLFLOW_TRACKING_URI, - batch_size=50, - flush_interval_seconds=60.0, - ) - await self.mlflow_tracker.start() - logger.info(f"MLflow inference tracking enabled at {MLFLOW_TRACKING_URI}") - except Exception as e: - logger.warning(f"MLflow initialization failed: {e}, tracking disabled") - self.mlflow_tracker = None - elif not MLFLOW_AVAILABLE: - logger.info("MLflow utils not available, inference tracking disabled") - else: - logger.info("MLflow tracking disabled via MLFLOW_ENABLED=false") - - # NATS connection with reconnection support - async def disconnected_cb(): - logger.warning("NATS disconnected, attempting reconnection...") - - async def reconnected_cb(): - logger.info(f"NATS reconnected to {self.nc.connected_url.netloc}") - - async def error_cb(e): - logger.error(f"NATS error: {e}") - - async def closed_cb(): - logger.warning("NATS connection closed") - - self.nc = await nats.connect( - NATS_URL, - reconnect_time_wait=2, - max_reconnect_attempts=-1, # Infinite reconnection attempts - disconnected_cb=disconnected_cb, - reconnected_cb=reconnected_cb, - error_cb=error_cb, - closed_cb=closed_cb, + self.chat_settings = ChatSettings() + super().__init__( + subject="ai.chat.request", + settings=self.chat_settings, + queue_group="chat-handlers", ) - logger.info(f"Connected to NATS at {NATS_URL}") - - # HTTP client for services - self.http_client = httpx.AsyncClient(timeout=180.0) - - # Connect to Valkey for conversation history and context caching - try: - self.valkey_client = redis.from_url( - VALKEY_URL, - encoding="utf-8", - decode_responses=True, - socket_connect_timeout=5 - ) - await self.valkey_client.ping() - logger.info(f"Connected to Valkey at {VALKEY_URL}") - except Exception as e: - logger.warning(f"Valkey connection failed: {e}, conversation history disabled") - self.valkey_client = None - - # Connect to Milvus if collection exists - try: - connections.connect(host=MILVUS_HOST, port=MILVUS_PORT) - if utility.has_collection(COLLECTION_NAME): - self.collection = Collection(COLLECTION_NAME) - self.collection.load() - logger.info(f"Connected to Milvus collection: {COLLECTION_NAME}") - else: - logger.warning(f"Collection {COLLECTION_NAME} not found, RAG disabled") - except Exception as e: - logger.warning(f"Milvus connection failed: {e}, RAG disabled") - - async def get_embeddings(self, texts: List[str]) -> List[List[float]]: - """Get embeddings from the embedding service.""" - try: - response = await self.http_client.post( - f"{EMBEDDINGS_URL}/embeddings", json={"input": texts, "model": "bge"} - ) - result = response.json() - return [d["embedding"] for d in result.get("data", [])] - except Exception as e: - logger.error(f"Embedding failed: {e}") - return [] - - async def search_milvus( - self, query_embedding: List[float], top_k: int = 5 - ) -> List[Dict]: - """Search Milvus for relevant documents.""" - if not self.collection: - return [] - try: - results = self.collection.search( - data=[query_embedding], - anns_field="embedding", - param={"metric_type": "COSINE", "params": {"ef": 64}}, - limit=top_k, - output_fields=["text", "book_name", "page_num"], - ) - docs = [] - for hits in results: - for hit in hits: - docs.append( - { - "text": hit.entity.get("text", ""), - "source": f'{hit.entity.get("book_name", "")} p.{hit.entity.get("page_num", "")}', - "score": hit.score, - } - ) - return docs - except Exception as e: - logger.error(f"Milvus search failed: {e}") - return [] - - async def rerank(self, query: str, documents: List[str]) -> List[Dict]: - """Rerank documents using the reranker service.""" - if not documents: - return [] - try: - response = await self.http_client.post( - f"{RERANKER_URL}/v1/rerank", - json={"query": query, "documents": documents}, - ) - return response.json().get("results", []) - except Exception as e: - logger.error(f"Reranking failed: {e}") - return [{"index": i, "relevance_score": 0.5} for i in range(len(documents))] - - async def get_conversation_history(self, session_id: str, max_messages: int = 10) -> List[Dict]: - """Retrieve conversation history from Valkey.""" - if not self.valkey_client or not session_id: - return [] - try: - key = f"chat:history:{session_id}" - # Get the most recent messages (stored as a list) - history_json = await self.valkey_client.lrange(key, -max_messages, -1) - history = [json.loads(msg) for msg in history_json] - logger.info(f"Retrieved {len(history)} messages from history for session {session_id}") - return history - except Exception as e: - logger.error(f"Failed to get conversation history: {e}") - return [] - - async def save_message_to_history(self, session_id: str, role: str, content: str, ttl: int = 3600): - """Save a message to conversation history in Valkey.""" - if not self.valkey_client or not session_id: - return - try: - key = f"chat:history:{session_id}" - message = json.dumps({"role": role, "content": content, "timestamp": time.time()}) - # Use RPUSH to append to the list - await self.valkey_client.rpush(key, message) - # Set TTL on the key (1 hour by default) - await self.valkey_client.expire(key, ttl) - logger.debug(f"Saved {role} message to history for session {session_id}") - except Exception as e: - logger.error(f"Failed to save message to history: {e}") - - async def get_context_window(self, session_id: str) -> Optional[str]: - """Retrieve cached context window from Valkey for attention offloading.""" - if not self.valkey_client or not session_id: - return None - try: - key = f"chat:context:{session_id}" - context = await self.valkey_client.get(key) - if context: - logger.info(f"Retrieved cached context window for session {session_id}") - return context - except Exception as e: - logger.error(f"Failed to get context window: {e}") - return None - - async def save_context_window(self, session_id: str, context: str, ttl: int = 3600): - """Save context window to Valkey for attention offloading.""" - if not self.valkey_client or not session_id: - return - try: - key = f"chat:context:{session_id}" - await self.valkey_client.set(key, context, ex=ttl) - logger.debug(f"Saved context window for session {session_id}") - except Exception as e: - logger.error(f"Failed to save context window: {e}") - - async def generate_response(self, query: str, context: str = "", session_id: str = None) -> str: - """Generate response using vLLM with conversation history from Valkey.""" - try: - messages = [{"role": "system", "content": SYSTEM_PROMPT}] - - # Add conversation history from Valkey if session exists - if session_id: - history = await self.get_conversation_history(session_id) - messages.extend(history) - - if context: - messages.append( - { - "role": "user", - "content": f"Context:\n{context}\n\nQuestion: {query}", - } - ) - else: - messages.append({"role": "user", "content": query}) - - response = await self.http_client.post( - f"{VLLM_URL}/v1/chat/completions", - json={ - "model": LLM_MODEL, - "messages": messages, - "max_tokens": 1000, - "temperature": 0.7, - }, - ) - result = response.json() - answer = result["choices"][0]["message"]["content"] - logger.info(f"Generated response: {answer[:100]}...") - - # Save messages to conversation history - if session_id: - await self.save_message_to_history(session_id, "user", query) - await self.save_message_to_history(session_id, "assistant", answer) - - return answer - except Exception as e: - logger.error(f"LLM generation failed: {e}") - return "I'm sorry, I couldn't generate a response." - - async def generate_response_streaming(self, query: str, context: str = "", request_id: str = "", session_id: str = None): - """Generate streaming response using vLLM and publish chunks to NATS. + + async def setup(self) -> None: + """Initialize service clients.""" + logger.info("Initializing service clients...") - Yields tokens as they are generated and publishes them to NATS streaming subject. - Returns the complete response text. - """ - try: - messages = [{"role": "system", "content": SYSTEM_PROMPT}] - - # Add conversation history from Valkey if session exists - if session_id: - history = await self.get_conversation_history(session_id) - messages.extend(history) - - if context: - messages.append( - { - "role": "user", - "content": f"Context:\n{context}\n\nQuestion: {query}", - } - ) - else: - messages.append({"role": "user", "content": query}) - - full_response = "" - - # Stream response from vLLM - async with self.http_client.stream( - "POST", - f"{VLLM_URL}/v1/chat/completions", - json={ - "model": LLM_MODEL, - "messages": messages, - "max_tokens": 1000, - "temperature": 0.7, - "stream": True, - }, - timeout=60.0, - ) as response: - # Parse SSE (Server-Sent Events) stream - async for line in response.aiter_lines(): - if not line or not line.startswith("data: "): - continue - - data_str = line[6:] # Remove "data: " prefix - if data_str.strip() == "[DONE]": - break - - try: - chunk_data = json.loads(data_str) - - # Extract token from delta - if chunk_data.get("choices") and len(chunk_data["choices"]) > 0: - delta = chunk_data["choices"][0].get("delta", {}) - content = delta.get("content", "") - - if content: - full_response += content - - # Publish token chunk to NATS streaming subject - chunk_msg = { - "request_id": request_id, - "type": "chunk", - "content": content, - "done": False, - } - await self.nc.publish( - f"{STREAM_RESPONSE_SUBJECT}.{request_id}", - msgpack.packb(chunk_msg) - ) - except json.JSONDecodeError: - continue - - # Send completion message - completion_msg = { - "request_id": request_id, - "type": "done", - "content": "", - "done": True, - } - await self.nc.publish( - f"{STREAM_RESPONSE_SUBJECT}.{request_id}", - msgpack.packb(completion_msg) - ) - - logger.info(f"Streamed complete response ({len(full_response)} chars) for request {request_id}") - - # Save messages to conversation history - if session_id: - await self.save_message_to_history(session_id, "user", query) - await self.save_message_to_history(session_id, "assistant", full_response) - - return full_response - - except Exception as e: - logger.error(f"Streaming LLM generation failed: {e}") - # Send error message - error_msg = { - "request_id": request_id, - "type": "error", - "content": "I'm sorry, I couldn't generate a response.", - "done": True, - "error": str(e), - } - await self.nc.publish( - f"{STREAM_RESPONSE_SUBJECT}.{request_id}", - msgpack.packb(error_msg) - ) - return "I'm sorry, I couldn't generate a response." - - async def synthesize_speech(self, text: str, language: str = "en") -> str: - """Convert text to speech using XTTS (Coqui TTS).""" - try: - response = await self.http_client.get( - f"{TTS_URL}/api/tts", params={"text": text, "language_id": language} - ) - if response.status_code == 200: - audio_b64 = base64.b64encode(response.content).decode("utf-8") - logger.info(f"Synthesized {len(response.content)} bytes of audio") - return audio_b64 - else: - logger.error( - f"TTS returned status {response.status_code}: {response.text}" - ) - return "" - except Exception as e: - logger.error(f"TTS failed: {e}") - return "" - - async def process_request(self, msg, is_premium=False): - """Process a chat request.""" - start_time = time.time() - span = None + self.embeddings = EmbeddingsClient(self.chat_settings) + self.reranker = RerankerClient(self.chat_settings) + self.llm = LLMClient(self.chat_settings) + self.milvus = MilvusClient(self.chat_settings) - # MLflow metrics tracking - mlflow_metrics = None - embedding_start = None - rag_start = None - rerank_start = None - llm_start = None + # TTS is optional + if self.chat_settings.enable_tts: + self.tts = TTSClient(self.chat_settings) + else: + self.tts = None - try: - data = msgpack.unpackb(msg.data, raw=False) - - # Support companions-frontend format (user_id, username, message, premium) - # as well as the original format (request_id, text, enable_rag, etc.) - user_id = data.get("user_id") - username = data.get("username", "") - - # Get text from either 'message' (companions-frontend) or 'text' (original) - text = data.get("message") or data.get("text", "") - - # Generate request_id from user_id if not provided - import uuid - request_id = data.get("request_id") or f"{user_id or 'anon'}-{uuid.uuid4().hex[:8]}" - - # Initialize MLflow metrics if available - if self.mlflow_tracker and MLFLOW_AVAILABLE: - mlflow_metrics = InferenceMetrics( - request_id=request_id, - user_id=user_id, - session_id=data.get("session_id"), - model_name=LLM_MODEL, - model_endpoint=VLLM_URL, - ) - - # Start tracing span - if self.tracer: - span = self.tracer.start_span("chat.process_request") - span.set_attribute("request_id", request_id) - span.set_attribute("user_id", user_id or "anonymous") - span.set_attribute("premium", is_premium) - - # Premium status from message or channel - is_premium = is_premium or data.get("premium", False) - - # Support both new parameters and backward compatibility with use_rag - use_rag = data.get("use_rag") # Legacy parameter - enable_rag = data.get( - "enable_rag", use_rag if use_rag is not None else True - ) - enable_reranker = data.get( - "enable_reranker", use_rag if use_rag is not None else True - ) - - # Premium users get more documents for deeper RAG - default_top_k = 15 if is_premium else 5 - top_k = data.get("top_k", default_top_k) - - # Get request parameters - enable_tts = data.get("enable_tts", False) - enable_streaming = data.get("enable_streaming", False) # New parameter for streaming - language = data.get("language", "en") - session_id = data.get("session_id") - - # Update MLflow metrics with request params - if mlflow_metrics: - mlflow_metrics.rag_enabled = enable_rag - mlflow_metrics.reranker_enabled = enable_reranker - mlflow_metrics.is_streaming = enable_streaming - mlflow_metrics.is_premium = is_premium - mlflow_metrics.prompt_length = len(text) - - # Add attributes to span + # Connect to Milvus + await self.milvus.connect(self.chat_settings.rag_collection) + + logger.info("Service clients initialized") + + async def teardown(self) -> None: + """Clean up service clients.""" + logger.info("Closing service clients...") + + await self.embeddings.close() + await self.reranker.close() + await self.llm.close() + await self.milvus.close() + + if self.tts: + await self.tts.close() + + logger.info("Service clients closed") + + async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]: + """Handle incoming chat request.""" + request_id = data.get("request_id", "unknown") + query = data.get("query", "") + collection = data.get("collection", self.chat_settings.rag_collection) + enable_tts = data.get("enable_tts", self.chat_settings.enable_tts) + system_prompt = data.get("system_prompt") + + logger.info(f"Processing request {request_id}: {query[:50]}...") + + with create_span("chat.process") as span: if span: - span.set_attribute("enable_rag", enable_rag) - span.set_attribute("enable_reranker", enable_reranker) - span.set_attribute("top_k", top_k) - span.set_attribute("enable_tts", enable_tts) - span.set_attribute("enable_streaming", enable_streaming) - - logger.info( - f"Processing {'premium ' if is_premium else ''}chat request {request_id} from {username or user_id or 'anonymous'}: {text[:50]}... (RAG: {enable_rag}, Reranker: {enable_reranker}, top_k: {top_k})" + span.set_attribute("request.id", request_id) + span.set_attribute("query.length", len(query)) + + # 1. Generate query embedding + embedding = await self._get_embedding(query) + + # 2. Search Milvus for context + documents = await self._search_context(embedding, collection) + + # 3. Rerank documents + reranked = await self._rerank_documents(query, documents) + + # 4. Build context from top documents + context = self._build_context(reranked) + + # 5. Generate LLM response + response_text = await self._generate_response( + query, context, system_prompt ) - - # Warn if reranker is enabled without RAG - if enable_reranker and not enable_rag: - logger.warning( - f"Request {request_id}: Reranker enabled without RAG - no documents to rerank" - ) - - if not text: - await self.publish_error(request_id, "No text provided") - return - - context = "" - rag_sources = [] - docs = [] - - # Step 1: RAG retrieval (if enabled) - if enable_rag and self.collection: - # Get embeddings for RAG - embedding_start = time.time() - embeddings = await self.get_embeddings([text]) - if mlflow_metrics: - mlflow_metrics.embedding_latency = time.time() - embedding_start - - if embeddings: - # Search Milvus with configurable top_k - rag_start = time.time() - docs = await self.search_milvus(embeddings[0], top_k=top_k) - if mlflow_metrics: - mlflow_metrics.rag_search_latency = time.time() - rag_start - mlflow_metrics.rag_documents_retrieved = len(docs) - if docs: - rag_sources = [d.get("source", "") for d in docs] - - # Step 2: Reranking (if enabled and we have documents) - if enable_reranker and docs: - # Rerank documents - rerank_start = time.time() - doc_texts = [d["text"] for d in docs] - reranked = await self.rerank(text, doc_texts) - if mlflow_metrics: - mlflow_metrics.rerank_latency = time.time() - rerank_start - # Take top 3 reranked documents with bounds checking - sorted_docs = sorted( - reranked, key=lambda x: x.get("relevance_score", 0), reverse=True - )[:3] - # Build context with bounds checking - # Note: doc_texts and docs have the same length (doc_texts derived from docs) - context_parts = [] - sources = [] - for item in sorted_docs: - idx = item.get("index", -1) - if 0 <= idx < len(docs): - context_parts.append(doc_texts[idx]) - sources.append(docs[idx].get("source", "")) - else: - logger.warning( - f"Reranker returned invalid index {idx} for {len(docs)} docs" - ) - context = "\n\n".join(context_parts) - rag_sources = sources - elif docs: - # Use documents without reranking (take top 3) - doc_texts = [d["text"] for d in docs[:3]] - context = "\n\n".join(doc_texts) - rag_sources = [d.get("source", "") for d in docs[:3]] - - # Step 3: Generate response (streaming or non-streaming) - # Check for cached context window from Valkey (for attention offloading) - cached_context = None - if session_id: - cached_context = await self.get_context_window(session_id) - # Combine RAG context with cached context if available - if cached_context and context: - # Prepend cached context to current RAG context - combined_context = f"{cached_context}\n\n{context}" - # Truncate to prevent unbounded growth - if len(combined_context) > MAX_CONTEXT_LENGTH: - logger.warning(f"Context length {len(combined_context)} exceeds max {MAX_CONTEXT_LENGTH}, truncating") - # Keep the most recent context (from the end) - combined_context = combined_context[-MAX_CONTEXT_LENGTH:] - context = combined_context - elif cached_context: - # Only cached context, still need to check length - if len(cached_context) > MAX_CONTEXT_LENGTH: - logger.warning(f"Cached context length {len(cached_context)} exceeds max {MAX_CONTEXT_LENGTH}, truncating") - cached_context = cached_context[-MAX_CONTEXT_LENGTH:] - context = cached_context + # 6. Optionally synthesize speech + audio_b64 = None + if enable_tts and self.tts: + audio_b64 = await self._synthesize_speech(response_text) - # Save the combined context for future use (already truncated if needed) - if session_id and context: - await self.save_context_window(session_id, context) - - # Track number of RAG docs used after reranking - if mlflow_metrics and enable_rag: - mlflow_metrics.rag_documents_used = min(3, len(docs)) if docs else 0 - - llm_start = time.time() - if enable_streaming: - # Use streaming response - answer = await self.generate_response_streaming(text, context, request_id, session_id) - else: - # Use non-streaming response - answer = await self.generate_response(text, context, session_id) - - if mlflow_metrics: - mlflow_metrics.llm_latency = time.time() - llm_start - mlflow_metrics.response_length = len(answer) - # Estimate token counts (rough approximation: 4 chars per token) - mlflow_metrics.input_tokens = len(text) // 4 - mlflow_metrics.output_tokens = len(answer) // 4 - mlflow_metrics.total_tokens = mlflow_metrics.input_tokens + mlflow_metrics.output_tokens - - # Step 4: Optionally synthesize speech - audio_b64 = "" - if enable_tts: - audio_b64 = await self.synthesize_speech(answer, language) - - # Publish result - # Include both 'response' (companions-frontend) and 'response_text' (original) for compatibility + # Build response result = { "request_id": request_id, - "user_id": user_id, - "text": text, - "response": answer, # companions-frontend expects 'response' - "response_text": answer, # original format - "audio_b64": audio_b64 if enable_tts else None, - "used_rag": bool(context), - "rag_enabled": enable_rag, - "reranker_enabled": enable_reranker, - "rag_sources": rag_sources, - "session_id": session_id, - "success": True, + "response": response_text, } - await self.nc.publish( - f"{RESPONSE_SUBJECT}.{request_id}", msgpack.packb(result) + + if self.chat_settings.include_sources: + result["sources"] = [ + {"text": d["document"][:200], "score": d["score"]} + for d in reranked[:3] + ] + + if audio_b64: + result["audio"] = audio_b64 + + logger.info(f"Completed request {request_id}") + + # Publish to response subject + response_subject = f"ai.chat.response.{request_id}" + await self.nats.publish(response_subject, result) + + return result + + async def _get_embedding(self, text: str) -> list[float]: + """Generate embedding for query text.""" + with create_span("chat.embedding"): + return await self.embeddings.embed_single(text) + + async def _search_context( + self, embedding: list[float], collection: str + ) -> list[dict]: + """Search Milvus for relevant documents.""" + with create_span("chat.search"): + return await self.milvus.search_with_texts( + embedding, + limit=self.chat_settings.rag_top_k, + text_field="text", + metadata_fields=["source", "title"], ) - logger.info(f"Published response for request {request_id}") - - # Record metrics - duration = time.time() - start_time - if self.request_counter: - self.request_counter.add(1, {"premium": str(is_premium), "rag_enabled": str(enable_rag), "success": "true"}) - if self.request_duration: - self.request_duration.record(duration, {"premium": str(is_premium), "rag_enabled": str(enable_rag)}) - if span: - span.set_attribute("success", True) - span.set_attribute("response_length", len(answer)) - - # Log to MLflow - if self.mlflow_tracker and mlflow_metrics: - mlflow_metrics.total_latency = duration - await self.mlflow_tracker.log_inference(mlflow_metrics) - - except Exception as e: - logger.error(f"Request processing failed: {e}") - if self.request_counter: - self.request_counter.add(1, {"premium": str(is_premium), "success": "false"}) - if span: - span.set_attribute("success", False) - span.set_attribute("error", str(e)) - - # Log error to MLflow - if self.mlflow_tracker and mlflow_metrics: - mlflow_metrics.has_error = True - mlflow_metrics.error_message = str(e) - mlflow_metrics.total_latency = time.time() - start_time - await self.mlflow_tracker.log_inference(mlflow_metrics) - - await self.publish_error(data.get("request_id", "unknown"), str(e)) - finally: - if span: - span.end() - - async def publish_error(self, request_id: str, error: str): - """Publish an error response.""" - result = {"request_id": request_id, "error": error, "success": False} - await self.nc.publish( - f"{RESPONSE_SUBJECT}.{request_id}", msgpack.packb(result) - ) - - async def process_premium_request(self, msg): - """Process a premium chat request (wrapper for deeper RAG).""" - await self.process_request(msg, is_premium=True) - - async def run(self): - """Main run loop.""" - await self.setup() - - # Subscribe to standard chat requests - sub = await self.nc.subscribe(REQUEST_SUBJECT, cb=self.process_request) - logger.info(f"Subscribed to {REQUEST_SUBJECT}") - - # Subscribe to premium chat requests (deeper RAG retrieval) - premium_sub = await self.nc.subscribe( - PREMIUM_REQUEST_SUBJECT, cb=self.process_premium_request - ) - logger.info(f"Subscribed to {PREMIUM_REQUEST_SUBJECT}") - - # Handle shutdown - def signal_handler(): - self.running = False - - loop = asyncio.get_event_loop() - for sig in (signal.SIGTERM, signal.SIGINT): - loop.add_signal_handler(sig, signal_handler) - - # Keep running - while self.running: - await asyncio.sleep(1) - - # Cleanup - await sub.unsubscribe() - await premium_sub.unsubscribe() - await self.nc.close() - if self.valkey_client: - await self.valkey_client.close() - if self.collection: - connections.disconnect("default") - if self.mlflow_tracker: - await self.mlflow_tracker.stop() - logger.info("Shutdown complete") + + async def _rerank_documents( + self, query: str, documents: list[dict] + ) -> list[dict]: + """Rerank documents by relevance to query.""" + with create_span("chat.rerank"): + texts = [d.get("text", "") for d in documents] + return await self.reranker.rerank( + query, texts, top_k=self.chat_settings.rag_rerank_top_k + ) + + def _build_context(self, documents: list[dict]) -> str: + """Build context string from ranked documents.""" + context_parts = [] + for i, doc in enumerate(documents, 1): + text = doc.get("document", "") + context_parts.append(f"[{i}] {text}") + return "\n\n".join(context_parts) + + async def _generate_response( + self, + query: str, + context: str, + system_prompt: Optional[str] = None, + ) -> str: + """Generate LLM response with context.""" + with create_span("chat.generate"): + return await self.llm.generate( + query, + context=context, + system_prompt=system_prompt, + ) + + async def _synthesize_speech(self, text: str) -> str: + """Synthesize speech and return base64 encoded audio.""" + with create_span("chat.tts"): + audio_bytes = await self.tts.synthesize( + text, + language=self.chat_settings.tts_language, + ) + return base64.b64encode(audio_bytes).decode() if __name__ == "__main__": - handler = ChatHandler() - asyncio.run(handler.run()) + ChatHandler().run() diff --git a/chat_handler_v2.py b/chat_handler_v2.py deleted file mode 100644 index f720e4e..0000000 --- a/chat_handler_v2.py +++ /dev/null @@ -1,233 +0,0 @@ -#!/usr/bin/env python3 -""" -Chat Handler Service (Refactored) - -Text-based chat pipeline using handler-base: -1. Listen for text on NATS subject "ai.chat.request" -2. Generate embeddings for RAG -3. Retrieve context from Milvus -4. Rerank with BGE reranker -5. Generate response with vLLM -6. Optionally synthesize speech with XTTS -7. Publish result to NATS "ai.chat.response.{request_id}" -""" -import base64 -import logging -from typing import Any, Optional - -from nats.aio.msg import Msg - -from handler_base import Handler, Settings -from handler_base.clients import ( - EmbeddingsClient, - RerankerClient, - LLMClient, - TTSClient, - MilvusClient, -) -from handler_base.telemetry import create_span - -logger = logging.getLogger("chat-handler") - - -class ChatSettings(Settings): - """Chat handler specific settings.""" - - service_name: str = "chat-handler" - - # RAG settings - rag_top_k: int = 10 - rag_rerank_top_k: int = 5 - rag_collection: str = "documents" - - # Response settings - include_sources: bool = True - enable_tts: bool = False - tts_language: str = "en" - - -class ChatHandler(Handler): - """ - Chat request handler with RAG pipeline. - - Request format: - { - "request_id": "uuid", - "query": "user question", - "collection": "optional collection name", - "enable_tts": false, - "system_prompt": "optional custom system prompt" - } - - Response format: - { - "request_id": "uuid", - "response": "generated response", - "sources": [{"text": "...", "score": 0.95}], - "audio": "base64 encoded audio (if tts enabled)" - } - """ - - def __init__(self): - self.chat_settings = ChatSettings() - super().__init__( - subject="ai.chat.request", - settings=self.chat_settings, - queue_group="chat-handlers", - ) - - async def setup(self) -> None: - """Initialize service clients.""" - logger.info("Initializing service clients...") - - self.embeddings = EmbeddingsClient(self.chat_settings) - self.reranker = RerankerClient(self.chat_settings) - self.llm = LLMClient(self.chat_settings) - self.milvus = MilvusClient(self.chat_settings) - - # TTS is optional - if self.chat_settings.enable_tts: - self.tts = TTSClient(self.chat_settings) - else: - self.tts = None - - # Connect to Milvus - await self.milvus.connect(self.chat_settings.rag_collection) - - logger.info("Service clients initialized") - - async def teardown(self) -> None: - """Clean up service clients.""" - logger.info("Closing service clients...") - - await self.embeddings.close() - await self.reranker.close() - await self.llm.close() - await self.milvus.close() - - if self.tts: - await self.tts.close() - - logger.info("Service clients closed") - - async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]: - """Handle incoming chat request.""" - request_id = data.get("request_id", "unknown") - query = data.get("query", "") - collection = data.get("collection", self.chat_settings.rag_collection) - enable_tts = data.get("enable_tts", self.chat_settings.enable_tts) - system_prompt = data.get("system_prompt") - - logger.info(f"Processing request {request_id}: {query[:50]}...") - - with create_span("chat.process") as span: - if span: - span.set_attribute("request.id", request_id) - span.set_attribute("query.length", len(query)) - - # 1. Generate query embedding - embedding = await self._get_embedding(query) - - # 2. Search Milvus for context - documents = await self._search_context(embedding, collection) - - # 3. Rerank documents - reranked = await self._rerank_documents(query, documents) - - # 4. Build context from top documents - context = self._build_context(reranked) - - # 5. Generate LLM response - response_text = await self._generate_response( - query, context, system_prompt - ) - - # 6. Optionally synthesize speech - audio_b64 = None - if enable_tts and self.tts: - audio_b64 = await self._synthesize_speech(response_text) - - # Build response - result = { - "request_id": request_id, - "response": response_text, - } - - if self.chat_settings.include_sources: - result["sources"] = [ - {"text": d["document"][:200], "score": d["score"]} - for d in reranked[:3] - ] - - if audio_b64: - result["audio"] = audio_b64 - - logger.info(f"Completed request {request_id}") - - # Publish to response subject - response_subject = f"ai.chat.response.{request_id}" - await self.nats.publish(response_subject, result) - - return result - - async def _get_embedding(self, text: str) -> list[float]: - """Generate embedding for query text.""" - with create_span("chat.embedding"): - return await self.embeddings.embed_single(text) - - async def _search_context( - self, embedding: list[float], collection: str - ) -> list[dict]: - """Search Milvus for relevant documents.""" - with create_span("chat.search"): - return await self.milvus.search_with_texts( - embedding, - limit=self.chat_settings.rag_top_k, - text_field="text", - metadata_fields=["source", "title"], - ) - - async def _rerank_documents( - self, query: str, documents: list[dict] - ) -> list[dict]: - """Rerank documents by relevance to query.""" - with create_span("chat.rerank"): - texts = [d.get("text", "") for d in documents] - return await self.reranker.rerank( - query, texts, top_k=self.chat_settings.rag_rerank_top_k - ) - - def _build_context(self, documents: list[dict]) -> str: - """Build context string from ranked documents.""" - context_parts = [] - for i, doc in enumerate(documents, 1): - text = doc.get("document", "") - context_parts.append(f"[{i}] {text}") - return "\n\n".join(context_parts) - - async def _generate_response( - self, - query: str, - context: str, - system_prompt: Optional[str] = None, - ) -> str: - """Generate LLM response with context.""" - with create_span("chat.generate"): - return await self.llm.generate( - query, - context=context, - system_prompt=system_prompt, - ) - - async def _synthesize_speech(self, text: str) -> str: - """Synthesize speech and return base64 encoded audio.""" - with create_span("chat.tts"): - audio_bytes = await self.tts.synthesize( - text, - language=self.chat_settings.tts_language, - ) - return base64.b64encode(audio_bytes).decode() - - -if __name__ == "__main__": - ChatHandler().run() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..610cf74 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,40 @@ +[project] +name = "chat-handler" +version = "1.0.0" +description = "Text chat pipeline with RAG - Query → Embeddings → Milvus → Rerank → LLM" +readme = "README.md" +requires-python = ">=3.11" +license = { text = "MIT" } +authors = [{ name = "Davies Tech Labs" }] + +dependencies = [ + "handler-base @ git+https://git.daviestechlabs.io/daviestechlabs/handler-base.git", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", + "ruff>=0.1.0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["."] +only-include = ["chat_handler.py"] + +[tool.ruff] +line-length = 100 +target-version = "py311" + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = "-v --tb=short" +filterwarnings = ["ignore::DeprecationWarning"] diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index ef68ffa..0000000 --- a/requirements.txt +++ /dev/null @@ -1,15 +0,0 @@ -nats-py -httpx -pymilvus -numpy -msgpack -redis>=5.0.0 -opentelemetry-api -opentelemetry-sdk -opentelemetry-exporter-otlp-proto-grpc -opentelemetry-exporter-otlp-proto-http -opentelemetry-instrumentation-httpx -opentelemetry-instrumentation-logging -# MLflow for inference metrics tracking -mlflow>=2.10.0 -psycopg2-binary>=2.9.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..926f4ef --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Chat Handler Tests diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..3c8bef8 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,81 @@ +""" +Pytest configuration and fixtures for chat-handler tests. +""" +import asyncio +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Set test environment variables before importing +os.environ.setdefault("NATS_URL", "nats://localhost:4222") +os.environ.setdefault("REDIS_URL", "redis://localhost:6379") +os.environ.setdefault("MILVUS_HOST", "localhost") +os.environ.setdefault("OTEL_ENABLED", "false") +os.environ.setdefault("MLFLOW_ENABLED", "false") + + +@pytest.fixture(scope="session") +def event_loop(): + """Create event loop for async tests.""" + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture +def sample_embedding(): + """Sample embedding vector.""" + return [0.1] * 1024 + + +@pytest.fixture +def sample_documents(): + """Sample search results.""" + return [ + {"text": "Machine learning is a subset of AI.", "score": 0.95}, + {"text": "Deep learning uses neural networks.", "score": 0.90}, + {"text": "AI enables intelligent automation.", "score": 0.85}, + ] + + +@pytest.fixture +def sample_reranked(): + """Sample reranked results.""" + return [ + {"document": "Machine learning is a subset of AI.", "score": 0.98}, + {"document": "Deep learning uses neural networks.", "score": 0.85}, + ] + + +@pytest.fixture +def mock_nats_message(): + """Create a mock NATS message.""" + msg = MagicMock() + msg.subject = "ai.chat.request" + msg.reply = "ai.chat.response.test-123" + return msg + + +@pytest.fixture +def mock_chat_request(): + """Sample chat request payload.""" + return { + "request_id": "test-request-123", + "query": "What is machine learning?", + "collection": "test_collection", + "enable_tts": False, + "system_prompt": None, + } + + +@pytest.fixture +def mock_chat_request_with_tts(): + """Sample chat request with TTS enabled.""" + return { + "request_id": "test-request-456", + "query": "Tell me about AI", + "collection": "documents", + "enable_tts": True, + "system_prompt": "You are a helpful assistant.", + } diff --git a/tests/test_chat_handler.py b/tests/test_chat_handler.py new file mode 100644 index 0000000..e435f7f --- /dev/null +++ b/tests/test_chat_handler.py @@ -0,0 +1,262 @@ +""" +Unit tests for ChatHandler. +""" +import base64 +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from chat_handler import ChatHandler, ChatSettings + + +class TestChatSettings: + """Tests for ChatSettings configuration.""" + + def test_default_settings(self): + """Test default settings values.""" + settings = ChatSettings() + + assert settings.service_name == "chat-handler" + assert settings.rag_top_k == 10 + assert settings.rag_rerank_top_k == 5 + assert settings.rag_collection == "documents" + assert settings.include_sources is True + assert settings.enable_tts is False + assert settings.tts_language == "en" + + def test_custom_settings(self): + """Test custom settings.""" + settings = ChatSettings( + rag_top_k=20, + rag_collection="custom_docs", + enable_tts=True, + ) + + assert settings.rag_top_k == 20 + assert settings.rag_collection == "custom_docs" + assert settings.enable_tts is True + + +class TestChatHandler: + """Tests for ChatHandler.""" + + @pytest.fixture + def handler(self): + """Create handler with mocked clients.""" + with patch("chat_handler.EmbeddingsClient"), \ + patch("chat_handler.RerankerClient"), \ + patch("chat_handler.LLMClient"), \ + patch("chat_handler.TTSClient"), \ + patch("chat_handler.MilvusClient"): + + handler = ChatHandler() + + # Setup mock clients + handler.embeddings = AsyncMock() + handler.reranker = AsyncMock() + handler.llm = AsyncMock() + handler.milvus = AsyncMock() + handler.tts = None # TTS disabled by default + handler.nats = AsyncMock() + + yield handler + + @pytest.fixture + def handler_with_tts(self): + """Create handler with TTS enabled.""" + with patch("chat_handler.EmbeddingsClient"), \ + patch("chat_handler.RerankerClient"), \ + patch("chat_handler.LLMClient"), \ + patch("chat_handler.TTSClient"), \ + patch("chat_handler.MilvusClient"): + + handler = ChatHandler() + handler.chat_settings.enable_tts = True + + # Setup mock clients + handler.embeddings = AsyncMock() + handler.reranker = AsyncMock() + handler.llm = AsyncMock() + handler.milvus = AsyncMock() + handler.tts = AsyncMock() + handler.nats = AsyncMock() + + yield handler + + def test_init(self, handler): + """Test handler initialization.""" + assert handler.subject == "ai.chat.request" + assert handler.queue_group == "chat-handlers" + assert handler.chat_settings.service_name == "chat-handler" + + @pytest.mark.asyncio + async def test_handle_message_success( + self, + handler, + mock_nats_message, + mock_chat_request, + sample_embedding, + sample_documents, + sample_reranked, + ): + """Test successful chat request handling.""" + # Setup mocks + handler.embeddings.embed_single.return_value = sample_embedding + handler.milvus.search_with_texts.return_value = sample_documents + handler.reranker.rerank.return_value = sample_reranked + handler.llm.generate.return_value = "Machine learning is a subset of AI that..." + + # Execute + result = await handler.handle_message(mock_nats_message, mock_chat_request) + + # Verify + assert result["request_id"] == "test-request-123" + assert "response" in result + assert result["response"] == "Machine learning is a subset of AI that..." + assert "sources" in result # include_sources is True by default + + # Verify pipeline was called + handler.embeddings.embed_single.assert_called_once() + handler.milvus.search_with_texts.assert_called_once() + handler.reranker.rerank.assert_called_once() + handler.llm.generate.assert_called_once() + + @pytest.mark.asyncio + async def test_handle_message_without_sources( + self, + handler, + mock_nats_message, + mock_chat_request, + sample_embedding, + sample_documents, + sample_reranked, + ): + """Test response without sources when disabled.""" + handler.chat_settings.include_sources = False + + handler.embeddings.embed_single.return_value = sample_embedding + handler.milvus.search_with_texts.return_value = sample_documents + handler.reranker.rerank.return_value = sample_reranked + handler.llm.generate.return_value = "Response text" + + result = await handler.handle_message(mock_nats_message, mock_chat_request) + + assert "sources" not in result + + @pytest.mark.asyncio + async def test_handle_message_with_tts( + self, + handler_with_tts, + mock_nats_message, + mock_chat_request_with_tts, + sample_embedding, + sample_documents, + sample_reranked, + ): + """Test response with TTS audio.""" + handler = handler_with_tts + + handler.embeddings.embed_single.return_value = sample_embedding + handler.milvus.search_with_texts.return_value = sample_documents + handler.reranker.rerank.return_value = sample_reranked + handler.llm.generate.return_value = "AI response" + handler.tts.synthesize.return_value = b"audio_bytes" + + result = await handler.handle_message(mock_nats_message, mock_chat_request_with_tts) + + assert "audio" in result + handler.tts.synthesize.assert_called_once() + + @pytest.mark.asyncio + async def test_handle_message_with_custom_system_prompt( + self, + handler, + mock_nats_message, + sample_embedding, + sample_documents, + sample_reranked, + ): + """Test LLM is called with custom system prompt.""" + request = { + "request_id": "test-123", + "query": "Hello", + "system_prompt": "You are a pirate. Respond like one.", + } + + handler.embeddings.embed_single.return_value = sample_embedding + handler.milvus.search_with_texts.return_value = sample_documents + handler.reranker.rerank.return_value = sample_reranked + handler.llm.generate.return_value = "Ahoy!" + + await handler.handle_message(mock_nats_message, request) + + # Verify system_prompt was passed to LLM + handler.llm.generate.assert_called_once() + call_kwargs = handler.llm.generate.call_args.kwargs + assert call_kwargs.get("system_prompt") == "You are a pirate. Respond like one." + + def test_build_context(self, handler): + """Test context building with numbered sources.""" + documents = [ + {"document": "First doc content"}, + {"document": "Second doc content"}, + ] + + context = handler._build_context(documents) + + assert "[1]" in context + assert "[2]" in context + assert "First doc content" in context + assert "Second doc content" in context + + @pytest.mark.asyncio + async def test_setup_initializes_clients(self): + """Test that setup initializes all required clients.""" + with patch("chat_handler.EmbeddingsClient") as emb_cls, \ + patch("chat_handler.RerankerClient") as rer_cls, \ + patch("chat_handler.LLMClient") as llm_cls, \ + patch("chat_handler.TTSClient") as tts_cls, \ + patch("chat_handler.MilvusClient") as mil_cls: + + mil_cls.return_value.connect = AsyncMock() + + handler = ChatHandler() + await handler.setup() + + emb_cls.assert_called_once() + rer_cls.assert_called_once() + llm_cls.assert_called_once() + mil_cls.assert_called_once() + # TTS should not be initialized when disabled + tts_cls.assert_not_called() + + @pytest.mark.asyncio + async def test_teardown_closes_clients(self, handler): + """Test that teardown closes all clients.""" + await handler.teardown() + + handler.embeddings.close.assert_called_once() + handler.reranker.close.assert_called_once() + handler.llm.close.assert_called_once() + handler.milvus.close.assert_called_once() + + @pytest.mark.asyncio + async def test_publishes_to_response_subject( + self, + handler, + mock_nats_message, + mock_chat_request, + sample_embedding, + sample_documents, + sample_reranked, + ): + """Test that result is published to response subject.""" + handler.embeddings.embed_single.return_value = sample_embedding + handler.milvus.search_with_texts.return_value = sample_documents + handler.reranker.rerank.return_value = sample_reranked + handler.llm.generate.return_value = "Response" + + await handler.handle_message(mock_nats_message, mock_chat_request) + + handler.nats.publish.assert_called_once() + call_args = handler.nats.publish.call_args + assert "ai.chat.response.test-request-123" in str(call_args)