style: apply ruff format to all files
This commit is contained in:
@@ -47,17 +47,14 @@ MLFLOW_PACKAGES = [
|
||||
]
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
|
||||
def create_mlflow_run(
|
||||
experiment_name: str,
|
||||
run_name: str,
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
tags: Dict[str, str] = None,
|
||||
params: Dict[str, str] = None,
|
||||
) -> NamedTuple('RunInfo', [('run_id', str), ('experiment_id', str), ('artifact_uri', str)]):
|
||||
) -> NamedTuple("RunInfo", [("run_id", str), ("experiment_id", str), ("artifact_uri", str)]):
|
||||
"""
|
||||
Create a new MLflow run for the pipeline.
|
||||
|
||||
@@ -90,8 +87,7 @@ def create_mlflow_run(
|
||||
experiment = client.get_experiment_by_name(experiment_name)
|
||||
if experiment is None:
|
||||
experiment_id = client.create_experiment(
|
||||
name=experiment_name,
|
||||
artifact_location=f"/mlflow/artifacts/{experiment_name}"
|
||||
name=experiment_name, artifact_location=f"/mlflow/artifacts/{experiment_name}"
|
||||
)
|
||||
else:
|
||||
experiment_id = experiment.experiment_id
|
||||
@@ -122,14 +118,11 @@ def create_mlflow_run(
|
||||
# End run (KFP components are isolated, we'll resume in other components)
|
||||
mlflow.end_run()
|
||||
|
||||
RunInfo = namedtuple('RunInfo', ['run_id', 'experiment_id', 'artifact_uri'])
|
||||
RunInfo = namedtuple("RunInfo", ["run_id", "experiment_id", "artifact_uri"])
|
||||
return RunInfo(run_id, experiment_id, artifact_uri)
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
|
||||
def log_params_component(
|
||||
run_id: str,
|
||||
params: Dict[str, str],
|
||||
@@ -158,10 +151,7 @@ def log_params_component(
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
|
||||
def log_metrics_component(
|
||||
run_id: str,
|
||||
metrics: Dict[str, float],
|
||||
@@ -192,10 +182,7 @@ def log_metrics_component(
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
|
||||
def log_artifact_component(
|
||||
run_id: str,
|
||||
artifact_path: str,
|
||||
@@ -225,10 +212,7 @@ def log_artifact_component(
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
|
||||
def log_dict_artifact(
|
||||
run_id: str,
|
||||
data: Dict[str, Any],
|
||||
@@ -258,23 +242,20 @@ def log_dict_artifact(
|
||||
client = MlflowClient()
|
||||
|
||||
# Ensure .json extension
|
||||
if not filename.endswith('.json'):
|
||||
filename += '.json'
|
||||
if not filename.endswith(".json"):
|
||||
filename += ".json"
|
||||
|
||||
# Write to temp file and log
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
filepath = Path(tmpdir) / filename
|
||||
with open(filepath, 'w') as f:
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
client.log_artifact(run_id, str(filepath))
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
|
||||
def end_mlflow_run(
|
||||
run_id: str,
|
||||
status: str = "FINISHED",
|
||||
@@ -310,10 +291,7 @@ def end_mlflow_run(
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES + ["httpx"]
|
||||
)
|
||||
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES + ["httpx"])
|
||||
def log_training_metrics(
|
||||
run_id: str,
|
||||
model_type: str,
|
||||
@@ -370,7 +348,7 @@ def log_training_metrics(
|
||||
# Log full config as artifact
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config_path = Path(tmpdir) / "training_config.json"
|
||||
with open(config_path, 'w') as f:
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(training_config, f, indent=2)
|
||||
client.log_artifact(run_id, str(config_path))
|
||||
|
||||
@@ -382,10 +360,7 @@ def log_training_metrics(
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
|
||||
def log_document_ingestion_metrics(
|
||||
run_id: str,
|
||||
source_url: str,
|
||||
@@ -452,10 +427,7 @@ def log_document_ingestion_metrics(
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
|
||||
def log_evaluation_results(
|
||||
run_id: str,
|
||||
model_name: str,
|
||||
@@ -502,7 +474,7 @@ def log_evaluation_results(
|
||||
if sample_results:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
results_path = Path(tmpdir) / "evaluation_results.json"
|
||||
with open(results_path, 'w') as f:
|
||||
with open(results_path, "w") as f:
|
||||
json.dump(sample_results, f, indent=2)
|
||||
client.log_artifact(run_id, str(results_path))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user