refactor: consolidate to handler-base, migrate to pyproject.toml, add tests
This commit is contained in:
@@ -1,351 +1,241 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Pipeline Bridge Service
|
||||
Pipeline Bridge Service (Refactored)
|
||||
|
||||
Bridges NATS events to workflow engines:
|
||||
Bridges NATS events to workflow engines using handler-base:
|
||||
1. Listen for pipeline triggers on "ai.pipeline.trigger"
|
||||
2. Submit to Kubeflow Pipelines or Argo Workflows
|
||||
3. Monitor execution and publish status updates
|
||||
4. Publish completion to "ai.pipeline.status.{request_id}"
|
||||
|
||||
Supported pipelines:
|
||||
- document-ingestion: Ingest documents into Milvus
|
||||
- batch-inference: Run batch LLM inference
|
||||
- model-evaluation: Evaluate model performance
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
# Install dependencies on startup
|
||||
subprocess.check_call([
|
||||
sys.executable, "-m", "pip", "install", "-q",
|
||||
"-r", "/app/requirements.txt"
|
||||
])
|
||||
|
||||
import httpx
|
||||
import nats
|
||||
from kubernetes import client, config
|
||||
from nats.aio.msg import Msg
|
||||
|
||||
from handler_base import Handler, Settings
|
||||
from handler_base.telemetry import create_span
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger("pipeline-bridge")
|
||||
|
||||
# Configuration from environment
|
||||
NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222")
|
||||
KUBEFLOW_HOST = os.environ.get("KUBEFLOW_HOST", "http://ml-pipeline.kubeflow.svc.cluster.local:8888")
|
||||
ARGO_HOST = os.environ.get("ARGO_HOST", "http://argo-server.argo.svc.cluster.local:2746")
|
||||
ARGO_NAMESPACE = os.environ.get("ARGO_NAMESPACE", "ai-ml")
|
||||
|
||||
# NATS subjects
|
||||
TRIGGER_SUBJECT = "ai.pipeline.trigger"
|
||||
STATUS_SUBJECT = "ai.pipeline.status"
|
||||
class PipelineSettings(Settings):
|
||||
"""Pipeline bridge specific settings."""
|
||||
|
||||
service_name: str = "pipeline-bridge"
|
||||
|
||||
# Kubeflow Pipelines
|
||||
kubeflow_host: str = "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
|
||||
|
||||
# Argo Workflows
|
||||
argo_host: str = "http://argo-server.argo.svc.cluster.local:2746"
|
||||
argo_namespace: str = "ai-ml"
|
||||
|
||||
# Pipeline definitions - maps pipeline names to their configurations
|
||||
|
||||
# Pipeline definitions
|
||||
PIPELINES = {
|
||||
"document-ingestion": {
|
||||
"engine": "argo",
|
||||
"template": "document-ingestion",
|
||||
"description": "Ingest documents into Milvus vector database"
|
||||
"description": "Ingest documents into Milvus vector database",
|
||||
},
|
||||
"batch-inference": {
|
||||
"engine": "argo",
|
||||
"template": "batch-inference",
|
||||
"description": "Run batch LLM inference on a dataset"
|
||||
"description": "Run batch LLM inference on a dataset",
|
||||
},
|
||||
"rag-query": {
|
||||
"engine": "kubeflow",
|
||||
"pipeline_id": "rag-pipeline",
|
||||
"description": "Execute RAG query pipeline"
|
||||
"description": "Execute RAG query pipeline",
|
||||
},
|
||||
"voice-pipeline": {
|
||||
"engine": "kubeflow",
|
||||
"pipeline_id": "voice-pipeline",
|
||||
"description": "Full voice assistant pipeline"
|
||||
}
|
||||
"description": "Full voice assistant pipeline",
|
||||
},
|
||||
"model-evaluation": {
|
||||
"engine": "argo",
|
||||
"template": "model-evaluation",
|
||||
"description": "Evaluate model performance",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class PipelineBridge:
|
||||
class PipelineBridge(Handler):
|
||||
"""
|
||||
Pipeline trigger handler.
|
||||
|
||||
Request format:
|
||||
{
|
||||
"request_id": "uuid",
|
||||
"pipeline": "document-ingestion",
|
||||
"parameters": {"key": "value"}
|
||||
}
|
||||
|
||||
Response format:
|
||||
{
|
||||
"request_id": "uuid",
|
||||
"status": "submitted",
|
||||
"run_id": "workflow-run-id",
|
||||
"engine": "argo|kubeflow"
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.nc = None
|
||||
self.http_client = None
|
||||
self.running = True
|
||||
self.active_workflows = {} # Track running workflows
|
||||
|
||||
async def setup(self):
|
||||
"""Initialize all connections."""
|
||||
# NATS connection
|
||||
self.nc = await nats.connect(NATS_URL)
|
||||
logger.info(f"Connected to NATS at {NATS_URL}")
|
||||
|
||||
# HTTP client for API calls
|
||||
self.http_client = httpx.AsyncClient(timeout=60.0)
|
||||
|
||||
# Initialize Kubernetes client for Argo
|
||||
try:
|
||||
config.load_incluster_config()
|
||||
self.k8s_custom = client.CustomObjectsApi()
|
||||
logger.info("Kubernetes client initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"Kubernetes client failed: {e}")
|
||||
self.k8s_custom = None
|
||||
|
||||
async def submit_argo_workflow(self, template: str, parameters: Dict, request_id: str) -> Optional[str]:
|
||||
"""Submit an Argo Workflow from a WorkflowTemplate."""
|
||||
if not self.k8s_custom:
|
||||
logger.error("Kubernetes client not available")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Create workflow from template
|
||||
self.pipeline_settings = PipelineSettings()
|
||||
super().__init__(
|
||||
subject="ai.pipeline.trigger",
|
||||
settings=self.pipeline_settings,
|
||||
queue_group="pipeline-bridges",
|
||||
)
|
||||
|
||||
self._http: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def setup(self) -> None:
|
||||
"""Initialize HTTP client."""
|
||||
logger.info("Initializing pipeline bridge...")
|
||||
|
||||
self._http = httpx.AsyncClient(timeout=60.0)
|
||||
|
||||
logger.info(f"Pipeline bridge ready. Available pipelines: {list(PIPELINES.keys())}")
|
||||
|
||||
async def teardown(self) -> None:
|
||||
"""Clean up HTTP client."""
|
||||
if self._http:
|
||||
await self._http.aclose()
|
||||
logger.info("Pipeline bridge closed")
|
||||
|
||||
async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]:
|
||||
"""Handle pipeline trigger request."""
|
||||
request_id = data.get("request_id", "unknown")
|
||||
pipeline_name = data.get("pipeline", "")
|
||||
parameters = data.get("parameters", {})
|
||||
|
||||
logger.info(f"Triggering pipeline '{pipeline_name}' for request {request_id}")
|
||||
|
||||
with create_span("pipeline.trigger") as span:
|
||||
if span:
|
||||
span.set_attribute("request.id", request_id)
|
||||
span.set_attribute("pipeline.name", pipeline_name)
|
||||
|
||||
# Validate pipeline
|
||||
if pipeline_name not in PIPELINES:
|
||||
error = f"Unknown pipeline: {pipeline_name}"
|
||||
logger.error(error)
|
||||
return {
|
||||
"request_id": request_id,
|
||||
"status": "error",
|
||||
"error": error,
|
||||
"available_pipelines": list(PIPELINES.keys()),
|
||||
}
|
||||
|
||||
pipeline = PIPELINES[pipeline_name]
|
||||
engine = pipeline["engine"]
|
||||
|
||||
try:
|
||||
if engine == "argo":
|
||||
run_id = await self._submit_argo(
|
||||
pipeline["template"], parameters, request_id
|
||||
)
|
||||
else:
|
||||
run_id = await self._submit_kubeflow(
|
||||
pipeline["pipeline_id"], parameters, request_id
|
||||
)
|
||||
|
||||
result = {
|
||||
"request_id": request_id,
|
||||
"status": "submitted",
|
||||
"run_id": run_id,
|
||||
"engine": engine,
|
||||
"pipeline": pipeline_name,
|
||||
"submitted_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
# Publish status update
|
||||
await self.nats.publish(
|
||||
f"ai.pipeline.status.{request_id}", result
|
||||
)
|
||||
|
||||
logger.info(f"Pipeline {pipeline_name} submitted: {run_id}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to submit pipeline {pipeline_name}")
|
||||
return {
|
||||
"request_id": request_id,
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
async def _submit_argo(
|
||||
self, template: str, parameters: dict, request_id: str
|
||||
) -> str:
|
||||
"""Submit workflow to Argo Workflows."""
|
||||
with create_span("pipeline.submit.argo") as span:
|
||||
if span:
|
||||
span.set_attribute("argo.template", template)
|
||||
|
||||
workflow = {
|
||||
"apiVersion": "argoproj.io/v1alpha1",
|
||||
"kind": "Workflow",
|
||||
"metadata": {
|
||||
"generateName": f"{template}-",
|
||||
"namespace": ARGO_NAMESPACE,
|
||||
"namespace": self.pipeline_settings.argo_namespace,
|
||||
"labels": {
|
||||
"app.kubernetes.io/managed-by": "pipeline-bridge",
|
||||
"pipeline-bridge/request-id": request_id
|
||||
}
|
||||
"request-id": request_id,
|
||||
},
|
||||
},
|
||||
"spec": {
|
||||
"workflowTemplateRef": {
|
||||
"name": template
|
||||
},
|
||||
"workflowTemplateRef": {"name": template},
|
||||
"arguments": {
|
||||
"parameters": [
|
||||
{"name": k, "value": str(v)}
|
||||
{"name": k, "value": str(v)}
|
||||
for k, v in parameters.items()
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = self.k8s_custom.create_namespaced_custom_object(
|
||||
group="argoproj.io",
|
||||
version="v1alpha1",
|
||||
namespace=ARGO_NAMESPACE,
|
||||
plural="workflows",
|
||||
body=workflow
|
||||
)
|
||||
|
||||
workflow_name = result["metadata"]["name"]
|
||||
logger.info(f"Submitted Argo workflow: {workflow_name}")
|
||||
return workflow_name
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to submit Argo workflow: {e}")
|
||||
return None
|
||||
|
||||
async def submit_kubeflow_pipeline(self, pipeline_id: str, parameters: Dict, request_id: str) -> Optional[str]:
|
||||
"""Submit a Kubeflow Pipeline run."""
|
||||
try:
|
||||
# Create pipeline run via Kubeflow API
|
||||
response = await self._http.post(
|
||||
f"{self.pipeline_settings.argo_host}/api/v1/workflows/{self.pipeline_settings.argo_namespace}",
|
||||
json={"workflow": workflow},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
return result["metadata"]["name"]
|
||||
|
||||
async def _submit_kubeflow(
|
||||
self, pipeline_id: str, parameters: dict, request_id: str
|
||||
) -> str:
|
||||
"""Submit run to Kubeflow Pipelines."""
|
||||
with create_span("pipeline.submit.kubeflow") as span:
|
||||
if span:
|
||||
span.set_attribute("kubeflow.pipeline_id", pipeline_id)
|
||||
|
||||
run_request = {
|
||||
"name": f"{pipeline_id}-{request_id[:8]}",
|
||||
"pipeline_spec": {
|
||||
"pipeline_id": pipeline_id
|
||||
"pipeline_id": pipeline_id,
|
||||
"parameters": [
|
||||
{"name": k, "value": str(v)}
|
||||
for k, v in parameters.items()
|
||||
],
|
||||
},
|
||||
"resource_references": [],
|
||||
"parameters": [
|
||||
{"name": k, "value": str(v)}
|
||||
for k, v in parameters.items()
|
||||
]
|
||||
}
|
||||
|
||||
response = await self.http_client.post(
|
||||
f"{KUBEFLOW_HOST}/apis/v1beta1/runs",
|
||||
json=run_request
|
||||
|
||||
response = await self._http.post(
|
||||
f"{self.pipeline_settings.kubeflow_host}/apis/v1beta1/runs",
|
||||
json=run_request,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
run_id = result.get("run", {}).get("id")
|
||||
logger.info(f"Submitted Kubeflow pipeline run: {run_id}")
|
||||
return run_id
|
||||
else:
|
||||
logger.error(f"Kubeflow API error: {response.status_code} - {response.text}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to submit Kubeflow pipeline: {e}")
|
||||
return None
|
||||
|
||||
async def get_argo_workflow_status(self, workflow_name: str) -> Dict:
|
||||
"""Get status of an Argo Workflow."""
|
||||
if not self.k8s_custom:
|
||||
return {"phase": "Unknown", "message": "Kubernetes client not available"}
|
||||
|
||||
try:
|
||||
result = self.k8s_custom.get_namespaced_custom_object(
|
||||
group="argoproj.io",
|
||||
version="v1alpha1",
|
||||
namespace=ARGO_NAMESPACE,
|
||||
plural="workflows",
|
||||
name=workflow_name
|
||||
)
|
||||
|
||||
status = result.get("status", {})
|
||||
return {
|
||||
"phase": status.get("phase", "Pending"),
|
||||
"message": status.get("message", ""),
|
||||
"startedAt": status.get("startedAt"),
|
||||
"finishedAt": status.get("finishedAt"),
|
||||
"progress": status.get("progress", "0/0")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get workflow status: {e}")
|
||||
return {"phase": "Error", "message": str(e)}
|
||||
|
||||
async def process_trigger(self, msg):
|
||||
"""Process a pipeline trigger request."""
|
||||
try:
|
||||
data = json.loads(msg.data.decode())
|
||||
request_id = data.get("request_id", "unknown")
|
||||
pipeline_name = data.get("pipeline", "")
|
||||
parameters = data.get("parameters", {})
|
||||
|
||||
logger.info(f"Processing pipeline trigger {request_id}: {pipeline_name}")
|
||||
|
||||
# Validate pipeline
|
||||
if pipeline_name not in PIPELINES:
|
||||
await self.publish_status(request_id, {
|
||||
"status": "error",
|
||||
"message": f"Unknown pipeline: {pipeline_name}",
|
||||
"available_pipelines": list(PIPELINES.keys())
|
||||
})
|
||||
return
|
||||
|
||||
pipeline_config = PIPELINES[pipeline_name]
|
||||
engine = pipeline_config["engine"]
|
||||
|
||||
# Submit to appropriate engine
|
||||
run_id = None
|
||||
if engine == "argo":
|
||||
run_id = await self.submit_argo_workflow(
|
||||
pipeline_config["template"],
|
||||
parameters,
|
||||
request_id
|
||||
)
|
||||
elif engine == "kubeflow":
|
||||
run_id = await self.submit_kubeflow_pipeline(
|
||||
pipeline_config["pipeline_id"],
|
||||
parameters,
|
||||
request_id
|
||||
)
|
||||
|
||||
if run_id:
|
||||
# Track workflow for status updates
|
||||
self.active_workflows[request_id] = {
|
||||
"engine": engine,
|
||||
"run_id": run_id,
|
||||
"started_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await self.publish_status(request_id, {
|
||||
"status": "submitted",
|
||||
"pipeline": pipeline_name,
|
||||
"engine": engine,
|
||||
"run_id": run_id,
|
||||
"message": f"Pipeline submitted successfully"
|
||||
})
|
||||
else:
|
||||
await self.publish_status(request_id, {
|
||||
"status": "error",
|
||||
"pipeline": pipeline_name,
|
||||
"message": "Failed to submit pipeline"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Trigger processing failed: {e}")
|
||||
await self.publish_status(
|
||||
data.get("request_id", "unknown"),
|
||||
{"status": "error", "message": str(e)}
|
||||
)
|
||||
|
||||
async def publish_status(self, request_id: str, status: Dict):
|
||||
"""Publish pipeline status update."""
|
||||
status["request_id"] = request_id
|
||||
status["timestamp"] = datetime.utcnow().isoformat()
|
||||
await self.nc.publish(
|
||||
f"{STATUS_SUBJECT}.{request_id}",
|
||||
json.dumps(status).encode()
|
||||
)
|
||||
logger.info(f"Published status for {request_id}: {status.get('status')}")
|
||||
|
||||
async def monitor_workflows(self):
|
||||
"""Periodically check and publish status of active workflows."""
|
||||
while self.running:
|
||||
completed = []
|
||||
|
||||
for request_id, workflow in self.active_workflows.items():
|
||||
try:
|
||||
if workflow["engine"] == "argo":
|
||||
status = await self.get_argo_workflow_status(workflow["run_id"])
|
||||
|
||||
# Publish status update
|
||||
await self.publish_status(request_id, {
|
||||
"status": status["phase"].lower(),
|
||||
"run_id": workflow["run_id"],
|
||||
"progress": status.get("progress"),
|
||||
"message": status.get("message", "")
|
||||
})
|
||||
|
||||
# Check if completed
|
||||
if status["phase"] in ["Succeeded", "Failed", "Error"]:
|
||||
completed.append(request_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error monitoring workflow {request_id}: {e}")
|
||||
|
||||
# Remove completed workflows from tracking
|
||||
for request_id in completed:
|
||||
del self.active_workflows[request_id]
|
||||
|
||||
await asyncio.sleep(10) # Check every 10 seconds
|
||||
|
||||
async def run(self):
|
||||
"""Main run loop."""
|
||||
await self.setup()
|
||||
|
||||
# Subscribe to pipeline triggers
|
||||
sub = await self.nc.subscribe(TRIGGER_SUBJECT, cb=self.process_trigger)
|
||||
logger.info(f"Subscribed to {TRIGGER_SUBJECT}")
|
||||
|
||||
# Start workflow monitor
|
||||
monitor_task = asyncio.create_task(self.monitor_workflows())
|
||||
|
||||
# Handle shutdown
|
||||
def signal_handler():
|
||||
self.running = False
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, signal_handler)
|
||||
|
||||
# Keep running
|
||||
while self.running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Cleanup
|
||||
monitor_task.cancel()
|
||||
await sub.unsubscribe()
|
||||
await self.nc.close()
|
||||
logger.info("Shutdown complete")
|
||||
result = response.json()
|
||||
return result["run"]["id"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bridge = PipelineBridge()
|
||||
asyncio.run(bridge.run())
|
||||
PipelineBridge().run()
|
||||
|
||||
Reference in New Issue
Block a user