fix: resolve all ruff lint errors
Some checks failed
CI / Test (push) Successful in 1m46s
CI / Lint (push) Failing after 1m49s
CI / Publish (push) Has been skipped
CI / Notify (push) Successful in 2s

This commit is contained in:
2026-02-13 10:57:57 -05:00
parent 6bcf84549c
commit 1c841729a0
9 changed files with 456 additions and 464 deletions

View File

@@ -20,17 +20,17 @@ Usage:
""" """
from .client import ( from .client import (
MLflowConfig,
ensure_experiment,
get_mlflow_client, get_mlflow_client,
get_tracking_uri, get_tracking_uri,
ensure_experiment,
MLflowConfig,
) )
from .tracker import MLflowTracker
from .inference_tracker import InferenceMetricsTracker from .inference_tracker import InferenceMetricsTracker
from .tracker import MLflowTracker
__all__ = [ __all__ = [
"get_mlflow_client", "get_mlflow_client",
"get_tracking_uri", "get_tracking_uri",
"ensure_experiment", "ensure_experiment",
"MLflowConfig", "MLflowConfig",
"MLflowTracker", "MLflowTracker",

View File

@@ -7,21 +7,21 @@ Command-line interface for querying and comparing MLflow experiments.
Usage: Usage:
# Compare recent runs in an experiment # Compare recent runs in an experiment
python -m mlflow_utils.cli compare --experiment chat-inference --runs 5 python -m mlflow_utils.cli compare --experiment chat-inference --runs 5
# Get best run by metric # Get best run by metric
python -m mlflow_utils.cli best --experiment evaluation --metric eval.accuracy python -m mlflow_utils.cli best --experiment evaluation --metric eval.accuracy
# Generate performance report # Generate performance report
python -m mlflow_utils.cli report --service chat-handler --hours 24 python -m mlflow_utils.cli report --service chat-handler --hours 24
# Check model promotion criteria # Check model promotion criteria
python -m mlflow_utils.cli promote --model whisper-finetuned \\ python -m mlflow_utils.cli promote --model whisper-finetuned \\
--experiment voice-evaluation \\ --experiment voice-evaluation \\
--criteria "eval.accuracy>=0.9,total_latency_p95<=2.0" --criteria "eval.accuracy>=0.9,total_latency_p95<=2.0"
# List experiments # List experiments
python -m mlflow_utils.cli list-experiments python -m mlflow_utils.cli list-experiments
# Query runs # Query runs
python -m mlflow_utils.cli query --experiment chat-inference \\ python -m mlflow_utils.cli query --experiment chat-inference \\
--filter "metrics.total_latency_mean < 1.0" --limit 10 --filter "metrics.total_latency_mean < 1.0" --limit 10
@@ -30,19 +30,16 @@ Usage:
import argparse import argparse
import json import json
import sys import sys
from typing import Optional
from .client import get_mlflow_client, health_check from .client import get_mlflow_client, health_check
from .experiment_comparison import ( from .experiment_comparison import (
ExperimentAnalyzer, ExperimentAnalyzer,
compare_experiments,
promotion_recommendation,
get_inference_performance_report, get_inference_performance_report,
promotion_recommendation,
) )
from .model_registry import ( from .model_registry import (
list_model_versions,
get_production_model,
generate_kserve_yaml, generate_kserve_yaml,
list_model_versions,
) )
@@ -57,7 +54,7 @@ def cmd_list_experiments(args):
"""List all experiments.""" """List all experiments."""
client = get_mlflow_client(tracking_uri=args.tracking_uri) client = get_mlflow_client(tracking_uri=args.tracking_uri)
experiments = client.search_experiments() experiments = client.search_experiments()
print(f"{'ID':<10} {'Name':<40} {'Artifact Location'}") print(f"{'ID':<10} {'Name':<40} {'Artifact Location'}")
print("-" * 80) print("-" * 80)
for exp in experiments: for exp in experiments:
@@ -70,13 +67,13 @@ def cmd_compare(args):
args.experiment, args.experiment,
tracking_uri=args.tracking_uri tracking_uri=args.tracking_uri
) )
if args.run_ids: if args.run_ids:
run_ids = args.run_ids.split(",") run_ids = args.run_ids.split(",")
comparison = analyzer.compare_runs(run_ids=run_ids) comparison = analyzer.compare_runs(run_ids=run_ids)
else: else:
comparison = analyzer.compare_runs(n_recent=args.runs) comparison = analyzer.compare_runs(n_recent=args.runs)
if args.json: if args.json:
print(json.dumps(comparison.to_dict(), indent=2, default=str)) print(json.dumps(comparison.to_dict(), indent=2, default=str))
else: else:
@@ -89,17 +86,17 @@ def cmd_best(args):
args.experiment, args.experiment,
tracking_uri=args.tracking_uri tracking_uri=args.tracking_uri
) )
best_run = analyzer.get_best_run( best_run = analyzer.get_best_run(
metric=args.metric, metric=args.metric,
minimize=args.minimize, minimize=args.minimize,
filter_string=args.filter or "", filter_string=args.filter or "",
) )
if not best_run: if not best_run:
print(f"No runs found with metric '{args.metric}'") print(f"No runs found with metric '{args.metric}'")
sys.exit(1) sys.exit(1)
result = { result = {
"run_id": best_run.info.run_id, "run_id": best_run.info.run_id,
"run_name": best_run.info.run_name, "run_name": best_run.info.run_name,
@@ -107,7 +104,7 @@ def cmd_best(args):
"all_metrics": dict(best_run.data.metrics), "all_metrics": dict(best_run.data.metrics),
"params": dict(best_run.data.params), "params": dict(best_run.data.params),
} }
if args.json: if args.json:
print(json.dumps(result, indent=2)) print(json.dumps(result, indent=2))
else: else:
@@ -122,12 +119,12 @@ def cmd_summary(args):
args.experiment, args.experiment,
tracking_uri=args.tracking_uri tracking_uri=args.tracking_uri
) )
summary = analyzer.get_metrics_summary( summary = analyzer.get_metrics_summary(
hours=args.hours, hours=args.hours,
metrics=args.metrics.split(",") if args.metrics else None, metrics=args.metrics.split(",") if args.metrics else None,
) )
if args.json: if args.json:
print(json.dumps(summary, indent=2)) print(json.dumps(summary, indent=2))
else: else:
@@ -146,7 +143,7 @@ def cmd_report(args):
hours=args.hours, hours=args.hours,
tracking_uri=args.tracking_uri, tracking_uri=args.tracking_uri,
) )
if args.json: if args.json:
print(json.dumps(report, indent=2)) print(json.dumps(report, indent=2))
else: else:
@@ -154,18 +151,18 @@ def cmd_report(args):
print(f"Period: Last {report['period_hours']} hours") print(f"Period: Last {report['period_hours']} hours")
print(f"Generated: {report['generated_at']}") print(f"Generated: {report['generated_at']}")
print() print()
if report["latency"]: if report["latency"]:
print("Latency Metrics:") print("Latency Metrics:")
for metric, stats in report["latency"].items(): for metric, stats in report["latency"].items():
if "mean" in stats: if "mean" in stats:
print(f" {metric}: {stats['mean']:.4f}s (p50: {stats.get('median', 'N/A')})") print(f" {metric}: {stats['mean']:.4f}s (p50: {stats.get('median', 'N/A')})")
if report["rag"]: if report["rag"]:
print("\nRAG Usage:") print("\nRAG Usage:")
for metric, stats in report["rag"].items(): for metric, stats in report["rag"].items():
print(f" {metric}: {stats.get('mean', 'N/A')}") print(f" {metric}: {stats.get('mean', 'N/A')}")
if report["errors"]: if report["errors"]:
print("\nError Rates:") print("\nError Rates:")
for metric, stats in report["errors"].items(): for metric, stats in report["errors"].items():
@@ -183,14 +180,14 @@ def cmd_promote(args):
metric, value = criterion.split(op) metric, value = criterion.split(op)
criteria[metric.strip()] = (op, float(value.strip())) criteria[metric.strip()] = (op, float(value.strip()))
break break
rec = promotion_recommendation( rec = promotion_recommendation(
model_name=args.model, model_name=args.model,
experiment_name=args.experiment, experiment_name=args.experiment,
criteria=criteria, criteria=criteria,
tracking_uri=args.tracking_uri, tracking_uri=args.tracking_uri,
) )
if args.json: if args.json:
print(json.dumps(rec.to_dict(), indent=2)) print(json.dumps(rec.to_dict(), indent=2))
else: else:
@@ -208,12 +205,12 @@ def cmd_query(args):
args.experiment, args.experiment,
tracking_uri=args.tracking_uri tracking_uri=args.tracking_uri
) )
runs = analyzer.search_runs( runs = analyzer.search_runs(
filter_string=args.filter or "", filter_string=args.filter or "",
max_results=args.limit, max_results=args.limit,
) )
if args.json: if args.json:
result = [ result = [
{ {
@@ -237,20 +234,21 @@ def cmd_query(args):
def cmd_models(args): def cmd_models(args):
"""List registered models.""" """List registered models."""
client = get_mlflow_client(tracking_uri=args.tracking_uri) client = get_mlflow_client(tracking_uri=args.tracking_uri)
if args.model: if args.model:
versions = list_model_versions(args.model, tracking_uri=args.tracking_uri) versions = list_model_versions(args.model, tracking_uri=args.tracking_uri)
if args.json: if args.json:
print(json.dumps(versions, indent=2, default=str)) print(json.dumps(versions, indent=2, default=str))
else: else:
print(f"Model: {args.model}") print(f"Model: {args.model}")
for v in versions: for v in versions:
print(f" v{v['version']} ({v['stage']}): {v['description'][:50] if v['description'] else 'No description'}") desc = v["description"][:50] if v["description"] else "No description"
print(f" v{v['version']} ({v['stage']}): {desc}")
else: else:
# List all models # List all models
models = client.search_registered_models() models = client.search_registered_models()
if args.json: if args.json:
result = [{"name": m.name, "description": m.description} for m in models] result = [{"name": m.name, "description": m.description} for m in models]
print(json.dumps(result, indent=2)) print(json.dumps(result, indent=2))
@@ -271,7 +269,7 @@ def cmd_kserve(args):
output_path=args.output, output_path=args.output,
tracking_uri=args.tracking_uri, tracking_uri=args.tracking_uri,
) )
if not args.output: if not args.output:
print(yaml_str) print(yaml_str)
@@ -281,7 +279,7 @@ def main():
description="MLflow Experiment CLI", description="MLflow Experiment CLI",
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
) )
parser.add_argument( parser.add_argument(
"--tracking-uri", "--tracking-uri",
default=None, default=None,
@@ -292,24 +290,24 @@ def main():
action="store_true", action="store_true",
help="Output as JSON", help="Output as JSON",
) )
subparsers = parser.add_subparsers(dest="command", help="Commands") subparsers = parser.add_subparsers(dest="command", help="Commands")
# health # health
health_parser = subparsers.add_parser("health", help="Check MLflow connectivity") health_parser = subparsers.add_parser("health", help="Check MLflow connectivity")
health_parser.set_defaults(func=cmd_health) health_parser.set_defaults(func=cmd_health)
# list-experiments # list-experiments
list_parser = subparsers.add_parser("list-experiments", help="List experiments") list_parser = subparsers.add_parser("list-experiments", help="List experiments")
list_parser.set_defaults(func=cmd_list_experiments) list_parser.set_defaults(func=cmd_list_experiments)
# compare # compare
compare_parser = subparsers.add_parser("compare", help="Compare runs") compare_parser = subparsers.add_parser("compare", help="Compare runs")
compare_parser.add_argument("--experiment", "-e", required=True, help="Experiment name") compare_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
compare_parser.add_argument("--runs", "-n", type=int, default=5, help="Number of recent runs") compare_parser.add_argument("--runs", "-n", type=int, default=5, help="Number of recent runs")
compare_parser.add_argument("--run-ids", help="Comma-separated run IDs to compare") compare_parser.add_argument("--run-ids", help="Comma-separated run IDs to compare")
compare_parser.set_defaults(func=cmd_compare) compare_parser.set_defaults(func=cmd_compare)
# best # best
best_parser = subparsers.add_parser("best", help="Find best run by metric") best_parser = subparsers.add_parser("best", help="Find best run by metric")
best_parser.add_argument("--experiment", "-e", required=True, help="Experiment name") best_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
@@ -317,39 +315,39 @@ def main():
best_parser.add_argument("--minimize", action="store_true", help="Minimize metric (default: maximize)") best_parser.add_argument("--minimize", action="store_true", help="Minimize metric (default: maximize)")
best_parser.add_argument("--filter", "-f", help="Filter string") best_parser.add_argument("--filter", "-f", help="Filter string")
best_parser.set_defaults(func=cmd_best) best_parser.set_defaults(func=cmd_best)
# summary # summary
summary_parser = subparsers.add_parser("summary", help="Get metrics summary") summary_parser = subparsers.add_parser("summary", help="Get metrics summary")
summary_parser.add_argument("--experiment", "-e", required=True, help="Experiment name") summary_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
summary_parser.add_argument("--hours", type=int, default=24, help="Hours of data") summary_parser.add_argument("--hours", type=int, default=24, help="Hours of data")
summary_parser.add_argument("--metrics", help="Comma-separated metric names") summary_parser.add_argument("--metrics", help="Comma-separated metric names")
summary_parser.set_defaults(func=cmd_summary) summary_parser.set_defaults(func=cmd_summary)
# report # report
report_parser = subparsers.add_parser("report", help="Generate performance report") report_parser = subparsers.add_parser("report", help="Generate performance report")
report_parser.add_argument("--service", "-s", required=True, help="Service name") report_parser.add_argument("--service", "-s", required=True, help="Service name")
report_parser.add_argument("--hours", type=int, default=24, help="Hours of data") report_parser.add_argument("--hours", type=int, default=24, help="Hours of data")
report_parser.set_defaults(func=cmd_report) report_parser.set_defaults(func=cmd_report)
# promote # promote
promote_parser = subparsers.add_parser("promote", help="Check promotion criteria") promote_parser = subparsers.add_parser("promote", help="Check promotion criteria")
promote_parser.add_argument("--model", "-m", required=True, help="Model name") promote_parser.add_argument("--model", "-m", required=True, help="Model name")
promote_parser.add_argument("--experiment", "-e", required=True, help="Experiment name") promote_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
promote_parser.add_argument("--criteria", "-c", required=True, help="Criteria (e.g., 'accuracy>=0.9,latency<=2.0')") promote_parser.add_argument("--criteria", "-c", required=True, help="Criteria (e.g., 'accuracy>=0.9,latency<=2.0')")
promote_parser.set_defaults(func=cmd_promote) promote_parser.set_defaults(func=cmd_promote)
# query # query
query_parser = subparsers.add_parser("query", help="Query runs") query_parser = subparsers.add_parser("query", help="Query runs")
query_parser.add_argument("--experiment", "-e", required=True, help="Experiment name") query_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
query_parser.add_argument("--filter", "-f", help="MLflow filter string") query_parser.add_argument("--filter", "-f", help="MLflow filter string")
query_parser.add_argument("--limit", "-l", type=int, default=20, help="Max results") query_parser.add_argument("--limit", "-l", type=int, default=20, help="Max results")
query_parser.set_defaults(func=cmd_query) query_parser.set_defaults(func=cmd_query)
# models # models
models_parser = subparsers.add_parser("models", help="List registered models") models_parser = subparsers.add_parser("models", help="List registered models")
models_parser.add_argument("--model", "-m", help="Specific model name") models_parser.add_argument("--model", "-m", help="Specific model name")
models_parser.set_defaults(func=cmd_models) models_parser.set_defaults(func=cmd_models)
# kserve # kserve
kserve_parser = subparsers.add_parser("kserve", help="Generate KServe manifest") kserve_parser = subparsers.add_parser("kserve", help="Generate KServe manifest")
kserve_parser.add_argument("--model", "-m", required=True, help="Model name") kserve_parser.add_argument("--model", "-m", required=True, help="Model name")
@@ -357,13 +355,13 @@ def main():
kserve_parser.add_argument("--namespace", "-n", default="ai-ml", help="K8s namespace") kserve_parser.add_argument("--namespace", "-n", default="ai-ml", help="K8s namespace")
kserve_parser.add_argument("--output", "-o", help="Output file path") kserve_parser.add_argument("--output", "-o", help="Output file path")
kserve_parser.set_defaults(func=cmd_kserve) kserve_parser.set_defaults(func=cmd_kserve)
args = parser.parse_args() args = parser.parse_args()
if not args.command: if not args.command:
parser.print_help() parser.print_help()
sys.exit(1) sys.exit(1)
args.func(args) args.func(args)

View File

@@ -5,10 +5,10 @@ Provides a configured MLflow client for all integrations in the LLM workflows.
Supports both in-cluster and external access patterns. Supports both in-cluster and external access patterns.
""" """
import os
import logging import logging
import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, Dict, Any from typing import Any, Dict, Optional
import mlflow import mlflow
from mlflow.tracking import MlflowClient from mlflow.tracking import MlflowClient
@@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
@dataclass @dataclass
class MLflowConfig: class MLflowConfig:
"""Configuration for MLflow integration.""" """Configuration for MLflow integration."""
# Tracking server URIs # Tracking server URIs
tracking_uri: str = field( tracking_uri: str = field(
default_factory=lambda: os.environ.get( default_factory=lambda: os.environ.get(
@@ -33,7 +33,7 @@ class MLflowConfig:
"https://mlflow.lab.daviestechlabs.io" "https://mlflow.lab.daviestechlabs.io"
) )
) )
# Artifact storage (NFS PVC mount) # Artifact storage (NFS PVC mount)
artifact_location: str = field( artifact_location: str = field(
default_factory=lambda: os.environ.get( default_factory=lambda: os.environ.get(
@@ -41,7 +41,7 @@ class MLflowConfig:
"/mlflow/artifacts" "/mlflow/artifacts"
) )
) )
# Default experiment settings # Default experiment settings
default_experiment: str = field( default_experiment: str = field(
default_factory=lambda: os.environ.get( default_factory=lambda: os.environ.get(
@@ -49,7 +49,7 @@ class MLflowConfig:
"llm-workflows" "llm-workflows"
) )
) )
# Service identification # Service identification
service_name: str = field( service_name: str = field(
default_factory=lambda: os.environ.get( default_factory=lambda: os.environ.get(
@@ -57,10 +57,10 @@ class MLflowConfig:
"unknown-service" "unknown-service"
) )
) )
# Additional tags to add to all runs # Additional tags to add to all runs
default_tags: Dict[str, str] = field(default_factory=dict) default_tags: Dict[str, str] = field(default_factory=dict)
def __post_init__(self): def __post_init__(self):
"""Add default tags based on environment.""" """Add default tags based on environment."""
env_tags = { env_tags = {
@@ -74,10 +74,10 @@ class MLflowConfig:
def get_tracking_uri(external: bool = False) -> str: def get_tracking_uri(external: bool = False) -> str:
""" """
Get the appropriate MLflow tracking URI. Get the appropriate MLflow tracking URI.
Args: Args:
external: If True, return the external URI for outside-cluster access external: If True, return the external URI for outside-cluster access
Returns: Returns:
The MLflow tracking URI string The MLflow tracking URI string
""" """
@@ -91,20 +91,20 @@ def get_mlflow_client(
) -> MlflowClient: ) -> MlflowClient:
""" """
Get a configured MLflow client. Get a configured MLflow client.
Args: Args:
tracking_uri: Override the default tracking URI tracking_uri: Override the default tracking URI
configure_global: If True, also set mlflow.set_tracking_uri() configure_global: If True, also set mlflow.set_tracking_uri()
Returns: Returns:
Configured MlflowClient instance Configured MlflowClient instance
""" """
uri = tracking_uri or get_tracking_uri() uri = tracking_uri or get_tracking_uri()
if configure_global: if configure_global:
mlflow.set_tracking_uri(uri) mlflow.set_tracking_uri(uri)
logger.info(f"MLflow tracking URI set to: {uri}") logger.info(f"MLflow tracking URI set to: {uri}")
client = MlflowClient(tracking_uri=uri) client = MlflowClient(tracking_uri=uri)
return client return client
@@ -116,21 +116,21 @@ def ensure_experiment(
) -> str: ) -> str:
""" """
Ensure an experiment exists, creating it if necessary. Ensure an experiment exists, creating it if necessary.
Args: Args:
experiment_name: Name of the experiment experiment_name: Name of the experiment
artifact_location: Override default artifact location artifact_location: Override default artifact location
tags: Additional tags for the experiment tags: Additional tags for the experiment
Returns: Returns:
The experiment ID The experiment ID
""" """
config = MLflowConfig() config = MLflowConfig()
client = get_mlflow_client() client = get_mlflow_client()
# Check if experiment exists # Check if experiment exists
experiment = client.get_experiment_by_name(experiment_name) experiment = client.get_experiment_by_name(experiment_name)
if experiment is None: if experiment is None:
# Create the experiment # Create the experiment
artifact_loc = artifact_location or f"{config.artifact_location}/{experiment_name}" artifact_loc = artifact_location or f"{config.artifact_location}/{experiment_name}"
@@ -143,7 +143,7 @@ def ensure_experiment(
else: else:
experiment_id = experiment.experiment_id experiment_id = experiment.experiment_id
logger.debug(f"Using existing experiment '{experiment_name}' with ID: {experiment_id}") logger.debug(f"Using existing experiment '{experiment_name}' with ID: {experiment_id}")
return experiment_id return experiment_id
@@ -154,17 +154,17 @@ def get_or_create_registered_model(
) -> str: ) -> str:
""" """
Get or create a registered model in the Model Registry. Get or create a registered model in the Model Registry.
Args: Args:
model_name: Name of the model to register model_name: Name of the model to register
description: Model description description: Model description
tags: Tags for the model tags: Tags for the model
Returns: Returns:
The registered model name The registered model name
""" """
client = get_mlflow_client() client = get_mlflow_client()
try: try:
# Check if model exists # Check if model exists
client.get_registered_model(model_name) client.get_registered_model(model_name)
@@ -177,14 +177,14 @@ def get_or_create_registered_model(
tags=tags or {} tags=tags or {}
) )
logger.info(f"Created registered model: {model_name}") logger.info(f"Created registered model: {model_name}")
return model_name return model_name
def health_check() -> Dict[str, Any]: def health_check() -> Dict[str, Any]:
""" """
Check MLflow server connectivity and return status. Check MLflow server connectivity and return status.
Returns: Returns:
Dictionary with health status information Dictionary with health status information
""" """
@@ -195,7 +195,7 @@ def health_check() -> Dict[str, Any]:
"connected": False, "connected": False,
"error": None, "error": None,
} }
try: try:
client = get_mlflow_client(configure_global=False) client = get_mlflow_client(configure_global=False)
# Try to list experiments as a health check # Try to list experiments as a health check
@@ -205,5 +205,5 @@ def health_check() -> Dict[str, Any]:
except Exception as e: except Exception as e:
result["error"] = str(e) result["error"] = str(e)
logger.error(f"MLflow health check failed: {e}") logger.error(f"MLflow health check failed: {e}")
return result return result

View File

@@ -18,15 +18,15 @@ Usage:
get_best_run, get_best_run,
promotion_recommendation, promotion_recommendation,
) )
analyzer = ExperimentAnalyzer("chat-inference") analyzer = ExperimentAnalyzer("chat-inference")
# Compare last N runs # Compare last N runs
comparison = analyzer.compare_recent_runs(n=5) comparison = analyzer.compare_recent_runs(n=5)
# Find best performing model # Find best performing model
best = analyzer.get_best_run(metric="total_latency_mean", minimize=True) best = analyzer.get_best_run(metric="total_latency_mean", minimize=True)
# Get promotion recommendation # Get promotion recommendation
rec = analyzer.promotion_recommendation( rec = analyzer.promotion_recommendation(
model_name="whisper-finetuned", model_name="whisper-finetuned",
@@ -35,19 +35,15 @@ Usage:
) )
""" """
import os
import json
import logging import logging
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List, Tuple, Union
from dataclasses import dataclass, field
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple
import mlflow from mlflow.entities import Experiment, Run
from mlflow.tracking import MlflowClient
from mlflow.entities import Run, Experiment
from .client import get_mlflow_client, MLflowConfig from .client import get_mlflow_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -57,21 +53,21 @@ class RunComparison:
"""Comparison result for multiple MLflow runs.""" """Comparison result for multiple MLflow runs."""
run_ids: List[str] run_ids: List[str]
experiment_name: str experiment_name: str
# Metric comparisons (metric_name -> {run_id -> value}) # Metric comparisons (metric_name -> {run_id -> value})
metrics: Dict[str, Dict[str, float]] = field(default_factory=dict) metrics: Dict[str, Dict[str, float]] = field(default_factory=dict)
# Parameter differences # Parameter differences
params: Dict[str, Dict[str, str]] = field(default_factory=dict) params: Dict[str, Dict[str, str]] = field(default_factory=dict)
# Run metadata # Run metadata
run_names: Dict[str, str] = field(default_factory=dict) run_names: Dict[str, str] = field(default_factory=dict)
start_times: Dict[str, datetime] = field(default_factory=dict) start_times: Dict[str, datetime] = field(default_factory=dict)
durations: Dict[str, float] = field(default_factory=dict) durations: Dict[str, float] = field(default_factory=dict)
# Best performers by metric # Best performers by metric
best_by_metric: Dict[str, str] = field(default_factory=dict) best_by_metric: Dict[str, str] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization.""" """Convert to dictionary for serialization."""
return { return {
@@ -82,22 +78,22 @@ class RunComparison:
"run_names": self.run_names, "run_names": self.run_names,
"best_by_metric": self.best_by_metric, "best_by_metric": self.best_by_metric,
} }
def summary_table(self) -> str: def summary_table(self) -> str:
"""Generate a text summary table of the comparison.""" """Generate a text summary table of the comparison."""
if not self.run_ids: if not self.run_ids:
return "No runs to compare" return "No runs to compare"
lines = [] lines = []
lines.append(f"Experiment: {self.experiment_name}") lines.append(f"Experiment: {self.experiment_name}")
lines.append(f"Comparing {len(self.run_ids)} runs") lines.append(f"Comparing {len(self.run_ids)} runs")
lines.append("") lines.append("")
# Header # Header
header = ["Metric"] + [self.run_names.get(rid, rid[:8]) for rid in self.run_ids] header = ["Metric"] + [self.run_names.get(rid, rid[:8]) for rid in self.run_ids]
lines.append(" | ".join(header)) lines.append(" | ".join(header))
lines.append("-" * (len(lines[-1]) + 10)) lines.append("-" * (len(lines[-1]) + 10))
# Metrics # Metrics
for metric_name, values in sorted(self.metrics.items()): for metric_name, values in sorted(self.metrics.items()):
row = [metric_name] row = [metric_name]
@@ -108,7 +104,7 @@ class RunComparison:
else: else:
row.append("N/A") row.append("N/A")
lines.append(" | ".join(row)) lines.append(" | ".join(row))
return "\n".join(lines) return "\n".join(lines)
@@ -121,7 +117,7 @@ class PromotionRecommendation:
reasons: List[str] reasons: List[str]
metrics_summary: Dict[str, float] metrics_summary: Dict[str, float]
comparison_with_production: Optional[Dict[str, Any]] = None comparison_with_production: Optional[Dict[str, Any]] = None
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return { return {
"model_name": self.model_name, "model_name": self.model_name,
@@ -136,20 +132,20 @@ class PromotionRecommendation:
class ExperimentAnalyzer: class ExperimentAnalyzer:
""" """
Analyze MLflow experiments for model comparison and promotion decisions. Analyze MLflow experiments for model comparison and promotion decisions.
Example: Example:
analyzer = ExperimentAnalyzer("chat-inference") analyzer = ExperimentAnalyzer("chat-inference")
# Get metrics summary for last 24 hours # Get metrics summary for last 24 hours
summary = analyzer.get_metrics_summary(hours=24) summary = analyzer.get_metrics_summary(hours=24)
# Compare models by accuracy # Compare models by accuracy
best = analyzer.get_best_run(metric="eval.accuracy", minimize=False) best = analyzer.get_best_run(metric="eval.accuracy", minimize=False)
# Analyze inference latency trends # Analyze inference latency trends
trends = analyzer.get_metric_trends("total_latency_mean", days=7) trends = analyzer.get_metric_trends("total_latency_mean", days=7)
""" """
def __init__( def __init__(
self, self,
experiment_name: str, experiment_name: str,
@@ -157,7 +153,7 @@ class ExperimentAnalyzer:
): ):
""" """
Initialize the experiment analyzer. Initialize the experiment analyzer.
Args: Args:
experiment_name: Name of the MLflow experiment to analyze experiment_name: Name of the MLflow experiment to analyze
tracking_uri: Override default tracking URI tracking_uri: Override default tracking URI
@@ -166,14 +162,14 @@ class ExperimentAnalyzer:
self.tracking_uri = tracking_uri self.tracking_uri = tracking_uri
self.client = get_mlflow_client(tracking_uri=tracking_uri) self.client = get_mlflow_client(tracking_uri=tracking_uri)
self._experiment: Optional[Experiment] = None self._experiment: Optional[Experiment] = None
@property @property
def experiment(self) -> Optional[Experiment]: def experiment(self) -> Optional[Experiment]:
"""Get the experiment object, fetching if needed.""" """Get the experiment object, fetching if needed."""
if self._experiment is None: if self._experiment is None:
self._experiment = self.client.get_experiment_by_name(self.experiment_name) self._experiment = self.client.get_experiment_by_name(self.experiment_name)
return self._experiment return self._experiment
def search_runs( def search_runs(
self, self,
filter_string: str = "", filter_string: str = "",
@@ -183,29 +179,29 @@ class ExperimentAnalyzer:
) -> List[Run]: ) -> List[Run]:
""" """
Search for runs matching criteria. Search for runs matching criteria.
Args: Args:
filter_string: MLflow search filter (e.g., "metrics.accuracy > 0.9") filter_string: MLflow search filter (e.g., "metrics.accuracy > 0.9")
order_by: List of order clauses (e.g., ["metrics.accuracy DESC"]) order_by: List of order clauses (e.g., ["metrics.accuracy DESC"])
max_results: Maximum runs to return max_results: Maximum runs to return
run_view_type: ACTIVE_ONLY, DELETED_ONLY, or ALL run_view_type: ACTIVE_ONLY, DELETED_ONLY, or ALL
Returns: Returns:
List of matching Run objects List of matching Run objects
""" """
if not self.experiment: if not self.experiment:
logger.warning(f"Experiment '{self.experiment_name}' not found") logger.warning(f"Experiment '{self.experiment_name}' not found")
return [] return []
runs = self.client.search_runs( runs = self.client.search_runs(
experiment_ids=[self.experiment.experiment_id], experiment_ids=[self.experiment.experiment_id],
filter_string=filter_string, filter_string=filter_string,
order_by=order_by or ["start_time DESC"], order_by=order_by or ["start_time DESC"],
max_results=max_results, max_results=max_results,
) )
return runs return runs
def get_recent_runs( def get_recent_runs(
self, self,
n: int = 10, n: int = 10,
@@ -213,11 +209,11 @@ class ExperimentAnalyzer:
) -> List[Run]: ) -> List[Run]:
""" """
Get the most recent runs. Get the most recent runs.
Args: Args:
n: Number of runs to return n: Number of runs to return
hours: Only include runs from the last N hours hours: Only include runs from the last N hours
Returns: Returns:
List of Run objects List of Run objects
""" """
@@ -226,13 +222,13 @@ class ExperimentAnalyzer:
cutoff = datetime.now() - timedelta(hours=hours) cutoff = datetime.now() - timedelta(hours=hours)
cutoff_ms = int(cutoff.timestamp() * 1000) cutoff_ms = int(cutoff.timestamp() * 1000)
filter_string = f"attributes.start_time >= {cutoff_ms}" filter_string = f"attributes.start_time >= {cutoff_ms}"
return self.search_runs( return self.search_runs(
filter_string=filter_string, filter_string=filter_string,
order_by=["start_time DESC"], order_by=["start_time DESC"],
max_results=n, max_results=n,
) )
def compare_runs( def compare_runs(
self, self,
run_ids: Optional[List[str]] = None, run_ids: Optional[List[str]] = None,
@@ -240,11 +236,11 @@ class ExperimentAnalyzer:
) -> RunComparison: ) -> RunComparison:
""" """
Compare multiple runs side by side. Compare multiple runs side by side.
Args: Args:
run_ids: Specific run IDs to compare, or None for recent runs run_ids: Specific run IDs to compare, or None for recent runs
n_recent: If run_ids is None, compare this many recent runs n_recent: If run_ids is None, compare this many recent runs
Returns: Returns:
RunComparison object with detailed comparison RunComparison object with detailed comparison
""" """
@@ -252,18 +248,18 @@ class ExperimentAnalyzer:
runs = [self.client.get_run(rid) for rid in run_ids] runs = [self.client.get_run(rid) for rid in run_ids]
else: else:
runs = self.get_recent_runs(n=n_recent) runs = self.get_recent_runs(n=n_recent)
comparison = RunComparison( comparison = RunComparison(
run_ids=[r.info.run_id for r in runs], run_ids=[r.info.run_id for r in runs],
experiment_name=self.experiment_name, experiment_name=self.experiment_name,
) )
# Collect all metrics and find best performers # Collect all metrics and find best performers
all_metrics: Dict[str, Dict[str, float]] = defaultdict(dict) all_metrics: Dict[str, Dict[str, float]] = defaultdict(dict)
for run in runs: for run in runs:
run_id = run.info.run_id run_id = run.info.run_id
# Metadata # Metadata
comparison.run_names[run_id] = run.info.run_name or run_id[:8] comparison.run_names[run_id] = run.info.run_name or run_id[:8]
comparison.start_times[run_id] = datetime.fromtimestamp( comparison.start_times[run_id] = datetime.fromtimestamp(
@@ -273,39 +269,39 @@ class ExperimentAnalyzer:
comparison.durations[run_id] = ( comparison.durations[run_id] = (
run.info.end_time - run.info.start_time run.info.end_time - run.info.start_time
) / 1000 ) / 1000
# Metrics # Metrics
for key, value in run.data.metrics.items(): for key, value in run.data.metrics.items():
all_metrics[key][run_id] = value all_metrics[key][run_id] = value
# Params # Params
for key, value in run.data.params.items(): for key, value in run.data.params.items():
if key not in comparison.params: if key not in comparison.params:
comparison.params[key] = {} comparison.params[key] = {}
comparison.params[key][run_id] = value comparison.params[key][run_id] = value
comparison.metrics = dict(all_metrics) comparison.metrics = dict(all_metrics)
# Find best performers for each metric # Find best performers for each metric
for metric_name, values in all_metrics.items(): for metric_name, values in all_metrics.items():
if not values: if not values:
continue continue
# Determine if lower is better based on metric name # Determine if lower is better based on metric name
minimize = any( minimize = any(
term in metric_name.lower() term in metric_name.lower()
for term in ["latency", "error", "loss", "time"] for term in ["latency", "error", "loss", "time"]
) )
if minimize: if minimize:
best_id = min(values.keys(), key=lambda k: values[k]) best_id = min(values.keys(), key=lambda k: values[k])
else: else:
best_id = max(values.keys(), key=lambda k: values[k]) best_id = max(values.keys(), key=lambda k: values[k])
comparison.best_by_metric[metric_name] = best_id comparison.best_by_metric[metric_name] = best_id
return comparison return comparison
def get_best_run( def get_best_run(
self, self,
metric: str, metric: str,
@@ -315,32 +311,32 @@ class ExperimentAnalyzer:
) -> Optional[Run]: ) -> Optional[Run]:
""" """
Get the best run by a specific metric. Get the best run by a specific metric.
Args: Args:
metric: Metric name to optimize metric: Metric name to optimize
minimize: If True, find minimum; if False, find maximum minimize: If True, find minimum; if False, find maximum
filter_string: Additional filter criteria filter_string: Additional filter criteria
max_results: Maximum runs to consider max_results: Maximum runs to consider
Returns: Returns:
Best Run object, or None if no runs found Best Run object, or None if no runs found
""" """
direction = "ASC" if minimize else "DESC" direction = "ASC" if minimize else "DESC"
runs = self.search_runs( runs = self.search_runs(
filter_string=filter_string, filter_string=filter_string,
order_by=[f"metrics.{metric} {direction}"], order_by=[f"metrics.{metric} {direction}"],
max_results=max_results, max_results=max_results,
) )
# Filter to only runs that have the metric # Filter to only runs that have the metric
runs_with_metric = [ runs_with_metric = [
r for r in runs r for r in runs
if metric in r.data.metrics if metric in r.data.metrics
] ]
return runs_with_metric[0] if runs_with_metric else None return runs_with_metric[0] if runs_with_metric else None
def get_metrics_summary( def get_metrics_summary(
self, self,
hours: Optional[int] = None, hours: Optional[int] = None,
@@ -348,45 +344,45 @@ class ExperimentAnalyzer:
) -> Dict[str, Dict[str, float]]: ) -> Dict[str, Dict[str, float]]:
""" """
Get summary statistics for metrics. Get summary statistics for metrics.
Args: Args:
hours: Only include runs from the last N hours hours: Only include runs from the last N hours
metrics: Specific metrics to summarize (None for all) metrics: Specific metrics to summarize (None for all)
Returns: Returns:
Dict mapping metric names to {mean, min, max, count} Dict mapping metric names to {mean, min, max, count}
""" """
import statistics import statistics
runs = self.get_recent_runs(n=1000, hours=hours) runs = self.get_recent_runs(n=1000, hours=hours)
# Collect all metric values # Collect all metric values
metric_values: Dict[str, List[float]] = defaultdict(list) metric_values: Dict[str, List[float]] = defaultdict(list)
for run in runs: for run in runs:
for key, value in run.data.metrics.items(): for key, value in run.data.metrics.items():
if metrics is None or key in metrics: if metrics is None or key in metrics:
metric_values[key].append(value) metric_values[key].append(value)
# Calculate statistics # Calculate statistics
summary = {} summary = {}
for metric_name, values in metric_values.items(): for metric_name, values in metric_values.items():
if not values: if not values:
continue continue
summary[metric_name] = { summary[metric_name] = {
"mean": statistics.mean(values), "mean": statistics.mean(values),
"min": min(values), "min": min(values),
"max": max(values), "max": max(values),
"count": len(values), "count": len(values),
} }
if len(values) >= 2: if len(values) >= 2:
summary[metric_name]["stdev"] = statistics.stdev(values) summary[metric_name]["stdev"] = statistics.stdev(values)
summary[metric_name]["median"] = statistics.median(values) summary[metric_name]["median"] = statistics.median(values)
return summary return summary
def get_metric_trends( def get_metric_trends(
self, self,
metric: str, metric: str,
@@ -395,30 +391,30 @@ class ExperimentAnalyzer:
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
Get metric trends over time. Get metric trends over time.
Args: Args:
metric: Metric name to track metric: Metric name to track
days: Number of days to look back days: Number of days to look back
granularity_hours: Time bucket size in hours granularity_hours: Time bucket size in hours
Returns: Returns:
List of {timestamp, mean, min, max, count} dicts List of {timestamp, mean, min, max, count} dicts
""" """
import statistics import statistics
runs = self.get_recent_runs(n=10000, hours=days * 24) runs = self.get_recent_runs(n=10000, hours=days * 24)
# Group runs by time bucket # Group runs by time bucket
buckets: Dict[int, List[float]] = defaultdict(list) buckets: Dict[int, List[float]] = defaultdict(list)
bucket_size_ms = granularity_hours * 3600 * 1000 bucket_size_ms = granularity_hours * 3600 * 1000
for run in runs: for run in runs:
if metric not in run.data.metrics: if metric not in run.data.metrics:
continue continue
bucket = (run.info.start_time // bucket_size_ms) * bucket_size_ms bucket = (run.info.start_time // bucket_size_ms) * bucket_size_ms
buckets[bucket].append(run.data.metrics[metric]) buckets[bucket].append(run.data.metrics[metric])
# Calculate statistics per bucket # Calculate statistics per bucket
trends = [] trends = []
for bucket_ts, values in sorted(buckets.items()): for bucket_ts, values in sorted(buckets.items()):
@@ -432,9 +428,9 @@ class ExperimentAnalyzer:
if len(values) >= 2: if len(values) >= 2:
trend["stdev"] = statistics.stdev(values) trend["stdev"] = statistics.stdev(values)
trends.append(trend) trends.append(trend)
return trends return trends
def get_runs_by_tag( def get_runs_by_tag(
self, self,
tag_key: str, tag_key: str,
@@ -443,12 +439,12 @@ class ExperimentAnalyzer:
) -> List[Run]: ) -> List[Run]:
""" """
Get runs with a specific tag. Get runs with a specific tag.
Args: Args:
tag_key: Tag key to filter by tag_key: Tag key to filter by
tag_value: Tag value to match tag_value: Tag value to match
max_results: Maximum runs to return max_results: Maximum runs to return
Returns: Returns:
List of matching Run objects List of matching Run objects
""" """
@@ -456,7 +452,7 @@ class ExperimentAnalyzer:
filter_string=f"tags.{tag_key} = '{tag_value}'", filter_string=f"tags.{tag_key} = '{tag_value}'",
max_results=max_results, max_results=max_results,
) )
def get_model_runs( def get_model_runs(
self, self,
model_name: str, model_name: str,
@@ -464,11 +460,11 @@ class ExperimentAnalyzer:
) -> List[Run]: ) -> List[Run]:
""" """
Get runs for a specific model. Get runs for a specific model.
Args: Args:
model_name: Model name to filter by model_name: Model name to filter by
max_results: Maximum runs to return max_results: Maximum runs to return
Returns: Returns:
List of matching Run objects List of matching Run objects
""" """
@@ -477,14 +473,14 @@ class ExperimentAnalyzer:
filter_string=f"tags.`model.name` = '{model_name}'", filter_string=f"tags.`model.name` = '{model_name}'",
max_results=max_results, max_results=max_results,
) )
if not runs: if not runs:
# Try params # Try params
runs = self.search_runs( runs = self.search_runs(
filter_string=f"params.model_name = '{model_name}'", filter_string=f"params.model_name = '{model_name}'",
max_results=max_results, max_results=max_results,
) )
return runs return runs
@@ -495,23 +491,23 @@ def compare_experiments(
) -> Dict[str, Dict[str, float]]: ) -> Dict[str, Dict[str, float]]:
""" """
Compare metrics across multiple experiments. Compare metrics across multiple experiments.
Args: Args:
experiment_names: Names of experiments to compare experiment_names: Names of experiments to compare
metric: Metric to compare metric: Metric to compare
tracking_uri: Override default tracking URI tracking_uri: Override default tracking URI
Returns: Returns:
Dict mapping experiment names to metric statistics Dict mapping experiment names to metric statistics
""" """
results = {} results = {}
for exp_name in experiment_names: for exp_name in experiment_names:
analyzer = ExperimentAnalyzer(exp_name, tracking_uri=tracking_uri) analyzer = ExperimentAnalyzer(exp_name, tracking_uri=tracking_uri)
summary = analyzer.get_metrics_summary(metrics=[metric]) summary = analyzer.get_metrics_summary(metrics=[metric])
if metric in summary: if metric in summary:
results[exp_name] = summary[metric] results[exp_name] = summary[metric]
return results return results
@@ -523,7 +519,7 @@ def promotion_recommendation(
) -> PromotionRecommendation: ) -> PromotionRecommendation:
""" """
Generate a recommendation for model promotion. Generate a recommendation for model promotion.
Args: Args:
model_name: Name of the model to evaluate model_name: Name of the model to evaluate
experiment_name: Experiment containing evaluation runs experiment_name: Experiment containing evaluation runs
@@ -531,15 +527,15 @@ def promotion_recommendation(
comparison is one of: ">=", "<=", ">", "<" comparison is one of: ">=", "<=", ">", "<"
e.g., {"eval.accuracy": (">=", 0.9), "total_latency_p95": ("<=", 2.0)} e.g., {"eval.accuracy": (">=", 0.9), "total_latency_p95": ("<=", 2.0)}
tracking_uri: Override default tracking URI tracking_uri: Override default tracking URI
Returns: Returns:
PromotionRecommendation with decision and reasons PromotionRecommendation with decision and reasons
""" """
analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri) analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri)
# Get model runs # Get model runs
runs = analyzer.get_model_runs(model_name, max_results=10) runs = analyzer.get_model_runs(model_name, max_results=10)
if not runs: if not runs:
return PromotionRecommendation( return PromotionRecommendation(
model_name=model_name, model_name=model_name,
@@ -548,41 +544,41 @@ def promotion_recommendation(
reasons=["No runs found for this model"], reasons=["No runs found for this model"],
metrics_summary={}, metrics_summary={},
) )
# Get the most recent run # Get the most recent run
latest_run = runs[0] latest_run = runs[0]
metrics = latest_run.data.metrics metrics = latest_run.data.metrics
# Evaluate criteria # Evaluate criteria
reasons = [] reasons = []
passed = True passed = True
comparisons = { comparisons = {
">=": lambda a, b: a >= b, ">=": lambda a, b: a >= b,
"<=": lambda a, b: a <= b, "<=": lambda a, b: a <= b,
">": lambda a, b: a > b, ">": lambda a, b: a > b,
"<": lambda a, b: a < b, "<": lambda a, b: a < b,
} }
for metric_name, (comparison, threshold) in criteria.items(): for metric_name, (comparison, threshold) in criteria.items():
if metric_name not in metrics: if metric_name not in metrics:
reasons.append(f"Metric '{metric_name}' not found") reasons.append(f"Metric '{metric_name}' not found")
passed = False passed = False
continue continue
value = metrics[metric_name] value = metrics[metric_name]
compare_fn = comparisons.get(comparison) compare_fn = comparisons.get(comparison)
if compare_fn is None: if compare_fn is None:
reasons.append(f"Invalid comparison operator: {comparison}") reasons.append(f"Invalid comparison operator: {comparison}")
continue continue
if compare_fn(value, threshold): if compare_fn(value, threshold):
reasons.append(f"{metric_name}: {value:.4f} {comparison} {threshold}") reasons.append(f"{metric_name}: {value:.4f} {comparison} {threshold}")
else: else:
reasons.append(f"{metric_name}: {value:.4f} NOT {comparison} {threshold}") reasons.append(f"{metric_name}: {value:.4f} NOT {comparison} {threshold}")
passed = False passed = False
# Extract version from tags if available # Extract version from tags if available
version = None version = None
if "mlflow.version" in latest_run.data.tags: if "mlflow.version" in latest_run.data.tags:
@@ -590,7 +586,7 @@ def promotion_recommendation(
version = int(latest_run.data.tags["mlflow.version"]) version = int(latest_run.data.tags["mlflow.version"])
except ValueError: except ValueError:
pass pass
return PromotionRecommendation( return PromotionRecommendation(
model_name=model_name, model_name=model_name,
version=version, version=version,
@@ -607,21 +603,21 @@ def get_inference_performance_report(
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Generate an inference performance report for a service. Generate an inference performance report for a service.
Args: Args:
service_name: Service name (chat-handler, voice-assistant) service_name: Service name (chat-handler, voice-assistant)
hours: Hours of data to analyze hours: Hours of data to analyze
tracking_uri: Override default tracking URI tracking_uri: Override default tracking URI
Returns: Returns:
Performance report dictionary Performance report dictionary
""" """
experiment_name = f"{service_name.replace('-', '')}-inference" experiment_name = f"{service_name.replace('-', '')}-inference"
analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri) analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri)
# Get summary metrics # Get summary metrics
summary = analyzer.get_metrics_summary(hours=hours) summary = analyzer.get_metrics_summary(hours=hours)
# Key latency metrics # Key latency metrics
latency_metrics = [ latency_metrics = [
"total_latency_mean", "total_latency_mean",
@@ -631,7 +627,7 @@ def get_inference_performance_report(
"embedding_latency_mean", "embedding_latency_mean",
"rag_search_latency_mean", "rag_search_latency_mean",
] ]
report = { report = {
"service": service_name, "service": service_name,
"period_hours": hours, "period_hours": hours,
@@ -641,24 +637,24 @@ def get_inference_performance_report(
"rag": {}, "rag": {},
"errors": {}, "errors": {},
} }
# Latency section # Latency section
for metric in latency_metrics: for metric in latency_metrics:
if metric in summary: if metric in summary:
report["latency"][metric] = summary[metric] report["latency"][metric] = summary[metric]
# Throughput # Throughput
if "total_requests" in summary: if "total_requests" in summary:
report["throughput"]["total_requests"] = summary["total_requests"]["mean"] report["throughput"]["total_requests"] = summary["total_requests"]["mean"]
# RAG usage # RAG usage
rag_metrics = ["rag_enabled_pct", "rag_documents_retrieved_mean", "rag_documents_used_mean"] rag_metrics = ["rag_enabled_pct", "rag_documents_retrieved_mean", "rag_documents_used_mean"]
for metric in rag_metrics: for metric in rag_metrics:
if metric in summary: if metric in summary:
report["rag"][metric] = summary[metric] report["rag"][metric] = summary[metric]
# Error rate # Error rate
if "error_rate" in summary: if "error_rate" in summary:
report["errors"]["error_rate_pct"] = summary["error_rate"]["mean"] report["errors"]["error_rate_pct"] = summary["error_rate"]["mean"]
return report return report

View File

@@ -9,20 +9,19 @@ complement OTel metrics with MLflow experiment tracking for
longer-term analysis and model comparison. longer-term analysis and model comparison.
""" """
import os
import time
import asyncio import asyncio
import logging import logging
from typing import Optional, Dict, Any, List import time
from dataclasses import dataclass, field
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from functools import partial from functools import partial
from typing import Any, Dict, List, Optional
import mlflow import mlflow
from mlflow.tracking import MlflowClient from mlflow.tracking import MlflowClient
from .client import get_mlflow_client, ensure_experiment, MLflowConfig from .client import MLflowConfig, ensure_experiment, get_mlflow_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -33,7 +32,7 @@ class InferenceMetrics:
request_id: str request_id: str
user_id: Optional[str] = None user_id: Optional[str] = None
session_id: Optional[str] = None session_id: Optional[str] = None
# Timing metrics (in seconds) # Timing metrics (in seconds)
total_latency: float = 0.0 total_latency: float = 0.0
embedding_latency: float = 0.0 embedding_latency: float = 0.0
@@ -42,33 +41,33 @@ class InferenceMetrics:
llm_latency: float = 0.0 llm_latency: float = 0.0
tts_latency: float = 0.0 tts_latency: float = 0.0
stt_latency: float = 0.0 stt_latency: float = 0.0
# Token/size metrics # Token/size metrics
input_tokens: int = 0 input_tokens: int = 0
output_tokens: int = 0 output_tokens: int = 0
total_tokens: int = 0 total_tokens: int = 0
prompt_length: int = 0 prompt_length: int = 0
response_length: int = 0 response_length: int = 0
# RAG metrics # RAG metrics
rag_enabled: bool = False rag_enabled: bool = False
rag_documents_retrieved: int = 0 rag_documents_retrieved: int = 0
rag_documents_used: int = 0 rag_documents_used: int = 0
reranker_enabled: bool = False reranker_enabled: bool = False
# Quality indicators # Quality indicators
is_streaming: bool = False is_streaming: bool = False
is_premium: bool = False is_premium: bool = False
has_error: bool = False has_error: bool = False
error_message: Optional[str] = None error_message: Optional[str] = None
# Model information # Model information
model_name: Optional[str] = None model_name: Optional[str] = None
model_endpoint: Optional[str] = None model_endpoint: Optional[str] = None
# Timestamps # Timestamps
timestamp: float = field(default_factory=time.time) timestamp: float = field(default_factory=time.time)
def as_metrics_dict(self) -> Dict[str, float]: def as_metrics_dict(self) -> Dict[str, float]:
"""Convert numeric fields to a metrics dictionary.""" """Convert numeric fields to a metrics dictionary."""
return { return {
@@ -87,7 +86,7 @@ class InferenceMetrics:
"rag_documents_retrieved": float(self.rag_documents_retrieved), "rag_documents_retrieved": float(self.rag_documents_retrieved),
"rag_documents_used": float(self.rag_documents_used), "rag_documents_used": float(self.rag_documents_used),
} }
def as_params_dict(self) -> Dict[str, str]: def as_params_dict(self) -> Dict[str, str]:
"""Convert configuration fields to a params dictionary.""" """Convert configuration fields to a params dictionary."""
params = { params = {
@@ -106,39 +105,39 @@ class InferenceMetrics:
class InferenceMetricsTracker: class InferenceMetricsTracker:
""" """
Async-compatible MLflow tracker for inference metrics. Async-compatible MLflow tracker for inference metrics.
Uses batching and a background thread pool to avoid blocking Uses batching and a background thread pool to avoid blocking
the async event loop during MLflow calls. the async event loop during MLflow calls.
Example usage in chat-handler: Example usage in chat-handler:
class ChatHandler: class ChatHandler:
def __init__(self): def __init__(self):
self.mlflow_tracker = InferenceMetricsTracker( self.mlflow_tracker = InferenceMetricsTracker(
service_name="chat-handler", service_name="chat-handler",
experiment_name="chat-inference" experiment_name="chat-inference"
) )
async def setup(self): async def setup(self):
await self.mlflow_tracker.start() await self.mlflow_tracker.start()
async def process_request(self, msg): async def process_request(self, msg):
metrics = InferenceMetrics(request_id=request_id) metrics = InferenceMetrics(request_id=request_id)
# Track timing # Track timing
start = time.time() start = time.time()
# ... do embedding ... # ... do embedding ...
metrics.embedding_latency = time.time() - start metrics.embedding_latency = time.time() - start
# ... more processing ... # ... more processing ...
# Log metrics (non-blocking) # Log metrics (non-blocking)
await self.mlflow_tracker.log_inference(metrics) await self.mlflow_tracker.log_inference(metrics)
async def shutdown(self): async def shutdown(self):
await self.mlflow_tracker.stop() await self.mlflow_tracker.stop()
""" """
def __init__( def __init__(
self, self,
service_name: str, service_name: str,
@@ -151,7 +150,7 @@ class InferenceMetricsTracker:
): ):
""" """
Initialize the inference metrics tracker. Initialize the inference metrics tracker.
Args: Args:
service_name: Name of the service (e.g., "chat-handler") service_name: Name of the service (e.g., "chat-handler")
experiment_name: MLflow experiment name (defaults to service_name) experiment_name: MLflow experiment name (defaults to service_name)
@@ -167,7 +166,7 @@ class InferenceMetricsTracker:
self.batch_size = batch_size self.batch_size = batch_size
self.flush_interval = flush_interval_seconds self.flush_interval = flush_interval_seconds
self.enable_batching = enable_batching self.enable_batching = enable_batching
self.config = MLflowConfig() self.config = MLflowConfig()
self._batch: List[InferenceMetrics] = [] self._batch: List[InferenceMetrics] = []
self._batch_lock = asyncio.Lock() self._batch_lock = asyncio.Lock()
@@ -176,34 +175,34 @@ class InferenceMetricsTracker:
self._running = False self._running = False
self._client: Optional[MlflowClient] = None self._client: Optional[MlflowClient] = None
self._experiment_id: Optional[str] = None self._experiment_id: Optional[str] = None
# Aggregate metrics for periodic logging # Aggregate metrics for periodic logging
self._aggregate_metrics: Dict[str, List[float]] = defaultdict(list) self._aggregate_metrics: Dict[str, List[float]] = defaultdict(list)
self._request_count = 0 self._request_count = 0
self._error_count = 0 self._error_count = 0
async def start(self) -> None: async def start(self) -> None:
"""Start the tracker and initialize MLflow connection.""" """Start the tracker and initialize MLflow connection."""
if self._running: if self._running:
return return
self._running = True self._running = True
# Initialize MLflow in thread pool to avoid blocking # Initialize MLflow in thread pool to avoid blocking
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
await loop.run_in_executor( await loop.run_in_executor(
self._executor, self._executor,
self._init_mlflow self._init_mlflow
) )
if self.enable_batching: if self.enable_batching:
self._flush_task = asyncio.create_task(self._periodic_flush()) self._flush_task = asyncio.create_task(self._periodic_flush())
logger.info( logger.info(
f"InferenceMetricsTracker started for {self.service_name} " f"InferenceMetricsTracker started for {self.service_name} "
f"(experiment: {self.experiment_name})" f"(experiment: {self.experiment_name})"
) )
def _init_mlflow(self) -> None: def _init_mlflow(self) -> None:
"""Initialize MLflow client and experiment (runs in thread pool).""" """Initialize MLflow client and experiment (runs in thread pool)."""
self._client = get_mlflow_client( self._client = get_mlflow_client(
@@ -217,47 +216,47 @@ class InferenceMetricsTracker:
"type": "inference-metrics", "type": "inference-metrics",
} }
) )
async def stop(self) -> None: async def stop(self) -> None:
"""Stop the tracker and flush remaining metrics.""" """Stop the tracker and flush remaining metrics."""
if not self._running: if not self._running:
return return
self._running = False self._running = False
if self._flush_task: if self._flush_task:
self._flush_task.cancel() self._flush_task.cancel()
try: try:
await self._flush_task await self._flush_task
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
# Final flush # Final flush
await self._flush_batch() await self._flush_batch()
self._executor.shutdown(wait=True) self._executor.shutdown(wait=True)
logger.info(f"InferenceMetricsTracker stopped for {self.service_name}") logger.info(f"InferenceMetricsTracker stopped for {self.service_name}")
async def log_inference(self, metrics: InferenceMetrics) -> None: async def log_inference(self, metrics: InferenceMetrics) -> None:
""" """
Log inference metrics (non-blocking). Log inference metrics (non-blocking).
Args: Args:
metrics: InferenceMetrics object with request data metrics: InferenceMetrics object with request data
""" """
if not self._running: if not self._running:
logger.warning("Tracker not running, skipping metrics") logger.warning("Tracker not running, skipping metrics")
return return
self._request_count += 1 self._request_count += 1
if metrics.has_error: if metrics.has_error:
self._error_count += 1 self._error_count += 1
# Update aggregates # Update aggregates
for key, value in metrics.as_metrics_dict().items(): for key, value in metrics.as_metrics_dict().items():
if value > 0: if value > 0:
self._aggregate_metrics[key].append(value) self._aggregate_metrics[key].append(value)
if self.enable_batching: if self.enable_batching:
async with self._batch_lock: async with self._batch_lock:
self._batch.append(metrics) self._batch.append(metrics)
@@ -270,29 +269,29 @@ class InferenceMetricsTracker:
self._executor, self._executor,
partial(self._log_single_inference, metrics) partial(self._log_single_inference, metrics)
) )
async def _periodic_flush(self) -> None: async def _periodic_flush(self) -> None:
"""Periodically flush batched metrics.""" """Periodically flush batched metrics."""
while self._running: while self._running:
await asyncio.sleep(self.flush_interval) await asyncio.sleep(self.flush_interval)
await self._flush_batch() await self._flush_batch()
async def _flush_batch(self) -> None: async def _flush_batch(self) -> None:
"""Flush the current batch of metrics to MLflow.""" """Flush the current batch of metrics to MLflow."""
async with self._batch_lock: async with self._batch_lock:
if not self._batch: if not self._batch:
return return
batch = self._batch batch = self._batch
self._batch = [] self._batch = []
# Log in thread pool # Log in thread pool
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
await loop.run_in_executor( await loop.run_in_executor(
self._executor, self._executor,
partial(self._log_batch, batch) partial(self._log_batch, batch)
) )
def _log_single_inference(self, metrics: InferenceMetrics) -> None: def _log_single_inference(self, metrics: InferenceMetrics) -> None:
"""Log a single inference request to MLflow (runs in thread pool).""" """Log a single inference request to MLflow (runs in thread pool)."""
try: try:
@@ -307,7 +306,7 @@ class InferenceMetricsTracker:
): ):
mlflow.log_params(metrics.as_params_dict()) mlflow.log_params(metrics.as_params_dict())
mlflow.log_metrics(metrics.as_metrics_dict()) mlflow.log_metrics(metrics.as_metrics_dict())
if metrics.user_id: if metrics.user_id:
mlflow.set_tag("user_id", metrics.user_id) mlflow.set_tag("user_id", metrics.user_id)
if metrics.session_id: if metrics.session_id:
@@ -318,18 +317,18 @@ class InferenceMetricsTracker:
mlflow.set_tag("error_message", metrics.error_message[:250]) mlflow.set_tag("error_message", metrics.error_message[:250])
except Exception as e: except Exception as e:
logger.error(f"Failed to log inference metrics: {e}") logger.error(f"Failed to log inference metrics: {e}")
def _log_batch(self, batch: List[InferenceMetrics]) -> None: def _log_batch(self, batch: List[InferenceMetrics]) -> None:
"""Log a batch of inference metrics as aggregate statistics.""" """Log a batch of inference metrics as aggregate statistics."""
if not batch: if not batch:
return return
try: try:
# Calculate aggregates # Calculate aggregates
aggregates = self._calculate_aggregates(batch) aggregates = self._calculate_aggregates(batch)
run_name = f"batch-{self.service_name}-{int(time.time())}" run_name = f"batch-{self.service_name}-{int(time.time())}"
with mlflow.start_run( with mlflow.start_run(
experiment_id=self._experiment_id, experiment_id=self._experiment_id,
run_name=run_name, run_name=run_name,
@@ -341,83 +340,83 @@ class InferenceMetricsTracker:
): ):
# Log aggregate metrics # Log aggregate metrics
mlflow.log_metrics(aggregates) mlflow.log_metrics(aggregates)
# Log batch info # Log batch info
mlflow.log_param("batch_size", len(batch)) mlflow.log_param("batch_size", len(batch))
mlflow.log_param("time_window_start", min(m.timestamp for m in batch)) mlflow.log_param("time_window_start", min(m.timestamp for m in batch))
mlflow.log_param("time_window_end", max(m.timestamp for m in batch)) mlflow.log_param("time_window_end", max(m.timestamp for m in batch))
# Log configuration breakdown # Log configuration breakdown
rag_enabled_count = sum(1 for m in batch if m.rag_enabled) rag_enabled_count = sum(1 for m in batch if m.rag_enabled)
streaming_count = sum(1 for m in batch if m.is_streaming) streaming_count = sum(1 for m in batch if m.is_streaming)
premium_count = sum(1 for m in batch if m.is_premium) premium_count = sum(1 for m in batch if m.is_premium)
error_count = sum(1 for m in batch if m.has_error) error_count = sum(1 for m in batch if m.has_error)
mlflow.log_metrics({ mlflow.log_metrics({
"rag_enabled_pct": rag_enabled_count / len(batch) * 100, "rag_enabled_pct": rag_enabled_count / len(batch) * 100,
"streaming_pct": streaming_count / len(batch) * 100, "streaming_pct": streaming_count / len(batch) * 100,
"premium_pct": premium_count / len(batch) * 100, "premium_pct": premium_count / len(batch) * 100,
"error_rate": error_count / len(batch) * 100, "error_rate": error_count / len(batch) * 100,
}) })
# Log model distribution # Log model distribution
model_counts: Dict[str, int] = defaultdict(int) model_counts: Dict[str, int] = defaultdict(int)
for m in batch: for m in batch:
if m.model_name: if m.model_name:
model_counts[m.model_name] += 1 model_counts[m.model_name] += 1
if model_counts: if model_counts:
mlflow.log_dict( mlflow.log_dict(
{"models": dict(model_counts)}, {"models": dict(model_counts)},
"model_distribution.json" "model_distribution.json"
) )
logger.info(f"Logged batch of {len(batch)} inference metrics") logger.info(f"Logged batch of {len(batch)} inference metrics")
except Exception as e: except Exception as e:
logger.error(f"Failed to log batch metrics: {e}") logger.error(f"Failed to log batch metrics: {e}")
def _calculate_aggregates( def _calculate_aggregates(
self, self,
batch: List[InferenceMetrics] batch: List[InferenceMetrics]
) -> Dict[str, float]: ) -> Dict[str, float]:
"""Calculate aggregate statistics from a batch of metrics.""" """Calculate aggregate statistics from a batch of metrics."""
import statistics import statistics
aggregates = {} aggregates = {}
# Collect all numeric metrics # Collect all numeric metrics
metric_values: Dict[str, List[float]] = defaultdict(list) metric_values: Dict[str, List[float]] = defaultdict(list)
for m in batch: for m in batch:
for key, value in m.as_metrics_dict().items(): for key, value in m.as_metrics_dict().items():
if value > 0: if value > 0:
metric_values[key].append(value) metric_values[key].append(value)
# Calculate statistics for each metric # Calculate statistics for each metric
for key, values in metric_values.items(): for key, values in metric_values.items():
if not values: if not values:
continue continue
aggregates[f"{key}_mean"] = statistics.mean(values) aggregates[f"{key}_mean"] = statistics.mean(values)
aggregates[f"{key}_min"] = min(values) aggregates[f"{key}_min"] = min(values)
aggregates[f"{key}_max"] = max(values) aggregates[f"{key}_max"] = max(values)
if len(values) >= 2: if len(values) >= 2:
aggregates[f"{key}_p50"] = statistics.median(values) aggregates[f"{key}_p50"] = statistics.median(values)
aggregates[f"{key}_stdev"] = statistics.stdev(values) aggregates[f"{key}_stdev"] = statistics.stdev(values)
if len(values) >= 4: if len(values) >= 4:
sorted_vals = sorted(values) sorted_vals = sorted(values)
p95_idx = int(len(sorted_vals) * 0.95) p95_idx = int(len(sorted_vals) * 0.95)
p99_idx = int(len(sorted_vals) * 0.99) p99_idx = int(len(sorted_vals) * 0.99)
aggregates[f"{key}_p95"] = sorted_vals[p95_idx] aggregates[f"{key}_p95"] = sorted_vals[p95_idx]
aggregates[f"{key}_p99"] = sorted_vals[p99_idx] aggregates[f"{key}_p99"] = sorted_vals[p99_idx]
# Add counts # Add counts
aggregates["total_requests"] = float(len(batch)) aggregates["total_requests"] = float(len(batch))
return aggregates return aggregates
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
"""Get current tracker statistics.""" """Get current tracker statistics."""
return { return {

View File

@@ -21,22 +21,22 @@ Usage in a Kubeflow Pipeline:
experiment_name="my-experiment", experiment_name="my-experiment",
run_name="training-run-1" run_name="training-run-1"
) )
# ... your pipeline steps ... # ... your pipeline steps ...
# Log metrics # Log metrics
log_step = log_metrics_component( log_step = log_metrics_component(
run_id=run_info.outputs["run_id"], run_id=run_info.outputs["run_id"],
metrics={"accuracy": 0.95, "loss": 0.05} metrics={"accuracy": 0.95, "loss": 0.05}
) )
# End run # End run
end_mlflow_run(run_id=run_info.outputs["run_id"]) end_mlflow_run(run_id=run_info.outputs["run_id"])
""" """
from kfp import dsl from typing import Any, Dict, List, NamedTuple
from typing import Dict, Any, List, Optional, NamedTuple
from kfp import dsl
# MLflow component image with all required dependencies # MLflow component image with all required dependencies
MLFLOW_IMAGE = "python:3.13-slim" MLFLOW_IMAGE = "python:3.13-slim"
@@ -60,31 +60,32 @@ def create_mlflow_run(
) -> NamedTuple('RunInfo', [('run_id', str), ('experiment_id', str), ('artifact_uri', str)]): ) -> NamedTuple('RunInfo', [('run_id', str), ('experiment_id', str), ('artifact_uri', str)]):
""" """
Create a new MLflow run for the pipeline. Create a new MLflow run for the pipeline.
This should be called at the start of a pipeline to initialize This should be called at the start of a pipeline to initialize
tracking. The returned run_id should be passed to subsequent tracking. The returned run_id should be passed to subsequent
components for logging. components for logging.
Args: Args:
experiment_name: Name of the MLflow experiment experiment_name: Name of the MLflow experiment
run_name: Name for this specific run run_name: Name for this specific run
mlflow_tracking_uri: MLflow tracking server URI mlflow_tracking_uri: MLflow tracking server URI
tags: Optional tags to add to the run tags: Optional tags to add to the run
params: Optional parameters to log params: Optional parameters to log
Returns: Returns:
NamedTuple with run_id, experiment_id, and artifact_uri NamedTuple with run_id, experiment_id, and artifact_uri
""" """
import os import os
from collections import namedtuple
import mlflow import mlflow
from mlflow.tracking import MlflowClient from mlflow.tracking import MlflowClient
from collections import namedtuple
# Set tracking URI # Set tracking URI
mlflow.set_tracking_uri(mlflow_tracking_uri) mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient() client = MlflowClient()
# Get or create experiment # Get or create experiment
experiment = client.get_experiment_by_name(experiment_name) experiment = client.get_experiment_by_name(experiment_name)
if experiment is None: if experiment is None:
@@ -94,7 +95,7 @@ def create_mlflow_run(
) )
else: else:
experiment_id = experiment.experiment_id experiment_id = experiment.experiment_id
# Create default tags # Create default tags
default_tags = { default_tags = {
"pipeline.type": "kubeflow", "pipeline.type": "kubeflow",
@@ -103,24 +104,24 @@ def create_mlflow_run(
} }
if tags: if tags:
default_tags.update(tags) default_tags.update(tags)
# Start run # Start run
run = mlflow.start_run( run = mlflow.start_run(
experiment_id=experiment_id, experiment_id=experiment_id,
run_name=run_name, run_name=run_name,
tags=default_tags, tags=default_tags,
) )
# Log initial params # Log initial params
if params: if params:
mlflow.log_params(params) mlflow.log_params(params)
run_id = run.info.run_id run_id = run.info.run_id
artifact_uri = run.info.artifact_uri artifact_uri = run.info.artifact_uri
# End run (KFP components are isolated, we'll resume in other components) # End run (KFP components are isolated, we'll resume in other components)
mlflow.end_run() mlflow.end_run()
RunInfo = namedtuple('RunInfo', ['run_id', 'experiment_id', 'artifact_uri']) RunInfo = namedtuple('RunInfo', ['run_id', 'experiment_id', 'artifact_uri'])
return RunInfo(run_id, experiment_id, artifact_uri) return RunInfo(run_id, experiment_id, artifact_uri)
@@ -136,24 +137,24 @@ def log_params_component(
) -> str: ) -> str:
""" """
Log parameters to an existing MLflow run. Log parameters to an existing MLflow run.
Args: Args:
run_id: The MLflow run ID to log to run_id: The MLflow run ID to log to
params: Dictionary of parameters to log params: Dictionary of parameters to log
mlflow_tracking_uri: MLflow tracking server URI mlflow_tracking_uri: MLflow tracking server URI
Returns: Returns:
The run_id for chaining The run_id for chaining
""" """
import mlflow import mlflow
from mlflow.tracking import MlflowClient from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri) mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient() client = MlflowClient()
for key, value in params.items(): for key, value in params.items():
client.log_param(run_id, key, str(value)[:500]) client.log_param(run_id, key, str(value)[:500])
return run_id return run_id
@@ -169,25 +170,25 @@ def log_metrics_component(
) -> str: ) -> str:
""" """
Log metrics to an existing MLflow run. Log metrics to an existing MLflow run.
Args: Args:
run_id: The MLflow run ID to log to run_id: The MLflow run ID to log to
metrics: Dictionary of metrics to log metrics: Dictionary of metrics to log
step: Step number for time-series metrics step: Step number for time-series metrics
mlflow_tracking_uri: MLflow tracking server URI mlflow_tracking_uri: MLflow tracking server URI
Returns: Returns:
The run_id for chaining The run_id for chaining
""" """
import mlflow import mlflow
from mlflow.tracking import MlflowClient from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri) mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient() client = MlflowClient()
for key, value in metrics.items(): for key, value in metrics.items():
client.log_metric(run_id, key, float(value), step=step) client.log_metric(run_id, key, float(value), step=step)
return run_id return run_id
@@ -203,24 +204,24 @@ def log_artifact_component(
) -> str: ) -> str:
""" """
Log an artifact file to an existing MLflow run. Log an artifact file to an existing MLflow run.
Args: Args:
run_id: The MLflow run ID to log to run_id: The MLflow run ID to log to
artifact_path: Path to the artifact file artifact_path: Path to the artifact file
artifact_name: Optional destination name in artifact store artifact_name: Optional destination name in artifact store
mlflow_tracking_uri: MLflow tracking server URI mlflow_tracking_uri: MLflow tracking server URI
Returns: Returns:
The run_id for chaining The run_id for chaining
""" """
import mlflow import mlflow
from mlflow.tracking import MlflowClient from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri) mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient() client = MlflowClient()
client.log_artifact(run_id, artifact_path, artifact_name or None) client.log_artifact(run_id, artifact_path, artifact_name or None)
return run_id return run_id
@@ -236,36 +237,37 @@ def log_dict_artifact(
) -> str: ) -> str:
""" """
Log a dictionary as a JSON artifact. Log a dictionary as a JSON artifact.
Args: Args:
run_id: The MLflow run ID to log to run_id: The MLflow run ID to log to
data: Dictionary to save as JSON data: Dictionary to save as JSON
filename: Name for the JSON file filename: Name for the JSON file
mlflow_tracking_uri: MLflow tracking server URI mlflow_tracking_uri: MLflow tracking server URI
Returns: Returns:
The run_id for chaining The run_id for chaining
""" """
import json import json
import tempfile import tempfile
from pathlib import Path
import mlflow import mlflow
from mlflow.tracking import MlflowClient from mlflow.tracking import MlflowClient
from pathlib import Path
mlflow.set_tracking_uri(mlflow_tracking_uri) mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient() client = MlflowClient()
# Ensure .json extension # Ensure .json extension
if not filename.endswith('.json'): if not filename.endswith('.json'):
filename += '.json' filename += '.json'
# Write to temp file and log # Write to temp file and log
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
filepath = Path(tmpdir) / filename filepath = Path(tmpdir) / filename
with open(filepath, 'w') as f: with open(filepath, 'w') as f:
json.dump(data, f, indent=2) json.dump(data, f, indent=2)
client.log_artifact(run_id, str(filepath)) client.log_artifact(run_id, str(filepath))
return run_id return run_id
@@ -280,31 +282,31 @@ def end_mlflow_run(
) -> str: ) -> str:
""" """
End an MLflow run with the specified status. End an MLflow run with the specified status.
Args: Args:
run_id: The MLflow run ID to end run_id: The MLflow run ID to end
status: Run status (FINISHED, FAILED, KILLED) status: Run status (FINISHED, FAILED, KILLED)
mlflow_tracking_uri: MLflow tracking server URI mlflow_tracking_uri: MLflow tracking server URI
Returns: Returns:
The run_id The run_id
""" """
import mlflow import mlflow
from mlflow.tracking import MlflowClient
from mlflow.entities import RunStatus from mlflow.entities import RunStatus
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri) mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient() client = MlflowClient()
status_map = { status_map = {
"FINISHED": RunStatus.FINISHED, "FINISHED": RunStatus.FINISHED,
"FAILED": RunStatus.FAILED, "FAILED": RunStatus.FAILED,
"KILLED": RunStatus.KILLED, "KILLED": RunStatus.KILLED,
} }
run_status = status_map.get(status.upper(), RunStatus.FINISHED) run_status = status_map.get(status.upper(), RunStatus.FINISHED)
client.set_terminated(run_id, status=run_status) client.set_terminated(run_id, status=run_status)
return run_id return run_id
@@ -322,10 +324,10 @@ def log_training_metrics(
) -> str: ) -> str:
""" """
Log comprehensive training metrics for ML models. Log comprehensive training metrics for ML models.
Designed for use with QLoRA training, voice training, and other Designed for use with QLoRA training, voice training, and other
ML training pipelines in the llm-workflows repository. ML training pipelines in the llm-workflows repository.
Args: Args:
run_id: The MLflow run ID to log to run_id: The MLflow run ID to log to
model_type: Type of model (llm, stt, tts, embeddings) model_type: Type of model (llm, stt, tts, embeddings)
@@ -333,19 +335,20 @@ def log_training_metrics(
final_metrics: Final training metrics final_metrics: Final training metrics
model_path: Path to saved model (if applicable) model_path: Path to saved model (if applicable)
mlflow_tracking_uri: MLflow tracking server URI mlflow_tracking_uri: MLflow tracking server URI
Returns: Returns:
The run_id for chaining The run_id for chaining
""" """
import json import json
import tempfile import tempfile
from pathlib import Path
import mlflow import mlflow
from mlflow.tracking import MlflowClient from mlflow.tracking import MlflowClient
from pathlib import Path
mlflow.set_tracking_uri(mlflow_tracking_uri) mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient() client = MlflowClient()
# Log training config as params # Log training config as params
flat_config = {} flat_config = {}
for key, value in training_config.items(): for key, value in training_config.items():
@@ -353,29 +356,29 @@ def log_training_metrics(
flat_config[f"config.{key}"] = json.dumps(value)[:500] flat_config[f"config.{key}"] = json.dumps(value)[:500]
else: else:
flat_config[f"config.{key}"] = str(value)[:500] flat_config[f"config.{key}"] = str(value)[:500]
for key, value in flat_config.items(): for key, value in flat_config.items():
client.log_param(run_id, key, value) client.log_param(run_id, key, value)
# Log model type tag # Log model type tag
client.set_tag(run_id, "model.type", model_type) client.set_tag(run_id, "model.type", model_type)
# Log metrics # Log metrics
for key, value in final_metrics.items(): for key, value in final_metrics.items():
client.log_metric(run_id, key, float(value)) client.log_metric(run_id, key, float(value))
# Log full config as artifact # Log full config as artifact
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
config_path = Path(tmpdir) / "training_config.json" config_path = Path(tmpdir) / "training_config.json"
with open(config_path, 'w') as f: with open(config_path, 'w') as f:
json.dump(training_config, f, indent=2) json.dump(training_config, f, indent=2)
client.log_artifact(run_id, str(config_path)) client.log_artifact(run_id, str(config_path))
# Log model path if provided # Log model path if provided
if model_path: if model_path:
client.log_param(run_id, "model.path", model_path) client.log_param(run_id, "model.path", model_path)
client.set_tag(run_id, "model.saved", "true") client.set_tag(run_id, "model.saved", "true")
return run_id return run_id
@@ -397,9 +400,9 @@ def log_document_ingestion_metrics(
) -> str: ) -> str:
""" """
Log document ingestion pipeline metrics. Log document ingestion pipeline metrics.
Designed for use with the document_ingestion_pipeline. Designed for use with the document_ingestion_pipeline.
Args: Args:
run_id: The MLflow run ID to log to run_id: The MLflow run ID to log to
source_url: URL of the source document source_url: URL of the source document
@@ -411,16 +414,16 @@ def log_document_ingestion_metrics(
chunk_size: Chunk size in tokens chunk_size: Chunk size in tokens
chunk_overlap: Chunk overlap in tokens chunk_overlap: Chunk overlap in tokens
mlflow_tracking_uri: MLflow tracking server URI mlflow_tracking_uri: MLflow tracking server URI
Returns: Returns:
The run_id for chaining The run_id for chaining
""" """
import mlflow import mlflow
from mlflow.tracking import MlflowClient from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri) mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient() client = MlflowClient()
# Log params # Log params
params = { params = {
"source_url": source_url[:500], "source_url": source_url[:500],
@@ -431,7 +434,7 @@ def log_document_ingestion_metrics(
} }
for key, value in params.items(): for key, value in params.items():
client.log_param(run_id, key, value) client.log_param(run_id, key, value)
# Log metrics # Log metrics
metrics = { metrics = {
"chunks_created": chunks_created, "chunks_created": chunks_created,
@@ -441,11 +444,11 @@ def log_document_ingestion_metrics(
} }
for key, value in metrics.items(): for key, value in metrics.items():
client.log_metric(run_id, key, float(value)) client.log_metric(run_id, key, float(value))
# Set pipeline type tag # Set pipeline type tag
client.set_tag(run_id, "pipeline.type", "document-ingestion") client.set_tag(run_id, "pipeline.type", "document-ingestion")
client.set_tag(run_id, "milvus.collection", collection_name) client.set_tag(run_id, "milvus.collection", collection_name)
return run_id return run_id
@@ -463,9 +466,9 @@ def log_evaluation_results(
) -> str: ) -> str:
""" """
Log model evaluation results. Log model evaluation results.
Designed for use with the evaluation_pipeline. Designed for use with the evaluation_pipeline.
Args: Args:
run_id: The MLflow run ID to log to run_id: The MLflow run ID to log to
model_name: Name of the evaluated model model_name: Name of the evaluated model
@@ -473,27 +476,28 @@ def log_evaluation_results(
metrics: Evaluation metrics (accuracy, etc.) metrics: Evaluation metrics (accuracy, etc.)
sample_results: Optional sample predictions sample_results: Optional sample predictions
mlflow_tracking_uri: MLflow tracking server URI mlflow_tracking_uri: MLflow tracking server URI
Returns: Returns:
The run_id for chaining The run_id for chaining
""" """
import json import json
import tempfile import tempfile
from pathlib import Path
import mlflow import mlflow
from mlflow.tracking import MlflowClient from mlflow.tracking import MlflowClient
from pathlib import Path
mlflow.set_tracking_uri(mlflow_tracking_uri) mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient() client = MlflowClient()
# Log params # Log params
client.log_param(run_id, "eval.model_name", model_name) client.log_param(run_id, "eval.model_name", model_name)
client.log_param(run_id, "eval.dataset", dataset_name) client.log_param(run_id, "eval.dataset", dataset_name)
# Log metrics # Log metrics
for key, value in metrics.items(): for key, value in metrics.items():
client.log_metric(run_id, f"eval.{key}", float(value)) client.log_metric(run_id, f"eval.{key}", float(value))
# Log sample results as artifact # Log sample results as artifact
if sample_results: if sample_results:
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
@@ -501,13 +505,13 @@ def log_evaluation_results(
with open(results_path, 'w') as f: with open(results_path, 'w') as f:
json.dump(sample_results, f, indent=2) json.dump(sample_results, f, indent=2)
client.log_artifact(run_id, str(results_path)) client.log_artifact(run_id, str(results_path))
# Set tags # Set tags
client.set_tag(run_id, "pipeline.type", "evaluation") client.set_tag(run_id, "pipeline.type", "evaluation")
client.set_tag(run_id, "model.name", model_name) client.set_tag(run_id, "model.name", model_name)
# Determine if passed # Determine if passed
passed = metrics.get("pass", metrics.get("accuracy", 0) >= 0.7) passed = metrics.get("pass", metrics.get("accuracy", 0) >= 0.7)
client.set_tag(run_id, "eval.passed", str(passed)) client.set_tag(run_id, "eval.passed", str(passed))
return run_id return run_id

View File

@@ -17,7 +17,7 @@ Usage:
promote_model_to_production, promote_model_to_production,
generate_kserve_manifest, generate_kserve_manifest,
) )
# Register a new model version # Register a new model version
model_version = register_model_for_kserve( model_version = register_model_for_kserve(
model_name="whisper-finetuned", model_name="whisper-finetuned",
@@ -28,7 +28,7 @@ Usage:
"container_image": "ghcr.io/my-org/whisper:v2", "container_image": "ghcr.io/my-org/whisper:v2",
} }
) )
# Generate KServe manifest for deployment # Generate KServe manifest for deployment
manifest = generate_kserve_manifest( manifest = generate_kserve_manifest(
model_name="whisper-finetuned", model_name="whisper-finetuned",
@@ -36,18 +36,15 @@ Usage:
) )
""" """
import os
import json
import yaml
import logging import logging
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import mlflow import mlflow
from mlflow.tracking import MlflowClient import yaml
from mlflow.entities.model_registry import ModelVersion from mlflow.entities.model_registry import ModelVersion
from .client import get_mlflow_client, MLflowConfig from .client import get_mlflow_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -55,15 +52,15 @@ logger = logging.getLogger(__name__)
@dataclass @dataclass
class KServeConfig: class KServeConfig:
"""Configuration for KServe deployment.""" """Configuration for KServe deployment."""
# Runtime/container configuration # Runtime/container configuration
runtime: str = "kserve-huggingface" # kserve-huggingface, kserve-custom, etc. runtime: str = "kserve-huggingface" # kserve-huggingface, kserve-custom, etc.
container_image: Optional[str] = None container_image: Optional[str] = None
container_port: int = 8080 container_port: int = 8080
# Protocol configuration # Protocol configuration
protocol: str = "v2" # v1, v2, grpc protocol: str = "v2" # v1, v2, grpc
# Resource requests/limits # Resource requests/limits
cpu_request: str = "1" cpu_request: str = "1"
cpu_limit: str = "4" cpu_limit: str = "4"
@@ -71,22 +68,22 @@ class KServeConfig:
memory_limit: str = "16Gi" memory_limit: str = "16Gi"
gpu_count: int = 0 gpu_count: int = 0
gpu_type: str = "nvidia.com/gpu" # or amd.com/gpu for ROCm gpu_type: str = "nvidia.com/gpu" # or amd.com/gpu for ROCm
# Storage configuration # Storage configuration
storage_uri: Optional[str] = None # s3://, pvc://, gs:// storage_uri: Optional[str] = None # s3://, pvc://, gs://
# Scaling configuration # Scaling configuration
min_replicas: int = 1 min_replicas: int = 1
max_replicas: int = 1 max_replicas: int = 1
scale_target: int = 10 # Target concurrent requests for scaling scale_target: int = 10 # Target concurrent requests for scaling
# Serving configuration # Serving configuration
timeout_seconds: int = 300 timeout_seconds: int = 300
batch_size: int = 1 batch_size: int = 1
# Additional environment variables # Additional environment variables
env_vars: Dict[str, str] = field(default_factory=dict) env_vars: Dict[str, str] = field(default_factory=dict)
def as_dict(self) -> Dict[str, Any]: def as_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for MLflow tags.""" """Convert to dictionary for MLflow tags."""
return { return {
@@ -165,7 +162,7 @@ def register_model_for_kserve(
) -> ModelVersion: ) -> ModelVersion:
""" """
Register a model in MLflow Model Registry with KServe metadata. Register a model in MLflow Model Registry with KServe metadata.
Args: Args:
model_name: Name for the registered model model_name: Name for the registered model
model_uri: URI to model artifacts (runs:/run_id/path, s3://, pvc://) model_uri: URI to model artifacts (runs:/run_id/path, s3://, pvc://)
@@ -175,16 +172,16 @@ def register_model_for_kserve(
kserve_config: KServe deployment configuration kserve_config: KServe deployment configuration
tags: Additional tags for the model version tags: Additional tags for the model version
tracking_uri: Override default tracking URI tracking_uri: Override default tracking URI
Returns: Returns:
The created ModelVersion object The created ModelVersion object
""" """
client = get_mlflow_client(tracking_uri=tracking_uri) client = get_mlflow_client(tracking_uri=tracking_uri)
# Get or use preset KServe config # Get or use preset KServe config
if kserve_config is None: if kserve_config is None:
kserve_config = KSERVE_PRESETS.get(model_type, KServeConfig()) kserve_config = KSERVE_PRESETS.get(model_type, KServeConfig())
# Ensure registered model exists # Ensure registered model exists
try: try:
client.get_registered_model(model_name) client.get_registered_model(model_name)
@@ -198,7 +195,7 @@ def register_model_for_kserve(
} }
) )
logger.info(f"Created registered model: {model_name}") logger.info(f"Created registered model: {model_name}")
# Create model version # Create model version
model_version = client.create_model_version( model_version = client.create_model_version(
name=model_name, name=model_name,
@@ -211,12 +208,12 @@ def register_model_for_kserve(
**kserve_config.as_dict(), **kserve_config.as_dict(),
} }
) )
logger.info( logger.info(
f"Registered model version {model_version.version} " f"Registered model version {model_version.version} "
f"for {model_name} (type: {model_type})" f"for {model_name} (type: {model_type})"
) )
return model_version return model_version
@@ -229,19 +226,19 @@ def promote_model_to_stage(
) -> ModelVersion: ) -> ModelVersion:
""" """
Promote a model version to a new stage. Promote a model version to a new stage.
Args: Args:
model_name: Name of the registered model model_name: Name of the registered model
version: Version number to promote version: Version number to promote
stage: Target stage (Staging, Production, Archived) stage: Target stage (Staging, Production, Archived)
archive_existing: If True, archive existing versions in target stage archive_existing: If True, archive existing versions in target stage
tracking_uri: Override default tracking URI tracking_uri: Override default tracking URI
Returns: Returns:
The updated ModelVersion The updated ModelVersion
""" """
client = get_mlflow_client(tracking_uri=tracking_uri) client = get_mlflow_client(tracking_uri=tracking_uri)
# Transition to new stage # Transition to new stage
model_version = client.transition_model_version_stage( model_version = client.transition_model_version_stage(
name=model_name, name=model_name,
@@ -249,9 +246,9 @@ def promote_model_to_stage(
stage=stage, stage=stage,
archive_existing_versions=archive_existing, archive_existing_versions=archive_existing,
) )
logger.info(f"Promoted {model_name} v{version} to {stage}") logger.info(f"Promoted {model_name} v{version} to {stage}")
return model_version return model_version
@@ -262,12 +259,12 @@ def promote_model_to_production(
) -> ModelVersion: ) -> ModelVersion:
""" """
Promote a model version directly to Production. Promote a model version directly to Production.
Args: Args:
model_name: Name of the registered model model_name: Name of the registered model
version: Version number to promote version: Version number to promote
tracking_uri: Override default tracking URI tracking_uri: Override default tracking URI
Returns: Returns:
The updated ModelVersion The updated ModelVersion
""" """
@@ -286,18 +283,18 @@ def get_production_model(
) -> Optional[ModelVersion]: ) -> Optional[ModelVersion]:
""" """
Get the current Production model version. Get the current Production model version.
Args: Args:
model_name: Name of the registered model model_name: Name of the registered model
tracking_uri: Override default tracking URI tracking_uri: Override default tracking URI
Returns: Returns:
The Production ModelVersion, or None if none exists The Production ModelVersion, or None if none exists
""" """
client = get_mlflow_client(tracking_uri=tracking_uri) client = get_mlflow_client(tracking_uri=tracking_uri)
versions = client.get_latest_versions(model_name, stages=["Production"]) versions = client.get_latest_versions(model_name, stages=["Production"])
return versions[0] if versions else None return versions[0] if versions else None
@@ -308,17 +305,17 @@ def get_model_kserve_config(
) -> KServeConfig: ) -> KServeConfig:
""" """
Get KServe configuration from a registered model version. Get KServe configuration from a registered model version.
Args: Args:
model_name: Name of the registered model model_name: Name of the registered model
version: Version number (uses Production if not specified) version: Version number (uses Production if not specified)
tracking_uri: Override default tracking URI tracking_uri: Override default tracking URI
Returns: Returns:
KServeConfig populated from model tags KServeConfig populated from model tags
""" """
client = get_mlflow_client(tracking_uri=tracking_uri) client = get_mlflow_client(tracking_uri=tracking_uri)
if version: if version:
model_version = client.get_model_version(model_name, str(version)) model_version = client.get_model_version(model_name, str(version))
else: else:
@@ -326,9 +323,9 @@ def get_model_kserve_config(
if not prod_version: if not prod_version:
raise ValueError(f"No Production version for {model_name}") raise ValueError(f"No Production version for {model_name}")
model_version = prod_version model_version = prod_version
tags = model_version.tags tags = model_version.tags
return KServeConfig( return KServeConfig(
runtime=tags.get("kserve.runtime", "kserve-huggingface"), runtime=tags.get("kserve.runtime", "kserve-huggingface"),
protocol=tags.get("kserve.protocol", "v2"), protocol=tags.get("kserve.protocol", "v2"),
@@ -352,7 +349,7 @@ def generate_kserve_manifest(
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Generate a KServe InferenceService manifest from a registered model. Generate a KServe InferenceService manifest from a registered model.
Args: Args:
model_name: Name of the registered model model_name: Name of the registered model
version: Version number (uses Production if not specified) version: Version number (uses Production if not specified)
@@ -360,12 +357,12 @@ def generate_kserve_manifest(
service_name: Name for the InferenceService (defaults to model_name) service_name: Name for the InferenceService (defaults to model_name)
extra_annotations: Additional annotations for the service extra_annotations: Additional annotations for the service
tracking_uri: Override default tracking URI tracking_uri: Override default tracking URI
Returns: Returns:
KServe InferenceService manifest as a dictionary KServe InferenceService manifest as a dictionary
""" """
client = get_mlflow_client(tracking_uri=tracking_uri) client = get_mlflow_client(tracking_uri=tracking_uri)
# Get model version # Get model version
if version: if version:
model_version = client.get_model_version(model_name, str(version)) model_version = client.get_model_version(model_name, str(version))
@@ -375,13 +372,13 @@ def generate_kserve_manifest(
raise ValueError(f"No Production version for {model_name}") raise ValueError(f"No Production version for {model_name}")
model_version = prod_version model_version = prod_version
version = int(model_version.version) version = int(model_version.version)
# Get KServe config # Get KServe config
config = get_model_kserve_config(model_name, version, tracking_uri) config = get_model_kserve_config(model_name, version, tracking_uri)
model_type = model_version.tags.get("model.type", "custom") model_type = model_version.tags.get("model.type", "custom")
svc_name = service_name or model_name.lower().replace("_", "-") svc_name = service_name or model_name.lower().replace("_", "-")
# Build manifest # Build manifest
manifest = { manifest = {
"apiVersion": "serving.kserve.io/v1beta1", "apiVersion": "serving.kserve.io/v1beta1",
@@ -409,10 +406,10 @@ def generate_kserve_manifest(
}, },
}, },
} }
# Configure predictor based on runtime # Configure predictor based on runtime
predictor = manifest["spec"]["predictor"] predictor = manifest["spec"]["predictor"]
if config.container_image: if config.container_image:
# Custom container # Custom container
predictor["containers"] = [{ predictor["containers"] = [{
@@ -434,16 +431,16 @@ def generate_kserve_manifest(
for k, v in config.env_vars.items() for k, v in config.env_vars.items()
], ],
}] }]
# Add GPU if needed # Add GPU if needed
if config.gpu_count > 0: if config.gpu_count > 0:
predictor["containers"][0]["resources"]["limits"][config.gpu_type] = str(config.gpu_count) predictor["containers"][0]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
predictor["containers"][0]["resources"]["requests"][config.gpu_type] = str(config.gpu_count) predictor["containers"][0]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
else: else:
# Standard KServe runtime # Standard KServe runtime
storage_uri = config.storage_uri or model_version.source storage_uri = config.storage_uri or model_version.source
predictor["model"] = { predictor["model"] = {
"modelFormat": {"name": "huggingface"}, "modelFormat": {"name": "huggingface"},
"protocolVersion": config.protocol, "protocolVersion": config.protocol,
@@ -459,11 +456,11 @@ def generate_kserve_manifest(
}, },
}, },
} }
if config.gpu_count > 0: if config.gpu_count > 0:
predictor["model"]["resources"]["limits"][config.gpu_type] = str(config.gpu_count) predictor["model"]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
predictor["model"]["resources"]["requests"][config.gpu_type] = str(config.gpu_count) predictor["model"]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
return manifest return manifest
@@ -476,14 +473,14 @@ def generate_kserve_yaml(
) -> str: ) -> str:
""" """
Generate a KServe InferenceService manifest as YAML. Generate a KServe InferenceService manifest as YAML.
Args: Args:
model_name: Name of the registered model model_name: Name of the registered model
version: Version number (uses Production if not specified) version: Version number (uses Production if not specified)
namespace: Kubernetes namespace namespace: Kubernetes namespace
output_path: If provided, write YAML to this path output_path: If provided, write YAML to this path
tracking_uri: Override default tracking URI tracking_uri: Override default tracking URI
Returns: Returns:
YAML string of the manifest YAML string of the manifest
""" """
@@ -493,14 +490,14 @@ def generate_kserve_yaml(
namespace=namespace, namespace=namespace,
tracking_uri=tracking_uri, tracking_uri=tracking_uri,
) )
yaml_str = yaml.dump(manifest, default_flow_style=False, sort_keys=False) yaml_str = yaml.dump(manifest, default_flow_style=False, sort_keys=False)
if output_path: if output_path:
with open(output_path, 'w') as f: with open(output_path, 'w') as f:
f.write(yaml_str) f.write(yaml_str)
logger.info(f"Wrote KServe manifest to {output_path}") logger.info(f"Wrote KServe manifest to {output_path}")
return yaml_str return yaml_str
@@ -511,17 +508,17 @@ def list_model_versions(
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
List all versions of a registered model. List all versions of a registered model.
Args: Args:
model_name: Name of the registered model model_name: Name of the registered model
stages: Filter by stages (None for all) stages: Filter by stages (None for all)
tracking_uri: Override default tracking URI tracking_uri: Override default tracking URI
Returns: Returns:
List of model version info dictionaries List of model version info dictionaries
""" """
client = get_mlflow_client(tracking_uri=tracking_uri) client = get_mlflow_client(tracking_uri=tracking_uri)
if stages: if stages:
versions = client.get_latest_versions(model_name, stages=stages) versions = client.get_latest_versions(model_name, stages=stages)
else: else:
@@ -529,7 +526,7 @@ def list_model_versions(
versions = [] versions = []
for mv in client.search_model_versions(f"name='{model_name}'"): for mv in client.search_model_versions(f"name='{model_name}'"):
versions.append(mv) versions.append(mv)
return [ return [
{ {
"version": mv.version, "version": mv.version,

View File

@@ -5,19 +5,17 @@ Provides a high-level interface for logging experiments, parameters,
metrics, and artifacts from Kubeflow Pipeline components. metrics, and artifacts from Kubeflow Pipeline components.
""" """
import os
import json
import time
import logging import logging
from pathlib import Path import os
from typing import Optional, Dict, Any, List, Union import time
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Union
import mlflow import mlflow
from mlflow.tracking import MlflowClient from mlflow.tracking import MlflowClient
from .client import get_mlflow_client, ensure_experiment, MLflowConfig from .client import MLflowConfig, ensure_experiment, get_mlflow_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -30,7 +28,7 @@ class PipelineMetadata:
run_name: Optional[str] = None run_name: Optional[str] = None
component_name: Optional[str] = None component_name: Optional[str] = None
namespace: str = "ai-ml" namespace: str = "ai-ml"
# KFP-specific metadata (populated from environment if available) # KFP-specific metadata (populated from environment if available)
kfp_run_id: Optional[str] = field( kfp_run_id: Optional[str] = field(
default_factory=lambda: os.environ.get("KFP_RUN_ID") default_factory=lambda: os.environ.get("KFP_RUN_ID")
@@ -38,7 +36,7 @@ class PipelineMetadata:
kfp_pod_name: Optional[str] = field( kfp_pod_name: Optional[str] = field(
default_factory=lambda: os.environ.get("KFP_POD_NAME") default_factory=lambda: os.environ.get("KFP_POD_NAME")
) )
def as_tags(self) -> Dict[str, str]: def as_tags(self) -> Dict[str, str]:
"""Convert metadata to MLflow tags.""" """Convert metadata to MLflow tags."""
tags = { tags = {
@@ -60,34 +58,34 @@ class PipelineMetadata:
class MLflowTracker: class MLflowTracker:
""" """
MLflow experiment tracker for Kubeflow Pipeline components. MLflow experiment tracker for Kubeflow Pipeline components.
Example usage in a KFP component: Example usage in a KFP component:
from mlflow_utils import MLflowTracker from mlflow_utils import MLflowTracker
tracker = MLflowTracker( tracker = MLflowTracker(
experiment_name="document-ingestion", experiment_name="document-ingestion",
run_name="batch-ingestion-2024-01" run_name="batch-ingestion-2024-01"
) )
with tracker.start_run() as run: with tracker.start_run() as run:
tracker.log_params({ tracker.log_params({
"chunk_size": 500, "chunk_size": 500,
"overlap": 50, "overlap": 50,
"embeddings_model": "bge-small-en-v1.5" "embeddings_model": "bge-small-en-v1.5"
}) })
# ... do work ... # ... do work ...
tracker.log_metrics({ tracker.log_metrics({
"documents_processed": 100, "documents_processed": 100,
"chunks_created": 2500, "chunks_created": 2500,
"processing_time_seconds": 120.5 "processing_time_seconds": 120.5
}) })
tracker.log_artifact("/path/to/output.json") tracker.log_artifact("/path/to/output.json")
""" """
def __init__( def __init__(
self, self,
experiment_name: str, experiment_name: str,
@@ -98,7 +96,7 @@ class MLflowTracker:
): ):
""" """
Initialize the MLflow tracker. Initialize the MLflow tracker.
Args: Args:
experiment_name: Name of the MLflow experiment experiment_name: Name of the MLflow experiment
run_name: Optional name for this run run_name: Optional name for this run
@@ -112,22 +110,22 @@ class MLflowTracker:
self.pipeline_metadata = pipeline_metadata self.pipeline_metadata = pipeline_metadata
self.user_tags = tags or {} self.user_tags = tags or {}
self.tracking_uri = tracking_uri self.tracking_uri = tracking_uri
self.client: Optional[MlflowClient] = None self.client: Optional[MlflowClient] = None
self.run: Optional[mlflow.ActiveRun] = None self.run: Optional[mlflow.ActiveRun] = None
self.run_id: Optional[str] = None self.run_id: Optional[str] = None
self._start_time: Optional[float] = None self._start_time: Optional[float] = None
def _get_all_tags(self) -> Dict[str, str]: def _get_all_tags(self) -> Dict[str, str]:
"""Combine all tags for the run.""" """Combine all tags for the run."""
tags = self.config.default_tags.copy() tags = self.config.default_tags.copy()
if self.pipeline_metadata: if self.pipeline_metadata:
tags.update(self.pipeline_metadata.as_tags()) tags.update(self.pipeline_metadata.as_tags())
tags.update(self.user_tags) tags.update(self.user_tags)
return tags return tags
@contextmanager @contextmanager
def start_run( def start_run(
self, self,
@@ -136,11 +134,11 @@ class MLflowTracker:
): ):
""" """
Start an MLflow run as a context manager. Start an MLflow run as a context manager.
Args: Args:
nested: If True, create a nested run under the current active run nested: If True, create a nested run under the current active run
parent_run_id: Explicit parent run ID for nested runs parent_run_id: Explicit parent run ID for nested runs
Yields: Yields:
The MLflow run object The MLflow run object
""" """
@@ -148,12 +146,12 @@ class MLflowTracker:
tracking_uri=self.tracking_uri, tracking_uri=self.tracking_uri,
configure_global=True configure_global=True
) )
# Ensure experiment exists # Ensure experiment exists
experiment_id = ensure_experiment(self.experiment_name) experiment_id = ensure_experiment(self.experiment_name)
self._start_time = time.time() self._start_time = time.time()
try: try:
# Start the run # Start the run
self.run = mlflow.start_run( self.run = mlflow.start_run(
@@ -163,14 +161,14 @@ class MLflowTracker:
tags=self._get_all_tags(), tags=self._get_all_tags(),
) )
self.run_id = self.run.info.run_id self.run_id = self.run.info.run_id
logger.info( logger.info(
f"Started MLflow run '{self.run_name}' " f"Started MLflow run '{self.run_name}' "
f"(ID: {self.run_id}) in experiment '{self.experiment_name}'" f"(ID: {self.run_id}) in experiment '{self.experiment_name}'"
) )
yield self.run yield self.run
except Exception as e: except Exception as e:
logger.error(f"MLflow run failed: {e}") logger.error(f"MLflow run failed: {e}")
if self.run: if self.run:
@@ -185,22 +183,22 @@ class MLflowTracker:
mlflow.log_metric("run_duration_seconds", duration) mlflow.log_metric("run_duration_seconds", duration)
except Exception: except Exception:
pass pass
# End the run # End the run
mlflow.end_run() mlflow.end_run()
logger.info(f"Ended MLflow run '{self.run_name}'") logger.info(f"Ended MLflow run '{self.run_name}'")
def log_params(self, params: Dict[str, Any]) -> None: def log_params(self, params: Dict[str, Any]) -> None:
""" """
Log parameters to the current run. Log parameters to the current run.
Args: Args:
params: Dictionary of parameter names to values params: Dictionary of parameter names to values
""" """
if not self.run: if not self.run:
logger.warning("No active run, skipping log_params") logger.warning("No active run, skipping log_params")
return return
# MLflow has limits on param values, truncate if needed # MLflow has limits on param values, truncate if needed
cleaned_params = {} cleaned_params = {}
for key, value in params.items(): for key, value in params.items():
@@ -208,14 +206,14 @@ class MLflowTracker:
if len(str_value) > 500: if len(str_value) > 500:
str_value = str_value[:497] + "..." str_value = str_value[:497] + "..."
cleaned_params[key] = str_value cleaned_params[key] = str_value
mlflow.log_params(cleaned_params) mlflow.log_params(cleaned_params)
logger.debug(f"Logged {len(params)} parameters") logger.debug(f"Logged {len(params)} parameters")
def log_param(self, key: str, value: Any) -> None: def log_param(self, key: str, value: Any) -> None:
"""Log a single parameter.""" """Log a single parameter."""
self.log_params({key: value}) self.log_params({key: value})
def log_metrics( def log_metrics(
self, self,
metrics: Dict[str, Union[float, int]], metrics: Dict[str, Union[float, int]],
@@ -223,7 +221,7 @@ class MLflowTracker:
) -> None: ) -> None:
""" """
Log metrics to the current run. Log metrics to the current run.
Args: Args:
metrics: Dictionary of metric names to values metrics: Dictionary of metric names to values
step: Optional step number for time-series metrics step: Optional step number for time-series metrics
@@ -231,10 +229,10 @@ class MLflowTracker:
if not self.run: if not self.run:
logger.warning("No active run, skipping log_metrics") logger.warning("No active run, skipping log_metrics")
return return
mlflow.log_metrics(metrics, step=step) mlflow.log_metrics(metrics, step=step)
logger.debug(f"Logged {len(metrics)} metrics") logger.debug(f"Logged {len(metrics)} metrics")
def log_metric( def log_metric(
self, self,
key: str, key: str,
@@ -243,7 +241,7 @@ class MLflowTracker:
) -> None: ) -> None:
"""Log a single metric.""" """Log a single metric."""
self.log_metrics({key: value}, step=step) self.log_metrics({key: value}, step=step)
def log_artifact( def log_artifact(
self, self,
local_path: str, local_path: str,
@@ -251,7 +249,7 @@ class MLflowTracker:
) -> None: ) -> None:
""" """
Log an artifact file to the current run. Log an artifact file to the current run.
Args: Args:
local_path: Path to the local file to log local_path: Path to the local file to log
artifact_path: Optional destination path within the artifact store artifact_path: Optional destination path within the artifact store
@@ -259,10 +257,10 @@ class MLflowTracker:
if not self.run: if not self.run:
logger.warning("No active run, skipping log_artifact") logger.warning("No active run, skipping log_artifact")
return return
mlflow.log_artifact(local_path, artifact_path) mlflow.log_artifact(local_path, artifact_path)
logger.info(f"Logged artifact: {local_path}") logger.info(f"Logged artifact: {local_path}")
def log_artifacts( def log_artifacts(
self, self,
local_dir: str, local_dir: str,
@@ -270,7 +268,7 @@ class MLflowTracker:
) -> None: ) -> None:
""" """
Log all files in a directory as artifacts. Log all files in a directory as artifacts.
Args: Args:
local_dir: Path to the local directory local_dir: Path to the local directory
artifact_path: Optional destination path within the artifact store artifact_path: Optional destination path within the artifact store
@@ -278,10 +276,10 @@ class MLflowTracker:
if not self.run: if not self.run:
logger.warning("No active run, skipping log_artifacts") logger.warning("No active run, skipping log_artifacts")
return return
mlflow.log_artifacts(local_dir, artifact_path) mlflow.log_artifacts(local_dir, artifact_path)
logger.info(f"Logged artifacts from: {local_dir}") logger.info(f"Logged artifacts from: {local_dir}")
def log_dict( def log_dict(
self, self,
data: Dict[str, Any], data: Dict[str, Any],
@@ -290,7 +288,7 @@ class MLflowTracker:
) -> None: ) -> None:
""" """
Log a dictionary as a JSON artifact. Log a dictionary as a JSON artifact.
Args: Args:
data: Dictionary to log data: Dictionary to log
filename: Name for the JSON file filename: Name for the JSON file
@@ -299,14 +297,14 @@ class MLflowTracker:
if not self.run: if not self.run:
logger.warning("No active run, skipping log_dict") logger.warning("No active run, skipping log_dict")
return return
# Ensure .json extension # Ensure .json extension
if not filename.endswith(".json"): if not filename.endswith(".json"):
filename += ".json" filename += ".json"
mlflow.log_dict(data, f"{artifact_path}/{filename}" if artifact_path else filename) mlflow.log_dict(data, f"{artifact_path}/{filename}" if artifact_path else filename)
logger.debug(f"Logged dict as: {filename}") logger.debug(f"Logged dict as: {filename}")
def log_model_info( def log_model_info(
self, self,
model_type: str, model_type: str,
@@ -317,7 +315,7 @@ class MLflowTracker:
) -> None: ) -> None:
""" """
Log model information as parameters and tags. Log model information as parameters and tags.
Args: Args:
model_type: Type of model (e.g., "llm", "embedding", "stt") model_type: Type of model (e.g., "llm", "embedding", "stt")
model_name: Name/identifier of the model model_name: Name/identifier of the model
@@ -335,13 +333,13 @@ class MLflowTracker:
if extra_info: if extra_info:
for key, value in extra_info.items(): for key, value in extra_info.items():
params[f"model.{key}"] = value params[f"model.{key}"] = value
self.log_params(params) self.log_params(params)
# Also set as tags for easier filtering # Also set as tags for easier filtering
mlflow.set_tag("model.type", model_type) mlflow.set_tag("model.type", model_type)
mlflow.set_tag("model.name", model_name) mlflow.set_tag("model.name", model_name)
def log_dataset_info( def log_dataset_info(
self, self,
name: str, name: str,
@@ -351,7 +349,7 @@ class MLflowTracker:
) -> None: ) -> None:
""" """
Log dataset information. Log dataset information.
Args: Args:
name: Dataset name name: Dataset name
source: Dataset source (URL, path, etc.) source: Dataset source (URL, path, etc.)
@@ -367,26 +365,26 @@ class MLflowTracker:
if extra_info: if extra_info:
for key, value in extra_info.items(): for key, value in extra_info.items():
params[f"dataset.{key}"] = value params[f"dataset.{key}"] = value
self.log_params(params) self.log_params(params)
def set_tag(self, key: str, value: str) -> None: def set_tag(self, key: str, value: str) -> None:
"""Set a single tag on the run.""" """Set a single tag on the run."""
if self.run: if self.run:
mlflow.set_tag(key, value) mlflow.set_tag(key, value)
def set_tags(self, tags: Dict[str, str]) -> None: def set_tags(self, tags: Dict[str, str]) -> None:
"""Set multiple tags on the run.""" """Set multiple tags on the run."""
if self.run: if self.run:
mlflow.set_tags(tags) mlflow.set_tags(tags)
@property @property
def artifact_uri(self) -> Optional[str]: def artifact_uri(self) -> Optional[str]:
"""Get the artifact URI for the current run.""" """Get the artifact URI for the current run."""
if self.run: if self.run:
return self.run.info.artifact_uri return self.run.info.artifact_uri
return None return None
@property @property
def experiment_id(self) -> Optional[str]: def experiment_id(self) -> Optional[str]:
"""Get the experiment ID for the current run.""" """Get the experiment ID for the current run."""

View File

@@ -8,12 +8,12 @@ import pytest
def test_package_imports() -> None: def test_package_imports() -> None:
"""All public symbols are importable.""" """All public symbols are importable."""
from mlflow_utils import ( # noqa: F401 from mlflow_utils import ( # noqa: F401
InferenceMetricsTracker,
MLflowConfig, MLflowConfig,
MLflowTracker, MLflowTracker,
InferenceMetricsTracker, ensure_experiment,
get_mlflow_client, get_mlflow_client,
get_tracking_uri, get_tracking_uri,
ensure_experiment,
) )
@@ -48,8 +48,8 @@ def test_kfp_components_importable() -> None:
def test_model_registry_importable() -> None: def test_model_registry_importable() -> None:
from mlflow_utils.model_registry import ( # noqa: F401 from mlflow_utils.model_registry import ( # noqa: F401
register_model_for_kserve,
generate_kserve_manifest, generate_kserve_manifest,
register_model_for_kserve,
) )