feat: Add MLflow integration utilities
- client: Connection management and helpers - tracker: General experiment tracking - inference_tracker: Async metrics for NATS handlers - model_registry: Model registration with KServe metadata - kfp_components: Kubeflow Pipeline components - experiment_comparison: Run comparison tools - cli: Command-line interface
This commit is contained in:
545
mlflow_utils/model_registry.py
Normal file
545
mlflow_utils/model_registry.py
Normal file
@@ -0,0 +1,545 @@
|
||||
"""
|
||||
MLflow Model Registry Integration for KServe
|
||||
|
||||
Provides utilities for registering trained models in MLflow Model Registry
|
||||
with metadata needed for deployment to KServe InferenceServices.
|
||||
|
||||
This module bridges the gap between Kubeflow training pipelines and
|
||||
KServe model serving by:
|
||||
1. Registering models with proper versioning
|
||||
2. Adding KServe-specific metadata (runtime, protocol, resources)
|
||||
3. Managing model stage transitions (Staging → Production)
|
||||
4. Generating KServe InferenceService manifests from registered models
|
||||
|
||||
Usage:
|
||||
from mlflow_utils.model_registry import (
|
||||
register_model_for_kserve,
|
||||
promote_model_to_production,
|
||||
generate_kserve_manifest,
|
||||
)
|
||||
|
||||
# Register a new model version
|
||||
model_version = register_model_for_kserve(
|
||||
model_name="whisper-finetuned",
|
||||
model_uri="s3://models/whisper-v2",
|
||||
model_type="stt",
|
||||
kserve_config={
|
||||
"runtime": "kserve-huggingface",
|
||||
"container_image": "ghcr.io/my-org/whisper:v2",
|
||||
}
|
||||
)
|
||||
|
||||
# Generate KServe manifest for deployment
|
||||
manifest = generate_kserve_manifest(
|
||||
model_name="whisper-finetuned",
|
||||
model_version=model_version.version,
|
||||
)
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import yaml
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from mlflow.entities.model_registry import ModelVersion
|
||||
|
||||
from .client import get_mlflow_client, MLflowConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KServeConfig:
|
||||
"""Configuration for KServe deployment."""
|
||||
|
||||
# Runtime/container configuration
|
||||
runtime: str = "kserve-huggingface" # kserve-huggingface, kserve-custom, etc.
|
||||
container_image: Optional[str] = None
|
||||
container_port: int = 8080
|
||||
|
||||
# Protocol configuration
|
||||
protocol: str = "v2" # v1, v2, grpc
|
||||
|
||||
# Resource requests/limits
|
||||
cpu_request: str = "1"
|
||||
cpu_limit: str = "4"
|
||||
memory_request: str = "4Gi"
|
||||
memory_limit: str = "16Gi"
|
||||
gpu_count: int = 0
|
||||
gpu_type: str = "nvidia.com/gpu" # or amd.com/gpu for ROCm
|
||||
|
||||
# Storage configuration
|
||||
storage_uri: Optional[str] = None # s3://, pvc://, gs://
|
||||
|
||||
# Scaling configuration
|
||||
min_replicas: int = 1
|
||||
max_replicas: int = 1
|
||||
scale_target: int = 10 # Target concurrent requests for scaling
|
||||
|
||||
# Serving configuration
|
||||
timeout_seconds: int = 300
|
||||
batch_size: int = 1
|
||||
|
||||
# Additional environment variables
|
||||
env_vars: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def as_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for MLflow tags."""
|
||||
return {
|
||||
"kserve.runtime": self.runtime,
|
||||
"kserve.protocol": self.protocol,
|
||||
"kserve.cpu_request": self.cpu_request,
|
||||
"kserve.memory_request": self.memory_request,
|
||||
"kserve.gpu_count": str(self.gpu_count),
|
||||
"kserve.min_replicas": str(self.min_replicas),
|
||||
"kserve.max_replicas": str(self.max_replicas),
|
||||
"kserve.storage_uri": self.storage_uri or "",
|
||||
"kserve.container_image": self.container_image or "",
|
||||
}
|
||||
|
||||
|
||||
# Pre-configured KServe configurations for common model types
|
||||
KSERVE_PRESETS: Dict[str, KServeConfig] = {
|
||||
"llm": KServeConfig(
|
||||
runtime="kserve-huggingface",
|
||||
cpu_request="2",
|
||||
cpu_limit="8",
|
||||
memory_request="16Gi",
|
||||
memory_limit="64Gi",
|
||||
gpu_count=1,
|
||||
timeout_seconds=600,
|
||||
),
|
||||
"stt": KServeConfig(
|
||||
runtime="kserve-custom",
|
||||
cpu_request="2",
|
||||
cpu_limit="4",
|
||||
memory_request="8Gi",
|
||||
memory_limit="16Gi",
|
||||
gpu_count=1,
|
||||
timeout_seconds=120,
|
||||
),
|
||||
"tts": KServeConfig(
|
||||
runtime="kserve-custom",
|
||||
cpu_request="2",
|
||||
cpu_limit="4",
|
||||
memory_request="8Gi",
|
||||
memory_limit="16Gi",
|
||||
gpu_count=1,
|
||||
timeout_seconds=60,
|
||||
),
|
||||
"embeddings": KServeConfig(
|
||||
runtime="kserve-huggingface",
|
||||
cpu_request="1",
|
||||
cpu_limit="4",
|
||||
memory_request="4Gi",
|
||||
memory_limit="16Gi",
|
||||
gpu_count=0,
|
||||
timeout_seconds=30,
|
||||
batch_size=32,
|
||||
),
|
||||
"reranker": KServeConfig(
|
||||
runtime="kserve-huggingface",
|
||||
cpu_request="1",
|
||||
cpu_limit="4",
|
||||
memory_request="4Gi",
|
||||
memory_limit="16Gi",
|
||||
gpu_count=0,
|
||||
timeout_seconds=30,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def register_model_for_kserve(
|
||||
model_name: str,
|
||||
model_uri: str,
|
||||
model_type: str,
|
||||
run_id: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
kserve_config: Optional[KServeConfig] = None,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> ModelVersion:
|
||||
"""
|
||||
Register a model in MLflow Model Registry with KServe metadata.
|
||||
|
||||
Args:
|
||||
model_name: Name for the registered model
|
||||
model_uri: URI to model artifacts (runs:/run_id/path, s3://, pvc://)
|
||||
model_type: Type of model (llm, stt, tts, embeddings, reranker)
|
||||
run_id: Optional MLflow run ID to associate with
|
||||
description: Description of the model version
|
||||
kserve_config: KServe deployment configuration
|
||||
tags: Additional tags for the model version
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
The created ModelVersion object
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
# Get or use preset KServe config
|
||||
if kserve_config is None:
|
||||
kserve_config = KSERVE_PRESETS.get(model_type, KServeConfig())
|
||||
|
||||
# Ensure registered model exists
|
||||
try:
|
||||
client.get_registered_model(model_name)
|
||||
except mlflow.exceptions.MlflowException:
|
||||
client.create_registered_model(
|
||||
name=model_name,
|
||||
description=f"{model_type.upper()} model for KServe deployment",
|
||||
tags={
|
||||
"model.type": model_type,
|
||||
"deployment.target": "kserve",
|
||||
}
|
||||
)
|
||||
logger.info(f"Created registered model: {model_name}")
|
||||
|
||||
# Create model version
|
||||
model_version = client.create_model_version(
|
||||
name=model_name,
|
||||
source=model_uri,
|
||||
run_id=run_id,
|
||||
description=description or f"Version from {model_uri}",
|
||||
tags={
|
||||
**(tags or {}),
|
||||
"model.type": model_type,
|
||||
**kserve_config.as_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Registered model version {model_version.version} "
|
||||
f"for {model_name} (type: {model_type})"
|
||||
)
|
||||
|
||||
return model_version
|
||||
|
||||
|
||||
def promote_model_to_stage(
|
||||
model_name: str,
|
||||
version: int,
|
||||
stage: str = "Staging",
|
||||
archive_existing: bool = True,
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> ModelVersion:
|
||||
"""
|
||||
Promote a model version to a new stage.
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
version: Version number to promote
|
||||
stage: Target stage (Staging, Production, Archived)
|
||||
archive_existing: If True, archive existing versions in target stage
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
The updated ModelVersion
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
# Transition to new stage
|
||||
model_version = client.transition_model_version_stage(
|
||||
name=model_name,
|
||||
version=str(version),
|
||||
stage=stage,
|
||||
archive_existing_versions=archive_existing,
|
||||
)
|
||||
|
||||
logger.info(f"Promoted {model_name} v{version} to {stage}")
|
||||
|
||||
return model_version
|
||||
|
||||
|
||||
def promote_model_to_production(
|
||||
model_name: str,
|
||||
version: int,
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> ModelVersion:
|
||||
"""
|
||||
Promote a model version directly to Production.
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
version: Version number to promote
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
The updated ModelVersion
|
||||
"""
|
||||
return promote_model_to_stage(
|
||||
model_name=model_name,
|
||||
version=version,
|
||||
stage="Production",
|
||||
archive_existing=True,
|
||||
tracking_uri=tracking_uri,
|
||||
)
|
||||
|
||||
|
||||
def get_production_model(
|
||||
model_name: str,
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> Optional[ModelVersion]:
|
||||
"""
|
||||
Get the current Production model version.
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
The Production ModelVersion, or None if none exists
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
versions = client.get_latest_versions(model_name, stages=["Production"])
|
||||
|
||||
return versions[0] if versions else None
|
||||
|
||||
|
||||
def get_model_kserve_config(
|
||||
model_name: str,
|
||||
version: Optional[int] = None,
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> KServeConfig:
|
||||
"""
|
||||
Get KServe configuration from a registered model version.
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
version: Version number (uses Production if not specified)
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
KServeConfig populated from model tags
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
if version:
|
||||
model_version = client.get_model_version(model_name, str(version))
|
||||
else:
|
||||
prod_version = get_production_model(model_name, tracking_uri)
|
||||
if not prod_version:
|
||||
raise ValueError(f"No Production version for {model_name}")
|
||||
model_version = prod_version
|
||||
|
||||
tags = model_version.tags
|
||||
|
||||
return KServeConfig(
|
||||
runtime=tags.get("kserve.runtime", "kserve-huggingface"),
|
||||
protocol=tags.get("kserve.protocol", "v2"),
|
||||
cpu_request=tags.get("kserve.cpu_request", "1"),
|
||||
memory_request=tags.get("kserve.memory_request", "4Gi"),
|
||||
gpu_count=int(tags.get("kserve.gpu_count", "0")),
|
||||
min_replicas=int(tags.get("kserve.min_replicas", "1")),
|
||||
max_replicas=int(tags.get("kserve.max_replicas", "1")),
|
||||
storage_uri=tags.get("kserve.storage_uri") or None,
|
||||
container_image=tags.get("kserve.container_image") or None,
|
||||
)
|
||||
|
||||
|
||||
def generate_kserve_manifest(
|
||||
model_name: str,
|
||||
version: Optional[int] = None,
|
||||
namespace: str = "ai-ml",
|
||||
service_name: Optional[str] = None,
|
||||
extra_annotations: Optional[Dict[str, str]] = None,
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a KServe InferenceService manifest from a registered model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
version: Version number (uses Production if not specified)
|
||||
namespace: Kubernetes namespace for deployment
|
||||
service_name: Name for the InferenceService (defaults to model_name)
|
||||
extra_annotations: Additional annotations for the service
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
KServe InferenceService manifest as a dictionary
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
# Get model version
|
||||
if version:
|
||||
model_version = client.get_model_version(model_name, str(version))
|
||||
else:
|
||||
prod_version = get_production_model(model_name, tracking_uri)
|
||||
if not prod_version:
|
||||
raise ValueError(f"No Production version for {model_name}")
|
||||
model_version = prod_version
|
||||
version = int(model_version.version)
|
||||
|
||||
# Get KServe config
|
||||
config = get_model_kserve_config(model_name, version, tracking_uri)
|
||||
model_type = model_version.tags.get("model.type", "custom")
|
||||
|
||||
svc_name = service_name or model_name.lower().replace("_", "-")
|
||||
|
||||
# Build manifest
|
||||
manifest = {
|
||||
"apiVersion": "serving.kserve.io/v1beta1",
|
||||
"kind": "InferenceService",
|
||||
"metadata": {
|
||||
"name": svc_name,
|
||||
"namespace": namespace,
|
||||
"labels": {
|
||||
"mlflow.model": model_name,
|
||||
"mlflow.version": str(version),
|
||||
"model.type": model_type,
|
||||
},
|
||||
"annotations": {
|
||||
"mlflow.tracking_uri": get_mlflow_client().tracking_uri,
|
||||
"mlflow.run_id": model_version.run_id or "",
|
||||
**(extra_annotations or {}),
|
||||
},
|
||||
},
|
||||
"spec": {
|
||||
"predictor": {
|
||||
"minReplicas": config.min_replicas,
|
||||
"maxReplicas": config.max_replicas,
|
||||
"scaleTarget": config.scale_target,
|
||||
"timeout": config.timeout_seconds,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Configure predictor based on runtime
|
||||
predictor = manifest["spec"]["predictor"]
|
||||
|
||||
if config.container_image:
|
||||
# Custom container
|
||||
predictor["containers"] = [{
|
||||
"name": "predictor",
|
||||
"image": config.container_image,
|
||||
"ports": [{"containerPort": config.container_port, "protocol": "TCP"}],
|
||||
"resources": {
|
||||
"requests": {
|
||||
"cpu": config.cpu_request,
|
||||
"memory": config.memory_request,
|
||||
},
|
||||
"limits": {
|
||||
"cpu": config.cpu_limit,
|
||||
"memory": config.memory_limit,
|
||||
},
|
||||
},
|
||||
"env": [
|
||||
{"name": k, "value": v}
|
||||
for k, v in config.env_vars.items()
|
||||
],
|
||||
}]
|
||||
|
||||
# Add GPU if needed
|
||||
if config.gpu_count > 0:
|
||||
predictor["containers"][0]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
|
||||
predictor["containers"][0]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
|
||||
|
||||
else:
|
||||
# Standard KServe runtime
|
||||
storage_uri = config.storage_uri or model_version.source
|
||||
|
||||
predictor["model"] = {
|
||||
"modelFormat": {"name": "huggingface"},
|
||||
"protocolVersion": config.protocol,
|
||||
"storageUri": storage_uri,
|
||||
"resources": {
|
||||
"requests": {
|
||||
"cpu": config.cpu_request,
|
||||
"memory": config.memory_request,
|
||||
},
|
||||
"limits": {
|
||||
"cpu": config.cpu_limit,
|
||||
"memory": config.memory_limit,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if config.gpu_count > 0:
|
||||
predictor["model"]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
|
||||
predictor["model"]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
|
||||
|
||||
return manifest
|
||||
|
||||
|
||||
def generate_kserve_yaml(
|
||||
model_name: str,
|
||||
version: Optional[int] = None,
|
||||
namespace: str = "ai-ml",
|
||||
output_path: Optional[str] = None,
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a KServe InferenceService manifest as YAML.
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
version: Version number (uses Production if not specified)
|
||||
namespace: Kubernetes namespace
|
||||
output_path: If provided, write YAML to this path
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
YAML string of the manifest
|
||||
"""
|
||||
manifest = generate_kserve_manifest(
|
||||
model_name=model_name,
|
||||
version=version,
|
||||
namespace=namespace,
|
||||
tracking_uri=tracking_uri,
|
||||
)
|
||||
|
||||
yaml_str = yaml.dump(manifest, default_flow_style=False, sort_keys=False)
|
||||
|
||||
if output_path:
|
||||
with open(output_path, 'w') as f:
|
||||
f.write(yaml_str)
|
||||
logger.info(f"Wrote KServe manifest to {output_path}")
|
||||
|
||||
return yaml_str
|
||||
|
||||
|
||||
def list_model_versions(
|
||||
model_name: str,
|
||||
stages: Optional[List[str]] = None,
|
||||
tracking_uri: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List all versions of a registered model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
stages: Filter by stages (None for all)
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
Returns:
|
||||
List of model version info dictionaries
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
if stages:
|
||||
versions = client.get_latest_versions(model_name, stages=stages)
|
||||
else:
|
||||
# Get all versions
|
||||
versions = []
|
||||
for mv in client.search_model_versions(f"name='{model_name}'"):
|
||||
versions.append(mv)
|
||||
|
||||
return [
|
||||
{
|
||||
"version": mv.version,
|
||||
"stage": mv.current_stage,
|
||||
"source": mv.source,
|
||||
"run_id": mv.run_id,
|
||||
"description": mv.description,
|
||||
"tags": mv.tags,
|
||||
"creation_timestamp": mv.creation_timestamp,
|
||||
"last_updated_timestamp": mv.last_updated_timestamp,
|
||||
}
|
||||
for mv in versions
|
||||
]
|
||||
Reference in New Issue
Block a user