Files
pipeline-bridge/tests/test_pipeline_bridge.py

272 lines
9.4 KiB
Python

"""
Unit tests for PipelineBridge handler.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from pipeline_bridge import PipelineBridge, PipelineSettings, PIPELINES
class TestPipelineSettings:
"""Tests for PipelineSettings configuration."""
def test_default_settings(self):
"""Test default settings values."""
settings = PipelineSettings()
assert settings.service_name == "pipeline-bridge"
assert settings.kubeflow_host == "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
assert settings.argo_host == "http://argo-server.argo.svc.cluster.local:2746"
assert settings.argo_namespace == "ai-ml"
def test_custom_settings(self):
"""Test custom settings."""
settings = PipelineSettings(
kubeflow_host="http://custom-kubeflow:8888",
argo_namespace="custom-ns",
)
assert settings.kubeflow_host == "http://custom-kubeflow:8888"
assert settings.argo_namespace == "custom-ns"
class TestPipelineDefinitions:
"""Tests for pipeline definitions."""
def test_required_pipelines_exist(self):
"""Test that required pipelines are defined."""
required = ["document-ingestion", "batch-inference", "rag-query", "voice-pipeline"]
for name in required:
assert name in PIPELINES, f"Pipeline {name} should be defined"
def test_argo_pipelines_have_template(self):
"""Test Argo pipelines have template field."""
for name, config in PIPELINES.items():
if config["engine"] == "argo":
assert "template" in config, f"Argo pipeline {name} missing template"
def test_kubeflow_pipelines_have_pipeline_id(self):
"""Test Kubeflow pipelines have pipeline_id field."""
for name, config in PIPELINES.items():
if config["engine"] == "kubeflow":
assert "pipeline_id" in config, f"Kubeflow pipeline {name} missing pipeline_id"
def test_all_pipelines_have_description(self):
"""Test all pipelines have descriptions."""
for name, config in PIPELINES.items():
assert "description" in config, f"Pipeline {name} missing description"
class TestPipelineBridge:
"""Tests for PipelineBridge handler."""
@pytest.fixture
def handler(self):
"""Create handler with mocked HTTP client."""
handler = PipelineBridge()
handler._http = AsyncMock()
handler.nats = AsyncMock()
return handler
def test_init(self, handler):
"""Test handler initialization."""
assert handler.subject == "ai.pipeline.trigger"
assert handler.queue_group == "pipeline-bridges"
assert handler.pipeline_settings.service_name == "pipeline-bridge"
@pytest.mark.asyncio
async def test_handle_unknown_pipeline(
self,
handler,
mock_nats_message,
unknown_pipeline_request,
):
"""Test handling unknown pipeline."""
result = await handler.handle_message(mock_nats_message, unknown_pipeline_request)
assert result["status"] == "error"
assert "Unknown pipeline" in result["error"]
assert "available_pipelines" in result
assert "document-ingestion" in result["available_pipelines"]
@pytest.mark.asyncio
async def test_handle_argo_pipeline(
self,
handler,
mock_nats_message,
argo_pipeline_request,
mock_argo_response,
):
"""Test triggering Argo workflow."""
# Setup mock response
mock_response = MagicMock()
mock_response.json.return_value = mock_argo_response
mock_response.raise_for_status = MagicMock()
handler._http.post.return_value = mock_response
result = await handler.handle_message(mock_nats_message, argo_pipeline_request)
assert result["status"] == "submitted"
assert result["engine"] == "argo"
assert result["run_id"] == "document-ingestion-abc123"
assert result["pipeline"] == "document-ingestion"
assert "submitted_at" in result
# Verify API call
handler._http.post.assert_called_once()
call_args = handler._http.post.call_args
assert "argo-server" in str(call_args)
@pytest.mark.asyncio
async def test_handle_kubeflow_pipeline(
self,
handler,
mock_nats_message,
kubeflow_pipeline_request,
mock_kubeflow_response,
):
"""Test triggering Kubeflow pipeline."""
# Setup mock response
mock_response = MagicMock()
mock_response.json.return_value = mock_kubeflow_response
mock_response.raise_for_status = MagicMock()
handler._http.post.return_value = mock_response
result = await handler.handle_message(mock_nats_message, kubeflow_pipeline_request)
assert result["status"] == "submitted"
assert result["engine"] == "kubeflow"
assert result["run_id"] == "run-xyz-789"
assert result["pipeline"] == "rag-query"
# Verify API call
handler._http.post.assert_called_once()
call_args = handler._http.post.call_args
assert "ml-pipeline" in str(call_args)
@pytest.mark.asyncio
async def test_handle_api_error(
self,
handler,
mock_nats_message,
argo_pipeline_request,
):
"""Test handling API errors."""
handler._http.post.side_effect = Exception("Connection refused")
result = await handler.handle_message(mock_nats_message, argo_pipeline_request)
assert result["status"] == "error"
assert "Connection refused" in result["error"]
@pytest.mark.asyncio
async def test_publishes_status_update(
self,
handler,
mock_nats_message,
argo_pipeline_request,
mock_argo_response,
):
"""Test that status is published to NATS."""
mock_response = MagicMock()
mock_response.json.return_value = mock_argo_response
mock_response.raise_for_status = MagicMock()
handler._http.post.return_value = mock_response
await handler.handle_message(mock_nats_message, argo_pipeline_request)
handler.nats.publish.assert_called_once()
call_args = handler.nats.publish.call_args
assert "ai.pipeline.status.test-request-123" in str(call_args)
@pytest.mark.asyncio
async def test_setup_creates_http_client(self):
"""Test that setup initializes HTTP client."""
with patch("pipeline_bridge.httpx.AsyncClient") as mock_client:
handler = PipelineBridge()
await handler.setup()
mock_client.assert_called_once()
@pytest.mark.asyncio
async def test_teardown_closes_http_client(self, handler):
"""Test that teardown closes HTTP client."""
await handler.teardown()
handler._http.aclose.assert_called_once()
class TestArgoSubmission:
"""Tests for Argo workflow submission."""
@pytest.fixture
def handler(self):
"""Create handler with mocked HTTP client."""
handler = PipelineBridge()
handler._http = AsyncMock()
return handler
@pytest.mark.asyncio
async def test_argo_workflow_structure(
self,
handler,
mock_argo_response,
):
"""Test Argo workflow request structure."""
mock_response = MagicMock()
mock_response.json.return_value = mock_argo_response
mock_response.raise_for_status = MagicMock()
handler._http.post.return_value = mock_response
await handler._submit_argo(
template="document-ingestion",
parameters={"key": "value"},
request_id="test-123",
)
# Verify workflow structure
call_kwargs = handler._http.post.call_args.kwargs
workflow = call_kwargs["json"]["workflow"]
assert workflow["apiVersion"] == "argoproj.io/v1alpha1"
assert workflow["kind"] == "Workflow"
assert "workflowTemplateRef" in workflow["spec"]
assert workflow["spec"]["workflowTemplateRef"]["name"] == "document-ingestion"
assert workflow["metadata"]["labels"]["request-id"] == "test-123"
class TestKubeflowSubmission:
"""Tests for Kubeflow pipeline submission."""
@pytest.fixture
def handler(self):
"""Create handler with mocked HTTP client."""
handler = PipelineBridge()
handler._http = AsyncMock()
return handler
@pytest.mark.asyncio
async def test_kubeflow_run_structure(
self,
handler,
mock_kubeflow_response,
):
"""Test Kubeflow run request structure."""
mock_response = MagicMock()
mock_response.json.return_value = mock_kubeflow_response
mock_response.raise_for_status = MagicMock()
handler._http.post.return_value = mock_response
await handler._submit_kubeflow(
pipeline_id="rag-pipeline",
parameters={"query": "test"},
request_id="test-456",
)
# Verify run request structure
call_kwargs = handler._http.post.call_args.kwargs
run_request = call_kwargs["json"]
assert "rag-pipeline" in run_request["name"]
assert run_request["pipeline_spec"]["pipeline_id"] == "rag-pipeline"