fix: resolve all ruff lint errors
Some checks failed
CI / Test (push) Successful in 1m46s
CI / Lint (push) Failing after 1m49s
CI / Publish (push) Has been skipped
CI / Notify (push) Successful in 2s

This commit is contained in:
2026-02-13 10:57:57 -05:00
parent 6bcf84549c
commit 1c841729a0
9 changed files with 456 additions and 464 deletions

View File

@@ -18,15 +18,15 @@ Usage:
get_best_run,
promotion_recommendation,
)
analyzer = ExperimentAnalyzer("chat-inference")
# Compare last N runs
comparison = analyzer.compare_recent_runs(n=5)
# Find best performing model
best = analyzer.get_best_run(metric="total_latency_mean", minimize=True)
# Get promotion recommendation
rec = analyzer.promotion_recommendation(
model_name="whisper-finetuned",
@@ -35,19 +35,15 @@ Usage:
)
"""
import os
import json
import logging
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List, Tuple, Union
from dataclasses import dataclass, field
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple
import mlflow
from mlflow.tracking import MlflowClient
from mlflow.entities import Run, Experiment
from mlflow.entities import Experiment, Run
from .client import get_mlflow_client, MLflowConfig
from .client import get_mlflow_client
logger = logging.getLogger(__name__)
@@ -57,21 +53,21 @@ class RunComparison:
"""Comparison result for multiple MLflow runs."""
run_ids: List[str]
experiment_name: str
# Metric comparisons (metric_name -> {run_id -> value})
metrics: Dict[str, Dict[str, float]] = field(default_factory=dict)
# Parameter differences
params: Dict[str, Dict[str, str]] = field(default_factory=dict)
# Run metadata
run_names: Dict[str, str] = field(default_factory=dict)
start_times: Dict[str, datetime] = field(default_factory=dict)
durations: Dict[str, float] = field(default_factory=dict)
# Best performers by metric
best_by_metric: Dict[str, str] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
@@ -82,22 +78,22 @@ class RunComparison:
"run_names": self.run_names,
"best_by_metric": self.best_by_metric,
}
def summary_table(self) -> str:
"""Generate a text summary table of the comparison."""
if not self.run_ids:
return "No runs to compare"
lines = []
lines.append(f"Experiment: {self.experiment_name}")
lines.append(f"Comparing {len(self.run_ids)} runs")
lines.append("")
# Header
header = ["Metric"] + [self.run_names.get(rid, rid[:8]) for rid in self.run_ids]
lines.append(" | ".join(header))
lines.append("-" * (len(lines[-1]) + 10))
# Metrics
for metric_name, values in sorted(self.metrics.items()):
row = [metric_name]
@@ -108,7 +104,7 @@ class RunComparison:
else:
row.append("N/A")
lines.append(" | ".join(row))
return "\n".join(lines)
@@ -121,7 +117,7 @@ class PromotionRecommendation:
reasons: List[str]
metrics_summary: Dict[str, float]
comparison_with_production: Optional[Dict[str, Any]] = None
def to_dict(self) -> Dict[str, Any]:
return {
"model_name": self.model_name,
@@ -136,20 +132,20 @@ class PromotionRecommendation:
class ExperimentAnalyzer:
"""
Analyze MLflow experiments for model comparison and promotion decisions.
Example:
analyzer = ExperimentAnalyzer("chat-inference")
# Get metrics summary for last 24 hours
summary = analyzer.get_metrics_summary(hours=24)
# Compare models by accuracy
best = analyzer.get_best_run(metric="eval.accuracy", minimize=False)
# Analyze inference latency trends
trends = analyzer.get_metric_trends("total_latency_mean", days=7)
"""
def __init__(
self,
experiment_name: str,
@@ -157,7 +153,7 @@ class ExperimentAnalyzer:
):
"""
Initialize the experiment analyzer.
Args:
experiment_name: Name of the MLflow experiment to analyze
tracking_uri: Override default tracking URI
@@ -166,14 +162,14 @@ class ExperimentAnalyzer:
self.tracking_uri = tracking_uri
self.client = get_mlflow_client(tracking_uri=tracking_uri)
self._experiment: Optional[Experiment] = None
@property
def experiment(self) -> Optional[Experiment]:
"""Get the experiment object, fetching if needed."""
if self._experiment is None:
self._experiment = self.client.get_experiment_by_name(self.experiment_name)
return self._experiment
def search_runs(
self,
filter_string: str = "",
@@ -183,29 +179,29 @@ class ExperimentAnalyzer:
) -> List[Run]:
"""
Search for runs matching criteria.
Args:
filter_string: MLflow search filter (e.g., "metrics.accuracy > 0.9")
order_by: List of order clauses (e.g., ["metrics.accuracy DESC"])
max_results: Maximum runs to return
run_view_type: ACTIVE_ONLY, DELETED_ONLY, or ALL
Returns:
List of matching Run objects
"""
if not self.experiment:
logger.warning(f"Experiment '{self.experiment_name}' not found")
return []
runs = self.client.search_runs(
experiment_ids=[self.experiment.experiment_id],
filter_string=filter_string,
order_by=order_by or ["start_time DESC"],
max_results=max_results,
)
return runs
def get_recent_runs(
self,
n: int = 10,
@@ -213,11 +209,11 @@ class ExperimentAnalyzer:
) -> List[Run]:
"""
Get the most recent runs.
Args:
n: Number of runs to return
hours: Only include runs from the last N hours
Returns:
List of Run objects
"""
@@ -226,13 +222,13 @@ class ExperimentAnalyzer:
cutoff = datetime.now() - timedelta(hours=hours)
cutoff_ms = int(cutoff.timestamp() * 1000)
filter_string = f"attributes.start_time >= {cutoff_ms}"
return self.search_runs(
filter_string=filter_string,
order_by=["start_time DESC"],
max_results=n,
)
def compare_runs(
self,
run_ids: Optional[List[str]] = None,
@@ -240,11 +236,11 @@ class ExperimentAnalyzer:
) -> RunComparison:
"""
Compare multiple runs side by side.
Args:
run_ids: Specific run IDs to compare, or None for recent runs
n_recent: If run_ids is None, compare this many recent runs
Returns:
RunComparison object with detailed comparison
"""
@@ -252,18 +248,18 @@ class ExperimentAnalyzer:
runs = [self.client.get_run(rid) for rid in run_ids]
else:
runs = self.get_recent_runs(n=n_recent)
comparison = RunComparison(
run_ids=[r.info.run_id for r in runs],
experiment_name=self.experiment_name,
)
# Collect all metrics and find best performers
all_metrics: Dict[str, Dict[str, float]] = defaultdict(dict)
for run in runs:
run_id = run.info.run_id
# Metadata
comparison.run_names[run_id] = run.info.run_name or run_id[:8]
comparison.start_times[run_id] = datetime.fromtimestamp(
@@ -273,39 +269,39 @@ class ExperimentAnalyzer:
comparison.durations[run_id] = (
run.info.end_time - run.info.start_time
) / 1000
# Metrics
for key, value in run.data.metrics.items():
all_metrics[key][run_id] = value
# Params
for key, value in run.data.params.items():
if key not in comparison.params:
comparison.params[key] = {}
comparison.params[key][run_id] = value
comparison.metrics = dict(all_metrics)
# Find best performers for each metric
for metric_name, values in all_metrics.items():
if not values:
continue
# Determine if lower is better based on metric name
minimize = any(
term in metric_name.lower()
for term in ["latency", "error", "loss", "time"]
)
if minimize:
best_id = min(values.keys(), key=lambda k: values[k])
else:
best_id = max(values.keys(), key=lambda k: values[k])
comparison.best_by_metric[metric_name] = best_id
return comparison
def get_best_run(
self,
metric: str,
@@ -315,32 +311,32 @@ class ExperimentAnalyzer:
) -> Optional[Run]:
"""
Get the best run by a specific metric.
Args:
metric: Metric name to optimize
minimize: If True, find minimum; if False, find maximum
filter_string: Additional filter criteria
max_results: Maximum runs to consider
Returns:
Best Run object, or None if no runs found
"""
direction = "ASC" if minimize else "DESC"
runs = self.search_runs(
filter_string=filter_string,
order_by=[f"metrics.{metric} {direction}"],
max_results=max_results,
)
# Filter to only runs that have the metric
runs_with_metric = [
r for r in runs
if metric in r.data.metrics
]
return runs_with_metric[0] if runs_with_metric else None
def get_metrics_summary(
self,
hours: Optional[int] = None,
@@ -348,45 +344,45 @@ class ExperimentAnalyzer:
) -> Dict[str, Dict[str, float]]:
"""
Get summary statistics for metrics.
Args:
hours: Only include runs from the last N hours
metrics: Specific metrics to summarize (None for all)
Returns:
Dict mapping metric names to {mean, min, max, count}
"""
import statistics
runs = self.get_recent_runs(n=1000, hours=hours)
# Collect all metric values
metric_values: Dict[str, List[float]] = defaultdict(list)
for run in runs:
for key, value in run.data.metrics.items():
if metrics is None or key in metrics:
metric_values[key].append(value)
# Calculate statistics
summary = {}
for metric_name, values in metric_values.items():
if not values:
continue
summary[metric_name] = {
"mean": statistics.mean(values),
"min": min(values),
"max": max(values),
"count": len(values),
}
if len(values) >= 2:
summary[metric_name]["stdev"] = statistics.stdev(values)
summary[metric_name]["median"] = statistics.median(values)
return summary
def get_metric_trends(
self,
metric: str,
@@ -395,30 +391,30 @@ class ExperimentAnalyzer:
) -> List[Dict[str, Any]]:
"""
Get metric trends over time.
Args:
metric: Metric name to track
days: Number of days to look back
granularity_hours: Time bucket size in hours
Returns:
List of {timestamp, mean, min, max, count} dicts
"""
import statistics
runs = self.get_recent_runs(n=10000, hours=days * 24)
# Group runs by time bucket
buckets: Dict[int, List[float]] = defaultdict(list)
bucket_size_ms = granularity_hours * 3600 * 1000
for run in runs:
if metric not in run.data.metrics:
continue
bucket = (run.info.start_time // bucket_size_ms) * bucket_size_ms
buckets[bucket].append(run.data.metrics[metric])
# Calculate statistics per bucket
trends = []
for bucket_ts, values in sorted(buckets.items()):
@@ -432,9 +428,9 @@ class ExperimentAnalyzer:
if len(values) >= 2:
trend["stdev"] = statistics.stdev(values)
trends.append(trend)
return trends
def get_runs_by_tag(
self,
tag_key: str,
@@ -443,12 +439,12 @@ class ExperimentAnalyzer:
) -> List[Run]:
"""
Get runs with a specific tag.
Args:
tag_key: Tag key to filter by
tag_value: Tag value to match
max_results: Maximum runs to return
Returns:
List of matching Run objects
"""
@@ -456,7 +452,7 @@ class ExperimentAnalyzer:
filter_string=f"tags.{tag_key} = '{tag_value}'",
max_results=max_results,
)
def get_model_runs(
self,
model_name: str,
@@ -464,11 +460,11 @@ class ExperimentAnalyzer:
) -> List[Run]:
"""
Get runs for a specific model.
Args:
model_name: Model name to filter by
max_results: Maximum runs to return
Returns:
List of matching Run objects
"""
@@ -477,14 +473,14 @@ class ExperimentAnalyzer:
filter_string=f"tags.`model.name` = '{model_name}'",
max_results=max_results,
)
if not runs:
# Try params
runs = self.search_runs(
filter_string=f"params.model_name = '{model_name}'",
max_results=max_results,
)
return runs
@@ -495,23 +491,23 @@ def compare_experiments(
) -> Dict[str, Dict[str, float]]:
"""
Compare metrics across multiple experiments.
Args:
experiment_names: Names of experiments to compare
metric: Metric to compare
tracking_uri: Override default tracking URI
Returns:
Dict mapping experiment names to metric statistics
"""
results = {}
for exp_name in experiment_names:
analyzer = ExperimentAnalyzer(exp_name, tracking_uri=tracking_uri)
summary = analyzer.get_metrics_summary(metrics=[metric])
if metric in summary:
results[exp_name] = summary[metric]
return results
@@ -523,7 +519,7 @@ def promotion_recommendation(
) -> PromotionRecommendation:
"""
Generate a recommendation for model promotion.
Args:
model_name: Name of the model to evaluate
experiment_name: Experiment containing evaluation runs
@@ -531,15 +527,15 @@ def promotion_recommendation(
comparison is one of: ">=", "<=", ">", "<"
e.g., {"eval.accuracy": (">=", 0.9), "total_latency_p95": ("<=", 2.0)}
tracking_uri: Override default tracking URI
Returns:
PromotionRecommendation with decision and reasons
"""
analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri)
# Get model runs
runs = analyzer.get_model_runs(model_name, max_results=10)
if not runs:
return PromotionRecommendation(
model_name=model_name,
@@ -548,41 +544,41 @@ def promotion_recommendation(
reasons=["No runs found for this model"],
metrics_summary={},
)
# Get the most recent run
latest_run = runs[0]
metrics = latest_run.data.metrics
# Evaluate criteria
reasons = []
passed = True
comparisons = {
">=": lambda a, b: a >= b,
"<=": lambda a, b: a <= b,
">": lambda a, b: a > b,
"<": lambda a, b: a < b,
}
for metric_name, (comparison, threshold) in criteria.items():
if metric_name not in metrics:
reasons.append(f"Metric '{metric_name}' not found")
passed = False
continue
value = metrics[metric_name]
compare_fn = comparisons.get(comparison)
if compare_fn is None:
reasons.append(f"Invalid comparison operator: {comparison}")
continue
if compare_fn(value, threshold):
reasons.append(f"{metric_name}: {value:.4f} {comparison} {threshold}")
else:
reasons.append(f"{metric_name}: {value:.4f} NOT {comparison} {threshold}")
passed = False
# Extract version from tags if available
version = None
if "mlflow.version" in latest_run.data.tags:
@@ -590,7 +586,7 @@ def promotion_recommendation(
version = int(latest_run.data.tags["mlflow.version"])
except ValueError:
pass
return PromotionRecommendation(
model_name=model_name,
version=version,
@@ -607,21 +603,21 @@ def get_inference_performance_report(
) -> Dict[str, Any]:
"""
Generate an inference performance report for a service.
Args:
service_name: Service name (chat-handler, voice-assistant)
hours: Hours of data to analyze
tracking_uri: Override default tracking URI
Returns:
Performance report dictionary
"""
experiment_name = f"{service_name.replace('-', '')}-inference"
analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri)
# Get summary metrics
summary = analyzer.get_metrics_summary(hours=hours)
# Key latency metrics
latency_metrics = [
"total_latency_mean",
@@ -631,7 +627,7 @@ def get_inference_performance_report(
"embedding_latency_mean",
"rag_search_latency_mean",
]
report = {
"service": service_name,
"period_hours": hours,
@@ -641,24 +637,24 @@ def get_inference_performance_report(
"rag": {},
"errors": {},
}
# Latency section
for metric in latency_metrics:
if metric in summary:
report["latency"][metric] = summary[metric]
# Throughput
if "total_requests" in summary:
report["throughput"]["total_requests"] = summary["total_requests"]["mean"]
# RAG usage
rag_metrics = ["rag_enabled_pct", "rag_documents_retrieved_mean", "rag_documents_used_mean"]
for metric in rag_metrics:
if metric in summary:
report["rag"][metric] = summary[metric]
# Error rate
if "error_rate" in summary:
report["errors"]["error_rate_pct"] = summary["error_rate"]["mean"]
return report