fix: resolve all ruff lint errors
This commit is contained in:
@@ -5,10 +5,10 @@ Provides a configured MLflow client for all integrations in the LLM workflows.
|
||||
Supports both in-cluster and external access patterns.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
@@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class MLflowConfig:
|
||||
"""Configuration for MLflow integration."""
|
||||
|
||||
|
||||
# Tracking server URIs
|
||||
tracking_uri: str = field(
|
||||
default_factory=lambda: os.environ.get(
|
||||
@@ -33,7 +33,7 @@ class MLflowConfig:
|
||||
"https://mlflow.lab.daviestechlabs.io"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Artifact storage (NFS PVC mount)
|
||||
artifact_location: str = field(
|
||||
default_factory=lambda: os.environ.get(
|
||||
@@ -41,7 +41,7 @@ class MLflowConfig:
|
||||
"/mlflow/artifacts"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Default experiment settings
|
||||
default_experiment: str = field(
|
||||
default_factory=lambda: os.environ.get(
|
||||
@@ -49,7 +49,7 @@ class MLflowConfig:
|
||||
"llm-workflows"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Service identification
|
||||
service_name: str = field(
|
||||
default_factory=lambda: os.environ.get(
|
||||
@@ -57,10 +57,10 @@ class MLflowConfig:
|
||||
"unknown-service"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Additional tags to add to all runs
|
||||
default_tags: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
"""Add default tags based on environment."""
|
||||
env_tags = {
|
||||
@@ -74,10 +74,10 @@ class MLflowConfig:
|
||||
def get_tracking_uri(external: bool = False) -> str:
|
||||
"""
|
||||
Get the appropriate MLflow tracking URI.
|
||||
|
||||
|
||||
Args:
|
||||
external: If True, return the external URI for outside-cluster access
|
||||
|
||||
|
||||
Returns:
|
||||
The MLflow tracking URI string
|
||||
"""
|
||||
@@ -91,20 +91,20 @@ def get_mlflow_client(
|
||||
) -> MlflowClient:
|
||||
"""
|
||||
Get a configured MLflow client.
|
||||
|
||||
|
||||
Args:
|
||||
tracking_uri: Override the default tracking URI
|
||||
configure_global: If True, also set mlflow.set_tracking_uri()
|
||||
|
||||
|
||||
Returns:
|
||||
Configured MlflowClient instance
|
||||
"""
|
||||
uri = tracking_uri or get_tracking_uri()
|
||||
|
||||
|
||||
if configure_global:
|
||||
mlflow.set_tracking_uri(uri)
|
||||
logger.info(f"MLflow tracking URI set to: {uri}")
|
||||
|
||||
|
||||
client = MlflowClient(tracking_uri=uri)
|
||||
return client
|
||||
|
||||
@@ -116,21 +116,21 @@ def ensure_experiment(
|
||||
) -> str:
|
||||
"""
|
||||
Ensure an experiment exists, creating it if necessary.
|
||||
|
||||
|
||||
Args:
|
||||
experiment_name: Name of the experiment
|
||||
artifact_location: Override default artifact location
|
||||
tags: Additional tags for the experiment
|
||||
|
||||
|
||||
Returns:
|
||||
The experiment ID
|
||||
"""
|
||||
config = MLflowConfig()
|
||||
client = get_mlflow_client()
|
||||
|
||||
|
||||
# Check if experiment exists
|
||||
experiment = client.get_experiment_by_name(experiment_name)
|
||||
|
||||
|
||||
if experiment is None:
|
||||
# Create the experiment
|
||||
artifact_loc = artifact_location or f"{config.artifact_location}/{experiment_name}"
|
||||
@@ -143,7 +143,7 @@ def ensure_experiment(
|
||||
else:
|
||||
experiment_id = experiment.experiment_id
|
||||
logger.debug(f"Using existing experiment '{experiment_name}' with ID: {experiment_id}")
|
||||
|
||||
|
||||
return experiment_id
|
||||
|
||||
|
||||
@@ -154,17 +154,17 @@ def get_or_create_registered_model(
|
||||
) -> str:
|
||||
"""
|
||||
Get or create a registered model in the Model Registry.
|
||||
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to register
|
||||
description: Model description
|
||||
tags: Tags for the model
|
||||
|
||||
|
||||
Returns:
|
||||
The registered model name
|
||||
"""
|
||||
client = get_mlflow_client()
|
||||
|
||||
|
||||
try:
|
||||
# Check if model exists
|
||||
client.get_registered_model(model_name)
|
||||
@@ -177,14 +177,14 @@ def get_or_create_registered_model(
|
||||
tags=tags or {}
|
||||
)
|
||||
logger.info(f"Created registered model: {model_name}")
|
||||
|
||||
|
||||
return model_name
|
||||
|
||||
|
||||
def health_check() -> Dict[str, Any]:
|
||||
"""
|
||||
Check MLflow server connectivity and return status.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with health status information
|
||||
"""
|
||||
@@ -195,7 +195,7 @@ def health_check() -> Dict[str, Any]:
|
||||
"connected": False,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
client = get_mlflow_client(configure_global=False)
|
||||
# Try to list experiments as a health check
|
||||
@@ -205,5 +205,5 @@ def health_check() -> Dict[str, Any]:
|
||||
except Exception as e:
|
||||
result["error"] = str(e)
|
||||
logger.error(f"MLflow health check failed: {e}")
|
||||
|
||||
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user