Files
mlflow/mlflow_utils/kfp_components.py
Billy D. ca5bef9664
Some checks failed
CI / Lint (push) Successful in 1m46s
CI / Test (push) Successful in 1m44s
CI / Publish (push) Failing after 19s
CI / Notify (push) Successful in 1s
style: apply ruff format to all files
2026-02-13 11:05:26 -05:00

490 lines
14 KiB
Python

"""
Kubeflow Pipeline Components with MLflow Tracking
Provides reusable KFP components that integrate MLflow experiment
tracking into Kubeflow Pipelines. These components can be used
directly in pipelines or as wrappers around existing pipeline steps.
Usage in a Kubeflow Pipeline:
from mlflow_utils.kfp_components import (
create_mlflow_run,
log_metrics_component,
log_model_artifact,
end_mlflow_run,
)
@dsl.pipeline(name="my-pipeline")
def my_pipeline():
# Start MLflow run
run_info = create_mlflow_run(
experiment_name="my-experiment",
run_name="training-run-1"
)
# ... your pipeline steps ...
# Log metrics
log_step = log_metrics_component(
run_id=run_info.outputs["run_id"],
metrics={"accuracy": 0.95, "loss": 0.05}
)
# End run
end_mlflow_run(run_id=run_info.outputs["run_id"])
"""
from typing import Any, Dict, List, NamedTuple
from kfp import dsl
# MLflow component image with all required dependencies
MLFLOW_IMAGE = "python:3.13-slim"
MLFLOW_PACKAGES = [
"mlflow>=2.10.0",
"boto3", # For S3 artifact storage if needed
"psycopg2-binary", # For PostgreSQL backend
]
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
def create_mlflow_run(
experiment_name: str,
run_name: str,
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
tags: Dict[str, str] = None,
params: Dict[str, str] = None,
) -> NamedTuple("RunInfo", [("run_id", str), ("experiment_id", str), ("artifact_uri", str)]):
"""
Create a new MLflow run for the pipeline.
This should be called at the start of a pipeline to initialize
tracking. The returned run_id should be passed to subsequent
components for logging.
Args:
experiment_name: Name of the MLflow experiment
run_name: Name for this specific run
mlflow_tracking_uri: MLflow tracking server URI
tags: Optional tags to add to the run
params: Optional parameters to log
Returns:
NamedTuple with run_id, experiment_id, and artifact_uri
"""
import os
from collections import namedtuple
import mlflow
from mlflow.tracking import MlflowClient
# Set tracking URI
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
# Get or create experiment
experiment = client.get_experiment_by_name(experiment_name)
if experiment is None:
experiment_id = client.create_experiment(
name=experiment_name, artifact_location=f"/mlflow/artifacts/{experiment_name}"
)
else:
experiment_id = experiment.experiment_id
# Create default tags
default_tags = {
"pipeline.type": "kubeflow",
"kfp.run_id": os.environ.get("KFP_RUN_ID", "unknown"),
"kfp.pod_name": os.environ.get("HOSTNAME", "unknown"),
}
if tags:
default_tags.update(tags)
# Start run
run = mlflow.start_run(
experiment_id=experiment_id,
run_name=run_name,
tags=default_tags,
)
# Log initial params
if params:
mlflow.log_params(params)
run_id = run.info.run_id
artifact_uri = run.info.artifact_uri
# End run (KFP components are isolated, we'll resume in other components)
mlflow.end_run()
RunInfo = namedtuple("RunInfo", ["run_id", "experiment_id", "artifact_uri"])
return RunInfo(run_id, experiment_id, artifact_uri)
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
def log_params_component(
run_id: str,
params: Dict[str, str],
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> str:
"""
Log parameters to an existing MLflow run.
Args:
run_id: The MLflow run ID to log to
params: Dictionary of parameters to log
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id for chaining
"""
import mlflow
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
for key, value in params.items():
client.log_param(run_id, key, str(value)[:500])
return run_id
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
def log_metrics_component(
run_id: str,
metrics: Dict[str, float],
step: int = 0,
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> str:
"""
Log metrics to an existing MLflow run.
Args:
run_id: The MLflow run ID to log to
metrics: Dictionary of metrics to log
step: Step number for time-series metrics
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id for chaining
"""
import mlflow
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
for key, value in metrics.items():
client.log_metric(run_id, key, float(value), step=step)
return run_id
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
def log_artifact_component(
run_id: str,
artifact_path: str,
artifact_name: str = "",
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> str:
"""
Log an artifact file to an existing MLflow run.
Args:
run_id: The MLflow run ID to log to
artifact_path: Path to the artifact file
artifact_name: Optional destination name in artifact store
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id for chaining
"""
import mlflow
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
client.log_artifact(run_id, artifact_path, artifact_name or None)
return run_id
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
def log_dict_artifact(
run_id: str,
data: Dict[str, Any],
filename: str,
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> str:
"""
Log a dictionary as a JSON artifact.
Args:
run_id: The MLflow run ID to log to
data: Dictionary to save as JSON
filename: Name for the JSON file
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id for chaining
"""
import json
import tempfile
from pathlib import Path
import mlflow
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
# Ensure .json extension
if not filename.endswith(".json"):
filename += ".json"
# Write to temp file and log
with tempfile.TemporaryDirectory() as tmpdir:
filepath = Path(tmpdir) / filename
with open(filepath, "w") as f:
json.dump(data, f, indent=2)
client.log_artifact(run_id, str(filepath))
return run_id
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
def end_mlflow_run(
run_id: str,
status: str = "FINISHED",
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> str:
"""
End an MLflow run with the specified status.
Args:
run_id: The MLflow run ID to end
status: Run status (FINISHED, FAILED, KILLED)
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id
"""
import mlflow
from mlflow.entities import RunStatus
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
status_map = {
"FINISHED": RunStatus.FINISHED,
"FAILED": RunStatus.FAILED,
"KILLED": RunStatus.KILLED,
}
run_status = status_map.get(status.upper(), RunStatus.FINISHED)
client.set_terminated(run_id, status=run_status)
return run_id
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES + ["httpx"])
def log_training_metrics(
run_id: str,
model_type: str,
training_config: Dict[str, Any],
final_metrics: Dict[str, float],
model_path: str = "",
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> str:
"""
Log comprehensive training metrics for ML models.
Designed for use with QLoRA training, voice training, and other
ML training pipelines in the llm-workflows repository.
Args:
run_id: The MLflow run ID to log to
model_type: Type of model (llm, stt, tts, embeddings)
training_config: Training configuration dict
final_metrics: Final training metrics
model_path: Path to saved model (if applicable)
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id for chaining
"""
import json
import tempfile
from pathlib import Path
import mlflow
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
# Log training config as params
flat_config = {}
for key, value in training_config.items():
if isinstance(value, (dict, list)):
flat_config[f"config.{key}"] = json.dumps(value)[:500]
else:
flat_config[f"config.{key}"] = str(value)[:500]
for key, value in flat_config.items():
client.log_param(run_id, key, value)
# Log model type tag
client.set_tag(run_id, "model.type", model_type)
# Log metrics
for key, value in final_metrics.items():
client.log_metric(run_id, key, float(value))
# Log full config as artifact
with tempfile.TemporaryDirectory() as tmpdir:
config_path = Path(tmpdir) / "training_config.json"
with open(config_path, "w") as f:
json.dump(training_config, f, indent=2)
client.log_artifact(run_id, str(config_path))
# Log model path if provided
if model_path:
client.log_param(run_id, "model.path", model_path)
client.set_tag(run_id, "model.saved", "true")
return run_id
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
def log_document_ingestion_metrics(
run_id: str,
source_url: str,
collection_name: str,
chunks_created: int,
documents_processed: int,
processing_time_seconds: float,
embeddings_model: str = "bge-small-en-v1.5",
chunk_size: int = 500,
chunk_overlap: int = 50,
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> str:
"""
Log document ingestion pipeline metrics.
Designed for use with the document_ingestion_pipeline.
Args:
run_id: The MLflow run ID to log to
source_url: URL of the source document
collection_name: Milvus collection name
chunks_created: Number of chunks created
documents_processed: Number of documents processed
processing_time_seconds: Total processing time
embeddings_model: Embeddings model used
chunk_size: Chunk size in tokens
chunk_overlap: Chunk overlap in tokens
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id for chaining
"""
import mlflow
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
# Log params
params = {
"source_url": source_url[:500],
"collection_name": collection_name,
"embeddings_model": embeddings_model,
"chunk_size": str(chunk_size),
"chunk_overlap": str(chunk_overlap),
}
for key, value in params.items():
client.log_param(run_id, key, value)
# Log metrics
metrics = {
"chunks_created": chunks_created,
"documents_processed": documents_processed,
"processing_time_seconds": processing_time_seconds,
"chunks_per_second": chunks_created / processing_time_seconds if processing_time_seconds > 0 else 0,
}
for key, value in metrics.items():
client.log_metric(run_id, key, float(value))
# Set pipeline type tag
client.set_tag(run_id, "pipeline.type", "document-ingestion")
client.set_tag(run_id, "milvus.collection", collection_name)
return run_id
@dsl.component(base_image=MLFLOW_IMAGE, packages_to_install=MLFLOW_PACKAGES)
def log_evaluation_results(
run_id: str,
model_name: str,
dataset_name: str,
metrics: Dict[str, float],
sample_results: List[Dict[str, Any]] = None,
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
) -> str:
"""
Log model evaluation results.
Designed for use with the evaluation_pipeline.
Args:
run_id: The MLflow run ID to log to
model_name: Name of the evaluated model
dataset_name: Name of the evaluation dataset
metrics: Evaluation metrics (accuracy, etc.)
sample_results: Optional sample predictions
mlflow_tracking_uri: MLflow tracking server URI
Returns:
The run_id for chaining
"""
import json
import tempfile
from pathlib import Path
import mlflow
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(mlflow_tracking_uri)
client = MlflowClient()
# Log params
client.log_param(run_id, "eval.model_name", model_name)
client.log_param(run_id, "eval.dataset", dataset_name)
# Log metrics
for key, value in metrics.items():
client.log_metric(run_id, f"eval.{key}", float(value))
# Log sample results as artifact
if sample_results:
with tempfile.TemporaryDirectory() as tmpdir:
results_path = Path(tmpdir) / "evaluation_results.json"
with open(results_path, "w") as f:
json.dump(sample_results, f, indent=2)
client.log_artifact(run_id, str(results_path))
# Set tags
client.set_tag(run_id, "pipeline.type", "evaluation")
client.set_tag(run_id, "model.name", model_name)
# Determine if passed
passed = metrics.get("pass", metrics.get("accuracy", 0) >= 0.7)
client.set_tag(run_id, "eval.passed", str(passed))
return run_id