""" Unit tests for ChatHandler. """ import pytest from unittest.mock import AsyncMock, patch from chat_handler import ChatHandler, ChatSettings class TestChatSettings: """Tests for ChatSettings configuration.""" def test_default_settings(self): """Test default settings values.""" settings = ChatSettings() assert settings.service_name == "chat-handler" assert settings.rag_top_k == 10 assert settings.rag_rerank_top_k == 5 assert settings.rag_collection == "documents" assert settings.include_sources is True assert settings.enable_tts is False assert settings.tts_language == "en" def test_custom_settings(self): """Test custom settings.""" settings = ChatSettings( rag_top_k=20, rag_collection="custom_docs", enable_tts=True, ) assert settings.rag_top_k == 20 assert settings.rag_collection == "custom_docs" assert settings.enable_tts is True class TestChatHandler: """Tests for ChatHandler.""" @pytest.fixture def handler(self): """Create handler with mocked clients.""" with ( patch("chat_handler.EmbeddingsClient"), patch("chat_handler.RerankerClient"), patch("chat_handler.LLMClient"), patch("chat_handler.TTSClient"), patch("chat_handler.MilvusClient"), ): handler = ChatHandler() # Setup mock clients handler.embeddings = AsyncMock() handler.reranker = AsyncMock() handler.llm = AsyncMock() handler.milvus = AsyncMock() handler.tts = None # TTS disabled by default handler.nats = AsyncMock() yield handler @pytest.fixture def handler_with_tts(self): """Create handler with TTS enabled.""" with ( patch("chat_handler.EmbeddingsClient"), patch("chat_handler.RerankerClient"), patch("chat_handler.LLMClient"), patch("chat_handler.TTSClient"), patch("chat_handler.MilvusClient"), ): handler = ChatHandler() handler.chat_settings.enable_tts = True # Setup mock clients handler.embeddings = AsyncMock() handler.reranker = AsyncMock() handler.llm = AsyncMock() handler.milvus = AsyncMock() handler.tts = AsyncMock() handler.nats = AsyncMock() yield handler def test_init(self, handler): """Test handler initialization.""" assert handler.subject == "ai.chat.user.*.message" assert handler.queue_group == "chat-handlers" assert handler.chat_settings.service_name == "chat-handler" @pytest.mark.asyncio async def test_handle_message_success( self, handler, mock_nats_message, mock_chat_request, sample_embedding, sample_documents, sample_reranked, ): """Test successful chat request handling.""" # Setup mocks handler.embeddings.embed_single.return_value = sample_embedding handler.milvus.search_with_texts.return_value = sample_documents handler.reranker.rerank.return_value = sample_reranked handler.llm.generate.return_value = "Machine learning is a subset of AI that..." # Execute result = await handler.handle_message(mock_nats_message, mock_chat_request) # Verify assert result["user_id"] == "test-user-1" assert result["success"] is True assert "response" in result assert result["response"] == "Machine learning is a subset of AI that..." assert result["response_text"] == result["response"] assert result["used_rag"] is True assert isinstance(result["rag_sources"], list) # Verify RAG pipeline was called (enable_rag=True in fixture) handler.embeddings.embed_single.assert_called_once() handler.milvus.search_with_texts.assert_called_once() handler.reranker.rerank.assert_called_once() handler.llm.generate.assert_called_once() @pytest.mark.asyncio async def test_handle_message_without_sources( self, handler, mock_nats_message, mock_chat_request, sample_embedding, sample_documents, sample_reranked, ): """Test response without sources when disabled.""" handler.chat_settings.include_sources = False handler.embeddings.embed_single.return_value = sample_embedding handler.milvus.search_with_texts.return_value = sample_documents handler.reranker.rerank.return_value = sample_reranked handler.llm.generate.return_value = "Response text" result = await handler.handle_message(mock_nats_message, mock_chat_request) # New response format doesn't have a separate "sources" key; # rag_sources is always present (may be empty) assert "rag_sources" in result @pytest.mark.asyncio async def test_handle_message_with_tts( self, handler_with_tts, mock_nats_message, mock_chat_request_with_tts, sample_embedding, sample_documents, sample_reranked, ): """Test response with TTS audio.""" handler = handler_with_tts handler.embeddings.embed_single.return_value = sample_embedding handler.milvus.search_with_texts.return_value = sample_documents handler.reranker.rerank.return_value = sample_reranked handler.llm.generate.return_value = "AI response" handler.tts.synthesize.return_value = b"audio_bytes" result = await handler.handle_message(mock_nats_message, mock_chat_request_with_tts) assert "audio" in result handler.tts.synthesize.assert_called_once() @pytest.mark.asyncio async def test_handle_message_with_custom_system_prompt( self, handler, mock_nats_message, sample_embedding, sample_documents, sample_reranked, ): """Test LLM is called with custom system prompt.""" request = { "request_id": "test-123", "user_id": "user-42", "message": "Hello", "premium": True, "enable_rag": True, "system_prompt": "You are a pirate. Respond like one.", } handler.embeddings.embed_single.return_value = sample_embedding handler.milvus.search_with_texts.return_value = sample_documents handler.reranker.rerank.return_value = sample_reranked handler.llm.generate.return_value = "Ahoy!" await handler.handle_message(mock_nats_message, request) # Verify system_prompt was passed to LLM handler.llm.generate.assert_called_once() call_kwargs = handler.llm.generate.call_args.kwargs assert call_kwargs.get("system_prompt") == "You are a pirate. Respond like one." def test_build_context(self, handler): """Test context building with numbered sources.""" documents = [ {"document": "First doc content"}, {"document": "Second doc content"}, ] context = handler._build_context(documents) assert "[1]" in context assert "[2]" in context assert "First doc content" in context assert "Second doc content" in context @pytest.mark.asyncio async def test_setup_initializes_clients(self): """Test that setup initializes all required clients.""" with ( patch("chat_handler.EmbeddingsClient") as emb_cls, patch("chat_handler.RerankerClient") as rer_cls, patch("chat_handler.LLMClient") as llm_cls, patch("chat_handler.TTSClient") as tts_cls, patch("chat_handler.MilvusClient") as mil_cls, ): mil_cls.return_value.connect = AsyncMock() handler = ChatHandler() await handler.setup() emb_cls.assert_called_once() rer_cls.assert_called_once() llm_cls.assert_called_once() mil_cls.assert_called_once() # TTS should not be initialized when disabled tts_cls.assert_not_called() @pytest.mark.asyncio async def test_teardown_closes_clients(self, handler): """Test that teardown closes all clients.""" await handler.teardown() handler.embeddings.close.assert_called_once() handler.reranker.close.assert_called_once() handler.llm.close.assert_called_once() handler.milvus.close.assert_called_once() @pytest.mark.asyncio async def test_publishes_to_response_subject( self, handler, mock_nats_message, mock_chat_request, sample_embedding, sample_documents, sample_reranked, ): """Test that result is published to response subject.""" handler.embeddings.embed_single.return_value = sample_embedding handler.milvus.search_with_texts.return_value = sample_documents handler.reranker.rerank.return_value = sample_reranked handler.llm.generate.return_value = "Response" await handler.handle_message(mock_nats_message, mock_chat_request) handler.nats.publish.assert_called_once() call_args = handler.nats.publish.call_args assert "ai.chat.response.test-request-123" in str(call_args)