fix: resolve all ruff lint errors
Some checks failed
CI / Test (push) Successful in 1m46s
CI / Lint (push) Failing after 1m49s
CI / Publish (push) Has been skipped
CI / Notify (push) Successful in 2s

This commit is contained in:
2026-02-13 10:57:57 -05:00
parent 6bcf84549c
commit 1c841729a0
9 changed files with 456 additions and 464 deletions

View File

@@ -21,22 +21,22 @@ Usage in a Kubeflow Pipeline:
experiment_name="my-experiment",
run_name="training-run-1"
)
# ... your pipeline steps ...
# Log metrics
log_step = log_metrics_component(
run_id=run_info.outputs["run_id"],
metrics={"accuracy": 0.95, "loss": 0.05}
)
# End run
end_mlflow_run(run_id=run_info.outputs["run_id"])
"""
from kfp import dsl
from typing import Dict, Any, List, Optional, NamedTuple
from typing import Any, Dict, List, NamedTuple
from kfp import dsl
# MLflow component image with all required dependencies
MLFLOW_IMAGE = "python:3.13-slim"
@@ -60,31 +60,32 @@ def create_mlflow_run(
) -> NamedTuple('RunInfo', [('run_id', str), ('experiment_id', str), ('artifact_uri', str)]):
"""
Create a new MLflow run for the pipeline.
This should be called at the start of a pipeline to initialize
tracking. The returned run_id should be passed to subsequent
components for logging.
Args:
experiment_name: Name of the MLflow experiment
run_name: Name for this specific run
mlflow_tracking_uri: MLflow tracking server URI
tags: Optional tags to add to the run
params: Optional parameters to log
Returns:
NamedTuple with run_id, experiment_id, and artifact_uri
"""
import os
from collections import namedtuple
import mlflow
from mlflow.tracking import MlflowClient
from collections import namedtuple
# Set tracking URI
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
# Get or create experiment
experiment = client.get_experiment_by_name(experiment_name)
if experiment is None:
@@ -94,7 +95,7 @@ def create_mlflow_run(
)
else:
experiment_id = experiment.experiment_id
# Create default tags
default_tags = {
"pipeline.type": "kubeflow",
@@ -103,24 +104,24 @@ def create_mlflow_run(
}
if tags:
default_tags.update(tags)
# Start run
run = mlflow.start_run(
experiment_id=experiment_id,
run_name=run_name,
tags=default_tags,
)
# Log initial params
if params:
mlflow.log_params(params)
run_id = run.info.run_id
artifact_uri = run.info.artifact_uri
# End run (KFP components are isolated, we'll resume in other components)
mlflow.end_run()
RunInfo = namedtuple('RunInfo', ['run_id', 'experiment_id', 'artifact_uri'])
return RunInfo(run_id, experiment_id, artifact_uri)
@@ -136,24 +137,24 @@ def log_params_component(
) -> str:
"""
Log parameters to an existing MLflow run.
Args:
run_id: The MLflow run ID to log to
params: Dictionary of parameters to log
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id for chaining
"""
import mlflow
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
for key, value in params.items():
client.log_param(run_id, key, str(value)[:500])
return run_id
@@ -169,25 +170,25 @@ def log_metrics_component(
) -> str:
"""
Log metrics to an existing MLflow run.
Args:
run_id: The MLflow run ID to log to
metrics: Dictionary of metrics to log
step: Step number for time-series metrics
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id for chaining
"""
import mlflow
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
for key, value in metrics.items():
client.log_metric(run_id, key, float(value), step=step)
return run_id
@@ -203,24 +204,24 @@ def log_artifact_component(
) -> str:
"""
Log an artifact file to an existing MLflow run.
Args:
run_id: The MLflow run ID to log to
artifact_path: Path to the artifact file
artifact_name: Optional destination name in artifact store
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id for chaining
"""
import mlflow
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
client.log_artifact(run_id, artifact_path, artifact_name or None)
return run_id
@@ -236,36 +237,37 @@ def log_dict_artifact(
) -> str:
"""
Log a dictionary as a JSON artifact.
Args:
run_id: The MLflow run ID to log to
data: Dictionary to save as JSON
filename: Name for the JSON file
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id for chaining
"""
import json
import tempfile
from pathlib import Path
import mlflow
from mlflow.tracking import MlflowClient
from pathlib import Path
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
# Ensure .json extension
if not filename.endswith('.json'):
filename += '.json'
# Write to temp file and log
with tempfile.TemporaryDirectory() as tmpdir:
filepath = Path(tmpdir) / filename
with open(filepath, 'w') as f:
json.dump(data, f, indent=2)
client.log_artifact(run_id, str(filepath))
return run_id
@@ -280,31 +282,31 @@ def end_mlflow_run(
) -> str:
"""
End an MLflow run with the specified status.
Args:
run_id: The MLflow run ID to end
status: Run status (FINISHED, FAILED, KILLED)
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id
"""
import mlflow
from mlflow.tracking import MlflowClient
from mlflow.entities import RunStatus
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
status_map = {
"FINISHED": RunStatus.FINISHED,
"FAILED": RunStatus.FAILED,
"KILLED": RunStatus.KILLED,
}
run_status = status_map.get(status.upper(), RunStatus.FINISHED)
client.set_terminated(run_id, status=run_status)
return run_id
@@ -322,10 +324,10 @@ def log_training_metrics(
) -> str:
"""
Log comprehensive training metrics for ML models.
Designed for use with QLoRA training, voice training, and other
ML training pipelines in the llm-workflows repository.
Args:
run_id: The MLflow run ID to log to
model_type: Type of model (llm, stt, tts, embeddings)
@@ -333,19 +335,20 @@ def log_training_metrics(
final_metrics: Final training metrics
model_path: Path to saved model (if applicable)
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id for chaining
"""
import json
import tempfile
from pathlib import Path
import mlflow
from mlflow.tracking import MlflowClient
from pathlib import Path
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
# Log training config as params
flat_config = {}
for key, value in training_config.items():
@@ -353,29 +356,29 @@ def log_training_metrics(
flat_config[f"config.{key}"] = json.dumps(value)[:500]
else:
flat_config[f"config.{key}"] = str(value)[:500]
for key, value in flat_config.items():
client.log_param(run_id, key, value)
# Log model type tag
client.set_tag(run_id, "model.type", model_type)
# Log metrics
for key, value in final_metrics.items():
client.log_metric(run_id, key, float(value))
# Log full config as artifact
with tempfile.TemporaryDirectory() as tmpdir:
config_path = Path(tmpdir) / "training_config.json"
with open(config_path, 'w') as f:
json.dump(training_config, f, indent=2)
client.log_artifact(run_id, str(config_path))
# Log model path if provided
if model_path:
client.log_param(run_id, "model.path", model_path)
client.set_tag(run_id, "model.saved", "true")
return run_id
@@ -397,9 +400,9 @@ def log_document_ingestion_metrics(
) -> str:
"""
Log document ingestion pipeline metrics.
Designed for use with the document_ingestion_pipeline.
Args:
run_id: The MLflow run ID to log to
source_url: URL of the source document
@@ -411,16 +414,16 @@ def log_document_ingestion_metrics(
chunk_size: Chunk size in tokens
chunk_overlap: Chunk overlap in tokens
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id for chaining
"""
import mlflow
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
# Log params
params = {
"source_url": source_url[:500],
@@ -431,7 +434,7 @@ def log_document_ingestion_metrics(
}
for key, value in params.items():
client.log_param(run_id, key, value)
# Log metrics
metrics = {
"chunks_created": chunks_created,
@@ -441,11 +444,11 @@ def log_document_ingestion_metrics(
}
for key, value in metrics.items():
client.log_metric(run_id, key, float(value))
# Set pipeline type tag
client.set_tag(run_id, "pipeline.type", "document-ingestion")
client.set_tag(run_id, "milvus.collection", collection_name)
return run_id
@@ -463,9 +466,9 @@ def log_evaluation_results(
) -> str:
"""
Log model evaluation results.
Designed for use with the evaluation_pipeline.
Args:
run_id: The MLflow run ID to log to
model_name: Name of the evaluated model
@@ -473,27 +476,28 @@ def log_evaluation_results(
metrics: Evaluation metrics (accuracy, etc.)
sample_results: Optional sample predictions
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id for chaining
"""
import json
import tempfile
from pathlib import Path
import mlflow
from mlflow.tracking import MlflowClient
from pathlib import Path
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
# Log params
client.log_param(run_id, "eval.model_name", model_name)
client.log_param(run_id, "eval.dataset", dataset_name)
# Log metrics
for key, value in metrics.items():
client.log_metric(run_id, f"eval.{key}", float(value))
# Log sample results as artifact
if sample_results:
with tempfile.TemporaryDirectory() as tmpdir:
@@ -501,13 +505,13 @@ def log_evaluation_results(
with open(results_path, 'w') as f:
json.dump(sample_results, f, indent=2)
client.log_artifact(run_id, str(results_path))
# Set tags
client.set_tag(run_id, "pipeline.type", "evaluation")
client.set_tag(run_id, "model.name", model_name)
# Determine if passed
passed = metrics.get("pass", metrics.get("accuracy", 0) >= 0.7)
client.set_tag(run_id, "eval.passed", str(passed))
return run_id