fix: resolve all ruff lint errors
This commit is contained in:
@@ -21,22 +21,22 @@ Usage in a Kubeflow Pipeline:
|
||||
experiment_name="my-experiment",
|
||||
run_name="training-run-1"
|
||||
)
|
||||
|
||||
|
||||
# ... your pipeline steps ...
|
||||
|
||||
|
||||
# Log metrics
|
||||
log_step = log_metrics_component(
|
||||
run_id=run_info.outputs["run_id"],
|
||||
metrics={"accuracy": 0.95, "loss": 0.05}
|
||||
)
|
||||
|
||||
|
||||
# End run
|
||||
end_mlflow_run(run_id=run_info.outputs["run_id"])
|
||||
"""
|
||||
|
||||
from kfp import dsl
|
||||
from typing import Dict, Any, List, Optional, NamedTuple
|
||||
from typing import Any, Dict, List, NamedTuple
|
||||
|
||||
from kfp import dsl
|
||||
|
||||
# MLflow component image with all required dependencies
|
||||
MLFLOW_IMAGE = "python:3.13-slim"
|
||||
@@ -60,31 +60,32 @@ def create_mlflow_run(
|
||||
) -> NamedTuple('RunInfo', [('run_id', str), ('experiment_id', str), ('artifact_uri', str)]):
|
||||
"""
|
||||
Create a new MLflow run for the pipeline.
|
||||
|
||||
|
||||
This should be called at the start of a pipeline to initialize
|
||||
tracking. The returned run_id should be passed to subsequent
|
||||
components for logging.
|
||||
|
||||
|
||||
Args:
|
||||
experiment_name: Name of the MLflow experiment
|
||||
run_name: Name for this specific run
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
tags: Optional tags to add to the run
|
||||
params: Optional parameters to log
|
||||
|
||||
|
||||
Returns:
|
||||
NamedTuple with run_id, experiment_id, and artifact_uri
|
||||
"""
|
||||
import os
|
||||
from collections import namedtuple
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
# Set tracking URI
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
|
||||
|
||||
client = MlflowClient()
|
||||
|
||||
|
||||
# Get or create experiment
|
||||
experiment = client.get_experiment_by_name(experiment_name)
|
||||
if experiment is None:
|
||||
@@ -94,7 +95,7 @@ def create_mlflow_run(
|
||||
)
|
||||
else:
|
||||
experiment_id = experiment.experiment_id
|
||||
|
||||
|
||||
# Create default tags
|
||||
default_tags = {
|
||||
"pipeline.type": "kubeflow",
|
||||
@@ -103,24 +104,24 @@ def create_mlflow_run(
|
||||
}
|
||||
if tags:
|
||||
default_tags.update(tags)
|
||||
|
||||
|
||||
# Start run
|
||||
run = mlflow.start_run(
|
||||
experiment_id=experiment_id,
|
||||
run_name=run_name,
|
||||
tags=default_tags,
|
||||
)
|
||||
|
||||
|
||||
# Log initial params
|
||||
if params:
|
||||
mlflow.log_params(params)
|
||||
|
||||
|
||||
run_id = run.info.run_id
|
||||
artifact_uri = run.info.artifact_uri
|
||||
|
||||
|
||||
# End run (KFP components are isolated, we'll resume in other components)
|
||||
mlflow.end_run()
|
||||
|
||||
|
||||
RunInfo = namedtuple('RunInfo', ['run_id', 'experiment_id', 'artifact_uri'])
|
||||
return RunInfo(run_id, experiment_id, artifact_uri)
|
||||
|
||||
@@ -136,24 +137,24 @@ def log_params_component(
|
||||
) -> str:
|
||||
"""
|
||||
Log parameters to an existing MLflow run.
|
||||
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
params: Dictionary of parameters to log
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
|
||||
for key, value in params.items():
|
||||
client.log_param(run_id, key, str(value)[:500])
|
||||
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@@ -169,25 +170,25 @@ def log_metrics_component(
|
||||
) -> str:
|
||||
"""
|
||||
Log metrics to an existing MLflow run.
|
||||
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
metrics: Dictionary of metrics to log
|
||||
step: Step number for time-series metrics
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
|
||||
for key, value in metrics.items():
|
||||
client.log_metric(run_id, key, float(value), step=step)
|
||||
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@@ -203,24 +204,24 @@ def log_artifact_component(
|
||||
) -> str:
|
||||
"""
|
||||
Log an artifact file to an existing MLflow run.
|
||||
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
artifact_path: Path to the artifact file
|
||||
artifact_name: Optional destination name in artifact store
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
|
||||
client.log_artifact(run_id, artifact_path, artifact_name or None)
|
||||
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@@ -236,36 +237,37 @@ def log_dict_artifact(
|
||||
) -> str:
|
||||
"""
|
||||
Log a dictionary as a JSON artifact.
|
||||
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
data: Dictionary to save as JSON
|
||||
filename: Name for the JSON file
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
|
||||
# Ensure .json extension
|
||||
if not filename.endswith('.json'):
|
||||
filename += '.json'
|
||||
|
||||
|
||||
# Write to temp file and log
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
filepath = Path(tmpdir) / filename
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
client.log_artifact(run_id, str(filepath))
|
||||
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@@ -280,31 +282,31 @@ def end_mlflow_run(
|
||||
) -> str:
|
||||
"""
|
||||
End an MLflow run with the specified status.
|
||||
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to end
|
||||
status: Run status (FINISHED, FAILED, KILLED)
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
|
||||
Returns:
|
||||
The run_id
|
||||
"""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from mlflow.entities import RunStatus
|
||||
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
|
||||
status_map = {
|
||||
"FINISHED": RunStatus.FINISHED,
|
||||
"FAILED": RunStatus.FAILED,
|
||||
"KILLED": RunStatus.KILLED,
|
||||
}
|
||||
|
||||
|
||||
run_status = status_map.get(status.upper(), RunStatus.FINISHED)
|
||||
client.set_terminated(run_id, status=run_status)
|
||||
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@@ -322,10 +324,10 @@ def log_training_metrics(
|
||||
) -> str:
|
||||
"""
|
||||
Log comprehensive training metrics for ML models.
|
||||
|
||||
|
||||
Designed for use with QLoRA training, voice training, and other
|
||||
ML training pipelines in the llm-workflows repository.
|
||||
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
model_type: Type of model (llm, stt, tts, embeddings)
|
||||
@@ -333,19 +335,20 @@ def log_training_metrics(
|
||||
final_metrics: Final training metrics
|
||||
model_path: Path to saved model (if applicable)
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
|
||||
# Log training config as params
|
||||
flat_config = {}
|
||||
for key, value in training_config.items():
|
||||
@@ -353,29 +356,29 @@ def log_training_metrics(
|
||||
flat_config[f"config.{key}"] = json.dumps(value)[:500]
|
||||
else:
|
||||
flat_config[f"config.{key}"] = str(value)[:500]
|
||||
|
||||
|
||||
for key, value in flat_config.items():
|
||||
client.log_param(run_id, key, value)
|
||||
|
||||
|
||||
# Log model type tag
|
||||
client.set_tag(run_id, "model.type", model_type)
|
||||
|
||||
|
||||
# Log metrics
|
||||
for key, value in final_metrics.items():
|
||||
client.log_metric(run_id, key, float(value))
|
||||
|
||||
|
||||
# Log full config as artifact
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config_path = Path(tmpdir) / "training_config.json"
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(training_config, f, indent=2)
|
||||
client.log_artifact(run_id, str(config_path))
|
||||
|
||||
|
||||
# Log model path if provided
|
||||
if model_path:
|
||||
client.log_param(run_id, "model.path", model_path)
|
||||
client.set_tag(run_id, "model.saved", "true")
|
||||
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@@ -397,9 +400,9 @@ def log_document_ingestion_metrics(
|
||||
) -> str:
|
||||
"""
|
||||
Log document ingestion pipeline metrics.
|
||||
|
||||
|
||||
Designed for use with the document_ingestion_pipeline.
|
||||
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
source_url: URL of the source document
|
||||
@@ -411,16 +414,16 @@ def log_document_ingestion_metrics(
|
||||
chunk_size: Chunk size in tokens
|
||||
chunk_overlap: Chunk overlap in tokens
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
|
||||
# Log params
|
||||
params = {
|
||||
"source_url": source_url[:500],
|
||||
@@ -431,7 +434,7 @@ def log_document_ingestion_metrics(
|
||||
}
|
||||
for key, value in params.items():
|
||||
client.log_param(run_id, key, value)
|
||||
|
||||
|
||||
# Log metrics
|
||||
metrics = {
|
||||
"chunks_created": chunks_created,
|
||||
@@ -441,11 +444,11 @@ def log_document_ingestion_metrics(
|
||||
}
|
||||
for key, value in metrics.items():
|
||||
client.log_metric(run_id, key, float(value))
|
||||
|
||||
|
||||
# Set pipeline type tag
|
||||
client.set_tag(run_id, "pipeline.type", "document-ingestion")
|
||||
client.set_tag(run_id, "milvus.collection", collection_name)
|
||||
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@@ -463,9 +466,9 @@ def log_evaluation_results(
|
||||
) -> str:
|
||||
"""
|
||||
Log model evaluation results.
|
||||
|
||||
|
||||
Designed for use with the evaluation_pipeline.
|
||||
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
model_name: Name of the evaluated model
|
||||
@@ -473,27 +476,28 @@ def log_evaluation_results(
|
||||
metrics: Evaluation metrics (accuracy, etc.)
|
||||
sample_results: Optional sample predictions
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
|
||||
# Log params
|
||||
client.log_param(run_id, "eval.model_name", model_name)
|
||||
client.log_param(run_id, "eval.dataset", dataset_name)
|
||||
|
||||
|
||||
# Log metrics
|
||||
for key, value in metrics.items():
|
||||
client.log_metric(run_id, f"eval.{key}", float(value))
|
||||
|
||||
|
||||
# Log sample results as artifact
|
||||
if sample_results:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
@@ -501,13 +505,13 @@ def log_evaluation_results(
|
||||
with open(results_path, 'w') as f:
|
||||
json.dump(sample_results, f, indent=2)
|
||||
client.log_artifact(run_id, str(results_path))
|
||||
|
||||
|
||||
# Set tags
|
||||
client.set_tag(run_id, "pipeline.type", "evaluation")
|
||||
client.set_tag(run_id, "model.name", model_name)
|
||||
|
||||
|
||||
# Determine if passed
|
||||
passed = metrics.get("pass", metrics.get("accuracy", 0) >= 0.7)
|
||||
client.set_tag(run_id, "eval.passed", str(passed))
|
||||
|
||||
|
||||
return run_id
|
||||
|
||||
Reference in New Issue
Block a user