fix: resolve all ruff lint errors
This commit is contained in:
@@ -20,17 +20,17 @@ Usage:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from .client import (
|
from .client import (
|
||||||
|
MLflowConfig,
|
||||||
|
ensure_experiment,
|
||||||
get_mlflow_client,
|
get_mlflow_client,
|
||||||
get_tracking_uri,
|
get_tracking_uri,
|
||||||
ensure_experiment,
|
|
||||||
MLflowConfig,
|
|
||||||
)
|
)
|
||||||
from .tracker import MLflowTracker
|
|
||||||
from .inference_tracker import InferenceMetricsTracker
|
from .inference_tracker import InferenceMetricsTracker
|
||||||
|
from .tracker import MLflowTracker
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_mlflow_client",
|
"get_mlflow_client",
|
||||||
"get_tracking_uri",
|
"get_tracking_uri",
|
||||||
"ensure_experiment",
|
"ensure_experiment",
|
||||||
"MLflowConfig",
|
"MLflowConfig",
|
||||||
"MLflowTracker",
|
"MLflowTracker",
|
||||||
|
|||||||
@@ -7,21 +7,21 @@ Command-line interface for querying and comparing MLflow experiments.
|
|||||||
Usage:
|
Usage:
|
||||||
# Compare recent runs in an experiment
|
# Compare recent runs in an experiment
|
||||||
python -m mlflow_utils.cli compare --experiment chat-inference --runs 5
|
python -m mlflow_utils.cli compare --experiment chat-inference --runs 5
|
||||||
|
|
||||||
# Get best run by metric
|
# Get best run by metric
|
||||||
python -m mlflow_utils.cli best --experiment evaluation --metric eval.accuracy
|
python -m mlflow_utils.cli best --experiment evaluation --metric eval.accuracy
|
||||||
|
|
||||||
# Generate performance report
|
# Generate performance report
|
||||||
python -m mlflow_utils.cli report --service chat-handler --hours 24
|
python -m mlflow_utils.cli report --service chat-handler --hours 24
|
||||||
|
|
||||||
# Check model promotion criteria
|
# Check model promotion criteria
|
||||||
python -m mlflow_utils.cli promote --model whisper-finetuned \\
|
python -m mlflow_utils.cli promote --model whisper-finetuned \\
|
||||||
--experiment voice-evaluation \\
|
--experiment voice-evaluation \\
|
||||||
--criteria "eval.accuracy>=0.9,total_latency_p95<=2.0"
|
--criteria "eval.accuracy>=0.9,total_latency_p95<=2.0"
|
||||||
|
|
||||||
# List experiments
|
# List experiments
|
||||||
python -m mlflow_utils.cli list-experiments
|
python -m mlflow_utils.cli list-experiments
|
||||||
|
|
||||||
# Query runs
|
# Query runs
|
||||||
python -m mlflow_utils.cli query --experiment chat-inference \\
|
python -m mlflow_utils.cli query --experiment chat-inference \\
|
||||||
--filter "metrics.total_latency_mean < 1.0" --limit 10
|
--filter "metrics.total_latency_mean < 1.0" --limit 10
|
||||||
@@ -30,19 +30,16 @@ Usage:
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from .client import get_mlflow_client, health_check
|
from .client import get_mlflow_client, health_check
|
||||||
from .experiment_comparison import (
|
from .experiment_comparison import (
|
||||||
ExperimentAnalyzer,
|
ExperimentAnalyzer,
|
||||||
compare_experiments,
|
|
||||||
promotion_recommendation,
|
|
||||||
get_inference_performance_report,
|
get_inference_performance_report,
|
||||||
|
promotion_recommendation,
|
||||||
)
|
)
|
||||||
from .model_registry import (
|
from .model_registry import (
|
||||||
list_model_versions,
|
|
||||||
get_production_model,
|
|
||||||
generate_kserve_yaml,
|
generate_kserve_yaml,
|
||||||
|
list_model_versions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -57,7 +54,7 @@ def cmd_list_experiments(args):
|
|||||||
"""List all experiments."""
|
"""List all experiments."""
|
||||||
client = get_mlflow_client(tracking_uri=args.tracking_uri)
|
client = get_mlflow_client(tracking_uri=args.tracking_uri)
|
||||||
experiments = client.search_experiments()
|
experiments = client.search_experiments()
|
||||||
|
|
||||||
print(f"{'ID':<10} {'Name':<40} {'Artifact Location'}")
|
print(f"{'ID':<10} {'Name':<40} {'Artifact Location'}")
|
||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
for exp in experiments:
|
for exp in experiments:
|
||||||
@@ -70,13 +67,13 @@ def cmd_compare(args):
|
|||||||
args.experiment,
|
args.experiment,
|
||||||
tracking_uri=args.tracking_uri
|
tracking_uri=args.tracking_uri
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.run_ids:
|
if args.run_ids:
|
||||||
run_ids = args.run_ids.split(",")
|
run_ids = args.run_ids.split(",")
|
||||||
comparison = analyzer.compare_runs(run_ids=run_ids)
|
comparison = analyzer.compare_runs(run_ids=run_ids)
|
||||||
else:
|
else:
|
||||||
comparison = analyzer.compare_runs(n_recent=args.runs)
|
comparison = analyzer.compare_runs(n_recent=args.runs)
|
||||||
|
|
||||||
if args.json:
|
if args.json:
|
||||||
print(json.dumps(comparison.to_dict(), indent=2, default=str))
|
print(json.dumps(comparison.to_dict(), indent=2, default=str))
|
||||||
else:
|
else:
|
||||||
@@ -89,17 +86,17 @@ def cmd_best(args):
|
|||||||
args.experiment,
|
args.experiment,
|
||||||
tracking_uri=args.tracking_uri
|
tracking_uri=args.tracking_uri
|
||||||
)
|
)
|
||||||
|
|
||||||
best_run = analyzer.get_best_run(
|
best_run = analyzer.get_best_run(
|
||||||
metric=args.metric,
|
metric=args.metric,
|
||||||
minimize=args.minimize,
|
minimize=args.minimize,
|
||||||
filter_string=args.filter or "",
|
filter_string=args.filter or "",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not best_run:
|
if not best_run:
|
||||||
print(f"No runs found with metric '{args.metric}'")
|
print(f"No runs found with metric '{args.metric}'")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"run_id": best_run.info.run_id,
|
"run_id": best_run.info.run_id,
|
||||||
"run_name": best_run.info.run_name,
|
"run_name": best_run.info.run_name,
|
||||||
@@ -107,7 +104,7 @@ def cmd_best(args):
|
|||||||
"all_metrics": dict(best_run.data.metrics),
|
"all_metrics": dict(best_run.data.metrics),
|
||||||
"params": dict(best_run.data.params),
|
"params": dict(best_run.data.params),
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.json:
|
if args.json:
|
||||||
print(json.dumps(result, indent=2))
|
print(json.dumps(result, indent=2))
|
||||||
else:
|
else:
|
||||||
@@ -122,12 +119,12 @@ def cmd_summary(args):
|
|||||||
args.experiment,
|
args.experiment,
|
||||||
tracking_uri=args.tracking_uri
|
tracking_uri=args.tracking_uri
|
||||||
)
|
)
|
||||||
|
|
||||||
summary = analyzer.get_metrics_summary(
|
summary = analyzer.get_metrics_summary(
|
||||||
hours=args.hours,
|
hours=args.hours,
|
||||||
metrics=args.metrics.split(",") if args.metrics else None,
|
metrics=args.metrics.split(",") if args.metrics else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.json:
|
if args.json:
|
||||||
print(json.dumps(summary, indent=2))
|
print(json.dumps(summary, indent=2))
|
||||||
else:
|
else:
|
||||||
@@ -146,7 +143,7 @@ def cmd_report(args):
|
|||||||
hours=args.hours,
|
hours=args.hours,
|
||||||
tracking_uri=args.tracking_uri,
|
tracking_uri=args.tracking_uri,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.json:
|
if args.json:
|
||||||
print(json.dumps(report, indent=2))
|
print(json.dumps(report, indent=2))
|
||||||
else:
|
else:
|
||||||
@@ -154,18 +151,18 @@ def cmd_report(args):
|
|||||||
print(f"Period: Last {report['period_hours']} hours")
|
print(f"Period: Last {report['period_hours']} hours")
|
||||||
print(f"Generated: {report['generated_at']}")
|
print(f"Generated: {report['generated_at']}")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
if report["latency"]:
|
if report["latency"]:
|
||||||
print("Latency Metrics:")
|
print("Latency Metrics:")
|
||||||
for metric, stats in report["latency"].items():
|
for metric, stats in report["latency"].items():
|
||||||
if "mean" in stats:
|
if "mean" in stats:
|
||||||
print(f" {metric}: {stats['mean']:.4f}s (p50: {stats.get('median', 'N/A')})")
|
print(f" {metric}: {stats['mean']:.4f}s (p50: {stats.get('median', 'N/A')})")
|
||||||
|
|
||||||
if report["rag"]:
|
if report["rag"]:
|
||||||
print("\nRAG Usage:")
|
print("\nRAG Usage:")
|
||||||
for metric, stats in report["rag"].items():
|
for metric, stats in report["rag"].items():
|
||||||
print(f" {metric}: {stats.get('mean', 'N/A')}")
|
print(f" {metric}: {stats.get('mean', 'N/A')}")
|
||||||
|
|
||||||
if report["errors"]:
|
if report["errors"]:
|
||||||
print("\nError Rates:")
|
print("\nError Rates:")
|
||||||
for metric, stats in report["errors"].items():
|
for metric, stats in report["errors"].items():
|
||||||
@@ -183,14 +180,14 @@ def cmd_promote(args):
|
|||||||
metric, value = criterion.split(op)
|
metric, value = criterion.split(op)
|
||||||
criteria[metric.strip()] = (op, float(value.strip()))
|
criteria[metric.strip()] = (op, float(value.strip()))
|
||||||
break
|
break
|
||||||
|
|
||||||
rec = promotion_recommendation(
|
rec = promotion_recommendation(
|
||||||
model_name=args.model,
|
model_name=args.model,
|
||||||
experiment_name=args.experiment,
|
experiment_name=args.experiment,
|
||||||
criteria=criteria,
|
criteria=criteria,
|
||||||
tracking_uri=args.tracking_uri,
|
tracking_uri=args.tracking_uri,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.json:
|
if args.json:
|
||||||
print(json.dumps(rec.to_dict(), indent=2))
|
print(json.dumps(rec.to_dict(), indent=2))
|
||||||
else:
|
else:
|
||||||
@@ -208,12 +205,12 @@ def cmd_query(args):
|
|||||||
args.experiment,
|
args.experiment,
|
||||||
tracking_uri=args.tracking_uri
|
tracking_uri=args.tracking_uri
|
||||||
)
|
)
|
||||||
|
|
||||||
runs = analyzer.search_runs(
|
runs = analyzer.search_runs(
|
||||||
filter_string=args.filter or "",
|
filter_string=args.filter or "",
|
||||||
max_results=args.limit,
|
max_results=args.limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.json:
|
if args.json:
|
||||||
result = [
|
result = [
|
||||||
{
|
{
|
||||||
@@ -237,20 +234,21 @@ def cmd_query(args):
|
|||||||
def cmd_models(args):
|
def cmd_models(args):
|
||||||
"""List registered models."""
|
"""List registered models."""
|
||||||
client = get_mlflow_client(tracking_uri=args.tracking_uri)
|
client = get_mlflow_client(tracking_uri=args.tracking_uri)
|
||||||
|
|
||||||
if args.model:
|
if args.model:
|
||||||
versions = list_model_versions(args.model, tracking_uri=args.tracking_uri)
|
versions = list_model_versions(args.model, tracking_uri=args.tracking_uri)
|
||||||
|
|
||||||
if args.json:
|
if args.json:
|
||||||
print(json.dumps(versions, indent=2, default=str))
|
print(json.dumps(versions, indent=2, default=str))
|
||||||
else:
|
else:
|
||||||
print(f"Model: {args.model}")
|
print(f"Model: {args.model}")
|
||||||
for v in versions:
|
for v in versions:
|
||||||
print(f" v{v['version']} ({v['stage']}): {v['description'][:50] if v['description'] else 'No description'}")
|
desc = v["description"][:50] if v["description"] else "No description"
|
||||||
|
print(f" v{v['version']} ({v['stage']}): {desc}")
|
||||||
else:
|
else:
|
||||||
# List all models
|
# List all models
|
||||||
models = client.search_registered_models()
|
models = client.search_registered_models()
|
||||||
|
|
||||||
if args.json:
|
if args.json:
|
||||||
result = [{"name": m.name, "description": m.description} for m in models]
|
result = [{"name": m.name, "description": m.description} for m in models]
|
||||||
print(json.dumps(result, indent=2))
|
print(json.dumps(result, indent=2))
|
||||||
@@ -271,7 +269,7 @@ def cmd_kserve(args):
|
|||||||
output_path=args.output,
|
output_path=args.output,
|
||||||
tracking_uri=args.tracking_uri,
|
tracking_uri=args.tracking_uri,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not args.output:
|
if not args.output:
|
||||||
print(yaml_str)
|
print(yaml_str)
|
||||||
|
|
||||||
@@ -281,7 +279,7 @@ def main():
|
|||||||
description="MLflow Experiment CLI",
|
description="MLflow Experiment CLI",
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tracking-uri",
|
"--tracking-uri",
|
||||||
default=None,
|
default=None,
|
||||||
@@ -292,24 +290,24 @@ def main():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Output as JSON",
|
help="Output as JSON",
|
||||||
)
|
)
|
||||||
|
|
||||||
subparsers = parser.add_subparsers(dest="command", help="Commands")
|
subparsers = parser.add_subparsers(dest="command", help="Commands")
|
||||||
|
|
||||||
# health
|
# health
|
||||||
health_parser = subparsers.add_parser("health", help="Check MLflow connectivity")
|
health_parser = subparsers.add_parser("health", help="Check MLflow connectivity")
|
||||||
health_parser.set_defaults(func=cmd_health)
|
health_parser.set_defaults(func=cmd_health)
|
||||||
|
|
||||||
# list-experiments
|
# list-experiments
|
||||||
list_parser = subparsers.add_parser("list-experiments", help="List experiments")
|
list_parser = subparsers.add_parser("list-experiments", help="List experiments")
|
||||||
list_parser.set_defaults(func=cmd_list_experiments)
|
list_parser.set_defaults(func=cmd_list_experiments)
|
||||||
|
|
||||||
# compare
|
# compare
|
||||||
compare_parser = subparsers.add_parser("compare", help="Compare runs")
|
compare_parser = subparsers.add_parser("compare", help="Compare runs")
|
||||||
compare_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
compare_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
||||||
compare_parser.add_argument("--runs", "-n", type=int, default=5, help="Number of recent runs")
|
compare_parser.add_argument("--runs", "-n", type=int, default=5, help="Number of recent runs")
|
||||||
compare_parser.add_argument("--run-ids", help="Comma-separated run IDs to compare")
|
compare_parser.add_argument("--run-ids", help="Comma-separated run IDs to compare")
|
||||||
compare_parser.set_defaults(func=cmd_compare)
|
compare_parser.set_defaults(func=cmd_compare)
|
||||||
|
|
||||||
# best
|
# best
|
||||||
best_parser = subparsers.add_parser("best", help="Find best run by metric")
|
best_parser = subparsers.add_parser("best", help="Find best run by metric")
|
||||||
best_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
best_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
||||||
@@ -317,39 +315,39 @@ def main():
|
|||||||
best_parser.add_argument("--minimize", action="store_true", help="Minimize metric (default: maximize)")
|
best_parser.add_argument("--minimize", action="store_true", help="Minimize metric (default: maximize)")
|
||||||
best_parser.add_argument("--filter", "-f", help="Filter string")
|
best_parser.add_argument("--filter", "-f", help="Filter string")
|
||||||
best_parser.set_defaults(func=cmd_best)
|
best_parser.set_defaults(func=cmd_best)
|
||||||
|
|
||||||
# summary
|
# summary
|
||||||
summary_parser = subparsers.add_parser("summary", help="Get metrics summary")
|
summary_parser = subparsers.add_parser("summary", help="Get metrics summary")
|
||||||
summary_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
summary_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
||||||
summary_parser.add_argument("--hours", type=int, default=24, help="Hours of data")
|
summary_parser.add_argument("--hours", type=int, default=24, help="Hours of data")
|
||||||
summary_parser.add_argument("--metrics", help="Comma-separated metric names")
|
summary_parser.add_argument("--metrics", help="Comma-separated metric names")
|
||||||
summary_parser.set_defaults(func=cmd_summary)
|
summary_parser.set_defaults(func=cmd_summary)
|
||||||
|
|
||||||
# report
|
# report
|
||||||
report_parser = subparsers.add_parser("report", help="Generate performance report")
|
report_parser = subparsers.add_parser("report", help="Generate performance report")
|
||||||
report_parser.add_argument("--service", "-s", required=True, help="Service name")
|
report_parser.add_argument("--service", "-s", required=True, help="Service name")
|
||||||
report_parser.add_argument("--hours", type=int, default=24, help="Hours of data")
|
report_parser.add_argument("--hours", type=int, default=24, help="Hours of data")
|
||||||
report_parser.set_defaults(func=cmd_report)
|
report_parser.set_defaults(func=cmd_report)
|
||||||
|
|
||||||
# promote
|
# promote
|
||||||
promote_parser = subparsers.add_parser("promote", help="Check promotion criteria")
|
promote_parser = subparsers.add_parser("promote", help="Check promotion criteria")
|
||||||
promote_parser.add_argument("--model", "-m", required=True, help="Model name")
|
promote_parser.add_argument("--model", "-m", required=True, help="Model name")
|
||||||
promote_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
promote_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
||||||
promote_parser.add_argument("--criteria", "-c", required=True, help="Criteria (e.g., 'accuracy>=0.9,latency<=2.0')")
|
promote_parser.add_argument("--criteria", "-c", required=True, help="Criteria (e.g., 'accuracy>=0.9,latency<=2.0')")
|
||||||
promote_parser.set_defaults(func=cmd_promote)
|
promote_parser.set_defaults(func=cmd_promote)
|
||||||
|
|
||||||
# query
|
# query
|
||||||
query_parser = subparsers.add_parser("query", help="Query runs")
|
query_parser = subparsers.add_parser("query", help="Query runs")
|
||||||
query_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
query_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
||||||
query_parser.add_argument("--filter", "-f", help="MLflow filter string")
|
query_parser.add_argument("--filter", "-f", help="MLflow filter string")
|
||||||
query_parser.add_argument("--limit", "-l", type=int, default=20, help="Max results")
|
query_parser.add_argument("--limit", "-l", type=int, default=20, help="Max results")
|
||||||
query_parser.set_defaults(func=cmd_query)
|
query_parser.set_defaults(func=cmd_query)
|
||||||
|
|
||||||
# models
|
# models
|
||||||
models_parser = subparsers.add_parser("models", help="List registered models")
|
models_parser = subparsers.add_parser("models", help="List registered models")
|
||||||
models_parser.add_argument("--model", "-m", help="Specific model name")
|
models_parser.add_argument("--model", "-m", help="Specific model name")
|
||||||
models_parser.set_defaults(func=cmd_models)
|
models_parser.set_defaults(func=cmd_models)
|
||||||
|
|
||||||
# kserve
|
# kserve
|
||||||
kserve_parser = subparsers.add_parser("kserve", help="Generate KServe manifest")
|
kserve_parser = subparsers.add_parser("kserve", help="Generate KServe manifest")
|
||||||
kserve_parser.add_argument("--model", "-m", required=True, help="Model name")
|
kserve_parser.add_argument("--model", "-m", required=True, help="Model name")
|
||||||
@@ -357,13 +355,13 @@ def main():
|
|||||||
kserve_parser.add_argument("--namespace", "-n", default="ai-ml", help="K8s namespace")
|
kserve_parser.add_argument("--namespace", "-n", default="ai-ml", help="K8s namespace")
|
||||||
kserve_parser.add_argument("--output", "-o", help="Output file path")
|
kserve_parser.add_argument("--output", "-o", help="Output file path")
|
||||||
kserve_parser.set_defaults(func=cmd_kserve)
|
kserve_parser.set_defaults(func=cmd_kserve)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not args.command:
|
if not args.command:
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
args.func(args)
|
args.func(args)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,10 +5,10 @@ Provides a configured MLflow client for all integrations in the LLM workflows.
|
|||||||
Supports both in-cluster and external access patterns.
|
Supports both in-cluster and external access patterns.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, Dict, Any
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import mlflow
|
import mlflow
|
||||||
from mlflow.tracking import MlflowClient
|
from mlflow.tracking import MlflowClient
|
||||||
@@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MLflowConfig:
|
class MLflowConfig:
|
||||||
"""Configuration for MLflow integration."""
|
"""Configuration for MLflow integration."""
|
||||||
|
|
||||||
# Tracking server URIs
|
# Tracking server URIs
|
||||||
tracking_uri: str = field(
|
tracking_uri: str = field(
|
||||||
default_factory=lambda: os.environ.get(
|
default_factory=lambda: os.environ.get(
|
||||||
@@ -33,7 +33,7 @@ class MLflowConfig:
|
|||||||
"https://mlflow.lab.daviestechlabs.io"
|
"https://mlflow.lab.daviestechlabs.io"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Artifact storage (NFS PVC mount)
|
# Artifact storage (NFS PVC mount)
|
||||||
artifact_location: str = field(
|
artifact_location: str = field(
|
||||||
default_factory=lambda: os.environ.get(
|
default_factory=lambda: os.environ.get(
|
||||||
@@ -41,7 +41,7 @@ class MLflowConfig:
|
|||||||
"/mlflow/artifacts"
|
"/mlflow/artifacts"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Default experiment settings
|
# Default experiment settings
|
||||||
default_experiment: str = field(
|
default_experiment: str = field(
|
||||||
default_factory=lambda: os.environ.get(
|
default_factory=lambda: os.environ.get(
|
||||||
@@ -49,7 +49,7 @@ class MLflowConfig:
|
|||||||
"llm-workflows"
|
"llm-workflows"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Service identification
|
# Service identification
|
||||||
service_name: str = field(
|
service_name: str = field(
|
||||||
default_factory=lambda: os.environ.get(
|
default_factory=lambda: os.environ.get(
|
||||||
@@ -57,10 +57,10 @@ class MLflowConfig:
|
|||||||
"unknown-service"
|
"unknown-service"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Additional tags to add to all runs
|
# Additional tags to add to all runs
|
||||||
default_tags: Dict[str, str] = field(default_factory=dict)
|
default_tags: Dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Add default tags based on environment."""
|
"""Add default tags based on environment."""
|
||||||
env_tags = {
|
env_tags = {
|
||||||
@@ -74,10 +74,10 @@ class MLflowConfig:
|
|||||||
def get_tracking_uri(external: bool = False) -> str:
|
def get_tracking_uri(external: bool = False) -> str:
|
||||||
"""
|
"""
|
||||||
Get the appropriate MLflow tracking URI.
|
Get the appropriate MLflow tracking URI.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
external: If True, return the external URI for outside-cluster access
|
external: If True, return the external URI for outside-cluster access
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The MLflow tracking URI string
|
The MLflow tracking URI string
|
||||||
"""
|
"""
|
||||||
@@ -91,20 +91,20 @@ def get_mlflow_client(
|
|||||||
) -> MlflowClient:
|
) -> MlflowClient:
|
||||||
"""
|
"""
|
||||||
Get a configured MLflow client.
|
Get a configured MLflow client.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tracking_uri: Override the default tracking URI
|
tracking_uri: Override the default tracking URI
|
||||||
configure_global: If True, also set mlflow.set_tracking_uri()
|
configure_global: If True, also set mlflow.set_tracking_uri()
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Configured MlflowClient instance
|
Configured MlflowClient instance
|
||||||
"""
|
"""
|
||||||
uri = tracking_uri or get_tracking_uri()
|
uri = tracking_uri or get_tracking_uri()
|
||||||
|
|
||||||
if configure_global:
|
if configure_global:
|
||||||
mlflow.set_tracking_uri(uri)
|
mlflow.set_tracking_uri(uri)
|
||||||
logger.info(f"MLflow tracking URI set to: {uri}")
|
logger.info(f"MLflow tracking URI set to: {uri}")
|
||||||
|
|
||||||
client = MlflowClient(tracking_uri=uri)
|
client = MlflowClient(tracking_uri=uri)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@@ -116,21 +116,21 @@ def ensure_experiment(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Ensure an experiment exists, creating it if necessary.
|
Ensure an experiment exists, creating it if necessary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
experiment_name: Name of the experiment
|
experiment_name: Name of the experiment
|
||||||
artifact_location: Override default artifact location
|
artifact_location: Override default artifact location
|
||||||
tags: Additional tags for the experiment
|
tags: Additional tags for the experiment
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The experiment ID
|
The experiment ID
|
||||||
"""
|
"""
|
||||||
config = MLflowConfig()
|
config = MLflowConfig()
|
||||||
client = get_mlflow_client()
|
client = get_mlflow_client()
|
||||||
|
|
||||||
# Check if experiment exists
|
# Check if experiment exists
|
||||||
experiment = client.get_experiment_by_name(experiment_name)
|
experiment = client.get_experiment_by_name(experiment_name)
|
||||||
|
|
||||||
if experiment is None:
|
if experiment is None:
|
||||||
# Create the experiment
|
# Create the experiment
|
||||||
artifact_loc = artifact_location or f"{config.artifact_location}/{experiment_name}"
|
artifact_loc = artifact_location or f"{config.artifact_location}/{experiment_name}"
|
||||||
@@ -143,7 +143,7 @@ def ensure_experiment(
|
|||||||
else:
|
else:
|
||||||
experiment_id = experiment.experiment_id
|
experiment_id = experiment.experiment_id
|
||||||
logger.debug(f"Using existing experiment '{experiment_name}' with ID: {experiment_id}")
|
logger.debug(f"Using existing experiment '{experiment_name}' with ID: {experiment_id}")
|
||||||
|
|
||||||
return experiment_id
|
return experiment_id
|
||||||
|
|
||||||
|
|
||||||
@@ -154,17 +154,17 @@ def get_or_create_registered_model(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Get or create a registered model in the Model Registry.
|
Get or create a registered model in the Model Registry.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Name of the model to register
|
model_name: Name of the model to register
|
||||||
description: Model description
|
description: Model description
|
||||||
tags: Tags for the model
|
tags: Tags for the model
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The registered model name
|
The registered model name
|
||||||
"""
|
"""
|
||||||
client = get_mlflow_client()
|
client = get_mlflow_client()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Check if model exists
|
# Check if model exists
|
||||||
client.get_registered_model(model_name)
|
client.get_registered_model(model_name)
|
||||||
@@ -177,14 +177,14 @@ def get_or_create_registered_model(
|
|||||||
tags=tags or {}
|
tags=tags or {}
|
||||||
)
|
)
|
||||||
logger.info(f"Created registered model: {model_name}")
|
logger.info(f"Created registered model: {model_name}")
|
||||||
|
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
|
|
||||||
def health_check() -> Dict[str, Any]:
|
def health_check() -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Check MLflow server connectivity and return status.
|
Check MLflow server connectivity and return status.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with health status information
|
Dictionary with health status information
|
||||||
"""
|
"""
|
||||||
@@ -195,7 +195,7 @@ def health_check() -> Dict[str, Any]:
|
|||||||
"connected": False,
|
"connected": False,
|
||||||
"error": None,
|
"error": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client = get_mlflow_client(configure_global=False)
|
client = get_mlflow_client(configure_global=False)
|
||||||
# Try to list experiments as a health check
|
# Try to list experiments as a health check
|
||||||
@@ -205,5 +205,5 @@ def health_check() -> Dict[str, Any]:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
result["error"] = str(e)
|
result["error"] = str(e)
|
||||||
logger.error(f"MLflow health check failed: {e}")
|
logger.error(f"MLflow health check failed: {e}")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -18,15 +18,15 @@ Usage:
|
|||||||
get_best_run,
|
get_best_run,
|
||||||
promotion_recommendation,
|
promotion_recommendation,
|
||||||
)
|
)
|
||||||
|
|
||||||
analyzer = ExperimentAnalyzer("chat-inference")
|
analyzer = ExperimentAnalyzer("chat-inference")
|
||||||
|
|
||||||
# Compare last N runs
|
# Compare last N runs
|
||||||
comparison = analyzer.compare_recent_runs(n=5)
|
comparison = analyzer.compare_recent_runs(n=5)
|
||||||
|
|
||||||
# Find best performing model
|
# Find best performing model
|
||||||
best = analyzer.get_best_run(metric="total_latency_mean", minimize=True)
|
best = analyzer.get_best_run(metric="total_latency_mean", minimize=True)
|
||||||
|
|
||||||
# Get promotion recommendation
|
# Get promotion recommendation
|
||||||
rec = analyzer.promotion_recommendation(
|
rec = analyzer.promotion_recommendation(
|
||||||
model_name="whisper-finetuned",
|
model_name="whisper-finetuned",
|
||||||
@@ -35,19 +35,15 @@ Usage:
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Optional, Dict, Any, List, Tuple, Union
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import mlflow
|
from mlflow.entities import Experiment, Run
|
||||||
from mlflow.tracking import MlflowClient
|
|
||||||
from mlflow.entities import Run, Experiment
|
|
||||||
|
|
||||||
from .client import get_mlflow_client, MLflowConfig
|
from .client import get_mlflow_client
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -57,21 +53,21 @@ class RunComparison:
|
|||||||
"""Comparison result for multiple MLflow runs."""
|
"""Comparison result for multiple MLflow runs."""
|
||||||
run_ids: List[str]
|
run_ids: List[str]
|
||||||
experiment_name: str
|
experiment_name: str
|
||||||
|
|
||||||
# Metric comparisons (metric_name -> {run_id -> value})
|
# Metric comparisons (metric_name -> {run_id -> value})
|
||||||
metrics: Dict[str, Dict[str, float]] = field(default_factory=dict)
|
metrics: Dict[str, Dict[str, float]] = field(default_factory=dict)
|
||||||
|
|
||||||
# Parameter differences
|
# Parameter differences
|
||||||
params: Dict[str, Dict[str, str]] = field(default_factory=dict)
|
params: Dict[str, Dict[str, str]] = field(default_factory=dict)
|
||||||
|
|
||||||
# Run metadata
|
# Run metadata
|
||||||
run_names: Dict[str, str] = field(default_factory=dict)
|
run_names: Dict[str, str] = field(default_factory=dict)
|
||||||
start_times: Dict[str, datetime] = field(default_factory=dict)
|
start_times: Dict[str, datetime] = field(default_factory=dict)
|
||||||
durations: Dict[str, float] = field(default_factory=dict)
|
durations: Dict[str, float] = field(default_factory=dict)
|
||||||
|
|
||||||
# Best performers by metric
|
# Best performers by metric
|
||||||
best_by_metric: Dict[str, str] = field(default_factory=dict)
|
best_by_metric: Dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""Convert to dictionary for serialization."""
|
"""Convert to dictionary for serialization."""
|
||||||
return {
|
return {
|
||||||
@@ -82,22 +78,22 @@ class RunComparison:
|
|||||||
"run_names": self.run_names,
|
"run_names": self.run_names,
|
||||||
"best_by_metric": self.best_by_metric,
|
"best_by_metric": self.best_by_metric,
|
||||||
}
|
}
|
||||||
|
|
||||||
def summary_table(self) -> str:
|
def summary_table(self) -> str:
|
||||||
"""Generate a text summary table of the comparison."""
|
"""Generate a text summary table of the comparison."""
|
||||||
if not self.run_ids:
|
if not self.run_ids:
|
||||||
return "No runs to compare"
|
return "No runs to compare"
|
||||||
|
|
||||||
lines = []
|
lines = []
|
||||||
lines.append(f"Experiment: {self.experiment_name}")
|
lines.append(f"Experiment: {self.experiment_name}")
|
||||||
lines.append(f"Comparing {len(self.run_ids)} runs")
|
lines.append(f"Comparing {len(self.run_ids)} runs")
|
||||||
lines.append("")
|
lines.append("")
|
||||||
|
|
||||||
# Header
|
# Header
|
||||||
header = ["Metric"] + [self.run_names.get(rid, rid[:8]) for rid in self.run_ids]
|
header = ["Metric"] + [self.run_names.get(rid, rid[:8]) for rid in self.run_ids]
|
||||||
lines.append(" | ".join(header))
|
lines.append(" | ".join(header))
|
||||||
lines.append("-" * (len(lines[-1]) + 10))
|
lines.append("-" * (len(lines[-1]) + 10))
|
||||||
|
|
||||||
# Metrics
|
# Metrics
|
||||||
for metric_name, values in sorted(self.metrics.items()):
|
for metric_name, values in sorted(self.metrics.items()):
|
||||||
row = [metric_name]
|
row = [metric_name]
|
||||||
@@ -108,7 +104,7 @@ class RunComparison:
|
|||||||
else:
|
else:
|
||||||
row.append("N/A")
|
row.append("N/A")
|
||||||
lines.append(" | ".join(row))
|
lines.append(" | ".join(row))
|
||||||
|
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
@@ -121,7 +117,7 @@ class PromotionRecommendation:
|
|||||||
reasons: List[str]
|
reasons: List[str]
|
||||||
metrics_summary: Dict[str, float]
|
metrics_summary: Dict[str, float]
|
||||||
comparison_with_production: Optional[Dict[str, Any]] = None
|
comparison_with_production: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"model_name": self.model_name,
|
"model_name": self.model_name,
|
||||||
@@ -136,20 +132,20 @@ class PromotionRecommendation:
|
|||||||
class ExperimentAnalyzer:
|
class ExperimentAnalyzer:
|
||||||
"""
|
"""
|
||||||
Analyze MLflow experiments for model comparison and promotion decisions.
|
Analyze MLflow experiments for model comparison and promotion decisions.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
analyzer = ExperimentAnalyzer("chat-inference")
|
analyzer = ExperimentAnalyzer("chat-inference")
|
||||||
|
|
||||||
# Get metrics summary for last 24 hours
|
# Get metrics summary for last 24 hours
|
||||||
summary = analyzer.get_metrics_summary(hours=24)
|
summary = analyzer.get_metrics_summary(hours=24)
|
||||||
|
|
||||||
# Compare models by accuracy
|
# Compare models by accuracy
|
||||||
best = analyzer.get_best_run(metric="eval.accuracy", minimize=False)
|
best = analyzer.get_best_run(metric="eval.accuracy", minimize=False)
|
||||||
|
|
||||||
# Analyze inference latency trends
|
# Analyze inference latency trends
|
||||||
trends = analyzer.get_metric_trends("total_latency_mean", days=7)
|
trends = analyzer.get_metric_trends("total_latency_mean", days=7)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
experiment_name: str,
|
experiment_name: str,
|
||||||
@@ -157,7 +153,7 @@ class ExperimentAnalyzer:
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the experiment analyzer.
|
Initialize the experiment analyzer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
experiment_name: Name of the MLflow experiment to analyze
|
experiment_name: Name of the MLflow experiment to analyze
|
||||||
tracking_uri: Override default tracking URI
|
tracking_uri: Override default tracking URI
|
||||||
@@ -166,14 +162,14 @@ class ExperimentAnalyzer:
|
|||||||
self.tracking_uri = tracking_uri
|
self.tracking_uri = tracking_uri
|
||||||
self.client = get_mlflow_client(tracking_uri=tracking_uri)
|
self.client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||||
self._experiment: Optional[Experiment] = None
|
self._experiment: Optional[Experiment] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def experiment(self) -> Optional[Experiment]:
|
def experiment(self) -> Optional[Experiment]:
|
||||||
"""Get the experiment object, fetching if needed."""
|
"""Get the experiment object, fetching if needed."""
|
||||||
if self._experiment is None:
|
if self._experiment is None:
|
||||||
self._experiment = self.client.get_experiment_by_name(self.experiment_name)
|
self._experiment = self.client.get_experiment_by_name(self.experiment_name)
|
||||||
return self._experiment
|
return self._experiment
|
||||||
|
|
||||||
def search_runs(
|
def search_runs(
|
||||||
self,
|
self,
|
||||||
filter_string: str = "",
|
filter_string: str = "",
|
||||||
@@ -183,29 +179,29 @@ class ExperimentAnalyzer:
|
|||||||
) -> List[Run]:
|
) -> List[Run]:
|
||||||
"""
|
"""
|
||||||
Search for runs matching criteria.
|
Search for runs matching criteria.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filter_string: MLflow search filter (e.g., "metrics.accuracy > 0.9")
|
filter_string: MLflow search filter (e.g., "metrics.accuracy > 0.9")
|
||||||
order_by: List of order clauses (e.g., ["metrics.accuracy DESC"])
|
order_by: List of order clauses (e.g., ["metrics.accuracy DESC"])
|
||||||
max_results: Maximum runs to return
|
max_results: Maximum runs to return
|
||||||
run_view_type: ACTIVE_ONLY, DELETED_ONLY, or ALL
|
run_view_type: ACTIVE_ONLY, DELETED_ONLY, or ALL
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of matching Run objects
|
List of matching Run objects
|
||||||
"""
|
"""
|
||||||
if not self.experiment:
|
if not self.experiment:
|
||||||
logger.warning(f"Experiment '{self.experiment_name}' not found")
|
logger.warning(f"Experiment '{self.experiment_name}' not found")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
runs = self.client.search_runs(
|
runs = self.client.search_runs(
|
||||||
experiment_ids=[self.experiment.experiment_id],
|
experiment_ids=[self.experiment.experiment_id],
|
||||||
filter_string=filter_string,
|
filter_string=filter_string,
|
||||||
order_by=order_by or ["start_time DESC"],
|
order_by=order_by or ["start_time DESC"],
|
||||||
max_results=max_results,
|
max_results=max_results,
|
||||||
)
|
)
|
||||||
|
|
||||||
return runs
|
return runs
|
||||||
|
|
||||||
def get_recent_runs(
|
def get_recent_runs(
|
||||||
self,
|
self,
|
||||||
n: int = 10,
|
n: int = 10,
|
||||||
@@ -213,11 +209,11 @@ class ExperimentAnalyzer:
|
|||||||
) -> List[Run]:
|
) -> List[Run]:
|
||||||
"""
|
"""
|
||||||
Get the most recent runs.
|
Get the most recent runs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n: Number of runs to return
|
n: Number of runs to return
|
||||||
hours: Only include runs from the last N hours
|
hours: Only include runs from the last N hours
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of Run objects
|
List of Run objects
|
||||||
"""
|
"""
|
||||||
@@ -226,13 +222,13 @@ class ExperimentAnalyzer:
|
|||||||
cutoff = datetime.now() - timedelta(hours=hours)
|
cutoff = datetime.now() - timedelta(hours=hours)
|
||||||
cutoff_ms = int(cutoff.timestamp() * 1000)
|
cutoff_ms = int(cutoff.timestamp() * 1000)
|
||||||
filter_string = f"attributes.start_time >= {cutoff_ms}"
|
filter_string = f"attributes.start_time >= {cutoff_ms}"
|
||||||
|
|
||||||
return self.search_runs(
|
return self.search_runs(
|
||||||
filter_string=filter_string,
|
filter_string=filter_string,
|
||||||
order_by=["start_time DESC"],
|
order_by=["start_time DESC"],
|
||||||
max_results=n,
|
max_results=n,
|
||||||
)
|
)
|
||||||
|
|
||||||
def compare_runs(
|
def compare_runs(
|
||||||
self,
|
self,
|
||||||
run_ids: Optional[List[str]] = None,
|
run_ids: Optional[List[str]] = None,
|
||||||
@@ -240,11 +236,11 @@ class ExperimentAnalyzer:
|
|||||||
) -> RunComparison:
|
) -> RunComparison:
|
||||||
"""
|
"""
|
||||||
Compare multiple runs side by side.
|
Compare multiple runs side by side.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
run_ids: Specific run IDs to compare, or None for recent runs
|
run_ids: Specific run IDs to compare, or None for recent runs
|
||||||
n_recent: If run_ids is None, compare this many recent runs
|
n_recent: If run_ids is None, compare this many recent runs
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
RunComparison object with detailed comparison
|
RunComparison object with detailed comparison
|
||||||
"""
|
"""
|
||||||
@@ -252,18 +248,18 @@ class ExperimentAnalyzer:
|
|||||||
runs = [self.client.get_run(rid) for rid in run_ids]
|
runs = [self.client.get_run(rid) for rid in run_ids]
|
||||||
else:
|
else:
|
||||||
runs = self.get_recent_runs(n=n_recent)
|
runs = self.get_recent_runs(n=n_recent)
|
||||||
|
|
||||||
comparison = RunComparison(
|
comparison = RunComparison(
|
||||||
run_ids=[r.info.run_id for r in runs],
|
run_ids=[r.info.run_id for r in runs],
|
||||||
experiment_name=self.experiment_name,
|
experiment_name=self.experiment_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Collect all metrics and find best performers
|
# Collect all metrics and find best performers
|
||||||
all_metrics: Dict[str, Dict[str, float]] = defaultdict(dict)
|
all_metrics: Dict[str, Dict[str, float]] = defaultdict(dict)
|
||||||
|
|
||||||
for run in runs:
|
for run in runs:
|
||||||
run_id = run.info.run_id
|
run_id = run.info.run_id
|
||||||
|
|
||||||
# Metadata
|
# Metadata
|
||||||
comparison.run_names[run_id] = run.info.run_name or run_id[:8]
|
comparison.run_names[run_id] = run.info.run_name or run_id[:8]
|
||||||
comparison.start_times[run_id] = datetime.fromtimestamp(
|
comparison.start_times[run_id] = datetime.fromtimestamp(
|
||||||
@@ -273,39 +269,39 @@ class ExperimentAnalyzer:
|
|||||||
comparison.durations[run_id] = (
|
comparison.durations[run_id] = (
|
||||||
run.info.end_time - run.info.start_time
|
run.info.end_time - run.info.start_time
|
||||||
) / 1000
|
) / 1000
|
||||||
|
|
||||||
# Metrics
|
# Metrics
|
||||||
for key, value in run.data.metrics.items():
|
for key, value in run.data.metrics.items():
|
||||||
all_metrics[key][run_id] = value
|
all_metrics[key][run_id] = value
|
||||||
|
|
||||||
# Params
|
# Params
|
||||||
for key, value in run.data.params.items():
|
for key, value in run.data.params.items():
|
||||||
if key not in comparison.params:
|
if key not in comparison.params:
|
||||||
comparison.params[key] = {}
|
comparison.params[key] = {}
|
||||||
comparison.params[key][run_id] = value
|
comparison.params[key][run_id] = value
|
||||||
|
|
||||||
comparison.metrics = dict(all_metrics)
|
comparison.metrics = dict(all_metrics)
|
||||||
|
|
||||||
# Find best performers for each metric
|
# Find best performers for each metric
|
||||||
for metric_name, values in all_metrics.items():
|
for metric_name, values in all_metrics.items():
|
||||||
if not values:
|
if not values:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Determine if lower is better based on metric name
|
# Determine if lower is better based on metric name
|
||||||
minimize = any(
|
minimize = any(
|
||||||
term in metric_name.lower()
|
term in metric_name.lower()
|
||||||
for term in ["latency", "error", "loss", "time"]
|
for term in ["latency", "error", "loss", "time"]
|
||||||
)
|
)
|
||||||
|
|
||||||
if minimize:
|
if minimize:
|
||||||
best_id = min(values.keys(), key=lambda k: values[k])
|
best_id = min(values.keys(), key=lambda k: values[k])
|
||||||
else:
|
else:
|
||||||
best_id = max(values.keys(), key=lambda k: values[k])
|
best_id = max(values.keys(), key=lambda k: values[k])
|
||||||
|
|
||||||
comparison.best_by_metric[metric_name] = best_id
|
comparison.best_by_metric[metric_name] = best_id
|
||||||
|
|
||||||
return comparison
|
return comparison
|
||||||
|
|
||||||
def get_best_run(
|
def get_best_run(
|
||||||
self,
|
self,
|
||||||
metric: str,
|
metric: str,
|
||||||
@@ -315,32 +311,32 @@ class ExperimentAnalyzer:
|
|||||||
) -> Optional[Run]:
|
) -> Optional[Run]:
|
||||||
"""
|
"""
|
||||||
Get the best run by a specific metric.
|
Get the best run by a specific metric.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
metric: Metric name to optimize
|
metric: Metric name to optimize
|
||||||
minimize: If True, find minimum; if False, find maximum
|
minimize: If True, find minimum; if False, find maximum
|
||||||
filter_string: Additional filter criteria
|
filter_string: Additional filter criteria
|
||||||
max_results: Maximum runs to consider
|
max_results: Maximum runs to consider
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Best Run object, or None if no runs found
|
Best Run object, or None if no runs found
|
||||||
"""
|
"""
|
||||||
direction = "ASC" if minimize else "DESC"
|
direction = "ASC" if minimize else "DESC"
|
||||||
|
|
||||||
runs = self.search_runs(
|
runs = self.search_runs(
|
||||||
filter_string=filter_string,
|
filter_string=filter_string,
|
||||||
order_by=[f"metrics.{metric} {direction}"],
|
order_by=[f"metrics.{metric} {direction}"],
|
||||||
max_results=max_results,
|
max_results=max_results,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Filter to only runs that have the metric
|
# Filter to only runs that have the metric
|
||||||
runs_with_metric = [
|
runs_with_metric = [
|
||||||
r for r in runs
|
r for r in runs
|
||||||
if metric in r.data.metrics
|
if metric in r.data.metrics
|
||||||
]
|
]
|
||||||
|
|
||||||
return runs_with_metric[0] if runs_with_metric else None
|
return runs_with_metric[0] if runs_with_metric else None
|
||||||
|
|
||||||
def get_metrics_summary(
|
def get_metrics_summary(
|
||||||
self,
|
self,
|
||||||
hours: Optional[int] = None,
|
hours: Optional[int] = None,
|
||||||
@@ -348,45 +344,45 @@ class ExperimentAnalyzer:
|
|||||||
) -> Dict[str, Dict[str, float]]:
|
) -> Dict[str, Dict[str, float]]:
|
||||||
"""
|
"""
|
||||||
Get summary statistics for metrics.
|
Get summary statistics for metrics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hours: Only include runs from the last N hours
|
hours: Only include runs from the last N hours
|
||||||
metrics: Specific metrics to summarize (None for all)
|
metrics: Specific metrics to summarize (None for all)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict mapping metric names to {mean, min, max, count}
|
Dict mapping metric names to {mean, min, max, count}
|
||||||
"""
|
"""
|
||||||
import statistics
|
import statistics
|
||||||
|
|
||||||
runs = self.get_recent_runs(n=1000, hours=hours)
|
runs = self.get_recent_runs(n=1000, hours=hours)
|
||||||
|
|
||||||
# Collect all metric values
|
# Collect all metric values
|
||||||
metric_values: Dict[str, List[float]] = defaultdict(list)
|
metric_values: Dict[str, List[float]] = defaultdict(list)
|
||||||
|
|
||||||
for run in runs:
|
for run in runs:
|
||||||
for key, value in run.data.metrics.items():
|
for key, value in run.data.metrics.items():
|
||||||
if metrics is None or key in metrics:
|
if metrics is None or key in metrics:
|
||||||
metric_values[key].append(value)
|
metric_values[key].append(value)
|
||||||
|
|
||||||
# Calculate statistics
|
# Calculate statistics
|
||||||
summary = {}
|
summary = {}
|
||||||
for metric_name, values in metric_values.items():
|
for metric_name, values in metric_values.items():
|
||||||
if not values:
|
if not values:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
summary[metric_name] = {
|
summary[metric_name] = {
|
||||||
"mean": statistics.mean(values),
|
"mean": statistics.mean(values),
|
||||||
"min": min(values),
|
"min": min(values),
|
||||||
"max": max(values),
|
"max": max(values),
|
||||||
"count": len(values),
|
"count": len(values),
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(values) >= 2:
|
if len(values) >= 2:
|
||||||
summary[metric_name]["stdev"] = statistics.stdev(values)
|
summary[metric_name]["stdev"] = statistics.stdev(values)
|
||||||
summary[metric_name]["median"] = statistics.median(values)
|
summary[metric_name]["median"] = statistics.median(values)
|
||||||
|
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
def get_metric_trends(
|
def get_metric_trends(
|
||||||
self,
|
self,
|
||||||
metric: str,
|
metric: str,
|
||||||
@@ -395,30 +391,30 @@ class ExperimentAnalyzer:
|
|||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Get metric trends over time.
|
Get metric trends over time.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
metric: Metric name to track
|
metric: Metric name to track
|
||||||
days: Number of days to look back
|
days: Number of days to look back
|
||||||
granularity_hours: Time bucket size in hours
|
granularity_hours: Time bucket size in hours
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of {timestamp, mean, min, max, count} dicts
|
List of {timestamp, mean, min, max, count} dicts
|
||||||
"""
|
"""
|
||||||
import statistics
|
import statistics
|
||||||
|
|
||||||
runs = self.get_recent_runs(n=10000, hours=days * 24)
|
runs = self.get_recent_runs(n=10000, hours=days * 24)
|
||||||
|
|
||||||
# Group runs by time bucket
|
# Group runs by time bucket
|
||||||
buckets: Dict[int, List[float]] = defaultdict(list)
|
buckets: Dict[int, List[float]] = defaultdict(list)
|
||||||
bucket_size_ms = granularity_hours * 3600 * 1000
|
bucket_size_ms = granularity_hours * 3600 * 1000
|
||||||
|
|
||||||
for run in runs:
|
for run in runs:
|
||||||
if metric not in run.data.metrics:
|
if metric not in run.data.metrics:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
bucket = (run.info.start_time // bucket_size_ms) * bucket_size_ms
|
bucket = (run.info.start_time // bucket_size_ms) * bucket_size_ms
|
||||||
buckets[bucket].append(run.data.metrics[metric])
|
buckets[bucket].append(run.data.metrics[metric])
|
||||||
|
|
||||||
# Calculate statistics per bucket
|
# Calculate statistics per bucket
|
||||||
trends = []
|
trends = []
|
||||||
for bucket_ts, values in sorted(buckets.items()):
|
for bucket_ts, values in sorted(buckets.items()):
|
||||||
@@ -432,9 +428,9 @@ class ExperimentAnalyzer:
|
|||||||
if len(values) >= 2:
|
if len(values) >= 2:
|
||||||
trend["stdev"] = statistics.stdev(values)
|
trend["stdev"] = statistics.stdev(values)
|
||||||
trends.append(trend)
|
trends.append(trend)
|
||||||
|
|
||||||
return trends
|
return trends
|
||||||
|
|
||||||
def get_runs_by_tag(
|
def get_runs_by_tag(
|
||||||
self,
|
self,
|
||||||
tag_key: str,
|
tag_key: str,
|
||||||
@@ -443,12 +439,12 @@ class ExperimentAnalyzer:
|
|||||||
) -> List[Run]:
|
) -> List[Run]:
|
||||||
"""
|
"""
|
||||||
Get runs with a specific tag.
|
Get runs with a specific tag.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tag_key: Tag key to filter by
|
tag_key: Tag key to filter by
|
||||||
tag_value: Tag value to match
|
tag_value: Tag value to match
|
||||||
max_results: Maximum runs to return
|
max_results: Maximum runs to return
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of matching Run objects
|
List of matching Run objects
|
||||||
"""
|
"""
|
||||||
@@ -456,7 +452,7 @@ class ExperimentAnalyzer:
|
|||||||
filter_string=f"tags.{tag_key} = '{tag_value}'",
|
filter_string=f"tags.{tag_key} = '{tag_value}'",
|
||||||
max_results=max_results,
|
max_results=max_results,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_model_runs(
|
def get_model_runs(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@@ -464,11 +460,11 @@ class ExperimentAnalyzer:
|
|||||||
) -> List[Run]:
|
) -> List[Run]:
|
||||||
"""
|
"""
|
||||||
Get runs for a specific model.
|
Get runs for a specific model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Model name to filter by
|
model_name: Model name to filter by
|
||||||
max_results: Maximum runs to return
|
max_results: Maximum runs to return
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of matching Run objects
|
List of matching Run objects
|
||||||
"""
|
"""
|
||||||
@@ -477,14 +473,14 @@ class ExperimentAnalyzer:
|
|||||||
filter_string=f"tags.`model.name` = '{model_name}'",
|
filter_string=f"tags.`model.name` = '{model_name}'",
|
||||||
max_results=max_results,
|
max_results=max_results,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not runs:
|
if not runs:
|
||||||
# Try params
|
# Try params
|
||||||
runs = self.search_runs(
|
runs = self.search_runs(
|
||||||
filter_string=f"params.model_name = '{model_name}'",
|
filter_string=f"params.model_name = '{model_name}'",
|
||||||
max_results=max_results,
|
max_results=max_results,
|
||||||
)
|
)
|
||||||
|
|
||||||
return runs
|
return runs
|
||||||
|
|
||||||
|
|
||||||
@@ -495,23 +491,23 @@ def compare_experiments(
|
|||||||
) -> Dict[str, Dict[str, float]]:
|
) -> Dict[str, Dict[str, float]]:
|
||||||
"""
|
"""
|
||||||
Compare metrics across multiple experiments.
|
Compare metrics across multiple experiments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
experiment_names: Names of experiments to compare
|
experiment_names: Names of experiments to compare
|
||||||
metric: Metric to compare
|
metric: Metric to compare
|
||||||
tracking_uri: Override default tracking URI
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict mapping experiment names to metric statistics
|
Dict mapping experiment names to metric statistics
|
||||||
"""
|
"""
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
for exp_name in experiment_names:
|
for exp_name in experiment_names:
|
||||||
analyzer = ExperimentAnalyzer(exp_name, tracking_uri=tracking_uri)
|
analyzer = ExperimentAnalyzer(exp_name, tracking_uri=tracking_uri)
|
||||||
summary = analyzer.get_metrics_summary(metrics=[metric])
|
summary = analyzer.get_metrics_summary(metrics=[metric])
|
||||||
if metric in summary:
|
if metric in summary:
|
||||||
results[exp_name] = summary[metric]
|
results[exp_name] = summary[metric]
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@@ -523,7 +519,7 @@ def promotion_recommendation(
|
|||||||
) -> PromotionRecommendation:
|
) -> PromotionRecommendation:
|
||||||
"""
|
"""
|
||||||
Generate a recommendation for model promotion.
|
Generate a recommendation for model promotion.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Name of the model to evaluate
|
model_name: Name of the model to evaluate
|
||||||
experiment_name: Experiment containing evaluation runs
|
experiment_name: Experiment containing evaluation runs
|
||||||
@@ -531,15 +527,15 @@ def promotion_recommendation(
|
|||||||
comparison is one of: ">=", "<=", ">", "<"
|
comparison is one of: ">=", "<=", ">", "<"
|
||||||
e.g., {"eval.accuracy": (">=", 0.9), "total_latency_p95": ("<=", 2.0)}
|
e.g., {"eval.accuracy": (">=", 0.9), "total_latency_p95": ("<=", 2.0)}
|
||||||
tracking_uri: Override default tracking URI
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
PromotionRecommendation with decision and reasons
|
PromotionRecommendation with decision and reasons
|
||||||
"""
|
"""
|
||||||
analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri)
|
analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri)
|
||||||
|
|
||||||
# Get model runs
|
# Get model runs
|
||||||
runs = analyzer.get_model_runs(model_name, max_results=10)
|
runs = analyzer.get_model_runs(model_name, max_results=10)
|
||||||
|
|
||||||
if not runs:
|
if not runs:
|
||||||
return PromotionRecommendation(
|
return PromotionRecommendation(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@@ -548,41 +544,41 @@ def promotion_recommendation(
|
|||||||
reasons=["No runs found for this model"],
|
reasons=["No runs found for this model"],
|
||||||
metrics_summary={},
|
metrics_summary={},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the most recent run
|
# Get the most recent run
|
||||||
latest_run = runs[0]
|
latest_run = runs[0]
|
||||||
metrics = latest_run.data.metrics
|
metrics = latest_run.data.metrics
|
||||||
|
|
||||||
# Evaluate criteria
|
# Evaluate criteria
|
||||||
reasons = []
|
reasons = []
|
||||||
passed = True
|
passed = True
|
||||||
|
|
||||||
comparisons = {
|
comparisons = {
|
||||||
">=": lambda a, b: a >= b,
|
">=": lambda a, b: a >= b,
|
||||||
"<=": lambda a, b: a <= b,
|
"<=": lambda a, b: a <= b,
|
||||||
">": lambda a, b: a > b,
|
">": lambda a, b: a > b,
|
||||||
"<": lambda a, b: a < b,
|
"<": lambda a, b: a < b,
|
||||||
}
|
}
|
||||||
|
|
||||||
for metric_name, (comparison, threshold) in criteria.items():
|
for metric_name, (comparison, threshold) in criteria.items():
|
||||||
if metric_name not in metrics:
|
if metric_name not in metrics:
|
||||||
reasons.append(f"Metric '{metric_name}' not found")
|
reasons.append(f"Metric '{metric_name}' not found")
|
||||||
passed = False
|
passed = False
|
||||||
continue
|
continue
|
||||||
|
|
||||||
value = metrics[metric_name]
|
value = metrics[metric_name]
|
||||||
compare_fn = comparisons.get(comparison)
|
compare_fn = comparisons.get(comparison)
|
||||||
|
|
||||||
if compare_fn is None:
|
if compare_fn is None:
|
||||||
reasons.append(f"Invalid comparison operator: {comparison}")
|
reasons.append(f"Invalid comparison operator: {comparison}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if compare_fn(value, threshold):
|
if compare_fn(value, threshold):
|
||||||
reasons.append(f"✓ {metric_name}: {value:.4f} {comparison} {threshold}")
|
reasons.append(f"✓ {metric_name}: {value:.4f} {comparison} {threshold}")
|
||||||
else:
|
else:
|
||||||
reasons.append(f"✗ {metric_name}: {value:.4f} NOT {comparison} {threshold}")
|
reasons.append(f"✗ {metric_name}: {value:.4f} NOT {comparison} {threshold}")
|
||||||
passed = False
|
passed = False
|
||||||
|
|
||||||
# Extract version from tags if available
|
# Extract version from tags if available
|
||||||
version = None
|
version = None
|
||||||
if "mlflow.version" in latest_run.data.tags:
|
if "mlflow.version" in latest_run.data.tags:
|
||||||
@@ -590,7 +586,7 @@ def promotion_recommendation(
|
|||||||
version = int(latest_run.data.tags["mlflow.version"])
|
version = int(latest_run.data.tags["mlflow.version"])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return PromotionRecommendation(
|
return PromotionRecommendation(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
version=version,
|
version=version,
|
||||||
@@ -607,21 +603,21 @@ def get_inference_performance_report(
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Generate an inference performance report for a service.
|
Generate an inference performance report for a service.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service_name: Service name (chat-handler, voice-assistant)
|
service_name: Service name (chat-handler, voice-assistant)
|
||||||
hours: Hours of data to analyze
|
hours: Hours of data to analyze
|
||||||
tracking_uri: Override default tracking URI
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Performance report dictionary
|
Performance report dictionary
|
||||||
"""
|
"""
|
||||||
experiment_name = f"{service_name.replace('-', '')}-inference"
|
experiment_name = f"{service_name.replace('-', '')}-inference"
|
||||||
analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri)
|
analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri)
|
||||||
|
|
||||||
# Get summary metrics
|
# Get summary metrics
|
||||||
summary = analyzer.get_metrics_summary(hours=hours)
|
summary = analyzer.get_metrics_summary(hours=hours)
|
||||||
|
|
||||||
# Key latency metrics
|
# Key latency metrics
|
||||||
latency_metrics = [
|
latency_metrics = [
|
||||||
"total_latency_mean",
|
"total_latency_mean",
|
||||||
@@ -631,7 +627,7 @@ def get_inference_performance_report(
|
|||||||
"embedding_latency_mean",
|
"embedding_latency_mean",
|
||||||
"rag_search_latency_mean",
|
"rag_search_latency_mean",
|
||||||
]
|
]
|
||||||
|
|
||||||
report = {
|
report = {
|
||||||
"service": service_name,
|
"service": service_name,
|
||||||
"period_hours": hours,
|
"period_hours": hours,
|
||||||
@@ -641,24 +637,24 @@ def get_inference_performance_report(
|
|||||||
"rag": {},
|
"rag": {},
|
||||||
"errors": {},
|
"errors": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Latency section
|
# Latency section
|
||||||
for metric in latency_metrics:
|
for metric in latency_metrics:
|
||||||
if metric in summary:
|
if metric in summary:
|
||||||
report["latency"][metric] = summary[metric]
|
report["latency"][metric] = summary[metric]
|
||||||
|
|
||||||
# Throughput
|
# Throughput
|
||||||
if "total_requests" in summary:
|
if "total_requests" in summary:
|
||||||
report["throughput"]["total_requests"] = summary["total_requests"]["mean"]
|
report["throughput"]["total_requests"] = summary["total_requests"]["mean"]
|
||||||
|
|
||||||
# RAG usage
|
# RAG usage
|
||||||
rag_metrics = ["rag_enabled_pct", "rag_documents_retrieved_mean", "rag_documents_used_mean"]
|
rag_metrics = ["rag_enabled_pct", "rag_documents_retrieved_mean", "rag_documents_used_mean"]
|
||||||
for metric in rag_metrics:
|
for metric in rag_metrics:
|
||||||
if metric in summary:
|
if metric in summary:
|
||||||
report["rag"][metric] = summary[metric]
|
report["rag"][metric] = summary[metric]
|
||||||
|
|
||||||
# Error rate
|
# Error rate
|
||||||
if "error_rate" in summary:
|
if "error_rate" in summary:
|
||||||
report["errors"]["error_rate_pct"] = summary["error_rate"]["mean"]
|
report["errors"]["error_rate_pct"] = summary["error_rate"]["mean"]
|
||||||
|
|
||||||
return report
|
return report
|
||||||
|
|||||||
@@ -9,20 +9,19 @@ complement OTel metrics with MLflow experiment tracking for
|
|||||||
longer-term analysis and model comparison.
|
longer-term analysis and model comparison.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Dict, Any, List
|
import time
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import mlflow
|
import mlflow
|
||||||
from mlflow.tracking import MlflowClient
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
from .client import get_mlflow_client, ensure_experiment, MLflowConfig
|
from .client import MLflowConfig, ensure_experiment, get_mlflow_client
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -33,7 +32,7 @@ class InferenceMetrics:
|
|||||||
request_id: str
|
request_id: str
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
session_id: Optional[str] = None
|
session_id: Optional[str] = None
|
||||||
|
|
||||||
# Timing metrics (in seconds)
|
# Timing metrics (in seconds)
|
||||||
total_latency: float = 0.0
|
total_latency: float = 0.0
|
||||||
embedding_latency: float = 0.0
|
embedding_latency: float = 0.0
|
||||||
@@ -42,33 +41,33 @@ class InferenceMetrics:
|
|||||||
llm_latency: float = 0.0
|
llm_latency: float = 0.0
|
||||||
tts_latency: float = 0.0
|
tts_latency: float = 0.0
|
||||||
stt_latency: float = 0.0
|
stt_latency: float = 0.0
|
||||||
|
|
||||||
# Token/size metrics
|
# Token/size metrics
|
||||||
input_tokens: int = 0
|
input_tokens: int = 0
|
||||||
output_tokens: int = 0
|
output_tokens: int = 0
|
||||||
total_tokens: int = 0
|
total_tokens: int = 0
|
||||||
prompt_length: int = 0
|
prompt_length: int = 0
|
||||||
response_length: int = 0
|
response_length: int = 0
|
||||||
|
|
||||||
# RAG metrics
|
# RAG metrics
|
||||||
rag_enabled: bool = False
|
rag_enabled: bool = False
|
||||||
rag_documents_retrieved: int = 0
|
rag_documents_retrieved: int = 0
|
||||||
rag_documents_used: int = 0
|
rag_documents_used: int = 0
|
||||||
reranker_enabled: bool = False
|
reranker_enabled: bool = False
|
||||||
|
|
||||||
# Quality indicators
|
# Quality indicators
|
||||||
is_streaming: bool = False
|
is_streaming: bool = False
|
||||||
is_premium: bool = False
|
is_premium: bool = False
|
||||||
has_error: bool = False
|
has_error: bool = False
|
||||||
error_message: Optional[str] = None
|
error_message: Optional[str] = None
|
||||||
|
|
||||||
# Model information
|
# Model information
|
||||||
model_name: Optional[str] = None
|
model_name: Optional[str] = None
|
||||||
model_endpoint: Optional[str] = None
|
model_endpoint: Optional[str] = None
|
||||||
|
|
||||||
# Timestamps
|
# Timestamps
|
||||||
timestamp: float = field(default_factory=time.time)
|
timestamp: float = field(default_factory=time.time)
|
||||||
|
|
||||||
def as_metrics_dict(self) -> Dict[str, float]:
|
def as_metrics_dict(self) -> Dict[str, float]:
|
||||||
"""Convert numeric fields to a metrics dictionary."""
|
"""Convert numeric fields to a metrics dictionary."""
|
||||||
return {
|
return {
|
||||||
@@ -87,7 +86,7 @@ class InferenceMetrics:
|
|||||||
"rag_documents_retrieved": float(self.rag_documents_retrieved),
|
"rag_documents_retrieved": float(self.rag_documents_retrieved),
|
||||||
"rag_documents_used": float(self.rag_documents_used),
|
"rag_documents_used": float(self.rag_documents_used),
|
||||||
}
|
}
|
||||||
|
|
||||||
def as_params_dict(self) -> Dict[str, str]:
|
def as_params_dict(self) -> Dict[str, str]:
|
||||||
"""Convert configuration fields to a params dictionary."""
|
"""Convert configuration fields to a params dictionary."""
|
||||||
params = {
|
params = {
|
||||||
@@ -106,39 +105,39 @@ class InferenceMetrics:
|
|||||||
class InferenceMetricsTracker:
|
class InferenceMetricsTracker:
|
||||||
"""
|
"""
|
||||||
Async-compatible MLflow tracker for inference metrics.
|
Async-compatible MLflow tracker for inference metrics.
|
||||||
|
|
||||||
Uses batching and a background thread pool to avoid blocking
|
Uses batching and a background thread pool to avoid blocking
|
||||||
the async event loop during MLflow calls.
|
the async event loop during MLflow calls.
|
||||||
|
|
||||||
Example usage in chat-handler:
|
Example usage in chat-handler:
|
||||||
|
|
||||||
class ChatHandler:
|
class ChatHandler:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.mlflow_tracker = InferenceMetricsTracker(
|
self.mlflow_tracker = InferenceMetricsTracker(
|
||||||
service_name="chat-handler",
|
service_name="chat-handler",
|
||||||
experiment_name="chat-inference"
|
experiment_name="chat-inference"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
await self.mlflow_tracker.start()
|
await self.mlflow_tracker.start()
|
||||||
|
|
||||||
async def process_request(self, msg):
|
async def process_request(self, msg):
|
||||||
metrics = InferenceMetrics(request_id=request_id)
|
metrics = InferenceMetrics(request_id=request_id)
|
||||||
|
|
||||||
# Track timing
|
# Track timing
|
||||||
start = time.time()
|
start = time.time()
|
||||||
# ... do embedding ...
|
# ... do embedding ...
|
||||||
metrics.embedding_latency = time.time() - start
|
metrics.embedding_latency = time.time() - start
|
||||||
|
|
||||||
# ... more processing ...
|
# ... more processing ...
|
||||||
|
|
||||||
# Log metrics (non-blocking)
|
# Log metrics (non-blocking)
|
||||||
await self.mlflow_tracker.log_inference(metrics)
|
await self.mlflow_tracker.log_inference(metrics)
|
||||||
|
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
await self.mlflow_tracker.stop()
|
await self.mlflow_tracker.stop()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
service_name: str,
|
service_name: str,
|
||||||
@@ -151,7 +150,7 @@ class InferenceMetricsTracker:
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the inference metrics tracker.
|
Initialize the inference metrics tracker.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service_name: Name of the service (e.g., "chat-handler")
|
service_name: Name of the service (e.g., "chat-handler")
|
||||||
experiment_name: MLflow experiment name (defaults to service_name)
|
experiment_name: MLflow experiment name (defaults to service_name)
|
||||||
@@ -167,7 +166,7 @@ class InferenceMetricsTracker:
|
|||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.flush_interval = flush_interval_seconds
|
self.flush_interval = flush_interval_seconds
|
||||||
self.enable_batching = enable_batching
|
self.enable_batching = enable_batching
|
||||||
|
|
||||||
self.config = MLflowConfig()
|
self.config = MLflowConfig()
|
||||||
self._batch: List[InferenceMetrics] = []
|
self._batch: List[InferenceMetrics] = []
|
||||||
self._batch_lock = asyncio.Lock()
|
self._batch_lock = asyncio.Lock()
|
||||||
@@ -176,34 +175,34 @@ class InferenceMetricsTracker:
|
|||||||
self._running = False
|
self._running = False
|
||||||
self._client: Optional[MlflowClient] = None
|
self._client: Optional[MlflowClient] = None
|
||||||
self._experiment_id: Optional[str] = None
|
self._experiment_id: Optional[str] = None
|
||||||
|
|
||||||
# Aggregate metrics for periodic logging
|
# Aggregate metrics for periodic logging
|
||||||
self._aggregate_metrics: Dict[str, List[float]] = defaultdict(list)
|
self._aggregate_metrics: Dict[str, List[float]] = defaultdict(list)
|
||||||
self._request_count = 0
|
self._request_count = 0
|
||||||
self._error_count = 0
|
self._error_count = 0
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the tracker and initialize MLflow connection."""
|
"""Start the tracker and initialize MLflow connection."""
|
||||||
if self._running:
|
if self._running:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
|
|
||||||
# Initialize MLflow in thread pool to avoid blocking
|
# Initialize MLflow in thread pool to avoid blocking
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
await loop.run_in_executor(
|
await loop.run_in_executor(
|
||||||
self._executor,
|
self._executor,
|
||||||
self._init_mlflow
|
self._init_mlflow
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.enable_batching:
|
if self.enable_batching:
|
||||||
self._flush_task = asyncio.create_task(self._periodic_flush())
|
self._flush_task = asyncio.create_task(self._periodic_flush())
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"InferenceMetricsTracker started for {self.service_name} "
|
f"InferenceMetricsTracker started for {self.service_name} "
|
||||||
f"(experiment: {self.experiment_name})"
|
f"(experiment: {self.experiment_name})"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_mlflow(self) -> None:
|
def _init_mlflow(self) -> None:
|
||||||
"""Initialize MLflow client and experiment (runs in thread pool)."""
|
"""Initialize MLflow client and experiment (runs in thread pool)."""
|
||||||
self._client = get_mlflow_client(
|
self._client = get_mlflow_client(
|
||||||
@@ -217,47 +216,47 @@ class InferenceMetricsTracker:
|
|||||||
"type": "inference-metrics",
|
"type": "inference-metrics",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""Stop the tracker and flush remaining metrics."""
|
"""Stop the tracker and flush remaining metrics."""
|
||||||
if not self._running:
|
if not self._running:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
if self._flush_task:
|
if self._flush_task:
|
||||||
self._flush_task.cancel()
|
self._flush_task.cancel()
|
||||||
try:
|
try:
|
||||||
await self._flush_task
|
await self._flush_task
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Final flush
|
# Final flush
|
||||||
await self._flush_batch()
|
await self._flush_batch()
|
||||||
|
|
||||||
self._executor.shutdown(wait=True)
|
self._executor.shutdown(wait=True)
|
||||||
logger.info(f"InferenceMetricsTracker stopped for {self.service_name}")
|
logger.info(f"InferenceMetricsTracker stopped for {self.service_name}")
|
||||||
|
|
||||||
async def log_inference(self, metrics: InferenceMetrics) -> None:
|
async def log_inference(self, metrics: InferenceMetrics) -> None:
|
||||||
"""
|
"""
|
||||||
Log inference metrics (non-blocking).
|
Log inference metrics (non-blocking).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
metrics: InferenceMetrics object with request data
|
metrics: InferenceMetrics object with request data
|
||||||
"""
|
"""
|
||||||
if not self._running:
|
if not self._running:
|
||||||
logger.warning("Tracker not running, skipping metrics")
|
logger.warning("Tracker not running, skipping metrics")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._request_count += 1
|
self._request_count += 1
|
||||||
if metrics.has_error:
|
if metrics.has_error:
|
||||||
self._error_count += 1
|
self._error_count += 1
|
||||||
|
|
||||||
# Update aggregates
|
# Update aggregates
|
||||||
for key, value in metrics.as_metrics_dict().items():
|
for key, value in metrics.as_metrics_dict().items():
|
||||||
if value > 0:
|
if value > 0:
|
||||||
self._aggregate_metrics[key].append(value)
|
self._aggregate_metrics[key].append(value)
|
||||||
|
|
||||||
if self.enable_batching:
|
if self.enable_batching:
|
||||||
async with self._batch_lock:
|
async with self._batch_lock:
|
||||||
self._batch.append(metrics)
|
self._batch.append(metrics)
|
||||||
@@ -270,29 +269,29 @@ class InferenceMetricsTracker:
|
|||||||
self._executor,
|
self._executor,
|
||||||
partial(self._log_single_inference, metrics)
|
partial(self._log_single_inference, metrics)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _periodic_flush(self) -> None:
|
async def _periodic_flush(self) -> None:
|
||||||
"""Periodically flush batched metrics."""
|
"""Periodically flush batched metrics."""
|
||||||
while self._running:
|
while self._running:
|
||||||
await asyncio.sleep(self.flush_interval)
|
await asyncio.sleep(self.flush_interval)
|
||||||
await self._flush_batch()
|
await self._flush_batch()
|
||||||
|
|
||||||
async def _flush_batch(self) -> None:
|
async def _flush_batch(self) -> None:
|
||||||
"""Flush the current batch of metrics to MLflow."""
|
"""Flush the current batch of metrics to MLflow."""
|
||||||
async with self._batch_lock:
|
async with self._batch_lock:
|
||||||
if not self._batch:
|
if not self._batch:
|
||||||
return
|
return
|
||||||
|
|
||||||
batch = self._batch
|
batch = self._batch
|
||||||
self._batch = []
|
self._batch = []
|
||||||
|
|
||||||
# Log in thread pool
|
# Log in thread pool
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
await loop.run_in_executor(
|
await loop.run_in_executor(
|
||||||
self._executor,
|
self._executor,
|
||||||
partial(self._log_batch, batch)
|
partial(self._log_batch, batch)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _log_single_inference(self, metrics: InferenceMetrics) -> None:
|
def _log_single_inference(self, metrics: InferenceMetrics) -> None:
|
||||||
"""Log a single inference request to MLflow (runs in thread pool)."""
|
"""Log a single inference request to MLflow (runs in thread pool)."""
|
||||||
try:
|
try:
|
||||||
@@ -307,7 +306,7 @@ class InferenceMetricsTracker:
|
|||||||
):
|
):
|
||||||
mlflow.log_params(metrics.as_params_dict())
|
mlflow.log_params(metrics.as_params_dict())
|
||||||
mlflow.log_metrics(metrics.as_metrics_dict())
|
mlflow.log_metrics(metrics.as_metrics_dict())
|
||||||
|
|
||||||
if metrics.user_id:
|
if metrics.user_id:
|
||||||
mlflow.set_tag("user_id", metrics.user_id)
|
mlflow.set_tag("user_id", metrics.user_id)
|
||||||
if metrics.session_id:
|
if metrics.session_id:
|
||||||
@@ -318,18 +317,18 @@ class InferenceMetricsTracker:
|
|||||||
mlflow.set_tag("error_message", metrics.error_message[:250])
|
mlflow.set_tag("error_message", metrics.error_message[:250])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to log inference metrics: {e}")
|
logger.error(f"Failed to log inference metrics: {e}")
|
||||||
|
|
||||||
def _log_batch(self, batch: List[InferenceMetrics]) -> None:
|
def _log_batch(self, batch: List[InferenceMetrics]) -> None:
|
||||||
"""Log a batch of inference metrics as aggregate statistics."""
|
"""Log a batch of inference metrics as aggregate statistics."""
|
||||||
if not batch:
|
if not batch:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Calculate aggregates
|
# Calculate aggregates
|
||||||
aggregates = self._calculate_aggregates(batch)
|
aggregates = self._calculate_aggregates(batch)
|
||||||
|
|
||||||
run_name = f"batch-{self.service_name}-{int(time.time())}"
|
run_name = f"batch-{self.service_name}-{int(time.time())}"
|
||||||
|
|
||||||
with mlflow.start_run(
|
with mlflow.start_run(
|
||||||
experiment_id=self._experiment_id,
|
experiment_id=self._experiment_id,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
@@ -341,83 +340,83 @@ class InferenceMetricsTracker:
|
|||||||
):
|
):
|
||||||
# Log aggregate metrics
|
# Log aggregate metrics
|
||||||
mlflow.log_metrics(aggregates)
|
mlflow.log_metrics(aggregates)
|
||||||
|
|
||||||
# Log batch info
|
# Log batch info
|
||||||
mlflow.log_param("batch_size", len(batch))
|
mlflow.log_param("batch_size", len(batch))
|
||||||
mlflow.log_param("time_window_start", min(m.timestamp for m in batch))
|
mlflow.log_param("time_window_start", min(m.timestamp for m in batch))
|
||||||
mlflow.log_param("time_window_end", max(m.timestamp for m in batch))
|
mlflow.log_param("time_window_end", max(m.timestamp for m in batch))
|
||||||
|
|
||||||
# Log configuration breakdown
|
# Log configuration breakdown
|
||||||
rag_enabled_count = sum(1 for m in batch if m.rag_enabled)
|
rag_enabled_count = sum(1 for m in batch if m.rag_enabled)
|
||||||
streaming_count = sum(1 for m in batch if m.is_streaming)
|
streaming_count = sum(1 for m in batch if m.is_streaming)
|
||||||
premium_count = sum(1 for m in batch if m.is_premium)
|
premium_count = sum(1 for m in batch if m.is_premium)
|
||||||
error_count = sum(1 for m in batch if m.has_error)
|
error_count = sum(1 for m in batch if m.has_error)
|
||||||
|
|
||||||
mlflow.log_metrics({
|
mlflow.log_metrics({
|
||||||
"rag_enabled_pct": rag_enabled_count / len(batch) * 100,
|
"rag_enabled_pct": rag_enabled_count / len(batch) * 100,
|
||||||
"streaming_pct": streaming_count / len(batch) * 100,
|
"streaming_pct": streaming_count / len(batch) * 100,
|
||||||
"premium_pct": premium_count / len(batch) * 100,
|
"premium_pct": premium_count / len(batch) * 100,
|
||||||
"error_rate": error_count / len(batch) * 100,
|
"error_rate": error_count / len(batch) * 100,
|
||||||
})
|
})
|
||||||
|
|
||||||
# Log model distribution
|
# Log model distribution
|
||||||
model_counts: Dict[str, int] = defaultdict(int)
|
model_counts: Dict[str, int] = defaultdict(int)
|
||||||
for m in batch:
|
for m in batch:
|
||||||
if m.model_name:
|
if m.model_name:
|
||||||
model_counts[m.model_name] += 1
|
model_counts[m.model_name] += 1
|
||||||
|
|
||||||
if model_counts:
|
if model_counts:
|
||||||
mlflow.log_dict(
|
mlflow.log_dict(
|
||||||
{"models": dict(model_counts)},
|
{"models": dict(model_counts)},
|
||||||
"model_distribution.json"
|
"model_distribution.json"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Logged batch of {len(batch)} inference metrics")
|
logger.info(f"Logged batch of {len(batch)} inference metrics")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to log batch metrics: {e}")
|
logger.error(f"Failed to log batch metrics: {e}")
|
||||||
|
|
||||||
def _calculate_aggregates(
|
def _calculate_aggregates(
|
||||||
self,
|
self,
|
||||||
batch: List[InferenceMetrics]
|
batch: List[InferenceMetrics]
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
"""Calculate aggregate statistics from a batch of metrics."""
|
"""Calculate aggregate statistics from a batch of metrics."""
|
||||||
import statistics
|
import statistics
|
||||||
|
|
||||||
aggregates = {}
|
aggregates = {}
|
||||||
|
|
||||||
# Collect all numeric metrics
|
# Collect all numeric metrics
|
||||||
metric_values: Dict[str, List[float]] = defaultdict(list)
|
metric_values: Dict[str, List[float]] = defaultdict(list)
|
||||||
for m in batch:
|
for m in batch:
|
||||||
for key, value in m.as_metrics_dict().items():
|
for key, value in m.as_metrics_dict().items():
|
||||||
if value > 0:
|
if value > 0:
|
||||||
metric_values[key].append(value)
|
metric_values[key].append(value)
|
||||||
|
|
||||||
# Calculate statistics for each metric
|
# Calculate statistics for each metric
|
||||||
for key, values in metric_values.items():
|
for key, values in metric_values.items():
|
||||||
if not values:
|
if not values:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
aggregates[f"{key}_mean"] = statistics.mean(values)
|
aggregates[f"{key}_mean"] = statistics.mean(values)
|
||||||
aggregates[f"{key}_min"] = min(values)
|
aggregates[f"{key}_min"] = min(values)
|
||||||
aggregates[f"{key}_max"] = max(values)
|
aggregates[f"{key}_max"] = max(values)
|
||||||
|
|
||||||
if len(values) >= 2:
|
if len(values) >= 2:
|
||||||
aggregates[f"{key}_p50"] = statistics.median(values)
|
aggregates[f"{key}_p50"] = statistics.median(values)
|
||||||
aggregates[f"{key}_stdev"] = statistics.stdev(values)
|
aggregates[f"{key}_stdev"] = statistics.stdev(values)
|
||||||
|
|
||||||
if len(values) >= 4:
|
if len(values) >= 4:
|
||||||
sorted_vals = sorted(values)
|
sorted_vals = sorted(values)
|
||||||
p95_idx = int(len(sorted_vals) * 0.95)
|
p95_idx = int(len(sorted_vals) * 0.95)
|
||||||
p99_idx = int(len(sorted_vals) * 0.99)
|
p99_idx = int(len(sorted_vals) * 0.99)
|
||||||
aggregates[f"{key}_p95"] = sorted_vals[p95_idx]
|
aggregates[f"{key}_p95"] = sorted_vals[p95_idx]
|
||||||
aggregates[f"{key}_p99"] = sorted_vals[p99_idx]
|
aggregates[f"{key}_p99"] = sorted_vals[p99_idx]
|
||||||
|
|
||||||
# Add counts
|
# Add counts
|
||||||
aggregates["total_requests"] = float(len(batch))
|
aggregates["total_requests"] = float(len(batch))
|
||||||
|
|
||||||
return aggregates
|
return aggregates
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
"""Get current tracker statistics."""
|
"""Get current tracker statistics."""
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -21,22 +21,22 @@ Usage in a Kubeflow Pipeline:
|
|||||||
experiment_name="my-experiment",
|
experiment_name="my-experiment",
|
||||||
run_name="training-run-1"
|
run_name="training-run-1"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ... your pipeline steps ...
|
# ... your pipeline steps ...
|
||||||
|
|
||||||
# Log metrics
|
# Log metrics
|
||||||
log_step = log_metrics_component(
|
log_step = log_metrics_component(
|
||||||
run_id=run_info.outputs["run_id"],
|
run_id=run_info.outputs["run_id"],
|
||||||
metrics={"accuracy": 0.95, "loss": 0.05}
|
metrics={"accuracy": 0.95, "loss": 0.05}
|
||||||
)
|
)
|
||||||
|
|
||||||
# End run
|
# End run
|
||||||
end_mlflow_run(run_id=run_info.outputs["run_id"])
|
end_mlflow_run(run_id=run_info.outputs["run_id"])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from kfp import dsl
|
from typing import Any, Dict, List, NamedTuple
|
||||||
from typing import Dict, Any, List, Optional, NamedTuple
|
|
||||||
|
|
||||||
|
from kfp import dsl
|
||||||
|
|
||||||
# MLflow component image with all required dependencies
|
# MLflow component image with all required dependencies
|
||||||
MLFLOW_IMAGE = "python:3.13-slim"
|
MLFLOW_IMAGE = "python:3.13-slim"
|
||||||
@@ -60,31 +60,32 @@ def create_mlflow_run(
|
|||||||
) -> NamedTuple('RunInfo', [('run_id', str), ('experiment_id', str), ('artifact_uri', str)]):
|
) -> NamedTuple('RunInfo', [('run_id', str), ('experiment_id', str), ('artifact_uri', str)]):
|
||||||
"""
|
"""
|
||||||
Create a new MLflow run for the pipeline.
|
Create a new MLflow run for the pipeline.
|
||||||
|
|
||||||
This should be called at the start of a pipeline to initialize
|
This should be called at the start of a pipeline to initialize
|
||||||
tracking. The returned run_id should be passed to subsequent
|
tracking. The returned run_id should be passed to subsequent
|
||||||
components for logging.
|
components for logging.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
experiment_name: Name of the MLflow experiment
|
experiment_name: Name of the MLflow experiment
|
||||||
run_name: Name for this specific run
|
run_name: Name for this specific run
|
||||||
mlflow_tracking_uri: MLflow tracking server URI
|
mlflow_tracking_uri: MLflow tracking server URI
|
||||||
tags: Optional tags to add to the run
|
tags: Optional tags to add to the run
|
||||||
params: Optional parameters to log
|
params: Optional parameters to log
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
NamedTuple with run_id, experiment_id, and artifact_uri
|
NamedTuple with run_id, experiment_id, and artifact_uri
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
import mlflow
|
import mlflow
|
||||||
from mlflow.tracking import MlflowClient
|
from mlflow.tracking import MlflowClient
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
# Set tracking URI
|
# Set tracking URI
|
||||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||||
|
|
||||||
client = MlflowClient()
|
client = MlflowClient()
|
||||||
|
|
||||||
# Get or create experiment
|
# Get or create experiment
|
||||||
experiment = client.get_experiment_by_name(experiment_name)
|
experiment = client.get_experiment_by_name(experiment_name)
|
||||||
if experiment is None:
|
if experiment is None:
|
||||||
@@ -94,7 +95,7 @@ def create_mlflow_run(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
experiment_id = experiment.experiment_id
|
experiment_id = experiment.experiment_id
|
||||||
|
|
||||||
# Create default tags
|
# Create default tags
|
||||||
default_tags = {
|
default_tags = {
|
||||||
"pipeline.type": "kubeflow",
|
"pipeline.type": "kubeflow",
|
||||||
@@ -103,24 +104,24 @@ def create_mlflow_run(
|
|||||||
}
|
}
|
||||||
if tags:
|
if tags:
|
||||||
default_tags.update(tags)
|
default_tags.update(tags)
|
||||||
|
|
||||||
# Start run
|
# Start run
|
||||||
run = mlflow.start_run(
|
run = mlflow.start_run(
|
||||||
experiment_id=experiment_id,
|
experiment_id=experiment_id,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
tags=default_tags,
|
tags=default_tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log initial params
|
# Log initial params
|
||||||
if params:
|
if params:
|
||||||
mlflow.log_params(params)
|
mlflow.log_params(params)
|
||||||
|
|
||||||
run_id = run.info.run_id
|
run_id = run.info.run_id
|
||||||
artifact_uri = run.info.artifact_uri
|
artifact_uri = run.info.artifact_uri
|
||||||
|
|
||||||
# End run (KFP components are isolated, we'll resume in other components)
|
# End run (KFP components are isolated, we'll resume in other components)
|
||||||
mlflow.end_run()
|
mlflow.end_run()
|
||||||
|
|
||||||
RunInfo = namedtuple('RunInfo', ['run_id', 'experiment_id', 'artifact_uri'])
|
RunInfo = namedtuple('RunInfo', ['run_id', 'experiment_id', 'artifact_uri'])
|
||||||
return RunInfo(run_id, experiment_id, artifact_uri)
|
return RunInfo(run_id, experiment_id, artifact_uri)
|
||||||
|
|
||||||
@@ -136,24 +137,24 @@ def log_params_component(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Log parameters to an existing MLflow run.
|
Log parameters to an existing MLflow run.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
run_id: The MLflow run ID to log to
|
run_id: The MLflow run ID to log to
|
||||||
params: Dictionary of parameters to log
|
params: Dictionary of parameters to log
|
||||||
mlflow_tracking_uri: MLflow tracking server URI
|
mlflow_tracking_uri: MLflow tracking server URI
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The run_id for chaining
|
The run_id for chaining
|
||||||
"""
|
"""
|
||||||
import mlflow
|
import mlflow
|
||||||
from mlflow.tracking import MlflowClient
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||||
client = MlflowClient()
|
client = MlflowClient()
|
||||||
|
|
||||||
for key, value in params.items():
|
for key, value in params.items():
|
||||||
client.log_param(run_id, key, str(value)[:500])
|
client.log_param(run_id, key, str(value)[:500])
|
||||||
|
|
||||||
return run_id
|
return run_id
|
||||||
|
|
||||||
|
|
||||||
@@ -169,25 +170,25 @@ def log_metrics_component(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Log metrics to an existing MLflow run.
|
Log metrics to an existing MLflow run.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
run_id: The MLflow run ID to log to
|
run_id: The MLflow run ID to log to
|
||||||
metrics: Dictionary of metrics to log
|
metrics: Dictionary of metrics to log
|
||||||
step: Step number for time-series metrics
|
step: Step number for time-series metrics
|
||||||
mlflow_tracking_uri: MLflow tracking server URI
|
mlflow_tracking_uri: MLflow tracking server URI
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The run_id for chaining
|
The run_id for chaining
|
||||||
"""
|
"""
|
||||||
import mlflow
|
import mlflow
|
||||||
from mlflow.tracking import MlflowClient
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||||
client = MlflowClient()
|
client = MlflowClient()
|
||||||
|
|
||||||
for key, value in metrics.items():
|
for key, value in metrics.items():
|
||||||
client.log_metric(run_id, key, float(value), step=step)
|
client.log_metric(run_id, key, float(value), step=step)
|
||||||
|
|
||||||
return run_id
|
return run_id
|
||||||
|
|
||||||
|
|
||||||
@@ -203,24 +204,24 @@ def log_artifact_component(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Log an artifact file to an existing MLflow run.
|
Log an artifact file to an existing MLflow run.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
run_id: The MLflow run ID to log to
|
run_id: The MLflow run ID to log to
|
||||||
artifact_path: Path to the artifact file
|
artifact_path: Path to the artifact file
|
||||||
artifact_name: Optional destination name in artifact store
|
artifact_name: Optional destination name in artifact store
|
||||||
mlflow_tracking_uri: MLflow tracking server URI
|
mlflow_tracking_uri: MLflow tracking server URI
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The run_id for chaining
|
The run_id for chaining
|
||||||
"""
|
"""
|
||||||
import mlflow
|
import mlflow
|
||||||
from mlflow.tracking import MlflowClient
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||||
client = MlflowClient()
|
client = MlflowClient()
|
||||||
|
|
||||||
client.log_artifact(run_id, artifact_path, artifact_name or None)
|
client.log_artifact(run_id, artifact_path, artifact_name or None)
|
||||||
|
|
||||||
return run_id
|
return run_id
|
||||||
|
|
||||||
|
|
||||||
@@ -236,36 +237,37 @@ def log_dict_artifact(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Log a dictionary as a JSON artifact.
|
Log a dictionary as a JSON artifact.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
run_id: The MLflow run ID to log to
|
run_id: The MLflow run ID to log to
|
||||||
data: Dictionary to save as JSON
|
data: Dictionary to save as JSON
|
||||||
filename: Name for the JSON file
|
filename: Name for the JSON file
|
||||||
mlflow_tracking_uri: MLflow tracking server URI
|
mlflow_tracking_uri: MLflow tracking server URI
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The run_id for chaining
|
The run_id for chaining
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import mlflow
|
import mlflow
|
||||||
from mlflow.tracking import MlflowClient
|
from mlflow.tracking import MlflowClient
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||||
client = MlflowClient()
|
client = MlflowClient()
|
||||||
|
|
||||||
# Ensure .json extension
|
# Ensure .json extension
|
||||||
if not filename.endswith('.json'):
|
if not filename.endswith('.json'):
|
||||||
filename += '.json'
|
filename += '.json'
|
||||||
|
|
||||||
# Write to temp file and log
|
# Write to temp file and log
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
filepath = Path(tmpdir) / filename
|
filepath = Path(tmpdir) / filename
|
||||||
with open(filepath, 'w') as f:
|
with open(filepath, 'w') as f:
|
||||||
json.dump(data, f, indent=2)
|
json.dump(data, f, indent=2)
|
||||||
client.log_artifact(run_id, str(filepath))
|
client.log_artifact(run_id, str(filepath))
|
||||||
|
|
||||||
return run_id
|
return run_id
|
||||||
|
|
||||||
|
|
||||||
@@ -280,31 +282,31 @@ def end_mlflow_run(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
End an MLflow run with the specified status.
|
End an MLflow run with the specified status.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
run_id: The MLflow run ID to end
|
run_id: The MLflow run ID to end
|
||||||
status: Run status (FINISHED, FAILED, KILLED)
|
status: Run status (FINISHED, FAILED, KILLED)
|
||||||
mlflow_tracking_uri: MLflow tracking server URI
|
mlflow_tracking_uri: MLflow tracking server URI
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The run_id
|
The run_id
|
||||||
"""
|
"""
|
||||||
import mlflow
|
import mlflow
|
||||||
from mlflow.tracking import MlflowClient
|
|
||||||
from mlflow.entities import RunStatus
|
from mlflow.entities import RunStatus
|
||||||
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||||
client = MlflowClient()
|
client = MlflowClient()
|
||||||
|
|
||||||
status_map = {
|
status_map = {
|
||||||
"FINISHED": RunStatus.FINISHED,
|
"FINISHED": RunStatus.FINISHED,
|
||||||
"FAILED": RunStatus.FAILED,
|
"FAILED": RunStatus.FAILED,
|
||||||
"KILLED": RunStatus.KILLED,
|
"KILLED": RunStatus.KILLED,
|
||||||
}
|
}
|
||||||
|
|
||||||
run_status = status_map.get(status.upper(), RunStatus.FINISHED)
|
run_status = status_map.get(status.upper(), RunStatus.FINISHED)
|
||||||
client.set_terminated(run_id, status=run_status)
|
client.set_terminated(run_id, status=run_status)
|
||||||
|
|
||||||
return run_id
|
return run_id
|
||||||
|
|
||||||
|
|
||||||
@@ -322,10 +324,10 @@ def log_training_metrics(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Log comprehensive training metrics for ML models.
|
Log comprehensive training metrics for ML models.
|
||||||
|
|
||||||
Designed for use with QLoRA training, voice training, and other
|
Designed for use with QLoRA training, voice training, and other
|
||||||
ML training pipelines in the llm-workflows repository.
|
ML training pipelines in the llm-workflows repository.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
run_id: The MLflow run ID to log to
|
run_id: The MLflow run ID to log to
|
||||||
model_type: Type of model (llm, stt, tts, embeddings)
|
model_type: Type of model (llm, stt, tts, embeddings)
|
||||||
@@ -333,19 +335,20 @@ def log_training_metrics(
|
|||||||
final_metrics: Final training metrics
|
final_metrics: Final training metrics
|
||||||
model_path: Path to saved model (if applicable)
|
model_path: Path to saved model (if applicable)
|
||||||
mlflow_tracking_uri: MLflow tracking server URI
|
mlflow_tracking_uri: MLflow tracking server URI
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The run_id for chaining
|
The run_id for chaining
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import mlflow
|
import mlflow
|
||||||
from mlflow.tracking import MlflowClient
|
from mlflow.tracking import MlflowClient
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||||
client = MlflowClient()
|
client = MlflowClient()
|
||||||
|
|
||||||
# Log training config as params
|
# Log training config as params
|
||||||
flat_config = {}
|
flat_config = {}
|
||||||
for key, value in training_config.items():
|
for key, value in training_config.items():
|
||||||
@@ -353,29 +356,29 @@ def log_training_metrics(
|
|||||||
flat_config[f"config.{key}"] = json.dumps(value)[:500]
|
flat_config[f"config.{key}"] = json.dumps(value)[:500]
|
||||||
else:
|
else:
|
||||||
flat_config[f"config.{key}"] = str(value)[:500]
|
flat_config[f"config.{key}"] = str(value)[:500]
|
||||||
|
|
||||||
for key, value in flat_config.items():
|
for key, value in flat_config.items():
|
||||||
client.log_param(run_id, key, value)
|
client.log_param(run_id, key, value)
|
||||||
|
|
||||||
# Log model type tag
|
# Log model type tag
|
||||||
client.set_tag(run_id, "model.type", model_type)
|
client.set_tag(run_id, "model.type", model_type)
|
||||||
|
|
||||||
# Log metrics
|
# Log metrics
|
||||||
for key, value in final_metrics.items():
|
for key, value in final_metrics.items():
|
||||||
client.log_metric(run_id, key, float(value))
|
client.log_metric(run_id, key, float(value))
|
||||||
|
|
||||||
# Log full config as artifact
|
# Log full config as artifact
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
config_path = Path(tmpdir) / "training_config.json"
|
config_path = Path(tmpdir) / "training_config.json"
|
||||||
with open(config_path, 'w') as f:
|
with open(config_path, 'w') as f:
|
||||||
json.dump(training_config, f, indent=2)
|
json.dump(training_config, f, indent=2)
|
||||||
client.log_artifact(run_id, str(config_path))
|
client.log_artifact(run_id, str(config_path))
|
||||||
|
|
||||||
# Log model path if provided
|
# Log model path if provided
|
||||||
if model_path:
|
if model_path:
|
||||||
client.log_param(run_id, "model.path", model_path)
|
client.log_param(run_id, "model.path", model_path)
|
||||||
client.set_tag(run_id, "model.saved", "true")
|
client.set_tag(run_id, "model.saved", "true")
|
||||||
|
|
||||||
return run_id
|
return run_id
|
||||||
|
|
||||||
|
|
||||||
@@ -397,9 +400,9 @@ def log_document_ingestion_metrics(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Log document ingestion pipeline metrics.
|
Log document ingestion pipeline metrics.
|
||||||
|
|
||||||
Designed for use with the document_ingestion_pipeline.
|
Designed for use with the document_ingestion_pipeline.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
run_id: The MLflow run ID to log to
|
run_id: The MLflow run ID to log to
|
||||||
source_url: URL of the source document
|
source_url: URL of the source document
|
||||||
@@ -411,16 +414,16 @@ def log_document_ingestion_metrics(
|
|||||||
chunk_size: Chunk size in tokens
|
chunk_size: Chunk size in tokens
|
||||||
chunk_overlap: Chunk overlap in tokens
|
chunk_overlap: Chunk overlap in tokens
|
||||||
mlflow_tracking_uri: MLflow tracking server URI
|
mlflow_tracking_uri: MLflow tracking server URI
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The run_id for chaining
|
The run_id for chaining
|
||||||
"""
|
"""
|
||||||
import mlflow
|
import mlflow
|
||||||
from mlflow.tracking import MlflowClient
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||||
client = MlflowClient()
|
client = MlflowClient()
|
||||||
|
|
||||||
# Log params
|
# Log params
|
||||||
params = {
|
params = {
|
||||||
"source_url": source_url[:500],
|
"source_url": source_url[:500],
|
||||||
@@ -431,7 +434,7 @@ def log_document_ingestion_metrics(
|
|||||||
}
|
}
|
||||||
for key, value in params.items():
|
for key, value in params.items():
|
||||||
client.log_param(run_id, key, value)
|
client.log_param(run_id, key, value)
|
||||||
|
|
||||||
# Log metrics
|
# Log metrics
|
||||||
metrics = {
|
metrics = {
|
||||||
"chunks_created": chunks_created,
|
"chunks_created": chunks_created,
|
||||||
@@ -441,11 +444,11 @@ def log_document_ingestion_metrics(
|
|||||||
}
|
}
|
||||||
for key, value in metrics.items():
|
for key, value in metrics.items():
|
||||||
client.log_metric(run_id, key, float(value))
|
client.log_metric(run_id, key, float(value))
|
||||||
|
|
||||||
# Set pipeline type tag
|
# Set pipeline type tag
|
||||||
client.set_tag(run_id, "pipeline.type", "document-ingestion")
|
client.set_tag(run_id, "pipeline.type", "document-ingestion")
|
||||||
client.set_tag(run_id, "milvus.collection", collection_name)
|
client.set_tag(run_id, "milvus.collection", collection_name)
|
||||||
|
|
||||||
return run_id
|
return run_id
|
||||||
|
|
||||||
|
|
||||||
@@ -463,9 +466,9 @@ def log_evaluation_results(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Log model evaluation results.
|
Log model evaluation results.
|
||||||
|
|
||||||
Designed for use with the evaluation_pipeline.
|
Designed for use with the evaluation_pipeline.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
run_id: The MLflow run ID to log to
|
run_id: The MLflow run ID to log to
|
||||||
model_name: Name of the evaluated model
|
model_name: Name of the evaluated model
|
||||||
@@ -473,27 +476,28 @@ def log_evaluation_results(
|
|||||||
metrics: Evaluation metrics (accuracy, etc.)
|
metrics: Evaluation metrics (accuracy, etc.)
|
||||||
sample_results: Optional sample predictions
|
sample_results: Optional sample predictions
|
||||||
mlflow_tracking_uri: MLflow tracking server URI
|
mlflow_tracking_uri: MLflow tracking server URI
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The run_id for chaining
|
The run_id for chaining
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import mlflow
|
import mlflow
|
||||||
from mlflow.tracking import MlflowClient
|
from mlflow.tracking import MlflowClient
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||||
client = MlflowClient()
|
client = MlflowClient()
|
||||||
|
|
||||||
# Log params
|
# Log params
|
||||||
client.log_param(run_id, "eval.model_name", model_name)
|
client.log_param(run_id, "eval.model_name", model_name)
|
||||||
client.log_param(run_id, "eval.dataset", dataset_name)
|
client.log_param(run_id, "eval.dataset", dataset_name)
|
||||||
|
|
||||||
# Log metrics
|
# Log metrics
|
||||||
for key, value in metrics.items():
|
for key, value in metrics.items():
|
||||||
client.log_metric(run_id, f"eval.{key}", float(value))
|
client.log_metric(run_id, f"eval.{key}", float(value))
|
||||||
|
|
||||||
# Log sample results as artifact
|
# Log sample results as artifact
|
||||||
if sample_results:
|
if sample_results:
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
@@ -501,13 +505,13 @@ def log_evaluation_results(
|
|||||||
with open(results_path, 'w') as f:
|
with open(results_path, 'w') as f:
|
||||||
json.dump(sample_results, f, indent=2)
|
json.dump(sample_results, f, indent=2)
|
||||||
client.log_artifact(run_id, str(results_path))
|
client.log_artifact(run_id, str(results_path))
|
||||||
|
|
||||||
# Set tags
|
# Set tags
|
||||||
client.set_tag(run_id, "pipeline.type", "evaluation")
|
client.set_tag(run_id, "pipeline.type", "evaluation")
|
||||||
client.set_tag(run_id, "model.name", model_name)
|
client.set_tag(run_id, "model.name", model_name)
|
||||||
|
|
||||||
# Determine if passed
|
# Determine if passed
|
||||||
passed = metrics.get("pass", metrics.get("accuracy", 0) >= 0.7)
|
passed = metrics.get("pass", metrics.get("accuracy", 0) >= 0.7)
|
||||||
client.set_tag(run_id, "eval.passed", str(passed))
|
client.set_tag(run_id, "eval.passed", str(passed))
|
||||||
|
|
||||||
return run_id
|
return run_id
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ Usage:
|
|||||||
promote_model_to_production,
|
promote_model_to_production,
|
||||||
generate_kserve_manifest,
|
generate_kserve_manifest,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Register a new model version
|
# Register a new model version
|
||||||
model_version = register_model_for_kserve(
|
model_version = register_model_for_kserve(
|
||||||
model_name="whisper-finetuned",
|
model_name="whisper-finetuned",
|
||||||
@@ -28,7 +28,7 @@ Usage:
|
|||||||
"container_image": "ghcr.io/my-org/whisper:v2",
|
"container_image": "ghcr.io/my-org/whisper:v2",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate KServe manifest for deployment
|
# Generate KServe manifest for deployment
|
||||||
manifest = generate_kserve_manifest(
|
manifest = generate_kserve_manifest(
|
||||||
model_name="whisper-finetuned",
|
model_name="whisper-finetuned",
|
||||||
@@ -36,18 +36,15 @@ Usage:
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import yaml
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Dict, Any, List
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import mlflow
|
import mlflow
|
||||||
from mlflow.tracking import MlflowClient
|
import yaml
|
||||||
from mlflow.entities.model_registry import ModelVersion
|
from mlflow.entities.model_registry import ModelVersion
|
||||||
|
|
||||||
from .client import get_mlflow_client, MLflowConfig
|
from .client import get_mlflow_client
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -55,15 +52,15 @@ logger = logging.getLogger(__name__)
|
|||||||
@dataclass
|
@dataclass
|
||||||
class KServeConfig:
|
class KServeConfig:
|
||||||
"""Configuration for KServe deployment."""
|
"""Configuration for KServe deployment."""
|
||||||
|
|
||||||
# Runtime/container configuration
|
# Runtime/container configuration
|
||||||
runtime: str = "kserve-huggingface" # kserve-huggingface, kserve-custom, etc.
|
runtime: str = "kserve-huggingface" # kserve-huggingface, kserve-custom, etc.
|
||||||
container_image: Optional[str] = None
|
container_image: Optional[str] = None
|
||||||
container_port: int = 8080
|
container_port: int = 8080
|
||||||
|
|
||||||
# Protocol configuration
|
# Protocol configuration
|
||||||
protocol: str = "v2" # v1, v2, grpc
|
protocol: str = "v2" # v1, v2, grpc
|
||||||
|
|
||||||
# Resource requests/limits
|
# Resource requests/limits
|
||||||
cpu_request: str = "1"
|
cpu_request: str = "1"
|
||||||
cpu_limit: str = "4"
|
cpu_limit: str = "4"
|
||||||
@@ -71,22 +68,22 @@ class KServeConfig:
|
|||||||
memory_limit: str = "16Gi"
|
memory_limit: str = "16Gi"
|
||||||
gpu_count: int = 0
|
gpu_count: int = 0
|
||||||
gpu_type: str = "nvidia.com/gpu" # or amd.com/gpu for ROCm
|
gpu_type: str = "nvidia.com/gpu" # or amd.com/gpu for ROCm
|
||||||
|
|
||||||
# Storage configuration
|
# Storage configuration
|
||||||
storage_uri: Optional[str] = None # s3://, pvc://, gs://
|
storage_uri: Optional[str] = None # s3://, pvc://, gs://
|
||||||
|
|
||||||
# Scaling configuration
|
# Scaling configuration
|
||||||
min_replicas: int = 1
|
min_replicas: int = 1
|
||||||
max_replicas: int = 1
|
max_replicas: int = 1
|
||||||
scale_target: int = 10 # Target concurrent requests for scaling
|
scale_target: int = 10 # Target concurrent requests for scaling
|
||||||
|
|
||||||
# Serving configuration
|
# Serving configuration
|
||||||
timeout_seconds: int = 300
|
timeout_seconds: int = 300
|
||||||
batch_size: int = 1
|
batch_size: int = 1
|
||||||
|
|
||||||
# Additional environment variables
|
# Additional environment variables
|
||||||
env_vars: Dict[str, str] = field(default_factory=dict)
|
env_vars: Dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
def as_dict(self) -> Dict[str, Any]:
|
def as_dict(self) -> Dict[str, Any]:
|
||||||
"""Convert to dictionary for MLflow tags."""
|
"""Convert to dictionary for MLflow tags."""
|
||||||
return {
|
return {
|
||||||
@@ -165,7 +162,7 @@ def register_model_for_kserve(
|
|||||||
) -> ModelVersion:
|
) -> ModelVersion:
|
||||||
"""
|
"""
|
||||||
Register a model in MLflow Model Registry with KServe metadata.
|
Register a model in MLflow Model Registry with KServe metadata.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Name for the registered model
|
model_name: Name for the registered model
|
||||||
model_uri: URI to model artifacts (runs:/run_id/path, s3://, pvc://)
|
model_uri: URI to model artifacts (runs:/run_id/path, s3://, pvc://)
|
||||||
@@ -175,16 +172,16 @@ def register_model_for_kserve(
|
|||||||
kserve_config: KServe deployment configuration
|
kserve_config: KServe deployment configuration
|
||||||
tags: Additional tags for the model version
|
tags: Additional tags for the model version
|
||||||
tracking_uri: Override default tracking URI
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The created ModelVersion object
|
The created ModelVersion object
|
||||||
"""
|
"""
|
||||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||||
|
|
||||||
# Get or use preset KServe config
|
# Get or use preset KServe config
|
||||||
if kserve_config is None:
|
if kserve_config is None:
|
||||||
kserve_config = KSERVE_PRESETS.get(model_type, KServeConfig())
|
kserve_config = KSERVE_PRESETS.get(model_type, KServeConfig())
|
||||||
|
|
||||||
# Ensure registered model exists
|
# Ensure registered model exists
|
||||||
try:
|
try:
|
||||||
client.get_registered_model(model_name)
|
client.get_registered_model(model_name)
|
||||||
@@ -198,7 +195,7 @@ def register_model_for_kserve(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
logger.info(f"Created registered model: {model_name}")
|
logger.info(f"Created registered model: {model_name}")
|
||||||
|
|
||||||
# Create model version
|
# Create model version
|
||||||
model_version = client.create_model_version(
|
model_version = client.create_model_version(
|
||||||
name=model_name,
|
name=model_name,
|
||||||
@@ -211,12 +208,12 @@ def register_model_for_kserve(
|
|||||||
**kserve_config.as_dict(),
|
**kserve_config.as_dict(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Registered model version {model_version.version} "
|
f"Registered model version {model_version.version} "
|
||||||
f"for {model_name} (type: {model_type})"
|
f"for {model_name} (type: {model_type})"
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_version
|
return model_version
|
||||||
|
|
||||||
|
|
||||||
@@ -229,19 +226,19 @@ def promote_model_to_stage(
|
|||||||
) -> ModelVersion:
|
) -> ModelVersion:
|
||||||
"""
|
"""
|
||||||
Promote a model version to a new stage.
|
Promote a model version to a new stage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Name of the registered model
|
model_name: Name of the registered model
|
||||||
version: Version number to promote
|
version: Version number to promote
|
||||||
stage: Target stage (Staging, Production, Archived)
|
stage: Target stage (Staging, Production, Archived)
|
||||||
archive_existing: If True, archive existing versions in target stage
|
archive_existing: If True, archive existing versions in target stage
|
||||||
tracking_uri: Override default tracking URI
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The updated ModelVersion
|
The updated ModelVersion
|
||||||
"""
|
"""
|
||||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||||
|
|
||||||
# Transition to new stage
|
# Transition to new stage
|
||||||
model_version = client.transition_model_version_stage(
|
model_version = client.transition_model_version_stage(
|
||||||
name=model_name,
|
name=model_name,
|
||||||
@@ -249,9 +246,9 @@ def promote_model_to_stage(
|
|||||||
stage=stage,
|
stage=stage,
|
||||||
archive_existing_versions=archive_existing,
|
archive_existing_versions=archive_existing,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Promoted {model_name} v{version} to {stage}")
|
logger.info(f"Promoted {model_name} v{version} to {stage}")
|
||||||
|
|
||||||
return model_version
|
return model_version
|
||||||
|
|
||||||
|
|
||||||
@@ -262,12 +259,12 @@ def promote_model_to_production(
|
|||||||
) -> ModelVersion:
|
) -> ModelVersion:
|
||||||
"""
|
"""
|
||||||
Promote a model version directly to Production.
|
Promote a model version directly to Production.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Name of the registered model
|
model_name: Name of the registered model
|
||||||
version: Version number to promote
|
version: Version number to promote
|
||||||
tracking_uri: Override default tracking URI
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The updated ModelVersion
|
The updated ModelVersion
|
||||||
"""
|
"""
|
||||||
@@ -286,18 +283,18 @@ def get_production_model(
|
|||||||
) -> Optional[ModelVersion]:
|
) -> Optional[ModelVersion]:
|
||||||
"""
|
"""
|
||||||
Get the current Production model version.
|
Get the current Production model version.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Name of the registered model
|
model_name: Name of the registered model
|
||||||
tracking_uri: Override default tracking URI
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The Production ModelVersion, or None if none exists
|
The Production ModelVersion, or None if none exists
|
||||||
"""
|
"""
|
||||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||||
|
|
||||||
versions = client.get_latest_versions(model_name, stages=["Production"])
|
versions = client.get_latest_versions(model_name, stages=["Production"])
|
||||||
|
|
||||||
return versions[0] if versions else None
|
return versions[0] if versions else None
|
||||||
|
|
||||||
|
|
||||||
@@ -308,17 +305,17 @@ def get_model_kserve_config(
|
|||||||
) -> KServeConfig:
|
) -> KServeConfig:
|
||||||
"""
|
"""
|
||||||
Get KServe configuration from a registered model version.
|
Get KServe configuration from a registered model version.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Name of the registered model
|
model_name: Name of the registered model
|
||||||
version: Version number (uses Production if not specified)
|
version: Version number (uses Production if not specified)
|
||||||
tracking_uri: Override default tracking URI
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
KServeConfig populated from model tags
|
KServeConfig populated from model tags
|
||||||
"""
|
"""
|
||||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||||
|
|
||||||
if version:
|
if version:
|
||||||
model_version = client.get_model_version(model_name, str(version))
|
model_version = client.get_model_version(model_name, str(version))
|
||||||
else:
|
else:
|
||||||
@@ -326,9 +323,9 @@ def get_model_kserve_config(
|
|||||||
if not prod_version:
|
if not prod_version:
|
||||||
raise ValueError(f"No Production version for {model_name}")
|
raise ValueError(f"No Production version for {model_name}")
|
||||||
model_version = prod_version
|
model_version = prod_version
|
||||||
|
|
||||||
tags = model_version.tags
|
tags = model_version.tags
|
||||||
|
|
||||||
return KServeConfig(
|
return KServeConfig(
|
||||||
runtime=tags.get("kserve.runtime", "kserve-huggingface"),
|
runtime=tags.get("kserve.runtime", "kserve-huggingface"),
|
||||||
protocol=tags.get("kserve.protocol", "v2"),
|
protocol=tags.get("kserve.protocol", "v2"),
|
||||||
@@ -352,7 +349,7 @@ def generate_kserve_manifest(
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Generate a KServe InferenceService manifest from a registered model.
|
Generate a KServe InferenceService manifest from a registered model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Name of the registered model
|
model_name: Name of the registered model
|
||||||
version: Version number (uses Production if not specified)
|
version: Version number (uses Production if not specified)
|
||||||
@@ -360,12 +357,12 @@ def generate_kserve_manifest(
|
|||||||
service_name: Name for the InferenceService (defaults to model_name)
|
service_name: Name for the InferenceService (defaults to model_name)
|
||||||
extra_annotations: Additional annotations for the service
|
extra_annotations: Additional annotations for the service
|
||||||
tracking_uri: Override default tracking URI
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
KServe InferenceService manifest as a dictionary
|
KServe InferenceService manifest as a dictionary
|
||||||
"""
|
"""
|
||||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||||
|
|
||||||
# Get model version
|
# Get model version
|
||||||
if version:
|
if version:
|
||||||
model_version = client.get_model_version(model_name, str(version))
|
model_version = client.get_model_version(model_name, str(version))
|
||||||
@@ -375,13 +372,13 @@ def generate_kserve_manifest(
|
|||||||
raise ValueError(f"No Production version for {model_name}")
|
raise ValueError(f"No Production version for {model_name}")
|
||||||
model_version = prod_version
|
model_version = prod_version
|
||||||
version = int(model_version.version)
|
version = int(model_version.version)
|
||||||
|
|
||||||
# Get KServe config
|
# Get KServe config
|
||||||
config = get_model_kserve_config(model_name, version, tracking_uri)
|
config = get_model_kserve_config(model_name, version, tracking_uri)
|
||||||
model_type = model_version.tags.get("model.type", "custom")
|
model_type = model_version.tags.get("model.type", "custom")
|
||||||
|
|
||||||
svc_name = service_name or model_name.lower().replace("_", "-")
|
svc_name = service_name or model_name.lower().replace("_", "-")
|
||||||
|
|
||||||
# Build manifest
|
# Build manifest
|
||||||
manifest = {
|
manifest = {
|
||||||
"apiVersion": "serving.kserve.io/v1beta1",
|
"apiVersion": "serving.kserve.io/v1beta1",
|
||||||
@@ -409,10 +406,10 @@ def generate_kserve_manifest(
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Configure predictor based on runtime
|
# Configure predictor based on runtime
|
||||||
predictor = manifest["spec"]["predictor"]
|
predictor = manifest["spec"]["predictor"]
|
||||||
|
|
||||||
if config.container_image:
|
if config.container_image:
|
||||||
# Custom container
|
# Custom container
|
||||||
predictor["containers"] = [{
|
predictor["containers"] = [{
|
||||||
@@ -434,16 +431,16 @@ def generate_kserve_manifest(
|
|||||||
for k, v in config.env_vars.items()
|
for k, v in config.env_vars.items()
|
||||||
],
|
],
|
||||||
}]
|
}]
|
||||||
|
|
||||||
# Add GPU if needed
|
# Add GPU if needed
|
||||||
if config.gpu_count > 0:
|
if config.gpu_count > 0:
|
||||||
predictor["containers"][0]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
|
predictor["containers"][0]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
|
||||||
predictor["containers"][0]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
|
predictor["containers"][0]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Standard KServe runtime
|
# Standard KServe runtime
|
||||||
storage_uri = config.storage_uri or model_version.source
|
storage_uri = config.storage_uri or model_version.source
|
||||||
|
|
||||||
predictor["model"] = {
|
predictor["model"] = {
|
||||||
"modelFormat": {"name": "huggingface"},
|
"modelFormat": {"name": "huggingface"},
|
||||||
"protocolVersion": config.protocol,
|
"protocolVersion": config.protocol,
|
||||||
@@ -459,11 +456,11 @@ def generate_kserve_manifest(
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.gpu_count > 0:
|
if config.gpu_count > 0:
|
||||||
predictor["model"]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
|
predictor["model"]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
|
||||||
predictor["model"]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
|
predictor["model"]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
|
||||||
|
|
||||||
return manifest
|
return manifest
|
||||||
|
|
||||||
|
|
||||||
@@ -476,14 +473,14 @@ def generate_kserve_yaml(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a KServe InferenceService manifest as YAML.
|
Generate a KServe InferenceService manifest as YAML.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Name of the registered model
|
model_name: Name of the registered model
|
||||||
version: Version number (uses Production if not specified)
|
version: Version number (uses Production if not specified)
|
||||||
namespace: Kubernetes namespace
|
namespace: Kubernetes namespace
|
||||||
output_path: If provided, write YAML to this path
|
output_path: If provided, write YAML to this path
|
||||||
tracking_uri: Override default tracking URI
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
YAML string of the manifest
|
YAML string of the manifest
|
||||||
"""
|
"""
|
||||||
@@ -493,14 +490,14 @@ def generate_kserve_yaml(
|
|||||||
namespace=namespace,
|
namespace=namespace,
|
||||||
tracking_uri=tracking_uri,
|
tracking_uri=tracking_uri,
|
||||||
)
|
)
|
||||||
|
|
||||||
yaml_str = yaml.dump(manifest, default_flow_style=False, sort_keys=False)
|
yaml_str = yaml.dump(manifest, default_flow_style=False, sort_keys=False)
|
||||||
|
|
||||||
if output_path:
|
if output_path:
|
||||||
with open(output_path, 'w') as f:
|
with open(output_path, 'w') as f:
|
||||||
f.write(yaml_str)
|
f.write(yaml_str)
|
||||||
logger.info(f"Wrote KServe manifest to {output_path}")
|
logger.info(f"Wrote KServe manifest to {output_path}")
|
||||||
|
|
||||||
return yaml_str
|
return yaml_str
|
||||||
|
|
||||||
|
|
||||||
@@ -511,17 +508,17 @@ def list_model_versions(
|
|||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
List all versions of a registered model.
|
List all versions of a registered model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Name of the registered model
|
model_name: Name of the registered model
|
||||||
stages: Filter by stages (None for all)
|
stages: Filter by stages (None for all)
|
||||||
tracking_uri: Override default tracking URI
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of model version info dictionaries
|
List of model version info dictionaries
|
||||||
"""
|
"""
|
||||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||||
|
|
||||||
if stages:
|
if stages:
|
||||||
versions = client.get_latest_versions(model_name, stages=stages)
|
versions = client.get_latest_versions(model_name, stages=stages)
|
||||||
else:
|
else:
|
||||||
@@ -529,7 +526,7 @@ def list_model_versions(
|
|||||||
versions = []
|
versions = []
|
||||||
for mv in client.search_model_versions(f"name='{model_name}'"):
|
for mv in client.search_model_versions(f"name='{model_name}'"):
|
||||||
versions.append(mv)
|
versions.append(mv)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"version": mv.version,
|
"version": mv.version,
|
||||||
|
|||||||
@@ -5,19 +5,17 @@ Provides a high-level interface for logging experiments, parameters,
|
|||||||
metrics, and artifacts from Kubeflow Pipeline components.
|
metrics, and artifacts from Kubeflow Pipeline components.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
import os
|
||||||
from typing import Optional, Dict, Any, List, Union
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
import mlflow
|
import mlflow
|
||||||
from mlflow.tracking import MlflowClient
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
from .client import get_mlflow_client, ensure_experiment, MLflowConfig
|
from .client import MLflowConfig, ensure_experiment, get_mlflow_client
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -30,7 +28,7 @@ class PipelineMetadata:
|
|||||||
run_name: Optional[str] = None
|
run_name: Optional[str] = None
|
||||||
component_name: Optional[str] = None
|
component_name: Optional[str] = None
|
||||||
namespace: str = "ai-ml"
|
namespace: str = "ai-ml"
|
||||||
|
|
||||||
# KFP-specific metadata (populated from environment if available)
|
# KFP-specific metadata (populated from environment if available)
|
||||||
kfp_run_id: Optional[str] = field(
|
kfp_run_id: Optional[str] = field(
|
||||||
default_factory=lambda: os.environ.get("KFP_RUN_ID")
|
default_factory=lambda: os.environ.get("KFP_RUN_ID")
|
||||||
@@ -38,7 +36,7 @@ class PipelineMetadata:
|
|||||||
kfp_pod_name: Optional[str] = field(
|
kfp_pod_name: Optional[str] = field(
|
||||||
default_factory=lambda: os.environ.get("KFP_POD_NAME")
|
default_factory=lambda: os.environ.get("KFP_POD_NAME")
|
||||||
)
|
)
|
||||||
|
|
||||||
def as_tags(self) -> Dict[str, str]:
|
def as_tags(self) -> Dict[str, str]:
|
||||||
"""Convert metadata to MLflow tags."""
|
"""Convert metadata to MLflow tags."""
|
||||||
tags = {
|
tags = {
|
||||||
@@ -60,34 +58,34 @@ class PipelineMetadata:
|
|||||||
class MLflowTracker:
|
class MLflowTracker:
|
||||||
"""
|
"""
|
||||||
MLflow experiment tracker for Kubeflow Pipeline components.
|
MLflow experiment tracker for Kubeflow Pipeline components.
|
||||||
|
|
||||||
Example usage in a KFP component:
|
Example usage in a KFP component:
|
||||||
|
|
||||||
from mlflow_utils import MLflowTracker
|
from mlflow_utils import MLflowTracker
|
||||||
|
|
||||||
tracker = MLflowTracker(
|
tracker = MLflowTracker(
|
||||||
experiment_name="document-ingestion",
|
experiment_name="document-ingestion",
|
||||||
run_name="batch-ingestion-2024-01"
|
run_name="batch-ingestion-2024-01"
|
||||||
)
|
)
|
||||||
|
|
||||||
with tracker.start_run() as run:
|
with tracker.start_run() as run:
|
||||||
tracker.log_params({
|
tracker.log_params({
|
||||||
"chunk_size": 500,
|
"chunk_size": 500,
|
||||||
"overlap": 50,
|
"overlap": 50,
|
||||||
"embeddings_model": "bge-small-en-v1.5"
|
"embeddings_model": "bge-small-en-v1.5"
|
||||||
})
|
})
|
||||||
|
|
||||||
# ... do work ...
|
# ... do work ...
|
||||||
|
|
||||||
tracker.log_metrics({
|
tracker.log_metrics({
|
||||||
"documents_processed": 100,
|
"documents_processed": 100,
|
||||||
"chunks_created": 2500,
|
"chunks_created": 2500,
|
||||||
"processing_time_seconds": 120.5
|
"processing_time_seconds": 120.5
|
||||||
})
|
})
|
||||||
|
|
||||||
tracker.log_artifact("/path/to/output.json")
|
tracker.log_artifact("/path/to/output.json")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
experiment_name: str,
|
experiment_name: str,
|
||||||
@@ -98,7 +96,7 @@ class MLflowTracker:
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the MLflow tracker.
|
Initialize the MLflow tracker.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
experiment_name: Name of the MLflow experiment
|
experiment_name: Name of the MLflow experiment
|
||||||
run_name: Optional name for this run
|
run_name: Optional name for this run
|
||||||
@@ -112,22 +110,22 @@ class MLflowTracker:
|
|||||||
self.pipeline_metadata = pipeline_metadata
|
self.pipeline_metadata = pipeline_metadata
|
||||||
self.user_tags = tags or {}
|
self.user_tags = tags or {}
|
||||||
self.tracking_uri = tracking_uri
|
self.tracking_uri = tracking_uri
|
||||||
|
|
||||||
self.client: Optional[MlflowClient] = None
|
self.client: Optional[MlflowClient] = None
|
||||||
self.run: Optional[mlflow.ActiveRun] = None
|
self.run: Optional[mlflow.ActiveRun] = None
|
||||||
self.run_id: Optional[str] = None
|
self.run_id: Optional[str] = None
|
||||||
self._start_time: Optional[float] = None
|
self._start_time: Optional[float] = None
|
||||||
|
|
||||||
def _get_all_tags(self) -> Dict[str, str]:
|
def _get_all_tags(self) -> Dict[str, str]:
|
||||||
"""Combine all tags for the run."""
|
"""Combine all tags for the run."""
|
||||||
tags = self.config.default_tags.copy()
|
tags = self.config.default_tags.copy()
|
||||||
|
|
||||||
if self.pipeline_metadata:
|
if self.pipeline_metadata:
|
||||||
tags.update(self.pipeline_metadata.as_tags())
|
tags.update(self.pipeline_metadata.as_tags())
|
||||||
|
|
||||||
tags.update(self.user_tags)
|
tags.update(self.user_tags)
|
||||||
return tags
|
return tags
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def start_run(
|
def start_run(
|
||||||
self,
|
self,
|
||||||
@@ -136,11 +134,11 @@ class MLflowTracker:
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Start an MLflow run as a context manager.
|
Start an MLflow run as a context manager.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
nested: If True, create a nested run under the current active run
|
nested: If True, create a nested run under the current active run
|
||||||
parent_run_id: Explicit parent run ID for nested runs
|
parent_run_id: Explicit parent run ID for nested runs
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
The MLflow run object
|
The MLflow run object
|
||||||
"""
|
"""
|
||||||
@@ -148,12 +146,12 @@ class MLflowTracker:
|
|||||||
tracking_uri=self.tracking_uri,
|
tracking_uri=self.tracking_uri,
|
||||||
configure_global=True
|
configure_global=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure experiment exists
|
# Ensure experiment exists
|
||||||
experiment_id = ensure_experiment(self.experiment_name)
|
experiment_id = ensure_experiment(self.experiment_name)
|
||||||
|
|
||||||
self._start_time = time.time()
|
self._start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Start the run
|
# Start the run
|
||||||
self.run = mlflow.start_run(
|
self.run = mlflow.start_run(
|
||||||
@@ -163,14 +161,14 @@ class MLflowTracker:
|
|||||||
tags=self._get_all_tags(),
|
tags=self._get_all_tags(),
|
||||||
)
|
)
|
||||||
self.run_id = self.run.info.run_id
|
self.run_id = self.run.info.run_id
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Started MLflow run '{self.run_name}' "
|
f"Started MLflow run '{self.run_name}' "
|
||||||
f"(ID: {self.run_id}) in experiment '{self.experiment_name}'"
|
f"(ID: {self.run_id}) in experiment '{self.experiment_name}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.run
|
yield self.run
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"MLflow run failed: {e}")
|
logger.error(f"MLflow run failed: {e}")
|
||||||
if self.run:
|
if self.run:
|
||||||
@@ -185,22 +183,22 @@ class MLflowTracker:
|
|||||||
mlflow.log_metric("run_duration_seconds", duration)
|
mlflow.log_metric("run_duration_seconds", duration)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# End the run
|
# End the run
|
||||||
mlflow.end_run()
|
mlflow.end_run()
|
||||||
logger.info(f"Ended MLflow run '{self.run_name}'")
|
logger.info(f"Ended MLflow run '{self.run_name}'")
|
||||||
|
|
||||||
def log_params(self, params: Dict[str, Any]) -> None:
|
def log_params(self, params: Dict[str, Any]) -> None:
|
||||||
"""
|
"""
|
||||||
Log parameters to the current run.
|
Log parameters to the current run.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: Dictionary of parameter names to values
|
params: Dictionary of parameter names to values
|
||||||
"""
|
"""
|
||||||
if not self.run:
|
if not self.run:
|
||||||
logger.warning("No active run, skipping log_params")
|
logger.warning("No active run, skipping log_params")
|
||||||
return
|
return
|
||||||
|
|
||||||
# MLflow has limits on param values, truncate if needed
|
# MLflow has limits on param values, truncate if needed
|
||||||
cleaned_params = {}
|
cleaned_params = {}
|
||||||
for key, value in params.items():
|
for key, value in params.items():
|
||||||
@@ -208,14 +206,14 @@ class MLflowTracker:
|
|||||||
if len(str_value) > 500:
|
if len(str_value) > 500:
|
||||||
str_value = str_value[:497] + "..."
|
str_value = str_value[:497] + "..."
|
||||||
cleaned_params[key] = str_value
|
cleaned_params[key] = str_value
|
||||||
|
|
||||||
mlflow.log_params(cleaned_params)
|
mlflow.log_params(cleaned_params)
|
||||||
logger.debug(f"Logged {len(params)} parameters")
|
logger.debug(f"Logged {len(params)} parameters")
|
||||||
|
|
||||||
def log_param(self, key: str, value: Any) -> None:
|
def log_param(self, key: str, value: Any) -> None:
|
||||||
"""Log a single parameter."""
|
"""Log a single parameter."""
|
||||||
self.log_params({key: value})
|
self.log_params({key: value})
|
||||||
|
|
||||||
def log_metrics(
|
def log_metrics(
|
||||||
self,
|
self,
|
||||||
metrics: Dict[str, Union[float, int]],
|
metrics: Dict[str, Union[float, int]],
|
||||||
@@ -223,7 +221,7 @@ class MLflowTracker:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Log metrics to the current run.
|
Log metrics to the current run.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
metrics: Dictionary of metric names to values
|
metrics: Dictionary of metric names to values
|
||||||
step: Optional step number for time-series metrics
|
step: Optional step number for time-series metrics
|
||||||
@@ -231,10 +229,10 @@ class MLflowTracker:
|
|||||||
if not self.run:
|
if not self.run:
|
||||||
logger.warning("No active run, skipping log_metrics")
|
logger.warning("No active run, skipping log_metrics")
|
||||||
return
|
return
|
||||||
|
|
||||||
mlflow.log_metrics(metrics, step=step)
|
mlflow.log_metrics(metrics, step=step)
|
||||||
logger.debug(f"Logged {len(metrics)} metrics")
|
logger.debug(f"Logged {len(metrics)} metrics")
|
||||||
|
|
||||||
def log_metric(
|
def log_metric(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
@@ -243,7 +241,7 @@ class MLflowTracker:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Log a single metric."""
|
"""Log a single metric."""
|
||||||
self.log_metrics({key: value}, step=step)
|
self.log_metrics({key: value}, step=step)
|
||||||
|
|
||||||
def log_artifact(
|
def log_artifact(
|
||||||
self,
|
self,
|
||||||
local_path: str,
|
local_path: str,
|
||||||
@@ -251,7 +249,7 @@ class MLflowTracker:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Log an artifact file to the current run.
|
Log an artifact file to the current run.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
local_path: Path to the local file to log
|
local_path: Path to the local file to log
|
||||||
artifact_path: Optional destination path within the artifact store
|
artifact_path: Optional destination path within the artifact store
|
||||||
@@ -259,10 +257,10 @@ class MLflowTracker:
|
|||||||
if not self.run:
|
if not self.run:
|
||||||
logger.warning("No active run, skipping log_artifact")
|
logger.warning("No active run, skipping log_artifact")
|
||||||
return
|
return
|
||||||
|
|
||||||
mlflow.log_artifact(local_path, artifact_path)
|
mlflow.log_artifact(local_path, artifact_path)
|
||||||
logger.info(f"Logged artifact: {local_path}")
|
logger.info(f"Logged artifact: {local_path}")
|
||||||
|
|
||||||
def log_artifacts(
|
def log_artifacts(
|
||||||
self,
|
self,
|
||||||
local_dir: str,
|
local_dir: str,
|
||||||
@@ -270,7 +268,7 @@ class MLflowTracker:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Log all files in a directory as artifacts.
|
Log all files in a directory as artifacts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
local_dir: Path to the local directory
|
local_dir: Path to the local directory
|
||||||
artifact_path: Optional destination path within the artifact store
|
artifact_path: Optional destination path within the artifact store
|
||||||
@@ -278,10 +276,10 @@ class MLflowTracker:
|
|||||||
if not self.run:
|
if not self.run:
|
||||||
logger.warning("No active run, skipping log_artifacts")
|
logger.warning("No active run, skipping log_artifacts")
|
||||||
return
|
return
|
||||||
|
|
||||||
mlflow.log_artifacts(local_dir, artifact_path)
|
mlflow.log_artifacts(local_dir, artifact_path)
|
||||||
logger.info(f"Logged artifacts from: {local_dir}")
|
logger.info(f"Logged artifacts from: {local_dir}")
|
||||||
|
|
||||||
def log_dict(
|
def log_dict(
|
||||||
self,
|
self,
|
||||||
data: Dict[str, Any],
|
data: Dict[str, Any],
|
||||||
@@ -290,7 +288,7 @@ class MLflowTracker:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Log a dictionary as a JSON artifact.
|
Log a dictionary as a JSON artifact.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: Dictionary to log
|
data: Dictionary to log
|
||||||
filename: Name for the JSON file
|
filename: Name for the JSON file
|
||||||
@@ -299,14 +297,14 @@ class MLflowTracker:
|
|||||||
if not self.run:
|
if not self.run:
|
||||||
logger.warning("No active run, skipping log_dict")
|
logger.warning("No active run, skipping log_dict")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Ensure .json extension
|
# Ensure .json extension
|
||||||
if not filename.endswith(".json"):
|
if not filename.endswith(".json"):
|
||||||
filename += ".json"
|
filename += ".json"
|
||||||
|
|
||||||
mlflow.log_dict(data, f"{artifact_path}/{filename}" if artifact_path else filename)
|
mlflow.log_dict(data, f"{artifact_path}/{filename}" if artifact_path else filename)
|
||||||
logger.debug(f"Logged dict as: {filename}")
|
logger.debug(f"Logged dict as: {filename}")
|
||||||
|
|
||||||
def log_model_info(
|
def log_model_info(
|
||||||
self,
|
self,
|
||||||
model_type: str,
|
model_type: str,
|
||||||
@@ -317,7 +315,7 @@ class MLflowTracker:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Log model information as parameters and tags.
|
Log model information as parameters and tags.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_type: Type of model (e.g., "llm", "embedding", "stt")
|
model_type: Type of model (e.g., "llm", "embedding", "stt")
|
||||||
model_name: Name/identifier of the model
|
model_name: Name/identifier of the model
|
||||||
@@ -335,13 +333,13 @@ class MLflowTracker:
|
|||||||
if extra_info:
|
if extra_info:
|
||||||
for key, value in extra_info.items():
|
for key, value in extra_info.items():
|
||||||
params[f"model.{key}"] = value
|
params[f"model.{key}"] = value
|
||||||
|
|
||||||
self.log_params(params)
|
self.log_params(params)
|
||||||
|
|
||||||
# Also set as tags for easier filtering
|
# Also set as tags for easier filtering
|
||||||
mlflow.set_tag("model.type", model_type)
|
mlflow.set_tag("model.type", model_type)
|
||||||
mlflow.set_tag("model.name", model_name)
|
mlflow.set_tag("model.name", model_name)
|
||||||
|
|
||||||
def log_dataset_info(
|
def log_dataset_info(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
@@ -351,7 +349,7 @@ class MLflowTracker:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Log dataset information.
|
Log dataset information.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Dataset name
|
name: Dataset name
|
||||||
source: Dataset source (URL, path, etc.)
|
source: Dataset source (URL, path, etc.)
|
||||||
@@ -367,26 +365,26 @@ class MLflowTracker:
|
|||||||
if extra_info:
|
if extra_info:
|
||||||
for key, value in extra_info.items():
|
for key, value in extra_info.items():
|
||||||
params[f"dataset.{key}"] = value
|
params[f"dataset.{key}"] = value
|
||||||
|
|
||||||
self.log_params(params)
|
self.log_params(params)
|
||||||
|
|
||||||
def set_tag(self, key: str, value: str) -> None:
|
def set_tag(self, key: str, value: str) -> None:
|
||||||
"""Set a single tag on the run."""
|
"""Set a single tag on the run."""
|
||||||
if self.run:
|
if self.run:
|
||||||
mlflow.set_tag(key, value)
|
mlflow.set_tag(key, value)
|
||||||
|
|
||||||
def set_tags(self, tags: Dict[str, str]) -> None:
|
def set_tags(self, tags: Dict[str, str]) -> None:
|
||||||
"""Set multiple tags on the run."""
|
"""Set multiple tags on the run."""
|
||||||
if self.run:
|
if self.run:
|
||||||
mlflow.set_tags(tags)
|
mlflow.set_tags(tags)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def artifact_uri(self) -> Optional[str]:
|
def artifact_uri(self) -> Optional[str]:
|
||||||
"""Get the artifact URI for the current run."""
|
"""Get the artifact URI for the current run."""
|
||||||
if self.run:
|
if self.run:
|
||||||
return self.run.info.artifact_uri
|
return self.run.info.artifact_uri
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def experiment_id(self) -> Optional[str]:
|
def experiment_id(self) -> Optional[str]:
|
||||||
"""Get the experiment ID for the current run."""
|
"""Get the experiment ID for the current run."""
|
||||||
|
|||||||
@@ -8,12 +8,12 @@ import pytest
|
|||||||
def test_package_imports() -> None:
|
def test_package_imports() -> None:
|
||||||
"""All public symbols are importable."""
|
"""All public symbols are importable."""
|
||||||
from mlflow_utils import ( # noqa: F401
|
from mlflow_utils import ( # noqa: F401
|
||||||
|
InferenceMetricsTracker,
|
||||||
MLflowConfig,
|
MLflowConfig,
|
||||||
MLflowTracker,
|
MLflowTracker,
|
||||||
InferenceMetricsTracker,
|
ensure_experiment,
|
||||||
get_mlflow_client,
|
get_mlflow_client,
|
||||||
get_tracking_uri,
|
get_tracking_uri,
|
||||||
ensure_experiment,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -48,8 +48,8 @@ def test_kfp_components_importable() -> None:
|
|||||||
|
|
||||||
def test_model_registry_importable() -> None:
|
def test_model_registry_importable() -> None:
|
||||||
from mlflow_utils.model_registry import ( # noqa: F401
|
from mlflow_utils.model_registry import ( # noqa: F401
|
||||||
register_model_for_kserve,
|
|
||||||
generate_kserve_manifest,
|
generate_kserve_manifest,
|
||||||
|
register_model_for_kserve,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user