feat: Add MLflow integration utilities
- client: Connection management and helpers - tracker: General experiment tracking - inference_tracker: Async metrics for NATS handlers - model_registry: Model registration with KServe metadata - kfp_components: Kubeflow Pipeline components - experiment_comparison: Run comparison tools - cli: Command-line interface
This commit is contained in:
129
README.md
129
README.md
@@ -1,2 +1,129 @@
|
|||||||
# mlflow
|
# MLflow Utils
|
||||||
|
|
||||||
|
MLflow integration utilities for the DaviesTechLabs AI/ML platform.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
Or from Gitea:
|
||||||
|
```bash
|
||||||
|
pip install git+https://git.daviestechlabs.io/daviestechlabs/mlflow.git
|
||||||
|
```
|
||||||
|
|
||||||
|
## Modules
|
||||||
|
|
||||||
|
| Module | Description |
|
||||||
|
|--------|-------------|
|
||||||
|
| `client.py` | MLflow client configuration and helpers |
|
||||||
|
| `tracker.py` | General MLflowTracker for experiments |
|
||||||
|
| `inference_tracker.py` | Async inference metrics for NATS handlers |
|
||||||
|
| `model_registry.py` | Model Registry with KServe metadata |
|
||||||
|
| `kfp_components.py` | Kubeflow Pipeline MLflow components |
|
||||||
|
| `experiment_comparison.py` | Compare experiments and runs |
|
||||||
|
| `cli.py` | Command-line interface |
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mlflow_utils import get_mlflow_client, MLflowTracker
|
||||||
|
|
||||||
|
# Simple tracking
|
||||||
|
with MLflowTracker(experiment_name="my-experiment") as tracker:
|
||||||
|
tracker.log_params({"learning_rate": 0.001})
|
||||||
|
tracker.log_metrics({"accuracy": 0.95})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Inference Tracking
|
||||||
|
|
||||||
|
For NATS handlers (chat-handler, voice-assistant):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mlflow_utils import InferenceMetricsTracker
|
||||||
|
from mlflow_utils.inference_tracker import InferenceMetrics
|
||||||
|
|
||||||
|
tracker = InferenceMetricsTracker(
|
||||||
|
experiment_name="voice-assistant-prod",
|
||||||
|
batch_size=100, # Batch metrics before logging
|
||||||
|
)
|
||||||
|
|
||||||
|
# During request handling
|
||||||
|
metrics = InferenceMetrics(
|
||||||
|
request_id="uuid",
|
||||||
|
total_latency=1.5,
|
||||||
|
llm_latency=0.8,
|
||||||
|
input_tokens=150,
|
||||||
|
output_tokens=200,
|
||||||
|
)
|
||||||
|
await tracker.log_inference(metrics)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model Registry
|
||||||
|
|
||||||
|
Register models with KServe deployment metadata:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mlflow_utils.model_registry import register_model_for_kserve
|
||||||
|
|
||||||
|
register_model_for_kserve(
|
||||||
|
model_name="my-qlora-adapter",
|
||||||
|
model_uri="runs:/abc123/model",
|
||||||
|
kserve_runtime="kserve-vllm",
|
||||||
|
gpu_type="amd-strixhalo",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Kubeflow Components
|
||||||
|
|
||||||
|
Use in KFP pipelines:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mlflow_utils.kfp_components import (
|
||||||
|
log_experiment_component,
|
||||||
|
register_model_component,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## CLI
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# List experiments
|
||||||
|
python -m mlflow_utils.cli list-experiments
|
||||||
|
|
||||||
|
# Compare runs
|
||||||
|
python -m mlflow_utils.cli compare-runs --experiment "qlora-training"
|
||||||
|
|
||||||
|
# Export metrics
|
||||||
|
python -m mlflow_utils.cli export --run-id abc123 --output metrics.json
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
| Environment Variable | Default | Description |
|
||||||
|
|---------------------|---------|-------------|
|
||||||
|
| `MLFLOW_TRACKING_URI` | `http://mlflow.mlflow.svc.cluster.local:80` | MLflow server |
|
||||||
|
| `MLFLOW_EXPERIMENT_NAME` | `default` | Default experiment |
|
||||||
|
| `MLFLOW_ENABLE_ASYNC` | `true` | Async logging for handlers |
|
||||||
|
|
||||||
|
## Module Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
mlflow_utils/
|
||||||
|
├── __init__.py # Public API
|
||||||
|
├── client.py # Connection management
|
||||||
|
├── tracker.py # General experiment tracker
|
||||||
|
├── inference_tracker.py # Async inference metrics
|
||||||
|
├── model_registry.py # Model registration + KServe
|
||||||
|
├── kfp_components.py # Kubeflow components
|
||||||
|
├── experiment_comparison.py # Run comparison tools
|
||||||
|
└── cli.py # Command-line interface
|
||||||
|
```
|
||||||
|
|
||||||
|
## Related
|
||||||
|
|
||||||
|
- [handler-base](https://git.daviestechlabs.io/daviestechlabs/handler-base) - Uses inference tracker
|
||||||
|
- [kubeflow](https://git.daviestechlabs.io/daviestechlabs/kubeflow) - KFP components
|
||||||
|
- [argo](https://git.daviestechlabs.io/daviestechlabs/argo) - Training workflows
|
||||||
|
- [homelab-design](https://git.daviestechlabs.io/daviestechlabs/homelab-design) - Architecture docs
|
||||||
|
|||||||
40
mlflow_utils/__init__.py
Normal file
40
mlflow_utils/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""
|
||||||
|
MLflow Integration Utilities for LLM Workflows
|
||||||
|
|
||||||
|
This module provides MLflow integration for:
|
||||||
|
- Kubeflow Pipelines experiment tracking
|
||||||
|
- Model Registry with KServe deployment metadata
|
||||||
|
- Inference metrics logging from NATS handlers
|
||||||
|
- Experiment comparison and analysis
|
||||||
|
|
||||||
|
Configuration:
|
||||||
|
Set MLFLOW_TRACKING_URI environment variable or use defaults:
|
||||||
|
- In-cluster: http://mlflow.mlflow.svc.cluster.local:80
|
||||||
|
- External: https://mlflow.lab.daviestechlabs.io
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from mlflow_utils import get_mlflow_client, MLflowTracker
|
||||||
|
from mlflow_utils.kfp_components import log_experiment_component
|
||||||
|
from mlflow_utils.model_registry import register_model_for_kserve
|
||||||
|
from mlflow_utils.inference_tracker import InferenceMetricsTracker
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .client import (
|
||||||
|
get_mlflow_client,
|
||||||
|
get_tracking_uri,
|
||||||
|
ensure_experiment,
|
||||||
|
MLflowConfig,
|
||||||
|
)
|
||||||
|
from .tracker import MLflowTracker
|
||||||
|
from .inference_tracker import InferenceMetricsTracker
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_mlflow_client",
|
||||||
|
"get_tracking_uri",
|
||||||
|
"ensure_experiment",
|
||||||
|
"MLflowConfig",
|
||||||
|
"MLflowTracker",
|
||||||
|
"InferenceMetricsTracker",
|
||||||
|
]
|
||||||
|
|
||||||
|
__version__ = "1.0.0"
|
||||||
371
mlflow_utils/cli.py
Normal file
371
mlflow_utils/cli.py
Normal file
@@ -0,0 +1,371 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
MLflow Experiment CLI
|
||||||
|
|
||||||
|
Command-line interface for querying and comparing MLflow experiments.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Compare recent runs in an experiment
|
||||||
|
python -m mlflow_utils.cli compare --experiment chat-inference --runs 5
|
||||||
|
|
||||||
|
# Get best run by metric
|
||||||
|
python -m mlflow_utils.cli best --experiment evaluation --metric eval.accuracy
|
||||||
|
|
||||||
|
# Generate performance report
|
||||||
|
python -m mlflow_utils.cli report --service chat-handler --hours 24
|
||||||
|
|
||||||
|
# Check model promotion criteria
|
||||||
|
python -m mlflow_utils.cli promote --model whisper-finetuned \\
|
||||||
|
--experiment voice-evaluation \\
|
||||||
|
--criteria "eval.accuracy>=0.9,total_latency_p95<=2.0"
|
||||||
|
|
||||||
|
# List experiments
|
||||||
|
python -m mlflow_utils.cli list-experiments
|
||||||
|
|
||||||
|
# Query runs
|
||||||
|
python -m mlflow_utils.cli query --experiment chat-inference \\
|
||||||
|
--filter "metrics.total_latency_mean < 1.0" --limit 10
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .client import get_mlflow_client, health_check
|
||||||
|
from .experiment_comparison import (
|
||||||
|
ExperimentAnalyzer,
|
||||||
|
compare_experiments,
|
||||||
|
promotion_recommendation,
|
||||||
|
get_inference_performance_report,
|
||||||
|
)
|
||||||
|
from .model_registry import (
|
||||||
|
list_model_versions,
|
||||||
|
get_production_model,
|
||||||
|
generate_kserve_yaml,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_health(args):
|
||||||
|
"""Check MLflow connectivity."""
|
||||||
|
result = health_check()
|
||||||
|
print(json.dumps(result, indent=2))
|
||||||
|
sys.exit(0 if result["connected"] else 1)
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_list_experiments(args):
|
||||||
|
"""List all experiments."""
|
||||||
|
client = get_mlflow_client(tracking_uri=args.tracking_uri)
|
||||||
|
experiments = client.search_experiments()
|
||||||
|
|
||||||
|
print(f"{'ID':<10} {'Name':<40} {'Artifact Location'}")
|
||||||
|
print("-" * 80)
|
||||||
|
for exp in experiments:
|
||||||
|
print(f"{exp.experiment_id:<10} {exp.name:<40} {exp.artifact_location}")
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_compare(args):
|
||||||
|
"""Compare recent runs in an experiment."""
|
||||||
|
analyzer = ExperimentAnalyzer(
|
||||||
|
args.experiment,
|
||||||
|
tracking_uri=args.tracking_uri
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.run_ids:
|
||||||
|
run_ids = args.run_ids.split(",")
|
||||||
|
comparison = analyzer.compare_runs(run_ids=run_ids)
|
||||||
|
else:
|
||||||
|
comparison = analyzer.compare_runs(n_recent=args.runs)
|
||||||
|
|
||||||
|
if args.json:
|
||||||
|
print(json.dumps(comparison.to_dict(), indent=2, default=str))
|
||||||
|
else:
|
||||||
|
print(comparison.summary_table())
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_best(args):
|
||||||
|
"""Find the best run by a metric."""
|
||||||
|
analyzer = ExperimentAnalyzer(
|
||||||
|
args.experiment,
|
||||||
|
tracking_uri=args.tracking_uri
|
||||||
|
)
|
||||||
|
|
||||||
|
best_run = analyzer.get_best_run(
|
||||||
|
metric=args.metric,
|
||||||
|
minimize=args.minimize,
|
||||||
|
filter_string=args.filter or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not best_run:
|
||||||
|
print(f"No runs found with metric '{args.metric}'")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"run_id": best_run.info.run_id,
|
||||||
|
"run_name": best_run.info.run_name,
|
||||||
|
"metric_value": best_run.data.metrics.get(args.metric),
|
||||||
|
"all_metrics": dict(best_run.data.metrics),
|
||||||
|
"params": dict(best_run.data.params),
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.json:
|
||||||
|
print(json.dumps(result, indent=2))
|
||||||
|
else:
|
||||||
|
print(f"Best Run: {best_run.info.run_name or best_run.info.run_id}")
|
||||||
|
print(f" {args.metric}: {result['metric_value']}")
|
||||||
|
print(f" Run ID: {best_run.info.run_id}")
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_summary(args):
|
||||||
|
"""Get metrics summary for an experiment."""
|
||||||
|
analyzer = ExperimentAnalyzer(
|
||||||
|
args.experiment,
|
||||||
|
tracking_uri=args.tracking_uri
|
||||||
|
)
|
||||||
|
|
||||||
|
summary = analyzer.get_metrics_summary(
|
||||||
|
hours=args.hours,
|
||||||
|
metrics=args.metrics.split(",") if args.metrics else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.json:
|
||||||
|
print(json.dumps(summary, indent=2))
|
||||||
|
else:
|
||||||
|
print(f"Metrics Summary for '{args.experiment}' (last {args.hours}h)")
|
||||||
|
print("=" * 60)
|
||||||
|
for metric, stats in sorted(summary.items()):
|
||||||
|
print(f"\n{metric}:")
|
||||||
|
for stat, value in stats.items():
|
||||||
|
print(f" {stat}: {value:.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_report(args):
|
||||||
|
"""Generate an inference performance report."""
|
||||||
|
report = get_inference_performance_report(
|
||||||
|
service_name=args.service,
|
||||||
|
hours=args.hours,
|
||||||
|
tracking_uri=args.tracking_uri,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.json:
|
||||||
|
print(json.dumps(report, indent=2))
|
||||||
|
else:
|
||||||
|
print(f"Performance Report: {report['service']}")
|
||||||
|
print(f"Period: Last {report['period_hours']} hours")
|
||||||
|
print(f"Generated: {report['generated_at']}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
if report["latency"]:
|
||||||
|
print("Latency Metrics:")
|
||||||
|
for metric, stats in report["latency"].items():
|
||||||
|
if "mean" in stats:
|
||||||
|
print(f" {metric}: {stats['mean']:.4f}s (p50: {stats.get('median', 'N/A')})")
|
||||||
|
|
||||||
|
if report["rag"]:
|
||||||
|
print("\nRAG Usage:")
|
||||||
|
for metric, stats in report["rag"].items():
|
||||||
|
print(f" {metric}: {stats.get('mean', 'N/A')}")
|
||||||
|
|
||||||
|
if report["errors"]:
|
||||||
|
print("\nError Rates:")
|
||||||
|
for metric, stats in report["errors"].items():
|
||||||
|
print(f" {metric}: {stats:.2f}%")
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_promote(args):
|
||||||
|
"""Check model promotion criteria."""
|
||||||
|
# Parse criteria
|
||||||
|
criteria = {}
|
||||||
|
for criterion in args.criteria.split(","):
|
||||||
|
# Parse "metric>=value" or "metric<=value" etc.
|
||||||
|
for op in [">=", "<=", ">", "<"]:
|
||||||
|
if op in criterion:
|
||||||
|
metric, value = criterion.split(op)
|
||||||
|
criteria[metric.strip()] = (op, float(value.strip()))
|
||||||
|
break
|
||||||
|
|
||||||
|
rec = promotion_recommendation(
|
||||||
|
model_name=args.model,
|
||||||
|
experiment_name=args.experiment,
|
||||||
|
criteria=criteria,
|
||||||
|
tracking_uri=args.tracking_uri,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.json:
|
||||||
|
print(json.dumps(rec.to_dict(), indent=2))
|
||||||
|
else:
|
||||||
|
status = "✓ RECOMMENDED" if rec.recommended else "✗ NOT RECOMMENDED"
|
||||||
|
print(f"Model: {args.model}")
|
||||||
|
print(f"Status: {status}")
|
||||||
|
print("\nCriteria Evaluation:")
|
||||||
|
for reason in rec.reasons:
|
||||||
|
print(f" {reason}")
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_query(args):
|
||||||
|
"""Query runs with a filter."""
|
||||||
|
analyzer = ExperimentAnalyzer(
|
||||||
|
args.experiment,
|
||||||
|
tracking_uri=args.tracking_uri
|
||||||
|
)
|
||||||
|
|
||||||
|
runs = analyzer.search_runs(
|
||||||
|
filter_string=args.filter or "",
|
||||||
|
max_results=args.limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.json:
|
||||||
|
result = [
|
||||||
|
{
|
||||||
|
"run_id": r.info.run_id,
|
||||||
|
"run_name": r.info.run_name,
|
||||||
|
"status": r.info.status,
|
||||||
|
"metrics": dict(r.data.metrics),
|
||||||
|
"params": dict(r.data.params),
|
||||||
|
}
|
||||||
|
for r in runs
|
||||||
|
]
|
||||||
|
print(json.dumps(result, indent=2))
|
||||||
|
else:
|
||||||
|
print(f"Found {len(runs)} runs")
|
||||||
|
for run in runs:
|
||||||
|
print(f"\n{run.info.run_name or run.info.run_id}")
|
||||||
|
print(f" ID: {run.info.run_id}")
|
||||||
|
print(f" Status: {run.info.status}")
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_models(args):
|
||||||
|
"""List registered models."""
|
||||||
|
client = get_mlflow_client(tracking_uri=args.tracking_uri)
|
||||||
|
|
||||||
|
if args.model:
|
||||||
|
versions = list_model_versions(args.model, tracking_uri=args.tracking_uri)
|
||||||
|
|
||||||
|
if args.json:
|
||||||
|
print(json.dumps(versions, indent=2, default=str))
|
||||||
|
else:
|
||||||
|
print(f"Model: {args.model}")
|
||||||
|
for v in versions:
|
||||||
|
print(f" v{v['version']} ({v['stage']}): {v['description'][:50] if v['description'] else 'No description'}")
|
||||||
|
else:
|
||||||
|
# List all models
|
||||||
|
models = client.search_registered_models()
|
||||||
|
|
||||||
|
if args.json:
|
||||||
|
result = [{"name": m.name, "description": m.description} for m in models]
|
||||||
|
print(json.dumps(result, indent=2))
|
||||||
|
else:
|
||||||
|
print(f"{'Model Name':<40} Description")
|
||||||
|
print("-" * 80)
|
||||||
|
for model in models:
|
||||||
|
desc = (model.description or "")[:35]
|
||||||
|
print(f"{model.name:<40} {desc}")
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_kserve(args):
|
||||||
|
"""Generate KServe manifest for a model."""
|
||||||
|
yaml_str = generate_kserve_yaml(
|
||||||
|
model_name=args.model,
|
||||||
|
version=args.version,
|
||||||
|
namespace=args.namespace,
|
||||||
|
output_path=args.output,
|
||||||
|
tracking_uri=args.tracking_uri,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not args.output:
|
||||||
|
print(yaml_str)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="MLflow Experiment CLI",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tracking-uri",
|
||||||
|
default=None,
|
||||||
|
help="MLflow tracking URI (default: from env or in-cluster)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--json",
|
||||||
|
action="store_true",
|
||||||
|
help="Output as JSON",
|
||||||
|
)
|
||||||
|
|
||||||
|
subparsers = parser.add_subparsers(dest="command", help="Commands")
|
||||||
|
|
||||||
|
# health
|
||||||
|
health_parser = subparsers.add_parser("health", help="Check MLflow connectivity")
|
||||||
|
health_parser.set_defaults(func=cmd_health)
|
||||||
|
|
||||||
|
# list-experiments
|
||||||
|
list_parser = subparsers.add_parser("list-experiments", help="List experiments")
|
||||||
|
list_parser.set_defaults(func=cmd_list_experiments)
|
||||||
|
|
||||||
|
# compare
|
||||||
|
compare_parser = subparsers.add_parser("compare", help="Compare runs")
|
||||||
|
compare_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
||||||
|
compare_parser.add_argument("--runs", "-n", type=int, default=5, help="Number of recent runs")
|
||||||
|
compare_parser.add_argument("--run-ids", help="Comma-separated run IDs to compare")
|
||||||
|
compare_parser.set_defaults(func=cmd_compare)
|
||||||
|
|
||||||
|
# best
|
||||||
|
best_parser = subparsers.add_parser("best", help="Find best run by metric")
|
||||||
|
best_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
||||||
|
best_parser.add_argument("--metric", "-m", required=True, help="Metric to optimize")
|
||||||
|
best_parser.add_argument("--minimize", action="store_true", help="Minimize metric (default: maximize)")
|
||||||
|
best_parser.add_argument("--filter", "-f", help="Filter string")
|
||||||
|
best_parser.set_defaults(func=cmd_best)
|
||||||
|
|
||||||
|
# summary
|
||||||
|
summary_parser = subparsers.add_parser("summary", help="Get metrics summary")
|
||||||
|
summary_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
||||||
|
summary_parser.add_argument("--hours", type=int, default=24, help="Hours of data")
|
||||||
|
summary_parser.add_argument("--metrics", help="Comma-separated metric names")
|
||||||
|
summary_parser.set_defaults(func=cmd_summary)
|
||||||
|
|
||||||
|
# report
|
||||||
|
report_parser = subparsers.add_parser("report", help="Generate performance report")
|
||||||
|
report_parser.add_argument("--service", "-s", required=True, help="Service name")
|
||||||
|
report_parser.add_argument("--hours", type=int, default=24, help="Hours of data")
|
||||||
|
report_parser.set_defaults(func=cmd_report)
|
||||||
|
|
||||||
|
# promote
|
||||||
|
promote_parser = subparsers.add_parser("promote", help="Check promotion criteria")
|
||||||
|
promote_parser.add_argument("--model", "-m", required=True, help="Model name")
|
||||||
|
promote_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
||||||
|
promote_parser.add_argument("--criteria", "-c", required=True, help="Criteria (e.g., 'accuracy>=0.9,latency<=2.0')")
|
||||||
|
promote_parser.set_defaults(func=cmd_promote)
|
||||||
|
|
||||||
|
# query
|
||||||
|
query_parser = subparsers.add_parser("query", help="Query runs")
|
||||||
|
query_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
|
||||||
|
query_parser.add_argument("--filter", "-f", help="MLflow filter string")
|
||||||
|
query_parser.add_argument("--limit", "-l", type=int, default=20, help="Max results")
|
||||||
|
query_parser.set_defaults(func=cmd_query)
|
||||||
|
|
||||||
|
# models
|
||||||
|
models_parser = subparsers.add_parser("models", help="List registered models")
|
||||||
|
models_parser.add_argument("--model", "-m", help="Specific model name")
|
||||||
|
models_parser.set_defaults(func=cmd_models)
|
||||||
|
|
||||||
|
# kserve
|
||||||
|
kserve_parser = subparsers.add_parser("kserve", help="Generate KServe manifest")
|
||||||
|
kserve_parser.add_argument("--model", "-m", required=True, help="Model name")
|
||||||
|
kserve_parser.add_argument("--version", "-v", type=int, help="Model version")
|
||||||
|
kserve_parser.add_argument("--namespace", "-n", default="ai-ml", help="K8s namespace")
|
||||||
|
kserve_parser.add_argument("--output", "-o", help="Output file path")
|
||||||
|
kserve_parser.set_defaults(func=cmd_kserve)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not args.command:
|
||||||
|
parser.print_help()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
args.func(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
209
mlflow_utils/client.py
Normal file
209
mlflow_utils/client.py
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
"""
|
||||||
|
MLflow Client Configuration and Initialization
|
||||||
|
|
||||||
|
Provides a configured MLflow client for all integrations in the LLM workflows.
|
||||||
|
Supports both in-cluster and external access patterns.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
import mlflow
|
||||||
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MLflowConfig:
|
||||||
|
"""Configuration for MLflow integration."""
|
||||||
|
|
||||||
|
# Tracking server URIs
|
||||||
|
tracking_uri: str = field(
|
||||||
|
default_factory=lambda: os.environ.get(
|
||||||
|
"MLFLOW_TRACKING_URI",
|
||||||
|
"http://mlflow.mlflow.svc.cluster.local:80"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
external_uri: str = field(
|
||||||
|
default_factory=lambda: os.environ.get(
|
||||||
|
"MLFLOW_EXTERNAL_URI",
|
||||||
|
"https://mlflow.lab.daviestechlabs.io"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Artifact storage (NFS PVC mount)
|
||||||
|
artifact_location: str = field(
|
||||||
|
default_factory=lambda: os.environ.get(
|
||||||
|
"MLFLOW_ARTIFACT_LOCATION",
|
||||||
|
"/mlflow/artifacts"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default experiment settings
|
||||||
|
default_experiment: str = field(
|
||||||
|
default_factory=lambda: os.environ.get(
|
||||||
|
"MLFLOW_DEFAULT_EXPERIMENT",
|
||||||
|
"llm-workflows"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Service identification
|
||||||
|
service_name: str = field(
|
||||||
|
default_factory=lambda: os.environ.get(
|
||||||
|
"OTEL_SERVICE_NAME",
|
||||||
|
"unknown-service"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Additional tags to add to all runs
|
||||||
|
default_tags: Dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Add default tags based on environment."""
|
||||||
|
env_tags = {
|
||||||
|
"environment": os.environ.get("DEPLOYMENT_ENV", "production"),
|
||||||
|
"hostname": os.environ.get("HOSTNAME", "unknown"),
|
||||||
|
"namespace": os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml"),
|
||||||
|
}
|
||||||
|
self.default_tags = {**env_tags, **self.default_tags}
|
||||||
|
|
||||||
|
|
||||||
|
def get_tracking_uri(external: bool = False) -> str:
|
||||||
|
"""
|
||||||
|
Get the appropriate MLflow tracking URI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
external: If True, return the external URI for outside-cluster access
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The MLflow tracking URI string
|
||||||
|
"""
|
||||||
|
config = MLflowConfig()
|
||||||
|
return config.external_uri if external else config.tracking_uri
|
||||||
|
|
||||||
|
|
||||||
|
def get_mlflow_client(
|
||||||
|
tracking_uri: Optional[str] = None,
|
||||||
|
configure_global: bool = True
|
||||||
|
) -> MlflowClient:
|
||||||
|
"""
|
||||||
|
Get a configured MLflow client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tracking_uri: Override the default tracking URI
|
||||||
|
configure_global: If True, also set mlflow.set_tracking_uri()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured MlflowClient instance
|
||||||
|
"""
|
||||||
|
uri = tracking_uri or get_tracking_uri()
|
||||||
|
|
||||||
|
if configure_global:
|
||||||
|
mlflow.set_tracking_uri(uri)
|
||||||
|
logger.info(f"MLflow tracking URI set to: {uri}")
|
||||||
|
|
||||||
|
client = MlflowClient(tracking_uri=uri)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_experiment(
|
||||||
|
experiment_name: str,
|
||||||
|
artifact_location: Optional[str] = None,
|
||||||
|
tags: Optional[Dict[str, str]] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Ensure an experiment exists, creating it if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
experiment_name: Name of the experiment
|
||||||
|
artifact_location: Override default artifact location
|
||||||
|
tags: Additional tags for the experiment
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The experiment ID
|
||||||
|
"""
|
||||||
|
config = MLflowConfig()
|
||||||
|
client = get_mlflow_client()
|
||||||
|
|
||||||
|
# Check if experiment exists
|
||||||
|
experiment = client.get_experiment_by_name(experiment_name)
|
||||||
|
|
||||||
|
if experiment is None:
|
||||||
|
# Create the experiment
|
||||||
|
artifact_loc = artifact_location or f"{config.artifact_location}/{experiment_name}"
|
||||||
|
experiment_id = client.create_experiment(
|
||||||
|
name=experiment_name,
|
||||||
|
artifact_location=artifact_loc,
|
||||||
|
tags=tags or {}
|
||||||
|
)
|
||||||
|
logger.info(f"Created experiment '{experiment_name}' with ID: {experiment_id}")
|
||||||
|
else:
|
||||||
|
experiment_id = experiment.experiment_id
|
||||||
|
logger.debug(f"Using existing experiment '{experiment_name}' with ID: {experiment_id}")
|
||||||
|
|
||||||
|
return experiment_id
|
||||||
|
|
||||||
|
|
||||||
|
def get_or_create_registered_model(
|
||||||
|
model_name: str,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
tags: Optional[Dict[str, str]] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Get or create a registered model in the Model Registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model to register
|
||||||
|
description: Model description
|
||||||
|
tags: Tags for the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The registered model name
|
||||||
|
"""
|
||||||
|
client = get_mlflow_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if model exists
|
||||||
|
client.get_registered_model(model_name)
|
||||||
|
logger.debug(f"Using existing registered model: {model_name}")
|
||||||
|
except mlflow.exceptions.MlflowException:
|
||||||
|
# Create the model
|
||||||
|
client.create_registered_model(
|
||||||
|
name=model_name,
|
||||||
|
description=description or f"Model for {model_name}",
|
||||||
|
tags=tags or {}
|
||||||
|
)
|
||||||
|
logger.info(f"Created registered model: {model_name}")
|
||||||
|
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
|
||||||
|
def health_check() -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Check MLflow server connectivity and return status.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with health status information
|
||||||
|
"""
|
||||||
|
config = MLflowConfig()
|
||||||
|
result = {
|
||||||
|
"tracking_uri": config.tracking_uri,
|
||||||
|
"external_uri": config.external_uri,
|
||||||
|
"connected": False,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
client = get_mlflow_client(configure_global=False)
|
||||||
|
# Try to list experiments as a health check
|
||||||
|
experiments = client.search_experiments(max_results=1)
|
||||||
|
result["connected"] = True
|
||||||
|
result["experiment_count"] = len(experiments)
|
||||||
|
except Exception as e:
|
||||||
|
result["error"] = str(e)
|
||||||
|
logger.error(f"MLflow health check failed: {e}")
|
||||||
|
|
||||||
|
return result
|
||||||
664
mlflow_utils/experiment_comparison.py
Normal file
664
mlflow_utils/experiment_comparison.py
Normal file
@@ -0,0 +1,664 @@
|
|||||||
|
"""
|
||||||
|
Experiment Comparison and Analysis Utilities
|
||||||
|
|
||||||
|
Provides tools for comparing model versions, querying experiments,
|
||||||
|
and making data-driven decisions about model promotion to production.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Compare multiple runs/experiments side by side
|
||||||
|
- Query experiments by tags, metrics, or parameters
|
||||||
|
- Analyze inference metrics from NATS handlers
|
||||||
|
- Generate promotion recommendations
|
||||||
|
- Export comparison reports
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from mlflow_utils.experiment_comparison import (
|
||||||
|
ExperimentAnalyzer,
|
||||||
|
compare_runs,
|
||||||
|
get_best_run,
|
||||||
|
promotion_recommendation,
|
||||||
|
)
|
||||||
|
|
||||||
|
analyzer = ExperimentAnalyzer("chat-inference")
|
||||||
|
|
||||||
|
# Compare last N runs
|
||||||
|
comparison = analyzer.compare_recent_runs(n=5)
|
||||||
|
|
||||||
|
# Find best performing model
|
||||||
|
best = analyzer.get_best_run(metric="total_latency_mean", minimize=True)
|
||||||
|
|
||||||
|
# Get promotion recommendation
|
||||||
|
rec = analyzer.promotion_recommendation(
|
||||||
|
model_name="whisper-finetuned",
|
||||||
|
min_accuracy=0.9,
|
||||||
|
max_latency_p95=2.0
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional, Dict, Any, List, Tuple, Union
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import mlflow
|
||||||
|
from mlflow.tracking import MlflowClient
|
||||||
|
from mlflow.entities import Run, Experiment
|
||||||
|
|
||||||
|
from .client import get_mlflow_client, MLflowConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RunComparison:
|
||||||
|
"""Comparison result for multiple MLflow runs."""
|
||||||
|
run_ids: List[str]
|
||||||
|
experiment_name: str
|
||||||
|
|
||||||
|
# Metric comparisons (metric_name -> {run_id -> value})
|
||||||
|
metrics: Dict[str, Dict[str, float]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Parameter differences
|
||||||
|
params: Dict[str, Dict[str, str]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Run metadata
|
||||||
|
run_names: Dict[str, str] = field(default_factory=dict)
|
||||||
|
start_times: Dict[str, datetime] = field(default_factory=dict)
|
||||||
|
durations: Dict[str, float] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Best performers by metric
|
||||||
|
best_by_metric: Dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for serialization."""
|
||||||
|
return {
|
||||||
|
"run_ids": self.run_ids,
|
||||||
|
"experiment_name": self.experiment_name,
|
||||||
|
"metrics": self.metrics,
|
||||||
|
"params": self.params,
|
||||||
|
"run_names": self.run_names,
|
||||||
|
"best_by_metric": self.best_by_metric,
|
||||||
|
}
|
||||||
|
|
||||||
|
def summary_table(self) -> str:
|
||||||
|
"""Generate a text summary table of the comparison."""
|
||||||
|
if not self.run_ids:
|
||||||
|
return "No runs to compare"
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
lines.append(f"Experiment: {self.experiment_name}")
|
||||||
|
lines.append(f"Comparing {len(self.run_ids)} runs")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Header
|
||||||
|
header = ["Metric"] + [self.run_names.get(rid, rid[:8]) for rid in self.run_ids]
|
||||||
|
lines.append(" | ".join(header))
|
||||||
|
lines.append("-" * (len(lines[-1]) + 10))
|
||||||
|
|
||||||
|
# Metrics
|
||||||
|
for metric_name, values in sorted(self.metrics.items()):
|
||||||
|
row = [metric_name]
|
||||||
|
for run_id in self.run_ids:
|
||||||
|
value = values.get(run_id)
|
||||||
|
if value is not None:
|
||||||
|
row.append(f"{value:.4f}")
|
||||||
|
else:
|
||||||
|
row.append("N/A")
|
||||||
|
lines.append(" | ".join(row))
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PromotionRecommendation:
|
||||||
|
"""Recommendation for model promotion."""
|
||||||
|
model_name: str
|
||||||
|
version: Optional[int]
|
||||||
|
recommended: bool
|
||||||
|
reasons: List[str]
|
||||||
|
metrics_summary: Dict[str, float]
|
||||||
|
comparison_with_production: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"model_name": self.model_name,
|
||||||
|
"version": self.version,
|
||||||
|
"recommended": self.recommended,
|
||||||
|
"reasons": self.reasons,
|
||||||
|
"metrics_summary": self.metrics_summary,
|
||||||
|
"comparison_with_production": self.comparison_with_production,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ExperimentAnalyzer:
|
||||||
|
"""
|
||||||
|
Analyze MLflow experiments for model comparison and promotion decisions.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
analyzer = ExperimentAnalyzer("chat-inference")
|
||||||
|
|
||||||
|
# Get metrics summary for last 24 hours
|
||||||
|
summary = analyzer.get_metrics_summary(hours=24)
|
||||||
|
|
||||||
|
# Compare models by accuracy
|
||||||
|
best = analyzer.get_best_run(metric="eval.accuracy", minimize=False)
|
||||||
|
|
||||||
|
# Analyze inference latency trends
|
||||||
|
trends = analyzer.get_metric_trends("total_latency_mean", days=7)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
experiment_name: str,
|
||||||
|
tracking_uri: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the experiment analyzer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
experiment_name: Name of the MLflow experiment to analyze
|
||||||
|
tracking_uri: Override default tracking URI
|
||||||
|
"""
|
||||||
|
self.experiment_name = experiment_name
|
||||||
|
self.tracking_uri = tracking_uri
|
||||||
|
self.client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||||
|
self._experiment: Optional[Experiment] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def experiment(self) -> Optional[Experiment]:
|
||||||
|
"""Get the experiment object, fetching if needed."""
|
||||||
|
if self._experiment is None:
|
||||||
|
self._experiment = self.client.get_experiment_by_name(self.experiment_name)
|
||||||
|
return self._experiment
|
||||||
|
|
||||||
|
def search_runs(
|
||||||
|
self,
|
||||||
|
filter_string: str = "",
|
||||||
|
order_by: Optional[List[str]] = None,
|
||||||
|
max_results: int = 100,
|
||||||
|
run_view_type: str = "ACTIVE_ONLY",
|
||||||
|
) -> List[Run]:
|
||||||
|
"""
|
||||||
|
Search for runs matching criteria.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_string: MLflow search filter (e.g., "metrics.accuracy > 0.9")
|
||||||
|
order_by: List of order clauses (e.g., ["metrics.accuracy DESC"])
|
||||||
|
max_results: Maximum runs to return
|
||||||
|
run_view_type: ACTIVE_ONLY, DELETED_ONLY, or ALL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching Run objects
|
||||||
|
"""
|
||||||
|
if not self.experiment:
|
||||||
|
logger.warning(f"Experiment '{self.experiment_name}' not found")
|
||||||
|
return []
|
||||||
|
|
||||||
|
runs = self.client.search_runs(
|
||||||
|
experiment_ids=[self.experiment.experiment_id],
|
||||||
|
filter_string=filter_string,
|
||||||
|
order_by=order_by or ["start_time DESC"],
|
||||||
|
max_results=max_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
return runs
|
||||||
|
|
||||||
|
def get_recent_runs(
|
||||||
|
self,
|
||||||
|
n: int = 10,
|
||||||
|
hours: Optional[int] = None,
|
||||||
|
) -> List[Run]:
|
||||||
|
"""
|
||||||
|
Get the most recent runs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n: Number of runs to return
|
||||||
|
hours: Only include runs from the last N hours
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Run objects
|
||||||
|
"""
|
||||||
|
filter_string = ""
|
||||||
|
if hours:
|
||||||
|
cutoff = datetime.now() - timedelta(hours=hours)
|
||||||
|
cutoff_ms = int(cutoff.timestamp() * 1000)
|
||||||
|
filter_string = f"attributes.start_time >= {cutoff_ms}"
|
||||||
|
|
||||||
|
return self.search_runs(
|
||||||
|
filter_string=filter_string,
|
||||||
|
order_by=["start_time DESC"],
|
||||||
|
max_results=n,
|
||||||
|
)
|
||||||
|
|
||||||
|
def compare_runs(
|
||||||
|
self,
|
||||||
|
run_ids: Optional[List[str]] = None,
|
||||||
|
n_recent: int = 5,
|
||||||
|
) -> RunComparison:
|
||||||
|
"""
|
||||||
|
Compare multiple runs side by side.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_ids: Specific run IDs to compare, or None for recent runs
|
||||||
|
n_recent: If run_ids is None, compare this many recent runs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RunComparison object with detailed comparison
|
||||||
|
"""
|
||||||
|
if run_ids:
|
||||||
|
runs = [self.client.get_run(rid) for rid in run_ids]
|
||||||
|
else:
|
||||||
|
runs = self.get_recent_runs(n=n_recent)
|
||||||
|
|
||||||
|
comparison = RunComparison(
|
||||||
|
run_ids=[r.info.run_id for r in runs],
|
||||||
|
experiment_name=self.experiment_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect all metrics and find best performers
|
||||||
|
all_metrics: Dict[str, Dict[str, float]] = defaultdict(dict)
|
||||||
|
|
||||||
|
for run in runs:
|
||||||
|
run_id = run.info.run_id
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
comparison.run_names[run_id] = run.info.run_name or run_id[:8]
|
||||||
|
comparison.start_times[run_id] = datetime.fromtimestamp(
|
||||||
|
run.info.start_time / 1000
|
||||||
|
)
|
||||||
|
if run.info.end_time:
|
||||||
|
comparison.durations[run_id] = (
|
||||||
|
run.info.end_time - run.info.start_time
|
||||||
|
) / 1000
|
||||||
|
|
||||||
|
# Metrics
|
||||||
|
for key, value in run.data.metrics.items():
|
||||||
|
all_metrics[key][run_id] = value
|
||||||
|
|
||||||
|
# Params
|
||||||
|
for key, value in run.data.params.items():
|
||||||
|
if key not in comparison.params:
|
||||||
|
comparison.params[key] = {}
|
||||||
|
comparison.params[key][run_id] = value
|
||||||
|
|
||||||
|
comparison.metrics = dict(all_metrics)
|
||||||
|
|
||||||
|
# Find best performers for each metric
|
||||||
|
for metric_name, values in all_metrics.items():
|
||||||
|
if not values:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Determine if lower is better based on metric name
|
||||||
|
minimize = any(
|
||||||
|
term in metric_name.lower()
|
||||||
|
for term in ["latency", "error", "loss", "time"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if minimize:
|
||||||
|
best_id = min(values.keys(), key=lambda k: values[k])
|
||||||
|
else:
|
||||||
|
best_id = max(values.keys(), key=lambda k: values[k])
|
||||||
|
|
||||||
|
comparison.best_by_metric[metric_name] = best_id
|
||||||
|
|
||||||
|
return comparison
|
||||||
|
|
||||||
|
def get_best_run(
|
||||||
|
self,
|
||||||
|
metric: str,
|
||||||
|
minimize: bool = True,
|
||||||
|
filter_string: str = "",
|
||||||
|
max_results: int = 100,
|
||||||
|
) -> Optional[Run]:
|
||||||
|
"""
|
||||||
|
Get the best run by a specific metric.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metric: Metric name to optimize
|
||||||
|
minimize: If True, find minimum; if False, find maximum
|
||||||
|
filter_string: Additional filter criteria
|
||||||
|
max_results: Maximum runs to consider
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Best Run object, or None if no runs found
|
||||||
|
"""
|
||||||
|
direction = "ASC" if minimize else "DESC"
|
||||||
|
|
||||||
|
runs = self.search_runs(
|
||||||
|
filter_string=filter_string,
|
||||||
|
order_by=[f"metrics.{metric} {direction}"],
|
||||||
|
max_results=max_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter to only runs that have the metric
|
||||||
|
runs_with_metric = [
|
||||||
|
r for r in runs
|
||||||
|
if metric in r.data.metrics
|
||||||
|
]
|
||||||
|
|
||||||
|
return runs_with_metric[0] if runs_with_metric else None
|
||||||
|
|
||||||
|
def get_metrics_summary(
|
||||||
|
self,
|
||||||
|
hours: Optional[int] = None,
|
||||||
|
metrics: Optional[List[str]] = None,
|
||||||
|
) -> Dict[str, Dict[str, float]]:
|
||||||
|
"""
|
||||||
|
Get summary statistics for metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hours: Only include runs from the last N hours
|
||||||
|
metrics: Specific metrics to summarize (None for all)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping metric names to {mean, min, max, count}
|
||||||
|
"""
|
||||||
|
import statistics
|
||||||
|
|
||||||
|
runs = self.get_recent_runs(n=1000, hours=hours)
|
||||||
|
|
||||||
|
# Collect all metric values
|
||||||
|
metric_values: Dict[str, List[float]] = defaultdict(list)
|
||||||
|
|
||||||
|
for run in runs:
|
||||||
|
for key, value in run.data.metrics.items():
|
||||||
|
if metrics is None or key in metrics:
|
||||||
|
metric_values[key].append(value)
|
||||||
|
|
||||||
|
# Calculate statistics
|
||||||
|
summary = {}
|
||||||
|
for metric_name, values in metric_values.items():
|
||||||
|
if not values:
|
||||||
|
continue
|
||||||
|
|
||||||
|
summary[metric_name] = {
|
||||||
|
"mean": statistics.mean(values),
|
||||||
|
"min": min(values),
|
||||||
|
"max": max(values),
|
||||||
|
"count": len(values),
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(values) >= 2:
|
||||||
|
summary[metric_name]["stdev"] = statistics.stdev(values)
|
||||||
|
summary[metric_name]["median"] = statistics.median(values)
|
||||||
|
|
||||||
|
return summary
|
||||||
|
|
||||||
|
def get_metric_trends(
|
||||||
|
self,
|
||||||
|
metric: str,
|
||||||
|
days: int = 7,
|
||||||
|
granularity_hours: int = 1,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get metric trends over time.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metric: Metric name to track
|
||||||
|
days: Number of days to look back
|
||||||
|
granularity_hours: Time bucket size in hours
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of {timestamp, mean, min, max, count} dicts
|
||||||
|
"""
|
||||||
|
import statistics
|
||||||
|
|
||||||
|
runs = self.get_recent_runs(n=10000, hours=days * 24)
|
||||||
|
|
||||||
|
# Group runs by time bucket
|
||||||
|
buckets: Dict[int, List[float]] = defaultdict(list)
|
||||||
|
bucket_size_ms = granularity_hours * 3600 * 1000
|
||||||
|
|
||||||
|
for run in runs:
|
||||||
|
if metric not in run.data.metrics:
|
||||||
|
continue
|
||||||
|
|
||||||
|
bucket = (run.info.start_time // bucket_size_ms) * bucket_size_ms
|
||||||
|
buckets[bucket].append(run.data.metrics[metric])
|
||||||
|
|
||||||
|
# Calculate statistics per bucket
|
||||||
|
trends = []
|
||||||
|
for bucket_ts, values in sorted(buckets.items()):
|
||||||
|
trend = {
|
||||||
|
"timestamp": datetime.fromtimestamp(bucket_ts / 1000).isoformat(),
|
||||||
|
"count": len(values),
|
||||||
|
"mean": statistics.mean(values),
|
||||||
|
"min": min(values),
|
||||||
|
"max": max(values),
|
||||||
|
}
|
||||||
|
if len(values) >= 2:
|
||||||
|
trend["stdev"] = statistics.stdev(values)
|
||||||
|
trends.append(trend)
|
||||||
|
|
||||||
|
return trends
|
||||||
|
|
||||||
|
def get_runs_by_tag(
|
||||||
|
self,
|
||||||
|
tag_key: str,
|
||||||
|
tag_value: str,
|
||||||
|
max_results: int = 100,
|
||||||
|
) -> List[Run]:
|
||||||
|
"""
|
||||||
|
Get runs with a specific tag.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag_key: Tag key to filter by
|
||||||
|
tag_value: Tag value to match
|
||||||
|
max_results: Maximum runs to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching Run objects
|
||||||
|
"""
|
||||||
|
return self.search_runs(
|
||||||
|
filter_string=f"tags.{tag_key} = '{tag_value}'",
|
||||||
|
max_results=max_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_model_runs(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
max_results: int = 100,
|
||||||
|
) -> List[Run]:
|
||||||
|
"""
|
||||||
|
Get runs for a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Model name to filter by
|
||||||
|
max_results: Maximum runs to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching Run objects
|
||||||
|
"""
|
||||||
|
# Try different tag conventions
|
||||||
|
runs = self.search_runs(
|
||||||
|
filter_string=f"tags.`model.name` = '{model_name}'",
|
||||||
|
max_results=max_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not runs:
|
||||||
|
# Try params
|
||||||
|
runs = self.search_runs(
|
||||||
|
filter_string=f"params.model_name = '{model_name}'",
|
||||||
|
max_results=max_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
return runs
|
||||||
|
|
||||||
|
|
||||||
|
def compare_experiments(
|
||||||
|
experiment_names: List[str],
|
||||||
|
metric: str,
|
||||||
|
tracking_uri: Optional[str] = None,
|
||||||
|
) -> Dict[str, Dict[str, float]]:
|
||||||
|
"""
|
||||||
|
Compare metrics across multiple experiments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
experiment_names: Names of experiments to compare
|
||||||
|
metric: Metric to compare
|
||||||
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping experiment names to metric statistics
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for exp_name in experiment_names:
|
||||||
|
analyzer = ExperimentAnalyzer(exp_name, tracking_uri=tracking_uri)
|
||||||
|
summary = analyzer.get_metrics_summary(metrics=[metric])
|
||||||
|
if metric in summary:
|
||||||
|
results[exp_name] = summary[metric]
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def promotion_recommendation(
|
||||||
|
model_name: str,
|
||||||
|
experiment_name: str,
|
||||||
|
criteria: Dict[str, Tuple[str, float]],
|
||||||
|
tracking_uri: Optional[str] = None,
|
||||||
|
) -> PromotionRecommendation:
|
||||||
|
"""
|
||||||
|
Generate a recommendation for model promotion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model to evaluate
|
||||||
|
experiment_name: Experiment containing evaluation runs
|
||||||
|
criteria: Dict of {metric: (comparison, threshold)}
|
||||||
|
comparison is one of: ">=", "<=", ">", "<"
|
||||||
|
e.g., {"eval.accuracy": (">=", 0.9), "total_latency_p95": ("<=", 2.0)}
|
||||||
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PromotionRecommendation with decision and reasons
|
||||||
|
"""
|
||||||
|
analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri)
|
||||||
|
|
||||||
|
# Get model runs
|
||||||
|
runs = analyzer.get_model_runs(model_name, max_results=10)
|
||||||
|
|
||||||
|
if not runs:
|
||||||
|
return PromotionRecommendation(
|
||||||
|
model_name=model_name,
|
||||||
|
version=None,
|
||||||
|
recommended=False,
|
||||||
|
reasons=["No runs found for this model"],
|
||||||
|
metrics_summary={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the most recent run
|
||||||
|
latest_run = runs[0]
|
||||||
|
metrics = latest_run.data.metrics
|
||||||
|
|
||||||
|
# Evaluate criteria
|
||||||
|
reasons = []
|
||||||
|
passed = True
|
||||||
|
|
||||||
|
comparisons = {
|
||||||
|
">=": lambda a, b: a >= b,
|
||||||
|
"<=": lambda a, b: a <= b,
|
||||||
|
">": lambda a, b: a > b,
|
||||||
|
"<": lambda a, b: a < b,
|
||||||
|
}
|
||||||
|
|
||||||
|
for metric_name, (comparison, threshold) in criteria.items():
|
||||||
|
if metric_name not in metrics:
|
||||||
|
reasons.append(f"Metric '{metric_name}' not found")
|
||||||
|
passed = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
value = metrics[metric_name]
|
||||||
|
compare_fn = comparisons.get(comparison)
|
||||||
|
|
||||||
|
if compare_fn is None:
|
||||||
|
reasons.append(f"Invalid comparison operator: {comparison}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if compare_fn(value, threshold):
|
||||||
|
reasons.append(f"✓ {metric_name}: {value:.4f} {comparison} {threshold}")
|
||||||
|
else:
|
||||||
|
reasons.append(f"✗ {metric_name}: {value:.4f} NOT {comparison} {threshold}")
|
||||||
|
passed = False
|
||||||
|
|
||||||
|
# Extract version from tags if available
|
||||||
|
version = None
|
||||||
|
if "mlflow.version" in latest_run.data.tags:
|
||||||
|
try:
|
||||||
|
version = int(latest_run.data.tags["mlflow.version"])
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return PromotionRecommendation(
|
||||||
|
model_name=model_name,
|
||||||
|
version=version,
|
||||||
|
recommended=passed,
|
||||||
|
reasons=reasons,
|
||||||
|
metrics_summary=dict(metrics),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_inference_performance_report(
|
||||||
|
service_name: str = "chat-handler",
|
||||||
|
hours: int = 24,
|
||||||
|
tracking_uri: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Generate an inference performance report for a service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
service_name: Service name (chat-handler, voice-assistant)
|
||||||
|
hours: Hours of data to analyze
|
||||||
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Performance report dictionary
|
||||||
|
"""
|
||||||
|
experiment_name = f"{service_name.replace('-', '')}-inference"
|
||||||
|
analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri)
|
||||||
|
|
||||||
|
# Get summary metrics
|
||||||
|
summary = analyzer.get_metrics_summary(hours=hours)
|
||||||
|
|
||||||
|
# Key latency metrics
|
||||||
|
latency_metrics = [
|
||||||
|
"total_latency_mean",
|
||||||
|
"total_latency_p50",
|
||||||
|
"total_latency_p95",
|
||||||
|
"llm_latency_mean",
|
||||||
|
"embedding_latency_mean",
|
||||||
|
"rag_search_latency_mean",
|
||||||
|
]
|
||||||
|
|
||||||
|
report = {
|
||||||
|
"service": service_name,
|
||||||
|
"period_hours": hours,
|
||||||
|
"generated_at": datetime.now().isoformat(),
|
||||||
|
"latency": {},
|
||||||
|
"throughput": {},
|
||||||
|
"rag": {},
|
||||||
|
"errors": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Latency section
|
||||||
|
for metric in latency_metrics:
|
||||||
|
if metric in summary:
|
||||||
|
report["latency"][metric] = summary[metric]
|
||||||
|
|
||||||
|
# Throughput
|
||||||
|
if "total_requests" in summary:
|
||||||
|
report["throughput"]["total_requests"] = summary["total_requests"]["mean"]
|
||||||
|
|
||||||
|
# RAG usage
|
||||||
|
rag_metrics = ["rag_enabled_pct", "rag_documents_retrieved_mean", "rag_documents_used_mean"]
|
||||||
|
for metric in rag_metrics:
|
||||||
|
if metric in summary:
|
||||||
|
report["rag"][metric] = summary[metric]
|
||||||
|
|
||||||
|
# Error rate
|
||||||
|
if "error_rate" in summary:
|
||||||
|
report["errors"]["error_rate_pct"] = summary["error_rate"]["mean"]
|
||||||
|
|
||||||
|
return report
|
||||||
431
mlflow_utils/inference_tracker.py
Normal file
431
mlflow_utils/inference_tracker.py
Normal file
@@ -0,0 +1,431 @@
|
|||||||
|
"""
|
||||||
|
Inference Metrics Tracker for NATS Handlers
|
||||||
|
|
||||||
|
Provides async-compatible MLflow logging for real-time inference
|
||||||
|
metrics from chat-handler and voice-assistant services.
|
||||||
|
|
||||||
|
Designed to integrate with the existing OpenTelemetry setup and
|
||||||
|
complement OTel metrics with MLflow experiment tracking for
|
||||||
|
longer-term analysis and model comparison.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from collections import defaultdict
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import mlflow
|
||||||
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
|
from .client import get_mlflow_client, ensure_experiment, MLflowConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InferenceMetrics:
|
||||||
|
"""Metrics collected during an inference request."""
|
||||||
|
request_id: str
|
||||||
|
user_id: Optional[str] = None
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
|
||||||
|
# Timing metrics (in seconds)
|
||||||
|
total_latency: float = 0.0
|
||||||
|
embedding_latency: float = 0.0
|
||||||
|
rag_search_latency: float = 0.0
|
||||||
|
rerank_latency: float = 0.0
|
||||||
|
llm_latency: float = 0.0
|
||||||
|
tts_latency: float = 0.0
|
||||||
|
stt_latency: float = 0.0
|
||||||
|
|
||||||
|
# Token/size metrics
|
||||||
|
input_tokens: int = 0
|
||||||
|
output_tokens: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
prompt_length: int = 0
|
||||||
|
response_length: int = 0
|
||||||
|
|
||||||
|
# RAG metrics
|
||||||
|
rag_enabled: bool = False
|
||||||
|
rag_documents_retrieved: int = 0
|
||||||
|
rag_documents_used: int = 0
|
||||||
|
reranker_enabled: bool = False
|
||||||
|
|
||||||
|
# Quality indicators
|
||||||
|
is_streaming: bool = False
|
||||||
|
is_premium: bool = False
|
||||||
|
has_error: bool = False
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
|
||||||
|
# Model information
|
||||||
|
model_name: Optional[str] = None
|
||||||
|
model_endpoint: Optional[str] = None
|
||||||
|
|
||||||
|
# Timestamps
|
||||||
|
timestamp: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
def as_metrics_dict(self) -> Dict[str, float]:
|
||||||
|
"""Convert numeric fields to a metrics dictionary."""
|
||||||
|
return {
|
||||||
|
"total_latency": self.total_latency,
|
||||||
|
"embedding_latency": self.embedding_latency,
|
||||||
|
"rag_search_latency": self.rag_search_latency,
|
||||||
|
"rerank_latency": self.rerank_latency,
|
||||||
|
"llm_latency": self.llm_latency,
|
||||||
|
"tts_latency": self.tts_latency,
|
||||||
|
"stt_latency": self.stt_latency,
|
||||||
|
"input_tokens": float(self.input_tokens),
|
||||||
|
"output_tokens": float(self.output_tokens),
|
||||||
|
"total_tokens": float(self.total_tokens),
|
||||||
|
"prompt_length": float(self.prompt_length),
|
||||||
|
"response_length": float(self.response_length),
|
||||||
|
"rag_documents_retrieved": float(self.rag_documents_retrieved),
|
||||||
|
"rag_documents_used": float(self.rag_documents_used),
|
||||||
|
}
|
||||||
|
|
||||||
|
def as_params_dict(self) -> Dict[str, str]:
|
||||||
|
"""Convert configuration fields to a params dictionary."""
|
||||||
|
params = {
|
||||||
|
"rag_enabled": str(self.rag_enabled),
|
||||||
|
"reranker_enabled": str(self.reranker_enabled),
|
||||||
|
"is_streaming": str(self.is_streaming),
|
||||||
|
"is_premium": str(self.is_premium),
|
||||||
|
}
|
||||||
|
if self.model_name:
|
||||||
|
params["model_name"] = self.model_name
|
||||||
|
if self.model_endpoint:
|
||||||
|
params["model_endpoint"] = self.model_endpoint
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceMetricsTracker:
|
||||||
|
"""
|
||||||
|
Async-compatible MLflow tracker for inference metrics.
|
||||||
|
|
||||||
|
Uses batching and a background thread pool to avoid blocking
|
||||||
|
the async event loop during MLflow calls.
|
||||||
|
|
||||||
|
Example usage in chat-handler:
|
||||||
|
|
||||||
|
class ChatHandler:
|
||||||
|
def __init__(self):
|
||||||
|
self.mlflow_tracker = InferenceMetricsTracker(
|
||||||
|
service_name="chat-handler",
|
||||||
|
experiment_name="chat-inference"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def setup(self):
|
||||||
|
await self.mlflow_tracker.start()
|
||||||
|
|
||||||
|
async def process_request(self, msg):
|
||||||
|
metrics = InferenceMetrics(request_id=request_id)
|
||||||
|
|
||||||
|
# Track timing
|
||||||
|
start = time.time()
|
||||||
|
# ... do embedding ...
|
||||||
|
metrics.embedding_latency = time.time() - start
|
||||||
|
|
||||||
|
# ... more processing ...
|
||||||
|
|
||||||
|
# Log metrics (non-blocking)
|
||||||
|
await self.mlflow_tracker.log_inference(metrics)
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
await self.mlflow_tracker.stop()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
service_name: str,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
|
tracking_uri: Optional[str] = None,
|
||||||
|
batch_size: int = 50,
|
||||||
|
flush_interval_seconds: float = 60.0,
|
||||||
|
enable_batching: bool = True,
|
||||||
|
max_workers: int = 2,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the inference metrics tracker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
service_name: Name of the service (e.g., "chat-handler")
|
||||||
|
experiment_name: MLflow experiment name (defaults to service_name)
|
||||||
|
tracking_uri: Override default tracking URI
|
||||||
|
batch_size: Number of metrics to batch before flushing
|
||||||
|
flush_interval_seconds: Maximum time between flushes
|
||||||
|
enable_batching: If False, log each request immediately
|
||||||
|
max_workers: Number of thread pool workers for MLflow calls
|
||||||
|
"""
|
||||||
|
self.service_name = service_name
|
||||||
|
self.experiment_name = experiment_name or f"{service_name}-inference"
|
||||||
|
self.tracking_uri = tracking_uri
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.flush_interval = flush_interval_seconds
|
||||||
|
self.enable_batching = enable_batching
|
||||||
|
|
||||||
|
self.config = MLflowConfig()
|
||||||
|
self._batch: List[InferenceMetrics] = []
|
||||||
|
self._batch_lock = asyncio.Lock()
|
||||||
|
self._flush_task: Optional[asyncio.Task] = None
|
||||||
|
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||||
|
self._running = False
|
||||||
|
self._client: Optional[MlflowClient] = None
|
||||||
|
self._experiment_id: Optional[str] = None
|
||||||
|
|
||||||
|
# Aggregate metrics for periodic logging
|
||||||
|
self._aggregate_metrics: Dict[str, List[float]] = defaultdict(list)
|
||||||
|
self._request_count = 0
|
||||||
|
self._error_count = 0
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start the tracker and initialize MLflow connection."""
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._running = True
|
||||||
|
|
||||||
|
# Initialize MLflow in thread pool to avoid blocking
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
await loop.run_in_executor(
|
||||||
|
self._executor,
|
||||||
|
self._init_mlflow
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.enable_batching:
|
||||||
|
self._flush_task = asyncio.create_task(self._periodic_flush())
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"InferenceMetricsTracker started for {self.service_name} "
|
||||||
|
f"(experiment: {self.experiment_name})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_mlflow(self) -> None:
|
||||||
|
"""Initialize MLflow client and experiment (runs in thread pool)."""
|
||||||
|
self._client = get_mlflow_client(
|
||||||
|
tracking_uri=self.tracking_uri,
|
||||||
|
configure_global=True
|
||||||
|
)
|
||||||
|
self._experiment_id = ensure_experiment(
|
||||||
|
self.experiment_name,
|
||||||
|
tags={
|
||||||
|
"service": self.service_name,
|
||||||
|
"type": "inference-metrics",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop the tracker and flush remaining metrics."""
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
if self._flush_task:
|
||||||
|
self._flush_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._flush_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Final flush
|
||||||
|
await self._flush_batch()
|
||||||
|
|
||||||
|
self._executor.shutdown(wait=True)
|
||||||
|
logger.info(f"InferenceMetricsTracker stopped for {self.service_name}")
|
||||||
|
|
||||||
|
async def log_inference(self, metrics: InferenceMetrics) -> None:
|
||||||
|
"""
|
||||||
|
Log inference metrics (non-blocking).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metrics: InferenceMetrics object with request data
|
||||||
|
"""
|
||||||
|
if not self._running:
|
||||||
|
logger.warning("Tracker not running, skipping metrics")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._request_count += 1
|
||||||
|
if metrics.has_error:
|
||||||
|
self._error_count += 1
|
||||||
|
|
||||||
|
# Update aggregates
|
||||||
|
for key, value in metrics.as_metrics_dict().items():
|
||||||
|
if value > 0:
|
||||||
|
self._aggregate_metrics[key].append(value)
|
||||||
|
|
||||||
|
if self.enable_batching:
|
||||||
|
async with self._batch_lock:
|
||||||
|
self._batch.append(metrics)
|
||||||
|
if len(self._batch) >= self.batch_size:
|
||||||
|
asyncio.create_task(self._flush_batch())
|
||||||
|
else:
|
||||||
|
# Immediate logging in thread pool
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
await loop.run_in_executor(
|
||||||
|
self._executor,
|
||||||
|
partial(self._log_single_inference, metrics)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _periodic_flush(self) -> None:
|
||||||
|
"""Periodically flush batched metrics."""
|
||||||
|
while self._running:
|
||||||
|
await asyncio.sleep(self.flush_interval)
|
||||||
|
await self._flush_batch()
|
||||||
|
|
||||||
|
async def _flush_batch(self) -> None:
|
||||||
|
"""Flush the current batch of metrics to MLflow."""
|
||||||
|
async with self._batch_lock:
|
||||||
|
if not self._batch:
|
||||||
|
return
|
||||||
|
|
||||||
|
batch = self._batch
|
||||||
|
self._batch = []
|
||||||
|
|
||||||
|
# Log in thread pool
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
await loop.run_in_executor(
|
||||||
|
self._executor,
|
||||||
|
partial(self._log_batch, batch)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _log_single_inference(self, metrics: InferenceMetrics) -> None:
|
||||||
|
"""Log a single inference request to MLflow (runs in thread pool)."""
|
||||||
|
try:
|
||||||
|
with mlflow.start_run(
|
||||||
|
experiment_id=self._experiment_id,
|
||||||
|
run_name=f"inference-{metrics.request_id}",
|
||||||
|
tags={
|
||||||
|
"service": self.service_name,
|
||||||
|
"request_id": metrics.request_id,
|
||||||
|
"type": "single-inference",
|
||||||
|
}
|
||||||
|
):
|
||||||
|
mlflow.log_params(metrics.as_params_dict())
|
||||||
|
mlflow.log_metrics(metrics.as_metrics_dict())
|
||||||
|
|
||||||
|
if metrics.user_id:
|
||||||
|
mlflow.set_tag("user_id", metrics.user_id)
|
||||||
|
if metrics.session_id:
|
||||||
|
mlflow.set_tag("session_id", metrics.session_id)
|
||||||
|
if metrics.has_error:
|
||||||
|
mlflow.set_tag("has_error", "true")
|
||||||
|
if metrics.error_message:
|
||||||
|
mlflow.set_tag("error_message", metrics.error_message[:250])
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to log inference metrics: {e}")
|
||||||
|
|
||||||
|
def _log_batch(self, batch: List[InferenceMetrics]) -> None:
|
||||||
|
"""Log a batch of inference metrics as aggregate statistics."""
|
||||||
|
if not batch:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Calculate aggregates
|
||||||
|
aggregates = self._calculate_aggregates(batch)
|
||||||
|
|
||||||
|
run_name = f"batch-{self.service_name}-{int(time.time())}"
|
||||||
|
|
||||||
|
with mlflow.start_run(
|
||||||
|
experiment_id=self._experiment_id,
|
||||||
|
run_name=run_name,
|
||||||
|
tags={
|
||||||
|
"service": self.service_name,
|
||||||
|
"type": "batch-inference",
|
||||||
|
"batch_size": str(len(batch)),
|
||||||
|
}
|
||||||
|
):
|
||||||
|
# Log aggregate metrics
|
||||||
|
mlflow.log_metrics(aggregates)
|
||||||
|
|
||||||
|
# Log batch info
|
||||||
|
mlflow.log_param("batch_size", len(batch))
|
||||||
|
mlflow.log_param("time_window_start", min(m.timestamp for m in batch))
|
||||||
|
mlflow.log_param("time_window_end", max(m.timestamp for m in batch))
|
||||||
|
|
||||||
|
# Log configuration breakdown
|
||||||
|
rag_enabled_count = sum(1 for m in batch if m.rag_enabled)
|
||||||
|
streaming_count = sum(1 for m in batch if m.is_streaming)
|
||||||
|
premium_count = sum(1 for m in batch if m.is_premium)
|
||||||
|
error_count = sum(1 for m in batch if m.has_error)
|
||||||
|
|
||||||
|
mlflow.log_metrics({
|
||||||
|
"rag_enabled_pct": rag_enabled_count / len(batch) * 100,
|
||||||
|
"streaming_pct": streaming_count / len(batch) * 100,
|
||||||
|
"premium_pct": premium_count / len(batch) * 100,
|
||||||
|
"error_rate": error_count / len(batch) * 100,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Log model distribution
|
||||||
|
model_counts: Dict[str, int] = defaultdict(int)
|
||||||
|
for m in batch:
|
||||||
|
if m.model_name:
|
||||||
|
model_counts[m.model_name] += 1
|
||||||
|
|
||||||
|
if model_counts:
|
||||||
|
mlflow.log_dict(
|
||||||
|
{"models": dict(model_counts)},
|
||||||
|
"model_distribution.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Logged batch of {len(batch)} inference metrics")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to log batch metrics: {e}")
|
||||||
|
|
||||||
|
def _calculate_aggregates(
|
||||||
|
self,
|
||||||
|
batch: List[InferenceMetrics]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""Calculate aggregate statistics from a batch of metrics."""
|
||||||
|
import statistics
|
||||||
|
|
||||||
|
aggregates = {}
|
||||||
|
|
||||||
|
# Collect all numeric metrics
|
||||||
|
metric_values: Dict[str, List[float]] = defaultdict(list)
|
||||||
|
for m in batch:
|
||||||
|
for key, value in m.as_metrics_dict().items():
|
||||||
|
if value > 0:
|
||||||
|
metric_values[key].append(value)
|
||||||
|
|
||||||
|
# Calculate statistics for each metric
|
||||||
|
for key, values in metric_values.items():
|
||||||
|
if not values:
|
||||||
|
continue
|
||||||
|
|
||||||
|
aggregates[f"{key}_mean"] = statistics.mean(values)
|
||||||
|
aggregates[f"{key}_min"] = min(values)
|
||||||
|
aggregates[f"{key}_max"] = max(values)
|
||||||
|
|
||||||
|
if len(values) >= 2:
|
||||||
|
aggregates[f"{key}_p50"] = statistics.median(values)
|
||||||
|
aggregates[f"{key}_stdev"] = statistics.stdev(values)
|
||||||
|
|
||||||
|
if len(values) >= 4:
|
||||||
|
sorted_vals = sorted(values)
|
||||||
|
p95_idx = int(len(sorted_vals) * 0.95)
|
||||||
|
p99_idx = int(len(sorted_vals) * 0.99)
|
||||||
|
aggregates[f"{key}_p95"] = sorted_vals[p95_idx]
|
||||||
|
aggregates[f"{key}_p99"] = sorted_vals[p99_idx]
|
||||||
|
|
||||||
|
# Add counts
|
||||||
|
aggregates["total_requests"] = float(len(batch))
|
||||||
|
|
||||||
|
return aggregates
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get current tracker statistics."""
|
||||||
|
return {
|
||||||
|
"service_name": self.service_name,
|
||||||
|
"experiment_name": self.experiment_name,
|
||||||
|
"running": self._running,
|
||||||
|
"total_requests": self._request_count,
|
||||||
|
"error_count": self._error_count,
|
||||||
|
"pending_batch_size": len(self._batch),
|
||||||
|
"aggregate_metrics_count": len(self._aggregate_metrics),
|
||||||
|
}
|
||||||
513
mlflow_utils/kfp_components.py
Normal file
513
mlflow_utils/kfp_components.py
Normal file
@@ -0,0 +1,513 @@
|
|||||||
|
"""
|
||||||
|
Kubeflow Pipeline Components with MLflow Tracking
|
||||||
|
|
||||||
|
Provides reusable KFP components that integrate MLflow experiment
|
||||||
|
tracking into Kubeflow Pipelines. These components can be used
|
||||||
|
directly in pipelines or as wrappers around existing pipeline steps.
|
||||||
|
|
||||||
|
Usage in a Kubeflow Pipeline:
|
||||||
|
|
||||||
|
from mlflow_utils.kfp_components import (
|
||||||
|
create_mlflow_run,
|
||||||
|
log_metrics_component,
|
||||||
|
log_model_artifact,
|
||||||
|
end_mlflow_run,
|
||||||
|
)
|
||||||
|
|
||||||
|
@dsl.pipeline(name="my-pipeline")
|
||||||
|
def my_pipeline():
|
||||||
|
# Start MLflow run
|
||||||
|
run_info = create_mlflow_run(
|
||||||
|
experiment_name="my-experiment",
|
||||||
|
run_name="training-run-1"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ... your pipeline steps ...
|
||||||
|
|
||||||
|
# Log metrics
|
||||||
|
log_step = log_metrics_component(
|
||||||
|
run_id=run_info.outputs["run_id"],
|
||||||
|
metrics={"accuracy": 0.95, "loss": 0.05}
|
||||||
|
)
|
||||||
|
|
||||||
|
# End run
|
||||||
|
end_mlflow_run(run_id=run_info.outputs["run_id"])
|
||||||
|
"""
|
||||||
|
|
||||||
|
from kfp import dsl
|
||||||
|
from typing import Dict, Any, List, Optional, NamedTuple
|
||||||
|
|
||||||
|
|
||||||
|
# MLflow component image with all required dependencies
|
||||||
|
MLFLOW_IMAGE = "python:3.13-slim"
|
||||||
|
MLFLOW_PACKAGES = [
|
||||||
|
"mlflow>=2.10.0",
|
||||||
|
"boto3", # For S3 artifact storage if needed
|
||||||
|
"psycopg2-binary", # For PostgreSQL backend
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dsl.component(
|
||||||
|
base_image=MLFLOW_IMAGE,
|
||||||
|
packages_to_install=MLFLOW_PACKAGES
|
||||||
|
)
|
||||||
|
def create_mlflow_run(
|
||||||
|
experiment_name: str,
|
||||||
|
run_name: str,
|
||||||
|
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||||
|
tags: Dict[str, str] = None,
|
||||||
|
params: Dict[str, str] = None,
|
||||||
|
) -> NamedTuple('RunInfo', [('run_id', str), ('experiment_id', str), ('artifact_uri', str)]):
|
||||||
|
"""
|
||||||
|
Create a new MLflow run for the pipeline.
|
||||||
|
|
||||||
|
This should be called at the start of a pipeline to initialize
|
||||||
|
tracking. The returned run_id should be passed to subsequent
|
||||||
|
components for logging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
experiment_name: Name of the MLflow experiment
|
||||||
|
run_name: Name for this specific run
|
||||||
|
mlflow_tracking_uri: MLflow tracking server URI
|
||||||
|
tags: Optional tags to add to the run
|
||||||
|
params: Optional parameters to log
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
NamedTuple with run_id, experiment_id, and artifact_uri
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import mlflow
|
||||||
|
from mlflow.tracking import MlflowClient
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
# Set tracking URI
|
||||||
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||||
|
|
||||||
|
client = MlflowClient()
|
||||||
|
|
||||||
|
# Get or create experiment
|
||||||
|
experiment = client.get_experiment_by_name(experiment_name)
|
||||||
|
if experiment is None:
|
||||||
|
experiment_id = client.create_experiment(
|
||||||
|
name=experiment_name,
|
||||||
|
artifact_location=f"/mlflow/artifacts/{experiment_name}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
experiment_id = experiment.experiment_id
|
||||||
|
|
||||||
|
# Create default tags
|
||||||
|
default_tags = {
|
||||||
|
"pipeline.type": "kubeflow",
|
||||||
|
"kfp.run_id": os.environ.get("KFP_RUN_ID", "unknown"),
|
||||||
|
"kfp.pod_name": os.environ.get("HOSTNAME", "unknown"),
|
||||||
|
}
|
||||||
|
if tags:
|
||||||
|
default_tags.update(tags)
|
||||||
|
|
||||||
|
# Start run
|
||||||
|
run = mlflow.start_run(
|
||||||
|
experiment_id=experiment_id,
|
||||||
|
run_name=run_name,
|
||||||
|
tags=default_tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log initial params
|
||||||
|
if params:
|
||||||
|
mlflow.log_params(params)
|
||||||
|
|
||||||
|
run_id = run.info.run_id
|
||||||
|
artifact_uri = run.info.artifact_uri
|
||||||
|
|
||||||
|
# End run (KFP components are isolated, we'll resume in other components)
|
||||||
|
mlflow.end_run()
|
||||||
|
|
||||||
|
RunInfo = namedtuple('RunInfo', ['run_id', 'experiment_id', 'artifact_uri'])
|
||||||
|
return RunInfo(run_id, experiment_id, artifact_uri)
|
||||||
|
|
||||||
|
|
||||||
|
@dsl.component(
|
||||||
|
base_image=MLFLOW_IMAGE,
|
||||||
|
packages_to_install=MLFLOW_PACKAGES
|
||||||
|
)
|
||||||
|
def log_params_component(
|
||||||
|
run_id: str,
|
||||||
|
params: Dict[str, str],
|
||||||
|
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Log parameters to an existing MLflow run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The MLflow run ID to log to
|
||||||
|
params: Dictionary of parameters to log
|
||||||
|
mlflow_tracking_uri: MLflow tracking server URI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The run_id for chaining
|
||||||
|
"""
|
||||||
|
import mlflow
|
||||||
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||||
|
client = MlflowClient()
|
||||||
|
|
||||||
|
for key, value in params.items():
|
||||||
|
client.log_param(run_id, key, str(value)[:500])
|
||||||
|
|
||||||
|
return run_id
|
||||||
|
|
||||||
|
|
||||||
|
@dsl.component(
|
||||||
|
base_image=MLFLOW_IMAGE,
|
||||||
|
packages_to_install=MLFLOW_PACKAGES
|
||||||
|
)
|
||||||
|
def log_metrics_component(
|
||||||
|
run_id: str,
|
||||||
|
metrics: Dict[str, float],
|
||||||
|
step: int = 0,
|
||||||
|
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Log metrics to an existing MLflow run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The MLflow run ID to log to
|
||||||
|
metrics: Dictionary of metrics to log
|
||||||
|
step: Step number for time-series metrics
|
||||||
|
mlflow_tracking_uri: MLflow tracking server URI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The run_id for chaining
|
||||||
|
"""
|
||||||
|
import mlflow
|
||||||
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||||
|
client = MlflowClient()
|
||||||
|
|
||||||
|
for key, value in metrics.items():
|
||||||
|
client.log_metric(run_id, key, float(value), step=step)
|
||||||
|
|
||||||
|
return run_id
|
||||||
|
|
||||||
|
|
||||||
|
@dsl.component(
|
||||||
|
base_image=MLFLOW_IMAGE,
|
||||||
|
packages_to_install=MLFLOW_PACKAGES
|
||||||
|
)
|
||||||
|
def log_artifact_component(
|
||||||
|
run_id: str,
|
||||||
|
artifact_path: str,
|
||||||
|
artifact_name: str = "",
|
||||||
|
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Log an artifact file to an existing MLflow run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The MLflow run ID to log to
|
||||||
|
artifact_path: Path to the artifact file
|
||||||
|
artifact_name: Optional destination name in artifact store
|
||||||
|
mlflow_tracking_uri: MLflow tracking server URI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The run_id for chaining
|
||||||
|
"""
|
||||||
|
import mlflow
|
||||||
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||||
|
client = MlflowClient()
|
||||||
|
|
||||||
|
client.log_artifact(run_id, artifact_path, artifact_name or None)
|
||||||
|
|
||||||
|
return run_id
|
||||||
|
|
||||||
|
|
||||||
|
@dsl.component(
|
||||||
|
base_image=MLFLOW_IMAGE,
|
||||||
|
packages_to_install=MLFLOW_PACKAGES
|
||||||
|
)
|
||||||
|
def log_dict_artifact(
|
||||||
|
run_id: str,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
filename: str,
|
||||||
|
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Log a dictionary as a JSON artifact.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The MLflow run ID to log to
|
||||||
|
data: Dictionary to save as JSON
|
||||||
|
filename: Name for the JSON file
|
||||||
|
mlflow_tracking_uri: MLflow tracking server URI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The run_id for chaining
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
import mlflow
|
||||||
|
from mlflow.tracking import MlflowClient
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||||
|
client = MlflowClient()
|
||||||
|
|
||||||
|
# Ensure .json extension
|
||||||
|
if not filename.endswith('.json'):
|
||||||
|
filename += '.json'
|
||||||
|
|
||||||
|
# Write to temp file and log
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
filepath = Path(tmpdir) / filename
|
||||||
|
with open(filepath, 'w') as f:
|
||||||
|
json.dump(data, f, indent=2)
|
||||||
|
client.log_artifact(run_id, str(filepath))
|
||||||
|
|
||||||
|
return run_id
|
||||||
|
|
||||||
|
|
||||||
|
@dsl.component(
|
||||||
|
base_image=MLFLOW_IMAGE,
|
||||||
|
packages_to_install=MLFLOW_PACKAGES
|
||||||
|
)
|
||||||
|
def end_mlflow_run(
|
||||||
|
run_id: str,
|
||||||
|
status: str = "FINISHED",
|
||||||
|
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
End an MLflow run with the specified status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The MLflow run ID to end
|
||||||
|
status: Run status (FINISHED, FAILED, KILLED)
|
||||||
|
mlflow_tracking_uri: MLflow tracking server URI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The run_id
|
||||||
|
"""
|
||||||
|
import mlflow
|
||||||
|
from mlflow.tracking import MlflowClient
|
||||||
|
from mlflow.entities import RunStatus
|
||||||
|
|
||||||
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||||
|
client = MlflowClient()
|
||||||
|
|
||||||
|
status_map = {
|
||||||
|
"FINISHED": RunStatus.FINISHED,
|
||||||
|
"FAILED": RunStatus.FAILED,
|
||||||
|
"KILLED": RunStatus.KILLED,
|
||||||
|
}
|
||||||
|
|
||||||
|
run_status = status_map.get(status.upper(), RunStatus.FINISHED)
|
||||||
|
client.set_terminated(run_id, status=run_status)
|
||||||
|
|
||||||
|
return run_id
|
||||||
|
|
||||||
|
|
||||||
|
@dsl.component(
|
||||||
|
base_image=MLFLOW_IMAGE,
|
||||||
|
packages_to_install=MLFLOW_PACKAGES + ["httpx"]
|
||||||
|
)
|
||||||
|
def log_training_metrics(
|
||||||
|
run_id: str,
|
||||||
|
model_type: str,
|
||||||
|
training_config: Dict[str, Any],
|
||||||
|
final_metrics: Dict[str, float],
|
||||||
|
model_path: str = "",
|
||||||
|
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Log comprehensive training metrics for ML models.
|
||||||
|
|
||||||
|
Designed for use with QLoRA training, voice training, and other
|
||||||
|
ML training pipelines in the llm-workflows repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The MLflow run ID to log to
|
||||||
|
model_type: Type of model (llm, stt, tts, embeddings)
|
||||||
|
training_config: Training configuration dict
|
||||||
|
final_metrics: Final training metrics
|
||||||
|
model_path: Path to saved model (if applicable)
|
||||||
|
mlflow_tracking_uri: MLflow tracking server URI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The run_id for chaining
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
import mlflow
|
||||||
|
from mlflow.tracking import MlflowClient
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||||
|
client = MlflowClient()
|
||||||
|
|
||||||
|
# Log training config as params
|
||||||
|
flat_config = {}
|
||||||
|
for key, value in training_config.items():
|
||||||
|
if isinstance(value, (dict, list)):
|
||||||
|
flat_config[f"config.{key}"] = json.dumps(value)[:500]
|
||||||
|
else:
|
||||||
|
flat_config[f"config.{key}"] = str(value)[:500]
|
||||||
|
|
||||||
|
for key, value in flat_config.items():
|
||||||
|
client.log_param(run_id, key, value)
|
||||||
|
|
||||||
|
# Log model type tag
|
||||||
|
client.set_tag(run_id, "model.type", model_type)
|
||||||
|
|
||||||
|
# Log metrics
|
||||||
|
for key, value in final_metrics.items():
|
||||||
|
client.log_metric(run_id, key, float(value))
|
||||||
|
|
||||||
|
# Log full config as artifact
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
config_path = Path(tmpdir) / "training_config.json"
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
json.dump(training_config, f, indent=2)
|
||||||
|
client.log_artifact(run_id, str(config_path))
|
||||||
|
|
||||||
|
# Log model path if provided
|
||||||
|
if model_path:
|
||||||
|
client.log_param(run_id, "model.path", model_path)
|
||||||
|
client.set_tag(run_id, "model.saved", "true")
|
||||||
|
|
||||||
|
return run_id
|
||||||
|
|
||||||
|
|
||||||
|
@dsl.component(
|
||||||
|
base_image=MLFLOW_IMAGE,
|
||||||
|
packages_to_install=MLFLOW_PACKAGES
|
||||||
|
)
|
||||||
|
def log_document_ingestion_metrics(
|
||||||
|
run_id: str,
|
||||||
|
source_url: str,
|
||||||
|
collection_name: str,
|
||||||
|
chunks_created: int,
|
||||||
|
documents_processed: int,
|
||||||
|
processing_time_seconds: float,
|
||||||
|
embeddings_model: str = "bge-small-en-v1.5",
|
||||||
|
chunk_size: int = 500,
|
||||||
|
chunk_overlap: int = 50,
|
||||||
|
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Log document ingestion pipeline metrics.
|
||||||
|
|
||||||
|
Designed for use with the document_ingestion_pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The MLflow run ID to log to
|
||||||
|
source_url: URL of the source document
|
||||||
|
collection_name: Milvus collection name
|
||||||
|
chunks_created: Number of chunks created
|
||||||
|
documents_processed: Number of documents processed
|
||||||
|
processing_time_seconds: Total processing time
|
||||||
|
embeddings_model: Embeddings model used
|
||||||
|
chunk_size: Chunk size in tokens
|
||||||
|
chunk_overlap: Chunk overlap in tokens
|
||||||
|
mlflow_tracking_uri: MLflow tracking server URI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The run_id for chaining
|
||||||
|
"""
|
||||||
|
import mlflow
|
||||||
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||||
|
client = MlflowClient()
|
||||||
|
|
||||||
|
# Log params
|
||||||
|
params = {
|
||||||
|
"source_url": source_url[:500],
|
||||||
|
"collection_name": collection_name,
|
||||||
|
"embeddings_model": embeddings_model,
|
||||||
|
"chunk_size": str(chunk_size),
|
||||||
|
"chunk_overlap": str(chunk_overlap),
|
||||||
|
}
|
||||||
|
for key, value in params.items():
|
||||||
|
client.log_param(run_id, key, value)
|
||||||
|
|
||||||
|
# Log metrics
|
||||||
|
metrics = {
|
||||||
|
"chunks_created": chunks_created,
|
||||||
|
"documents_processed": documents_processed,
|
||||||
|
"processing_time_seconds": processing_time_seconds,
|
||||||
|
"chunks_per_second": chunks_created / processing_time_seconds if processing_time_seconds > 0 else 0,
|
||||||
|
}
|
||||||
|
for key, value in metrics.items():
|
||||||
|
client.log_metric(run_id, key, float(value))
|
||||||
|
|
||||||
|
# Set pipeline type tag
|
||||||
|
client.set_tag(run_id, "pipeline.type", "document-ingestion")
|
||||||
|
client.set_tag(run_id, "milvus.collection", collection_name)
|
||||||
|
|
||||||
|
return run_id
|
||||||
|
|
||||||
|
|
||||||
|
@dsl.component(
|
||||||
|
base_image=MLFLOW_IMAGE,
|
||||||
|
packages_to_install=MLFLOW_PACKAGES
|
||||||
|
)
|
||||||
|
def log_evaluation_results(
|
||||||
|
run_id: str,
|
||||||
|
model_name: str,
|
||||||
|
dataset_name: str,
|
||||||
|
metrics: Dict[str, float],
|
||||||
|
sample_results: List[Dict[str, Any]] = None,
|
||||||
|
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Log model evaluation results.
|
||||||
|
|
||||||
|
Designed for use with the evaluation_pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The MLflow run ID to log to
|
||||||
|
model_name: Name of the evaluated model
|
||||||
|
dataset_name: Name of the evaluation dataset
|
||||||
|
metrics: Evaluation metrics (accuracy, etc.)
|
||||||
|
sample_results: Optional sample predictions
|
||||||
|
mlflow_tracking_uri: MLflow tracking server URI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The run_id for chaining
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
import mlflow
|
||||||
|
from mlflow.tracking import MlflowClient
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||||
|
client = MlflowClient()
|
||||||
|
|
||||||
|
# Log params
|
||||||
|
client.log_param(run_id, "eval.model_name", model_name)
|
||||||
|
client.log_param(run_id, "eval.dataset", dataset_name)
|
||||||
|
|
||||||
|
# Log metrics
|
||||||
|
for key, value in metrics.items():
|
||||||
|
client.log_metric(run_id, f"eval.{key}", float(value))
|
||||||
|
|
||||||
|
# Log sample results as artifact
|
||||||
|
if sample_results:
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
results_path = Path(tmpdir) / "evaluation_results.json"
|
||||||
|
with open(results_path, 'w') as f:
|
||||||
|
json.dump(sample_results, f, indent=2)
|
||||||
|
client.log_artifact(run_id, str(results_path))
|
||||||
|
|
||||||
|
# Set tags
|
||||||
|
client.set_tag(run_id, "pipeline.type", "evaluation")
|
||||||
|
client.set_tag(run_id, "model.name", model_name)
|
||||||
|
|
||||||
|
# Determine if passed
|
||||||
|
passed = metrics.get("pass", metrics.get("accuracy", 0) >= 0.7)
|
||||||
|
client.set_tag(run_id, "eval.passed", str(passed))
|
||||||
|
|
||||||
|
return run_id
|
||||||
545
mlflow_utils/model_registry.py
Normal file
545
mlflow_utils/model_registry.py
Normal file
@@ -0,0 +1,545 @@
|
|||||||
|
"""
|
||||||
|
MLflow Model Registry Integration for KServe
|
||||||
|
|
||||||
|
Provides utilities for registering trained models in MLflow Model Registry
|
||||||
|
with metadata needed for deployment to KServe InferenceServices.
|
||||||
|
|
||||||
|
This module bridges the gap between Kubeflow training pipelines and
|
||||||
|
KServe model serving by:
|
||||||
|
1. Registering models with proper versioning
|
||||||
|
2. Adding KServe-specific metadata (runtime, protocol, resources)
|
||||||
|
3. Managing model stage transitions (Staging → Production)
|
||||||
|
4. Generating KServe InferenceService manifests from registered models
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from mlflow_utils.model_registry import (
|
||||||
|
register_model_for_kserve,
|
||||||
|
promote_model_to_production,
|
||||||
|
generate_kserve_manifest,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register a new model version
|
||||||
|
model_version = register_model_for_kserve(
|
||||||
|
model_name="whisper-finetuned",
|
||||||
|
model_uri="s3://models/whisper-v2",
|
||||||
|
model_type="stt",
|
||||||
|
kserve_config={
|
||||||
|
"runtime": "kserve-huggingface",
|
||||||
|
"container_image": "ghcr.io/my-org/whisper:v2",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate KServe manifest for deployment
|
||||||
|
manifest = generate_kserve_manifest(
|
||||||
|
model_name="whisper-finetuned",
|
||||||
|
model_version=model_version.version,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import yaml
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import mlflow
|
||||||
|
from mlflow.tracking import MlflowClient
|
||||||
|
from mlflow.entities.model_registry import ModelVersion
|
||||||
|
|
||||||
|
from .client import get_mlflow_client, MLflowConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KServeConfig:
|
||||||
|
"""Configuration for KServe deployment."""
|
||||||
|
|
||||||
|
# Runtime/container configuration
|
||||||
|
runtime: str = "kserve-huggingface" # kserve-huggingface, kserve-custom, etc.
|
||||||
|
container_image: Optional[str] = None
|
||||||
|
container_port: int = 8080
|
||||||
|
|
||||||
|
# Protocol configuration
|
||||||
|
protocol: str = "v2" # v1, v2, grpc
|
||||||
|
|
||||||
|
# Resource requests/limits
|
||||||
|
cpu_request: str = "1"
|
||||||
|
cpu_limit: str = "4"
|
||||||
|
memory_request: str = "4Gi"
|
||||||
|
memory_limit: str = "16Gi"
|
||||||
|
gpu_count: int = 0
|
||||||
|
gpu_type: str = "nvidia.com/gpu" # or amd.com/gpu for ROCm
|
||||||
|
|
||||||
|
# Storage configuration
|
||||||
|
storage_uri: Optional[str] = None # s3://, pvc://, gs://
|
||||||
|
|
||||||
|
# Scaling configuration
|
||||||
|
min_replicas: int = 1
|
||||||
|
max_replicas: int = 1
|
||||||
|
scale_target: int = 10 # Target concurrent requests for scaling
|
||||||
|
|
||||||
|
# Serving configuration
|
||||||
|
timeout_seconds: int = 300
|
||||||
|
batch_size: int = 1
|
||||||
|
|
||||||
|
# Additional environment variables
|
||||||
|
env_vars: Dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def as_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for MLflow tags."""
|
||||||
|
return {
|
||||||
|
"kserve.runtime": self.runtime,
|
||||||
|
"kserve.protocol": self.protocol,
|
||||||
|
"kserve.cpu_request": self.cpu_request,
|
||||||
|
"kserve.memory_request": self.memory_request,
|
||||||
|
"kserve.gpu_count": str(self.gpu_count),
|
||||||
|
"kserve.min_replicas": str(self.min_replicas),
|
||||||
|
"kserve.max_replicas": str(self.max_replicas),
|
||||||
|
"kserve.storage_uri": self.storage_uri or "",
|
||||||
|
"kserve.container_image": self.container_image or "",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Pre-configured KServe configurations for common model types
|
||||||
|
KSERVE_PRESETS: Dict[str, KServeConfig] = {
|
||||||
|
"llm": KServeConfig(
|
||||||
|
runtime="kserve-huggingface",
|
||||||
|
cpu_request="2",
|
||||||
|
cpu_limit="8",
|
||||||
|
memory_request="16Gi",
|
||||||
|
memory_limit="64Gi",
|
||||||
|
gpu_count=1,
|
||||||
|
timeout_seconds=600,
|
||||||
|
),
|
||||||
|
"stt": KServeConfig(
|
||||||
|
runtime="kserve-custom",
|
||||||
|
cpu_request="2",
|
||||||
|
cpu_limit="4",
|
||||||
|
memory_request="8Gi",
|
||||||
|
memory_limit="16Gi",
|
||||||
|
gpu_count=1,
|
||||||
|
timeout_seconds=120,
|
||||||
|
),
|
||||||
|
"tts": KServeConfig(
|
||||||
|
runtime="kserve-custom",
|
||||||
|
cpu_request="2",
|
||||||
|
cpu_limit="4",
|
||||||
|
memory_request="8Gi",
|
||||||
|
memory_limit="16Gi",
|
||||||
|
gpu_count=1,
|
||||||
|
timeout_seconds=60,
|
||||||
|
),
|
||||||
|
"embeddings": KServeConfig(
|
||||||
|
runtime="kserve-huggingface",
|
||||||
|
cpu_request="1",
|
||||||
|
cpu_limit="4",
|
||||||
|
memory_request="4Gi",
|
||||||
|
memory_limit="16Gi",
|
||||||
|
gpu_count=0,
|
||||||
|
timeout_seconds=30,
|
||||||
|
batch_size=32,
|
||||||
|
),
|
||||||
|
"reranker": KServeConfig(
|
||||||
|
runtime="kserve-huggingface",
|
||||||
|
cpu_request="1",
|
||||||
|
cpu_limit="4",
|
||||||
|
memory_request="4Gi",
|
||||||
|
memory_limit="16Gi",
|
||||||
|
gpu_count=0,
|
||||||
|
timeout_seconds=30,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def register_model_for_kserve(
|
||||||
|
model_name: str,
|
||||||
|
model_uri: str,
|
||||||
|
model_type: str,
|
||||||
|
run_id: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
kserve_config: Optional[KServeConfig] = None,
|
||||||
|
tags: Optional[Dict[str, str]] = None,
|
||||||
|
tracking_uri: Optional[str] = None,
|
||||||
|
) -> ModelVersion:
|
||||||
|
"""
|
||||||
|
Register a model in MLflow Model Registry with KServe metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name for the registered model
|
||||||
|
model_uri: URI to model artifacts (runs:/run_id/path, s3://, pvc://)
|
||||||
|
model_type: Type of model (llm, stt, tts, embeddings, reranker)
|
||||||
|
run_id: Optional MLflow run ID to associate with
|
||||||
|
description: Description of the model version
|
||||||
|
kserve_config: KServe deployment configuration
|
||||||
|
tags: Additional tags for the model version
|
||||||
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created ModelVersion object
|
||||||
|
"""
|
||||||
|
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||||
|
|
||||||
|
# Get or use preset KServe config
|
||||||
|
if kserve_config is None:
|
||||||
|
kserve_config = KSERVE_PRESETS.get(model_type, KServeConfig())
|
||||||
|
|
||||||
|
# Ensure registered model exists
|
||||||
|
try:
|
||||||
|
client.get_registered_model(model_name)
|
||||||
|
except mlflow.exceptions.MlflowException:
|
||||||
|
client.create_registered_model(
|
||||||
|
name=model_name,
|
||||||
|
description=f"{model_type.upper()} model for KServe deployment",
|
||||||
|
tags={
|
||||||
|
"model.type": model_type,
|
||||||
|
"deployment.target": "kserve",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.info(f"Created registered model: {model_name}")
|
||||||
|
|
||||||
|
# Create model version
|
||||||
|
model_version = client.create_model_version(
|
||||||
|
name=model_name,
|
||||||
|
source=model_uri,
|
||||||
|
run_id=run_id,
|
||||||
|
description=description or f"Version from {model_uri}",
|
||||||
|
tags={
|
||||||
|
**(tags or {}),
|
||||||
|
"model.type": model_type,
|
||||||
|
**kserve_config.as_dict(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Registered model version {model_version.version} "
|
||||||
|
f"for {model_name} (type: {model_type})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_version
|
||||||
|
|
||||||
|
|
||||||
|
def promote_model_to_stage(
|
||||||
|
model_name: str,
|
||||||
|
version: int,
|
||||||
|
stage: str = "Staging",
|
||||||
|
archive_existing: bool = True,
|
||||||
|
tracking_uri: Optional[str] = None,
|
||||||
|
) -> ModelVersion:
|
||||||
|
"""
|
||||||
|
Promote a model version to a new stage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the registered model
|
||||||
|
version: Version number to promote
|
||||||
|
stage: Target stage (Staging, Production, Archived)
|
||||||
|
archive_existing: If True, archive existing versions in target stage
|
||||||
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated ModelVersion
|
||||||
|
"""
|
||||||
|
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||||
|
|
||||||
|
# Transition to new stage
|
||||||
|
model_version = client.transition_model_version_stage(
|
||||||
|
name=model_name,
|
||||||
|
version=str(version),
|
||||||
|
stage=stage,
|
||||||
|
archive_existing_versions=archive_existing,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Promoted {model_name} v{version} to {stage}")
|
||||||
|
|
||||||
|
return model_version
|
||||||
|
|
||||||
|
|
||||||
|
def promote_model_to_production(
|
||||||
|
model_name: str,
|
||||||
|
version: int,
|
||||||
|
tracking_uri: Optional[str] = None,
|
||||||
|
) -> ModelVersion:
|
||||||
|
"""
|
||||||
|
Promote a model version directly to Production.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the registered model
|
||||||
|
version: Version number to promote
|
||||||
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated ModelVersion
|
||||||
|
"""
|
||||||
|
return promote_model_to_stage(
|
||||||
|
model_name=model_name,
|
||||||
|
version=version,
|
||||||
|
stage="Production",
|
||||||
|
archive_existing=True,
|
||||||
|
tracking_uri=tracking_uri,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_production_model(
|
||||||
|
model_name: str,
|
||||||
|
tracking_uri: Optional[str] = None,
|
||||||
|
) -> Optional[ModelVersion]:
|
||||||
|
"""
|
||||||
|
Get the current Production model version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the registered model
|
||||||
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The Production ModelVersion, or None if none exists
|
||||||
|
"""
|
||||||
|
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||||
|
|
||||||
|
versions = client.get_latest_versions(model_name, stages=["Production"])
|
||||||
|
|
||||||
|
return versions[0] if versions else None
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_kserve_config(
|
||||||
|
model_name: str,
|
||||||
|
version: Optional[int] = None,
|
||||||
|
tracking_uri: Optional[str] = None,
|
||||||
|
) -> KServeConfig:
|
||||||
|
"""
|
||||||
|
Get KServe configuration from a registered model version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the registered model
|
||||||
|
version: Version number (uses Production if not specified)
|
||||||
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KServeConfig populated from model tags
|
||||||
|
"""
|
||||||
|
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||||
|
|
||||||
|
if version:
|
||||||
|
model_version = client.get_model_version(model_name, str(version))
|
||||||
|
else:
|
||||||
|
prod_version = get_production_model(model_name, tracking_uri)
|
||||||
|
if not prod_version:
|
||||||
|
raise ValueError(f"No Production version for {model_name}")
|
||||||
|
model_version = prod_version
|
||||||
|
|
||||||
|
tags = model_version.tags
|
||||||
|
|
||||||
|
return KServeConfig(
|
||||||
|
runtime=tags.get("kserve.runtime", "kserve-huggingface"),
|
||||||
|
protocol=tags.get("kserve.protocol", "v2"),
|
||||||
|
cpu_request=tags.get("kserve.cpu_request", "1"),
|
||||||
|
memory_request=tags.get("kserve.memory_request", "4Gi"),
|
||||||
|
gpu_count=int(tags.get("kserve.gpu_count", "0")),
|
||||||
|
min_replicas=int(tags.get("kserve.min_replicas", "1")),
|
||||||
|
max_replicas=int(tags.get("kserve.max_replicas", "1")),
|
||||||
|
storage_uri=tags.get("kserve.storage_uri") or None,
|
||||||
|
container_image=tags.get("kserve.container_image") or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_kserve_manifest(
|
||||||
|
model_name: str,
|
||||||
|
version: Optional[int] = None,
|
||||||
|
namespace: str = "ai-ml",
|
||||||
|
service_name: Optional[str] = None,
|
||||||
|
extra_annotations: Optional[Dict[str, str]] = None,
|
||||||
|
tracking_uri: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Generate a KServe InferenceService manifest from a registered model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the registered model
|
||||||
|
version: Version number (uses Production if not specified)
|
||||||
|
namespace: Kubernetes namespace for deployment
|
||||||
|
service_name: Name for the InferenceService (defaults to model_name)
|
||||||
|
extra_annotations: Additional annotations for the service
|
||||||
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KServe InferenceService manifest as a dictionary
|
||||||
|
"""
|
||||||
|
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||||
|
|
||||||
|
# Get model version
|
||||||
|
if version:
|
||||||
|
model_version = client.get_model_version(model_name, str(version))
|
||||||
|
else:
|
||||||
|
prod_version = get_production_model(model_name, tracking_uri)
|
||||||
|
if not prod_version:
|
||||||
|
raise ValueError(f"No Production version for {model_name}")
|
||||||
|
model_version = prod_version
|
||||||
|
version = int(model_version.version)
|
||||||
|
|
||||||
|
# Get KServe config
|
||||||
|
config = get_model_kserve_config(model_name, version, tracking_uri)
|
||||||
|
model_type = model_version.tags.get("model.type", "custom")
|
||||||
|
|
||||||
|
svc_name = service_name or model_name.lower().replace("_", "-")
|
||||||
|
|
||||||
|
# Build manifest
|
||||||
|
manifest = {
|
||||||
|
"apiVersion": "serving.kserve.io/v1beta1",
|
||||||
|
"kind": "InferenceService",
|
||||||
|
"metadata": {
|
||||||
|
"name": svc_name,
|
||||||
|
"namespace": namespace,
|
||||||
|
"labels": {
|
||||||
|
"mlflow.model": model_name,
|
||||||
|
"mlflow.version": str(version),
|
||||||
|
"model.type": model_type,
|
||||||
|
},
|
||||||
|
"annotations": {
|
||||||
|
"mlflow.tracking_uri": get_mlflow_client().tracking_uri,
|
||||||
|
"mlflow.run_id": model_version.run_id or "",
|
||||||
|
**(extra_annotations or {}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"spec": {
|
||||||
|
"predictor": {
|
||||||
|
"minReplicas": config.min_replicas,
|
||||||
|
"maxReplicas": config.max_replicas,
|
||||||
|
"scaleTarget": config.scale_target,
|
||||||
|
"timeout": config.timeout_seconds,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Configure predictor based on runtime
|
||||||
|
predictor = manifest["spec"]["predictor"]
|
||||||
|
|
||||||
|
if config.container_image:
|
||||||
|
# Custom container
|
||||||
|
predictor["containers"] = [{
|
||||||
|
"name": "predictor",
|
||||||
|
"image": config.container_image,
|
||||||
|
"ports": [{"containerPort": config.container_port, "protocol": "TCP"}],
|
||||||
|
"resources": {
|
||||||
|
"requests": {
|
||||||
|
"cpu": config.cpu_request,
|
||||||
|
"memory": config.memory_request,
|
||||||
|
},
|
||||||
|
"limits": {
|
||||||
|
"cpu": config.cpu_limit,
|
||||||
|
"memory": config.memory_limit,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"env": [
|
||||||
|
{"name": k, "value": v}
|
||||||
|
for k, v in config.env_vars.items()
|
||||||
|
],
|
||||||
|
}]
|
||||||
|
|
||||||
|
# Add GPU if needed
|
||||||
|
if config.gpu_count > 0:
|
||||||
|
predictor["containers"][0]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
|
||||||
|
predictor["containers"][0]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Standard KServe runtime
|
||||||
|
storage_uri = config.storage_uri or model_version.source
|
||||||
|
|
||||||
|
predictor["model"] = {
|
||||||
|
"modelFormat": {"name": "huggingface"},
|
||||||
|
"protocolVersion": config.protocol,
|
||||||
|
"storageUri": storage_uri,
|
||||||
|
"resources": {
|
||||||
|
"requests": {
|
||||||
|
"cpu": config.cpu_request,
|
||||||
|
"memory": config.memory_request,
|
||||||
|
},
|
||||||
|
"limits": {
|
||||||
|
"cpu": config.cpu_limit,
|
||||||
|
"memory": config.memory_limit,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.gpu_count > 0:
|
||||||
|
predictor["model"]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
|
||||||
|
predictor["model"]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
|
||||||
|
|
||||||
|
return manifest
|
||||||
|
|
||||||
|
|
||||||
|
def generate_kserve_yaml(
|
||||||
|
model_name: str,
|
||||||
|
version: Optional[int] = None,
|
||||||
|
namespace: str = "ai-ml",
|
||||||
|
output_path: Optional[str] = None,
|
||||||
|
tracking_uri: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate a KServe InferenceService manifest as YAML.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the registered model
|
||||||
|
version: Version number (uses Production if not specified)
|
||||||
|
namespace: Kubernetes namespace
|
||||||
|
output_path: If provided, write YAML to this path
|
||||||
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
YAML string of the manifest
|
||||||
|
"""
|
||||||
|
manifest = generate_kserve_manifest(
|
||||||
|
model_name=model_name,
|
||||||
|
version=version,
|
||||||
|
namespace=namespace,
|
||||||
|
tracking_uri=tracking_uri,
|
||||||
|
)
|
||||||
|
|
||||||
|
yaml_str = yaml.dump(manifest, default_flow_style=False, sort_keys=False)
|
||||||
|
|
||||||
|
if output_path:
|
||||||
|
with open(output_path, 'w') as f:
|
||||||
|
f.write(yaml_str)
|
||||||
|
logger.info(f"Wrote KServe manifest to {output_path}")
|
||||||
|
|
||||||
|
return yaml_str
|
||||||
|
|
||||||
|
|
||||||
|
def list_model_versions(
|
||||||
|
model_name: str,
|
||||||
|
stages: Optional[List[str]] = None,
|
||||||
|
tracking_uri: Optional[str] = None,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
List all versions of a registered model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the registered model
|
||||||
|
stages: Filter by stages (None for all)
|
||||||
|
tracking_uri: Override default tracking URI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of model version info dictionaries
|
||||||
|
"""
|
||||||
|
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||||
|
|
||||||
|
if stages:
|
||||||
|
versions = client.get_latest_versions(model_name, stages=stages)
|
||||||
|
else:
|
||||||
|
# Get all versions
|
||||||
|
versions = []
|
||||||
|
for mv in client.search_model_versions(f"name='{model_name}'"):
|
||||||
|
versions.append(mv)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"version": mv.version,
|
||||||
|
"stage": mv.current_stage,
|
||||||
|
"source": mv.source,
|
||||||
|
"run_id": mv.run_id,
|
||||||
|
"description": mv.description,
|
||||||
|
"tags": mv.tags,
|
||||||
|
"creation_timestamp": mv.creation_timestamp,
|
||||||
|
"last_updated_timestamp": mv.last_updated_timestamp,
|
||||||
|
}
|
||||||
|
for mv in versions
|
||||||
|
]
|
||||||
395
mlflow_utils/tracker.py
Normal file
395
mlflow_utils/tracker.py
Normal file
@@ -0,0 +1,395 @@
|
|||||||
|
"""
|
||||||
|
MLflow Tracker for Kubeflow Pipelines
|
||||||
|
|
||||||
|
Provides a high-level interface for logging experiments, parameters,
|
||||||
|
metrics, and artifacts from Kubeflow Pipeline components.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Dict, Any, List, Union
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import mlflow
|
||||||
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
|
from .client import get_mlflow_client, ensure_experiment, MLflowConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineMetadata:
|
||||||
|
"""Metadata about the Kubeflow Pipeline run."""
|
||||||
|
pipeline_name: str
|
||||||
|
run_id: str
|
||||||
|
run_name: Optional[str] = None
|
||||||
|
component_name: Optional[str] = None
|
||||||
|
namespace: str = "ai-ml"
|
||||||
|
|
||||||
|
# KFP-specific metadata (populated from environment if available)
|
||||||
|
kfp_run_id: Optional[str] = field(
|
||||||
|
default_factory=lambda: os.environ.get("KFP_RUN_ID")
|
||||||
|
)
|
||||||
|
kfp_pod_name: Optional[str] = field(
|
||||||
|
default_factory=lambda: os.environ.get("KFP_POD_NAME")
|
||||||
|
)
|
||||||
|
|
||||||
|
def as_tags(self) -> Dict[str, str]:
|
||||||
|
"""Convert metadata to MLflow tags."""
|
||||||
|
tags = {
|
||||||
|
"pipeline.name": self.pipeline_name,
|
||||||
|
"pipeline.run_id": self.run_id,
|
||||||
|
"pipeline.namespace": self.namespace,
|
||||||
|
}
|
||||||
|
if self.run_name:
|
||||||
|
tags["pipeline.run_name"] = self.run_name
|
||||||
|
if self.component_name:
|
||||||
|
tags["pipeline.component"] = self.component_name
|
||||||
|
if self.kfp_run_id:
|
||||||
|
tags["kfp.run_id"] = self.kfp_run_id
|
||||||
|
if self.kfp_pod_name:
|
||||||
|
tags["kfp.pod_name"] = self.kfp_pod_name
|
||||||
|
return tags
|
||||||
|
|
||||||
|
|
||||||
|
class MLflowTracker:
|
||||||
|
"""
|
||||||
|
MLflow experiment tracker for Kubeflow Pipeline components.
|
||||||
|
|
||||||
|
Example usage in a KFP component:
|
||||||
|
|
||||||
|
from mlflow_utils import MLflowTracker
|
||||||
|
|
||||||
|
tracker = MLflowTracker(
|
||||||
|
experiment_name="document-ingestion",
|
||||||
|
run_name="batch-ingestion-2024-01"
|
||||||
|
)
|
||||||
|
|
||||||
|
with tracker.start_run() as run:
|
||||||
|
tracker.log_params({
|
||||||
|
"chunk_size": 500,
|
||||||
|
"overlap": 50,
|
||||||
|
"embeddings_model": "bge-small-en-v1.5"
|
||||||
|
})
|
||||||
|
|
||||||
|
# ... do work ...
|
||||||
|
|
||||||
|
tracker.log_metrics({
|
||||||
|
"documents_processed": 100,
|
||||||
|
"chunks_created": 2500,
|
||||||
|
"processing_time_seconds": 120.5
|
||||||
|
})
|
||||||
|
|
||||||
|
tracker.log_artifact("/path/to/output.json")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
experiment_name: str,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
|
pipeline_metadata: Optional[PipelineMetadata] = None,
|
||||||
|
tags: Optional[Dict[str, str]] = None,
|
||||||
|
tracking_uri: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the MLflow tracker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
experiment_name: Name of the MLflow experiment
|
||||||
|
run_name: Optional name for this run
|
||||||
|
pipeline_metadata: Metadata about the KFP pipeline
|
||||||
|
tags: Additional tags to add to the run
|
||||||
|
tracking_uri: Override default tracking URI
|
||||||
|
"""
|
||||||
|
self.config = MLflowConfig()
|
||||||
|
self.experiment_name = experiment_name
|
||||||
|
self.run_name = run_name or f"{experiment_name}-{int(time.time())}"
|
||||||
|
self.pipeline_metadata = pipeline_metadata
|
||||||
|
self.user_tags = tags or {}
|
||||||
|
self.tracking_uri = tracking_uri
|
||||||
|
|
||||||
|
self.client: Optional[MlflowClient] = None
|
||||||
|
self.run: Optional[mlflow.ActiveRun] = None
|
||||||
|
self.run_id: Optional[str] = None
|
||||||
|
self._start_time: Optional[float] = None
|
||||||
|
|
||||||
|
def _get_all_tags(self) -> Dict[str, str]:
|
||||||
|
"""Combine all tags for the run."""
|
||||||
|
tags = self.config.default_tags.copy()
|
||||||
|
|
||||||
|
if self.pipeline_metadata:
|
||||||
|
tags.update(self.pipeline_metadata.as_tags())
|
||||||
|
|
||||||
|
tags.update(self.user_tags)
|
||||||
|
return tags
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def start_run(
|
||||||
|
self,
|
||||||
|
nested: bool = False,
|
||||||
|
parent_run_id: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Start an MLflow run as a context manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nested: If True, create a nested run under the current active run
|
||||||
|
parent_run_id: Explicit parent run ID for nested runs
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The MLflow run object
|
||||||
|
"""
|
||||||
|
self.client = get_mlflow_client(
|
||||||
|
tracking_uri=self.tracking_uri,
|
||||||
|
configure_global=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure experiment exists
|
||||||
|
experiment_id = ensure_experiment(self.experiment_name)
|
||||||
|
|
||||||
|
self._start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Start the run
|
||||||
|
self.run = mlflow.start_run(
|
||||||
|
experiment_id=experiment_id,
|
||||||
|
run_name=self.run_name,
|
||||||
|
nested=nested,
|
||||||
|
tags=self._get_all_tags(),
|
||||||
|
)
|
||||||
|
self.run_id = self.run.info.run_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Started MLflow run '{self.run_name}' "
|
||||||
|
f"(ID: {self.run_id}) in experiment '{self.experiment_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.run
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"MLflow run failed: {e}")
|
||||||
|
if self.run:
|
||||||
|
mlflow.set_tag("run.status", "failed")
|
||||||
|
mlflow.set_tag("run.error", str(e))
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# Log duration
|
||||||
|
if self._start_time:
|
||||||
|
duration = time.time() - self._start_time
|
||||||
|
try:
|
||||||
|
mlflow.log_metric("run_duration_seconds", duration)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# End the run
|
||||||
|
mlflow.end_run()
|
||||||
|
logger.info(f"Ended MLflow run '{self.run_name}'")
|
||||||
|
|
||||||
|
def log_params(self, params: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Log parameters to the current run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: Dictionary of parameter names to values
|
||||||
|
"""
|
||||||
|
if not self.run:
|
||||||
|
logger.warning("No active run, skipping log_params")
|
||||||
|
return
|
||||||
|
|
||||||
|
# MLflow has limits on param values, truncate if needed
|
||||||
|
cleaned_params = {}
|
||||||
|
for key, value in params.items():
|
||||||
|
str_value = str(value)
|
||||||
|
if len(str_value) > 500:
|
||||||
|
str_value = str_value[:497] + "..."
|
||||||
|
cleaned_params[key] = str_value
|
||||||
|
|
||||||
|
mlflow.log_params(cleaned_params)
|
||||||
|
logger.debug(f"Logged {len(params)} parameters")
|
||||||
|
|
||||||
|
def log_param(self, key: str, value: Any) -> None:
|
||||||
|
"""Log a single parameter."""
|
||||||
|
self.log_params({key: value})
|
||||||
|
|
||||||
|
def log_metrics(
|
||||||
|
self,
|
||||||
|
metrics: Dict[str, Union[float, int]],
|
||||||
|
step: Optional[int] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Log metrics to the current run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metrics: Dictionary of metric names to values
|
||||||
|
step: Optional step number for time-series metrics
|
||||||
|
"""
|
||||||
|
if not self.run:
|
||||||
|
logger.warning("No active run, skipping log_metrics")
|
||||||
|
return
|
||||||
|
|
||||||
|
mlflow.log_metrics(metrics, step=step)
|
||||||
|
logger.debug(f"Logged {len(metrics)} metrics")
|
||||||
|
|
||||||
|
def log_metric(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Union[float, int],
|
||||||
|
step: Optional[int] = None
|
||||||
|
) -> None:
|
||||||
|
"""Log a single metric."""
|
||||||
|
self.log_metrics({key: value}, step=step)
|
||||||
|
|
||||||
|
def log_artifact(
|
||||||
|
self,
|
||||||
|
local_path: str,
|
||||||
|
artifact_path: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Log an artifact file to the current run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_path: Path to the local file to log
|
||||||
|
artifact_path: Optional destination path within the artifact store
|
||||||
|
"""
|
||||||
|
if not self.run:
|
||||||
|
logger.warning("No active run, skipping log_artifact")
|
||||||
|
return
|
||||||
|
|
||||||
|
mlflow.log_artifact(local_path, artifact_path)
|
||||||
|
logger.info(f"Logged artifact: {local_path}")
|
||||||
|
|
||||||
|
def log_artifacts(
|
||||||
|
self,
|
||||||
|
local_dir: str,
|
||||||
|
artifact_path: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Log all files in a directory as artifacts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_dir: Path to the local directory
|
||||||
|
artifact_path: Optional destination path within the artifact store
|
||||||
|
"""
|
||||||
|
if not self.run:
|
||||||
|
logger.warning("No active run, skipping log_artifacts")
|
||||||
|
return
|
||||||
|
|
||||||
|
mlflow.log_artifacts(local_dir, artifact_path)
|
||||||
|
logger.info(f"Logged artifacts from: {local_dir}")
|
||||||
|
|
||||||
|
def log_dict(
|
||||||
|
self,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
filename: str,
|
||||||
|
artifact_path: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Log a dictionary as a JSON artifact.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary to log
|
||||||
|
filename: Name for the JSON file
|
||||||
|
artifact_path: Optional destination path
|
||||||
|
"""
|
||||||
|
if not self.run:
|
||||||
|
logger.warning("No active run, skipping log_dict")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Ensure .json extension
|
||||||
|
if not filename.endswith(".json"):
|
||||||
|
filename += ".json"
|
||||||
|
|
||||||
|
mlflow.log_dict(data, f"{artifact_path}/{filename}" if artifact_path else filename)
|
||||||
|
logger.debug(f"Logged dict as: {filename}")
|
||||||
|
|
||||||
|
def log_model_info(
|
||||||
|
self,
|
||||||
|
model_type: str,
|
||||||
|
model_name: str,
|
||||||
|
model_path: Optional[str] = None,
|
||||||
|
framework: str = "pytorch",
|
||||||
|
extra_info: Optional[Dict[str, Any]] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Log model information as parameters and tags.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: Type of model (e.g., "llm", "embedding", "stt")
|
||||||
|
model_name: Name/identifier of the model
|
||||||
|
model_path: Path to model weights
|
||||||
|
framework: ML framework used
|
||||||
|
extra_info: Additional model information
|
||||||
|
"""
|
||||||
|
params = {
|
||||||
|
"model.type": model_type,
|
||||||
|
"model.name": model_name,
|
||||||
|
"model.framework": framework,
|
||||||
|
}
|
||||||
|
if model_path:
|
||||||
|
params["model.path"] = model_path
|
||||||
|
if extra_info:
|
||||||
|
for key, value in extra_info.items():
|
||||||
|
params[f"model.{key}"] = value
|
||||||
|
|
||||||
|
self.log_params(params)
|
||||||
|
|
||||||
|
# Also set as tags for easier filtering
|
||||||
|
mlflow.set_tag("model.type", model_type)
|
||||||
|
mlflow.set_tag("model.name", model_name)
|
||||||
|
|
||||||
|
def log_dataset_info(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
source: str,
|
||||||
|
size: Optional[int] = None,
|
||||||
|
extra_info: Optional[Dict[str, Any]] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Log dataset information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Dataset name
|
||||||
|
source: Dataset source (URL, path, etc.)
|
||||||
|
size: Number of samples
|
||||||
|
extra_info: Additional dataset information
|
||||||
|
"""
|
||||||
|
params = {
|
||||||
|
"dataset.name": name,
|
||||||
|
"dataset.source": source,
|
||||||
|
}
|
||||||
|
if size is not None:
|
||||||
|
params["dataset.size"] = size
|
||||||
|
if extra_info:
|
||||||
|
for key, value in extra_info.items():
|
||||||
|
params[f"dataset.{key}"] = value
|
||||||
|
|
||||||
|
self.log_params(params)
|
||||||
|
|
||||||
|
def set_tag(self, key: str, value: str) -> None:
|
||||||
|
"""Set a single tag on the run."""
|
||||||
|
if self.run:
|
||||||
|
mlflow.set_tag(key, value)
|
||||||
|
|
||||||
|
def set_tags(self, tags: Dict[str, str]) -> None:
|
||||||
|
"""Set multiple tags on the run."""
|
||||||
|
if self.run:
|
||||||
|
mlflow.set_tags(tags)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def artifact_uri(self) -> Optional[str]:
|
||||||
|
"""Get the artifact URI for the current run."""
|
||||||
|
if self.run:
|
||||||
|
return self.run.info.artifact_uri
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def experiment_id(self) -> Optional[str]:
|
||||||
|
"""Get the experiment ID for the current run."""
|
||||||
|
if self.run:
|
||||||
|
return self.run.info.experiment_id
|
||||||
|
return None
|
||||||
19
requirements.txt
Normal file
19
requirements.txt
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# MLflow Utils Module Requirements
|
||||||
|
# Core MLflow
|
||||||
|
mlflow>=2.10.0
|
||||||
|
|
||||||
|
# Database backends
|
||||||
|
psycopg2-binary>=2.9.0 # PostgreSQL (CNPG)
|
||||||
|
boto3>=1.34.0 # S3-compatible artifact storage (optional)
|
||||||
|
|
||||||
|
# For async tracking
|
||||||
|
aiohttp>=3.9.0
|
||||||
|
|
||||||
|
# YAML generation for KServe manifests
|
||||||
|
PyYAML>=6.0
|
||||||
|
|
||||||
|
# Already in chat-handler/voice-assistant requirements:
|
||||||
|
# httpx (for health checks)
|
||||||
|
# Used but typically installed with mlflow:
|
||||||
|
# numpy
|
||||||
|
# pandas
|
||||||
Reference in New Issue
Block a user