""" Unit tests for service clients. """ from unittest.mock import MagicMock import pytest class TestEmbeddingsClient: """Tests for EmbeddingsClient.""" @pytest.fixture def embeddings_client(self, mock_httpx_client): """Create an EmbeddingsClient with mocked HTTP.""" from handler_base.clients.embeddings import EmbeddingsClient client = EmbeddingsClient() client._client = mock_httpx_client return client @pytest.mark.asyncio async def test_embed_single(self, embeddings_client, mock_httpx_client, sample_embedding): """Test embedding a single text.""" # Setup mock response mock_response = MagicMock() mock_response.json.return_value = {"data": [{"embedding": sample_embedding, "index": 0}]} mock_response.raise_for_status = MagicMock() mock_httpx_client.post.return_value = mock_response result = await embeddings_client.embed_single("Hello world") assert result == sample_embedding mock_httpx_client.post.assert_called_once() @pytest.mark.asyncio async def test_embed_batch(self, embeddings_client, mock_httpx_client, sample_embedding): """Test embedding multiple texts.""" texts = ["Hello", "World"] mock_response = MagicMock() mock_response.json.return_value = { "data": [ {"embedding": sample_embedding, "index": 0}, {"embedding": sample_embedding, "index": 1}, ] } mock_response.raise_for_status = MagicMock() mock_httpx_client.post.return_value = mock_response result = await embeddings_client.embed(texts) assert len(result) == 2 assert all(len(e) == len(sample_embedding) for e in result) @pytest.mark.asyncio async def test_health_check(self, embeddings_client, mock_httpx_client): """Test health check.""" mock_response = MagicMock() mock_response.status_code = 200 mock_httpx_client.get.return_value = mock_response result = await embeddings_client.health() assert result is True class TestRerankerClient: """Tests for RerankerClient.""" @pytest.fixture def reranker_client(self, mock_httpx_client): """Create a RerankerClient with mocked HTTP.""" from handler_base.clients.reranker import RerankerClient client = RerankerClient() client._client = mock_httpx_client return client @pytest.mark.asyncio async def test_rerank(self, reranker_client, mock_httpx_client, sample_documents): """Test reranking documents.""" texts = [d["text"] for d in sample_documents] mock_response = MagicMock() mock_response.json.return_value = { "results": [ {"index": 1, "relevance_score": 0.95}, {"index": 0, "relevance_score": 0.80}, {"index": 2, "relevance_score": 0.65}, ] } mock_response.raise_for_status = MagicMock() mock_httpx_client.post.return_value = mock_response result = await reranker_client.rerank("What is ML?", texts) assert len(result) == 3 assert result[0]["score"] == 0.95 assert result[0]["index"] == 1 class TestLLMClient: """Tests for LLMClient.""" @pytest.fixture def llm_client(self, mock_httpx_client): """Create an LLMClient with mocked HTTP.""" from handler_base.clients.llm import LLMClient client = LLMClient() client._client = mock_httpx_client return client @pytest.mark.asyncio async def test_generate(self, llm_client, mock_httpx_client): """Test generating a response.""" mock_response = MagicMock() mock_response.json.return_value = { "choices": [{"message": {"content": "Hello! I'm an AI assistant."}}], "usage": {"prompt_tokens": 10, "completion_tokens": 20}, } mock_response.raise_for_status = MagicMock() mock_httpx_client.post.return_value = mock_response result = await llm_client.generate("Hello") assert result == "Hello! I'm an AI assistant." @pytest.mark.asyncio async def test_generate_with_context(self, llm_client, mock_httpx_client): """Test generating with RAG context.""" mock_response = MagicMock() mock_response.json.return_value = { "choices": [{"message": {"content": "Based on the context..."}}], "usage": {}, } mock_response.raise_for_status = MagicMock() mock_httpx_client.post.return_value = mock_response result = await llm_client.generate( "What is Python?", context="Python is a programming language." ) assert "Based on the context" in result # Verify context was included in the request call_args = mock_httpx_client.post.call_args messages = call_args.kwargs["json"]["messages"] assert any("Context:" in m["content"] for m in messages if m["role"] == "user")