# QLoRA Fine-tuning Workflow # Trains QLora adapters from a reference model using data from Milvus vector database # Triggered via NATS: ai.pipeline.trigger with pipeline="qlora-training" --- apiVersion: argoproj.io/v1alpha1 kind: WorkflowTemplate metadata: name: qlora-training namespace: ai-ml labels: app.kubernetes.io/name: qlora-training app.kubernetes.io/part-of: ai-ml-pipelines spec: entrypoint: train-qlora serviceAccountName: argo-workflow arguments: parameters: - name: reference-model description: "Base model to fine-tune (HuggingFace model ID or path)" value: "mistralai/Mistral-7B-Instruct-v0.3" - name: output-name description: "Name for the output QLora adapter" value: "qlora-adapter" - name: milvus-collections description: "Comma-separated list of Milvus collections to use (empty = all available)" value: "" - name: learning-rate value: "2e-4" description: "Learning rate for training" - name: num-epochs value: "3" description: "Number of training epochs" - name: batch-size value: "4" description: "Training batch size" - name: max-seq-length value: "2048" description: "Maximum sequence length" - name: lora-r value: "64" description: "LoRA attention dimension" - name: lora-alpha value: "16" description: "LoRA alpha parameter" - name: lora-dropout value: "0.05" description: "LoRA dropout rate" volumeClaimTemplates: - metadata: name: model-storage spec: accessModes: ["ReadWriteMany"] storageClassName: nfs-slow resources: requests: storage: 50Gi templates: - name: train-qlora dag: tasks: - name: fetch-training-data template: fetch-data arguments: parameters: - name: milvus-collections value: "{{workflow.parameters.milvus-collections}}" - name: prepare-dataset template: prepare-data dependencies: [fetch-training-data] arguments: parameters: - name: max-seq-length value: "{{workflow.parameters.max-seq-length}}" artifacts: - name: raw-data from: "{{tasks.fetch-training-data.outputs.artifacts.raw-data}}" - name: train-model template: train dependencies: [prepare-dataset] arguments: parameters: - name: reference-model value: "{{workflow.parameters.reference-model}}" - name: output-name value: "{{workflow.parameters.output-name}}" - name: learning-rate value: "{{workflow.parameters.learning-rate}}" - name: num-epochs value: "{{workflow.parameters.num-epochs}}" - name: batch-size value: "{{workflow.parameters.batch-size}}" - name: max-seq-length value: "{{workflow.parameters.max-seq-length}}" - name: lora-r value: "{{workflow.parameters.lora-r}}" - name: lora-alpha value: "{{workflow.parameters.lora-alpha}}" - name: lora-dropout value: "{{workflow.parameters.lora-dropout}}" artifacts: - name: training-data from: "{{tasks.prepare-dataset.outputs.artifacts.training-data}}" - name: fetch-data inputs: parameters: - name: milvus-collections outputs: artifacts: - name: raw-data path: /tmp/raw-data container: image: python:3.13-slim command: [python] args: - -c - | import subprocess subprocess.run(["pip", "install", "pymilvus", "-q"], check=True) import json from pathlib import Path from pymilvus import connections, Collection, utility MILVUS_HOST = "milvus.ai-ml.svc.cluster.local" MILVUS_PORT = 19530 collections_param = "{{inputs.parameters.milvus-collections}}" output_dir = Path("/tmp/raw-data") output_dir.mkdir(parents=True, exist_ok=True) print(f"Connecting to Milvus at {MILVUS_HOST}:{MILVUS_PORT}") connections.connect(host=MILVUS_HOST, port=MILVUS_PORT) # Determine which collections to use if collections_param and collections_param.strip(): collection_names = [c.strip() for c in collections_param.split(",")] print(f"Using specified collections: {collection_names}") else: # Get all available collections collection_names = utility.list_collections() print(f"Using all available collections: {collection_names}") all_training_data = [] for collection_name in collection_names: if not utility.has_collection(collection_name): print(f"Warning: Collection {collection_name} not found, skipping") continue print(f"Fetching data from collection: {collection_name}") collection = Collection(collection_name) collection.load() # Query all data from the collection # Note: Adjust field names based on your schema try: # Get collection schema to determine fields schema = collection.schema field_names = [field.name for field in schema.fields if field.name != "id"] # Query all entities (limited to reasonable batch size) # For large collections, you may want to implement pagination results = collection.query( expr="id >= 0", output_fields=field_names, limit=100000 ) print(f" Retrieved {len(results)} records from {collection_name}") for result in results: # Extract text field (adjust based on your schema) text_content = result.get("text", "") source = result.get("source", collection_name) if text_content: all_training_data.append({ "text": text_content, "source": source, "collection": collection_name }) except Exception as e: print(f"Error querying collection {collection_name}: {e}") continue # Save all training data output_file = output_dir / "training_data.json" with open(output_file, "w") as f: json.dump({"data": all_training_data}, f) print(f"Total training samples collected: {len(all_training_data)}") print(f"Saved to {output_file}") connections.disconnect("default") envFrom: - configMapRef: name: ai-services-config resources: requests: memory: 1Gi cpu: 500m - name: prepare-data inputs: parameters: - name: max-seq-length artifacts: - name: raw-data path: /tmp/raw-data outputs: artifacts: - name: training-data path: /tmp/training-data container: image: python:3.13-slim command: [python] args: - -c - | import json from pathlib import Path max_seq_length = int("{{inputs.parameters.max-seq-length}}") input_dir = Path("/tmp/raw-data") output_dir = Path("/tmp/training-data") output_dir.mkdir(parents=True, exist_ok=True) # Load raw data with open(input_dir / "training_data.json") as f: data = json.load(f) raw_samples = data["data"] print(f"Processing {len(raw_samples)} raw samples") # Prepare data in instruction format for fine-tuning # Using Alpaca-style format: instruction + response training_samples = [] for sample in raw_samples: text = sample["text"] source = sample.get("source", "") # Create instruction-response pairs # You can customize this based on your use case training_sample = { "instruction": f"Based on the following information from {source}, provide a comprehensive response:", "input": text[:max_seq_length // 2], # Truncate if needed "output": text[:max_seq_length // 2], "source": source } training_samples.append(training_sample) # Split into train/validation (90/10) split_idx = int(len(training_samples) * 0.9) train_data = training_samples[:split_idx] val_data = training_samples[split_idx:] # Save prepared datasets with open(output_dir / "train.json", "w") as f: json.dump(train_data, f) with open(output_dir / "validation.json", "w") as f: json.dump(val_data, f) print(f"Prepared {len(train_data)} training samples") print(f"Prepared {len(val_data)} validation samples") print("Data preparation complete") resources: requests: memory: 2Gi cpu: 500m - name: train inputs: parameters: - name: reference-model - name: output-name - name: learning-rate - name: num-epochs - name: batch-size - name: max-seq-length - name: lora-r - name: lora-alpha - name: lora-dropout artifacts: - name: training-data path: /tmp/training-data container: image: python:3.13-slim command: [bash] args: - -c - | set -e echo "Installing dependencies..." pip install -q torch transformers peft datasets accelerate bitsandbytes scipy echo "Starting QLoRA training..." python << 'EOF' import json import os from pathlib import Path from datasets import Dataset from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, Trainer, DataCollatorForLanguageModeling ) from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training import torch # Parameters reference_model = "{{inputs.parameters.reference-model}}" output_name = "{{inputs.parameters.output-name}}" learning_rate = float("{{inputs.parameters.learning-rate}}") num_epochs = int("{{inputs.parameters.num-epochs}}") batch_size = int("{{inputs.parameters.batch-size}}") max_seq_length = int("{{inputs.parameters.max-seq-length}}") lora_r = int("{{inputs.parameters.lora-r}}") lora_alpha = int("{{inputs.parameters.lora-alpha}}") lora_dropout = float("{{inputs.parameters.lora-dropout}}") data_dir = Path("/tmp/training-data") output_dir = Path("/mnt/model-storage") / output_name output_dir.mkdir(parents=True, exist_ok=True) print(f"Training configuration:") print(f" Model: {reference_model}") print(f" Learning rate: {learning_rate}") print(f" Epochs: {num_epochs}") print(f" Batch size: {batch_size}") print(f" Max sequence length: {max_seq_length}") print(f" LoRA r: {lora_r}, alpha: {lora_alpha}, dropout: {lora_dropout}") # Load datasets with open(data_dir / "train.json") as f: train_data = json.load(f) with open(data_dir / "validation.json") as f: val_data = json.load(f) print(f"Loaded {len(train_data)} training samples, {len(val_data)} validation samples") # Load tokenizer print(f"Loading tokenizer from {reference_model}...") tokenizer = AutoTokenizer.from_pretrained(reference_model, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Prepare datasets def format_sample(sample): instruction = sample.get("instruction", "") input_text = sample.get("input", "") output = sample.get("output", "") if input_text: prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{output}" else: prompt = f"### Instruction:\n{instruction}\n\n### Response:\n{output}" return {"text": prompt} train_dataset = Dataset.from_list([format_sample(s) for s in train_data]) val_dataset = Dataset.from_list([format_sample(s) for s in val_data]) # Tokenize datasets def tokenize_function(examples): return tokenizer( examples["text"], truncation=True, max_length=max_seq_length, padding="max_length" ) print("Tokenizing datasets...") train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"]) val_dataset = val_dataset.map(tokenize_function, batched=True, remove_columns=["text"]) # Configure quantization for QLoRA bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, ) # Load model print(f"Loading model {reference_model} with 4-bit quantization...") model = AutoModelForCausalLM.from_pretrained( reference_model, quantization_config=bnb_config, device_map="auto", trust_remote_code=True ) # Prepare model for training model = prepare_model_for_kbit_training(model) # Configure LoRA lora_config = LoraConfig( r=lora_r, lora_alpha=lora_alpha, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_dropout=lora_dropout, bias="none", task_type="CAUSAL_LM" ) # Add LoRA adapters print("Adding LoRA adapters...") model = get_peft_model(model, lora_config) model.print_trainable_parameters() # Training arguments training_args = TrainingArguments( output_dir=str(output_dir / "checkpoints"), num_train_epochs=num_epochs, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, gradient_accumulation_steps=4, learning_rate=learning_rate, fp16=True, logging_steps=10, evaluation_strategy="steps", eval_steps=50, save_strategy="steps", save_steps=100, save_total_limit=3, load_best_model_at_end=True, report_to="none", remove_unused_columns=False, ) # Data collator data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False ) # Trainer trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, data_collator=data_collator, ) # Train print("Starting training...") trainer.train() # Save final model print(f"Saving QLora adapter to {output_dir}") model.save_pretrained(str(output_dir / "final")) tokenizer.save_pretrained(str(output_dir / "final")) # Save training metadata metadata = { "reference_model": reference_model, "output_name": output_name, "training_params": { "learning_rate": learning_rate, "num_epochs": num_epochs, "batch_size": batch_size, "max_seq_length": max_seq_length, "lora_r": lora_r, "lora_alpha": lora_alpha, "lora_dropout": lora_dropout }, "dataset_info": { "train_samples": len(train_data), "val_samples": len(val_data) } } with open(output_dir / "metadata.json", "w") as f: json.dump(metadata, f, indent=2) print("Training complete!") print(f"QLora adapter saved to: {output_dir}") EOF envFrom: - configMapRef: name: ai-services-config volumeMounts: - name: model-storage mountPath: /mnt/model-storage resources: requests: memory: 16Gi cpu: 4 nvidia.com/gpu: 1 limits: memory: 32Gi cpu: 8 nvidia.com/gpu: 1