180 lines
5.4 KiB
Python
180 lines
5.4 KiB
Python
"""
|
|
MLflow Client Configuration and Initialization
|
|
|
|
Provides a configured MLflow client for all integrations in the LLM workflows.
|
|
Supports both in-cluster and external access patterns.
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, Optional
|
|
|
|
import mlflow
|
|
from mlflow.tracking import MlflowClient
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class MLflowConfig:
|
|
"""Configuration for MLflow integration."""
|
|
|
|
# Tracking server URIs
|
|
tracking_uri: str = field(
|
|
default_factory=lambda: os.environ.get("MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80")
|
|
)
|
|
external_uri: str = field(
|
|
default_factory=lambda: os.environ.get("MLFLOW_EXTERNAL_URI", "https://mlflow.lab.daviestechlabs.io")
|
|
)
|
|
|
|
# Artifact storage (NFS PVC mount)
|
|
artifact_location: str = field(
|
|
default_factory=lambda: os.environ.get("MLFLOW_ARTIFACT_LOCATION", "/mlflow/artifacts")
|
|
)
|
|
|
|
# Default experiment settings
|
|
default_experiment: str = field(
|
|
default_factory=lambda: os.environ.get("MLFLOW_DEFAULT_EXPERIMENT", "llm-workflows")
|
|
)
|
|
|
|
# Service identification
|
|
service_name: str = field(default_factory=lambda: os.environ.get("OTEL_SERVICE_NAME", "unknown-service"))
|
|
|
|
# Additional tags to add to all runs
|
|
default_tags: Dict[str, str] = field(default_factory=dict)
|
|
|
|
def __post_init__(self):
|
|
"""Add default tags based on environment."""
|
|
env_tags = {
|
|
"environment": os.environ.get("DEPLOYMENT_ENV", "production"),
|
|
"hostname": os.environ.get("HOSTNAME", "unknown"),
|
|
"namespace": os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml"),
|
|
}
|
|
self.default_tags = {**env_tags, **self.default_tags}
|
|
|
|
|
|
def get_tracking_uri(external: bool = False) -> str:
|
|
"""
|
|
Get the appropriate MLflow tracking URI.
|
|
|
|
Args:
|
|
external: If True, return the external URI for outside-cluster access
|
|
|
|
Returns:
|
|
The MLflow tracking URI string
|
|
"""
|
|
config = MLflowConfig()
|
|
return config.external_uri if external else config.tracking_uri
|
|
|
|
|
|
def get_mlflow_client(tracking_uri: Optional[str] = None, configure_global: bool = True) -> MlflowClient:
|
|
"""
|
|
Get a configured MLflow client.
|
|
|
|
Args:
|
|
tracking_uri: Override the default tracking URI
|
|
configure_global: If True, also set mlflow.set_tracking_uri()
|
|
|
|
Returns:
|
|
Configured MlflowClient instance
|
|
"""
|
|
uri = tracking_uri or get_tracking_uri()
|
|
|
|
if configure_global:
|
|
mlflow.set_tracking_uri(uri)
|
|
logger.info(f"MLflow tracking URI set to: {uri}")
|
|
|
|
client = MlflowClient(tracking_uri=uri)
|
|
return client
|
|
|
|
|
|
def ensure_experiment(
|
|
experiment_name: str, artifact_location: Optional[str] = None, tags: Optional[Dict[str, str]] = None
|
|
) -> str:
|
|
"""
|
|
Ensure an experiment exists, creating it if necessary.
|
|
|
|
Args:
|
|
experiment_name: Name of the experiment
|
|
artifact_location: Override default artifact location
|
|
tags: Additional tags for the experiment
|
|
|
|
Returns:
|
|
The experiment ID
|
|
"""
|
|
config = MLflowConfig()
|
|
client = get_mlflow_client()
|
|
|
|
# Check if experiment exists
|
|
experiment = client.get_experiment_by_name(experiment_name)
|
|
|
|
if experiment is None:
|
|
# Create the experiment
|
|
artifact_loc = artifact_location or f"{config.artifact_location}/{experiment_name}"
|
|
experiment_id = client.create_experiment(name=experiment_name, artifact_location=artifact_loc, tags=tags or {})
|
|
logger.info(f"Created experiment '{experiment_name}' with ID: {experiment_id}")
|
|
else:
|
|
experiment_id = experiment.experiment_id
|
|
logger.debug(f"Using existing experiment '{experiment_name}' with ID: {experiment_id}")
|
|
|
|
return experiment_id
|
|
|
|
|
|
def get_or_create_registered_model(
|
|
model_name: str, description: Optional[str] = None, tags: Optional[Dict[str, str]] = None
|
|
) -> str:
|
|
"""
|
|
Get or create a registered model in the Model Registry.
|
|
|
|
Args:
|
|
model_name: Name of the model to register
|
|
description: Model description
|
|
tags: Tags for the model
|
|
|
|
Returns:
|
|
The registered model name
|
|
"""
|
|
client = get_mlflow_client()
|
|
|
|
try:
|
|
# Check if model exists
|
|
client.get_registered_model(model_name)
|
|
logger.debug(f"Using existing registered model: {model_name}")
|
|
except mlflow.exceptions.MlflowException:
|
|
# Create the model
|
|
client.create_registered_model(
|
|
name=model_name, description=description or f"Model for {model_name}", tags=tags or {}
|
|
)
|
|
logger.info(f"Created registered model: {model_name}")
|
|
|
|
return model_name
|
|
|
|
|
|
def health_check() -> Dict[str, Any]:
|
|
"""
|
|
Check MLflow server connectivity and return status.
|
|
|
|
Returns:
|
|
Dictionary with health status information
|
|
"""
|
|
config = MLflowConfig()
|
|
result = {
|
|
"tracking_uri": config.tracking_uri,
|
|
"external_uri": config.external_uri,
|
|
"connected": False,
|
|
"error": None,
|
|
}
|
|
|
|
try:
|
|
client = get_mlflow_client(configure_global=False)
|
|
# Try to list experiments as a health check
|
|
experiments = client.search_experiments(max_results=1)
|
|
result["connected"] = True
|
|
result["experiment_count"] = len(experiments)
|
|
except Exception as e:
|
|
result["error"] = str(e)
|
|
logger.error(f"MLflow health check failed: {e}")
|
|
|
|
return result
|