style: apply ruff format to all files
Some checks failed
CI / Lint (push) Successful in 1m46s
CI / Test (push) Successful in 1m44s
CI / Publish (push) Failing after 19s
CI / Notify (push) Successful in 1s

This commit is contained in:
2026-02-13 11:05:26 -05:00
parent 1c841729a0
commit ca5bef9664
7 changed files with 89 additions and 222 deletions

View File

@@ -23,6 +23,7 @@ logger = logging.getLogger(__name__)
@dataclass
class PipelineMetadata:
"""Metadata about the Kubeflow Pipeline run."""
pipeline_name: str
run_id: str
run_name: Optional[str] = None
@@ -30,12 +31,8 @@ class PipelineMetadata:
namespace: str = "ai-ml"
# KFP-specific metadata (populated from environment if available)
kfp_run_id: Optional[str] = field(
default_factory=lambda: os.environ.get("KFP_RUN_ID")
)
kfp_pod_name: Optional[str] = field(
default_factory=lambda: os.environ.get("KFP_POD_NAME")
)
kfp_run_id: Optional[str] = field(default_factory=lambda: os.environ.get("KFP_RUN_ID"))
kfp_pod_name: Optional[str] = field(default_factory=lambda: os.environ.get("KFP_POD_NAME"))
def as_tags(self) -> Dict[str, str]:
"""Convert metadata to MLflow tags."""
@@ -142,10 +139,7 @@ class MLflowTracker:
Yields:
The MLflow run object
"""
self.client = get_mlflow_client(
tracking_uri=self.tracking_uri,
configure_global=True
)
self.client = get_mlflow_client(tracking_uri=self.tracking_uri, configure_global=True)
# Ensure experiment exists
experiment_id = ensure_experiment(self.experiment_name)
@@ -163,8 +157,7 @@ class MLflowTracker:
self.run_id = self.run.info.run_id
logger.info(
f"Started MLflow run '{self.run_name}' "
f"(ID: {self.run_id}) in experiment '{self.experiment_name}'"
f"Started MLflow run '{self.run_name}' (ID: {self.run_id}) in experiment '{self.experiment_name}'"
)
yield self.run
@@ -214,11 +207,7 @@ class MLflowTracker:
"""Log a single parameter."""
self.log_params({key: value})
def log_metrics(
self,
metrics: Dict[str, Union[float, int]],
step: Optional[int] = None
) -> None:
def log_metrics(self, metrics: Dict[str, Union[float, int]], step: Optional[int] = None) -> None:
"""
Log metrics to the current run.
@@ -233,20 +222,11 @@ class MLflowTracker:
mlflow.log_metrics(metrics, step=step)
logger.debug(f"Logged {len(metrics)} metrics")
def log_metric(
self,
key: str,
value: Union[float, int],
step: Optional[int] = None
) -> None:
def log_metric(self, key: str, value: Union[float, int], step: Optional[int] = None) -> None:
"""Log a single metric."""
self.log_metrics({key: value}, step=step)
def log_artifact(
self,
local_path: str,
artifact_path: Optional[str] = None
) -> None:
def log_artifact(self, local_path: str, artifact_path: Optional[str] = None) -> None:
"""
Log an artifact file to the current run.
@@ -261,11 +241,7 @@ class MLflowTracker:
mlflow.log_artifact(local_path, artifact_path)
logger.info(f"Logged artifact: {local_path}")
def log_artifacts(
self,
local_dir: str,
artifact_path: Optional[str] = None
) -> None:
def log_artifacts(self, local_dir: str, artifact_path: Optional[str] = None) -> None:
"""
Log all files in a directory as artifacts.
@@ -280,12 +256,7 @@ class MLflowTracker:
mlflow.log_artifacts(local_dir, artifact_path)
logger.info(f"Logged artifacts from: {local_dir}")
def log_dict(
self,
data: Dict[str, Any],
filename: str,
artifact_path: Optional[str] = None
) -> None:
def log_dict(self, data: Dict[str, Any], filename: str, artifact_path: Optional[str] = None) -> None:
"""
Log a dictionary as a JSON artifact.
@@ -311,7 +282,7 @@ class MLflowTracker:
model_name: str,
model_path: Optional[str] = None,
framework: str = "pytorch",
extra_info: Optional[Dict[str, Any]] = None
extra_info: Optional[Dict[str, Any]] = None,
) -> None:
"""
Log model information as parameters and tags.
@@ -341,11 +312,7 @@ class MLflowTracker:
mlflow.set_tag("model.name", model_name)
def log_dataset_info(
self,
name: str,
source: str,
size: Optional[int] = None,
extra_info: Optional[Dict[str, Any]] = None
self, name: str, source: str, size: Optional[int] = None, extra_info: Optional[Dict[str, Any]] = None
) -> None:
"""
Log dataset information.