""" MLflow Tracker for Kubeflow Pipelines Provides a high-level interface for logging experiments, parameters, metrics, and artifacts from Kubeflow Pipeline components. """ import os import json import time import logging from pathlib import Path from typing import Optional, Dict, Any, List, Union from contextlib import contextmanager from dataclasses import dataclass, field import mlflow from mlflow.tracking import MlflowClient from .client import get_mlflow_client, ensure_experiment, MLflowConfig logger = logging.getLogger(__name__) @dataclass class PipelineMetadata: """Metadata about the Kubeflow Pipeline run.""" pipeline_name: str run_id: str run_name: Optional[str] = None component_name: Optional[str] = None namespace: str = "ai-ml" # KFP-specific metadata (populated from environment if available) kfp_run_id: Optional[str] = field( default_factory=lambda: os.environ.get("KFP_RUN_ID") ) kfp_pod_name: Optional[str] = field( default_factory=lambda: os.environ.get("KFP_POD_NAME") ) def as_tags(self) -> Dict[str, str]: """Convert metadata to MLflow tags.""" tags = { "pipeline.name": self.pipeline_name, "pipeline.run_id": self.run_id, "pipeline.namespace": self.namespace, } if self.run_name: tags["pipeline.run_name"] = self.run_name if self.component_name: tags["pipeline.component"] = self.component_name if self.kfp_run_id: tags["kfp.run_id"] = self.kfp_run_id if self.kfp_pod_name: tags["kfp.pod_name"] = self.kfp_pod_name return tags class MLflowTracker: """ MLflow experiment tracker for Kubeflow Pipeline components. Example usage in a KFP component: from mlflow_utils import MLflowTracker tracker = MLflowTracker( experiment_name="document-ingestion", run_name="batch-ingestion-2024-01" ) with tracker.start_run() as run: tracker.log_params({ "chunk_size": 500, "overlap": 50, "embeddings_model": "bge-small-en-v1.5" }) # ... do work ... tracker.log_metrics({ "documents_processed": 100, "chunks_created": 2500, "processing_time_seconds": 120.5 }) tracker.log_artifact("/path/to/output.json") """ def __init__( self, experiment_name: str, run_name: Optional[str] = None, pipeline_metadata: Optional[PipelineMetadata] = None, tags: Optional[Dict[str, str]] = None, tracking_uri: Optional[str] = None, ): """ Initialize the MLflow tracker. Args: experiment_name: Name of the MLflow experiment run_name: Optional name for this run pipeline_metadata: Metadata about the KFP pipeline tags: Additional tags to add to the run tracking_uri: Override default tracking URI """ self.config = MLflowConfig() self.experiment_name = experiment_name self.run_name = run_name or f"{experiment_name}-{int(time.time())}" self.pipeline_metadata = pipeline_metadata self.user_tags = tags or {} self.tracking_uri = tracking_uri self.client: Optional[MlflowClient] = None self.run: Optional[mlflow.ActiveRun] = None self.run_id: Optional[str] = None self._start_time: Optional[float] = None def _get_all_tags(self) -> Dict[str, str]: """Combine all tags for the run.""" tags = self.config.default_tags.copy() if self.pipeline_metadata: tags.update(self.pipeline_metadata.as_tags()) tags.update(self.user_tags) return tags @contextmanager def start_run( self, nested: bool = False, parent_run_id: Optional[str] = None, ): """ Start an MLflow run as a context manager. Args: nested: If True, create a nested run under the current active run parent_run_id: Explicit parent run ID for nested runs Yields: The MLflow run object """ self.client = get_mlflow_client( tracking_uri=self.tracking_uri, configure_global=True ) # Ensure experiment exists experiment_id = ensure_experiment(self.experiment_name) self._start_time = time.time() try: # Start the run self.run = mlflow.start_run( experiment_id=experiment_id, run_name=self.run_name, nested=nested, tags=self._get_all_tags(), ) self.run_id = self.run.info.run_id logger.info( f"Started MLflow run '{self.run_name}' " f"(ID: {self.run_id}) in experiment '{self.experiment_name}'" ) yield self.run except Exception as e: logger.error(f"MLflow run failed: {e}") if self.run: mlflow.set_tag("run.status", "failed") mlflow.set_tag("run.error", str(e)) raise finally: # Log duration if self._start_time: duration = time.time() - self._start_time try: mlflow.log_metric("run_duration_seconds", duration) except Exception: pass # End the run mlflow.end_run() logger.info(f"Ended MLflow run '{self.run_name}'") def log_params(self, params: Dict[str, Any]) -> None: """ Log parameters to the current run. Args: params: Dictionary of parameter names to values """ if not self.run: logger.warning("No active run, skipping log_params") return # MLflow has limits on param values, truncate if needed cleaned_params = {} for key, value in params.items(): str_value = str(value) if len(str_value) > 500: str_value = str_value[:497] + "..." cleaned_params[key] = str_value mlflow.log_params(cleaned_params) logger.debug(f"Logged {len(params)} parameters") def log_param(self, key: str, value: Any) -> None: """Log a single parameter.""" self.log_params({key: value}) def log_metrics( self, metrics: Dict[str, Union[float, int]], step: Optional[int] = None ) -> None: """ Log metrics to the current run. Args: metrics: Dictionary of metric names to values step: Optional step number for time-series metrics """ if not self.run: logger.warning("No active run, skipping log_metrics") return mlflow.log_metrics(metrics, step=step) logger.debug(f"Logged {len(metrics)} metrics") def log_metric( self, key: str, value: Union[float, int], step: Optional[int] = None ) -> None: """Log a single metric.""" self.log_metrics({key: value}, step=step) def log_artifact( self, local_path: str, artifact_path: Optional[str] = None ) -> None: """ Log an artifact file to the current run. Args: local_path: Path to the local file to log artifact_path: Optional destination path within the artifact store """ if not self.run: logger.warning("No active run, skipping log_artifact") return mlflow.log_artifact(local_path, artifact_path) logger.info(f"Logged artifact: {local_path}") def log_artifacts( self, local_dir: str, artifact_path: Optional[str] = None ) -> None: """ Log all files in a directory as artifacts. Args: local_dir: Path to the local directory artifact_path: Optional destination path within the artifact store """ if not self.run: logger.warning("No active run, skipping log_artifacts") return mlflow.log_artifacts(local_dir, artifact_path) logger.info(f"Logged artifacts from: {local_dir}") def log_dict( self, data: Dict[str, Any], filename: str, artifact_path: Optional[str] = None ) -> None: """ Log a dictionary as a JSON artifact. Args: data: Dictionary to log filename: Name for the JSON file artifact_path: Optional destination path """ if not self.run: logger.warning("No active run, skipping log_dict") return # Ensure .json extension if not filename.endswith(".json"): filename += ".json" mlflow.log_dict(data, f"{artifact_path}/{filename}" if artifact_path else filename) logger.debug(f"Logged dict as: {filename}") def log_model_info( self, model_type: str, model_name: str, model_path: Optional[str] = None, framework: str = "pytorch", extra_info: Optional[Dict[str, Any]] = None ) -> None: """ Log model information as parameters and tags. Args: model_type: Type of model (e.g., "llm", "embedding", "stt") model_name: Name/identifier of the model model_path: Path to model weights framework: ML framework used extra_info: Additional model information """ params = { "model.type": model_type, "model.name": model_name, "model.framework": framework, } if model_path: params["model.path"] = model_path if extra_info: for key, value in extra_info.items(): params[f"model.{key}"] = value self.log_params(params) # Also set as tags for easier filtering mlflow.set_tag("model.type", model_type) mlflow.set_tag("model.name", model_name) def log_dataset_info( self, name: str, source: str, size: Optional[int] = None, extra_info: Optional[Dict[str, Any]] = None ) -> None: """ Log dataset information. Args: name: Dataset name source: Dataset source (URL, path, etc.) size: Number of samples extra_info: Additional dataset information """ params = { "dataset.name": name, "dataset.source": source, } if size is not None: params["dataset.size"] = size if extra_info: for key, value in extra_info.items(): params[f"dataset.{key}"] = value self.log_params(params) def set_tag(self, key: str, value: str) -> None: """Set a single tag on the run.""" if self.run: mlflow.set_tag(key, value) def set_tags(self, tags: Dict[str, str]) -> None: """Set multiple tags on the run.""" if self.run: mlflow.set_tags(tags) @property def artifact_uri(self) -> Optional[str]: """Get the artifact URI for the current run.""" if self.run: return self.run.info.artifact_uri return None @property def experiment_id(self) -> Optional[str]: """Get the experiment ID for the current run.""" if self.run: return self.run.info.experiment_id return None