# Hybrid ML Training Workflow # Combines Argo Workflows orchestration with Kubeflow Pipeline ML components # Use case: Train a model using data from Milvus, with checkpointing and evaluation --- apiVersion: argoproj.io/v1alpha1 kind: WorkflowTemplate metadata: name: hybrid-ml-training namespace: ai-ml labels: app.kubernetes.io/name: hybrid-ml-training app.kubernetes.io/part-of: ai-ml-pipelines annotations: description: | Demonstrates hybrid Argo+KFP workflow: - Argo handles orchestration, branching, retry logic - KFP pipelines handle ML-specific operations (with caching) - NATS for status updates to frontends spec: entrypoint: hybrid-training serviceAccountName: argo-workflow # Artifact repository for model checkpoints artifactRepositoryRef: configMap: artifact-repository key: default arguments: parameters: - name: collection-name description: "Milvus collection to pull training data from" value: "dnd_text_embeddings" - name: model-name description: "Base model for fine-tuning" value: "mistralai/Mistral-7B-v0.3" - name: lora-rank description: "LoRA rank (higher = more params)" value: "16" - name: epochs description: "Training epochs" value: "3" - name: batch-size description: "Training batch size" value: "4" - name: output-path description: "S3 path for model output" value: "s3://models/lora-adapters" - name: notify-nats description: "Publish status to NATS" value: "true" # Volumes for GPU caching volumes: - name: model-cache persistentVolumeClaim: claimName: model-cache-pvc - name: shm emptyDir: medium: Memory sizeLimit: 16Gi templates: # Main DAG orchestrating the workflow - name: hybrid-training dag: tasks: - name: notify-start template: nats-notify when: "{{workflow.parameters.notify-nats}} == true" arguments: parameters: - name: subject value: "ai.pipeline.status.{{workflow.name}}" - name: message value: '{"status": "started", "pipeline": "hybrid-ml-training"}' - name: prepare-data template: extract-training-data arguments: parameters: - name: collection-name value: "{{workflow.parameters.collection-name}}" - name: validate-data template: validate-dataset dependencies: [prepare-data] arguments: artifacts: - name: dataset from: "{{tasks.prepare-data.outputs.artifacts.dataset}}" # KFP Pipeline: Run embedding generation if needed - name: generate-embeddings template: trigger-kfp dependencies: [validate-data] when: "{{tasks.validate-data.outputs.parameters.needs-embeddings}} == true" arguments: parameters: - name: pipeline-id value: "embedding-generation" - name: params value: '{"input_path": "{{tasks.prepare-data.outputs.parameters.data-path}}"}' # Training step (runs on GPU) - name: train-lora template: lora-training dependencies: [validate-data, generate-embeddings] arguments: parameters: - name: model-name value: "{{workflow.parameters.model-name}}" - name: lora-rank value: "{{workflow.parameters.lora-rank}}" - name: epochs value: "{{workflow.parameters.epochs}}" - name: batch-size value: "{{workflow.parameters.batch-size}}" artifacts: - name: dataset from: "{{tasks.prepare-data.outputs.artifacts.dataset}}" # Evaluate model - name: evaluate template: evaluate-model dependencies: [train-lora] arguments: artifacts: - name: adapter from: "{{tasks.train-lora.outputs.artifacts.adapter}}" # Branch based on evaluation results - name: check-quality template: quality-gate dependencies: [evaluate] arguments: parameters: - name: eval-score value: "{{tasks.evaluate.outputs.parameters.score}}" - name: threshold value: "0.7" # If quality is good, upload to S3 - name: upload-model template: upload-to-s3 dependencies: [check-quality] when: "{{tasks.check-quality.outputs.parameters.passed}} == true" arguments: parameters: - name: output-path value: "{{workflow.parameters.output-path}}" artifacts: - name: adapter from: "{{tasks.train-lora.outputs.artifacts.adapter}}" # If quality is poor, trigger retraining with different params - name: retry-training template: adjust-and-retry dependencies: [check-quality] when: "{{tasks.check-quality.outputs.parameters.passed}} == false" arguments: parameters: - name: current-rank value: "{{workflow.parameters.lora-rank}}" - name: current-epochs value: "{{workflow.parameters.epochs}}" - name: notify-complete template: nats-notify when: "{{workflow.parameters.notify-nats}} == true" dependencies: [upload-model] arguments: parameters: - name: subject value: "ai.pipeline.status.{{workflow.name}}" - name: message value: '{"status": "completed", "score": "{{tasks.evaluate.outputs.parameters.score}}"}' # Extract training data from Milvus - name: extract-training-data inputs: parameters: - name: collection-name outputs: artifacts: - name: dataset path: /tmp/dataset parameters: - name: data-path valueFrom: path: /tmp/data-path script: image: python:3.13-slim command: [python] source: | import subprocess import sys subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "pymilvus", "pandas", "pyarrow"]) from pymilvus import connections, Collection import pandas as pd from pathlib import Path connections.connect(host="milvus.ai-ml.svc.cluster.local", port=19530) collection = Collection("{{inputs.parameters.collection-name}}") collection.load() # Query all training samples results = collection.query( expr="source != ''", output_fields=["text", "source", "metadata"], limit=10000 ) # Convert to training format df = pd.DataFrame(results) output_dir = Path("/tmp/dataset") output_dir.mkdir(parents=True, exist_ok=True) # Save as parquet for efficient loading df.to_parquet(output_dir / "train.parquet") print(f"Extracted {len(df)} samples") with open("/tmp/data-path", "w") as f: f.write(str(output_dir / "train.parquet")) # Validate dataset - name: validate-dataset inputs: artifacts: - name: dataset path: /tmp/dataset outputs: parameters: - name: needs-embeddings valueFrom: path: /tmp/needs-embeddings - name: sample-count valueFrom: path: /tmp/sample-count script: image: python:3.13-slim command: [python] source: | import subprocess import sys subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "pandas", "pyarrow"]) import pandas as pd from pathlib import Path dataset_dir = Path("/tmp/dataset") parquet_files = list(dataset_dir.glob("*.parquet")) if not parquet_files: raise ValueError("No parquet files found in dataset") df = pd.read_parquet(parquet_files[0]) sample_count = len(df) print(f"Dataset contains {sample_count} samples") # Check if embeddings column exists needs_embeddings = "embedding" not in df.columns with open("/tmp/needs-embeddings", "w") as f: f.write(str(needs_embeddings).lower()) with open("/tmp/sample-count", "w") as f: f.write(str(sample_count)) # Trigger KFP pipeline - name: trigger-kfp inputs: parameters: - name: pipeline-id - name: params outputs: parameters: - name: run-id valueFrom: path: /tmp/run-id script: image: python:3.13-slim command: [python] source: | import subprocess import sys import json subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "kfp==2.12.1"]) from kfp import Client client = Client(host="http://ml-pipeline.kubeflow.svc.cluster.local:8888") params = json.loads('''{{inputs.parameters.params}}''') run = client.create_run_from_pipeline_func( pipeline_func=None, pipeline_id="{{inputs.parameters.pipeline-id}}", arguments=params ) print(f"Triggered KFP pipeline: {run.run_id}") with open("/tmp/run-id", "w") as f: f.write(run.run_id) # LoRA training (GPU) - name: lora-training inputs: parameters: - name: model-name - name: lora-rank - name: epochs - name: batch-size artifacts: - name: dataset path: /data/dataset outputs: artifacts: - name: adapter path: /output/adapter - name: logs path: /output/logs podSpecPatch: | containers: - name: main resources: requests: amd.com/gpu: 1 limits: amd.com/gpu: 1 script: image: ghcr.io/billy-davies-2/lora-trainer:latest command: [python] env: - name: HF_HOME value: /cache/huggingface - name: TRANSFORMERS_CACHE value: /cache/huggingface volumeMounts: - name: model-cache mountPath: /cache - name: shm mountPath: /dev/shm source: | import os import sys from pathlib import Path # Training script from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments from peft import LoraConfig, get_peft_model from datasets import load_dataset from trl import SFTTrainer model_name = "{{inputs.parameters.model-name}}" lora_rank = int("{{inputs.parameters.lora-rank}}") epochs = int("{{inputs.parameters.epochs}}") batch_size = int("{{inputs.parameters.batch-size}}") # Load dataset dataset = load_dataset("parquet", data_files="/data/dataset/*.parquet", split="train") # Load model model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="auto", device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained(model_name) # Configure LoRA lora_config = LoraConfig( r=lora_rank, lora_alpha=lora_rank * 2, target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM" ) model = get_peft_model(model, lora_config) # Training training_args = TrainingArguments( output_dir="/output/adapter", num_train_epochs=epochs, per_device_train_batch_size=batch_size, gradient_accumulation_steps=4, learning_rate=2e-4, logging_dir="/output/logs", save_strategy="epoch" ) trainer = SFTTrainer( model=model, train_dataset=dataset, tokenizer=tokenizer, args=training_args, dataset_text_field="text" ) trainer.train() trainer.save_model("/output/adapter") print("Training complete!") # Evaluate model - name: evaluate-model inputs: artifacts: - name: adapter path: /input/adapter outputs: parameters: - name: score valueFrom: path: /tmp/score script: image: python:3.13-slim command: [python] source: | import subprocess import sys subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "httpx"]) import httpx import json # Run evaluation using vLLM with the adapter test_prompts = [ "What is the capital of France?", "Explain machine learning in simple terms.", "Write a haiku about coding." ] scores = [] with httpx.Client(timeout=120.0) as client: for prompt in test_prompts: response = client.post( "http://llm-draft.ai-ml.svc.cluster.local:8000/v1/chat/completions", json={ "model": "local-adapter", "messages": [{"role": "user", "content": prompt}], "max_tokens": 200 } ) # Simple scoring based on response coherence result = response.json() content = result["choices"][0]["message"]["content"] score = min(1.0, len(content) / 100) # Placeholder scoring scores.append(score) avg_score = sum(scores) / len(scores) print(f"Average evaluation score: {avg_score}") with open("/tmp/score", "w") as f: f.write(str(round(avg_score, 3))) # Quality gate - name: quality-gate inputs: parameters: - name: eval-score - name: threshold outputs: parameters: - name: passed valueFrom: path: /tmp/passed script: image: python:3.13-slim command: [python] source: | score = float("{{inputs.parameters.eval-score}}") threshold = float("{{inputs.parameters.threshold}}") passed = score >= threshold print(f"Score {score} {'passed' if passed else 'failed'} threshold {threshold}") with open("/tmp/passed", "w") as f: f.write(str(passed).lower()) # Upload to S3 - name: upload-to-s3 inputs: parameters: - name: output-path artifacts: - name: adapter path: /input/adapter script: image: amazon/aws-cli:latest command: [bash] env: - name: AWS_ACCESS_KEY_ID valueFrom: secretKeyRef: name: s3-credentials key: access-key - name: AWS_SECRET_ACCESS_KEY valueFrom: secretKeyRef: name: s3-credentials key: secret-key - name: AWS_ENDPOINT_URL value: "https://quobjects.billy.davies.cloud" source: | aws s3 cp --recursive /input/adapter "{{inputs.parameters.output-path}}/$(date +%Y%m%d-%H%M%S)/" echo "Uploaded adapter to {{inputs.parameters.output-path}}" # Adjust parameters and retry - name: adjust-and-retry inputs: parameters: - name: current-rank - name: current-epochs script: image: python:3.13-slim command: [python] source: | current_rank = int("{{inputs.parameters.current-rank}}") current_epochs = int("{{inputs.parameters.current-epochs}}") # Increase rank and epochs for next attempt new_rank = min(64, current_rank * 2) new_epochs = current_epochs + 2 print(f"Adjusting parameters: rank {current_rank}->{new_rank}, epochs {current_epochs}->{new_epochs}") print("TODO: Trigger new workflow with adjusted parameters") # NATS notification - name: nats-notify inputs: parameters: - name: subject - name: message script: image: python:3.13-slim command: [python] source: | import subprocess import sys subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "nats-py"]) import asyncio import nats async def notify(): nc = await nats.connect("nats://nats.ai-ml.svc.cluster.local:4222") await nc.publish( "{{inputs.parameters.subject}}", b'''{{inputs.parameters.message}}''' ) await nc.close() asyncio.run(notify())