fix: ruff formatting and allow-direct-references for handler-base dep
This commit is contained in:
@@ -8,6 +8,7 @@ Bridges NATS events to workflow engines using handler-base:
|
|||||||
3. Monitor execution and publish status updates
|
3. Monitor execution and publish status updates
|
||||||
4. Publish completion to "ai.pipeline.status.{request_id}"
|
4. Publish completion to "ai.pipeline.status.{request_id}"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -23,12 +24,12 @@ logger = logging.getLogger("pipeline-bridge")
|
|||||||
|
|
||||||
class PipelineSettings(Settings):
|
class PipelineSettings(Settings):
|
||||||
"""Pipeline bridge specific settings."""
|
"""Pipeline bridge specific settings."""
|
||||||
|
|
||||||
service_name: str = "pipeline-bridge"
|
service_name: str = "pipeline-bridge"
|
||||||
|
|
||||||
# Kubeflow Pipelines
|
# Kubeflow Pipelines
|
||||||
kubeflow_host: str = "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
|
kubeflow_host: str = "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
|
||||||
|
|
||||||
# Argo Workflows
|
# Argo Workflows
|
||||||
argo_host: str = "http://argo-server.argo.svc.cluster.local:2746"
|
argo_host: str = "http://argo-server.argo.svc.cluster.local:2746"
|
||||||
argo_namespace: str = "ai-ml"
|
argo_namespace: str = "ai-ml"
|
||||||
@@ -67,14 +68,14 @@ PIPELINES = {
|
|||||||
class PipelineBridge(Handler):
|
class PipelineBridge(Handler):
|
||||||
"""
|
"""
|
||||||
Pipeline trigger handler.
|
Pipeline trigger handler.
|
||||||
|
|
||||||
Request format:
|
Request format:
|
||||||
{
|
{
|
||||||
"request_id": "uuid",
|
"request_id": "uuid",
|
||||||
"pipeline": "document-ingestion",
|
"pipeline": "document-ingestion",
|
||||||
"parameters": {"key": "value"}
|
"parameters": {"key": "value"}
|
||||||
}
|
}
|
||||||
|
|
||||||
Response format:
|
Response format:
|
||||||
{
|
{
|
||||||
"request_id": "uuid",
|
"request_id": "uuid",
|
||||||
@@ -83,7 +84,7 @@ class PipelineBridge(Handler):
|
|||||||
"engine": "argo|kubeflow"
|
"engine": "argo|kubeflow"
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.pipeline_settings = PipelineSettings()
|
self.pipeline_settings = PipelineSettings()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -91,36 +92,36 @@ class PipelineBridge(Handler):
|
|||||||
settings=self.pipeline_settings,
|
settings=self.pipeline_settings,
|
||||||
queue_group="pipeline-bridges",
|
queue_group="pipeline-bridges",
|
||||||
)
|
)
|
||||||
|
|
||||||
self._http: Optional[httpx.AsyncClient] = None
|
self._http: Optional[httpx.AsyncClient] = None
|
||||||
|
|
||||||
async def setup(self) -> None:
|
async def setup(self) -> None:
|
||||||
"""Initialize HTTP client."""
|
"""Initialize HTTP client."""
|
||||||
logger.info("Initializing pipeline bridge...")
|
logger.info("Initializing pipeline bridge...")
|
||||||
|
|
||||||
self._http = httpx.AsyncClient(timeout=60.0)
|
self._http = httpx.AsyncClient(timeout=60.0)
|
||||||
|
|
||||||
logger.info(f"Pipeline bridge ready. Available pipelines: {list(PIPELINES.keys())}")
|
logger.info(f"Pipeline bridge ready. Available pipelines: {list(PIPELINES.keys())}")
|
||||||
|
|
||||||
async def teardown(self) -> None:
|
async def teardown(self) -> None:
|
||||||
"""Clean up HTTP client."""
|
"""Clean up HTTP client."""
|
||||||
if self._http:
|
if self._http:
|
||||||
await self._http.aclose()
|
await self._http.aclose()
|
||||||
logger.info("Pipeline bridge closed")
|
logger.info("Pipeline bridge closed")
|
||||||
|
|
||||||
async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]:
|
async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]:
|
||||||
"""Handle pipeline trigger request."""
|
"""Handle pipeline trigger request."""
|
||||||
request_id = data.get("request_id", "unknown")
|
request_id = data.get("request_id", "unknown")
|
||||||
pipeline_name = data.get("pipeline", "")
|
pipeline_name = data.get("pipeline", "")
|
||||||
parameters = data.get("parameters", {})
|
parameters = data.get("parameters", {})
|
||||||
|
|
||||||
logger.info(f"Triggering pipeline '{pipeline_name}' for request {request_id}")
|
logger.info(f"Triggering pipeline '{pipeline_name}' for request {request_id}")
|
||||||
|
|
||||||
with create_span("pipeline.trigger") as span:
|
with create_span("pipeline.trigger") as span:
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("request.id", request_id)
|
span.set_attribute("request.id", request_id)
|
||||||
span.set_attribute("pipeline.name", pipeline_name)
|
span.set_attribute("pipeline.name", pipeline_name)
|
||||||
|
|
||||||
# Validate pipeline
|
# Validate pipeline
|
||||||
if pipeline_name not in PIPELINES:
|
if pipeline_name not in PIPELINES:
|
||||||
error = f"Unknown pipeline: {pipeline_name}"
|
error = f"Unknown pipeline: {pipeline_name}"
|
||||||
@@ -131,20 +132,18 @@ class PipelineBridge(Handler):
|
|||||||
"error": error,
|
"error": error,
|
||||||
"available_pipelines": list(PIPELINES.keys()),
|
"available_pipelines": list(PIPELINES.keys()),
|
||||||
}
|
}
|
||||||
|
|
||||||
pipeline = PIPELINES[pipeline_name]
|
pipeline = PIPELINES[pipeline_name]
|
||||||
engine = pipeline["engine"]
|
engine = pipeline["engine"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if engine == "argo":
|
if engine == "argo":
|
||||||
run_id = await self._submit_argo(
|
run_id = await self._submit_argo(pipeline["template"], parameters, request_id)
|
||||||
pipeline["template"], parameters, request_id
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
run_id = await self._submit_kubeflow(
|
run_id = await self._submit_kubeflow(
|
||||||
pipeline["pipeline_id"], parameters, request_id
|
pipeline["pipeline_id"], parameters, request_id
|
||||||
)
|
)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"request_id": request_id,
|
"request_id": request_id,
|
||||||
"status": "submitted",
|
"status": "submitted",
|
||||||
@@ -153,15 +152,13 @@ class PipelineBridge(Handler):
|
|||||||
"pipeline": pipeline_name,
|
"pipeline": pipeline_name,
|
||||||
"submitted_at": datetime.utcnow().isoformat(),
|
"submitted_at": datetime.utcnow().isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Publish status update
|
# Publish status update
|
||||||
await self.nats.publish(
|
await self.nats.publish(f"ai.pipeline.status.{request_id}", result)
|
||||||
f"ai.pipeline.status.{request_id}", result
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Pipeline {pipeline_name} submitted: {run_id}")
|
logger.info(f"Pipeline {pipeline_name} submitted: {run_id}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Failed to submit pipeline {pipeline_name}")
|
logger.exception(f"Failed to submit pipeline {pipeline_name}")
|
||||||
return {
|
return {
|
||||||
@@ -169,15 +166,13 @@ class PipelineBridge(Handler):
|
|||||||
"status": "error",
|
"status": "error",
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _submit_argo(
|
async def _submit_argo(self, template: str, parameters: dict, request_id: str) -> str:
|
||||||
self, template: str, parameters: dict, request_id: str
|
|
||||||
) -> str:
|
|
||||||
"""Submit workflow to Argo Workflows."""
|
"""Submit workflow to Argo Workflows."""
|
||||||
with create_span("pipeline.submit.argo") as span:
|
with create_span("pipeline.submit.argo") as span:
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("argo.template", template)
|
span.set_attribute("argo.template", template)
|
||||||
|
|
||||||
workflow = {
|
workflow = {
|
||||||
"apiVersion": "argoproj.io/v1alpha1",
|
"apiVersion": "argoproj.io/v1alpha1",
|
||||||
"kind": "Workflow",
|
"kind": "Workflow",
|
||||||
@@ -191,48 +186,40 @@ class PipelineBridge(Handler):
|
|||||||
"spec": {
|
"spec": {
|
||||||
"workflowTemplateRef": {"name": template},
|
"workflowTemplateRef": {"name": template},
|
||||||
"arguments": {
|
"arguments": {
|
||||||
"parameters": [
|
"parameters": [{"name": k, "value": str(v)} for k, v in parameters.items()]
|
||||||
{"name": k, "value": str(v)}
|
|
||||||
for k, v in parameters.items()
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
response = await self._http.post(
|
response = await self._http.post(
|
||||||
f"{self.pipeline_settings.argo_host}/api/v1/workflows/{self.pipeline_settings.argo_namespace}",
|
f"{self.pipeline_settings.argo_host}/api/v1/workflows/{self.pipeline_settings.argo_namespace}",
|
||||||
json={"workflow": workflow},
|
json={"workflow": workflow},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
result = response.json()
|
result = response.json()
|
||||||
return result["metadata"]["name"]
|
return result["metadata"]["name"]
|
||||||
|
|
||||||
async def _submit_kubeflow(
|
async def _submit_kubeflow(self, pipeline_id: str, parameters: dict, request_id: str) -> str:
|
||||||
self, pipeline_id: str, parameters: dict, request_id: str
|
|
||||||
) -> str:
|
|
||||||
"""Submit run to Kubeflow Pipelines."""
|
"""Submit run to Kubeflow Pipelines."""
|
||||||
with create_span("pipeline.submit.kubeflow") as span:
|
with create_span("pipeline.submit.kubeflow") as span:
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("kubeflow.pipeline_id", pipeline_id)
|
span.set_attribute("kubeflow.pipeline_id", pipeline_id)
|
||||||
|
|
||||||
run_request = {
|
run_request = {
|
||||||
"name": f"{pipeline_id}-{request_id[:8]}",
|
"name": f"{pipeline_id}-{request_id[:8]}",
|
||||||
"pipeline_spec": {
|
"pipeline_spec": {
|
||||||
"pipeline_id": pipeline_id,
|
"pipeline_id": pipeline_id,
|
||||||
"parameters": [
|
"parameters": [{"name": k, "value": str(v)} for k, v in parameters.items()],
|
||||||
{"name": k, "value": str(v)}
|
|
||||||
for k, v in parameters.items()
|
|
||||||
],
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
response = await self._http.post(
|
response = await self._http.post(
|
||||||
f"{self.pipeline_settings.kubeflow_host}/apis/v1beta1/runs",
|
f"{self.pipeline_settings.kubeflow_host}/apis/v1beta1/runs",
|
||||||
json=run_request,
|
json=run_request,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
result = response.json()
|
result = response.json()
|
||||||
return result["run"]["id"]
|
return result["run"]["id"]
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,9 @@ dev = [
|
|||||||
requires = ["hatchling"]
|
requires = ["hatchling"]
|
||||||
build-backend = "hatchling.build"
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[tool.hatch.metadata]
|
||||||
|
allow-direct-references = true
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
[tool.hatch.build.targets.wheel]
|
||||||
packages = ["."]
|
packages = ["."]
|
||||||
only-include = ["pipeline_bridge.py"]
|
only-include = ["pipeline_bridge.py"]
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
Pytest configuration and fixtures for pipeline-bridge tests.
|
Pytest configuration and fixtures for pipeline-bridge tests.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Unit tests for PipelineBridge handler.
|
Unit tests for PipelineBridge handler.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
@@ -9,48 +10,48 @@ from pipeline_bridge import PipelineBridge, PipelineSettings, PIPELINES
|
|||||||
|
|
||||||
class TestPipelineSettings:
|
class TestPipelineSettings:
|
||||||
"""Tests for PipelineSettings configuration."""
|
"""Tests for PipelineSettings configuration."""
|
||||||
|
|
||||||
def test_default_settings(self):
|
def test_default_settings(self):
|
||||||
"""Test default settings values."""
|
"""Test default settings values."""
|
||||||
settings = PipelineSettings()
|
settings = PipelineSettings()
|
||||||
|
|
||||||
assert settings.service_name == "pipeline-bridge"
|
assert settings.service_name == "pipeline-bridge"
|
||||||
assert settings.kubeflow_host == "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
|
assert settings.kubeflow_host == "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
|
||||||
assert settings.argo_host == "http://argo-server.argo.svc.cluster.local:2746"
|
assert settings.argo_host == "http://argo-server.argo.svc.cluster.local:2746"
|
||||||
assert settings.argo_namespace == "ai-ml"
|
assert settings.argo_namespace == "ai-ml"
|
||||||
|
|
||||||
def test_custom_settings(self):
|
def test_custom_settings(self):
|
||||||
"""Test custom settings."""
|
"""Test custom settings."""
|
||||||
settings = PipelineSettings(
|
settings = PipelineSettings(
|
||||||
kubeflow_host="http://custom-kubeflow:8888",
|
kubeflow_host="http://custom-kubeflow:8888",
|
||||||
argo_namespace="custom-ns",
|
argo_namespace="custom-ns",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert settings.kubeflow_host == "http://custom-kubeflow:8888"
|
assert settings.kubeflow_host == "http://custom-kubeflow:8888"
|
||||||
assert settings.argo_namespace == "custom-ns"
|
assert settings.argo_namespace == "custom-ns"
|
||||||
|
|
||||||
|
|
||||||
class TestPipelineDefinitions:
|
class TestPipelineDefinitions:
|
||||||
"""Tests for pipeline definitions."""
|
"""Tests for pipeline definitions."""
|
||||||
|
|
||||||
def test_required_pipelines_exist(self):
|
def test_required_pipelines_exist(self):
|
||||||
"""Test that required pipelines are defined."""
|
"""Test that required pipelines are defined."""
|
||||||
required = ["document-ingestion", "batch-inference", "rag-query", "voice-pipeline"]
|
required = ["document-ingestion", "batch-inference", "rag-query", "voice-pipeline"]
|
||||||
for name in required:
|
for name in required:
|
||||||
assert name in PIPELINES, f"Pipeline {name} should be defined"
|
assert name in PIPELINES, f"Pipeline {name} should be defined"
|
||||||
|
|
||||||
def test_argo_pipelines_have_template(self):
|
def test_argo_pipelines_have_template(self):
|
||||||
"""Test Argo pipelines have template field."""
|
"""Test Argo pipelines have template field."""
|
||||||
for name, config in PIPELINES.items():
|
for name, config in PIPELINES.items():
|
||||||
if config["engine"] == "argo":
|
if config["engine"] == "argo":
|
||||||
assert "template" in config, f"Argo pipeline {name} missing template"
|
assert "template" in config, f"Argo pipeline {name} missing template"
|
||||||
|
|
||||||
def test_kubeflow_pipelines_have_pipeline_id(self):
|
def test_kubeflow_pipelines_have_pipeline_id(self):
|
||||||
"""Test Kubeflow pipelines have pipeline_id field."""
|
"""Test Kubeflow pipelines have pipeline_id field."""
|
||||||
for name, config in PIPELINES.items():
|
for name, config in PIPELINES.items():
|
||||||
if config["engine"] == "kubeflow":
|
if config["engine"] == "kubeflow":
|
||||||
assert "pipeline_id" in config, f"Kubeflow pipeline {name} missing pipeline_id"
|
assert "pipeline_id" in config, f"Kubeflow pipeline {name} missing pipeline_id"
|
||||||
|
|
||||||
def test_all_pipelines_have_description(self):
|
def test_all_pipelines_have_description(self):
|
||||||
"""Test all pipelines have descriptions."""
|
"""Test all pipelines have descriptions."""
|
||||||
for name, config in PIPELINES.items():
|
for name, config in PIPELINES.items():
|
||||||
@@ -59,7 +60,7 @@ class TestPipelineDefinitions:
|
|||||||
|
|
||||||
class TestPipelineBridge:
|
class TestPipelineBridge:
|
||||||
"""Tests for PipelineBridge handler."""
|
"""Tests for PipelineBridge handler."""
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def handler(self):
|
def handler(self):
|
||||||
"""Create handler with mocked HTTP client."""
|
"""Create handler with mocked HTTP client."""
|
||||||
@@ -67,13 +68,13 @@ class TestPipelineBridge:
|
|||||||
handler._http = AsyncMock()
|
handler._http = AsyncMock()
|
||||||
handler.nats = AsyncMock()
|
handler.nats = AsyncMock()
|
||||||
return handler
|
return handler
|
||||||
|
|
||||||
def test_init(self, handler):
|
def test_init(self, handler):
|
||||||
"""Test handler initialization."""
|
"""Test handler initialization."""
|
||||||
assert handler.subject == "ai.pipeline.trigger"
|
assert handler.subject == "ai.pipeline.trigger"
|
||||||
assert handler.queue_group == "pipeline-bridges"
|
assert handler.queue_group == "pipeline-bridges"
|
||||||
assert handler.pipeline_settings.service_name == "pipeline-bridge"
|
assert handler.pipeline_settings.service_name == "pipeline-bridge"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_handle_unknown_pipeline(
|
async def test_handle_unknown_pipeline(
|
||||||
self,
|
self,
|
||||||
@@ -83,12 +84,12 @@ class TestPipelineBridge:
|
|||||||
):
|
):
|
||||||
"""Test handling unknown pipeline."""
|
"""Test handling unknown pipeline."""
|
||||||
result = await handler.handle_message(mock_nats_message, unknown_pipeline_request)
|
result = await handler.handle_message(mock_nats_message, unknown_pipeline_request)
|
||||||
|
|
||||||
assert result["status"] == "error"
|
assert result["status"] == "error"
|
||||||
assert "Unknown pipeline" in result["error"]
|
assert "Unknown pipeline" in result["error"]
|
||||||
assert "available_pipelines" in result
|
assert "available_pipelines" in result
|
||||||
assert "document-ingestion" in result["available_pipelines"]
|
assert "document-ingestion" in result["available_pipelines"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_handle_argo_pipeline(
|
async def test_handle_argo_pipeline(
|
||||||
self,
|
self,
|
||||||
@@ -103,20 +104,20 @@ class TestPipelineBridge:
|
|||||||
mock_response.json.return_value = mock_argo_response
|
mock_response.json.return_value = mock_argo_response
|
||||||
mock_response.raise_for_status = MagicMock()
|
mock_response.raise_for_status = MagicMock()
|
||||||
handler._http.post.return_value = mock_response
|
handler._http.post.return_value = mock_response
|
||||||
|
|
||||||
result = await handler.handle_message(mock_nats_message, argo_pipeline_request)
|
result = await handler.handle_message(mock_nats_message, argo_pipeline_request)
|
||||||
|
|
||||||
assert result["status"] == "submitted"
|
assert result["status"] == "submitted"
|
||||||
assert result["engine"] == "argo"
|
assert result["engine"] == "argo"
|
||||||
assert result["run_id"] == "document-ingestion-abc123"
|
assert result["run_id"] == "document-ingestion-abc123"
|
||||||
assert result["pipeline"] == "document-ingestion"
|
assert result["pipeline"] == "document-ingestion"
|
||||||
assert "submitted_at" in result
|
assert "submitted_at" in result
|
||||||
|
|
||||||
# Verify API call
|
# Verify API call
|
||||||
handler._http.post.assert_called_once()
|
handler._http.post.assert_called_once()
|
||||||
call_args = handler._http.post.call_args
|
call_args = handler._http.post.call_args
|
||||||
assert "argo-server" in str(call_args)
|
assert "argo-server" in str(call_args)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_handle_kubeflow_pipeline(
|
async def test_handle_kubeflow_pipeline(
|
||||||
self,
|
self,
|
||||||
@@ -131,19 +132,19 @@ class TestPipelineBridge:
|
|||||||
mock_response.json.return_value = mock_kubeflow_response
|
mock_response.json.return_value = mock_kubeflow_response
|
||||||
mock_response.raise_for_status = MagicMock()
|
mock_response.raise_for_status = MagicMock()
|
||||||
handler._http.post.return_value = mock_response
|
handler._http.post.return_value = mock_response
|
||||||
|
|
||||||
result = await handler.handle_message(mock_nats_message, kubeflow_pipeline_request)
|
result = await handler.handle_message(mock_nats_message, kubeflow_pipeline_request)
|
||||||
|
|
||||||
assert result["status"] == "submitted"
|
assert result["status"] == "submitted"
|
||||||
assert result["engine"] == "kubeflow"
|
assert result["engine"] == "kubeflow"
|
||||||
assert result["run_id"] == "run-xyz-789"
|
assert result["run_id"] == "run-xyz-789"
|
||||||
assert result["pipeline"] == "rag-query"
|
assert result["pipeline"] == "rag-query"
|
||||||
|
|
||||||
# Verify API call
|
# Verify API call
|
||||||
handler._http.post.assert_called_once()
|
handler._http.post.assert_called_once()
|
||||||
call_args = handler._http.post.call_args
|
call_args = handler._http.post.call_args
|
||||||
assert "ml-pipeline" in str(call_args)
|
assert "ml-pipeline" in str(call_args)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_handle_api_error(
|
async def test_handle_api_error(
|
||||||
self,
|
self,
|
||||||
@@ -153,12 +154,12 @@ class TestPipelineBridge:
|
|||||||
):
|
):
|
||||||
"""Test handling API errors."""
|
"""Test handling API errors."""
|
||||||
handler._http.post.side_effect = Exception("Connection refused")
|
handler._http.post.side_effect = Exception("Connection refused")
|
||||||
|
|
||||||
result = await handler.handle_message(mock_nats_message, argo_pipeline_request)
|
result = await handler.handle_message(mock_nats_message, argo_pipeline_request)
|
||||||
|
|
||||||
assert result["status"] == "error"
|
assert result["status"] == "error"
|
||||||
assert "Connection refused" in result["error"]
|
assert "Connection refused" in result["error"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_publishes_status_update(
|
async def test_publishes_status_update(
|
||||||
self,
|
self,
|
||||||
@@ -172,40 +173,40 @@ class TestPipelineBridge:
|
|||||||
mock_response.json.return_value = mock_argo_response
|
mock_response.json.return_value = mock_argo_response
|
||||||
mock_response.raise_for_status = MagicMock()
|
mock_response.raise_for_status = MagicMock()
|
||||||
handler._http.post.return_value = mock_response
|
handler._http.post.return_value = mock_response
|
||||||
|
|
||||||
await handler.handle_message(mock_nats_message, argo_pipeline_request)
|
await handler.handle_message(mock_nats_message, argo_pipeline_request)
|
||||||
|
|
||||||
handler.nats.publish.assert_called_once()
|
handler.nats.publish.assert_called_once()
|
||||||
call_args = handler.nats.publish.call_args
|
call_args = handler.nats.publish.call_args
|
||||||
assert "ai.pipeline.status.test-request-123" in str(call_args)
|
assert "ai.pipeline.status.test-request-123" in str(call_args)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_setup_creates_http_client(self):
|
async def test_setup_creates_http_client(self):
|
||||||
"""Test that setup initializes HTTP client."""
|
"""Test that setup initializes HTTP client."""
|
||||||
with patch("pipeline_bridge.httpx.AsyncClient") as mock_client:
|
with patch("pipeline_bridge.httpx.AsyncClient") as mock_client:
|
||||||
handler = PipelineBridge()
|
handler = PipelineBridge()
|
||||||
await handler.setup()
|
await handler.setup()
|
||||||
|
|
||||||
mock_client.assert_called_once()
|
mock_client.assert_called_once()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_teardown_closes_http_client(self, handler):
|
async def test_teardown_closes_http_client(self, handler):
|
||||||
"""Test that teardown closes HTTP client."""
|
"""Test that teardown closes HTTP client."""
|
||||||
await handler.teardown()
|
await handler.teardown()
|
||||||
|
|
||||||
handler._http.aclose.assert_called_once()
|
handler._http.aclose.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
class TestArgoSubmission:
|
class TestArgoSubmission:
|
||||||
"""Tests for Argo workflow submission."""
|
"""Tests for Argo workflow submission."""
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def handler(self):
|
def handler(self):
|
||||||
"""Create handler with mocked HTTP client."""
|
"""Create handler with mocked HTTP client."""
|
||||||
handler = PipelineBridge()
|
handler = PipelineBridge()
|
||||||
handler._http = AsyncMock()
|
handler._http = AsyncMock()
|
||||||
return handler
|
return handler
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_argo_workflow_structure(
|
async def test_argo_workflow_structure(
|
||||||
self,
|
self,
|
||||||
@@ -217,17 +218,17 @@ class TestArgoSubmission:
|
|||||||
mock_response.json.return_value = mock_argo_response
|
mock_response.json.return_value = mock_argo_response
|
||||||
mock_response.raise_for_status = MagicMock()
|
mock_response.raise_for_status = MagicMock()
|
||||||
handler._http.post.return_value = mock_response
|
handler._http.post.return_value = mock_response
|
||||||
|
|
||||||
await handler._submit_argo(
|
await handler._submit_argo(
|
||||||
template="document-ingestion",
|
template="document-ingestion",
|
||||||
parameters={"key": "value"},
|
parameters={"key": "value"},
|
||||||
request_id="test-123",
|
request_id="test-123",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify workflow structure
|
# Verify workflow structure
|
||||||
call_kwargs = handler._http.post.call_args.kwargs
|
call_kwargs = handler._http.post.call_args.kwargs
|
||||||
workflow = call_kwargs["json"]["workflow"]
|
workflow = call_kwargs["json"]["workflow"]
|
||||||
|
|
||||||
assert workflow["apiVersion"] == "argoproj.io/v1alpha1"
|
assert workflow["apiVersion"] == "argoproj.io/v1alpha1"
|
||||||
assert workflow["kind"] == "Workflow"
|
assert workflow["kind"] == "Workflow"
|
||||||
assert "workflowTemplateRef" in workflow["spec"]
|
assert "workflowTemplateRef" in workflow["spec"]
|
||||||
@@ -237,14 +238,14 @@ class TestArgoSubmission:
|
|||||||
|
|
||||||
class TestKubeflowSubmission:
|
class TestKubeflowSubmission:
|
||||||
"""Tests for Kubeflow pipeline submission."""
|
"""Tests for Kubeflow pipeline submission."""
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def handler(self):
|
def handler(self):
|
||||||
"""Create handler with mocked HTTP client."""
|
"""Create handler with mocked HTTP client."""
|
||||||
handler = PipelineBridge()
|
handler = PipelineBridge()
|
||||||
handler._http = AsyncMock()
|
handler._http = AsyncMock()
|
||||||
return handler
|
return handler
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_kubeflow_run_structure(
|
async def test_kubeflow_run_structure(
|
||||||
self,
|
self,
|
||||||
@@ -256,16 +257,16 @@ class TestKubeflowSubmission:
|
|||||||
mock_response.json.return_value = mock_kubeflow_response
|
mock_response.json.return_value = mock_kubeflow_response
|
||||||
mock_response.raise_for_status = MagicMock()
|
mock_response.raise_for_status = MagicMock()
|
||||||
handler._http.post.return_value = mock_response
|
handler._http.post.return_value = mock_response
|
||||||
|
|
||||||
await handler._submit_kubeflow(
|
await handler._submit_kubeflow(
|
||||||
pipeline_id="rag-pipeline",
|
pipeline_id="rag-pipeline",
|
||||||
parameters={"query": "test"},
|
parameters={"query": "test"},
|
||||||
request_id="test-456",
|
request_id="test-456",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify run request structure
|
# Verify run request structure
|
||||||
call_kwargs = handler._http.post.call_args.kwargs
|
call_kwargs = handler._http.post.call_args.kwargs
|
||||||
run_request = call_kwargs["json"]
|
run_request = call_kwargs["json"]
|
||||||
|
|
||||||
assert "rag-pipeline" in run_request["name"]
|
assert "rag-pipeline" in run_request["name"]
|
||||||
assert run_request["pipeline_spec"]["pipeline_id"] == "rag-pipeline"
|
assert run_request["pipeline_spec"]["pipeline_id"] == "rag-pipeline"
|
||||||
|
|||||||
Reference in New Issue
Block a user