feat: Add ML training and batch inference workflows
- batch-inference: LLM inference with optional RAG - qlora-training: QLoRA adapter fine-tuning from Milvus - hybrid-ml-training: Multi-GPU distributed training - coqui-voice-training: XTTS voice cloning - document-ingestion: Ingest documents to Milvus - eventsource-kfp: Argo Events / Kubeflow integration - kfp-integration: Bridge between Argo and Kubeflow
This commit is contained in:
555
hybrid-ml-training.yaml
Normal file
555
hybrid-ml-training.yaml
Normal file
@@ -0,0 +1,555 @@
|
||||
# Hybrid ML Training Workflow
|
||||
# Combines Argo Workflows orchestration with Kubeflow Pipeline ML components
|
||||
# Use case: Train a model using data from Milvus, with checkpointing and evaluation
|
||||
---
|
||||
apiVersion: argoproj.io/v1alpha1
|
||||
kind: WorkflowTemplate
|
||||
metadata:
|
||||
name: hybrid-ml-training
|
||||
namespace: ai-ml
|
||||
labels:
|
||||
app.kubernetes.io/name: hybrid-ml-training
|
||||
app.kubernetes.io/part-of: llm-workflows
|
||||
annotations:
|
||||
description: |
|
||||
Demonstrates hybrid Argo+KFP workflow:
|
||||
- Argo handles orchestration, branching, retry logic
|
||||
- KFP pipelines handle ML-specific operations (with caching)
|
||||
- NATS for status updates to frontends
|
||||
spec:
|
||||
entrypoint: hybrid-training
|
||||
serviceAccountName: argo-workflow
|
||||
|
||||
# Artifact repository for model checkpoints
|
||||
artifactRepositoryRef:
|
||||
configMap: artifact-repository
|
||||
key: default
|
||||
|
||||
arguments:
|
||||
parameters:
|
||||
- name: collection-name
|
||||
description: "Milvus collection to pull training data from"
|
||||
value: "dnd_text_embeddings"
|
||||
- name: model-name
|
||||
description: "Base model for fine-tuning"
|
||||
value: "mistralai/Mistral-7B-v0.3"
|
||||
- name: lora-rank
|
||||
description: "LoRA rank (higher = more params)"
|
||||
value: "16"
|
||||
- name: epochs
|
||||
description: "Training epochs"
|
||||
value: "3"
|
||||
- name: batch-size
|
||||
description: "Training batch size"
|
||||
value: "4"
|
||||
- name: output-path
|
||||
description: "S3 path for model output"
|
||||
value: "s3://models/lora-adapters"
|
||||
- name: notify-nats
|
||||
description: "Publish status to NATS"
|
||||
value: "true"
|
||||
|
||||
# Volumes for GPU caching
|
||||
volumes:
|
||||
- name: model-cache
|
||||
persistentVolumeClaim:
|
||||
claimName: model-cache-pvc
|
||||
- name: shm
|
||||
emptyDir:
|
||||
medium: Memory
|
||||
sizeLimit: 16Gi
|
||||
|
||||
templates:
|
||||
# Main DAG orchestrating the workflow
|
||||
- name: hybrid-training
|
||||
dag:
|
||||
tasks:
|
||||
- name: notify-start
|
||||
template: nats-notify
|
||||
when: "{{workflow.parameters.notify-nats}} == true"
|
||||
arguments:
|
||||
parameters:
|
||||
- name: subject
|
||||
value: "ai.pipeline.status.{{workflow.name}}"
|
||||
- name: message
|
||||
value: '{"status": "started", "pipeline": "hybrid-ml-training"}'
|
||||
|
||||
- name: prepare-data
|
||||
template: extract-training-data
|
||||
arguments:
|
||||
parameters:
|
||||
- name: collection-name
|
||||
value: "{{workflow.parameters.collection-name}}"
|
||||
|
||||
- name: validate-data
|
||||
template: validate-dataset
|
||||
dependencies: [prepare-data]
|
||||
arguments:
|
||||
artifacts:
|
||||
- name: dataset
|
||||
from: "{{tasks.prepare-data.outputs.artifacts.dataset}}"
|
||||
|
||||
# KFP Pipeline: Run embedding generation if needed
|
||||
- name: generate-embeddings
|
||||
template: trigger-kfp
|
||||
dependencies: [validate-data]
|
||||
when: "{{tasks.validate-data.outputs.parameters.needs-embeddings}} == true"
|
||||
arguments:
|
||||
parameters:
|
||||
- name: pipeline-id
|
||||
value: "embedding-generation"
|
||||
- name: params
|
||||
value: '{"input_path": "{{tasks.prepare-data.outputs.parameters.data-path}}"}'
|
||||
|
||||
# Training step (runs on GPU)
|
||||
- name: train-lora
|
||||
template: lora-training
|
||||
dependencies: [validate-data, generate-embeddings]
|
||||
arguments:
|
||||
parameters:
|
||||
- name: model-name
|
||||
value: "{{workflow.parameters.model-name}}"
|
||||
- name: lora-rank
|
||||
value: "{{workflow.parameters.lora-rank}}"
|
||||
- name: epochs
|
||||
value: "{{workflow.parameters.epochs}}"
|
||||
- name: batch-size
|
||||
value: "{{workflow.parameters.batch-size}}"
|
||||
artifacts:
|
||||
- name: dataset
|
||||
from: "{{tasks.prepare-data.outputs.artifacts.dataset}}"
|
||||
|
||||
# Evaluate model
|
||||
- name: evaluate
|
||||
template: evaluate-model
|
||||
dependencies: [train-lora]
|
||||
arguments:
|
||||
artifacts:
|
||||
- name: adapter
|
||||
from: "{{tasks.train-lora.outputs.artifacts.adapter}}"
|
||||
|
||||
# Branch based on evaluation results
|
||||
- name: check-quality
|
||||
template: quality-gate
|
||||
dependencies: [evaluate]
|
||||
arguments:
|
||||
parameters:
|
||||
- name: eval-score
|
||||
value: "{{tasks.evaluate.outputs.parameters.score}}"
|
||||
- name: threshold
|
||||
value: "0.7"
|
||||
|
||||
# If quality is good, upload to S3
|
||||
- name: upload-model
|
||||
template: upload-to-s3
|
||||
dependencies: [check-quality]
|
||||
when: "{{tasks.check-quality.outputs.parameters.passed}} == true"
|
||||
arguments:
|
||||
parameters:
|
||||
- name: output-path
|
||||
value: "{{workflow.parameters.output-path}}"
|
||||
artifacts:
|
||||
- name: adapter
|
||||
from: "{{tasks.train-lora.outputs.artifacts.adapter}}"
|
||||
|
||||
# If quality is poor, trigger retraining with different params
|
||||
- name: retry-training
|
||||
template: adjust-and-retry
|
||||
dependencies: [check-quality]
|
||||
when: "{{tasks.check-quality.outputs.parameters.passed}} == false"
|
||||
arguments:
|
||||
parameters:
|
||||
- name: current-rank
|
||||
value: "{{workflow.parameters.lora-rank}}"
|
||||
- name: current-epochs
|
||||
value: "{{workflow.parameters.epochs}}"
|
||||
|
||||
- name: notify-complete
|
||||
template: nats-notify
|
||||
when: "{{workflow.parameters.notify-nats}} == true"
|
||||
dependencies: [upload-model]
|
||||
arguments:
|
||||
parameters:
|
||||
- name: subject
|
||||
value: "ai.pipeline.status.{{workflow.name}}"
|
||||
- name: message
|
||||
value: '{"status": "completed", "score": "{{tasks.evaluate.outputs.parameters.score}}"}'
|
||||
|
||||
# Extract training data from Milvus
|
||||
- name: extract-training-data
|
||||
inputs:
|
||||
parameters:
|
||||
- name: collection-name
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: dataset
|
||||
path: /tmp/dataset
|
||||
parameters:
|
||||
- name: data-path
|
||||
valueFrom:
|
||||
path: /tmp/data-path
|
||||
script:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
source: |
|
||||
import subprocess
|
||||
import sys
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "pymilvus", "pandas", "pyarrow"])
|
||||
|
||||
from pymilvus import connections, Collection
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
connections.connect(host="milvus.ai-ml.svc.cluster.local", port=19530)
|
||||
|
||||
collection = Collection("{{inputs.parameters.collection-name}}")
|
||||
collection.load()
|
||||
|
||||
# Query all training samples
|
||||
results = collection.query(
|
||||
expr="source != ''",
|
||||
output_fields=["text", "source", "metadata"],
|
||||
limit=10000
|
||||
)
|
||||
|
||||
# Convert to training format
|
||||
df = pd.DataFrame(results)
|
||||
output_dir = Path("/tmp/dataset")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save as parquet for efficient loading
|
||||
df.to_parquet(output_dir / "train.parquet")
|
||||
|
||||
print(f"Extracted {len(df)} samples")
|
||||
with open("/tmp/data-path", "w") as f:
|
||||
f.write(str(output_dir / "train.parquet"))
|
||||
|
||||
# Validate dataset
|
||||
- name: validate-dataset
|
||||
inputs:
|
||||
artifacts:
|
||||
- name: dataset
|
||||
path: /tmp/dataset
|
||||
outputs:
|
||||
parameters:
|
||||
- name: needs-embeddings
|
||||
valueFrom:
|
||||
path: /tmp/needs-embeddings
|
||||
- name: sample-count
|
||||
valueFrom:
|
||||
path: /tmp/sample-count
|
||||
script:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
source: |
|
||||
import subprocess
|
||||
import sys
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "pandas", "pyarrow"])
|
||||
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
dataset_dir = Path("/tmp/dataset")
|
||||
parquet_files = list(dataset_dir.glob("*.parquet"))
|
||||
|
||||
if not parquet_files:
|
||||
raise ValueError("No parquet files found in dataset")
|
||||
|
||||
df = pd.read_parquet(parquet_files[0])
|
||||
sample_count = len(df)
|
||||
print(f"Dataset contains {sample_count} samples")
|
||||
|
||||
# Check if embeddings column exists
|
||||
needs_embeddings = "embedding" not in df.columns
|
||||
|
||||
with open("/tmp/needs-embeddings", "w") as f:
|
||||
f.write(str(needs_embeddings).lower())
|
||||
|
||||
with open("/tmp/sample-count", "w") as f:
|
||||
f.write(str(sample_count))
|
||||
|
||||
# Trigger KFP pipeline
|
||||
- name: trigger-kfp
|
||||
inputs:
|
||||
parameters:
|
||||
- name: pipeline-id
|
||||
- name: params
|
||||
outputs:
|
||||
parameters:
|
||||
- name: run-id
|
||||
valueFrom:
|
||||
path: /tmp/run-id
|
||||
script:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
source: |
|
||||
import subprocess
|
||||
import sys
|
||||
import json
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "kfp==2.12.1"])
|
||||
|
||||
from kfp import Client
|
||||
|
||||
client = Client(host="http://ml-pipeline.kubeflow.svc.cluster.local:8888")
|
||||
params = json.loads('''{{inputs.parameters.params}}''')
|
||||
|
||||
run = client.create_run_from_pipeline_func(
|
||||
pipeline_func=None,
|
||||
pipeline_id="{{inputs.parameters.pipeline-id}}",
|
||||
arguments=params
|
||||
)
|
||||
|
||||
print(f"Triggered KFP pipeline: {run.run_id}")
|
||||
with open("/tmp/run-id", "w") as f:
|
||||
f.write(run.run_id)
|
||||
|
||||
# LoRA training (GPU)
|
||||
- name: lora-training
|
||||
inputs:
|
||||
parameters:
|
||||
- name: model-name
|
||||
- name: lora-rank
|
||||
- name: epochs
|
||||
- name: batch-size
|
||||
artifacts:
|
||||
- name: dataset
|
||||
path: /data/dataset
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: adapter
|
||||
path: /output/adapter
|
||||
- name: logs
|
||||
path: /output/logs
|
||||
podSpecPatch: |
|
||||
containers:
|
||||
- name: main
|
||||
resources:
|
||||
requests:
|
||||
amd.com/gpu: 1
|
||||
limits:
|
||||
amd.com/gpu: 1
|
||||
script:
|
||||
image: ghcr.io/billy-davies-2/lora-trainer:latest
|
||||
command: [python]
|
||||
env:
|
||||
- name: HF_HOME
|
||||
value: /cache/huggingface
|
||||
- name: TRANSFORMERS_CACHE
|
||||
value: /cache/huggingface
|
||||
volumeMounts:
|
||||
- name: model-cache
|
||||
mountPath: /cache
|
||||
- name: shm
|
||||
mountPath: /dev/shm
|
||||
source: |
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Training script
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
|
||||
model_name = "{{inputs.parameters.model-name}}"
|
||||
lora_rank = int("{{inputs.parameters.lora-rank}}")
|
||||
epochs = int("{{inputs.parameters.epochs}}")
|
||||
batch_size = int("{{inputs.parameters.batch-size}}")
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("parquet", data_files="/data/dataset/*.parquet", split="train")
|
||||
|
||||
# Load model
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# Configure LoRA
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_rank * 2,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
# Training
|
||||
training_args = TrainingArguments(
|
||||
output_dir="/output/adapter",
|
||||
num_train_epochs=epochs,
|
||||
per_device_train_batch_size=batch_size,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=2e-4,
|
||||
logging_dir="/output/logs",
|
||||
save_strategy="epoch"
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
args=training_args,
|
||||
dataset_text_field="text"
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model("/output/adapter")
|
||||
print("Training complete!")
|
||||
|
||||
# Evaluate model
|
||||
- name: evaluate-model
|
||||
inputs:
|
||||
artifacts:
|
||||
- name: adapter
|
||||
path: /input/adapter
|
||||
outputs:
|
||||
parameters:
|
||||
- name: score
|
||||
valueFrom:
|
||||
path: /tmp/score
|
||||
script:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
source: |
|
||||
import subprocess
|
||||
import sys
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "httpx"])
|
||||
|
||||
import httpx
|
||||
import json
|
||||
|
||||
# Run evaluation using vLLM with the adapter
|
||||
test_prompts = [
|
||||
"What is the capital of France?",
|
||||
"Explain machine learning in simple terms.",
|
||||
"Write a haiku about coding."
|
||||
]
|
||||
|
||||
scores = []
|
||||
with httpx.Client(timeout=120.0) as client:
|
||||
for prompt in test_prompts:
|
||||
response = client.post(
|
||||
"http://llm-draft.ai-ml.svc.cluster.local:8000/v1/chat/completions",
|
||||
json={
|
||||
"model": "local-adapter",
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 200
|
||||
}
|
||||
)
|
||||
# Simple scoring based on response coherence
|
||||
result = response.json()
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
score = min(1.0, len(content) / 100) # Placeholder scoring
|
||||
scores.append(score)
|
||||
|
||||
avg_score = sum(scores) / len(scores)
|
||||
print(f"Average evaluation score: {avg_score}")
|
||||
|
||||
with open("/tmp/score", "w") as f:
|
||||
f.write(str(round(avg_score, 3)))
|
||||
|
||||
# Quality gate
|
||||
- name: quality-gate
|
||||
inputs:
|
||||
parameters:
|
||||
- name: eval-score
|
||||
- name: threshold
|
||||
outputs:
|
||||
parameters:
|
||||
- name: passed
|
||||
valueFrom:
|
||||
path: /tmp/passed
|
||||
script:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
source: |
|
||||
score = float("{{inputs.parameters.eval-score}}")
|
||||
threshold = float("{{inputs.parameters.threshold}}")
|
||||
|
||||
passed = score >= threshold
|
||||
print(f"Score {score} {'passed' if passed else 'failed'} threshold {threshold}")
|
||||
|
||||
with open("/tmp/passed", "w") as f:
|
||||
f.write(str(passed).lower())
|
||||
|
||||
# Upload to S3
|
||||
- name: upload-to-s3
|
||||
inputs:
|
||||
parameters:
|
||||
- name: output-path
|
||||
artifacts:
|
||||
- name: adapter
|
||||
path: /input/adapter
|
||||
script:
|
||||
image: amazon/aws-cli:latest
|
||||
command: [bash]
|
||||
env:
|
||||
- name: AWS_ACCESS_KEY_ID
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: s3-credentials
|
||||
key: access-key
|
||||
- name: AWS_SECRET_ACCESS_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: s3-credentials
|
||||
key: secret-key
|
||||
- name: AWS_ENDPOINT_URL
|
||||
value: "https://quobjects.billy.davies.cloud"
|
||||
source: |
|
||||
aws s3 cp --recursive /input/adapter "{{inputs.parameters.output-path}}/$(date +%Y%m%d-%H%M%S)/"
|
||||
echo "Uploaded adapter to {{inputs.parameters.output-path}}"
|
||||
|
||||
# Adjust parameters and retry
|
||||
- name: adjust-and-retry
|
||||
inputs:
|
||||
parameters:
|
||||
- name: current-rank
|
||||
- name: current-epochs
|
||||
script:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
source: |
|
||||
current_rank = int("{{inputs.parameters.current-rank}}")
|
||||
current_epochs = int("{{inputs.parameters.current-epochs}}")
|
||||
|
||||
# Increase rank and epochs for next attempt
|
||||
new_rank = min(64, current_rank * 2)
|
||||
new_epochs = current_epochs + 2
|
||||
|
||||
print(f"Adjusting parameters: rank {current_rank}->{new_rank}, epochs {current_epochs}->{new_epochs}")
|
||||
print("TODO: Trigger new workflow with adjusted parameters")
|
||||
|
||||
# NATS notification
|
||||
- name: nats-notify
|
||||
inputs:
|
||||
parameters:
|
||||
- name: subject
|
||||
- name: message
|
||||
script:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
source: |
|
||||
import subprocess
|
||||
import sys
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "nats-py"])
|
||||
|
||||
import asyncio
|
||||
import nats
|
||||
|
||||
async def notify():
|
||||
nc = await nats.connect("nats://nats.ai-ml.svc.cluster.local:4222")
|
||||
await nc.publish(
|
||||
"{{inputs.parameters.subject}}",
|
||||
b'''{{inputs.parameters.message}}'''
|
||||
)
|
||||
await nc.close()
|
||||
|
||||
asyncio.run(notify())
|
||||
Reference in New Issue
Block a user