fix: resolve all ruff lint errors
Some checks failed
CI / Test (push) Successful in 1m46s
CI / Lint (push) Failing after 1m49s
CI / Publish (push) Has been skipped
CI / Notify (push) Successful in 2s

This commit is contained in:
2026-02-13 10:57:57 -05:00
parent 6bcf84549c
commit 1c841729a0
9 changed files with 456 additions and 464 deletions

View File

@@ -17,7 +17,7 @@ Usage:
promote_model_to_production,
generate_kserve_manifest,
)
# Register a new model version
model_version = register_model_for_kserve(
model_name="whisper-finetuned",
@@ -28,7 +28,7 @@ Usage:
"container_image": "ghcr.io/my-org/whisper:v2",
}
)
# Generate KServe manifest for deployment
manifest = generate_kserve_manifest(
model_name="whisper-finetuned",
@@ -36,18 +36,15 @@ Usage:
)
"""
import os
import json
import yaml
import logging
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import mlflow
from mlflow.tracking import MlflowClient
import yaml
from mlflow.entities.model_registry import ModelVersion
from .client import get_mlflow_client, MLflowConfig
from .client import get_mlflow_client
logger = logging.getLogger(__name__)
@@ -55,15 +52,15 @@ logger = logging.getLogger(__name__)
@dataclass
class KServeConfig:
"""Configuration for KServe deployment."""
# Runtime/container configuration
runtime: str = "kserve-huggingface" # kserve-huggingface, kserve-custom, etc.
container_image: Optional[str] = None
container_port: int = 8080
# Protocol configuration
protocol: str = "v2" # v1, v2, grpc
# Resource requests/limits
cpu_request: str = "1"
cpu_limit: str = "4"
@@ -71,22 +68,22 @@ class KServeConfig:
memory_limit: str = "16Gi"
gpu_count: int = 0
gpu_type: str = "nvidia.com/gpu" # or amd.com/gpu for ROCm
# Storage configuration
storage_uri: Optional[str] = None # s3://, pvc://, gs://
# Scaling configuration
min_replicas: int = 1
max_replicas: int = 1
scale_target: int = 10 # Target concurrent requests for scaling
# Serving configuration
timeout_seconds: int = 300
batch_size: int = 1
# Additional environment variables
env_vars: Dict[str, str] = field(default_factory=dict)
def as_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for MLflow tags."""
return {
@@ -165,7 +162,7 @@ def register_model_for_kserve(
) -> ModelVersion:
"""
Register a model in MLflow Model Registry with KServe metadata.
Args:
model_name: Name for the registered model
model_uri: URI to model artifacts (runs:/run_id/path, s3://, pvc://)
@@ -175,16 +172,16 @@ def register_model_for_kserve(
kserve_config: KServe deployment configuration
tags: Additional tags for the model version
tracking_uri: Override default tracking URI
Returns:
The created ModelVersion object
"""
client = get_mlflow_client(tracking_uri=tracking_uri)
# Get or use preset KServe config
if kserve_config is None:
kserve_config = KSERVE_PRESETS.get(model_type, KServeConfig())
# Ensure registered model exists
try:
client.get_registered_model(model_name)
@@ -198,7 +195,7 @@ def register_model_for_kserve(
}
)
logger.info(f"Created registered model: {model_name}")
# Create model version
model_version = client.create_model_version(
name=model_name,
@@ -211,12 +208,12 @@ def register_model_for_kserve(
**kserve_config.as_dict(),
}
)
logger.info(
f"Registered model version {model_version.version} "
f"for {model_name} (type: {model_type})"
)
return model_version
@@ -229,19 +226,19 @@ def promote_model_to_stage(
) -> ModelVersion:
"""
Promote a model version to a new stage.
Args:
model_name: Name of the registered model
version: Version number to promote
stage: Target stage (Staging, Production, Archived)
archive_existing: If True, archive existing versions in target stage
tracking_uri: Override default tracking URI
Returns:
The updated ModelVersion
"""
client = get_mlflow_client(tracking_uri=tracking_uri)
# Transition to new stage
model_version = client.transition_model_version_stage(
name=model_name,
@@ -249,9 +246,9 @@ def promote_model_to_stage(
stage=stage,
archive_existing_versions=archive_existing,
)
logger.info(f"Promoted {model_name} v{version} to {stage}")
return model_version
@@ -262,12 +259,12 @@ def promote_model_to_production(
) -> ModelVersion:
"""
Promote a model version directly to Production.
Args:
model_name: Name of the registered model
version: Version number to promote
tracking_uri: Override default tracking URI
Returns:
The updated ModelVersion
"""
@@ -286,18 +283,18 @@ def get_production_model(
) -> Optional[ModelVersion]:
"""
Get the current Production model version.
Args:
model_name: Name of the registered model
tracking_uri: Override default tracking URI
Returns:
The Production ModelVersion, or None if none exists
"""
client = get_mlflow_client(tracking_uri=tracking_uri)
versions = client.get_latest_versions(model_name, stages=["Production"])
return versions[0] if versions else None
@@ -308,17 +305,17 @@ def get_model_kserve_config(
) -> KServeConfig:
"""
Get KServe configuration from a registered model version.
Args:
model_name: Name of the registered model
version: Version number (uses Production if not specified)
tracking_uri: Override default tracking URI
Returns:
KServeConfig populated from model tags
"""
client = get_mlflow_client(tracking_uri=tracking_uri)
if version:
model_version = client.get_model_version(model_name, str(version))
else:
@@ -326,9 +323,9 @@ def get_model_kserve_config(
if not prod_version:
raise ValueError(f"No Production version for {model_name}")
model_version = prod_version
tags = model_version.tags
return KServeConfig(
runtime=tags.get("kserve.runtime", "kserve-huggingface"),
protocol=tags.get("kserve.protocol", "v2"),
@@ -352,7 +349,7 @@ def generate_kserve_manifest(
) -> Dict[str, Any]:
"""
Generate a KServe InferenceService manifest from a registered model.
Args:
model_name: Name of the registered model
version: Version number (uses Production if not specified)
@@ -360,12 +357,12 @@ def generate_kserve_manifest(
service_name: Name for the InferenceService (defaults to model_name)
extra_annotations: Additional annotations for the service
tracking_uri: Override default tracking URI
Returns:
KServe InferenceService manifest as a dictionary
"""
client = get_mlflow_client(tracking_uri=tracking_uri)
# Get model version
if version:
model_version = client.get_model_version(model_name, str(version))
@@ -375,13 +372,13 @@ def generate_kserve_manifest(
raise ValueError(f"No Production version for {model_name}")
model_version = prod_version
version = int(model_version.version)
# Get KServe config
config = get_model_kserve_config(model_name, version, tracking_uri)
model_type = model_version.tags.get("model.type", "custom")
svc_name = service_name or model_name.lower().replace("_", "-")
# Build manifest
manifest = {
"apiVersion": "serving.kserve.io/v1beta1",
@@ -409,10 +406,10 @@ def generate_kserve_manifest(
},
},
}
# Configure predictor based on runtime
predictor = manifest["spec"]["predictor"]
if config.container_image:
# Custom container
predictor["containers"] = [{
@@ -434,16 +431,16 @@ def generate_kserve_manifest(
for k, v in config.env_vars.items()
],
}]
# Add GPU if needed
if config.gpu_count > 0:
predictor["containers"][0]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
predictor["containers"][0]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
else:
# Standard KServe runtime
storage_uri = config.storage_uri or model_version.source
predictor["model"] = {
"modelFormat": {"name": "huggingface"},
"protocolVersion": config.protocol,
@@ -459,11 +456,11 @@ def generate_kserve_manifest(
},
},
}
if config.gpu_count > 0:
predictor["model"]["resources"]["limits"][config.gpu_type] = str(config.gpu_count)
predictor["model"]["resources"]["requests"][config.gpu_type] = str(config.gpu_count)
return manifest
@@ -476,14 +473,14 @@ def generate_kserve_yaml(
) -> str:
"""
Generate a KServe InferenceService manifest as YAML.
Args:
model_name: Name of the registered model
version: Version number (uses Production if not specified)
namespace: Kubernetes namespace
output_path: If provided, write YAML to this path
tracking_uri: Override default tracking URI
Returns:
YAML string of the manifest
"""
@@ -493,14 +490,14 @@ def generate_kserve_yaml(
namespace=namespace,
tracking_uri=tracking_uri,
)
yaml_str = yaml.dump(manifest, default_flow_style=False, sort_keys=False)
if output_path:
with open(output_path, 'w') as f:
f.write(yaml_str)
logger.info(f"Wrote KServe manifest to {output_path}")
return yaml_str
@@ -511,17 +508,17 @@ def list_model_versions(
) -> List[Dict[str, Any]]:
"""
List all versions of a registered model.
Args:
model_name: Name of the registered model
stages: Filter by stages (None for all)
tracking_uri: Override default tracking URI
Returns:
List of model version info dictionaries
"""
client = get_mlflow_client(tracking_uri=tracking_uri)
if stages:
versions = client.get_latest_versions(model_name, stages=stages)
else:
@@ -529,7 +526,7 @@ def list_model_versions(
versions = []
for mv in client.search_model_versions(f"name='{model_name}'"):
versions.append(mv)
return [
{
"version": mv.version,