""" MLflow Model Registry Integration for KServe Provides utilities for registering trained models in MLflow Model Registry with metadata needed for deployment to KServe InferenceServices. This module bridges the gap between Kubeflow training pipelines and KServe model serving by: 1. Registering models with proper versioning 2. Adding KServe-specific metadata (runtime, protocol, resources) 3. Managing model stage transitions (Staging → Production) 4. Generating KServe InferenceService manifests from registered models Usage: from mlflow_utils.model_registry import ( register_model_for_kserve, promote_model_to_production, generate_kserve_manifest, ) # Register a new model version model_version = register_model_for_kserve( model_name="whisper-finetuned", model_uri="s3://models/whisper-v2", model_type="stt", kserve_config={ "runtime": "kserve-huggingface", "container_image": "ghcr.io/my-org/whisper:v2", } ) # Generate KServe manifest for deployment manifest = generate_kserve_manifest( model_name="whisper-finetuned", model_version=model_version.version, ) """ import logging from dataclasses import dataclass, field from typing import Any, Dict, List, Optional import mlflow import yaml from mlflow.entities.model_registry import ModelVersion from .client import get_mlflow_client logger = logging.getLogger(__name__) @dataclass class KServeConfig: """Configuration for KServe deployment.""" # Runtime/container configuration runtime: str = "kserve-huggingface" # kserve-huggingface, kserve-custom, etc. container_image: Optional[str] = None container_port: int = 8080 # Protocol configuration protocol: str = "v2" # v1, v2, grpc # Resource requests/limits cpu_request: str = "1" cpu_limit: str = "4" memory_request: str = "4Gi" memory_limit: str = "16Gi" gpu_count: int = 0 gpu_type: str = "nvidia.com/gpu" # or amd.com/gpu for ROCm # Storage configuration storage_uri: Optional[str] = None # s3://, pvc://, gs:// # Scaling configuration min_replicas: int = 1 max_replicas: int = 1 scale_target: int = 10 # Target concurrent requests for scaling # Serving configuration timeout_seconds: int = 300 batch_size: int = 1 # Additional environment variables env_vars: Dict[str, str] = field(default_factory=dict) def as_dict(self) -> Dict[str, Any]: """Convert to dictionary for MLflow tags.""" return { "kserve.runtime": self.runtime, "kserve.protocol": self.protocol, "kserve.cpu_request": self.cpu_request, "kserve.memory_request": self.memory_request, "kserve.gpu_count": str(self.gpu_count), "kserve.min_replicas": str(self.min_replicas), "kserve.max_replicas": str(self.max_replicas), "kserve.storage_uri": self.storage_uri or "", "kserve.container_image": self.container_image or "", } # Pre-configured KServe configurations for common model types KSERVE_PRESETS: Dict[str, KServeConfig] = { "llm": KServeConfig( runtime="kserve-huggingface", cpu_request="2", cpu_limit="8", memory_request="16Gi", memory_limit="64Gi", gpu_count=1, timeout_seconds=600, ), "stt": KServeConfig( runtime="kserve-custom", cpu_request="2", cpu_limit="4", memory_request="8Gi", memory_limit="16Gi", gpu_count=1, timeout_seconds=120, ), "tts": KServeConfig( runtime="kserve-custom", cpu_request="2", cpu_limit="4", memory_request="8Gi", memory_limit="16Gi", gpu_count=1, timeout_seconds=60, ), "embeddings": KServeConfig( runtime="kserve-huggingface", cpu_request="1", cpu_limit="4", memory_request="4Gi", memory_limit="16Gi", gpu_count=0, timeout_seconds=30, batch_size=32, ), "reranker": KServeConfig( runtime="kserve-huggingface", cpu_request="1", cpu_limit="4", memory_request="4Gi", memory_limit="16Gi", gpu_count=0, timeout_seconds=30, ), } def register_model_for_kserve( model_name: str, model_uri: str, model_type: str, run_id: Optional[str] = None, description: Optional[str] = None, kserve_config: Optional[KServeConfig] = None, tags: Optional[Dict[str, str]] = None, tracking_uri: Optional[str] = None, ) -> ModelVersion: """ Register a model in MLflow Model Registry with KServe metadata. Args: model_name: Name for the registered model model_uri: URI to model artifacts (runs:/run_id/path, s3://, pvc://) model_type: Type of model (llm, stt, tts, embeddings, reranker) run_id: Optional MLflow run ID to associate with description: Description of the model version kserve_config: KServe deployment configuration tags: Additional tags for the model version tracking_uri: Override default tracking URI Returns: The created ModelVersion object """ client = get_mlflow_client(tracking_uri=tracking_uri) # Get or use preset KServe config if kserve_config is None: kserve_config = KSERVE_PRESETS.get(model_type, KServeConfig()) # Ensure registered model exists try: client.get_registered_model(model_name) except mlflow.exceptions.MlflowException: client.create_registered_model( name=model_name, description=f"{model_type.upper()} model for KServe deployment", tags={ "model.type": model_type, "deployment.target": "kserve", } ) logger.info(f"Created registered model: {model_name}") # Create model version model_version = client.create_model_version( name=model_name, source=model_uri, run_id=run_id, description=description or f"Version from {model_uri}", tags={ **(tags or {}), "model.type": model_type, **kserve_config.as_dict(), } ) logger.info( f"Registered model version {model_version.version} " f"for {model_name} (type: {model_type})" ) return model_version def promote_model_to_stage( model_name: str, version: int, stage: str = "Staging", archive_existing: bool = True, tracking_uri: Optional[str] = None, ) -> ModelVersion: """ Promote a model version to a new stage. Args: model_name: Name of the registered model version: Version number to promote stage: Target stage (Staging, Production, Archived) archive_existing: If True, archive existing versions in target stage tracking_uri: Override default tracking URI Returns: The updated ModelVersion """ client = get_mlflow_client(tracking_uri=tracking_uri) # Transition to new stage model_version = client.transition_model_version_stage( name=model_name, version=str(version), stage=stage, archive_existing_versions=archive_existing, ) logger.info(f"Promoted {model_name} v{version} to {stage}") return model_version def promote_model_to_production( model_name: str, version: int, tracking_uri: Optional[str] = None, ) -> ModelVersion: """ Promote a model version directly to Production. Args: model_name: Name of the registered model version: Version number to promote tracking_uri: Override default tracking URI Returns: The updated ModelVersion """ return promote_model_to_stage( model_name=model_name, version=version, stage="Production", archive_existing=True, tracking_uri=tracking_uri, ) def get_production_model( model_name: str, tracking_uri: Optional[str] = None, ) -> Optional[ModelVersion]: """ Get the current Production model version. Args: model_name: Name of the registered model tracking_uri: Override default tracking URI Returns: The Production ModelVersion, or None if none exists """ client = get_mlflow_client(tracking_uri=tracking_uri) versions = client.get_latest_versions(model_name, stages=["Production"]) return versions[0] if versions else None def get_model_kserve_config( model_name: str, version: Optional[int] = None, tracking_uri: Optional[str] = None, ) -> KServeConfig: """ Get KServe configuration from a registered model version. Args: model_name: Name of the registered model version: Version number (uses Production if not specified) tracking_uri: Override default tracking URI Returns: KServeConfig populated from model tags """ client = get_mlflow_client(tracking_uri=tracking_uri) if version: model_version = client.get_model_version(model_name, str(version)) else: prod_version = get_production_model(model_name, tracking_uri) if not prod_version: raise ValueError(f"No Production version for {model_name}") model_version = prod_version tags = model_version.tags return KServeConfig( runtime=tags.get("kserve.runtime", "kserve-huggingface"), protocol=tags.get("kserve.protocol", "v2"), cpu_request=tags.get("kserve.cpu_request", "1"), memory_request=tags.get("kserve.memory_request", "4Gi"), gpu_count=int(tags.get("kserve.gpu_count", "0")), min_replicas=int(tags.get("kserve.min_replicas", "1")), max_replicas=int(tags.get("kserve.max_replicas", "1")), storage_uri=tags.get("kserve.storage_uri") or None, container_image=tags.get("kserve.container_image") or None, ) def generate_kserve_manifest( model_name: str, version: Optional[int] = None, namespace: str = "ai-ml", service_name: Optional[str] = None, extra_annotations: Optional[Dict[str, str]] = None, tracking_uri: Optional[str] = None, ) -> Dict[str, Any]: """ Generate a KServe InferenceService manifest from a registered model. Args: model_name: Name of the registered model version: Version number (uses Production if not specified) namespace: Kubernetes namespace for deployment service_name: Name for the InferenceService (defaults to model_name) extra_annotations: Additional annotations for the service tracking_uri: Override default tracking URI Returns: KServe InferenceService manifest as a dictionary """ client = get_mlflow_client(tracking_uri=tracking_uri) # Get model version if version: model_version = client.get_model_version(model_name, str(version)) else: prod_version = get_production_model(model_name, tracking_uri) if not prod_version: raise ValueError(f"No Production version for {model_name}") model_version = prod_version version = int(model_version.version) # Get KServe config config = get_model_kserve_config(model_name, version, tracking_uri) model_type = model_version.tags.get("model.type", "custom") svc_name = service_name or model_name.lower().replace("_", "-") # Build manifest manifest = { "apiVersion": "serving.kserve.io/v1beta1", "kind": "InferenceService", "metadata": { "name": svc_name, "namespace": namespace, "labels": { "mlflow.model": model_name, "mlflow.version": str(version), "model.type": model_type, }, "annotations": { "mlflow.tracking_uri": get_mlflow_client().tracking_uri, "mlflow.run_id": model_version.run_id or "", **(extra_annotations or {}), }, }, "spec": { "predictor": { "minReplicas": config.min_replicas, "maxReplicas": config.max_replicas, "scaleTarget": config.scale_target, "timeout": config.timeout_seconds, }, }, } # Configure predictor based on runtime predictor = manifest["spec"]["predictor"] if config.container_image: # Custom container predictor["containers"] = [{ "name": "predictor", "image": config.container_image, "ports": [{"containerPort": config.container_port, "protocol": "TCP"}], "resources": { "requests": { "cpu": config.cpu_request, "memory": config.memory_request, }, "limits": { "cpu": config.cpu_limit, "memory": config.memory_limit, }, }, "env": [ {"name": k, "value": v} for k, v in config.env_vars.items() ], }] # Add GPU if needed if config.gpu_count > 0: predictor["containers"][0]["resources"]["limits"][config.gpu_type] = str(config.gpu_count) predictor["containers"][0]["resources"]["requests"][config.gpu_type] = str(config.gpu_count) else: # Standard KServe runtime storage_uri = config.storage_uri or model_version.source predictor["model"] = { "modelFormat": {"name": "huggingface"}, "protocolVersion": config.protocol, "storageUri": storage_uri, "resources": { "requests": { "cpu": config.cpu_request, "memory": config.memory_request, }, "limits": { "cpu": config.cpu_limit, "memory": config.memory_limit, }, }, } if config.gpu_count > 0: predictor["model"]["resources"]["limits"][config.gpu_type] = str(config.gpu_count) predictor["model"]["resources"]["requests"][config.gpu_type] = str(config.gpu_count) return manifest def generate_kserve_yaml( model_name: str, version: Optional[int] = None, namespace: str = "ai-ml", output_path: Optional[str] = None, tracking_uri: Optional[str] = None, ) -> str: """ Generate a KServe InferenceService manifest as YAML. Args: model_name: Name of the registered model version: Version number (uses Production if not specified) namespace: Kubernetes namespace output_path: If provided, write YAML to this path tracking_uri: Override default tracking URI Returns: YAML string of the manifest """ manifest = generate_kserve_manifest( model_name=model_name, version=version, namespace=namespace, tracking_uri=tracking_uri, ) yaml_str = yaml.dump(manifest, default_flow_style=False, sort_keys=False) if output_path: with open(output_path, 'w') as f: f.write(yaml_str) logger.info(f"Wrote KServe manifest to {output_path}") return yaml_str def list_model_versions( model_name: str, stages: Optional[List[str]] = None, tracking_uri: Optional[str] = None, ) -> List[Dict[str, Any]]: """ List all versions of a registered model. Args: model_name: Name of the registered model stages: Filter by stages (None for all) tracking_uri: Override default tracking URI Returns: List of model version info dictionaries """ client = get_mlflow_client(tracking_uri=tracking_uri) if stages: versions = client.get_latest_versions(model_name, stages=stages) else: # Get all versions versions = [] for mv in client.search_model_versions(f"name='{model_name}'"): versions.append(mv) return [ { "version": mv.version, "stage": mv.current_stage, "source": mv.source, "run_id": mv.run_id, "description": mv.description, "tags": mv.tags, "creation_timestamp": mv.creation_timestamp, "last_updated_timestamp": mv.last_updated_timestamp, } for mv in versions ]