diff --git a/mlflow_utils/__init__.py b/mlflow_utils/__init__.py index 0ec37ef..a659eca 100644 --- a/mlflow_utils/__init__.py +++ b/mlflow_utils/__init__.py @@ -20,17 +20,17 @@ Usage: """ from .client import ( + MLflowConfig, + ensure_experiment, get_mlflow_client, get_tracking_uri, - ensure_experiment, - MLflowConfig, ) -from .tracker import MLflowTracker from .inference_tracker import InferenceMetricsTracker +from .tracker import MLflowTracker __all__ = [ "get_mlflow_client", - "get_tracking_uri", + "get_tracking_uri", "ensure_experiment", "MLflowConfig", "MLflowTracker", diff --git a/mlflow_utils/cli.py b/mlflow_utils/cli.py index c535a5a..fd2362e 100644 --- a/mlflow_utils/cli.py +++ b/mlflow_utils/cli.py @@ -7,21 +7,21 @@ Command-line interface for querying and comparing MLflow experiments. Usage: # Compare recent runs in an experiment python -m mlflow_utils.cli compare --experiment chat-inference --runs 5 - + # Get best run by metric python -m mlflow_utils.cli best --experiment evaluation --metric eval.accuracy - + # Generate performance report python -m mlflow_utils.cli report --service chat-handler --hours 24 - + # Check model promotion criteria python -m mlflow_utils.cli promote --model whisper-finetuned \\ --experiment voice-evaluation \\ --criteria "eval.accuracy>=0.9,total_latency_p95<=2.0" - + # List experiments python -m mlflow_utils.cli list-experiments - + # Query runs python -m mlflow_utils.cli query --experiment chat-inference \\ --filter "metrics.total_latency_mean < 1.0" --limit 10 @@ -30,19 +30,16 @@ Usage: import argparse import json import sys -from typing import Optional from .client import get_mlflow_client, health_check from .experiment_comparison import ( ExperimentAnalyzer, - compare_experiments, - promotion_recommendation, get_inference_performance_report, + promotion_recommendation, ) from .model_registry import ( - list_model_versions, - get_production_model, generate_kserve_yaml, + list_model_versions, ) @@ -57,7 +54,7 @@ def cmd_list_experiments(args): """List all experiments.""" client = get_mlflow_client(tracking_uri=args.tracking_uri) experiments = client.search_experiments() - + print(f"{'ID':<10} {'Name':<40} {'Artifact Location'}") print("-" * 80) for exp in experiments: @@ -70,13 +67,13 @@ def cmd_compare(args): args.experiment, tracking_uri=args.tracking_uri ) - + if args.run_ids: run_ids = args.run_ids.split(",") comparison = analyzer.compare_runs(run_ids=run_ids) else: comparison = analyzer.compare_runs(n_recent=args.runs) - + if args.json: print(json.dumps(comparison.to_dict(), indent=2, default=str)) else: @@ -89,17 +86,17 @@ def cmd_best(args): args.experiment, tracking_uri=args.tracking_uri ) - + best_run = analyzer.get_best_run( metric=args.metric, minimize=args.minimize, filter_string=args.filter or "", ) - + if not best_run: print(f"No runs found with metric '{args.metric}'") sys.exit(1) - + result = { "run_id": best_run.info.run_id, "run_name": best_run.info.run_name, @@ -107,7 +104,7 @@ def cmd_best(args): "all_metrics": dict(best_run.data.metrics), "params": dict(best_run.data.params), } - + if args.json: print(json.dumps(result, indent=2)) else: @@ -122,12 +119,12 @@ def cmd_summary(args): args.experiment, tracking_uri=args.tracking_uri ) - + summary = analyzer.get_metrics_summary( hours=args.hours, metrics=args.metrics.split(",") if args.metrics else None, ) - + if args.json: print(json.dumps(summary, indent=2)) else: @@ -146,7 +143,7 @@ def cmd_report(args): hours=args.hours, tracking_uri=args.tracking_uri, ) - + if args.json: print(json.dumps(report, indent=2)) else: @@ -154,18 +151,18 @@ def cmd_report(args): print(f"Period: Last {report['period_hours']} hours") print(f"Generated: {report['generated_at']}") print() - + if report["latency"]: print("Latency Metrics:") for metric, stats in report["latency"].items(): if "mean" in stats: print(f" {metric}: {stats['mean']:.4f}s (p50: {stats.get('median', 'N/A')})") - + if report["rag"]: print("\nRAG Usage:") for metric, stats in report["rag"].items(): print(f" {metric}: {stats.get('mean', 'N/A')}") - + if report["errors"]: print("\nError Rates:") for metric, stats in report["errors"].items(): @@ -183,14 +180,14 @@ def cmd_promote(args): metric, value = criterion.split(op) criteria[metric.strip()] = (op, float(value.strip())) break - + rec = promotion_recommendation( model_name=args.model, experiment_name=args.experiment, criteria=criteria, tracking_uri=args.tracking_uri, ) - + if args.json: print(json.dumps(rec.to_dict(), indent=2)) else: @@ -208,12 +205,12 @@ def cmd_query(args): args.experiment, tracking_uri=args.tracking_uri ) - + runs = analyzer.search_runs( filter_string=args.filter or "", max_results=args.limit, ) - + if args.json: result = [ { @@ -237,20 +234,21 @@ def cmd_query(args): def cmd_models(args): """List registered models.""" client = get_mlflow_client(tracking_uri=args.tracking_uri) - + if args.model: versions = list_model_versions(args.model, tracking_uri=args.tracking_uri) - + if args.json: print(json.dumps(versions, indent=2, default=str)) else: print(f"Model: {args.model}") for v in versions: - print(f" v{v['version']} ({v['stage']}): {v['description'][:50] if v['description'] else 'No description'}") + desc = v["description"][:50] if v["description"] else "No description" + print(f" v{v['version']} ({v['stage']}): {desc}") else: # List all models models = client.search_registered_models() - + if args.json: result = [{"name": m.name, "description": m.description} for m in models] print(json.dumps(result, indent=2)) @@ -271,7 +269,7 @@ def cmd_kserve(args): output_path=args.output, tracking_uri=args.tracking_uri, ) - + if not args.output: print(yaml_str) @@ -281,7 +279,7 @@ def main(): description="MLflow Experiment CLI", formatter_class=argparse.RawDescriptionHelpFormatter, ) - + parser.add_argument( "--tracking-uri", default=None, @@ -292,24 +290,24 @@ def main(): action="store_true", help="Output as JSON", ) - + subparsers = parser.add_subparsers(dest="command", help="Commands") - + # health health_parser = subparsers.add_parser("health", help="Check MLflow connectivity") health_parser.set_defaults(func=cmd_health) - + # list-experiments list_parser = subparsers.add_parser("list-experiments", help="List experiments") list_parser.set_defaults(func=cmd_list_experiments) - + # compare compare_parser = subparsers.add_parser("compare", help="Compare runs") compare_parser.add_argument("--experiment", "-e", required=True, help="Experiment name") compare_parser.add_argument("--runs", "-n", type=int, default=5, help="Number of recent runs") compare_parser.add_argument("--run-ids", help="Comma-separated run IDs to compare") compare_parser.set_defaults(func=cmd_compare) - + # best best_parser = subparsers.add_parser("best", help="Find best run by metric") best_parser.add_argument("--experiment", "-e", required=True, help="Experiment name") @@ -317,39 +315,39 @@ def main(): best_parser.add_argument("--minimize", action="store_true", help="Minimize metric (default: maximize)") best_parser.add_argument("--filter", "-f", help="Filter string") best_parser.set_defaults(func=cmd_best) - + # summary summary_parser = subparsers.add_parser("summary", help="Get metrics summary") summary_parser.add_argument("--experiment", "-e", required=True, help="Experiment name") summary_parser.add_argument("--hours", type=int, default=24, help="Hours of data") summary_parser.add_argument("--metrics", help="Comma-separated metric names") summary_parser.set_defaults(func=cmd_summary) - + # report report_parser = subparsers.add_parser("report", help="Generate performance report") report_parser.add_argument("--service", "-s", required=True, help="Service name") report_parser.add_argument("--hours", type=int, default=24, help="Hours of data") report_parser.set_defaults(func=cmd_report) - + # promote promote_parser = subparsers.add_parser("promote", help="Check promotion criteria") promote_parser.add_argument("--model", "-m", required=True, help="Model name") promote_parser.add_argument("--experiment", "-e", required=True, help="Experiment name") promote_parser.add_argument("--criteria", "-c", required=True, help="Criteria (e.g., 'accuracy>=0.9,latency<=2.0')") promote_parser.set_defaults(func=cmd_promote) - + # query query_parser = subparsers.add_parser("query", help="Query runs") query_parser.add_argument("--experiment", "-e", required=True, help="Experiment name") query_parser.add_argument("--filter", "-f", help="MLflow filter string") query_parser.add_argument("--limit", "-l", type=int, default=20, help="Max results") query_parser.set_defaults(func=cmd_query) - + # models models_parser = subparsers.add_parser("models", help="List registered models") models_parser.add_argument("--model", "-m", help="Specific model name") models_parser.set_defaults(func=cmd_models) - + # kserve kserve_parser = subparsers.add_parser("kserve", help="Generate KServe manifest") kserve_parser.add_argument("--model", "-m", required=True, help="Model name") @@ -357,13 +355,13 @@ def main(): kserve_parser.add_argument("--namespace", "-n", default="ai-ml", help="K8s namespace") kserve_parser.add_argument("--output", "-o", help="Output file path") kserve_parser.set_defaults(func=cmd_kserve) - + args = parser.parse_args() - + if not args.command: parser.print_help() sys.exit(1) - + args.func(args) diff --git a/mlflow_utils/client.py b/mlflow_utils/client.py index b448d99..290c87d 100644 --- a/mlflow_utils/client.py +++ b/mlflow_utils/client.py @@ -5,10 +5,10 @@ Provides a configured MLflow client for all integrations in the LLM workflows. Supports both in-cluster and external access patterns. """ -import os import logging +import os from dataclasses import dataclass, field -from typing import Optional, Dict, Any +from typing import Any, Dict, Optional import mlflow from mlflow.tracking import MlflowClient @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) @dataclass class MLflowConfig: """Configuration for MLflow integration.""" - + # Tracking server URIs tracking_uri: str = field( default_factory=lambda: os.environ.get( @@ -33,7 +33,7 @@ class MLflowConfig: "https://mlflow.lab.daviestechlabs.io" ) ) - + # Artifact storage (NFS PVC mount) artifact_location: str = field( default_factory=lambda: os.environ.get( @@ -41,7 +41,7 @@ class MLflowConfig: "/mlflow/artifacts" ) ) - + # Default experiment settings default_experiment: str = field( default_factory=lambda: os.environ.get( @@ -49,7 +49,7 @@ class MLflowConfig: "llm-workflows" ) ) - + # Service identification service_name: str = field( default_factory=lambda: os.environ.get( @@ -57,10 +57,10 @@ class MLflowConfig: "unknown-service" ) ) - + # Additional tags to add to all runs default_tags: Dict[str, str] = field(default_factory=dict) - + def __post_init__(self): """Add default tags based on environment.""" env_tags = { @@ -74,10 +74,10 @@ class MLflowConfig: def get_tracking_uri(external: bool = False) -> str: """ Get the appropriate MLflow tracking URI. - + Args: external: If True, return the external URI for outside-cluster access - + Returns: The MLflow tracking URI string """ @@ -91,20 +91,20 @@ def get_mlflow_client( ) -> MlflowClient: """ Get a configured MLflow client. - + Args: tracking_uri: Override the default tracking URI configure_global: If True, also set mlflow.set_tracking_uri() - + Returns: Configured MlflowClient instance """ uri = tracking_uri or get_tracking_uri() - + if configure_global: mlflow.set_tracking_uri(uri) logger.info(f"MLflow tracking URI set to: {uri}") - + client = MlflowClient(tracking_uri=uri) return client @@ -116,21 +116,21 @@ def ensure_experiment( ) -> str: """ Ensure an experiment exists, creating it if necessary. - + Args: experiment_name: Name of the experiment artifact_location: Override default artifact location tags: Additional tags for the experiment - + Returns: The experiment ID """ config = MLflowConfig() client = get_mlflow_client() - + # Check if experiment exists experiment = client.get_experiment_by_name(experiment_name) - + if experiment is None: # Create the experiment artifact_loc = artifact_location or f"{config.artifact_location}/{experiment_name}" @@ -143,7 +143,7 @@ def ensure_experiment( else: experiment_id = experiment.experiment_id logger.debug(f"Using existing experiment '{experiment_name}' with ID: {experiment_id}") - + return experiment_id @@ -154,17 +154,17 @@ def get_or_create_registered_model( ) -> str: """ Get or create a registered model in the Model Registry. - + Args: model_name: Name of the model to register description: Model description tags: Tags for the model - + Returns: The registered model name """ client = get_mlflow_client() - + try: # Check if model exists client.get_registered_model(model_name) @@ -177,14 +177,14 @@ def get_or_create_registered_model( tags=tags or {} ) logger.info(f"Created registered model: {model_name}") - + return model_name def health_check() -> Dict[str, Any]: """ Check MLflow server connectivity and return status. - + Returns: Dictionary with health status information """ @@ -195,7 +195,7 @@ def health_check() -> Dict[str, Any]: "connected": False, "error": None, } - + try: client = get_mlflow_client(configure_global=False) # Try to list experiments as a health check @@ -205,5 +205,5 @@ def health_check() -> Dict[str, Any]: except Exception as e: result["error"] = str(e) logger.error(f"MLflow health check failed: {e}") - + return result diff --git a/mlflow_utils/experiment_comparison.py b/mlflow_utils/experiment_comparison.py index f66d164..87dca68 100644 --- a/mlflow_utils/experiment_comparison.py +++ b/mlflow_utils/experiment_comparison.py @@ -18,15 +18,15 @@ Usage: get_best_run, promotion_recommendation, ) - + analyzer = ExperimentAnalyzer("chat-inference") - + # Compare last N runs comparison = analyzer.compare_recent_runs(n=5) - + # Find best performing model best = analyzer.get_best_run(metric="total_latency_mean", minimize=True) - + # Get promotion recommendation rec = analyzer.promotion_recommendation( model_name="whisper-finetuned", @@ -35,19 +35,15 @@ Usage: ) """ -import os -import json import logging -from datetime import datetime, timedelta -from typing import Optional, Dict, Any, List, Tuple, Union -from dataclasses import dataclass, field from collections import defaultdict +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional, Tuple -import mlflow -from mlflow.tracking import MlflowClient -from mlflow.entities import Run, Experiment +from mlflow.entities import Experiment, Run -from .client import get_mlflow_client, MLflowConfig +from .client import get_mlflow_client logger = logging.getLogger(__name__) @@ -57,21 +53,21 @@ class RunComparison: """Comparison result for multiple MLflow runs.""" run_ids: List[str] experiment_name: str - + # Metric comparisons (metric_name -> {run_id -> value}) metrics: Dict[str, Dict[str, float]] = field(default_factory=dict) - + # Parameter differences params: Dict[str, Dict[str, str]] = field(default_factory=dict) - + # Run metadata run_names: Dict[str, str] = field(default_factory=dict) start_times: Dict[str, datetime] = field(default_factory=dict) durations: Dict[str, float] = field(default_factory=dict) - + # Best performers by metric best_by_metric: Dict[str, str] = field(default_factory=dict) - + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for serialization.""" return { @@ -82,22 +78,22 @@ class RunComparison: "run_names": self.run_names, "best_by_metric": self.best_by_metric, } - + def summary_table(self) -> str: """Generate a text summary table of the comparison.""" if not self.run_ids: return "No runs to compare" - + lines = [] lines.append(f"Experiment: {self.experiment_name}") lines.append(f"Comparing {len(self.run_ids)} runs") lines.append("") - + # Header header = ["Metric"] + [self.run_names.get(rid, rid[:8]) for rid in self.run_ids] lines.append(" | ".join(header)) lines.append("-" * (len(lines[-1]) + 10)) - + # Metrics for metric_name, values in sorted(self.metrics.items()): row = [metric_name] @@ -108,7 +104,7 @@ class RunComparison: else: row.append("N/A") lines.append(" | ".join(row)) - + return "\n".join(lines) @@ -121,7 +117,7 @@ class PromotionRecommendation: reasons: List[str] metrics_summary: Dict[str, float] comparison_with_production: Optional[Dict[str, Any]] = None - + def to_dict(self) -> Dict[str, Any]: return { "model_name": self.model_name, @@ -136,20 +132,20 @@ class PromotionRecommendation: class ExperimentAnalyzer: """ Analyze MLflow experiments for model comparison and promotion decisions. - + Example: analyzer = ExperimentAnalyzer("chat-inference") - + # Get metrics summary for last 24 hours summary = analyzer.get_metrics_summary(hours=24) - + # Compare models by accuracy best = analyzer.get_best_run(metric="eval.accuracy", minimize=False) - + # Analyze inference latency trends trends = analyzer.get_metric_trends("total_latency_mean", days=7) """ - + def __init__( self, experiment_name: str, @@ -157,7 +153,7 @@ class ExperimentAnalyzer: ): """ Initialize the experiment analyzer. - + Args: experiment_name: Name of the MLflow experiment to analyze tracking_uri: Override default tracking URI @@ -166,14 +162,14 @@ class ExperimentAnalyzer: self.tracking_uri = tracking_uri self.client = get_mlflow_client(tracking_uri=tracking_uri) self._experiment: Optional[Experiment] = None - + @property def experiment(self) -> Optional[Experiment]: """Get the experiment object, fetching if needed.""" if self._experiment is None: self._experiment = self.client.get_experiment_by_name(self.experiment_name) return self._experiment - + def search_runs( self, filter_string: str = "", @@ -183,29 +179,29 @@ class ExperimentAnalyzer: ) -> List[Run]: """ Search for runs matching criteria. - + Args: filter_string: MLflow search filter (e.g., "metrics.accuracy > 0.9") order_by: List of order clauses (e.g., ["metrics.accuracy DESC"]) max_results: Maximum runs to return run_view_type: ACTIVE_ONLY, DELETED_ONLY, or ALL - + Returns: List of matching Run objects """ if not self.experiment: logger.warning(f"Experiment '{self.experiment_name}' not found") return [] - + runs = self.client.search_runs( experiment_ids=[self.experiment.experiment_id], filter_string=filter_string, order_by=order_by or ["start_time DESC"], max_results=max_results, ) - + return runs - + def get_recent_runs( self, n: int = 10, @@ -213,11 +209,11 @@ class ExperimentAnalyzer: ) -> List[Run]: """ Get the most recent runs. - + Args: n: Number of runs to return hours: Only include runs from the last N hours - + Returns: List of Run objects """ @@ -226,13 +222,13 @@ class ExperimentAnalyzer: cutoff = datetime.now() - timedelta(hours=hours) cutoff_ms = int(cutoff.timestamp() * 1000) filter_string = f"attributes.start_time >= {cutoff_ms}" - + return self.search_runs( filter_string=filter_string, order_by=["start_time DESC"], max_results=n, ) - + def compare_runs( self, run_ids: Optional[List[str]] = None, @@ -240,11 +236,11 @@ class ExperimentAnalyzer: ) -> RunComparison: """ Compare multiple runs side by side. - + Args: run_ids: Specific run IDs to compare, or None for recent runs n_recent: If run_ids is None, compare this many recent runs - + Returns: RunComparison object with detailed comparison """ @@ -252,18 +248,18 @@ class ExperimentAnalyzer: runs = [self.client.get_run(rid) for rid in run_ids] else: runs = self.get_recent_runs(n=n_recent) - + comparison = RunComparison( run_ids=[r.info.run_id for r in runs], experiment_name=self.experiment_name, ) - + # Collect all metrics and find best performers all_metrics: Dict[str, Dict[str, float]] = defaultdict(dict) - + for run in runs: run_id = run.info.run_id - + # Metadata comparison.run_names[run_id] = run.info.run_name or run_id[:8] comparison.start_times[run_id] = datetime.fromtimestamp( @@ -273,39 +269,39 @@ class ExperimentAnalyzer: comparison.durations[run_id] = ( run.info.end_time - run.info.start_time ) / 1000 - + # Metrics for key, value in run.data.metrics.items(): all_metrics[key][run_id] = value - + # Params for key, value in run.data.params.items(): if key not in comparison.params: comparison.params[key] = {} comparison.params[key][run_id] = value - + comparison.metrics = dict(all_metrics) - + # Find best performers for each metric for metric_name, values in all_metrics.items(): if not values: continue - + # Determine if lower is better based on metric name minimize = any( term in metric_name.lower() for term in ["latency", "error", "loss", "time"] ) - + if minimize: best_id = min(values.keys(), key=lambda k: values[k]) else: best_id = max(values.keys(), key=lambda k: values[k]) - + comparison.best_by_metric[metric_name] = best_id - + return comparison - + def get_best_run( self, metric: str, @@ -315,32 +311,32 @@ class ExperimentAnalyzer: ) -> Optional[Run]: """ Get the best run by a specific metric. - + Args: metric: Metric name to optimize minimize: If True, find minimum; if False, find maximum filter_string: Additional filter criteria max_results: Maximum runs to consider - + Returns: Best Run object, or None if no runs found """ direction = "ASC" if minimize else "DESC" - + runs = self.search_runs( filter_string=filter_string, order_by=[f"metrics.{metric} {direction}"], max_results=max_results, ) - + # Filter to only runs that have the metric runs_with_metric = [ r for r in runs if metric in r.data.metrics ] - + return runs_with_metric[0] if runs_with_metric else None - + def get_metrics_summary( self, hours: Optional[int] = None, @@ -348,45 +344,45 @@ class ExperimentAnalyzer: ) -> Dict[str, Dict[str, float]]: """ Get summary statistics for metrics. - + Args: hours: Only include runs from the last N hours metrics: Specific metrics to summarize (None for all) - + Returns: Dict mapping metric names to {mean, min, max, count} """ import statistics - + runs = self.get_recent_runs(n=1000, hours=hours) - + # Collect all metric values metric_values: Dict[str, List[float]] = defaultdict(list) - + for run in runs: for key, value in run.data.metrics.items(): if metrics is None or key in metrics: metric_values[key].append(value) - + # Calculate statistics summary = {} for metric_name, values in metric_values.items(): if not values: continue - + summary[metric_name] = { "mean": statistics.mean(values), "min": min(values), "max": max(values), "count": len(values), } - + if len(values) >= 2: summary[metric_name]["stdev"] = statistics.stdev(values) summary[metric_name]["median"] = statistics.median(values) - + return summary - + def get_metric_trends( self, metric: str, @@ -395,30 +391,30 @@ class ExperimentAnalyzer: ) -> List[Dict[str, Any]]: """ Get metric trends over time. - + Args: metric: Metric name to track days: Number of days to look back granularity_hours: Time bucket size in hours - + Returns: List of {timestamp, mean, min, max, count} dicts """ import statistics - + runs = self.get_recent_runs(n=10000, hours=days * 24) - + # Group runs by time bucket buckets: Dict[int, List[float]] = defaultdict(list) bucket_size_ms = granularity_hours * 3600 * 1000 - + for run in runs: if metric not in run.data.metrics: continue - + bucket = (run.info.start_time // bucket_size_ms) * bucket_size_ms buckets[bucket].append(run.data.metrics[metric]) - + # Calculate statistics per bucket trends = [] for bucket_ts, values in sorted(buckets.items()): @@ -432,9 +428,9 @@ class ExperimentAnalyzer: if len(values) >= 2: trend["stdev"] = statistics.stdev(values) trends.append(trend) - + return trends - + def get_runs_by_tag( self, tag_key: str, @@ -443,12 +439,12 @@ class ExperimentAnalyzer: ) -> List[Run]: """ Get runs with a specific tag. - + Args: tag_key: Tag key to filter by tag_value: Tag value to match max_results: Maximum runs to return - + Returns: List of matching Run objects """ @@ -456,7 +452,7 @@ class ExperimentAnalyzer: filter_string=f"tags.{tag_key} = '{tag_value}'", max_results=max_results, ) - + def get_model_runs( self, model_name: str, @@ -464,11 +460,11 @@ class ExperimentAnalyzer: ) -> List[Run]: """ Get runs for a specific model. - + Args: model_name: Model name to filter by max_results: Maximum runs to return - + Returns: List of matching Run objects """ @@ -477,14 +473,14 @@ class ExperimentAnalyzer: filter_string=f"tags.`model.name` = '{model_name}'", max_results=max_results, ) - + if not runs: # Try params runs = self.search_runs( filter_string=f"params.model_name = '{model_name}'", max_results=max_results, ) - + return runs @@ -495,23 +491,23 @@ def compare_experiments( ) -> Dict[str, Dict[str, float]]: """ Compare metrics across multiple experiments. - + Args: experiment_names: Names of experiments to compare metric: Metric to compare tracking_uri: Override default tracking URI - + Returns: Dict mapping experiment names to metric statistics """ results = {} - + for exp_name in experiment_names: analyzer = ExperimentAnalyzer(exp_name, tracking_uri=tracking_uri) summary = analyzer.get_metrics_summary(metrics=[metric]) if metric in summary: results[exp_name] = summary[metric] - + return results @@ -523,7 +519,7 @@ def promotion_recommendation( ) -> PromotionRecommendation: """ Generate a recommendation for model promotion. - + Args: model_name: Name of the model to evaluate experiment_name: Experiment containing evaluation runs @@ -531,15 +527,15 @@ def promotion_recommendation( comparison is one of: ">=", "<=", ">", "<" e.g., {"eval.accuracy": (">=", 0.9), "total_latency_p95": ("<=", 2.0)} tracking_uri: Override default tracking URI - + Returns: PromotionRecommendation with decision and reasons """ analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri) - + # Get model runs runs = analyzer.get_model_runs(model_name, max_results=10) - + if not runs: return PromotionRecommendation( model_name=model_name, @@ -548,41 +544,41 @@ def promotion_recommendation( reasons=["No runs found for this model"], metrics_summary={}, ) - + # Get the most recent run latest_run = runs[0] metrics = latest_run.data.metrics - + # Evaluate criteria reasons = [] passed = True - + comparisons = { ">=": lambda a, b: a >= b, "<=": lambda a, b: a <= b, ">": lambda a, b: a > b, "<": lambda a, b: a < b, } - + for metric_name, (comparison, threshold) in criteria.items(): if metric_name not in metrics: reasons.append(f"Metric '{metric_name}' not found") passed = False continue - + value = metrics[metric_name] compare_fn = comparisons.get(comparison) - + if compare_fn is None: reasons.append(f"Invalid comparison operator: {comparison}") continue - + if compare_fn(value, threshold): reasons.append(f"✓ {metric_name}: {value:.4f} {comparison} {threshold}") else: reasons.append(f"✗ {metric_name}: {value:.4f} NOT {comparison} {threshold}") passed = False - + # Extract version from tags if available version = None if "mlflow.version" in latest_run.data.tags: @@ -590,7 +586,7 @@ def promotion_recommendation( version = int(latest_run.data.tags["mlflow.version"]) except ValueError: pass - + return PromotionRecommendation( model_name=model_name, version=version, @@ -607,21 +603,21 @@ def get_inference_performance_report( ) -> Dict[str, Any]: """ Generate an inference performance report for a service. - + Args: service_name: Service name (chat-handler, voice-assistant) hours: Hours of data to analyze tracking_uri: Override default tracking URI - + Returns: Performance report dictionary """ experiment_name = f"{service_name.replace('-', '')}-inference" analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri) - + # Get summary metrics summary = analyzer.get_metrics_summary(hours=hours) - + # Key latency metrics latency_metrics = [ "total_latency_mean", @@ -631,7 +627,7 @@ def get_inference_performance_report( "embedding_latency_mean", "rag_search_latency_mean", ] - + report = { "service": service_name, "period_hours": hours, @@ -641,24 +637,24 @@ def get_inference_performance_report( "rag": {}, "errors": {}, } - + # Latency section for metric in latency_metrics: if metric in summary: report["latency"][metric] = summary[metric] - + # Throughput if "total_requests" in summary: report["throughput"]["total_requests"] = summary["total_requests"]["mean"] - + # RAG usage rag_metrics = ["rag_enabled_pct", "rag_documents_retrieved_mean", "rag_documents_used_mean"] for metric in rag_metrics: if metric in summary: report["rag"][metric] = summary[metric] - + # Error rate if "error_rate" in summary: report["errors"]["error_rate_pct"] = summary["error_rate"]["mean"] - + return report diff --git a/mlflow_utils/inference_tracker.py b/mlflow_utils/inference_tracker.py index 820db8a..9feb871 100644 --- a/mlflow_utils/inference_tracker.py +++ b/mlflow_utils/inference_tracker.py @@ -9,20 +9,19 @@ complement OTel metrics with MLflow experiment tracking for longer-term analysis and model comparison. """ -import os -import time import asyncio import logging -from typing import Optional, Dict, Any, List -from dataclasses import dataclass, field +import time from collections import defaultdict from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field from functools import partial +from typing import Any, Dict, List, Optional import mlflow from mlflow.tracking import MlflowClient -from .client import get_mlflow_client, ensure_experiment, MLflowConfig +from .client import MLflowConfig, ensure_experiment, get_mlflow_client logger = logging.getLogger(__name__) @@ -33,7 +32,7 @@ class InferenceMetrics: request_id: str user_id: Optional[str] = None session_id: Optional[str] = None - + # Timing metrics (in seconds) total_latency: float = 0.0 embedding_latency: float = 0.0 @@ -42,33 +41,33 @@ class InferenceMetrics: llm_latency: float = 0.0 tts_latency: float = 0.0 stt_latency: float = 0.0 - + # Token/size metrics input_tokens: int = 0 output_tokens: int = 0 total_tokens: int = 0 prompt_length: int = 0 response_length: int = 0 - + # RAG metrics rag_enabled: bool = False rag_documents_retrieved: int = 0 rag_documents_used: int = 0 reranker_enabled: bool = False - + # Quality indicators is_streaming: bool = False is_premium: bool = False has_error: bool = False error_message: Optional[str] = None - + # Model information model_name: Optional[str] = None model_endpoint: Optional[str] = None - + # Timestamps timestamp: float = field(default_factory=time.time) - + def as_metrics_dict(self) -> Dict[str, float]: """Convert numeric fields to a metrics dictionary.""" return { @@ -87,7 +86,7 @@ class InferenceMetrics: "rag_documents_retrieved": float(self.rag_documents_retrieved), "rag_documents_used": float(self.rag_documents_used), } - + def as_params_dict(self) -> Dict[str, str]: """Convert configuration fields to a params dictionary.""" params = { @@ -106,39 +105,39 @@ class InferenceMetrics: class InferenceMetricsTracker: """ Async-compatible MLflow tracker for inference metrics. - + Uses batching and a background thread pool to avoid blocking the async event loop during MLflow calls. - + Example usage in chat-handler: - + class ChatHandler: def __init__(self): self.mlflow_tracker = InferenceMetricsTracker( service_name="chat-handler", experiment_name="chat-inference" ) - + async def setup(self): await self.mlflow_tracker.start() - + async def process_request(self, msg): metrics = InferenceMetrics(request_id=request_id) - + # Track timing start = time.time() # ... do embedding ... metrics.embedding_latency = time.time() - start - + # ... more processing ... - + # Log metrics (non-blocking) await self.mlflow_tracker.log_inference(metrics) - + async def shutdown(self): await self.mlflow_tracker.stop() """ - + def __init__( self, service_name: str, @@ -151,7 +150,7 @@ class InferenceMetricsTracker: ): """ Initialize the inference metrics tracker. - + Args: service_name: Name of the service (e.g., "chat-handler") experiment_name: MLflow experiment name (defaults to service_name) @@ -167,7 +166,7 @@ class InferenceMetricsTracker: self.batch_size = batch_size self.flush_interval = flush_interval_seconds self.enable_batching = enable_batching - + self.config = MLflowConfig() self._batch: List[InferenceMetrics] = [] self._batch_lock = asyncio.Lock() @@ -176,34 +175,34 @@ class InferenceMetricsTracker: self._running = False self._client: Optional[MlflowClient] = None self._experiment_id: Optional[str] = None - + # Aggregate metrics for periodic logging self._aggregate_metrics: Dict[str, List[float]] = defaultdict(list) self._request_count = 0 self._error_count = 0 - + async def start(self) -> None: """Start the tracker and initialize MLflow connection.""" if self._running: return - + self._running = True - + # Initialize MLflow in thread pool to avoid blocking loop = asyncio.get_event_loop() await loop.run_in_executor( self._executor, self._init_mlflow ) - + if self.enable_batching: self._flush_task = asyncio.create_task(self._periodic_flush()) - + logger.info( f"InferenceMetricsTracker started for {self.service_name} " f"(experiment: {self.experiment_name})" ) - + def _init_mlflow(self) -> None: """Initialize MLflow client and experiment (runs in thread pool).""" self._client = get_mlflow_client( @@ -217,47 +216,47 @@ class InferenceMetricsTracker: "type": "inference-metrics", } ) - + async def stop(self) -> None: """Stop the tracker and flush remaining metrics.""" if not self._running: return - + self._running = False - + if self._flush_task: self._flush_task.cancel() try: await self._flush_task except asyncio.CancelledError: pass - + # Final flush await self._flush_batch() - + self._executor.shutdown(wait=True) logger.info(f"InferenceMetricsTracker stopped for {self.service_name}") - + async def log_inference(self, metrics: InferenceMetrics) -> None: """ Log inference metrics (non-blocking). - + Args: metrics: InferenceMetrics object with request data """ if not self._running: logger.warning("Tracker not running, skipping metrics") return - + self._request_count += 1 if metrics.has_error: self._error_count += 1 - + # Update aggregates for key, value in metrics.as_metrics_dict().items(): if value > 0: self._aggregate_metrics[key].append(value) - + if self.enable_batching: async with self._batch_lock: self._batch.append(metrics) @@ -270,29 +269,29 @@ class InferenceMetricsTracker: self._executor, partial(self._log_single_inference, metrics) ) - + async def _periodic_flush(self) -> None: """Periodically flush batched metrics.""" while self._running: await asyncio.sleep(self.flush_interval) await self._flush_batch() - + async def _flush_batch(self) -> None: """Flush the current batch of metrics to MLflow.""" async with self._batch_lock: if not self._batch: return - + batch = self._batch self._batch = [] - + # Log in thread pool loop = asyncio.get_event_loop() await loop.run_in_executor( self._executor, partial(self._log_batch, batch) ) - + def _log_single_inference(self, metrics: InferenceMetrics) -> None: """Log a single inference request to MLflow (runs in thread pool).""" try: @@ -307,7 +306,7 @@ class InferenceMetricsTracker: ): mlflow.log_params(metrics.as_params_dict()) mlflow.log_metrics(metrics.as_metrics_dict()) - + if metrics.user_id: mlflow.set_tag("user_id", metrics.user_id) if metrics.session_id: @@ -318,18 +317,18 @@ class InferenceMetricsTracker: mlflow.set_tag("error_message", metrics.error_message[:250]) except Exception as e: logger.error(f"Failed to log inference metrics: {e}") - + def _log_batch(self, batch: List[InferenceMetrics]) -> None: """Log a batch of inference metrics as aggregate statistics.""" if not batch: return - + try: # Calculate aggregates aggregates = self._calculate_aggregates(batch) - + run_name = f"batch-{self.service_name}-{int(time.time())}" - + with mlflow.start_run( experiment_id=self._experiment_id, run_name=run_name, @@ -341,83 +340,83 @@ class InferenceMetricsTracker: ): # Log aggregate metrics mlflow.log_metrics(aggregates) - + # Log batch info mlflow.log_param("batch_size", len(batch)) mlflow.log_param("time_window_start", min(m.timestamp for m in batch)) mlflow.log_param("time_window_end", max(m.timestamp for m in batch)) - + # Log configuration breakdown rag_enabled_count = sum(1 for m in batch if m.rag_enabled) streaming_count = sum(1 for m in batch if m.is_streaming) premium_count = sum(1 for m in batch if m.is_premium) error_count = sum(1 for m in batch if m.has_error) - + mlflow.log_metrics({ "rag_enabled_pct": rag_enabled_count / len(batch) * 100, "streaming_pct": streaming_count / len(batch) * 100, "premium_pct": premium_count / len(batch) * 100, "error_rate": error_count / len(batch) * 100, }) - + # Log model distribution model_counts: Dict[str, int] = defaultdict(int) for m in batch: if m.model_name: model_counts[m.model_name] += 1 - + if model_counts: mlflow.log_dict( {"models": dict(model_counts)}, "model_distribution.json" ) - + logger.info(f"Logged batch of {len(batch)} inference metrics") - + except Exception as e: logger.error(f"Failed to log batch metrics: {e}") - + def _calculate_aggregates( self, batch: List[InferenceMetrics] ) -> Dict[str, float]: """Calculate aggregate statistics from a batch of metrics.""" import statistics - + aggregates = {} - + # Collect all numeric metrics metric_values: Dict[str, List[float]] = defaultdict(list) for m in batch: for key, value in m.as_metrics_dict().items(): if value > 0: metric_values[key].append(value) - + # Calculate statistics for each metric for key, values in metric_values.items(): if not values: continue - + aggregates[f"{key}_mean"] = statistics.mean(values) aggregates[f"{key}_min"] = min(values) aggregates[f"{key}_max"] = max(values) - + if len(values) >= 2: aggregates[f"{key}_p50"] = statistics.median(values) aggregates[f"{key}_stdev"] = statistics.stdev(values) - + if len(values) >= 4: sorted_vals = sorted(values) p95_idx = int(len(sorted_vals) * 0.95) p99_idx = int(len(sorted_vals) * 0.99) aggregates[f"{key}_p95"] = sorted_vals[p95_idx] aggregates[f"{key}_p99"] = sorted_vals[p99_idx] - + # Add counts aggregates["total_requests"] = float(len(batch)) - + return aggregates - + def get_stats(self) -> Dict[str, Any]: """Get current tracker statistics.""" return { diff --git a/mlflow_utils/kfp_components.py b/mlflow_utils/kfp_components.py index f3cdf18..7048f55 100644 --- a/mlflow_utils/kfp_components.py +++ b/mlflow_utils/kfp_components.py @@ -21,22 +21,22 @@ Usage in a Kubeflow Pipeline: experiment_name="my-experiment", run_name="training-run-1" ) - + # ... your pipeline steps ... - + # Log metrics log_step = log_metrics_component( run_id=run_info.outputs["run_id"], metrics={"accuracy": 0.95, "loss": 0.05} ) - + # End run end_mlflow_run(run_id=run_info.outputs["run_id"]) """ -from kfp import dsl -from typing import Dict, Any, List, Optional, NamedTuple +from typing import Any, Dict, List, NamedTuple +from kfp import dsl # MLflow component image with all required dependencies MLFLOW_IMAGE = "python:3.13-slim" @@ -60,31 +60,32 @@ def create_mlflow_run( ) -> NamedTuple('RunInfo', [('run_id', str), ('experiment_id', str), ('artifact_uri', str)]): """ Create a new MLflow run for the pipeline. - + This should be called at the start of a pipeline to initialize tracking. The returned run_id should be passed to subsequent components for logging. - + Args: experiment_name: Name of the MLflow experiment run_name: Name for this specific run mlflow_tracking_uri: MLflow tracking server URI tags: Optional tags to add to the run params: Optional parameters to log - + Returns: NamedTuple with run_id, experiment_id, and artifact_uri """ import os + from collections import namedtuple + import mlflow from mlflow.tracking import MlflowClient - from collections import namedtuple - + # Set tracking URI mlflow.set_tracking_uri(mlflow_tracking_uri) - + client = MlflowClient() - + # Get or create experiment experiment = client.get_experiment_by_name(experiment_name) if experiment is None: @@ -94,7 +95,7 @@ def create_mlflow_run( ) else: experiment_id = experiment.experiment_id - + # Create default tags default_tags = { "pipeline.type": "kubeflow", @@ -103,24 +104,24 @@ def create_mlflow_run( } if tags: default_tags.update(tags) - + # Start run run = mlflow.start_run( experiment_id=experiment_id, run_name=run_name, tags=default_tags, ) - + # Log initial params if params: mlflow.log_params(params) - + run_id = run.info.run_id artifact_uri = run.info.artifact_uri - + # End run (KFP components are isolated, we'll resume in other components) mlflow.end_run() - + RunInfo = namedtuple('RunInfo', ['run_id', 'experiment_id', 'artifact_uri']) return RunInfo(run_id, experiment_id, artifact_uri) @@ -136,24 +137,24 @@ def log_params_component( ) -> str: """ Log parameters to an existing MLflow run. - + Args: run_id: The MLflow run ID to log to params: Dictionary of parameters to log mlflow_tracking_uri: MLflow tracking server URI - + Returns: The run_id for chaining """ import mlflow from mlflow.tracking import MlflowClient - + mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() - + for key, value in params.items(): client.log_param(run_id, key, str(value)[:500]) - + return run_id @@ -169,25 +170,25 @@ def log_metrics_component( ) -> str: """ Log metrics to an existing MLflow run. - + Args: run_id: The MLflow run ID to log to metrics: Dictionary of metrics to log step: Step number for time-series metrics mlflow_tracking_uri: MLflow tracking server URI - + Returns: The run_id for chaining """ import mlflow from mlflow.tracking import MlflowClient - + mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() - + for key, value in metrics.items(): client.log_metric(run_id, key, float(value), step=step) - + return run_id @@ -203,24 +204,24 @@ def log_artifact_component( ) -> str: """ Log an artifact file to an existing MLflow run. - + Args: run_id: The MLflow run ID to log to artifact_path: Path to the artifact file artifact_name: Optional destination name in artifact store mlflow_tracking_uri: MLflow tracking server URI - + Returns: The run_id for chaining """ import mlflow from mlflow.tracking import MlflowClient - + mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() - + client.log_artifact(run_id, artifact_path, artifact_name or None) - + return run_id @@ -236,36 +237,37 @@ def log_dict_artifact( ) -> str: """ Log a dictionary as a JSON artifact. - + Args: run_id: The MLflow run ID to log to data: Dictionary to save as JSON filename: Name for the JSON file mlflow_tracking_uri: MLflow tracking server URI - + Returns: The run_id for chaining """ import json import tempfile + from pathlib import Path + import mlflow from mlflow.tracking import MlflowClient - from pathlib import Path - + mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() - + # Ensure .json extension if not filename.endswith('.json'): filename += '.json' - + # Write to temp file and log with tempfile.TemporaryDirectory() as tmpdir: filepath = Path(tmpdir) / filename with open(filepath, 'w') as f: json.dump(data, f, indent=2) client.log_artifact(run_id, str(filepath)) - + return run_id @@ -280,31 +282,31 @@ def end_mlflow_run( ) -> str: """ End an MLflow run with the specified status. - + Args: run_id: The MLflow run ID to end status: Run status (FINISHED, FAILED, KILLED) mlflow_tracking_uri: MLflow tracking server URI - + Returns: The run_id """ import mlflow - from mlflow.tracking import MlflowClient from mlflow.entities import RunStatus - + from mlflow.tracking import MlflowClient + mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() - + status_map = { "FINISHED": RunStatus.FINISHED, "FAILED": RunStatus.FAILED, "KILLED": RunStatus.KILLED, } - + run_status = status_map.get(status.upper(), RunStatus.FINISHED) client.set_terminated(run_id, status=run_status) - + return run_id @@ -322,10 +324,10 @@ def log_training_metrics( ) -> str: """ Log comprehensive training metrics for ML models. - + Designed for use with QLoRA training, voice training, and other ML training pipelines in the llm-workflows repository. - + Args: run_id: The MLflow run ID to log to model_type: Type of model (llm, stt, tts, embeddings) @@ -333,19 +335,20 @@ def log_training_metrics( final_metrics: Final training metrics model_path: Path to saved model (if applicable) mlflow_tracking_uri: MLflow tracking server URI - + Returns: The run_id for chaining """ import json import tempfile + from pathlib import Path + import mlflow from mlflow.tracking import MlflowClient - from pathlib import Path - + mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() - + # Log training config as params flat_config = {} for key, value in training_config.items(): @@ -353,29 +356,29 @@ def log_training_metrics( flat_config[f"config.{key}"] = json.dumps(value)[:500] else: flat_config[f"config.{key}"] = str(value)[:500] - + for key, value in flat_config.items(): client.log_param(run_id, key, value) - + # Log model type tag client.set_tag(run_id, "model.type", model_type) - + # Log metrics for key, value in final_metrics.items(): client.log_metric(run_id, key, float(value)) - + # Log full config as artifact with tempfile.TemporaryDirectory() as tmpdir: config_path = Path(tmpdir) / "training_config.json" with open(config_path, 'w') as f: json.dump(training_config, f, indent=2) client.log_artifact(run_id, str(config_path)) - + # Log model path if provided if model_path: client.log_param(run_id, "model.path", model_path) client.set_tag(run_id, "model.saved", "true") - + return run_id @@ -397,9 +400,9 @@ def log_document_ingestion_metrics( ) -> str: """ Log document ingestion pipeline metrics. - + Designed for use with the document_ingestion_pipeline. - + Args: run_id: The MLflow run ID to log to source_url: URL of the source document @@ -411,16 +414,16 @@ def log_document_ingestion_metrics( chunk_size: Chunk size in tokens chunk_overlap: Chunk overlap in tokens mlflow_tracking_uri: MLflow tracking server URI - + Returns: The run_id for chaining """ import mlflow from mlflow.tracking import MlflowClient - + mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() - + # Log params params = { "source_url": source_url[:500], @@ -431,7 +434,7 @@ def log_document_ingestion_metrics( } for key, value in params.items(): client.log_param(run_id, key, value) - + # Log metrics metrics = { "chunks_created": chunks_created, @@ -441,11 +444,11 @@ def log_document_ingestion_metrics( } for key, value in metrics.items(): client.log_metric(run_id, key, float(value)) - + # Set pipeline type tag client.set_tag(run_id, "pipeline.type", "document-ingestion") client.set_tag(run_id, "milvus.collection", collection_name) - + return run_id @@ -463,9 +466,9 @@ def log_evaluation_results( ) -> str: """ Log model evaluation results. - + Designed for use with the evaluation_pipeline. - + Args: run_id: The MLflow run ID to log to model_name: Name of the evaluated model @@ -473,27 +476,28 @@ def log_evaluation_results( metrics: Evaluation metrics (accuracy, etc.) sample_results: Optional sample predictions mlflow_tracking_uri: MLflow tracking server URI - + Returns: The run_id for chaining """ import json import tempfile + from pathlib import Path + import mlflow from mlflow.tracking import MlflowClient - from pathlib import Path - + mlflow.set_tracking_uri(mlflow_tracking_uri) client = MlflowClient() - + # Log params client.log_param(run_id, "eval.model_name", model_name) client.log_param(run_id, "eval.dataset", dataset_name) - + # Log metrics for key, value in metrics.items(): client.log_metric(run_id, f"eval.{key}", float(value)) - + # Log sample results as artifact if sample_results: with tempfile.TemporaryDirectory() as tmpdir: @@ -501,13 +505,13 @@ def log_evaluation_results( with open(results_path, 'w') as f: json.dump(sample_results, f, indent=2) client.log_artifact(run_id, str(results_path)) - + # Set tags client.set_tag(run_id, "pipeline.type", "evaluation") client.set_tag(run_id, "model.name", model_name) - + # Determine if passed passed = metrics.get("pass", metrics.get("accuracy", 0) >= 0.7) client.set_tag(run_id, "eval.passed", str(passed)) - + return run_id diff --git a/mlflow_utils/model_registry.py b/mlflow_utils/model_registry.py index ccc7674..b229a85 100644 --- a/mlflow_utils/model_registry.py +++ b/mlflow_utils/model_registry.py @@ -17,7 +17,7 @@ Usage: promote_model_to_production, generate_kserve_manifest, ) - + # Register a new model version model_version = register_model_for_kserve( model_name="whisper-finetuned", @@ -28,7 +28,7 @@ Usage: "container_image": "ghcr.io/my-org/whisper:v2", } ) - + # Generate KServe manifest for deployment manifest = generate_kserve_manifest( model_name="whisper-finetuned", @@ -36,18 +36,15 @@ Usage: ) """ -import os -import json -import yaml import logging -from typing import Optional, Dict, Any, List from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional import mlflow -from mlflow.tracking import MlflowClient +import yaml from mlflow.entities.model_registry import ModelVersion -from .client import get_mlflow_client, MLflowConfig +from .client import get_mlflow_client logger = logging.getLogger(__name__) @@ -55,15 +52,15 @@ logger = logging.getLogger(__name__) @dataclass class KServeConfig: """Configuration for KServe deployment.""" - + # Runtime/container configuration runtime: str = "kserve-huggingface" # kserve-huggingface, kserve-custom, etc. container_image: Optional[str] = None container_port: int = 8080 - + # Protocol configuration protocol: str = "v2" # v1, v2, grpc - + # Resource requests/limits cpu_request: str = "1" cpu_limit: str = "4" @@ -71,22 +68,22 @@ class KServeConfig: memory_limit: str = "16Gi" gpu_count: int = 0 gpu_type: str = "nvidia.com/gpu" # or amd.com/gpu for ROCm - + # Storage configuration storage_uri: Optional[str] = None # s3://, pvc://, gs:// - + # Scaling configuration min_replicas: int = 1 max_replicas: int = 1 scale_target: int = 10 # Target concurrent requests for scaling - + # Serving configuration timeout_seconds: int = 300 batch_size: int = 1 - + # Additional environment variables env_vars: Dict[str, str] = field(default_factory=dict) - + def as_dict(self) -> Dict[str, Any]: """Convert to dictionary for MLflow tags.""" return { @@ -165,7 +162,7 @@ def register_model_for_kserve( ) -> ModelVersion: """ Register a model in MLflow Model Registry with KServe metadata. - + Args: model_name: Name for the registered model model_uri: URI to model artifacts (runs:/run_id/path, s3://, pvc://) @@ -175,16 +172,16 @@ def register_model_for_kserve( kserve_config: KServe deployment configuration tags: Additional tags for the model version tracking_uri: Override default tracking URI - + Returns: The created ModelVersion object """ client = get_mlflow_client(tracking_uri=tracking_uri) - + # Get or use preset KServe config if kserve_config is None: kserve_config = KSERVE_PRESETS.get(model_type, KServeConfig()) - + # Ensure registered model exists try: client.get_registered_model(model_name) @@ -198,7 +195,7 @@ def register_model_for_kserve( } ) logger.info(f"Created registered model: {model_name}") - + # Create model version model_version = client.create_model_version( name=model_name, @@ -211,12 +208,12 @@ def register_model_for_kserve( **kserve_config.as_dict(), } ) - + logger.info( f"Registered model version {model_version.version} " f"for {model_name} (type: {model_type})" ) - + return model_version @@ -229,19 +226,19 @@ def promote_model_to_stage( ) -> ModelVersion: """ Promote a model version to a new stage. - + Args: model_name: Name of the registered model version: Version number to promote stage: Target stage (Staging, Production, Archived) archive_existing: If True, archive existing versions in target stage tracking_uri: Override default tracking URI - + Returns: The updated ModelVersion """ client = get_mlflow_client(tracking_uri=tracking_uri) - + # Transition to new stage model_version = client.transition_model_version_stage( name=model_name, @@ -249,9 +246,9 @@ def promote_model_to_stage( stage=stage, archive_existing_versions=archive_existing, ) - + logger.info(f"Promoted {model_name} v{version} to {stage}") - + return model_version @@ -262,12 +259,12 @@ def promote_model_to_production( ) -> ModelVersion: """ Promote a model version directly to Production. - + Args: model_name: Name of the registered model version: Version number to promote tracking_uri: Override default tracking URI - + Returns: The updated ModelVersion """ @@ -286,18 +283,18 @@ def get_production_model( ) -> Optional[ModelVersion]: """ Get the current Production model version. - + Args: model_name: Name of the registered model tracking_uri: Override default tracking URI - + Returns: The Production ModelVersion, or None if none exists """ client = get_mlflow_client(tracking_uri=tracking_uri) - + versions = client.get_latest_versions(model_name, stages=["Production"]) - + return versions[0] if versions else None @@ -308,17 +305,17 @@ def get_model_kserve_config( ) -> KServeConfig: """ Get KServe configuration from a registered model version. - + Args: model_name: Name of the registered model version: Version number (uses Production if not specified) tracking_uri: Override default tracking URI - + Returns: KServeConfig populated from model tags """ client = get_mlflow_client(tracking_uri=tracking_uri) - + if version: model_version = client.get_model_version(model_name, str(version)) else: @@ -326,9 +323,9 @@ def get_model_kserve_config( if not prod_version: raise ValueError(f"No Production version for {model_name}") model_version = prod_version - + tags = model_version.tags - + return KServeConfig( runtime=tags.get("kserve.runtime", "kserve-huggingface"), protocol=tags.get("kserve.protocol", "v2"), @@ -352,7 +349,7 @@ def generate_kserve_manifest( ) -> Dict[str, Any]: """ Generate a KServe InferenceService manifest from a registered model. - + Args: model_name: Name of the registered model version: Version number (uses Production if not specified) @@ -360,12 +357,12 @@ def generate_kserve_manifest( service_name: Name for the InferenceService (defaults to model_name) extra_annotations: Additional annotations for the service tracking_uri: Override default tracking URI - + Returns: KServe InferenceService manifest as a dictionary """ client = get_mlflow_client(tracking_uri=tracking_uri) - + # Get model version if version: model_version = client.get_model_version(model_name, str(version)) @@ -375,13 +372,13 @@ def generate_kserve_manifest( raise ValueError(f"No Production version for {model_name}") model_version = prod_version version = int(model_version.version) - + # Get KServe config config = get_model_kserve_config(model_name, version, tracking_uri) model_type = model_version.tags.get("model.type", "custom") - + svc_name = service_name or model_name.lower().replace("_", "-") - + # Build manifest manifest = { "apiVersion": "serving.kserve.io/v1beta1", @@ -409,10 +406,10 @@ def generate_kserve_manifest( }, }, } - + # Configure predictor based on runtime predictor = manifest["spec"]["predictor"] - + if config.container_image: # Custom container predictor["containers"] = [{ @@ -434,16 +431,16 @@ def generate_kserve_manifest( for k, v in config.env_vars.items() ], }] - + # Add GPU if needed if config.gpu_count > 0: predictor["containers"][0]["resources"]["limits"][config.gpu_type] = str(config.gpu_count) predictor["containers"][0]["resources"]["requests"][config.gpu_type] = str(config.gpu_count) - + else: # Standard KServe runtime storage_uri = config.storage_uri or model_version.source - + predictor["model"] = { "modelFormat": {"name": "huggingface"}, "protocolVersion": config.protocol, @@ -459,11 +456,11 @@ def generate_kserve_manifest( }, }, } - + if config.gpu_count > 0: predictor["model"]["resources"]["limits"][config.gpu_type] = str(config.gpu_count) predictor["model"]["resources"]["requests"][config.gpu_type] = str(config.gpu_count) - + return manifest @@ -476,14 +473,14 @@ def generate_kserve_yaml( ) -> str: """ Generate a KServe InferenceService manifest as YAML. - + Args: model_name: Name of the registered model version: Version number (uses Production if not specified) namespace: Kubernetes namespace output_path: If provided, write YAML to this path tracking_uri: Override default tracking URI - + Returns: YAML string of the manifest """ @@ -493,14 +490,14 @@ def generate_kserve_yaml( namespace=namespace, tracking_uri=tracking_uri, ) - + yaml_str = yaml.dump(manifest, default_flow_style=False, sort_keys=False) - + if output_path: with open(output_path, 'w') as f: f.write(yaml_str) logger.info(f"Wrote KServe manifest to {output_path}") - + return yaml_str @@ -511,17 +508,17 @@ def list_model_versions( ) -> List[Dict[str, Any]]: """ List all versions of a registered model. - + Args: model_name: Name of the registered model stages: Filter by stages (None for all) tracking_uri: Override default tracking URI - + Returns: List of model version info dictionaries """ client = get_mlflow_client(tracking_uri=tracking_uri) - + if stages: versions = client.get_latest_versions(model_name, stages=stages) else: @@ -529,7 +526,7 @@ def list_model_versions( versions = [] for mv in client.search_model_versions(f"name='{model_name}'"): versions.append(mv) - + return [ { "version": mv.version, diff --git a/mlflow_utils/tracker.py b/mlflow_utils/tracker.py index a92ae6b..3438ebb 100644 --- a/mlflow_utils/tracker.py +++ b/mlflow_utils/tracker.py @@ -5,19 +5,17 @@ Provides a high-level interface for logging experiments, parameters, metrics, and artifacts from Kubeflow Pipeline components. """ -import os -import json -import time import logging -from pathlib import Path -from typing import Optional, Dict, Any, List, Union +import os +import time from contextlib import contextmanager from dataclasses import dataclass, field +from typing import Any, Dict, Optional, Union import mlflow from mlflow.tracking import MlflowClient -from .client import get_mlflow_client, ensure_experiment, MLflowConfig +from .client import MLflowConfig, ensure_experiment, get_mlflow_client logger = logging.getLogger(__name__) @@ -30,7 +28,7 @@ class PipelineMetadata: run_name: Optional[str] = None component_name: Optional[str] = None namespace: str = "ai-ml" - + # KFP-specific metadata (populated from environment if available) kfp_run_id: Optional[str] = field( default_factory=lambda: os.environ.get("KFP_RUN_ID") @@ -38,7 +36,7 @@ class PipelineMetadata: kfp_pod_name: Optional[str] = field( default_factory=lambda: os.environ.get("KFP_POD_NAME") ) - + def as_tags(self) -> Dict[str, str]: """Convert metadata to MLflow tags.""" tags = { @@ -60,34 +58,34 @@ class PipelineMetadata: class MLflowTracker: """ MLflow experiment tracker for Kubeflow Pipeline components. - + Example usage in a KFP component: - + from mlflow_utils import MLflowTracker - + tracker = MLflowTracker( experiment_name="document-ingestion", run_name="batch-ingestion-2024-01" ) - + with tracker.start_run() as run: tracker.log_params({ "chunk_size": 500, "overlap": 50, "embeddings_model": "bge-small-en-v1.5" }) - + # ... do work ... - + tracker.log_metrics({ "documents_processed": 100, "chunks_created": 2500, "processing_time_seconds": 120.5 }) - + tracker.log_artifact("/path/to/output.json") """ - + def __init__( self, experiment_name: str, @@ -98,7 +96,7 @@ class MLflowTracker: ): """ Initialize the MLflow tracker. - + Args: experiment_name: Name of the MLflow experiment run_name: Optional name for this run @@ -112,22 +110,22 @@ class MLflowTracker: self.pipeline_metadata = pipeline_metadata self.user_tags = tags or {} self.tracking_uri = tracking_uri - + self.client: Optional[MlflowClient] = None self.run: Optional[mlflow.ActiveRun] = None self.run_id: Optional[str] = None self._start_time: Optional[float] = None - + def _get_all_tags(self) -> Dict[str, str]: """Combine all tags for the run.""" tags = self.config.default_tags.copy() - + if self.pipeline_metadata: tags.update(self.pipeline_metadata.as_tags()) - + tags.update(self.user_tags) return tags - + @contextmanager def start_run( self, @@ -136,11 +134,11 @@ class MLflowTracker: ): """ Start an MLflow run as a context manager. - + Args: nested: If True, create a nested run under the current active run parent_run_id: Explicit parent run ID for nested runs - + Yields: The MLflow run object """ @@ -148,12 +146,12 @@ class MLflowTracker: tracking_uri=self.tracking_uri, configure_global=True ) - + # Ensure experiment exists experiment_id = ensure_experiment(self.experiment_name) - + self._start_time = time.time() - + try: # Start the run self.run = mlflow.start_run( @@ -163,14 +161,14 @@ class MLflowTracker: tags=self._get_all_tags(), ) self.run_id = self.run.info.run_id - + logger.info( f"Started MLflow run '{self.run_name}' " f"(ID: {self.run_id}) in experiment '{self.experiment_name}'" ) - + yield self.run - + except Exception as e: logger.error(f"MLflow run failed: {e}") if self.run: @@ -185,22 +183,22 @@ class MLflowTracker: mlflow.log_metric("run_duration_seconds", duration) except Exception: pass - + # End the run mlflow.end_run() logger.info(f"Ended MLflow run '{self.run_name}'") - + def log_params(self, params: Dict[str, Any]) -> None: """ Log parameters to the current run. - + Args: params: Dictionary of parameter names to values """ if not self.run: logger.warning("No active run, skipping log_params") return - + # MLflow has limits on param values, truncate if needed cleaned_params = {} for key, value in params.items(): @@ -208,14 +206,14 @@ class MLflowTracker: if len(str_value) > 500: str_value = str_value[:497] + "..." cleaned_params[key] = str_value - + mlflow.log_params(cleaned_params) logger.debug(f"Logged {len(params)} parameters") - + def log_param(self, key: str, value: Any) -> None: """Log a single parameter.""" self.log_params({key: value}) - + def log_metrics( self, metrics: Dict[str, Union[float, int]], @@ -223,7 +221,7 @@ class MLflowTracker: ) -> None: """ Log metrics to the current run. - + Args: metrics: Dictionary of metric names to values step: Optional step number for time-series metrics @@ -231,10 +229,10 @@ class MLflowTracker: if not self.run: logger.warning("No active run, skipping log_metrics") return - + mlflow.log_metrics(metrics, step=step) logger.debug(f"Logged {len(metrics)} metrics") - + def log_metric( self, key: str, @@ -243,7 +241,7 @@ class MLflowTracker: ) -> None: """Log a single metric.""" self.log_metrics({key: value}, step=step) - + def log_artifact( self, local_path: str, @@ -251,7 +249,7 @@ class MLflowTracker: ) -> None: """ Log an artifact file to the current run. - + Args: local_path: Path to the local file to log artifact_path: Optional destination path within the artifact store @@ -259,10 +257,10 @@ class MLflowTracker: if not self.run: logger.warning("No active run, skipping log_artifact") return - + mlflow.log_artifact(local_path, artifact_path) logger.info(f"Logged artifact: {local_path}") - + def log_artifacts( self, local_dir: str, @@ -270,7 +268,7 @@ class MLflowTracker: ) -> None: """ Log all files in a directory as artifacts. - + Args: local_dir: Path to the local directory artifact_path: Optional destination path within the artifact store @@ -278,10 +276,10 @@ class MLflowTracker: if not self.run: logger.warning("No active run, skipping log_artifacts") return - + mlflow.log_artifacts(local_dir, artifact_path) logger.info(f"Logged artifacts from: {local_dir}") - + def log_dict( self, data: Dict[str, Any], @@ -290,7 +288,7 @@ class MLflowTracker: ) -> None: """ Log a dictionary as a JSON artifact. - + Args: data: Dictionary to log filename: Name for the JSON file @@ -299,14 +297,14 @@ class MLflowTracker: if not self.run: logger.warning("No active run, skipping log_dict") return - + # Ensure .json extension if not filename.endswith(".json"): filename += ".json" - + mlflow.log_dict(data, f"{artifact_path}/{filename}" if artifact_path else filename) logger.debug(f"Logged dict as: {filename}") - + def log_model_info( self, model_type: str, @@ -317,7 +315,7 @@ class MLflowTracker: ) -> None: """ Log model information as parameters and tags. - + Args: model_type: Type of model (e.g., "llm", "embedding", "stt") model_name: Name/identifier of the model @@ -335,13 +333,13 @@ class MLflowTracker: if extra_info: for key, value in extra_info.items(): params[f"model.{key}"] = value - + self.log_params(params) - + # Also set as tags for easier filtering mlflow.set_tag("model.type", model_type) mlflow.set_tag("model.name", model_name) - + def log_dataset_info( self, name: str, @@ -351,7 +349,7 @@ class MLflowTracker: ) -> None: """ Log dataset information. - + Args: name: Dataset name source: Dataset source (URL, path, etc.) @@ -367,26 +365,26 @@ class MLflowTracker: if extra_info: for key, value in extra_info.items(): params[f"dataset.{key}"] = value - + self.log_params(params) - + def set_tag(self, key: str, value: str) -> None: """Set a single tag on the run.""" if self.run: mlflow.set_tag(key, value) - + def set_tags(self, tags: Dict[str, str]) -> None: """Set multiple tags on the run.""" if self.run: mlflow.set_tags(tags) - + @property def artifact_uri(self) -> Optional[str]: """Get the artifact URI for the current run.""" if self.run: return self.run.info.artifact_uri return None - + @property def experiment_id(self) -> Optional[str]: """Get the experiment ID for the current run.""" diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 2f9a50a..75116ae 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -8,12 +8,12 @@ import pytest def test_package_imports() -> None: """All public symbols are importable.""" from mlflow_utils import ( # noqa: F401 + InferenceMetricsTracker, MLflowConfig, MLflowTracker, - InferenceMetricsTracker, + ensure_experiment, get_mlflow_client, get_tracking_uri, - ensure_experiment, ) @@ -48,8 +48,8 @@ def test_kfp_components_importable() -> None: def test_model_registry_importable() -> None: from mlflow_utils.model_registry import ( # noqa: F401 - register_model_for_kserve, generate_kserve_manifest, + register_model_for_kserve, )