273 lines
9.1 KiB
Python
273 lines
9.1 KiB
Python
"""
|
|
Unit tests for PipelineBridge handler.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from pipeline_bridge import PipelineBridge, PipelineSettings, PIPELINES
|
|
|
|
|
|
class TestPipelineSettings:
|
|
"""Tests for PipelineSettings configuration."""
|
|
|
|
def test_default_settings(self):
|
|
"""Test default settings values."""
|
|
settings = PipelineSettings()
|
|
|
|
assert settings.service_name == "pipeline-bridge"
|
|
assert settings.kubeflow_host == "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
|
|
assert settings.argo_host == "http://argo-server.argo.svc.cluster.local:2746"
|
|
assert settings.argo_namespace == "ai-ml"
|
|
|
|
def test_custom_settings(self):
|
|
"""Test custom settings."""
|
|
settings = PipelineSettings(
|
|
kubeflow_host="http://custom-kubeflow:8888",
|
|
argo_namespace="custom-ns",
|
|
)
|
|
|
|
assert settings.kubeflow_host == "http://custom-kubeflow:8888"
|
|
assert settings.argo_namespace == "custom-ns"
|
|
|
|
|
|
class TestPipelineDefinitions:
|
|
"""Tests for pipeline definitions."""
|
|
|
|
def test_required_pipelines_exist(self):
|
|
"""Test that required pipelines are defined."""
|
|
required = ["document-ingestion", "batch-inference", "rag-query", "voice-pipeline"]
|
|
for name in required:
|
|
assert name in PIPELINES, f"Pipeline {name} should be defined"
|
|
|
|
def test_argo_pipelines_have_template(self):
|
|
"""Test Argo pipelines have template field."""
|
|
for name, config in PIPELINES.items():
|
|
if config["engine"] == "argo":
|
|
assert "template" in config, f"Argo pipeline {name} missing template"
|
|
|
|
def test_kubeflow_pipelines_have_pipeline_id(self):
|
|
"""Test Kubeflow pipelines have pipeline_id field."""
|
|
for name, config in PIPELINES.items():
|
|
if config["engine"] == "kubeflow":
|
|
assert "pipeline_id" in config, f"Kubeflow pipeline {name} missing pipeline_id"
|
|
|
|
def test_all_pipelines_have_description(self):
|
|
"""Test all pipelines have descriptions."""
|
|
for name, config in PIPELINES.items():
|
|
assert "description" in config, f"Pipeline {name} missing description"
|
|
|
|
|
|
class TestPipelineBridge:
|
|
"""Tests for PipelineBridge handler."""
|
|
|
|
@pytest.fixture
|
|
def handler(self):
|
|
"""Create handler with mocked HTTP client."""
|
|
handler = PipelineBridge()
|
|
handler._http = AsyncMock()
|
|
handler.nats = AsyncMock()
|
|
return handler
|
|
|
|
def test_init(self, handler):
|
|
"""Test handler initialization."""
|
|
assert handler.subject == "ai.pipeline.trigger"
|
|
assert handler.queue_group == "pipeline-bridges"
|
|
assert handler.pipeline_settings.service_name == "pipeline-bridge"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_unknown_pipeline(
|
|
self,
|
|
handler,
|
|
mock_nats_message,
|
|
unknown_pipeline_request,
|
|
):
|
|
"""Test handling unknown pipeline."""
|
|
result = await handler.handle_message(mock_nats_message, unknown_pipeline_request)
|
|
|
|
assert result["status"] == "error"
|
|
assert "Unknown pipeline" in result["error"]
|
|
assert "available_pipelines" in result
|
|
assert "document-ingestion" in result["available_pipelines"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_argo_pipeline(
|
|
self,
|
|
handler,
|
|
mock_nats_message,
|
|
argo_pipeline_request,
|
|
mock_argo_response,
|
|
):
|
|
"""Test triggering Argo workflow."""
|
|
# Setup mock response
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = mock_argo_response
|
|
mock_response.raise_for_status = MagicMock()
|
|
handler._http.post.return_value = mock_response
|
|
|
|
result = await handler.handle_message(mock_nats_message, argo_pipeline_request)
|
|
|
|
assert result["status"] == "submitted"
|
|
assert result["engine"] == "argo"
|
|
assert result["run_id"] == "document-ingestion-abc123"
|
|
assert result["pipeline"] == "document-ingestion"
|
|
assert "submitted_at" in result
|
|
|
|
# Verify API call
|
|
handler._http.post.assert_called_once()
|
|
call_args = handler._http.post.call_args
|
|
assert "argo-server" in str(call_args)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_kubeflow_pipeline(
|
|
self,
|
|
handler,
|
|
mock_nats_message,
|
|
kubeflow_pipeline_request,
|
|
mock_kubeflow_response,
|
|
):
|
|
"""Test triggering Kubeflow pipeline."""
|
|
# Setup mock response
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = mock_kubeflow_response
|
|
mock_response.raise_for_status = MagicMock()
|
|
handler._http.post.return_value = mock_response
|
|
|
|
result = await handler.handle_message(mock_nats_message, kubeflow_pipeline_request)
|
|
|
|
assert result["status"] == "submitted"
|
|
assert result["engine"] == "kubeflow"
|
|
assert result["run_id"] == "run-xyz-789"
|
|
assert result["pipeline"] == "rag-query"
|
|
|
|
# Verify API call
|
|
handler._http.post.assert_called_once()
|
|
call_args = handler._http.post.call_args
|
|
assert "ml-pipeline" in str(call_args)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_api_error(
|
|
self,
|
|
handler,
|
|
mock_nats_message,
|
|
argo_pipeline_request,
|
|
):
|
|
"""Test handling API errors."""
|
|
handler._http.post.side_effect = Exception("Connection refused")
|
|
|
|
result = await handler.handle_message(mock_nats_message, argo_pipeline_request)
|
|
|
|
assert result["status"] == "error"
|
|
assert "Connection refused" in result["error"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publishes_status_update(
|
|
self,
|
|
handler,
|
|
mock_nats_message,
|
|
argo_pipeline_request,
|
|
mock_argo_response,
|
|
):
|
|
"""Test that status is published to NATS."""
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = mock_argo_response
|
|
mock_response.raise_for_status = MagicMock()
|
|
handler._http.post.return_value = mock_response
|
|
|
|
await handler.handle_message(mock_nats_message, argo_pipeline_request)
|
|
|
|
handler.nats.publish.assert_called_once()
|
|
call_args = handler.nats.publish.call_args
|
|
assert "ai.pipeline.status.test-request-123" in str(call_args)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_setup_creates_http_client(self):
|
|
"""Test that setup initializes HTTP client."""
|
|
with patch("pipeline_bridge.httpx.AsyncClient") as mock_client:
|
|
handler = PipelineBridge()
|
|
await handler.setup()
|
|
|
|
mock_client.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_teardown_closes_http_client(self, handler):
|
|
"""Test that teardown closes HTTP client."""
|
|
await handler.teardown()
|
|
|
|
handler._http.aclose.assert_called_once()
|
|
|
|
|
|
class TestArgoSubmission:
|
|
"""Tests for Argo workflow submission."""
|
|
|
|
@pytest.fixture
|
|
def handler(self):
|
|
"""Create handler with mocked HTTP client."""
|
|
handler = PipelineBridge()
|
|
handler._http = AsyncMock()
|
|
return handler
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_argo_workflow_structure(
|
|
self,
|
|
handler,
|
|
mock_argo_response,
|
|
):
|
|
"""Test Argo workflow request structure."""
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = mock_argo_response
|
|
mock_response.raise_for_status = MagicMock()
|
|
handler._http.post.return_value = mock_response
|
|
|
|
await handler._submit_argo(
|
|
template="document-ingestion",
|
|
parameters={"key": "value"},
|
|
request_id="test-123",
|
|
)
|
|
|
|
# Verify workflow structure
|
|
call_kwargs = handler._http.post.call_args.kwargs
|
|
workflow = call_kwargs["json"]["workflow"]
|
|
|
|
assert workflow["apiVersion"] == "argoproj.io/v1alpha1"
|
|
assert workflow["kind"] == "Workflow"
|
|
assert "workflowTemplateRef" in workflow["spec"]
|
|
assert workflow["spec"]["workflowTemplateRef"]["name"] == "document-ingestion"
|
|
assert workflow["metadata"]["labels"]["request-id"] == "test-123"
|
|
|
|
|
|
class TestKubeflowSubmission:
|
|
"""Tests for Kubeflow pipeline submission."""
|
|
|
|
@pytest.fixture
|
|
def handler(self):
|
|
"""Create handler with mocked HTTP client."""
|
|
handler = PipelineBridge()
|
|
handler._http = AsyncMock()
|
|
return handler
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_kubeflow_run_structure(
|
|
self,
|
|
handler,
|
|
mock_kubeflow_response,
|
|
):
|
|
"""Test Kubeflow run request structure."""
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = mock_kubeflow_response
|
|
mock_response.raise_for_status = MagicMock()
|
|
handler._http.post.return_value = mock_response
|
|
|
|
await handler._submit_kubeflow(
|
|
pipeline_id="rag-pipeline",
|
|
parameters={"query": "test"},
|
|
request_id="test-456",
|
|
)
|
|
|
|
# Verify run request structure
|
|
call_kwargs = handler._http.post.call_args.kwargs
|
|
run_request = call_kwargs["json"]
|
|
|
|
assert "rag-pipeline" in run_request["name"]
|
|
assert run_request["pipeline_spec"]["pipeline_id"] == "rag-pipeline"
|