feat: Add MLflow integration utilities
- client: Connection management and helpers - tracker: General experiment tracking - inference_tracker: Async metrics for NATS handlers - model_registry: Model registration with KServe metadata - kfp_components: Kubeflow Pipeline components - experiment_comparison: Run comparison tools - cli: Command-line interface
This commit is contained in:
513
mlflow_utils/kfp_components.py
Normal file
513
mlflow_utils/kfp_components.py
Normal file
@@ -0,0 +1,513 @@
|
||||
"""
|
||||
Kubeflow Pipeline Components with MLflow Tracking
|
||||
|
||||
Provides reusable KFP components that integrate MLflow experiment
|
||||
tracking into Kubeflow Pipelines. These components can be used
|
||||
directly in pipelines or as wrappers around existing pipeline steps.
|
||||
|
||||
Usage in a Kubeflow Pipeline:
|
||||
|
||||
from mlflow_utils.kfp_components import (
|
||||
create_mlflow_run,
|
||||
log_metrics_component,
|
||||
log_model_artifact,
|
||||
end_mlflow_run,
|
||||
)
|
||||
|
||||
@dsl.pipeline(name="my-pipeline")
|
||||
def my_pipeline():
|
||||
# Start MLflow run
|
||||
run_info = create_mlflow_run(
|
||||
experiment_name="my-experiment",
|
||||
run_name="training-run-1"
|
||||
)
|
||||
|
||||
# ... your pipeline steps ...
|
||||
|
||||
# Log metrics
|
||||
log_step = log_metrics_component(
|
||||
run_id=run_info.outputs["run_id"],
|
||||
metrics={"accuracy": 0.95, "loss": 0.05}
|
||||
)
|
||||
|
||||
# End run
|
||||
end_mlflow_run(run_id=run_info.outputs["run_id"])
|
||||
"""
|
||||
|
||||
from kfp import dsl
|
||||
from typing import Dict, Any, List, Optional, NamedTuple
|
||||
|
||||
|
||||
# MLflow component image with all required dependencies
|
||||
MLFLOW_IMAGE = "python:3.13-slim"
|
||||
MLFLOW_PACKAGES = [
|
||||
"mlflow>=2.10.0",
|
||||
"boto3", # For S3 artifact storage if needed
|
||||
"psycopg2-binary", # For PostgreSQL backend
|
||||
]
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
def create_mlflow_run(
|
||||
experiment_name: str,
|
||||
run_name: str,
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
tags: Dict[str, str] = None,
|
||||
params: Dict[str, str] = None,
|
||||
) -> NamedTuple('RunInfo', [('run_id', str), ('experiment_id', str), ('artifact_uri', str)]):
|
||||
"""
|
||||
Create a new MLflow run for the pipeline.
|
||||
|
||||
This should be called at the start of a pipeline to initialize
|
||||
tracking. The returned run_id should be passed to subsequent
|
||||
components for logging.
|
||||
|
||||
Args:
|
||||
experiment_name: Name of the MLflow experiment
|
||||
run_name: Name for this specific run
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
tags: Optional tags to add to the run
|
||||
params: Optional parameters to log
|
||||
|
||||
Returns:
|
||||
NamedTuple with run_id, experiment_id, and artifact_uri
|
||||
"""
|
||||
import os
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from collections import namedtuple
|
||||
|
||||
# Set tracking URI
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
|
||||
client = MlflowClient()
|
||||
|
||||
# Get or create experiment
|
||||
experiment = client.get_experiment_by_name(experiment_name)
|
||||
if experiment is None:
|
||||
experiment_id = client.create_experiment(
|
||||
name=experiment_name,
|
||||
artifact_location=f"/mlflow/artifacts/{experiment_name}"
|
||||
)
|
||||
else:
|
||||
experiment_id = experiment.experiment_id
|
||||
|
||||
# Create default tags
|
||||
default_tags = {
|
||||
"pipeline.type": "kubeflow",
|
||||
"kfp.run_id": os.environ.get("KFP_RUN_ID", "unknown"),
|
||||
"kfp.pod_name": os.environ.get("HOSTNAME", "unknown"),
|
||||
}
|
||||
if tags:
|
||||
default_tags.update(tags)
|
||||
|
||||
# Start run
|
||||
run = mlflow.start_run(
|
||||
experiment_id=experiment_id,
|
||||
run_name=run_name,
|
||||
tags=default_tags,
|
||||
)
|
||||
|
||||
# Log initial params
|
||||
if params:
|
||||
mlflow.log_params(params)
|
||||
|
||||
run_id = run.info.run_id
|
||||
artifact_uri = run.info.artifact_uri
|
||||
|
||||
# End run (KFP components are isolated, we'll resume in other components)
|
||||
mlflow.end_run()
|
||||
|
||||
RunInfo = namedtuple('RunInfo', ['run_id', 'experiment_id', 'artifact_uri'])
|
||||
return RunInfo(run_id, experiment_id, artifact_uri)
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
def log_params_component(
|
||||
run_id: str,
|
||||
params: Dict[str, str],
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> str:
|
||||
"""
|
||||
Log parameters to an existing MLflow run.
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
params: Dictionary of parameters to log
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
for key, value in params.items():
|
||||
client.log_param(run_id, key, str(value)[:500])
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
def log_metrics_component(
|
||||
run_id: str,
|
||||
metrics: Dict[str, float],
|
||||
step: int = 0,
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> str:
|
||||
"""
|
||||
Log metrics to an existing MLflow run.
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
metrics: Dictionary of metrics to log
|
||||
step: Step number for time-series metrics
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
for key, value in metrics.items():
|
||||
client.log_metric(run_id, key, float(value), step=step)
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
def log_artifact_component(
|
||||
run_id: str,
|
||||
artifact_path: str,
|
||||
artifact_name: str = "",
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> str:
|
||||
"""
|
||||
Log an artifact file to an existing MLflow run.
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
artifact_path: Path to the artifact file
|
||||
artifact_name: Optional destination name in artifact store
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
client.log_artifact(run_id, artifact_path, artifact_name or None)
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
def log_dict_artifact(
|
||||
run_id: str,
|
||||
data: Dict[str, Any],
|
||||
filename: str,
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> str:
|
||||
"""
|
||||
Log a dictionary as a JSON artifact.
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
data: Dictionary to save as JSON
|
||||
filename: Name for the JSON file
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import json
|
||||
import tempfile
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from pathlib import Path
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
# Ensure .json extension
|
||||
if not filename.endswith('.json'):
|
||||
filename += '.json'
|
||||
|
||||
# Write to temp file and log
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
filepath = Path(tmpdir) / filename
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
client.log_artifact(run_id, str(filepath))
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
def end_mlflow_run(
|
||||
run_id: str,
|
||||
status: str = "FINISHED",
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> str:
|
||||
"""
|
||||
End an MLflow run with the specified status.
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to end
|
||||
status: Run status (FINISHED, FAILED, KILLED)
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
Returns:
|
||||
The run_id
|
||||
"""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from mlflow.entities import RunStatus
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
status_map = {
|
||||
"FINISHED": RunStatus.FINISHED,
|
||||
"FAILED": RunStatus.FAILED,
|
||||
"KILLED": RunStatus.KILLED,
|
||||
}
|
||||
|
||||
run_status = status_map.get(status.upper(), RunStatus.FINISHED)
|
||||
client.set_terminated(run_id, status=run_status)
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES + ["httpx"]
|
||||
)
|
||||
def log_training_metrics(
|
||||
run_id: str,
|
||||
model_type: str,
|
||||
training_config: Dict[str, Any],
|
||||
final_metrics: Dict[str, float],
|
||||
model_path: str = "",
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> str:
|
||||
"""
|
||||
Log comprehensive training metrics for ML models.
|
||||
|
||||
Designed for use with QLoRA training, voice training, and other
|
||||
ML training pipelines in the llm-workflows repository.
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
model_type: Type of model (llm, stt, tts, embeddings)
|
||||
training_config: Training configuration dict
|
||||
final_metrics: Final training metrics
|
||||
model_path: Path to saved model (if applicable)
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import json
|
||||
import tempfile
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from pathlib import Path
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
# Log training config as params
|
||||
flat_config = {}
|
||||
for key, value in training_config.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
flat_config[f"config.{key}"] = json.dumps(value)[:500]
|
||||
else:
|
||||
flat_config[f"config.{key}"] = str(value)[:500]
|
||||
|
||||
for key, value in flat_config.items():
|
||||
client.log_param(run_id, key, value)
|
||||
|
||||
# Log model type tag
|
||||
client.set_tag(run_id, "model.type", model_type)
|
||||
|
||||
# Log metrics
|
||||
for key, value in final_metrics.items():
|
||||
client.log_metric(run_id, key, float(value))
|
||||
|
||||
# Log full config as artifact
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config_path = Path(tmpdir) / "training_config.json"
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(training_config, f, indent=2)
|
||||
client.log_artifact(run_id, str(config_path))
|
||||
|
||||
# Log model path if provided
|
||||
if model_path:
|
||||
client.log_param(run_id, "model.path", model_path)
|
||||
client.set_tag(run_id, "model.saved", "true")
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
def log_document_ingestion_metrics(
|
||||
run_id: str,
|
||||
source_url: str,
|
||||
collection_name: str,
|
||||
chunks_created: int,
|
||||
documents_processed: int,
|
||||
processing_time_seconds: float,
|
||||
embeddings_model: str = "bge-small-en-v1.5",
|
||||
chunk_size: int = 500,
|
||||
chunk_overlap: int = 50,
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> str:
|
||||
"""
|
||||
Log document ingestion pipeline metrics.
|
||||
|
||||
Designed for use with the document_ingestion_pipeline.
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
source_url: URL of the source document
|
||||
collection_name: Milvus collection name
|
||||
chunks_created: Number of chunks created
|
||||
documents_processed: Number of documents processed
|
||||
processing_time_seconds: Total processing time
|
||||
embeddings_model: Embeddings model used
|
||||
chunk_size: Chunk size in tokens
|
||||
chunk_overlap: Chunk overlap in tokens
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
# Log params
|
||||
params = {
|
||||
"source_url": source_url[:500],
|
||||
"collection_name": collection_name,
|
||||
"embeddings_model": embeddings_model,
|
||||
"chunk_size": str(chunk_size),
|
||||
"chunk_overlap": str(chunk_overlap),
|
||||
}
|
||||
for key, value in params.items():
|
||||
client.log_param(run_id, key, value)
|
||||
|
||||
# Log metrics
|
||||
metrics = {
|
||||
"chunks_created": chunks_created,
|
||||
"documents_processed": documents_processed,
|
||||
"processing_time_seconds": processing_time_seconds,
|
||||
"chunks_per_second": chunks_created / processing_time_seconds if processing_time_seconds > 0 else 0,
|
||||
}
|
||||
for key, value in metrics.items():
|
||||
client.log_metric(run_id, key, float(value))
|
||||
|
||||
# Set pipeline type tag
|
||||
client.set_tag(run_id, "pipeline.type", "document-ingestion")
|
||||
client.set_tag(run_id, "milvus.collection", collection_name)
|
||||
|
||||
return run_id
|
||||
|
||||
|
||||
@dsl.component(
|
||||
base_image=MLFLOW_IMAGE,
|
||||
packages_to_install=MLFLOW_PACKAGES
|
||||
)
|
||||
def log_evaluation_results(
|
||||
run_id: str,
|
||||
model_name: str,
|
||||
dataset_name: str,
|
||||
metrics: Dict[str, float],
|
||||
sample_results: List[Dict[str, Any]] = None,
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
) -> str:
|
||||
"""
|
||||
Log model evaluation results.
|
||||
|
||||
Designed for use with the evaluation_pipeline.
|
||||
|
||||
Args:
|
||||
run_id: The MLflow run ID to log to
|
||||
model_name: Name of the evaluated model
|
||||
dataset_name: Name of the evaluation dataset
|
||||
metrics: Evaluation metrics (accuracy, etc.)
|
||||
sample_results: Optional sample predictions
|
||||
mlflow_tracking_uri: MLflow tracking server URI
|
||||
|
||||
Returns:
|
||||
The run_id for chaining
|
||||
"""
|
||||
import json
|
||||
import tempfile
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
from pathlib import Path
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
client = MlflowClient()
|
||||
|
||||
# Log params
|
||||
client.log_param(run_id, "eval.model_name", model_name)
|
||||
client.log_param(run_id, "eval.dataset", dataset_name)
|
||||
|
||||
# Log metrics
|
||||
for key, value in metrics.items():
|
||||
client.log_metric(run_id, f"eval.{key}", float(value))
|
||||
|
||||
# Log sample results as artifact
|
||||
if sample_results:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
results_path = Path(tmpdir) / "evaluation_results.json"
|
||||
with open(results_path, 'w') as f:
|
||||
json.dump(sample_results, f, indent=2)
|
||||
client.log_artifact(run_id, str(results_path))
|
||||
|
||||
# Set tags
|
||||
client.set_tag(run_id, "pipeline.type", "evaluation")
|
||||
client.set_tag(run_id, "model.name", model_name)
|
||||
|
||||
# Determine if passed
|
||||
passed = metrics.get("pass", metrics.get("accuracy", 0) >= 0.7)
|
||||
client.set_tag(run_id, "eval.passed", str(passed))
|
||||
|
||||
return run_id
|
||||
Reference in New Issue
Block a user