539 lines
16 KiB
Python
539 lines
16 KiB
Python
"""
|
|
MLflow Model Registry Integration for KServe
|
|
|
|
Provides utilities for registering trained models in MLflow Model Registry
|
|
with metadata needed for deployment to KServe InferenceServices.
|
|
|
|
This module bridges the gap between Kubeflow training pipelines and
|
|
KServe model serving by:
|
|
1. Registering models with proper versioning
|
|
2. Adding KServe-specific metadata (runtime, protocol, resources)
|
|
3. Managing model stage transitions (Staging → Production)
|
|
4. Generating KServe InferenceService manifests from registered models
|
|
|
|
Usage:
|
|
from mlflow_utils.model_registry import (
|
|
register_model_for_kserve,
|
|
promote_model_to_production,
|
|
generate_kserve_manifest,
|
|
)
|
|
|
|
# Register a new model version
|
|
model_version = register_model_for_kserve(
|
|
model_name="whisper-finetuned",
|
|
model_uri="s3://models/whisper-v2",
|
|
model_type="stt",
|
|
kserve_config={
|
|
"runtime": "kserve-huggingface",
|
|
"container_image": "ghcr.io/my-org/whisper:v2",
|
|
}
|
|
)
|
|
|
|
# Generate KServe manifest for deployment
|
|
manifest = generate_kserve_manifest(
|
|
model_name="whisper-finetuned",
|
|
model_version=model_version.version,
|
|
)
|
|
"""
|
|
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import mlflow
|
|
import yaml
|
|
from mlflow.entities.model_registry import ModelVersion
|
|
|
|
from .client import get_mlflow_client
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class KServeConfig:
|
|
"""Configuration for KServe deployment."""
|
|
|
|
# Runtime/container configuration
|
|
runtime: str = "kserve-huggingface" # kserve-huggingface, kserve-custom, etc.
|
|
container_image: Optional[str] = None
|
|
container_port: int = 8080
|
|
|
|
# Protocol configuration
|
|
protocol: str = "v2" # v1, v2, grpc
|
|
|
|
# Resource requests/limits
|
|
cpu_request: str = "1"
|
|
cpu_limit: str = "4"
|
|
memory_request: str = "4Gi"
|
|
memory_limit: str = "16Gi"
|
|
gpu_count: int = 0
|
|
gpu_type: str = "nvidia.com/gpu" # or amd.com/gpu for ROCm
|
|
|
|
# Storage configuration
|
|
storage_uri: Optional[str] = None # s3://, pvc://, gs://
|
|
|
|
# Scaling configuration
|
|
min_replicas: int = 1
|
|
max_replicas: int = 1
|
|
scale_target: int = 10 # Target concurrent requests for scaling
|
|
|
|
# Serving configuration
|
|
timeout_seconds: int = 300
|
|
batch_size: int = 1
|
|
|
|
# Additional environment variables
|
|
env_vars: Dict[str, str] = field(default_factory=dict)
|
|
|
|
def as_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary for MLflow tags."""
|
|
return {
|
|
"kserve.runtime": self.runtime,
|
|
"kserve.protocol": self.protocol,
|
|
"kserve.cpu_request": self.cpu_request,
|
|
"kserve.memory_request": self.memory_request,
|
|
"kserve.gpu_count": str(self.gpu_count),
|
|
"kserve.min_replicas": str(self.min_replicas),
|
|
"kserve.max_replicas": str(self.max_replicas),
|
|
"kserve.storage_uri": self.storage_uri or "",
|
|
"kserve.container_image": self.container_image or "",
|
|
}
|
|
|
|
|
|
# Pre-configured KServe configurations for common model types
|
|
KSERVE_PRESETS: Dict[str, KServeConfig] = {
|
|
"llm": KServeConfig(
|
|
runtime="kserve-huggingface",
|
|
cpu_request="2",
|
|
cpu_limit="8",
|
|
memory_request="16Gi",
|
|
memory_limit="64Gi",
|
|
gpu_count=1,
|
|
timeout_seconds=600,
|
|
),
|
|
"stt": KServeConfig(
|
|
runtime="kserve-custom",
|
|
cpu_request="2",
|
|
cpu_limit="4",
|
|
memory_request="8Gi",
|
|
memory_limit="16Gi",
|
|
gpu_count=1,
|
|
timeout_seconds=120,
|
|
),
|
|
"tts": KServeConfig(
|
|
runtime="kserve-custom",
|
|
cpu_request="2",
|
|
cpu_limit="4",
|
|
memory_request="8Gi",
|
|
memory_limit="16Gi",
|
|
gpu_count=1,
|
|
timeout_seconds=60,
|
|
),
|
|
"embeddings": KServeConfig(
|
|
runtime="kserve-huggingface",
|
|
cpu_request="1",
|
|
cpu_limit="4",
|
|
memory_request="4Gi",
|
|
memory_limit="16Gi",
|
|
gpu_count=0,
|
|
timeout_seconds=30,
|
|
batch_size=32,
|
|
),
|
|
"reranker": KServeConfig(
|
|
runtime="kserve-huggingface",
|
|
cpu_request="1",
|
|
cpu_limit="4",
|
|
memory_request="4Gi",
|
|
memory_limit="16Gi",
|
|
gpu_count=0,
|
|
timeout_seconds=30,
|
|
),
|
|
}
|
|
|
|
|
|
def register_model_for_kserve(
|
|
model_name: str,
|
|
model_uri: str,
|
|
model_type: str,
|
|
run_id: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
kserve_config: Optional[KServeConfig] = None,
|
|
tags: Optional[Dict[str, str]] = None,
|
|
tracking_uri: Optional[str] = None,
|
|
) -> ModelVersion:
|
|
"""
|
|
Register a model in MLflow Model Registry with KServe metadata.
|
|
|
|
Args:
|
|
model_name: Name for the registered model
|
|
model_uri: URI to model artifacts (runs:/run_id/path, s3://, pvc://)
|
|
model_type: Type of model (llm, stt, tts, embeddings, reranker)
|
|
run_id: Optional MLflow run ID to associate with
|
|
description: Description of the model version
|
|
kserve_config: KServe deployment configuration
|
|
tags: Additional tags for the model version
|
|
tracking_uri: Override default tracking URI
|
|
|
|
Returns:
|
|
The created ModelVersion object
|
|
"""
|
|
client = get_mlflow_client(tracking_uri=tracking_uri)
|
|
|
|
# Get or use preset KServe config
|
|
if kserve_config is None:
|
|
kserve_config = KSERVE_PRESETS.get(model_type, KServeConfig())
|
|
|
|
# Ensure registered model exists
|
|
try:
|
|
client.get_registered_model(model_name)
|
|
except mlflow.exceptions.MlflowException:
|
|
client.create_registered_model(
|
|
name=model_name,
|
|
description=f"{model_type.upper()} model for KServe deployment",
|
|
tags={
|
|
"model.type": model_type,
|
|
"deployment.target": "kserve",
|
|
},
|
|
)
|
|
logger.info(f"Created registered model: {model_name}")
|
|
|
|
# Create model version
|
|
model_version = client.create_model_version(
|
|
name=model_name,
|
|
source=model_uri,
|
|
run_id=run_id,
|
|
description=description or f"Version from {model_uri}",
|
|
tags={
|
|
**(tags or {}),
|
|
"model.type": model_type,
|
|
**kserve_config.as_dict(),
|
|
},
|
|
)
|
|
|
|
logger.info(f"Registered model version {model_version.version} for {model_name} (type: {model_type})")
|
|
|
|
return model_version
|
|
|
|
|
|
def promote_model_to_stage(
|
|
model_name: str,
|
|
version: int,
|
|
stage: str = "Staging",
|
|
archive_existing: bool = True,
|
|
tracking_uri: Optional[str] = None,
|
|
) -> ModelVersion:
|
|
"""
|
|
Promote a model version to a new stage.
|
|
|
|
Args:
|
|
model_name: Name of the registered model
|
|
version: Version number to promote
|
|
stage: Target stage (Staging, Production, Archived)
|
|
archive_existing: If True, archive existing versions in target stage
|
|
tracking_uri: Override default tracking URI
|
|
|
|
Returns:
|
|
The updated ModelVersion
|
|
"""
|
|
client = get_mlflow_client(tracking_uri=tracking_uri)
|
|
|
|
# Transition to new stage
|
|
model_version = client.transition_model_version_stage(
|
|
name=model_name,
|
|
version=str(version),
|
|
stage=stage,
|
|
archive_existing_versions=archive_existing,
|
|
)
|
|
|
|
logger.info(f"Promoted {model_name} v{version} to {stage}")
|
|
|
|
return model_version
|
|
|
|
|
|
def promote_model_to_production(
|
|
model_name: str,
|
|
version: int,
|
|
tracking_uri: Optional[str] = None,
|
|
) -> ModelVersion:
|
|
"""
|
|
Promote a model version directly to Production.
|
|
|
|
Args:
|
|
model_name: Name of the registered model
|
|
version: Version number to promote
|
|
tracking_uri: Override default tracking URI
|
|
|
|
Returns:
|
|
The updated ModelVersion
|
|
"""
|
|
return promote_model_to_stage(
|
|
model_name=model_name,
|
|
version=version,
|
|
stage="Production",
|
|
archive_existing=True,
|
|
tracking_uri=tracking_uri,
|
|
)
|
|
|
|
|
|
def get_production_model(
|
|
model_name: str,
|
|
tracking_uri: Optional[str] = None,
|
|
) -> Optional[ModelVersion]:
|
|
"""
|
|
Get the current Production model version.
|
|
|
|
Args:
|
|
model_name: Name of the registered model
|
|
tracking_uri: Override default tracking URI
|
|
|
|
Returns:
|
|
The Production ModelVersion, or None if none exists
|
|
"""
|
|
client = get_mlflow_client(tracking_uri=tracking_uri)
|
|
|
|
versions = client.get_latest_versions(model_name, stages=["Production"])
|
|
|
|
return versions[0] if versions else None
|
|
|
|
|
|
def get_model_kserve_config(
|
|
model_name: str,
|
|
version: Optional[int] = None,
|
|
tracking_uri: Optional[str] = None,
|
|
) -> KServeConfig:
|
|
"""
|
|
Get KServe configuration from a registered model version.
|
|
|
|
Args:
|
|
model_name: Name of the registered model
|
|
version: Version number (uses Production if not specified)
|
|
tracking_uri: Override default tracking URI
|
|
|
|
Returns:
|
|
KServeConfig populated from model tags
|
|
"""
|
|
client = get_mlflow_client(tracking_uri=tracking_uri)
|
|
|
|
if version:
|
|
model_version = client.get_model_version(model_name, str(version))
|
|
else:
|
|
prod_version = get_production_model(model_name, tracking_uri)
|
|
if not prod_version:
|
|
raise ValueError(f"No Production version for {model_name}")
|
|
model_version = prod_version
|
|
|
|
tags = model_version.tags
|
|
|
|
return KServeConfig(
|
|
runtime=tags.get("kserve.runtime", "kserve-huggingface"),
|
|
protocol=tags.get("kserve.protocol", "v2"),
|
|
cpu_request=tags.get("kserve.cpu_request", "1"),
|
|
memory_request=tags.get("kserve.memory_request", "4Gi"),
|
|
gpu_count=int(tags.get("kserve.gpu_count", "0")),
|
|
min_replicas=int(tags.get("kserve.min_replicas", "1")),
|
|
max_replicas=int(tags.get("kserve.max_replicas", "1")),
|
|
storage_uri=tags.get("kserve.storage_uri") or None,
|
|
container_image=tags.get("kserve.container_image") or None,
|
|
)
|
|
|
|
|
|
def generate_kserve_manifest(
|
|
model_name: str,
|
|
version: Optional[int] = None,
|
|
namespace: str = "ai-ml",
|
|
service_name: Optional[str] = None,
|
|
extra_annotations: Optional[Dict[str, str]] = None,
|
|
tracking_uri: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Generate a KServe InferenceService manifest from a registered model.
|
|
|
|
Args:
|
|
model_name: Name of the registered model
|
|
version: Version number (uses Production if not specified)
|
|
namespace: Kubernetes namespace for deployment
|
|
service_name: Name for the InferenceService (defaults to model_name)
|
|
extra_annotations: Additional annotations for the service
|
|
tracking_uri: Override default tracking URI
|
|
|
|
Returns:
|
|
KServe InferenceService manifest as a dictionary
|
|
"""
|
|
client = get_mlflow_client(tracking_uri=tracking_uri)
|
|
|
|
# Get model version
|
|
if version:
|
|
model_version = client.get_model_version(model_name, str(version))
|
|
else:
|
|
prod_version = get_production_model(model_name, tracking_uri)
|
|
if not prod_version:
|
|
raise ValueError(f"No Production version for {model_name}")
|
|
model_version = prod_version
|
|
version = int(model_version.version)
|
|
|
|
# Get KServe config
|
|
config = get_model_kserve_config(model_name, version, tracking_uri)
|
|
model_type = model_version.tags.get("model.type", "custom")
|
|
|
|
svc_name = service_name or model_name.lower().replace("_", "-")
|
|
|
|
# Build manifest
|
|
manifest = {
|
|
"apiVersion": "serving.kserve.io/v1beta1",
|
|
"kind": "InferenceService",
|
|
"metadata": {
|
|
"name": svc_name,
|
|
"namespace": namespace,
|
|
"labels": {
|
|
"mlflow.model": model_name,
|
|
"mlflow.version": str(version),
|
|
"model.type": model_type,
|
|
},
|
|
"annotations": {
|
|
"mlflow.tracking_uri": get_mlflow_client().tracking_uri,
|
|
"mlflow.run_id": model_version.run_id or "",
|
|
**(extra_annotations or {}),
|
|
},
|
|
},
|
|
"spec": {
|
|
"predictor": {
|
|
"minReplicas": config.min_replicas,
|
|
"maxReplicas": config.max_replicas,
|
|
"scaleTarget": config.scale_target,
|
|
"timeout": config.timeout_seconds,
|
|
},
|
|
},
|
|
}
|
|
|
|
# Configure predictor based on runtime
|
|
predictor = manifest["spec"]["predictor"]
|
|
|
|
if config.container_image:
|
|
# Custom container
|
|
predictor["containers"] = [
|
|
{
|
|
"name": "predictor",
|
|
"image": config.container_image,
|
|
"ports": [{"containerPort": config.container_port, "protocol": "TCP"}],
|
|
"resources": {
|
|
"requests": {
|
|
"cpu": config.cpu_request,
|
|
"memory": config.memory_request,
|
|
},
|
|
"limits": {
|
|
"cpu": config.cpu_limit,
|
|
"memory": config.memory_limit,
|
|
},
|
|
},
|
|
"env": [{"name": k, "value": v} for k, v in config.env_vars.items()],
|
|
}
|
|
]
|
|
|
|
# Add GPU if needed
|
|
if config.gpu_count > 0:
|
|
predictor["containers"][0]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
|
|
predictor["containers"][0]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
|
|
|
|
else:
|
|
# Standard KServe runtime
|
|
storage_uri = config.storage_uri or model_version.source
|
|
|
|
predictor["model"] = {
|
|
"modelFormat": {"name": "huggingface"},
|
|
"protocolVersion": config.protocol,
|
|
"storageUri": storage_uri,
|
|
"resources": {
|
|
"requests": {
|
|
"cpu": config.cpu_request,
|
|
"memory": config.memory_request,
|
|
},
|
|
"limits": {
|
|
"cpu": config.cpu_limit,
|
|
"memory": config.memory_limit,
|
|
},
|
|
},
|
|
}
|
|
|
|
if config.gpu_count > 0:
|
|
predictor["model"]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
|
|
predictor["model"]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
|
|
|
|
return manifest
|
|
|
|
|
|
def generate_kserve_yaml(
|
|
model_name: str,
|
|
version: Optional[int] = None,
|
|
namespace: str = "ai-ml",
|
|
output_path: Optional[str] = None,
|
|
tracking_uri: Optional[str] = None,
|
|
) -> str:
|
|
"""
|
|
Generate a KServe InferenceService manifest as YAML.
|
|
|
|
Args:
|
|
model_name: Name of the registered model
|
|
version: Version number (uses Production if not specified)
|
|
namespace: Kubernetes namespace
|
|
output_path: If provided, write YAML to this path
|
|
tracking_uri: Override default tracking URI
|
|
|
|
Returns:
|
|
YAML string of the manifest
|
|
"""
|
|
manifest = generate_kserve_manifest(
|
|
model_name=model_name,
|
|
version=version,
|
|
namespace=namespace,
|
|
tracking_uri=tracking_uri,
|
|
)
|
|
|
|
yaml_str = yaml.dump(manifest, default_flow_style=False, sort_keys=False)
|
|
|
|
if output_path:
|
|
with open(output_path, "w") as f:
|
|
f.write(yaml_str)
|
|
logger.info(f"Wrote KServe manifest to {output_path}")
|
|
|
|
return yaml_str
|
|
|
|
|
|
def list_model_versions(
|
|
model_name: str,
|
|
stages: Optional[List[str]] = None,
|
|
tracking_uri: Optional[str] = None,
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
List all versions of a registered model.
|
|
|
|
Args:
|
|
model_name: Name of the registered model
|
|
stages: Filter by stages (None for all)
|
|
tracking_uri: Override default tracking URI
|
|
|
|
Returns:
|
|
List of model version info dictionaries
|
|
"""
|
|
client = get_mlflow_client(tracking_uri=tracking_uri)
|
|
|
|
if stages:
|
|
versions = client.get_latest_versions(model_name, stages=stages)
|
|
else:
|
|
# Get all versions
|
|
versions = []
|
|
for mv in client.search_model_versions(f"name='{model_name}'"):
|
|
versions.append(mv)
|
|
|
|
return [
|
|
{
|
|
"version": mv.version,
|
|
"stage": mv.current_stage,
|
|
"source": mv.source,
|
|
"run_id": mv.run_id,
|
|
"description": mv.description,
|
|
"tags": mv.tags,
|
|
"creation_timestamp": mv.creation_timestamp,
|
|
"last_updated_timestamp": mv.last_updated_timestamp,
|
|
}
|
|
for mv in versions
|
|
]
|