fix: resolve all ruff lint errors
This commit is contained in:
@@ -20,17 +20,17 @@ Usage:
|
||||
"""
|
||||
|
||||
from .client import (
|
||||
MLflowConfig,
|
||||
ensure_experiment,
|
||||
get_mlflow_client,
|
||||
get_tracking_uri,
|
||||
ensure_experiment,
|
||||
MLflowConfig,
|
||||
)
|
||||
from .tracker import MLflowTracker
|
||||
from .inference_tracker import InferenceMetricsTracker
|
||||
from .tracker import MLflowTracker
|
||||
|
||||
__all__ = [
|
||||
"get_mlflow_client",
|
||||
"get_tracking_uri",
|
||||
"get_tracking_uri",
|
||||
"ensure_experiment",
|
||||
"MLflowConfig",
|
||||
"MLflowTracker",
|
||||
|
||||
@@ -7,21 +7,21 @@ Command-line interface for querying and comparing MLflow experiments.
|
||||
Usage:
|
||||
# Compare recent runs in an experiment
|
||||
python -m mlflow_utils.cli compare --experiment chat-inference --runs 5
|
||||
|
||||
|
||||
# Get best run by metric
|
||||
python -m mlflow_utils.cli best --experiment evaluation --metric eval.accuracy
|
||||
|
||||
|
||||
# Generate performance report
|
||||
python -m mlflow_utils.cli report --service chat-handler --hours 24
|
||||
|
||||
|
||||
# Check model promotion criteria
|
||||
python -m mlflow_utils.cli promote --model whisper-finetuned \\
|
||||
--experiment voice-evaluation \\
|
||||
--criteria "eval.accuracy>=0.9,total_latency_p95<=2.0"
|
||||
|
||||
|
||||
# List experiments
|
||||
python -m mlflow_utils.cli list-experiments
|
||||
|
||||
|
||||
# Query runs
|
||||
python -m mlflow_utils.cli query --experiment chat-inference \\
|
||||
--filter "metrics.total_latency_mean < 1.0" --limit 10
|
||||
@@ -30,19 +30,16 @@ Usage:
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
from .client import get_mlflow_client, health_check
|
||||
from .experiment_comparison import (
|
||||
ExperimentAnalyzer,
|
||||
compare_experiments,
|
||||
promotion_recommendation,
|
||||
get_inference_performance_report,
|
||||
promotion_recommendation,
|
||||
)
|
||||
from .model_registry import (
|
||||
list_model_versions,
|
||||
get_production_model,
|
||||
generate_kserve_yaml,
|
||||
list_model_versions,
|
||||
)
|
||||
|
||||
|
||||
@@ -57,7 +54,7 @@ def cmd_list_experiments(args):
|
||||
"""List all experiments."""
|
||||
client = get_mlflow_client(tracking_uri=args.tracking_uri)
|
||||
experiments = client.search_experiments()
|
||||
|
||||
|
||||
print(f"{'ID':<10} {'Name':<40} {'Artifact Location'}")
|
||||
print("-" * 80)
|
||||
for exp in experiments:
|
||||
@@ -70,13 +67,13 @@ def cmd_compare(args):
|
||||
args.experiment,
|
||||
tracking_uri=args.tracking_uri
|
||||
)
|
||||
|
||||
|
||||
if args.run_ids:
|
||||
run_ids = args.run_ids.split(",")
|
||||
comparison = analyzer.compare_runs(run_ids=run_ids)
|
||||
else:
|
||||
comparison = analyzer.compare_runs(n_recent=args.runs)
|
||||
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(comparison.to_dict(), indent=2, default=str))
|
||||
else:
|
||||
@@ -89,17 +86,17 @@ def cmd_best(args):
|
||||
args.experiment,
|
||||
tracking_uri=args.tracking_uri
|
||||
)
|
||||
|
||||
|
||||
best_run = analyzer.get_best_run(
|
||||
metric=args.metric,
|
||||
minimize=args.minimize,
|
||||
filter_string=args.filter or "",
|
||||
)
|
||||
|
||||
|
||||
if not best_run:
|
||||
print(f"No runs found with metric '{args.metric}'")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
result = {
|
||||
"run_id": best_run.info.run_id,
|
||||
"run_name": best_run.info.run_name,
|
||||
@@ -107,7 +104,7 @@ def cmd_best(args):
|
||||
"all_metrics": dict(best_run.data.metrics),
|
||||
"params": dict(best_run.data.params),
|
||||
}
|
||||
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(result, indent=2))
|
||||
else:
|
||||
@@ -122,12 +119,12 @@ def cmd_summary(args):
|
||||
args.experiment,
|
||||
tracking_uri=args.tracking_uri
|
||||
)
|
||||
|
||||
|
||||
summary = analyzer.get_metrics_summary(
|
||||
hours=args.hours,
|
||||
metrics=args.metrics.split(",") if args.metrics else None,
|
||||
)
|
||||
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(summary, indent=2))
|
||||
else:
|
||||
@@ -146,7 +143,7 @@ def cmd_report(args):
|
||||
hours=args.hours,
|
||||
tracking_uri=args.tracking_uri,
|
||||
)
|
||||
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(report, indent=2))
|
||||
else:
|
||||
@@ -154,18 +151,18 @@ def cmd_report(args):
|
||||
print(f"Period: Last {report['period_hours']} hours")
|
||||
print(f"Generated: {report['generated_at']}")
|
||||
print()
|
||||
|
||||
|
||||
if report["latency"]:
|
||||
print("Latency Metrics:")
|
||||
for metric, stats in report["latency"].items():
|
||||
if "mean" in stats:
|
||||
print(f" {metric}: {stats['mean']:.4f}s (p50: {stats.get('median', 'N/A')})")
|
||||
|
||||
|
||||
if report["rag"]:
|
||||
print("\nRAG Usage:")
|
||||
for metric, stats in report["rag"].items():
|
||||
print(f" {metric}: {stats.get('mean', 'N/A')}")
|
||||
|
||||
|
||||
if report["errors"]:
|
||||
print("\nError Rates:")
|
||||
for metric, stats in report["errors"].items():
|
||||
@@ -183,14 +180,14 @@ def cmd_promote(args):
|
||||
metric, value = criterion.split(op)
|
||||
criteria[metric.strip()] = (op, float(value.strip()))
|
||||
break
|
||||
|
||||
|
||||
rec = promotion_recommendation(
|
||||
model_name=args.model,
|
||||
experiment_name=args.experiment,
|
||||
criteria=criteria,
|
||||
tracking_uri=args.tracking_uri,
|
||||
)
|
||||
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(rec.to_dict(), indent=2))
|
||||
else:
|
||||
@@ -208,12 +205,12 @@ def cmd_query(args):
|
||||
args.experiment,
|
||||
tracking_uri=args.tracking_uri
|
||||
)
|
||||
|
||||
|
||||
runs = analyzer.search_runs(
|
||||
filter_string=args.filter or "",
|
||||
max_results=args.limit,
|
||||
)
|
||||
|
||||
|
||||
if args.json:
|
||||
result = [
|
||||
{
|
||||
@@ -237,20 +234,21 @@ def cmd_query(args):
|
||||
def cmd_models(args):
|
||||
"""List registered models."""
|
||||
client = get_mlflow_client(tracking_uri=args.tracking_uri)
|
||||
|
||||
|
||||
if args.model:
|
||||
versions = list_model_versions(args.model, tracking_uri=args.tracking_uri)
|
||||
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(versions, indent=2, default=str))
|
||||
else:
|
||||
print(f"Model: {args.model}")
|
||||
for v in versions:
|
||||
print(f" v{v['version']} ({v['stage']}): {v['description'][:50] if v['description'] else 'No description'}")
|
||||
desc = v["description"][:50] if v["description"] else "No description"
|
||||
print(f" v{v['version']} ({v['stage']}): {desc}")
|
||||
else:
|
||||
# List all models
|
||||
models = client.search_registered_models()
|
||||
|
||||
|
||||
if args.json:
|
||||
result = [{"name": m.name, "description": m.description} for m in models]
|
||||
print(json.dumps(result, indent=2))
|
||||
@@ -271,7 +269,7 @@ def cmd_kserve(args):
|
||||
output_path=args.output,
|
||||
tracking_uri=args.tracking_uri,
|
||||
)
|
||||
|
||||
|
||||
if not args.output:
|
||||
print(yaml_str)
|
||||
|
||||
@@ -281,7 +279,7 @@ def main():
|
||||
description="MLflow Experiment CLI",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--tracking-uri",
|
||||
default=None,
|
||||
@@ -292,24 +290,24 @@ def main():
|
||||
action="store_true",
|
||||
help="Output as JSON",
|
||||
)
|
||||
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", help="Commands")
|
||||
|
||||
|
||||
# health
|
||||
health_parser = subparsers.add_parser("health", help="Check MLflow connectivity")
|
||||
health_parser.set_defaults(func=cmd_health)
|
||||
|
||||
|
||||
# list-experiments
|
||||
list_parser = subparsers.add_parser("list-experiments", help="List experiments")
|
||||
list_parser.set_defaults(func=cmd_list_experiments)
|
||||
|
||||
|
||||
# compare
|
||||
compare_parser = subparsers.add_parser("compare", help="Compare runs")
|
||||
compare_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
||||
compare_parser.add_argument("--runs", "-n", type=int, default=5, help="Number of recent runs")
|
||||
compare_parser.add_argument("--run-ids", help="Comma-separated run IDs to compare")
|
||||
compare_parser.set_defaults(func=cmd_compare)
|
||||
|
||||
|
||||
# best
|
||||
best_parser = subparsers.add_parser("best", help="Find best run by metric")
|
||||
best_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
||||
@@ -317,39 +315,39 @@ def main():
|
||||
best_parser.add_argument("--minimize", action="store_true", help="Minimize metric (default: maximize)")
|
||||
best_parser.add_argument("--filter", "-f", help="Filter string")
|
||||
best_parser.set_defaults(func=cmd_best)
|
||||
|
||||
|
||||
# summary
|
||||
summary_parser = subparsers.add_parser("summary", help="Get metrics summary")
|
||||
summary_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
||||
summary_parser.add_argument("--hours", type=int, default=24, help="Hours of data")
|
||||
summary_parser.add_argument("--metrics", help="Comma-separated metric names")
|
||||
summary_parser.set_defaults(func=cmd_summary)
|
||||
|
||||
|
||||
# report
|
||||
report_parser = subparsers.add_parser("report", help="Generate performance report")
|
||||
report_parser.add_argument("--service", "-s", required=True, help="Service name")
|
||||
report_parser.add_argument("--hours", type=int, default=24, help="Hours of data")
|
||||
report_parser.set_defaults(func=cmd_report)
|
||||
|
||||
|
||||
# promote
|
||||
promote_parser = subparsers.add_parser("promote", help="Check promotion criteria")
|
||||
promote_parser.add_argument("--model", "-m", required=True, help="Model name")
|
||||
promote_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
||||
promote_parser.add_argument("--criteria", "-c", required=True, help="Criteria (e.g., 'accuracy>=0.9,latency<=2.0')")
|
||||
promote_parser.set_defaults(func=cmd_promote)
|
||||
|
||||
|
||||
# query
|
||||
query_parser = subparsers.add_parser("query", help="Query runs")
|
||||
query_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
||||
query_parser.add_argument("--filter", "-f", help="MLflow filter string")
|
||||
query_parser.add_argument("--limit", "-l", type=int, default=20, help="Max results")
|
||||
query_parser.set_defaults(func=cmd_query)
|
||||
|
||||
|
||||
# models
|
||||
models_parser = subparsers.add_parser("models", help="List registered models")
|
||||
models_parser.add_argument("--model", "-m", help="Specific model name")
|
||||
models_parser.set_defaults(func=cmd_models)
|
||||
|
||||
|
||||
# kserve
|
||||
kserve_parser = subparsers.add_parser("kserve", help="Generate KServe manifest")
|
||||
kserve_parser.add_argument("--model", "-m", required=True, help="Model name")
|
||||
@@ -357,13 +355,13 @@ def main():
|
||||
kserve_parser.add_argument("--namespace", "-n", default="ai-ml", help="K8s namespace")
|
||||
kserve_parser.add_argument("--output", "-o", help="Output file path")
|
||||
kserve_parser.set_defaults(func=cmd_kserve)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
args.func(args)
|
||||
|
||||
|
||||
|
||||
@@ -5,10 +5,10 @@ Provides a configured MLflow client for all integrations in the LLM workflows.
|
||||
Supports both in-cluster and external access patterns.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
@@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class MLflowConfig:
|
||||
"""Configuration for MLflow integration."""
|
||||
|
||||
|
||||
# Tracking server URIs
|
||||
tracking_uri: str = field(
|
||||
default_factory=lambda: os.environ.get(
|
||||
@@ -33,7 +33,7 @@ class MLflowConfig:
|
||||
"https://mlflow.lab.daviestechlabs.io"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Artifact storage (NFS PVC mount)
|
||||
artifact_location: str = field(
|
||||
default_factory=lambda: os.environ.get(
|
||||
@@ -41,7 +41,7 @@ class MLflowConfig:
|
||||
"/mlflow/artifacts"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Default experiment settings
|
||||
default_experiment: str = field(
|
||||
default_factory=lambda: os.environ.get(
|
||||
@@ -49,7 +49,7 @@ class MLflowConfig:
|
||||
"llm-workflows"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Service identification
|
||||
service_name: str = field(
|
||||
default_factory=lambda: os.environ.get(
|
||||
@@ -57,10 +57,10 @@ class MLflowConfig:
|
||||
"unknown-service"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Additional tags to add to all runs
|
||||
default_tags: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
"""Add default tags based on environment."""
|
||||
env_tags = {
|
||||
@@ -74,10 +74,10 @@ class MLflowConfig:
|
||||
def get_tracking_uri(external: bool = False) -> str:
|
||||
"""
|
||||
Get the appropriate MLflow tracking URI.
|
||||
|
||||
|
||||
Args:
|
||||
external: If True, return the external URI for outside-cluster access
|
||||
|
||||
|
||||
Returns:
|
||||
The MLflow tracking URI string
|
||||
"""
|
||||
@@ -91,20 +91,20 @@ def get_mlflow_client(
|
||||
) -> MlflowClient:
|
||||
"""
|
||||
Get a configured MLflow client.
|
||||
|
||||
|
||||
Args:
|
||||
tracking_uri: Override the default tracking URI
|
||||
configure_global: If True, also set mlflow.set_tracking_uri()
|
||||
|
||||
|
||||
Returns:
|
||||
Configured MlflowClient instance
|
||||
"""
|
||||
uri = tracking_uri or get_tracking_uri()
|
||||
|
||||
|
||||
if configure_global:
|
||||
mlflow.set_tracking_uri(uri)
|
||||
logger.info(f"MLflow tracking URI set to: {uri}")
|
||||
|
||||
|
||||
client = MlflowClient(tracking_uri=uri)
|
||||
return client
|
||||
|
||||
@@ -116,21 +116,21 @@ def ensure_experiment(
|
||||
) -> str:
|
||||
"""
|
||||
Ensure an experiment exists, creating it if necessary.
|
||||
|
||||
|
||||
Args:
|
||||
experiment_name: Name of the experiment
|
||||
artifact_location: Override default artifact location
|
||||
tags: Additional tags for the experiment
|
||||
|
||||
|
||||
Returns:
|
||||
The experiment ID
|
||||
"""
|
||||
config = MLflowConfig()
|
||||
client = get_mlflow_client()
|
||||
|
||||
|
||||
# Check if experiment exists
|
||||
experiment = client.get_experiment_by_name(experiment_name)
|
||||
|
||||
|
||||
if experiment is None:
|
||||
# Create the experiment
|
||||
artifact_loc = artifact_location or f"{config.artifact_location}/{experiment_name}"
|
||||
@@ -143,7 +143,7 @@ def ensure_experiment(
|
||||
else:
|
||||
experiment_id = experiment.experiment_id
|
||||
logger.debug(f"Using existing experiment '{experiment_name}' with ID: {experiment_id}")
|
||||
|
||||
|
||||
return experiment_id
|
||||
|
||||
|
||||
@@ -154,17 +154,17 @@ def get_or_create_registered_model(
|
||||
) -> str:
|
||||
"""
|
||||
Get or create a registered model in the Model Registry.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to register
|
||||
description: Model description
|
||||
tags: Tags for the model
|
||||
|
||||
|
||||
Returns:
|
||||
The registered model name
|
||||
"""
|
||||
client = get_mlflow_client()
|
||||
|
||||
|
||||
try:
|
||||
# Check if model exists
|
||||
client.get_registered_model(model_name)
|
||||
@@ -177,14 +177,14 @@ def get_or_create_registered_model(
|
||||
tags=tags or {}
|
||||
)
|
||||
logger.info(f"Created registered model: {model_name}")
|
||||
|
||||
|
||||
return model_name
|
||||
|
||||
|
||||
def health_check() -> Dict[str, Any]:
|
||||
"""
|
||||
Check MLflow server connectivity and return status.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with health status information
|
||||
"""
|
||||
@@ -195,7 +195,7 @@ def health_check() -> Dict[str, Any]:
|
||||
"connected": False,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
client = get_mlflow_client(configure_global=False)
|
||||
# Try to list experiments as a health check
|
||||
@@ -205,5 +205,5 @@ def health_check() -> Dict[str, Any]:
|
||||
except Exception as e:
|
||||
result["error"] = str(e)
|
||||
logger.error(f"MLflow health check failed: {e}")
|
||||
|
||||
|
||||
return result
|
||||
|
||||
@@ -18,15 +18,15 @@ Usage:
|
||||
get_best_run,
|
||||
promotion_recommendation,
|
||||
)
|
||||
|
||||
|
||||
analyzer = ExperimentAnalyzer("chat-inference")
|
||||
|
||||
|
||||
# Compare last N runs
|
||||
comparison = analyzer.compare_recent_runs(n=5)
|
||||
|
||||
|
||||
# Find best performing model
|
||||
best = analyzer.get_best_run(metric="total_latency_mean", minimize=True)
|
||||
|
||||
|
||||
# Get promotion recommendation
|
||||
rec = analyzer.promotion_recommendation(
|
||||
model_name="whisper-finetuned",
|
||||
@@ -35,19 +35,15 @@ Usage:
|
||||
)
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any, List, Tuple, Union
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from mlflow.entities import Run, Experiment
|
||||
from mlflow.entities import Experiment, Run
|
||||
|
||||
from .client import get_mlflow_client, MLflowConfig
|
||||
from .client import get_mlflow_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -57,21 +53,21 @@ class RunComparison:
|
||||
"""Comparison result for multiple MLflow runs."""
|
||||
run_ids: List[str]
|
||||
experiment_name: str
|
||||
|
||||
|
||||
# Metric comparisons (metric_name -> {run_id -> value})
|
||||
metrics: Dict[str, Dict[str, float]] = field(default_factory=dict)
|
||||
|
||||
|
||||
# Parameter differences
|
||||
params: Dict[str, Dict[str, str]] = field(default_factory=dict)
|
||||
|
||||
|
||||
# Run metadata
|
||||
run_names: Dict[str, str] = field(default_factory=dict)
|
||||
start_times: Dict[str, datetime] = field(default_factory=dict)
|
||||
durations: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
# Best performers by metric
|
||||
best_by_metric: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
@@ -82,22 +78,22 @@ class RunComparison:
|
||||
"run_names": self.run_names,
|
||||
"best_by_metric": self.best_by_metric,
|
||||
}
|
||||
|
||||
|
||||
def summary_table(self) -> str:
|
||||
"""Generate a text summary table of the comparison."""
|
||||
if not self.run_ids:
|
||||
return "No runs to compare"
|
||||
|
||||
|
||||
lines = []
|
||||
lines.append(f"Experiment: {self.experiment_name}")
|
||||
lines.append(f"Comparing {len(self.run_ids)} runs")
|
||||
lines.append("")
|
||||
|
||||
|
||||
# Header
|
||||
header = ["Metric"] + [self.run_names.get(rid, rid[:8]) for rid in self.run_ids]
|
||||
lines.append(" | ".join(header))
|
||||
lines.append("-" * (len(lines[-1]) + 10))
|
||||
|
||||
|
||||
# Metrics
|
||||
for metric_name, values in sorted(self.metrics.items()):
|
||||
row = [metric_name]
|
||||
@@ -108,7 +104,7 @@ class RunComparison:
|
||||
else:
|
||||
row.append("N/A")
|
||||
lines.append(" | ".join(row))
|
||||
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@@ -121,7 +117,7 @@ class PromotionRecommendation:
|
||||
reasons: List[str]
|
||||
metrics_summary: Dict[str, float]
|
||||
comparison_with_production: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
@@ -136,20 +132,20 @@ class PromotionRecommendation:
|
||||
class ExperimentAnalyzer:
|
||||
"""
|
||||
Analyze MLflow experiments for model comparison and promotion decisions.
|
||||
|
||||
|
||||
Example:
|
||||
analyzer = ExperimentAnalyzer("chat-inference")
|
||||
|
||||
|
||||
# Get metrics summary for last 24 hours
|
||||
summary = analyzer.get_metrics_summary(hours=24)
|
||||
|
||||
|
||||
# Compare models by accuracy
|
||||
best = analyzer.get_best_run(metric="eval.accuracy", minimize=False)
|
||||
|
||||
|
||||
# Analyze inference latency trends
|
||||
trends = analyzer.get_metric_trends("total_latency_mean", days=7)
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name: str,
|
||||
@@ -157,7 +153,7 @@ class ExperimentAnalyzer:
|
||||
):
|
||||
"""
|
||||
Initialize the experiment analyzer.
|
||||
|
||||
|
||||
Args:
|
||||
experiment_name: Name of the MLflow experiment to analyze
|
||||
tracking_uri: Override default tracking URI
|
||||
@@ -166,14 +162,14 @@ class ExperimentAnalyzer:
|
||||
self.tracking_uri = tracking_uri
|
||||
self.client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
self._experiment: Optional[Experiment] = None
|
||||
|
||||
|
||||
@property
|
||||
def experiment(self) -> Optional[Experiment]:
|
||||
"""Get the experiment object, fetching if needed."""
|
||||
if self._experiment is None:
|
||||
self._experiment = self.client.get_experiment_by_name(self.experiment_name)
|
||||
return self._experiment
|
||||
|
||||
|
||||
def search_runs(
|
||||
self,
|
||||
filter_string: str = "",
|
||||
@@ -183,29 +179,29 @@ class ExperimentAnalyzer:
|
||||
) -> List[Run]:
|
||||
"""
|
||||
Search for runs matching criteria.
|
||||
|
||||
|
||||
Args:
|
||||
filter_string: MLflow search filter (e.g., "metrics.accuracy > 0.9")
|
||||
order_by: List of order clauses (e.g., ["metrics.accuracy DESC"])
|
||||
max_results: Maximum runs to return
|
||||
run_view_type: ACTIVE_ONLY, DELETED_ONLY, or ALL
|
||||
|
||||
|
||||
Returns:
|
||||
List of matching Run objects
|
||||
"""
|
||||
if not self.experiment:
|
||||
logger.warning(f"Experiment '{self.experiment_name}' not found")
|
||||
return []
|
||||
|
||||
|
||||
runs = self.client.search_runs(
|
||||
experiment_ids=[self.experiment.experiment_id],
|
||||
filter_string=filter_string,
|
||||
order_by=order_by or ["start_time DESC"],
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
|
||||
return runs
|
||||
|
||||
|
||||
def get_recent_runs(
|
||||
self,
|
||||
n: int = 10,
|
||||
@@ -213,11 +209,11 @@ class ExperimentAnalyzer:
|
||||
) -> List[Run]:
|
||||
"""
|
||||
Get the most recent runs.
|
||||
|
||||
|
||||
Args:
|
||||
n: Number of runs to return
|
||||
hours: Only include runs from the last N hours
|
||||
|
||||
|
||||
Returns:
|
||||
List of Run objects
|
||||
"""
|
||||
@@ -226,13 +222,13 @@ class ExperimentAnalyzer:
|
||||
cutoff = datetime.now() - timedelta(hours=hours)
|
||||
cutoff_ms = int(cutoff.timestamp() * 1000)
|
||||
filter_string = f"attributes.start_time >= {cutoff_ms}"
|
||||
|
||||
|
||||
return self.search_runs(
|
||||
filter_string=filter_string,
|
||||
order_by=["start_time DESC"],
|
||||
max_results=n,
|
||||
)
|
||||
|
||||
|
||||
def compare_runs(
|
||||
self,
|
||||
run_ids: Optional[List[str]] = None,
|
||||
@@ -240,11 +236,11 @@ class ExperimentAnalyzer:
|
||||
) -> RunComparison:
|
||||
"""
|
||||
Compare multiple runs side by side.
|
||||
|
||||
|
||||
Args:
|
||||
run_ids: Specific run IDs to compare, or None for recent runs
|
||||
n_recent: If run_ids is None, compare this many recent runs
|
||||
|
||||
|
||||
Returns:
|
||||
RunComparison object with detailed comparison
|
||||
"""
|
||||
@@ -252,18 +248,18 @@ class ExperimentAnalyzer:
|
||||
runs = [self.client.get_run(rid) for rid in run_ids]
|
||||
else:
|
||||
runs = self.get_recent_runs(n=n_recent)
|
||||
|
||||
|
||||
comparison = RunComparison(
|
||||
run_ids=[r.info.run_id for r in runs],
|
||||
experiment_name=self.experiment_name,
|
||||
)
|
||||
|
||||
|
||||
# Collect all metrics and find best performers
|
||||
all_metrics: Dict[str, Dict[str, float]] = defaultdict(dict)
|
||||
|
||||
|
||||
for run in runs:
|
||||
run_id = run.info.run_id
|
||||
|
||||
|
||||
# Metadata
|
||||
comparison.run_names[run_id] = run.info.run_name or run_id[:8]
|
||||
comparison.start_times[run_id] = datetime.fromtimestamp(
|
||||
@@ -273,39 +269,39 @@ class ExperimentAnalyzer:
|
||||
comparison.durations[run_id] = (
|
||||
run.info.end_time - run.info.start_time
|
||||
) / 1000
|
||||
|
||||
|
||||
# Metrics
|
||||
for key, value in run.data.metrics.items():
|
||||
all_metrics[key][run_id] = value
|
||||
|
||||
|
||||
# Params
|
||||
for key, value in run.data.params.items():
|
||||
if key not in comparison.params:
|
||||
comparison.params[key] = {}
|
||||
comparison.params[key][run_id] = value
|
||||
|
||||
|
||||
comparison.metrics = dict(all_metrics)
|
||||
|
||||
|
||||
# Find best performers for each metric
|
||||
for metric_name, values in all_metrics.items():
|
||||
if not values:
|
||||
continue
|
||||
|
||||
|
||||
# Determine if lower is better based on metric name
|
||||
minimize = any(
|
||||
term in metric_name.lower()
|
||||
for term in ["latency", "error", "loss", "time"]
|
||||
)
|
||||
|
||||
|
||||
if minimize:
|
||||
best_id = min(values.keys(), key=lambda k: values[k])
|
||||
else:
|
||||
best_id = max(values.keys(), key=lambda k: values[k])
|
||||
|
||||
|
||||
comparison.best_by_metric[metric_name] = best_id
|
||||
|
||||
|
||||
return comparison
|
||||
|
||||
|
||||
def get_best_run(
|
||||
self,
|
||||
metric: str,
|
||||
@@ -315,32 +311,32 @@ class ExperimentAnalyzer:
|
||||
) -> Optional[Run]:
|
||||
"""
|
||||
Get the best run by a specific metric.
|
||||
|
||||
|
||||
Args:
|
||||
metric: Metric name to optimize
|
||||
minimize: If True, find minimum; if False, find maximum
|
||||
filter_string: Additional filter criteria
|
||||
max_results: Maximum runs to consider
|
||||
|
||||
|
||||
Returns:
|
||||
Best Run object, or None if no runs found
|
||||
"""
|
||||
direction = "ASC" if minimize else "DESC"
|
||||
|
||||
|
||||
runs = self.search_runs(
|
||||
filter_string=filter_string,
|
||||
order_by=[f"metrics.{metric} {direction}"],
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
|
||||
# Filter to only runs that have the metric
|
||||
runs_with_metric = [
|
||||
r for r in runs
|
||||
if metric in r.data.metrics
|
||||
]
|
||||
|
||||
|
||||
return runs_with_metric[0] if runs_with_metric else None
|
||||
|
||||
|
||||
def get_metrics_summary(
|
||||
self,
|
||||
hours: Optional[int] = None,
|
||||
@@ -348,45 +344,45 @@ class ExperimentAnalyzer:
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
Get summary statistics for metrics.
|
||||
|
||||
|
||||
Args:
|
||||
hours: Only include runs from the last N hours
|
||||
metrics: Specific metrics to summarize (None for all)
|
||||
|
||||
|
||||
Returns:
|
||||
Dict mapping metric names to {mean, min, max, count}
|
||||
"""
|
||||
import statistics
|
||||
|
||||
|
||||
runs = self.get_recent_runs(n=1000, hours=hours)
|
||||
|
||||
|
||||
# Collect all metric values
|
||||
metric_values: Dict[str, List[float]] = defaultdict(list)
|
||||
|
||||
|
||||
for run in runs:
|
||||
for key, value in run.data.metrics.items():
|
||||
if metrics is None or key in metrics:
|
||||
metric_values[key].append(value)
|
||||
|
||||
|
||||
# Calculate statistics
|
||||
summary = {}
|
||||
for metric_name, values in metric_values.items():
|
||||
if not values:
|
||||
continue
|
||||
|
||||
|
||||
summary[metric_name] = {
|
||||
"mean": statistics.mean(values),
|
||||
"min": min(values),
|
||||
"max": max(values),
|
||||
"count": len(values),
|
||||
}
|
||||
|
||||
|
||||
if len(values) >= 2:
|
||||
summary[metric_name]["stdev"] = statistics.stdev(values)
|
||||
summary[metric_name]["median"] = statistics.median(values)
|
||||
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def get_metric_trends(
|
||||
self,
|
||||
metric: str,
|
||||
@@ -395,30 +391,30 @@ class ExperimentAnalyzer:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get metric trends over time.
|
||||
|
||||
|
||||
Args:
|
||||
metric: Metric name to track
|
||||
days: Number of days to look back
|
||||
granularity_hours: Time bucket size in hours
|
||||
|
||||
|
||||
Returns:
|
||||
List of {timestamp, mean, min, max, count} dicts
|
||||
"""
|
||||
import statistics
|
||||
|
||||
|
||||
runs = self.get_recent_runs(n=10000, hours=days * 24)
|
||||
|
||||
|
||||
# Group runs by time bucket
|
||||
buckets: Dict[int, List[float]] = defaultdict(list)
|
||||
bucket_size_ms = granularity_hours * 3600 * 1000
|
||||
|
||||
|
||||
for run in runs:
|
||||
if metric not in run.data.metrics:
|
||||
continue
|
||||
|
||||
|
||||
bucket = (run.info.start_time // bucket_size_ms) * bucket_size_ms
|
||||
buckets[bucket].append(run.data.metrics[metric])
|
||||
|
||||
|
||||
# Calculate statistics per bucket
|
||||
trends = []
|
||||
for bucket_ts, values in sorted(buckets.items()):
|
||||
@@ -432,9 +428,9 @@ class ExperimentAnalyzer:
|
||||
if len(values) >= 2:
|
||||
trend["stdev"] = statistics.stdev(values)
|
||||
trends.append(trend)
|
||||
|
||||
|
||||
return trends
|
||||
|
||||
|
||||
def get_runs_by_tag(
|
||||
self,
|
||||
tag_key: str,
|
||||
@@ -443,12 +439,12 @@ class ExperimentAnalyzer:
|
||||
) -> List[Run]:
|
||||
"""
|
||||
Get runs with a specific tag.
|
||||
|
||||
|
||||
Args:
|
||||
tag_key: Tag key to filter by
|
||||
tag_value: Tag value to match
|
||||
max_results: Maximum runs to return
|
||||
|
||||
|
||||
Returns:
|
||||
List of matching Run objects
|
||||
"""
|
||||
@@ -456,7 +452,7 @@ class ExperimentAnalyzer:
|
||||
filter_string=f"tags.{tag_key} = '{tag_value}'",
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
|
||||
def get_model_runs(
|
||||
self,
|
||||
model_name: str,
|
||||
@@ -464,11 +460,11 @@ class ExperimentAnalyzer:
|
||||
) -> List[Run]:
|
||||
"""
|
||||
Get runs for a specific model.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Model name to filter by
|
||||
max_results: Maximum runs to return
|
||||
|
||||
|
||||
Returns:
|
||||
List of matching Run objects
|
||||
"""
|
||||
@@ -477,14 +473,14 @@ class ExperimentAnalyzer:
|
||||
filter_string=f"tags.`model.name` = '{model_name}'",
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
|
||||
if not runs:
|
||||
# Try params
|
||||
runs = self.search_runs(
|
||||
filter_string=f"params.model_name = '{model_name}'",
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
|
||||
return runs
|
||||
|
||||
|
||||
@@ -495,23 +491,23 @@ def compare_experiments(
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
Compare metrics across multiple experiments.
|
||||
|
||||
|
||||
Args:
|
||||
experiment_names: Names of experiments to compare
|
||||
metric: Metric to compare
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
|
||||
Returns:
|
||||
Dict mapping experiment names to metric statistics
|
||||
"""
|
||||
results = {}
|
||||
|
||||
|
||||
for exp_name in experiment_names:
|
||||
analyzer = ExperimentAnalyzer(exp_name, tracking_uri=tracking_uri)
|
||||
summary = analyzer.get_metrics_summary(metrics=[metric])
|
||||
if metric in summary:
|
||||
results[exp_name] = summary[metric]
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -523,7 +519,7 @@ def promotion_recommendation(
|
||||
) -> PromotionRecommendation:
|
||||
"""
|
||||
Generate a recommendation for model promotion.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to evaluate
|
||||
experiment_name: Experiment containing evaluation runs
|
||||
@@ -531,15 +527,15 @@ def promotion_recommendation(
|
||||
comparison is one of: ">=", "<=", ">", "<"
|
||||
e.g., {"eval.accuracy": (">=", 0.9), "total_latency_p95": ("<=", 2.0)}
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
|
||||
Returns:
|
||||
PromotionRecommendation with decision and reasons
|
||||
"""
|
||||
analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri)
|
||||
|
||||
|
||||
# Get model runs
|
||||
runs = analyzer.get_model_runs(model_name, max_results=10)
|
||||
|
||||
|
||||
if not runs:
|
||||
return PromotionRecommendation(
|
||||
model_name=model_name,
|
||||
@@ -548,41 +544,41 @@ def promotion_recommendation(
|
||||
reasons=["No runs found for this model"],
|
||||
metrics_summary={},
|
||||
)
|
||||
|
||||
|
||||
# Get the most recent run
|
||||
latest_run = runs[0]
|
||||
metrics = latest_run.data.metrics
|
||||
|
||||
|
||||
# Evaluate criteria
|
||||
reasons = []
|
||||
passed = True
|
||||
|
||||
|
||||
comparisons = {
|
||||
">=": lambda a, b: a >= b,
|
||||
"<=": lambda a, b: a <= b,
|
||||
">": lambda a, b: a > b,
|
||||
"<": lambda a, b: a < b,
|
||||
}
|
||||
|
||||
|
||||
for metric_name, (comparison, threshold) in criteria.items():
|
||||
if metric_name not in metrics:
|
||||
reasons.append(f"Metric '{metric_name}' not found")
|
||||
passed = False
|
||||
continue
|
||||
|
||||
|
||||
value = metrics[metric_name]
|
||||
compare_fn = comparisons.get(comparison)
|
||||
|
||||
|
||||
if compare_fn is None:
|
||||
reasons.append(f"Invalid comparison operator: {comparison}")
|
||||
continue
|
||||
|
||||
|
||||
if compare_fn(value, threshold):
|
||||
reasons.append(f"✓ {metric_name}: {value:.4f} {comparison} {threshold}")
|
||||
else:
|
||||
reasons.append(f"✗ {metric_name}: {value:.4f} NOT {comparison} {threshold}")
|
||||
passed = False
|
||||
|
||||
|
||||
# Extract version from tags if available
|
||||
version = None
|
||||
if "mlflow.version" in latest_run.data.tags:
|
||||
@@ -590,7 +586,7 @@ def promotion_recommendation(
|
||||
version = int(latest_run.data.tags["mlflow.version"])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
return PromotionRecommendation(
|
||||
model_name=model_name,
|
||||
version=version,
|
||||
@@ -607,21 +603,21 @@ def get_inference_performance_report(
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate an inference performance report for a service.
|
||||
|
||||
|
||||
Args:
|
||||
service_name: Service name (chat-handler, voice-assistant)
|
||||
hours: Hours of data to analyze
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
|
||||
Returns:
|
||||
Performance report dictionary
|
||||
"""
|
||||
experiment_name = f"{service_name.replace('-', '')}-inference"
|
||||
analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri)
|
||||
|
||||
|
||||
# Get summary metrics
|
||||
summary = analyzer.get_metrics_summary(hours=hours)
|
||||
|
||||
|
||||
# Key latency metrics
|
||||
latency_metrics = [
|
||||
"total_latency_mean",
|
||||
@@ -631,7 +627,7 @@ def get_inference_performance_report(
|
||||
"embedding_latency_mean",
|
||||
"rag_search_latency_mean",
|
||||
]
|
||||
|
||||
|
||||
report = {
|
||||
"service": service_name,
|
||||
"period_hours": hours,
|
||||
@@ -641,24 +637,24 @@ def get_inference_performance_report(
|
||||
"rag": {},
|
||||
"errors": {},
|
||||
}
|
||||
|
||||
|
||||
# Latency section
|
||||
for metric in latency_metrics:
|
||||
if metric in summary:
|
||||
report["latency"][metric] = summary[metric]
|
||||
|
||||
|
||||
# Throughput
|
||||
if "total_requests" in summary:
|
||||
report["throughput"]["total_requests"] = summary["total_requests"]["mean"]
|
||||
|
||||
|
||||
# RAG usage
|
||||
rag_metrics = ["rag_enabled_pct", "rag_documents_retrieved_mean", "rag_documents_used_mean"]
|
||||
for metric in rag_metrics:
|
||||
if metric in summary:
|
||||
report["rag"][metric] = summary[metric]
|
||||
|
||||
|
||||
# Error rate
|
||||
if "error_rate" in summary:
|
||||
report["errors"]["error_rate_pct"] = summary["error_rate"]["mean"]
|
||||
|
||||
|
||||
return report
|
||||
|
||||
@@ -9,20 +9,19 @@ complement OTel metrics with MLflow experiment tracking for
|
||||
longer-term analysis and model comparison.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass, field
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
from .client import get_mlflow_client, ensure_experiment, MLflowConfig
|
||||
from .client import MLflowConfig, ensure_experiment, get_mlflow_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,7 +32,7 @@ class InferenceMetrics:
|
||||
request_id: str
|
||||
user_id: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
|
||||
|
||||
# Timing metrics (in seconds)
|
||||
total_latency: float = 0.0
|
||||
embedding_latency: float = 0.0
|
||||
@@ -42,33 +41,33 @@ class InferenceMetrics:
|
||||
llm_latency: float = 0.0
|
||||
tts_latency: float = 0.0
|
||||
stt_latency: float = 0.0
|
||||
|
||||
|
||||
# Token/size metrics
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
prompt_length: int = 0
|
||||
response_length: int = 0
|
||||
|
||||
|
||||
# RAG metrics
|
||||
rag_enabled: bool = False
|
||||
rag_documents_retrieved: int = 0
|
||||
rag_documents_used: int = 0
|
||||
reranker_enabled: bool = False
|
||||
|
||||
|
||||
# Quality indicators
|
||||
is_streaming: bool = False
|
||||
is_premium: bool = False
|
||||
has_error: bool = False
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
# Model information
|
||||
model_name: Optional[str] = None
|
||||
model_endpoint: Optional[str] = None
|
||||
|
||||
|
||||
# Timestamps
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
def as_metrics_dict(self) -> Dict[str, float]:
|
||||
"""Convert numeric fields to a metrics dictionary."""
|
||||
return {
|
||||
@@ -87,7 +86,7 @@ class InferenceMetrics:
|
||||
"rag_documents_retrieved": float(self.rag_documents_retrieved),
|
||||
"rag_documents_used": float(self.rag_documents_used),
|
||||
}
|
||||
|
||||
|
||||
def as_params_dict(self) -> Dict[str, str]:
|
||||
"""Convert configuration fields to a params dictionary."""
|
||||
params = {
|
||||
@@ -106,39 +105,39 @@ class InferenceMetrics:
|
||||
class InferenceMetricsTracker:
|
||||
"""
|
||||
Async-compatible MLflow tracker for inference metrics.
|
||||
|
||||
|
||||
Uses batching and a background thread pool to avoid blocking
|
||||
the async event loop during MLflow calls.
|
||||
|
||||
|
||||
Example usage in chat-handler:
|
||||
|
||||
|
||||
class ChatHandler:
|
||||
def __init__(self):
|
||||
self.mlflow_tracker = InferenceMetricsTracker(
|
||||
service_name="chat-handler",
|
||||
experiment_name="chat-inference"
|
||||
)
|
||||
|
||||
|
||||
async def setup(self):
|
||||
await self.mlflow_tracker.start()
|
||||
|
||||
|
||||
async def process_request(self, msg):
|
||||
metrics = InferenceMetrics(request_id=request_id)
|
||||
|
||||
|
||||
# Track timing
|
||||
start = time.time()
|
||||
# ... do embedding ...
|
||||
metrics.embedding_latency = time.time() - start
|
||||
|
||||
|
||||
# ... more processing ...
|
||||
|
||||
|
||||
# Log metrics (non-blocking)
|
||||
await self.mlflow_tracker.log_inference(metrics)
|
||||
|
||||
|
||||
async def shutdown(self):
|
||||
await self.mlflow_tracker.stop()
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service_name: str,
|
||||
@@ -151,7 +150,7 @@ class InferenceMetricsTracker:
|
||||
):
|
||||
"""
|
||||
Initialize the inference metrics tracker.
|
||||
|
||||
|
||||
Args:
|
||||
service_name: Name of the service (e.g., "chat-handler")
|
||||
experiment_name: MLflow experiment name (defaults to service_name)
|
||||
@@ -167,7 +166,7 @@ class InferenceMetricsTracker:
|
||||
self.batch_size = batch_size
|
||||
self.flush_interval = flush_interval_seconds
|
||||
self.enable_batching = enable_batching
|
||||
|
||||
|
||||
self.config = MLflowConfig()
|
||||
self._batch: List[InferenceMetrics] = []
|
||||
self._batch_lock = asyncio.Lock()
|
||||
@@ -176,34 +175,34 @@ class InferenceMetricsTracker:
|
||||
self._running = False
|
||||
self._client: Optional[MlflowClient] = None
|
||||
self._experiment_id: Optional[str] = None
|
||||
|
||||
|
||||
# Aggregate metrics for periodic logging
|
||||
self._aggregate_metrics: Dict[str, List[float]] = defaultdict(list)
|
||||
self._request_count = 0
|
||||
self._error_count = 0
|
||||
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the tracker and initialize MLflow connection."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
|
||||
self._running = True
|
||||
|
||||
|
||||
# Initialize MLflow in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
self._executor,
|
||||
self._init_mlflow
|
||||
)
|
||||
|
||||
|
||||
if self.enable_batching:
|
||||
self._flush_task = asyncio.create_task(self._periodic_flush())
|
||||
|
||||
|
||||
logger.info(
|
||||
f"InferenceMetricsTracker started for {self.service_name} "
|
||||
f"(experiment: {self.experiment_name})"
|
||||
)
|
||||
|
||||
|
||||
def _init_mlflow(self) -> None:
|
||||
"""Initialize MLflow client and experiment (runs in thread pool)."""
|
||||
self._client = get_mlflow_client(
|
||||
@@ -217,47 +216,47 @@ class InferenceMetricsTracker:
|
||||
"type": "inference-metrics",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the tracker and flush remaining metrics."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
|
||||
self._running = False
|
||||
|
||||
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
# Final flush
|
||||
await self._flush_batch()
|
||||
|
||||
|
||||
self._executor.shutdown(wait=True)
|
||||
logger.info(f"InferenceMetricsTracker stopped for {self.service_name}")
|
||||
|
||||
|
||||
async def log_inference(self, metrics: InferenceMetrics) -> None:
|
||||
"""
|
||||
Log inference metrics (non-blocking).
|
||||
|
||||
|
||||
Args:
|
||||
metrics: InferenceMetrics object with request data
|
||||
"""
|
||||
if not self._running:
|
||||
logger.warning("Tracker not running, skipping metrics")
|
||||
return
|
||||
|
||||
|
||||
self._request_count += 1
|
||||
if metrics.has_error:
|
||||
self._error_count += 1
|
||||
|
||||
|
||||
# Update aggregates
|
||||
for key, value in metrics.as_metrics_dict().items():
|
||||
if value > 0:
|
||||
self._aggregate_metrics[key].append(value)
|
||||
|
||||
|
||||
if self.enable_batching:
|
||||
async with self._batch_lock:
|
||||
self._batch.append(metrics)
|
||||
@@ -270,29 +269,29 @@ class InferenceMetricsTracker:
|
||||
self._executor,
|
||||
partial(self._log_single_inference, metrics)
|
||||
)
|
||||
|
||||
|
||||
async def _periodic_flush(self) -> None:
|
||||
"""Periodically flush batched metrics."""
|
||||
while self._running:
|
||||
await asyncio.sleep(self.flush_interval)
|
||||
await self._flush_batch()
|
||||
|
||||
|
||||
async def _flush_batch(self) -> None:
|
||||
"""Flush the current batch of metrics to MLflow."""
|
||||
async with self._batch_lock:
|
||||
if not self._batch:
|
||||
return
|
||||
|
||||
|
||||
batch = self._batch
|
||||
self._batch = []
|
||||
|
||||
|
||||
# Log in thread pool
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
self._executor,
|
||||
partial(self._log_batch, batch)
|
||||
)
|
||||
|
||||
|
||||
def _log_single_inference(self, metrics: InferenceMetrics) -> None:
|
||||
"""Log a single inference request to MLflow (runs in thread pool)."""
|
||||
try:
|
||||
@@ -307,7 +306,7 @@ class InferenceMetricsTracker:
|
||||
):
|
||||
mlflow.log_params(metrics.as_params_dict())
|
||||
mlflow.log_metrics(metrics.as_metrics_dict())
|
||||
|
||||
|
||||
if metrics.user_id:
|
||||
mlflow.set_tag("user_id", metrics.user_id)
|
||||
if metrics.session_id:
|
||||
@@ -318,18 +317,18 @@ class InferenceMetricsTracker:
|
||||
mlflow.set_tag("error_message", metrics.error_message[:250])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log inference metrics: {e}")
|
||||
|
||||
|
||||
def _log_batch(self, batch: List[InferenceMetrics]) -> None:
|
||||
"""Log a batch of inference metrics as aggregate statistics."""
|
||||
if not batch:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
# Calculate aggregates
|
||||
aggregates = self._calculate_aggregates(batch)
|
||||
|
||||
|
||||
run_name = f"batch-{self.service_name}-{int(time.time())}"
|
||||
|
||||
|
||||
with mlflow.start_run(
|
||||
experiment_id=self._experiment_id,
|
||||
run_name=run_name,
|
||||
@@ -341,83 +340,83 @@ class InferenceMetricsTracker:
|
||||
):
|
||||
# Log aggregate metrics
|
||||
mlflow.log_metrics(aggregates)
|
||||
|
||||
|
||||
# Log batch info
|
||||
mlflow.log_param("batch_size", len(batch))
|
||||
mlflow.log_param("time_window_start", min(m.timestamp for m in batch))
|
||||
mlflow.log_param("time_window_end", max(m.timestamp for m in batch))
|
||||
|
||||
|
||||
# Log configuration breakdown
|
||||
rag_enabled_count = sum(1 for m in batch if m.rag_enabled)
|
||||
streaming_count = sum(1 for m in batch if m.is_streaming)
|
||||
premium_count = sum(1 for m in batch if m.is_premium)
|
||||
error_count = sum(1 for m in batch if m.has_error)
|
||||
|
||||
|
||||
mlflow.log_metrics({
|
||||
"rag_enabled_pct": rag_enabled_count / len(batch) * 100,
|
||||
"streaming_pct": streaming_count / len(batch) * 100,
|
||||
"premium_pct": premium_count / len(batch) * 100,
|
||||
"error_rate": error_count / len(batch) * 100,
|
||||
})
|
||||
|
||||
|
||||
# Log model distribution
|
||||
model_counts: Dict[str, int] = defaultdict(int)
|
||||
for m in batch:
|
||||
if m.model_name:
|
||||
model_counts[m.model_name] += 1
|
||||
|
||||
|
||||
if model_counts:
|
||||
mlflow.log_dict(
|
||||
{"models": dict(model_counts)},
|
||||
"model_distribution.json"
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Logged batch of {len(batch)} inference metrics")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log batch metrics: {e}")
|
||||
|
||||
|
||||
def _calculate_aggregates(
|
||||
self,
|
||||
batch: List[InferenceMetrics]
|
||||
) -> Dict[str, float]:
|
||||
"""Calculate aggregate statistics from a batch of metrics."""
|
||||
import statistics
|
||||
|
||||
|
||||
aggregates = {}
|
||||
|
||||
|
||||
# Collect all numeric metrics
|
||||
metric_values: Dict[str, List[float]] = defaultdict(list)
|
||||
for m in batch:
|
||||
for key, value in m.as_metrics_dict().items():
|
||||
if value > 0:
|
||||
metric_values[key].append(value)
|
||||
|
||||
|
||||
# Calculate statistics for each metric
|
||||
for key, values in metric_values.items():
|
||||
if not values:
|
||||
continue
|
||||
|
||||
|
||||
aggregates[f"{key}_mean"] = statistics.mean(values)
|
||||
aggregates[f"{key}_min"] = min(values)
|
||||
aggregates[f"{key}_max"] = max(values)
|
||||
|
||||
|
||||
if len(values) >= 2:
|
||||
aggregates[f"{key}_p50"] = statistics.median(values)
|
||||
aggregates[f"{key}_stdev"] = statistics.stdev(values)
|
||||
|
||||
|
||||
if len(values) >= 4:
|
||||
sorted_vals = sorted(values)
|
||||
p95_idx = int(len(sorted_vals) * 0.95)
|
||||
p99_idx = int(len(sorted_vals) * 0.99)
|
||||
aggregates[f"{key}_p95"] = sorted_vals[p95_idx]
|
||||
aggregates[f"{key}_p99"] = sorted_vals[p99_idx]
|
||||
|
||||
|
||||
# Add counts
|
||||
aggregates["total_requests"] = float(len(batch))
|
||||
|
||||
|
||||
return aggregates
|
||||
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get current tracker statistics."""
|
||||
return {
|
||||
|
||||
@@ -21,22 +21,22 @@ Usage in a Kubeflow Pipeline:
|
||||
experiment_name="my-experiment",
|
||||
run_name="training-run-1"
|
||||
)
|
||||
|
||||
|
||||
# ... your pipeline steps ...
|
||||
|
||||
|
||||
# Log metrics
|
||||
log_step = log_metrics_component(
|
||||
run_id=run_info.outputs["run_id"],
|
||||
metrics={"accuracy": 0.95, "loss": 0.05}
|
||||
)
|
||||
|
||||
|
||||
# End run
|
||||
end_mlflow_run(run_id=run_info.outputs["run_id"])
|
||||
"""
|
||||
|
||||
from kfp import dsl
|
||||
from typing import Dict, Any, List, Optional, NamedTuple
|
||||
from typing import Any, Dict, List, NamedTuple
|
||||
|
||||
from kfp import dsl
|
||||
|
||||
# MLflow component image with all required dependencies
|
||||
MLFLOW_IMAGE = "python:3.13-slim"
|
||||
@@ -60,31 +60,32 @@ def create_mlflow_run(
|
||||
) -> NamedTuple('RunInfo', [('run_id', str), ('experiment_id', str), ('artifact_uri', str)]):
|
||||
"""
|
||||
Create a new MLflow run for the pipeline.
|
||||
|
||||
|
||||
This should be called at the start of a pipeline to initialize
|
||||
tracking. The returned run_id should be passed to subsequent
|
||||
components for logging.
|
||||
|
||||
|
||||
Args:
|
||||
experiment_name: Name of the MLflow experiment
|
||||
run_name: Name for this specific run
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
tags: Optional tags to add to the run
|
||||
params: Optional parameters to log
|
||||
|
||||
|
||||
Returns:
|
||||
NamedTuple with run_id, experiment_id, and artifact_uri
|
||||
"""
|
||||
import os
|
||||
from collections import namedtuple
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
# Set tracking URI
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
|
||||
|
||||
client = MlflowClient()
|
||||
|
||||
|
||||
# Get or create experiment
|
||||
experiment = client.get_experiment_by_name(experiment_name)
|
||||
if experiment is None:
|
||||
@@ -94,7 +95,7 @@ def create_mlflow_run(
|
||||
)
|
||||
else:
|
||||
experiment_id = experiment.experiment_id
|
||||
|
||||
|
||||
# Create default tags
|
||||
default_tags = {
|
||||
"pipeline.type": "kubeflow",
|
||||
@@ -103,24 +104,24 @@ def create_mlflow_run(
|
||||
}
|
||||
if tags:
|
||||
default_tags.update(tags)
|
||||
|
||||
|
||||
# Start run
|
||||
run = mlflow.start_run(
|
||||
experiment_id=experiment_id,
|
||||
run_name=run_name,
|
||||
tags=default_tags,
|
||||
)
|
||||
|
||||
|
||||
# Log initial params
|
||||
if params:
|
||||
mlflow.log_params(params)
|
||||
|
||||
|
||||
run_id = run.info.run_id
|
||||
artifact_uri = run.info.artifact_uri
|
||||
|
||||
|
||||
# End run (KFP components are isolated, we'll resume in other components)
|
||||
mlflow.end_run()
|
||||
|
||||
|
||||
RunInfo = namedtuple('RunInfo', ['run_id', 'experiment_id', 'artifact_uri'])
|
||||
return RunInfo(run_id, experiment_id, artifact_uri)
|
||||
|
||||
@@ -136,24 +137,24 @@ def log_params_component(
|
||||
) -> str:
|
||||
"""
|
||||
Log parameters to an existing MLflow run.
|
||||
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
params: Dictionary of parameters to log
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
|
||||
for key, value in params.items():
|
||||
client.log_param(run_id, key, str(value)[:500])
|
||||
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@@ -169,25 +170,25 @@ def log_metrics_component(
|
||||
) -> str:
|
||||
"""
|
||||
Log metrics to an existing MLflow run.
|
||||
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
metrics: Dictionary of metrics to log
|
||||
step: Step number for time-series metrics
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
|
||||
for key, value in metrics.items():
|
||||
client.log_metric(run_id, key, float(value), step=step)
|
||||
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@@ -203,24 +204,24 @@ def log_artifact_component(
|
||||
) -> str:
|
||||
"""
|
||||
Log an artifact file to an existing MLflow run.
|
||||
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
artifact_path: Path to the artifact file
|
||||
artifact_name: Optional destination name in artifact store
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
|
||||
client.log_artifact(run_id, artifact_path, artifact_name or None)
|
||||
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@@ -236,36 +237,37 @@ def log_dict_artifact(
|
||||
) -> str:
|
||||
"""
|
||||
Log a dictionary as a JSON artifact.
|
||||
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
data: Dictionary to save as JSON
|
||||
filename: Name for the JSON file
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
|
||||
# Ensure .json extension
|
||||
if not filename.endswith('.json'):
|
||||
filename += '.json'
|
||||
|
||||
|
||||
# Write to temp file and log
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
filepath = Path(tmpdir) / filename
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
client.log_artifact(run_id, str(filepath))
|
||||
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@@ -280,31 +282,31 @@ def end_mlflow_run(
|
||||
) -> str:
|
||||
"""
|
||||
End an MLflow run with the specified status.
|
||||
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to end
|
||||
status: Run status (FINISHED, FAILED, KILLED)
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
|
||||
Returns:
|
||||
The run_id
|
||||
"""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from mlflow.entities import RunStatus
|
||||
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
|
||||
status_map = {
|
||||
"FINISHED": RunStatus.FINISHED,
|
||||
"FAILED": RunStatus.FAILED,
|
||||
"KILLED": RunStatus.KILLED,
|
||||
}
|
||||
|
||||
|
||||
run_status = status_map.get(status.upper(), RunStatus.FINISHED)
|
||||
client.set_terminated(run_id, status=run_status)
|
||||
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@@ -322,10 +324,10 @@ def log_training_metrics(
|
||||
) -> str:
|
||||
"""
|
||||
Log comprehensive training metrics for ML models.
|
||||
|
||||
|
||||
Designed for use with QLoRA training, voice training, and other
|
||||
ML training pipelines in the llm-workflows repository.
|
||||
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
model_type: Type of model (llm, stt, tts, embeddings)
|
||||
@@ -333,19 +335,20 @@ def log_training_metrics(
|
||||
final_metrics: Final training metrics
|
||||
model_path: Path to saved model (if applicable)
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
|
||||
# Log training config as params
|
||||
flat_config = {}
|
||||
for key, value in training_config.items():
|
||||
@@ -353,29 +356,29 @@ def log_training_metrics(
|
||||
flat_config[f"config.{key}"] = json.dumps(value)[:500]
|
||||
else:
|
||||
flat_config[f"config.{key}"] = str(value)[:500]
|
||||
|
||||
|
||||
for key, value in flat_config.items():
|
||||
client.log_param(run_id, key, value)
|
||||
|
||||
|
||||
# Log model type tag
|
||||
client.set_tag(run_id, "model.type", model_type)
|
||||
|
||||
|
||||
# Log metrics
|
||||
for key, value in final_metrics.items():
|
||||
client.log_metric(run_id, key, float(value))
|
||||
|
||||
|
||||
# Log full config as artifact
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config_path = Path(tmpdir) / "training_config.json"
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(training_config, f, indent=2)
|
||||
client.log_artifact(run_id, str(config_path))
|
||||
|
||||
|
||||
# Log model path if provided
|
||||
if model_path:
|
||||
client.log_param(run_id, "model.path", model_path)
|
||||
client.set_tag(run_id, "model.saved", "true")
|
||||
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@@ -397,9 +400,9 @@ def log_document_ingestion_metrics(
|
||||
) -> str:
|
||||
"""
|
||||
Log document ingestion pipeline metrics.
|
||||
|
||||
|
||||
Designed for use with the document_ingestion_pipeline.
|
||||
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
source_url: URL of the source document
|
||||
@@ -411,16 +414,16 @@ def log_document_ingestion_metrics(
|
||||
chunk_size: Chunk size in tokens
|
||||
chunk_overlap: Chunk overlap in tokens
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
|
||||
# Log params
|
||||
params = {
|
||||
"source_url": source_url[:500],
|
||||
@@ -431,7 +434,7 @@ def log_document_ingestion_metrics(
|
||||
}
|
||||
for key, value in params.items():
|
||||
client.log_param(run_id, key, value)
|
||||
|
||||
|
||||
# Log metrics
|
||||
metrics = {
|
||||
"chunks_created": chunks_created,
|
||||
@@ -441,11 +444,11 @@ def log_document_ingestion_metrics(
|
||||
}
|
||||
for key, value in metrics.items():
|
||||
client.log_metric(run_id, key, float(value))
|
||||
|
||||
|
||||
# Set pipeline type tag
|
||||
client.set_tag(run_id, "pipeline.type", "document-ingestion")
|
||||
client.set_tag(run_id, "milvus.collection", collection_name)
|
||||
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@@ -463,9 +466,9 @@ def log_evaluation_results(
|
||||
) -> str:
|
||||
"""
|
||||
Log model evaluation results.
|
||||
|
||||
|
||||
Designed for use with the evaluation_pipeline.
|
||||
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
model_name: Name of the evaluated model
|
||||
@@ -473,27 +476,28 @@ def log_evaluation_results(
|
||||
metrics: Evaluation metrics (accuracy, etc.)
|
||||
sample_results: Optional sample predictions
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
|
||||
# Log params
|
||||
client.log_param(run_id, "eval.model_name", model_name)
|
||||
client.log_param(run_id, "eval.dataset", dataset_name)
|
||||
|
||||
|
||||
# Log metrics
|
||||
for key, value in metrics.items():
|
||||
client.log_metric(run_id, f"eval.{key}", float(value))
|
||||
|
||||
|
||||
# Log sample results as artifact
|
||||
if sample_results:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
@@ -501,13 +505,13 @@ def log_evaluation_results(
|
||||
with open(results_path, 'w') as f:
|
||||
json.dump(sample_results, f, indent=2)
|
||||
client.log_artifact(run_id, str(results_path))
|
||||
|
||||
|
||||
# Set tags
|
||||
client.set_tag(run_id, "pipeline.type", "evaluation")
|
||||
client.set_tag(run_id, "model.name", model_name)
|
||||
|
||||
|
||||
# Determine if passed
|
||||
passed = metrics.get("pass", metrics.get("accuracy", 0) >= 0.7)
|
||||
client.set_tag(run_id, "eval.passed", str(passed))
|
||||
|
||||
|
||||
return run_id
|
||||
|
||||
@@ -17,7 +17,7 @@ Usage:
|
||||
promote_model_to_production,
|
||||
generate_kserve_manifest,
|
||||
)
|
||||
|
||||
|
||||
# Register a new model version
|
||||
model_version = register_model_for_kserve(
|
||||
model_name="whisper-finetuned",
|
||||
@@ -28,7 +28,7 @@ Usage:
|
||||
"container_image": "ghcr.io/my-org/whisper:v2",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Generate KServe manifest for deployment
|
||||
manifest = generate_kserve_manifest(
|
||||
model_name="whisper-finetuned",
|
||||
@@ -36,18 +36,15 @@ Usage:
|
||||
)
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import yaml
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
import yaml
|
||||
from mlflow.entities.model_registry import ModelVersion
|
||||
|
||||
from .client import get_mlflow_client, MLflowConfig
|
||||
from .client import get_mlflow_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -55,15 +52,15 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class KServeConfig:
|
||||
"""Configuration for KServe deployment."""
|
||||
|
||||
|
||||
# Runtime/container configuration
|
||||
runtime: str = "kserve-huggingface" # kserve-huggingface, kserve-custom, etc.
|
||||
container_image: Optional[str] = None
|
||||
container_port: int = 8080
|
||||
|
||||
|
||||
# Protocol configuration
|
||||
protocol: str = "v2" # v1, v2, grpc
|
||||
|
||||
|
||||
# Resource requests/limits
|
||||
cpu_request: str = "1"
|
||||
cpu_limit: str = "4"
|
||||
@@ -71,22 +68,22 @@ class KServeConfig:
|
||||
memory_limit: str = "16Gi"
|
||||
gpu_count: int = 0
|
||||
gpu_type: str = "nvidia.com/gpu" # or amd.com/gpu for ROCm
|
||||
|
||||
|
||||
# Storage configuration
|
||||
storage_uri: Optional[str] = None # s3://, pvc://, gs://
|
||||
|
||||
|
||||
# Scaling configuration
|
||||
min_replicas: int = 1
|
||||
max_replicas: int = 1
|
||||
scale_target: int = 10 # Target concurrent requests for scaling
|
||||
|
||||
|
||||
# Serving configuration
|
||||
timeout_seconds: int = 300
|
||||
batch_size: int = 1
|
||||
|
||||
|
||||
# Additional environment variables
|
||||
env_vars: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def as_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for MLflow tags."""
|
||||
return {
|
||||
@@ -165,7 +162,7 @@ def register_model_for_kserve(
|
||||
) -> ModelVersion:
|
||||
"""
|
||||
Register a model in MLflow Model Registry with KServe metadata.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name for the registered model
|
||||
model_uri: URI to model artifacts (runs:/run_id/path, s3://, pvc://)
|
||||
@@ -175,16 +172,16 @@ def register_model_for_kserve(
|
||||
kserve_config: KServe deployment configuration
|
||||
tags: Additional tags for the model version
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
|
||||
Returns:
|
||||
The created ModelVersion object
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
|
||||
# Get or use preset KServe config
|
||||
if kserve_config is None:
|
||||
kserve_config = KSERVE_PRESETS.get(model_type, KServeConfig())
|
||||
|
||||
|
||||
# Ensure registered model exists
|
||||
try:
|
||||
client.get_registered_model(model_name)
|
||||
@@ -198,7 +195,7 @@ def register_model_for_kserve(
|
||||
}
|
||||
)
|
||||
logger.info(f"Created registered model: {model_name}")
|
||||
|
||||
|
||||
# Create model version
|
||||
model_version = client.create_model_version(
|
||||
name=model_name,
|
||||
@@ -211,12 +208,12 @@ def register_model_for_kserve(
|
||||
**kserve_config.as_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Registered model version {model_version.version} "
|
||||
f"for {model_name} (type: {model_type})"
|
||||
)
|
||||
|
||||
|
||||
return model_version
|
||||
|
||||
|
||||
@@ -229,19 +226,19 @@ def promote_model_to_stage(
|
||||
) -> ModelVersion:
|
||||
"""
|
||||
Promote a model version to a new stage.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
version: Version number to promote
|
||||
stage: Target stage (Staging, Production, Archived)
|
||||
archive_existing: If True, archive existing versions in target stage
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
|
||||
Returns:
|
||||
The updated ModelVersion
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
|
||||
# Transition to new stage
|
||||
model_version = client.transition_model_version_stage(
|
||||
name=model_name,
|
||||
@@ -249,9 +246,9 @@ def promote_model_to_stage(
|
||||
stage=stage,
|
||||
archive_existing_versions=archive_existing,
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Promoted {model_name} v{version} to {stage}")
|
||||
|
||||
|
||||
return model_version
|
||||
|
||||
|
||||
@@ -262,12 +259,12 @@ def promote_model_to_production(
|
||||
) -> ModelVersion:
|
||||
"""
|
||||
Promote a model version directly to Production.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
version: Version number to promote
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
|
||||
Returns:
|
||||
The updated ModelVersion
|
||||
"""
|
||||
@@ -286,18 +283,18 @@ def get_production_model(
|
||||
) -> Optional[ModelVersion]:
|
||||
"""
|
||||
Get the current Production model version.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
|
||||
Returns:
|
||||
The Production ModelVersion, or None if none exists
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
|
||||
versions = client.get_latest_versions(model_name, stages=["Production"])
|
||||
|
||||
|
||||
return versions[0] if versions else None
|
||||
|
||||
|
||||
@@ -308,17 +305,17 @@ def get_model_kserve_config(
|
||||
) -> KServeConfig:
|
||||
"""
|
||||
Get KServe configuration from a registered model version.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
version: Version number (uses Production if not specified)
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
|
||||
Returns:
|
||||
KServeConfig populated from model tags
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
|
||||
if version:
|
||||
model_version = client.get_model_version(model_name, str(version))
|
||||
else:
|
||||
@@ -326,9 +323,9 @@ def get_model_kserve_config(
|
||||
if not prod_version:
|
||||
raise ValueError(f"No Production version for {model_name}")
|
||||
model_version = prod_version
|
||||
|
||||
|
||||
tags = model_version.tags
|
||||
|
||||
|
||||
return KServeConfig(
|
||||
runtime=tags.get("kserve.runtime", "kserve-huggingface"),
|
||||
protocol=tags.get("kserve.protocol", "v2"),
|
||||
@@ -352,7 +349,7 @@ def generate_kserve_manifest(
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a KServe InferenceService manifest from a registered model.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
version: Version number (uses Production if not specified)
|
||||
@@ -360,12 +357,12 @@ def generate_kserve_manifest(
|
||||
service_name: Name for the InferenceService (defaults to model_name)
|
||||
extra_annotations: Additional annotations for the service
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
|
||||
Returns:
|
||||
KServe InferenceService manifest as a dictionary
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
|
||||
# Get model version
|
||||
if version:
|
||||
model_version = client.get_model_version(model_name, str(version))
|
||||
@@ -375,13 +372,13 @@ def generate_kserve_manifest(
|
||||
raise ValueError(f"No Production version for {model_name}")
|
||||
model_version = prod_version
|
||||
version = int(model_version.version)
|
||||
|
||||
|
||||
# Get KServe config
|
||||
config = get_model_kserve_config(model_name, version, tracking_uri)
|
||||
model_type = model_version.tags.get("model.type", "custom")
|
||||
|
||||
|
||||
svc_name = service_name or model_name.lower().replace("_", "-")
|
||||
|
||||
|
||||
# Build manifest
|
||||
manifest = {
|
||||
"apiVersion": "serving.kserve.io/v1beta1",
|
||||
@@ -409,10 +406,10 @@ def generate_kserve_manifest(
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Configure predictor based on runtime
|
||||
predictor = manifest["spec"]["predictor"]
|
||||
|
||||
|
||||
if config.container_image:
|
||||
# Custom container
|
||||
predictor["containers"] = [{
|
||||
@@ -434,16 +431,16 @@ def generate_kserve_manifest(
|
||||
for k, v in config.env_vars.items()
|
||||
],
|
||||
}]
|
||||
|
||||
|
||||
# Add GPU if needed
|
||||
if config.gpu_count > 0:
|
||||
predictor["containers"][0]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
|
||||
predictor["containers"][0]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
|
||||
|
||||
|
||||
else:
|
||||
# Standard KServe runtime
|
||||
storage_uri = config.storage_uri or model_version.source
|
||||
|
||||
|
||||
predictor["model"] = {
|
||||
"modelFormat": {"name": "huggingface"},
|
||||
"protocolVersion": config.protocol,
|
||||
@@ -459,11 +456,11 @@ def generate_kserve_manifest(
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if config.gpu_count > 0:
|
||||
predictor["model"]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
|
||||
predictor["model"]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
|
||||
|
||||
|
||||
return manifest
|
||||
|
||||
|
||||
@@ -476,14 +473,14 @@ def generate_kserve_yaml(
|
||||
) -> str:
|
||||
"""
|
||||
Generate a KServe InferenceService manifest as YAML.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
version: Version number (uses Production if not specified)
|
||||
namespace: Kubernetes namespace
|
||||
output_path: If provided, write YAML to this path
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
|
||||
Returns:
|
||||
YAML string of the manifest
|
||||
"""
|
||||
@@ -493,14 +490,14 @@ def generate_kserve_yaml(
|
||||
namespace=namespace,
|
||||
tracking_uri=tracking_uri,
|
||||
)
|
||||
|
||||
|
||||
yaml_str = yaml.dump(manifest, default_flow_style=False, sort_keys=False)
|
||||
|
||||
|
||||
if output_path:
|
||||
with open(output_path, 'w') as f:
|
||||
f.write(yaml_str)
|
||||
logger.info(f"Wrote KServe manifest to {output_path}")
|
||||
|
||||
|
||||
return yaml_str
|
||||
|
||||
|
||||
@@ -511,17 +508,17 @@ def list_model_versions(
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List all versions of a registered model.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
stages: Filter by stages (None for all)
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
|
||||
Returns:
|
||||
List of model version info dictionaries
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
|
||||
if stages:
|
||||
versions = client.get_latest_versions(model_name, stages=stages)
|
||||
else:
|
||||
@@ -529,7 +526,7 @@ def list_model_versions(
|
||||
versions = []
|
||||
for mv in client.search_model_versions(f"name='{model_name}'"):
|
||||
versions.append(mv)
|
||||
|
||||
|
||||
return [
|
||||
{
|
||||
"version": mv.version,
|
||||
|
||||
@@ -5,19 +5,17 @@ Provides a high-level interface for logging experiments, parameters,
|
||||
metrics, and artifacts from Kubeflow Pipeline components.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List, Union
|
||||
import os
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
from .client import get_mlflow_client, ensure_experiment, MLflowConfig
|
||||
from .client import MLflowConfig, ensure_experiment, get_mlflow_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -30,7 +28,7 @@ class PipelineMetadata:
|
||||
run_name: Optional[str] = None
|
||||
component_name: Optional[str] = None
|
||||
namespace: str = "ai-ml"
|
||||
|
||||
|
||||
# KFP-specific metadata (populated from environment if available)
|
||||
kfp_run_id: Optional[str] = field(
|
||||
default_factory=lambda: os.environ.get("KFP_RUN_ID")
|
||||
@@ -38,7 +36,7 @@ class PipelineMetadata:
|
||||
kfp_pod_name: Optional[str] = field(
|
||||
default_factory=lambda: os.environ.get("KFP_POD_NAME")
|
||||
)
|
||||
|
||||
|
||||
def as_tags(self) -> Dict[str, str]:
|
||||
"""Convert metadata to MLflow tags."""
|
||||
tags = {
|
||||
@@ -60,34 +58,34 @@ class PipelineMetadata:
|
||||
class MLflowTracker:
|
||||
"""
|
||||
MLflow experiment tracker for Kubeflow Pipeline components.
|
||||
|
||||
|
||||
Example usage in a KFP component:
|
||||
|
||||
|
||||
from mlflow_utils import MLflowTracker
|
||||
|
||||
|
||||
tracker = MLflowTracker(
|
||||
experiment_name="document-ingestion",
|
||||
run_name="batch-ingestion-2024-01"
|
||||
)
|
||||
|
||||
|
||||
with tracker.start_run() as run:
|
||||
tracker.log_params({
|
||||
"chunk_size": 500,
|
||||
"overlap": 50,
|
||||
"embeddings_model": "bge-small-en-v1.5"
|
||||
})
|
||||
|
||||
|
||||
# ... do work ...
|
||||
|
||||
|
||||
tracker.log_metrics({
|
||||
"documents_processed": 100,
|
||||
"chunks_created": 2500,
|
||||
"processing_time_seconds": 120.5
|
||||
})
|
||||
|
||||
|
||||
tracker.log_artifact("/path/to/output.json")
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name: str,
|
||||
@@ -98,7 +96,7 @@ class MLflowTracker:
|
||||
):
|
||||
"""
|
||||
Initialize the MLflow tracker.
|
||||
|
||||
|
||||
Args:
|
||||
experiment_name: Name of the MLflow experiment
|
||||
run_name: Optional name for this run
|
||||
@@ -112,22 +110,22 @@ class MLflowTracker:
|
||||
self.pipeline_metadata = pipeline_metadata
|
||||
self.user_tags = tags or {}
|
||||
self.tracking_uri = tracking_uri
|
||||
|
||||
|
||||
self.client: Optional[MlflowClient] = None
|
||||
self.run: Optional[mlflow.ActiveRun] = None
|
||||
self.run_id: Optional[str] = None
|
||||
self._start_time: Optional[float] = None
|
||||
|
||||
|
||||
def _get_all_tags(self) -> Dict[str, str]:
|
||||
"""Combine all tags for the run."""
|
||||
tags = self.config.default_tags.copy()
|
||||
|
||||
|
||||
if self.pipeline_metadata:
|
||||
tags.update(self.pipeline_metadata.as_tags())
|
||||
|
||||
|
||||
tags.update(self.user_tags)
|
||||
return tags
|
||||
|
||||
|
||||
@contextmanager
|
||||
def start_run(
|
||||
self,
|
||||
@@ -136,11 +134,11 @@ class MLflowTracker:
|
||||
):
|
||||
"""
|
||||
Start an MLflow run as a context manager.
|
||||
|
||||
|
||||
Args:
|
||||
nested: If True, create a nested run under the current active run
|
||||
parent_run_id: Explicit parent run ID for nested runs
|
||||
|
||||
|
||||
Yields:
|
||||
The MLflow run object
|
||||
"""
|
||||
@@ -148,12 +146,12 @@ class MLflowTracker:
|
||||
tracking_uri=self.tracking_uri,
|
||||
configure_global=True
|
||||
)
|
||||
|
||||
|
||||
# Ensure experiment exists
|
||||
experiment_id = ensure_experiment(self.experiment_name)
|
||||
|
||||
|
||||
self._start_time = time.time()
|
||||
|
||||
|
||||
try:
|
||||
# Start the run
|
||||
self.run = mlflow.start_run(
|
||||
@@ -163,14 +161,14 @@ class MLflowTracker:
|
||||
tags=self._get_all_tags(),
|
||||
)
|
||||
self.run_id = self.run.info.run_id
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Started MLflow run '{self.run_name}' "
|
||||
f"(ID: {self.run_id}) in experiment '{self.experiment_name}'"
|
||||
)
|
||||
|
||||
|
||||
yield self.run
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MLflow run failed: {e}")
|
||||
if self.run:
|
||||
@@ -185,22 +183,22 @@ class MLflowTracker:
|
||||
mlflow.log_metric("run_duration_seconds", duration)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# End the run
|
||||
mlflow.end_run()
|
||||
logger.info(f"Ended MLflow run '{self.run_name}'")
|
||||
|
||||
|
||||
def log_params(self, params: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Log parameters to the current run.
|
||||
|
||||
|
||||
Args:
|
||||
params: Dictionary of parameter names to values
|
||||
"""
|
||||
if not self.run:
|
||||
logger.warning("No active run, skipping log_params")
|
||||
return
|
||||
|
||||
|
||||
# MLflow has limits on param values, truncate if needed
|
||||
cleaned_params = {}
|
||||
for key, value in params.items():
|
||||
@@ -208,14 +206,14 @@ class MLflowTracker:
|
||||
if len(str_value) > 500:
|
||||
str_value = str_value[:497] + "..."
|
||||
cleaned_params[key] = str_value
|
||||
|
||||
|
||||
mlflow.log_params(cleaned_params)
|
||||
logger.debug(f"Logged {len(params)} parameters")
|
||||
|
||||
|
||||
def log_param(self, key: str, value: Any) -> None:
|
||||
"""Log a single parameter."""
|
||||
self.log_params({key: value})
|
||||
|
||||
|
||||
def log_metrics(
|
||||
self,
|
||||
metrics: Dict[str, Union[float, int]],
|
||||
@@ -223,7 +221,7 @@ class MLflowTracker:
|
||||
) -> None:
|
||||
"""
|
||||
Log metrics to the current run.
|
||||
|
||||
|
||||
Args:
|
||||
metrics: Dictionary of metric names to values
|
||||
step: Optional step number for time-series metrics
|
||||
@@ -231,10 +229,10 @@ class MLflowTracker:
|
||||
if not self.run:
|
||||
logger.warning("No active run, skipping log_metrics")
|
||||
return
|
||||
|
||||
|
||||
mlflow.log_metrics(metrics, step=step)
|
||||
logger.debug(f"Logged {len(metrics)} metrics")
|
||||
|
||||
|
||||
def log_metric(
|
||||
self,
|
||||
key: str,
|
||||
@@ -243,7 +241,7 @@ class MLflowTracker:
|
||||
) -> None:
|
||||
"""Log a single metric."""
|
||||
self.log_metrics({key: value}, step=step)
|
||||
|
||||
|
||||
def log_artifact(
|
||||
self,
|
||||
local_path: str,
|
||||
@@ -251,7 +249,7 @@ class MLflowTracker:
|
||||
) -> None:
|
||||
"""
|
||||
Log an artifact file to the current run.
|
||||
|
||||
|
||||
Args:
|
||||
local_path: Path to the local file to log
|
||||
artifact_path: Optional destination path within the artifact store
|
||||
@@ -259,10 +257,10 @@ class MLflowTracker:
|
||||
if not self.run:
|
||||
logger.warning("No active run, skipping log_artifact")
|
||||
return
|
||||
|
||||
|
||||
mlflow.log_artifact(local_path, artifact_path)
|
||||
logger.info(f"Logged artifact: {local_path}")
|
||||
|
||||
|
||||
def log_artifacts(
|
||||
self,
|
||||
local_dir: str,
|
||||
@@ -270,7 +268,7 @@ class MLflowTracker:
|
||||
) -> None:
|
||||
"""
|
||||
Log all files in a directory as artifacts.
|
||||
|
||||
|
||||
Args:
|
||||
local_dir: Path to the local directory
|
||||
artifact_path: Optional destination path within the artifact store
|
||||
@@ -278,10 +276,10 @@ class MLflowTracker:
|
||||
if not self.run:
|
||||
logger.warning("No active run, skipping log_artifacts")
|
||||
return
|
||||
|
||||
|
||||
mlflow.log_artifacts(local_dir, artifact_path)
|
||||
logger.info(f"Logged artifacts from: {local_dir}")
|
||||
|
||||
|
||||
def log_dict(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
@@ -290,7 +288,7 @@ class MLflowTracker:
|
||||
) -> None:
|
||||
"""
|
||||
Log a dictionary as a JSON artifact.
|
||||
|
||||
|
||||
Args:
|
||||
data: Dictionary to log
|
||||
filename: Name for the JSON file
|
||||
@@ -299,14 +297,14 @@ class MLflowTracker:
|
||||
if not self.run:
|
||||
logger.warning("No active run, skipping log_dict")
|
||||
return
|
||||
|
||||
|
||||
# Ensure .json extension
|
||||
if not filename.endswith(".json"):
|
||||
filename += ".json"
|
||||
|
||||
|
||||
mlflow.log_dict(data, f"{artifact_path}/{filename}" if artifact_path else filename)
|
||||
logger.debug(f"Logged dict as: {filename}")
|
||||
|
||||
|
||||
def log_model_info(
|
||||
self,
|
||||
model_type: str,
|
||||
@@ -317,7 +315,7 @@ class MLflowTracker:
|
||||
) -> None:
|
||||
"""
|
||||
Log model information as parameters and tags.
|
||||
|
||||
|
||||
Args:
|
||||
model_type: Type of model (e.g., "llm", "embedding", "stt")
|
||||
model_name: Name/identifier of the model
|
||||
@@ -335,13 +333,13 @@ class MLflowTracker:
|
||||
if extra_info:
|
||||
for key, value in extra_info.items():
|
||||
params[f"model.{key}"] = value
|
||||
|
||||
|
||||
self.log_params(params)
|
||||
|
||||
|
||||
# Also set as tags for easier filtering
|
||||
mlflow.set_tag("model.type", model_type)
|
||||
mlflow.set_tag("model.name", model_name)
|
||||
|
||||
|
||||
def log_dataset_info(
|
||||
self,
|
||||
name: str,
|
||||
@@ -351,7 +349,7 @@ class MLflowTracker:
|
||||
) -> None:
|
||||
"""
|
||||
Log dataset information.
|
||||
|
||||
|
||||
Args:
|
||||
name: Dataset name
|
||||
source: Dataset source (URL, path, etc.)
|
||||
@@ -367,26 +365,26 @@ class MLflowTracker:
|
||||
if extra_info:
|
||||
for key, value in extra_info.items():
|
||||
params[f"dataset.{key}"] = value
|
||||
|
||||
|
||||
self.log_params(params)
|
||||
|
||||
|
||||
def set_tag(self, key: str, value: str) -> None:
|
||||
"""Set a single tag on the run."""
|
||||
if self.run:
|
||||
mlflow.set_tag(key, value)
|
||||
|
||||
|
||||
def set_tags(self, tags: Dict[str, str]) -> None:
|
||||
"""Set multiple tags on the run."""
|
||||
if self.run:
|
||||
mlflow.set_tags(tags)
|
||||
|
||||
|
||||
@property
|
||||
def artifact_uri(self) -> Optional[str]:
|
||||
"""Get the artifact URI for the current run."""
|
||||
if self.run:
|
||||
return self.run.info.artifact_uri
|
||||
return None
|
||||
|
||||
|
||||
@property
|
||||
def experiment_id(self) -> Optional[str]:
|
||||
"""Get the experiment ID for the current run."""
|
||||
|
||||
Reference in New Issue
Block a user