fix: resolve all ruff lint errors
Some checks failed
CI / Test (push) Successful in 1m46s
CI / Lint (push) Failing after 1m49s
CI / Publish (push) Has been skipped
CI / Notify (push) Successful in 2s

This commit is contained in:
2026-02-13 10:57:57 -05:00
parent 6bcf84549c
commit 1c841729a0
9 changed files with 456 additions and 464 deletions

View File

@@ -5,19 +5,17 @@ Provides a high-level interface for logging experiments, parameters,
metrics, and artifacts from Kubeflow Pipeline components.
"""
import os
import json
import time
import logging
from pathlib import Path
from typing import Optional, Dict, Any, List, Union
import os
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Union
import mlflow
from mlflow.tracking import MlflowClient
from .client import get_mlflow_client, ensure_experiment, MLflowConfig
from .client import MLflowConfig, ensure_experiment, get_mlflow_client
logger = logging.getLogger(__name__)
@@ -30,7 +28,7 @@ class PipelineMetadata:
run_name: Optional[str] = None
component_name: Optional[str] = None
namespace: str = "ai-ml"
# KFP-specific metadata (populated from environment if available)
kfp_run_id: Optional[str] = field(
default_factory=lambda: os.environ.get("KFP_RUN_ID")
@@ -38,7 +36,7 @@ class PipelineMetadata:
kfp_pod_name: Optional[str] = field(
default_factory=lambda: os.environ.get("KFP_POD_NAME")
)
def as_tags(self) -> Dict[str, str]:
"""Convert metadata to MLflow tags."""
tags = {
@@ -60,34 +58,34 @@ class PipelineMetadata:
class MLflowTracker:
"""
MLflow experiment tracker for Kubeflow Pipeline components.
Example usage in a KFP component:
from mlflow_utils import MLflowTracker
tracker = MLflowTracker(
experiment_name="document-ingestion",
run_name="batch-ingestion-2024-01"
)
with tracker.start_run() as run:
tracker.log_params({
"chunk_size": 500,
"overlap": 50,
"embeddings_model": "bge-small-en-v1.5"
})
# ... do work ...
tracker.log_metrics({
"documents_processed": 100,
"chunks_created": 2500,
"processing_time_seconds": 120.5
})
tracker.log_artifact("/path/to/output.json")
"""
def __init__(
self,
experiment_name: str,
@@ -98,7 +96,7 @@ class MLflowTracker:
):
"""
Initialize the MLflow tracker.
Args:
experiment_name: Name of the MLflow experiment
run_name: Optional name for this run
@@ -112,22 +110,22 @@ class MLflowTracker:
self.pipeline_metadata = pipeline_metadata
self.user_tags = tags or {}
self.tracking_uri = tracking_uri
self.client: Optional[MlflowClient] = None
self.run: Optional[mlflow.ActiveRun] = None
self.run_id: Optional[str] = None
self._start_time: Optional[float] = None
def _get_all_tags(self) -> Dict[str, str]:
"""Combine all tags for the run."""
tags = self.config.default_tags.copy()
if self.pipeline_metadata:
tags.update(self.pipeline_metadata.as_tags())
tags.update(self.user_tags)
return tags
@contextmanager
def start_run(
self,
@@ -136,11 +134,11 @@ class MLflowTracker:
):
"""
Start an MLflow run as a context manager.
Args:
nested: If True, create a nested run under the current active run
parent_run_id: Explicit parent run ID for nested runs
Yields:
The MLflow run object
"""
@@ -148,12 +146,12 @@ class MLflowTracker:
tracking_uri=self.tracking_uri,
configure_global=True
)
# Ensure experiment exists
experiment_id = ensure_experiment(self.experiment_name)
self._start_time = time.time()
try:
# Start the run
self.run = mlflow.start_run(
@@ -163,14 +161,14 @@ class MLflowTracker:
tags=self._get_all_tags(),
)
self.run_id = self.run.info.run_id
logger.info(
f"Started MLflow run '{self.run_name}' "
f"(ID: {self.run_id}) in experiment '{self.experiment_name}'"
)
yield self.run
except Exception as e:
logger.error(f"MLflow run failed: {e}")
if self.run:
@@ -185,22 +183,22 @@ class MLflowTracker:
mlflow.log_metric("run_duration_seconds", duration)
except Exception:
pass
# End the run
mlflow.end_run()
logger.info(f"Ended MLflow run '{self.run_name}'")
def log_params(self, params: Dict[str, Any]) -> None:
"""
Log parameters to the current run.
Args:
params: Dictionary of parameter names to values
"""
if not self.run:
logger.warning("No active run, skipping log_params")
return
# MLflow has limits on param values, truncate if needed
cleaned_params = {}
for key, value in params.items():
@@ -208,14 +206,14 @@ class MLflowTracker:
if len(str_value) > 500:
str_value = str_value[:497] + "..."
cleaned_params[key] = str_value
mlflow.log_params(cleaned_params)
logger.debug(f"Logged {len(params)} parameters")
def log_param(self, key: str, value: Any) -> None:
"""Log a single parameter."""
self.log_params({key: value})
def log_metrics(
self,
metrics: Dict[str, Union[float, int]],
@@ -223,7 +221,7 @@ class MLflowTracker:
) -> None:
"""
Log metrics to the current run.
Args:
metrics: Dictionary of metric names to values
step: Optional step number for time-series metrics
@@ -231,10 +229,10 @@ class MLflowTracker:
if not self.run:
logger.warning("No active run, skipping log_metrics")
return
mlflow.log_metrics(metrics, step=step)
logger.debug(f"Logged {len(metrics)} metrics")
def log_metric(
self,
key: str,
@@ -243,7 +241,7 @@ class MLflowTracker:
) -> None:
"""Log a single metric."""
self.log_metrics({key: value}, step=step)
def log_artifact(
self,
local_path: str,
@@ -251,7 +249,7 @@ class MLflowTracker:
) -> None:
"""
Log an artifact file to the current run.
Args:
local_path: Path to the local file to log
artifact_path: Optional destination path within the artifact store
@@ -259,10 +257,10 @@ class MLflowTracker:
if not self.run:
logger.warning("No active run, skipping log_artifact")
return
mlflow.log_artifact(local_path, artifact_path)
logger.info(f"Logged artifact: {local_path}")
def log_artifacts(
self,
local_dir: str,
@@ -270,7 +268,7 @@ class MLflowTracker:
) -> None:
"""
Log all files in a directory as artifacts.
Args:
local_dir: Path to the local directory
artifact_path: Optional destination path within the artifact store
@@ -278,10 +276,10 @@ class MLflowTracker:
if not self.run:
logger.warning("No active run, skipping log_artifacts")
return
mlflow.log_artifacts(local_dir, artifact_path)
logger.info(f"Logged artifacts from: {local_dir}")
def log_dict(
self,
data: Dict[str, Any],
@@ -290,7 +288,7 @@ class MLflowTracker:
) -> None:
"""
Log a dictionary as a JSON artifact.
Args:
data: Dictionary to log
filename: Name for the JSON file
@@ -299,14 +297,14 @@ class MLflowTracker:
if not self.run:
logger.warning("No active run, skipping log_dict")
return
# Ensure .json extension
if not filename.endswith(".json"):
filename += ".json"
mlflow.log_dict(data, f"{artifact_path}/{filename}" if artifact_path else filename)
logger.debug(f"Logged dict as: {filename}")
def log_model_info(
self,
model_type: str,
@@ -317,7 +315,7 @@ class MLflowTracker:
) -> None:
"""
Log model information as parameters and tags.
Args:
model_type: Type of model (e.g., "llm", "embedding", "stt")
model_name: Name/identifier of the model
@@ -335,13 +333,13 @@ class MLflowTracker:
if extra_info:
for key, value in extra_info.items():
params[f"model.{key}"] = value
self.log_params(params)
# Also set as tags for easier filtering
mlflow.set_tag("model.type", model_type)
mlflow.set_tag("model.name", model_name)
def log_dataset_info(
self,
name: str,
@@ -351,7 +349,7 @@ class MLflowTracker:
) -> None:
"""
Log dataset information.
Args:
name: Dataset name
source: Dataset source (URL, path, etc.)
@@ -367,26 +365,26 @@ class MLflowTracker:
if extra_info:
for key, value in extra_info.items():
params[f"dataset.{key}"] = value
self.log_params(params)
def set_tag(self, key: str, value: str) -> None:
"""Set a single tag on the run."""
if self.run:
mlflow.set_tag(key, value)
def set_tags(self, tags: Dict[str, str]) -> None:
"""Set multiple tags on the run."""
if self.run:
mlflow.set_tags(tags)
@property
def artifact_uri(self) -> Optional[str]:
"""Get the artifact URI for the current run."""
if self.run:
return self.run.info.artifact_uri
return None
@property
def experiment_id(self) -> Optional[str]:
"""Get the experiment ID for the current run."""