style: apply ruff format to all files
Some checks failed
CI / Lint (push) Successful in 1m46s
CI / Test (push) Successful in 1m44s
CI / Publish (push) Failing after 19s
CI / Notify (push) Successful in 1s

This commit is contained in:
2026-02-13 11:05:26 -05:00
parent 1c841729a0
commit ca5bef9664
7 changed files with 89 additions and 222 deletions

View File

@@ -47,17 +47,14 @@ MLFLOW_PACKAGES = [
]
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
def create_mlflow_run(
experiment_name: str,
run_name: str,
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
tags: Dict[str, str] = None,
params: Dict[str, str] = None,
) -> NamedTuple('RunInfo', [('run_id', str), ('experiment_id', str), ('artifact_uri', str)]):
) -> NamedTuple("RunInfo", [("run_id", str), ("experiment_id", str), ("artifact_uri", str)]):
"""
Create a new MLflow run for the pipeline.
@@ -90,8 +87,7 @@ def create_mlflow_run(
experiment = client.get_experiment_by_name(experiment_name)
if experiment is None:
experiment_id = client.create_experiment(
name=experiment_name,
artifact_location=f"/mlflow/artifacts/{experiment_name}"
name=experiment_name, artifact_location=f"/mlflow/artifacts/{experiment_name}"
)
else:
experiment_id = experiment.experiment_id
@@ -122,14 +118,11 @@ def create_mlflow_run(
# End run (KFP components are isolated, we'll resume in other components)
mlflow.end_run()
RunInfo = namedtuple('RunInfo', ['run_id', 'experiment_id', 'artifact_uri'])
RunInfo = namedtuple("RunInfo", ["run_id", "experiment_id", "artifact_uri"])
return RunInfo(run_id, experiment_id, artifact_uri)
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
def log_params_component(
run_id: str,
params: Dict[str, str],
@@ -158,10 +151,7 @@ def log_params_component(
return run_id
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
def log_metrics_component(
run_id: str,
metrics: Dict[str, float],
@@ -192,10 +182,7 @@ def log_metrics_component(
return run_id
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
def log_artifact_component(
run_id: str,
artifact_path: str,
@@ -225,10 +212,7 @@ def log_artifact_component(
return run_id
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
def log_dict_artifact(
run_id: str,
data: Dict[str, Any],
@@ -258,23 +242,20 @@ def log_dict_artifact(
client = MlflowClient()
# Ensure .json extension
if not filename.endswith('.json'):
filename += '.json'
if not filename.endswith(".json"):
filename += ".json"
# Write to temp file and log
with tempfile.TemporaryDirectory() as tmpdir:
filepath = Path(tmpdir) / filename
with open(filepath, 'w') as f:
with open(filepath, "w") as f:
json.dump(data, f, indent=2)
client.log_artifact(run_id, str(filepath))
return run_id
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
def end_mlflow_run(
run_id: str,
status: str = "FINISHED",
@@ -310,10 +291,7 @@ def end_mlflow_run(
return run_id
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES + ["httpx"]
)
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES + ["httpx"])
def log_training_metrics(
run_id: str,
model_type: str,
@@ -370,7 +348,7 @@ def log_training_metrics(
# Log full config as artifact
with tempfile.TemporaryDirectory() as tmpdir:
config_path = Path(tmpdir) / "training_config.json"
with open(config_path, 'w') as f:
with open(config_path, "w") as f:
json.dump(training_config, f, indent=2)
client.log_artifact(run_id, str(config_path))
@@ -382,10 +360,7 @@ def log_training_metrics(
return run_id
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
def log_document_ingestion_metrics(
run_id: str,
source_url: str,
@@ -452,10 +427,7 @@ def log_document_ingestion_metrics(
return run_id
@dsl.component(
base_image=MLFLOW_IMAGE,
packages_to_install=MLFLOW_PACKAGES
)
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
def log_evaluation_results(
run_id: str,
model_name: str,
@@ -502,7 +474,7 @@ def log_evaluation_results(
if sample_results:
with tempfile.TemporaryDirectory() as tmpdir:
results_path = Path(tmpdir) / "evaluation_results.json"
with open(results_path, 'w') as f:
with open(results_path, "w") as f:
json.dump(sample_results, f, indent=2)
client.log_artifact(run_id, str(results_path))