diff --git a/README.md b/README.md index 990b8bd..c1d110e 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,129 @@ -# mlflow +# MLflow Utils +MLflow integration utilities for the DaviesTechLabs AI/ML platform. + +## Installation + +```bash +pip install -r requirements.txt +``` + +Or from Gitea: +```bash +pip install git+https://git.daviestechlabs.io/daviestechlabs/mlflow.git +``` + +## Modules + +| Module | Description | +|--------|-------------| +| `client.py` | MLflow client configuration and helpers | +| `tracker.py` | General MLflowTracker for experiments | +| `inference_tracker.py` | Async inference metrics for NATS handlers | +| `model_registry.py` | Model Registry with KServe metadata | +| `kfp_components.py` | Kubeflow Pipeline MLflow components | +| `experiment_comparison.py` | Compare experiments and runs | +| `cli.py` | Command-line interface | + +## Quick Start + +```python +from mlflow_utils import get_mlflow_client, MLflowTracker + +# Simple tracking +with MLflowTracker(experiment_name="my-experiment") as tracker: + tracker.log_params({"learning_rate": 0.001}) + tracker.log_metrics({"accuracy": 0.95}) +``` + +## Inference Tracking + +For NATS handlers (chat-handler, voice-assistant): + +```python +from mlflow_utils import InferenceMetricsTracker +from mlflow_utils.inference_tracker import InferenceMetrics + +tracker = InferenceMetricsTracker( + experiment_name="voice-assistant-prod", + batch_size=100, # Batch metrics before logging +) + +# During request handling +metrics = InferenceMetrics( + request_id="uuid", + total_latency=1.5, + llm_latency=0.8, + input_tokens=150, + output_tokens=200, +) +await tracker.log_inference(metrics) +``` + +## Model Registry + +Register models with KServe deployment metadata: + +```python +from mlflow_utils.model_registry import register_model_for_kserve + +register_model_for_kserve( + model_name="my-qlora-adapter", + model_uri="runs:/abc123/model", + kserve_runtime="kserve-vllm", + gpu_type="amd-strixhalo", +) +``` + +## Kubeflow Components + +Use in KFP pipelines: + +```python +from mlflow_utils.kfp_components import ( + log_experiment_component, + register_model_component, +) +``` + +## CLI + +```bash +# List experiments +python -m mlflow_utils.cli list-experiments + +# Compare runs +python -m mlflow_utils.cli compare-runs --experiment "qlora-training" + +# Export metrics +python -m mlflow_utils.cli export --run-id abc123 --output metrics.json +``` + +## Configuration + +| Environment Variable | Default | Description | +|---------------------|---------|-------------| +| `MLFLOW_TRACKING_URI` | `http://mlflow.mlflow.svc.cluster.local:80` | MLflow server | +| `MLFLOW_EXPERIMENT_NAME` | `default` | Default experiment | +| `MLFLOW_ENABLE_ASYNC` | `true` | Async logging for handlers | + +## Module Structure + +``` +mlflow_utils/ +├── __init__.py # Public API +├── client.py # Connection management +├── tracker.py # General experiment tracker +├── inference_tracker.py # Async inference metrics +├── model_registry.py # Model registration + KServe +├── kfp_components.py # Kubeflow components +├── experiment_comparison.py # Run comparison tools +└── cli.py # Command-line interface +``` + +## Related + +- [handler-base](https://git.daviestechlabs.io/daviestechlabs/handler-base) - Uses inference tracker +- [kubeflow](https://git.daviestechlabs.io/daviestechlabs/kubeflow) - KFP components +- [argo](https://git.daviestechlabs.io/daviestechlabs/argo) - Training workflows +- [homelab-design](https://git.daviestechlabs.io/daviestechlabs/homelab-design) - Architecture docs diff --git a/mlflow_utils/__init__.py b/mlflow_utils/__init__.py new file mode 100644 index 0000000..0ec37ef --- /dev/null +++ b/mlflow_utils/__init__.py @@ -0,0 +1,40 @@ +""" +MLflow Integration Utilities for LLM Workflows + +This module provides MLflow integration for: +- Kubeflow Pipelines experiment tracking +- Model Registry with KServe deployment metadata +- Inference metrics logging from NATS handlers +- Experiment comparison and analysis + +Configuration: + Set MLFLOW_TRACKING_URI environment variable or use defaults: + - In-cluster: http://mlflow.mlflow.svc.cluster.local:80 + - External: https://mlflow.lab.daviestechlabs.io + +Usage: + from mlflow_utils import get_mlflow_client, MLflowTracker + from mlflow_utils.kfp_components import log_experiment_component + from mlflow_utils.model_registry import register_model_for_kserve + from mlflow_utils.inference_tracker import InferenceMetricsTracker +""" + +from .client import ( + get_mlflow_client, + get_tracking_uri, + ensure_experiment, + MLflowConfig, +) +from .tracker import MLflowTracker +from .inference_tracker import InferenceMetricsTracker + +__all__ = [ + "get_mlflow_client", + "get_tracking_uri", + "ensure_experiment", + "MLflowConfig", + "MLflowTracker", + "InferenceMetricsTracker", +] + +__version__ = "1.0.0" diff --git a/mlflow_utils/cli.py b/mlflow_utils/cli.py new file mode 100644 index 0000000..c535a5a --- /dev/null +++ b/mlflow_utils/cli.py @@ -0,0 +1,371 @@ +#!/usr/bin/env python3 +""" +MLflow Experiment CLI + +Command-line interface for querying and comparing MLflow experiments. + +Usage: + # Compare recent runs in an experiment + python -m mlflow_utils.cli compare --experiment chat-inference --runs 5 + + # Get best run by metric + python -m mlflow_utils.cli best --experiment evaluation --metric eval.accuracy + + # Generate performance report + python -m mlflow_utils.cli report --service chat-handler --hours 24 + + # Check model promotion criteria + python -m mlflow_utils.cli promote --model whisper-finetuned \\ + --experiment voice-evaluation \\ + --criteria "eval.accuracy>=0.9,total_latency_p95<=2.0" + + # List experiments + python -m mlflow_utils.cli list-experiments + + # Query runs + python -m mlflow_utils.cli query --experiment chat-inference \\ + --filter "metrics.total_latency_mean < 1.0" --limit 10 +""" + +import argparse +import json +import sys +from typing import Optional + +from .client import get_mlflow_client, health_check +from .experiment_comparison import ( + ExperimentAnalyzer, + compare_experiments, + promotion_recommendation, + get_inference_performance_report, +) +from .model_registry import ( + list_model_versions, + get_production_model, + generate_kserve_yaml, +) + + +def cmd_health(args): + """Check MLflow connectivity.""" + result = health_check() + print(json.dumps(result, indent=2)) + sys.exit(0 if result["connected"] else 1) + + +def cmd_list_experiments(args): + """List all experiments.""" + client = get_mlflow_client(tracking_uri=args.tracking_uri) + experiments = client.search_experiments() + + print(f"{'ID':<10} {'Name':<40} {'Artifact Location'}") + print("-" * 80) + for exp in experiments: + print(f"{exp.experiment_id:<10} {exp.name:<40} {exp.artifact_location}") + + +def cmd_compare(args): + """Compare recent runs in an experiment.""" + analyzer = ExperimentAnalyzer( + args.experiment, + tracking_uri=args.tracking_uri + ) + + if args.run_ids: + run_ids = args.run_ids.split(",") + comparison = analyzer.compare_runs(run_ids=run_ids) + else: + comparison = analyzer.compare_runs(n_recent=args.runs) + + if args.json: + print(json.dumps(comparison.to_dict(), indent=2, default=str)) + else: + print(comparison.summary_table()) + + +def cmd_best(args): + """Find the best run by a metric.""" + analyzer = ExperimentAnalyzer( + args.experiment, + tracking_uri=args.tracking_uri + ) + + best_run = analyzer.get_best_run( + metric=args.metric, + minimize=args.minimize, + filter_string=args.filter or "", + ) + + if not best_run: + print(f"No runs found with metric '{args.metric}'") + sys.exit(1) + + result = { + "run_id": best_run.info.run_id, + "run_name": best_run.info.run_name, + "metric_value": best_run.data.metrics.get(args.metric), + "all_metrics": dict(best_run.data.metrics), + "params": dict(best_run.data.params), + } + + if args.json: + print(json.dumps(result, indent=2)) + else: + print(f"Best Run: {best_run.info.run_name or best_run.info.run_id}") + print(f" {args.metric}: {result['metric_value']}") + print(f" Run ID: {best_run.info.run_id}") + + +def cmd_summary(args): + """Get metrics summary for an experiment.""" + analyzer = ExperimentAnalyzer( + args.experiment, + tracking_uri=args.tracking_uri + ) + + summary = analyzer.get_metrics_summary( + hours=args.hours, + metrics=args.metrics.split(",") if args.metrics else None, + ) + + if args.json: + print(json.dumps(summary, indent=2)) + else: + print(f"Metrics Summary for '{args.experiment}' (last {args.hours}h)") + print("=" * 60) + for metric, stats in sorted(summary.items()): + print(f"\n{metric}:") + for stat, value in stats.items(): + print(f" {stat}: {value:.4f}") + + +def cmd_report(args): + """Generate an inference performance report.""" + report = get_inference_performance_report( + service_name=args.service, + hours=args.hours, + tracking_uri=args.tracking_uri, + ) + + if args.json: + print(json.dumps(report, indent=2)) + else: + print(f"Performance Report: {report['service']}") + print(f"Period: Last {report['period_hours']} hours") + print(f"Generated: {report['generated_at']}") + print() + + if report["latency"]: + print("Latency Metrics:") + for metric, stats in report["latency"].items(): + if "mean" in stats: + print(f" {metric}: {stats['mean']:.4f}s (p50: {stats.get('median', 'N/A')})") + + if report["rag"]: + print("\nRAG Usage:") + for metric, stats in report["rag"].items(): + print(f" {metric}: {stats.get('mean', 'N/A')}") + + if report["errors"]: + print("\nError Rates:") + for metric, stats in report["errors"].items(): + print(f" {metric}: {stats:.2f}%") + + +def cmd_promote(args): + """Check model promotion criteria.""" + # Parse criteria + criteria = {} + for criterion in args.criteria.split(","): + # Parse "metric>=value" or "metric<=value" etc. + for op in [">=", "<=", ">", "<"]: + if op in criterion: + metric, value = criterion.split(op) + criteria[metric.strip()] = (op, float(value.strip())) + break + + rec = promotion_recommendation( + model_name=args.model, + experiment_name=args.experiment, + criteria=criteria, + tracking_uri=args.tracking_uri, + ) + + if args.json: + print(json.dumps(rec.to_dict(), indent=2)) + else: + status = "✓ RECOMMENDED" if rec.recommended else "✗ NOT RECOMMENDED" + print(f"Model: {args.model}") + print(f"Status: {status}") + print("\nCriteria Evaluation:") + for reason in rec.reasons: + print(f" {reason}") + + +def cmd_query(args): + """Query runs with a filter.""" + analyzer = ExperimentAnalyzer( + args.experiment, + tracking_uri=args.tracking_uri + ) + + runs = analyzer.search_runs( + filter_string=args.filter or "", + max_results=args.limit, + ) + + if args.json: + result = [ + { + "run_id": r.info.run_id, + "run_name": r.info.run_name, + "status": r.info.status, + "metrics": dict(r.data.metrics), + "params": dict(r.data.params), + } + for r in runs + ] + print(json.dumps(result, indent=2)) + else: + print(f"Found {len(runs)} runs") + for run in runs: + print(f"\n{run.info.run_name or run.info.run_id}") + print(f" ID: {run.info.run_id}") + print(f" Status: {run.info.status}") + + +def cmd_models(args): + """List registered models.""" + client = get_mlflow_client(tracking_uri=args.tracking_uri) + + if args.model: + versions = list_model_versions(args.model, tracking_uri=args.tracking_uri) + + if args.json: + print(json.dumps(versions, indent=2, default=str)) + else: + print(f"Model: {args.model}") + for v in versions: + print(f" v{v['version']} ({v['stage']}): {v['description'][:50] if v['description'] else 'No description'}") + else: + # List all models + models = client.search_registered_models() + + if args.json: + result = [{"name": m.name, "description": m.description} for m in models] + print(json.dumps(result, indent=2)) + else: + print(f"{'Model Name':<40} Description") + print("-" * 80) + for model in models: + desc = (model.description or "")[:35] + print(f"{model.name:<40} {desc}") + + +def cmd_kserve(args): + """Generate KServe manifest for a model.""" + yaml_str = generate_kserve_yaml( + model_name=args.model, + version=args.version, + namespace=args.namespace, + output_path=args.output, + tracking_uri=args.tracking_uri, + ) + + if not args.output: + print(yaml_str) + + +def main(): + parser = argparse.ArgumentParser( + description="MLflow Experiment CLI", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "--tracking-uri", + default=None, + help="MLflow tracking URI (default: from env or in-cluster)", + ) + parser.add_argument( + "--json", + action="store_true", + help="Output as JSON", + ) + + subparsers = parser.add_subparsers(dest="command", help="Commands") + + # health + health_parser = subparsers.add_parser("health", help="Check MLflow connectivity") + health_parser.set_defaults(func=cmd_health) + + # list-experiments + list_parser = subparsers.add_parser("list-experiments", help="List experiments") + list_parser.set_defaults(func=cmd_list_experiments) + + # compare + compare_parser = subparsers.add_parser("compare", help="Compare runs") + compare_parser.add_argument("--experiment", "-e", required=True, help="Experiment name") + compare_parser.add_argument("--runs", "-n", type=int, default=5, help="Number of recent runs") + compare_parser.add_argument("--run-ids", help="Comma-separated run IDs to compare") + compare_parser.set_defaults(func=cmd_compare) + + # best + best_parser = subparsers.add_parser("best", help="Find best run by metric") + best_parser.add_argument("--experiment", "-e", required=True, help="Experiment name") + best_parser.add_argument("--metric", "-m", required=True, help="Metric to optimize") + best_parser.add_argument("--minimize", action="store_true", help="Minimize metric (default: maximize)") + best_parser.add_argument("--filter", "-f", help="Filter string") + best_parser.set_defaults(func=cmd_best) + + # summary + summary_parser = subparsers.add_parser("summary", help="Get metrics summary") + summary_parser.add_argument("--experiment", "-e", required=True, help="Experiment name") + summary_parser.add_argument("--hours", type=int, default=24, help="Hours of data") + summary_parser.add_argument("--metrics", help="Comma-separated metric names") + summary_parser.set_defaults(func=cmd_summary) + + # report + report_parser = subparsers.add_parser("report", help="Generate performance report") + report_parser.add_argument("--service", "-s", required=True, help="Service name") + report_parser.add_argument("--hours", type=int, default=24, help="Hours of data") + report_parser.set_defaults(func=cmd_report) + + # promote + promote_parser = subparsers.add_parser("promote", help="Check promotion criteria") + promote_parser.add_argument("--model", "-m", required=True, help="Model name") + promote_parser.add_argument("--experiment", "-e", required=True, help="Experiment name") + promote_parser.add_argument("--criteria", "-c", required=True, help="Criteria (e.g., 'accuracy>=0.9,latency<=2.0')") + promote_parser.set_defaults(func=cmd_promote) + + # query + query_parser = subparsers.add_parser("query", help="Query runs") + query_parser.add_argument("--experiment", "-e", required=True, help="Experiment name") + query_parser.add_argument("--filter", "-f", help="MLflow filter string") + query_parser.add_argument("--limit", "-l", type=int, default=20, help="Max results") + query_parser.set_defaults(func=cmd_query) + + # models + models_parser = subparsers.add_parser("models", help="List registered models") + models_parser.add_argument("--model", "-m", help="Specific model name") + models_parser.set_defaults(func=cmd_models) + + # kserve + kserve_parser = subparsers.add_parser("kserve", help="Generate KServe manifest") + kserve_parser.add_argument("--model", "-m", required=True, help="Model name") + kserve_parser.add_argument("--version", "-v", type=int, help="Model version") + kserve_parser.add_argument("--namespace", "-n", default="ai-ml", help="K8s namespace") + kserve_parser.add_argument("--output", "-o", help="Output file path") + kserve_parser.set_defaults(func=cmd_kserve) + + args = parser.parse_args() + + if not args.command: + parser.print_help() + sys.exit(1) + + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/mlflow_utils/client.py b/mlflow_utils/client.py new file mode 100644 index 0000000..b448d99 --- /dev/null +++ b/mlflow_utils/client.py @@ -0,0 +1,209 @@ +""" +MLflow Client Configuration and Initialization + +Provides a configured MLflow client for all integrations in the LLM workflows. +Supports both in-cluster and external access patterns. +""" + +import os +import logging +from dataclasses import dataclass, field +from typing import Optional, Dict, Any + +import mlflow +from mlflow.tracking import MlflowClient + +logger = logging.getLogger(__name__) + + +@dataclass +class MLflowConfig: + """Configuration for MLflow integration.""" + + # Tracking server URIs + tracking_uri: str = field( + default_factory=lambda: os.environ.get( + "MLFLOW_TRACKING_URI", + "http://mlflow.mlflow.svc.cluster.local:80" + ) + ) + external_uri: str = field( + default_factory=lambda: os.environ.get( + "MLFLOW_EXTERNAL_URI", + "https://mlflow.lab.daviestechlabs.io" + ) + ) + + # Artifact storage (NFS PVC mount) + artifact_location: str = field( + default_factory=lambda: os.environ.get( + "MLFLOW_ARTIFACT_LOCATION", + "/mlflow/artifacts" + ) + ) + + # Default experiment settings + default_experiment: str = field( + default_factory=lambda: os.environ.get( + "MLFLOW_DEFAULT_EXPERIMENT", + "llm-workflows" + ) + ) + + # Service identification + service_name: str = field( + default_factory=lambda: os.environ.get( + "OTEL_SERVICE_NAME", + "unknown-service" + ) + ) + + # Additional tags to add to all runs + default_tags: Dict[str, str] = field(default_factory=dict) + + def __post_init__(self): + """Add default tags based on environment.""" + env_tags = { + "environment": os.environ.get("DEPLOYMENT_ENV", "production"), + "hostname": os.environ.get("HOSTNAME", "unknown"), + "namespace": os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml"), + } + self.default_tags = {**env_tags, **self.default_tags} + + +def get_tracking_uri(external: bool = False) -> str: + """ + Get the appropriate MLflow tracking URI. + + Args: + external: If True, return the external URI for outside-cluster access + + Returns: + The MLflow tracking URI string + """ + config = MLflowConfig() + return config.external_uri if external else config.tracking_uri + + +def get_mlflow_client( + tracking_uri: Optional[str] = None, + configure_global: bool = True +) -> MlflowClient: + """ + Get a configured MLflow client. + + Args: + tracking_uri: Override the default tracking URI + configure_global: If True, also set mlflow.set_tracking_uri() + + Returns: + Configured MlflowClient instance + """ + uri = tracking_uri or get_tracking_uri() + + if configure_global: + mlflow.set_tracking_uri(uri) + logger.info(f"MLflow tracking URI set to: {uri}") + + client = MlflowClient(tracking_uri=uri) + return client + + +def ensure_experiment( + experiment_name: str, + artifact_location: Optional[str] = None, + tags: Optional[Dict[str, str]] = None +) -> str: + """ + Ensure an experiment exists, creating it if necessary. + + Args: + experiment_name: Name of the experiment + artifact_location: Override default artifact location + tags: Additional tags for the experiment + + Returns: + The experiment ID + """ + config = MLflowConfig() + client = get_mlflow_client() + + # Check if experiment exists + experiment = client.get_experiment_by_name(experiment_name) + + if experiment is None: + # Create the experiment + artifact_loc = artifact_location or f"{config.artifact_location}/{experiment_name}" + experiment_id = client.create_experiment( + name=experiment_name, + artifact_location=artifact_loc, + tags=tags or {} + ) + logger.info(f"Created experiment '{experiment_name}' with ID: {experiment_id}") + else: + experiment_id = experiment.experiment_id + logger.debug(f"Using existing experiment '{experiment_name}' with ID: {experiment_id}") + + return experiment_id + + +def get_or_create_registered_model( + model_name: str, + description: Optional[str] = None, + tags: Optional[Dict[str, str]] = None +) -> str: + """ + Get or create a registered model in the Model Registry. + + Args: + model_name: Name of the model to register + description: Model description + tags: Tags for the model + + Returns: + The registered model name + """ + client = get_mlflow_client() + + try: + # Check if model exists + client.get_registered_model(model_name) + logger.debug(f"Using existing registered model: {model_name}") + except mlflow.exceptions.MlflowException: + # Create the model + client.create_registered_model( + name=model_name, + description=description or f"Model for {model_name}", + tags=tags or {} + ) + logger.info(f"Created registered model: {model_name}") + + return model_name + + +def health_check() -> Dict[str, Any]: + """ + Check MLflow server connectivity and return status. + + Returns: + Dictionary with health status information + """ + config = MLflowConfig() + result = { + "tracking_uri": config.tracking_uri, + "external_uri": config.external_uri, + "connected": False, + "error": None, + } + + try: + client = get_mlflow_client(configure_global=False) + # Try to list experiments as a health check + experiments = client.search_experiments(max_results=1) + result["connected"] = True + result["experiment_count"] = len(experiments) + except Exception as e: + result["error"] = str(e) + logger.error(f"MLflow health check failed: {e}") + + return result diff --git a/mlflow_utils/experiment_comparison.py b/mlflow_utils/experiment_comparison.py new file mode 100644 index 0000000..f66d164 --- /dev/null +++ b/mlflow_utils/experiment_comparison.py @@ -0,0 +1,664 @@ +""" +Experiment Comparison and Analysis Utilities + +Provides tools for comparing model versions, querying experiments, +and making data-driven decisions about model promotion to production. + +Features: +- Compare multiple runs/experiments side by side +- Query experiments by tags, metrics, or parameters +- Analyze inference metrics from NATS handlers +- Generate promotion recommendations +- Export comparison reports + +Usage: + from mlflow_utils.experiment_comparison import ( + ExperimentAnalyzer, + compare_runs, + get_best_run, + promotion_recommendation, + ) + + analyzer = ExperimentAnalyzer("chat-inference") + + # Compare last N runs + comparison = analyzer.compare_recent_runs(n=5) + + # Find best performing model + best = analyzer.get_best_run(metric="total_latency_mean", minimize=True) + + # Get promotion recommendation + rec = analyzer.promotion_recommendation( + model_name="whisper-finetuned", + min_accuracy=0.9, + max_latency_p95=2.0 + ) +""" + +import os +import json +import logging +from datetime import datetime, timedelta +from typing import Optional, Dict, Any, List, Tuple, Union +from dataclasses import dataclass, field +from collections import defaultdict + +import mlflow +from mlflow.tracking import MlflowClient +from mlflow.entities import Run, Experiment + +from .client import get_mlflow_client, MLflowConfig + +logger = logging.getLogger(__name__) + + +@dataclass +class RunComparison: + """Comparison result for multiple MLflow runs.""" + run_ids: List[str] + experiment_name: str + + # Metric comparisons (metric_name -> {run_id -> value}) + metrics: Dict[str, Dict[str, float]] = field(default_factory=dict) + + # Parameter differences + params: Dict[str, Dict[str, str]] = field(default_factory=dict) + + # Run metadata + run_names: Dict[str, str] = field(default_factory=dict) + start_times: Dict[str, datetime] = field(default_factory=dict) + durations: Dict[str, float] = field(default_factory=dict) + + # Best performers by metric + best_by_metric: Dict[str, str] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "run_ids": self.run_ids, + "experiment_name": self.experiment_name, + "metrics": self.metrics, + "params": self.params, + "run_names": self.run_names, + "best_by_metric": self.best_by_metric, + } + + def summary_table(self) -> str: + """Generate a text summary table of the comparison.""" + if not self.run_ids: + return "No runs to compare" + + lines = [] + lines.append(f"Experiment: {self.experiment_name}") + lines.append(f"Comparing {len(self.run_ids)} runs") + lines.append("") + + # Header + header = ["Metric"] + [self.run_names.get(rid, rid[:8]) for rid in self.run_ids] + lines.append(" | ".join(header)) + lines.append("-" * (len(lines[-1]) + 10)) + + # Metrics + for metric_name, values in sorted(self.metrics.items()): + row = [metric_name] + for run_id in self.run_ids: + value = values.get(run_id) + if value is not None: + row.append(f"{value:.4f}") + else: + row.append("N/A") + lines.append(" | ".join(row)) + + return "\n".join(lines) + + +@dataclass +class PromotionRecommendation: + """Recommendation for model promotion.""" + model_name: str + version: Optional[int] + recommended: bool + reasons: List[str] + metrics_summary: Dict[str, float] + comparison_with_production: Optional[Dict[str, Any]] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "model_name": self.model_name, + "version": self.version, + "recommended": self.recommended, + "reasons": self.reasons, + "metrics_summary": self.metrics_summary, + "comparison_with_production": self.comparison_with_production, + } + + +class ExperimentAnalyzer: + """ + Analyze MLflow experiments for model comparison and promotion decisions. + + Example: + analyzer = ExperimentAnalyzer("chat-inference") + + # Get metrics summary for last 24 hours + summary = analyzer.get_metrics_summary(hours=24) + + # Compare models by accuracy + best = analyzer.get_best_run(metric="eval.accuracy", minimize=False) + + # Analyze inference latency trends + trends = analyzer.get_metric_trends("total_latency_mean", days=7) + """ + + def __init__( + self, + experiment_name: str, + tracking_uri: Optional[str] = None, + ): + """ + Initialize the experiment analyzer. + + Args: + experiment_name: Name of the MLflow experiment to analyze + tracking_uri: Override default tracking URI + """ + self.experiment_name = experiment_name + self.tracking_uri = tracking_uri + self.client = get_mlflow_client(tracking_uri=tracking_uri) + self._experiment: Optional[Experiment] = None + + @property + def experiment(self) -> Optional[Experiment]: + """Get the experiment object, fetching if needed.""" + if self._experiment is None: + self._experiment = self.client.get_experiment_by_name(self.experiment_name) + return self._experiment + + def search_runs( + self, + filter_string: str = "", + order_by: Optional[List[str]] = None, + max_results: int = 100, + run_view_type: str = "ACTIVE_ONLY", + ) -> List[Run]: + """ + Search for runs matching criteria. + + Args: + filter_string: MLflow search filter (e.g., "metrics.accuracy > 0.9") + order_by: List of order clauses (e.g., ["metrics.accuracy DESC"]) + max_results: Maximum runs to return + run_view_type: ACTIVE_ONLY, DELETED_ONLY, or ALL + + Returns: + List of matching Run objects + """ + if not self.experiment: + logger.warning(f"Experiment '{self.experiment_name}' not found") + return [] + + runs = self.client.search_runs( + experiment_ids=[self.experiment.experiment_id], + filter_string=filter_string, + order_by=order_by or ["start_time DESC"], + max_results=max_results, + ) + + return runs + + def get_recent_runs( + self, + n: int = 10, + hours: Optional[int] = None, + ) -> List[Run]: + """ + Get the most recent runs. + + Args: + n: Number of runs to return + hours: Only include runs from the last N hours + + Returns: + List of Run objects + """ + filter_string = "" + if hours: + cutoff = datetime.now() - timedelta(hours=hours) + cutoff_ms = int(cutoff.timestamp() * 1000) + filter_string = f"attributes.start_time >= {cutoff_ms}" + + return self.search_runs( + filter_string=filter_string, + order_by=["start_time DESC"], + max_results=n, + ) + + def compare_runs( + self, + run_ids: Optional[List[str]] = None, + n_recent: int = 5, + ) -> RunComparison: + """ + Compare multiple runs side by side. + + Args: + run_ids: Specific run IDs to compare, or None for recent runs + n_recent: If run_ids is None, compare this many recent runs + + Returns: + RunComparison object with detailed comparison + """ + if run_ids: + runs = [self.client.get_run(rid) for rid in run_ids] + else: + runs = self.get_recent_runs(n=n_recent) + + comparison = RunComparison( + run_ids=[r.info.run_id for r in runs], + experiment_name=self.experiment_name, + ) + + # Collect all metrics and find best performers + all_metrics: Dict[str, Dict[str, float]] = defaultdict(dict) + + for run in runs: + run_id = run.info.run_id + + # Metadata + comparison.run_names[run_id] = run.info.run_name or run_id[:8] + comparison.start_times[run_id] = datetime.fromtimestamp( + run.info.start_time / 1000 + ) + if run.info.end_time: + comparison.durations[run_id] = ( + run.info.end_time - run.info.start_time + ) / 1000 + + # Metrics + for key, value in run.data.metrics.items(): + all_metrics[key][run_id] = value + + # Params + for key, value in run.data.params.items(): + if key not in comparison.params: + comparison.params[key] = {} + comparison.params[key][run_id] = value + + comparison.metrics = dict(all_metrics) + + # Find best performers for each metric + for metric_name, values in all_metrics.items(): + if not values: + continue + + # Determine if lower is better based on metric name + minimize = any( + term in metric_name.lower() + for term in ["latency", "error", "loss", "time"] + ) + + if minimize: + best_id = min(values.keys(), key=lambda k: values[k]) + else: + best_id = max(values.keys(), key=lambda k: values[k]) + + comparison.best_by_metric[metric_name] = best_id + + return comparison + + def get_best_run( + self, + metric: str, + minimize: bool = True, + filter_string: str = "", + max_results: int = 100, + ) -> Optional[Run]: + """ + Get the best run by a specific metric. + + Args: + metric: Metric name to optimize + minimize: If True, find minimum; if False, find maximum + filter_string: Additional filter criteria + max_results: Maximum runs to consider + + Returns: + Best Run object, or None if no runs found + """ + direction = "ASC" if minimize else "DESC" + + runs = self.search_runs( + filter_string=filter_string, + order_by=[f"metrics.{metric} {direction}"], + max_results=max_results, + ) + + # Filter to only runs that have the metric + runs_with_metric = [ + r for r in runs + if metric in r.data.metrics + ] + + return runs_with_metric[0] if runs_with_metric else None + + def get_metrics_summary( + self, + hours: Optional[int] = None, + metrics: Optional[List[str]] = None, + ) -> Dict[str, Dict[str, float]]: + """ + Get summary statistics for metrics. + + Args: + hours: Only include runs from the last N hours + metrics: Specific metrics to summarize (None for all) + + Returns: + Dict mapping metric names to {mean, min, max, count} + """ + import statistics + + runs = self.get_recent_runs(n=1000, hours=hours) + + # Collect all metric values + metric_values: Dict[str, List[float]] = defaultdict(list) + + for run in runs: + for key, value in run.data.metrics.items(): + if metrics is None or key in metrics: + metric_values[key].append(value) + + # Calculate statistics + summary = {} + for metric_name, values in metric_values.items(): + if not values: + continue + + summary[metric_name] = { + "mean": statistics.mean(values), + "min": min(values), + "max": max(values), + "count": len(values), + } + + if len(values) >= 2: + summary[metric_name]["stdev"] = statistics.stdev(values) + summary[metric_name]["median"] = statistics.median(values) + + return summary + + def get_metric_trends( + self, + metric: str, + days: int = 7, + granularity_hours: int = 1, + ) -> List[Dict[str, Any]]: + """ + Get metric trends over time. + + Args: + metric: Metric name to track + days: Number of days to look back + granularity_hours: Time bucket size in hours + + Returns: + List of {timestamp, mean, min, max, count} dicts + """ + import statistics + + runs = self.get_recent_runs(n=10000, hours=days * 24) + + # Group runs by time bucket + buckets: Dict[int, List[float]] = defaultdict(list) + bucket_size_ms = granularity_hours * 3600 * 1000 + + for run in runs: + if metric not in run.data.metrics: + continue + + bucket = (run.info.start_time // bucket_size_ms) * bucket_size_ms + buckets[bucket].append(run.data.metrics[metric]) + + # Calculate statistics per bucket + trends = [] + for bucket_ts, values in sorted(buckets.items()): + trend = { + "timestamp": datetime.fromtimestamp(bucket_ts / 1000).isoformat(), + "count": len(values), + "mean": statistics.mean(values), + "min": min(values), + "max": max(values), + } + if len(values) >= 2: + trend["stdev"] = statistics.stdev(values) + trends.append(trend) + + return trends + + def get_runs_by_tag( + self, + tag_key: str, + tag_value: str, + max_results: int = 100, + ) -> List[Run]: + """ + Get runs with a specific tag. + + Args: + tag_key: Tag key to filter by + tag_value: Tag value to match + max_results: Maximum runs to return + + Returns: + List of matching Run objects + """ + return self.search_runs( + filter_string=f"tags.{tag_key} = '{tag_value}'", + max_results=max_results, + ) + + def get_model_runs( + self, + model_name: str, + max_results: int = 100, + ) -> List[Run]: + """ + Get runs for a specific model. + + Args: + model_name: Model name to filter by + max_results: Maximum runs to return + + Returns: + List of matching Run objects + """ + # Try different tag conventions + runs = self.search_runs( + filter_string=f"tags.`model.name` = '{model_name}'", + max_results=max_results, + ) + + if not runs: + # Try params + runs = self.search_runs( + filter_string=f"params.model_name = '{model_name}'", + max_results=max_results, + ) + + return runs + + +def compare_experiments( + experiment_names: List[str], + metric: str, + tracking_uri: Optional[str] = None, +) -> Dict[str, Dict[str, float]]: + """ + Compare metrics across multiple experiments. + + Args: + experiment_names: Names of experiments to compare + metric: Metric to compare + tracking_uri: Override default tracking URI + + Returns: + Dict mapping experiment names to metric statistics + """ + results = {} + + for exp_name in experiment_names: + analyzer = ExperimentAnalyzer(exp_name, tracking_uri=tracking_uri) + summary = analyzer.get_metrics_summary(metrics=[metric]) + if metric in summary: + results[exp_name] = summary[metric] + + return results + + +def promotion_recommendation( + model_name: str, + experiment_name: str, + criteria: Dict[str, Tuple[str, float]], + tracking_uri: Optional[str] = None, +) -> PromotionRecommendation: + """ + Generate a recommendation for model promotion. + + Args: + model_name: Name of the model to evaluate + experiment_name: Experiment containing evaluation runs + criteria: Dict of {metric: (comparison, threshold)} + comparison is one of: ">=", "<=", ">", "<" + e.g., {"eval.accuracy": (">=", 0.9), "total_latency_p95": ("<=", 2.0)} + tracking_uri: Override default tracking URI + + Returns: + PromotionRecommendation with decision and reasons + """ + analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri) + + # Get model runs + runs = analyzer.get_model_runs(model_name, max_results=10) + + if not runs: + return PromotionRecommendation( + model_name=model_name, + version=None, + recommended=False, + reasons=["No runs found for this model"], + metrics_summary={}, + ) + + # Get the most recent run + latest_run = runs[0] + metrics = latest_run.data.metrics + + # Evaluate criteria + reasons = [] + passed = True + + comparisons = { + ">=": lambda a, b: a >= b, + "<=": lambda a, b: a <= b, + ">": lambda a, b: a > b, + "<": lambda a, b: a < b, + } + + for metric_name, (comparison, threshold) in criteria.items(): + if metric_name not in metrics: + reasons.append(f"Metric '{metric_name}' not found") + passed = False + continue + + value = metrics[metric_name] + compare_fn = comparisons.get(comparison) + + if compare_fn is None: + reasons.append(f"Invalid comparison operator: {comparison}") + continue + + if compare_fn(value, threshold): + reasons.append(f"✓ {metric_name}: {value:.4f} {comparison} {threshold}") + else: + reasons.append(f"✗ {metric_name}: {value:.4f} NOT {comparison} {threshold}") + passed = False + + # Extract version from tags if available + version = None + if "mlflow.version" in latest_run.data.tags: + try: + version = int(latest_run.data.tags["mlflow.version"]) + except ValueError: + pass + + return PromotionRecommendation( + model_name=model_name, + version=version, + recommended=passed, + reasons=reasons, + metrics_summary=dict(metrics), + ) + + +def get_inference_performance_report( + service_name: str = "chat-handler", + hours: int = 24, + tracking_uri: Optional[str] = None, +) -> Dict[str, Any]: + """ + Generate an inference performance report for a service. + + Args: + service_name: Service name (chat-handler, voice-assistant) + hours: Hours of data to analyze + tracking_uri: Override default tracking URI + + Returns: + Performance report dictionary + """ + experiment_name = f"{service_name.replace('-', '')}-inference" + analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri) + + # Get summary metrics + summary = analyzer.get_metrics_summary(hours=hours) + + # Key latency metrics + latency_metrics = [ + "total_latency_mean", + "total_latency_p50", + "total_latency_p95", + "llm_latency_mean", + "embedding_latency_mean", + "rag_search_latency_mean", + ] + + report = { + "service": service_name, + "period_hours": hours, + "generated_at": datetime.now().isoformat(), + "latency": {}, + "throughput": {}, + "rag": {}, + "errors": {}, + } + + # Latency section + for metric in latency_metrics: + if metric in summary: + report["latency"][metric] = summary[metric] + + # Throughput + if "total_requests" in summary: + report["throughput"]["total_requests"] = summary["total_requests"]["mean"] + + # RAG usage + rag_metrics = ["rag_enabled_pct", "rag_documents_retrieved_mean", "rag_documents_used_mean"] + for metric in rag_metrics: + if metric in summary: + report["rag"][metric] = summary[metric] + + # Error rate + if "error_rate" in summary: + report["errors"]["error_rate_pct"] = summary["error_rate"]["mean"] + + return report diff --git a/mlflow_utils/inference_tracker.py b/mlflow_utils/inference_tracker.py new file mode 100644 index 0000000..820db8a --- /dev/null +++ b/mlflow_utils/inference_tracker.py @@ -0,0 +1,431 @@ +""" +Inference Metrics Tracker for NATS Handlers + +Provides async-compatible MLflow logging for real-time inference +metrics from chat-handler and voice-assistant services. + +Designed to integrate with the existing OpenTelemetry setup and +complement OTel metrics with MLflow experiment tracking for +longer-term analysis and model comparison. +""" + +import os +import time +import asyncio +import logging +from typing import Optional, Dict, Any, List +from dataclasses import dataclass, field +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from functools import partial + +import mlflow +from mlflow.tracking import MlflowClient + +from .client import get_mlflow_client, ensure_experiment, MLflowConfig + +logger = logging.getLogger(__name__) + + +@dataclass +class InferenceMetrics: + """Metrics collected during an inference request.""" + request_id: str + user_id: Optional[str] = None + session_id: Optional[str] = None + + # Timing metrics (in seconds) + total_latency: float = 0.0 + embedding_latency: float = 0.0 + rag_search_latency: float = 0.0 + rerank_latency: float = 0.0 + llm_latency: float = 0.0 + tts_latency: float = 0.0 + stt_latency: float = 0.0 + + # Token/size metrics + input_tokens: int = 0 + output_tokens: int = 0 + total_tokens: int = 0 + prompt_length: int = 0 + response_length: int = 0 + + # RAG metrics + rag_enabled: bool = False + rag_documents_retrieved: int = 0 + rag_documents_used: int = 0 + reranker_enabled: bool = False + + # Quality indicators + is_streaming: bool = False + is_premium: bool = False + has_error: bool = False + error_message: Optional[str] = None + + # Model information + model_name: Optional[str] = None + model_endpoint: Optional[str] = None + + # Timestamps + timestamp: float = field(default_factory=time.time) + + def as_metrics_dict(self) -> Dict[str, float]: + """Convert numeric fields to a metrics dictionary.""" + return { + "total_latency": self.total_latency, + "embedding_latency": self.embedding_latency, + "rag_search_latency": self.rag_search_latency, + "rerank_latency": self.rerank_latency, + "llm_latency": self.llm_latency, + "tts_latency": self.tts_latency, + "stt_latency": self.stt_latency, + "input_tokens": float(self.input_tokens), + "output_tokens": float(self.output_tokens), + "total_tokens": float(self.total_tokens), + "prompt_length": float(self.prompt_length), + "response_length": float(self.response_length), + "rag_documents_retrieved": float(self.rag_documents_retrieved), + "rag_documents_used": float(self.rag_documents_used), + } + + def as_params_dict(self) -> Dict[str, str]: + """Convert configuration fields to a params dictionary.""" + params = { + "rag_enabled": str(self.rag_enabled), + "reranker_enabled": str(self.reranker_enabled), + "is_streaming": str(self.is_streaming), + "is_premium": str(self.is_premium), + } + if self.model_name: + params["model_name"] = self.model_name + if self.model_endpoint: + params["model_endpoint"] = self.model_endpoint + return params + + +class InferenceMetricsTracker: + """ + Async-compatible MLflow tracker for inference metrics. + + Uses batching and a background thread pool to avoid blocking + the async event loop during MLflow calls. + + Example usage in chat-handler: + + class ChatHandler: + def __init__(self): + self.mlflow_tracker = InferenceMetricsTracker( + service_name="chat-handler", + experiment_name="chat-inference" + ) + + async def setup(self): + await self.mlflow_tracker.start() + + async def process_request(self, msg): + metrics = InferenceMetrics(request_id=request_id) + + # Track timing + start = time.time() + # ... do embedding ... + metrics.embedding_latency = time.time() - start + + # ... more processing ... + + # Log metrics (non-blocking) + await self.mlflow_tracker.log_inference(metrics) + + async def shutdown(self): + await self.mlflow_tracker.stop() + """ + + def __init__( + self, + service_name: str, + experiment_name: Optional[str] = None, + tracking_uri: Optional[str] = None, + batch_size: int = 50, + flush_interval_seconds: float = 60.0, + enable_batching: bool = True, + max_workers: int = 2, + ): + """ + Initialize the inference metrics tracker. + + Args: + service_name: Name of the service (e.g., "chat-handler") + experiment_name: MLflow experiment name (defaults to service_name) + tracking_uri: Override default tracking URI + batch_size: Number of metrics to batch before flushing + flush_interval_seconds: Maximum time between flushes + enable_batching: If False, log each request immediately + max_workers: Number of thread pool workers for MLflow calls + """ + self.service_name = service_name + self.experiment_name = experiment_name or f"{service_name}-inference" + self.tracking_uri = tracking_uri + self.batch_size = batch_size + self.flush_interval = flush_interval_seconds + self.enable_batching = enable_batching + + self.config = MLflowConfig() + self._batch: List[InferenceMetrics] = [] + self._batch_lock = asyncio.Lock() + self._flush_task: Optional[asyncio.Task] = None + self._executor = ThreadPoolExecutor(max_workers=max_workers) + self._running = False + self._client: Optional[MlflowClient] = None + self._experiment_id: Optional[str] = None + + # Aggregate metrics for periodic logging + self._aggregate_metrics: Dict[str, List[float]] = defaultdict(list) + self._request_count = 0 + self._error_count = 0 + + async def start(self) -> None: + """Start the tracker and initialize MLflow connection.""" + if self._running: + return + + self._running = True + + # Initialize MLflow in thread pool to avoid blocking + loop = asyncio.get_event_loop() + await loop.run_in_executor( + self._executor, + self._init_mlflow + ) + + if self.enable_batching: + self._flush_task = asyncio.create_task(self._periodic_flush()) + + logger.info( + f"InferenceMetricsTracker started for {self.service_name} " + f"(experiment: {self.experiment_name})" + ) + + def _init_mlflow(self) -> None: + """Initialize MLflow client and experiment (runs in thread pool).""" + self._client = get_mlflow_client( + tracking_uri=self.tracking_uri, + configure_global=True + ) + self._experiment_id = ensure_experiment( + self.experiment_name, + tags={ + "service": self.service_name, + "type": "inference-metrics", + } + ) + + async def stop(self) -> None: + """Stop the tracker and flush remaining metrics.""" + if not self._running: + return + + self._running = False + + if self._flush_task: + self._flush_task.cancel() + try: + await self._flush_task + except asyncio.CancelledError: + pass + + # Final flush + await self._flush_batch() + + self._executor.shutdown(wait=True) + logger.info(f"InferenceMetricsTracker stopped for {self.service_name}") + + async def log_inference(self, metrics: InferenceMetrics) -> None: + """ + Log inference metrics (non-blocking). + + Args: + metrics: InferenceMetrics object with request data + """ + if not self._running: + logger.warning("Tracker not running, skipping metrics") + return + + self._request_count += 1 + if metrics.has_error: + self._error_count += 1 + + # Update aggregates + for key, value in metrics.as_metrics_dict().items(): + if value > 0: + self._aggregate_metrics[key].append(value) + + if self.enable_batching: + async with self._batch_lock: + self._batch.append(metrics) + if len(self._batch) >= self.batch_size: + asyncio.create_task(self._flush_batch()) + else: + # Immediate logging in thread pool + loop = asyncio.get_event_loop() + await loop.run_in_executor( + self._executor, + partial(self._log_single_inference, metrics) + ) + + async def _periodic_flush(self) -> None: + """Periodically flush batched metrics.""" + while self._running: + await asyncio.sleep(self.flush_interval) + await self._flush_batch() + + async def _flush_batch(self) -> None: + """Flush the current batch of metrics to MLflow.""" + async with self._batch_lock: + if not self._batch: + return + + batch = self._batch + self._batch = [] + + # Log in thread pool + loop = asyncio.get_event_loop() + await loop.run_in_executor( + self._executor, + partial(self._log_batch, batch) + ) + + def _log_single_inference(self, metrics: InferenceMetrics) -> None: + """Log a single inference request to MLflow (runs in thread pool).""" + try: + with mlflow.start_run( + experiment_id=self._experiment_id, + run_name=f"inference-{metrics.request_id}", + tags={ + "service": self.service_name, + "request_id": metrics.request_id, + "type": "single-inference", + } + ): + mlflow.log_params(metrics.as_params_dict()) + mlflow.log_metrics(metrics.as_metrics_dict()) + + if metrics.user_id: + mlflow.set_tag("user_id", metrics.user_id) + if metrics.session_id: + mlflow.set_tag("session_id", metrics.session_id) + if metrics.has_error: + mlflow.set_tag("has_error", "true") + if metrics.error_message: + mlflow.set_tag("error_message", metrics.error_message[:250]) + except Exception as e: + logger.error(f"Failed to log inference metrics: {e}") + + def _log_batch(self, batch: List[InferenceMetrics]) -> None: + """Log a batch of inference metrics as aggregate statistics.""" + if not batch: + return + + try: + # Calculate aggregates + aggregates = self._calculate_aggregates(batch) + + run_name = f"batch-{self.service_name}-{int(time.time())}" + + with mlflow.start_run( + experiment_id=self._experiment_id, + run_name=run_name, + tags={ + "service": self.service_name, + "type": "batch-inference", + "batch_size": str(len(batch)), + } + ): + # Log aggregate metrics + mlflow.log_metrics(aggregates) + + # Log batch info + mlflow.log_param("batch_size", len(batch)) + mlflow.log_param("time_window_start", min(m.timestamp for m in batch)) + mlflow.log_param("time_window_end", max(m.timestamp for m in batch)) + + # Log configuration breakdown + rag_enabled_count = sum(1 for m in batch if m.rag_enabled) + streaming_count = sum(1 for m in batch if m.is_streaming) + premium_count = sum(1 for m in batch if m.is_premium) + error_count = sum(1 for m in batch if m.has_error) + + mlflow.log_metrics({ + "rag_enabled_pct": rag_enabled_count / len(batch) * 100, + "streaming_pct": streaming_count / len(batch) * 100, + "premium_pct": premium_count / len(batch) * 100, + "error_rate": error_count / len(batch) * 100, + }) + + # Log model distribution + model_counts: Dict[str, int] = defaultdict(int) + for m in batch: + if m.model_name: + model_counts[m.model_name] += 1 + + if model_counts: + mlflow.log_dict( + {"models": dict(model_counts)}, + "model_distribution.json" + ) + + logger.info(f"Logged batch of {len(batch)} inference metrics") + + except Exception as e: + logger.error(f"Failed to log batch metrics: {e}") + + def _calculate_aggregates( + self, + batch: List[InferenceMetrics] + ) -> Dict[str, float]: + """Calculate aggregate statistics from a batch of metrics.""" + import statistics + + aggregates = {} + + # Collect all numeric metrics + metric_values: Dict[str, List[float]] = defaultdict(list) + for m in batch: + for key, value in m.as_metrics_dict().items(): + if value > 0: + metric_values[key].append(value) + + # Calculate statistics for each metric + for key, values in metric_values.items(): + if not values: + continue + + aggregates[f"{key}_mean"] = statistics.mean(values) + aggregates[f"{key}_min"] = min(values) + aggregates[f"{key}_max"] = max(values) + + if len(values) >= 2: + aggregates[f"{key}_p50"] = statistics.median(values) + aggregates[f"{key}_stdev"] = statistics.stdev(values) + + if len(values) >= 4: + sorted_vals = sorted(values) + p95_idx = int(len(sorted_vals) * 0.95) + p99_idx = int(len(sorted_vals) * 0.99) + aggregates[f"{key}_p95"] = sorted_vals[p95_idx] + aggregates[f"{key}_p99"] = sorted_vals[p99_idx] + + # Add counts + aggregates["total_requests"] = float(len(batch)) + + return aggregates + + def get_stats(self) -> Dict[str, Any]: + """Get current tracker statistics.""" + return { + "service_name": self.service_name, + "experiment_name": self.experiment_name, + "running": self._running, + "total_requests": self._request_count, + "error_count": self._error_count, + "pending_batch_size": len(self._batch), + "aggregate_metrics_count": len(self._aggregate_metrics), + } diff --git a/mlflow_utils/kfp_components.py b/mlflow_utils/kfp_components.py new file mode 100644 index 0000000..f3cdf18 --- /dev/null +++ b/mlflow_utils/kfp_components.py @@ -0,0 +1,513 @@ +""" +Kubeflow Pipeline Components with MLflow Tracking + +Provides reusable KFP components that integrate MLflow experiment +tracking into Kubeflow Pipelines. These components can be used +directly in pipelines or as wrappers around existing pipeline steps. + +Usage in a Kubeflow Pipeline: + + from mlflow_utils.kfp_components import ( + create_mlflow_run, + log_metrics_component, + log_model_artifact, + end_mlflow_run, + ) + + @dsl.pipeline(name="my-pipeline") + def my_pipeline(): + # Start MLflow run + run_info = create_mlflow_run( + experiment_name="my-experiment", + run_name="training-run-1" + ) + + # ... your pipeline steps ... + + # Log metrics + log_step = log_metrics_component( + run_id=run_info.outputs["run_id"], + metrics={"accuracy": 0.95, "loss": 0.05} + ) + + # End run + end_mlflow_run(run_id=run_info.outputs["run_id"]) +""" + +from kfp import dsl +from typing import Dict, Any, List, Optional, NamedTuple + + +# MLflow component image with all required dependencies +MLFLOW_IMAGE = "python:3.13-slim" +MLFLOW_PACKAGES = [ + "mlflow>=2.10.0", + "boto3", # For S3 artifact storage if needed + "psycopg2-binary", # For PostgreSQL backend +] + + +@dsl.component( + base_image=MLFLOW_IMAGE, + packages_to_install=MLFLOW_PACKAGES +) +def create_mlflow_run( + experiment_name: str, + run_name: str, + mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", + tags: Dict[str, str] = None, + params: Dict[str, str] = None, +) -> NamedTuple('RunInfo', [('run_id', str), ('experiment_id', str), ('artifact_uri', str)]): + """ + Create a new MLflow run for the pipeline. + + This should be called at the start of a pipeline to initialize + tracking. The returned run_id should be passed to subsequent + components for logging. + + Args: + experiment_name: Name of the MLflow experiment + run_name: Name for this specific run + mlflow_tracking_uri: MLflow tracking server URI + tags: Optional tags to add to the run + params: Optional parameters to log + + Returns: + NamedTuple with run_id, experiment_id, and artifact_uri + """ + import os + import mlflow + from mlflow.tracking import MlflowClient + from collections import namedtuple + + # Set tracking URI + mlflow.set_tracking_uri(mlflow_tracking_uri) + + client = MlflowClient() + + # Get or create experiment + experiment = client.get_experiment_by_name(experiment_name) + if experiment is None: + experiment_id = client.create_experiment( + name=experiment_name, + artifact_location=f"/mlflow/artifacts/{experiment_name}" + ) + else: + experiment_id = experiment.experiment_id + + # Create default tags + default_tags = { + "pipeline.type": "kubeflow", + "kfp.run_id": os.environ.get("KFP_RUN_ID", "unknown"), + "kfp.pod_name": os.environ.get("HOSTNAME", "unknown"), + } + if tags: + default_tags.update(tags) + + # Start run + run = mlflow.start_run( + experiment_id=experiment_id, + run_name=run_name, + tags=default_tags, + ) + + # Log initial params + if params: + mlflow.log_params(params) + + run_id = run.info.run_id + artifact_uri = run.info.artifact_uri + + # End run (KFP components are isolated, we'll resume in other components) + mlflow.end_run() + + RunInfo = namedtuple('RunInfo', ['run_id', 'experiment_id', 'artifact_uri']) + return RunInfo(run_id, experiment_id, artifact_uri) + + +@dsl.component( + base_image=MLFLOW_IMAGE, + packages_to_install=MLFLOW_PACKAGES +) +def log_params_component( + run_id: str, + params: Dict[str, str], + mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", +) -> str: + """ + Log parameters to an existing MLflow run. + + Args: + run_id: The MLflow run ID to log to + params: Dictionary of parameters to log + mlflow_tracking_uri: MLflow tracking server URI + + Returns: + The run_id for chaining + """ + import mlflow + from mlflow.tracking import MlflowClient + + mlflow.set_tracking_uri(mlflow_tracking_uri) + client = MlflowClient() + + for key, value in params.items(): + client.log_param(run_id, key, str(value)[:500]) + + return run_id + + +@dsl.component( + base_image=MLFLOW_IMAGE, + packages_to_install=MLFLOW_PACKAGES +) +def log_metrics_component( + run_id: str, + metrics: Dict[str, float], + step: int = 0, + mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", +) -> str: + """ + Log metrics to an existing MLflow run. + + Args: + run_id: The MLflow run ID to log to + metrics: Dictionary of metrics to log + step: Step number for time-series metrics + mlflow_tracking_uri: MLflow tracking server URI + + Returns: + The run_id for chaining + """ + import mlflow + from mlflow.tracking import MlflowClient + + mlflow.set_tracking_uri(mlflow_tracking_uri) + client = MlflowClient() + + for key, value in metrics.items(): + client.log_metric(run_id, key, float(value), step=step) + + return run_id + + +@dsl.component( + base_image=MLFLOW_IMAGE, + packages_to_install=MLFLOW_PACKAGES +) +def log_artifact_component( + run_id: str, + artifact_path: str, + artifact_name: str = "", + mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", +) -> str: + """ + Log an artifact file to an existing MLflow run. + + Args: + run_id: The MLflow run ID to log to + artifact_path: Path to the artifact file + artifact_name: Optional destination name in artifact store + mlflow_tracking_uri: MLflow tracking server URI + + Returns: + The run_id for chaining + """ + import mlflow + from mlflow.tracking import MlflowClient + + mlflow.set_tracking_uri(mlflow_tracking_uri) + client = MlflowClient() + + client.log_artifact(run_id, artifact_path, artifact_name or None) + + return run_id + + +@dsl.component( + base_image=MLFLOW_IMAGE, + packages_to_install=MLFLOW_PACKAGES +) +def log_dict_artifact( + run_id: str, + data: Dict[str, Any], + filename: str, + mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", +) -> str: + """ + Log a dictionary as a JSON artifact. + + Args: + run_id: The MLflow run ID to log to + data: Dictionary to save as JSON + filename: Name for the JSON file + mlflow_tracking_uri: MLflow tracking server URI + + Returns: + The run_id for chaining + """ + import json + import tempfile + import mlflow + from mlflow.tracking import MlflowClient + from pathlib import Path + + mlflow.set_tracking_uri(mlflow_tracking_uri) + client = MlflowClient() + + # Ensure .json extension + if not filename.endswith('.json'): + filename += '.json' + + # Write to temp file and log + with tempfile.TemporaryDirectory() as tmpdir: + filepath = Path(tmpdir) / filename + with open(filepath, 'w') as f: + json.dump(data, f, indent=2) + client.log_artifact(run_id, str(filepath)) + + return run_id + + +@dsl.component( + base_image=MLFLOW_IMAGE, + packages_to_install=MLFLOW_PACKAGES +) +def end_mlflow_run( + run_id: str, + status: str = "FINISHED", + mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", +) -> str: + """ + End an MLflow run with the specified status. + + Args: + run_id: The MLflow run ID to end + status: Run status (FINISHED, FAILED, KILLED) + mlflow_tracking_uri: MLflow tracking server URI + + Returns: + The run_id + """ + import mlflow + from mlflow.tracking import MlflowClient + from mlflow.entities import RunStatus + + mlflow.set_tracking_uri(mlflow_tracking_uri) + client = MlflowClient() + + status_map = { + "FINISHED": RunStatus.FINISHED, + "FAILED": RunStatus.FAILED, + "KILLED": RunStatus.KILLED, + } + + run_status = status_map.get(status.upper(), RunStatus.FINISHED) + client.set_terminated(run_id, status=run_status) + + return run_id + + +@dsl.component( + base_image=MLFLOW_IMAGE, + packages_to_install=MLFLOW_PACKAGES + ["httpx"] +) +def log_training_metrics( + run_id: str, + model_type: str, + training_config: Dict[str, Any], + final_metrics: Dict[str, float], + model_path: str = "", + mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", +) -> str: + """ + Log comprehensive training metrics for ML models. + + Designed for use with QLoRA training, voice training, and other + ML training pipelines in the llm-workflows repository. + + Args: + run_id: The MLflow run ID to log to + model_type: Type of model (llm, stt, tts, embeddings) + training_config: Training configuration dict + final_metrics: Final training metrics + model_path: Path to saved model (if applicable) + mlflow_tracking_uri: MLflow tracking server URI + + Returns: + The run_id for chaining + """ + import json + import tempfile + import mlflow + from mlflow.tracking import MlflowClient + from pathlib import Path + + mlflow.set_tracking_uri(mlflow_tracking_uri) + client = MlflowClient() + + # Log training config as params + flat_config = {} + for key, value in training_config.items(): + if isinstance(value, (dict, list)): + flat_config[f"config.{key}"] = json.dumps(value)[:500] + else: + flat_config[f"config.{key}"] = str(value)[:500] + + for key, value in flat_config.items(): + client.log_param(run_id, key, value) + + # Log model type tag + client.set_tag(run_id, "model.type", model_type) + + # Log metrics + for key, value in final_metrics.items(): + client.log_metric(run_id, key, float(value)) + + # Log full config as artifact + with tempfile.TemporaryDirectory() as tmpdir: + config_path = Path(tmpdir) / "training_config.json" + with open(config_path, 'w') as f: + json.dump(training_config, f, indent=2) + client.log_artifact(run_id, str(config_path)) + + # Log model path if provided + if model_path: + client.log_param(run_id, "model.path", model_path) + client.set_tag(run_id, "model.saved", "true") + + return run_id + + +@dsl.component( + base_image=MLFLOW_IMAGE, + packages_to_install=MLFLOW_PACKAGES +) +def log_document_ingestion_metrics( + run_id: str, + source_url: str, + collection_name: str, + chunks_created: int, + documents_processed: int, + processing_time_seconds: float, + embeddings_model: str = "bge-small-en-v1.5", + chunk_size: int = 500, + chunk_overlap: int = 50, + mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", +) -> str: + """ + Log document ingestion pipeline metrics. + + Designed for use with the document_ingestion_pipeline. + + Args: + run_id: The MLflow run ID to log to + source_url: URL of the source document + collection_name: Milvus collection name + chunks_created: Number of chunks created + documents_processed: Number of documents processed + processing_time_seconds: Total processing time + embeddings_model: Embeddings model used + chunk_size: Chunk size in tokens + chunk_overlap: Chunk overlap in tokens + mlflow_tracking_uri: MLflow tracking server URI + + Returns: + The run_id for chaining + """ + import mlflow + from mlflow.tracking import MlflowClient + + mlflow.set_tracking_uri(mlflow_tracking_uri) + client = MlflowClient() + + # Log params + params = { + "source_url": source_url[:500], + "collection_name": collection_name, + "embeddings_model": embeddings_model, + "chunk_size": str(chunk_size), + "chunk_overlap": str(chunk_overlap), + } + for key, value in params.items(): + client.log_param(run_id, key, value) + + # Log metrics + metrics = { + "chunks_created": chunks_created, + "documents_processed": documents_processed, + "processing_time_seconds": processing_time_seconds, + "chunks_per_second": chunks_created / processing_time_seconds if processing_time_seconds > 0 else 0, + } + for key, value in metrics.items(): + client.log_metric(run_id, key, float(value)) + + # Set pipeline type tag + client.set_tag(run_id, "pipeline.type", "document-ingestion") + client.set_tag(run_id, "milvus.collection", collection_name) + + return run_id + + +@dsl.component( + base_image=MLFLOW_IMAGE, + packages_to_install=MLFLOW_PACKAGES +) +def log_evaluation_results( + run_id: str, + model_name: str, + dataset_name: str, + metrics: Dict[str, float], + sample_results: List[Dict[str, Any]] = None, + mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80", +) -> str: + """ + Log model evaluation results. + + Designed for use with the evaluation_pipeline. + + Args: + run_id: The MLflow run ID to log to + model_name: Name of the evaluated model + dataset_name: Name of the evaluation dataset + metrics: Evaluation metrics (accuracy, etc.) + sample_results: Optional sample predictions + mlflow_tracking_uri: MLflow tracking server URI + + Returns: + The run_id for chaining + """ + import json + import tempfile + import mlflow + from mlflow.tracking import MlflowClient + from pathlib import Path + + mlflow.set_tracking_uri(mlflow_tracking_uri) + client = MlflowClient() + + # Log params + client.log_param(run_id, "eval.model_name", model_name) + client.log_param(run_id, "eval.dataset", dataset_name) + + # Log metrics + for key, value in metrics.items(): + client.log_metric(run_id, f"eval.{key}", float(value)) + + # Log sample results as artifact + if sample_results: + with tempfile.TemporaryDirectory() as tmpdir: + results_path = Path(tmpdir) / "evaluation_results.json" + with open(results_path, 'w') as f: + json.dump(sample_results, f, indent=2) + client.log_artifact(run_id, str(results_path)) + + # Set tags + client.set_tag(run_id, "pipeline.type", "evaluation") + client.set_tag(run_id, "model.name", model_name) + + # Determine if passed + passed = metrics.get("pass", metrics.get("accuracy", 0) >= 0.7) + client.set_tag(run_id, "eval.passed", str(passed)) + + return run_id diff --git a/mlflow_utils/model_registry.py b/mlflow_utils/model_registry.py new file mode 100644 index 0000000..ccc7674 --- /dev/null +++ b/mlflow_utils/model_registry.py @@ -0,0 +1,545 @@ +""" +MLflow Model Registry Integration for KServe + +Provides utilities for registering trained models in MLflow Model Registry +with metadata needed for deployment to KServe InferenceServices. + +This module bridges the gap between Kubeflow training pipelines and +KServe model serving by: +1. Registering models with proper versioning +2. Adding KServe-specific metadata (runtime, protocol, resources) +3. Managing model stage transitions (Staging → Production) +4. Generating KServe InferenceService manifests from registered models + +Usage: + from mlflow_utils.model_registry import ( + register_model_for_kserve, + promote_model_to_production, + generate_kserve_manifest, + ) + + # Register a new model version + model_version = register_model_for_kserve( + model_name="whisper-finetuned", + model_uri="s3://models/whisper-v2", + model_type="stt", + kserve_config={ + "runtime": "kserve-huggingface", + "container_image": "ghcr.io/my-org/whisper:v2", + } + ) + + # Generate KServe manifest for deployment + manifest = generate_kserve_manifest( + model_name="whisper-finetuned", + model_version=model_version.version, + ) +""" + +import os +import json +import yaml +import logging +from typing import Optional, Dict, Any, List +from dataclasses import dataclass, field + +import mlflow +from mlflow.tracking import MlflowClient +from mlflow.entities.model_registry import ModelVersion + +from .client import get_mlflow_client, MLflowConfig + +logger = logging.getLogger(__name__) + + +@dataclass +class KServeConfig: + """Configuration for KServe deployment.""" + + # Runtime/container configuration + runtime: str = "kserve-huggingface" # kserve-huggingface, kserve-custom, etc. + container_image: Optional[str] = None + container_port: int = 8080 + + # Protocol configuration + protocol: str = "v2" # v1, v2, grpc + + # Resource requests/limits + cpu_request: str = "1" + cpu_limit: str = "4" + memory_request: str = "4Gi" + memory_limit: str = "16Gi" + gpu_count: int = 0 + gpu_type: str = "nvidia.com/gpu" # or amd.com/gpu for ROCm + + # Storage configuration + storage_uri: Optional[str] = None # s3://, pvc://, gs:// + + # Scaling configuration + min_replicas: int = 1 + max_replicas: int = 1 + scale_target: int = 10 # Target concurrent requests for scaling + + # Serving configuration + timeout_seconds: int = 300 + batch_size: int = 1 + + # Additional environment variables + env_vars: Dict[str, str] = field(default_factory=dict) + + def as_dict(self) -> Dict[str, Any]: + """Convert to dictionary for MLflow tags.""" + return { + "kserve.runtime": self.runtime, + "kserve.protocol": self.protocol, + "kserve.cpu_request": self.cpu_request, + "kserve.memory_request": self.memory_request, + "kserve.gpu_count": str(self.gpu_count), + "kserve.min_replicas": str(self.min_replicas), + "kserve.max_replicas": str(self.max_replicas), + "kserve.storage_uri": self.storage_uri or "", + "kserve.container_image": self.container_image or "", + } + + +# Pre-configured KServe configurations for common model types +KSERVE_PRESETS: Dict[str, KServeConfig] = { + "llm": KServeConfig( + runtime="kserve-huggingface", + cpu_request="2", + cpu_limit="8", + memory_request="16Gi", + memory_limit="64Gi", + gpu_count=1, + timeout_seconds=600, + ), + "stt": KServeConfig( + runtime="kserve-custom", + cpu_request="2", + cpu_limit="4", + memory_request="8Gi", + memory_limit="16Gi", + gpu_count=1, + timeout_seconds=120, + ), + "tts": KServeConfig( + runtime="kserve-custom", + cpu_request="2", + cpu_limit="4", + memory_request="8Gi", + memory_limit="16Gi", + gpu_count=1, + timeout_seconds=60, + ), + "embeddings": KServeConfig( + runtime="kserve-huggingface", + cpu_request="1", + cpu_limit="4", + memory_request="4Gi", + memory_limit="16Gi", + gpu_count=0, + timeout_seconds=30, + batch_size=32, + ), + "reranker": KServeConfig( + runtime="kserve-huggingface", + cpu_request="1", + cpu_limit="4", + memory_request="4Gi", + memory_limit="16Gi", + gpu_count=0, + timeout_seconds=30, + ), +} + + +def register_model_for_kserve( + model_name: str, + model_uri: str, + model_type: str, + run_id: Optional[str] = None, + description: Optional[str] = None, + kserve_config: Optional[KServeConfig] = None, + tags: Optional[Dict[str, str]] = None, + tracking_uri: Optional[str] = None, +) -> ModelVersion: + """ + Register a model in MLflow Model Registry with KServe metadata. + + Args: + model_name: Name for the registered model + model_uri: URI to model artifacts (runs:/run_id/path, s3://, pvc://) + model_type: Type of model (llm, stt, tts, embeddings, reranker) + run_id: Optional MLflow run ID to associate with + description: Description of the model version + kserve_config: KServe deployment configuration + tags: Additional tags for the model version + tracking_uri: Override default tracking URI + + Returns: + The created ModelVersion object + """ + client = get_mlflow_client(tracking_uri=tracking_uri) + + # Get or use preset KServe config + if kserve_config is None: + kserve_config = KSERVE_PRESETS.get(model_type, KServeConfig()) + + # Ensure registered model exists + try: + client.get_registered_model(model_name) + except mlflow.exceptions.MlflowException: + client.create_registered_model( + name=model_name, + description=f"{model_type.upper()} model for KServe deployment", + tags={ + "model.type": model_type, + "deployment.target": "kserve", + } + ) + logger.info(f"Created registered model: {model_name}") + + # Create model version + model_version = client.create_model_version( + name=model_name, + source=model_uri, + run_id=run_id, + description=description or f"Version from {model_uri}", + tags={ + **(tags or {}), + "model.type": model_type, + **kserve_config.as_dict(), + } + ) + + logger.info( + f"Registered model version {model_version.version} " + f"for {model_name} (type: {model_type})" + ) + + return model_version + + +def promote_model_to_stage( + model_name: str, + version: int, + stage: str = "Staging", + archive_existing: bool = True, + tracking_uri: Optional[str] = None, +) -> ModelVersion: + """ + Promote a model version to a new stage. + + Args: + model_name: Name of the registered model + version: Version number to promote + stage: Target stage (Staging, Production, Archived) + archive_existing: If True, archive existing versions in target stage + tracking_uri: Override default tracking URI + + Returns: + The updated ModelVersion + """ + client = get_mlflow_client(tracking_uri=tracking_uri) + + # Transition to new stage + model_version = client.transition_model_version_stage( + name=model_name, + version=str(version), + stage=stage, + archive_existing_versions=archive_existing, + ) + + logger.info(f"Promoted {model_name} v{version} to {stage}") + + return model_version + + +def promote_model_to_production( + model_name: str, + version: int, + tracking_uri: Optional[str] = None, +) -> ModelVersion: + """ + Promote a model version directly to Production. + + Args: + model_name: Name of the registered model + version: Version number to promote + tracking_uri: Override default tracking URI + + Returns: + The updated ModelVersion + """ + return promote_model_to_stage( + model_name=model_name, + version=version, + stage="Production", + archive_existing=True, + tracking_uri=tracking_uri, + ) + + +def get_production_model( + model_name: str, + tracking_uri: Optional[str] = None, +) -> Optional[ModelVersion]: + """ + Get the current Production model version. + + Args: + model_name: Name of the registered model + tracking_uri: Override default tracking URI + + Returns: + The Production ModelVersion, or None if none exists + """ + client = get_mlflow_client(tracking_uri=tracking_uri) + + versions = client.get_latest_versions(model_name, stages=["Production"]) + + return versions[0] if versions else None + + +def get_model_kserve_config( + model_name: str, + version: Optional[int] = None, + tracking_uri: Optional[str] = None, +) -> KServeConfig: + """ + Get KServe configuration from a registered model version. + + Args: + model_name: Name of the registered model + version: Version number (uses Production if not specified) + tracking_uri: Override default tracking URI + + Returns: + KServeConfig populated from model tags + """ + client = get_mlflow_client(tracking_uri=tracking_uri) + + if version: + model_version = client.get_model_version(model_name, str(version)) + else: + prod_version = get_production_model(model_name, tracking_uri) + if not prod_version: + raise ValueError(f"No Production version for {model_name}") + model_version = prod_version + + tags = model_version.tags + + return KServeConfig( + runtime=tags.get("kserve.runtime", "kserve-huggingface"), + protocol=tags.get("kserve.protocol", "v2"), + cpu_request=tags.get("kserve.cpu_request", "1"), + memory_request=tags.get("kserve.memory_request", "4Gi"), + gpu_count=int(tags.get("kserve.gpu_count", "0")), + min_replicas=int(tags.get("kserve.min_replicas", "1")), + max_replicas=int(tags.get("kserve.max_replicas", "1")), + storage_uri=tags.get("kserve.storage_uri") or None, + container_image=tags.get("kserve.container_image") or None, + ) + + +def generate_kserve_manifest( + model_name: str, + version: Optional[int] = None, + namespace: str = "ai-ml", + service_name: Optional[str] = None, + extra_annotations: Optional[Dict[str, str]] = None, + tracking_uri: Optional[str] = None, +) -> Dict[str, Any]: + """ + Generate a KServe InferenceService manifest from a registered model. + + Args: + model_name: Name of the registered model + version: Version number (uses Production if not specified) + namespace: Kubernetes namespace for deployment + service_name: Name for the InferenceService (defaults to model_name) + extra_annotations: Additional annotations for the service + tracking_uri: Override default tracking URI + + Returns: + KServe InferenceService manifest as a dictionary + """ + client = get_mlflow_client(tracking_uri=tracking_uri) + + # Get model version + if version: + model_version = client.get_model_version(model_name, str(version)) + else: + prod_version = get_production_model(model_name, tracking_uri) + if not prod_version: + raise ValueError(f"No Production version for {model_name}") + model_version = prod_version + version = int(model_version.version) + + # Get KServe config + config = get_model_kserve_config(model_name, version, tracking_uri) + model_type = model_version.tags.get("model.type", "custom") + + svc_name = service_name or model_name.lower().replace("_", "-") + + # Build manifest + manifest = { + "apiVersion": "serving.kserve.io/v1beta1", + "kind": "InferenceService", + "metadata": { + "name": svc_name, + "namespace": namespace, + "labels": { + "mlflow.model": model_name, + "mlflow.version": str(version), + "model.type": model_type, + }, + "annotations": { + "mlflow.tracking_uri": get_mlflow_client().tracking_uri, + "mlflow.run_id": model_version.run_id or "", + **(extra_annotations or {}), + }, + }, + "spec": { + "predictor": { + "minReplicas": config.min_replicas, + "maxReplicas": config.max_replicas, + "scaleTarget": config.scale_target, + "timeout": config.timeout_seconds, + }, + }, + } + + # Configure predictor based on runtime + predictor = manifest["spec"]["predictor"] + + if config.container_image: + # Custom container + predictor["containers"] = [{ + "name": "predictor", + "image": config.container_image, + "ports": [{"containerPort": config.container_port, "protocol": "TCP"}], + "resources": { + "requests": { + "cpu": config.cpu_request, + "memory": config.memory_request, + }, + "limits": { + "cpu": config.cpu_limit, + "memory": config.memory_limit, + }, + }, + "env": [ + {"name": k, "value": v} + for k, v in config.env_vars.items() + ], + }] + + # Add GPU if needed + if config.gpu_count > 0: + predictor["containers"][0]["resources"]["limits"][config.gpu_type] = str(config.gpu_count) + predictor["containers"][0]["resources"]["requests"][config.gpu_type] = str(config.gpu_count) + + else: + # Standard KServe runtime + storage_uri = config.storage_uri or model_version.source + + predictor["model"] = { + "modelFormat": {"name": "huggingface"}, + "protocolVersion": config.protocol, + "storageUri": storage_uri, + "resources": { + "requests": { + "cpu": config.cpu_request, + "memory": config.memory_request, + }, + "limits": { + "cpu": config.cpu_limit, + "memory": config.memory_limit, + }, + }, + } + + if config.gpu_count > 0: + predictor["model"]["resources"]["limits"][config.gpu_type] = str(config.gpu_count) + predictor["model"]["resources"]["requests"][config.gpu_type] = str(config.gpu_count) + + return manifest + + +def generate_kserve_yaml( + model_name: str, + version: Optional[int] = None, + namespace: str = "ai-ml", + output_path: Optional[str] = None, + tracking_uri: Optional[str] = None, +) -> str: + """ + Generate a KServe InferenceService manifest as YAML. + + Args: + model_name: Name of the registered model + version: Version number (uses Production if not specified) + namespace: Kubernetes namespace + output_path: If provided, write YAML to this path + tracking_uri: Override default tracking URI + + Returns: + YAML string of the manifest + """ + manifest = generate_kserve_manifest( + model_name=model_name, + version=version, + namespace=namespace, + tracking_uri=tracking_uri, + ) + + yaml_str = yaml.dump(manifest, default_flow_style=False, sort_keys=False) + + if output_path: + with open(output_path, 'w') as f: + f.write(yaml_str) + logger.info(f"Wrote KServe manifest to {output_path}") + + return yaml_str + + +def list_model_versions( + model_name: str, + stages: Optional[List[str]] = None, + tracking_uri: Optional[str] = None, +) -> List[Dict[str, Any]]: + """ + List all versions of a registered model. + + Args: + model_name: Name of the registered model + stages: Filter by stages (None for all) + tracking_uri: Override default tracking URI + + Returns: + List of model version info dictionaries + """ + client = get_mlflow_client(tracking_uri=tracking_uri) + + if stages: + versions = client.get_latest_versions(model_name, stages=stages) + else: + # Get all versions + versions = [] + for mv in client.search_model_versions(f"name='{model_name}'"): + versions.append(mv) + + return [ + { + "version": mv.version, + "stage": mv.current_stage, + "source": mv.source, + "run_id": mv.run_id, + "description": mv.description, + "tags": mv.tags, + "creation_timestamp": mv.creation_timestamp, + "last_updated_timestamp": mv.last_updated_timestamp, + } + for mv in versions + ] diff --git a/mlflow_utils/tracker.py b/mlflow_utils/tracker.py new file mode 100644 index 0000000..a92ae6b --- /dev/null +++ b/mlflow_utils/tracker.py @@ -0,0 +1,395 @@ +""" +MLflow Tracker for Kubeflow Pipelines + +Provides a high-level interface for logging experiments, parameters, +metrics, and artifacts from Kubeflow Pipeline components. +""" + +import os +import json +import time +import logging +from pathlib import Path +from typing import Optional, Dict, Any, List, Union +from contextlib import contextmanager +from dataclasses import dataclass, field + +import mlflow +from mlflow.tracking import MlflowClient + +from .client import get_mlflow_client, ensure_experiment, MLflowConfig + +logger = logging.getLogger(__name__) + + +@dataclass +class PipelineMetadata: + """Metadata about the Kubeflow Pipeline run.""" + pipeline_name: str + run_id: str + run_name: Optional[str] = None + component_name: Optional[str] = None + namespace: str = "ai-ml" + + # KFP-specific metadata (populated from environment if available) + kfp_run_id: Optional[str] = field( + default_factory=lambda: os.environ.get("KFP_RUN_ID") + ) + kfp_pod_name: Optional[str] = field( + default_factory=lambda: os.environ.get("KFP_POD_NAME") + ) + + def as_tags(self) -> Dict[str, str]: + """Convert metadata to MLflow tags.""" + tags = { + "pipeline.name": self.pipeline_name, + "pipeline.run_id": self.run_id, + "pipeline.namespace": self.namespace, + } + if self.run_name: + tags["pipeline.run_name"] = self.run_name + if self.component_name: + tags["pipeline.component"] = self.component_name + if self.kfp_run_id: + tags["kfp.run_id"] = self.kfp_run_id + if self.kfp_pod_name: + tags["kfp.pod_name"] = self.kfp_pod_name + return tags + + +class MLflowTracker: + """ + MLflow experiment tracker for Kubeflow Pipeline components. + + Example usage in a KFP component: + + from mlflow_utils import MLflowTracker + + tracker = MLflowTracker( + experiment_name="document-ingestion", + run_name="batch-ingestion-2024-01" + ) + + with tracker.start_run() as run: + tracker.log_params({ + "chunk_size": 500, + "overlap": 50, + "embeddings_model": "bge-small-en-v1.5" + }) + + # ... do work ... + + tracker.log_metrics({ + "documents_processed": 100, + "chunks_created": 2500, + "processing_time_seconds": 120.5 + }) + + tracker.log_artifact("/path/to/output.json") + """ + + def __init__( + self, + experiment_name: str, + run_name: Optional[str] = None, + pipeline_metadata: Optional[PipelineMetadata] = None, + tags: Optional[Dict[str, str]] = None, + tracking_uri: Optional[str] = None, + ): + """ + Initialize the MLflow tracker. + + Args: + experiment_name: Name of the MLflow experiment + run_name: Optional name for this run + pipeline_metadata: Metadata about the KFP pipeline + tags: Additional tags to add to the run + tracking_uri: Override default tracking URI + """ + self.config = MLflowConfig() + self.experiment_name = experiment_name + self.run_name = run_name or f"{experiment_name}-{int(time.time())}" + self.pipeline_metadata = pipeline_metadata + self.user_tags = tags or {} + self.tracking_uri = tracking_uri + + self.client: Optional[MlflowClient] = None + self.run: Optional[mlflow.ActiveRun] = None + self.run_id: Optional[str] = None + self._start_time: Optional[float] = None + + def _get_all_tags(self) -> Dict[str, str]: + """Combine all tags for the run.""" + tags = self.config.default_tags.copy() + + if self.pipeline_metadata: + tags.update(self.pipeline_metadata.as_tags()) + + tags.update(self.user_tags) + return tags + + @contextmanager + def start_run( + self, + nested: bool = False, + parent_run_id: Optional[str] = None, + ): + """ + Start an MLflow run as a context manager. + + Args: + nested: If True, create a nested run under the current active run + parent_run_id: Explicit parent run ID for nested runs + + Yields: + The MLflow run object + """ + self.client = get_mlflow_client( + tracking_uri=self.tracking_uri, + configure_global=True + ) + + # Ensure experiment exists + experiment_id = ensure_experiment(self.experiment_name) + + self._start_time = time.time() + + try: + # Start the run + self.run = mlflow.start_run( + experiment_id=experiment_id, + run_name=self.run_name, + nested=nested, + tags=self._get_all_tags(), + ) + self.run_id = self.run.info.run_id + + logger.info( + f"Started MLflow run '{self.run_name}' " + f"(ID: {self.run_id}) in experiment '{self.experiment_name}'" + ) + + yield self.run + + except Exception as e: + logger.error(f"MLflow run failed: {e}") + if self.run: + mlflow.set_tag("run.status", "failed") + mlflow.set_tag("run.error", str(e)) + raise + finally: + # Log duration + if self._start_time: + duration = time.time() - self._start_time + try: + mlflow.log_metric("run_duration_seconds", duration) + except Exception: + pass + + # End the run + mlflow.end_run() + logger.info(f"Ended MLflow run '{self.run_name}'") + + def log_params(self, params: Dict[str, Any]) -> None: + """ + Log parameters to the current run. + + Args: + params: Dictionary of parameter names to values + """ + if not self.run: + logger.warning("No active run, skipping log_params") + return + + # MLflow has limits on param values, truncate if needed + cleaned_params = {} + for key, value in params.items(): + str_value = str(value) + if len(str_value) > 500: + str_value = str_value[:497] + "..." + cleaned_params[key] = str_value + + mlflow.log_params(cleaned_params) + logger.debug(f"Logged {len(params)} parameters") + + def log_param(self, key: str, value: Any) -> None: + """Log a single parameter.""" + self.log_params({key: value}) + + def log_metrics( + self, + metrics: Dict[str, Union[float, int]], + step: Optional[int] = None + ) -> None: + """ + Log metrics to the current run. + + Args: + metrics: Dictionary of metric names to values + step: Optional step number for time-series metrics + """ + if not self.run: + logger.warning("No active run, skipping log_metrics") + return + + mlflow.log_metrics(metrics, step=step) + logger.debug(f"Logged {len(metrics)} metrics") + + def log_metric( + self, + key: str, + value: Union[float, int], + step: Optional[int] = None + ) -> None: + """Log a single metric.""" + self.log_metrics({key: value}, step=step) + + def log_artifact( + self, + local_path: str, + artifact_path: Optional[str] = None + ) -> None: + """ + Log an artifact file to the current run. + + Args: + local_path: Path to the local file to log + artifact_path: Optional destination path within the artifact store + """ + if not self.run: + logger.warning("No active run, skipping log_artifact") + return + + mlflow.log_artifact(local_path, artifact_path) + logger.info(f"Logged artifact: {local_path}") + + def log_artifacts( + self, + local_dir: str, + artifact_path: Optional[str] = None + ) -> None: + """ + Log all files in a directory as artifacts. + + Args: + local_dir: Path to the local directory + artifact_path: Optional destination path within the artifact store + """ + if not self.run: + logger.warning("No active run, skipping log_artifacts") + return + + mlflow.log_artifacts(local_dir, artifact_path) + logger.info(f"Logged artifacts from: {local_dir}") + + def log_dict( + self, + data: Dict[str, Any], + filename: str, + artifact_path: Optional[str] = None + ) -> None: + """ + Log a dictionary as a JSON artifact. + + Args: + data: Dictionary to log + filename: Name for the JSON file + artifact_path: Optional destination path + """ + if not self.run: + logger.warning("No active run, skipping log_dict") + return + + # Ensure .json extension + if not filename.endswith(".json"): + filename += ".json" + + mlflow.log_dict(data, f"{artifact_path}/{filename}" if artifact_path else filename) + logger.debug(f"Logged dict as: {filename}") + + def log_model_info( + self, + model_type: str, + model_name: str, + model_path: Optional[str] = None, + framework: str = "pytorch", + extra_info: Optional[Dict[str, Any]] = None + ) -> None: + """ + Log model information as parameters and tags. + + Args: + model_type: Type of model (e.g., "llm", "embedding", "stt") + model_name: Name/identifier of the model + model_path: Path to model weights + framework: ML framework used + extra_info: Additional model information + """ + params = { + "model.type": model_type, + "model.name": model_name, + "model.framework": framework, + } + if model_path: + params["model.path"] = model_path + if extra_info: + for key, value in extra_info.items(): + params[f"model.{key}"] = value + + self.log_params(params) + + # Also set as tags for easier filtering + mlflow.set_tag("model.type", model_type) + mlflow.set_tag("model.name", model_name) + + def log_dataset_info( + self, + name: str, + source: str, + size: Optional[int] = None, + extra_info: Optional[Dict[str, Any]] = None + ) -> None: + """ + Log dataset information. + + Args: + name: Dataset name + source: Dataset source (URL, path, etc.) + size: Number of samples + extra_info: Additional dataset information + """ + params = { + "dataset.name": name, + "dataset.source": source, + } + if size is not None: + params["dataset.size"] = size + if extra_info: + for key, value in extra_info.items(): + params[f"dataset.{key}"] = value + + self.log_params(params) + + def set_tag(self, key: str, value: str) -> None: + """Set a single tag on the run.""" + if self.run: + mlflow.set_tag(key, value) + + def set_tags(self, tags: Dict[str, str]) -> None: + """Set multiple tags on the run.""" + if self.run: + mlflow.set_tags(tags) + + @property + def artifact_uri(self) -> Optional[str]: + """Get the artifact URI for the current run.""" + if self.run: + return self.run.info.artifact_uri + return None + + @property + def experiment_id(self) -> Optional[str]: + """Get the experiment ID for the current run.""" + if self.run: + return self.run.info.experiment_id + return None diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1828f85 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +# MLflow Utils Module Requirements +# Core MLflow +mlflow>=2.10.0 + +# Database backends +psycopg2-binary>=2.9.0 # PostgreSQL (CNPG) +boto3>=1.34.0 # S3-compatible artifact storage (optional) + +# For async tracking +aiohttp>=3.9.0 + +# YAML generation for KServe manifests +PyYAML>=6.0 + +# Already in chat-handler/voice-assistant requirements: +# httpx (for health checks) +# Used but typically installed with mlflow: +# numpy +# pandas