test: add unit tests for handler-base

- tests/conftest.py: Pytest fixtures and configuration
- tests/unit/test_config.py: Settings tests
- tests/unit/test_nats_client.py: NATS client tests
- tests/unit/test_health.py: Health server tests
- tests/unit/test_clients.py: Service client tests
- pytest.ini: Pytest configuration
This commit is contained in:
2026-02-02 06:23:44 -05:00
parent 99c97b7973
commit 849da661e6
7 changed files with 500 additions and 0 deletions

76
tests/conftest.py Normal file
View File

@@ -0,0 +1,76 @@
"""
Pytest configuration and fixtures.
"""
import asyncio
import os
from typing import AsyncGenerator
from unittest.mock import AsyncMock, MagicMock
import pytest
# Set test environment variables before importing handler_base
os.environ.setdefault("NATS_URL", "nats://localhost:4222")
os.environ.setdefault("REDIS_URL", "redis://localhost:6379")
os.environ.setdefault("MILVUS_HOST", "localhost")
os.environ.setdefault("OTEL_ENABLED", "false")
os.environ.setdefault("MLFLOW_ENABLED", "false")
@pytest.fixture(scope="session")
def event_loop():
"""Create event loop for async tests."""
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture
def settings():
"""Create test settings."""
from handler_base.config import Settings
return Settings(
service_name="test-service",
service_version="1.0.0-test",
otel_enabled=False,
mlflow_enabled=False,
nats_url="nats://localhost:4222",
redis_url="redis://localhost:6379",
milvus_host="localhost",
)
@pytest.fixture
def mock_httpx_client():
"""Create a mock httpx AsyncClient."""
client = AsyncMock()
client.post = AsyncMock()
client.get = AsyncMock()
client.aclose = AsyncMock()
return client
@pytest.fixture
def mock_nats_message():
"""Create a mock NATS message."""
msg = MagicMock()
msg.subject = "test.subject"
msg.reply = "test.reply"
msg.data = b'\x82\xa8query\xa5hello\xaarequest_id\xa4test' # msgpack
return msg
@pytest.fixture
def sample_embedding():
"""Sample embedding vector."""
return [0.1] * 1024
@pytest.fixture
def sample_documents():
"""Sample documents for testing."""
return [
{"text": "Python is a programming language.", "source": "doc1"},
{"text": "Machine learning is a subset of AI.", "source": "doc2"},
{"text": "Deep learning uses neural networks.", "source": "doc3"},
]

1
tests/unit/__init__.py Normal file
View File

@@ -0,0 +1 @@
# Unit tests package

156
tests/unit/test_clients.py Normal file
View File

@@ -0,0 +1,156 @@
"""
Unit tests for service clients.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
class TestEmbeddingsClient:
"""Tests for EmbeddingsClient."""
@pytest.fixture
def embeddings_client(self, mock_httpx_client):
"""Create an EmbeddingsClient with mocked HTTP."""
from handler_base.clients.embeddings import EmbeddingsClient
client = EmbeddingsClient()
client._client = mock_httpx_client
return client
@pytest.mark.asyncio
async def test_embed_single(self, embeddings_client, mock_httpx_client, sample_embedding):
"""Test embedding a single text."""
# Setup mock response
mock_response = MagicMock()
mock_response.json.return_value = {
"data": [{"embedding": sample_embedding, "index": 0}]
}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response
result = await embeddings_client.embed_single("Hello world")
assert result == sample_embedding
mock_httpx_client.post.assert_called_once()
@pytest.mark.asyncio
async def test_embed_batch(self, embeddings_client, mock_httpx_client, sample_embedding):
"""Test embedding multiple texts."""
texts = ["Hello", "World"]
mock_response = MagicMock()
mock_response.json.return_value = {
"data": [
{"embedding": sample_embedding, "index": 0},
{"embedding": sample_embedding, "index": 1},
]
}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response
result = await embeddings_client.embed(texts)
assert len(result) == 2
assert all(len(e) == len(sample_embedding) for e in result)
@pytest.mark.asyncio
async def test_health_check(self, embeddings_client, mock_httpx_client):
"""Test health check."""
mock_response = MagicMock()
mock_response.status_code = 200
mock_httpx_client.get.return_value = mock_response
result = await embeddings_client.health()
assert result is True
class TestRerankerClient:
"""Tests for RerankerClient."""
@pytest.fixture
def reranker_client(self, mock_httpx_client):
"""Create a RerankerClient with mocked HTTP."""
from handler_base.clients.reranker import RerankerClient
client = RerankerClient()
client._client = mock_httpx_client
return client
@pytest.mark.asyncio
async def test_rerank(self, reranker_client, mock_httpx_client, sample_documents):
"""Test reranking documents."""
texts = [d["text"] for d in sample_documents]
mock_response = MagicMock()
mock_response.json.return_value = {
"results": [
{"index": 1, "relevance_score": 0.95},
{"index": 0, "relevance_score": 0.80},
{"index": 2, "relevance_score": 0.65},
]
}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response
result = await reranker_client.rerank("What is ML?", texts)
assert len(result) == 3
assert result[0]["score"] == 0.95
assert result[0]["index"] == 1
class TestLLMClient:
"""Tests for LLMClient."""
@pytest.fixture
def llm_client(self, mock_httpx_client):
"""Create an LLMClient with mocked HTTP."""
from handler_base.clients.llm import LLMClient
client = LLMClient()
client._client = mock_httpx_client
return client
@pytest.mark.asyncio
async def test_generate(self, llm_client, mock_httpx_client):
"""Test generating a response."""
mock_response = MagicMock()
mock_response.json.return_value = {
"choices": [
{"message": {"content": "Hello! I'm an AI assistant."}}
],
"usage": {"prompt_tokens": 10, "completion_tokens": 20}
}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response
result = await llm_client.generate("Hello")
assert result == "Hello! I'm an AI assistant."
@pytest.mark.asyncio
async def test_generate_with_context(self, llm_client, mock_httpx_client):
"""Test generating with RAG context."""
mock_response = MagicMock()
mock_response.json.return_value = {
"choices": [
{"message": {"content": "Based on the context..."}}
],
"usage": {}
}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.post.return_value = mock_response
result = await llm_client.generate(
"What is Python?",
context="Python is a programming language."
)
assert "Based on the context" in result
# Verify context was included in the request
call_args = mock_httpx_client.post.call_args
messages = call_args.kwargs["json"]["messages"]
assert any("Context:" in m["content"] for m in messages if m["role"] == "user")

47
tests/unit/test_config.py Normal file
View File

@@ -0,0 +1,47 @@
"""
Unit tests for handler_base.config module.
"""
import os
import pytest
class TestSettings:
"""Tests for Settings configuration."""
def test_default_settings(self, settings):
"""Test that default settings are loaded correctly."""
assert settings.service_name == "test-service"
assert settings.service_version == "1.0.0-test"
assert settings.otel_enabled is False
def test_settings_from_env(self, monkeypatch):
"""Test that settings can be loaded from environment variables."""
monkeypatch.setenv("SERVICE_NAME", "env-service")
monkeypatch.setenv("SERVICE_VERSION", "2.0.0")
monkeypatch.setenv("NATS_URL", "nats://custom:4222")
# Need to reimport to pick up env changes
from handler_base.config import Settings
s = Settings()
assert s.service_name == "env-service"
assert s.service_version == "2.0.0"
assert s.nats_url == "nats://custom:4222"
def test_embeddings_settings(self):
"""Test EmbeddingsSettings extends base correctly."""
from handler_base.config import EmbeddingsSettings
s = EmbeddingsSettings()
assert hasattr(s, "embeddings_model")
assert hasattr(s, "embeddings_batch_size")
assert s.embeddings_model == "bge"
def test_llm_settings(self):
"""Test LLMSettings has expected defaults."""
from handler_base.config import LLMSettings
s = LLMSettings()
assert s.llm_max_tokens == 2048
assert s.llm_temperature == 0.7
assert 0 <= s.llm_top_p <= 1

122
tests/unit/test_health.py Normal file
View File

@@ -0,0 +1,122 @@
"""
Unit tests for handler_base.health module.
"""
import pytest
import json
import threading
import time
from http.client import HTTPConnection
from unittest.mock import AsyncMock
class TestHealthServer:
"""Tests for HealthServer."""
@pytest.fixture
def health_server(self, settings):
"""Create a HealthServer instance."""
from handler_base.health import HealthServer
# Use a random high port to avoid conflicts
settings.health_port = 18080
return HealthServer(settings)
def test_start_stop(self, health_server):
"""Test starting and stopping the health server."""
health_server.start()
time.sleep(0.1) # Give server time to start
# Verify server is running
assert health_server._server is not None
assert health_server._thread is not None
assert health_server._thread.is_alive()
health_server.stop()
time.sleep(0.1)
assert health_server._server is None
def test_health_endpoint(self, health_server):
"""Test the /health endpoint."""
health_server.start()
time.sleep(0.1)
try:
conn = HTTPConnection("localhost", 18080, timeout=5)
conn.request("GET", "/health")
response = conn.getresponse()
assert response.status == 200
data = json.loads(response.read().decode())
assert data["status"] == "healthy"
finally:
conn.close()
health_server.stop()
def test_ready_endpoint_default(self, health_server):
"""Test the /ready endpoint with no custom check."""
health_server.start()
time.sleep(0.1)
try:
conn = HTTPConnection("localhost", 18080, timeout=5)
conn.request("GET", "/ready")
response = conn.getresponse()
assert response.status == 200
data = json.loads(response.read().decode())
assert data["status"] == "ready"
finally:
conn.close()
health_server.stop()
def test_ready_endpoint_with_check(self, settings):
"""Test /ready endpoint with custom readiness check."""
from handler_base.health import HealthServer
ready_flag = [False] # Use list to allow mutation in closure
async def check_ready():
return ready_flag[0]
settings.health_port = 18081
server = HealthServer(settings, ready_check=check_ready)
server.start()
time.sleep(0.2)
try:
conn = HTTPConnection("localhost", 18081, timeout=5)
# Should be not ready initially
conn.request("GET", "/ready")
response = conn.getresponse()
response.read() # Consume response body
assert response.status == 503
# Mark as ready
ready_flag[0] = True
# Need new connection after consuming response
conn.close()
conn = HTTPConnection("localhost", 18081, timeout=5)
conn.request("GET", "/ready")
response = conn.getresponse()
assert response.status == 200
finally:
conn.close()
server.stop()
def test_404_for_unknown_path(self, health_server):
"""Test that unknown paths return 404."""
health_server.start()
time.sleep(0.1)
try:
conn = HTTPConnection("localhost", 18080, timeout=5)
conn.request("GET", "/unknown")
response = conn.getresponse()
assert response.status == 404
finally:
conn.close()
health_server.stop()

View File

@@ -0,0 +1,91 @@
"""
Unit tests for handler_base.nats_client module.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import msgpack
class TestNATSClient:
"""Tests for NATSClient."""
@pytest.fixture
def nats_client(self, settings):
"""Create a NATSClient instance."""
from handler_base.nats_client import NATSClient
return NATSClient(settings)
def test_init(self, nats_client, settings):
"""Test NATSClient initialization."""
assert nats_client.settings == settings
assert nats_client._nc is None
assert nats_client._js is None
def test_decode_msgpack(self, nats_client):
"""Test msgpack decoding."""
data = {"query": "hello", "request_id": "123"}
encoded = msgpack.packb(data, use_bin_type=True)
msg = MagicMock()
msg.data = encoded
result = nats_client.decode_msgpack(msg)
assert result == data
def test_decode_json(self, nats_client):
"""Test JSON decoding."""
import json
data = {"query": "hello"}
msg = MagicMock()
msg.data = json.dumps(data).encode()
result = nats_client.decode_json(msg)
assert result == data
@pytest.mark.asyncio
async def test_connect(self, nats_client):
"""Test NATS connection."""
with patch("handler_base.nats_client.nats") as mock_nats:
mock_nc = AsyncMock()
mock_js = MagicMock()
mock_nc.jetstream = MagicMock(return_value=mock_js) # Not async
mock_nats.connect = AsyncMock(return_value=mock_nc)
await nats_client.connect()
assert nats_client._nc == mock_nc
assert nats_client._js == mock_js
mock_nats.connect.assert_called_once()
@pytest.mark.asyncio
async def test_publish(self, nats_client):
"""Test publishing a message."""
mock_nc = AsyncMock()
nats_client._nc = mock_nc
data = {"key": "value"}
await nats_client.publish("test.subject", data)
mock_nc.publish.assert_called_once()
call_args = mock_nc.publish.call_args
assert call_args.args[0] == "test.subject"
# Verify msgpack encoding
decoded = msgpack.unpackb(call_args.args[1], raw=False)
assert decoded == data
@pytest.mark.asyncio
async def test_subscribe(self, nats_client):
"""Test subscribing to a subject."""
mock_nc = AsyncMock()
mock_sub = MagicMock()
mock_nc.subscribe = AsyncMock(return_value=mock_sub)
nats_client._nc = mock_nc
handler = AsyncMock()
await nats_client.subscribe("test.subject", handler, queue="test-queue")
mock_nc.subscribe.assert_called_once()
call_kwargs = mock_nc.subscribe.call_args.kwargs
assert call_kwargs["queue"] == "test-queue"