From 3f1e05eaad91420e1f8cc305ae350730e4d4492f Mon Sep 17 00:00:00 2001 From: "Billy D." Date: Mon, 2 Feb 2026 07:11:08 -0500 Subject: [PATCH] refactor: consolidate to handler-base, migrate to pyproject.toml, add tests --- Dockerfile | 28 +- Dockerfile.v2 | 12 - README.md | 13 +- pipeline_bridge.py | 470 +++++++++++++--------------------- pipeline_bridge_v2.py | 241 ----------------- pyproject.toml | 42 +++ requirements.txt | 3 - tests/__init__.py | 1 + tests/conftest.py | 90 +++++++ tests/test_pipeline_bridge.py | 271 ++++++++++++++++++++ 10 files changed, 594 insertions(+), 577 deletions(-) delete mode 100644 Dockerfile.v2 delete mode 100644 pipeline_bridge_v2.py create mode 100644 pyproject.toml delete mode 100644 requirements.txt create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_pipeline_bridge.py diff --git a/Dockerfile b/Dockerfile index aab0b16..327d0ec 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,29 +1,13 @@ -FROM python:3.13-slim +# Pipeline Bridge - Using handler-base +ARG BASE_TAG=latest +FROM ghcr.io/billy-davies-2/handler-base:${BASE_TAG} WORKDIR /app -# Install uv for fast, reliable package management -COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv +# Install additional dependencies from pyproject.toml +COPY pyproject.toml . +RUN uv pip install --system --no-cache httpx kubernetes -# Install system dependencies -RUN apt-get update && apt-get install -y --no-install-recommends \ - curl \ - && rm -rf /var/lib/apt/lists/* - -# Copy requirements first for better caching -COPY requirements.txt . -RUN uv pip install --system --no-cache -r requirements.txt - -# Copy application code COPY pipeline_bridge.py . -# Set environment variables -ENV PYTHONUNBUFFERED=1 -ENV PYTHONDONTWRITEBYTECODE=1 - -# Health check -HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ - CMD python -c "print('healthy')" || exit 1 - -# Run the application CMD ["python", "pipeline_bridge.py"] diff --git a/Dockerfile.v2 b/Dockerfile.v2 deleted file mode 100644 index 31934a8..0000000 --- a/Dockerfile.v2 +++ /dev/null @@ -1,12 +0,0 @@ -# Pipeline Bridge v2 - Using handler-base -ARG BASE_TAG=local -FROM ghcr.io/billy-davies-2/handler-base:${BASE_TAG} - -WORKDIR /app - -# Additional dependency for Kubernetes API -RUN uv pip install --system --no-cache kubernetes>=28.0.0 - -COPY pipeline_bridge_v2.py ./pipeline_bridge.py - -CMD ["python", "pipeline_bridge.py"] diff --git a/README.md b/README.md index 4b9ce7f..051ba1e 100644 --- a/README.md +++ b/README.md @@ -61,13 +61,9 @@ The bridge publishes status updates as the workflow progresses: - `failed` - Failed - `error` - System error -## Variants +## Implementation -### pipeline_bridge.py (Standalone) -Self-contained service with pip install on startup. Good for simple deployments. - -### pipeline_bridge_v2.py (handler-base) -Uses handler-base library for standardized NATS handling, telemetry, and health checks. +The pipeline bridge uses the [handler-base](https://git.daviestechlabs.io/daviestechlabs/handler-base) library for standardized NATS handling, telemetry, and health checks. ## Environment Variables @@ -81,11 +77,10 @@ Uses handler-base library for standardized NATS handling, telemetry, and health ## Building ```bash -# Standalone version docker build -t pipeline-bridge:latest . -# handler-base version -docker build -f Dockerfile.v2 -t pipeline-bridge:v2 --build-arg BASE_TAG=latest . +# With specific handler-base tag +docker build --build-arg BASE_TAG=latest -t pipeline-bridge:latest . ``` ## Testing diff --git a/pipeline_bridge.py b/pipeline_bridge.py index 33fbbd1..3b573d4 100644 --- a/pipeline_bridge.py +++ b/pipeline_bridge.py @@ -1,351 +1,241 @@ #!/usr/bin/env python3 """ -Pipeline Bridge Service +Pipeline Bridge Service (Refactored) -Bridges NATS events to workflow engines: +Bridges NATS events to workflow engines using handler-base: 1. Listen for pipeline triggers on "ai.pipeline.trigger" 2. Submit to Kubeflow Pipelines or Argo Workflows 3. Monitor execution and publish status updates 4. Publish completion to "ai.pipeline.status.{request_id}" - -Supported pipelines: -- document-ingestion: Ingest documents into Milvus -- batch-inference: Run batch LLM inference -- model-evaluation: Evaluate model performance """ -import asyncio -import json import logging -import os -import signal -import subprocess -import sys -from typing import Dict, Optional +from typing import Any, Optional from datetime import datetime -# Install dependencies on startup -subprocess.check_call([ - sys.executable, "-m", "pip", "install", "-q", - "-r", "/app/requirements.txt" -]) - import httpx -import nats -from kubernetes import client, config +from nats.aio.msg import Msg + +from handler_base import Handler, Settings +from handler_base.telemetry import create_span -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) logger = logging.getLogger("pipeline-bridge") -# Configuration from environment -NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222") -KUBEFLOW_HOST = os.environ.get("KUBEFLOW_HOST", "http://ml-pipeline.kubeflow.svc.cluster.local:8888") -ARGO_HOST = os.environ.get("ARGO_HOST", "http://argo-server.argo.svc.cluster.local:2746") -ARGO_NAMESPACE = os.environ.get("ARGO_NAMESPACE", "ai-ml") -# NATS subjects -TRIGGER_SUBJECT = "ai.pipeline.trigger" -STATUS_SUBJECT = "ai.pipeline.status" +class PipelineSettings(Settings): + """Pipeline bridge specific settings.""" + + service_name: str = "pipeline-bridge" + + # Kubeflow Pipelines + kubeflow_host: str = "http://ml-pipeline.kubeflow.svc.cluster.local:8888" + + # Argo Workflows + argo_host: str = "http://argo-server.argo.svc.cluster.local:2746" + argo_namespace: str = "ai-ml" -# Pipeline definitions - maps pipeline names to their configurations + +# Pipeline definitions PIPELINES = { "document-ingestion": { "engine": "argo", "template": "document-ingestion", - "description": "Ingest documents into Milvus vector database" + "description": "Ingest documents into Milvus vector database", }, "batch-inference": { "engine": "argo", "template": "batch-inference", - "description": "Run batch LLM inference on a dataset" + "description": "Run batch LLM inference on a dataset", }, "rag-query": { "engine": "kubeflow", "pipeline_id": "rag-pipeline", - "description": "Execute RAG query pipeline" + "description": "Execute RAG query pipeline", }, "voice-pipeline": { "engine": "kubeflow", "pipeline_id": "voice-pipeline", - "description": "Full voice assistant pipeline" - } + "description": "Full voice assistant pipeline", + }, + "model-evaluation": { + "engine": "argo", + "template": "model-evaluation", + "description": "Evaluate model performance", + }, } -class PipelineBridge: +class PipelineBridge(Handler): + """ + Pipeline trigger handler. + + Request format: + { + "request_id": "uuid", + "pipeline": "document-ingestion", + "parameters": {"key": "value"} + } + + Response format: + { + "request_id": "uuid", + "status": "submitted", + "run_id": "workflow-run-id", + "engine": "argo|kubeflow" + } + """ + def __init__(self): - self.nc = None - self.http_client = None - self.running = True - self.active_workflows = {} # Track running workflows - - async def setup(self): - """Initialize all connections.""" - # NATS connection - self.nc = await nats.connect(NATS_URL) - logger.info(f"Connected to NATS at {NATS_URL}") - - # HTTP client for API calls - self.http_client = httpx.AsyncClient(timeout=60.0) - - # Initialize Kubernetes client for Argo - try: - config.load_incluster_config() - self.k8s_custom = client.CustomObjectsApi() - logger.info("Kubernetes client initialized") - except Exception as e: - logger.warning(f"Kubernetes client failed: {e}") - self.k8s_custom = None - - async def submit_argo_workflow(self, template: str, parameters: Dict, request_id: str) -> Optional[str]: - """Submit an Argo Workflow from a WorkflowTemplate.""" - if not self.k8s_custom: - logger.error("Kubernetes client not available") - return None - - try: - # Create workflow from template + self.pipeline_settings = PipelineSettings() + super().__init__( + subject="ai.pipeline.trigger", + settings=self.pipeline_settings, + queue_group="pipeline-bridges", + ) + + self._http: Optional[httpx.AsyncClient] = None + + async def setup(self) -> None: + """Initialize HTTP client.""" + logger.info("Initializing pipeline bridge...") + + self._http = httpx.AsyncClient(timeout=60.0) + + logger.info(f"Pipeline bridge ready. Available pipelines: {list(PIPELINES.keys())}") + + async def teardown(self) -> None: + """Clean up HTTP client.""" + if self._http: + await self._http.aclose() + logger.info("Pipeline bridge closed") + + async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]: + """Handle pipeline trigger request.""" + request_id = data.get("request_id", "unknown") + pipeline_name = data.get("pipeline", "") + parameters = data.get("parameters", {}) + + logger.info(f"Triggering pipeline '{pipeline_name}' for request {request_id}") + + with create_span("pipeline.trigger") as span: + if span: + span.set_attribute("request.id", request_id) + span.set_attribute("pipeline.name", pipeline_name) + + # Validate pipeline + if pipeline_name not in PIPELINES: + error = f"Unknown pipeline: {pipeline_name}" + logger.error(error) + return { + "request_id": request_id, + "status": "error", + "error": error, + "available_pipelines": list(PIPELINES.keys()), + } + + pipeline = PIPELINES[pipeline_name] + engine = pipeline["engine"] + + try: + if engine == "argo": + run_id = await self._submit_argo( + pipeline["template"], parameters, request_id + ) + else: + run_id = await self._submit_kubeflow( + pipeline["pipeline_id"], parameters, request_id + ) + + result = { + "request_id": request_id, + "status": "submitted", + "run_id": run_id, + "engine": engine, + "pipeline": pipeline_name, + "submitted_at": datetime.utcnow().isoformat(), + } + + # Publish status update + await self.nats.publish( + f"ai.pipeline.status.{request_id}", result + ) + + logger.info(f"Pipeline {pipeline_name} submitted: {run_id}") + return result + + except Exception as e: + logger.exception(f"Failed to submit pipeline {pipeline_name}") + return { + "request_id": request_id, + "status": "error", + "error": str(e), + } + + async def _submit_argo( + self, template: str, parameters: dict, request_id: str + ) -> str: + """Submit workflow to Argo Workflows.""" + with create_span("pipeline.submit.argo") as span: + if span: + span.set_attribute("argo.template", template) + workflow = { "apiVersion": "argoproj.io/v1alpha1", "kind": "Workflow", "metadata": { "generateName": f"{template}-", - "namespace": ARGO_NAMESPACE, + "namespace": self.pipeline_settings.argo_namespace, "labels": { - "app.kubernetes.io/managed-by": "pipeline-bridge", - "pipeline-bridge/request-id": request_id - } + "request-id": request_id, + }, }, "spec": { - "workflowTemplateRef": { - "name": template - }, + "workflowTemplateRef": {"name": template}, "arguments": { "parameters": [ - {"name": k, "value": str(v)} + {"name": k, "value": str(v)} for k, v in parameters.items() ] - } - } + }, + }, } - - result = self.k8s_custom.create_namespaced_custom_object( - group="argoproj.io", - version="v1alpha1", - namespace=ARGO_NAMESPACE, - plural="workflows", - body=workflow - ) - workflow_name = result["metadata"]["name"] - logger.info(f"Submitted Argo workflow: {workflow_name}") - return workflow_name - - except Exception as e: - logger.error(f"Failed to submit Argo workflow: {e}") - return None - - async def submit_kubeflow_pipeline(self, pipeline_id: str, parameters: Dict, request_id: str) -> Optional[str]: - """Submit a Kubeflow Pipeline run.""" - try: - # Create pipeline run via Kubeflow API + response = await self._http.post( + f"{self.pipeline_settings.argo_host}/api/v1/workflows/{self.pipeline_settings.argo_namespace}", + json={"workflow": workflow}, + ) + response.raise_for_status() + + result = response.json() + return result["metadata"]["name"] + + async def _submit_kubeflow( + self, pipeline_id: str, parameters: dict, request_id: str + ) -> str: + """Submit run to Kubeflow Pipelines.""" + with create_span("pipeline.submit.kubeflow") as span: + if span: + span.set_attribute("kubeflow.pipeline_id", pipeline_id) + run_request = { "name": f"{pipeline_id}-{request_id[:8]}", "pipeline_spec": { - "pipeline_id": pipeline_id + "pipeline_id": pipeline_id, + "parameters": [ + {"name": k, "value": str(v)} + for k, v in parameters.items() + ], }, - "resource_references": [], - "parameters": [ - {"name": k, "value": str(v)} - for k, v in parameters.items() - ] } - - response = await self.http_client.post( - f"{KUBEFLOW_HOST}/apis/v1beta1/runs", - json=run_request + + response = await self._http.post( + f"{self.pipeline_settings.kubeflow_host}/apis/v1beta1/runs", + json=run_request, ) + response.raise_for_status() - if response.status_code == 200: - result = response.json() - run_id = result.get("run", {}).get("id") - logger.info(f"Submitted Kubeflow pipeline run: {run_id}") - return run_id - else: - logger.error(f"Kubeflow API error: {response.status_code} - {response.text}") - return None - - except Exception as e: - logger.error(f"Failed to submit Kubeflow pipeline: {e}") - return None - - async def get_argo_workflow_status(self, workflow_name: str) -> Dict: - """Get status of an Argo Workflow.""" - if not self.k8s_custom: - return {"phase": "Unknown", "message": "Kubernetes client not available"} - - try: - result = self.k8s_custom.get_namespaced_custom_object( - group="argoproj.io", - version="v1alpha1", - namespace=ARGO_NAMESPACE, - plural="workflows", - name=workflow_name - ) - - status = result.get("status", {}) - return { - "phase": status.get("phase", "Pending"), - "message": status.get("message", ""), - "startedAt": status.get("startedAt"), - "finishedAt": status.get("finishedAt"), - "progress": status.get("progress", "0/0") - } - - except Exception as e: - logger.error(f"Failed to get workflow status: {e}") - return {"phase": "Error", "message": str(e)} - - async def process_trigger(self, msg): - """Process a pipeline trigger request.""" - try: - data = json.loads(msg.data.decode()) - request_id = data.get("request_id", "unknown") - pipeline_name = data.get("pipeline", "") - parameters = data.get("parameters", {}) - - logger.info(f"Processing pipeline trigger {request_id}: {pipeline_name}") - - # Validate pipeline - if pipeline_name not in PIPELINES: - await self.publish_status(request_id, { - "status": "error", - "message": f"Unknown pipeline: {pipeline_name}", - "available_pipelines": list(PIPELINES.keys()) - }) - return - - pipeline_config = PIPELINES[pipeline_name] - engine = pipeline_config["engine"] - - # Submit to appropriate engine - run_id = None - if engine == "argo": - run_id = await self.submit_argo_workflow( - pipeline_config["template"], - parameters, - request_id - ) - elif engine == "kubeflow": - run_id = await self.submit_kubeflow_pipeline( - pipeline_config["pipeline_id"], - parameters, - request_id - ) - - if run_id: - # Track workflow for status updates - self.active_workflows[request_id] = { - "engine": engine, - "run_id": run_id, - "started_at": datetime.utcnow().isoformat() - } - - await self.publish_status(request_id, { - "status": "submitted", - "pipeline": pipeline_name, - "engine": engine, - "run_id": run_id, - "message": f"Pipeline submitted successfully" - }) - else: - await self.publish_status(request_id, { - "status": "error", - "pipeline": pipeline_name, - "message": "Failed to submit pipeline" - }) - - except Exception as e: - logger.error(f"Trigger processing failed: {e}") - await self.publish_status( - data.get("request_id", "unknown"), - {"status": "error", "message": str(e)} - ) - - async def publish_status(self, request_id: str, status: Dict): - """Publish pipeline status update.""" - status["request_id"] = request_id - status["timestamp"] = datetime.utcnow().isoformat() - await self.nc.publish( - f"{STATUS_SUBJECT}.{request_id}", - json.dumps(status).encode() - ) - logger.info(f"Published status for {request_id}: {status.get('status')}") - - async def monitor_workflows(self): - """Periodically check and publish status of active workflows.""" - while self.running: - completed = [] - - for request_id, workflow in self.active_workflows.items(): - try: - if workflow["engine"] == "argo": - status = await self.get_argo_workflow_status(workflow["run_id"]) - - # Publish status update - await self.publish_status(request_id, { - "status": status["phase"].lower(), - "run_id": workflow["run_id"], - "progress": status.get("progress"), - "message": status.get("message", "") - }) - - # Check if completed - if status["phase"] in ["Succeeded", "Failed", "Error"]: - completed.append(request_id) - - except Exception as e: - logger.error(f"Error monitoring workflow {request_id}: {e}") - - # Remove completed workflows from tracking - for request_id in completed: - del self.active_workflows[request_id] - - await asyncio.sleep(10) # Check every 10 seconds - - async def run(self): - """Main run loop.""" - await self.setup() - - # Subscribe to pipeline triggers - sub = await self.nc.subscribe(TRIGGER_SUBJECT, cb=self.process_trigger) - logger.info(f"Subscribed to {TRIGGER_SUBJECT}") - - # Start workflow monitor - monitor_task = asyncio.create_task(self.monitor_workflows()) - - # Handle shutdown - def signal_handler(): - self.running = False - - loop = asyncio.get_event_loop() - for sig in (signal.SIGTERM, signal.SIGINT): - loop.add_signal_handler(sig, signal_handler) - - # Keep running - while self.running: - await asyncio.sleep(1) - - # Cleanup - monitor_task.cancel() - await sub.unsubscribe() - await self.nc.close() - logger.info("Shutdown complete") + result = response.json() + return result["run"]["id"] if __name__ == "__main__": - bridge = PipelineBridge() - asyncio.run(bridge.run()) + PipelineBridge().run() diff --git a/pipeline_bridge_v2.py b/pipeline_bridge_v2.py deleted file mode 100644 index 3b573d4..0000000 --- a/pipeline_bridge_v2.py +++ /dev/null @@ -1,241 +0,0 @@ -#!/usr/bin/env python3 -""" -Pipeline Bridge Service (Refactored) - -Bridges NATS events to workflow engines using handler-base: -1. Listen for pipeline triggers on "ai.pipeline.trigger" -2. Submit to Kubeflow Pipelines or Argo Workflows -3. Monitor execution and publish status updates -4. Publish completion to "ai.pipeline.status.{request_id}" -""" -import logging -from typing import Any, Optional -from datetime import datetime - -import httpx -from nats.aio.msg import Msg - -from handler_base import Handler, Settings -from handler_base.telemetry import create_span - -logger = logging.getLogger("pipeline-bridge") - - -class PipelineSettings(Settings): - """Pipeline bridge specific settings.""" - - service_name: str = "pipeline-bridge" - - # Kubeflow Pipelines - kubeflow_host: str = "http://ml-pipeline.kubeflow.svc.cluster.local:8888" - - # Argo Workflows - argo_host: str = "http://argo-server.argo.svc.cluster.local:2746" - argo_namespace: str = "ai-ml" - - -# Pipeline definitions -PIPELINES = { - "document-ingestion": { - "engine": "argo", - "template": "document-ingestion", - "description": "Ingest documents into Milvus vector database", - }, - "batch-inference": { - "engine": "argo", - "template": "batch-inference", - "description": "Run batch LLM inference on a dataset", - }, - "rag-query": { - "engine": "kubeflow", - "pipeline_id": "rag-pipeline", - "description": "Execute RAG query pipeline", - }, - "voice-pipeline": { - "engine": "kubeflow", - "pipeline_id": "voice-pipeline", - "description": "Full voice assistant pipeline", - }, - "model-evaluation": { - "engine": "argo", - "template": "model-evaluation", - "description": "Evaluate model performance", - }, -} - - -class PipelineBridge(Handler): - """ - Pipeline trigger handler. - - Request format: - { - "request_id": "uuid", - "pipeline": "document-ingestion", - "parameters": {"key": "value"} - } - - Response format: - { - "request_id": "uuid", - "status": "submitted", - "run_id": "workflow-run-id", - "engine": "argo|kubeflow" - } - """ - - def __init__(self): - self.pipeline_settings = PipelineSettings() - super().__init__( - subject="ai.pipeline.trigger", - settings=self.pipeline_settings, - queue_group="pipeline-bridges", - ) - - self._http: Optional[httpx.AsyncClient] = None - - async def setup(self) -> None: - """Initialize HTTP client.""" - logger.info("Initializing pipeline bridge...") - - self._http = httpx.AsyncClient(timeout=60.0) - - logger.info(f"Pipeline bridge ready. Available pipelines: {list(PIPELINES.keys())}") - - async def teardown(self) -> None: - """Clean up HTTP client.""" - if self._http: - await self._http.aclose() - logger.info("Pipeline bridge closed") - - async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]: - """Handle pipeline trigger request.""" - request_id = data.get("request_id", "unknown") - pipeline_name = data.get("pipeline", "") - parameters = data.get("parameters", {}) - - logger.info(f"Triggering pipeline '{pipeline_name}' for request {request_id}") - - with create_span("pipeline.trigger") as span: - if span: - span.set_attribute("request.id", request_id) - span.set_attribute("pipeline.name", pipeline_name) - - # Validate pipeline - if pipeline_name not in PIPELINES: - error = f"Unknown pipeline: {pipeline_name}" - logger.error(error) - return { - "request_id": request_id, - "status": "error", - "error": error, - "available_pipelines": list(PIPELINES.keys()), - } - - pipeline = PIPELINES[pipeline_name] - engine = pipeline["engine"] - - try: - if engine == "argo": - run_id = await self._submit_argo( - pipeline["template"], parameters, request_id - ) - else: - run_id = await self._submit_kubeflow( - pipeline["pipeline_id"], parameters, request_id - ) - - result = { - "request_id": request_id, - "status": "submitted", - "run_id": run_id, - "engine": engine, - "pipeline": pipeline_name, - "submitted_at": datetime.utcnow().isoformat(), - } - - # Publish status update - await self.nats.publish( - f"ai.pipeline.status.{request_id}", result - ) - - logger.info(f"Pipeline {pipeline_name} submitted: {run_id}") - return result - - except Exception as e: - logger.exception(f"Failed to submit pipeline {pipeline_name}") - return { - "request_id": request_id, - "status": "error", - "error": str(e), - } - - async def _submit_argo( - self, template: str, parameters: dict, request_id: str - ) -> str: - """Submit workflow to Argo Workflows.""" - with create_span("pipeline.submit.argo") as span: - if span: - span.set_attribute("argo.template", template) - - workflow = { - "apiVersion": "argoproj.io/v1alpha1", - "kind": "Workflow", - "metadata": { - "generateName": f"{template}-", - "namespace": self.pipeline_settings.argo_namespace, - "labels": { - "request-id": request_id, - }, - }, - "spec": { - "workflowTemplateRef": {"name": template}, - "arguments": { - "parameters": [ - {"name": k, "value": str(v)} - for k, v in parameters.items() - ] - }, - }, - } - - response = await self._http.post( - f"{self.pipeline_settings.argo_host}/api/v1/workflows/{self.pipeline_settings.argo_namespace}", - json={"workflow": workflow}, - ) - response.raise_for_status() - - result = response.json() - return result["metadata"]["name"] - - async def _submit_kubeflow( - self, pipeline_id: str, parameters: dict, request_id: str - ) -> str: - """Submit run to Kubeflow Pipelines.""" - with create_span("pipeline.submit.kubeflow") as span: - if span: - span.set_attribute("kubeflow.pipeline_id", pipeline_id) - - run_request = { - "name": f"{pipeline_id}-{request_id[:8]}", - "pipeline_spec": { - "pipeline_id": pipeline_id, - "parameters": [ - {"name": k, "value": str(v)} - for k, v in parameters.items() - ], - }, - } - - response = await self._http.post( - f"{self.pipeline_settings.kubeflow_host}/apis/v1beta1/runs", - json=run_request, - ) - response.raise_for_status() - - result = response.json() - return result["run"]["id"] - - -if __name__ == "__main__": - PipelineBridge().run() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..6d5da00 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,42 @@ +[project] +name = "pipeline-bridge" +version = "1.0.0" +description = "Bridge NATS events to Kubeflow Pipelines and Argo Workflows" +readme = "README.md" +requires-python = ">=3.11" +license = { text = "MIT" } +authors = [{ name = "Davies Tech Labs" }] + +dependencies = [ + "handler-base @ git+https://git.daviestechlabs.io/daviestechlabs/handler-base.git", + "httpx>=0.27.0", + "kubernetes>=28.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", + "ruff>=0.1.0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["."] +only-include = ["pipeline_bridge.py"] + +[tool.ruff] +line-length = 100 +target-version = "py311" + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = "-v --tb=short" +filterwarnings = ["ignore::DeprecationWarning"] diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index cf6b7b2..0000000 --- a/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -nats-py -httpx -kubernetes diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..a4a56cb --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Pipeline Bridge Tests diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..53bad0f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,90 @@ +""" +Pytest configuration and fixtures for pipeline-bridge tests. +""" +import asyncio +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Set test environment variables before importing +os.environ.setdefault("NATS_URL", "nats://localhost:4222") +os.environ.setdefault("OTEL_ENABLED", "false") +os.environ.setdefault("MLFLOW_ENABLED", "false") + + +@pytest.fixture(scope="session") +def event_loop(): + """Create event loop for async tests.""" + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture +def mock_nats_message(): + """Create a mock NATS message.""" + msg = MagicMock() + msg.subject = "ai.pipeline.trigger" + msg.reply = "ai.pipeline.status.test-123" + return msg + + +@pytest.fixture +def argo_pipeline_request(): + """Sample Argo pipeline trigger request.""" + return { + "request_id": "test-request-123", + "pipeline": "document-ingestion", + "parameters": { + "source_path": "s3://bucket/documents", + "collection_name": "test_collection", + }, + } + + +@pytest.fixture +def kubeflow_pipeline_request(): + """Sample Kubeflow pipeline trigger request.""" + return { + "request_id": "test-request-456", + "pipeline": "rag-query", + "parameters": { + "query": "What is AI?", + "collection": "documents", + }, + } + + +@pytest.fixture +def unknown_pipeline_request(): + """Request for unknown pipeline.""" + return { + "request_id": "test-request-789", + "pipeline": "nonexistent-pipeline", + "parameters": {}, + } + + +@pytest.fixture +def mock_argo_response(): + """Mock Argo Workflows API response.""" + return { + "metadata": { + "name": "document-ingestion-abc123", + "namespace": "ai-ml", + }, + "status": {"phase": "Pending"}, + } + + +@pytest.fixture +def mock_kubeflow_response(): + """Mock Kubeflow Pipelines API response.""" + return { + "run": { + "id": "run-xyz-789", + "name": "rag-query-test", + "status": "Running", + } + } diff --git a/tests/test_pipeline_bridge.py b/tests/test_pipeline_bridge.py new file mode 100644 index 0000000..1e819d4 --- /dev/null +++ b/tests/test_pipeline_bridge.py @@ -0,0 +1,271 @@ +""" +Unit tests for PipelineBridge handler. +""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from pipeline_bridge import PipelineBridge, PipelineSettings, PIPELINES + + +class TestPipelineSettings: + """Tests for PipelineSettings configuration.""" + + def test_default_settings(self): + """Test default settings values.""" + settings = PipelineSettings() + + assert settings.service_name == "pipeline-bridge" + assert settings.kubeflow_host == "http://ml-pipeline.kubeflow.svc.cluster.local:8888" + assert settings.argo_host == "http://argo-server.argo.svc.cluster.local:2746" + assert settings.argo_namespace == "ai-ml" + + def test_custom_settings(self): + """Test custom settings.""" + settings = PipelineSettings( + kubeflow_host="http://custom-kubeflow:8888", + argo_namespace="custom-ns", + ) + + assert settings.kubeflow_host == "http://custom-kubeflow:8888" + assert settings.argo_namespace == "custom-ns" + + +class TestPipelineDefinitions: + """Tests for pipeline definitions.""" + + def test_required_pipelines_exist(self): + """Test that required pipelines are defined.""" + required = ["document-ingestion", "batch-inference", "rag-query", "voice-pipeline"] + for name in required: + assert name in PIPELINES, f"Pipeline {name} should be defined" + + def test_argo_pipelines_have_template(self): + """Test Argo pipelines have template field.""" + for name, config in PIPELINES.items(): + if config["engine"] == "argo": + assert "template" in config, f"Argo pipeline {name} missing template" + + def test_kubeflow_pipelines_have_pipeline_id(self): + """Test Kubeflow pipelines have pipeline_id field.""" + for name, config in PIPELINES.items(): + if config["engine"] == "kubeflow": + assert "pipeline_id" in config, f"Kubeflow pipeline {name} missing pipeline_id" + + def test_all_pipelines_have_description(self): + """Test all pipelines have descriptions.""" + for name, config in PIPELINES.items(): + assert "description" in config, f"Pipeline {name} missing description" + + +class TestPipelineBridge: + """Tests for PipelineBridge handler.""" + + @pytest.fixture + def handler(self): + """Create handler with mocked HTTP client.""" + handler = PipelineBridge() + handler._http = AsyncMock() + handler.nats = AsyncMock() + return handler + + def test_init(self, handler): + """Test handler initialization.""" + assert handler.subject == "ai.pipeline.trigger" + assert handler.queue_group == "pipeline-bridges" + assert handler.pipeline_settings.service_name == "pipeline-bridge" + + @pytest.mark.asyncio + async def test_handle_unknown_pipeline( + self, + handler, + mock_nats_message, + unknown_pipeline_request, + ): + """Test handling unknown pipeline.""" + result = await handler.handle_message(mock_nats_message, unknown_pipeline_request) + + assert result["status"] == "error" + assert "Unknown pipeline" in result["error"] + assert "available_pipelines" in result + assert "document-ingestion" in result["available_pipelines"] + + @pytest.mark.asyncio + async def test_handle_argo_pipeline( + self, + handler, + mock_nats_message, + argo_pipeline_request, + mock_argo_response, + ): + """Test triggering Argo workflow.""" + # Setup mock response + mock_response = MagicMock() + mock_response.json.return_value = mock_argo_response + mock_response.raise_for_status = MagicMock() + handler._http.post.return_value = mock_response + + result = await handler.handle_message(mock_nats_message, argo_pipeline_request) + + assert result["status"] == "submitted" + assert result["engine"] == "argo" + assert result["run_id"] == "document-ingestion-abc123" + assert result["pipeline"] == "document-ingestion" + assert "submitted_at" in result + + # Verify API call + handler._http.post.assert_called_once() + call_args = handler._http.post.call_args + assert "argo-server" in str(call_args) + + @pytest.mark.asyncio + async def test_handle_kubeflow_pipeline( + self, + handler, + mock_nats_message, + kubeflow_pipeline_request, + mock_kubeflow_response, + ): + """Test triggering Kubeflow pipeline.""" + # Setup mock response + mock_response = MagicMock() + mock_response.json.return_value = mock_kubeflow_response + mock_response.raise_for_status = MagicMock() + handler._http.post.return_value = mock_response + + result = await handler.handle_message(mock_nats_message, kubeflow_pipeline_request) + + assert result["status"] == "submitted" + assert result["engine"] == "kubeflow" + assert result["run_id"] == "run-xyz-789" + assert result["pipeline"] == "rag-query" + + # Verify API call + handler._http.post.assert_called_once() + call_args = handler._http.post.call_args + assert "ml-pipeline" in str(call_args) + + @pytest.mark.asyncio + async def test_handle_api_error( + self, + handler, + mock_nats_message, + argo_pipeline_request, + ): + """Test handling API errors.""" + handler._http.post.side_effect = Exception("Connection refused") + + result = await handler.handle_message(mock_nats_message, argo_pipeline_request) + + assert result["status"] == "error" + assert "Connection refused" in result["error"] + + @pytest.mark.asyncio + async def test_publishes_status_update( + self, + handler, + mock_nats_message, + argo_pipeline_request, + mock_argo_response, + ): + """Test that status is published to NATS.""" + mock_response = MagicMock() + mock_response.json.return_value = mock_argo_response + mock_response.raise_for_status = MagicMock() + handler._http.post.return_value = mock_response + + await handler.handle_message(mock_nats_message, argo_pipeline_request) + + handler.nats.publish.assert_called_once() + call_args = handler.nats.publish.call_args + assert "ai.pipeline.status.test-request-123" in str(call_args) + + @pytest.mark.asyncio + async def test_setup_creates_http_client(self): + """Test that setup initializes HTTP client.""" + with patch("pipeline_bridge.httpx.AsyncClient") as mock_client: + handler = PipelineBridge() + await handler.setup() + + mock_client.assert_called_once() + + @pytest.mark.asyncio + async def test_teardown_closes_http_client(self, handler): + """Test that teardown closes HTTP client.""" + await handler.teardown() + + handler._http.aclose.assert_called_once() + + +class TestArgoSubmission: + """Tests for Argo workflow submission.""" + + @pytest.fixture + def handler(self): + """Create handler with mocked HTTP client.""" + handler = PipelineBridge() + handler._http = AsyncMock() + return handler + + @pytest.mark.asyncio + async def test_argo_workflow_structure( + self, + handler, + mock_argo_response, + ): + """Test Argo workflow request structure.""" + mock_response = MagicMock() + mock_response.json.return_value = mock_argo_response + mock_response.raise_for_status = MagicMock() + handler._http.post.return_value = mock_response + + await handler._submit_argo( + template="document-ingestion", + parameters={"key": "value"}, + request_id="test-123", + ) + + # Verify workflow structure + call_kwargs = handler._http.post.call_args.kwargs + workflow = call_kwargs["json"]["workflow"] + + assert workflow["apiVersion"] == "argoproj.io/v1alpha1" + assert workflow["kind"] == "Workflow" + assert "workflowTemplateRef" in workflow["spec"] + assert workflow["spec"]["workflowTemplateRef"]["name"] == "document-ingestion" + assert workflow["metadata"]["labels"]["request-id"] == "test-123" + + +class TestKubeflowSubmission: + """Tests for Kubeflow pipeline submission.""" + + @pytest.fixture + def handler(self): + """Create handler with mocked HTTP client.""" + handler = PipelineBridge() + handler._http = AsyncMock() + return handler + + @pytest.mark.asyncio + async def test_kubeflow_run_structure( + self, + handler, + mock_kubeflow_response, + ): + """Test Kubeflow run request structure.""" + mock_response = MagicMock() + mock_response.json.return_value = mock_kubeflow_response + mock_response.raise_for_status = MagicMock() + handler._http.post.return_value = mock_response + + await handler._submit_kubeflow( + pipeline_id="rag-pipeline", + parameters={"query": "test"}, + request_id="test-456", + ) + + # Verify run request structure + call_kwargs = handler._http.post.call_args.kwargs + run_request = call_kwargs["json"] + + assert "rag-pipeline" in run_request["name"] + assert run_request["pipeline_spec"]["pipeline_id"] == "rag-pipeline"