refactor: consolidate to handler-base, migrate to pyproject.toml, add tests

This commit is contained in:
2026-02-02 07:11:08 -05:00
parent 50b1835688
commit 3f1e05eaad
10 changed files with 594 additions and 577 deletions

View File

@@ -1,29 +1,13 @@
FROM python:3.13-slim # Pipeline Bridge - Using handler-base
ARG BASE_TAG=latest
FROM ghcr.io/billy-davies-2/handler-base:${BASE_TAG}
WORKDIR /app WORKDIR /app
# Install uv for fast, reliable package management # Install additional dependencies from pyproject.toml
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv COPY pyproject.toml .
RUN uv pip install --system --no-cache httpx kubernetes
# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements first for better caching
COPY requirements.txt .
RUN uv pip install --system --no-cache -r requirements.txt
# Copy application code
COPY pipeline_bridge.py . COPY pipeline_bridge.py .
# Set environment variables
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD python -c "print('healthy')" || exit 1
# Run the application
CMD ["python", "pipeline_bridge.py"] CMD ["python", "pipeline_bridge.py"]

View File

@@ -1,12 +0,0 @@
# Pipeline Bridge v2 - Using handler-base
ARG BASE_TAG=local
FROM ghcr.io/billy-davies-2/handler-base:${BASE_TAG}
WORKDIR /app
# Additional dependency for Kubernetes API
RUN uv pip install --system --no-cache kubernetes>=28.0.0
COPY pipeline_bridge_v2.py ./pipeline_bridge.py
CMD ["python", "pipeline_bridge.py"]

View File

@@ -61,13 +61,9 @@ The bridge publishes status updates as the workflow progresses:
- `failed` - Failed - `failed` - Failed
- `error` - System error - `error` - System error
## Variants ## Implementation
### pipeline_bridge.py (Standalone) The pipeline bridge uses the [handler-base](https://git.daviestechlabs.io/daviestechlabs/handler-base) library for standardized NATS handling, telemetry, and health checks.
Self-contained service with pip install on startup. Good for simple deployments.
### pipeline_bridge_v2.py (handler-base)
Uses handler-base library for standardized NATS handling, telemetry, and health checks.
## Environment Variables ## Environment Variables
@@ -81,11 +77,10 @@ Uses handler-base library for standardized NATS handling, telemetry, and health
## Building ## Building
```bash ```bash
# Standalone version
docker build -t pipeline-bridge:latest . docker build -t pipeline-bridge:latest .
# handler-base version # With specific handler-base tag
docker build -f Dockerfile.v2 -t pipeline-bridge:v2 --build-arg BASE_TAG=latest . docker build --build-arg BASE_TAG=latest -t pipeline-bridge:latest .
``` ```
## Testing ## Testing

View File

@@ -1,351 +1,241 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
Pipeline Bridge Service Pipeline Bridge Service (Refactored)
Bridges NATS events to workflow engines: Bridges NATS events to workflow engines using handler-base:
1. Listen for pipeline triggers on "ai.pipeline.trigger" 1. Listen for pipeline triggers on "ai.pipeline.trigger"
2. Submit to Kubeflow Pipelines or Argo Workflows 2. Submit to Kubeflow Pipelines or Argo Workflows
3. Monitor execution and publish status updates 3. Monitor execution and publish status updates
4. Publish completion to "ai.pipeline.status.{request_id}" 4. Publish completion to "ai.pipeline.status.{request_id}"
Supported pipelines:
- document-ingestion: Ingest documents into Milvus
- batch-inference: Run batch LLM inference
- model-evaluation: Evaluate model performance
""" """
import asyncio
import json
import logging import logging
import os from typing import Any, Optional
import signal
import subprocess
import sys
from typing import Dict, Optional
from datetime import datetime from datetime import datetime
# Install dependencies on startup
subprocess.check_call([
sys.executable, "-m", "pip", "install", "-q",
"-r", "/app/requirements.txt"
])
import httpx import httpx
import nats from nats.aio.msg import Msg
from kubernetes import client, config
from handler_base import Handler, Settings
from handler_base.telemetry import create_span
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("pipeline-bridge") logger = logging.getLogger("pipeline-bridge")
# Configuration from environment
NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222")
KUBEFLOW_HOST = os.environ.get("KUBEFLOW_HOST", "http://ml-pipeline.kubeflow.svc.cluster.local:8888")
ARGO_HOST = os.environ.get("ARGO_HOST", "http://argo-server.argo.svc.cluster.local:2746")
ARGO_NAMESPACE = os.environ.get("ARGO_NAMESPACE", "ai-ml")
# NATS subjects class PipelineSettings(Settings):
TRIGGER_SUBJECT = "ai.pipeline.trigger" """Pipeline bridge specific settings."""
STATUS_SUBJECT = "ai.pipeline.status"
# Pipeline definitions - maps pipeline names to their configurations service_name: str = "pipeline-bridge"
# Kubeflow Pipelines
kubeflow_host: str = "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
# Argo Workflows
argo_host: str = "http://argo-server.argo.svc.cluster.local:2746"
argo_namespace: str = "ai-ml"
# Pipeline definitions
PIPELINES = { PIPELINES = {
"document-ingestion": { "document-ingestion": {
"engine": "argo", "engine": "argo",
"template": "document-ingestion", "template": "document-ingestion",
"description": "Ingest documents into Milvus vector database" "description": "Ingest documents into Milvus vector database",
}, },
"batch-inference": { "batch-inference": {
"engine": "argo", "engine": "argo",
"template": "batch-inference", "template": "batch-inference",
"description": "Run batch LLM inference on a dataset" "description": "Run batch LLM inference on a dataset",
}, },
"rag-query": { "rag-query": {
"engine": "kubeflow", "engine": "kubeflow",
"pipeline_id": "rag-pipeline", "pipeline_id": "rag-pipeline",
"description": "Execute RAG query pipeline" "description": "Execute RAG query pipeline",
}, },
"voice-pipeline": { "voice-pipeline": {
"engine": "kubeflow", "engine": "kubeflow",
"pipeline_id": "voice-pipeline", "pipeline_id": "voice-pipeline",
"description": "Full voice assistant pipeline" "description": "Full voice assistant pipeline",
} },
"model-evaluation": {
"engine": "argo",
"template": "model-evaluation",
"description": "Evaluate model performance",
},
} }
class PipelineBridge: class PipelineBridge(Handler):
"""
Pipeline trigger handler.
Request format:
{
"request_id": "uuid",
"pipeline": "document-ingestion",
"parameters": {"key": "value"}
}
Response format:
{
"request_id": "uuid",
"status": "submitted",
"run_id": "workflow-run-id",
"engine": "argo|kubeflow"
}
"""
def __init__(self): def __init__(self):
self.nc = None self.pipeline_settings = PipelineSettings()
self.http_client = None super().__init__(
self.running = True subject="ai.pipeline.trigger",
self.active_workflows = {} # Track running workflows settings=self.pipeline_settings,
queue_group="pipeline-bridges",
)
async def setup(self): self._http: Optional[httpx.AsyncClient] = None
"""Initialize all connections."""
# NATS connection
self.nc = await nats.connect(NATS_URL)
logger.info(f"Connected to NATS at {NATS_URL}")
# HTTP client for API calls async def setup(self) -> None:
self.http_client = httpx.AsyncClient(timeout=60.0) """Initialize HTTP client."""
logger.info("Initializing pipeline bridge...")
self._http = httpx.AsyncClient(timeout=60.0)
logger.info(f"Pipeline bridge ready. Available pipelines: {list(PIPELINES.keys())}")
async def teardown(self) -> None:
"""Clean up HTTP client."""
if self._http:
await self._http.aclose()
logger.info("Pipeline bridge closed")
async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]:
"""Handle pipeline trigger request."""
request_id = data.get("request_id", "unknown")
pipeline_name = data.get("pipeline", "")
parameters = data.get("parameters", {})
logger.info(f"Triggering pipeline '{pipeline_name}' for request {request_id}")
with create_span("pipeline.trigger") as span:
if span:
span.set_attribute("request.id", request_id)
span.set_attribute("pipeline.name", pipeline_name)
# Validate pipeline
if pipeline_name not in PIPELINES:
error = f"Unknown pipeline: {pipeline_name}"
logger.error(error)
return {
"request_id": request_id,
"status": "error",
"error": error,
"available_pipelines": list(PIPELINES.keys()),
}
pipeline = PIPELINES[pipeline_name]
engine = pipeline["engine"]
# Initialize Kubernetes client for Argo
try: try:
config.load_incluster_config() if engine == "argo":
self.k8s_custom = client.CustomObjectsApi() run_id = await self._submit_argo(
logger.info("Kubernetes client initialized") pipeline["template"], parameters, request_id
)
else:
run_id = await self._submit_kubeflow(
pipeline["pipeline_id"], parameters, request_id
)
result = {
"request_id": request_id,
"status": "submitted",
"run_id": run_id,
"engine": engine,
"pipeline": pipeline_name,
"submitted_at": datetime.utcnow().isoformat(),
}
# Publish status update
await self.nats.publish(
f"ai.pipeline.status.{request_id}", result
)
logger.info(f"Pipeline {pipeline_name} submitted: {run_id}")
return result
except Exception as e: except Exception as e:
logger.warning(f"Kubernetes client failed: {e}") logger.exception(f"Failed to submit pipeline {pipeline_name}")
self.k8s_custom = None return {
"request_id": request_id,
"status": "error",
"error": str(e),
}
async def submit_argo_workflow(self, template: str, parameters: Dict, request_id: str) -> Optional[str]: async def _submit_argo(
"""Submit an Argo Workflow from a WorkflowTemplate.""" self, template: str, parameters: dict, request_id: str
if not self.k8s_custom: ) -> str:
logger.error("Kubernetes client not available") """Submit workflow to Argo Workflows."""
return None with create_span("pipeline.submit.argo") as span:
if span:
span.set_attribute("argo.template", template)
try:
# Create workflow from template
workflow = { workflow = {
"apiVersion": "argoproj.io/v1alpha1", "apiVersion": "argoproj.io/v1alpha1",
"kind": "Workflow", "kind": "Workflow",
"metadata": { "metadata": {
"generateName": f"{template}-", "generateName": f"{template}-",
"namespace": ARGO_NAMESPACE, "namespace": self.pipeline_settings.argo_namespace,
"labels": { "labels": {
"app.kubernetes.io/managed-by": "pipeline-bridge", "request-id": request_id,
"pipeline-bridge/request-id": request_id },
}
}, },
"spec": { "spec": {
"workflowTemplateRef": { "workflowTemplateRef": {"name": template},
"name": template
},
"arguments": { "arguments": {
"parameters": [ "parameters": [
{"name": k, "value": str(v)} {"name": k, "value": str(v)}
for k, v in parameters.items() for k, v in parameters.items()
] ]
} },
} },
} }
result = self.k8s_custom.create_namespaced_custom_object( response = await self._http.post(
group="argoproj.io", f"{self.pipeline_settings.argo_host}/api/v1/workflows/{self.pipeline_settings.argo_namespace}",
version="v1alpha1", json={"workflow": workflow},
namespace=ARGO_NAMESPACE,
plural="workflows",
body=workflow
) )
response.raise_for_status()
workflow_name = result["metadata"]["name"] result = response.json()
logger.info(f"Submitted Argo workflow: {workflow_name}") return result["metadata"]["name"]
return workflow_name
except Exception as e: async def _submit_kubeflow(
logger.error(f"Failed to submit Argo workflow: {e}") self, pipeline_id: str, parameters: dict, request_id: str
return None ) -> str:
"""Submit run to Kubeflow Pipelines."""
with create_span("pipeline.submit.kubeflow") as span:
if span:
span.set_attribute("kubeflow.pipeline_id", pipeline_id)
async def submit_kubeflow_pipeline(self, pipeline_id: str, parameters: Dict, request_id: str) -> Optional[str]:
"""Submit a Kubeflow Pipeline run."""
try:
# Create pipeline run via Kubeflow API
run_request = { run_request = {
"name": f"{pipeline_id}-{request_id[:8]}", "name": f"{pipeline_id}-{request_id[:8]}",
"pipeline_spec": { "pipeline_spec": {
"pipeline_id": pipeline_id "pipeline_id": pipeline_id,
},
"resource_references": [],
"parameters": [ "parameters": [
{"name": k, "value": str(v)} {"name": k, "value": str(v)}
for k, v in parameters.items() for k, v in parameters.items()
] ],
},
} }
response = await self.http_client.post( response = await self._http.post(
f"{KUBEFLOW_HOST}/apis/v1beta1/runs", f"{self.pipeline_settings.kubeflow_host}/apis/v1beta1/runs",
json=run_request json=run_request,
) )
response.raise_for_status()
if response.status_code == 200:
result = response.json() result = response.json()
run_id = result.get("run", {}).get("id") return result["run"]["id"]
logger.info(f"Submitted Kubeflow pipeline run: {run_id}")
return run_id
else:
logger.error(f"Kubeflow API error: {response.status_code} - {response.text}")
return None
except Exception as e:
logger.error(f"Failed to submit Kubeflow pipeline: {e}")
return None
async def get_argo_workflow_status(self, workflow_name: str) -> Dict:
"""Get status of an Argo Workflow."""
if not self.k8s_custom:
return {"phase": "Unknown", "message": "Kubernetes client not available"}
try:
result = self.k8s_custom.get_namespaced_custom_object(
group="argoproj.io",
version="v1alpha1",
namespace=ARGO_NAMESPACE,
plural="workflows",
name=workflow_name
)
status = result.get("status", {})
return {
"phase": status.get("phase", "Pending"),
"message": status.get("message", ""),
"startedAt": status.get("startedAt"),
"finishedAt": status.get("finishedAt"),
"progress": status.get("progress", "0/0")
}
except Exception as e:
logger.error(f"Failed to get workflow status: {e}")
return {"phase": "Error", "message": str(e)}
async def process_trigger(self, msg):
"""Process a pipeline trigger request."""
try:
data = json.loads(msg.data.decode())
request_id = data.get("request_id", "unknown")
pipeline_name = data.get("pipeline", "")
parameters = data.get("parameters", {})
logger.info(f"Processing pipeline trigger {request_id}: {pipeline_name}")
# Validate pipeline
if pipeline_name not in PIPELINES:
await self.publish_status(request_id, {
"status": "error",
"message": f"Unknown pipeline: {pipeline_name}",
"available_pipelines": list(PIPELINES.keys())
})
return
pipeline_config = PIPELINES[pipeline_name]
engine = pipeline_config["engine"]
# Submit to appropriate engine
run_id = None
if engine == "argo":
run_id = await self.submit_argo_workflow(
pipeline_config["template"],
parameters,
request_id
)
elif engine == "kubeflow":
run_id = await self.submit_kubeflow_pipeline(
pipeline_config["pipeline_id"],
parameters,
request_id
)
if run_id:
# Track workflow for status updates
self.active_workflows[request_id] = {
"engine": engine,
"run_id": run_id,
"started_at": datetime.utcnow().isoformat()
}
await self.publish_status(request_id, {
"status": "submitted",
"pipeline": pipeline_name,
"engine": engine,
"run_id": run_id,
"message": f"Pipeline submitted successfully"
})
else:
await self.publish_status(request_id, {
"status": "error",
"pipeline": pipeline_name,
"message": "Failed to submit pipeline"
})
except Exception as e:
logger.error(f"Trigger processing failed: {e}")
await self.publish_status(
data.get("request_id", "unknown"),
{"status": "error", "message": str(e)}
)
async def publish_status(self, request_id: str, status: Dict):
"""Publish pipeline status update."""
status["request_id"] = request_id
status["timestamp"] = datetime.utcnow().isoformat()
await self.nc.publish(
f"{STATUS_SUBJECT}.{request_id}",
json.dumps(status).encode()
)
logger.info(f"Published status for {request_id}: {status.get('status')}")
async def monitor_workflows(self):
"""Periodically check and publish status of active workflows."""
while self.running:
completed = []
for request_id, workflow in self.active_workflows.items():
try:
if workflow["engine"] == "argo":
status = await self.get_argo_workflow_status(workflow["run_id"])
# Publish status update
await self.publish_status(request_id, {
"status": status["phase"].lower(),
"run_id": workflow["run_id"],
"progress": status.get("progress"),
"message": status.get("message", "")
})
# Check if completed
if status["phase"] in ["Succeeded", "Failed", "Error"]:
completed.append(request_id)
except Exception as e:
logger.error(f"Error monitoring workflow {request_id}: {e}")
# Remove completed workflows from tracking
for request_id in completed:
del self.active_workflows[request_id]
await asyncio.sleep(10) # Check every 10 seconds
async def run(self):
"""Main run loop."""
await self.setup()
# Subscribe to pipeline triggers
sub = await self.nc.subscribe(TRIGGER_SUBJECT, cb=self.process_trigger)
logger.info(f"Subscribed to {TRIGGER_SUBJECT}")
# Start workflow monitor
monitor_task = asyncio.create_task(self.monitor_workflows())
# Handle shutdown
def signal_handler():
self.running = False
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
# Keep running
while self.running:
await asyncio.sleep(1)
# Cleanup
monitor_task.cancel()
await sub.unsubscribe()
await self.nc.close()
logger.info("Shutdown complete")
if __name__ == "__main__": if __name__ == "__main__":
bridge = PipelineBridge() PipelineBridge().run()
asyncio.run(bridge.run())

View File

@@ -1,241 +0,0 @@
#!/usr/bin/env python3
"""
Pipeline Bridge Service (Refactored)
Bridges NATS events to workflow engines using handler-base:
1. Listen for pipeline triggers on "ai.pipeline.trigger"
2. Submit to Kubeflow Pipelines or Argo Workflows
3. Monitor execution and publish status updates
4. Publish completion to "ai.pipeline.status.{request_id}"
"""
import logging
from typing import Any, Optional
from datetime import datetime
import httpx
from nats.aio.msg import Msg
from handler_base import Handler, Settings
from handler_base.telemetry import create_span
logger = logging.getLogger("pipeline-bridge")
class PipelineSettings(Settings):
"""Pipeline bridge specific settings."""
service_name: str = "pipeline-bridge"
# Kubeflow Pipelines
kubeflow_host: str = "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
# Argo Workflows
argo_host: str = "http://argo-server.argo.svc.cluster.local:2746"
argo_namespace: str = "ai-ml"
# Pipeline definitions
PIPELINES = {
"document-ingestion": {
"engine": "argo",
"template": "document-ingestion",
"description": "Ingest documents into Milvus vector database",
},
"batch-inference": {
"engine": "argo",
"template": "batch-inference",
"description": "Run batch LLM inference on a dataset",
},
"rag-query": {
"engine": "kubeflow",
"pipeline_id": "rag-pipeline",
"description": "Execute RAG query pipeline",
},
"voice-pipeline": {
"engine": "kubeflow",
"pipeline_id": "voice-pipeline",
"description": "Full voice assistant pipeline",
},
"model-evaluation": {
"engine": "argo",
"template": "model-evaluation",
"description": "Evaluate model performance",
},
}
class PipelineBridge(Handler):
"""
Pipeline trigger handler.
Request format:
{
"request_id": "uuid",
"pipeline": "document-ingestion",
"parameters": {"key": "value"}
}
Response format:
{
"request_id": "uuid",
"status": "submitted",
"run_id": "workflow-run-id",
"engine": "argo|kubeflow"
}
"""
def __init__(self):
self.pipeline_settings = PipelineSettings()
super().__init__(
subject="ai.pipeline.trigger",
settings=self.pipeline_settings,
queue_group="pipeline-bridges",
)
self._http: Optional[httpx.AsyncClient] = None
async def setup(self) -> None:
"""Initialize HTTP client."""
logger.info("Initializing pipeline bridge...")
self._http = httpx.AsyncClient(timeout=60.0)
logger.info(f"Pipeline bridge ready. Available pipelines: {list(PIPELINES.keys())}")
async def teardown(self) -> None:
"""Clean up HTTP client."""
if self._http:
await self._http.aclose()
logger.info("Pipeline bridge closed")
async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]:
"""Handle pipeline trigger request."""
request_id = data.get("request_id", "unknown")
pipeline_name = data.get("pipeline", "")
parameters = data.get("parameters", {})
logger.info(f"Triggering pipeline '{pipeline_name}' for request {request_id}")
with create_span("pipeline.trigger") as span:
if span:
span.set_attribute("request.id", request_id)
span.set_attribute("pipeline.name", pipeline_name)
# Validate pipeline
if pipeline_name not in PIPELINES:
error = f"Unknown pipeline: {pipeline_name}"
logger.error(error)
return {
"request_id": request_id,
"status": "error",
"error": error,
"available_pipelines": list(PIPELINES.keys()),
}
pipeline = PIPELINES[pipeline_name]
engine = pipeline["engine"]
try:
if engine == "argo":
run_id = await self._submit_argo(
pipeline["template"], parameters, request_id
)
else:
run_id = await self._submit_kubeflow(
pipeline["pipeline_id"], parameters, request_id
)
result = {
"request_id": request_id,
"status": "submitted",
"run_id": run_id,
"engine": engine,
"pipeline": pipeline_name,
"submitted_at": datetime.utcnow().isoformat(),
}
# Publish status update
await self.nats.publish(
f"ai.pipeline.status.{request_id}", result
)
logger.info(f"Pipeline {pipeline_name} submitted: {run_id}")
return result
except Exception as e:
logger.exception(f"Failed to submit pipeline {pipeline_name}")
return {
"request_id": request_id,
"status": "error",
"error": str(e),
}
async def _submit_argo(
self, template: str, parameters: dict, request_id: str
) -> str:
"""Submit workflow to Argo Workflows."""
with create_span("pipeline.submit.argo") as span:
if span:
span.set_attribute("argo.template", template)
workflow = {
"apiVersion": "argoproj.io/v1alpha1",
"kind": "Workflow",
"metadata": {
"generateName": f"{template}-",
"namespace": self.pipeline_settings.argo_namespace,
"labels": {
"request-id": request_id,
},
},
"spec": {
"workflowTemplateRef": {"name": template},
"arguments": {
"parameters": [
{"name": k, "value": str(v)}
for k, v in parameters.items()
]
},
},
}
response = await self._http.post(
f"{self.pipeline_settings.argo_host}/api/v1/workflows/{self.pipeline_settings.argo_namespace}",
json={"workflow": workflow},
)
response.raise_for_status()
result = response.json()
return result["metadata"]["name"]
async def _submit_kubeflow(
self, pipeline_id: str, parameters: dict, request_id: str
) -> str:
"""Submit run to Kubeflow Pipelines."""
with create_span("pipeline.submit.kubeflow") as span:
if span:
span.set_attribute("kubeflow.pipeline_id", pipeline_id)
run_request = {
"name": f"{pipeline_id}-{request_id[:8]}",
"pipeline_spec": {
"pipeline_id": pipeline_id,
"parameters": [
{"name": k, "value": str(v)}
for k, v in parameters.items()
],
},
}
response = await self._http.post(
f"{self.pipeline_settings.kubeflow_host}/apis/v1beta1/runs",
json=run_request,
)
response.raise_for_status()
result = response.json()
return result["run"]["id"]
if __name__ == "__main__":
PipelineBridge().run()

42
pyproject.toml Normal file
View File

@@ -0,0 +1,42 @@
[project]
name = "pipeline-bridge"
version = "1.0.0"
description = "Bridge NATS events to Kubeflow Pipelines and Argo Workflows"
readme = "README.md"
requires-python = ">=3.11"
license = { text = "MIT" }
authors = [{ name = "Davies Tech Labs" }]
dependencies = [
"handler-base @ git+https://git.daviestechlabs.io/daviestechlabs/handler-base.git",
"httpx>=0.27.0",
"kubernetes>=28.0.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.23.0",
"ruff>=0.1.0",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["."]
only-include = ["pipeline_bridge.py"]
[tool.ruff]
line-length = 100
target-version = "py311"
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = "-v --tb=short"
filterwarnings = ["ignore::DeprecationWarning"]

View File

@@ -1,3 +0,0 @@
nats-py
httpx
kubernetes

1
tests/__init__.py Normal file
View File

@@ -0,0 +1 @@
# Pipeline Bridge Tests

90
tests/conftest.py Normal file
View File

@@ -0,0 +1,90 @@
"""
Pytest configuration and fixtures for pipeline-bridge tests.
"""
import asyncio
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# Set test environment variables before importing
os.environ.setdefault("NATS_URL", "nats://localhost:4222")
os.environ.setdefault("OTEL_ENABLED", "false")
os.environ.setdefault("MLFLOW_ENABLED", "false")
@pytest.fixture(scope="session")
def event_loop():
"""Create event loop for async tests."""
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture
def mock_nats_message():
"""Create a mock NATS message."""
msg = MagicMock()
msg.subject = "ai.pipeline.trigger"
msg.reply = "ai.pipeline.status.test-123"
return msg
@pytest.fixture
def argo_pipeline_request():
"""Sample Argo pipeline trigger request."""
return {
"request_id": "test-request-123",
"pipeline": "document-ingestion",
"parameters": {
"source_path": "s3://bucket/documents",
"collection_name": "test_collection",
},
}
@pytest.fixture
def kubeflow_pipeline_request():
"""Sample Kubeflow pipeline trigger request."""
return {
"request_id": "test-request-456",
"pipeline": "rag-query",
"parameters": {
"query": "What is AI?",
"collection": "documents",
},
}
@pytest.fixture
def unknown_pipeline_request():
"""Request for unknown pipeline."""
return {
"request_id": "test-request-789",
"pipeline": "nonexistent-pipeline",
"parameters": {},
}
@pytest.fixture
def mock_argo_response():
"""Mock Argo Workflows API response."""
return {
"metadata": {
"name": "document-ingestion-abc123",
"namespace": "ai-ml",
},
"status": {"phase": "Pending"},
}
@pytest.fixture
def mock_kubeflow_response():
"""Mock Kubeflow Pipelines API response."""
return {
"run": {
"id": "run-xyz-789",
"name": "rag-query-test",
"status": "Running",
}
}

View File

@@ -0,0 +1,271 @@
"""
Unit tests for PipelineBridge handler.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from pipeline_bridge import PipelineBridge, PipelineSettings, PIPELINES
class TestPipelineSettings:
"""Tests for PipelineSettings configuration."""
def test_default_settings(self):
"""Test default settings values."""
settings = PipelineSettings()
assert settings.service_name == "pipeline-bridge"
assert settings.kubeflow_host == "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
assert settings.argo_host == "http://argo-server.argo.svc.cluster.local:2746"
assert settings.argo_namespace == "ai-ml"
def test_custom_settings(self):
"""Test custom settings."""
settings = PipelineSettings(
kubeflow_host="http://custom-kubeflow:8888",
argo_namespace="custom-ns",
)
assert settings.kubeflow_host == "http://custom-kubeflow:8888"
assert settings.argo_namespace == "custom-ns"
class TestPipelineDefinitions:
"""Tests for pipeline definitions."""
def test_required_pipelines_exist(self):
"""Test that required pipelines are defined."""
required = ["document-ingestion", "batch-inference", "rag-query", "voice-pipeline"]
for name in required:
assert name in PIPELINES, f"Pipeline {name} should be defined"
def test_argo_pipelines_have_template(self):
"""Test Argo pipelines have template field."""
for name, config in PIPELINES.items():
if config["engine"] == "argo":
assert "template" in config, f"Argo pipeline {name} missing template"
def test_kubeflow_pipelines_have_pipeline_id(self):
"""Test Kubeflow pipelines have pipeline_id field."""
for name, config in PIPELINES.items():
if config["engine"] == "kubeflow":
assert "pipeline_id" in config, f"Kubeflow pipeline {name} missing pipeline_id"
def test_all_pipelines_have_description(self):
"""Test all pipelines have descriptions."""
for name, config in PIPELINES.items():
assert "description" in config, f"Pipeline {name} missing description"
class TestPipelineBridge:
"""Tests for PipelineBridge handler."""
@pytest.fixture
def handler(self):
"""Create handler with mocked HTTP client."""
handler = PipelineBridge()
handler._http = AsyncMock()
handler.nats = AsyncMock()
return handler
def test_init(self, handler):
"""Test handler initialization."""
assert handler.subject == "ai.pipeline.trigger"
assert handler.queue_group == "pipeline-bridges"
assert handler.pipeline_settings.service_name == "pipeline-bridge"
@pytest.mark.asyncio
async def test_handle_unknown_pipeline(
self,
handler,
mock_nats_message,
unknown_pipeline_request,
):
"""Test handling unknown pipeline."""
result = await handler.handle_message(mock_nats_message, unknown_pipeline_request)
assert result["status"] == "error"
assert "Unknown pipeline" in result["error"]
assert "available_pipelines" in result
assert "document-ingestion" in result["available_pipelines"]
@pytest.mark.asyncio
async def test_handle_argo_pipeline(
self,
handler,
mock_nats_message,
argo_pipeline_request,
mock_argo_response,
):
"""Test triggering Argo workflow."""
# Setup mock response
mock_response = MagicMock()
mock_response.json.return_value = mock_argo_response
mock_response.raise_for_status = MagicMock()
handler._http.post.return_value = mock_response
result = await handler.handle_message(mock_nats_message, argo_pipeline_request)
assert result["status"] == "submitted"
assert result["engine"] == "argo"
assert result["run_id"] == "document-ingestion-abc123"
assert result["pipeline"] == "document-ingestion"
assert "submitted_at" in result
# Verify API call
handler._http.post.assert_called_once()
call_args = handler._http.post.call_args
assert "argo-server" in str(call_args)
@pytest.mark.asyncio
async def test_handle_kubeflow_pipeline(
self,
handler,
mock_nats_message,
kubeflow_pipeline_request,
mock_kubeflow_response,
):
"""Test triggering Kubeflow pipeline."""
# Setup mock response
mock_response = MagicMock()
mock_response.json.return_value = mock_kubeflow_response
mock_response.raise_for_status = MagicMock()
handler._http.post.return_value = mock_response
result = await handler.handle_message(mock_nats_message, kubeflow_pipeline_request)
assert result["status"] == "submitted"
assert result["engine"] == "kubeflow"
assert result["run_id"] == "run-xyz-789"
assert result["pipeline"] == "rag-query"
# Verify API call
handler._http.post.assert_called_once()
call_args = handler._http.post.call_args
assert "ml-pipeline" in str(call_args)
@pytest.mark.asyncio
async def test_handle_api_error(
self,
handler,
mock_nats_message,
argo_pipeline_request,
):
"""Test handling API errors."""
handler._http.post.side_effect = Exception("Connection refused")
result = await handler.handle_message(mock_nats_message, argo_pipeline_request)
assert result["status"] == "error"
assert "Connection refused" in result["error"]
@pytest.mark.asyncio
async def test_publishes_status_update(
self,
handler,
mock_nats_message,
argo_pipeline_request,
mock_argo_response,
):
"""Test that status is published to NATS."""
mock_response = MagicMock()
mock_response.json.return_value = mock_argo_response
mock_response.raise_for_status = MagicMock()
handler._http.post.return_value = mock_response
await handler.handle_message(mock_nats_message, argo_pipeline_request)
handler.nats.publish.assert_called_once()
call_args = handler.nats.publish.call_args
assert "ai.pipeline.status.test-request-123" in str(call_args)
@pytest.mark.asyncio
async def test_setup_creates_http_client(self):
"""Test that setup initializes HTTP client."""
with patch("pipeline_bridge.httpx.AsyncClient") as mock_client:
handler = PipelineBridge()
await handler.setup()
mock_client.assert_called_once()
@pytest.mark.asyncio
async def test_teardown_closes_http_client(self, handler):
"""Test that teardown closes HTTP client."""
await handler.teardown()
handler._http.aclose.assert_called_once()
class TestArgoSubmission:
"""Tests for Argo workflow submission."""
@pytest.fixture
def handler(self):
"""Create handler with mocked HTTP client."""
handler = PipelineBridge()
handler._http = AsyncMock()
return handler
@pytest.mark.asyncio
async def test_argo_workflow_structure(
self,
handler,
mock_argo_response,
):
"""Test Argo workflow request structure."""
mock_response = MagicMock()
mock_response.json.return_value = mock_argo_response
mock_response.raise_for_status = MagicMock()
handler._http.post.return_value = mock_response
await handler._submit_argo(
template="document-ingestion",
parameters={"key": "value"},
request_id="test-123",
)
# Verify workflow structure
call_kwargs = handler._http.post.call_args.kwargs
workflow = call_kwargs["json"]["workflow"]
assert workflow["apiVersion"] == "argoproj.io/v1alpha1"
assert workflow["kind"] == "Workflow"
assert "workflowTemplateRef" in workflow["spec"]
assert workflow["spec"]["workflowTemplateRef"]["name"] == "document-ingestion"
assert workflow["metadata"]["labels"]["request-id"] == "test-123"
class TestKubeflowSubmission:
"""Tests for Kubeflow pipeline submission."""
@pytest.fixture
def handler(self):
"""Create handler with mocked HTTP client."""
handler = PipelineBridge()
handler._http = AsyncMock()
return handler
@pytest.mark.asyncio
async def test_kubeflow_run_structure(
self,
handler,
mock_kubeflow_response,
):
"""Test Kubeflow run request structure."""
mock_response = MagicMock()
mock_response.json.return_value = mock_kubeflow_response
mock_response.raise_for_status = MagicMock()
handler._http.post.return_value = mock_response
await handler._submit_kubeflow(
pipeline_id="rag-pipeline",
parameters={"query": "test"},
request_id="test-456",
)
# Verify run request structure
call_kwargs = handler._http.post.call_args.kwargs
run_request = call_kwargs["json"]
assert "rag-pipeline" in run_request["name"]
assert run_request["pipeline_spec"]["pipeline_id"] == "rag-pipeline"