""" MLflow Client Configuration and Initialization Provides a configured MLflow client for all integrations in the LLM workflows. Supports both in-cluster and external access patterns. """ import os import logging from dataclasses import dataclass, field from typing import Optional, Dict, Any import mlflow from mlflow.tracking import MlflowClient logger = logging.getLogger(__name__) @dataclass class MLflowConfig: """Configuration for MLflow integration.""" # Tracking server URIs tracking_uri: str = field( default_factory=lambda: os.environ.get( "MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80" ) ) external_uri: str = field( default_factory=lambda: os.environ.get( "MLFLOW_EXTERNAL_URI", "https://mlflow.lab.daviestechlabs.io" ) ) # Artifact storage (NFS PVC mount) artifact_location: str = field( default_factory=lambda: os.environ.get( "MLFLOW_ARTIFACT_LOCATION", "/mlflow/artifacts" ) ) # Default experiment settings default_experiment: str = field( default_factory=lambda: os.environ.get( "MLFLOW_DEFAULT_EXPERIMENT", "llm-workflows" ) ) # Service identification service_name: str = field( default_factory=lambda: os.environ.get( "OTEL_SERVICE_NAME", "unknown-service" ) ) # Additional tags to add to all runs default_tags: Dict[str, str] = field(default_factory=dict) def __post_init__(self): """Add default tags based on environment.""" env_tags = { "environment": os.environ.get("DEPLOYMENT_ENV", "production"), "hostname": os.environ.get("HOSTNAME", "unknown"), "namespace": os.environ.get("OTEL_SERVICE_NAMESPACE", "ai-ml"), } self.default_tags = {**env_tags, **self.default_tags} def get_tracking_uri(external: bool = False) -> str: """ Get the appropriate MLflow tracking URI. Args: external: If True, return the external URI for outside-cluster access Returns: The MLflow tracking URI string """ config = MLflowConfig() return config.external_uri if external else config.tracking_uri def get_mlflow_client( tracking_uri: Optional[str] = None, configure_global: bool = True ) -> MlflowClient: """ Get a configured MLflow client. Args: tracking_uri: Override the default tracking URI configure_global: If True, also set mlflow.set_tracking_uri() Returns: Configured MlflowClient instance """ uri = tracking_uri or get_tracking_uri() if configure_global: mlflow.set_tracking_uri(uri) logger.info(f"MLflow tracking URI set to: {uri}") client = MlflowClient(tracking_uri=uri) return client def ensure_experiment( experiment_name: str, artifact_location: Optional[str] = None, tags: Optional[Dict[str, str]] = None ) -> str: """ Ensure an experiment exists, creating it if necessary. Args: experiment_name: Name of the experiment artifact_location: Override default artifact location tags: Additional tags for the experiment Returns: The experiment ID """ config = MLflowConfig() client = get_mlflow_client() # Check if experiment exists experiment = client.get_experiment_by_name(experiment_name) if experiment is None: # Create the experiment artifact_loc = artifact_location or f"{config.artifact_location}/{experiment_name}" experiment_id = client.create_experiment( name=experiment_name, artifact_location=artifact_loc, tags=tags or {} ) logger.info(f"Created experiment '{experiment_name}' with ID: {experiment_id}") else: experiment_id = experiment.experiment_id logger.debug(f"Using existing experiment '{experiment_name}' with ID: {experiment_id}") return experiment_id def get_or_create_registered_model( model_name: str, description: Optional[str] = None, tags: Optional[Dict[str, str]] = None ) -> str: """ Get or create a registered model in the Model Registry. Args: model_name: Name of the model to register description: Model description tags: Tags for the model Returns: The registered model name """ client = get_mlflow_client() try: # Check if model exists client.get_registered_model(model_name) logger.debug(f"Using existing registered model: {model_name}") except mlflow.exceptions.MlflowException: # Create the model client.create_registered_model( name=model_name, description=description or f"Model for {model_name}", tags=tags or {} ) logger.info(f"Created registered model: {model_name}") return model_name def health_check() -> Dict[str, Any]: """ Check MLflow server connectivity and return status. Returns: Dictionary with health status information """ config = MLflowConfig() result = { "tracking_uri": config.tracking_uri, "external_uri": config.external_uri, "connected": False, "error": None, } try: client = get_mlflow_client(configure_global=False) # Try to list experiments as a health check experiments = client.search_experiments(max_results=1) result["connected"] = True result["experiment_count"] = len(experiments) except Exception as e: result["error"] = str(e) logger.error(f"MLflow health check failed: {e}") return result