fix: resolve all ruff lint errors
This commit is contained in:
@@ -17,7 +17,7 @@ Usage:
|
||||
promote_model_to_production,
|
||||
generate_kserve_manifest,
|
||||
)
|
||||
|
||||
|
||||
# Register a new model version
|
||||
model_version = register_model_for_kserve(
|
||||
model_name="whisper-finetuned",
|
||||
@@ -28,7 +28,7 @@ Usage:
|
||||
"container_image": "ghcr.io/my-org/whisper:v2",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Generate KServe manifest for deployment
|
||||
manifest = generate_kserve_manifest(
|
||||
model_name="whisper-finetuned",
|
||||
@@ -36,18 +36,15 @@ Usage:
|
||||
)
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import yaml
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
import yaml
|
||||
from mlflow.entities.model_registry import ModelVersion
|
||||
|
||||
from .client import get_mlflow_client, MLflowConfig
|
||||
from .client import get_mlflow_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -55,15 +52,15 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class KServeConfig:
|
||||
"""Configuration for KServe deployment."""
|
||||
|
||||
|
||||
# Runtime/container configuration
|
||||
runtime: str = "kserve-huggingface" # kserve-huggingface, kserve-custom, etc.
|
||||
container_image: Optional[str] = None
|
||||
container_port: int = 8080
|
||||
|
||||
|
||||
# Protocol configuration
|
||||
protocol: str = "v2" # v1, v2, grpc
|
||||
|
||||
|
||||
# Resource requests/limits
|
||||
cpu_request: str = "1"
|
||||
cpu_limit: str = "4"
|
||||
@@ -71,22 +68,22 @@ class KServeConfig:
|
||||
memory_limit: str = "16Gi"
|
||||
gpu_count: int = 0
|
||||
gpu_type: str = "nvidia.com/gpu" # or amd.com/gpu for ROCm
|
||||
|
||||
|
||||
# Storage configuration
|
||||
storage_uri: Optional[str] = None # s3://, pvc://, gs://
|
||||
|
||||
|
||||
# Scaling configuration
|
||||
min_replicas: int = 1
|
||||
max_replicas: int = 1
|
||||
scale_target: int = 10 # Target concurrent requests for scaling
|
||||
|
||||
|
||||
# Serving configuration
|
||||
timeout_seconds: int = 300
|
||||
batch_size: int = 1
|
||||
|
||||
|
||||
# Additional environment variables
|
||||
env_vars: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def as_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for MLflow tags."""
|
||||
return {
|
||||
@@ -165,7 +162,7 @@ def register_model_for_kserve(
|
||||
) -> ModelVersion:
|
||||
"""
|
||||
Register a model in MLflow Model Registry with KServe metadata.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name for the registered model
|
||||
model_uri: URI to model artifacts (runs:/run_id/path, s3://, pvc://)
|
||||
@@ -175,16 +172,16 @@ def register_model_for_kserve(
|
||||
kserve_config: KServe deployment configuration
|
||||
tags: Additional tags for the model version
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
|
||||
Returns:
|
||||
The created ModelVersion object
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
|
||||
# Get or use preset KServe config
|
||||
if kserve_config is None:
|
||||
kserve_config = KSERVE_PRESETS.get(model_type, KServeConfig())
|
||||
|
||||
|
||||
# Ensure registered model exists
|
||||
try:
|
||||
client.get_registered_model(model_name)
|
||||
@@ -198,7 +195,7 @@ def register_model_for_kserve(
|
||||
}
|
||||
)
|
||||
logger.info(f"Created registered model: {model_name}")
|
||||
|
||||
|
||||
# Create model version
|
||||
model_version = client.create_model_version(
|
||||
name=model_name,
|
||||
@@ -211,12 +208,12 @@ def register_model_for_kserve(
|
||||
**kserve_config.as_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Registered model version {model_version.version} "
|
||||
f"for {model_name} (type: {model_type})"
|
||||
)
|
||||
|
||||
|
||||
return model_version
|
||||
|
||||
|
||||
@@ -229,19 +226,19 @@ def promote_model_to_stage(
|
||||
) -> ModelVersion:
|
||||
"""
|
||||
Promote a model version to a new stage.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
version: Version number to promote
|
||||
stage: Target stage (Staging, Production, Archived)
|
||||
archive_existing: If True, archive existing versions in target stage
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
|
||||
Returns:
|
||||
The updated ModelVersion
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
|
||||
# Transition to new stage
|
||||
model_version = client.transition_model_version_stage(
|
||||
name=model_name,
|
||||
@@ -249,9 +246,9 @@ def promote_model_to_stage(
|
||||
stage=stage,
|
||||
archive_existing_versions=archive_existing,
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Promoted {model_name} v{version} to {stage}")
|
||||
|
||||
|
||||
return model_version
|
||||
|
||||
|
||||
@@ -262,12 +259,12 @@ def promote_model_to_production(
|
||||
) -> ModelVersion:
|
||||
"""
|
||||
Promote a model version directly to Production.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
version: Version number to promote
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
|
||||
Returns:
|
||||
The updated ModelVersion
|
||||
"""
|
||||
@@ -286,18 +283,18 @@ def get_production_model(
|
||||
) -> Optional[ModelVersion]:
|
||||
"""
|
||||
Get the current Production model version.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
|
||||
Returns:
|
||||
The Production ModelVersion, or None if none exists
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
|
||||
versions = client.get_latest_versions(model_name, stages=["Production"])
|
||||
|
||||
|
||||
return versions[0] if versions else None
|
||||
|
||||
|
||||
@@ -308,17 +305,17 @@ def get_model_kserve_config(
|
||||
) -> KServeConfig:
|
||||
"""
|
||||
Get KServe configuration from a registered model version.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
version: Version number (uses Production if not specified)
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
|
||||
Returns:
|
||||
KServeConfig populated from model tags
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
|
||||
if version:
|
||||
model_version = client.get_model_version(model_name, str(version))
|
||||
else:
|
||||
@@ -326,9 +323,9 @@ def get_model_kserve_config(
|
||||
if not prod_version:
|
||||
raise ValueError(f"No Production version for {model_name}")
|
||||
model_version = prod_version
|
||||
|
||||
|
||||
tags = model_version.tags
|
||||
|
||||
|
||||
return KServeConfig(
|
||||
runtime=tags.get("kserve.runtime", "kserve-huggingface"),
|
||||
protocol=tags.get("kserve.protocol", "v2"),
|
||||
@@ -352,7 +349,7 @@ def generate_kserve_manifest(
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a KServe InferenceService manifest from a registered model.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
version: Version number (uses Production if not specified)
|
||||
@@ -360,12 +357,12 @@ def generate_kserve_manifest(
|
||||
service_name: Name for the InferenceService (defaults to model_name)
|
||||
extra_annotations: Additional annotations for the service
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
|
||||
Returns:
|
||||
KServe InferenceService manifest as a dictionary
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
|
||||
# Get model version
|
||||
if version:
|
||||
model_version = client.get_model_version(model_name, str(version))
|
||||
@@ -375,13 +372,13 @@ def generate_kserve_manifest(
|
||||
raise ValueError(f"No Production version for {model_name}")
|
||||
model_version = prod_version
|
||||
version = int(model_version.version)
|
||||
|
||||
|
||||
# Get KServe config
|
||||
config = get_model_kserve_config(model_name, version, tracking_uri)
|
||||
model_type = model_version.tags.get("model.type", "custom")
|
||||
|
||||
|
||||
svc_name = service_name or model_name.lower().replace("_", "-")
|
||||
|
||||
|
||||
# Build manifest
|
||||
manifest = {
|
||||
"apiVersion": "serving.kserve.io/v1beta1",
|
||||
@@ -409,10 +406,10 @@ def generate_kserve_manifest(
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Configure predictor based on runtime
|
||||
predictor = manifest["spec"]["predictor"]
|
||||
|
||||
|
||||
if config.container_image:
|
||||
# Custom container
|
||||
predictor["containers"] = [{
|
||||
@@ -434,16 +431,16 @@ def generate_kserve_manifest(
|
||||
for k, v in config.env_vars.items()
|
||||
],
|
||||
}]
|
||||
|
||||
|
||||
# Add GPU if needed
|
||||
if config.gpu_count > 0:
|
||||
predictor["containers"][0]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
|
||||
predictor["containers"][0]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
|
||||
|
||||
|
||||
else:
|
||||
# Standard KServe runtime
|
||||
storage_uri = config.storage_uri or model_version.source
|
||||
|
||||
|
||||
predictor["model"] = {
|
||||
"modelFormat": {"name": "huggingface"},
|
||||
"protocolVersion": config.protocol,
|
||||
@@ -459,11 +456,11 @@ def generate_kserve_manifest(
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if config.gpu_count > 0:
|
||||
predictor["model"]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
|
||||
predictor["model"]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
|
||||
|
||||
|
||||
return manifest
|
||||
|
||||
|
||||
@@ -476,14 +473,14 @@ def generate_kserve_yaml(
|
||||
) -> str:
|
||||
"""
|
||||
Generate a KServe InferenceService manifest as YAML.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
version: Version number (uses Production if not specified)
|
||||
namespace: Kubernetes namespace
|
||||
output_path: If provided, write YAML to this path
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
|
||||
Returns:
|
||||
YAML string of the manifest
|
||||
"""
|
||||
@@ -493,14 +490,14 @@ def generate_kserve_yaml(
|
||||
namespace=namespace,
|
||||
tracking_uri=tracking_uri,
|
||||
)
|
||||
|
||||
|
||||
yaml_str = yaml.dump(manifest, default_flow_style=False, sort_keys=False)
|
||||
|
||||
|
||||
if output_path:
|
||||
with open(output_path, 'w') as f:
|
||||
f.write(yaml_str)
|
||||
logger.info(f"Wrote KServe manifest to {output_path}")
|
||||
|
||||
|
||||
return yaml_str
|
||||
|
||||
|
||||
@@ -511,17 +508,17 @@ def list_model_versions(
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List all versions of a registered model.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the registered model
|
||||
stages: Filter by stages (None for all)
|
||||
tracking_uri: Override default tracking URI
|
||||
|
||||
|
||||
Returns:
|
||||
List of model version info dictionaries
|
||||
"""
|
||||
client = get_mlflow_client(tracking_uri=tracking_uri)
|
||||
|
||||
|
||||
if stages:
|
||||
versions = client.get_latest_versions(model_name, stages=stages)
|
||||
else:
|
||||
@@ -529,7 +526,7 @@ def list_model_versions(
|
||||
versions = []
|
||||
for mv in client.search_model_versions(f"name='{model_name}'"):
|
||||
versions.append(mv)
|
||||
|
||||
|
||||
return [
|
||||
{
|
||||
"version": mv.version,
|
||||
|
||||
Reference in New Issue
Block a user