llm-workflows repo has been phased out. Update labels to reflect the new ai-ml-pipelines naming convention.
556 lines
18 KiB
YAML
556 lines
18 KiB
YAML
# Hybrid ML Training Workflow
|
|
# Combines Argo Workflows orchestration with Kubeflow Pipeline ML components
|
|
# Use case: Train a model using data from Milvus, with checkpointing and evaluation
|
|
---
|
|
apiVersion: argoproj.io/v1alpha1
|
|
kind: WorkflowTemplate
|
|
metadata:
|
|
name: hybrid-ml-training
|
|
namespace: ai-ml
|
|
labels:
|
|
app.kubernetes.io/name: hybrid-ml-training
|
|
app.kubernetes.io/part-of: ai-ml-pipelines
|
|
annotations:
|
|
description: |
|
|
Demonstrates hybrid Argo+KFP workflow:
|
|
- Argo handles orchestration, branching, retry logic
|
|
- KFP pipelines handle ML-specific operations (with caching)
|
|
- NATS for status updates to frontends
|
|
spec:
|
|
entrypoint: hybrid-training
|
|
serviceAccountName: argo-workflow
|
|
|
|
# Artifact repository for model checkpoints
|
|
artifactRepositoryRef:
|
|
configMap: artifact-repository
|
|
key: default
|
|
|
|
arguments:
|
|
parameters:
|
|
- name: collection-name
|
|
description: "Milvus collection to pull training data from"
|
|
value: "dnd_text_embeddings"
|
|
- name: model-name
|
|
description: "Base model for fine-tuning"
|
|
value: "mistralai/Mistral-7B-v0.3"
|
|
- name: lora-rank
|
|
description: "LoRA rank (higher = more params)"
|
|
value: "16"
|
|
- name: epochs
|
|
description: "Training epochs"
|
|
value: "3"
|
|
- name: batch-size
|
|
description: "Training batch size"
|
|
value: "4"
|
|
- name: output-path
|
|
description: "S3 path for model output"
|
|
value: "s3://models/lora-adapters"
|
|
- name: notify-nats
|
|
description: "Publish status to NATS"
|
|
value: "true"
|
|
|
|
# Volumes for GPU caching
|
|
volumes:
|
|
- name: model-cache
|
|
persistentVolumeClaim:
|
|
claimName: model-cache-pvc
|
|
- name: shm
|
|
emptyDir:
|
|
medium: Memory
|
|
sizeLimit: 16Gi
|
|
|
|
templates:
|
|
# Main DAG orchestrating the workflow
|
|
- name: hybrid-training
|
|
dag:
|
|
tasks:
|
|
- name: notify-start
|
|
template: nats-notify
|
|
when: "{{workflow.parameters.notify-nats}} == true"
|
|
arguments:
|
|
parameters:
|
|
- name: subject
|
|
value: "ai.pipeline.status.{{workflow.name}}"
|
|
- name: message
|
|
value: '{"status": "started", "pipeline": "hybrid-ml-training"}'
|
|
|
|
- name: prepare-data
|
|
template: extract-training-data
|
|
arguments:
|
|
parameters:
|
|
- name: collection-name
|
|
value: "{{workflow.parameters.collection-name}}"
|
|
|
|
- name: validate-data
|
|
template: validate-dataset
|
|
dependencies: [prepare-data]
|
|
arguments:
|
|
artifacts:
|
|
- name: dataset
|
|
from: "{{tasks.prepare-data.outputs.artifacts.dataset}}"
|
|
|
|
# KFP Pipeline: Run embedding generation if needed
|
|
- name: generate-embeddings
|
|
template: trigger-kfp
|
|
dependencies: [validate-data]
|
|
when: "{{tasks.validate-data.outputs.parameters.needs-embeddings}} == true"
|
|
arguments:
|
|
parameters:
|
|
- name: pipeline-id
|
|
value: "embedding-generation"
|
|
- name: params
|
|
value: '{"input_path": "{{tasks.prepare-data.outputs.parameters.data-path}}"}'
|
|
|
|
# Training step (runs on GPU)
|
|
- name: train-lora
|
|
template: lora-training
|
|
dependencies: [validate-data, generate-embeddings]
|
|
arguments:
|
|
parameters:
|
|
- name: model-name
|
|
value: "{{workflow.parameters.model-name}}"
|
|
- name: lora-rank
|
|
value: "{{workflow.parameters.lora-rank}}"
|
|
- name: epochs
|
|
value: "{{workflow.parameters.epochs}}"
|
|
- name: batch-size
|
|
value: "{{workflow.parameters.batch-size}}"
|
|
artifacts:
|
|
- name: dataset
|
|
from: "{{tasks.prepare-data.outputs.artifacts.dataset}}"
|
|
|
|
# Evaluate model
|
|
- name: evaluate
|
|
template: evaluate-model
|
|
dependencies: [train-lora]
|
|
arguments:
|
|
artifacts:
|
|
- name: adapter
|
|
from: "{{tasks.train-lora.outputs.artifacts.adapter}}"
|
|
|
|
# Branch based on evaluation results
|
|
- name: check-quality
|
|
template: quality-gate
|
|
dependencies: [evaluate]
|
|
arguments:
|
|
parameters:
|
|
- name: eval-score
|
|
value: "{{tasks.evaluate.outputs.parameters.score}}"
|
|
- name: threshold
|
|
value: "0.7"
|
|
|
|
# If quality is good, upload to S3
|
|
- name: upload-model
|
|
template: upload-to-s3
|
|
dependencies: [check-quality]
|
|
when: "{{tasks.check-quality.outputs.parameters.passed}} == true"
|
|
arguments:
|
|
parameters:
|
|
- name: output-path
|
|
value: "{{workflow.parameters.output-path}}"
|
|
artifacts:
|
|
- name: adapter
|
|
from: "{{tasks.train-lora.outputs.artifacts.adapter}}"
|
|
|
|
# If quality is poor, trigger retraining with different params
|
|
- name: retry-training
|
|
template: adjust-and-retry
|
|
dependencies: [check-quality]
|
|
when: "{{tasks.check-quality.outputs.parameters.passed}} == false"
|
|
arguments:
|
|
parameters:
|
|
- name: current-rank
|
|
value: "{{workflow.parameters.lora-rank}}"
|
|
- name: current-epochs
|
|
value: "{{workflow.parameters.epochs}}"
|
|
|
|
- name: notify-complete
|
|
template: nats-notify
|
|
when: "{{workflow.parameters.notify-nats}} == true"
|
|
dependencies: [upload-model]
|
|
arguments:
|
|
parameters:
|
|
- name: subject
|
|
value: "ai.pipeline.status.{{workflow.name}}"
|
|
- name: message
|
|
value: '{"status": "completed", "score": "{{tasks.evaluate.outputs.parameters.score}}"}'
|
|
|
|
# Extract training data from Milvus
|
|
- name: extract-training-data
|
|
inputs:
|
|
parameters:
|
|
- name: collection-name
|
|
outputs:
|
|
artifacts:
|
|
- name: dataset
|
|
path: /tmp/dataset
|
|
parameters:
|
|
- name: data-path
|
|
valueFrom:
|
|
path: /tmp/data-path
|
|
script:
|
|
image: python:3.13-slim
|
|
command: [python]
|
|
source: |
|
|
import subprocess
|
|
import sys
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "pymilvus", "pandas", "pyarrow"])
|
|
|
|
from pymilvus import connections, Collection
|
|
import pandas as pd
|
|
from pathlib import Path
|
|
|
|
connections.connect(host="milvus.ai-ml.svc.cluster.local", port=19530)
|
|
|
|
collection = Collection("{{inputs.parameters.collection-name}}")
|
|
collection.load()
|
|
|
|
# Query all training samples
|
|
results = collection.query(
|
|
expr="source != ''",
|
|
output_fields=["text", "source", "metadata"],
|
|
limit=10000
|
|
)
|
|
|
|
# Convert to training format
|
|
df = pd.DataFrame(results)
|
|
output_dir = Path("/tmp/dataset")
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save as parquet for efficient loading
|
|
df.to_parquet(output_dir / "train.parquet")
|
|
|
|
print(f"Extracted {len(df)} samples")
|
|
with open("/tmp/data-path", "w") as f:
|
|
f.write(str(output_dir / "train.parquet"))
|
|
|
|
# Validate dataset
|
|
- name: validate-dataset
|
|
inputs:
|
|
artifacts:
|
|
- name: dataset
|
|
path: /tmp/dataset
|
|
outputs:
|
|
parameters:
|
|
- name: needs-embeddings
|
|
valueFrom:
|
|
path: /tmp/needs-embeddings
|
|
- name: sample-count
|
|
valueFrom:
|
|
path: /tmp/sample-count
|
|
script:
|
|
image: python:3.13-slim
|
|
command: [python]
|
|
source: |
|
|
import subprocess
|
|
import sys
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "pandas", "pyarrow"])
|
|
|
|
import pandas as pd
|
|
from pathlib import Path
|
|
|
|
dataset_dir = Path("/tmp/dataset")
|
|
parquet_files = list(dataset_dir.glob("*.parquet"))
|
|
|
|
if not parquet_files:
|
|
raise ValueError("No parquet files found in dataset")
|
|
|
|
df = pd.read_parquet(parquet_files[0])
|
|
sample_count = len(df)
|
|
print(f"Dataset contains {sample_count} samples")
|
|
|
|
# Check if embeddings column exists
|
|
needs_embeddings = "embedding" not in df.columns
|
|
|
|
with open("/tmp/needs-embeddings", "w") as f:
|
|
f.write(str(needs_embeddings).lower())
|
|
|
|
with open("/tmp/sample-count", "w") as f:
|
|
f.write(str(sample_count))
|
|
|
|
# Trigger KFP pipeline
|
|
- name: trigger-kfp
|
|
inputs:
|
|
parameters:
|
|
- name: pipeline-id
|
|
- name: params
|
|
outputs:
|
|
parameters:
|
|
- name: run-id
|
|
valueFrom:
|
|
path: /tmp/run-id
|
|
script:
|
|
image: python:3.13-slim
|
|
command: [python]
|
|
source: |
|
|
import subprocess
|
|
import sys
|
|
import json
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "kfp==2.12.1"])
|
|
|
|
from kfp import Client
|
|
|
|
client = Client(host="http://ml-pipeline.kubeflow.svc.cluster.local:8888")
|
|
params = json.loads('''{{inputs.parameters.params}}''')
|
|
|
|
run = client.create_run_from_pipeline_func(
|
|
pipeline_func=None,
|
|
pipeline_id="{{inputs.parameters.pipeline-id}}",
|
|
arguments=params
|
|
)
|
|
|
|
print(f"Triggered KFP pipeline: {run.run_id}")
|
|
with open("/tmp/run-id", "w") as f:
|
|
f.write(run.run_id)
|
|
|
|
# LoRA training (GPU)
|
|
- name: lora-training
|
|
inputs:
|
|
parameters:
|
|
- name: model-name
|
|
- name: lora-rank
|
|
- name: epochs
|
|
- name: batch-size
|
|
artifacts:
|
|
- name: dataset
|
|
path: /data/dataset
|
|
outputs:
|
|
artifacts:
|
|
- name: adapter
|
|
path: /output/adapter
|
|
- name: logs
|
|
path: /output/logs
|
|
podSpecPatch: |
|
|
containers:
|
|
- name: main
|
|
resources:
|
|
requests:
|
|
amd.com/gpu: 1
|
|
limits:
|
|
amd.com/gpu: 1
|
|
script:
|
|
image: ghcr.io/billy-davies-2/lora-trainer:latest
|
|
command: [python]
|
|
env:
|
|
- name: HF_HOME
|
|
value: /cache/huggingface
|
|
- name: TRANSFORMERS_CACHE
|
|
value: /cache/huggingface
|
|
volumeMounts:
|
|
- name: model-cache
|
|
mountPath: /cache
|
|
- name: shm
|
|
mountPath: /dev/shm
|
|
source: |
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Training script
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
|
|
from peft import LoraConfig, get_peft_model
|
|
from datasets import load_dataset
|
|
from trl import SFTTrainer
|
|
|
|
model_name = "{{inputs.parameters.model-name}}"
|
|
lora_rank = int("{{inputs.parameters.lora-rank}}")
|
|
epochs = int("{{inputs.parameters.epochs}}")
|
|
batch_size = int("{{inputs.parameters.batch-size}}")
|
|
|
|
# Load dataset
|
|
dataset = load_dataset("parquet", data_files="/data/dataset/*.parquet", split="train")
|
|
|
|
# Load model
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name,
|
|
torch_dtype="auto",
|
|
device_map="auto"
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
# Configure LoRA
|
|
lora_config = LoraConfig(
|
|
r=lora_rank,
|
|
lora_alpha=lora_rank * 2,
|
|
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
|
|
lora_dropout=0.05,
|
|
bias="none",
|
|
task_type="CAUSAL_LM"
|
|
)
|
|
|
|
model = get_peft_model(model, lora_config)
|
|
|
|
# Training
|
|
training_args = TrainingArguments(
|
|
output_dir="/output/adapter",
|
|
num_train_epochs=epochs,
|
|
per_device_train_batch_size=batch_size,
|
|
gradient_accumulation_steps=4,
|
|
learning_rate=2e-4,
|
|
logging_dir="/output/logs",
|
|
save_strategy="epoch"
|
|
)
|
|
|
|
trainer = SFTTrainer(
|
|
model=model,
|
|
train_dataset=dataset,
|
|
tokenizer=tokenizer,
|
|
args=training_args,
|
|
dataset_text_field="text"
|
|
)
|
|
|
|
trainer.train()
|
|
trainer.save_model("/output/adapter")
|
|
print("Training complete!")
|
|
|
|
# Evaluate model
|
|
- name: evaluate-model
|
|
inputs:
|
|
artifacts:
|
|
- name: adapter
|
|
path: /input/adapter
|
|
outputs:
|
|
parameters:
|
|
- name: score
|
|
valueFrom:
|
|
path: /tmp/score
|
|
script:
|
|
image: python:3.13-slim
|
|
command: [python]
|
|
source: |
|
|
import subprocess
|
|
import sys
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "httpx"])
|
|
|
|
import httpx
|
|
import json
|
|
|
|
# Run evaluation using vLLM with the adapter
|
|
test_prompts = [
|
|
"What is the capital of France?",
|
|
"Explain machine learning in simple terms.",
|
|
"Write a haiku about coding."
|
|
]
|
|
|
|
scores = []
|
|
with httpx.Client(timeout=120.0) as client:
|
|
for prompt in test_prompts:
|
|
response = client.post(
|
|
"http://llm-draft.ai-ml.svc.cluster.local:8000/v1/chat/completions",
|
|
json={
|
|
"model": "local-adapter",
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"max_tokens": 200
|
|
}
|
|
)
|
|
# Simple scoring based on response coherence
|
|
result = response.json()
|
|
content = result["choices"][0]["message"]["content"]
|
|
score = min(1.0, len(content) / 100) # Placeholder scoring
|
|
scores.append(score)
|
|
|
|
avg_score = sum(scores) / len(scores)
|
|
print(f"Average evaluation score: {avg_score}")
|
|
|
|
with open("/tmp/score", "w") as f:
|
|
f.write(str(round(avg_score, 3)))
|
|
|
|
# Quality gate
|
|
- name: quality-gate
|
|
inputs:
|
|
parameters:
|
|
- name: eval-score
|
|
- name: threshold
|
|
outputs:
|
|
parameters:
|
|
- name: passed
|
|
valueFrom:
|
|
path: /tmp/passed
|
|
script:
|
|
image: python:3.13-slim
|
|
command: [python]
|
|
source: |
|
|
score = float("{{inputs.parameters.eval-score}}")
|
|
threshold = float("{{inputs.parameters.threshold}}")
|
|
|
|
passed = score >= threshold
|
|
print(f"Score {score} {'passed' if passed else 'failed'} threshold {threshold}")
|
|
|
|
with open("/tmp/passed", "w") as f:
|
|
f.write(str(passed).lower())
|
|
|
|
# Upload to S3
|
|
- name: upload-to-s3
|
|
inputs:
|
|
parameters:
|
|
- name: output-path
|
|
artifacts:
|
|
- name: adapter
|
|
path: /input/adapter
|
|
script:
|
|
image: amazon/aws-cli:latest
|
|
command: [bash]
|
|
env:
|
|
- name: AWS_ACCESS_KEY_ID
|
|
valueFrom:
|
|
secretKeyRef:
|
|
name: s3-credentials
|
|
key: access-key
|
|
- name: AWS_SECRET_ACCESS_KEY
|
|
valueFrom:
|
|
secretKeyRef:
|
|
name: s3-credentials
|
|
key: secret-key
|
|
- name: AWS_ENDPOINT_URL
|
|
value: "https://quobjects.billy.davies.cloud"
|
|
source: |
|
|
aws s3 cp --recursive /input/adapter "{{inputs.parameters.output-path}}/$(date +%Y%m%d-%H%M%S)/"
|
|
echo "Uploaded adapter to {{inputs.parameters.output-path}}"
|
|
|
|
# Adjust parameters and retry
|
|
- name: adjust-and-retry
|
|
inputs:
|
|
parameters:
|
|
- name: current-rank
|
|
- name: current-epochs
|
|
script:
|
|
image: python:3.13-slim
|
|
command: [python]
|
|
source: |
|
|
current_rank = int("{{inputs.parameters.current-rank}}")
|
|
current_epochs = int("{{inputs.parameters.current-epochs}}")
|
|
|
|
# Increase rank and epochs for next attempt
|
|
new_rank = min(64, current_rank * 2)
|
|
new_epochs = current_epochs + 2
|
|
|
|
print(f"Adjusting parameters: rank {current_rank}->{new_rank}, epochs {current_epochs}->{new_epochs}")
|
|
print("TODO: Trigger new workflow with adjusted parameters")
|
|
|
|
# NATS notification
|
|
- name: nats-notify
|
|
inputs:
|
|
parameters:
|
|
- name: subject
|
|
- name: message
|
|
script:
|
|
image: python:3.13-slim
|
|
command: [python]
|
|
source: |
|
|
import subprocess
|
|
import sys
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "nats-py"])
|
|
|
|
import asyncio
|
|
import nats
|
|
|
|
async def notify():
|
|
nc = await nats.connect("nats://nats.ai-ml.svc.cluster.local:4222")
|
|
await nc.publish(
|
|
"{{inputs.parameters.subject}}",
|
|
b'''{{inputs.parameters.message}}'''
|
|
)
|
|
await nc.close()
|
|
|
|
asyncio.run(notify())
|