Files
mlflow/mlflow_utils/tracker.py
Billy D. ca5bef9664
Some checks failed
CI / Lint (push) Successful in 1m46s
CI / Test (push) Successful in 1m44s
CI / Publish (push) Failing after 19s
CI / Notify (push) Successful in 1s
style: apply ruff format to all files
2026-02-13 11:05:26 -05:00

361 lines
11 KiB
Python

"""
MLflow Tracker for Kubeflow Pipelines
Provides a high-level interface for logging experiments, parameters,
metrics, and artifacts from Kubeflow Pipeline components.
"""
import logging
import os
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Union
import mlflow
from mlflow.tracking import MlflowClient
from .client import MLflowConfig, ensure_experiment, get_mlflow_client
logger = logging.getLogger(__name__)
@dataclass
class PipelineMetadata:
"""Metadata about the Kubeflow Pipeline run."""
pipeline_name: str
run_id: str
run_name: Optional[str] = None
component_name: Optional[str] = None
namespace: str = "ai-ml"
# KFP-specific metadata (populated from environment if available)
kfp_run_id: Optional[str] = field(default_factory=lambda: os.environ.get("KFP_RUN_ID"))
kfp_pod_name: Optional[str] = field(default_factory=lambda: os.environ.get("KFP_POD_NAME"))
def as_tags(self) -> Dict[str, str]:
"""Convert metadata to MLflow tags."""
tags = {
"pipeline.name": self.pipeline_name,
"pipeline.run_id": self.run_id,
"pipeline.namespace": self.namespace,
}
if self.run_name:
tags["pipeline.run_name"] = self.run_name
if self.component_name:
tags["pipeline.component"] = self.component_name
if self.kfp_run_id:
tags["kfp.run_id"] = self.kfp_run_id
if self.kfp_pod_name:
tags["kfp.pod_name"] = self.kfp_pod_name
return tags
class MLflowTracker:
"""
MLflow experiment tracker for Kubeflow Pipeline components.
Example usage in a KFP component:
from mlflow_utils import MLflowTracker
tracker = MLflowTracker(
experiment_name="document-ingestion",
run_name="batch-ingestion-2024-01"
)
with tracker.start_run() as run:
tracker.log_params({
"chunk_size": 500,
"overlap": 50,
"embeddings_model": "bge-small-en-v1.5"
})
# ... do work ...
tracker.log_metrics({
"documents_processed": 100,
"chunks_created": 2500,
"processing_time_seconds": 120.5
})
tracker.log_artifact("/path/to/output.json")
"""
def __init__(
self,
experiment_name: str,
run_name: Optional[str] = None,
pipeline_metadata: Optional[PipelineMetadata] = None,
tags: Optional[Dict[str, str]] = None,
tracking_uri: Optional[str] = None,
):
"""
Initialize the MLflow tracker.
Args:
experiment_name: Name of the MLflow experiment
run_name: Optional name for this run
pipeline_metadata: Metadata about the KFP pipeline
tags: Additional tags to add to the run
tracking_uri: Override default tracking URI
"""
self.config = MLflowConfig()
self.experiment_name = experiment_name
self.run_name = run_name or f"{experiment_name}-{int(time.time())}"
self.pipeline_metadata = pipeline_metadata
self.user_tags = tags or {}
self.tracking_uri = tracking_uri
self.client: Optional[MlflowClient] = None
self.run: Optional[mlflow.ActiveRun] = None
self.run_id: Optional[str] = None
self._start_time: Optional[float] = None
def _get_all_tags(self) -> Dict[str, str]:
"""Combine all tags for the run."""
tags = self.config.default_tags.copy()
if self.pipeline_metadata:
tags.update(self.pipeline_metadata.as_tags())
tags.update(self.user_tags)
return tags
@contextmanager
def start_run(
self,
nested: bool = False,
parent_run_id: Optional[str] = None,
):
"""
Start an MLflow run as a context manager.
Args:
nested: If True, create a nested run under the current active run
parent_run_id: Explicit parent run ID for nested runs
Yields:
The MLflow run object
"""
self.client = get_mlflow_client(tracking_uri=self.tracking_uri, configure_global=True)
# Ensure experiment exists
experiment_id = ensure_experiment(self.experiment_name)
self._start_time = time.time()
try:
# Start the run
self.run = mlflow.start_run(
experiment_id=experiment_id,
run_name=self.run_name,
nested=nested,
tags=self._get_all_tags(),
)
self.run_id = self.run.info.run_id
logger.info(
f"Started MLflow run '{self.run_name}' (ID: {self.run_id}) in experiment '{self.experiment_name}'"
)
yield self.run
except Exception as e:
logger.error(f"MLflow run failed: {e}")
if self.run:
mlflow.set_tag("run.status", "failed")
mlflow.set_tag("run.error", str(e))
raise
finally:
# Log duration
if self._start_time:
duration = time.time() - self._start_time
try:
mlflow.log_metric("run_duration_seconds", duration)
except Exception:
pass
# End the run
mlflow.end_run()
logger.info(f"Ended MLflow run '{self.run_name}'")
def log_params(self, params: Dict[str, Any]) -> None:
"""
Log parameters to the current run.
Args:
params: Dictionary of parameter names to values
"""
if not self.run:
logger.warning("No active run, skipping log_params")
return
# MLflow has limits on param values, truncate if needed
cleaned_params = {}
for key, value in params.items():
str_value = str(value)
if len(str_value) > 500:
str_value = str_value[:497] + "..."
cleaned_params[key] = str_value
mlflow.log_params(cleaned_params)
logger.debug(f"Logged {len(params)} parameters")
def log_param(self, key: str, value: Any) -> None:
"""Log a single parameter."""
self.log_params({key: value})
def log_metrics(self, metrics: Dict[str, Union[float, int]], step: Optional[int] = None) -> None:
"""
Log metrics to the current run.
Args:
metrics: Dictionary of metric names to values
step: Optional step number for time-series metrics
"""
if not self.run:
logger.warning("No active run, skipping log_metrics")
return
mlflow.log_metrics(metrics, step=step)
logger.debug(f"Logged {len(metrics)} metrics")
def log_metric(self, key: str, value: Union[float, int], step: Optional[int] = None) -> None:
"""Log a single metric."""
self.log_metrics({key: value}, step=step)
def log_artifact(self, local_path: str, artifact_path: Optional[str] = None) -> None:
"""
Log an artifact file to the current run.
Args:
local_path: Path to the local file to log
artifact_path: Optional destination path within the artifact store
"""
if not self.run:
logger.warning("No active run, skipping log_artifact")
return
mlflow.log_artifact(local_path, artifact_path)
logger.info(f"Logged artifact: {local_path}")
def log_artifacts(self, local_dir: str, artifact_path: Optional[str] = None) -> None:
"""
Log all files in a directory as artifacts.
Args:
local_dir: Path to the local directory
artifact_path: Optional destination path within the artifact store
"""
if not self.run:
logger.warning("No active run, skipping log_artifacts")
return
mlflow.log_artifacts(local_dir, artifact_path)
logger.info(f"Logged artifacts from: {local_dir}")
def log_dict(self, data: Dict[str, Any], filename: str, artifact_path: Optional[str] = None) -> None:
"""
Log a dictionary as a JSON artifact.
Args:
data: Dictionary to log
filename: Name for the JSON file
artifact_path: Optional destination path
"""
if not self.run:
logger.warning("No active run, skipping log_dict")
return
# Ensure .json extension
if not filename.endswith(".json"):
filename += ".json"
mlflow.log_dict(data, f"{artifact_path}/{filename}" if artifact_path else filename)
logger.debug(f"Logged dict as: {filename}")
def log_model_info(
self,
model_type: str,
model_name: str,
model_path: Optional[str] = None,
framework: str = "pytorch",
extra_info: Optional[Dict[str, Any]] = None,
) -> None:
"""
Log model information as parameters and tags.
Args:
model_type: Type of model (e.g., "llm", "embedding", "stt")
model_name: Name/identifier of the model
model_path: Path to model weights
framework: ML framework used
extra_info: Additional model information
"""
params = {
"model.type": model_type,
"model.name": model_name,
"model.framework": framework,
}
if model_path:
params["model.path"] = model_path
if extra_info:
for key, value in extra_info.items():
params[f"model.{key}"] = value
self.log_params(params)
# Also set as tags for easier filtering
mlflow.set_tag("model.type", model_type)
mlflow.set_tag("model.name", model_name)
def log_dataset_info(
self, name: str, source: str, size: Optional[int] = None, extra_info: Optional[Dict[str, Any]] = None
) -> None:
"""
Log dataset information.
Args:
name: Dataset name
source: Dataset source (URL, path, etc.)
size: Number of samples
extra_info: Additional dataset information
"""
params = {
"dataset.name": name,
"dataset.source": source,
}
if size is not None:
params["dataset.size"] = size
if extra_info:
for key, value in extra_info.items():
params[f"dataset.{key}"] = value
self.log_params(params)
def set_tag(self, key: str, value: str) -> None:
"""Set a single tag on the run."""
if self.run:
mlflow.set_tag(key, value)
def set_tags(self, tags: Dict[str, str]) -> None:
"""Set multiple tags on the run."""
if self.run:
mlflow.set_tags(tags)
@property
def artifact_uri(self) -> Optional[str]:
"""Get the artifact URI for the current run."""
if self.run:
return self.run.info.artifact_uri
return None
@property
def experiment_id(self) -> Optional[str]:
"""Get the experiment ID for the current run."""
if self.run:
return self.run.info.experiment_id
return None