""" Unit tests for ChatHandler. """ import base64 import pytest from unittest.mock import AsyncMock, MagicMock, patch from chat_handler import ChatHandler, ChatSettings class TestChatSettings: """Tests for ChatSettings configuration.""" def test_default_settings(self): """Test default settings values.""" settings = ChatSettings() assert settings.service_name == "chat-handler" assert settings.rag_top_k == 10 assert settings.rag_rerank_top_k == 5 assert settings.rag_collection == "documents" assert settings.include_sources is True assert settings.enable_tts is False assert settings.tts_language == "en" def test_custom_settings(self): """Test custom settings.""" settings = ChatSettings( rag_top_k=20, rag_collection="custom_docs", enable_tts=True, ) assert settings.rag_top_k == 20 assert settings.rag_collection == "custom_docs" assert settings.enable_tts is True class TestChatHandler: """Tests for ChatHandler.""" @pytest.fixture def handler(self): """Create handler with mocked clients.""" with patch("chat_handler.EmbeddingsClient"), \ patch("chat_handler.RerankerClient"), \ patch("chat_handler.LLMClient"), \ patch("chat_handler.TTSClient"), \ patch("chat_handler.MilvusClient"): handler = ChatHandler() # Setup mock clients handler.embeddings = AsyncMock() handler.reranker = AsyncMock() handler.llm = AsyncMock() handler.milvus = AsyncMock() handler.tts = None # TTS disabled by default handler.nats = AsyncMock() yield handler @pytest.fixture def handler_with_tts(self): """Create handler with TTS enabled.""" with patch("chat_handler.EmbeddingsClient"), \ patch("chat_handler.RerankerClient"), \ patch("chat_handler.LLMClient"), \ patch("chat_handler.TTSClient"), \ patch("chat_handler.MilvusClient"): handler = ChatHandler() handler.chat_settings.enable_tts = True # Setup mock clients handler.embeddings = AsyncMock() handler.reranker = AsyncMock() handler.llm = AsyncMock() handler.milvus = AsyncMock() handler.tts = AsyncMock() handler.nats = AsyncMock() yield handler def test_init(self, handler): """Test handler initialization.""" assert handler.subject == "ai.chat.request" assert handler.queue_group == "chat-handlers" assert handler.chat_settings.service_name == "chat-handler" @pytest.mark.asyncio async def test_handle_message_success( self, handler, mock_nats_message, mock_chat_request, sample_embedding, sample_documents, sample_reranked, ): """Test successful chat request handling.""" # Setup mocks handler.embeddings.embed_single.return_value = sample_embedding handler.milvus.search_with_texts.return_value = sample_documents handler.reranker.rerank.return_value = sample_reranked handler.llm.generate.return_value = "Machine learning is a subset of AI that..." # Execute result = await handler.handle_message(mock_nats_message, mock_chat_request) # Verify assert result["request_id"] == "test-request-123" assert "response" in result assert result["response"] == "Machine learning is a subset of AI that..." assert "sources" in result # include_sources is True by default # Verify pipeline was called handler.embeddings.embed_single.assert_called_once() handler.milvus.search_with_texts.assert_called_once() handler.reranker.rerank.assert_called_once() handler.llm.generate.assert_called_once() @pytest.mark.asyncio async def test_handle_message_without_sources( self, handler, mock_nats_message, mock_chat_request, sample_embedding, sample_documents, sample_reranked, ): """Test response without sources when disabled.""" handler.chat_settings.include_sources = False handler.embeddings.embed_single.return_value = sample_embedding handler.milvus.search_with_texts.return_value = sample_documents handler.reranker.rerank.return_value = sample_reranked handler.llm.generate.return_value = "Response text" result = await handler.handle_message(mock_nats_message, mock_chat_request) assert "sources" not in result @pytest.mark.asyncio async def test_handle_message_with_tts( self, handler_with_tts, mock_nats_message, mock_chat_request_with_tts, sample_embedding, sample_documents, sample_reranked, ): """Test response with TTS audio.""" handler = handler_with_tts handler.embeddings.embed_single.return_value = sample_embedding handler.milvus.search_with_texts.return_value = sample_documents handler.reranker.rerank.return_value = sample_reranked handler.llm.generate.return_value = "AI response" handler.tts.synthesize.return_value = b"audio_bytes" result = await handler.handle_message(mock_nats_message, mock_chat_request_with_tts) assert "audio" in result handler.tts.synthesize.assert_called_once() @pytest.mark.asyncio async def test_handle_message_with_custom_system_prompt( self, handler, mock_nats_message, sample_embedding, sample_documents, sample_reranked, ): """Test LLM is called with custom system prompt.""" request = { "request_id": "test-123", "query": "Hello", "system_prompt": "You are a pirate. Respond like one.", } handler.embeddings.embed_single.return_value = sample_embedding handler.milvus.search_with_texts.return_value = sample_documents handler.reranker.rerank.return_value = sample_reranked handler.llm.generate.return_value = "Ahoy!" await handler.handle_message(mock_nats_message, request) # Verify system_prompt was passed to LLM handler.llm.generate.assert_called_once() call_kwargs = handler.llm.generate.call_args.kwargs assert call_kwargs.get("system_prompt") == "You are a pirate. Respond like one." def test_build_context(self, handler): """Test context building with numbered sources.""" documents = [ {"document": "First doc content"}, {"document": "Second doc content"}, ] context = handler._build_context(documents) assert "[1]" in context assert "[2]" in context assert "First doc content" in context assert "Second doc content" in context @pytest.mark.asyncio async def test_setup_initializes_clients(self): """Test that setup initializes all required clients.""" with patch("chat_handler.EmbeddingsClient") as emb_cls, \ patch("chat_handler.RerankerClient") as rer_cls, \ patch("chat_handler.LLMClient") as llm_cls, \ patch("chat_handler.TTSClient") as tts_cls, \ patch("chat_handler.MilvusClient") as mil_cls: mil_cls.return_value.connect = AsyncMock() handler = ChatHandler() await handler.setup() emb_cls.assert_called_once() rer_cls.assert_called_once() llm_cls.assert_called_once() mil_cls.assert_called_once() # TTS should not be initialized when disabled tts_cls.assert_not_called() @pytest.mark.asyncio async def test_teardown_closes_clients(self, handler): """Test that teardown closes all clients.""" await handler.teardown() handler.embeddings.close.assert_called_once() handler.reranker.close.assert_called_once() handler.llm.close.assert_called_once() handler.milvus.close.assert_called_once() @pytest.mark.asyncio async def test_publishes_to_response_subject( self, handler, mock_nats_message, mock_chat_request, sample_embedding, sample_documents, sample_reranked, ): """Test that result is published to response subject.""" handler.embeddings.embed_single.return_value = sample_embedding handler.milvus.search_with_texts.return_value = sample_documents handler.reranker.rerank.return_value = sample_reranked handler.llm.generate.return_value = "Response" await handler.handle_message(mock_nats_message, mock_chat_request) handler.nats.publish.assert_called_once() call_args = handler.nats.publish.call_args assert "ai.chat.response.test-request-123" in str(call_args)