fix: resolve all ruff lint errors
Some checks failed
CI / Test (push) Successful in 1m46s
CI / Lint (push) Failing after 1m49s
CI / Publish (push) Has been skipped
CI / Notify (push) Successful in 2s

This commit is contained in:
2026-02-13 10:57:57 -05:00
parent 6bcf84549c
commit 1c841729a0
9 changed files with 456 additions and 464 deletions

View File

@@ -5,10 +5,10 @@ Provides a configured MLflow client for all integrations in the LLM workflows.
Supports both in-cluster and external access patterns.
"""
import os
import logging
import os
from dataclasses import dataclass, field
from typing import Optional, Dict, Any
from typing import Any, Dict, Optional
import mlflow
from mlflow.tracking import MlflowClient
@@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
@dataclass
class MLflowConfig:
"""Configuration for MLflow integration."""
# Tracking server URIs
tracking_uri: str = field(
default_factory=lambda: os.environ.get(
@@ -33,7 +33,7 @@ class MLflowConfig:
"https://mlflow.lab.daviestechlabs.io"
)
)
# Artifact storage (NFS PVC mount)
artifact_location: str = field(
default_factory=lambda: os.environ.get(
@@ -41,7 +41,7 @@ class MLflowConfig:
"/mlflow/artifacts"
)
)
# Default experiment settings
default_experiment: str = field(
default_factory=lambda: os.environ.get(
@@ -49,7 +49,7 @@ class MLflowConfig:
"llm-workflows"
)
)
# Service identification
service_name: str = field(
default_factory=lambda: os.environ.get(
@@ -57,10 +57,10 @@ class MLflowConfig:
"unknown-service"
)
)
# Additional tags to add to all runs
default_tags: Dict[str, str] = field(default_factory=dict)
def __post_init__(self):
"""Add default tags based on environment."""
env_tags = {
@@ -74,10 +74,10 @@ class MLflowConfig:
def get_tracking_uri(external: bool = False) -> str:
"""
Get the appropriate MLflow tracking URI.
Args:
external: If True, return the external URI for outside-cluster access
Returns:
The MLflow tracking URI string
"""
@@ -91,20 +91,20 @@ def get_mlflow_client(
) -> MlflowClient:
"""
Get a configured MLflow client.
Args:
tracking_uri: Override the default tracking URI
configure_global: If True, also set mlflow.set_tracking_uri()
Returns:
Configured MlflowClient instance
"""
uri = tracking_uri or get_tracking_uri()
if configure_global:
mlflow.set_tracking_uri(uri)
logger.info(f"MLflow tracking URI set to: {uri}")
client = MlflowClient(tracking_uri=uri)
return client
@@ -116,21 +116,21 @@ def ensure_experiment(
) -> str:
"""
Ensure an experiment exists, creating it if necessary.
Args:
experiment_name: Name of the experiment
artifact_location: Override default artifact location
tags: Additional tags for the experiment
Returns:
The experiment ID
"""
config = MLflowConfig()
client = get_mlflow_client()
# Check if experiment exists
experiment = client.get_experiment_by_name(experiment_name)
if experiment is None:
# Create the experiment
artifact_loc = artifact_location or f"{config.artifact_location}/{experiment_name}"
@@ -143,7 +143,7 @@ def ensure_experiment(
else:
experiment_id = experiment.experiment_id
logger.debug(f"Using existing experiment '{experiment_name}' with ID: {experiment_id}")
return experiment_id
@@ -154,17 +154,17 @@ def get_or_create_registered_model(
) -> str:
"""
Get or create a registered model in the Model Registry.
Args:
model_name: Name of the model to register
description: Model description
tags: Tags for the model
Returns:
The registered model name
"""
client = get_mlflow_client()
try:
# Check if model exists
client.get_registered_model(model_name)
@@ -177,14 +177,14 @@ def get_or_create_registered_model(
tags=tags or {}
)
logger.info(f"Created registered model: {model_name}")
return model_name
def health_check() -> Dict[str, Any]:
"""
Check MLflow server connectivity and return status.
Returns:
Dictionary with health status information
"""
@@ -195,7 +195,7 @@ def health_check() -> Dict[str, Any]:
"connected": False,
"error": None,
}
try:
client = get_mlflow_client(configure_global=False)
# Try to list experiments as a health check
@@ -205,5 +205,5 @@ def health_check() -> Dict[str, Any]:
except Exception as e:
result["error"] = str(e)
logger.error(f"MLflow health check failed: {e}")
return result