518 lines
14 KiB
Python
518 lines
14 KiB
Python
"""
|
|
Kubeflow Pipeline Components with MLflow Tracking
|
|
|
|
Provides reusable KFP components that integrate MLflow experiment
|
|
tracking into Kubeflow Pipelines. These components can be used
|
|
directly in pipelines or as wrappers around existing pipeline steps.
|
|
|
|
Usage in a Kubeflow Pipeline:
|
|
|
|
from mlflow_utils.kfp_components import (
|
|
create_mlflow_run,
|
|
log_metrics_component,
|
|
log_model_artifact,
|
|
end_mlflow_run,
|
|
)
|
|
|
|
@dsl.pipeline(name="my-pipeline")
|
|
def my_pipeline():
|
|
# Start MLflow run
|
|
run_info = create_mlflow_run(
|
|
experiment_name="my-experiment",
|
|
run_name="training-run-1"
|
|
)
|
|
|
|
# ... your pipeline steps ...
|
|
|
|
# Log metrics
|
|
log_step = log_metrics_component(
|
|
run_id=run_info.outputs["run_id"],
|
|
metrics={"accuracy": 0.95, "loss": 0.05}
|
|
)
|
|
|
|
# End run
|
|
end_mlflow_run(run_id=run_info.outputs["run_id"])
|
|
"""
|
|
|
|
from typing import Any, Dict, List, NamedTuple
|
|
|
|
from kfp import dsl
|
|
|
|
# MLflow component image with all required dependencies
|
|
MLFLOW_IMAGE = "python:3.13-slim"
|
|
MLFLOW_PACKAGES = [
|
|
"mlflow>=2.10.0",
|
|
"boto3", # For S3 artifact storage if needed
|
|
"psycopg2-binary", # For PostgreSQL backend
|
|
]
|
|
|
|
|
|
@dsl.component(
|
|
base_image=MLFLOW_IMAGE,
|
|
packages_to_install=MLFLOW_PACKAGES
|
|
)
|
|
def create_mlflow_run(
|
|
experiment_name: str,
|
|
run_name: str,
|
|
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
|
tags: Dict[str, str] = None,
|
|
params: Dict[str, str] = None,
|
|
) -> NamedTuple('RunInfo', [('run_id', str), ('experiment_id', str), ('artifact_uri', str)]):
|
|
"""
|
|
Create a new MLflow run for the pipeline.
|
|
|
|
This should be called at the start of a pipeline to initialize
|
|
tracking. The returned run_id should be passed to subsequent
|
|
components for logging.
|
|
|
|
Args:
|
|
experiment_name: Name of the MLflow experiment
|
|
run_name: Name for this specific run
|
|
mlflow_tracking_uri: MLflow tracking server URI
|
|
tags: Optional tags to add to the run
|
|
params: Optional parameters to log
|
|
|
|
Returns:
|
|
NamedTuple with run_id, experiment_id, and artifact_uri
|
|
"""
|
|
import os
|
|
from collections import namedtuple
|
|
|
|
import mlflow
|
|
from mlflow.tracking import MlflowClient
|
|
|
|
# Set tracking URI
|
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
|
|
|
client = MlflowClient()
|
|
|
|
# Get or create experiment
|
|
experiment = client.get_experiment_by_name(experiment_name)
|
|
if experiment is None:
|
|
experiment_id = client.create_experiment(
|
|
name=experiment_name,
|
|
artifact_location=f"/mlflow/artifacts/{experiment_name}"
|
|
)
|
|
else:
|
|
experiment_id = experiment.experiment_id
|
|
|
|
# Create default tags
|
|
default_tags = {
|
|
"pipeline.type": "kubeflow",
|
|
"kfp.run_id": os.environ.get("KFP_RUN_ID", "unknown"),
|
|
"kfp.pod_name": os.environ.get("HOSTNAME", "unknown"),
|
|
}
|
|
if tags:
|
|
default_tags.update(tags)
|
|
|
|
# Start run
|
|
run = mlflow.start_run(
|
|
experiment_id=experiment_id,
|
|
run_name=run_name,
|
|
tags=default_tags,
|
|
)
|
|
|
|
# Log initial params
|
|
if params:
|
|
mlflow.log_params(params)
|
|
|
|
run_id = run.info.run_id
|
|
artifact_uri = run.info.artifact_uri
|
|
|
|
# End run (KFP components are isolated, we'll resume in other components)
|
|
mlflow.end_run()
|
|
|
|
RunInfo = namedtuple('RunInfo', ['run_id', 'experiment_id', 'artifact_uri'])
|
|
return RunInfo(run_id, experiment_id, artifact_uri)
|
|
|
|
|
|
@dsl.component(
|
|
base_image=MLFLOW_IMAGE,
|
|
packages_to_install=MLFLOW_PACKAGES
|
|
)
|
|
def log_params_component(
|
|
run_id: str,
|
|
params: Dict[str, str],
|
|
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
|
) -> str:
|
|
"""
|
|
Log parameters to an existing MLflow run.
|
|
|
|
Args:
|
|
run_id: The MLflow run ID to log to
|
|
params: Dictionary of parameters to log
|
|
mlflow_tracking_uri: MLflow tracking server URI
|
|
|
|
Returns:
|
|
The run_id for chaining
|
|
"""
|
|
import mlflow
|
|
from mlflow.tracking import MlflowClient
|
|
|
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
|
client = MlflowClient()
|
|
|
|
for key, value in params.items():
|
|
client.log_param(run_id, key, str(value)[:500])
|
|
|
|
return run_id
|
|
|
|
|
|
@dsl.component(
|
|
base_image=MLFLOW_IMAGE,
|
|
packages_to_install=MLFLOW_PACKAGES
|
|
)
|
|
def log_metrics_component(
|
|
run_id: str,
|
|
metrics: Dict[str, float],
|
|
step: int = 0,
|
|
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
|
) -> str:
|
|
"""
|
|
Log metrics to an existing MLflow run.
|
|
|
|
Args:
|
|
run_id: The MLflow run ID to log to
|
|
metrics: Dictionary of metrics to log
|
|
step: Step number for time-series metrics
|
|
mlflow_tracking_uri: MLflow tracking server URI
|
|
|
|
Returns:
|
|
The run_id for chaining
|
|
"""
|
|
import mlflow
|
|
from mlflow.tracking import MlflowClient
|
|
|
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
|
client = MlflowClient()
|
|
|
|
for key, value in metrics.items():
|
|
client.log_metric(run_id, key, float(value), step=step)
|
|
|
|
return run_id
|
|
|
|
|
|
@dsl.component(
|
|
base_image=MLFLOW_IMAGE,
|
|
packages_to_install=MLFLOW_PACKAGES
|
|
)
|
|
def log_artifact_component(
|
|
run_id: str,
|
|
artifact_path: str,
|
|
artifact_name: str = "",
|
|
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
|
) -> str:
|
|
"""
|
|
Log an artifact file to an existing MLflow run.
|
|
|
|
Args:
|
|
run_id: The MLflow run ID to log to
|
|
artifact_path: Path to the artifact file
|
|
artifact_name: Optional destination name in artifact store
|
|
mlflow_tracking_uri: MLflow tracking server URI
|
|
|
|
Returns:
|
|
The run_id for chaining
|
|
"""
|
|
import mlflow
|
|
from mlflow.tracking import MlflowClient
|
|
|
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
|
client = MlflowClient()
|
|
|
|
client.log_artifact(run_id, artifact_path, artifact_name or None)
|
|
|
|
return run_id
|
|
|
|
|
|
@dsl.component(
|
|
base_image=MLFLOW_IMAGE,
|
|
packages_to_install=MLFLOW_PACKAGES
|
|
)
|
|
def log_dict_artifact(
|
|
run_id: str,
|
|
data: Dict[str, Any],
|
|
filename: str,
|
|
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
|
) -> str:
|
|
"""
|
|
Log a dictionary as a JSON artifact.
|
|
|
|
Args:
|
|
run_id: The MLflow run ID to log to
|
|
data: Dictionary to save as JSON
|
|
filename: Name for the JSON file
|
|
mlflow_tracking_uri: MLflow tracking server URI
|
|
|
|
Returns:
|
|
The run_id for chaining
|
|
"""
|
|
import json
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
import mlflow
|
|
from mlflow.tracking import MlflowClient
|
|
|
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
|
client = MlflowClient()
|
|
|
|
# Ensure .json extension
|
|
if not filename.endswith('.json'):
|
|
filename += '.json'
|
|
|
|
# Write to temp file and log
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
filepath = Path(tmpdir) / filename
|
|
with open(filepath, 'w') as f:
|
|
json.dump(data, f, indent=2)
|
|
client.log_artifact(run_id, str(filepath))
|
|
|
|
return run_id
|
|
|
|
|
|
@dsl.component(
|
|
base_image=MLFLOW_IMAGE,
|
|
packages_to_install=MLFLOW_PACKAGES
|
|
)
|
|
def end_mlflow_run(
|
|
run_id: str,
|
|
status: str = "FINISHED",
|
|
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
|
) -> str:
|
|
"""
|
|
End an MLflow run with the specified status.
|
|
|
|
Args:
|
|
run_id: The MLflow run ID to end
|
|
status: Run status (FINISHED, FAILED, KILLED)
|
|
mlflow_tracking_uri: MLflow tracking server URI
|
|
|
|
Returns:
|
|
The run_id
|
|
"""
|
|
import mlflow
|
|
from mlflow.entities import RunStatus
|
|
from mlflow.tracking import MlflowClient
|
|
|
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
|
client = MlflowClient()
|
|
|
|
status_map = {
|
|
"FINISHED": RunStatus.FINISHED,
|
|
"FAILED": RunStatus.FAILED,
|
|
"KILLED": RunStatus.KILLED,
|
|
}
|
|
|
|
run_status = status_map.get(status.upper(), RunStatus.FINISHED)
|
|
client.set_terminated(run_id, status=run_status)
|
|
|
|
return run_id
|
|
|
|
|
|
@dsl.component(
|
|
base_image=MLFLOW_IMAGE,
|
|
packages_to_install=MLFLOW_PACKAGES + ["httpx"]
|
|
)
|
|
def log_training_metrics(
|
|
run_id: str,
|
|
model_type: str,
|
|
training_config: Dict[str, Any],
|
|
final_metrics: Dict[str, float],
|
|
model_path: str = "",
|
|
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
|
) -> str:
|
|
"""
|
|
Log comprehensive training metrics for ML models.
|
|
|
|
Designed for use with QLoRA training, voice training, and other
|
|
ML training pipelines in the llm-workflows repository.
|
|
|
|
Args:
|
|
run_id: The MLflow run ID to log to
|
|
model_type: Type of model (llm, stt, tts, embeddings)
|
|
training_config: Training configuration dict
|
|
final_metrics: Final training metrics
|
|
model_path: Path to saved model (if applicable)
|
|
mlflow_tracking_uri: MLflow tracking server URI
|
|
|
|
Returns:
|
|
The run_id for chaining
|
|
"""
|
|
import json
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
import mlflow
|
|
from mlflow.tracking import MlflowClient
|
|
|
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
|
client = MlflowClient()
|
|
|
|
# Log training config as params
|
|
flat_config = {}
|
|
for key, value in training_config.items():
|
|
if isinstance(value, (dict, list)):
|
|
flat_config[f"config.{key}"] = json.dumps(value)[:500]
|
|
else:
|
|
flat_config[f"config.{key}"] = str(value)[:500]
|
|
|
|
for key, value in flat_config.items():
|
|
client.log_param(run_id, key, value)
|
|
|
|
# Log model type tag
|
|
client.set_tag(run_id, "model.type", model_type)
|
|
|
|
# Log metrics
|
|
for key, value in final_metrics.items():
|
|
client.log_metric(run_id, key, float(value))
|
|
|
|
# Log full config as artifact
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
config_path = Path(tmpdir) / "training_config.json"
|
|
with open(config_path, 'w') as f:
|
|
json.dump(training_config, f, indent=2)
|
|
client.log_artifact(run_id, str(config_path))
|
|
|
|
# Log model path if provided
|
|
if model_path:
|
|
client.log_param(run_id, "model.path", model_path)
|
|
client.set_tag(run_id, "model.saved", "true")
|
|
|
|
return run_id
|
|
|
|
|
|
@dsl.component(
|
|
base_image=MLFLOW_IMAGE,
|
|
packages_to_install=MLFLOW_PACKAGES
|
|
)
|
|
def log_document_ingestion_metrics(
|
|
run_id: str,
|
|
source_url: str,
|
|
collection_name: str,
|
|
chunks_created: int,
|
|
documents_processed: int,
|
|
processing_time_seconds: float,
|
|
embeddings_model: str = "bge-small-en-v1.5",
|
|
chunk_size: int = 500,
|
|
chunk_overlap: int = 50,
|
|
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
|
) -> str:
|
|
"""
|
|
Log document ingestion pipeline metrics.
|
|
|
|
Designed for use with the document_ingestion_pipeline.
|
|
|
|
Args:
|
|
run_id: The MLflow run ID to log to
|
|
source_url: URL of the source document
|
|
collection_name: Milvus collection name
|
|
chunks_created: Number of chunks created
|
|
documents_processed: Number of documents processed
|
|
processing_time_seconds: Total processing time
|
|
embeddings_model: Embeddings model used
|
|
chunk_size: Chunk size in tokens
|
|
chunk_overlap: Chunk overlap in tokens
|
|
mlflow_tracking_uri: MLflow tracking server URI
|
|
|
|
Returns:
|
|
The run_id for chaining
|
|
"""
|
|
import mlflow
|
|
from mlflow.tracking import MlflowClient
|
|
|
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
|
client = MlflowClient()
|
|
|
|
# Log params
|
|
params = {
|
|
"source_url": source_url[:500],
|
|
"collection_name": collection_name,
|
|
"embeddings_model": embeddings_model,
|
|
"chunk_size": str(chunk_size),
|
|
"chunk_overlap": str(chunk_overlap),
|
|
}
|
|
for key, value in params.items():
|
|
client.log_param(run_id, key, value)
|
|
|
|
# Log metrics
|
|
metrics = {
|
|
"chunks_created": chunks_created,
|
|
"documents_processed": documents_processed,
|
|
"processing_time_seconds": processing_time_seconds,
|
|
"chunks_per_second": chunks_created / processing_time_seconds if processing_time_seconds > 0 else 0,
|
|
}
|
|
for key, value in metrics.items():
|
|
client.log_metric(run_id, key, float(value))
|
|
|
|
# Set pipeline type tag
|
|
client.set_tag(run_id, "pipeline.type", "document-ingestion")
|
|
client.set_tag(run_id, "milvus.collection", collection_name)
|
|
|
|
return run_id
|
|
|
|
|
|
@dsl.component(
|
|
base_image=MLFLOW_IMAGE,
|
|
packages_to_install=MLFLOW_PACKAGES
|
|
)
|
|
def log_evaluation_results(
|
|
run_id: str,
|
|
model_name: str,
|
|
dataset_name: str,
|
|
metrics: Dict[str, float],
|
|
sample_results: List[Dict[str, Any]] = None,
|
|
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
|
) -> str:
|
|
"""
|
|
Log model evaluation results.
|
|
|
|
Designed for use with the evaluation_pipeline.
|
|
|
|
Args:
|
|
run_id: The MLflow run ID to log to
|
|
model_name: Name of the evaluated model
|
|
dataset_name: Name of the evaluation dataset
|
|
metrics: Evaluation metrics (accuracy, etc.)
|
|
sample_results: Optional sample predictions
|
|
mlflow_tracking_uri: MLflow tracking server URI
|
|
|
|
Returns:
|
|
The run_id for chaining
|
|
"""
|
|
import json
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
import mlflow
|
|
from mlflow.tracking import MlflowClient
|
|
|
|
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
|
client = MlflowClient()
|
|
|
|
# Log params
|
|
client.log_param(run_id, "eval.model_name", model_name)
|
|
client.log_param(run_id, "eval.dataset", dataset_name)
|
|
|
|
# Log metrics
|
|
for key, value in metrics.items():
|
|
client.log_metric(run_id, f"eval.{key}", float(value))
|
|
|
|
# Log sample results as artifact
|
|
if sample_results:
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
results_path = Path(tmpdir) / "evaluation_results.json"
|
|
with open(results_path, 'w') as f:
|
|
json.dump(sample_results, f, indent=2)
|
|
client.log_artifact(run_id, str(results_path))
|
|
|
|
# Set tags
|
|
client.set_tag(run_id, "pipeline.type", "evaluation")
|
|
client.set_tag(run_id, "model.name", model_name)
|
|
|
|
# Determine if passed
|
|
passed = metrics.get("pass", metrics.get("accuracy", 0) >= 0.7)
|
|
client.set_tag(run_id, "eval.passed", str(passed))
|
|
|
|
return run_id
|