394 lines
12 KiB
Python
394 lines
12 KiB
Python
"""
|
|
MLflow Tracker for Kubeflow Pipelines
|
|
|
|
Provides a high-level interface for logging experiments, parameters,
|
|
metrics, and artifacts from Kubeflow Pipeline components.
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import time
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, Optional, Union
|
|
|
|
import mlflow
|
|
from mlflow.tracking import MlflowClient
|
|
|
|
from .client import MLflowConfig, ensure_experiment, get_mlflow_client
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class PipelineMetadata:
|
|
"""Metadata about the Kubeflow Pipeline run."""
|
|
pipeline_name: str
|
|
run_id: str
|
|
run_name: Optional[str] = None
|
|
component_name: Optional[str] = None
|
|
namespace: str = "ai-ml"
|
|
|
|
# KFP-specific metadata (populated from environment if available)
|
|
kfp_run_id: Optional[str] = field(
|
|
default_factory=lambda: os.environ.get("KFP_RUN_ID")
|
|
)
|
|
kfp_pod_name: Optional[str] = field(
|
|
default_factory=lambda: os.environ.get("KFP_POD_NAME")
|
|
)
|
|
|
|
def as_tags(self) -> Dict[str, str]:
|
|
"""Convert metadata to MLflow tags."""
|
|
tags = {
|
|
"pipeline.name": self.pipeline_name,
|
|
"pipeline.run_id": self.run_id,
|
|
"pipeline.namespace": self.namespace,
|
|
}
|
|
if self.run_name:
|
|
tags["pipeline.run_name"] = self.run_name
|
|
if self.component_name:
|
|
tags["pipeline.component"] = self.component_name
|
|
if self.kfp_run_id:
|
|
tags["kfp.run_id"] = self.kfp_run_id
|
|
if self.kfp_pod_name:
|
|
tags["kfp.pod_name"] = self.kfp_pod_name
|
|
return tags
|
|
|
|
|
|
class MLflowTracker:
|
|
"""
|
|
MLflow experiment tracker for Kubeflow Pipeline components.
|
|
|
|
Example usage in a KFP component:
|
|
|
|
from mlflow_utils import MLflowTracker
|
|
|
|
tracker = MLflowTracker(
|
|
experiment_name="document-ingestion",
|
|
run_name="batch-ingestion-2024-01"
|
|
)
|
|
|
|
with tracker.start_run() as run:
|
|
tracker.log_params({
|
|
"chunk_size": 500,
|
|
"overlap": 50,
|
|
"embeddings_model": "bge-small-en-v1.5"
|
|
})
|
|
|
|
# ... do work ...
|
|
|
|
tracker.log_metrics({
|
|
"documents_processed": 100,
|
|
"chunks_created": 2500,
|
|
"processing_time_seconds": 120.5
|
|
})
|
|
|
|
tracker.log_artifact("/path/to/output.json")
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
experiment_name: str,
|
|
run_name: Optional[str] = None,
|
|
pipeline_metadata: Optional[PipelineMetadata] = None,
|
|
tags: Optional[Dict[str, str]] = None,
|
|
tracking_uri: Optional[str] = None,
|
|
):
|
|
"""
|
|
Initialize the MLflow tracker.
|
|
|
|
Args:
|
|
experiment_name: Name of the MLflow experiment
|
|
run_name: Optional name for this run
|
|
pipeline_metadata: Metadata about the KFP pipeline
|
|
tags: Additional tags to add to the run
|
|
tracking_uri: Override default tracking URI
|
|
"""
|
|
self.config = MLflowConfig()
|
|
self.experiment_name = experiment_name
|
|
self.run_name = run_name or f"{experiment_name}-{int(time.time())}"
|
|
self.pipeline_metadata = pipeline_metadata
|
|
self.user_tags = tags or {}
|
|
self.tracking_uri = tracking_uri
|
|
|
|
self.client: Optional[MlflowClient] = None
|
|
self.run: Optional[mlflow.ActiveRun] = None
|
|
self.run_id: Optional[str] = None
|
|
self._start_time: Optional[float] = None
|
|
|
|
def _get_all_tags(self) -> Dict[str, str]:
|
|
"""Combine all tags for the run."""
|
|
tags = self.config.default_tags.copy()
|
|
|
|
if self.pipeline_metadata:
|
|
tags.update(self.pipeline_metadata.as_tags())
|
|
|
|
tags.update(self.user_tags)
|
|
return tags
|
|
|
|
@contextmanager
|
|
def start_run(
|
|
self,
|
|
nested: bool = False,
|
|
parent_run_id: Optional[str] = None,
|
|
):
|
|
"""
|
|
Start an MLflow run as a context manager.
|
|
|
|
Args:
|
|
nested: If True, create a nested run under the current active run
|
|
parent_run_id: Explicit parent run ID for nested runs
|
|
|
|
Yields:
|
|
The MLflow run object
|
|
"""
|
|
self.client = get_mlflow_client(
|
|
tracking_uri=self.tracking_uri,
|
|
configure_global=True
|
|
)
|
|
|
|
# Ensure experiment exists
|
|
experiment_id = ensure_experiment(self.experiment_name)
|
|
|
|
self._start_time = time.time()
|
|
|
|
try:
|
|
# Start the run
|
|
self.run = mlflow.start_run(
|
|
experiment_id=experiment_id,
|
|
run_name=self.run_name,
|
|
nested=nested,
|
|
tags=self._get_all_tags(),
|
|
)
|
|
self.run_id = self.run.info.run_id
|
|
|
|
logger.info(
|
|
f"Started MLflow run '{self.run_name}' "
|
|
f"(ID: {self.run_id}) in experiment '{self.experiment_name}'"
|
|
)
|
|
|
|
yield self.run
|
|
|
|
except Exception as e:
|
|
logger.error(f"MLflow run failed: {e}")
|
|
if self.run:
|
|
mlflow.set_tag("run.status", "failed")
|
|
mlflow.set_tag("run.error", str(e))
|
|
raise
|
|
finally:
|
|
# Log duration
|
|
if self._start_time:
|
|
duration = time.time() - self._start_time
|
|
try:
|
|
mlflow.log_metric("run_duration_seconds", duration)
|
|
except Exception:
|
|
pass
|
|
|
|
# End the run
|
|
mlflow.end_run()
|
|
logger.info(f"Ended MLflow run '{self.run_name}'")
|
|
|
|
def log_params(self, params: Dict[str, Any]) -> None:
|
|
"""
|
|
Log parameters to the current run.
|
|
|
|
Args:
|
|
params: Dictionary of parameter names to values
|
|
"""
|
|
if not self.run:
|
|
logger.warning("No active run, skipping log_params")
|
|
return
|
|
|
|
# MLflow has limits on param values, truncate if needed
|
|
cleaned_params = {}
|
|
for key, value in params.items():
|
|
str_value = str(value)
|
|
if len(str_value) > 500:
|
|
str_value = str_value[:497] + "..."
|
|
cleaned_params[key] = str_value
|
|
|
|
mlflow.log_params(cleaned_params)
|
|
logger.debug(f"Logged {len(params)} parameters")
|
|
|
|
def log_param(self, key: str, value: Any) -> None:
|
|
"""Log a single parameter."""
|
|
self.log_params({key: value})
|
|
|
|
def log_metrics(
|
|
self,
|
|
metrics: Dict[str, Union[float, int]],
|
|
step: Optional[int] = None
|
|
) -> None:
|
|
"""
|
|
Log metrics to the current run.
|
|
|
|
Args:
|
|
metrics: Dictionary of metric names to values
|
|
step: Optional step number for time-series metrics
|
|
"""
|
|
if not self.run:
|
|
logger.warning("No active run, skipping log_metrics")
|
|
return
|
|
|
|
mlflow.log_metrics(metrics, step=step)
|
|
logger.debug(f"Logged {len(metrics)} metrics")
|
|
|
|
def log_metric(
|
|
self,
|
|
key: str,
|
|
value: Union[float, int],
|
|
step: Optional[int] = None
|
|
) -> None:
|
|
"""Log a single metric."""
|
|
self.log_metrics({key: value}, step=step)
|
|
|
|
def log_artifact(
|
|
self,
|
|
local_path: str,
|
|
artifact_path: Optional[str] = None
|
|
) -> None:
|
|
"""
|
|
Log an artifact file to the current run.
|
|
|
|
Args:
|
|
local_path: Path to the local file to log
|
|
artifact_path: Optional destination path within the artifact store
|
|
"""
|
|
if not self.run:
|
|
logger.warning("No active run, skipping log_artifact")
|
|
return
|
|
|
|
mlflow.log_artifact(local_path, artifact_path)
|
|
logger.info(f"Logged artifact: {local_path}")
|
|
|
|
def log_artifacts(
|
|
self,
|
|
local_dir: str,
|
|
artifact_path: Optional[str] = None
|
|
) -> None:
|
|
"""
|
|
Log all files in a directory as artifacts.
|
|
|
|
Args:
|
|
local_dir: Path to the local directory
|
|
artifact_path: Optional destination path within the artifact store
|
|
"""
|
|
if not self.run:
|
|
logger.warning("No active run, skipping log_artifacts")
|
|
return
|
|
|
|
mlflow.log_artifacts(local_dir, artifact_path)
|
|
logger.info(f"Logged artifacts from: {local_dir}")
|
|
|
|
def log_dict(
|
|
self,
|
|
data: Dict[str, Any],
|
|
filename: str,
|
|
artifact_path: Optional[str] = None
|
|
) -> None:
|
|
"""
|
|
Log a dictionary as a JSON artifact.
|
|
|
|
Args:
|
|
data: Dictionary to log
|
|
filename: Name for the JSON file
|
|
artifact_path: Optional destination path
|
|
"""
|
|
if not self.run:
|
|
logger.warning("No active run, skipping log_dict")
|
|
return
|
|
|
|
# Ensure .json extension
|
|
if not filename.endswith(".json"):
|
|
filename += ".json"
|
|
|
|
mlflow.log_dict(data, f"{artifact_path}/{filename}" if artifact_path else filename)
|
|
logger.debug(f"Logged dict as: {filename}")
|
|
|
|
def log_model_info(
|
|
self,
|
|
model_type: str,
|
|
model_name: str,
|
|
model_path: Optional[str] = None,
|
|
framework: str = "pytorch",
|
|
extra_info: Optional[Dict[str, Any]] = None
|
|
) -> None:
|
|
"""
|
|
Log model information as parameters and tags.
|
|
|
|
Args:
|
|
model_type: Type of model (e.g., "llm", "embedding", "stt")
|
|
model_name: Name/identifier of the model
|
|
model_path: Path to model weights
|
|
framework: ML framework used
|
|
extra_info: Additional model information
|
|
"""
|
|
params = {
|
|
"model.type": model_type,
|
|
"model.name": model_name,
|
|
"model.framework": framework,
|
|
}
|
|
if model_path:
|
|
params["model.path"] = model_path
|
|
if extra_info:
|
|
for key, value in extra_info.items():
|
|
params[f"model.{key}"] = value
|
|
|
|
self.log_params(params)
|
|
|
|
# Also set as tags for easier filtering
|
|
mlflow.set_tag("model.type", model_type)
|
|
mlflow.set_tag("model.name", model_name)
|
|
|
|
def log_dataset_info(
|
|
self,
|
|
name: str,
|
|
source: str,
|
|
size: Optional[int] = None,
|
|
extra_info: Optional[Dict[str, Any]] = None
|
|
) -> None:
|
|
"""
|
|
Log dataset information.
|
|
|
|
Args:
|
|
name: Dataset name
|
|
source: Dataset source (URL, path, etc.)
|
|
size: Number of samples
|
|
extra_info: Additional dataset information
|
|
"""
|
|
params = {
|
|
"dataset.name": name,
|
|
"dataset.source": source,
|
|
}
|
|
if size is not None:
|
|
params["dataset.size"] = size
|
|
if extra_info:
|
|
for key, value in extra_info.items():
|
|
params[f"dataset.{key}"] = value
|
|
|
|
self.log_params(params)
|
|
|
|
def set_tag(self, key: str, value: str) -> None:
|
|
"""Set a single tag on the run."""
|
|
if self.run:
|
|
mlflow.set_tag(key, value)
|
|
|
|
def set_tags(self, tags: Dict[str, str]) -> None:
|
|
"""Set multiple tags on the run."""
|
|
if self.run:
|
|
mlflow.set_tags(tags)
|
|
|
|
@property
|
|
def artifact_uri(self) -> Optional[str]:
|
|
"""Get the artifact URI for the current run."""
|
|
if self.run:
|
|
return self.run.info.artifact_uri
|
|
return None
|
|
|
|
@property
|
|
def experiment_id(self) -> Optional[str]:
|
|
"""Get the experiment ID for the current run."""
|
|
if self.run:
|
|
return self.run.info.experiment_id
|
|
return None
|