Files
mlflow/mlflow_utils/experiment_comparison.py
Billy D. 2df3f27af7 feat: Add MLflow integration utilities
- client: Connection management and helpers
- tracker: General experiment tracking
- inference_tracker: Async metrics for NATS handlers
- model_registry: Model registration with KServe metadata
- kfp_components: Kubeflow Pipeline components
- experiment_comparison: Run comparison tools
- cli: Command-line interface
2026-02-01 20:43:13 -05:00

665 lines
20 KiB
Python

"""
Experiment Comparison and Analysis Utilities
Provides tools for comparing model versions, querying experiments,
and making data-driven decisions about model promotion to production.
Features:
- Compare multiple runs/experiments side by side
- Query experiments by tags, metrics, or parameters
- Analyze inference metrics from NATS handlers
- Generate promotion recommendations
- Export comparison reports
Usage:
from mlflow_utils.experiment_comparison import (
ExperimentAnalyzer,
compare_runs,
get_best_run,
promotion_recommendation,
)
analyzer = ExperimentAnalyzer("chat-inference")
# Compare last N runs
comparison = analyzer.compare_recent_runs(n=5)
# Find best performing model
best = analyzer.get_best_run(metric="total_latency_mean", minimize=True)
# Get promotion recommendation
rec = analyzer.promotion_recommendation(
model_name="whisper-finetuned",
min_accuracy=0.9,
max_latency_p95=2.0
)
"""
import os
import json
import logging
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List, Tuple, Union
from dataclasses import dataclass, field
from collections import defaultdict
import mlflow
from mlflow.tracking import MlflowClient
from mlflow.entities import Run, Experiment
from .client import get_mlflow_client, MLflowConfig
logger = logging.getLogger(__name__)
@dataclass
class RunComparison:
"""Comparison result for multiple MLflow runs."""
run_ids: List[str]
experiment_name: str
# Metric comparisons (metric_name -> {run_id -> value})
metrics: Dict[str, Dict[str, float]] = field(default_factory=dict)
# Parameter differences
params: Dict[str, Dict[str, str]] = field(default_factory=dict)
# Run metadata
run_names: Dict[str, str] = field(default_factory=dict)
start_times: Dict[str, datetime] = field(default_factory=dict)
durations: Dict[str, float] = field(default_factory=dict)
# Best performers by metric
best_by_metric: Dict[str, str] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"run_ids": self.run_ids,
"experiment_name": self.experiment_name,
"metrics": self.metrics,
"params": self.params,
"run_names": self.run_names,
"best_by_metric": self.best_by_metric,
}
def summary_table(self) -> str:
"""Generate a text summary table of the comparison."""
if not self.run_ids:
return "No runs to compare"
lines = []
lines.append(f"Experiment: {self.experiment_name}")
lines.append(f"Comparing {len(self.run_ids)} runs")
lines.append("")
# Header
header = ["Metric"] + [self.run_names.get(rid, rid[:8]) for rid in self.run_ids]
lines.append(" | ".join(header))
lines.append("-" * (len(lines[-1]) + 10))
# Metrics
for metric_name, values in sorted(self.metrics.items()):
row = [metric_name]
for run_id in self.run_ids:
value = values.get(run_id)
if value is not None:
row.append(f"{value:.4f}")
else:
row.append("N/A")
lines.append(" | ".join(row))
return "\n".join(lines)
@dataclass
class PromotionRecommendation:
"""Recommendation for model promotion."""
model_name: str
version: Optional[int]
recommended: bool
reasons: List[str]
metrics_summary: Dict[str, float]
comparison_with_production: Optional[Dict[str, Any]] = None
def to_dict(self) -> Dict[str, Any]:
return {
"model_name": self.model_name,
"version": self.version,
"recommended": self.recommended,
"reasons": self.reasons,
"metrics_summary": self.metrics_summary,
"comparison_with_production": self.comparison_with_production,
}
class ExperimentAnalyzer:
"""
Analyze MLflow experiments for model comparison and promotion decisions.
Example:
analyzer = ExperimentAnalyzer("chat-inference")
# Get metrics summary for last 24 hours
summary = analyzer.get_metrics_summary(hours=24)
# Compare models by accuracy
best = analyzer.get_best_run(metric="eval.accuracy", minimize=False)
# Analyze inference latency trends
trends = analyzer.get_metric_trends("total_latency_mean", days=7)
"""
def __init__(
self,
experiment_name: str,
tracking_uri: Optional[str] = None,
):
"""
Initialize the experiment analyzer.
Args:
experiment_name: Name of the MLflow experiment to analyze
tracking_uri: Override default tracking URI
"""
self.experiment_name = experiment_name
self.tracking_uri = tracking_uri
self.client = get_mlflow_client(tracking_uri=tracking_uri)
self._experiment: Optional[Experiment] = None
@property
def experiment(self) -> Optional[Experiment]:
"""Get the experiment object, fetching if needed."""
if self._experiment is None:
self._experiment = self.client.get_experiment_by_name(self.experiment_name)
return self._experiment
def search_runs(
self,
filter_string: str = "",
order_by: Optional[List[str]] = None,
max_results: int = 100,
run_view_type: str = "ACTIVE_ONLY",
) -> List[Run]:
"""
Search for runs matching criteria.
Args:
filter_string: MLflow search filter (e.g., "metrics.accuracy > 0.9")
order_by: List of order clauses (e.g., ["metrics.accuracy DESC"])
max_results: Maximum runs to return
run_view_type: ACTIVE_ONLY, DELETED_ONLY, or ALL
Returns:
List of matching Run objects
"""
if not self.experiment:
logger.warning(f"Experiment '{self.experiment_name}' not found")
return []
runs = self.client.search_runs(
experiment_ids=[self.experiment.experiment_id],
filter_string=filter_string,
order_by=order_by or ["start_time DESC"],
max_results=max_results,
)
return runs
def get_recent_runs(
self,
n: int = 10,
hours: Optional[int] = None,
) -> List[Run]:
"""
Get the most recent runs.
Args:
n: Number of runs to return
hours: Only include runs from the last N hours
Returns:
List of Run objects
"""
filter_string = ""
if hours:
cutoff = datetime.now() - timedelta(hours=hours)
cutoff_ms = int(cutoff.timestamp() * 1000)
filter_string = f"attributes.start_time >= {cutoff_ms}"
return self.search_runs(
filter_string=filter_string,
order_by=["start_time DESC"],
max_results=n,
)
def compare_runs(
self,
run_ids: Optional[List[str]] = None,
n_recent: int = 5,
) -> RunComparison:
"""
Compare multiple runs side by side.
Args:
run_ids: Specific run IDs to compare, or None for recent runs
n_recent: If run_ids is None, compare this many recent runs
Returns:
RunComparison object with detailed comparison
"""
if run_ids:
runs = [self.client.get_run(rid) for rid in run_ids]
else:
runs = self.get_recent_runs(n=n_recent)
comparison = RunComparison(
run_ids=[r.info.run_id for r in runs],
experiment_name=self.experiment_name,
)
# Collect all metrics and find best performers
all_metrics: Dict[str, Dict[str, float]] = defaultdict(dict)
for run in runs:
run_id = run.info.run_id
# Metadata
comparison.run_names[run_id] = run.info.run_name or run_id[:8]
comparison.start_times[run_id] = datetime.fromtimestamp(
run.info.start_time / 1000
)
if run.info.end_time:
comparison.durations[run_id] = (
run.info.end_time - run.info.start_time
) / 1000
# Metrics
for key, value in run.data.metrics.items():
all_metrics[key][run_id] = value
# Params
for key, value in run.data.params.items():
if key not in comparison.params:
comparison.params[key] = {}
comparison.params[key][run_id] = value
comparison.metrics = dict(all_metrics)
# Find best performers for each metric
for metric_name, values in all_metrics.items():
if not values:
continue
# Determine if lower is better based on metric name
minimize = any(
term in metric_name.lower()
for term in ["latency", "error", "loss", "time"]
)
if minimize:
best_id = min(values.keys(), key=lambda k: values[k])
else:
best_id = max(values.keys(), key=lambda k: values[k])
comparison.best_by_metric[metric_name] = best_id
return comparison
def get_best_run(
self,
metric: str,
minimize: bool = True,
filter_string: str = "",
max_results: int = 100,
) -> Optional[Run]:
"""
Get the best run by a specific metric.
Args:
metric: Metric name to optimize
minimize: If True, find minimum; if False, find maximum
filter_string: Additional filter criteria
max_results: Maximum runs to consider
Returns:
Best Run object, or None if no runs found
"""
direction = "ASC" if minimize else "DESC"
runs = self.search_runs(
filter_string=filter_string,
order_by=[f"metrics.{metric} {direction}"],
max_results=max_results,
)
# Filter to only runs that have the metric
runs_with_metric = [
r for r in runs
if metric in r.data.metrics
]
return runs_with_metric[0] if runs_with_metric else None
def get_metrics_summary(
self,
hours: Optional[int] = None,
metrics: Optional[List[str]] = None,
) -> Dict[str, Dict[str, float]]:
"""
Get summary statistics for metrics.
Args:
hours: Only include runs from the last N hours
metrics: Specific metrics to summarize (None for all)
Returns:
Dict mapping metric names to {mean, min, max, count}
"""
import statistics
runs = self.get_recent_runs(n=1000, hours=hours)
# Collect all metric values
metric_values: Dict[str, List[float]] = defaultdict(list)
for run in runs:
for key, value in run.data.metrics.items():
if metrics is None or key in metrics:
metric_values[key].append(value)
# Calculate statistics
summary = {}
for metric_name, values in metric_values.items():
if not values:
continue
summary[metric_name] = {
"mean": statistics.mean(values),
"min": min(values),
"max": max(values),
"count": len(values),
}
if len(values) >= 2:
summary[metric_name]["stdev"] = statistics.stdev(values)
summary[metric_name]["median"] = statistics.median(values)
return summary
def get_metric_trends(
self,
metric: str,
days: int = 7,
granularity_hours: int = 1,
) -> List[Dict[str, Any]]:
"""
Get metric trends over time.
Args:
metric: Metric name to track
days: Number of days to look back
granularity_hours: Time bucket size in hours
Returns:
List of {timestamp, mean, min, max, count} dicts
"""
import statistics
runs = self.get_recent_runs(n=10000, hours=days * 24)
# Group runs by time bucket
buckets: Dict[int, List[float]] = defaultdict(list)
bucket_size_ms = granularity_hours * 3600 * 1000
for run in runs:
if metric not in run.data.metrics:
continue
bucket = (run.info.start_time // bucket_size_ms) * bucket_size_ms
buckets[bucket].append(run.data.metrics[metric])
# Calculate statistics per bucket
trends = []
for bucket_ts, values in sorted(buckets.items()):
trend = {
"timestamp": datetime.fromtimestamp(bucket_ts / 1000).isoformat(),
"count": len(values),
"mean": statistics.mean(values),
"min": min(values),
"max": max(values),
}
if len(values) >= 2:
trend["stdev"] = statistics.stdev(values)
trends.append(trend)
return trends
def get_runs_by_tag(
self,
tag_key: str,
tag_value: str,
max_results: int = 100,
) -> List[Run]:
"""
Get runs with a specific tag.
Args:
tag_key: Tag key to filter by
tag_value: Tag value to match
max_results: Maximum runs to return
Returns:
List of matching Run objects
"""
return self.search_runs(
filter_string=f"tags.{tag_key} = '{tag_value}'",
max_results=max_results,
)
def get_model_runs(
self,
model_name: str,
max_results: int = 100,
) -> List[Run]:
"""
Get runs for a specific model.
Args:
model_name: Model name to filter by
max_results: Maximum runs to return
Returns:
List of matching Run objects
"""
# Try different tag conventions
runs = self.search_runs(
filter_string=f"tags.`model.name` = '{model_name}'",
max_results=max_results,
)
if not runs:
# Try params
runs = self.search_runs(
filter_string=f"params.model_name = '{model_name}'",
max_results=max_results,
)
return runs
def compare_experiments(
experiment_names: List[str],
metric: str,
tracking_uri: Optional[str] = None,
) -> Dict[str, Dict[str, float]]:
"""
Compare metrics across multiple experiments.
Args:
experiment_names: Names of experiments to compare
metric: Metric to compare
tracking_uri: Override default tracking URI
Returns:
Dict mapping experiment names to metric statistics
"""
results = {}
for exp_name in experiment_names:
analyzer = ExperimentAnalyzer(exp_name, tracking_uri=tracking_uri)
summary = analyzer.get_metrics_summary(metrics=[metric])
if metric in summary:
results[exp_name] = summary[metric]
return results
def promotion_recommendation(
model_name: str,
experiment_name: str,
criteria: Dict[str, Tuple[str, float]],
tracking_uri: Optional[str] = None,
) -> PromotionRecommendation:
"""
Generate a recommendation for model promotion.
Args:
model_name: Name of the model to evaluate
experiment_name: Experiment containing evaluation runs
criteria: Dict of {metric: (comparison, threshold)}
comparison is one of: ">=", "<=", ">", "<"
e.g., {"eval.accuracy": (">=", 0.9), "total_latency_p95": ("<=", 2.0)}
tracking_uri: Override default tracking URI
Returns:
PromotionRecommendation with decision and reasons
"""
analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri)
# Get model runs
runs = analyzer.get_model_runs(model_name, max_results=10)
if not runs:
return PromotionRecommendation(
model_name=model_name,
version=None,
recommended=False,
reasons=["No runs found for this model"],
metrics_summary={},
)
# Get the most recent run
latest_run = runs[0]
metrics = latest_run.data.metrics
# Evaluate criteria
reasons = []
passed = True
comparisons = {
">=": lambda a, b: a >= b,
"<=": lambda a, b: a <= b,
">": lambda a, b: a > b,
"<": lambda a, b: a < b,
}
for metric_name, (comparison, threshold) in criteria.items():
if metric_name not in metrics:
reasons.append(f"Metric '{metric_name}' not found")
passed = False
continue
value = metrics[metric_name]
compare_fn = comparisons.get(comparison)
if compare_fn is None:
reasons.append(f"Invalid comparison operator: {comparison}")
continue
if compare_fn(value, threshold):
reasons.append(f"{metric_name}: {value:.4f} {comparison} {threshold}")
else:
reasons.append(f"{metric_name}: {value:.4f} NOT {comparison} {threshold}")
passed = False
# Extract version from tags if available
version = None
if "mlflow.version" in latest_run.data.tags:
try:
version = int(latest_run.data.tags["mlflow.version"])
except ValueError:
pass
return PromotionRecommendation(
model_name=model_name,
version=version,
recommended=passed,
reasons=reasons,
metrics_summary=dict(metrics),
)
def get_inference_performance_report(
service_name: str = "chat-handler",
hours: int = 24,
tracking_uri: Optional[str] = None,
) -> Dict[str, Any]:
"""
Generate an inference performance report for a service.
Args:
service_name: Service name (chat-handler, voice-assistant)
hours: Hours of data to analyze
tracking_uri: Override default tracking URI
Returns:
Performance report dictionary
"""
experiment_name = f"{service_name.replace('-', '')}-inference"
analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri)
# Get summary metrics
summary = analyzer.get_metrics_summary(hours=hours)
# Key latency metrics
latency_metrics = [
"total_latency_mean",
"total_latency_p50",
"total_latency_p95",
"llm_latency_mean",
"embedding_latency_mean",
"rag_search_latency_mean",
]
report = {
"service": service_name,
"period_hours": hours,
"generated_at": datetime.now().isoformat(),
"latency": {},
"throughput": {},
"rag": {},
"errors": {},
}
# Latency section
for metric in latency_metrics:
if metric in summary:
report["latency"][metric] = summary[metric]
# Throughput
if "total_requests" in summary:
report["throughput"]["total_requests"] = summary["total_requests"]["mean"]
# RAG usage
rag_metrics = ["rag_enabled_pct", "rag_documents_retrieved_mean", "rag_documents_used_mean"]
for metric in rag_metrics:
if metric in summary:
report["rag"][metric] = summary[metric]
# Error rate
if "error_rate" in summary:
report["errors"]["error_rate_pct"] = summary["error_rate"]["mean"]
return report