feat: Add MLflow integration utilities

- client: Connection management and helpers
- tracker: General experiment tracking
- inference_tracker: Async metrics for NATS handlers
- model_registry: Model registration with KServe metadata
- kfp_components: Kubeflow Pipeline components
- experiment_comparison: Run comparison tools
- cli: Command-line interface
This commit is contained in:
2026-02-01 20:43:13 -05:00
parent 944bd8bc06
commit 2df3f27af7
10 changed files with 3315 additions and 1 deletions

371
mlflow_utils/cli.py Normal file
View File

@@ -0,0 +1,371 @@
#!/usr/bin/env python3
"""
MLflow Experiment CLI
Command-line interface for querying and comparing MLflow experiments.
Usage:
# Compare recent runs in an experiment
python -m mlflow_utils.cli compare --experiment chat-inference --runs 5
# Get best run by metric
python -m mlflow_utils.cli best --experiment evaluation --metric eval.accuracy
# Generate performance report
python -m mlflow_utils.cli report --service chat-handler --hours 24
# Check model promotion criteria
python -m mlflow_utils.cli promote --model whisper-finetuned \\
--experiment voice-evaluation \\
--criteria "eval.accuracy>=0.9,total_latency_p95<=2.0"
# List experiments
python -m mlflow_utils.cli list-experiments
# Query runs
python -m mlflow_utils.cli query --experiment chat-inference \\
--filter "metrics.total_latency_mean < 1.0" --limit 10
"""
import argparse
import json
import sys
from typing import Optional
from .client import get_mlflow_client, health_check
from .experiment_comparison import (
ExperimentAnalyzer,
compare_experiments,
promotion_recommendation,
get_inference_performance_report,
)
from .model_registry import (
list_model_versions,
get_production_model,
generate_kserve_yaml,
)
def cmd_health(args):
"""Check MLflow connectivity."""
result = health_check()
print(json.dumps(result, indent=2))
sys.exit(0 if result["connected"] else 1)
def cmd_list_experiments(args):
"""List all experiments."""
client = get_mlflow_client(tracking_uri=args.tracking_uri)
experiments = client.search_experiments()
print(f"{'ID':<10} {'Name':<40} {'Artifact Location'}")
print("-" * 80)
for exp in experiments:
print(f"{exp.experiment_id:<10} {exp.name:<40} {exp.artifact_location}")
def cmd_compare(args):
"""Compare recent runs in an experiment."""
analyzer = ExperimentAnalyzer(
args.experiment,
tracking_uri=args.tracking_uri
)
if args.run_ids:
run_ids = args.run_ids.split(",")
comparison = analyzer.compare_runs(run_ids=run_ids)
else:
comparison = analyzer.compare_runs(n_recent=args.runs)
if args.json:
print(json.dumps(comparison.to_dict(), indent=2, default=str))
else:
print(comparison.summary_table())
def cmd_best(args):
"""Find the best run by a metric."""
analyzer = ExperimentAnalyzer(
args.experiment,
tracking_uri=args.tracking_uri
)
best_run = analyzer.get_best_run(
metric=args.metric,
minimize=args.minimize,
filter_string=args.filter or "",
)
if not best_run:
print(f"No runs found with metric '{args.metric}'")
sys.exit(1)
result = {
"run_id": best_run.info.run_id,
"run_name": best_run.info.run_name,
"metric_value": best_run.data.metrics.get(args.metric),
"all_metrics": dict(best_run.data.metrics),
"params": dict(best_run.data.params),
}
if args.json:
print(json.dumps(result, indent=2))
else:
print(f"Best Run: {best_run.info.run_name or best_run.info.run_id}")
print(f" {args.metric}: {result['metric_value']}")
print(f" Run ID: {best_run.info.run_id}")
def cmd_summary(args):
"""Get metrics summary for an experiment."""
analyzer = ExperimentAnalyzer(
args.experiment,
tracking_uri=args.tracking_uri
)
summary = analyzer.get_metrics_summary(
hours=args.hours,
metrics=args.metrics.split(",") if args.metrics else None,
)
if args.json:
print(json.dumps(summary, indent=2))
else:
print(f"Metrics Summary for '{args.experiment}' (last {args.hours}h)")
print("=" * 60)
for metric, stats in sorted(summary.items()):
print(f"\n{metric}:")
for stat, value in stats.items():
print(f" {stat}: {value:.4f}")
def cmd_report(args):
"""Generate an inference performance report."""
report = get_inference_performance_report(
service_name=args.service,
hours=args.hours,
tracking_uri=args.tracking_uri,
)
if args.json:
print(json.dumps(report, indent=2))
else:
print(f"Performance Report: {report['service']}")
print(f"Period: Last {report['period_hours']} hours")
print(f"Generated: {report['generated_at']}")
print()
if report["latency"]:
print("Latency Metrics:")
for metric, stats in report["latency"].items():
if "mean" in stats:
print(f" {metric}: {stats['mean']:.4f}s (p50: {stats.get('median', 'N/A')})")
if report["rag"]:
print("\nRAG Usage:")
for metric, stats in report["rag"].items():
print(f" {metric}: {stats.get('mean', 'N/A')}")
if report["errors"]:
print("\nError Rates:")
for metric, stats in report["errors"].items():
print(f" {metric}: {stats:.2f}%")
def cmd_promote(args):
"""Check model promotion criteria."""
# Parse criteria
criteria = {}
for criterion in args.criteria.split(","):
# Parse "metric>=value" or "metric<=value" etc.
for op in [">=", "<=", ">", "<"]:
if op in criterion:
metric, value = criterion.split(op)
criteria[metric.strip()] = (op, float(value.strip()))
break
rec = promotion_recommendation(
model_name=args.model,
experiment_name=args.experiment,
criteria=criteria,
tracking_uri=args.tracking_uri,
)
if args.json:
print(json.dumps(rec.to_dict(), indent=2))
else:
status = "✓ RECOMMENDED" if rec.recommended else "✗ NOT RECOMMENDED"
print(f"Model: {args.model}")
print(f"Status: {status}")
print("\nCriteria Evaluation:")
for reason in rec.reasons:
print(f" {reason}")
def cmd_query(args):
"""Query runs with a filter."""
analyzer = ExperimentAnalyzer(
args.experiment,
tracking_uri=args.tracking_uri
)
runs = analyzer.search_runs(
filter_string=args.filter or "",
max_results=args.limit,
)
if args.json:
result = [
{
"run_id": r.info.run_id,
"run_name": r.info.run_name,
"status": r.info.status,
"metrics": dict(r.data.metrics),
"params": dict(r.data.params),
}
for r in runs
]
print(json.dumps(result, indent=2))
else:
print(f"Found {len(runs)} runs")
for run in runs:
print(f"\n{run.info.run_name or run.info.run_id}")
print(f" ID: {run.info.run_id}")
print(f" Status: {run.info.status}")
def cmd_models(args):
"""List registered models."""
client = get_mlflow_client(tracking_uri=args.tracking_uri)
if args.model:
versions = list_model_versions(args.model, tracking_uri=args.tracking_uri)
if args.json:
print(json.dumps(versions, indent=2, default=str))
else:
print(f"Model: {args.model}")
for v in versions:
print(f" v{v['version']} ({v['stage']}): {v['description'][:50] if v['description'] else 'No description'}")
else:
# List all models
models = client.search_registered_models()
if args.json:
result = [{"name": m.name, "description": m.description} for m in models]
print(json.dumps(result, indent=2))
else:
print(f"{'Model Name':<40} Description")
print("-" * 80)
for model in models:
desc = (model.description or "")[:35]
print(f"{model.name:<40} {desc}")
def cmd_kserve(args):
"""Generate KServe manifest for a model."""
yaml_str = generate_kserve_yaml(
model_name=args.model,
version=args.version,
namespace=args.namespace,
output_path=args.output,
tracking_uri=args.tracking_uri,
)
if not args.output:
print(yaml_str)
def main():
parser = argparse.ArgumentParser(
description="MLflow Experiment CLI",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--tracking-uri",
default=None,
help="MLflow tracking URI (default: from env or in-cluster)",
)
parser.add_argument(
"--json",
action="store_true",
help="Output as JSON",
)
subparsers = parser.add_subparsers(dest="command", help="Commands")
# health
health_parser = subparsers.add_parser("health", help="Check MLflow connectivity")
health_parser.set_defaults(func=cmd_health)
# list-experiments
list_parser = subparsers.add_parser("list-experiments", help="List experiments")
list_parser.set_defaults(func=cmd_list_experiments)
# compare
compare_parser = subparsers.add_parser("compare", help="Compare runs")
compare_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
compare_parser.add_argument("--runs", "-n", type=int, default=5, help="Number of recent runs")
compare_parser.add_argument("--run-ids", help="Comma-separated run IDs to compare")
compare_parser.set_defaults(func=cmd_compare)
# best
best_parser = subparsers.add_parser("best", help="Find best run by metric")
best_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
best_parser.add_argument("--metric", "-m", required=True, help="Metric to optimize")
best_parser.add_argument("--minimize", action="store_true", help="Minimize metric (default: maximize)")
best_parser.add_argument("--filter", "-f", help="Filter string")
best_parser.set_defaults(func=cmd_best)
# summary
summary_parser = subparsers.add_parser("summary", help="Get metrics summary")
summary_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
summary_parser.add_argument("--hours", type=int, default=24, help="Hours of data")
summary_parser.add_argument("--metrics", help="Comma-separated metric names")
summary_parser.set_defaults(func=cmd_summary)
# report
report_parser = subparsers.add_parser("report", help="Generate performance report")
report_parser.add_argument("--service", "-s", required=True, help="Service name")
report_parser.add_argument("--hours", type=int, default=24, help="Hours of data")
report_parser.set_defaults(func=cmd_report)
# promote
promote_parser = subparsers.add_parser("promote", help="Check promotion criteria")
promote_parser.add_argument("--model", "-m", required=True, help="Model name")
promote_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
promote_parser.add_argument("--criteria", "-c", required=True, help="Criteria (e.g., 'accuracy>=0.9,latency<=2.0')")
promote_parser.set_defaults(func=cmd_promote)
# query
query_parser = subparsers.add_parser("query", help="Query runs")
query_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
query_parser.add_argument("--filter", "-f", help="MLflow filter string")
query_parser.add_argument("--limit", "-l", type=int, default=20, help="Max results")
query_parser.set_defaults(func=cmd_query)
# models
models_parser = subparsers.add_parser("models", help="List registered models")
models_parser.add_argument("--model", "-m", help="Specific model name")
models_parser.set_defaults(func=cmd_models)
# kserve
kserve_parser = subparsers.add_parser("kserve", help="Generate KServe manifest")
kserve_parser.add_argument("--model", "-m", required=True, help="Model name")
kserve_parser.add_argument("--version", "-v", type=int, help="Model version")
kserve_parser.add_argument("--namespace", "-n", default="ai-ml", help="K8s namespace")
kserve_parser.add_argument("--output", "-o", help="Output file path")
kserve_parser.set_defaults(func=cmd_kserve)
args = parser.parse_args()
if not args.command:
parser.print_help()
sys.exit(1)
args.func(args)
if __name__ == "__main__":
main()