refactor: consolidate to handler-base, migrate to pyproject.toml, add tests
This commit is contained in:
28
Dockerfile
28
Dockerfile
@@ -1,29 +1,13 @@
|
|||||||
FROM python:3.13-slim
|
# Pipeline Bridge - Using handler-base
|
||||||
|
ARG BASE_TAG=latest
|
||||||
|
FROM ghcr.io/billy-davies-2/handler-base:${BASE_TAG}
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Install uv for fast, reliable package management
|
# Install additional dependencies from pyproject.toml
|
||||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
|
COPY pyproject.toml .
|
||||||
|
RUN uv pip install --system --no-cache httpx kubernetes
|
||||||
|
|
||||||
# Install system dependencies
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
||||||
curl \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Copy requirements first for better caching
|
|
||||||
COPY requirements.txt .
|
|
||||||
RUN uv pip install --system --no-cache -r requirements.txt
|
|
||||||
|
|
||||||
# Copy application code
|
|
||||||
COPY pipeline_bridge.py .
|
COPY pipeline_bridge.py .
|
||||||
|
|
||||||
# Set environment variables
|
|
||||||
ENV PYTHONUNBUFFERED=1
|
|
||||||
ENV PYTHONDONTWRITEBYTECODE=1
|
|
||||||
|
|
||||||
# Health check
|
|
||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
|
||||||
CMD python -c "print('healthy')" || exit 1
|
|
||||||
|
|
||||||
# Run the application
|
|
||||||
CMD ["python", "pipeline_bridge.py"]
|
CMD ["python", "pipeline_bridge.py"]
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
# Pipeline Bridge v2 - Using handler-base
|
|
||||||
ARG BASE_TAG=local
|
|
||||||
FROM ghcr.io/billy-davies-2/handler-base:${BASE_TAG}
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
# Additional dependency for Kubernetes API
|
|
||||||
RUN uv pip install --system --no-cache kubernetes>=28.0.0
|
|
||||||
|
|
||||||
COPY pipeline_bridge_v2.py ./pipeline_bridge.py
|
|
||||||
|
|
||||||
CMD ["python", "pipeline_bridge.py"]
|
|
||||||
13
README.md
13
README.md
@@ -61,13 +61,9 @@ The bridge publishes status updates as the workflow progresses:
|
|||||||
- `failed` - Failed
|
- `failed` - Failed
|
||||||
- `error` - System error
|
- `error` - System error
|
||||||
|
|
||||||
## Variants
|
## Implementation
|
||||||
|
|
||||||
### pipeline_bridge.py (Standalone)
|
The pipeline bridge uses the [handler-base](https://git.daviestechlabs.io/daviestechlabs/handler-base) library for standardized NATS handling, telemetry, and health checks.
|
||||||
Self-contained service with pip install on startup. Good for simple deployments.
|
|
||||||
|
|
||||||
### pipeline_bridge_v2.py (handler-base)
|
|
||||||
Uses handler-base library for standardized NATS handling, telemetry, and health checks.
|
|
||||||
|
|
||||||
## Environment Variables
|
## Environment Variables
|
||||||
|
|
||||||
@@ -81,11 +77,10 @@ Uses handler-base library for standardized NATS handling, telemetry, and health
|
|||||||
## Building
|
## Building
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Standalone version
|
|
||||||
docker build -t pipeline-bridge:latest .
|
docker build -t pipeline-bridge:latest .
|
||||||
|
|
||||||
# handler-base version
|
# With specific handler-base tag
|
||||||
docker build -f Dockerfile.v2 -t pipeline-bridge:v2 --build-arg BASE_TAG=latest .
|
docker build --build-arg BASE_TAG=latest -t pipeline-bridge:latest .
|
||||||
```
|
```
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|||||||
@@ -1,351 +1,241 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
Pipeline Bridge Service
|
Pipeline Bridge Service (Refactored)
|
||||||
|
|
||||||
Bridges NATS events to workflow engines:
|
Bridges NATS events to workflow engines using handler-base:
|
||||||
1. Listen for pipeline triggers on "ai.pipeline.trigger"
|
1. Listen for pipeline triggers on "ai.pipeline.trigger"
|
||||||
2. Submit to Kubeflow Pipelines or Argo Workflows
|
2. Submit to Kubeflow Pipelines or Argo Workflows
|
||||||
3. Monitor execution and publish status updates
|
3. Monitor execution and publish status updates
|
||||||
4. Publish completion to "ai.pipeline.status.{request_id}"
|
4. Publish completion to "ai.pipeline.status.{request_id}"
|
||||||
|
|
||||||
Supported pipelines:
|
|
||||||
- document-ingestion: Ingest documents into Milvus
|
|
||||||
- batch-inference: Run batch LLM inference
|
|
||||||
- model-evaluation: Evaluate model performance
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
from typing import Any, Optional
|
||||||
import signal
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
from typing import Dict, Optional
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
# Install dependencies on startup
|
|
||||||
subprocess.check_call([
|
|
||||||
sys.executable, "-m", "pip", "install", "-q",
|
|
||||||
"-r", "/app/requirements.txt"
|
|
||||||
])
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import nats
|
from nats.aio.msg import Msg
|
||||||
from kubernetes import client, config
|
|
||||||
|
from handler_base import Handler, Settings
|
||||||
|
from handler_base.telemetry import create_span
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
||||||
)
|
|
||||||
logger = logging.getLogger("pipeline-bridge")
|
logger = logging.getLogger("pipeline-bridge")
|
||||||
|
|
||||||
# Configuration from environment
|
|
||||||
NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222")
|
|
||||||
KUBEFLOW_HOST = os.environ.get("KUBEFLOW_HOST", "http://ml-pipeline.kubeflow.svc.cluster.local:8888")
|
|
||||||
ARGO_HOST = os.environ.get("ARGO_HOST", "http://argo-server.argo.svc.cluster.local:2746")
|
|
||||||
ARGO_NAMESPACE = os.environ.get("ARGO_NAMESPACE", "ai-ml")
|
|
||||||
|
|
||||||
# NATS subjects
|
class PipelineSettings(Settings):
|
||||||
TRIGGER_SUBJECT = "ai.pipeline.trigger"
|
"""Pipeline bridge specific settings."""
|
||||||
STATUS_SUBJECT = "ai.pipeline.status"
|
|
||||||
|
|
||||||
# Pipeline definitions - maps pipeline names to their configurations
|
service_name: str = "pipeline-bridge"
|
||||||
|
|
||||||
|
# Kubeflow Pipelines
|
||||||
|
kubeflow_host: str = "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
|
||||||
|
|
||||||
|
# Argo Workflows
|
||||||
|
argo_host: str = "http://argo-server.argo.svc.cluster.local:2746"
|
||||||
|
argo_namespace: str = "ai-ml"
|
||||||
|
|
||||||
|
|
||||||
|
# Pipeline definitions
|
||||||
PIPELINES = {
|
PIPELINES = {
|
||||||
"document-ingestion": {
|
"document-ingestion": {
|
||||||
"engine": "argo",
|
"engine": "argo",
|
||||||
"template": "document-ingestion",
|
"template": "document-ingestion",
|
||||||
"description": "Ingest documents into Milvus vector database"
|
"description": "Ingest documents into Milvus vector database",
|
||||||
},
|
},
|
||||||
"batch-inference": {
|
"batch-inference": {
|
||||||
"engine": "argo",
|
"engine": "argo",
|
||||||
"template": "batch-inference",
|
"template": "batch-inference",
|
||||||
"description": "Run batch LLM inference on a dataset"
|
"description": "Run batch LLM inference on a dataset",
|
||||||
},
|
},
|
||||||
"rag-query": {
|
"rag-query": {
|
||||||
"engine": "kubeflow",
|
"engine": "kubeflow",
|
||||||
"pipeline_id": "rag-pipeline",
|
"pipeline_id": "rag-pipeline",
|
||||||
"description": "Execute RAG query pipeline"
|
"description": "Execute RAG query pipeline",
|
||||||
},
|
},
|
||||||
"voice-pipeline": {
|
"voice-pipeline": {
|
||||||
"engine": "kubeflow",
|
"engine": "kubeflow",
|
||||||
"pipeline_id": "voice-pipeline",
|
"pipeline_id": "voice-pipeline",
|
||||||
"description": "Full voice assistant pipeline"
|
"description": "Full voice assistant pipeline",
|
||||||
}
|
},
|
||||||
|
"model-evaluation": {
|
||||||
|
"engine": "argo",
|
||||||
|
"template": "model-evaluation",
|
||||||
|
"description": "Evaluate model performance",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class PipelineBridge:
|
class PipelineBridge(Handler):
|
||||||
|
"""
|
||||||
|
Pipeline trigger handler.
|
||||||
|
|
||||||
|
Request format:
|
||||||
|
{
|
||||||
|
"request_id": "uuid",
|
||||||
|
"pipeline": "document-ingestion",
|
||||||
|
"parameters": {"key": "value"}
|
||||||
|
}
|
||||||
|
|
||||||
|
Response format:
|
||||||
|
{
|
||||||
|
"request_id": "uuid",
|
||||||
|
"status": "submitted",
|
||||||
|
"run_id": "workflow-run-id",
|
||||||
|
"engine": "argo|kubeflow"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.nc = None
|
self.pipeline_settings = PipelineSettings()
|
||||||
self.http_client = None
|
super().__init__(
|
||||||
self.running = True
|
subject="ai.pipeline.trigger",
|
||||||
self.active_workflows = {} # Track running workflows
|
settings=self.pipeline_settings,
|
||||||
|
queue_group="pipeline-bridges",
|
||||||
|
)
|
||||||
|
|
||||||
async def setup(self):
|
self._http: Optional[httpx.AsyncClient] = None
|
||||||
"""Initialize all connections."""
|
|
||||||
# NATS connection
|
|
||||||
self.nc = await nats.connect(NATS_URL)
|
|
||||||
logger.info(f"Connected to NATS at {NATS_URL}")
|
|
||||||
|
|
||||||
# HTTP client for API calls
|
async def setup(self) -> None:
|
||||||
self.http_client = httpx.AsyncClient(timeout=60.0)
|
"""Initialize HTTP client."""
|
||||||
|
logger.info("Initializing pipeline bridge...")
|
||||||
|
|
||||||
|
self._http = httpx.AsyncClient(timeout=60.0)
|
||||||
|
|
||||||
|
logger.info(f"Pipeline bridge ready. Available pipelines: {list(PIPELINES.keys())}")
|
||||||
|
|
||||||
|
async def teardown(self) -> None:
|
||||||
|
"""Clean up HTTP client."""
|
||||||
|
if self._http:
|
||||||
|
await self._http.aclose()
|
||||||
|
logger.info("Pipeline bridge closed")
|
||||||
|
|
||||||
|
async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]:
|
||||||
|
"""Handle pipeline trigger request."""
|
||||||
|
request_id = data.get("request_id", "unknown")
|
||||||
|
pipeline_name = data.get("pipeline", "")
|
||||||
|
parameters = data.get("parameters", {})
|
||||||
|
|
||||||
|
logger.info(f"Triggering pipeline '{pipeline_name}' for request {request_id}")
|
||||||
|
|
||||||
|
with create_span("pipeline.trigger") as span:
|
||||||
|
if span:
|
||||||
|
span.set_attribute("request.id", request_id)
|
||||||
|
span.set_attribute("pipeline.name", pipeline_name)
|
||||||
|
|
||||||
|
# Validate pipeline
|
||||||
|
if pipeline_name not in PIPELINES:
|
||||||
|
error = f"Unknown pipeline: {pipeline_name}"
|
||||||
|
logger.error(error)
|
||||||
|
return {
|
||||||
|
"request_id": request_id,
|
||||||
|
"status": "error",
|
||||||
|
"error": error,
|
||||||
|
"available_pipelines": list(PIPELINES.keys()),
|
||||||
|
}
|
||||||
|
|
||||||
|
pipeline = PIPELINES[pipeline_name]
|
||||||
|
engine = pipeline["engine"]
|
||||||
|
|
||||||
# Initialize Kubernetes client for Argo
|
|
||||||
try:
|
try:
|
||||||
config.load_incluster_config()
|
if engine == "argo":
|
||||||
self.k8s_custom = client.CustomObjectsApi()
|
run_id = await self._submit_argo(
|
||||||
logger.info("Kubernetes client initialized")
|
pipeline["template"], parameters, request_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
run_id = await self._submit_kubeflow(
|
||||||
|
pipeline["pipeline_id"], parameters, request_id
|
||||||
|
)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"request_id": request_id,
|
||||||
|
"status": "submitted",
|
||||||
|
"run_id": run_id,
|
||||||
|
"engine": engine,
|
||||||
|
"pipeline": pipeline_name,
|
||||||
|
"submitted_at": datetime.utcnow().isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Publish status update
|
||||||
|
await self.nats.publish(
|
||||||
|
f"ai.pipeline.status.{request_id}", result
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Pipeline {pipeline_name} submitted: {run_id}")
|
||||||
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Kubernetes client failed: {e}")
|
logger.exception(f"Failed to submit pipeline {pipeline_name}")
|
||||||
self.k8s_custom = None
|
return {
|
||||||
|
"request_id": request_id,
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
|
||||||
async def submit_argo_workflow(self, template: str, parameters: Dict, request_id: str) -> Optional[str]:
|
async def _submit_argo(
|
||||||
"""Submit an Argo Workflow from a WorkflowTemplate."""
|
self, template: str, parameters: dict, request_id: str
|
||||||
if not self.k8s_custom:
|
) -> str:
|
||||||
logger.error("Kubernetes client not available")
|
"""Submit workflow to Argo Workflows."""
|
||||||
return None
|
with create_span("pipeline.submit.argo") as span:
|
||||||
|
if span:
|
||||||
|
span.set_attribute("argo.template", template)
|
||||||
|
|
||||||
try:
|
|
||||||
# Create workflow from template
|
|
||||||
workflow = {
|
workflow = {
|
||||||
"apiVersion": "argoproj.io/v1alpha1",
|
"apiVersion": "argoproj.io/v1alpha1",
|
||||||
"kind": "Workflow",
|
"kind": "Workflow",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"generateName": f"{template}-",
|
"generateName": f"{template}-",
|
||||||
"namespace": ARGO_NAMESPACE,
|
"namespace": self.pipeline_settings.argo_namespace,
|
||||||
"labels": {
|
"labels": {
|
||||||
"app.kubernetes.io/managed-by": "pipeline-bridge",
|
"request-id": request_id,
|
||||||
"pipeline-bridge/request-id": request_id
|
},
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"spec": {
|
"spec": {
|
||||||
"workflowTemplateRef": {
|
"workflowTemplateRef": {"name": template},
|
||||||
"name": template
|
|
||||||
},
|
|
||||||
"arguments": {
|
"arguments": {
|
||||||
"parameters": [
|
"parameters": [
|
||||||
{"name": k, "value": str(v)}
|
{"name": k, "value": str(v)}
|
||||||
for k, v in parameters.items()
|
for k, v in parameters.items()
|
||||||
]
|
]
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.k8s_custom.create_namespaced_custom_object(
|
response = await self._http.post(
|
||||||
group="argoproj.io",
|
f"{self.pipeline_settings.argo_host}/api/v1/workflows/{self.pipeline_settings.argo_namespace}",
|
||||||
version="v1alpha1",
|
json={"workflow": workflow},
|
||||||
namespace=ARGO_NAMESPACE,
|
|
||||||
plural="workflows",
|
|
||||||
body=workflow
|
|
||||||
)
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
workflow_name = result["metadata"]["name"]
|
result = response.json()
|
||||||
logger.info(f"Submitted Argo workflow: {workflow_name}")
|
return result["metadata"]["name"]
|
||||||
return workflow_name
|
|
||||||
|
|
||||||
except Exception as e:
|
async def _submit_kubeflow(
|
||||||
logger.error(f"Failed to submit Argo workflow: {e}")
|
self, pipeline_id: str, parameters: dict, request_id: str
|
||||||
return None
|
) -> str:
|
||||||
|
"""Submit run to Kubeflow Pipelines."""
|
||||||
|
with create_span("pipeline.submit.kubeflow") as span:
|
||||||
|
if span:
|
||||||
|
span.set_attribute("kubeflow.pipeline_id", pipeline_id)
|
||||||
|
|
||||||
async def submit_kubeflow_pipeline(self, pipeline_id: str, parameters: Dict, request_id: str) -> Optional[str]:
|
|
||||||
"""Submit a Kubeflow Pipeline run."""
|
|
||||||
try:
|
|
||||||
# Create pipeline run via Kubeflow API
|
|
||||||
run_request = {
|
run_request = {
|
||||||
"name": f"{pipeline_id}-{request_id[:8]}",
|
"name": f"{pipeline_id}-{request_id[:8]}",
|
||||||
"pipeline_spec": {
|
"pipeline_spec": {
|
||||||
"pipeline_id": pipeline_id
|
"pipeline_id": pipeline_id,
|
||||||
},
|
|
||||||
"resource_references": [],
|
|
||||||
"parameters": [
|
"parameters": [
|
||||||
{"name": k, "value": str(v)}
|
{"name": k, "value": str(v)}
|
||||||
for k, v in parameters.items()
|
for k, v in parameters.items()
|
||||||
]
|
],
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
response = await self.http_client.post(
|
response = await self._http.post(
|
||||||
f"{KUBEFLOW_HOST}/apis/v1beta1/runs",
|
f"{self.pipeline_settings.kubeflow_host}/apis/v1beta1/runs",
|
||||||
json=run_request
|
json=run_request,
|
||||||
)
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
result = response.json()
|
result = response.json()
|
||||||
run_id = result.get("run", {}).get("id")
|
return result["run"]["id"]
|
||||||
logger.info(f"Submitted Kubeflow pipeline run: {run_id}")
|
|
||||||
return run_id
|
|
||||||
else:
|
|
||||||
logger.error(f"Kubeflow API error: {response.status_code} - {response.text}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to submit Kubeflow pipeline: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_argo_workflow_status(self, workflow_name: str) -> Dict:
|
|
||||||
"""Get status of an Argo Workflow."""
|
|
||||||
if not self.k8s_custom:
|
|
||||||
return {"phase": "Unknown", "message": "Kubernetes client not available"}
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = self.k8s_custom.get_namespaced_custom_object(
|
|
||||||
group="argoproj.io",
|
|
||||||
version="v1alpha1",
|
|
||||||
namespace=ARGO_NAMESPACE,
|
|
||||||
plural="workflows",
|
|
||||||
name=workflow_name
|
|
||||||
)
|
|
||||||
|
|
||||||
status = result.get("status", {})
|
|
||||||
return {
|
|
||||||
"phase": status.get("phase", "Pending"),
|
|
||||||
"message": status.get("message", ""),
|
|
||||||
"startedAt": status.get("startedAt"),
|
|
||||||
"finishedAt": status.get("finishedAt"),
|
|
||||||
"progress": status.get("progress", "0/0")
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to get workflow status: {e}")
|
|
||||||
return {"phase": "Error", "message": str(e)}
|
|
||||||
|
|
||||||
async def process_trigger(self, msg):
|
|
||||||
"""Process a pipeline trigger request."""
|
|
||||||
try:
|
|
||||||
data = json.loads(msg.data.decode())
|
|
||||||
request_id = data.get("request_id", "unknown")
|
|
||||||
pipeline_name = data.get("pipeline", "")
|
|
||||||
parameters = data.get("parameters", {})
|
|
||||||
|
|
||||||
logger.info(f"Processing pipeline trigger {request_id}: {pipeline_name}")
|
|
||||||
|
|
||||||
# Validate pipeline
|
|
||||||
if pipeline_name not in PIPELINES:
|
|
||||||
await self.publish_status(request_id, {
|
|
||||||
"status": "error",
|
|
||||||
"message": f"Unknown pipeline: {pipeline_name}",
|
|
||||||
"available_pipelines": list(PIPELINES.keys())
|
|
||||||
})
|
|
||||||
return
|
|
||||||
|
|
||||||
pipeline_config = PIPELINES[pipeline_name]
|
|
||||||
engine = pipeline_config["engine"]
|
|
||||||
|
|
||||||
# Submit to appropriate engine
|
|
||||||
run_id = None
|
|
||||||
if engine == "argo":
|
|
||||||
run_id = await self.submit_argo_workflow(
|
|
||||||
pipeline_config["template"],
|
|
||||||
parameters,
|
|
||||||
request_id
|
|
||||||
)
|
|
||||||
elif engine == "kubeflow":
|
|
||||||
run_id = await self.submit_kubeflow_pipeline(
|
|
||||||
pipeline_config["pipeline_id"],
|
|
||||||
parameters,
|
|
||||||
request_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if run_id:
|
|
||||||
# Track workflow for status updates
|
|
||||||
self.active_workflows[request_id] = {
|
|
||||||
"engine": engine,
|
|
||||||
"run_id": run_id,
|
|
||||||
"started_at": datetime.utcnow().isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
await self.publish_status(request_id, {
|
|
||||||
"status": "submitted",
|
|
||||||
"pipeline": pipeline_name,
|
|
||||||
"engine": engine,
|
|
||||||
"run_id": run_id,
|
|
||||||
"message": f"Pipeline submitted successfully"
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
await self.publish_status(request_id, {
|
|
||||||
"status": "error",
|
|
||||||
"pipeline": pipeline_name,
|
|
||||||
"message": "Failed to submit pipeline"
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Trigger processing failed: {e}")
|
|
||||||
await self.publish_status(
|
|
||||||
data.get("request_id", "unknown"),
|
|
||||||
{"status": "error", "message": str(e)}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def publish_status(self, request_id: str, status: Dict):
|
|
||||||
"""Publish pipeline status update."""
|
|
||||||
status["request_id"] = request_id
|
|
||||||
status["timestamp"] = datetime.utcnow().isoformat()
|
|
||||||
await self.nc.publish(
|
|
||||||
f"{STATUS_SUBJECT}.{request_id}",
|
|
||||||
json.dumps(status).encode()
|
|
||||||
)
|
|
||||||
logger.info(f"Published status for {request_id}: {status.get('status')}")
|
|
||||||
|
|
||||||
async def monitor_workflows(self):
|
|
||||||
"""Periodically check and publish status of active workflows."""
|
|
||||||
while self.running:
|
|
||||||
completed = []
|
|
||||||
|
|
||||||
for request_id, workflow in self.active_workflows.items():
|
|
||||||
try:
|
|
||||||
if workflow["engine"] == "argo":
|
|
||||||
status = await self.get_argo_workflow_status(workflow["run_id"])
|
|
||||||
|
|
||||||
# Publish status update
|
|
||||||
await self.publish_status(request_id, {
|
|
||||||
"status": status["phase"].lower(),
|
|
||||||
"run_id": workflow["run_id"],
|
|
||||||
"progress": status.get("progress"),
|
|
||||||
"message": status.get("message", "")
|
|
||||||
})
|
|
||||||
|
|
||||||
# Check if completed
|
|
||||||
if status["phase"] in ["Succeeded", "Failed", "Error"]:
|
|
||||||
completed.append(request_id)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error monitoring workflow {request_id}: {e}")
|
|
||||||
|
|
||||||
# Remove completed workflows from tracking
|
|
||||||
for request_id in completed:
|
|
||||||
del self.active_workflows[request_id]
|
|
||||||
|
|
||||||
await asyncio.sleep(10) # Check every 10 seconds
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
"""Main run loop."""
|
|
||||||
await self.setup()
|
|
||||||
|
|
||||||
# Subscribe to pipeline triggers
|
|
||||||
sub = await self.nc.subscribe(TRIGGER_SUBJECT, cb=self.process_trigger)
|
|
||||||
logger.info(f"Subscribed to {TRIGGER_SUBJECT}")
|
|
||||||
|
|
||||||
# Start workflow monitor
|
|
||||||
monitor_task = asyncio.create_task(self.monitor_workflows())
|
|
||||||
|
|
||||||
# Handle shutdown
|
|
||||||
def signal_handler():
|
|
||||||
self.running = False
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
|
||||||
loop.add_signal_handler(sig, signal_handler)
|
|
||||||
|
|
||||||
# Keep running
|
|
||||||
while self.running:
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
|
|
||||||
# Cleanup
|
|
||||||
monitor_task.cancel()
|
|
||||||
await sub.unsubscribe()
|
|
||||||
await self.nc.close()
|
|
||||||
logger.info("Shutdown complete")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
bridge = PipelineBridge()
|
PipelineBridge().run()
|
||||||
asyncio.run(bridge.run())
|
|
||||||
|
|||||||
@@ -1,241 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Pipeline Bridge Service (Refactored)
|
|
||||||
|
|
||||||
Bridges NATS events to workflow engines using handler-base:
|
|
||||||
1. Listen for pipeline triggers on "ai.pipeline.trigger"
|
|
||||||
2. Submit to Kubeflow Pipelines or Argo Workflows
|
|
||||||
3. Monitor execution and publish status updates
|
|
||||||
4. Publish completion to "ai.pipeline.status.{request_id}"
|
|
||||||
"""
|
|
||||||
import logging
|
|
||||||
from typing import Any, Optional
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
from nats.aio.msg import Msg
|
|
||||||
|
|
||||||
from handler_base import Handler, Settings
|
|
||||||
from handler_base.telemetry import create_span
|
|
||||||
|
|
||||||
logger = logging.getLogger("pipeline-bridge")
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineSettings(Settings):
|
|
||||||
"""Pipeline bridge specific settings."""
|
|
||||||
|
|
||||||
service_name: str = "pipeline-bridge"
|
|
||||||
|
|
||||||
# Kubeflow Pipelines
|
|
||||||
kubeflow_host: str = "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
|
|
||||||
|
|
||||||
# Argo Workflows
|
|
||||||
argo_host: str = "http://argo-server.argo.svc.cluster.local:2746"
|
|
||||||
argo_namespace: str = "ai-ml"
|
|
||||||
|
|
||||||
|
|
||||||
# Pipeline definitions
|
|
||||||
PIPELINES = {
|
|
||||||
"document-ingestion": {
|
|
||||||
"engine": "argo",
|
|
||||||
"template": "document-ingestion",
|
|
||||||
"description": "Ingest documents into Milvus vector database",
|
|
||||||
},
|
|
||||||
"batch-inference": {
|
|
||||||
"engine": "argo",
|
|
||||||
"template": "batch-inference",
|
|
||||||
"description": "Run batch LLM inference on a dataset",
|
|
||||||
},
|
|
||||||
"rag-query": {
|
|
||||||
"engine": "kubeflow",
|
|
||||||
"pipeline_id": "rag-pipeline",
|
|
||||||
"description": "Execute RAG query pipeline",
|
|
||||||
},
|
|
||||||
"voice-pipeline": {
|
|
||||||
"engine": "kubeflow",
|
|
||||||
"pipeline_id": "voice-pipeline",
|
|
||||||
"description": "Full voice assistant pipeline",
|
|
||||||
},
|
|
||||||
"model-evaluation": {
|
|
||||||
"engine": "argo",
|
|
||||||
"template": "model-evaluation",
|
|
||||||
"description": "Evaluate model performance",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineBridge(Handler):
|
|
||||||
"""
|
|
||||||
Pipeline trigger handler.
|
|
||||||
|
|
||||||
Request format:
|
|
||||||
{
|
|
||||||
"request_id": "uuid",
|
|
||||||
"pipeline": "document-ingestion",
|
|
||||||
"parameters": {"key": "value"}
|
|
||||||
}
|
|
||||||
|
|
||||||
Response format:
|
|
||||||
{
|
|
||||||
"request_id": "uuid",
|
|
||||||
"status": "submitted",
|
|
||||||
"run_id": "workflow-run-id",
|
|
||||||
"engine": "argo|kubeflow"
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.pipeline_settings = PipelineSettings()
|
|
||||||
super().__init__(
|
|
||||||
subject="ai.pipeline.trigger",
|
|
||||||
settings=self.pipeline_settings,
|
|
||||||
queue_group="pipeline-bridges",
|
|
||||||
)
|
|
||||||
|
|
||||||
self._http: Optional[httpx.AsyncClient] = None
|
|
||||||
|
|
||||||
async def setup(self) -> None:
|
|
||||||
"""Initialize HTTP client."""
|
|
||||||
logger.info("Initializing pipeline bridge...")
|
|
||||||
|
|
||||||
self._http = httpx.AsyncClient(timeout=60.0)
|
|
||||||
|
|
||||||
logger.info(f"Pipeline bridge ready. Available pipelines: {list(PIPELINES.keys())}")
|
|
||||||
|
|
||||||
async def teardown(self) -> None:
|
|
||||||
"""Clean up HTTP client."""
|
|
||||||
if self._http:
|
|
||||||
await self._http.aclose()
|
|
||||||
logger.info("Pipeline bridge closed")
|
|
||||||
|
|
||||||
async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]:
|
|
||||||
"""Handle pipeline trigger request."""
|
|
||||||
request_id = data.get("request_id", "unknown")
|
|
||||||
pipeline_name = data.get("pipeline", "")
|
|
||||||
parameters = data.get("parameters", {})
|
|
||||||
|
|
||||||
logger.info(f"Triggering pipeline '{pipeline_name}' for request {request_id}")
|
|
||||||
|
|
||||||
with create_span("pipeline.trigger") as span:
|
|
||||||
if span:
|
|
||||||
span.set_attribute("request.id", request_id)
|
|
||||||
span.set_attribute("pipeline.name", pipeline_name)
|
|
||||||
|
|
||||||
# Validate pipeline
|
|
||||||
if pipeline_name not in PIPELINES:
|
|
||||||
error = f"Unknown pipeline: {pipeline_name}"
|
|
||||||
logger.error(error)
|
|
||||||
return {
|
|
||||||
"request_id": request_id,
|
|
||||||
"status": "error",
|
|
||||||
"error": error,
|
|
||||||
"available_pipelines": list(PIPELINES.keys()),
|
|
||||||
}
|
|
||||||
|
|
||||||
pipeline = PIPELINES[pipeline_name]
|
|
||||||
engine = pipeline["engine"]
|
|
||||||
|
|
||||||
try:
|
|
||||||
if engine == "argo":
|
|
||||||
run_id = await self._submit_argo(
|
|
||||||
pipeline["template"], parameters, request_id
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
run_id = await self._submit_kubeflow(
|
|
||||||
pipeline["pipeline_id"], parameters, request_id
|
|
||||||
)
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"request_id": request_id,
|
|
||||||
"status": "submitted",
|
|
||||||
"run_id": run_id,
|
|
||||||
"engine": engine,
|
|
||||||
"pipeline": pipeline_name,
|
|
||||||
"submitted_at": datetime.utcnow().isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Publish status update
|
|
||||||
await self.nats.publish(
|
|
||||||
f"ai.pipeline.status.{request_id}", result
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Pipeline {pipeline_name} submitted: {run_id}")
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Failed to submit pipeline {pipeline_name}")
|
|
||||||
return {
|
|
||||||
"request_id": request_id,
|
|
||||||
"status": "error",
|
|
||||||
"error": str(e),
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _submit_argo(
|
|
||||||
self, template: str, parameters: dict, request_id: str
|
|
||||||
) -> str:
|
|
||||||
"""Submit workflow to Argo Workflows."""
|
|
||||||
with create_span("pipeline.submit.argo") as span:
|
|
||||||
if span:
|
|
||||||
span.set_attribute("argo.template", template)
|
|
||||||
|
|
||||||
workflow = {
|
|
||||||
"apiVersion": "argoproj.io/v1alpha1",
|
|
||||||
"kind": "Workflow",
|
|
||||||
"metadata": {
|
|
||||||
"generateName": f"{template}-",
|
|
||||||
"namespace": self.pipeline_settings.argo_namespace,
|
|
||||||
"labels": {
|
|
||||||
"request-id": request_id,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"spec": {
|
|
||||||
"workflowTemplateRef": {"name": template},
|
|
||||||
"arguments": {
|
|
||||||
"parameters": [
|
|
||||||
{"name": k, "value": str(v)}
|
|
||||||
for k, v in parameters.items()
|
|
||||||
]
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
response = await self._http.post(
|
|
||||||
f"{self.pipeline_settings.argo_host}/api/v1/workflows/{self.pipeline_settings.argo_namespace}",
|
|
||||||
json={"workflow": workflow},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
return result["metadata"]["name"]
|
|
||||||
|
|
||||||
async def _submit_kubeflow(
|
|
||||||
self, pipeline_id: str, parameters: dict, request_id: str
|
|
||||||
) -> str:
|
|
||||||
"""Submit run to Kubeflow Pipelines."""
|
|
||||||
with create_span("pipeline.submit.kubeflow") as span:
|
|
||||||
if span:
|
|
||||||
span.set_attribute("kubeflow.pipeline_id", pipeline_id)
|
|
||||||
|
|
||||||
run_request = {
|
|
||||||
"name": f"{pipeline_id}-{request_id[:8]}",
|
|
||||||
"pipeline_spec": {
|
|
||||||
"pipeline_id": pipeline_id,
|
|
||||||
"parameters": [
|
|
||||||
{"name": k, "value": str(v)}
|
|
||||||
for k, v in parameters.items()
|
|
||||||
],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
response = await self._http.post(
|
|
||||||
f"{self.pipeline_settings.kubeflow_host}/apis/v1beta1/runs",
|
|
||||||
json=run_request,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
return result["run"]["id"]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
PipelineBridge().run()
|
|
||||||
42
pyproject.toml
Normal file
42
pyproject.toml
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
[project]
|
||||||
|
name = "pipeline-bridge"
|
||||||
|
version = "1.0.0"
|
||||||
|
description = "Bridge NATS events to Kubeflow Pipelines and Argo Workflows"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
license = { text = "MIT" }
|
||||||
|
authors = [{ name = "Davies Tech Labs" }]
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
"handler-base @ git+https://git.daviestechlabs.io/daviestechlabs/handler-base.git",
|
||||||
|
"httpx>=0.27.0",
|
||||||
|
"kubernetes>=28.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
"pytest-asyncio>=0.23.0",
|
||||||
|
"ruff>=0.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["."]
|
||||||
|
only-include = ["pipeline_bridge.py"]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 100
|
||||||
|
target-version = "py311"
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
testpaths = ["tests"]
|
||||||
|
python_files = ["test_*.py"]
|
||||||
|
python_classes = ["Test*"]
|
||||||
|
python_functions = ["test_*"]
|
||||||
|
addopts = "-v --tb=short"
|
||||||
|
filterwarnings = ["ignore::DeprecationWarning"]
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
nats-py
|
|
||||||
httpx
|
|
||||||
kubernetes
|
|
||||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Pipeline Bridge Tests
|
||||||
90
tests/conftest.py
Normal file
90
tests/conftest.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""
|
||||||
|
Pytest configuration and fixtures for pipeline-bridge tests.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Set test environment variables before importing
|
||||||
|
os.environ.setdefault("NATS_URL", "nats://localhost:4222")
|
||||||
|
os.environ.setdefault("OTEL_ENABLED", "false")
|
||||||
|
os.environ.setdefault("MLFLOW_ENABLED", "false")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def event_loop():
|
||||||
|
"""Create event loop for async tests."""
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
yield loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_nats_message():
|
||||||
|
"""Create a mock NATS message."""
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.subject = "ai.pipeline.trigger"
|
||||||
|
msg.reply = "ai.pipeline.status.test-123"
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def argo_pipeline_request():
|
||||||
|
"""Sample Argo pipeline trigger request."""
|
||||||
|
return {
|
||||||
|
"request_id": "test-request-123",
|
||||||
|
"pipeline": "document-ingestion",
|
||||||
|
"parameters": {
|
||||||
|
"source_path": "s3://bucket/documents",
|
||||||
|
"collection_name": "test_collection",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def kubeflow_pipeline_request():
|
||||||
|
"""Sample Kubeflow pipeline trigger request."""
|
||||||
|
return {
|
||||||
|
"request_id": "test-request-456",
|
||||||
|
"pipeline": "rag-query",
|
||||||
|
"parameters": {
|
||||||
|
"query": "What is AI?",
|
||||||
|
"collection": "documents",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def unknown_pipeline_request():
|
||||||
|
"""Request for unknown pipeline."""
|
||||||
|
return {
|
||||||
|
"request_id": "test-request-789",
|
||||||
|
"pipeline": "nonexistent-pipeline",
|
||||||
|
"parameters": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_argo_response():
|
||||||
|
"""Mock Argo Workflows API response."""
|
||||||
|
return {
|
||||||
|
"metadata": {
|
||||||
|
"name": "document-ingestion-abc123",
|
||||||
|
"namespace": "ai-ml",
|
||||||
|
},
|
||||||
|
"status": {"phase": "Pending"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_kubeflow_response():
|
||||||
|
"""Mock Kubeflow Pipelines API response."""
|
||||||
|
return {
|
||||||
|
"run": {
|
||||||
|
"id": "run-xyz-789",
|
||||||
|
"name": "rag-query-test",
|
||||||
|
"status": "Running",
|
||||||
|
}
|
||||||
|
}
|
||||||
271
tests/test_pipeline_bridge.py
Normal file
271
tests/test_pipeline_bridge.py
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for PipelineBridge handler.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from pipeline_bridge import PipelineBridge, PipelineSettings, PIPELINES
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipelineSettings:
|
||||||
|
"""Tests for PipelineSettings configuration."""
|
||||||
|
|
||||||
|
def test_default_settings(self):
|
||||||
|
"""Test default settings values."""
|
||||||
|
settings = PipelineSettings()
|
||||||
|
|
||||||
|
assert settings.service_name == "pipeline-bridge"
|
||||||
|
assert settings.kubeflow_host == "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
|
||||||
|
assert settings.argo_host == "http://argo-server.argo.svc.cluster.local:2746"
|
||||||
|
assert settings.argo_namespace == "ai-ml"
|
||||||
|
|
||||||
|
def test_custom_settings(self):
|
||||||
|
"""Test custom settings."""
|
||||||
|
settings = PipelineSettings(
|
||||||
|
kubeflow_host="http://custom-kubeflow:8888",
|
||||||
|
argo_namespace="custom-ns",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert settings.kubeflow_host == "http://custom-kubeflow:8888"
|
||||||
|
assert settings.argo_namespace == "custom-ns"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipelineDefinitions:
|
||||||
|
"""Tests for pipeline definitions."""
|
||||||
|
|
||||||
|
def test_required_pipelines_exist(self):
|
||||||
|
"""Test that required pipelines are defined."""
|
||||||
|
required = ["document-ingestion", "batch-inference", "rag-query", "voice-pipeline"]
|
||||||
|
for name in required:
|
||||||
|
assert name in PIPELINES, f"Pipeline {name} should be defined"
|
||||||
|
|
||||||
|
def test_argo_pipelines_have_template(self):
|
||||||
|
"""Test Argo pipelines have template field."""
|
||||||
|
for name, config in PIPELINES.items():
|
||||||
|
if config["engine"] == "argo":
|
||||||
|
assert "template" in config, f"Argo pipeline {name} missing template"
|
||||||
|
|
||||||
|
def test_kubeflow_pipelines_have_pipeline_id(self):
|
||||||
|
"""Test Kubeflow pipelines have pipeline_id field."""
|
||||||
|
for name, config in PIPELINES.items():
|
||||||
|
if config["engine"] == "kubeflow":
|
||||||
|
assert "pipeline_id" in config, f"Kubeflow pipeline {name} missing pipeline_id"
|
||||||
|
|
||||||
|
def test_all_pipelines_have_description(self):
|
||||||
|
"""Test all pipelines have descriptions."""
|
||||||
|
for name, config in PIPELINES.items():
|
||||||
|
assert "description" in config, f"Pipeline {name} missing description"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipelineBridge:
|
||||||
|
"""Tests for PipelineBridge handler."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def handler(self):
|
||||||
|
"""Create handler with mocked HTTP client."""
|
||||||
|
handler = PipelineBridge()
|
||||||
|
handler._http = AsyncMock()
|
||||||
|
handler.nats = AsyncMock()
|
||||||
|
return handler
|
||||||
|
|
||||||
|
def test_init(self, handler):
|
||||||
|
"""Test handler initialization."""
|
||||||
|
assert handler.subject == "ai.pipeline.trigger"
|
||||||
|
assert handler.queue_group == "pipeline-bridges"
|
||||||
|
assert handler.pipeline_settings.service_name == "pipeline-bridge"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_unknown_pipeline(
|
||||||
|
self,
|
||||||
|
handler,
|
||||||
|
mock_nats_message,
|
||||||
|
unknown_pipeline_request,
|
||||||
|
):
|
||||||
|
"""Test handling unknown pipeline."""
|
||||||
|
result = await handler.handle_message(mock_nats_message, unknown_pipeline_request)
|
||||||
|
|
||||||
|
assert result["status"] == "error"
|
||||||
|
assert "Unknown pipeline" in result["error"]
|
||||||
|
assert "available_pipelines" in result
|
||||||
|
assert "document-ingestion" in result["available_pipelines"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_argo_pipeline(
|
||||||
|
self,
|
||||||
|
handler,
|
||||||
|
mock_nats_message,
|
||||||
|
argo_pipeline_request,
|
||||||
|
mock_argo_response,
|
||||||
|
):
|
||||||
|
"""Test triggering Argo workflow."""
|
||||||
|
# Setup mock response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = mock_argo_response
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
handler._http.post.return_value = mock_response
|
||||||
|
|
||||||
|
result = await handler.handle_message(mock_nats_message, argo_pipeline_request)
|
||||||
|
|
||||||
|
assert result["status"] == "submitted"
|
||||||
|
assert result["engine"] == "argo"
|
||||||
|
assert result["run_id"] == "document-ingestion-abc123"
|
||||||
|
assert result["pipeline"] == "document-ingestion"
|
||||||
|
assert "submitted_at" in result
|
||||||
|
|
||||||
|
# Verify API call
|
||||||
|
handler._http.post.assert_called_once()
|
||||||
|
call_args = handler._http.post.call_args
|
||||||
|
assert "argo-server" in str(call_args)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_kubeflow_pipeline(
|
||||||
|
self,
|
||||||
|
handler,
|
||||||
|
mock_nats_message,
|
||||||
|
kubeflow_pipeline_request,
|
||||||
|
mock_kubeflow_response,
|
||||||
|
):
|
||||||
|
"""Test triggering Kubeflow pipeline."""
|
||||||
|
# Setup mock response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = mock_kubeflow_response
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
handler._http.post.return_value = mock_response
|
||||||
|
|
||||||
|
result = await handler.handle_message(mock_nats_message, kubeflow_pipeline_request)
|
||||||
|
|
||||||
|
assert result["status"] == "submitted"
|
||||||
|
assert result["engine"] == "kubeflow"
|
||||||
|
assert result["run_id"] == "run-xyz-789"
|
||||||
|
assert result["pipeline"] == "rag-query"
|
||||||
|
|
||||||
|
# Verify API call
|
||||||
|
handler._http.post.assert_called_once()
|
||||||
|
call_args = handler._http.post.call_args
|
||||||
|
assert "ml-pipeline" in str(call_args)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_api_error(
|
||||||
|
self,
|
||||||
|
handler,
|
||||||
|
mock_nats_message,
|
||||||
|
argo_pipeline_request,
|
||||||
|
):
|
||||||
|
"""Test handling API errors."""
|
||||||
|
handler._http.post.side_effect = Exception("Connection refused")
|
||||||
|
|
||||||
|
result = await handler.handle_message(mock_nats_message, argo_pipeline_request)
|
||||||
|
|
||||||
|
assert result["status"] == "error"
|
||||||
|
assert "Connection refused" in result["error"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_publishes_status_update(
|
||||||
|
self,
|
||||||
|
handler,
|
||||||
|
mock_nats_message,
|
||||||
|
argo_pipeline_request,
|
||||||
|
mock_argo_response,
|
||||||
|
):
|
||||||
|
"""Test that status is published to NATS."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = mock_argo_response
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
handler._http.post.return_value = mock_response
|
||||||
|
|
||||||
|
await handler.handle_message(mock_nats_message, argo_pipeline_request)
|
||||||
|
|
||||||
|
handler.nats.publish.assert_called_once()
|
||||||
|
call_args = handler.nats.publish.call_args
|
||||||
|
assert "ai.pipeline.status.test-request-123" in str(call_args)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_setup_creates_http_client(self):
|
||||||
|
"""Test that setup initializes HTTP client."""
|
||||||
|
with patch("pipeline_bridge.httpx.AsyncClient") as mock_client:
|
||||||
|
handler = PipelineBridge()
|
||||||
|
await handler.setup()
|
||||||
|
|
||||||
|
mock_client.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_teardown_closes_http_client(self, handler):
|
||||||
|
"""Test that teardown closes HTTP client."""
|
||||||
|
await handler.teardown()
|
||||||
|
|
||||||
|
handler._http.aclose.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestArgoSubmission:
|
||||||
|
"""Tests for Argo workflow submission."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def handler(self):
|
||||||
|
"""Create handler with mocked HTTP client."""
|
||||||
|
handler = PipelineBridge()
|
||||||
|
handler._http = AsyncMock()
|
||||||
|
return handler
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_argo_workflow_structure(
|
||||||
|
self,
|
||||||
|
handler,
|
||||||
|
mock_argo_response,
|
||||||
|
):
|
||||||
|
"""Test Argo workflow request structure."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = mock_argo_response
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
handler._http.post.return_value = mock_response
|
||||||
|
|
||||||
|
await handler._submit_argo(
|
||||||
|
template="document-ingestion",
|
||||||
|
parameters={"key": "value"},
|
||||||
|
request_id="test-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify workflow structure
|
||||||
|
call_kwargs = handler._http.post.call_args.kwargs
|
||||||
|
workflow = call_kwargs["json"]["workflow"]
|
||||||
|
|
||||||
|
assert workflow["apiVersion"] == "argoproj.io/v1alpha1"
|
||||||
|
assert workflow["kind"] == "Workflow"
|
||||||
|
assert "workflowTemplateRef" in workflow["spec"]
|
||||||
|
assert workflow["spec"]["workflowTemplateRef"]["name"] == "document-ingestion"
|
||||||
|
assert workflow["metadata"]["labels"]["request-id"] == "test-123"
|
||||||
|
|
||||||
|
|
||||||
|
class TestKubeflowSubmission:
|
||||||
|
"""Tests for Kubeflow pipeline submission."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def handler(self):
|
||||||
|
"""Create handler with mocked HTTP client."""
|
||||||
|
handler = PipelineBridge()
|
||||||
|
handler._http = AsyncMock()
|
||||||
|
return handler
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_kubeflow_run_structure(
|
||||||
|
self,
|
||||||
|
handler,
|
||||||
|
mock_kubeflow_response,
|
||||||
|
):
|
||||||
|
"""Test Kubeflow run request structure."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = mock_kubeflow_response
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
handler._http.post.return_value = mock_response
|
||||||
|
|
||||||
|
await handler._submit_kubeflow(
|
||||||
|
pipeline_id="rag-pipeline",
|
||||||
|
parameters={"query": "test"},
|
||||||
|
request_id="test-456",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify run request structure
|
||||||
|
call_kwargs = handler._http.post.call_args.kwargs
|
||||||
|
run_request = call_kwargs["json"]
|
||||||
|
|
||||||
|
assert "rag-pipeline" in run_request["name"]
|
||||||
|
assert run_request["pipeline_spec"]["pipeline_id"] == "rag-pipeline"
|
||||||
Reference in New Issue
Block a user