feat: Add MLflow integration utilities
- client: Connection management and helpers - tracker: General experiment tracking - inference_tracker: Async metrics for NATS handlers - model_registry: Model registration with KServe metadata - kfp_components: Kubeflow Pipeline components - experiment_comparison: Run comparison tools - cli: Command-line interface
This commit is contained in:
395
mlflow_utils/tracker.py
Normal file
395
mlflow_utils/tracker.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
MLflow Tracker for Kubeflow Pipelines
|
||||
|
||||
Provides a high-level interface for logging experiments, parameters,
|
||||
metrics, and artifacts from Kubeflow Pipeline components.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List, Union
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
from .client import get_mlflow_client, ensure_experiment, MLflowConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineMetadata:
|
||||
"""Metadata about the Kubeflow Pipeline run."""
|
||||
pipeline_name: str
|
||||
run_id: str
|
||||
run_name: Optional[str] = None
|
||||
component_name: Optional[str] = None
|
||||
namespace: str = "ai-ml"
|
||||
|
||||
# KFP-specific metadata (populated from environment if available)
|
||||
kfp_run_id: Optional[str] = field(
|
||||
default_factory=lambda: os.environ.get("KFP_RUN_ID")
|
||||
)
|
||||
kfp_pod_name: Optional[str] = field(
|
||||
default_factory=lambda: os.environ.get("KFP_POD_NAME")
|
||||
)
|
||||
|
||||
def as_tags(self) -> Dict[str, str]:
|
||||
"""Convert metadata to MLflow tags."""
|
||||
tags = {
|
||||
"pipeline.name": self.pipeline_name,
|
||||
"pipeline.run_id": self.run_id,
|
||||
"pipeline.namespace": self.namespace,
|
||||
}
|
||||
if self.run_name:
|
||||
tags["pipeline.run_name"] = self.run_name
|
||||
if self.component_name:
|
||||
tags["pipeline.component"] = self.component_name
|
||||
if self.kfp_run_id:
|
||||
tags["kfp.run_id"] = self.kfp_run_id
|
||||
if self.kfp_pod_name:
|
||||
tags["kfp.pod_name"] = self.kfp_pod_name
|
||||
return tags
|
||||
|
||||
|
||||
class MLflowTracker:
|
||||
"""
|
||||
MLflow experiment tracker for Kubeflow Pipeline components.
|
||||
|
||||
Example usage in a KFP component:
|
||||
|
||||
from mlflow_utils import MLflowTracker
|
||||
|
||||
tracker = MLflowTracker(
|
||||
experiment_name="document-ingestion",
|
||||
run_name="batch-ingestion-2024-01"
|
||||
)
|
||||
|
||||
with tracker.start_run() as run:
|
||||
tracker.log_params({
|
||||
"chunk_size": 500,
|
||||
"overlap": 50,
|
||||
"embeddings_model": "bge-small-en-v1.5"
|
||||
})
|
||||
|
||||
# ... do work ...
|
||||
|
||||
tracker.log_metrics({
|
||||
"documents_processed": 100,
|
||||
"chunks_created": 2500,
|
||||
"processing_time_seconds": 120.5
|
||||
})
|
||||
|
||||
tracker.log_artifact("/path/to/output.json")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name: str,
|
||||
run_name: Optional[str] = None,
|
||||
pipeline_metadata: Optional[PipelineMetadata] = None,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
tracking_uri: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the MLflow tracker.
|
||||
|
||||
Args:
|
||||
experiment_name: Name of the MLflow experiment
|
||||
run_name: Optional name for this run
|
||||
pipeline_metadata: Metadata about the KFP pipeline
|
||||
tags: Additional tags to add to the run
|
||||
tracking_uri: Override default tracking URI
|
||||
"""
|
||||
self.config = MLflowConfig()
|
||||
self.experiment_name = experiment_name
|
||||
self.run_name = run_name or f"{experiment_name}-{int(time.time())}"
|
||||
self.pipeline_metadata = pipeline_metadata
|
||||
self.user_tags = tags or {}
|
||||
self.tracking_uri = tracking_uri
|
||||
|
||||
self.client: Optional[MlflowClient] = None
|
||||
self.run: Optional[mlflow.ActiveRun] = None
|
||||
self.run_id: Optional[str] = None
|
||||
self._start_time: Optional[float] = None
|
||||
|
||||
def _get_all_tags(self) -> Dict[str, str]:
|
||||
"""Combine all tags for the run."""
|
||||
tags = self.config.default_tags.copy()
|
||||
|
||||
if self.pipeline_metadata:
|
||||
tags.update(self.pipeline_metadata.as_tags())
|
||||
|
||||
tags.update(self.user_tags)
|
||||
return tags
|
||||
|
||||
@contextmanager
|
||||
def start_run(
|
||||
self,
|
||||
nested: bool = False,
|
||||
parent_run_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Start an MLflow run as a context manager.
|
||||
|
||||
Args:
|
||||
nested: If True, create a nested run under the current active run
|
||||
parent_run_id: Explicit parent run ID for nested runs
|
||||
|
||||
Yields:
|
||||
The MLflow run object
|
||||
"""
|
||||
self.client = get_mlflow_client(
|
||||
tracking_uri=self.tracking_uri,
|
||||
configure_global=True
|
||||
)
|
||||
|
||||
# Ensure experiment exists
|
||||
experiment_id = ensure_experiment(self.experiment_name)
|
||||
|
||||
self._start_time = time.time()
|
||||
|
||||
try:
|
||||
# Start the run
|
||||
self.run = mlflow.start_run(
|
||||
experiment_id=experiment_id,
|
||||
run_name=self.run_name,
|
||||
nested=nested,
|
||||
tags=self._get_all_tags(),
|
||||
)
|
||||
self.run_id = self.run.info.run_id
|
||||
|
||||
logger.info(
|
||||
f"Started MLflow run '{self.run_name}' "
|
||||
f"(ID: {self.run_id}) in experiment '{self.experiment_name}'"
|
||||
)
|
||||
|
||||
yield self.run
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MLflow run failed: {e}")
|
||||
if self.run:
|
||||
mlflow.set_tag("run.status", "failed")
|
||||
mlflow.set_tag("run.error", str(e))
|
||||
raise
|
||||
finally:
|
||||
# Log duration
|
||||
if self._start_time:
|
||||
duration = time.time() - self._start_time
|
||||
try:
|
||||
mlflow.log_metric("run_duration_seconds", duration)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# End the run
|
||||
mlflow.end_run()
|
||||
logger.info(f"Ended MLflow run '{self.run_name}'")
|
||||
|
||||
def log_params(self, params: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Log parameters to the current run.
|
||||
|
||||
Args:
|
||||
params: Dictionary of parameter names to values
|
||||
"""
|
||||
if not self.run:
|
||||
logger.warning("No active run, skipping log_params")
|
||||
return
|
||||
|
||||
# MLflow has limits on param values, truncate if needed
|
||||
cleaned_params = {}
|
||||
for key, value in params.items():
|
||||
str_value = str(value)
|
||||
if len(str_value) > 500:
|
||||
str_value = str_value[:497] + "..."
|
||||
cleaned_params[key] = str_value
|
||||
|
||||
mlflow.log_params(cleaned_params)
|
||||
logger.debug(f"Logged {len(params)} parameters")
|
||||
|
||||
def log_param(self, key: str, value: Any) -> None:
|
||||
"""Log a single parameter."""
|
||||
self.log_params({key: value})
|
||||
|
||||
def log_metrics(
|
||||
self,
|
||||
metrics: Dict[str, Union[float, int]],
|
||||
step: Optional[int] = None
|
||||
) -> None:
|
||||
"""
|
||||
Log metrics to the current run.
|
||||
|
||||
Args:
|
||||
metrics: Dictionary of metric names to values
|
||||
step: Optional step number for time-series metrics
|
||||
"""
|
||||
if not self.run:
|
||||
logger.warning("No active run, skipping log_metrics")
|
||||
return
|
||||
|
||||
mlflow.log_metrics(metrics, step=step)
|
||||
logger.debug(f"Logged {len(metrics)} metrics")
|
||||
|
||||
def log_metric(
|
||||
self,
|
||||
key: str,
|
||||
value: Union[float, int],
|
||||
step: Optional[int] = None
|
||||
) -> None:
|
||||
"""Log a single metric."""
|
||||
self.log_metrics({key: value}, step=step)
|
||||
|
||||
def log_artifact(
|
||||
self,
|
||||
local_path: str,
|
||||
artifact_path: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Log an artifact file to the current run.
|
||||
|
||||
Args:
|
||||
local_path: Path to the local file to log
|
||||
artifact_path: Optional destination path within the artifact store
|
||||
"""
|
||||
if not self.run:
|
||||
logger.warning("No active run, skipping log_artifact")
|
||||
return
|
||||
|
||||
mlflow.log_artifact(local_path, artifact_path)
|
||||
logger.info(f"Logged artifact: {local_path}")
|
||||
|
||||
def log_artifacts(
|
||||
self,
|
||||
local_dir: str,
|
||||
artifact_path: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Log all files in a directory as artifacts.
|
||||
|
||||
Args:
|
||||
local_dir: Path to the local directory
|
||||
artifact_path: Optional destination path within the artifact store
|
||||
"""
|
||||
if not self.run:
|
||||
logger.warning("No active run, skipping log_artifacts")
|
||||
return
|
||||
|
||||
mlflow.log_artifacts(local_dir, artifact_path)
|
||||
logger.info(f"Logged artifacts from: {local_dir}")
|
||||
|
||||
def log_dict(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
filename: str,
|
||||
artifact_path: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Log a dictionary as a JSON artifact.
|
||||
|
||||
Args:
|
||||
data: Dictionary to log
|
||||
filename: Name for the JSON file
|
||||
artifact_path: Optional destination path
|
||||
"""
|
||||
if not self.run:
|
||||
logger.warning("No active run, skipping log_dict")
|
||||
return
|
||||
|
||||
# Ensure .json extension
|
||||
if not filename.endswith(".json"):
|
||||
filename += ".json"
|
||||
|
||||
mlflow.log_dict(data, f"{artifact_path}/{filename}" if artifact_path else filename)
|
||||
logger.debug(f"Logged dict as: {filename}")
|
||||
|
||||
def log_model_info(
|
||||
self,
|
||||
model_type: str,
|
||||
model_name: str,
|
||||
model_path: Optional[str] = None,
|
||||
framework: str = "pytorch",
|
||||
extra_info: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Log model information as parameters and tags.
|
||||
|
||||
Args:
|
||||
model_type: Type of model (e.g., "llm", "embedding", "stt")
|
||||
model_name: Name/identifier of the model
|
||||
model_path: Path to model weights
|
||||
framework: ML framework used
|
||||
extra_info: Additional model information
|
||||
"""
|
||||
params = {
|
||||
"model.type": model_type,
|
||||
"model.name": model_name,
|
||||
"model.framework": framework,
|
||||
}
|
||||
if model_path:
|
||||
params["model.path"] = model_path
|
||||
if extra_info:
|
||||
for key, value in extra_info.items():
|
||||
params[f"model.{key}"] = value
|
||||
|
||||
self.log_params(params)
|
||||
|
||||
# Also set as tags for easier filtering
|
||||
mlflow.set_tag("model.type", model_type)
|
||||
mlflow.set_tag("model.name", model_name)
|
||||
|
||||
def log_dataset_info(
|
||||
self,
|
||||
name: str,
|
||||
source: str,
|
||||
size: Optional[int] = None,
|
||||
extra_info: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Log dataset information.
|
||||
|
||||
Args:
|
||||
name: Dataset name
|
||||
source: Dataset source (URL, path, etc.)
|
||||
size: Number of samples
|
||||
extra_info: Additional dataset information
|
||||
"""
|
||||
params = {
|
||||
"dataset.name": name,
|
||||
"dataset.source": source,
|
||||
}
|
||||
if size is not None:
|
||||
params["dataset.size"] = size
|
||||
if extra_info:
|
||||
for key, value in extra_info.items():
|
||||
params[f"dataset.{key}"] = value
|
||||
|
||||
self.log_params(params)
|
||||
|
||||
def set_tag(self, key: str, value: str) -> None:
|
||||
"""Set a single tag on the run."""
|
||||
if self.run:
|
||||
mlflow.set_tag(key, value)
|
||||
|
||||
def set_tags(self, tags: Dict[str, str]) -> None:
|
||||
"""Set multiple tags on the run."""
|
||||
if self.run:
|
||||
mlflow.set_tags(tags)
|
||||
|
||||
@property
|
||||
def artifact_uri(self) -> Optional[str]:
|
||||
"""Get the artifact URI for the current run."""
|
||||
if self.run:
|
||||
return self.run.info.artifact_uri
|
||||
return None
|
||||
|
||||
@property
|
||||
def experiment_id(self) -> Optional[str]:
|
||||
"""Get the experiment ID for the current run."""
|
||||
if self.run:
|
||||
return self.run.info.experiment_id
|
||||
return None
|
||||
Reference in New Issue
Block a user