refactor: consolidate to handler-base, migrate to pyproject.toml, add tests

This commit is contained in:
2026-02-02 07:11:08 -05:00
parent 50b1835688
commit 3f1e05eaad
10 changed files with 594 additions and 577 deletions

View File

@@ -1,351 +1,241 @@
#!/usr/bin/env python3
"""
Pipeline Bridge Service
Pipeline Bridge Service (Refactored)
Bridges NATS events to workflow engines:
Bridges NATS events to workflow engines using handler-base:
1. Listen for pipeline triggers on "ai.pipeline.trigger"
2. Submit to Kubeflow Pipelines or Argo Workflows
3. Monitor execution and publish status updates
4. Publish completion to "ai.pipeline.status.{request_id}"
Supported pipelines:
- document-ingestion: Ingest documents into Milvus
- batch-inference: Run batch LLM inference
- model-evaluation: Evaluate model performance
"""
import asyncio
import json
import logging
import os
import signal
import subprocess
import sys
from typing import Dict, Optional
from typing import Any, Optional
from datetime import datetime
# Install dependencies on startup
subprocess.check_call([
sys.executable, "-m", "pip", "install", "-q",
"-r", "/app/requirements.txt"
])
import httpx
import nats
from kubernetes import client, config
from nats.aio.msg import Msg
from handler_base import Handler, Settings
from handler_base.telemetry import create_span
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("pipeline-bridge")
# Configuration from environment
NATS_URL = os.environ.get("NATS_URL", "nats://nats.ai-ml.svc.cluster.local:4222")
KUBEFLOW_HOST = os.environ.get("KUBEFLOW_HOST", "http://ml-pipeline.kubeflow.svc.cluster.local:8888")
ARGO_HOST = os.environ.get("ARGO_HOST", "http://argo-server.argo.svc.cluster.local:2746")
ARGO_NAMESPACE = os.environ.get("ARGO_NAMESPACE", "ai-ml")
# NATS subjects
TRIGGER_SUBJECT = "ai.pipeline.trigger"
STATUS_SUBJECT = "ai.pipeline.status"
class PipelineSettings(Settings):
"""Pipeline bridge specific settings."""
service_name: str = "pipeline-bridge"
# Kubeflow Pipelines
kubeflow_host: str = "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
# Argo Workflows
argo_host: str = "http://argo-server.argo.svc.cluster.local:2746"
argo_namespace: str = "ai-ml"
# Pipeline definitions - maps pipeline names to their configurations
# Pipeline definitions
PIPELINES = {
"document-ingestion": {
"engine": "argo",
"template": "document-ingestion",
"description": "Ingest documents into Milvus vector database"
"description": "Ingest documents into Milvus vector database",
},
"batch-inference": {
"engine": "argo",
"template": "batch-inference",
"description": "Run batch LLM inference on a dataset"
"description": "Run batch LLM inference on a dataset",
},
"rag-query": {
"engine": "kubeflow",
"pipeline_id": "rag-pipeline",
"description": "Execute RAG query pipeline"
"description": "Execute RAG query pipeline",
},
"voice-pipeline": {
"engine": "kubeflow",
"pipeline_id": "voice-pipeline",
"description": "Full voice assistant pipeline"
}
"description": "Full voice assistant pipeline",
},
"model-evaluation": {
"engine": "argo",
"template": "model-evaluation",
"description": "Evaluate model performance",
},
}
class PipelineBridge:
class PipelineBridge(Handler):
"""
Pipeline trigger handler.
Request format:
{
"request_id": "uuid",
"pipeline": "document-ingestion",
"parameters": {"key": "value"}
}
Response format:
{
"request_id": "uuid",
"status": "submitted",
"run_id": "workflow-run-id",
"engine": "argo|kubeflow"
}
"""
def __init__(self):
self.nc = None
self.http_client = None
self.running = True
self.active_workflows = {} # Track running workflows
async def setup(self):
"""Initialize all connections."""
# NATS connection
self.nc = await nats.connect(NATS_URL)
logger.info(f"Connected to NATS at {NATS_URL}")
# HTTP client for API calls
self.http_client = httpx.AsyncClient(timeout=60.0)
# Initialize Kubernetes client for Argo
try:
config.load_incluster_config()
self.k8s_custom = client.CustomObjectsApi()
logger.info("Kubernetes client initialized")
except Exception as e:
logger.warning(f"Kubernetes client failed: {e}")
self.k8s_custom = None
async def submit_argo_workflow(self, template: str, parameters: Dict, request_id: str) -> Optional[str]:
"""Submit an Argo Workflow from a WorkflowTemplate."""
if not self.k8s_custom:
logger.error("Kubernetes client not available")
return None
try:
# Create workflow from template
self.pipeline_settings = PipelineSettings()
super().__init__(
subject="ai.pipeline.trigger",
settings=self.pipeline_settings,
queue_group="pipeline-bridges",
)
self._http: Optional[httpx.AsyncClient] = None
async def setup(self) -> None:
"""Initialize HTTP client."""
logger.info("Initializing pipeline bridge...")
self._http = httpx.AsyncClient(timeout=60.0)
logger.info(f"Pipeline bridge ready. Available pipelines: {list(PIPELINES.keys())}")
async def teardown(self) -> None:
"""Clean up HTTP client."""
if self._http:
await self._http.aclose()
logger.info("Pipeline bridge closed")
async def handle_message(self, msg: Msg, data: Any) -> Optional[dict]:
"""Handle pipeline trigger request."""
request_id = data.get("request_id", "unknown")
pipeline_name = data.get("pipeline", "")
parameters = data.get("parameters", {})
logger.info(f"Triggering pipeline '{pipeline_name}' for request {request_id}")
with create_span("pipeline.trigger") as span:
if span:
span.set_attribute("request.id", request_id)
span.set_attribute("pipeline.name", pipeline_name)
# Validate pipeline
if pipeline_name not in PIPELINES:
error = f"Unknown pipeline: {pipeline_name}"
logger.error(error)
return {
"request_id": request_id,
"status": "error",
"error": error,
"available_pipelines": list(PIPELINES.keys()),
}
pipeline = PIPELINES[pipeline_name]
engine = pipeline["engine"]
try:
if engine == "argo":
run_id = await self._submit_argo(
pipeline["template"], parameters, request_id
)
else:
run_id = await self._submit_kubeflow(
pipeline["pipeline_id"], parameters, request_id
)
result = {
"request_id": request_id,
"status": "submitted",
"run_id": run_id,
"engine": engine,
"pipeline": pipeline_name,
"submitted_at": datetime.utcnow().isoformat(),
}
# Publish status update
await self.nats.publish(
f"ai.pipeline.status.{request_id}", result
)
logger.info(f"Pipeline {pipeline_name} submitted: {run_id}")
return result
except Exception as e:
logger.exception(f"Failed to submit pipeline {pipeline_name}")
return {
"request_id": request_id,
"status": "error",
"error": str(e),
}
async def _submit_argo(
self, template: str, parameters: dict, request_id: str
) -> str:
"""Submit workflow to Argo Workflows."""
with create_span("pipeline.submit.argo") as span:
if span:
span.set_attribute("argo.template", template)
workflow = {
"apiVersion": "argoproj.io/v1alpha1",
"kind": "Workflow",
"metadata": {
"generateName": f"{template}-",
"namespace": ARGO_NAMESPACE,
"namespace": self.pipeline_settings.argo_namespace,
"labels": {
"app.kubernetes.io/managed-by": "pipeline-bridge",
"pipeline-bridge/request-id": request_id
}
"request-id": request_id,
},
},
"spec": {
"workflowTemplateRef": {
"name": template
},
"workflowTemplateRef": {"name": template},
"arguments": {
"parameters": [
{"name": k, "value": str(v)}
{"name": k, "value": str(v)}
for k, v in parameters.items()
]
}
}
},
},
}
result = self.k8s_custom.create_namespaced_custom_object(
group="argoproj.io",
version="v1alpha1",
namespace=ARGO_NAMESPACE,
plural="workflows",
body=workflow
)
workflow_name = result["metadata"]["name"]
logger.info(f"Submitted Argo workflow: {workflow_name}")
return workflow_name
except Exception as e:
logger.error(f"Failed to submit Argo workflow: {e}")
return None
async def submit_kubeflow_pipeline(self, pipeline_id: str, parameters: Dict, request_id: str) -> Optional[str]:
"""Submit a Kubeflow Pipeline run."""
try:
# Create pipeline run via Kubeflow API
response = await self._http.post(
f"{self.pipeline_settings.argo_host}/api/v1/workflows/{self.pipeline_settings.argo_namespace}",
json={"workflow": workflow},
)
response.raise_for_status()
result = response.json()
return result["metadata"]["name"]
async def _submit_kubeflow(
self, pipeline_id: str, parameters: dict, request_id: str
) -> str:
"""Submit run to Kubeflow Pipelines."""
with create_span("pipeline.submit.kubeflow") as span:
if span:
span.set_attribute("kubeflow.pipeline_id", pipeline_id)
run_request = {
"name": f"{pipeline_id}-{request_id[:8]}",
"pipeline_spec": {
"pipeline_id": pipeline_id
"pipeline_id": pipeline_id,
"parameters": [
{"name": k, "value": str(v)}
for k, v in parameters.items()
],
},
"resource_references": [],
"parameters": [
{"name": k, "value": str(v)}
for k, v in parameters.items()
]
}
response = await self.http_client.post(
f"{KUBEFLOW_HOST}/apis/v1beta1/runs",
json=run_request
response = await self._http.post(
f"{self.pipeline_settings.kubeflow_host}/apis/v1beta1/runs",
json=run_request,
)
response.raise_for_status()
if response.status_code == 200:
result = response.json()
run_id = result.get("run", {}).get("id")
logger.info(f"Submitted Kubeflow pipeline run: {run_id}")
return run_id
else:
logger.error(f"Kubeflow API error: {response.status_code} - {response.text}")
return None
except Exception as e:
logger.error(f"Failed to submit Kubeflow pipeline: {e}")
return None
async def get_argo_workflow_status(self, workflow_name: str) -> Dict:
"""Get status of an Argo Workflow."""
if not self.k8s_custom:
return {"phase": "Unknown", "message": "Kubernetes client not available"}
try:
result = self.k8s_custom.get_namespaced_custom_object(
group="argoproj.io",
version="v1alpha1",
namespace=ARGO_NAMESPACE,
plural="workflows",
name=workflow_name
)
status = result.get("status", {})
return {
"phase": status.get("phase", "Pending"),
"message": status.get("message", ""),
"startedAt": status.get("startedAt"),
"finishedAt": status.get("finishedAt"),
"progress": status.get("progress", "0/0")
}
except Exception as e:
logger.error(f"Failed to get workflow status: {e}")
return {"phase": "Error", "message": str(e)}
async def process_trigger(self, msg):
"""Process a pipeline trigger request."""
try:
data = json.loads(msg.data.decode())
request_id = data.get("request_id", "unknown")
pipeline_name = data.get("pipeline", "")
parameters = data.get("parameters", {})
logger.info(f"Processing pipeline trigger {request_id}: {pipeline_name}")
# Validate pipeline
if pipeline_name not in PIPELINES:
await self.publish_status(request_id, {
"status": "error",
"message": f"Unknown pipeline: {pipeline_name}",
"available_pipelines": list(PIPELINES.keys())
})
return
pipeline_config = PIPELINES[pipeline_name]
engine = pipeline_config["engine"]
# Submit to appropriate engine
run_id = None
if engine == "argo":
run_id = await self.submit_argo_workflow(
pipeline_config["template"],
parameters,
request_id
)
elif engine == "kubeflow":
run_id = await self.submit_kubeflow_pipeline(
pipeline_config["pipeline_id"],
parameters,
request_id
)
if run_id:
# Track workflow for status updates
self.active_workflows[request_id] = {
"engine": engine,
"run_id": run_id,
"started_at": datetime.utcnow().isoformat()
}
await self.publish_status(request_id, {
"status": "submitted",
"pipeline": pipeline_name,
"engine": engine,
"run_id": run_id,
"message": f"Pipeline submitted successfully"
})
else:
await self.publish_status(request_id, {
"status": "error",
"pipeline": pipeline_name,
"message": "Failed to submit pipeline"
})
except Exception as e:
logger.error(f"Trigger processing failed: {e}")
await self.publish_status(
data.get("request_id", "unknown"),
{"status": "error", "message": str(e)}
)
async def publish_status(self, request_id: str, status: Dict):
"""Publish pipeline status update."""
status["request_id"] = request_id
status["timestamp"] = datetime.utcnow().isoformat()
await self.nc.publish(
f"{STATUS_SUBJECT}.{request_id}",
json.dumps(status).encode()
)
logger.info(f"Published status for {request_id}: {status.get('status')}")
async def monitor_workflows(self):
"""Periodically check and publish status of active workflows."""
while self.running:
completed = []
for request_id, workflow in self.active_workflows.items():
try:
if workflow["engine"] == "argo":
status = await self.get_argo_workflow_status(workflow["run_id"])
# Publish status update
await self.publish_status(request_id, {
"status": status["phase"].lower(),
"run_id": workflow["run_id"],
"progress": status.get("progress"),
"message": status.get("message", "")
})
# Check if completed
if status["phase"] in ["Succeeded", "Failed", "Error"]:
completed.append(request_id)
except Exception as e:
logger.error(f"Error monitoring workflow {request_id}: {e}")
# Remove completed workflows from tracking
for request_id in completed:
del self.active_workflows[request_id]
await asyncio.sleep(10) # Check every 10 seconds
async def run(self):
"""Main run loop."""
await self.setup()
# Subscribe to pipeline triggers
sub = await self.nc.subscribe(TRIGGER_SUBJECT, cb=self.process_trigger)
logger.info(f"Subscribed to {TRIGGER_SUBJECT}")
# Start workflow monitor
monitor_task = asyncio.create_task(self.monitor_workflows())
# Handle shutdown
def signal_handler():
self.running = False
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
# Keep running
while self.running:
await asyncio.sleep(1)
# Cleanup
monitor_task.cancel()
await sub.unsubscribe()
await self.nc.close()
logger.info("Shutdown complete")
result = response.json()
return result["run"]["id"]
if __name__ == "__main__":
bridge = PipelineBridge()
asyncio.run(bridge.run())
PipelineBridge().run()