feat: Add MLflow integration utilities

- client: Connection management and helpers
- tracker: General experiment tracking
- inference_tracker: Async metrics for NATS handlers
- model_registry: Model registration with KServe metadata
- kfp_components: Kubeflow Pipeline components
- experiment_comparison: Run comparison tools
- cli: Command-line interface
This commit is contained in:
2026-02-01 20:43:13 -05:00
parent 944bd8bc06
commit 2df3f27af7
10 changed files with 3315 additions and 1 deletions

129
README.md
View File

@@ -1,2 +1,129 @@
# mlflow
# MLflow Utils
MLflow integration utilities for the DaviesTechLabs AI/ML platform.
## Installation
```bash
pip install -r requirements.txt
```
Or from Gitea:
```bash
pip install git+https://git.daviestechlabs.io/daviestechlabs/mlflow.git
```
## Modules
| Module | Description |
|--------|-------------|
| `client.py` | MLflow client configuration and helpers |
| `tracker.py` | General MLflowTracker for experiments |
| `inference_tracker.py` | Async inference metrics for NATS handlers |
| `model_registry.py` | Model Registry with KServe metadata |
| `kfp_components.py` | Kubeflow Pipeline MLflow components |
| `experiment_comparison.py` | Compare experiments and runs |
| `cli.py` | Command-line interface |
## Quick Start
```python
from mlflow_utils import get_mlflow_client, MLflowTracker
# Simple tracking
with MLflowTracker(experiment_name="my-experiment") as tracker:
tracker.log_params({"learning_rate": 0.001})
tracker.log_metrics({"accuracy": 0.95})
```
## Inference Tracking
For NATS handlers (chat-handler, voice-assistant):
```python
from mlflow_utils import InferenceMetricsTracker
from mlflow_utils.inference_tracker import InferenceMetrics
tracker = InferenceMetricsTracker(
experiment_name="voice-assistant-prod",
batch_size=100, # Batch metrics before logging
)
# During request handling
metrics = InferenceMetrics(
request_id="uuid",
total_latency=1.5,
llm_latency=0.8,
input_tokens=150,
output_tokens=200,
)
await tracker.log_inference(metrics)
```
## Model Registry
Register models with KServe deployment metadata:
```python
from mlflow_utils.model_registry import register_model_for_kserve
register_model_for_kserve(
model_name="my-qlora-adapter",
model_uri="runs:/abc123/model",
kserve_runtime="kserve-vllm",
gpu_type="amd-strixhalo",
)
```
## Kubeflow Components
Use in KFP pipelines:
```python
from mlflow_utils.kfp_components import (
log_experiment_component,
register_model_component,
)
```
## CLI
```bash
# List experiments
python -m mlflow_utils.cli list-experiments
# Compare runs
python -m mlflow_utils.cli compare-runs --experiment "qlora-training"
# Export metrics
python -m mlflow_utils.cli export --run-id abc123 --output metrics.json
```
## Configuration
| Environment Variable | Default | Description |
|---------------------|---------|-------------|
| `MLFLOW_TRACKING_URI` | `http://mlflow.mlflow.svc.cluster.local:80` | MLflow server |
| `MLFLOW_EXPERIMENT_NAME` | `default` | Default experiment |
| `MLFLOW_ENABLE_ASYNC` | `true` | Async logging for handlers |
## Module Structure
```
mlflow_utils/
├── __init__.py # Public API
├── client.py # Connection management
├── tracker.py # General experiment tracker
├── inference_tracker.py # Async inference metrics
├── model_registry.py # Model registration + KServe
├── kfp_components.py # Kubeflow components
├── experiment_comparison.py # Run comparison tools
└── cli.py # Command-line interface
```
## Related
- [handler-base](https://git.daviestechlabs.io/daviestechlabs/handler-base) - Uses inference tracker
- [kubeflow](https://git.daviestechlabs.io/daviestechlabs/kubeflow) - KFP components
- [argo](https://git.daviestechlabs.io/daviestechlabs/argo) - Training workflows
- [homelab-design](https://git.daviestechlabs.io/daviestechlabs/homelab-design) - Architecture docs

40
mlflow_utils/__init__.py Normal file
View File

@@ -0,0 +1,40 @@
"""
MLflow Integration Utilities for LLM Workflows
This module provides MLflow integration for:
- Kubeflow Pipelines experiment tracking
- Model Registry with KServe deployment metadata
- Inference metrics logging from NATS handlers
- Experiment comparison and analysis
Configuration:
Set MLFLOW_TRACKING_URI environment variable or use defaults:
- In-cluster: http://mlflow.mlflow.svc.cluster.local:80
- External: https://mlflow.lab.daviestechlabs.io
Usage:
from mlflow_utils import get_mlflow_client, MLflowTracker
from mlflow_utils.kfp_components import log_experiment_component
from mlflow_utils.model_registry import register_model_for_kserve
from mlflow_utils.inference_tracker import InferenceMetricsTracker
"""
from .client import (
get_mlflow_client,
get_tracking_uri,
ensure_experiment,
MLflowConfig,
)
from .tracker import MLflowTracker
from .inference_tracker import InferenceMetricsTracker
__all__ = [
"get_mlflow_client",
"get_tracking_uri",
"ensure_experiment",
"MLflowConfig",
"MLflowTracker",
"InferenceMetricsTracker",
]
__version__ = "1.0.0"

371
mlflow_utils/cli.py Normal file
View File

@@ -0,0 +1,371 @@
#!/usr/bin/env python3
"""
MLflow Experiment CLI
Command-line interface for querying and comparing MLflow experiments.
Usage:
# Compare recent runs in an experiment
python -m mlflow_utils.cli compare --experiment chat-inference --runs 5
# Get best run by metric
python -m mlflow_utils.cli best --experiment evaluation --metric eval.accuracy
# Generate performance report
python -m mlflow_utils.cli report --service chat-handler --hours 24
# Check model promotion criteria
python -m mlflow_utils.cli promote --model whisper-finetuned \\
--experiment voice-evaluation \\
--criteria "eval.accuracy>=0.9,total_latency_p95<=2.0"
# List experiments
python -m mlflow_utils.cli list-experiments
# Query runs
python -m mlflow_utils.cli query --experiment chat-inference \\
--filter "metrics.total_latency_mean < 1.0" --limit 10
"""
import argparse
import json
import sys
from typing import Optional
from .client import get_mlflow_client, health_check
from .experiment_comparison import (
ExperimentAnalyzer,
compare_experiments,
promotion_recommendation,
get_inference_performance_report,
)
from .model_registry import (
list_model_versions,
get_production_model,
generate_kserve_yaml,
)
def cmd_health(args):
"""Check MLflow connectivity."""
result = health_check()
print(json.dumps(result, indent=2))
sys.exit(0 if result["connected"] else 1)
def cmd_list_experiments(args):
"""List all experiments."""
client = get_mlflow_client(tracking_uri=args.tracking_uri)
experiments = client.search_experiments()
print(f"{'ID':<10} {'Name':<40} {'Artifact Location'}")
print("-" * 80)
for exp in experiments:
print(f"{exp.experiment_id:<10} {exp.name:<40} {exp.artifact_location}")
def cmd_compare(args):
"""Compare recent runs in an experiment."""
analyzer = ExperimentAnalyzer(
args.experiment,
tracking_uri=args.tracking_uri
)
if args.run_ids:
run_ids = args.run_ids.split(",")
comparison = analyzer.compare_runs(run_ids=run_ids)
else:
comparison = analyzer.compare_runs(n_recent=args.runs)
if args.json:
print(json.dumps(comparison.to_dict(), indent=2, default=str))
else:
print(comparison.summary_table())
def cmd_best(args):
"""Find the best run by a metric."""
analyzer = ExperimentAnalyzer(
args.experiment,
tracking_uri=args.tracking_uri
)
best_run = analyzer.get_best_run(
metric=args.metric,
minimize=args.minimize,
filter_string=args.filter or "",
)
if not best_run:
print(f"No runs found with metric '{args.metric}'")
sys.exit(1)
result = {
"run_id": best_run.info.run_id,
"run_name": best_run.info.run_name,
"metric_value": best_run.data.metrics.get(args.metric),
"all_metrics": dict(best_run.data.metrics),
"params": dict(best_run.data.params),
}
if args.json:
print(json.dumps(result, indent=2))
else:
print(f"Best Run: {best_run.info.run_name or best_run.info.run_id}")
print(f" {args.metric}: {result['metric_value']}")
print(f" Run ID: {best_run.info.run_id}")
def cmd_summary(args):
"""Get metrics summary for an experiment."""
analyzer = ExperimentAnalyzer(
args.experiment,
tracking_uri=args.tracking_uri
)
summary = analyzer.get_metrics_summary(
hours=args.hours,
metrics=args.metrics.split(",") if args.metrics else None,
)
if args.json:
print(json.dumps(summary, indent=2))
else:
print(f"Metrics Summary for '{args.experiment}' (last {args.hours}h)")
print("=" * 60)
for metric, stats in sorted(summary.items()):
print(f"\n{metric}:")
for stat, value in stats.items():
print(f" {stat}: {value:.4f}")
def cmd_report(args):
"""Generate an inference performance report."""
report = get_inference_performance_report(
service_name=args.service,
hours=args.hours,
tracking_uri=args.tracking_uri,
)
if args.json:
print(json.dumps(report, indent=2))
else:
print(f"Performance Report: {report['service']}")
print(f"Period: Last {report['period_hours']} hours")
print(f"Generated: {report['generated_at']}")
print()
if report["latency"]:
print("Latency Metrics:")
for metric, stats in report["latency"].items():
if "mean" in stats:
print(f" {metric}: {stats['mean']:.4f}s (p50: {stats.get('median', 'N/A')})")
if report["rag"]:
print("\nRAG Usage:")
for metric, stats in report["rag"].items():
print(f" {metric}: {stats.get('mean', 'N/A')}")
if report["errors"]:
print("\nError Rates:")
for metric, stats in report["errors"].items():
print(f" {metric}: {stats:.2f}%")
def cmd_promote(args):
"""Check model promotion criteria."""
# Parse criteria
criteria = {}
for criterion in args.criteria.split(","):
# Parse "metric>=value" or "metric<=value" etc.
for op in [">=", "<=", ">", "<"]:
if op in criterion:
metric, value = criterion.split(op)
criteria[metric.strip()] = (op, float(value.strip()))
break
rec = promotion_recommendation(
model_name=args.model,
experiment_name=args.experiment,
criteria=criteria,
tracking_uri=args.tracking_uri,
)
if args.json:
print(json.dumps(rec.to_dict(), indent=2))
else:
status = "✓ RECOMMENDED" if rec.recommended else "✗ NOT RECOMMENDED"
print(f"Model: {args.model}")
print(f"Status: {status}")
print("\nCriteria Evaluation:")
for reason in rec.reasons:
print(f" {reason}")
def cmd_query(args):
"""Query runs with a filter."""
analyzer = ExperimentAnalyzer(
args.experiment,
tracking_uri=args.tracking_uri
)
runs = analyzer.search_runs(
filter_string=args.filter or "",
max_results=args.limit,
)
if args.json:
result = [
{
"run_id": r.info.run_id,
"run_name": r.info.run_name,
"status": r.info.status,
"metrics": dict(r.data.metrics),
"params": dict(r.data.params),
}
for r in runs
]
print(json.dumps(result, indent=2))
else:
print(f"Found {len(runs)} runs")
for run in runs:
print(f"\n{run.info.run_name or run.info.run_id}")
print(f" ID: {run.info.run_id}")
print(f" Status: {run.info.status}")
def cmd_models(args):
"""List registered models."""
client = get_mlflow_client(tracking_uri=args.tracking_uri)
if args.model:
versions = list_model_versions(args.model, tracking_uri=args.tracking_uri)
if args.json:
print(json.dumps(versions, indent=2, default=str))
else:
print(f"Model: {args.model}")
for v in versions:
print(f" v{v['version']} ({v['stage']}): {v['description'][:50] if v['description'] else 'No description'}")
else:
# List all models
models = client.search_registered_models()
if args.json:
result = [{"name": m.name, "description": m.description} for m in models]
print(json.dumps(result, indent=2))
else:
print(f"{'Model Name':<40} Description")
print("-" * 80)
for model in models:
desc = (model.description or "")[:35]
print(f"{model.name:<40} {desc}")
def cmd_kserve(args):
"""Generate KServe manifest for a model."""
yaml_str = generate_kserve_yaml(
model_name=args.model,
version=args.version,
namespace=args.namespace,
output_path=args.output,
tracking_uri=args.tracking_uri,
)
if not args.output:
print(yaml_str)
def main():
parser = argparse.ArgumentParser(
description="MLflow Experiment CLI",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--tracking-uri",
default=None,
help="MLflow tracking URI (default: from env or in-cluster)",
)
parser.add_argument(
"--json",
action="store_true",
help="Output as JSON",
)
subparsers = parser.add_subparsers(dest="command", help="Commands")
# health
health_parser = subparsers.add_parser("health", help="Check MLflow connectivity")
health_parser.set_defaults(func=cmd_health)
# list-experiments
list_parser = subparsers.add_parser("list-experiments", help="List experiments")
list_parser.set_defaults(func=cmd_list_experiments)
# compare
compare_parser = subparsers.add_parser("compare", help="Compare runs")
compare_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
compare_parser.add_argument("--runs", "-n", type=int, default=5, help="Number of recent runs")
compare_parser.add_argument("--run-ids", help="Comma-separated run IDs to compare")
compare_parser.set_defaults(func=cmd_compare)
# best
best_parser = subparsers.add_parser("best", help="Find best run by metric")
best_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
best_parser.add_argument("--metric", "-m", required=True, help="Metric to optimize")
best_parser.add_argument("--minimize", action="store_true", help="Minimize metric (default: maximize)")
best_parser.add_argument("--filter", "-f", help="Filter string")
best_parser.set_defaults(func=cmd_best)
# summary
summary_parser = subparsers.add_parser("summary", help="Get metrics summary")
summary_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
summary_parser.add_argument("--hours", type=int, default=24, help="Hours of data")
summary_parser.add_argument("--metrics", help="Comma-separated metric names")
summary_parser.set_defaults(func=cmd_summary)
# report
report_parser = subparsers.add_parser("report", help="Generate performance report")
report_parser.add_argument("--service", "-s", required=True, help="Service name")
report_parser.add_argument("--hours", type=int, default=24, help="Hours of data")
report_parser.set_defaults(func=cmd_report)
# promote
promote_parser = subparsers.add_parser("promote", help="Check promotion criteria")
promote_parser.add_argument("--model", "-m", required=True, help="Model name")
promote_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
promote_parser.add_argument("--criteria", "-c", required=True, help="Criteria (e.g., 'accuracy>=0.9,latency<=2.0')")
promote_parser.set_defaults(func=cmd_promote)
# query
query_parser = subparsers.add_parser("query", help="Query runs")
query_parser.add_argument("--experiment", "-e", required=True, help="Experiment name")
query_parser.add_argument("--filter", "-f", help="MLflow filter string")
query_parser.add_argument("--limit", "-l", type=int, default=20, help="Max results")
query_parser.set_defaults(func=cmd_query)
# models
models_parser = subparsers.add_parser("models", help="List registered models")
models_parser.add_argument("--model", "-m", help="Specific model name")
models_parser.set_defaults(func=cmd_models)
# kserve
kserve_parser = subparsers.add_parser("kserve", help="Generate KServe manifest")
kserve_parser.add_argument("--model", "-m", required=True, help="Model name")
kserve_parser.add_argument("--version", "-v", type=int, help="Model version")
kserve_parser.add_argument("--namespace", "-n", default="ai-ml", help="K8s namespace")
kserve_parser.add_argument("--output", "-o", help="Output file path")
kserve_parser.set_defaults(func=cmd_kserve)
args = parser.parse_args()
if not args.command:
parser.print_help()
sys.exit(1)
args.func(args)
if __name__ == "__main__":
main()

209
mlflow_utils/client.py Normal file
View File

@@ -0,0 +1,209 @@
"""
MLflow Client Configuration and Initialization
Provides a configured MLflow client for all integrations in the LLM workflows.
Supports both in-cluster and external access patterns.
"""
import os
import logging
from dataclasses import dataclass, field
from typing import Optional, Dict, Any
import mlflow
from mlflow.tracking import MlflowClient
logger = logging.getLogger(__name__)
@dataclass
class MLflowConfig:
"""Configuration for MLflow integration."""
# Tracking server URIs
tracking_uri: str = field(
default_factory=lambda: os.environ.get(
"MLFLOW_TRACKING_URI",
"http://mlflow.mlflow.svc.cluster.local:80"
)
)
external_uri: str = field(
default_factory=lambda: os.environ.get(
"MLFLOW_EXTERNAL_URI",
"https://mlflow.lab.daviestechlabs.io"
)
)
# Artifact storage (NFS PVC mount)
artifact_location: str = field(
default_factory=lambda: os.environ.get(
"MLFLOW_ARTIFACT_LOCATION",
"/mlflow/artifacts"
)
)
# Default experiment settings
default_experiment: str = field(
default_factory=lambda: os.environ.get(
"MLFLOW_DEFAULT_EXPERIMENT",
"llm-workflows"
)
)
# Service identification
service_name: str = field(
default_factory=lambda: os.environ.get(
"OTEL_SERVICE_NAME",
"unknown-service"
)
)
# Additional tags to add to all runs
default_tags: Dict[str, str] = field(default_factory=dict)
def __post_init__(self):
"""Add default tags based on environment."""
env_tags = {
"environment": os.environ.get("DEPLOYMENT_ENV", "production"),
"hostname": os.environ.get("HOSTNAME", "unknown"),
"namespace": os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml"),
}
self.default_tags = {**env_tags, **self.default_tags}
def get_tracking_uri(external: bool = False) -> str:
"""
Get the appropriate MLflow tracking URI.
Args:
external: If True, return the external URI for outside-cluster access
Returns:
The MLflow tracking URI string
"""
config = MLflowConfig()
return config.external_uri if external else config.tracking_uri
def get_mlflow_client(
tracking_uri: Optional[str] = None,
configure_global: bool = True
) -> MlflowClient:
"""
Get a configured MLflow client.
Args:
tracking_uri: Override the default tracking URI
configure_global: If True, also set mlflow.set_tracking_uri()
Returns:
Configured MlflowClient instance
"""
uri = tracking_uri or get_tracking_uri()
if configure_global:
mlflow.set_tracking_uri(uri)
logger.info(f"MLflow tracking URI set to: {uri}")
client = MlflowClient(tracking_uri=uri)
return client
def ensure_experiment(
experiment_name: str,
artifact_location: Optional[str] = None,
tags: Optional[Dict[str, str]] = None
) -> str:
"""
Ensure an experiment exists, creating it if necessary.
Args:
experiment_name: Name of the experiment
artifact_location: Override default artifact location
tags: Additional tags for the experiment
Returns:
The experiment ID
"""
config = MLflowConfig()
client = get_mlflow_client()
# Check if experiment exists
experiment = client.get_experiment_by_name(experiment_name)
if experiment is None:
# Create the experiment
artifact_loc = artifact_location or f"{config.artifact_location}/{experiment_name}"
experiment_id = client.create_experiment(
name=experiment_name,
artifact_location=artifact_loc,
tags=tags or {}
)
logger.info(f"Created experiment '{experiment_name}' with ID: {experiment_id}")
else:
experiment_id = experiment.experiment_id
logger.debug(f"Using existing experiment '{experiment_name}' with ID: {experiment_id}")
return experiment_id
def get_or_create_registered_model(
model_name: str,
description: Optional[str] = None,
tags: Optional[Dict[str, str]] = None
) -> str:
"""
Get or create a registered model in the Model Registry.
Args:
model_name: Name of the model to register
description: Model description
tags: Tags for the model
Returns:
The registered model name
"""
client = get_mlflow_client()
try:
# Check if model exists
client.get_registered_model(model_name)
logger.debug(f"Using existing registered model: {model_name}")
except mlflow.exceptions.MlflowException:
# Create the model
client.create_registered_model(
name=model_name,
description=description or f"Model for {model_name}",
tags=tags or {}
)
logger.info(f"Created registered model: {model_name}")
return model_name
def health_check() -> Dict[str, Any]:
"""
Check MLflow server connectivity and return status.
Returns:
Dictionary with health status information
"""
config = MLflowConfig()
result = {
"tracking_uri": config.tracking_uri,
"external_uri": config.external_uri,
"connected": False,
"error": None,
}
try:
client = get_mlflow_client(configure_global=False)
# Try to list experiments as a health check
experiments = client.search_experiments(max_results=1)
result["connected"] = True
result["experiment_count"] = len(experiments)
except Exception as e:
result["error"] = str(e)
logger.error(f"MLflow health check failed: {e}")
return result

View File

@@ -0,0 +1,664 @@
"""
Experiment Comparison and Analysis Utilities
Provides tools for comparing model versions, querying experiments,
and making data-driven decisions about model promotion to production.
Features:
- Compare multiple runs/experiments side by side
- Query experiments by tags, metrics, or parameters
- Analyze inference metrics from NATS handlers
- Generate promotion recommendations
- Export comparison reports
Usage:
from mlflow_utils.experiment_comparison import (
ExperimentAnalyzer,
compare_runs,
get_best_run,
promotion_recommendation,
)
analyzer = ExperimentAnalyzer("chat-inference")
# Compare last N runs
comparison = analyzer.compare_recent_runs(n=5)
# Find best performing model
best = analyzer.get_best_run(metric="total_latency_mean", minimize=True)
# Get promotion recommendation
rec = analyzer.promotion_recommendation(
model_name="whisper-finetuned",
min_accuracy=0.9,
max_latency_p95=2.0
)
"""
import os
import json
import logging
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List, Tuple, Union
from dataclasses import dataclass, field
from collections import defaultdict
import mlflow
from mlflow.tracking import MlflowClient
from mlflow.entities import Run, Experiment
from .client import get_mlflow_client, MLflowConfig
logger = logging.getLogger(__name__)
@dataclass
class RunComparison:
"""Comparison result for multiple MLflow runs."""
run_ids: List[str]
experiment_name: str
# Metric comparisons (metric_name -> {run_id -> value})
metrics: Dict[str, Dict[str, float]] = field(default_factory=dict)
# Parameter differences
params: Dict[str, Dict[str, str]] = field(default_factory=dict)
# Run metadata
run_names: Dict[str, str] = field(default_factory=dict)
start_times: Dict[str, datetime] = field(default_factory=dict)
durations: Dict[str, float] = field(default_factory=dict)
# Best performers by metric
best_by_metric: Dict[str, str] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"run_ids": self.run_ids,
"experiment_name": self.experiment_name,
"metrics": self.metrics,
"params": self.params,
"run_names": self.run_names,
"best_by_metric": self.best_by_metric,
}
def summary_table(self) -> str:
"""Generate a text summary table of the comparison."""
if not self.run_ids:
return "No runs to compare"
lines = []
lines.append(f"Experiment: {self.experiment_name}")
lines.append(f"Comparing {len(self.run_ids)} runs")
lines.append("")
# Header
header = ["Metric"] + [self.run_names.get(rid, rid[:8]) for rid in self.run_ids]
lines.append(" | ".join(header))
lines.append("-" * (len(lines[-1]) + 10))
# Metrics
for metric_name, values in sorted(self.metrics.items()):
row = [metric_name]
for run_id in self.run_ids:
value = values.get(run_id)
if value is not None:
row.append(f"{value:.4f}")
else:
row.append("N/A")
lines.append(" | ".join(row))
return "\n".join(lines)
@dataclass
class PromotionRecommendation:
"""Recommendation for model promotion."""
model_name: str
version: Optional[int]
recommended: bool
reasons: List[str]
metrics_summary: Dict[str, float]
comparison_with_production: Optional[Dict[str, Any]] = None
def to_dict(self) -> Dict[str, Any]:
return {
"model_name": self.model_name,
"version": self.version,
"recommended": self.recommended,
"reasons": self.reasons,
"metrics_summary": self.metrics_summary,
"comparison_with_production": self.comparison_with_production,
}
class ExperimentAnalyzer:
"""
Analyze MLflow experiments for model comparison and promotion decisions.
Example:
analyzer = ExperimentAnalyzer("chat-inference")
# Get metrics summary for last 24 hours
summary = analyzer.get_metrics_summary(hours=24)
# Compare models by accuracy
best = analyzer.get_best_run(metric="eval.accuracy", minimize=False)
# Analyze inference latency trends
trends = analyzer.get_metric_trends("total_latency_mean", days=7)
"""
def __init__(
self,
experiment_name: str,
tracking_uri: Optional[str] = None,
):
"""
Initialize the experiment analyzer.
Args:
experiment_name: Name of the MLflow experiment to analyze
tracking_uri: Override default tracking URI
"""
self.experiment_name = experiment_name
self.tracking_uri = tracking_uri
self.client = get_mlflow_client(tracking_uri=tracking_uri)
self._experiment: Optional[Experiment] = None
@property
def experiment(self) -> Optional[Experiment]:
"""Get the experiment object, fetching if needed."""
if self._experiment is None:
self._experiment = self.client.get_experiment_by_name(self.experiment_name)
return self._experiment
def search_runs(
self,
filter_string: str = "",
order_by: Optional[List[str]] = None,
max_results: int = 100,
run_view_type: str = "ACTIVE_ONLY",
) -> List[Run]:
"""
Search for runs matching criteria.
Args:
filter_string: MLflow search filter (e.g., "metrics.accuracy > 0.9")
order_by: List of order clauses (e.g., ["metrics.accuracy DESC"])
max_results: Maximum runs to return
run_view_type: ACTIVE_ONLY, DELETED_ONLY, or ALL
Returns:
List of matching Run objects
"""
if not self.experiment:
logger.warning(f"Experiment '{self.experiment_name}' not found")
return []
runs = self.client.search_runs(
experiment_ids=[self.experiment.experiment_id],
filter_string=filter_string,
order_by=order_by or ["start_time DESC"],
max_results=max_results,
)
return runs
def get_recent_runs(
self,
n: int = 10,
hours: Optional[int] = None,
) -> List[Run]:
"""
Get the most recent runs.
Args:
n: Number of runs to return
hours: Only include runs from the last N hours
Returns:
List of Run objects
"""
filter_string = ""
if hours:
cutoff = datetime.now() - timedelta(hours=hours)
cutoff_ms = int(cutoff.timestamp() * 1000)
filter_string = f"attributes.start_time >= {cutoff_ms}"
return self.search_runs(
filter_string=filter_string,
order_by=["start_time DESC"],
max_results=n,
)
def compare_runs(
self,
run_ids: Optional[List[str]] = None,
n_recent: int = 5,
) -> RunComparison:
"""
Compare multiple runs side by side.
Args:
run_ids: Specific run IDs to compare, or None for recent runs
n_recent: If run_ids is None, compare this many recent runs
Returns:
RunComparison object with detailed comparison
"""
if run_ids:
runs = [self.client.get_run(rid) for rid in run_ids]
else:
runs = self.get_recent_runs(n=n_recent)
comparison = RunComparison(
run_ids=[r.info.run_id for r in runs],
experiment_name=self.experiment_name,
)
# Collect all metrics and find best performers
all_metrics: Dict[str, Dict[str, float]] = defaultdict(dict)
for run in runs:
run_id = run.info.run_id
# Metadata
comparison.run_names[run_id] = run.info.run_name or run_id[:8]
comparison.start_times[run_id] = datetime.fromtimestamp(
run.info.start_time / 1000
)
if run.info.end_time:
comparison.durations[run_id] = (
run.info.end_time - run.info.start_time
) / 1000
# Metrics
for key, value in run.data.metrics.items():
all_metrics[key][run_id] = value
# Params
for key, value in run.data.params.items():
if key not in comparison.params:
comparison.params[key] = {}
comparison.params[key][run_id] = value
comparison.metrics = dict(all_metrics)
# Find best performers for each metric
for metric_name, values in all_metrics.items():
if not values:
continue
# Determine if lower is better based on metric name
minimize = any(
term in metric_name.lower()
for term in ["latency", "error", "loss", "time"]
)
if minimize:
best_id = min(values.keys(), key=lambda k: values[k])
else:
best_id = max(values.keys(), key=lambda k: values[k])
comparison.best_by_metric[metric_name] = best_id
return comparison
def get_best_run(
self,
metric: str,
minimize: bool = True,
filter_string: str = "",
max_results: int = 100,
) -> Optional[Run]:
"""
Get the best run by a specific metric.
Args:
metric: Metric name to optimize
minimize: If True, find minimum; if False, find maximum
filter_string: Additional filter criteria
max_results: Maximum runs to consider
Returns:
Best Run object, or None if no runs found
"""
direction = "ASC" if minimize else "DESC"
runs = self.search_runs(
filter_string=filter_string,
order_by=[f"metrics.{metric} {direction}"],
max_results=max_results,
)
# Filter to only runs that have the metric
runs_with_metric = [
r for r in runs
if metric in r.data.metrics
]
return runs_with_metric[0] if runs_with_metric else None
def get_metrics_summary(
self,
hours: Optional[int] = None,
metrics: Optional[List[str]] = None,
) -> Dict[str, Dict[str, float]]:
"""
Get summary statistics for metrics.
Args:
hours: Only include runs from the last N hours
metrics: Specific metrics to summarize (None for all)
Returns:
Dict mapping metric names to {mean, min, max, count}
"""
import statistics
runs = self.get_recent_runs(n=1000, hours=hours)
# Collect all metric values
metric_values: Dict[str, List[float]] = defaultdict(list)
for run in runs:
for key, value in run.data.metrics.items():
if metrics is None or key in metrics:
metric_values[key].append(value)
# Calculate statistics
summary = {}
for metric_name, values in metric_values.items():
if not values:
continue
summary[metric_name] = {
"mean": statistics.mean(values),
"min": min(values),
"max": max(values),
"count": len(values),
}
if len(values) >= 2:
summary[metric_name]["stdev"] = statistics.stdev(values)
summary[metric_name]["median"] = statistics.median(values)
return summary
def get_metric_trends(
self,
metric: str,
days: int = 7,
granularity_hours: int = 1,
) -> List[Dict[str, Any]]:
"""
Get metric trends over time.
Args:
metric: Metric name to track
days: Number of days to look back
granularity_hours: Time bucket size in hours
Returns:
List of {timestamp, mean, min, max, count} dicts
"""
import statistics
runs = self.get_recent_runs(n=10000, hours=days * 24)
# Group runs by time bucket
buckets: Dict[int, List[float]] = defaultdict(list)
bucket_size_ms = granularity_hours * 3600 * 1000
for run in runs:
if metric not in run.data.metrics:
continue
bucket = (run.info.start_time // bucket_size_ms) * bucket_size_ms
buckets[bucket].append(run.data.metrics[metric])
# Calculate statistics per bucket
trends = []
for bucket_ts, values in sorted(buckets.items()):
trend = {
"timestamp": datetime.fromtimestamp(bucket_ts / 1000).isoformat(),
"count": len(values),
"mean": statistics.mean(values),
"min": min(values),
"max": max(values),
}
if len(values) >= 2:
trend["stdev"] = statistics.stdev(values)
trends.append(trend)
return trends
def get_runs_by_tag(
self,
tag_key: str,
tag_value: str,
max_results: int = 100,
) -> List[Run]:
"""
Get runs with a specific tag.
Args:
tag_key: Tag key to filter by
tag_value: Tag value to match
max_results: Maximum runs to return
Returns:
List of matching Run objects
"""
return self.search_runs(
filter_string=f"tags.{tag_key} = '{tag_value}'",
max_results=max_results,
)
def get_model_runs(
self,
model_name: str,
max_results: int = 100,
) -> List[Run]:
"""
Get runs for a specific model.
Args:
model_name: Model name to filter by
max_results: Maximum runs to return
Returns:
List of matching Run objects
"""
# Try different tag conventions
runs = self.search_runs(
filter_string=f"tags.`model.name` = '{model_name}'",
max_results=max_results,
)
if not runs:
# Try params
runs = self.search_runs(
filter_string=f"params.model_name = '{model_name}'",
max_results=max_results,
)
return runs
def compare_experiments(
experiment_names: List[str],
metric: str,
tracking_uri: Optional[str] = None,
) -> Dict[str, Dict[str, float]]:
"""
Compare metrics across multiple experiments.
Args:
experiment_names: Names of experiments to compare
metric: Metric to compare
tracking_uri: Override default tracking URI
Returns:
Dict mapping experiment names to metric statistics
"""
results = {}
for exp_name in experiment_names:
analyzer = ExperimentAnalyzer(exp_name, tracking_uri=tracking_uri)
summary = analyzer.get_metrics_summary(metrics=[metric])
if metric in summary:
results[exp_name] = summary[metric]
return results
def promotion_recommendation(
model_name: str,
experiment_name: str,
criteria: Dict[str, Tuple[str, float]],
tracking_uri: Optional[str] = None,
) -> PromotionRecommendation:
"""
Generate a recommendation for model promotion.
Args:
model_name: Name of the model to evaluate
experiment_name: Experiment containing evaluation runs
criteria: Dict of {metric: (comparison, threshold)}
comparison is one of: ">=", "<=", ">", "<"
e.g., {"eval.accuracy": (">=", 0.9), "total_latency_p95": ("<=", 2.0)}
tracking_uri: Override default tracking URI
Returns:
PromotionRecommendation with decision and reasons
"""
analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri)
# Get model runs
runs = analyzer.get_model_runs(model_name, max_results=10)
if not runs:
return PromotionRecommendation(
model_name=model_name,
version=None,
recommended=False,
reasons=["No runs found for this model"],
metrics_summary={},
)
# Get the most recent run
latest_run = runs[0]
metrics = latest_run.data.metrics
# Evaluate criteria
reasons = []
passed = True
comparisons = {
">=": lambda a, b: a >= b,
"<=": lambda a, b: a <= b,
">": lambda a, b: a > b,
"<": lambda a, b: a < b,
}
for metric_name, (comparison, threshold) in criteria.items():
if metric_name not in metrics:
reasons.append(f"Metric '{metric_name}' not found")
passed = False
continue
value = metrics[metric_name]
compare_fn = comparisons.get(comparison)
if compare_fn is None:
reasons.append(f"Invalid comparison operator: {comparison}")
continue
if compare_fn(value, threshold):
reasons.append(f"{metric_name}: {value:.4f} {comparison} {threshold}")
else:
reasons.append(f"{metric_name}: {value:.4f} NOT {comparison} {threshold}")
passed = False
# Extract version from tags if available
version = None
if "mlflow.version" in latest_run.data.tags:
try:
version = int(latest_run.data.tags["mlflow.version"])
except ValueError:
pass
return PromotionRecommendation(
model_name=model_name,
version=version,
recommended=passed,
reasons=reasons,
metrics_summary=dict(metrics),
)
def get_inference_performance_report(
service_name: str = "chat-handler",
hours: int = 24,
tracking_uri: Optional[str] = None,
) -> Dict[str, Any]:
"""
Generate an inference performance report for a service.
Args:
service_name: Service name (chat-handler, voice-assistant)
hours: Hours of data to analyze
tracking_uri: Override default tracking URI
Returns:
Performance report dictionary
"""
experiment_name = f"{service_name.replace('-', '')}-inference"
analyzer = ExperimentAnalyzer(experiment_name, tracking_uri=tracking_uri)
# Get summary metrics
summary = analyzer.get_metrics_summary(hours=hours)
# Key latency metrics
latency_metrics = [
"total_latency_mean",
"total_latency_p50",
"total_latency_p95",
"llm_latency_mean",
"embedding_latency_mean",
"rag_search_latency_mean",
]
report = {
"service": service_name,
"period_hours": hours,
"generated_at": datetime.now().isoformat(),
"latency": {},
"throughput": {},
"rag": {},
"errors": {},
}
# Latency section
for metric in latency_metrics:
if metric in summary:
report["latency"][metric] = summary[metric]
# Throughput
if "total_requests" in summary:
report["throughput"]["total_requests"] = summary["total_requests"]["mean"]
# RAG usage
rag_metrics = ["rag_enabled_pct", "rag_documents_retrieved_mean", "rag_documents_used_mean"]
for metric in rag_metrics:
if metric in summary:
report["rag"][metric] = summary[metric]
# Error rate
if "error_rate" in summary:
report["errors"]["error_rate_pct"] = summary["error_rate"]["mean"]
return report

View File

@@ -0,0 +1,431 @@
"""
Inference Metrics Tracker for NATS Handlers
Provides async-compatible MLflow logging for real-time inference
metrics from chat-handler and voice-assistant services.
Designed to integrate with the existing OpenTelemetry setup and
complement OTel metrics with MLflow experiment tracking for
longer-term analysis and model comparison.
"""
import os
import time
import asyncio
import logging
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, field
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import mlflow
from mlflow.tracking import MlflowClient
from .client import get_mlflow_client, ensure_experiment, MLflowConfig
logger = logging.getLogger(__name__)
@dataclass
class InferenceMetrics:
"""Metrics collected during an inference request."""
request_id: str
user_id: Optional[str] = None
session_id: Optional[str] = None
# Timing metrics (in seconds)
total_latency: float = 0.0
embedding_latency: float = 0.0
rag_search_latency: float = 0.0
rerank_latency: float = 0.0
llm_latency: float = 0.0
tts_latency: float = 0.0
stt_latency: float = 0.0
# Token/size metrics
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
prompt_length: int = 0
response_length: int = 0
# RAG metrics
rag_enabled: bool = False
rag_documents_retrieved: int = 0
rag_documents_used: int = 0
reranker_enabled: bool = False
# Quality indicators
is_streaming: bool = False
is_premium: bool = False
has_error: bool = False
error_message: Optional[str] = None
# Model information
model_name: Optional[str] = None
model_endpoint: Optional[str] = None
# Timestamps
timestamp: float = field(default_factory=time.time)
def as_metrics_dict(self) -> Dict[str, float]:
"""Convert numeric fields to a metrics dictionary."""
return {
"total_latency": self.total_latency,
"embedding_latency": self.embedding_latency,
"rag_search_latency": self.rag_search_latency,
"rerank_latency": self.rerank_latency,
"llm_latency": self.llm_latency,
"tts_latency": self.tts_latency,
"stt_latency": self.stt_latency,
"input_tokens": float(self.input_tokens),
"output_tokens": float(self.output_tokens),
"total_tokens": float(self.total_tokens),
"prompt_length": float(self.prompt_length),
"response_length": float(self.response_length),
"rag_documents_retrieved": float(self.rag_documents_retrieved),
"rag_documents_used": float(self.rag_documents_used),
}
def as_params_dict(self) -> Dict[str, str]:
"""Convert configuration fields to a params dictionary."""
params = {
"rag_enabled": str(self.rag_enabled),
"reranker_enabled": str(self.reranker_enabled),
"is_streaming": str(self.is_streaming),
"is_premium": str(self.is_premium),
}
if self.model_name:
params["model_name"] = self.model_name
if self.model_endpoint:
params["model_endpoint"] = self.model_endpoint
return params
class InferenceMetricsTracker:
"""
Async-compatible MLflow tracker for inference metrics.
Uses batching and a background thread pool to avoid blocking
the async event loop during MLflow calls.
Example usage in chat-handler:
class ChatHandler:
def __init__(self):
self.mlflow_tracker = InferenceMetricsTracker(
service_name="chat-handler",
experiment_name="chat-inference"
)
async def setup(self):
await self.mlflow_tracker.start()
async def process_request(self, msg):
metrics = InferenceMetrics(request_id=request_id)
# Track timing
start = time.time()
# ... do embedding ...
metrics.embedding_latency = time.time() - start
# ... more processing ...
# Log metrics (non-blocking)
await self.mlflow_tracker.log_inference(metrics)
async def shutdown(self):
await self.mlflow_tracker.stop()
"""
def __init__(
self,
service_name: str,
experiment_name: Optional[str] = None,
tracking_uri: Optional[str] = None,
batch_size: int = 50,
flush_interval_seconds: float = 60.0,
enable_batching: bool = True,
max_workers: int = 2,
):
"""
Initialize the inference metrics tracker.
Args:
service_name: Name of the service (e.g., "chat-handler")
experiment_name: MLflow experiment name (defaults to service_name)
tracking_uri: Override default tracking URI
batch_size: Number of metrics to batch before flushing
flush_interval_seconds: Maximum time between flushes
enable_batching: If False, log each request immediately
max_workers: Number of thread pool workers for MLflow calls
"""
self.service_name = service_name
self.experiment_name = experiment_name or f"{service_name}-inference"
self.tracking_uri = tracking_uri
self.batch_size = batch_size
self.flush_interval = flush_interval_seconds
self.enable_batching = enable_batching
self.config = MLflowConfig()
self._batch: List[InferenceMetrics] = []
self._batch_lock = asyncio.Lock()
self._flush_task: Optional[asyncio.Task] = None
self._executor = ThreadPoolExecutor(max_workers=max_workers)
self._running = False
self._client: Optional[MlflowClient] = None
self._experiment_id: Optional[str] = None
# Aggregate metrics for periodic logging
self._aggregate_metrics: Dict[str, List[float]] = defaultdict(list)
self._request_count = 0
self._error_count = 0
async def start(self) -> None:
"""Start the tracker and initialize MLflow connection."""
if self._running:
return
self._running = True
# Initialize MLflow in thread pool to avoid blocking
loop = asyncio.get_event_loop()
await loop.run_in_executor(
self._executor,
self._init_mlflow
)
if self.enable_batching:
self._flush_task = asyncio.create_task(self._periodic_flush())
logger.info(
f"InferenceMetricsTracker started for {self.service_name} "
f"(experiment: {self.experiment_name})"
)
def _init_mlflow(self) -> None:
"""Initialize MLflow client and experiment (runs in thread pool)."""
self._client = get_mlflow_client(
tracking_uri=self.tracking_uri,
configure_global=True
)
self._experiment_id = ensure_experiment(
self.experiment_name,
tags={
"service": self.service_name,
"type": "inference-metrics",
}
)
async def stop(self) -> None:
"""Stop the tracker and flush remaining metrics."""
if not self._running:
return
self._running = False
if self._flush_task:
self._flush_task.cancel()
try:
await self._flush_task
except asyncio.CancelledError:
pass
# Final flush
await self._flush_batch()
self._executor.shutdown(wait=True)
logger.info(f"InferenceMetricsTracker stopped for {self.service_name}")
async def log_inference(self, metrics: InferenceMetrics) -> None:
"""
Log inference metrics (non-blocking).
Args:
metrics: InferenceMetrics object with request data
"""
if not self._running:
logger.warning("Tracker not running, skipping metrics")
return
self._request_count += 1
if metrics.has_error:
self._error_count += 1
# Update aggregates
for key, value in metrics.as_metrics_dict().items():
if value > 0:
self._aggregate_metrics[key].append(value)
if self.enable_batching:
async with self._batch_lock:
self._batch.append(metrics)
if len(self._batch) >= self.batch_size:
asyncio.create_task(self._flush_batch())
else:
# Immediate logging in thread pool
loop = asyncio.get_event_loop()
await loop.run_in_executor(
self._executor,
partial(self._log_single_inference, metrics)
)
async def _periodic_flush(self) -> None:
"""Periodically flush batched metrics."""
while self._running:
await asyncio.sleep(self.flush_interval)
await self._flush_batch()
async def _flush_batch(self) -> None:
"""Flush the current batch of metrics to MLflow."""
async with self._batch_lock:
if not self._batch:
return
batch = self._batch
self._batch = []
# Log in thread pool
loop = asyncio.get_event_loop()
await loop.run_in_executor(
self._executor,
partial(self._log_batch, batch)
)
def _log_single_inference(self, metrics: InferenceMetrics) -> None:
"""Log a single inference request to MLflow (runs in thread pool)."""
try:
with mlflow.start_run(
experiment_id=self._experiment_id,
run_name=f"inference-{metrics.request_id}",
tags={
"service": self.service_name,
"request_id": metrics.request_id,
"type": "single-inference",
}
):
mlflow.log_params(metrics.as_params_dict())
mlflow.log_metrics(metrics.as_metrics_dict())
if metrics.user_id:
mlflow.set_tag("user_id", metrics.user_id)
if metrics.session_id:
mlflow.set_tag("session_id", metrics.session_id)
if metrics.has_error:
mlflow.set_tag("has_error", "true")
if metrics.error_message:
mlflow.set_tag("error_message", metrics.error_message[:250])
except Exception as e:
logger.error(f"Failed to log inference metrics: {e}")
def _log_batch(self, batch: List[InferenceMetrics]) -> None:
"""Log a batch of inference metrics as aggregate statistics."""
if not batch:
return
try:
# Calculate aggregates
aggregates = self._calculate_aggregates(batch)
run_name = f"batch-{self.service_name}-{int(time.time())}"
with mlflow.start_run(
experiment_id=self._experiment_id,
run_name=run_name,
tags={
"service": self.service_name,
"type": "batch-inference",
"batch_size": str(len(batch)),
}
):
# Log aggregate metrics
mlflow.log_metrics(aggregates)
# Log batch info
mlflow.log_param("batch_size", len(batch))
mlflow.log_param("time_window_start", min(m.timestamp for m in batch))
mlflow.log_param("time_window_end", max(m.timestamp for m in batch))
# Log configuration breakdown
rag_enabled_count = sum(1 for m in batch if m.rag_enabled)
streaming_count = sum(1 for m in batch if m.is_streaming)
premium_count = sum(1 for m in batch if m.is_premium)
error_count = sum(1 for m in batch if m.has_error)
mlflow.log_metrics({
"rag_enabled_pct": rag_enabled_count / len(batch) * 100,
"streaming_pct": streaming_count / len(batch) * 100,
"premium_pct": premium_count / len(batch) * 100,
"error_rate": error_count / len(batch) * 100,
})
# Log model distribution
model_counts: Dict[str, int] = defaultdict(int)
for m in batch:
if m.model_name:
model_counts[m.model_name] += 1
if model_counts:
mlflow.log_dict(
{"models": dict(model_counts)},
"model_distribution.json"
)
logger.info(f"Logged batch of {len(batch)} inference metrics")
except Exception as e:
logger.error(f"Failed to log batch metrics: {e}")
def _calculate_aggregates(
self,
batch: List[InferenceMetrics]
) -> Dict[str, float]:
"""Calculate aggregate statistics from a batch of metrics."""
import statistics
aggregates = {}
# Collect all numeric metrics
metric_values: Dict[str, List[float]] = defaultdict(list)
for m in batch:
for key, value in m.as_metrics_dict().items():
if value > 0:
metric_values[key].append(value)
# Calculate statistics for each metric
for key, values in metric_values.items():
if not values:
continue
aggregates[f"{key}_mean"] = statistics.mean(values)
aggregates[f"{key}_min"] = min(values)
aggregates[f"{key}_max"] = max(values)
if len(values) >= 2:
aggregates[f"{key}_p50"] = statistics.median(values)
aggregates[f"{key}_stdev"] = statistics.stdev(values)
if len(values) >= 4:
sorted_vals = sorted(values)
p95_idx = int(len(sorted_vals) * 0.95)
p99_idx = int(len(sorted_vals) * 0.99)
aggregates[f"{key}_p95"] = sorted_vals[p95_idx]
aggregates[f"{key}_p99"] = sorted_vals[p99_idx]
# Add counts
aggregates["total_requests"] = float(len(batch))
return aggregates
def get_stats(self) -> Dict[str, Any]:
"""Get current tracker statistics."""
return {
"service_name": self.service_name,
"experiment_name": self.experiment_name,
"running": self._running,
"total_requests": self._request_count,
"error_count": self._error_count,
"pending_batch_size": len(self._batch),
"aggregate_metrics_count": len(self._aggregate_metrics),
}

View File

@@ -0,0 +1,513 @@
"""
Kubeflow Pipeline Components with MLflow Tracking
Provides reusable KFP components that integrate MLflow experiment
tracking into Kubeflow Pipelines. These components can be used
directly in pipelines or as wrappers around existing pipeline steps.
Usage in a Kubeflow Pipeline:
from mlflow_utils.kfp_components import (
create_mlflow_run,
log_metrics_component,
log_model_artifact,
end_mlflow_run,
)
@dsl.pipeline(name="my-pipeline")
def my_pipeline():
# Start MLflow run
run_info = create_mlflow_run(
experiment_name="my-experiment",
run_name="training-run-1"
)
# ... your pipeline steps ...
# Log metrics
log_step = log_metrics_component(
run_id=run_info.outputs["run_id"],
metrics={"accuracy": 0.95, "loss": 0.05}
)
# End run
end_mlflow_run(run_id=run_info.outputs["run_id"])
"""
from kfp import dsl
from typing import Dict, Any, List, Optional, NamedTuple
# MLflow component image with all required dependencies
MLFLOW_IMAGE = "python:3.13-slim"
MLFLOW_PACKAGES = [
"mlflow>=2.10.0",
"boto3", # For S3 artifact storage if needed
"psycopg2-binary", # For PostgreSQL backend
]
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
def create_mlflow_run(
experiment_name: str,
run_name: str,
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
tags: Dict[str, str] = None,
params: Dict[str, str] = None,
) -> NamedTuple('RunInfo', [('run_id', str), ('experiment_id', str), ('artifact_uri', str)]):
"""
Create a new MLflow run for the pipeline.
This should be called at the start of a pipeline to initialize
tracking. The returned run_id should be passed to subsequent
components for logging.
Args:
experiment_name: Name of the MLflow experiment
run_name: Name for this specific run
mlflow_tracking_uri: MLflow tracking server URI
tags: Optional tags to add to the run
params: Optional parameters to log
Returns:
NamedTuple with run_id, experiment_id, and artifact_uri
"""
import os
import mlflow
from mlflow.tracking import MlflowClient
from collections import namedtuple
# Set tracking URI
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
# Get or create experiment
experiment = client.get_experiment_by_name(experiment_name)
if experiment is None:
experiment_id = client.create_experiment(
name=experiment_name,
artifact_location=f"/mlflow/artifacts/{experiment_name}"
)
else:
experiment_id = experiment.experiment_id
# Create default tags
default_tags = {
"pipeline.type": "kubeflow",
"kfp.run_id": os.environ.get("KFP_RUN_ID", "unknown"),
"kfp.pod_name": os.environ.get("HOSTNAME", "unknown"),
}
if tags:
default_tags.update(tags)
# Start run
run = mlflow.start_run(
experiment_id=experiment_id,
run_name=run_name,
tags=default_tags,
)
# Log initial params
if params:
mlflow.log_params(params)
run_id = run.info.run_id
artifact_uri = run.info.artifact_uri
# End run (KFP components are isolated, we'll resume in other components)
mlflow.end_run()
RunInfo = namedtuple('RunInfo', ['run_id', 'experiment_id', 'artifact_uri'])
return RunInfo(run_id, experiment_id, artifact_uri)
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
def log_params_component(
run_id: str,
params: Dict[str, str],
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> str:
"""
Log parameters to an existing MLflow run.
Args:
run_id: The MLflow run ID to log to
params: Dictionary of parameters to log
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id for chaining
"""
import mlflow
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
for key, value in params.items():
client.log_param(run_id, key, str(value)[:500])
return run_id
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
def log_metrics_component(
run_id: str,
metrics: Dict[str, float],
step: int = 0,
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> str:
"""
Log metrics to an existing MLflow run.
Args:
run_id: The MLflow run ID to log to
metrics: Dictionary of metrics to log
step: Step number for time-series metrics
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id for chaining
"""
import mlflow
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
for key, value in metrics.items():
client.log_metric(run_id, key, float(value), step=step)
return run_id
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
def log_artifact_component(
run_id: str,
artifact_path: str,
artifact_name: str = "",
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> str:
"""
Log an artifact file to an existing MLflow run.
Args:
run_id: The MLflow run ID to log to
artifact_path: Path to the artifact file
artifact_name: Optional destination name in artifact store
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id for chaining
"""
import mlflow
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
client.log_artifact(run_id, artifact_path, artifact_name or None)
return run_id
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
def log_dict_artifact(
run_id: str,
data: Dict[str, Any],
filename: str,
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> str:
"""
Log a dictionary as a JSON artifact.
Args:
run_id: The MLflow run ID to log to
data: Dictionary to save as JSON
filename: Name for the JSON file
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id for chaining
"""
import json
import tempfile
import mlflow
from mlflow.tracking import MlflowClient
from pathlib import Path
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
# Ensure .json extension
if not filename.endswith('.json'):
filename += '.json'
# Write to temp file and log
with tempfile.TemporaryDirectory() as tmpdir:
filepath = Path(tmpdir) / filename
with open(filepath, 'w') as f:
json.dump(data, f, indent=2)
client.log_artifact(run_id, str(filepath))
return run_id
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
def end_mlflow_run(
run_id: str,
status: str = "FINISHED",
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> str:
"""
End an MLflow run with the specified status.
Args:
run_id: The MLflow run ID to end
status: Run status (FINISHED, FAILED, KILLED)
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id
"""
import mlflow
from mlflow.tracking import MlflowClient
from mlflow.entities import RunStatus
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
status_map = {
"FINISHED": RunStatus.FINISHED,
"FAILED": RunStatus.FAILED,
"KILLED": RunStatus.KILLED,
}
run_status = status_map.get(status.upper(), RunStatus.FINISHED)
client.set_terminated(run_id, status=run_status)
return run_id
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES + ["httpx"]
)
def log_training_metrics(
run_id: str,
model_type: str,
training_config: Dict[str, Any],
final_metrics: Dict[str, float],
model_path: str = "",
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> str:
"""
Log comprehensive training metrics for ML models.
Designed for use with QLoRA training, voice training, and other
ML training pipelines in the llm-workflows repository.
Args:
run_id: The MLflow run ID to log to
model_type: Type of model (llm, stt, tts, embeddings)
training_config: Training configuration dict
final_metrics: Final training metrics
model_path: Path to saved model (if applicable)
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id for chaining
"""
import json
import tempfile
import mlflow
from mlflow.tracking import MlflowClient
from pathlib import Path
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
# Log training config as params
flat_config = {}
for key, value in training_config.items():
if isinstance(value, (dict, list)):
flat_config[f"config.{key}"] = json.dumps(value)[:500]
else:
flat_config[f"config.{key}"] = str(value)[:500]
for key, value in flat_config.items():
client.log_param(run_id, key, value)
# Log model type tag
client.set_tag(run_id, "model.type", model_type)
# Log metrics
for key, value in final_metrics.items():
client.log_metric(run_id, key, float(value))
# Log full config as artifact
with tempfile.TemporaryDirectory() as tmpdir:
config_path = Path(tmpdir) / "training_config.json"
with open(config_path, 'w') as f:
json.dump(training_config, f, indent=2)
client.log_artifact(run_id, str(config_path))
# Log model path if provided
if model_path:
client.log_param(run_id, "model.path", model_path)
client.set_tag(run_id, "model.saved", "true")
return run_id
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
def log_document_ingestion_metrics(
run_id: str,
source_url: str,
collection_name: str,
chunks_created: int,
documents_processed: int,
processing_time_seconds: float,
embeddings_model: str = "bge-small-en-v1.5",
chunk_size: int = 500,
chunk_overlap: int = 50,
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> str:
"""
Log document ingestion pipeline metrics.
Designed for use with the document_ingestion_pipeline.
Args:
run_id: The MLflow run ID to log to
source_url: URL of the source document
collection_name: Milvus collection name
chunks_created: Number of chunks created
documents_processed: Number of documents processed
processing_time_seconds: Total processing time
embeddings_model: Embeddings model used
chunk_size: Chunk size in tokens
chunk_overlap: Chunk overlap in tokens
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id for chaining
"""
import mlflow
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
# Log params
params = {
"source_url": source_url[:500],
"collection_name": collection_name,
"embeddings_model": embeddings_model,
"chunk_size": str(chunk_size),
"chunk_overlap": str(chunk_overlap),
}
for key, value in params.items():
client.log_param(run_id, key, value)
# Log metrics
metrics = {
"chunks_created": chunks_created,
"documents_processed": documents_processed,
"processing_time_seconds": processing_time_seconds,
"chunks_per_second": chunks_created / processing_time_seconds if processing_time_seconds > 0 else 0,
}
for key, value in metrics.items():
client.log_metric(run_id, key, float(value))
# Set pipeline type tag
client.set_tag(run_id, "pipeline.type", "document-ingestion")
client.set_tag(run_id, "milvus.collection", collection_name)
return run_id
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
def log_evaluation_results(
run_id: str,
model_name: str,
dataset_name: str,
metrics: Dict[str, float],
sample_results: List[Dict[str, Any]] = None,
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> str:
"""
Log model evaluation results.
Designed for use with the evaluation_pipeline.
Args:
run_id: The MLflow run ID to log to
model_name: Name of the evaluated model
dataset_name: Name of the evaluation dataset
metrics: Evaluation metrics (accuracy, etc.)
sample_results: Optional sample predictions
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id for chaining
"""
import json
import tempfile
import mlflow
from mlflow.tracking import MlflowClient
from pathlib import Path
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
# Log params
client.log_param(run_id, "eval.model_name", model_name)
client.log_param(run_id, "eval.dataset", dataset_name)
# Log metrics
for key, value in metrics.items():
client.log_metric(run_id, f"eval.{key}", float(value))
# Log sample results as artifact
if sample_results:
with tempfile.TemporaryDirectory() as tmpdir:
results_path = Path(tmpdir) / "evaluation_results.json"
with open(results_path, 'w') as f:
json.dump(sample_results, f, indent=2)
client.log_artifact(run_id, str(results_path))
# Set tags
client.set_tag(run_id, "pipeline.type", "evaluation")
client.set_tag(run_id, "model.name", model_name)
# Determine if passed
passed = metrics.get("pass", metrics.get("accuracy", 0) >= 0.7)
client.set_tag(run_id, "eval.passed", str(passed))
return run_id

View File

@@ -0,0 +1,545 @@
"""
MLflow Model Registry Integration for KServe
Provides utilities for registering trained models in MLflow Model Registry
with metadata needed for deployment to KServe InferenceServices.
This module bridges the gap between Kubeflow training pipelines and
KServe model serving by:
1. Registering models with proper versioning
2. Adding KServe-specific metadata (runtime, protocol, resources)
3. Managing model stage transitions (Staging → Production)
4. Generating KServe InferenceService manifests from registered models
Usage:
from mlflow_utils.model_registry import (
register_model_for_kserve,
promote_model_to_production,
generate_kserve_manifest,
)
# Register a new model version
model_version = register_model_for_kserve(
model_name="whisper-finetuned",
model_uri="s3://models/whisper-v2",
model_type="stt",
kserve_config={
"runtime": "kserve-huggingface",
"container_image": "ghcr.io/my-org/whisper:v2",
}
)
# Generate KServe manifest for deployment
manifest = generate_kserve_manifest(
model_name="whisper-finetuned",
model_version=model_version.version,
)
"""
import os
import json
import yaml
import logging
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, field
import mlflow
from mlflow.tracking import MlflowClient
from mlflow.entities.model_registry import ModelVersion
from .client import get_mlflow_client, MLflowConfig
logger = logging.getLogger(__name__)
@dataclass
class KServeConfig:
"""Configuration for KServe deployment."""
# Runtime/container configuration
runtime: str = "kserve-huggingface" # kserve-huggingface, kserve-custom, etc.
container_image: Optional[str] = None
container_port: int = 8080
# Protocol configuration
protocol: str = "v2" # v1, v2, grpc
# Resource requests/limits
cpu_request: str = "1"
cpu_limit: str = "4"
memory_request: str = "4Gi"
memory_limit: str = "16Gi"
gpu_count: int = 0
gpu_type: str = "nvidia.com/gpu" # or amd.com/gpu for ROCm
# Storage configuration
storage_uri: Optional[str] = None # s3://, pvc://, gs://
# Scaling configuration
min_replicas: int = 1
max_replicas: int = 1
scale_target: int = 10 # Target concurrent requests for scaling
# Serving configuration
timeout_seconds: int = 300
batch_size: int = 1
# Additional environment variables
env_vars: Dict[str, str] = field(default_factory=dict)
def as_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for MLflow tags."""
return {
"kserve.runtime": self.runtime,
"kserve.protocol": self.protocol,
"kserve.cpu_request": self.cpu_request,
"kserve.memory_request": self.memory_request,
"kserve.gpu_count": str(self.gpu_count),
"kserve.min_replicas": str(self.min_replicas),
"kserve.max_replicas": str(self.max_replicas),
"kserve.storage_uri": self.storage_uri or "",
"kserve.container_image": self.container_image or "",
}
# Pre-configured KServe configurations for common model types
KSERVE_PRESETS: Dict[str, KServeConfig] = {
"llm": KServeConfig(
runtime="kserve-huggingface",
cpu_request="2",
cpu_limit="8",
memory_request="16Gi",
memory_limit="64Gi",
gpu_count=1,
timeout_seconds=600,
),
"stt": KServeConfig(
runtime="kserve-custom",
cpu_request="2",
cpu_limit="4",
memory_request="8Gi",
memory_limit="16Gi",
gpu_count=1,
timeout_seconds=120,
),
"tts": KServeConfig(
runtime="kserve-custom",
cpu_request="2",
cpu_limit="4",
memory_request="8Gi",
memory_limit="16Gi",
gpu_count=1,
timeout_seconds=60,
),
"embeddings": KServeConfig(
runtime="kserve-huggingface",
cpu_request="1",
cpu_limit="4",
memory_request="4Gi",
memory_limit="16Gi",
gpu_count=0,
timeout_seconds=30,
batch_size=32,
),
"reranker": KServeConfig(
runtime="kserve-huggingface",
cpu_request="1",
cpu_limit="4",
memory_request="4Gi",
memory_limit="16Gi",
gpu_count=0,
timeout_seconds=30,
),
}
def register_model_for_kserve(
model_name: str,
model_uri: str,
model_type: str,
run_id: Optional[str] = None,
description: Optional[str] = None,
kserve_config: Optional[KServeConfig] = None,
tags: Optional[Dict[str, str]] = None,
tracking_uri: Optional[str] = None,
) -> ModelVersion:
"""
Register a model in MLflow Model Registry with KServe metadata.
Args:
model_name: Name for the registered model
model_uri: URI to model artifacts (runs:/run_id/path, s3://, pvc://)
model_type: Type of model (llm, stt, tts, embeddings, reranker)
run_id: Optional MLflow run ID to associate with
description: Description of the model version
kserve_config: KServe deployment configuration
tags: Additional tags for the model version
tracking_uri: Override default tracking URI
Returns:
The created ModelVersion object
"""
client = get_mlflow_client(tracking_uri=tracking_uri)
# Get or use preset KServe config
if kserve_config is None:
kserve_config = KSERVE_PRESETS.get(model_type, KServeConfig())
# Ensure registered model exists
try:
client.get_registered_model(model_name)
except mlflow.exceptions.MlflowException:
client.create_registered_model(
name=model_name,
description=f"{model_type.upper()} model for KServe deployment",
tags={
"model.type": model_type,
"deployment.target": "kserve",
}
)
logger.info(f"Created registered model: {model_name}")
# Create model version
model_version = client.create_model_version(
name=model_name,
source=model_uri,
run_id=run_id,
description=description or f"Version from {model_uri}",
tags={
**(tags or {}),
"model.type": model_type,
**kserve_config.as_dict(),
}
)
logger.info(
f"Registered model version {model_version.version} "
f"for {model_name} (type: {model_type})"
)
return model_version
def promote_model_to_stage(
model_name: str,
version: int,
stage: str = "Staging",
archive_existing: bool = True,
tracking_uri: Optional[str] = None,
) -> ModelVersion:
"""
Promote a model version to a new stage.
Args:
model_name: Name of the registered model
version: Version number to promote
stage: Target stage (Staging, Production, Archived)
archive_existing: If True, archive existing versions in target stage
tracking_uri: Override default tracking URI
Returns:
The updated ModelVersion
"""
client = get_mlflow_client(tracking_uri=tracking_uri)
# Transition to new stage
model_version = client.transition_model_version_stage(
name=model_name,
version=str(version),
stage=stage,
archive_existing_versions=archive_existing,
)
logger.info(f"Promoted {model_name} v{version} to {stage}")
return model_version
def promote_model_to_production(
model_name: str,
version: int,
tracking_uri: Optional[str] = None,
) -> ModelVersion:
"""
Promote a model version directly to Production.
Args:
model_name: Name of the registered model
version: Version number to promote
tracking_uri: Override default tracking URI
Returns:
The updated ModelVersion
"""
return promote_model_to_stage(
model_name=model_name,
version=version,
stage="Production",
archive_existing=True,
tracking_uri=tracking_uri,
)
def get_production_model(
model_name: str,
tracking_uri: Optional[str] = None,
) -> Optional[ModelVersion]:
"""
Get the current Production model version.
Args:
model_name: Name of the registered model
tracking_uri: Override default tracking URI
Returns:
The Production ModelVersion, or None if none exists
"""
client = get_mlflow_client(tracking_uri=tracking_uri)
versions = client.get_latest_versions(model_name, stages=["Production"])
return versions[0] if versions else None
def get_model_kserve_config(
model_name: str,
version: Optional[int] = None,
tracking_uri: Optional[str] = None,
) -> KServeConfig:
"""
Get KServe configuration from a registered model version.
Args:
model_name: Name of the registered model
version: Version number (uses Production if not specified)
tracking_uri: Override default tracking URI
Returns:
KServeConfig populated from model tags
"""
client = get_mlflow_client(tracking_uri=tracking_uri)
if version:
model_version = client.get_model_version(model_name, str(version))
else:
prod_version = get_production_model(model_name, tracking_uri)
if not prod_version:
raise ValueError(f"No Production version for {model_name}")
model_version = prod_version
tags = model_version.tags
return KServeConfig(
runtime=tags.get("kserve.runtime", "kserve-huggingface"),
protocol=tags.get("kserve.protocol", "v2"),
cpu_request=tags.get("kserve.cpu_request", "1"),
memory_request=tags.get("kserve.memory_request", "4Gi"),
gpu_count=int(tags.get("kserve.gpu_count", "0")),
min_replicas=int(tags.get("kserve.min_replicas", "1")),
max_replicas=int(tags.get("kserve.max_replicas", "1")),
storage_uri=tags.get("kserve.storage_uri") or None,
container_image=tags.get("kserve.container_image") or None,
)
def generate_kserve_manifest(
model_name: str,
version: Optional[int] = None,
namespace: str = "ai-ml",
service_name: Optional[str] = None,
extra_annotations: Optional[Dict[str, str]] = None,
tracking_uri: Optional[str] = None,
) -> Dict[str, Any]:
"""
Generate a KServe InferenceService manifest from a registered model.
Args:
model_name: Name of the registered model
version: Version number (uses Production if not specified)
namespace: Kubernetes namespace for deployment
service_name: Name for the InferenceService (defaults to model_name)
extra_annotations: Additional annotations for the service
tracking_uri: Override default tracking URI
Returns:
KServe InferenceService manifest as a dictionary
"""
client = get_mlflow_client(tracking_uri=tracking_uri)
# Get model version
if version:
model_version = client.get_model_version(model_name, str(version))
else:
prod_version = get_production_model(model_name, tracking_uri)
if not prod_version:
raise ValueError(f"No Production version for {model_name}")
model_version = prod_version
version = int(model_version.version)
# Get KServe config
config = get_model_kserve_config(model_name, version, tracking_uri)
model_type = model_version.tags.get("model.type", "custom")
svc_name = service_name or model_name.lower().replace("_", "-")
# Build manifest
manifest = {
"apiVersion": "serving.kserve.io/v1beta1",
"kind": "InferenceService",
"metadata": {
"name": svc_name,
"namespace": namespace,
"labels": {
"mlflow.model": model_name,
"mlflow.version": str(version),
"model.type": model_type,
},
"annotations": {
"mlflow.tracking_uri": get_mlflow_client().tracking_uri,
"mlflow.run_id": model_version.run_id or "",
**(extra_annotations or {}),
},
},
"spec": {
"predictor": {
"minReplicas": config.min_replicas,
"maxReplicas": config.max_replicas,
"scaleTarget": config.scale_target,
"timeout": config.timeout_seconds,
},
},
}
# Configure predictor based on runtime
predictor = manifest["spec"]["predictor"]
if config.container_image:
# Custom container
predictor["containers"] = [{
"name": "predictor",
"image": config.container_image,
"ports": [{"containerPort": config.container_port, "protocol": "TCP"}],
"resources": {
"requests": {
"cpu": config.cpu_request,
"memory": config.memory_request,
},
"limits": {
"cpu": config.cpu_limit,
"memory": config.memory_limit,
},
},
"env": [
{"name": k, "value": v}
for k, v in config.env_vars.items()
],
}]
# Add GPU if needed
if config.gpu_count > 0:
predictor["containers"][0]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
predictor["containers"][0]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
else:
# Standard KServe runtime
storage_uri = config.storage_uri or model_version.source
predictor["model"] = {
"modelFormat": {"name": "huggingface"},
"protocolVersion": config.protocol,
"storageUri": storage_uri,
"resources": {
"requests": {
"cpu": config.cpu_request,
"memory": config.memory_request,
},
"limits": {
"cpu": config.cpu_limit,
"memory": config.memory_limit,
},
},
}
if config.gpu_count > 0:
predictor["model"]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
predictor["model"]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
return manifest
def generate_kserve_yaml(
model_name: str,
version: Optional[int] = None,
namespace: str = "ai-ml",
output_path: Optional[str] = None,
tracking_uri: Optional[str] = None,
) -> str:
"""
Generate a KServe InferenceService manifest as YAML.
Args:
model_name: Name of the registered model
version: Version number (uses Production if not specified)
namespace: Kubernetes namespace
output_path: If provided, write YAML to this path
tracking_uri: Override default tracking URI
Returns:
YAML string of the manifest
"""
manifest = generate_kserve_manifest(
model_name=model_name,
version=version,
namespace=namespace,
tracking_uri=tracking_uri,
)
yaml_str = yaml.dump(manifest, default_flow_style=False, sort_keys=False)
if output_path:
with open(output_path, 'w') as f:
f.write(yaml_str)
logger.info(f"Wrote KServe manifest to {output_path}")
return yaml_str
def list_model_versions(
model_name: str,
stages: Optional[List[str]] = None,
tracking_uri: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
List all versions of a registered model.
Args:
model_name: Name of the registered model
stages: Filter by stages (None for all)
tracking_uri: Override default tracking URI
Returns:
List of model version info dictionaries
"""
client = get_mlflow_client(tracking_uri=tracking_uri)
if stages:
versions = client.get_latest_versions(model_name, stages=stages)
else:
# Get all versions
versions = []
for mv in client.search_model_versions(f"name='{model_name}'"):
versions.append(mv)
return [
{
"version": mv.version,
"stage": mv.current_stage,
"source": mv.source,
"run_id": mv.run_id,
"description": mv.description,
"tags": mv.tags,
"creation_timestamp": mv.creation_timestamp,
"last_updated_timestamp": mv.last_updated_timestamp,
}
for mv in versions
]

395
mlflow_utils/tracker.py Normal file
View File

@@ -0,0 +1,395 @@
"""
MLflow Tracker for Kubeflow Pipelines
Provides a high-level interface for logging experiments, parameters,
metrics, and artifacts from Kubeflow Pipeline components.
"""
import os
import json
import time
import logging
from pathlib import Path
from typing import Optional, Dict, Any, List, Union
from contextlib import contextmanager
from dataclasses import dataclass, field
import mlflow
from mlflow.tracking import MlflowClient
from .client import get_mlflow_client, ensure_experiment, MLflowConfig
logger = logging.getLogger(__name__)
@dataclass
class PipelineMetadata:
"""Metadata about the Kubeflow Pipeline run."""
pipeline_name: str
run_id: str
run_name: Optional[str] = None
component_name: Optional[str] = None
namespace: str = "ai-ml"
# KFP-specific metadata (populated from environment if available)
kfp_run_id: Optional[str] = field(
default_factory=lambda: os.environ.get("KFP_RUN_ID")
)
kfp_pod_name: Optional[str] = field(
default_factory=lambda: os.environ.get("KFP_POD_NAME")
)
def as_tags(self) -> Dict[str, str]:
"""Convert metadata to MLflow tags."""
tags = {
"pipeline.name": self.pipeline_name,
"pipeline.run_id": self.run_id,
"pipeline.namespace": self.namespace,
}
if self.run_name:
tags["pipeline.run_name"] = self.run_name
if self.component_name:
tags["pipeline.component"] = self.component_name
if self.kfp_run_id:
tags["kfp.run_id"] = self.kfp_run_id
if self.kfp_pod_name:
tags["kfp.pod_name"] = self.kfp_pod_name
return tags
class MLflowTracker:
"""
MLflow experiment tracker for Kubeflow Pipeline components.
Example usage in a KFP component:
from mlflow_utils import MLflowTracker
tracker = MLflowTracker(
experiment_name="document-ingestion",
run_name="batch-ingestion-2024-01"
)
with tracker.start_run() as run:
tracker.log_params({
"chunk_size": 500,
"overlap": 50,
"embeddings_model": "bge-small-en-v1.5"
})
# ... do work ...
tracker.log_metrics({
"documents_processed": 100,
"chunks_created": 2500,
"processing_time_seconds": 120.5
})
tracker.log_artifact("/path/to/output.json")
"""
def __init__(
self,
experiment_name: str,
run_name: Optional[str] = None,
pipeline_metadata: Optional[PipelineMetadata] = None,
tags: Optional[Dict[str, str]] = None,
tracking_uri: Optional[str] = None,
):
"""
Initialize the MLflow tracker.
Args:
experiment_name: Name of the MLflow experiment
run_name: Optional name for this run
pipeline_metadata: Metadata about the KFP pipeline
tags: Additional tags to add to the run
tracking_uri: Override default tracking URI
"""
self.config = MLflowConfig()
self.experiment_name = experiment_name
self.run_name = run_name or f"{experiment_name}-{int(time.time())}"
self.pipeline_metadata = pipeline_metadata
self.user_tags = tags or {}
self.tracking_uri = tracking_uri
self.client: Optional[MlflowClient] = None
self.run: Optional[mlflow.ActiveRun] = None
self.run_id: Optional[str] = None
self._start_time: Optional[float] = None
def _get_all_tags(self) -> Dict[str, str]:
"""Combine all tags for the run."""
tags = self.config.default_tags.copy()
if self.pipeline_metadata:
tags.update(self.pipeline_metadata.as_tags())
tags.update(self.user_tags)
return tags
@contextmanager
def start_run(
self,
nested: bool = False,
parent_run_id: Optional[str] = None,
):
"""
Start an MLflow run as a context manager.
Args:
nested: If True, create a nested run under the current active run
parent_run_id: Explicit parent run ID for nested runs
Yields:
The MLflow run object
"""
self.client = get_mlflow_client(
tracking_uri=self.tracking_uri,
configure_global=True
)
# Ensure experiment exists
experiment_id = ensure_experiment(self.experiment_name)
self._start_time = time.time()
try:
# Start the run
self.run = mlflow.start_run(
experiment_id=experiment_id,
run_name=self.run_name,
nested=nested,
tags=self._get_all_tags(),
)
self.run_id = self.run.info.run_id
logger.info(
f"Started MLflow run '{self.run_name}' "
f"(ID: {self.run_id}) in experiment '{self.experiment_name}'"
)
yield self.run
except Exception as e:
logger.error(f"MLflow run failed: {e}")
if self.run:
mlflow.set_tag("run.status", "failed")
mlflow.set_tag("run.error", str(e))
raise
finally:
# Log duration
if self._start_time:
duration = time.time() - self._start_time
try:
mlflow.log_metric("run_duration_seconds", duration)
except Exception:
pass
# End the run
mlflow.end_run()
logger.info(f"Ended MLflow run '{self.run_name}'")
def log_params(self, params: Dict[str, Any]) -> None:
"""
Log parameters to the current run.
Args:
params: Dictionary of parameter names to values
"""
if not self.run:
logger.warning("No active run, skipping log_params")
return
# MLflow has limits on param values, truncate if needed
cleaned_params = {}
for key, value in params.items():
str_value = str(value)
if len(str_value) > 500:
str_value = str_value[:497] + "..."
cleaned_params[key] = str_value
mlflow.log_params(cleaned_params)
logger.debug(f"Logged {len(params)} parameters")
def log_param(self, key: str, value: Any) -> None:
"""Log a single parameter."""
self.log_params({key: value})
def log_metrics(
self,
metrics: Dict[str, Union[float, int]],
step: Optional[int] = None
) -> None:
"""
Log metrics to the current run.
Args:
metrics: Dictionary of metric names to values
step: Optional step number for time-series metrics
"""
if not self.run:
logger.warning("No active run, skipping log_metrics")
return
mlflow.log_metrics(metrics, step=step)
logger.debug(f"Logged {len(metrics)} metrics")
def log_metric(
self,
key: str,
value: Union[float, int],
step: Optional[int] = None
) -> None:
"""Log a single metric."""
self.log_metrics({key: value}, step=step)
def log_artifact(
self,
local_path: str,
artifact_path: Optional[str] = None
) -> None:
"""
Log an artifact file to the current run.
Args:
local_path: Path to the local file to log
artifact_path: Optional destination path within the artifact store
"""
if not self.run:
logger.warning("No active run, skipping log_artifact")
return
mlflow.log_artifact(local_path, artifact_path)
logger.info(f"Logged artifact: {local_path}")
def log_artifacts(
self,
local_dir: str,
artifact_path: Optional[str] = None
) -> None:
"""
Log all files in a directory as artifacts.
Args:
local_dir: Path to the local directory
artifact_path: Optional destination path within the artifact store
"""
if not self.run:
logger.warning("No active run, skipping log_artifacts")
return
mlflow.log_artifacts(local_dir, artifact_path)
logger.info(f"Logged artifacts from: {local_dir}")
def log_dict(
self,
data: Dict[str, Any],
filename: str,
artifact_path: Optional[str] = None
) -> None:
"""
Log a dictionary as a JSON artifact.
Args:
data: Dictionary to log
filename: Name for the JSON file
artifact_path: Optional destination path
"""
if not self.run:
logger.warning("No active run, skipping log_dict")
return
# Ensure .json extension
if not filename.endswith(".json"):
filename += ".json"
mlflow.log_dict(data, f"{artifact_path}/{filename}" if artifact_path else filename)
logger.debug(f"Logged dict as: {filename}")
def log_model_info(
self,
model_type: str,
model_name: str,
model_path: Optional[str] = None,
framework: str = "pytorch",
extra_info: Optional[Dict[str, Any]] = None
) -> None:
"""
Log model information as parameters and tags.
Args:
model_type: Type of model (e.g., "llm", "embedding", "stt")
model_name: Name/identifier of the model
model_path: Path to model weights
framework: ML framework used
extra_info: Additional model information
"""
params = {
"model.type": model_type,
"model.name": model_name,
"model.framework": framework,
}
if model_path:
params["model.path"] = model_path
if extra_info:
for key, value in extra_info.items():
params[f"model.{key}"] = value
self.log_params(params)
# Also set as tags for easier filtering
mlflow.set_tag("model.type", model_type)
mlflow.set_tag("model.name", model_name)
def log_dataset_info(
self,
name: str,
source: str,
size: Optional[int] = None,
extra_info: Optional[Dict[str, Any]] = None
) -> None:
"""
Log dataset information.
Args:
name: Dataset name
source: Dataset source (URL, path, etc.)
size: Number of samples
extra_info: Additional dataset information
"""
params = {
"dataset.name": name,
"dataset.source": source,
}
if size is not None:
params["dataset.size"] = size
if extra_info:
for key, value in extra_info.items():
params[f"dataset.{key}"] = value
self.log_params(params)
def set_tag(self, key: str, value: str) -> None:
"""Set a single tag on the run."""
if self.run:
mlflow.set_tag(key, value)
def set_tags(self, tags: Dict[str, str]) -> None:
"""Set multiple tags on the run."""
if self.run:
mlflow.set_tags(tags)
@property
def artifact_uri(self) -> Optional[str]:
"""Get the artifact URI for the current run."""
if self.run:
return self.run.info.artifact_uri
return None
@property
def experiment_id(self) -> Optional[str]:
"""Get the experiment ID for the current run."""
if self.run:
return self.run.info.experiment_id
return None

19
requirements.txt Normal file
View File

@@ -0,0 +1,19 @@
# MLflow Utils Module Requirements
# Core MLflow
mlflow>=2.10.0
# Database backends
psycopg2-binary>=2.9.0 # PostgreSQL (CNPG)
boto3>=1.34.0 # S3-compatible artifact storage (optional)
# For async tracking
aiohttp>=3.9.0
# YAML generation for KServe manifests
PyYAML>=6.0
# Already in chat-handler/voice-assistant requirements:
# httpx (for health checks)
# Used but typically installed with mlflow:
# numpy
# pandas