Files
mlflow/mlflow_utils/model_registry.py
Billy D. ca5bef9664
Some checks failed
CI / Lint (push) Successful in 1m46s
CI / Test (push) Successful in 1m44s
CI / Publish (push) Failing after 19s
CI / Notify (push) Successful in 1s
style: apply ruff format to all files
2026-02-13 11:05:26 -05:00

539 lines
16 KiB
Python

"""
MLflow Model Registry Integration for KServe
Provides utilities for registering trained models in MLflow Model Registry
with metadata needed for deployment to KServe InferenceServices.
This module bridges the gap between Kubeflow training pipelines and
KServe model serving by:
1. Registering models with proper versioning
2. Adding KServe-specific metadata (runtime, protocol, resources)
3. Managing model stage transitions (Staging → Production)
4. Generating KServe InferenceService manifests from registered models
Usage:
from mlflow_utils.model_registry import (
register_model_for_kserve,
promote_model_to_production,
generate_kserve_manifest,
)
# Register a new model version
model_version = register_model_for_kserve(
model_name="whisper-finetuned",
model_uri="s3://models/whisper-v2",
model_type="stt",
kserve_config={
"runtime": "kserve-huggingface",
"container_image": "ghcr.io/my-org/whisper:v2",
}
)
# Generate KServe manifest for deployment
manifest = generate_kserve_manifest(
model_name="whisper-finetuned",
model_version=model_version.version,
)
"""
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import mlflow
import yaml
from mlflow.entities.model_registry import ModelVersion
from .client import get_mlflow_client
logger = logging.getLogger(__name__)
@dataclass
class KServeConfig:
"""Configuration for KServe deployment."""
# Runtime/container configuration
runtime: str = "kserve-huggingface" # kserve-huggingface, kserve-custom, etc.
container_image: Optional[str] = None
container_port: int = 8080
# Protocol configuration
protocol: str = "v2" # v1, v2, grpc
# Resource requests/limits
cpu_request: str = "1"
cpu_limit: str = "4"
memory_request: str = "4Gi"
memory_limit: str = "16Gi"
gpu_count: int = 0
gpu_type: str = "nvidia.com/gpu" # or amd.com/gpu for ROCm
# Storage configuration
storage_uri: Optional[str] = None # s3://, pvc://, gs://
# Scaling configuration
min_replicas: int = 1
max_replicas: int = 1
scale_target: int = 10 # Target concurrent requests for scaling
# Serving configuration
timeout_seconds: int = 300
batch_size: int = 1
# Additional environment variables
env_vars: Dict[str, str] = field(default_factory=dict)
def as_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for MLflow tags."""
return {
"kserve.runtime": self.runtime,
"kserve.protocol": self.protocol,
"kserve.cpu_request": self.cpu_request,
"kserve.memory_request": self.memory_request,
"kserve.gpu_count": str(self.gpu_count),
"kserve.min_replicas": str(self.min_replicas),
"kserve.max_replicas": str(self.max_replicas),
"kserve.storage_uri": self.storage_uri or "",
"kserve.container_image": self.container_image or "",
}
# Pre-configured KServe configurations for common model types
KSERVE_PRESETS: Dict[str, KServeConfig] = {
"llm": KServeConfig(
runtime="kserve-huggingface",
cpu_request="2",
cpu_limit="8",
memory_request="16Gi",
memory_limit="64Gi",
gpu_count=1,
timeout_seconds=600,
),
"stt": KServeConfig(
runtime="kserve-custom",
cpu_request="2",
cpu_limit="4",
memory_request="8Gi",
memory_limit="16Gi",
gpu_count=1,
timeout_seconds=120,
),
"tts": KServeConfig(
runtime="kserve-custom",
cpu_request="2",
cpu_limit="4",
memory_request="8Gi",
memory_limit="16Gi",
gpu_count=1,
timeout_seconds=60,
),
"embeddings": KServeConfig(
runtime="kserve-huggingface",
cpu_request="1",
cpu_limit="4",
memory_request="4Gi",
memory_limit="16Gi",
gpu_count=0,
timeout_seconds=30,
batch_size=32,
),
"reranker": KServeConfig(
runtime="kserve-huggingface",
cpu_request="1",
cpu_limit="4",
memory_request="4Gi",
memory_limit="16Gi",
gpu_count=0,
timeout_seconds=30,
),
}
def register_model_for_kserve(
model_name: str,
model_uri: str,
model_type: str,
run_id: Optional[str] = None,
description: Optional[str] = None,
kserve_config: Optional[KServeConfig] = None,
tags: Optional[Dict[str, str]] = None,
tracking_uri: Optional[str] = None,
) -> ModelVersion:
"""
Register a model in MLflow Model Registry with KServe metadata.
Args:
model_name: Name for the registered model
model_uri: URI to model artifacts (runs:/run_id/path, s3://, pvc://)
model_type: Type of model (llm, stt, tts, embeddings, reranker)
run_id: Optional MLflow run ID to associate with
description: Description of the model version
kserve_config: KServe deployment configuration
tags: Additional tags for the model version
tracking_uri: Override default tracking URI
Returns:
The created ModelVersion object
"""
client = get_mlflow_client(tracking_uri=tracking_uri)
# Get or use preset KServe config
if kserve_config is None:
kserve_config = KSERVE_PRESETS.get(model_type, KServeConfig())
# Ensure registered model exists
try:
client.get_registered_model(model_name)
except mlflow.exceptions.MlflowException:
client.create_registered_model(
name=model_name,
description=f"{model_type.upper()} model for KServe deployment",
tags={
"model.type": model_type,
"deployment.target": "kserve",
},
)
logger.info(f"Created registered model: {model_name}")
# Create model version
model_version = client.create_model_version(
name=model_name,
source=model_uri,
run_id=run_id,
description=description or f"Version from {model_uri}",
tags={
**(tags or {}),
"model.type": model_type,
**kserve_config.as_dict(),
},
)
logger.info(f"Registered model version {model_version.version} for {model_name} (type: {model_type})")
return model_version
def promote_model_to_stage(
model_name: str,
version: int,
stage: str = "Staging",
archive_existing: bool = True,
tracking_uri: Optional[str] = None,
) -> ModelVersion:
"""
Promote a model version to a new stage.
Args:
model_name: Name of the registered model
version: Version number to promote
stage: Target stage (Staging, Production, Archived)
archive_existing: If True, archive existing versions in target stage
tracking_uri: Override default tracking URI
Returns:
The updated ModelVersion
"""
client = get_mlflow_client(tracking_uri=tracking_uri)
# Transition to new stage
model_version = client.transition_model_version_stage(
name=model_name,
version=str(version),
stage=stage,
archive_existing_versions=archive_existing,
)
logger.info(f"Promoted {model_name} v{version} to {stage}")
return model_version
def promote_model_to_production(
model_name: str,
version: int,
tracking_uri: Optional[str] = None,
) -> ModelVersion:
"""
Promote a model version directly to Production.
Args:
model_name: Name of the registered model
version: Version number to promote
tracking_uri: Override default tracking URI
Returns:
The updated ModelVersion
"""
return promote_model_to_stage(
model_name=model_name,
version=version,
stage="Production",
archive_existing=True,
tracking_uri=tracking_uri,
)
def get_production_model(
model_name: str,
tracking_uri: Optional[str] = None,
) -> Optional[ModelVersion]:
"""
Get the current Production model version.
Args:
model_name: Name of the registered model
tracking_uri: Override default tracking URI
Returns:
The Production ModelVersion, or None if none exists
"""
client = get_mlflow_client(tracking_uri=tracking_uri)
versions = client.get_latest_versions(model_name, stages=["Production"])
return versions[0] if versions else None
def get_model_kserve_config(
model_name: str,
version: Optional[int] = None,
tracking_uri: Optional[str] = None,
) -> KServeConfig:
"""
Get KServe configuration from a registered model version.
Args:
model_name: Name of the registered model
version: Version number (uses Production if not specified)
tracking_uri: Override default tracking URI
Returns:
KServeConfig populated from model tags
"""
client = get_mlflow_client(tracking_uri=tracking_uri)
if version:
model_version = client.get_model_version(model_name, str(version))
else:
prod_version = get_production_model(model_name, tracking_uri)
if not prod_version:
raise ValueError(f"No Production version for {model_name}")
model_version = prod_version
tags = model_version.tags
return KServeConfig(
runtime=tags.get("kserve.runtime", "kserve-huggingface"),
protocol=tags.get("kserve.protocol", "v2"),
cpu_request=tags.get("kserve.cpu_request", "1"),
memory_request=tags.get("kserve.memory_request", "4Gi"),
gpu_count=int(tags.get("kserve.gpu_count", "0")),
min_replicas=int(tags.get("kserve.min_replicas", "1")),
max_replicas=int(tags.get("kserve.max_replicas", "1")),
storage_uri=tags.get("kserve.storage_uri") or None,
container_image=tags.get("kserve.container_image") or None,
)
def generate_kserve_manifest(
model_name: str,
version: Optional[int] = None,
namespace: str = "ai-ml",
service_name: Optional[str] = None,
extra_annotations: Optional[Dict[str, str]] = None,
tracking_uri: Optional[str] = None,
) -> Dict[str, Any]:
"""
Generate a KServe InferenceService manifest from a registered model.
Args:
model_name: Name of the registered model
version: Version number (uses Production if not specified)
namespace: Kubernetes namespace for deployment
service_name: Name for the InferenceService (defaults to model_name)
extra_annotations: Additional annotations for the service
tracking_uri: Override default tracking URI
Returns:
KServe InferenceService manifest as a dictionary
"""
client = get_mlflow_client(tracking_uri=tracking_uri)
# Get model version
if version:
model_version = client.get_model_version(model_name, str(version))
else:
prod_version = get_production_model(model_name, tracking_uri)
if not prod_version:
raise ValueError(f"No Production version for {model_name}")
model_version = prod_version
version = int(model_version.version)
# Get KServe config
config = get_model_kserve_config(model_name, version, tracking_uri)
model_type = model_version.tags.get("model.type", "custom")
svc_name = service_name or model_name.lower().replace("_", "-")
# Build manifest
manifest = {
"apiVersion": "serving.kserve.io/v1beta1",
"kind": "InferenceService",
"metadata": {
"name": svc_name,
"namespace": namespace,
"labels": {
"mlflow.model": model_name,
"mlflow.version": str(version),
"model.type": model_type,
},
"annotations": {
"mlflow.tracking_uri": get_mlflow_client().tracking_uri,
"mlflow.run_id": model_version.run_id or "",
**(extra_annotations or {}),
},
},
"spec": {
"predictor": {
"minReplicas": config.min_replicas,
"maxReplicas": config.max_replicas,
"scaleTarget": config.scale_target,
"timeout": config.timeout_seconds,
},
},
}
# Configure predictor based on runtime
predictor = manifest["spec"]["predictor"]
if config.container_image:
# Custom container
predictor["containers"] = [
{
"name": "predictor",
"image": config.container_image,
"ports": [{"containerPort": config.container_port, "protocol": "TCP"}],
"resources": {
"requests": {
"cpu": config.cpu_request,
"memory": config.memory_request,
},
"limits": {
"cpu": config.cpu_limit,
"memory": config.memory_limit,
},
},
"env": [{"name": k, "value": v} for k, v in config.env_vars.items()],
}
]
# Add GPU if needed
if config.gpu_count > 0:
predictor["containers"][0]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
predictor["containers"][0]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
else:
# Standard KServe runtime
storage_uri = config.storage_uri or model_version.source
predictor["model"] = {
"modelFormat": {"name": "huggingface"},
"protocolVersion": config.protocol,
"storageUri": storage_uri,
"resources": {
"requests": {
"cpu": config.cpu_request,
"memory": config.memory_request,
},
"limits": {
"cpu": config.cpu_limit,
"memory": config.memory_limit,
},
},
}
if config.gpu_count > 0:
predictor["model"]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
predictor["model"]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
return manifest
def generate_kserve_yaml(
model_name: str,
version: Optional[int] = None,
namespace: str = "ai-ml",
output_path: Optional[str] = None,
tracking_uri: Optional[str] = None,
) -> str:
"""
Generate a KServe InferenceService manifest as YAML.
Args:
model_name: Name of the registered model
version: Version number (uses Production if not specified)
namespace: Kubernetes namespace
output_path: If provided, write YAML to this path
tracking_uri: Override default tracking URI
Returns:
YAML string of the manifest
"""
manifest = generate_kserve_manifest(
model_name=model_name,
version=version,
namespace=namespace,
tracking_uri=tracking_uri,
)
yaml_str = yaml.dump(manifest, default_flow_style=False, sort_keys=False)
if output_path:
with open(output_path, "w") as f:
f.write(yaml_str)
logger.info(f"Wrote KServe manifest to {output_path}")
return yaml_str
def list_model_versions(
model_name: str,
stages: Optional[List[str]] = None,
tracking_uri: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
List all versions of a registered model.
Args:
model_name: Name of the registered model
stages: Filter by stages (None for all)
tracking_uri: Override default tracking URI
Returns:
List of model version info dictionaries
"""
client = get_mlflow_client(tracking_uri=tracking_uri)
if stages:
versions = client.get_latest_versions(model_name, stages=stages)
else:
# Get all versions
versions = []
for mv in client.search_model_versions(f"name='{model_name}'"):
versions.append(mv)
return [
{
"version": mv.version,
"stage": mv.current_stage,
"source": mv.source,
"run_id": mv.run_id,
"description": mv.description,
"tags": mv.tags,
"creation_timestamp": mv.creation_timestamp,
"last_updated_timestamp": mv.last_updated_timestamp,
}
for mv in versions
]