From 849da661e6f704bdb70b8c3f0f7e25f9c7da0ece Mon Sep 17 00:00:00 2001 From: "Billy D." Date: Mon, 2 Feb 2026 06:23:44 -0500 Subject: [PATCH] test: add unit tests for handler-base - tests/conftest.py: Pytest fixtures and configuration - tests/unit/test_config.py: Settings tests - tests/unit/test_nats_client.py: NATS client tests - tests/unit/test_health.py: Health server tests - tests/unit/test_clients.py: Service client tests - pytest.ini: Pytest configuration --- pytest.ini | 7 ++ tests/conftest.py | 76 ++++++++++++++++ tests/unit/__init__.py | 1 + tests/unit/test_clients.py | 156 +++++++++++++++++++++++++++++++++ tests/unit/test_config.py | 47 ++++++++++ tests/unit/test_health.py | 122 ++++++++++++++++++++++++++ tests/unit/test_nats_client.py | 91 +++++++++++++++++++ 7 files changed, 500 insertions(+) create mode 100644 pytest.ini create mode 100644 tests/conftest.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/test_clients.py create mode 100644 tests/unit/test_config.py create mode 100644 tests/unit/test_health.py create mode 100644 tests/unit/test_nats_client.py diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..a1d9fad --- /dev/null +++ b/pytest.ini @@ -0,0 +1,7 @@ +[pytest] +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* +asyncio_mode = auto +addopts = -v --tb=short diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..5581e1f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,76 @@ +""" +Pytest configuration and fixtures. +""" +import asyncio +import os +from typing import AsyncGenerator +from unittest.mock import AsyncMock, MagicMock + +import pytest + + +# Set test environment variables before importing handler_base +os.environ.setdefault("NATS_URL", "nats://localhost:4222") +os.environ.setdefault("REDIS_URL", "redis://localhost:6379") +os.environ.setdefault("MILVUS_HOST", "localhost") +os.environ.setdefault("OTEL_ENABLED", "false") +os.environ.setdefault("MLFLOW_ENABLED", "false") + + +@pytest.fixture(scope="session") +def event_loop(): + """Create event loop for async tests.""" + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture +def settings(): + """Create test settings.""" + from handler_base.config import Settings + return Settings( + service_name="test-service", + service_version="1.0.0-test", + otel_enabled=False, + mlflow_enabled=False, + nats_url="nats://localhost:4222", + redis_url="redis://localhost:6379", + milvus_host="localhost", + ) + + +@pytest.fixture +def mock_httpx_client(): + """Create a mock httpx AsyncClient.""" + client = AsyncMock() + client.post = AsyncMock() + client.get = AsyncMock() + client.aclose = AsyncMock() + return client + + +@pytest.fixture +def mock_nats_message(): + """Create a mock NATS message.""" + msg = MagicMock() + msg.subject = "test.subject" + msg.reply = "test.reply" + msg.data = b'\x82\xa8query\xa5hello\xaarequest_id\xa4test' # msgpack + return msg + + +@pytest.fixture +def sample_embedding(): + """Sample embedding vector.""" + return [0.1] * 1024 + + +@pytest.fixture +def sample_documents(): + """Sample documents for testing.""" + return [ + {"text": "Python is a programming language.", "source": "doc1"}, + {"text": "Machine learning is a subset of AI.", "source": "doc2"}, + {"text": "Deep learning uses neural networks.", "source": "doc3"}, + ] diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..4a5d263 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1 @@ +# Unit tests package diff --git a/tests/unit/test_clients.py b/tests/unit/test_clients.py new file mode 100644 index 0000000..b187758 --- /dev/null +++ b/tests/unit/test_clients.py @@ -0,0 +1,156 @@ +""" +Unit tests for service clients. +""" +import json +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + + +class TestEmbeddingsClient: + """Tests for EmbeddingsClient.""" + + @pytest.fixture + def embeddings_client(self, mock_httpx_client): + """Create an EmbeddingsClient with mocked HTTP.""" + from handler_base.clients.embeddings import EmbeddingsClient + + client = EmbeddingsClient() + client._client = mock_httpx_client + return client + + @pytest.mark.asyncio + async def test_embed_single(self, embeddings_client, mock_httpx_client, sample_embedding): + """Test embedding a single text.""" + # Setup mock response + mock_response = MagicMock() + mock_response.json.return_value = { + "data": [{"embedding": sample_embedding, "index": 0}] + } + mock_response.raise_for_status = MagicMock() + mock_httpx_client.post.return_value = mock_response + + result = await embeddings_client.embed_single("Hello world") + + assert result == sample_embedding + mock_httpx_client.post.assert_called_once() + + @pytest.mark.asyncio + async def test_embed_batch(self, embeddings_client, mock_httpx_client, sample_embedding): + """Test embedding multiple texts.""" + texts = ["Hello", "World"] + + mock_response = MagicMock() + mock_response.json.return_value = { + "data": [ + {"embedding": sample_embedding, "index": 0}, + {"embedding": sample_embedding, "index": 1}, + ] + } + mock_response.raise_for_status = MagicMock() + mock_httpx_client.post.return_value = mock_response + + result = await embeddings_client.embed(texts) + + assert len(result) == 2 + assert all(len(e) == len(sample_embedding) for e in result) + + @pytest.mark.asyncio + async def test_health_check(self, embeddings_client, mock_httpx_client): + """Test health check.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_httpx_client.get.return_value = mock_response + + result = await embeddings_client.health() + + assert result is True + + +class TestRerankerClient: + """Tests for RerankerClient.""" + + @pytest.fixture + def reranker_client(self, mock_httpx_client): + """Create a RerankerClient with mocked HTTP.""" + from handler_base.clients.reranker import RerankerClient + + client = RerankerClient() + client._client = mock_httpx_client + return client + + @pytest.mark.asyncio + async def test_rerank(self, reranker_client, mock_httpx_client, sample_documents): + """Test reranking documents.""" + texts = [d["text"] for d in sample_documents] + + mock_response = MagicMock() + mock_response.json.return_value = { + "results": [ + {"index": 1, "relevance_score": 0.95}, + {"index": 0, "relevance_score": 0.80}, + {"index": 2, "relevance_score": 0.65}, + ] + } + mock_response.raise_for_status = MagicMock() + mock_httpx_client.post.return_value = mock_response + + result = await reranker_client.rerank("What is ML?", texts) + + assert len(result) == 3 + assert result[0]["score"] == 0.95 + assert result[0]["index"] == 1 + + +class TestLLMClient: + """Tests for LLMClient.""" + + @pytest.fixture + def llm_client(self, mock_httpx_client): + """Create an LLMClient with mocked HTTP.""" + from handler_base.clients.llm import LLMClient + + client = LLMClient() + client._client = mock_httpx_client + return client + + @pytest.mark.asyncio + async def test_generate(self, llm_client, mock_httpx_client): + """Test generating a response.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "choices": [ + {"message": {"content": "Hello! I'm an AI assistant."}} + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 20} + } + mock_response.raise_for_status = MagicMock() + mock_httpx_client.post.return_value = mock_response + + result = await llm_client.generate("Hello") + + assert result == "Hello! I'm an AI assistant." + + @pytest.mark.asyncio + async def test_generate_with_context(self, llm_client, mock_httpx_client): + """Test generating with RAG context.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "choices": [ + {"message": {"content": "Based on the context..."}} + ], + "usage": {} + } + mock_response.raise_for_status = MagicMock() + mock_httpx_client.post.return_value = mock_response + + result = await llm_client.generate( + "What is Python?", + context="Python is a programming language." + ) + + assert "Based on the context" in result + + # Verify context was included in the request + call_args = mock_httpx_client.post.call_args + messages = call_args.kwargs["json"]["messages"] + assert any("Context:" in m["content"] for m in messages if m["role"] == "user") diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 0000000..c9134ce --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,47 @@ +""" +Unit tests for handler_base.config module. +""" +import os +import pytest + + +class TestSettings: + """Tests for Settings configuration.""" + + def test_default_settings(self, settings): + """Test that default settings are loaded correctly.""" + assert settings.service_name == "test-service" + assert settings.service_version == "1.0.0-test" + assert settings.otel_enabled is False + + def test_settings_from_env(self, monkeypatch): + """Test that settings can be loaded from environment variables.""" + monkeypatch.setenv("SERVICE_NAME", "env-service") + monkeypatch.setenv("SERVICE_VERSION", "2.0.0") + monkeypatch.setenv("NATS_URL", "nats://custom:4222") + + # Need to reimport to pick up env changes + from handler_base.config import Settings + s = Settings() + + assert s.service_name == "env-service" + assert s.service_version == "2.0.0" + assert s.nats_url == "nats://custom:4222" + + def test_embeddings_settings(self): + """Test EmbeddingsSettings extends base correctly.""" + from handler_base.config import EmbeddingsSettings + + s = EmbeddingsSettings() + assert hasattr(s, "embeddings_model") + assert hasattr(s, "embeddings_batch_size") + assert s.embeddings_model == "bge" + + def test_llm_settings(self): + """Test LLMSettings has expected defaults.""" + from handler_base.config import LLMSettings + + s = LLMSettings() + assert s.llm_max_tokens == 2048 + assert s.llm_temperature == 0.7 + assert 0 <= s.llm_top_p <= 1 diff --git a/tests/unit/test_health.py b/tests/unit/test_health.py new file mode 100644 index 0000000..a5cde25 --- /dev/null +++ b/tests/unit/test_health.py @@ -0,0 +1,122 @@ +""" +Unit tests for handler_base.health module. +""" +import pytest +import json +import threading +import time +from http.client import HTTPConnection +from unittest.mock import AsyncMock + + +class TestHealthServer: + """Tests for HealthServer.""" + + @pytest.fixture + def health_server(self, settings): + """Create a HealthServer instance.""" + from handler_base.health import HealthServer + + # Use a random high port to avoid conflicts + settings.health_port = 18080 + return HealthServer(settings) + + def test_start_stop(self, health_server): + """Test starting and stopping the health server.""" + health_server.start() + time.sleep(0.1) # Give server time to start + + # Verify server is running + assert health_server._server is not None + assert health_server._thread is not None + assert health_server._thread.is_alive() + + health_server.stop() + time.sleep(0.1) + + assert health_server._server is None + + def test_health_endpoint(self, health_server): + """Test the /health endpoint.""" + health_server.start() + time.sleep(0.1) + + try: + conn = HTTPConnection("localhost", 18080, timeout=5) + conn.request("GET", "/health") + response = conn.getresponse() + + assert response.status == 200 + data = json.loads(response.read().decode()) + assert data["status"] == "healthy" + finally: + conn.close() + health_server.stop() + + def test_ready_endpoint_default(self, health_server): + """Test the /ready endpoint with no custom check.""" + health_server.start() + time.sleep(0.1) + + try: + conn = HTTPConnection("localhost", 18080, timeout=5) + conn.request("GET", "/ready") + response = conn.getresponse() + + assert response.status == 200 + data = json.loads(response.read().decode()) + assert data["status"] == "ready" + finally: + conn.close() + health_server.stop() + + def test_ready_endpoint_with_check(self, settings): + """Test /ready endpoint with custom readiness check.""" + from handler_base.health import HealthServer + + ready_flag = [False] # Use list to allow mutation in closure + + async def check_ready(): + return ready_flag[0] + + settings.health_port = 18081 + server = HealthServer(settings, ready_check=check_ready) + server.start() + time.sleep(0.2) + + try: + conn = HTTPConnection("localhost", 18081, timeout=5) + + # Should be not ready initially + conn.request("GET", "/ready") + response = conn.getresponse() + response.read() # Consume response body + assert response.status == 503 + + # Mark as ready + ready_flag[0] = True + + # Need new connection after consuming response + conn.close() + conn = HTTPConnection("localhost", 18081, timeout=5) + conn.request("GET", "/ready") + response = conn.getresponse() + assert response.status == 200 + finally: + conn.close() + server.stop() + + def test_404_for_unknown_path(self, health_server): + """Test that unknown paths return 404.""" + health_server.start() + time.sleep(0.1) + + try: + conn = HTTPConnection("localhost", 18080, timeout=5) + conn.request("GET", "/unknown") + response = conn.getresponse() + + assert response.status == 404 + finally: + conn.close() + health_server.stop() diff --git a/tests/unit/test_nats_client.py b/tests/unit/test_nats_client.py new file mode 100644 index 0000000..5fb5507 --- /dev/null +++ b/tests/unit/test_nats_client.py @@ -0,0 +1,91 @@ +""" +Unit tests for handler_base.nats_client module. +""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +import msgpack + + +class TestNATSClient: + """Tests for NATSClient.""" + + @pytest.fixture + def nats_client(self, settings): + """Create a NATSClient instance.""" + from handler_base.nats_client import NATSClient + return NATSClient(settings) + + def test_init(self, nats_client, settings): + """Test NATSClient initialization.""" + assert nats_client.settings == settings + assert nats_client._nc is None + assert nats_client._js is None + + def test_decode_msgpack(self, nats_client): + """Test msgpack decoding.""" + data = {"query": "hello", "request_id": "123"} + encoded = msgpack.packb(data, use_bin_type=True) + + msg = MagicMock() + msg.data = encoded + + result = nats_client.decode_msgpack(msg) + assert result == data + + def test_decode_json(self, nats_client): + """Test JSON decoding.""" + import json + data = {"query": "hello"} + + msg = MagicMock() + msg.data = json.dumps(data).encode() + + result = nats_client.decode_json(msg) + assert result == data + + @pytest.mark.asyncio + async def test_connect(self, nats_client): + """Test NATS connection.""" + with patch("handler_base.nats_client.nats") as mock_nats: + mock_nc = AsyncMock() + mock_js = MagicMock() + mock_nc.jetstream = MagicMock(return_value=mock_js) # Not async + mock_nats.connect = AsyncMock(return_value=mock_nc) + + await nats_client.connect() + + assert nats_client._nc == mock_nc + assert nats_client._js == mock_js + mock_nats.connect.assert_called_once() + + @pytest.mark.asyncio + async def test_publish(self, nats_client): + """Test publishing a message.""" + mock_nc = AsyncMock() + nats_client._nc = mock_nc + + data = {"key": "value"} + await nats_client.publish("test.subject", data) + + mock_nc.publish.assert_called_once() + call_args = mock_nc.publish.call_args + assert call_args.args[0] == "test.subject" + + # Verify msgpack encoding + decoded = msgpack.unpackb(call_args.args[1], raw=False) + assert decoded == data + + @pytest.mark.asyncio + async def test_subscribe(self, nats_client): + """Test subscribing to a subject.""" + mock_nc = AsyncMock() + mock_sub = MagicMock() + mock_nc.subscribe = AsyncMock(return_value=mock_sub) + nats_client._nc = mock_nc + + handler = AsyncMock() + await nats_client.subscribe("test.subject", handler, queue="test-queue") + + mock_nc.subscribe.assert_called_once() + call_kwargs = mock_nc.subscribe.call_args.kwargs + assert call_kwargs["queue"] == "test-queue"