From 49c804e23442e245efb2c4d717e750a253ecd6dd Mon Sep 17 00:00:00 2001 From: "Billy D." Date: Mon, 2 Feb 2026 08:44:51 -0500 Subject: [PATCH] fix: ruff formatting and allow-direct-references for handler-base dep --- pipeline_bridge.py | 85 +++++++++++++++-------------------- pyproject.toml | 3 ++ tests/conftest.py | 3 +- tests/test_pipeline_bridge.py | 81 ++++++++++++++++----------------- 4 files changed, 82 insertions(+), 90 deletions(-) diff --git a/pipeline_bridge.py b/pipeline_bridge.py index 3b573d4..3e94758 100644 --- a/pipeline_bridge.py +++ b/pipeline_bridge.py @@ -8,6 +8,7 @@ Bridges NATS events to workflow engines using handler-base: 3. Monitor execution and publish status updates 4. Publish completion to "ai.pipeline.status.{request_id}" """ + import logging from typing import Any, Optional from datetime import datetime @@ -23,12 +24,12 @@ logger = logging.getLogger("pipeline-bridge") class PipelineSettings(Settings): """Pipeline bridge specific settings.""" - + service_name: str = "pipeline-bridge" - + # Kubeflow Pipelines kubeflow_host: str = "http://ml-pipeline.kubeflow.svc.cluster.local:8888" - + # Argo Workflows argo_host: str = "http://argo-server.argo.svc.cluster.local:2746" argo_namespace: str = "ai-ml" @@ -67,14 +68,14 @@ PIPELINES = { class PipelineBridge(Handler): """ Pipeline trigger handler. - + Request format: { "request_id": "uuid", "pipeline": "document-ingestion", "parameters": {"key": "value"} } - + Response format: { "request_id": "uuid", @@ -83,7 +84,7 @@ class PipelineBridge(Handler): "engine": "argo|kubeflow" } """ - + def __init__(self): self.pipeline_settings = PipelineSettings() super().__init__( @@ -91,36 +92,36 @@ class PipelineBridge(Handler): settings=self.pipeline_settings, queue_group="pipeline-bridges", ) - + self._http: Optional[httpx.AsyncClient] = None - + async def setup(self) -> None: """Initialize HTTP client.""" logger.info("Initializing pipeline bridge...") - + self._http = httpx.AsyncClient(timeout=60.0) - + logger.info(f"Pipeline bridge ready. Available pipelines: {list(PIPELINES.keys())}") - + async def teardown(self) -> None: """Clean up HTTP client.""" if self._http: await self._http.aclose() logger.info("Pipeline bridge closed") - + async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]: """Handle pipeline trigger request.""" request_id = data.get("request_id", "unknown") pipeline_name = data.get("pipeline", "") parameters = data.get("parameters", {}) - + logger.info(f"Triggering pipeline '{pipeline_name}' for request {request_id}") - + with create_span("pipeline.trigger") as span: if span: span.set_attribute("request.id", request_id) span.set_attribute("pipeline.name", pipeline_name) - + # Validate pipeline if pipeline_name not in PIPELINES: error = f"Unknown pipeline: {pipeline_name}" @@ -131,20 +132,18 @@ class PipelineBridge(Handler): "error": error, "available_pipelines": list(PIPELINES.keys()), } - + pipeline = PIPELINES[pipeline_name] engine = pipeline["engine"] - + try: if engine == "argo": - run_id = await self._submit_argo( - pipeline["template"], parameters, request_id - ) + run_id = await self._submit_argo(pipeline["template"], parameters, request_id) else: run_id = await self._submit_kubeflow( pipeline["pipeline_id"], parameters, request_id ) - + result = { "request_id": request_id, "status": "submitted", @@ -153,15 +152,13 @@ class PipelineBridge(Handler): "pipeline": pipeline_name, "submitted_at": datetime.utcnow().isoformat(), } - + # Publish status update - await self.nats.publish( - f"ai.pipeline.status.{request_id}", result - ) - + await self.nats.publish(f"ai.pipeline.status.{request_id}", result) + logger.info(f"Pipeline {pipeline_name} submitted: {run_id}") return result - + except Exception as e: logger.exception(f"Failed to submit pipeline {pipeline_name}") return { @@ -169,15 +166,13 @@ class PipelineBridge(Handler): "status": "error", "error": str(e), } - - async def _submit_argo( - self, template: str, parameters: dict, request_id: str - ) -> str: + + async def _submit_argo(self, template: str, parameters: dict, request_id: str) -> str: """Submit workflow to Argo Workflows.""" with create_span("pipeline.submit.argo") as span: if span: span.set_attribute("argo.template", template) - + workflow = { "apiVersion": "argoproj.io/v1alpha1", "kind": "Workflow", @@ -191,48 +186,40 @@ class PipelineBridge(Handler): "spec": { "workflowTemplateRef": {"name": template}, "arguments": { - "parameters": [ - {"name": k, "value": str(v)} - for k, v in parameters.items() - ] + "parameters": [{"name": k, "value": str(v)} for k, v in parameters.items()] }, }, } - + response = await self._http.post( f"{self.pipeline_settings.argo_host}/api/v1/workflows/{self.pipeline_settings.argo_namespace}", json={"workflow": workflow}, ) response.raise_for_status() - + result = response.json() return result["metadata"]["name"] - - async def _submit_kubeflow( - self, pipeline_id: str, parameters: dict, request_id: str - ) -> str: + + async def _submit_kubeflow(self, pipeline_id: str, parameters: dict, request_id: str) -> str: """Submit run to Kubeflow Pipelines.""" with create_span("pipeline.submit.kubeflow") as span: if span: span.set_attribute("kubeflow.pipeline_id", pipeline_id) - + run_request = { "name": f"{pipeline_id}-{request_id[:8]}", "pipeline_spec": { "pipeline_id": pipeline_id, - "parameters": [ - {"name": k, "value": str(v)} - for k, v in parameters.items() - ], + "parameters": [{"name": k, "value": str(v)} for k, v in parameters.items()], }, } - + response = await self._http.post( f"{self.pipeline_settings.kubeflow_host}/apis/v1beta1/runs", json=run_request, ) response.raise_for_status() - + result = response.json() return result["run"]["id"] diff --git a/pyproject.toml b/pyproject.toml index 6d5da00..32f5117 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,9 @@ dev = [ requires = ["hatchling"] build-backend = "hatchling.build" +[tool.hatch.metadata] +allow-direct-references = true + [tool.hatch.build.targets.wheel] packages = ["."] only-include = ["pipeline_bridge.py"] diff --git a/tests/conftest.py b/tests/conftest.py index 53bad0f..d1ea6f2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,10 @@ """ Pytest configuration and fixtures for pipeline-bridge tests. """ + import asyncio import os -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock import pytest diff --git a/tests/test_pipeline_bridge.py b/tests/test_pipeline_bridge.py index 1e819d4..c0f8859 100644 --- a/tests/test_pipeline_bridge.py +++ b/tests/test_pipeline_bridge.py @@ -1,6 +1,7 @@ """ Unit tests for PipelineBridge handler. """ + import pytest from unittest.mock import AsyncMock, MagicMock, patch @@ -9,48 +10,48 @@ from pipeline_bridge import PipelineBridge, PipelineSettings, PIPELINES class TestPipelineSettings: """Tests for PipelineSettings configuration.""" - + def test_default_settings(self): """Test default settings values.""" settings = PipelineSettings() - + assert settings.service_name == "pipeline-bridge" assert settings.kubeflow_host == "http://ml-pipeline.kubeflow.svc.cluster.local:8888" assert settings.argo_host == "http://argo-server.argo.svc.cluster.local:2746" assert settings.argo_namespace == "ai-ml" - + def test_custom_settings(self): """Test custom settings.""" settings = PipelineSettings( kubeflow_host="http://custom-kubeflow:8888", argo_namespace="custom-ns", ) - + assert settings.kubeflow_host == "http://custom-kubeflow:8888" assert settings.argo_namespace == "custom-ns" class TestPipelineDefinitions: """Tests for pipeline definitions.""" - + def test_required_pipelines_exist(self): """Test that required pipelines are defined.""" required = ["document-ingestion", "batch-inference", "rag-query", "voice-pipeline"] for name in required: assert name in PIPELINES, f"Pipeline {name} should be defined" - + def test_argo_pipelines_have_template(self): """Test Argo pipelines have template field.""" for name, config in PIPELINES.items(): if config["engine"] == "argo": assert "template" in config, f"Argo pipeline {name} missing template" - + def test_kubeflow_pipelines_have_pipeline_id(self): """Test Kubeflow pipelines have pipeline_id field.""" for name, config in PIPELINES.items(): if config["engine"] == "kubeflow": assert "pipeline_id" in config, f"Kubeflow pipeline {name} missing pipeline_id" - + def test_all_pipelines_have_description(self): """Test all pipelines have descriptions.""" for name, config in PIPELINES.items(): @@ -59,7 +60,7 @@ class TestPipelineDefinitions: class TestPipelineBridge: """Tests for PipelineBridge handler.""" - + @pytest.fixture def handler(self): """Create handler with mocked HTTP client.""" @@ -67,13 +68,13 @@ class TestPipelineBridge: handler._http = AsyncMock() handler.nats = AsyncMock() return handler - + def test_init(self, handler): """Test handler initialization.""" assert handler.subject == "ai.pipeline.trigger" assert handler.queue_group == "pipeline-bridges" assert handler.pipeline_settings.service_name == "pipeline-bridge" - + @pytest.mark.asyncio async def test_handle_unknown_pipeline( self, @@ -83,12 +84,12 @@ class TestPipelineBridge: ): """Test handling unknown pipeline.""" result = await handler.handle_message(mock_nats_message, unknown_pipeline_request) - + assert result["status"] == "error" assert "Unknown pipeline" in result["error"] assert "available_pipelines" in result assert "document-ingestion" in result["available_pipelines"] - + @pytest.mark.asyncio async def test_handle_argo_pipeline( self, @@ -103,20 +104,20 @@ class TestPipelineBridge: mock_response.json.return_value = mock_argo_response mock_response.raise_for_status = MagicMock() handler._http.post.return_value = mock_response - + result = await handler.handle_message(mock_nats_message, argo_pipeline_request) - + assert result["status"] == "submitted" assert result["engine"] == "argo" assert result["run_id"] == "document-ingestion-abc123" assert result["pipeline"] == "document-ingestion" assert "submitted_at" in result - + # Verify API call handler._http.post.assert_called_once() call_args = handler._http.post.call_args assert "argo-server" in str(call_args) - + @pytest.mark.asyncio async def test_handle_kubeflow_pipeline( self, @@ -131,19 +132,19 @@ class TestPipelineBridge: mock_response.json.return_value = mock_kubeflow_response mock_response.raise_for_status = MagicMock() handler._http.post.return_value = mock_response - + result = await handler.handle_message(mock_nats_message, kubeflow_pipeline_request) - + assert result["status"] == "submitted" assert result["engine"] == "kubeflow" assert result["run_id"] == "run-xyz-789" assert result["pipeline"] == "rag-query" - + # Verify API call handler._http.post.assert_called_once() call_args = handler._http.post.call_args assert "ml-pipeline" in str(call_args) - + @pytest.mark.asyncio async def test_handle_api_error( self, @@ -153,12 +154,12 @@ class TestPipelineBridge: ): """Test handling API errors.""" handler._http.post.side_effect = Exception("Connection refused") - + result = await handler.handle_message(mock_nats_message, argo_pipeline_request) - + assert result["status"] == "error" assert "Connection refused" in result["error"] - + @pytest.mark.asyncio async def test_publishes_status_update( self, @@ -172,40 +173,40 @@ class TestPipelineBridge: mock_response.json.return_value = mock_argo_response mock_response.raise_for_status = MagicMock() handler._http.post.return_value = mock_response - + await handler.handle_message(mock_nats_message, argo_pipeline_request) - + handler.nats.publish.assert_called_once() call_args = handler.nats.publish.call_args assert "ai.pipeline.status.test-request-123" in str(call_args) - + @pytest.mark.asyncio async def test_setup_creates_http_client(self): """Test that setup initializes HTTP client.""" with patch("pipeline_bridge.httpx.AsyncClient") as mock_client: handler = PipelineBridge() await handler.setup() - + mock_client.assert_called_once() - + @pytest.mark.asyncio async def test_teardown_closes_http_client(self, handler): """Test that teardown closes HTTP client.""" await handler.teardown() - + handler._http.aclose.assert_called_once() class TestArgoSubmission: """Tests for Argo workflow submission.""" - + @pytest.fixture def handler(self): """Create handler with mocked HTTP client.""" handler = PipelineBridge() handler._http = AsyncMock() return handler - + @pytest.mark.asyncio async def test_argo_workflow_structure( self, @@ -217,17 +218,17 @@ class TestArgoSubmission: mock_response.json.return_value = mock_argo_response mock_response.raise_for_status = MagicMock() handler._http.post.return_value = mock_response - + await handler._submit_argo( template="document-ingestion", parameters={"key": "value"}, request_id="test-123", ) - + # Verify workflow structure call_kwargs = handler._http.post.call_args.kwargs workflow = call_kwargs["json"]["workflow"] - + assert workflow["apiVersion"] == "argoproj.io/v1alpha1" assert workflow["kind"] == "Workflow" assert "workflowTemplateRef" in workflow["spec"] @@ -237,14 +238,14 @@ class TestArgoSubmission: class TestKubeflowSubmission: """Tests for Kubeflow pipeline submission.""" - + @pytest.fixture def handler(self): """Create handler with mocked HTTP client.""" handler = PipelineBridge() handler._http = AsyncMock() return handler - + @pytest.mark.asyncio async def test_kubeflow_run_structure( self, @@ -256,16 +257,16 @@ class TestKubeflowSubmission: mock_response.json.return_value = mock_kubeflow_response mock_response.raise_for_status = MagicMock() handler._http.post.return_value = mock_response - + await handler._submit_kubeflow( pipeline_id="rag-pipeline", parameters={"query": "test"}, request_id="test-456", ) - + # Verify run request structure call_kwargs = handler._http.post.call_args.kwargs run_request = call_kwargs["json"] - + assert "rag-pipeline" in run_request["name"] assert run_request["pipeline_spec"]["pipeline_id"] == "rag-pipeline"