fix: ruff formatting and allow-direct-references for handler-base dep

This commit is contained in:
2026-02-02 08:44:51 -05:00
parent 7c7a147db6
commit 4fbde95eb7
4 changed files with 82 additions and 90 deletions

View File

@@ -8,6 +8,7 @@ Bridges NATS events to workflow engines using handler-base:
3. Monitor execution and publish status updates 3. Monitor execution and publish status updates
4. Publish completion to "ai.pipeline.status.{request_id}" 4. Publish completion to "ai.pipeline.status.{request_id}"
""" """
import logging import logging
from typing import Any, Optional from typing import Any, Optional
from datetime import datetime from datetime import datetime
@@ -23,12 +24,12 @@ logger = logging.getLogger("pipeline-bridge")
class PipelineSettings(Settings): class PipelineSettings(Settings):
"""Pipeline bridge specific settings.""" """Pipeline bridge specific settings."""
service_name: str = "pipeline-bridge" service_name: str = "pipeline-bridge"
# Kubeflow Pipelines # Kubeflow Pipelines
kubeflow_host: str = "http://ml-pipeline.kubeflow.svc.cluster.local:8888" kubeflow_host: str = "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
# Argo Workflows # Argo Workflows
argo_host: str = "http://argo-server.argo.svc.cluster.local:2746" argo_host: str = "http://argo-server.argo.svc.cluster.local:2746"
argo_namespace: str = "ai-ml" argo_namespace: str = "ai-ml"
@@ -67,14 +68,14 @@ PIPELINES = {
class PipelineBridge(Handler): class PipelineBridge(Handler):
""" """
Pipeline trigger handler. Pipeline trigger handler.
Request format: Request format:
{ {
"request_id": "uuid", "request_id": "uuid",
"pipeline": "document-ingestion", "pipeline": "document-ingestion",
"parameters": {"key": "value"} "parameters": {"key": "value"}
} }
Response format: Response format:
{ {
"request_id": "uuid", "request_id": "uuid",
@@ -83,7 +84,7 @@ class PipelineBridge(Handler):
"engine": "argo|kubeflow" "engine": "argo|kubeflow"
} }
""" """
def __init__(self): def __init__(self):
self.pipeline_settings = PipelineSettings() self.pipeline_settings = PipelineSettings()
super().__init__( super().__init__(
@@ -91,36 +92,36 @@ class PipelineBridge(Handler):
settings=self.pipeline_settings, settings=self.pipeline_settings,
queue_group="pipeline-bridges", queue_group="pipeline-bridges",
) )
self._http: Optional[httpx.AsyncClient] = None self._http: Optional[httpx.AsyncClient] = None
async def setup(self) -> None: async def setup(self) -> None:
"""Initialize HTTP client.""" """Initialize HTTP client."""
logger.info("Initializing pipeline bridge...") logger.info("Initializing pipeline bridge...")
self._http = httpx.AsyncClient(timeout=60.0) self._http = httpx.AsyncClient(timeout=60.0)
logger.info(f"Pipeline bridge ready. Available pipelines: {list(PIPELINES.keys())}") logger.info(f"Pipeline bridge ready. Available pipelines: {list(PIPELINES.keys())}")
async def teardown(self) -> None: async def teardown(self) -> None:
"""Clean up HTTP client.""" """Clean up HTTP client."""
if self._http: if self._http:
await self._http.aclose() await self._http.aclose()
logger.info("Pipeline bridge closed") logger.info("Pipeline bridge closed")
async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]: async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]:
"""Handle pipeline trigger request.""" """Handle pipeline trigger request."""
request_id = data.get("request_id", "unknown") request_id = data.get("request_id", "unknown")
pipeline_name = data.get("pipeline", "") pipeline_name = data.get("pipeline", "")
parameters = data.get("parameters", {}) parameters = data.get("parameters", {})
logger.info(f"Triggering pipeline '{pipeline_name}' for request {request_id}") logger.info(f"Triggering pipeline '{pipeline_name}' for request {request_id}")
with create_span("pipeline.trigger") as span: with create_span("pipeline.trigger") as span:
if span: if span:
span.set_attribute("request.id", request_id) span.set_attribute("request.id", request_id)
span.set_attribute("pipeline.name", pipeline_name) span.set_attribute("pipeline.name", pipeline_name)
# Validate pipeline # Validate pipeline
if pipeline_name not in PIPELINES: if pipeline_name not in PIPELINES:
error = f"Unknown pipeline: {pipeline_name}" error = f"Unknown pipeline: {pipeline_name}"
@@ -131,20 +132,18 @@ class PipelineBridge(Handler):
"error": error, "error": error,
"available_pipelines": list(PIPELINES.keys()), "available_pipelines": list(PIPELINES.keys()),
} }
pipeline = PIPELINES[pipeline_name] pipeline = PIPELINES[pipeline_name]
engine = pipeline["engine"] engine = pipeline["engine"]
try: try:
if engine == "argo": if engine == "argo":
run_id = await self._submit_argo( run_id = await self._submit_argo(pipeline["template"], parameters, request_id)
pipeline["template"], parameters, request_id
)
else: else:
run_id = await self._submit_kubeflow( run_id = await self._submit_kubeflow(
pipeline["pipeline_id"], parameters, request_id pipeline["pipeline_id"], parameters, request_id
) )
result = { result = {
"request_id": request_id, "request_id": request_id,
"status": "submitted", "status": "submitted",
@@ -153,15 +152,13 @@ class PipelineBridge(Handler):
"pipeline": pipeline_name, "pipeline": pipeline_name,
"submitted_at": datetime.utcnow().isoformat(), "submitted_at": datetime.utcnow().isoformat(),
} }
# Publish status update # Publish status update
await self.nats.publish( await self.nats.publish(f"ai.pipeline.status.{request_id}", result)
f"ai.pipeline.status.{request_id}", result
)
logger.info(f"Pipeline {pipeline_name} submitted: {run_id}") logger.info(f"Pipeline {pipeline_name} submitted: {run_id}")
return result return result
except Exception as e: except Exception as e:
logger.exception(f"Failed to submit pipeline {pipeline_name}") logger.exception(f"Failed to submit pipeline {pipeline_name}")
return { return {
@@ -169,15 +166,13 @@ class PipelineBridge(Handler):
"status": "error", "status": "error",
"error": str(e), "error": str(e),
} }
async def _submit_argo( async def _submit_argo(self, template: str, parameters: dict, request_id: str) -> str:
self, template: str, parameters: dict, request_id: str
) -> str:
"""Submit workflow to Argo Workflows.""" """Submit workflow to Argo Workflows."""
with create_span("pipeline.submit.argo") as span: with create_span("pipeline.submit.argo") as span:
if span: if span:
span.set_attribute("argo.template", template) span.set_attribute("argo.template", template)
workflow = { workflow = {
"apiVersion": "argoproj.io/v1alpha1", "apiVersion": "argoproj.io/v1alpha1",
"kind": "Workflow", "kind": "Workflow",
@@ -191,48 +186,40 @@ class PipelineBridge(Handler):
"spec": { "spec": {
"workflowTemplateRef": {"name": template}, "workflowTemplateRef": {"name": template},
"arguments": { "arguments": {
"parameters": [ "parameters": [{"name": k, "value": str(v)} for k, v in parameters.items()]
{"name": k, "value": str(v)}
for k, v in parameters.items()
]
}, },
}, },
} }
response = await self._http.post( response = await self._http.post(
f"{self.pipeline_settings.argo_host}/api/v1/workflows/{self.pipeline_settings.argo_namespace}", f"{self.pipeline_settings.argo_host}/api/v1/workflows/{self.pipeline_settings.argo_namespace}",
json={"workflow": workflow}, json={"workflow": workflow},
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
return result["metadata"]["name"] return result["metadata"]["name"]
async def _submit_kubeflow( async def _submit_kubeflow(self, pipeline_id: str, parameters: dict, request_id: str) -> str:
self, pipeline_id: str, parameters: dict, request_id: str
) -> str:
"""Submit run to Kubeflow Pipelines.""" """Submit run to Kubeflow Pipelines."""
with create_span("pipeline.submit.kubeflow") as span: with create_span("pipeline.submit.kubeflow") as span:
if span: if span:
span.set_attribute("kubeflow.pipeline_id", pipeline_id) span.set_attribute("kubeflow.pipeline_id", pipeline_id)
run_request = { run_request = {
"name": f"{pipeline_id}-{request_id[:8]}", "name": f"{pipeline_id}-{request_id[:8]}",
"pipeline_spec": { "pipeline_spec": {
"pipeline_id": pipeline_id, "pipeline_id": pipeline_id,
"parameters": [ "parameters": [{"name": k, "value": str(v)} for k, v in parameters.items()],
{"name": k, "value": str(v)}
for k, v in parameters.items()
],
}, },
} }
response = await self._http.post( response = await self._http.post(
f"{self.pipeline_settings.kubeflow_host}/apis/v1beta1/runs", f"{self.pipeline_settings.kubeflow_host}/apis/v1beta1/runs",
json=run_request, json=run_request,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
return result["run"]["id"] return result["run"]["id"]

View File

@@ -24,6 +24,9 @@ dev = [
requires = ["hatchling"] requires = ["hatchling"]
build-backend = "hatchling.build" build-backend = "hatchling.build"
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build.targets.wheel] [tool.hatch.build.targets.wheel]
packages = ["."] packages = ["."]
only-include = ["pipeline_bridge.py"] only-include = ["pipeline_bridge.py"]

View File

@@ -1,9 +1,10 @@
""" """
Pytest configuration and fixtures for pipeline-bridge tests. Pytest configuration and fixtures for pipeline-bridge tests.
""" """
import asyncio import asyncio
import os import os
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import MagicMock
import pytest import pytest

View File

@@ -1,6 +1,7 @@
""" """
Unit tests for PipelineBridge handler. Unit tests for PipelineBridge handler.
""" """
import pytest import pytest
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
@@ -9,48 +10,48 @@ from pipeline_bridge import PipelineBridge, PipelineSettings, PIPELINES
class TestPipelineSettings: class TestPipelineSettings:
"""Tests for PipelineSettings configuration.""" """Tests for PipelineSettings configuration."""
def test_default_settings(self): def test_default_settings(self):
"""Test default settings values.""" """Test default settings values."""
settings = PipelineSettings() settings = PipelineSettings()
assert settings.service_name == "pipeline-bridge" assert settings.service_name == "pipeline-bridge"
assert settings.kubeflow_host == "http://ml-pipeline.kubeflow.svc.cluster.local:8888" assert settings.kubeflow_host == "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
assert settings.argo_host == "http://argo-server.argo.svc.cluster.local:2746" assert settings.argo_host == "http://argo-server.argo.svc.cluster.local:2746"
assert settings.argo_namespace == "ai-ml" assert settings.argo_namespace == "ai-ml"
def test_custom_settings(self): def test_custom_settings(self):
"""Test custom settings.""" """Test custom settings."""
settings = PipelineSettings( settings = PipelineSettings(
kubeflow_host="http://custom-kubeflow:8888", kubeflow_host="http://custom-kubeflow:8888",
argo_namespace="custom-ns", argo_namespace="custom-ns",
) )
assert settings.kubeflow_host == "http://custom-kubeflow:8888" assert settings.kubeflow_host == "http://custom-kubeflow:8888"
assert settings.argo_namespace == "custom-ns" assert settings.argo_namespace == "custom-ns"
class TestPipelineDefinitions: class TestPipelineDefinitions:
"""Tests for pipeline definitions.""" """Tests for pipeline definitions."""
def test_required_pipelines_exist(self): def test_required_pipelines_exist(self):
"""Test that required pipelines are defined.""" """Test that required pipelines are defined."""
required = ["document-ingestion", "batch-inference", "rag-query", "voice-pipeline"] required = ["document-ingestion", "batch-inference", "rag-query", "voice-pipeline"]
for name in required: for name in required:
assert name in PIPELINES, f"Pipeline {name} should be defined" assert name in PIPELINES, f"Pipeline {name} should be defined"
def test_argo_pipelines_have_template(self): def test_argo_pipelines_have_template(self):
"""Test Argo pipelines have template field.""" """Test Argo pipelines have template field."""
for name, config in PIPELINES.items(): for name, config in PIPELINES.items():
if config["engine"] == "argo": if config["engine"] == "argo":
assert "template" in config, f"Argo pipeline {name} missing template" assert "template" in config, f"Argo pipeline {name} missing template"
def test_kubeflow_pipelines_have_pipeline_id(self): def test_kubeflow_pipelines_have_pipeline_id(self):
"""Test Kubeflow pipelines have pipeline_id field.""" """Test Kubeflow pipelines have pipeline_id field."""
for name, config in PIPELINES.items(): for name, config in PIPELINES.items():
if config["engine"] == "kubeflow": if config["engine"] == "kubeflow":
assert "pipeline_id" in config, f"Kubeflow pipeline {name} missing pipeline_id" assert "pipeline_id" in config, f"Kubeflow pipeline {name} missing pipeline_id"
def test_all_pipelines_have_description(self): def test_all_pipelines_have_description(self):
"""Test all pipelines have descriptions.""" """Test all pipelines have descriptions."""
for name, config in PIPELINES.items(): for name, config in PIPELINES.items():
@@ -59,7 +60,7 @@ class TestPipelineDefinitions:
class TestPipelineBridge: class TestPipelineBridge:
"""Tests for PipelineBridge handler.""" """Tests for PipelineBridge handler."""
@pytest.fixture @pytest.fixture
def handler(self): def handler(self):
"""Create handler with mocked HTTP client.""" """Create handler with mocked HTTP client."""
@@ -67,13 +68,13 @@ class TestPipelineBridge:
handler._http = AsyncMock() handler._http = AsyncMock()
handler.nats = AsyncMock() handler.nats = AsyncMock()
return handler return handler
def test_init(self, handler): def test_init(self, handler):
"""Test handler initialization.""" """Test handler initialization."""
assert handler.subject == "ai.pipeline.trigger" assert handler.subject == "ai.pipeline.trigger"
assert handler.queue_group == "pipeline-bridges" assert handler.queue_group == "pipeline-bridges"
assert handler.pipeline_settings.service_name == "pipeline-bridge" assert handler.pipeline_settings.service_name == "pipeline-bridge"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_unknown_pipeline( async def test_handle_unknown_pipeline(
self, self,
@@ -83,12 +84,12 @@ class TestPipelineBridge:
): ):
"""Test handling unknown pipeline.""" """Test handling unknown pipeline."""
result = await handler.handle_message(mock_nats_message, unknown_pipeline_request) result = await handler.handle_message(mock_nats_message, unknown_pipeline_request)
assert result["status"] == "error" assert result["status"] == "error"
assert "Unknown pipeline" in result["error"] assert "Unknown pipeline" in result["error"]
assert "available_pipelines" in result assert "available_pipelines" in result
assert "document-ingestion" in result["available_pipelines"] assert "document-ingestion" in result["available_pipelines"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_argo_pipeline( async def test_handle_argo_pipeline(
self, self,
@@ -103,20 +104,20 @@ class TestPipelineBridge:
mock_response.json.return_value = mock_argo_response mock_response.json.return_value = mock_argo_response
mock_response.raise_for_status = MagicMock() mock_response.raise_for_status = MagicMock()
handler._http.post.return_value = mock_response handler._http.post.return_value = mock_response
result = await handler.handle_message(mock_nats_message, argo_pipeline_request) result = await handler.handle_message(mock_nats_message, argo_pipeline_request)
assert result["status"] == "submitted" assert result["status"] == "submitted"
assert result["engine"] == "argo" assert result["engine"] == "argo"
assert result["run_id"] == "document-ingestion-abc123" assert result["run_id"] == "document-ingestion-abc123"
assert result["pipeline"] == "document-ingestion" assert result["pipeline"] == "document-ingestion"
assert "submitted_at" in result assert "submitted_at" in result
# Verify API call # Verify API call
handler._http.post.assert_called_once() handler._http.post.assert_called_once()
call_args = handler._http.post.call_args call_args = handler._http.post.call_args
assert "argo-server" in str(call_args) assert "argo-server" in str(call_args)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_kubeflow_pipeline( async def test_handle_kubeflow_pipeline(
self, self,
@@ -131,19 +132,19 @@ class TestPipelineBridge:
mock_response.json.return_value = mock_kubeflow_response mock_response.json.return_value = mock_kubeflow_response
mock_response.raise_for_status = MagicMock() mock_response.raise_for_status = MagicMock()
handler._http.post.return_value = mock_response handler._http.post.return_value = mock_response
result = await handler.handle_message(mock_nats_message, kubeflow_pipeline_request) result = await handler.handle_message(mock_nats_message, kubeflow_pipeline_request)
assert result["status"] == "submitted" assert result["status"] == "submitted"
assert result["engine"] == "kubeflow" assert result["engine"] == "kubeflow"
assert result["run_id"] == "run-xyz-789" assert result["run_id"] == "run-xyz-789"
assert result["pipeline"] == "rag-query" assert result["pipeline"] == "rag-query"
# Verify API call # Verify API call
handler._http.post.assert_called_once() handler._http.post.assert_called_once()
call_args = handler._http.post.call_args call_args = handler._http.post.call_args
assert "ml-pipeline" in str(call_args) assert "ml-pipeline" in str(call_args)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_api_error( async def test_handle_api_error(
self, self,
@@ -153,12 +154,12 @@ class TestPipelineBridge:
): ):
"""Test handling API errors.""" """Test handling API errors."""
handler._http.post.side_effect = Exception("Connection refused") handler._http.post.side_effect = Exception("Connection refused")
result = await handler.handle_message(mock_nats_message, argo_pipeline_request) result = await handler.handle_message(mock_nats_message, argo_pipeline_request)
assert result["status"] == "error" assert result["status"] == "error"
assert "Connection refused" in result["error"] assert "Connection refused" in result["error"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_publishes_status_update( async def test_publishes_status_update(
self, self,
@@ -172,40 +173,40 @@ class TestPipelineBridge:
mock_response.json.return_value = mock_argo_response mock_response.json.return_value = mock_argo_response
mock_response.raise_for_status = MagicMock() mock_response.raise_for_status = MagicMock()
handler._http.post.return_value = mock_response handler._http.post.return_value = mock_response
await handler.handle_message(mock_nats_message, argo_pipeline_request) await handler.handle_message(mock_nats_message, argo_pipeline_request)
handler.nats.publish.assert_called_once() handler.nats.publish.assert_called_once()
call_args = handler.nats.publish.call_args call_args = handler.nats.publish.call_args
assert "ai.pipeline.status.test-request-123" in str(call_args) assert "ai.pipeline.status.test-request-123" in str(call_args)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_http_client(self): async def test_setup_creates_http_client(self):
"""Test that setup initializes HTTP client.""" """Test that setup initializes HTTP client."""
with patch("pipeline_bridge.httpx.AsyncClient") as mock_client: with patch("pipeline_bridge.httpx.AsyncClient") as mock_client:
handler = PipelineBridge() handler = PipelineBridge()
await handler.setup() await handler.setup()
mock_client.assert_called_once() mock_client.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_teardown_closes_http_client(self, handler): async def test_teardown_closes_http_client(self, handler):
"""Test that teardown closes HTTP client.""" """Test that teardown closes HTTP client."""
await handler.teardown() await handler.teardown()
handler._http.aclose.assert_called_once() handler._http.aclose.assert_called_once()
class TestArgoSubmission: class TestArgoSubmission:
"""Tests for Argo workflow submission.""" """Tests for Argo workflow submission."""
@pytest.fixture @pytest.fixture
def handler(self): def handler(self):
"""Create handler with mocked HTTP client.""" """Create handler with mocked HTTP client."""
handler = PipelineBridge() handler = PipelineBridge()
handler._http = AsyncMock() handler._http = AsyncMock()
return handler return handler
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_argo_workflow_structure( async def test_argo_workflow_structure(
self, self,
@@ -217,17 +218,17 @@ class TestArgoSubmission:
mock_response.json.return_value = mock_argo_response mock_response.json.return_value = mock_argo_response
mock_response.raise_for_status = MagicMock() mock_response.raise_for_status = MagicMock()
handler._http.post.return_value = mock_response handler._http.post.return_value = mock_response
await handler._submit_argo( await handler._submit_argo(
template="document-ingestion", template="document-ingestion",
parameters={"key": "value"}, parameters={"key": "value"},
request_id="test-123", request_id="test-123",
) )
# Verify workflow structure # Verify workflow structure
call_kwargs = handler._http.post.call_args.kwargs call_kwargs = handler._http.post.call_args.kwargs
workflow = call_kwargs["json"]["workflow"] workflow = call_kwargs["json"]["workflow"]
assert workflow["apiVersion"] == "argoproj.io/v1alpha1" assert workflow["apiVersion"] == "argoproj.io/v1alpha1"
assert workflow["kind"] == "Workflow" assert workflow["kind"] == "Workflow"
assert "workflowTemplateRef" in workflow["spec"] assert "workflowTemplateRef" in workflow["spec"]
@@ -237,14 +238,14 @@ class TestArgoSubmission:
class TestKubeflowSubmission: class TestKubeflowSubmission:
"""Tests for Kubeflow pipeline submission.""" """Tests for Kubeflow pipeline submission."""
@pytest.fixture @pytest.fixture
def handler(self): def handler(self):
"""Create handler with mocked HTTP client.""" """Create handler with mocked HTTP client."""
handler = PipelineBridge() handler = PipelineBridge()
handler._http = AsyncMock() handler._http = AsyncMock()
return handler return handler
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_kubeflow_run_structure( async def test_kubeflow_run_structure(
self, self,
@@ -256,16 +257,16 @@ class TestKubeflowSubmission:
mock_response.json.return_value = mock_kubeflow_response mock_response.json.return_value = mock_kubeflow_response
mock_response.raise_for_status = MagicMock() mock_response.raise_for_status = MagicMock()
handler._http.post.return_value = mock_response handler._http.post.return_value = mock_response
await handler._submit_kubeflow( await handler._submit_kubeflow(
pipeline_id="rag-pipeline", pipeline_id="rag-pipeline",
parameters={"query": "test"}, parameters={"query": "test"},
request_id="test-456", request_id="test-456",
) )
# Verify run request structure # Verify run request structure
call_kwargs = handler._http.post.call_args.kwargs call_kwargs = handler._http.post.call_args.kwargs
run_request = call_kwargs["json"] run_request = call_kwargs["json"]
assert "rag-pipeline" in run_request["name"] assert "rag-pipeline" in run_request["name"]
assert run_request["pipeline_spec"]["pipeline_id"] == "rag-pipeline" assert run_request["pipeline_spec"]["pipeline_id"] == "rag-pipeline"