Files
argo/hybrid-ml-training.yaml
Billy D. f3d7da9008 refactor: rename part-of label from llm-workflows to ai-ml-pipelines
llm-workflows repo has been phased out. Update labels to reflect
the new ai-ml-pipelines naming convention.
2026-02-02 17:41:27 -05:00

556 lines
18 KiB
YAML

# Hybrid ML Training Workflow
# Combines Argo Workflows orchestration with Kubeflow Pipeline ML components
# Use case: Train a model using data from Milvus, with checkpointing and evaluation
---
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
name: hybrid-ml-training
namespace: ai-ml
labels:
app.kubernetes.io/name: hybrid-ml-training
app.kubernetes.io/part-of: ai-ml-pipelines
annotations:
description: |
Demonstrates hybrid Argo+KFP workflow:
- Argo handles orchestration, branching, retry logic
- KFP pipelines handle ML-specific operations (with caching)
- NATS for status updates to frontends
spec:
entrypoint: hybrid-training
serviceAccountName: argo-workflow
# Artifact repository for model checkpoints
artifactRepositoryRef:
configMap: artifact-repository
key: default
arguments:
parameters:
- name: collection-name
description: "Milvus collection to pull training data from"
value: "dnd_text_embeddings"
- name: model-name
description: "Base model for fine-tuning"
value: "mistralai/Mistral-7B-v0.3"
- name: lora-rank
description: "LoRA rank (higher = more params)"
value: "16"
- name: epochs
description: "Training epochs"
value: "3"
- name: batch-size
description: "Training batch size"
value: "4"
- name: output-path
description: "S3 path for model output"
value: "s3://models/lora-adapters"
- name: notify-nats
description: "Publish status to NATS"
value: "true"
# Volumes for GPU caching
volumes:
- name: model-cache
persistentVolumeClaim:
claimName: model-cache-pvc
- name: shm
emptyDir:
medium: Memory
sizeLimit: 16Gi
templates:
# Main DAG orchestrating the workflow
- name: hybrid-training
dag:
tasks:
- name: notify-start
template: nats-notify
when: "{{workflow.parameters.notify-nats}} == true"
arguments:
parameters:
- name: subject
value: "ai.pipeline.status.{{workflow.name}}"
- name: message
value: '{"status": "started", "pipeline": "hybrid-ml-training"}'
- name: prepare-data
template: extract-training-data
arguments:
parameters:
- name: collection-name
value: "{{workflow.parameters.collection-name}}"
- name: validate-data
template: validate-dataset
dependencies: [prepare-data]
arguments:
artifacts:
- name: dataset
from: "{{tasks.prepare-data.outputs.artifacts.dataset}}"
# KFP Pipeline: Run embedding generation if needed
- name: generate-embeddings
template: trigger-kfp
dependencies: [validate-data]
when: "{{tasks.validate-data.outputs.parameters.needs-embeddings}} == true"
arguments:
parameters:
- name: pipeline-id
value: "embedding-generation"
- name: params
value: '{"input_path": "{{tasks.prepare-data.outputs.parameters.data-path}}"}'
# Training step (runs on GPU)
- name: train-lora
template: lora-training
dependencies: [validate-data, generate-embeddings]
arguments:
parameters:
- name: model-name
value: "{{workflow.parameters.model-name}}"
- name: lora-rank
value: "{{workflow.parameters.lora-rank}}"
- name: epochs
value: "{{workflow.parameters.epochs}}"
- name: batch-size
value: "{{workflow.parameters.batch-size}}"
artifacts:
- name: dataset
from: "{{tasks.prepare-data.outputs.artifacts.dataset}}"
# Evaluate model
- name: evaluate
template: evaluate-model
dependencies: [train-lora]
arguments:
artifacts:
- name: adapter
from: "{{tasks.train-lora.outputs.artifacts.adapter}}"
# Branch based on evaluation results
- name: check-quality
template: quality-gate
dependencies: [evaluate]
arguments:
parameters:
- name: eval-score
value: "{{tasks.evaluate.outputs.parameters.score}}"
- name: threshold
value: "0.7"
# If quality is good, upload to S3
- name: upload-model
template: upload-to-s3
dependencies: [check-quality]
when: "{{tasks.check-quality.outputs.parameters.passed}} == true"
arguments:
parameters:
- name: output-path
value: "{{workflow.parameters.output-path}}"
artifacts:
- name: adapter
from: "{{tasks.train-lora.outputs.artifacts.adapter}}"
# If quality is poor, trigger retraining with different params
- name: retry-training
template: adjust-and-retry
dependencies: [check-quality]
when: "{{tasks.check-quality.outputs.parameters.passed}} == false"
arguments:
parameters:
- name: current-rank
value: "{{workflow.parameters.lora-rank}}"
- name: current-epochs
value: "{{workflow.parameters.epochs}}"
- name: notify-complete
template: nats-notify
when: "{{workflow.parameters.notify-nats}} == true"
dependencies: [upload-model]
arguments:
parameters:
- name: subject
value: "ai.pipeline.status.{{workflow.name}}"
- name: message
value: '{"status": "completed", "score": "{{tasks.evaluate.outputs.parameters.score}}"}'
# Extract training data from Milvus
- name: extract-training-data
inputs:
parameters:
- name: collection-name
outputs:
artifacts:
- name: dataset
path: /tmp/dataset
parameters:
- name: data-path
valueFrom:
path: /tmp/data-path
script:
image: python:3.13-slim
command: [python]
source: |
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "pymilvus", "pandas", "pyarrow"])
from pymilvus import connections, Collection
import pandas as pd
from pathlib import Path
connections.connect(host="milvus.ai-ml.svc.cluster.local", port=19530)
collection = Collection("{{inputs.parameters.collection-name}}")
collection.load()
# Query all training samples
results = collection.query(
expr="source != ''",
output_fields=["text", "source", "metadata"],
limit=10000
)
# Convert to training format
df = pd.DataFrame(results)
output_dir = Path("/tmp/dataset")
output_dir.mkdir(parents=True, exist_ok=True)
# Save as parquet for efficient loading
df.to_parquet(output_dir / "train.parquet")
print(f"Extracted {len(df)} samples")
with open("/tmp/data-path", "w") as f:
f.write(str(output_dir / "train.parquet"))
# Validate dataset
- name: validate-dataset
inputs:
artifacts:
- name: dataset
path: /tmp/dataset
outputs:
parameters:
- name: needs-embeddings
valueFrom:
path: /tmp/needs-embeddings
- name: sample-count
valueFrom:
path: /tmp/sample-count
script:
image: python:3.13-slim
command: [python]
source: |
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "pandas", "pyarrow"])
import pandas as pd
from pathlib import Path
dataset_dir = Path("/tmp/dataset")
parquet_files = list(dataset_dir.glob("*.parquet"))
if not parquet_files:
raise ValueError("No parquet files found in dataset")
df = pd.read_parquet(parquet_files[0])
sample_count = len(df)
print(f"Dataset contains {sample_count} samples")
# Check if embeddings column exists
needs_embeddings = "embedding" not in df.columns
with open("/tmp/needs-embeddings", "w") as f:
f.write(str(needs_embeddings).lower())
with open("/tmp/sample-count", "w") as f:
f.write(str(sample_count))
# Trigger KFP pipeline
- name: trigger-kfp
inputs:
parameters:
- name: pipeline-id
- name: params
outputs:
parameters:
- name: run-id
valueFrom:
path: /tmp/run-id
script:
image: python:3.13-slim
command: [python]
source: |
import subprocess
import sys
import json
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "kfp==2.12.1"])
from kfp import Client
client = Client(host="http://ml-pipeline.kubeflow.svc.cluster.local:8888")
params = json.loads('''{{inputs.parameters.params}}''')
run = client.create_run_from_pipeline_func(
pipeline_func=None,
pipeline_id="{{inputs.parameters.pipeline-id}}",
arguments=params
)
print(f"Triggered KFP pipeline: {run.run_id}")
with open("/tmp/run-id", "w") as f:
f.write(run.run_id)
# LoRA training (GPU)
- name: lora-training
inputs:
parameters:
- name: model-name
- name: lora-rank
- name: epochs
- name: batch-size
artifacts:
- name: dataset
path: /data/dataset
outputs:
artifacts:
- name: adapter
path: /output/adapter
- name: logs
path: /output/logs
podSpecPatch: |
containers:
- name: main
resources:
requests:
amd.com/gpu: 1
limits:
amd.com/gpu: 1
script:
image: ghcr.io/billy-davies-2/lora-trainer:latest
command: [python]
env:
- name: HF_HOME
value: /cache/huggingface
- name: TRANSFORMERS_CACHE
value: /cache/huggingface
volumeMounts:
- name: model-cache
mountPath: /cache
- name: shm
mountPath: /dev/shm
source: |
import os
import sys
from pathlib import Path
# Training script
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from trl import SFTTrainer
model_name = "{{inputs.parameters.model-name}}"
lora_rank = int("{{inputs.parameters.lora-rank}}")
epochs = int("{{inputs.parameters.epochs}}")
batch_size = int("{{inputs.parameters.batch-size}}")
# Load dataset
dataset = load_dataset("parquet", data_files="/data/dataset/*.parquet", split="train")
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Configure LoRA
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_rank * 2,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
# Training
training_args = TrainingArguments(
output_dir="/output/adapter",
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=4,
learning_rate=2e-4,
logging_dir="/output/logs",
save_strategy="epoch"
)
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
tokenizer=tokenizer,
args=training_args,
dataset_text_field="text"
)
trainer.train()
trainer.save_model("/output/adapter")
print("Training complete!")
# Evaluate model
- name: evaluate-model
inputs:
artifacts:
- name: adapter
path: /input/adapter
outputs:
parameters:
- name: score
valueFrom:
path: /tmp/score
script:
image: python:3.13-slim
command: [python]
source: |
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "httpx"])
import httpx
import json
# Run evaluation using vLLM with the adapter
test_prompts = [
"What is the capital of France?",
"Explain machine learning in simple terms.",
"Write a haiku about coding."
]
scores = []
with httpx.Client(timeout=120.0) as client:
for prompt in test_prompts:
response = client.post(
"http://llm-draft.ai-ml.svc.cluster.local:8000/v1/chat/completions",
json={
"model": "local-adapter",
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 200
}
)
# Simple scoring based on response coherence
result = response.json()
content = result["choices"][0]["message"]["content"]
score = min(1.0, len(content) / 100) # Placeholder scoring
scores.append(score)
avg_score = sum(scores) / len(scores)
print(f"Average evaluation score: {avg_score}")
with open("/tmp/score", "w") as f:
f.write(str(round(avg_score, 3)))
# Quality gate
- name: quality-gate
inputs:
parameters:
- name: eval-score
- name: threshold
outputs:
parameters:
- name: passed
valueFrom:
path: /tmp/passed
script:
image: python:3.13-slim
command: [python]
source: |
score = float("{{inputs.parameters.eval-score}}")
threshold = float("{{inputs.parameters.threshold}}")
passed = score >= threshold
print(f"Score {score} {'passed' if passed else 'failed'} threshold {threshold}")
with open("/tmp/passed", "w") as f:
f.write(str(passed).lower())
# Upload to S3
- name: upload-to-s3
inputs:
parameters:
- name: output-path
artifacts:
- name: adapter
path: /input/adapter
script:
image: amazon/aws-cli:latest
command: [bash]
env:
- name: AWS_ACCESS_KEY_ID
valueFrom:
secretKeyRef:
name: s3-credentials
key: access-key
- name: AWS_SECRET_ACCESS_KEY
valueFrom:
secretKeyRef:
name: s3-credentials
key: secret-key
- name: AWS_ENDPOINT_URL
value: "https://quobjects.billy.davies.cloud"
source: |
aws s3 cp --recursive /input/adapter "{{inputs.parameters.output-path}}/$(date +%Y%m%d-%H%M%S)/"
echo "Uploaded adapter to {{inputs.parameters.output-path}}"
# Adjust parameters and retry
- name: adjust-and-retry
inputs:
parameters:
- name: current-rank
- name: current-epochs
script:
image: python:3.13-slim
command: [python]
source: |
current_rank = int("{{inputs.parameters.current-rank}}")
current_epochs = int("{{inputs.parameters.current-epochs}}")
# Increase rank and epochs for next attempt
new_rank = min(64, current_rank * 2)
new_epochs = current_epochs + 2
print(f"Adjusting parameters: rank {current_rank}->{new_rank}, epochs {current_epochs}->{new_epochs}")
print("TODO: Trigger new workflow with adjusted parameters")
# NATS notification
- name: nats-notify
inputs:
parameters:
- name: subject
- name: message
script:
image: python:3.13-slim
command: [python]
source: |
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "nats-py"])
import asyncio
import nats
async def notify():
nc = await nats.connect("nats://nats.ai-ml.svc.cluster.local:4222")
await nc.publish(
"{{inputs.parameters.subject}}",
b'''{{inputs.parameters.message}}'''
)
await nc.close()
asyncio.run(notify())