""" Kubeflow Pipeline Components with MLflow Tracking Provides reusable KFP components that integrate MLflow experiment tracking into Kubeflow Pipelines. These components can be used directly in pipelines or as wrappers around existing pipeline steps. Usage in a Kubeflow Pipeline: from mlflow_utils.kfp_components import ( create_mlflow_run, log_metrics_component, log_model_artifact, end_mlflow_run, ) @dsl.pipeline(name="my-pipeline") def my_pipeline(): # Start MLflow run run_info = create_mlflow_run( experiment_name="my-experiment", run_name="training-run-1" ) # ... your pipeline steps ... # Log metrics log_step = log_metrics_component( run_id=run_info.outputs["run_id"], metrics={"accuracy": 0.95, "loss": 0.05} ) # End run end_mlflow_run(run_id=run_info.outputs["run_id"]) """ from typing import Any, Dict, List, NamedTuple from kfp import dsl # MLflow component image with all required dependencies MLFLOW_IMAGE = "python:3.13-slim" MLFLOW_PACKAGES = [ "mlflow>=2.10.0", "boto3", # For S3 artifact storage if needed "psycopg2-binary", # For PostgreSQL backend ] @dsl.component( base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES ) def create_mlflow_run( experiment_name: str, run_name: str, mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", tags: Dict[str, str] = None, params: Dict[str, str] = None, ) -> NamedTuple('RunInfo', [('run_id', str), ('experiment_id', str), ('artifact_uri', str)]): """ Create a new MLflow run for the pipeline. This should be called at the start of a pipeline to initialize tracking. The returned run_id should be passed to subsequent components for logging. Args: experiment_name: Name of the MLflow experiment run_name: Name for this specific run mlflow_tracking_uri: MLflow tracking server URI tags: Optional tags to add to the run params: Optional parameters to log Returns: NamedTuple with run_id, experiment_id, and artifact_uri """ import os from collections import namedtuple import mlflow from mlflow.tracking import MlflowClient # Set tracking URI mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() # Get or create experiment experiment = client.get_experiment_by_name(experiment_name) if experiment is None: experiment_id = client.create_experiment( name=experiment_name, artifact_location=f"/mlflow/artifacts/{experiment_name}" ) else: experiment_id = experiment.experiment_id # Create default tags default_tags = { "pipeline.type": "kubeflow", "kfp.run_id": os.environ.get("KFP_RUN_ID", "unknown"), "kfp.pod_name": os.environ.get("HOSTNAME", "unknown"), } if tags: default_tags.update(tags) # Start run run = mlflow.start_run( experiment_id=experiment_id, run_name=run_name, tags=default_tags, ) # Log initial params if params: mlflow.log_params(params) run_id = run.info.run_id artifact_uri = run.info.artifact_uri # End run (KFP components are isolated, we'll resume in other components) mlflow.end_run() RunInfo = namedtuple('RunInfo', ['run_id', 'experiment_id', 'artifact_uri']) return RunInfo(run_id, experiment_id, artifact_uri) @dsl.component( base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES ) def log_params_component( run_id: str, params: Dict[str, str], mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", ) -> str: """ Log parameters to an existing MLflow run. Args: run_id: The MLflow run ID to log to params: Dictionary of parameters to log mlflow_tracking_uri: MLflow tracking server URI Returns: The run_id for chaining """ import mlflow from mlflow.tracking import MlflowClient mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() for key, value in params.items(): client.log_param(run_id, key, str(value)[:500]) return run_id @dsl.component( base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES ) def log_metrics_component( run_id: str, metrics: Dict[str, float], step: int = 0, mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", ) -> str: """ Log metrics to an existing MLflow run. Args: run_id: The MLflow run ID to log to metrics: Dictionary of metrics to log step: Step number for time-series metrics mlflow_tracking_uri: MLflow tracking server URI Returns: The run_id for chaining """ import mlflow from mlflow.tracking import MlflowClient mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() for key, value in metrics.items(): client.log_metric(run_id, key, float(value), step=step) return run_id @dsl.component( base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES ) def log_artifact_component( run_id: str, artifact_path: str, artifact_name: str = "", mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", ) -> str: """ Log an artifact file to an existing MLflow run. Args: run_id: The MLflow run ID to log to artifact_path: Path to the artifact file artifact_name: Optional destination name in artifact store mlflow_tracking_uri: MLflow tracking server URI Returns: The run_id for chaining """ import mlflow from mlflow.tracking import MlflowClient mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() client.log_artifact(run_id, artifact_path, artifact_name or None) return run_id @dsl.component( base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES ) def log_dict_artifact( run_id: str, data: Dict[str, Any], filename: str, mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", ) -> str: """ Log a dictionary as a JSON artifact. Args: run_id: The MLflow run ID to log to data: Dictionary to save as JSON filename: Name for the JSON file mlflow_tracking_uri: MLflow tracking server URI Returns: The run_id for chaining """ import json import tempfile from pathlib import Path import mlflow from mlflow.tracking import MlflowClient mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() # Ensure .json extension if not filename.endswith('.json'): filename += '.json' # Write to temp file and log with tempfile.TemporaryDirectory() as tmpdir: filepath = Path(tmpdir) / filename with open(filepath, 'w') as f: json.dump(data, f, indent=2) client.log_artifact(run_id, str(filepath)) return run_id @dsl.component( base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES ) def end_mlflow_run( run_id: str, status: str = "FINISHED", mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", ) -> str: """ End an MLflow run with the specified status. Args: run_id: The MLflow run ID to end status: Run status (FINISHED, FAILED, KILLED) mlflow_tracking_uri: MLflow tracking server URI Returns: The run_id """ import mlflow from mlflow.entities import RunStatus from mlflow.tracking import MlflowClient mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() status_map = { "FINISHED": RunStatus.FINISHED, "FAILED": RunStatus.FAILED, "KILLED": RunStatus.KILLED, } run_status = status_map.get(status.upper(), RunStatus.FINISHED) client.set_terminated(run_id, status=run_status) return run_id @dsl.component( base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES + ["httpx"] ) def log_training_metrics( run_id: str, model_type: str, training_config: Dict[str, Any], final_metrics: Dict[str, float], model_path: str = "", mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", ) -> str: """ Log comprehensive training metrics for ML models. Designed for use with QLoRA training, voice training, and other ML training pipelines in the llm-workflows repository. Args: run_id: The MLflow run ID to log to model_type: Type of model (llm, stt, tts, embeddings) training_config: Training configuration dict final_metrics: Final training metrics model_path: Path to saved model (if applicable) mlflow_tracking_uri: MLflow tracking server URI Returns: The run_id for chaining """ import json import tempfile from pathlib import Path import mlflow from mlflow.tracking import MlflowClient mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() # Log training config as params flat_config = {} for key, value in training_config.items(): if isinstance(value, (dict, list)): flat_config[f"config.{key}"] = json.dumps(value)[:500] else: flat_config[f"config.{key}"] = str(value)[:500] for key, value in flat_config.items(): client.log_param(run_id, key, value) # Log model type tag client.set_tag(run_id, "model.type", model_type) # Log metrics for key, value in final_metrics.items(): client.log_metric(run_id, key, float(value)) # Log full config as artifact with tempfile.TemporaryDirectory() as tmpdir: config_path = Path(tmpdir) / "training_config.json" with open(config_path, 'w') as f: json.dump(training_config, f, indent=2) client.log_artifact(run_id, str(config_path)) # Log model path if provided if model_path: client.log_param(run_id, "model.path", model_path) client.set_tag(run_id, "model.saved", "true") return run_id @dsl.component( base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES ) def log_document_ingestion_metrics( run_id: str, source_url: str, collection_name: str, chunks_created: int, documents_processed: int, processing_time_seconds: float, embeddings_model: str = "bge-small-en-v1.5", chunk_size: int = 500, chunk_overlap: int = 50, mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", ) -> str: """ Log document ingestion pipeline metrics. Designed for use with the document_ingestion_pipeline. Args: run_id: The MLflow run ID to log to source_url: URL of the source document collection_name: Milvus collection name chunks_created: Number of chunks created documents_processed: Number of documents processed processing_time_seconds: Total processing time embeddings_model: Embeddings model used chunk_size: Chunk size in tokens chunk_overlap: Chunk overlap in tokens mlflow_tracking_uri: MLflow tracking server URI Returns: The run_id for chaining """ import mlflow from mlflow.tracking import MlflowClient mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() # Log params params = { "source_url": source_url[:500], "collection_name": collection_name, "embeddings_model": embeddings_model, "chunk_size": str(chunk_size), "chunk_overlap": str(chunk_overlap), } for key, value in params.items(): client.log_param(run_id, key, value) # Log metrics metrics = { "chunks_created": chunks_created, "documents_processed": documents_processed, "processing_time_seconds": processing_time_seconds, "chunks_per_second": chunks_created / processing_time_seconds if processing_time_seconds > 0 else 0, } for key, value in metrics.items(): client.log_metric(run_id, key, float(value)) # Set pipeline type tag client.set_tag(run_id, "pipeline.type", "document-ingestion") client.set_tag(run_id, "milvus.collection", collection_name) return run_id @dsl.component( base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES ) def log_evaluation_results( run_id: str, model_name: str, dataset_name: str, metrics: Dict[str, float], sample_results: List[Dict[str, Any]] = None, mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", ) -> str: """ Log model evaluation results. Designed for use with the evaluation_pipeline. Args: run_id: The MLflow run ID to log to model_name: Name of the evaluated model dataset_name: Name of the evaluation dataset metrics: Evaluation metrics (accuracy, etc.) sample_results: Optional sample predictions mlflow_tracking_uri: MLflow tracking server URI Returns: The run_id for chaining """ import json import tempfile from pathlib import Path import mlflow from mlflow.tracking import MlflowClient mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() # Log params client.log_param(run_id, "eval.model_name", model_name) client.log_param(run_id, "eval.dataset", dataset_name) # Log metrics for key, value in metrics.items(): client.log_metric(run_id, f"eval.{key}", float(value)) # Log sample results as artifact if sample_results: with tempfile.TemporaryDirectory() as tmpdir: results_path = Path(tmpdir) / "evaluation_results.json" with open(results_path, 'w') as f: json.dump(sample_results, f, indent=2) client.log_artifact(run_id, str(results_path)) # Set tags client.set_tag(run_id, "pipeline.type", "evaluation") client.set_tag(run_id, "model.name", model_name) # Determine if passed passed = metrics.get("pass", metrics.get("accuracy", 0) >= 0.7) client.set_tag(run_id, "eval.passed", str(passed)) return run_id