feat: Add MLflow integration utilities
- client: Connection management and helpers - tracker: General experiment tracking - inference_tracker: Async metrics for NATS handlers - model_registry: Model registration with KServe metadata - kfp_components: Kubeflow Pipeline components - experiment_comparison: Run comparison tools - cli: Command-line interface
This commit is contained in:
40
mlflow_utils/__init__.py
Normal file
40
mlflow_utils/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
MLflow Integration Utilities for LLM Workflows
|
||||
|
||||
This module provides MLflow integration for:
|
||||
- Kubeflow Pipelines experiment tracking
|
||||
- Model Registry with KServe deployment metadata
|
||||
- Inference metrics logging from NATS handlers
|
||||
- Experiment comparison and analysis
|
||||
|
||||
Configuration:
|
||||
Set MLFLOW_TRACKING_URI environment variable or use defaults:
|
||||
- In-cluster: http://mlflow.mlflow.svc.cluster.local:80
|
||||
- External: https://mlflow.lab.daviestechlabs.io
|
||||
|
||||
Usage:
|
||||
from mlflow_utils import get_mlflow_client, MLflowTracker
|
||||
from mlflow_utils.kfp_components import log_experiment_component
|
||||
from mlflow_utils.model_registry import register_model_for_kserve
|
||||
from mlflow_utils.inference_tracker import InferenceMetricsTracker
|
||||
"""
|
||||
|
||||
from .client import (
|
||||
get_mlflow_client,
|
||||
get_tracking_uri,
|
||||
ensure_experiment,
|
||||
MLflowConfig,
|
||||
)
|
||||
from .tracker import MLflowTracker
|
||||
from .inference_tracker import InferenceMetricsTracker
|
||||
|
||||
__all__ = [
|
||||
"get_mlflow_client",
|
||||
"get_tracking_uri",
|
||||
"ensure_experiment",
|
||||
"MLflowConfig",
|
||||
"MLflowTracker",
|
||||
"InferenceMetricsTracker",
|
||||
]
|
||||
|
||||
__version__ = "1.0.0"
|
||||
371
mlflow_utils/cli.py
Normal file
371
mlflow_utils/cli.py
Normal file
@@ -0,0 +1,371 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
MLflow Experiment CLI
|
||||
|
||||
Command-line interface for querying and comparing MLflow experiments.
|
||||
|
||||
Usage:
|
||||
# Compare recent runs in an experiment
|
||||
python -m mlflow_utils.cli compare --experiment chat-inference --runs 5
|
||||
|
||||
# Get best run by metric
|
||||
python -m mlflow_utils.cli best --experiment evaluation --metric eval.accuracy
|
||||
|
||||
# Generate performance report
|
||||
python -m mlflow_utils.cli report --service chat-handler --hours 24
|
||||
|
||||
# Check model promotion criteria
|
||||
python -m mlflow_utils.cli promote --model whisper-finetuned \\
|
||||
--experiment voice-evaluation \\
|
||||
--criteria "eval.accuracy>=0.9,total_latency_p95<=2.0"
|
||||
|
||||
# List experiments
|
||||
python -m mlflow_utils.cli list-experiments
|
||||
|
||||
# Query runs
|
||||
python -m mlflow_utils.cli query --experiment chat-inference \\
|
||||
--filter "metrics.total_latency_mean < 1.0" --limit 10
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
from .client import get_mlflow_client, health_check
|
||||
from .experiment_comparison import (
|
||||
ExperimentAnalyzer,
|
||||
compare_experiments,
|
||||
promotion_recommendation,
|
||||
get_inference_performance_report,
|
||||
)
|
||||
from .model_registry import (
|
||||
list_model_versions,
|
||||
get_production_model,
|
||||
generate_kserve_yaml,
|
||||
)
|
||||
|
||||
|
||||
def cmd_health(args):
|
||||
"""Check MLflow connectivity."""
|
||||
result = health_check()
|
||||
print(json.dumps(result, indent=2))
|
||||
sys.exit(0 if result["connected"] else 1)
|
||||
|
||||
|
||||
def cmd_list_experiments(args):
|
||||
"""List all experiments."""
|
||||
client = get_mlflow_client(tracking_uri=args.tracking_uri)
|
||||
experiments = client.search_experiments()
|
||||
|
||||
print(f"{'ID':<10} {'Name':<40} {'Artifact Location'}")
|
||||
print("-" * 80)
|
||||
for exp in experiments:
|
||||
print(f"{exp.experiment_id:<10} {exp.name:<40} {exp.artifact_location}")
|
||||
|
||||
|
||||
def cmd_compare(args):
|
||||
"""Compare recent runs in an experiment."""
|
||||
analyzer = ExperimentAnalyzer(
|
||||
args.experiment,
|
||||
tracking_uri=args.tracking_uri
|
||||
)
|
||||
|
||||
if args.run_ids:
|
||||
run_ids = args.run_ids.split(",")
|
||||
comparison = analyzer.compare_runs(run_ids=run_ids)
|
||||
else:
|
||||
comparison = analyzer.compare_runs(n_recent=args.runs)
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(comparison.to_dict(), indent=2, default=str))
|
||||
else:
|
||||
print(comparison.summary_table())
|
||||
|
||||
|
||||
def cmd_best(args):
|
||||
"""Find the best run by a metric."""
|
||||
analyzer = ExperimentAnalyzer(
|
||||
args.experiment,
|
||||
tracking_uri=args.tracking_uri
|
||||
)
|
||||
|
||||
best_run = analyzer.get_best_run(
|
||||
metric=args.metric,
|
||||
minimize=args.minimize,
|
||||
filter_string=args.filter or "",
|
||||
)
|
||||
|
||||
if not best_run:
|
||||
print(f"No runs found with metric '{args.metric}'")
|
||||
sys.exit(1)
|
||||
|
||||
result = {
|
||||
"run_id": best_run.info.run_id,
|
||||
"run_name": best_run.info.run_name,
|
||||
"metric_value": best_run.data.metrics.get(args.metric),
|
||||
"all_metrics": dict(best_run.data.metrics),
|
||||
"params": dict(best_run.data.params),
|
||||
}
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(result, indent=2))
|
||||
else:
|
||||
print(f"Best Run: {best_run.info.run_name or best_run.info.run_id}")
|
||||
print(f" {args.metric}: {result['metric_value']}")
|
||||
print(f" Run ID: {best_run.info.run_id}")
|
||||
|
||||
|
||||
def cmd_summary(args):
|
||||
"""Get metrics summary for an experiment."""
|
||||
analyzer = ExperimentAnalyzer(
|
||||
args.experiment,
|
||||
tracking_uri=args.tracking_uri
|
||||
)
|
||||
|
||||
summary = analyzer.get_metrics_summary(
|
||||
hours=args.hours,
|
||||
metrics=args.metrics.split(",") if args.metrics else None,
|
||||
)
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(summary, indent=2))
|
||||
else:
|
||||
print(f"Metrics Summary for '{args.experiment}' (last {args.hours}h)")
|
||||
print("=" * 60)
|
||||
for metric, stats in sorted(summary.items()):
|
||||
print(f"\n{metric}:")
|
||||
for stat, value in stats.items():
|
||||
print(f" {stat}: {value:.4f}")
|
||||
|
||||
|
||||
def cmd_report(args):
|
||||
"""Generate an inference performance report."""
|
||||
report = get_inference_performance_report(
|
||||
service_name=args.service,
|
||||
hours=args.hours,
|
||||
tracking_uri=args.tracking_uri,
|
||||
)
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(report, indent=2))
|
||||
else:
|
||||
print(f"Performance Report: {report['service']}")
|
||||
print(f"Period: Last {report['period_hours']} hours")
|
||||
print(f"Generated: {report['generated_at']}")
|
||||
print()
|
||||
|
||||
if report["latency"]:
|
||||
print("Latency Metrics:")
|
||||
for metric, stats in report["latency"].items():
|
||||
if "mean" in stats:
|
||||
print(f" {metric}: {stats['mean']:.4f}s (p50: {stats.get('median', 'N/A')})")
|
||||
|
||||
if report["rag"]:
|
||||
print("\nRAG Usage:")
|
||||
for metric, stats in report["rag"].items():
|
||||
print(f" {metric}: {stats.get('mean', 'N/A')}")
|
||||
|
||||
if report["errors"]:
|
||||
print("\nError Rates:")
|
||||
for metric, stats in report["errors"].items():
|
||||
print(f" {metric}: {stats:.2f}%")
|
||||
|
||||
|
||||
def cmd_promote(args):
|
||||
"""Check model promotion criteria."""
|
||||
# Parse criteria
|
||||
criteria = {}
|
||||
for criterion in args.criteria.split(","):
|
||||
# Parse "metric>=value" or "metric<=value" etc.
|
||||
for op in [">=", "<=", ">", "<"]:
|
||||
if op in criterion:
|
||||
metric, value = criterion.split(op)
|
||||
criteria[metric.strip()] = (op, float(value.strip()))
|
||||
break
|
||||
|
||||
rec = promotion_recommendation(
|
||||
model_name=args.model,
|
||||
experiment_name=args.experiment,
|
||||
criteria=criteria,
|
||||
tracking_uri=args.tracking_uri,
|
||||
)
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(rec.to_dict(), indent=2))
|
||||
else:
|
||||
status = "✓ RECOMMENDED" if rec.recommended else "✗ NOT RECOMMENDED"
|
||||
print(f"Model: {args.model}")
|
||||
print(f"Status: {status}")
|
||||
print("\nCriteria Evaluation:")
|
||||
for reason in rec.reasons:
|
||||
print(f" {reason}")
|
||||
|
||||
|
||||
def cmd_query(args):
|
||||
"""Query runs with a filter."""
|
||||
analyzer = ExperimentAnalyzer(
|
||||
args.experiment,
|
||||
tracking_uri=args.tracking_uri
|
||||
)
|
||||
|
||||
runs = analyzer.search_runs(
|
||||
filter_string=args.filter or "",
|
||||
max_results=args.limit,
|
||||
)
|
||||
|
||||
if args.json:
|
||||
result = [
|
||||
{
|
||||
"run_id": r.info.run_id,
|
||||
"run_name": r.info.run_name,
|
||||
"status": r.info.status,
|
||||
"metrics": dict(r.data.metrics),
|
||||
"params": dict(r.data.params),
|
||||
}
|
||||
for r in runs
|
||||
]
|
||||
print(json.dumps(result, indent=2))
|
||||
else:
|
||||
print(f"Found {len(runs)} runs")
|
||||
for run in runs:
|
||||
print(f"\n{run.info.run_name or run.info.run_id}")
|
||||
print(f" ID: {run.info.run_id}")
|
||||
print(f" Status: {run.info.status}")
|
||||
|
||||
|
||||
def cmd_models(args):
|
||||
"""List registered models."""
|
||||
client = get_mlflow_client(tracking_uri=args.tracking_uri)
|
||||
|
||||
if args.model:
|
||||
versions = list_model_versions(args.model, tracking_uri=args.tracking_uri)
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(versions, indent=2, default=str))
|
||||
else:
|
||||
print(f"Model: {args.model}")
|
||||
for v in versions:
|
||||
print(f" v{v['version']} ({v['stage']}): {v['description'][:50] if v['description'] else 'No description'}")
|
||||
else:
|
||||
# List all models
|
||||
models = client.search_registered_models()
|
||||
|
||||
if args.json:
|
||||
result = [{"name": m.name, "description": m.description} for m in models]
|
||||
print(json.dumps(result, indent=2))
|
||||
else:
|
||||
print(f"{'Model Name':<40} Description")
|
||||
print("-" * 80)
|
||||
for model in models:
|
||||
desc = (model.description or "")[:35]
|
||||
print(f"{model.name:<40} {desc}")
|
||||
|
||||
|
||||
def cmd_kserve(args):
|
||||
"""Generate KServe manifest for a model."""
|
||||
yaml_str = generate_kserve_yaml(
|
||||
model_name=args.model,
|
||||
version=args.version,
|
||||
namespace=args.namespace,
|
||||
output_path=args.output,
|
||||
tracking_uri=args.tracking_uri,
|
||||
)
|
||||
|
||||
if not args.output:
|
||||
print(yaml_str)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="MLflow Experiment CLI",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tracking-uri",
|
||||
default=None,
|
||||
help="MLflow tracking URI (default: from env or in-cluster)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--json",
|
||||
action="store_true",
|
||||
help="Output as JSON",
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", help="Commands")
|
||||
|
||||
# health
|
||||
health_parser = subparsers.add_parser("health", help="Check MLflow connectivity")
|
||||
health_parser.set_defaults(func=cmd_health)
|
||||
|
||||
# list-experiments
|
||||
list_parser = subparsers.add_parser("list-experiments", help="List experiments")
|
||||
list_parser.set_defaults(func=cmd_list_experiments)
|
||||
|
||||
# compare
|
||||
compare_parser = subparsers.add_parser("compare", help="Compare runs")
|
||||
compare_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
||||
compare_parser.add_argument("--runs", "-n", type=int, default=5, help="Number of recent runs")
|
||||
compare_parser.add_argument("--run-ids", help="Comma-separated run IDs to compare")
|
||||
compare_parser.set_defaults(func=cmd_compare)
|
||||
|
||||
# best
|
||||
best_parser = subparsers.add_parser("best", help="Find best run by metric")
|
||||
best_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
||||
best_parser.add_argument("--metric", "-m", required=True, help="Metric to optimize")
|
||||
best_parser.add_argument("--minimize", action="store_true", help="Minimize metric (default: maximize)")
|
||||
best_parser.add_argument("--filter", "-f", help="Filter string")
|
||||
best_parser.set_defaults(func=cmd_best)
|
||||
|
||||
# summary
|
||||
summary_parser = subparsers.add_parser("summary", help="Get metrics summary")
|
||||
summary_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
||||
summary_parser.add_argument("--hours", type=int, default=24, help="Hours of data")
|
||||
summary_parser.add_argument("--metrics", help="Comma-separated metric names")
|
||||
summary_parser.set_defaults(func=cmd_summary)
|
||||
|
||||
# report
|
||||
report_parser = subparsers.add_parser("report", help="Generate performance report")
|
||||
report_parser.add_argument("--service", "-s", required=True, help="Service name")
|
||||
report_parser.add_argument("--hours", type=int, default=24, help="Hours of data")
|
||||
report_parser.set_defaults(func=cmd_report)
|
||||
|
||||
# promote
|
||||
promote_parser = subparsers.add_parser("promote", help="Check promotion criteria")
|
||||
promote_parser.add_argument("--model", "-m", required=True, help="Model name")
|
||||
promote_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
||||
promote_parser.add_argument("--criteria", "-c", required=True, help="Criteria (e.g., 'accuracy>=0.9,latency<=2.0')")
|
||||
promote_parser.set_defaults(func=cmd_promote)
|
||||
|
||||
# query
|
||||
query_parser = subparsers.add_parser("query", help="Query runs")
|
||||
query_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
||||
query_parser.add_argument("--filter", "-f", help="MLflow filter string")
|
||||
query_parser.add_argument("--limit", "-l", type=int, default=20, help="Max results")
|
||||
query_parser.set_defaults(func=cmd_query)
|
||||
|
||||
# models
|
||||
models_parser = subparsers.add_parser("models", help="List registered models")
|
||||
models_parser.add_argument("--model", "-m", help="Specific model name")
|
||||
models_parser.set_defaults(func=cmd_models)
|
||||
|
||||
# kserve
|
||||
kserve_parser = subparsers.add_parser("kserve", help="Generate KServe manifest")
|
||||
kserve_parser.add_argument("--model", "-m", required=True, help="Model name")
|
||||
kserve_parser.add_argument("--version", "-v", type=int, help="Model version")
|
||||
kserve_parser.add_argument("--namespace", "-n", default="ai-ml", help="K8s namespace")
|
||||
kserve_parser.add_argument("--output", "-o", help="Output file path")
|
||||
kserve_parser.set_defaults(func=cmd_kserve)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
args.func(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
209
mlflow_utils/client.py
Normal file
209
mlflow_utils/client.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
MLflow Client Configuration and Initialization
|
||||
|
||||
Provides a configured MLflow client for all integrations in the LLM workflows.
|
||||
Supports both in-cluster and external access patterns.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLflowConfig:
|
||||
"""Configuration for MLflow integration."""
|
||||
|
||||
# Tracking server URIs
|
||||
tracking_uri: str = field(
|
||||
default_factory=lambda: os.environ.get(
|
||||
"MLFLOW_TRACKING_URI",
|
||||
"http://mlflow.mlflow.svc.cluster.local:80"
|
||||
)
|
||||
)
|
||||
external_uri: str = field(
|
||||
default_factory=lambda: os.environ.get(
|
||||
"MLFLOW_EXTERNAL_URI",
|
||||
"https://mlflow.lab.daviestechlabs.io"
|
||||
)
|
||||
)
|
||||
|
||||
# Artifact storage (NFS PVC mount)
|
||||
artifact_location: str = field(
|
||||
default_factory=lambda: os.environ.get(
|
||||
"MLFLOW_ARTIFACT_LOCATION",
|
||||
"/mlflow/artifacts"
|
||||
)
|
||||
)
|
||||
|
||||
# Default experiment settings
|
||||
default_experiment: str = field(
|
||||
default_factory=lambda: os.environ.get(
|
||||
"MLFLOW_DEFAULT_EXPERIMENT",
|
||||
"llm-workflows"
|
||||
)
|
||||
)
|
||||
|
||||
# Service identification
|
||||
service_name: str = field(
|
||||
default_factory=lambda: os.environ.get(
|
||||
"OTEL_SERVICE_NAME",
|
||||
"unknown-service"
|
||||
)
|
||||
)
|
||||
|
||||
# Additional tags to add to all runs
|
||||
default_tags: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Add default tags based on environment."""
|
||||
env_tags = {
|
||||
"environment": os.environ.get("DEPLOYMENT_ENV", "production"),
|
||||
"hostname": os.environ.get("HOSTNAME", "unknown"),
|
||||
"namespace": os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml"),
|
||||
}
|
||||
self.default_tags = {**env_tags, **self.default_tags}
|
||||
|
||||
|
||||
def get_tracking_uri(external: bool = False) -> str:
|
||||
"""
|
||||
Get the appropriate MLflow tracking URI.
|
||||
|
||||
Args:
|
||||
external: If True, return the external URI for outside-cluster access
|
||||
|
||||
Returns:
|
||||
The MLflow tracking URI string
|
||||
"""
|
||||
config = MLflowConfig()
|
||||
return config.external_uri if external else config.tracking_uri
|
||||
|
||||
|
||||
def get_mlflow_client(
|
||||
tracking_uri: Optional[str] = None,
|
||||
configure_global: bool = True
|
||||
) -> MlflowClient:
|
||||
"""
|
||||
Get a configured MLflow client.
|
||||
|
||||
Args:
|
||||
tracking_uri: Override the default tracking URI
|
||||
configure_global: If True, also set mlflow.set_tracking_uri()
|
||||
|
||||
Returns:
|
||||
Configured MlflowClient instance
|
||||
"""
|
||||
uri = tracking_uri or get_tracking_uri()
|
||||
|
||||
if configure_global:
|
||||
mlflow.set_tracking_uri(uri)
|
||||
logger.info(f"MLflow tracking URI set to: {uri}")
|
||||
|
||||
client = MlflowClient(tracking_uri=uri)
|
||||
return client
|
||||
|
||||
|
||||
def ensure_experiment(
|
||||
experiment_name: str,
|
||||
artifact_location: Optional[str] = None,
|
||||
tags: Optional[Dict[str, str]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Ensure an experiment exists, creating it if necessary.
|
||||
|
||||
Args:
|
||||
experiment_name: Name of the experiment
|
||||
artifact_location: Override default artifact location
|
||||
tags: Additional tags for the experiment
|
||||
|
||||
Returns:
|
||||
The experiment ID
|
||||
"""
|
||||
config = MLflowConfig()
|
||||
client = get_mlflow_client()
|
||||
|
||||
# Check if experiment exists
|
||||
experiment = client.get_experiment_by_name(experiment_name)
|
||||
|
||||
if experiment is None:
|
||||
# Create the experiment
|
||||
artifact_loc = artifact_location or f"{config.artifact_location}/{experiment_name}"
|
||||
experiment_id = client.create_experiment(
|
||||
name=experiment_name,
|
||||
artifact_location=artifact_loc,
|
||||
tags=tags or {}
|
||||
)
|
||||
logger.info(f"Created experiment '{experiment_name}' with ID: {experiment_id}")
|
||||
else:
|
||||
experiment_id = experiment.experiment_id
|
||||
logger.debug(f"Using existing experiment '{experiment_name}' with ID: {experiment_id}")
|
||||
|
||||
return experiment_id
|
||||
|
||||
|
||||
def get_or_create_registered_model(
|
||||
model_name: str,
|
||||
description: Optional[str] = None,
|
||||
tags: Optional[Dict[str, str]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Get or create a registered model in the Model Registry.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to register
|
||||
description: Model description
|
||||
tags: Tags for the model
|
||||
|
||||
Returns:
|
||||
The registered model name
|
||||
"""
|
||||
client = get_mlflow_client()
|
||||
|
||||
try:
|
||||
# Check if model exists
|
||||
client.get_registered_model(model_name)
|
||||
logger.debug(f"Using existing registered model: {model_name}")
|
||||
except mlflow.exceptions.MlflowException:
|
||||
# Create the model
|
||||
client.create_registered_model(
|
||||
name=model_name,
|
||||
description=description or f"Model for {model_name}",
|
||||
tags=tags or {}
|
||||
)
|
||||
logger.info(f"Created registered model: {model_name}")
|
||||
|
||||
return model_name
|
||||
|
||||
|
||||
def health_check() -> Dict[str, Any]:
|
||||
"""
|
||||
Check MLflow server connectivity and return status.
|
||||
|
||||
Returns:
|
||||
Dictionary with health status information
|
||||
"""
|
||||
config = MLflowConfig()
|
||||
result = {
|
||||
"tracking_uri": config.tracking_uri,
|
||||
"external_uri": config.external_uri,
|
||||
"connected": False,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
try:
|
||||
client = get_mlflow_client(configure_global=False)
|
||||
# Try to list experiments as a health check
|
||||
experiments = client.search_experiments(max_results=1)
|
||||
result["connected"] = True
|
||||
result["experiment_count"] = len(experiments)
|
||||
except Exception as e:
|
||||
result["error"] = str(e)
|
||||
logger.error(f"MLflow health check failed: {e}")
|
||||
|
||||
return result
|
||||
664
mlflow_utils/experiment_comparison.py
Normal file
664
mlflow_utils/experiment_comparison.py
Normal file
@@ -0,0 +1,664 @@
|
||||
"""
|
||||
Experiment Comparison and Analysis Utilities
|
||||
|
||||
Provides tools for comparing model versions, querying experiments,
|
||||
and making data-driven decisions about model promotion to production.
|
||||
|
||||
Features:
|
||||
- Compare multiple runs/experiments side by side
|
||||
- Query experiments by tags, metrics, or parameters
|
||||
- Analyze inference metrics from NATS handlers
|
||||
- Generate promotion recommendations
|
||||
- Export comparison reports
|
||||
|
||||
Usage:
|
||||
from mlflow_utils.experiment_comparison import (
|
||||
ExperimentAnalyzer,
|
||||
compare_runs,
|
||||
get_best_run,
|
||||
promotion_recommendation,
|
||||
)
|
||||
|
||||
analyzer = ExperimentAnalyzer("chat-inference")
|
||||
|
||||
# Compare last N runs
|
||||
comparison = analyzer.compare_recent_runs(n=5)
|
||||
|
||||
# Find best performing model
|
||||
best = analyzer.get_best_run(metric="total_latency_mean", minimize=True)
|
||||
|
||||
# Get promotion recommendation
|
||||
rec = analyzer.promotion_recommendation(
|
||||
model_name="whisper-finetuned",
|
||||
min_accuracy=0.9,
|
||||
max_latency_p95=2.0
|
||||
)
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any, List, Tuple, Union
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from mlflow.entities import Run, Experiment
|
||||
|
||||
from .client import get_mlflow_client, MLflowConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunComparison:
|
||||
"""Comparison result for multiple MLflow runs."""
|
||||
run_ids: List[str]
|
||||
experiment_name: str
|
||||
|
||||
# Metric comparisons (metric_name -> {run_id -> value})
|
||||
metrics: Dict[str, Dict[str, float]] = field(default_factory=dict)
|
||||
|
||||
# Parameter differences
|
||||
params: Dict[str, Dict[str, str]] = field(default_factory=dict)
|
||||
|
||||
# Run metadata
|
||||
run_names: Dict[str, str] = field(default_factory=dict)
|
||||
start_times: Dict[str, datetime] = field(default_factory=dict)
|
||||
durations: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
# Best performers by metric
|
||||
best_by_metric: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"run_ids": self.run_ids,
|
||||
"experiment_name": self.experiment_name,
|
||||
"metrics": self.metrics,
|
||||
"params": self.params,
|
||||
"run_names": self.run_names,
|
||||
"best_by_metric": self.best_by_metric,
|
||||
}
|
||||
|
||||
def summary_table(self) -> str:
|
||||
"""Generate a text summary table of the comparison."""
|
||||
if not self.run_ids:
|
||||
return "No runs to compare"
|
||||
|
||||
lines = []
|
||||
lines.append(f"Experiment: {self.experiment_name}")
|
||||
lines.append(f"Comparing {len(self.run_ids)} runs")
|
||||
lines.append("")
|
||||
|
||||
# Header
|
||||
header = ["Metric"] + [self.run_names.get(rid, rid[:8]) for rid in self.run_ids]
|
||||
lines.append(" | ".join(header))
|
||||
lines.append("-" * (len(lines[-1]) + 10))
|
||||
|
||||
# Metrics
|
||||
for metric_name, values in sorted(self.metrics.items()):
|
||||
row = [metric_name]
|
||||
for run_id in self.run_ids:
|
||||
value = values.get(run_id)
|
||||
if value is not None:
|
||||
row.append(f"{value:.4f}")
|
||||
else:
|
||||
row.append("N/A")
|
||||
lines.append(" | ".join(row))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromotionRecommendation:
|
||||
"""Recommendation for model promotion."""
|
||||
model_name: str
|
||||
version: Optional[int]
|
||||
recommended: bool
|
||||
reasons: List[str]
|
||||
metrics_summary: Dict[str, float]
|
||||
comparison_with_production: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"version": self.version,
|
||||
"recommended": self.recommended,
|
||||
"reasons": self.reasons,
|
||||
"metrics_summary": self.metrics_summary,
|
||||
"comparison_with_production": self.comparison_with_production,
|
||||
}
|
||||
|
||||
|
||||
class ExperimentAnalyzer:
|
||||
"""
|
||||
Analyze MLflow experiments for model comparison and promotion decisions.
|
||||
|
||||
Example:
|
||||
analyzer = ExperimentAnalyzer("chat-inference")
|
||||
|
||||
# Get metrics summary for last 24 hours
|
||||
summary = analyzer.get_metrics_summary(hours=24)
|
||||
|
||||
# Compare models by accuracy
|
||||
best = analyzer.get_best_run(metric="eval.accuracy", minimize=False)
|
||||
|
||||
# Analyze inference latency trends
|
||||
trends = analyzer.get_metric_trends("total_latency_mean", days=7)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name: str,
|
||||
tracking_uri: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the experiment analyzer.
|
||||
|
||||
Args:
|
||||
experiment_name: Name of the MLflow experiment to analyze
|
||||
tracking_uri: Override default tracking URI
|
||||
"""
|
||||
self.experiment_name = experiment_name
|
||||
self.tracking_uri = tracking_uri
|
||||
self.client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
self._experiment: Optional[Experiment] = None
|
||||
|
||||
@property
|
||||
def experiment(self) -> Optional[Experiment]:
|
||||
"""Get the experiment object, fetching if needed."""
|
||||
if self._experiment is None:
|
||||
self._experiment = self.client.get_experiment_by_name(self.experiment_name)
|
||||
return self._experiment
|
||||
|
||||
def search_runs(
|
||||
self,
|
||||
filter_string: str = "",
|
||||
order_by: Optional[List[str]] = None,
|
||||
max_results: int = 100,
|
||||
run_view_type: str = "ACTIVE_ONLY",
|
||||
) -> List[Run]:
|
||||
"""
|
||||
Search for runs matching criteria.
|
||||
|
||||
Args:
|
||||
filter_string: MLflow search filter (e.g., "metrics.accuracy > 0.9")
|
||||
order_by: List of order clauses (e.g., ["metrics.accuracy DESC"])
|
||||
max_results: Maximum runs to return
|
||||
run_view_type: ACTIVE_ONLY, DELETED_ONLY, or ALL
|
||||
|
||||
Returns:
|
||||
List of matching Run objects
|
||||
"""
|
||||
if not self.experiment:
|
||||
logger.warning(f"Experiment '{self.experiment_name}' not found")
|
||||
return []
|
||||
|
||||
runs = self.client.search_runs(
|
||||
experiment_ids=[self.experiment.experiment_id],
|
||||
filter_string=filter_string,
|
||||
order_by=order_by or ["start_time DESC"],
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
return runs
|
||||
|
||||
def get_recent_runs(
|
||||
self,
|
||||
n: int = 10,
|
||||
hours: Optional[int] = None,
|
||||
) -> List[Run]:
|
||||
"""
|
||||
Get the most recent runs.
|
||||
|
||||
Args:
|
||||
n: Number of runs to return
|
||||
hours: Only include runs from the last N hours
|
||||
|
||||
Returns:
|
||||
List of Run objects
|
||||
"""
|
||||
filter_string = ""
|
||||
if hours:
|
||||
cutoff = datetime.now() - timedelta(hours=hours)
|
||||
cutoff_ms = int(cutoff.timestamp() * 1000)
|
||||
filter_string = f"attributes.start_time >= {cutoff_ms}"
|
||||
|
||||
return self.search_runs(
|
||||
filter_string=filter_string,
|
||||
order_by=["start_time DESC"],
|
||||
max_results=n,
|
||||
)
|
||||
|
||||
def compare_runs(
|
||||
self,
|
||||
run_ids: Optional[List[str]] = None,
|
||||
n_recent: int = 5,
|
||||
) -> RunComparison:
|
||||
"""
|
||||
Compare multiple runs side by side.
|
||||
|
||||
Args:
|
||||
run_ids: Specific run IDs to compare, or None for recent runs
|
||||
n_recent: If run_ids is None, compare this many recent runs
|
||||
|
||||
Returns:
|
||||
RunComparison object with detailed comparison
|
||||
"""
|
||||
if run_ids:
|
||||
runs = [self.client.get_run(rid) for rid in run_ids]
|
||||
else:
|
||||
runs = self.get_recent_runs(n=n_recent)
|
||||
|
||||
comparison = RunComparison(
|
||||
run_ids=[r.info.run_id for r in runs],
|
||||
experiment_name=self.experiment_name,
|
||||
)
|
||||
|
||||
# Collect all metrics and find best performers
|
||||
all_metrics: Dict[str, Dict[str, float]] = defaultdict(dict)
|
||||
|
||||
for run in runs:
|
||||
run_id = run.info.run_id
|
||||
|
||||
# Metadata
|
||||
comparison.run_names[run_id] = run.info.run_name or run_id[:8]
|
||||
comparison.start_times[run_id] = datetime.fromtimestamp(
|
||||
run.info.start_time / 1000
|
||||
)
|
||||
if run.info.end_time:
|
||||
comparison.durations[run_id] = (
|
||||
run.info.end_time - run.info.start_time
|
||||
) / 1000
|
||||
|
||||
# Metrics
|
||||
for key, value in run.data.metrics.items():
|
||||
all_metrics[key][run_id] = value
|
||||
|
||||
# Params
|
||||
for key, value in run.data.params.items():
|
||||
if key not in comparison.params:
|
||||
comparison.params[key] = {}
|
||||
comparison.params[key][run_id] = value
|
||||
|
||||
comparison.metrics = dict(all_metrics)
|
||||
|
||||
# Find best performers for each metric
|
||||
for metric_name, values in all_metrics.items():
|
||||
if not values:
|
||||
continue
|
||||
|
||||
# Determine if lower is better based on metric name
|
||||
minimize = any(
|
||||
term in metric_name.lower()
|
||||
for term in ["latency", "error", "loss", "time"]
|
||||
)
|
||||
|
||||
if minimize:
|
||||
best_id = min(values.keys(), key=lambda k: values[k])
|
||||
else:
|
||||
best_id = max(values.keys(), key=lambda k: values[k])
|
||||
|
||||
comparison.best_by_metric[metric_name] = best_id
|
||||
|
||||
return comparison
|
||||
|
||||
def get_best_run(
|
||||
self,
|
||||
metric: str,
|
||||
minimize: bool = True,
|
||||
filter_string: str = "",
|
||||
max_results: int = 100,
|
||||
) -> Optional[Run]:
|
||||
"""
|
||||
Get the best run by a specific metric.
|
||||
|
||||
Args:
|
||||
metric: Metric name to optimize
|
||||
minimize: If True, find minimum; if False, find maximum
|
||||
filter_string: Additional filter criteria
|
||||
max_results: Maximum runs to consider
|
||||
|
||||
Returns:
|
||||
Best Run object, or None if no runs found
|
||||
"""
|
||||
direction = "ASC" if minimize else "DESC"
|
||||
|
||||
runs = self.search_runs(
|
||||
filter_string=filter_string,
|
||||
order_by=[f"metrics.{metric} {direction}"],
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
# Filter to only runs that have the metric
|
||||
runs_with_metric = [
|
||||
r for r in runs
|
||||
if metric in r.data.metrics
|
||||
]
|
||||
|
||||
return runs_with_metric[0] if runs_with_metric else None
|
||||
|
||||
def get_metrics_summary(
|
||||
self,
|
||||
hours: Optional[int] = None,
|
||||
metrics: Optional[List[str]] = None,
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
Get summary statistics for metrics.
|
||||
|
||||
Args:
|
||||
hours: Only include runs from the last N hours
|
||||
metrics: Specific metrics to summarize (None for all)
|
||||
|
||||
Returns:
|
||||
Dict mapping metric names to {mean, min, max, count}
|
||||
"""
|
||||
import statistics
|
||||
|
||||
runs = self.get_recent_runs(n=1000, hours=hours)
|
||||
|
||||
# Collect all metric values
|
||||
metric_values: Dict[str, List[float]] = defaultdict(list)
|
||||
|
||||
for run in runs:
|
||||
for key, value in run.data.metrics.items():
|
||||
if metrics is None or key in metrics:
|
||||
metric_values[key].append(value)
|
||||
|
||||
# Calculate statistics
|
||||
summary = {}
|
||||
for metric_name, values in metric_values.items():
|
||||
if not values:
|
||||
continue
|
||||
|
||||
summary[metric_name] = {
|
||||
"mean": statistics.mean(values),
|
||||
"min": min(values),
|
||||
"max": max(values),
|
||||
"count": len(values),
|
||||
}
|
||||
|
||||
if len(values) >= 2:
|
||||
summary[metric_name]["stdev"] = statistics.stdev(values)
|
||||
summary[metric_name]["median"] = statistics.median(values)
|
||||
|
||||
return summary
|
||||
|
||||
def get_metric_trends(
|
||||
self,
|
||||
metric: str,
|
||||
days: int = 7,
|
||||
granularity_hours: int = 1,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get metric trends over time.
|
||||
|
||||
Args:
|
||||
metric: Metric name to track
|
||||
days: Number of days to look back
|
||||
granularity_hours: Time bucket size in hours
|
||||
|
||||
Returns:
|
||||
List of {timestamp, mean, min, max, count} dicts
|
||||
"""
|
||||
import statistics
|
||||
|
||||
runs = self.get_recent_runs(n=10000, hours=days * 24)
|
||||
|
||||
# Group runs by time bucket
|
||||
buckets: Dict[int, List[float]] = defaultdict(list)
|
||||
bucket_size_ms = granularity_hours * 3600 * 1000
|
||||
|
||||
for run in runs:
|
||||
if metric not in run.data.metrics:
|
||||
continue
|
||||
|
||||
bucket = (run.info.start_time // bucket_size_ms) * bucket_size_ms
|
||||
buckets[bucket].append(run.data.metrics[metric])
|
||||
|
||||
# Calculate statistics per bucket
|
||||
trends = []
|
||||
for bucket_ts, values in sorted(buckets.items()):
|
||||
trend = {
|
||||
"timestamp": datetime.fromtimestamp(bucket_ts / 1000).isoformat(),
|
||||
"count": len(values),
|
||||
"mean": statistics.mean(values),
|
||||
"min": min(values),
|
||||
"max": max(values),
|
||||
}
|
||||
if len(values) >= 2:
|
||||
trend["stdev"] = statistics.stdev(values)
|
||||
trends.append(trend)
|
||||
|
||||
return trends
|
||||
|
||||
def get_runs_by_tag(
|
||||
self,
|
||||
tag_key: str,
|
||||
tag_value: str,
|
||||
max_results: int = 100,
|
||||
) -> List[Run]:
|
||||
"""
|
||||
Get runs with a specific tag.
|
||||
|
||||
Args:
|
||||
tag_key: Tag key to filter by
|
||||
tag_value: Tag value to match
|
||||
max_results: Maximum runs to return
|
||||
|
||||
Returns:
|
||||
List of matching Run objects
|
||||
"""
|
||||
return self.search_runs(
|
||||
filter_string=f"tags.{tag_key} = '{tag_value}'",
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
def get_model_runs(
|
||||
self,
|
||||
model_name: str,
|
||||
max_results: int = 100,
|
||||
) -> List[Run]:
|
||||
"""
|
||||
Get runs for a specific model.
|
||||
|
||||
Args:
|
||||
model_name: Model name to filter by
|
||||
max_results: Maximum runs to return
|
||||
|
||||
Returns:
|
||||
List of matching Run objects
|
||||
"""
|
||||
# Try different tag conventions
|
||||
runs = self.search_runs(
|
||||
filter_string=f"tags.`model.name` = '{model_name}'",
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
if not runs:
|
||||
# Try params
|
||||
runs = self.search_runs(
|
||||
filter_string=f"params.model_name = '{model_name}'",
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
return runs
|
||||
|
||||
|
||||
def compare_experiments(
|
||||
experiment_names: List[str],
|
||||
metric: str,
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
Compare metrics across multiple experiments.
|
||||
|
||||
Args:
|
||||
experiment_names: Names of experiments to compare
|
||||
metric: Metric to compare
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
Dict mapping experiment names to metric statistics
|
||||
"""
|
||||
results = {}
|
||||
|
||||
for exp_name in experiment_names:
|
||||
analyzer = ExperimentAnalyzer(exp_name, tracking_uri=tracking_uri)
|
||||
summary = analyzer.get_metrics_summary(metrics=[metric])
|
||||
if metric in summary:
|
||||
results[exp_name] = summary[metric]
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def promotion_recommendation(
|
||||
model_name: str,
|
||||
experiment_name: str,
|
||||
criteria: Dict[str, Tuple[str, float]],
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> PromotionRecommendation:
|
||||
"""
|
||||
Generate a recommendation for model promotion.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to evaluate
|
||||
experiment_name: Experiment containing evaluation runs
|
||||
criteria: Dict of {metric: (comparison, threshold)}
|
||||
comparison is one of: ">=", "<=", ">", "<"
|
||||
e.g., {"eval.accuracy": (">=", 0.9), "total_latency_p95": ("<=", 2.0)}
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
PromotionRecommendation with decision and reasons
|
||||
"""
|
||||
analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri)
|
||||
|
||||
# Get model runs
|
||||
runs = analyzer.get_model_runs(model_name, max_results=10)
|
||||
|
||||
if not runs:
|
||||
return PromotionRecommendation(
|
||||
model_name=model_name,
|
||||
version=None,
|
||||
recommended=False,
|
||||
reasons=["No runs found for this model"],
|
||||
metrics_summary={},
|
||||
)
|
||||
|
||||
# Get the most recent run
|
||||
latest_run = runs[0]
|
||||
metrics = latest_run.data.metrics
|
||||
|
||||
# Evaluate criteria
|
||||
reasons = []
|
||||
passed = True
|
||||
|
||||
comparisons = {
|
||||
">=": lambda a, b: a >= b,
|
||||
"<=": lambda a, b: a <= b,
|
||||
">": lambda a, b: a > b,
|
||||
"<": lambda a, b: a < b,
|
||||
}
|
||||
|
||||
for metric_name, (comparison, threshold) in criteria.items():
|
||||
if metric_name not in metrics:
|
||||
reasons.append(f"Metric '{metric_name}' not found")
|
||||
passed = False
|
||||
continue
|
||||
|
||||
value = metrics[metric_name]
|
||||
compare_fn = comparisons.get(comparison)
|
||||
|
||||
if compare_fn is None:
|
||||
reasons.append(f"Invalid comparison operator: {comparison}")
|
||||
continue
|
||||
|
||||
if compare_fn(value, threshold):
|
||||
reasons.append(f"✓ {metric_name}: {value:.4f} {comparison} {threshold}")
|
||||
else:
|
||||
reasons.append(f"✗ {metric_name}: {value:.4f} NOT {comparison} {threshold}")
|
||||
passed = False
|
||||
|
||||
# Extract version from tags if available
|
||||
version = None
|
||||
if "mlflow.version" in latest_run.data.tags:
|
||||
try:
|
||||
version = int(latest_run.data.tags["mlflow.version"])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return PromotionRecommendation(
|
||||
model_name=model_name,
|
||||
version=version,
|
||||
recommended=passed,
|
||||
reasons=reasons,
|
||||
metrics_summary=dict(metrics),
|
||||
)
|
||||
|
||||
|
||||
def get_inference_performance_report(
|
||||
service_name: str = "chat-handler",
|
||||
hours: int = 24,
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate an inference performance report for a service.
|
||||
|
||||
Args:
|
||||
service_name: Service name (chat-handler, voice-assistant)
|
||||
hours: Hours of data to analyze
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
Performance report dictionary
|
||||
"""
|
||||
experiment_name = f"{service_name.replace('-', '')}-inference"
|
||||
analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri)
|
||||
|
||||
# Get summary metrics
|
||||
summary = analyzer.get_metrics_summary(hours=hours)
|
||||
|
||||
# Key latency metrics
|
||||
latency_metrics = [
|
||||
"total_latency_mean",
|
||||
"total_latency_p50",
|
||||
"total_latency_p95",
|
||||
"llm_latency_mean",
|
||||
"embedding_latency_mean",
|
||||
"rag_search_latency_mean",
|
||||
]
|
||||
|
||||
report = {
|
||||
"service": service_name,
|
||||
"period_hours": hours,
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"latency": {},
|
||||
"throughput": {},
|
||||
"rag": {},
|
||||
"errors": {},
|
||||
}
|
||||
|
||||
# Latency section
|
||||
for metric in latency_metrics:
|
||||
if metric in summary:
|
||||
report["latency"][metric] = summary[metric]
|
||||
|
||||
# Throughput
|
||||
if "total_requests" in summary:
|
||||
report["throughput"]["total_requests"] = summary["total_requests"]["mean"]
|
||||
|
||||
# RAG usage
|
||||
rag_metrics = ["rag_enabled_pct", "rag_documents_retrieved_mean", "rag_documents_used_mean"]
|
||||
for metric in rag_metrics:
|
||||
if metric in summary:
|
||||
report["rag"][metric] = summary[metric]
|
||||
|
||||
# Error rate
|
||||
if "error_rate" in summary:
|
||||
report["errors"]["error_rate_pct"] = summary["error_rate"]["mean"]
|
||||
|
||||
return report
|
||||
431
mlflow_utils/inference_tracker.py
Normal file
431
mlflow_utils/inference_tracker.py
Normal file
@@ -0,0 +1,431 @@
|
||||
"""
|
||||
Inference Metrics Tracker for NATS Handlers
|
||||
|
||||
Provides async-compatible MLflow logging for real-time inference
|
||||
metrics from chat-handler and voice-assistant services.
|
||||
|
||||
Designed to integrate with the existing OpenTelemetry setup and
|
||||
complement OTel metrics with MLflow experiment tracking for
|
||||
longer-term analysis and model comparison.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
from .client import get_mlflow_client, ensure_experiment, MLflowConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceMetrics:
|
||||
"""Metrics collected during an inference request."""
|
||||
request_id: str
|
||||
user_id: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
|
||||
# Timing metrics (in seconds)
|
||||
total_latency: float = 0.0
|
||||
embedding_latency: float = 0.0
|
||||
rag_search_latency: float = 0.0
|
||||
rerank_latency: float = 0.0
|
||||
llm_latency: float = 0.0
|
||||
tts_latency: float = 0.0
|
||||
stt_latency: float = 0.0
|
||||
|
||||
# Token/size metrics
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
prompt_length: int = 0
|
||||
response_length: int = 0
|
||||
|
||||
# RAG metrics
|
||||
rag_enabled: bool = False
|
||||
rag_documents_retrieved: int = 0
|
||||
rag_documents_used: int = 0
|
||||
reranker_enabled: bool = False
|
||||
|
||||
# Quality indicators
|
||||
is_streaming: bool = False
|
||||
is_premium: bool = False
|
||||
has_error: bool = False
|
||||
error_message: Optional[str] = None
|
||||
|
||||
# Model information
|
||||
model_name: Optional[str] = None
|
||||
model_endpoint: Optional[str] = None
|
||||
|
||||
# Timestamps
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
def as_metrics_dict(self) -> Dict[str, float]:
|
||||
"""Convert numeric fields to a metrics dictionary."""
|
||||
return {
|
||||
"total_latency": self.total_latency,
|
||||
"embedding_latency": self.embedding_latency,
|
||||
"rag_search_latency": self.rag_search_latency,
|
||||
"rerank_latency": self.rerank_latency,
|
||||
"llm_latency": self.llm_latency,
|
||||
"tts_latency": self.tts_latency,
|
||||
"stt_latency": self.stt_latency,
|
||||
"input_tokens": float(self.input_tokens),
|
||||
"output_tokens": float(self.output_tokens),
|
||||
"total_tokens": float(self.total_tokens),
|
||||
"prompt_length": float(self.prompt_length),
|
||||
"response_length": float(self.response_length),
|
||||
"rag_documents_retrieved": float(self.rag_documents_retrieved),
|
||||
"rag_documents_used": float(self.rag_documents_used),
|
||||
}
|
||||
|
||||
def as_params_dict(self) -> Dict[str, str]:
|
||||
"""Convert configuration fields to a params dictionary."""
|
||||
params = {
|
||||
"rag_enabled": str(self.rag_enabled),
|
||||
"reranker_enabled": str(self.reranker_enabled),
|
||||
"is_streaming": str(self.is_streaming),
|
||||
"is_premium": str(self.is_premium),
|
||||
}
|
||||
if self.model_name:
|
||||
params["model_name"] = self.model_name
|
||||
if self.model_endpoint:
|
||||
params["model_endpoint"] = self.model_endpoint
|
||||
return params
|
||||
|
||||
|
||||
class InferenceMetricsTracker:
|
||||
"""
|
||||
Async-compatible MLflow tracker for inference metrics.
|
||||
|
||||
Uses batching and a background thread pool to avoid blocking
|
||||
the async event loop during MLflow calls.
|
||||
|
||||
Example usage in chat-handler:
|
||||
|
||||
class ChatHandler:
|
||||
def __init__(self):
|
||||
self.mlflow_tracker = InferenceMetricsTracker(
|
||||
service_name="chat-handler",
|
||||
experiment_name="chat-inference"
|
||||
)
|
||||
|
||||
async def setup(self):
|
||||
await self.mlflow_tracker.start()
|
||||
|
||||
async def process_request(self, msg):
|
||||
metrics = InferenceMetrics(request_id=request_id)
|
||||
|
||||
# Track timing
|
||||
start = time.time()
|
||||
# ... do embedding ...
|
||||
metrics.embedding_latency = time.time() - start
|
||||
|
||||
# ... more processing ...
|
||||
|
||||
# Log metrics (non-blocking)
|
||||
await self.mlflow_tracker.log_inference(metrics)
|
||||
|
||||
async def shutdown(self):
|
||||
await self.mlflow_tracker.stop()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service_name: str,
|
||||
experiment_name: Optional[str] = None,
|
||||
tracking_uri: Optional[str] = None,
|
||||
batch_size: int = 50,
|
||||
flush_interval_seconds: float = 60.0,
|
||||
enable_batching: bool = True,
|
||||
max_workers: int = 2,
|
||||
):
|
||||
"""
|
||||
Initialize the inference metrics tracker.
|
||||
|
||||
Args:
|
||||
service_name: Name of the service (e.g., "chat-handler")
|
||||
experiment_name: MLflow experiment name (defaults to service_name)
|
||||
tracking_uri: Override default tracking URI
|
||||
batch_size: Number of metrics to batch before flushing
|
||||
flush_interval_seconds: Maximum time between flushes
|
||||
enable_batching: If False, log each request immediately
|
||||
max_workers: Number of thread pool workers for MLflow calls
|
||||
"""
|
||||
self.service_name = service_name
|
||||
self.experiment_name = experiment_name or f"{service_name}-inference"
|
||||
self.tracking_uri = tracking_uri
|
||||
self.batch_size = batch_size
|
||||
self.flush_interval = flush_interval_seconds
|
||||
self.enable_batching = enable_batching
|
||||
|
||||
self.config = MLflowConfig()
|
||||
self._batch: List[InferenceMetrics] = []
|
||||
self._batch_lock = asyncio.Lock()
|
||||
self._flush_task: Optional[asyncio.Task] = None
|
||||
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self._running = False
|
||||
self._client: Optional[MlflowClient] = None
|
||||
self._experiment_id: Optional[str] = None
|
||||
|
||||
# Aggregate metrics for periodic logging
|
||||
self._aggregate_metrics: Dict[str, List[float]] = defaultdict(list)
|
||||
self._request_count = 0
|
||||
self._error_count = 0
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the tracker and initialize MLflow connection."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Initialize MLflow in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
self._executor,
|
||||
self._init_mlflow
|
||||
)
|
||||
|
||||
if self.enable_batching:
|
||||
self._flush_task = asyncio.create_task(self._periodic_flush())
|
||||
|
||||
logger.info(
|
||||
f"InferenceMetricsTracker started for {self.service_name} "
|
||||
f"(experiment: {self.experiment_name})"
|
||||
)
|
||||
|
||||
def _init_mlflow(self) -> None:
|
||||
"""Initialize MLflow client and experiment (runs in thread pool)."""
|
||||
self._client = get_mlflow_client(
|
||||
tracking_uri=self.tracking_uri,
|
||||
configure_global=True
|
||||
)
|
||||
self._experiment_id = ensure_experiment(
|
||||
self.experiment_name,
|
||||
tags={
|
||||
"service": self.service_name,
|
||||
"type": "inference-metrics",
|
||||
}
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the tracker and flush remaining metrics."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Final flush
|
||||
await self._flush_batch()
|
||||
|
||||
self._executor.shutdown(wait=True)
|
||||
logger.info(f"InferenceMetricsTracker stopped for {self.service_name}")
|
||||
|
||||
async def log_inference(self, metrics: InferenceMetrics) -> None:
|
||||
"""
|
||||
Log inference metrics (non-blocking).
|
||||
|
||||
Args:
|
||||
metrics: InferenceMetrics object with request data
|
||||
"""
|
||||
if not self._running:
|
||||
logger.warning("Tracker not running, skipping metrics")
|
||||
return
|
||||
|
||||
self._request_count += 1
|
||||
if metrics.has_error:
|
||||
self._error_count += 1
|
||||
|
||||
# Update aggregates
|
||||
for key, value in metrics.as_metrics_dict().items():
|
||||
if value > 0:
|
||||
self._aggregate_metrics[key].append(value)
|
||||
|
||||
if self.enable_batching:
|
||||
async with self._batch_lock:
|
||||
self._batch.append(metrics)
|
||||
if len(self._batch) >= self.batch_size:
|
||||
asyncio.create_task(self._flush_batch())
|
||||
else:
|
||||
# Immediate logging in thread pool
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
self._executor,
|
||||
partial(self._log_single_inference, metrics)
|
||||
)
|
||||
|
||||
async def _periodic_flush(self) -> None:
|
||||
"""Periodically flush batched metrics."""
|
||||
while self._running:
|
||||
await asyncio.sleep(self.flush_interval)
|
||||
await self._flush_batch()
|
||||
|
||||
async def _flush_batch(self) -> None:
|
||||
"""Flush the current batch of metrics to MLflow."""
|
||||
async with self._batch_lock:
|
||||
if not self._batch:
|
||||
return
|
||||
|
||||
batch = self._batch
|
||||
self._batch = []
|
||||
|
||||
# Log in thread pool
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
self._executor,
|
||||
partial(self._log_batch, batch)
|
||||
)
|
||||
|
||||
def _log_single_inference(self, metrics: InferenceMetrics) -> None:
|
||||
"""Log a single inference request to MLflow (runs in thread pool)."""
|
||||
try:
|
||||
with mlflow.start_run(
|
||||
experiment_id=self._experiment_id,
|
||||
run_name=f"inference-{metrics.request_id}",
|
||||
tags={
|
||||
"service": self.service_name,
|
||||
"request_id": metrics.request_id,
|
||||
"type": "single-inference",
|
||||
}
|
||||
):
|
||||
mlflow.log_params(metrics.as_params_dict())
|
||||
mlflow.log_metrics(metrics.as_metrics_dict())
|
||||
|
||||
if metrics.user_id:
|
||||
mlflow.set_tag("user_id", metrics.user_id)
|
||||
if metrics.session_id:
|
||||
mlflow.set_tag("session_id", metrics.session_id)
|
||||
if metrics.has_error:
|
||||
mlflow.set_tag("has_error", "true")
|
||||
if metrics.error_message:
|
||||
mlflow.set_tag("error_message", metrics.error_message[:250])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log inference metrics: {e}")
|
||||
|
||||
def _log_batch(self, batch: List[InferenceMetrics]) -> None:
|
||||
"""Log a batch of inference metrics as aggregate statistics."""
|
||||
if not batch:
|
||||
return
|
||||
|
||||
try:
|
||||
# Calculate aggregates
|
||||
aggregates = self._calculate_aggregates(batch)
|
||||
|
||||
run_name = f"batch-{self.service_name}-{int(time.time())}"
|
||||
|
||||
with mlflow.start_run(
|
||||
experiment_id=self._experiment_id,
|
||||
run_name=run_name,
|
||||
tags={
|
||||
"service": self.service_name,
|
||||
"type": "batch-inference",
|
||||
"batch_size": str(len(batch)),
|
||||
}
|
||||
):
|
||||
# Log aggregate metrics
|
||||
mlflow.log_metrics(aggregates)
|
||||
|
||||
# Log batch info
|
||||
mlflow.log_param("batch_size", len(batch))
|
||||
mlflow.log_param("time_window_start", min(m.timestamp for m in batch))
|
||||
mlflow.log_param("time_window_end", max(m.timestamp for m in batch))
|
||||
|
||||
# Log configuration breakdown
|
||||
rag_enabled_count = sum(1 for m in batch if m.rag_enabled)
|
||||
streaming_count = sum(1 for m in batch if m.is_streaming)
|
||||
premium_count = sum(1 for m in batch if m.is_premium)
|
||||
error_count = sum(1 for m in batch if m.has_error)
|
||||
|
||||
mlflow.log_metrics({
|
||||
"rag_enabled_pct": rag_enabled_count / len(batch) * 100,
|
||||
"streaming_pct": streaming_count / len(batch) * 100,
|
||||
"premium_pct": premium_count / len(batch) * 100,
|
||||
"error_rate": error_count / len(batch) * 100,
|
||||
})
|
||||
|
||||
# Log model distribution
|
||||
model_counts: Dict[str, int] = defaultdict(int)
|
||||
for m in batch:
|
||||
if m.model_name:
|
||||
model_counts[m.model_name] += 1
|
||||
|
||||
if model_counts:
|
||||
mlflow.log_dict(
|
||||
{"models": dict(model_counts)},
|
||||
"model_distribution.json"
|
||||
)
|
||||
|
||||
logger.info(f"Logged batch of {len(batch)} inference metrics")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log batch metrics: {e}")
|
||||
|
||||
def _calculate_aggregates(
|
||||
self,
|
||||
batch: List[InferenceMetrics]
|
||||
) -> Dict[str, float]:
|
||||
"""Calculate aggregate statistics from a batch of metrics."""
|
||||
import statistics
|
||||
|
||||
aggregates = {}
|
||||
|
||||
# Collect all numeric metrics
|
||||
metric_values: Dict[str, List[float]] = defaultdict(list)
|
||||
for m in batch:
|
||||
for key, value in m.as_metrics_dict().items():
|
||||
if value > 0:
|
||||
metric_values[key].append(value)
|
||||
|
||||
# Calculate statistics for each metric
|
||||
for key, values in metric_values.items():
|
||||
if not values:
|
||||
continue
|
||||
|
||||
aggregates[f"{key}_mean"] = statistics.mean(values)
|
||||
aggregates[f"{key}_min"] = min(values)
|
||||
aggregates[f"{key}_max"] = max(values)
|
||||
|
||||
if len(values) >= 2:
|
||||
aggregates[f"{key}_p50"] = statistics.median(values)
|
||||
aggregates[f"{key}_stdev"] = statistics.stdev(values)
|
||||
|
||||
if len(values) >= 4:
|
||||
sorted_vals = sorted(values)
|
||||
p95_idx = int(len(sorted_vals) * 0.95)
|
||||
p99_idx = int(len(sorted_vals) * 0.99)
|
||||
aggregates[f"{key}_p95"] = sorted_vals[p95_idx]
|
||||
aggregates[f"{key}_p99"] = sorted_vals[p99_idx]
|
||||
|
||||
# Add counts
|
||||
aggregates["total_requests"] = float(len(batch))
|
||||
|
||||
return aggregates
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get current tracker statistics."""
|
||||
return {
|
||||
"service_name": self.service_name,
|
||||
"experiment_name": self.experiment_name,
|
||||
"running": self._running,
|
||||
"total_requests": self._request_count,
|
||||
"error_count": self._error_count,
|
||||
"pending_batch_size": len(self._batch),
|
||||
"aggregate_metrics_count": len(self._aggregate_metrics),
|
||||
}
|
||||
513
mlflow_utils/kfp_components.py
Normal file
513
mlflow_utils/kfp_components.py
Normal file
@@ -0,0 +1,513 @@
|
||||
"""
|
||||
Kubeflow Pipeline Components with MLflow Tracking
|
||||
|
||||
Provides reusable KFP components that integrate MLflow experiment
|
||||
tracking into Kubeflow Pipelines. These components can be used
|
||||
directly in pipelines or as wrappers around existing pipeline steps.
|
||||
|
||||
Usage in a Kubeflow Pipeline:
|
||||
|
||||
from mlflow_utils.kfp_components import (
|
||||
create_mlflow_run,
|
||||
log_metrics_component,
|
||||
log_model_artifact,
|
||||
end_mlflow_run,
|
||||
)
|
||||
|
||||
@dsl.pipeline(name="my-pipeline")
|
||||
def my_pipeline():
|
||||
# Start MLflow run
|
||||
run_info = create_mlflow_run(
|
||||
experiment_name="my-experiment",
|
||||
run_name="training-run-1"
|
||||
)
|
||||
|
||||
# ... your pipeline steps ...
|
||||
|
||||
# Log metrics
|
||||
log_step = log_metrics_component(
|
||||
run_id=run_info.outputs["run_id"],
|
||||
metrics={"accuracy": 0.95, "loss": 0.05}
|
||||
)
|
||||
|
||||
# End run
|
||||
end_mlflow_run(run_id=run_info.outputs["run_id"])
|
||||
"""
|
||||
|
||||
from kfp import dsl
|
||||
from typing import Dict, Any, List, Optional, NamedTuple
|
||||
|
||||
|
||||
# MLflow component image with all required dependencies
|
||||
MLFLOW_IMAGE = "python:3.13-slim"
|
||||
MLFLOW_PACKAGES = [
|
||||
"mlflow>=2.10.0",
|
||||
"boto3", # For S3 artifact storage if needed
|
||||
"psycopg2-binary", # For PostgreSQL backend
|
||||
]
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
def create_mlflow_run(
|
||||
experiment_name: str,
|
||||
run_name: str,
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
tags: Dict[str, str] = None,
|
||||
params: Dict[str, str] = None,
|
||||
) -> NamedTuple('RunInfo', [('run_id', str), ('experiment_id', str), ('artifact_uri', str)]):
|
||||
"""
|
||||
Create a new MLflow run for the pipeline.
|
||||
|
||||
This should be called at the start of a pipeline to initialize
|
||||
tracking. The returned run_id should be passed to subsequent
|
||||
components for logging.
|
||||
|
||||
Args:
|
||||
experiment_name: Name of the MLflow experiment
|
||||
run_name: Name for this specific run
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
tags: Optional tags to add to the run
|
||||
params: Optional parameters to log
|
||||
|
||||
Returns:
|
||||
NamedTuple with run_id, experiment_id, and artifact_uri
|
||||
"""
|
||||
import os
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from collections import namedtuple
|
||||
|
||||
# Set tracking URI
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
|
||||
client = MlflowClient()
|
||||
|
||||
# Get or create experiment
|
||||
experiment = client.get_experiment_by_name(experiment_name)
|
||||
if experiment is None:
|
||||
experiment_id = client.create_experiment(
|
||||
name=experiment_name,
|
||||
artifact_location=f"/mlflow/artifacts/{experiment_name}"
|
||||
)
|
||||
else:
|
||||
experiment_id = experiment.experiment_id
|
||||
|
||||
# Create default tags
|
||||
default_tags = {
|
||||
"pipeline.type": "kubeflow",
|
||||
"kfp.run_id": os.environ.get("KFP_RUN_ID", "unknown"),
|
||||
"kfp.pod_name": os.environ.get("HOSTNAME", "unknown"),
|
||||
}
|
||||
if tags:
|
||||
default_tags.update(tags)
|
||||
|
||||
# Start run
|
||||
run = mlflow.start_run(
|
||||
experiment_id=experiment_id,
|
||||
run_name=run_name,
|
||||
tags=default_tags,
|
||||
)
|
||||
|
||||
# Log initial params
|
||||
if params:
|
||||
mlflow.log_params(params)
|
||||
|
||||
run_id = run.info.run_id
|
||||
artifact_uri = run.info.artifact_uri
|
||||
|
||||
# End run (KFP components are isolated, we'll resume in other components)
|
||||
mlflow.end_run()
|
||||
|
||||
RunInfo = namedtuple('RunInfo', ['run_id', 'experiment_id', 'artifact_uri'])
|
||||
return RunInfo(run_id, experiment_id, artifact_uri)
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
def log_params_component(
|
||||
run_id: str,
|
||||
params: Dict[str, str],
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> str:
|
||||
"""
|
||||
Log parameters to an existing MLflow run.
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
params: Dictionary of parameters to log
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
for key, value in params.items():
|
||||
client.log_param(run_id, key, str(value)[:500])
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
def log_metrics_component(
|
||||
run_id: str,
|
||||
metrics: Dict[str, float],
|
||||
step: int = 0,
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> str:
|
||||
"""
|
||||
Log metrics to an existing MLflow run.
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
metrics: Dictionary of metrics to log
|
||||
step: Step number for time-series metrics
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
for key, value in metrics.items():
|
||||
client.log_metric(run_id, key, float(value), step=step)
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
def log_artifact_component(
|
||||
run_id: str,
|
||||
artifact_path: str,
|
||||
artifact_name: str = "",
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> str:
|
||||
"""
|
||||
Log an artifact file to an existing MLflow run.
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
artifact_path: Path to the artifact file
|
||||
artifact_name: Optional destination name in artifact store
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
client.log_artifact(run_id, artifact_path, artifact_name or None)
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
def log_dict_artifact(
|
||||
run_id: str,
|
||||
data: Dict[str, Any],
|
||||
filename: str,
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> str:
|
||||
"""
|
||||
Log a dictionary as a JSON artifact.
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
data: Dictionary to save as JSON
|
||||
filename: Name for the JSON file
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import json
|
||||
import tempfile
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from pathlib import Path
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
# Ensure .json extension
|
||||
if not filename.endswith('.json'):
|
||||
filename += '.json'
|
||||
|
||||
# Write to temp file and log
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
filepath = Path(tmpdir) / filename
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
client.log_artifact(run_id, str(filepath))
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
def end_mlflow_run(
|
||||
run_id: str,
|
||||
status: str = "FINISHED",
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> str:
|
||||
"""
|
||||
End an MLflow run with the specified status.
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to end
|
||||
status: Run status (FINISHED, FAILED, KILLED)
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
Returns:
|
||||
The run_id
|
||||
"""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from mlflow.entities import RunStatus
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
status_map = {
|
||||
"FINISHED": RunStatus.FINISHED,
|
||||
"FAILED": RunStatus.FAILED,
|
||||
"KILLED": RunStatus.KILLED,
|
||||
}
|
||||
|
||||
run_status = status_map.get(status.upper(), RunStatus.FINISHED)
|
||||
client.set_terminated(run_id, status=run_status)
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES + ["httpx"]
|
||||
)
|
||||
def log_training_metrics(
|
||||
run_id: str,
|
||||
model_type: str,
|
||||
training_config: Dict[str, Any],
|
||||
final_metrics: Dict[str, float],
|
||||
model_path: str = "",
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> str:
|
||||
"""
|
||||
Log comprehensive training metrics for ML models.
|
||||
|
||||
Designed for use with QLoRA training, voice training, and other
|
||||
ML training pipelines in the llm-workflows repository.
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
model_type: Type of model (llm, stt, tts, embeddings)
|
||||
training_config: Training configuration dict
|
||||
final_metrics: Final training metrics
|
||||
model_path: Path to saved model (if applicable)
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import json
|
||||
import tempfile
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from pathlib import Path
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
# Log training config as params
|
||||
flat_config = {}
|
||||
for key, value in training_config.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
flat_config[f"config.{key}"] = json.dumps(value)[:500]
|
||||
else:
|
||||
flat_config[f"config.{key}"] = str(value)[:500]
|
||||
|
||||
for key, value in flat_config.items():
|
||||
client.log_param(run_id, key, value)
|
||||
|
||||
# Log model type tag
|
||||
client.set_tag(run_id, "model.type", model_type)
|
||||
|
||||
# Log metrics
|
||||
for key, value in final_metrics.items():
|
||||
client.log_metric(run_id, key, float(value))
|
||||
|
||||
# Log full config as artifact
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config_path = Path(tmpdir) / "training_config.json"
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(training_config, f, indent=2)
|
||||
client.log_artifact(run_id, str(config_path))
|
||||
|
||||
# Log model path if provided
|
||||
if model_path:
|
||||
client.log_param(run_id, "model.path", model_path)
|
||||
client.set_tag(run_id, "model.saved", "true")
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
def log_document_ingestion_metrics(
|
||||
run_id: str,
|
||||
source_url: str,
|
||||
collection_name: str,
|
||||
chunks_created: int,
|
||||
documents_processed: int,
|
||||
processing_time_seconds: float,
|
||||
embeddings_model: str = "bge-small-en-v1.5",
|
||||
chunk_size: int = 500,
|
||||
chunk_overlap: int = 50,
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> str:
|
||||
"""
|
||||
Log document ingestion pipeline metrics.
|
||||
|
||||
Designed for use with the document_ingestion_pipeline.
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
source_url: URL of the source document
|
||||
collection_name: Milvus collection name
|
||||
chunks_created: Number of chunks created
|
||||
documents_processed: Number of documents processed
|
||||
processing_time_seconds: Total processing time
|
||||
embeddings_model: Embeddings model used
|
||||
chunk_size: Chunk size in tokens
|
||||
chunk_overlap: Chunk overlap in tokens
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
# Log params
|
||||
params = {
|
||||
"source_url": source_url[:500],
|
||||
"collection_name": collection_name,
|
||||
"embeddings_model": embeddings_model,
|
||||
"chunk_size": str(chunk_size),
|
||||
"chunk_overlap": str(chunk_overlap),
|
||||
}
|
||||
for key, value in params.items():
|
||||
client.log_param(run_id, key, value)
|
||||
|
||||
# Log metrics
|
||||
metrics = {
|
||||
"chunks_created": chunks_created,
|
||||
"documents_processed": documents_processed,
|
||||
"processing_time_seconds": processing_time_seconds,
|
||||
"chunks_per_second": chunks_created / processing_time_seconds if processing_time_seconds > 0 else 0,
|
||||
}
|
||||
for key, value in metrics.items():
|
||||
client.log_metric(run_id, key, float(value))
|
||||
|
||||
# Set pipeline type tag
|
||||
client.set_tag(run_id, "pipeline.type", "document-ingestion")
|
||||
client.set_tag(run_id, "milvus.collection", collection_name)
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
def log_evaluation_results(
|
||||
run_id: str,
|
||||
model_name: str,
|
||||
dataset_name: str,
|
||||
metrics: Dict[str, float],
|
||||
sample_results: List[Dict[str, Any]] = None,
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> str:
|
||||
"""
|
||||
Log model evaluation results.
|
||||
|
||||
Designed for use with the evaluation_pipeline.
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
model_name: Name of the evaluated model
|
||||
dataset_name: Name of the evaluation dataset
|
||||
metrics: Evaluation metrics (accuracy, etc.)
|
||||
sample_results: Optional sample predictions
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import json
|
||||
import tempfile
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from pathlib import Path
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
# Log params
|
||||
client.log_param(run_id, "eval.model_name", model_name)
|
||||
client.log_param(run_id, "eval.dataset", dataset_name)
|
||||
|
||||
# Log metrics
|
||||
for key, value in metrics.items():
|
||||
client.log_metric(run_id, f"eval.{key}", float(value))
|
||||
|
||||
# Log sample results as artifact
|
||||
if sample_results:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
results_path = Path(tmpdir) / "evaluation_results.json"
|
||||
with open(results_path, 'w') as f:
|
||||
json.dump(sample_results, f, indent=2)
|
||||
client.log_artifact(run_id, str(results_path))
|
||||
|
||||
# Set tags
|
||||
client.set_tag(run_id, "pipeline.type", "evaluation")
|
||||
client.set_tag(run_id, "model.name", model_name)
|
||||
|
||||
# Determine if passed
|
||||
passed = metrics.get("pass", metrics.get("accuracy", 0) >= 0.7)
|
||||
client.set_tag(run_id, "eval.passed", str(passed))
|
||||
|
||||
return run_id
|
||||
545
mlflow_utils/model_registry.py
Normal file
545
mlflow_utils/model_registry.py
Normal file
@@ -0,0 +1,545 @@
|
||||
"""
|
||||
MLflow Model Registry Integration for KServe
|
||||
|
||||
Provides utilities for registering trained models in MLflow Model Registry
|
||||
with metadata needed for deployment to KServe InferenceServices.
|
||||
|
||||
This module bridges the gap between Kubeflow training pipelines and
|
||||
KServe model serving by:
|
||||
1. Registering models with proper versioning
|
||||
2. Adding KServe-specific metadata (runtime, protocol, resources)
|
||||
3. Managing model stage transitions (Staging → Production)
|
||||
4. Generating KServe InferenceService manifests from registered models
|
||||
|
||||
Usage:
|
||||
from mlflow_utils.model_registry import (
|
||||
register_model_for_kserve,
|
||||
promote_model_to_production,
|
||||
generate_kserve_manifest,
|
||||
)
|
||||
|
||||
# Register a new model version
|
||||
model_version = register_model_for_kserve(
|
||||
model_name="whisper-finetuned",
|
||||
model_uri="s3://models/whisper-v2",
|
||||
model_type="stt",
|
||||
kserve_config={
|
||||
"runtime": "kserve-huggingface",
|
||||
"container_image": "ghcr.io/my-org/whisper:v2",
|
||||
}
|
||||
)
|
||||
|
||||
# Generate KServe manifest for deployment
|
||||
manifest = generate_kserve_manifest(
|
||||
model_name="whisper-finetuned",
|
||||
model_version=model_version.version,
|
||||
)
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import yaml
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from mlflow.entities.model_registry import ModelVersion
|
||||
|
||||
from .client import get_mlflow_client, MLflowConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KServeConfig:
|
||||
"""Configuration for KServe deployment."""
|
||||
|
||||
# Runtime/container configuration
|
||||
runtime: str = "kserve-huggingface" # kserve-huggingface, kserve-custom, etc.
|
||||
container_image: Optional[str] = None
|
||||
container_port: int = 8080
|
||||
|
||||
# Protocol configuration
|
||||
protocol: str = "v2" # v1, v2, grpc
|
||||
|
||||
# Resource requests/limits
|
||||
cpu_request: str = "1"
|
||||
cpu_limit: str = "4"
|
||||
memory_request: str = "4Gi"
|
||||
memory_limit: str = "16Gi"
|
||||
gpu_count: int = 0
|
||||
gpu_type: str = "nvidia.com/gpu" # or amd.com/gpu for ROCm
|
||||
|
||||
# Storage configuration
|
||||
storage_uri: Optional[str] = None # s3://, pvc://, gs://
|
||||
|
||||
# Scaling configuration
|
||||
min_replicas: int = 1
|
||||
max_replicas: int = 1
|
||||
scale_target: int = 10 # Target concurrent requests for scaling
|
||||
|
||||
# Serving configuration
|
||||
timeout_seconds: int = 300
|
||||
batch_size: int = 1
|
||||
|
||||
# Additional environment variables
|
||||
env_vars: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def as_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for MLflow tags."""
|
||||
return {
|
||||
"kserve.runtime": self.runtime,
|
||||
"kserve.protocol": self.protocol,
|
||||
"kserve.cpu_request": self.cpu_request,
|
||||
"kserve.memory_request": self.memory_request,
|
||||
"kserve.gpu_count": str(self.gpu_count),
|
||||
"kserve.min_replicas": str(self.min_replicas),
|
||||
"kserve.max_replicas": str(self.max_replicas),
|
||||
"kserve.storage_uri": self.storage_uri or "",
|
||||
"kserve.container_image": self.container_image or "",
|
||||
}
|
||||
|
||||
|
||||
# Pre-configured KServe configurations for common model types
|
||||
KSERVE_PRESETS: Dict[str, KServeConfig] = {
|
||||
"llm": KServeConfig(
|
||||
runtime="kserve-huggingface",
|
||||
cpu_request="2",
|
||||
cpu_limit="8",
|
||||
memory_request="16Gi",
|
||||
memory_limit="64Gi",
|
||||
gpu_count=1,
|
||||
timeout_seconds=600,
|
||||
),
|
||||
"stt": KServeConfig(
|
||||
runtime="kserve-custom",
|
||||
cpu_request="2",
|
||||
cpu_limit="4",
|
||||
memory_request="8Gi",
|
||||
memory_limit="16Gi",
|
||||
gpu_count=1,
|
||||
timeout_seconds=120,
|
||||
),
|
||||
"tts": KServeConfig(
|
||||
runtime="kserve-custom",
|
||||
cpu_request="2",
|
||||
cpu_limit="4",
|
||||
memory_request="8Gi",
|
||||
memory_limit="16Gi",
|
||||
gpu_count=1,
|
||||
timeout_seconds=60,
|
||||
),
|
||||
"embeddings": KServeConfig(
|
||||
runtime="kserve-huggingface",
|
||||
cpu_request="1",
|
||||
cpu_limit="4",
|
||||
memory_request="4Gi",
|
||||
memory_limit="16Gi",
|
||||
gpu_count=0,
|
||||
timeout_seconds=30,
|
||||
batch_size=32,
|
||||
),
|
||||
"reranker": KServeConfig(
|
||||
runtime="kserve-huggingface",
|
||||
cpu_request="1",
|
||||
cpu_limit="4",
|
||||
memory_request="4Gi",
|
||||
memory_limit="16Gi",
|
||||
gpu_count=0,
|
||||
timeout_seconds=30,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def register_model_for_kserve(
|
||||
model_name: str,
|
||||
model_uri: str,
|
||||
model_type: str,
|
||||
run_id: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
kserve_config: Optional[KServeConfig] = None,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> ModelVersion:
|
||||
"""
|
||||
Register a model in MLflow Model Registry with KServe metadata.
|
||||
|
||||
Args:
|
||||
model_name: Name for the registered model
|
||||
model_uri: URI to model artifacts (runs:/run_id/path, s3://, pvc://)
|
||||
model_type: Type of model (llm, stt, tts, embeddings, reranker)
|
||||
run_id: Optional MLflow run ID to associate with
|
||||
description: Description of the model version
|
||||
kserve_config: KServe deployment configuration
|
||||
tags: Additional tags for the model version
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
The created ModelVersion object
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
# Get or use preset KServe config
|
||||
if kserve_config is None:
|
||||
kserve_config = KSERVE_PRESETS.get(model_type, KServeConfig())
|
||||
|
||||
# Ensure registered model exists
|
||||
try:
|
||||
client.get_registered_model(model_name)
|
||||
except mlflow.exceptions.MlflowException:
|
||||
client.create_registered_model(
|
||||
name=model_name,
|
||||
description=f"{model_type.upper()} model for KServe deployment",
|
||||
tags={
|
||||
"model.type": model_type,
|
||||
"deployment.target": "kserve",
|
||||
}
|
||||
)
|
||||
logger.info(f"Created registered model: {model_name}")
|
||||
|
||||
# Create model version
|
||||
model_version = client.create_model_version(
|
||||
name=model_name,
|
||||
source=model_uri,
|
||||
run_id=run_id,
|
||||
description=description or f"Version from {model_uri}",
|
||||
tags={
|
||||
**(tags or {}),
|
||||
"model.type": model_type,
|
||||
**kserve_config.as_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Registered model version {model_version.version} "
|
||||
f"for {model_name} (type: {model_type})"
|
||||
)
|
||||
|
||||
return model_version
|
||||
|
||||
|
||||
def promote_model_to_stage(
|
||||
model_name: str,
|
||||
version: int,
|
||||
stage: str = "Staging",
|
||||
archive_existing: bool = True,
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> ModelVersion:
|
||||
"""
|
||||
Promote a model version to a new stage.
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
version: Version number to promote
|
||||
stage: Target stage (Staging, Production, Archived)
|
||||
archive_existing: If True, archive existing versions in target stage
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
The updated ModelVersion
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
# Transition to new stage
|
||||
model_version = client.transition_model_version_stage(
|
||||
name=model_name,
|
||||
version=str(version),
|
||||
stage=stage,
|
||||
archive_existing_versions=archive_existing,
|
||||
)
|
||||
|
||||
logger.info(f"Promoted {model_name} v{version} to {stage}")
|
||||
|
||||
return model_version
|
||||
|
||||
|
||||
def promote_model_to_production(
|
||||
model_name: str,
|
||||
version: int,
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> ModelVersion:
|
||||
"""
|
||||
Promote a model version directly to Production.
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
version: Version number to promote
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
The updated ModelVersion
|
||||
"""
|
||||
return promote_model_to_stage(
|
||||
model_name=model_name,
|
||||
version=version,
|
||||
stage="Production",
|
||||
archive_existing=True,
|
||||
tracking_uri=tracking_uri,
|
||||
)
|
||||
|
||||
|
||||
def get_production_model(
|
||||
model_name: str,
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> Optional[ModelVersion]:
|
||||
"""
|
||||
Get the current Production model version.
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
The Production ModelVersion, or None if none exists
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
versions = client.get_latest_versions(model_name, stages=["Production"])
|
||||
|
||||
return versions[0] if versions else None
|
||||
|
||||
|
||||
def get_model_kserve_config(
|
||||
model_name: str,
|
||||
version: Optional[int] = None,
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> KServeConfig:
|
||||
"""
|
||||
Get KServe configuration from a registered model version.
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
version: Version number (uses Production if not specified)
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
KServeConfig populated from model tags
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
if version:
|
||||
model_version = client.get_model_version(model_name, str(version))
|
||||
else:
|
||||
prod_version = get_production_model(model_name, tracking_uri)
|
||||
if not prod_version:
|
||||
raise ValueError(f"No Production version for {model_name}")
|
||||
model_version = prod_version
|
||||
|
||||
tags = model_version.tags
|
||||
|
||||
return KServeConfig(
|
||||
runtime=tags.get("kserve.runtime", "kserve-huggingface"),
|
||||
protocol=tags.get("kserve.protocol", "v2"),
|
||||
cpu_request=tags.get("kserve.cpu_request", "1"),
|
||||
memory_request=tags.get("kserve.memory_request", "4Gi"),
|
||||
gpu_count=int(tags.get("kserve.gpu_count", "0")),
|
||||
min_replicas=int(tags.get("kserve.min_replicas", "1")),
|
||||
max_replicas=int(tags.get("kserve.max_replicas", "1")),
|
||||
storage_uri=tags.get("kserve.storage_uri") or None,
|
||||
container_image=tags.get("kserve.container_image") or None,
|
||||
)
|
||||
|
||||
|
||||
def generate_kserve_manifest(
|
||||
model_name: str,
|
||||
version: Optional[int] = None,
|
||||
namespace: str = "ai-ml",
|
||||
service_name: Optional[str] = None,
|
||||
extra_annotations: Optional[Dict[str, str]] = None,
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a KServe InferenceService manifest from a registered model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
version: Version number (uses Production if not specified)
|
||||
namespace: Kubernetes namespace for deployment
|
||||
service_name: Name for the InferenceService (defaults to model_name)
|
||||
extra_annotations: Additional annotations for the service
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
KServe InferenceService manifest as a dictionary
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
# Get model version
|
||||
if version:
|
||||
model_version = client.get_model_version(model_name, str(version))
|
||||
else:
|
||||
prod_version = get_production_model(model_name, tracking_uri)
|
||||
if not prod_version:
|
||||
raise ValueError(f"No Production version for {model_name}")
|
||||
model_version = prod_version
|
||||
version = int(model_version.version)
|
||||
|
||||
# Get KServe config
|
||||
config = get_model_kserve_config(model_name, version, tracking_uri)
|
||||
model_type = model_version.tags.get("model.type", "custom")
|
||||
|
||||
svc_name = service_name or model_name.lower().replace("_", "-")
|
||||
|
||||
# Build manifest
|
||||
manifest = {
|
||||
"apiVersion": "serving.kserve.io/v1beta1",
|
||||
"kind": "InferenceService",
|
||||
"metadata": {
|
||||
"name": svc_name,
|
||||
"namespace": namespace,
|
||||
"labels": {
|
||||
"mlflow.model": model_name,
|
||||
"mlflow.version": str(version),
|
||||
"model.type": model_type,
|
||||
},
|
||||
"annotations": {
|
||||
"mlflow.tracking_uri": get_mlflow_client().tracking_uri,
|
||||
"mlflow.run_id": model_version.run_id or "",
|
||||
**(extra_annotations or {}),
|
||||
},
|
||||
},
|
||||
"spec": {
|
||||
"predictor": {
|
||||
"minReplicas": config.min_replicas,
|
||||
"maxReplicas": config.max_replicas,
|
||||
"scaleTarget": config.scale_target,
|
||||
"timeout": config.timeout_seconds,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Configure predictor based on runtime
|
||||
predictor = manifest["spec"]["predictor"]
|
||||
|
||||
if config.container_image:
|
||||
# Custom container
|
||||
predictor["containers"] = [{
|
||||
"name": "predictor",
|
||||
"image": config.container_image,
|
||||
"ports": [{"containerPort": config.container_port, "protocol": "TCP"}],
|
||||
"resources": {
|
||||
"requests": {
|
||||
"cpu": config.cpu_request,
|
||||
"memory": config.memory_request,
|
||||
},
|
||||
"limits": {
|
||||
"cpu": config.cpu_limit,
|
||||
"memory": config.memory_limit,
|
||||
},
|
||||
},
|
||||
"env": [
|
||||
{"name": k, "value": v}
|
||||
for k, v in config.env_vars.items()
|
||||
],
|
||||
}]
|
||||
|
||||
# Add GPU if needed
|
||||
if config.gpu_count > 0:
|
||||
predictor["containers"][0]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
|
||||
predictor["containers"][0]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
|
||||
|
||||
else:
|
||||
# Standard KServe runtime
|
||||
storage_uri = config.storage_uri or model_version.source
|
||||
|
||||
predictor["model"] = {
|
||||
"modelFormat": {"name": "huggingface"},
|
||||
"protocolVersion": config.protocol,
|
||||
"storageUri": storage_uri,
|
||||
"resources": {
|
||||
"requests": {
|
||||
"cpu": config.cpu_request,
|
||||
"memory": config.memory_request,
|
||||
},
|
||||
"limits": {
|
||||
"cpu": config.cpu_limit,
|
||||
"memory": config.memory_limit,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if config.gpu_count > 0:
|
||||
predictor["model"]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
|
||||
predictor["model"]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
|
||||
|
||||
return manifest
|
||||
|
||||
|
||||
def generate_kserve_yaml(
|
||||
model_name: str,
|
||||
version: Optional[int] = None,
|
||||
namespace: str = "ai-ml",
|
||||
output_path: Optional[str] = None,
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a KServe InferenceService manifest as YAML.
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
version: Version number (uses Production if not specified)
|
||||
namespace: Kubernetes namespace
|
||||
output_path: If provided, write YAML to this path
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
YAML string of the manifest
|
||||
"""
|
||||
manifest = generate_kserve_manifest(
|
||||
model_name=model_name,
|
||||
version=version,
|
||||
namespace=namespace,
|
||||
tracking_uri=tracking_uri,
|
||||
)
|
||||
|
||||
yaml_str = yaml.dump(manifest, default_flow_style=False, sort_keys=False)
|
||||
|
||||
if output_path:
|
||||
with open(output_path, 'w') as f:
|
||||
f.write(yaml_str)
|
||||
logger.info(f"Wrote KServe manifest to {output_path}")
|
||||
|
||||
return yaml_str
|
||||
|
||||
|
||||
def list_model_versions(
|
||||
model_name: str,
|
||||
stages: Optional[List[str]] = None,
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List all versions of a registered model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
stages: Filter by stages (None for all)
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
List of model version info dictionaries
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
if stages:
|
||||
versions = client.get_latest_versions(model_name, stages=stages)
|
||||
else:
|
||||
# Get all versions
|
||||
versions = []
|
||||
for mv in client.search_model_versions(f"name='{model_name}'"):
|
||||
versions.append(mv)
|
||||
|
||||
return [
|
||||
{
|
||||
"version": mv.version,
|
||||
"stage": mv.current_stage,
|
||||
"source": mv.source,
|
||||
"run_id": mv.run_id,
|
||||
"description": mv.description,
|
||||
"tags": mv.tags,
|
||||
"creation_timestamp": mv.creation_timestamp,
|
||||
"last_updated_timestamp": mv.last_updated_timestamp,
|
||||
}
|
||||
for mv in versions
|
||||
]
|
||||
395
mlflow_utils/tracker.py
Normal file
395
mlflow_utils/tracker.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
MLflow Tracker for Kubeflow Pipelines
|
||||
|
||||
Provides a high-level interface for logging experiments, parameters,
|
||||
metrics, and artifacts from Kubeflow Pipeline components.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List, Union
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
from .client import get_mlflow_client, ensure_experiment, MLflowConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineMetadata:
|
||||
"""Metadata about the Kubeflow Pipeline run."""
|
||||
pipeline_name: str
|
||||
run_id: str
|
||||
run_name: Optional[str] = None
|
||||
component_name: Optional[str] = None
|
||||
namespace: str = "ai-ml"
|
||||
|
||||
# KFP-specific metadata (populated from environment if available)
|
||||
kfp_run_id: Optional[str] = field(
|
||||
default_factory=lambda: os.environ.get("KFP_RUN_ID")
|
||||
)
|
||||
kfp_pod_name: Optional[str] = field(
|
||||
default_factory=lambda: os.environ.get("KFP_POD_NAME")
|
||||
)
|
||||
|
||||
def as_tags(self) -> Dict[str, str]:
|
||||
"""Convert metadata to MLflow tags."""
|
||||
tags = {
|
||||
"pipeline.name": self.pipeline_name,
|
||||
"pipeline.run_id": self.run_id,
|
||||
"pipeline.namespace": self.namespace,
|
||||
}
|
||||
if self.run_name:
|
||||
tags["pipeline.run_name"] = self.run_name
|
||||
if self.component_name:
|
||||
tags["pipeline.component"] = self.component_name
|
||||
if self.kfp_run_id:
|
||||
tags["kfp.run_id"] = self.kfp_run_id
|
||||
if self.kfp_pod_name:
|
||||
tags["kfp.pod_name"] = self.kfp_pod_name
|
||||
return tags
|
||||
|
||||
|
||||
class MLflowTracker:
|
||||
"""
|
||||
MLflow experiment tracker for Kubeflow Pipeline components.
|
||||
|
||||
Example usage in a KFP component:
|
||||
|
||||
from mlflow_utils import MLflowTracker
|
||||
|
||||
tracker = MLflowTracker(
|
||||
experiment_name="document-ingestion",
|
||||
run_name="batch-ingestion-2024-01"
|
||||
)
|
||||
|
||||
with tracker.start_run() as run:
|
||||
tracker.log_params({
|
||||
"chunk_size": 500,
|
||||
"overlap": 50,
|
||||
"embeddings_model": "bge-small-en-v1.5"
|
||||
})
|
||||
|
||||
# ... do work ...
|
||||
|
||||
tracker.log_metrics({
|
||||
"documents_processed": 100,
|
||||
"chunks_created": 2500,
|
||||
"processing_time_seconds": 120.5
|
||||
})
|
||||
|
||||
tracker.log_artifact("/path/to/output.json")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name: str,
|
||||
run_name: Optional[str] = None,
|
||||
pipeline_metadata: Optional[PipelineMetadata] = None,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
tracking_uri: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the MLflow tracker.
|
||||
|
||||
Args:
|
||||
experiment_name: Name of the MLflow experiment
|
||||
run_name: Optional name for this run
|
||||
pipeline_metadata: Metadata about the KFP pipeline
|
||||
tags: Additional tags to add to the run
|
||||
tracking_uri: Override default tracking URI
|
||||
"""
|
||||
self.config = MLflowConfig()
|
||||
self.experiment_name = experiment_name
|
||||
self.run_name = run_name or f"{experiment_name}-{int(time.time())}"
|
||||
self.pipeline_metadata = pipeline_metadata
|
||||
self.user_tags = tags or {}
|
||||
self.tracking_uri = tracking_uri
|
||||
|
||||
self.client: Optional[MlflowClient] = None
|
||||
self.run: Optional[mlflow.ActiveRun] = None
|
||||
self.run_id: Optional[str] = None
|
||||
self._start_time: Optional[float] = None
|
||||
|
||||
def _get_all_tags(self) -> Dict[str, str]:
|
||||
"""Combine all tags for the run."""
|
||||
tags = self.config.default_tags.copy()
|
||||
|
||||
if self.pipeline_metadata:
|
||||
tags.update(self.pipeline_metadata.as_tags())
|
||||
|
||||
tags.update(self.user_tags)
|
||||
return tags
|
||||
|
||||
@contextmanager
|
||||
def start_run(
|
||||
self,
|
||||
nested: bool = False,
|
||||
parent_run_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Start an MLflow run as a context manager.
|
||||
|
||||
Args:
|
||||
nested: If True, create a nested run under the current active run
|
||||
parent_run_id: Explicit parent run ID for nested runs
|
||||
|
||||
Yields:
|
||||
The MLflow run object
|
||||
"""
|
||||
self.client = get_mlflow_client(
|
||||
tracking_uri=self.tracking_uri,
|
||||
configure_global=True
|
||||
)
|
||||
|
||||
# Ensure experiment exists
|
||||
experiment_id = ensure_experiment(self.experiment_name)
|
||||
|
||||
self._start_time = time.time()
|
||||
|
||||
try:
|
||||
# Start the run
|
||||
self.run = mlflow.start_run(
|
||||
experiment_id=experiment_id,
|
||||
run_name=self.run_name,
|
||||
nested=nested,
|
||||
tags=self._get_all_tags(),
|
||||
)
|
||||
self.run_id = self.run.info.run_id
|
||||
|
||||
logger.info(
|
||||
f"Started MLflow run '{self.run_name}' "
|
||||
f"(ID: {self.run_id}) in experiment '{self.experiment_name}'"
|
||||
)
|
||||
|
||||
yield self.run
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MLflow run failed: {e}")
|
||||
if self.run:
|
||||
mlflow.set_tag("run.status", "failed")
|
||||
mlflow.set_tag("run.error", str(e))
|
||||
raise
|
||||
finally:
|
||||
# Log duration
|
||||
if self._start_time:
|
||||
duration = time.time() - self._start_time
|
||||
try:
|
||||
mlflow.log_metric("run_duration_seconds", duration)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# End the run
|
||||
mlflow.end_run()
|
||||
logger.info(f"Ended MLflow run '{self.run_name}'")
|
||||
|
||||
def log_params(self, params: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Log parameters to the current run.
|
||||
|
||||
Args:
|
||||
params: Dictionary of parameter names to values
|
||||
"""
|
||||
if not self.run:
|
||||
logger.warning("No active run, skipping log_params")
|
||||
return
|
||||
|
||||
# MLflow has limits on param values, truncate if needed
|
||||
cleaned_params = {}
|
||||
for key, value in params.items():
|
||||
str_value = str(value)
|
||||
if len(str_value) > 500:
|
||||
str_value = str_value[:497] + "..."
|
||||
cleaned_params[key] = str_value
|
||||
|
||||
mlflow.log_params(cleaned_params)
|
||||
logger.debug(f"Logged {len(params)} parameters")
|
||||
|
||||
def log_param(self, key: str, value: Any) -> None:
|
||||
"""Log a single parameter."""
|
||||
self.log_params({key: value})
|
||||
|
||||
def log_metrics(
|
||||
self,
|
||||
metrics: Dict[str, Union[float, int]],
|
||||
step: Optional[int] = None
|
||||
) -> None:
|
||||
"""
|
||||
Log metrics to the current run.
|
||||
|
||||
Args:
|
||||
metrics: Dictionary of metric names to values
|
||||
step: Optional step number for time-series metrics
|
||||
"""
|
||||
if not self.run:
|
||||
logger.warning("No active run, skipping log_metrics")
|
||||
return
|
||||
|
||||
mlflow.log_metrics(metrics, step=step)
|
||||
logger.debug(f"Logged {len(metrics)} metrics")
|
||||
|
||||
def log_metric(
|
||||
self,
|
||||
key: str,
|
||||
value: Union[float, int],
|
||||
step: Optional[int] = None
|
||||
) -> None:
|
||||
"""Log a single metric."""
|
||||
self.log_metrics({key: value}, step=step)
|
||||
|
||||
def log_artifact(
|
||||
self,
|
||||
local_path: str,
|
||||
artifact_path: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Log an artifact file to the current run.
|
||||
|
||||
Args:
|
||||
local_path: Path to the local file to log
|
||||
artifact_path: Optional destination path within the artifact store
|
||||
"""
|
||||
if not self.run:
|
||||
logger.warning("No active run, skipping log_artifact")
|
||||
return
|
||||
|
||||
mlflow.log_artifact(local_path, artifact_path)
|
||||
logger.info(f"Logged artifact: {local_path}")
|
||||
|
||||
def log_artifacts(
|
||||
self,
|
||||
local_dir: str,
|
||||
artifact_path: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Log all files in a directory as artifacts.
|
||||
|
||||
Args:
|
||||
local_dir: Path to the local directory
|
||||
artifact_path: Optional destination path within the artifact store
|
||||
"""
|
||||
if not self.run:
|
||||
logger.warning("No active run, skipping log_artifacts")
|
||||
return
|
||||
|
||||
mlflow.log_artifacts(local_dir, artifact_path)
|
||||
logger.info(f"Logged artifacts from: {local_dir}")
|
||||
|
||||
def log_dict(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
filename: str,
|
||||
artifact_path: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Log a dictionary as a JSON artifact.
|
||||
|
||||
Args:
|
||||
data: Dictionary to log
|
||||
filename: Name for the JSON file
|
||||
artifact_path: Optional destination path
|
||||
"""
|
||||
if not self.run:
|
||||
logger.warning("No active run, skipping log_dict")
|
||||
return
|
||||
|
||||
# Ensure .json extension
|
||||
if not filename.endswith(".json"):
|
||||
filename += ".json"
|
||||
|
||||
mlflow.log_dict(data, f"{artifact_path}/{filename}" if artifact_path else filename)
|
||||
logger.debug(f"Logged dict as: {filename}")
|
||||
|
||||
def log_model_info(
|
||||
self,
|
||||
model_type: str,
|
||||
model_name: str,
|
||||
model_path: Optional[str] = None,
|
||||
framework: str = "pytorch",
|
||||
extra_info: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Log model information as parameters and tags.
|
||||
|
||||
Args:
|
||||
model_type: Type of model (e.g., "llm", "embedding", "stt")
|
||||
model_name: Name/identifier of the model
|
||||
model_path: Path to model weights
|
||||
framework: ML framework used
|
||||
extra_info: Additional model information
|
||||
"""
|
||||
params = {
|
||||
"model.type": model_type,
|
||||
"model.name": model_name,
|
||||
"model.framework": framework,
|
||||
}
|
||||
if model_path:
|
||||
params["model.path"] = model_path
|
||||
if extra_info:
|
||||
for key, value in extra_info.items():
|
||||
params[f"model.{key}"] = value
|
||||
|
||||
self.log_params(params)
|
||||
|
||||
# Also set as tags for easier filtering
|
||||
mlflow.set_tag("model.type", model_type)
|
||||
mlflow.set_tag("model.name", model_name)
|
||||
|
||||
def log_dataset_info(
|
||||
self,
|
||||
name: str,
|
||||
source: str,
|
||||
size: Optional[int] = None,
|
||||
extra_info: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Log dataset information.
|
||||
|
||||
Args:
|
||||
name: Dataset name
|
||||
source: Dataset source (URL, path, etc.)
|
||||
size: Number of samples
|
||||
extra_info: Additional dataset information
|
||||
"""
|
||||
params = {
|
||||
"dataset.name": name,
|
||||
"dataset.source": source,
|
||||
}
|
||||
if size is not None:
|
||||
params["dataset.size"] = size
|
||||
if extra_info:
|
||||
for key, value in extra_info.items():
|
||||
params[f"dataset.{key}"] = value
|
||||
|
||||
self.log_params(params)
|
||||
|
||||
def set_tag(self, key: str, value: str) -> None:
|
||||
"""Set a single tag on the run."""
|
||||
if self.run:
|
||||
mlflow.set_tag(key, value)
|
||||
|
||||
def set_tags(self, tags: Dict[str, str]) -> None:
|
||||
"""Set multiple tags on the run."""
|
||||
if self.run:
|
||||
mlflow.set_tags(tags)
|
||||
|
||||
@property
|
||||
def artifact_uri(self) -> Optional[str]:
|
||||
"""Get the artifact URI for the current run."""
|
||||
if self.run:
|
||||
return self.run.info.artifact_uri
|
||||
return None
|
||||
|
||||
@property
|
||||
def experiment_id(self) -> Optional[str]:
|
||||
"""Get the experiment ID for the current run."""
|
||||
if self.run:
|
||||
return self.run.info.experiment_id
|
||||
return None
|
||||
Reference in New Issue
Block a user