Files
argo/qlora-training.yaml
Billy D. f3d7da9008 refactor: rename part-of label from llm-workflows to ai-ml-pipelines
llm-workflows repo has been phased out. Update labels to reflect
the new ai-ml-pipelines naming convention.
2026-02-02 17:41:27 -05:00

511 lines
19 KiB
YAML

# QLoRA Fine-tuning Workflow
# Trains QLora adapters from a reference model using data from Milvus vector database
# Triggered via NATS: ai.pipeline.trigger with pipeline="qlora-training"
---
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
name: qlora-training
namespace: ai-ml
labels:
app.kubernetes.io/name: qlora-training
app.kubernetes.io/part-of: ai-ml-pipelines
spec:
entrypoint: train-qlora
serviceAccountName: argo-workflow
arguments:
parameters:
- name: reference-model
description: "Base model to fine-tune (HuggingFace model ID or path)"
value: "mistralai/Mistral-7B-Instruct-v0.3"
- name: output-name
description: "Name for the output QLora adapter"
value: "qlora-adapter"
- name: milvus-collections
description: "Comma-separated list of Milvus collections to use (empty = all available)"
value: ""
- name: learning-rate
value: "2e-4"
description: "Learning rate for training"
- name: num-epochs
value: "3"
description: "Number of training epochs"
- name: batch-size
value: "4"
description: "Training batch size"
- name: max-seq-length
value: "2048"
description: "Maximum sequence length"
- name: lora-r
value: "64"
description: "LoRA attention dimension"
- name: lora-alpha
value: "16"
description: "LoRA alpha parameter"
- name: lora-dropout
value: "0.05"
description: "LoRA dropout rate"
volumeClaimTemplates:
- metadata:
name: model-storage
spec:
accessModes: ["ReadWriteMany"]
storageClassName: nfs-slow
resources:
requests:
storage: 50Gi
templates:
- name: train-qlora
dag:
tasks:
- name: fetch-training-data
template: fetch-data
arguments:
parameters:
- name: milvus-collections
value: "{{workflow.parameters.milvus-collections}}"
- name: prepare-dataset
template: prepare-data
dependencies: [fetch-training-data]
arguments:
parameters:
- name: max-seq-length
value: "{{workflow.parameters.max-seq-length}}"
artifacts:
- name: raw-data
from: "{{tasks.fetch-training-data.outputs.artifacts.raw-data}}"
- name: train-model
template: train
dependencies: [prepare-dataset]
arguments:
parameters:
- name: reference-model
value: "{{workflow.parameters.reference-model}}"
- name: output-name
value: "{{workflow.parameters.output-name}}"
- name: learning-rate
value: "{{workflow.parameters.learning-rate}}"
- name: num-epochs
value: "{{workflow.parameters.num-epochs}}"
- name: batch-size
value: "{{workflow.parameters.batch-size}}"
- name: max-seq-length
value: "{{workflow.parameters.max-seq-length}}"
- name: lora-r
value: "{{workflow.parameters.lora-r}}"
- name: lora-alpha
value: "{{workflow.parameters.lora-alpha}}"
- name: lora-dropout
value: "{{workflow.parameters.lora-dropout}}"
artifacts:
- name: training-data
from: "{{tasks.prepare-dataset.outputs.artifacts.training-data}}"
- name: fetch-data
inputs:
parameters:
- name: milvus-collections
outputs:
artifacts:
- name: raw-data
path: /tmp/raw-data
container:
image: python:3.13-slim
command: [python]
args:
- -c
- |
import subprocess
subprocess.run(["pip", "install", "pymilvus", "-q"], check=True)
import json
from pathlib import Path
from pymilvus import connections, Collection, utility
MILVUS_HOST = "milvus.ai-ml.svc.cluster.local"
MILVUS_PORT = 19530
collections_param = "{{inputs.parameters.milvus-collections}}"
output_dir = Path("/tmp/raw-data")
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Connecting to Milvus at {MILVUS_HOST}:{MILVUS_PORT}")
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
# Determine which collections to use
if collections_param and collections_param.strip():
collection_names = [c.strip() for c in collections_param.split(",")]
print(f"Using specified collections: {collection_names}")
else:
# Get all available collections
collection_names = utility.list_collections()
print(f"Using all available collections: {collection_names}")
all_training_data = []
for collection_name in collection_names:
if not utility.has_collection(collection_name):
print(f"Warning: Collection {collection_name} not found, skipping")
continue
print(f"Fetching data from collection: {collection_name}")
collection = Collection(collection_name)
collection.load()
# Query all data from the collection
# Note: Adjust field names based on your schema
try:
# Get collection schema to determine fields
schema = collection.schema
field_names = [field.name for field in schema.fields if field.name != "id"]
# Query all entities (limited to reasonable batch size)
# For large collections, you may want to implement pagination
results = collection.query(
expr="id >= 0",
output_fields=field_names,
limit=100000
)
print(f" Retrieved {len(results)} records from {collection_name}")
for result in results:
# Extract text field (adjust based on your schema)
text_content = result.get("text", "")
source = result.get("source", collection_name)
if text_content:
all_training_data.append({
"text": text_content,
"source": source,
"collection": collection_name
})
except Exception as e:
print(f"Error querying collection {collection_name}: {e}")
continue
# Save all training data
output_file = output_dir / "training_data.json"
with open(output_file, "w") as f:
json.dump({"data": all_training_data}, f)
print(f"Total training samples collected: {len(all_training_data)}")
print(f"Saved to {output_file}")
connections.disconnect("default")
envFrom:
- configMapRef:
name: ai-services-config
resources:
requests:
memory: 1Gi
cpu: 500m
- name: prepare-data
inputs:
parameters:
- name: max-seq-length
artifacts:
- name: raw-data
path: /tmp/raw-data
outputs:
artifacts:
- name: training-data
path: /tmp/training-data
container:
image: python:3.13-slim
command: [python]
args:
- -c
- |
import json
from pathlib import Path
max_seq_length = int("{{inputs.parameters.max-seq-length}}")
input_dir = Path("/tmp/raw-data")
output_dir = Path("/tmp/training-data")
output_dir.mkdir(parents=True, exist_ok=True)
# Load raw data
with open(input_dir / "training_data.json") as f:
data = json.load(f)
raw_samples = data["data"]
print(f"Processing {len(raw_samples)} raw samples")
# Prepare data in instruction format for fine-tuning
# Using Alpaca-style format: instruction + response
training_samples = []
for sample in raw_samples:
text = sample["text"]
source = sample.get("source", "")
# Create instruction-response pairs
# You can customize this based on your use case
training_sample = {
"instruction": f"Based on the following information from {source}, provide a comprehensive response:",
"input": text[:max_seq_length // 2], # Truncate if needed
"output": text[:max_seq_length // 2],
"source": source
}
training_samples.append(training_sample)
# Split into train/validation (90/10)
split_idx = int(len(training_samples) * 0.9)
train_data = training_samples[:split_idx]
val_data = training_samples[split_idx:]
# Save prepared datasets
with open(output_dir / "train.json", "w") as f:
json.dump(train_data, f)
with open(output_dir / "validation.json", "w") as f:
json.dump(val_data, f)
print(f"Prepared {len(train_data)} training samples")
print(f"Prepared {len(val_data)} validation samples")
print("Data preparation complete")
resources:
requests:
memory: 2Gi
cpu: 500m
- name: train
inputs:
parameters:
- name: reference-model
- name: output-name
- name: learning-rate
- name: num-epochs
- name: batch-size
- name: max-seq-length
- name: lora-r
- name: lora-alpha
- name: lora-dropout
artifacts:
- name: training-data
path: /tmp/training-data
container:
image: python:3.13-slim
command: [bash]
args:
- -c
- |
set -e
echo "Installing dependencies..."
pip install -q torch transformers peft datasets accelerate bitsandbytes scipy
echo "Starting QLoRA training..."
python << 'EOF'
import json
import os
from pathlib import Path
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch
# Parameters
reference_model = "{{inputs.parameters.reference-model}}"
output_name = "{{inputs.parameters.output-name}}"
learning_rate = float("{{inputs.parameters.learning-rate}}")
num_epochs = int("{{inputs.parameters.num-epochs}}")
batch_size = int("{{inputs.parameters.batch-size}}")
max_seq_length = int("{{inputs.parameters.max-seq-length}}")
lora_r = int("{{inputs.parameters.lora-r}}")
lora_alpha = int("{{inputs.parameters.lora-alpha}}")
lora_dropout = float("{{inputs.parameters.lora-dropout}}")
data_dir = Path("/tmp/training-data")
output_dir = Path("/mnt/model-storage") / output_name
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Training configuration:")
print(f" Model: {reference_model}")
print(f" Learning rate: {learning_rate}")
print(f" Epochs: {num_epochs}")
print(f" Batch size: {batch_size}")
print(f" Max sequence length: {max_seq_length}")
print(f" LoRA r: {lora_r}, alpha: {lora_alpha}, dropout: {lora_dropout}")
# Load datasets
with open(data_dir / "train.json") as f:
train_data = json.load(f)
with open(data_dir / "validation.json") as f:
val_data = json.load(f)
print(f"Loaded {len(train_data)} training samples, {len(val_data)} validation samples")
# Load tokenizer
print(f"Loading tokenizer from {reference_model}...")
tokenizer = AutoTokenizer.from_pretrained(reference_model, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Prepare datasets
def format_sample(sample):
instruction = sample.get("instruction", "")
input_text = sample.get("input", "")
output = sample.get("output", "")
if input_text:
prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{output}"
else:
prompt = f"### Instruction:\n{instruction}\n\n### Response:\n{output}"
return {"text": prompt}
train_dataset = Dataset.from_list([format_sample(s) for s in train_data])
val_dataset = Dataset.from_list([format_sample(s) for s in val_data])
# Tokenize datasets
def tokenize_function(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=max_seq_length,
padding="max_length"
)
print("Tokenizing datasets...")
train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
val_dataset = val_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
# Configure quantization for QLoRA
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
# Load model
print(f"Loading model {reference_model} with 4-bit quantization...")
model = AutoModelForCausalLM.from_pretrained(
reference_model,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
# Prepare model for training
model = prepare_model_for_kbit_training(model)
# Configure LoRA
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM"
)
# Add LoRA adapters
print("Adding LoRA adapters...")
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Training arguments
training_args = TrainingArguments(
output_dir=str(output_dir / "checkpoints"),
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
gradient_accumulation_steps=4,
learning_rate=learning_rate,
fp16=True,
logging_steps=10,
evaluation_strategy="steps",
eval_steps=50,
save_strategy="steps",
save_steps=100,
save_total_limit=3,
load_best_model_at_end=True,
report_to="none",
remove_unused_columns=False,
)
# Data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=data_collator,
)
# Train
print("Starting training...")
trainer.train()
# Save final model
print(f"Saving QLora adapter to {output_dir}")
model.save_pretrained(str(output_dir / "final"))
tokenizer.save_pretrained(str(output_dir / "final"))
# Save training metadata
metadata = {
"reference_model": reference_model,
"output_name": output_name,
"training_params": {
"learning_rate": learning_rate,
"num_epochs": num_epochs,
"batch_size": batch_size,
"max_seq_length": max_seq_length,
"lora_r": lora_r,
"lora_alpha": lora_alpha,
"lora_dropout": lora_dropout
},
"dataset_info": {
"train_samples": len(train_data),
"val_samples": len(val_data)
}
}
with open(output_dir / "metadata.json", "w") as f:
json.dump(metadata, f, indent=2)
print("Training complete!")
print(f"QLora adapter saved to: {output_dir}")
EOF
envFrom:
- configMapRef:
name: ai-services-config
volumeMounts:
- name: model-storage
mountPath: /mnt/model-storage
resources:
requests:
memory: 16Gi
cpu: 4
nvidia.com/gpu: 1
limits:
memory: 32Gi
cpu: 8
nvidia.com/gpu: 1