feat: Add MLflow integration utilities
- client: Connection management and helpers - tracker: General experiment tracking - inference_tracker: Async metrics for NATS handlers - model_registry: Model registration with KServe metadata - kfp_components: Kubeflow Pipeline components - experiment_comparison: Run comparison tools - cli: Command-line interface
This commit is contained in:
664
mlflow_utils/experiment_comparison.py
Normal file
664
mlflow_utils/experiment_comparison.py
Normal file
@@ -0,0 +1,664 @@
|
||||
"""
|
||||
Experiment Comparison and Analysis Utilities
|
||||
|
||||
Provides tools for comparing model versions, querying experiments,
|
||||
and making data-driven decisions about model promotion to production.
|
||||
|
||||
Features:
|
||||
- Compare multiple runs/experiments side by side
|
||||
- Query experiments by tags, metrics, or parameters
|
||||
- Analyze inference metrics from NATS handlers
|
||||
- Generate promotion recommendations
|
||||
- Export comparison reports
|
||||
|
||||
Usage:
|
||||
from mlflow_utils.experiment_comparison import (
|
||||
ExperimentAnalyzer,
|
||||
compare_runs,
|
||||
get_best_run,
|
||||
promotion_recommendation,
|
||||
)
|
||||
|
||||
analyzer = ExperimentAnalyzer("chat-inference")
|
||||
|
||||
# Compare last N runs
|
||||
comparison = analyzer.compare_recent_runs(n=5)
|
||||
|
||||
# Find best performing model
|
||||
best = analyzer.get_best_run(metric="total_latency_mean", minimize=True)
|
||||
|
||||
# Get promotion recommendation
|
||||
rec = analyzer.promotion_recommendation(
|
||||
model_name="whisper-finetuned",
|
||||
min_accuracy=0.9,
|
||||
max_latency_p95=2.0
|
||||
)
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any, List, Tuple, Union
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from mlflow.entities import Run, Experiment
|
||||
|
||||
from .client import get_mlflow_client, MLflowConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunComparison:
|
||||
"""Comparison result for multiple MLflow runs."""
|
||||
run_ids: List[str]
|
||||
experiment_name: str
|
||||
|
||||
# Metric comparisons (metric_name -> {run_id -> value})
|
||||
metrics: Dict[str, Dict[str, float]] = field(default_factory=dict)
|
||||
|
||||
# Parameter differences
|
||||
params: Dict[str, Dict[str, str]] = field(default_factory=dict)
|
||||
|
||||
# Run metadata
|
||||
run_names: Dict[str, str] = field(default_factory=dict)
|
||||
start_times: Dict[str, datetime] = field(default_factory=dict)
|
||||
durations: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
# Best performers by metric
|
||||
best_by_metric: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"run_ids": self.run_ids,
|
||||
"experiment_name": self.experiment_name,
|
||||
"metrics": self.metrics,
|
||||
"params": self.params,
|
||||
"run_names": self.run_names,
|
||||
"best_by_metric": self.best_by_metric,
|
||||
}
|
||||
|
||||
def summary_table(self) -> str:
|
||||
"""Generate a text summary table of the comparison."""
|
||||
if not self.run_ids:
|
||||
return "No runs to compare"
|
||||
|
||||
lines = []
|
||||
lines.append(f"Experiment: {self.experiment_name}")
|
||||
lines.append(f"Comparing {len(self.run_ids)} runs")
|
||||
lines.append("")
|
||||
|
||||
# Header
|
||||
header = ["Metric"] + [self.run_names.get(rid, rid[:8]) for rid in self.run_ids]
|
||||
lines.append(" | ".join(header))
|
||||
lines.append("-" * (len(lines[-1]) + 10))
|
||||
|
||||
# Metrics
|
||||
for metric_name, values in sorted(self.metrics.items()):
|
||||
row = [metric_name]
|
||||
for run_id in self.run_ids:
|
||||
value = values.get(run_id)
|
||||
if value is not None:
|
||||
row.append(f"{value:.4f}")
|
||||
else:
|
||||
row.append("N/A")
|
||||
lines.append(" | ".join(row))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromotionRecommendation:
|
||||
"""Recommendation for model promotion."""
|
||||
model_name: str
|
||||
version: Optional[int]
|
||||
recommended: bool
|
||||
reasons: List[str]
|
||||
metrics_summary: Dict[str, float]
|
||||
comparison_with_production: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"version": self.version,
|
||||
"recommended": self.recommended,
|
||||
"reasons": self.reasons,
|
||||
"metrics_summary": self.metrics_summary,
|
||||
"comparison_with_production": self.comparison_with_production,
|
||||
}
|
||||
|
||||
|
||||
class ExperimentAnalyzer:
|
||||
"""
|
||||
Analyze MLflow experiments for model comparison and promotion decisions.
|
||||
|
||||
Example:
|
||||
analyzer = ExperimentAnalyzer("chat-inference")
|
||||
|
||||
# Get metrics summary for last 24 hours
|
||||
summary = analyzer.get_metrics_summary(hours=24)
|
||||
|
||||
# Compare models by accuracy
|
||||
best = analyzer.get_best_run(metric="eval.accuracy", minimize=False)
|
||||
|
||||
# Analyze inference latency trends
|
||||
trends = analyzer.get_metric_trends("total_latency_mean", days=7)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name: str,
|
||||
tracking_uri: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the experiment analyzer.
|
||||
|
||||
Args:
|
||||
experiment_name: Name of the MLflow experiment to analyze
|
||||
tracking_uri: Override default tracking URI
|
||||
"""
|
||||
self.experiment_name = experiment_name
|
||||
self.tracking_uri = tracking_uri
|
||||
self.client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
self._experiment: Optional[Experiment] = None
|
||||
|
||||
@property
|
||||
def experiment(self) -> Optional[Experiment]:
|
||||
"""Get the experiment object, fetching if needed."""
|
||||
if self._experiment is None:
|
||||
self._experiment = self.client.get_experiment_by_name(self.experiment_name)
|
||||
return self._experiment
|
||||
|
||||
def search_runs(
|
||||
self,
|
||||
filter_string: str = "",
|
||||
order_by: Optional[List[str]] = None,
|
||||
max_results: int = 100,
|
||||
run_view_type: str = "ACTIVE_ONLY",
|
||||
) -> List[Run]:
|
||||
"""
|
||||
Search for runs matching criteria.
|
||||
|
||||
Args:
|
||||
filter_string: MLflow search filter (e.g., "metrics.accuracy > 0.9")
|
||||
order_by: List of order clauses (e.g., ["metrics.accuracy DESC"])
|
||||
max_results: Maximum runs to return
|
||||
run_view_type: ACTIVE_ONLY, DELETED_ONLY, or ALL
|
||||
|
||||
Returns:
|
||||
List of matching Run objects
|
||||
"""
|
||||
if not self.experiment:
|
||||
logger.warning(f"Experiment '{self.experiment_name}' not found")
|
||||
return []
|
||||
|
||||
runs = self.client.search_runs(
|
||||
experiment_ids=[self.experiment.experiment_id],
|
||||
filter_string=filter_string,
|
||||
order_by=order_by or ["start_time DESC"],
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
return runs
|
||||
|
||||
def get_recent_runs(
|
||||
self,
|
||||
n: int = 10,
|
||||
hours: Optional[int] = None,
|
||||
) -> List[Run]:
|
||||
"""
|
||||
Get the most recent runs.
|
||||
|
||||
Args:
|
||||
n: Number of runs to return
|
||||
hours: Only include runs from the last N hours
|
||||
|
||||
Returns:
|
||||
List of Run objects
|
||||
"""
|
||||
filter_string = ""
|
||||
if hours:
|
||||
cutoff = datetime.now() - timedelta(hours=hours)
|
||||
cutoff_ms = int(cutoff.timestamp() * 1000)
|
||||
filter_string = f"attributes.start_time >= {cutoff_ms}"
|
||||
|
||||
return self.search_runs(
|
||||
filter_string=filter_string,
|
||||
order_by=["start_time DESC"],
|
||||
max_results=n,
|
||||
)
|
||||
|
||||
def compare_runs(
|
||||
self,
|
||||
run_ids: Optional[List[str]] = None,
|
||||
n_recent: int = 5,
|
||||
) -> RunComparison:
|
||||
"""
|
||||
Compare multiple runs side by side.
|
||||
|
||||
Args:
|
||||
run_ids: Specific run IDs to compare, or None for recent runs
|
||||
n_recent: If run_ids is None, compare this many recent runs
|
||||
|
||||
Returns:
|
||||
RunComparison object with detailed comparison
|
||||
"""
|
||||
if run_ids:
|
||||
runs = [self.client.get_run(rid) for rid in run_ids]
|
||||
else:
|
||||
runs = self.get_recent_runs(n=n_recent)
|
||||
|
||||
comparison = RunComparison(
|
||||
run_ids=[r.info.run_id for r in runs],
|
||||
experiment_name=self.experiment_name,
|
||||
)
|
||||
|
||||
# Collect all metrics and find best performers
|
||||
all_metrics: Dict[str, Dict[str, float]] = defaultdict(dict)
|
||||
|
||||
for run in runs:
|
||||
run_id = run.info.run_id
|
||||
|
||||
# Metadata
|
||||
comparison.run_names[run_id] = run.info.run_name or run_id[:8]
|
||||
comparison.start_times[run_id] = datetime.fromtimestamp(
|
||||
run.info.start_time / 1000
|
||||
)
|
||||
if run.info.end_time:
|
||||
comparison.durations[run_id] = (
|
||||
run.info.end_time - run.info.start_time
|
||||
) / 1000
|
||||
|
||||
# Metrics
|
||||
for key, value in run.data.metrics.items():
|
||||
all_metrics[key][run_id] = value
|
||||
|
||||
# Params
|
||||
for key, value in run.data.params.items():
|
||||
if key not in comparison.params:
|
||||
comparison.params[key] = {}
|
||||
comparison.params[key][run_id] = value
|
||||
|
||||
comparison.metrics = dict(all_metrics)
|
||||
|
||||
# Find best performers for each metric
|
||||
for metric_name, values in all_metrics.items():
|
||||
if not values:
|
||||
continue
|
||||
|
||||
# Determine if lower is better based on metric name
|
||||
minimize = any(
|
||||
term in metric_name.lower()
|
||||
for term in ["latency", "error", "loss", "time"]
|
||||
)
|
||||
|
||||
if minimize:
|
||||
best_id = min(values.keys(), key=lambda k: values[k])
|
||||
else:
|
||||
best_id = max(values.keys(), key=lambda k: values[k])
|
||||
|
||||
comparison.best_by_metric[metric_name] = best_id
|
||||
|
||||
return comparison
|
||||
|
||||
def get_best_run(
|
||||
self,
|
||||
metric: str,
|
||||
minimize: bool = True,
|
||||
filter_string: str = "",
|
||||
max_results: int = 100,
|
||||
) -> Optional[Run]:
|
||||
"""
|
||||
Get the best run by a specific metric.
|
||||
|
||||
Args:
|
||||
metric: Metric name to optimize
|
||||
minimize: If True, find minimum; if False, find maximum
|
||||
filter_string: Additional filter criteria
|
||||
max_results: Maximum runs to consider
|
||||
|
||||
Returns:
|
||||
Best Run object, or None if no runs found
|
||||
"""
|
||||
direction = "ASC" if minimize else "DESC"
|
||||
|
||||
runs = self.search_runs(
|
||||
filter_string=filter_string,
|
||||
order_by=[f"metrics.{metric} {direction}"],
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
# Filter to only runs that have the metric
|
||||
runs_with_metric = [
|
||||
r for r in runs
|
||||
if metric in r.data.metrics
|
||||
]
|
||||
|
||||
return runs_with_metric[0] if runs_with_metric else None
|
||||
|
||||
def get_metrics_summary(
|
||||
self,
|
||||
hours: Optional[int] = None,
|
||||
metrics: Optional[List[str]] = None,
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
Get summary statistics for metrics.
|
||||
|
||||
Args:
|
||||
hours: Only include runs from the last N hours
|
||||
metrics: Specific metrics to summarize (None for all)
|
||||
|
||||
Returns:
|
||||
Dict mapping metric names to {mean, min, max, count}
|
||||
"""
|
||||
import statistics
|
||||
|
||||
runs = self.get_recent_runs(n=1000, hours=hours)
|
||||
|
||||
# Collect all metric values
|
||||
metric_values: Dict[str, List[float]] = defaultdict(list)
|
||||
|
||||
for run in runs:
|
||||
for key, value in run.data.metrics.items():
|
||||
if metrics is None or key in metrics:
|
||||
metric_values[key].append(value)
|
||||
|
||||
# Calculate statistics
|
||||
summary = {}
|
||||
for metric_name, values in metric_values.items():
|
||||
if not values:
|
||||
continue
|
||||
|
||||
summary[metric_name] = {
|
||||
"mean": statistics.mean(values),
|
||||
"min": min(values),
|
||||
"max": max(values),
|
||||
"count": len(values),
|
||||
}
|
||||
|
||||
if len(values) >= 2:
|
||||
summary[metric_name]["stdev"] = statistics.stdev(values)
|
||||
summary[metric_name]["median"] = statistics.median(values)
|
||||
|
||||
return summary
|
||||
|
||||
def get_metric_trends(
|
||||
self,
|
||||
metric: str,
|
||||
days: int = 7,
|
||||
granularity_hours: int = 1,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get metric trends over time.
|
||||
|
||||
Args:
|
||||
metric: Metric name to track
|
||||
days: Number of days to look back
|
||||
granularity_hours: Time bucket size in hours
|
||||
|
||||
Returns:
|
||||
List of {timestamp, mean, min, max, count} dicts
|
||||
"""
|
||||
import statistics
|
||||
|
||||
runs = self.get_recent_runs(n=10000, hours=days * 24)
|
||||
|
||||
# Group runs by time bucket
|
||||
buckets: Dict[int, List[float]] = defaultdict(list)
|
||||
bucket_size_ms = granularity_hours * 3600 * 1000
|
||||
|
||||
for run in runs:
|
||||
if metric not in run.data.metrics:
|
||||
continue
|
||||
|
||||
bucket = (run.info.start_time // bucket_size_ms) * bucket_size_ms
|
||||
buckets[bucket].append(run.data.metrics[metric])
|
||||
|
||||
# Calculate statistics per bucket
|
||||
trends = []
|
||||
for bucket_ts, values in sorted(buckets.items()):
|
||||
trend = {
|
||||
"timestamp": datetime.fromtimestamp(bucket_ts / 1000).isoformat(),
|
||||
"count": len(values),
|
||||
"mean": statistics.mean(values),
|
||||
"min": min(values),
|
||||
"max": max(values),
|
||||
}
|
||||
if len(values) >= 2:
|
||||
trend["stdev"] = statistics.stdev(values)
|
||||
trends.append(trend)
|
||||
|
||||
return trends
|
||||
|
||||
def get_runs_by_tag(
|
||||
self,
|
||||
tag_key: str,
|
||||
tag_value: str,
|
||||
max_results: int = 100,
|
||||
) -> List[Run]:
|
||||
"""
|
||||
Get runs with a specific tag.
|
||||
|
||||
Args:
|
||||
tag_key: Tag key to filter by
|
||||
tag_value: Tag value to match
|
||||
max_results: Maximum runs to return
|
||||
|
||||
Returns:
|
||||
List of matching Run objects
|
||||
"""
|
||||
return self.search_runs(
|
||||
filter_string=f"tags.{tag_key} = '{tag_value}'",
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
def get_model_runs(
|
||||
self,
|
||||
model_name: str,
|
||||
max_results: int = 100,
|
||||
) -> List[Run]:
|
||||
"""
|
||||
Get runs for a specific model.
|
||||
|
||||
Args:
|
||||
model_name: Model name to filter by
|
||||
max_results: Maximum runs to return
|
||||
|
||||
Returns:
|
||||
List of matching Run objects
|
||||
"""
|
||||
# Try different tag conventions
|
||||
runs = self.search_runs(
|
||||
filter_string=f"tags.`model.name` = '{model_name}'",
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
if not runs:
|
||||
# Try params
|
||||
runs = self.search_runs(
|
||||
filter_string=f"params.model_name = '{model_name}'",
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
return runs
|
||||
|
||||
|
||||
def compare_experiments(
|
||||
experiment_names: List[str],
|
||||
metric: str,
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
Compare metrics across multiple experiments.
|
||||
|
||||
Args:
|
||||
experiment_names: Names of experiments to compare
|
||||
metric: Metric to compare
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
Dict mapping experiment names to metric statistics
|
||||
"""
|
||||
results = {}
|
||||
|
||||
for exp_name in experiment_names:
|
||||
analyzer = ExperimentAnalyzer(exp_name, tracking_uri=tracking_uri)
|
||||
summary = analyzer.get_metrics_summary(metrics=[metric])
|
||||
if metric in summary:
|
||||
results[exp_name] = summary[metric]
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def promotion_recommendation(
|
||||
model_name: str,
|
||||
experiment_name: str,
|
||||
criteria: Dict[str, Tuple[str, float]],
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> PromotionRecommendation:
|
||||
"""
|
||||
Generate a recommendation for model promotion.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to evaluate
|
||||
experiment_name: Experiment containing evaluation runs
|
||||
criteria: Dict of {metric: (comparison, threshold)}
|
||||
comparison is one of: ">=", "<=", ">", "<"
|
||||
e.g., {"eval.accuracy": (">=", 0.9), "total_latency_p95": ("<=", 2.0)}
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
PromotionRecommendation with decision and reasons
|
||||
"""
|
||||
analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri)
|
||||
|
||||
# Get model runs
|
||||
runs = analyzer.get_model_runs(model_name, max_results=10)
|
||||
|
||||
if not runs:
|
||||
return PromotionRecommendation(
|
||||
model_name=model_name,
|
||||
version=None,
|
||||
recommended=False,
|
||||
reasons=["No runs found for this model"],
|
||||
metrics_summary={},
|
||||
)
|
||||
|
||||
# Get the most recent run
|
||||
latest_run = runs[0]
|
||||
metrics = latest_run.data.metrics
|
||||
|
||||
# Evaluate criteria
|
||||
reasons = []
|
||||
passed = True
|
||||
|
||||
comparisons = {
|
||||
">=": lambda a, b: a >= b,
|
||||
"<=": lambda a, b: a <= b,
|
||||
">": lambda a, b: a > b,
|
||||
"<": lambda a, b: a < b,
|
||||
}
|
||||
|
||||
for metric_name, (comparison, threshold) in criteria.items():
|
||||
if metric_name not in metrics:
|
||||
reasons.append(f"Metric '{metric_name}' not found")
|
||||
passed = False
|
||||
continue
|
||||
|
||||
value = metrics[metric_name]
|
||||
compare_fn = comparisons.get(comparison)
|
||||
|
||||
if compare_fn is None:
|
||||
reasons.append(f"Invalid comparison operator: {comparison}")
|
||||
continue
|
||||
|
||||
if compare_fn(value, threshold):
|
||||
reasons.append(f"✓ {metric_name}: {value:.4f} {comparison} {threshold}")
|
||||
else:
|
||||
reasons.append(f"✗ {metric_name}: {value:.4f} NOT {comparison} {threshold}")
|
||||
passed = False
|
||||
|
||||
# Extract version from tags if available
|
||||
version = None
|
||||
if "mlflow.version" in latest_run.data.tags:
|
||||
try:
|
||||
version = int(latest_run.data.tags["mlflow.version"])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return PromotionRecommendation(
|
||||
model_name=model_name,
|
||||
version=version,
|
||||
recommended=passed,
|
||||
reasons=reasons,
|
||||
metrics_summary=dict(metrics),
|
||||
)
|
||||
|
||||
|
||||
def get_inference_performance_report(
|
||||
service_name: str = "chat-handler",
|
||||
hours: int = 24,
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate an inference performance report for a service.
|
||||
|
||||
Args:
|
||||
service_name: Service name (chat-handler, voice-assistant)
|
||||
hours: Hours of data to analyze
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
Performance report dictionary
|
||||
"""
|
||||
experiment_name = f"{service_name.replace('-', '')}-inference"
|
||||
analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri)
|
||||
|
||||
# Get summary metrics
|
||||
summary = analyzer.get_metrics_summary(hours=hours)
|
||||
|
||||
# Key latency metrics
|
||||
latency_metrics = [
|
||||
"total_latency_mean",
|
||||
"total_latency_p50",
|
||||
"total_latency_p95",
|
||||
"llm_latency_mean",
|
||||
"embedding_latency_mean",
|
||||
"rag_search_latency_mean",
|
||||
]
|
||||
|
||||
report = {
|
||||
"service": service_name,
|
||||
"period_hours": hours,
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"latency": {},
|
||||
"throughput": {},
|
||||
"rag": {},
|
||||
"errors": {},
|
||||
}
|
||||
|
||||
# Latency section
|
||||
for metric in latency_metrics:
|
||||
if metric in summary:
|
||||
report["latency"][metric] = summary[metric]
|
||||
|
||||
# Throughput
|
||||
if "total_requests" in summary:
|
||||
report["throughput"]["total_requests"] = summary["total_requests"]["mean"]
|
||||
|
||||
# RAG usage
|
||||
rag_metrics = ["rag_enabled_pct", "rag_documents_retrieved_mean", "rag_documents_used_mean"]
|
||||
for metric in rag_metrics:
|
||||
if metric in summary:
|
||||
report["rag"][metric] = summary[metric]
|
||||
|
||||
# Error rate
|
||||
if "error_rate" in summary:
|
||||
report["errors"]["error_rate_pct"] = summary["error_rate"]["mean"]
|
||||
|
||||
return report
|
||||
Reference in New Issue
Block a user