fix: auto-fix ruff linting errors and remove unsupported upload-artifact
This commit is contained in:
@@ -1,14 +1,13 @@
|
||||
"""
|
||||
Pytest configuration and fixtures.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# Set test environment variables before importing handler_base
|
||||
os.environ.setdefault("NATS_URL", "nats://localhost:4222")
|
||||
os.environ.setdefault("REDIS_URL", "redis://localhost:6379")
|
||||
@@ -29,6 +28,7 @@ def event_loop():
|
||||
def settings():
|
||||
"""Create test settings."""
|
||||
from handler_base.config import Settings
|
||||
|
||||
return Settings(
|
||||
service_name="test-service",
|
||||
service_version="1.0.0-test",
|
||||
@@ -56,7 +56,7 @@ def mock_nats_message():
|
||||
msg = MagicMock()
|
||||
msg.subject = "test.subject"
|
||||
msg.reply = "test.reply"
|
||||
msg.data = b'\x82\xa8query\xa5hello\xaarequest_id\xa4test' # msgpack
|
||||
msg.data = b"\x82\xa8query\xa5hello\xaarequest_id\xa4test" # msgpack
|
||||
return msg
|
||||
|
||||
|
||||
|
||||
@@ -1,44 +1,43 @@
|
||||
"""
|
||||
Unit tests for service clients.
|
||||
"""
|
||||
import json
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
class TestEmbeddingsClient:
|
||||
"""Tests for EmbeddingsClient."""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embeddings_client(self, mock_httpx_client):
|
||||
"""Create an EmbeddingsClient with mocked HTTP."""
|
||||
from handler_base.clients.embeddings import EmbeddingsClient
|
||||
|
||||
|
||||
client = EmbeddingsClient()
|
||||
client._client = mock_httpx_client
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_single(self, embeddings_client, mock_httpx_client, sample_embedding):
|
||||
"""Test embedding a single text."""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"data": [{"embedding": sample_embedding, "index": 0}]
|
||||
}
|
||||
mock_response.json.return_value = {"data": [{"embedding": sample_embedding, "index": 0}]}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_httpx_client.post.return_value = mock_response
|
||||
|
||||
|
||||
result = await embeddings_client.embed_single("Hello world")
|
||||
|
||||
|
||||
assert result == sample_embedding
|
||||
mock_httpx_client.post.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_batch(self, embeddings_client, mock_httpx_client, sample_embedding):
|
||||
"""Test embedding multiple texts."""
|
||||
texts = ["Hello", "World"]
|
||||
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"data": [
|
||||
@@ -48,41 +47,41 @@ class TestEmbeddingsClient:
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_httpx_client.post.return_value = mock_response
|
||||
|
||||
|
||||
result = await embeddings_client.embed(texts)
|
||||
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(len(e) == len(sample_embedding) for e in result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(self, embeddings_client, mock_httpx_client):
|
||||
"""Test health check."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_httpx_client.get.return_value = mock_response
|
||||
|
||||
|
||||
result = await embeddings_client.health()
|
||||
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestRerankerClient:
|
||||
"""Tests for RerankerClient."""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reranker_client(self, mock_httpx_client):
|
||||
"""Create a RerankerClient with mocked HTTP."""
|
||||
from handler_base.clients.reranker import RerankerClient
|
||||
|
||||
|
||||
client = RerankerClient()
|
||||
client._client = mock_httpx_client
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerank(self, reranker_client, mock_httpx_client, sample_documents):
|
||||
"""Test reranking documents."""
|
||||
texts = [d["text"] for d in sample_documents]
|
||||
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"results": [
|
||||
@@ -93,9 +92,9 @@ class TestRerankerClient:
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_httpx_client.post.return_value = mock_response
|
||||
|
||||
|
||||
result = await reranker_client.rerank("What is ML?", texts)
|
||||
|
||||
|
||||
assert len(result) == 3
|
||||
assert result[0]["score"] == 0.95
|
||||
assert result[0]["index"] == 1
|
||||
@@ -103,53 +102,48 @@ class TestRerankerClient:
|
||||
|
||||
class TestLLMClient:
|
||||
"""Tests for LLMClient."""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_client(self, mock_httpx_client):
|
||||
"""Create an LLMClient with mocked HTTP."""
|
||||
from handler_base.clients.llm import LLMClient
|
||||
|
||||
|
||||
client = LLMClient()
|
||||
client._client = mock_httpx_client
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate(self, llm_client, mock_httpx_client):
|
||||
"""Test generating a response."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{"message": {"content": "Hello! I'm an AI assistant."}}
|
||||
],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 20}
|
||||
"choices": [{"message": {"content": "Hello! I'm an AI assistant."}}],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_httpx_client.post.return_value = mock_response
|
||||
|
||||
|
||||
result = await llm_client.generate("Hello")
|
||||
|
||||
|
||||
assert result == "Hello! I'm an AI assistant."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_with_context(self, llm_client, mock_httpx_client):
|
||||
"""Test generating with RAG context."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{"message": {"content": "Based on the context..."}}
|
||||
],
|
||||
"usage": {}
|
||||
"choices": [{"message": {"content": "Based on the context..."}}],
|
||||
"usage": {},
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_httpx_client.post.return_value = mock_response
|
||||
|
||||
|
||||
result = await llm_client.generate(
|
||||
"What is Python?",
|
||||
context="Python is a programming language."
|
||||
"What is Python?", context="Python is a programming language."
|
||||
)
|
||||
|
||||
|
||||
assert "Based on the context" in result
|
||||
|
||||
|
||||
# Verify context was included in the request
|
||||
call_args = mock_httpx_client.post.call_args
|
||||
messages = call_args.kwargs["json"]["messages"]
|
||||
|
||||
@@ -1,46 +1,45 @@
|
||||
"""
|
||||
Unit tests for handler_base.config module.
|
||||
"""
|
||||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
class TestSettings:
|
||||
"""Tests for Settings configuration."""
|
||||
|
||||
|
||||
def test_default_settings(self, settings):
|
||||
"""Test that default settings are loaded correctly."""
|
||||
assert settings.service_name == "test-service"
|
||||
assert settings.service_version == "1.0.0-test"
|
||||
assert settings.otel_enabled is False
|
||||
|
||||
|
||||
def test_settings_from_env(self, monkeypatch):
|
||||
"""Test that settings can be loaded from environment variables."""
|
||||
monkeypatch.setenv("SERVICE_NAME", "env-service")
|
||||
monkeypatch.setenv("SERVICE_VERSION", "2.0.0")
|
||||
monkeypatch.setenv("NATS_URL", "nats://custom:4222")
|
||||
|
||||
|
||||
# Need to reimport to pick up env changes
|
||||
from handler_base.config import Settings
|
||||
|
||||
s = Settings()
|
||||
|
||||
|
||||
assert s.service_name == "env-service"
|
||||
assert s.service_version == "2.0.0"
|
||||
assert s.nats_url == "nats://custom:4222"
|
||||
|
||||
|
||||
def test_embeddings_settings(self):
|
||||
"""Test EmbeddingsSettings extends base correctly."""
|
||||
from handler_base.config import EmbeddingsSettings
|
||||
|
||||
|
||||
s = EmbeddingsSettings()
|
||||
assert hasattr(s, "embeddings_model")
|
||||
assert hasattr(s, "embeddings_batch_size")
|
||||
assert s.embeddings_model == "bge"
|
||||
|
||||
|
||||
def test_llm_settings(self):
|
||||
"""Test LLMSettings has expected defaults."""
|
||||
from handler_base.config import LLMSettings
|
||||
|
||||
|
||||
s = LLMSettings()
|
||||
assert s.llm_max_tokens == 2048
|
||||
assert s.llm_temperature == 0.7
|
||||
|
||||
@@ -1,101 +1,101 @@
|
||||
"""
|
||||
Unit tests for handler_base.health module.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from http.client import HTTPConnection
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestHealthServer:
|
||||
"""Tests for HealthServer."""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def health_server(self, settings):
|
||||
"""Create a HealthServer instance."""
|
||||
from handler_base.health import HealthServer
|
||||
|
||||
|
||||
# Use a random high port to avoid conflicts
|
||||
settings.health_port = 18080
|
||||
return HealthServer(settings)
|
||||
|
||||
|
||||
def test_start_stop(self, health_server):
|
||||
"""Test starting and stopping the health server."""
|
||||
health_server.start()
|
||||
time.sleep(0.1) # Give server time to start
|
||||
|
||||
|
||||
# Verify server is running
|
||||
assert health_server._server is not None
|
||||
assert health_server._thread is not None
|
||||
assert health_server._thread.is_alive()
|
||||
|
||||
|
||||
health_server.stop()
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
assert health_server._server is None
|
||||
|
||||
|
||||
def test_health_endpoint(self, health_server):
|
||||
"""Test the /health endpoint."""
|
||||
health_server.start()
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
try:
|
||||
conn = HTTPConnection("localhost", 18080, timeout=5)
|
||||
conn.request("GET", "/health")
|
||||
response = conn.getresponse()
|
||||
|
||||
|
||||
assert response.status == 200
|
||||
data = json.loads(response.read().decode())
|
||||
assert data["status"] == "healthy"
|
||||
finally:
|
||||
conn.close()
|
||||
health_server.stop()
|
||||
|
||||
|
||||
def test_ready_endpoint_default(self, health_server):
|
||||
"""Test the /ready endpoint with no custom check."""
|
||||
health_server.start()
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
try:
|
||||
conn = HTTPConnection("localhost", 18080, timeout=5)
|
||||
conn.request("GET", "/ready")
|
||||
response = conn.getresponse()
|
||||
|
||||
|
||||
assert response.status == 200
|
||||
data = json.loads(response.read().decode())
|
||||
assert data["status"] == "ready"
|
||||
finally:
|
||||
conn.close()
|
||||
health_server.stop()
|
||||
|
||||
|
||||
def test_ready_endpoint_with_check(self, settings):
|
||||
"""Test /ready endpoint with custom readiness check."""
|
||||
from handler_base.health import HealthServer
|
||||
|
||||
|
||||
ready_flag = [False] # Use list to allow mutation in closure
|
||||
|
||||
|
||||
async def check_ready():
|
||||
return ready_flag[0]
|
||||
|
||||
|
||||
settings.health_port = 18081
|
||||
server = HealthServer(settings, ready_check=check_ready)
|
||||
server.start()
|
||||
time.sleep(0.2)
|
||||
|
||||
|
||||
try:
|
||||
conn = HTTPConnection("localhost", 18081, timeout=5)
|
||||
|
||||
|
||||
# Should be not ready initially
|
||||
conn.request("GET", "/ready")
|
||||
response = conn.getresponse()
|
||||
response.read() # Consume response body
|
||||
assert response.status == 503
|
||||
|
||||
|
||||
# Mark as ready
|
||||
ready_flag[0] = True
|
||||
|
||||
|
||||
# Need new connection after consuming response
|
||||
conn.close()
|
||||
conn = HTTPConnection("localhost", 18081, timeout=5)
|
||||
@@ -105,17 +105,17 @@ class TestHealthServer:
|
||||
finally:
|
||||
conn.close()
|
||||
server.stop()
|
||||
|
||||
|
||||
def test_404_for_unknown_path(self, health_server):
|
||||
"""Test that unknown paths return 404."""
|
||||
health_server.start()
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
try:
|
||||
conn = HTTPConnection("localhost", 18080, timeout=5)
|
||||
conn.request("GET", "/unknown")
|
||||
response = conn.getresponse()
|
||||
|
||||
|
||||
assert response.status == 404
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
@@ -1,48 +1,52 @@
|
||||
"""
|
||||
Unit tests for handler_base.nats_client module.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import msgpack
|
||||
import pytest
|
||||
|
||||
|
||||
class TestNATSClient:
|
||||
"""Tests for NATSClient."""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nats_client(self, settings):
|
||||
"""Create a NATSClient instance."""
|
||||
from handler_base.nats_client import NATSClient
|
||||
|
||||
return NATSClient(settings)
|
||||
|
||||
|
||||
def test_init(self, nats_client, settings):
|
||||
"""Test NATSClient initialization."""
|
||||
assert nats_client.settings == settings
|
||||
assert nats_client._nc is None
|
||||
assert nats_client._js is None
|
||||
|
||||
|
||||
def test_decode_msgpack(self, nats_client):
|
||||
"""Test msgpack decoding."""
|
||||
data = {"query": "hello", "request_id": "123"}
|
||||
encoded = msgpack.packb(data, use_bin_type=True)
|
||||
|
||||
|
||||
msg = MagicMock()
|
||||
msg.data = encoded
|
||||
|
||||
|
||||
result = nats_client.decode_msgpack(msg)
|
||||
assert result == data
|
||||
|
||||
|
||||
def test_decode_json(self, nats_client):
|
||||
"""Test JSON decoding."""
|
||||
import json
|
||||
|
||||
data = {"query": "hello"}
|
||||
|
||||
|
||||
msg = MagicMock()
|
||||
msg.data = json.dumps(data).encode()
|
||||
|
||||
|
||||
result = nats_client.decode_json(msg)
|
||||
assert result == data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect(self, nats_client):
|
||||
"""Test NATS connection."""
|
||||
@@ -51,30 +55,30 @@ class TestNATSClient:
|
||||
mock_js = MagicMock()
|
||||
mock_nc.jetstream = MagicMock(return_value=mock_js) # Not async
|
||||
mock_nats.connect = AsyncMock(return_value=mock_nc)
|
||||
|
||||
|
||||
await nats_client.connect()
|
||||
|
||||
|
||||
assert nats_client._nc == mock_nc
|
||||
assert nats_client._js == mock_js
|
||||
mock_nats.connect.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish(self, nats_client):
|
||||
"""Test publishing a message."""
|
||||
mock_nc = AsyncMock()
|
||||
nats_client._nc = mock_nc
|
||||
|
||||
|
||||
data = {"key": "value"}
|
||||
await nats_client.publish("test.subject", data)
|
||||
|
||||
|
||||
mock_nc.publish.assert_called_once()
|
||||
call_args = mock_nc.publish.call_args
|
||||
assert call_args.args[0] == "test.subject"
|
||||
|
||||
|
||||
# Verify msgpack encoding
|
||||
decoded = msgpack.unpackb(call_args.args[1], raw=False)
|
||||
assert decoded == data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe(self, nats_client):
|
||||
"""Test subscribing to a subject."""
|
||||
@@ -82,10 +86,10 @@ class TestNATSClient:
|
||||
mock_sub = MagicMock()
|
||||
mock_nc.subscribe = AsyncMock(return_value=mock_sub)
|
||||
nats_client._nc = mock_nc
|
||||
|
||||
|
||||
handler = AsyncMock()
|
||||
await nats_client.subscribe("test.subject", handler, queue="test-queue")
|
||||
|
||||
|
||||
mock_nc.subscribe.assert_called_once()
|
||||
call_kwargs = mock_nc.subscribe.call_args.kwargs
|
||||
assert call_kwargs["queue"] == "test-queue"
|
||||
|
||||
Reference in New Issue
Block a user