From 7104698eeed929930e93b41951ebbd1a582b079a Mon Sep 17 00:00:00 2001 From: "Billy D." Date: Sun, 1 Feb 2026 20:39:42 -0500 Subject: [PATCH] feat: Add ML training and batch inference workflows - batch-inference: LLM inference with optional RAG - qlora-training: QLoRA adapter fine-tuning from Milvus - hybrid-ml-training: Multi-GPU distributed training - coqui-voice-training: XTTS voice cloning - document-ingestion: Ingest documents to Milvus - eventsource-kfp: Argo Events / Kubeflow integration - kfp-integration: Bridge between Argo and Kubeflow --- README.md | 128 ++++- batch-inference.yaml | 328 +++++++++++++ coqui-voice-training.yaml | 969 ++++++++++++++++++++++++++++++++++++++ document-ingestion.yaml | 369 +++++++++++++++ eventsource-kfp.yaml | 270 +++++++++++ hybrid-ml-training.yaml | 555 ++++++++++++++++++++++ kfp-integration.yaml | 237 ++++++++++ qlora-training.yaml | 510 ++++++++++++++++++++ 8 files changed, 3365 insertions(+), 1 deletion(-) create mode 100644 batch-inference.yaml create mode 100644 coqui-voice-training.yaml create mode 100644 document-ingestion.yaml create mode 100644 eventsource-kfp.yaml create mode 100644 hybrid-ml-training.yaml create mode 100644 kfp-integration.yaml create mode 100644 qlora-training.yaml diff --git a/README.md b/README.md index cd76262..735925b 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,128 @@ -# argo +# Argo Workflows +ML training and batch inference workflows for the DaviesTechLabs AI/ML platform. + +## Workflows + +| Workflow | Description | Trigger | +|----------|-------------|---------| +| `batch-inference` | Run LLM inference on batch inputs | `ai.pipeline.trigger` (pipeline="batch-inference") | +| `qlora-training` | Train QLoRA adapters from Milvus data | `ai.pipeline.trigger` (pipeline="qlora-training") | +| `hybrid-ml-training` | Multi-GPU distributed training | `ai.pipeline.trigger` (pipeline="hybrid-ml-training") | +| `coqui-voice-training` | XTTS voice cloning/training | `ai.pipeline.trigger` (pipeline="coqui-voice-training") | +| `document-ingestion` | Ingest documents into Milvus | `ai.pipeline.trigger` (pipeline="document-ingestion") | + +## Integration + +| File | Description | +|------|-------------| +| `eventsource-kfp.yaml` | Argo Events source for Kubeflow Pipelines integration | +| `kfp-integration.yaml` | Bridge workflows between Argo and Kubeflow | + +## Architecture + +``` +NATS (ai.pipeline.trigger) + │ + ▼ +┌─────────────────┐ +│ Argo Events │ +│ EventSource │ +└─────────────────┘ + │ + ▼ +┌─────────────────┐ +│ Argo Sensor │ +└─────────────────┘ + │ + ▼ +┌─────────────────┐ +│ WorkflowTemplate│ +│ (batch-inf, │ +│ qlora, etc) │ +└─────────────────┘ + │ + ├──▶ GPU Pods (AMD ROCm / NVIDIA CUDA) + ├──▶ Milvus Vector DB + ├──▶ vLLM / Ray Serve + └──▶ MLflow Tracking +``` + +## Workflow Details + +### batch-inference + +Batch LLM inference with optional RAG: + +```bash +argo submit batch-inference.yaml \ + -p input-url="s3://bucket/inputs.json" \ + -p output-url="s3://bucket/outputs.json" \ + -p use-rag="true" \ + -p max-tokens="500" +``` + +### qlora-training + +Fine-tune QLoRA adapters from Milvus knowledge: + +```bash +argo submit qlora-training.yaml \ + -p reference-model="mistralai/Mistral-7B-Instruct-v0.3" \ + -p output-name="my-adapter" \ + -p milvus-collections="docs,wiki" \ + -p num-epochs="3" +``` + +### coqui-voice-training + +Train XTTS voice models: + +```bash +argo submit coqui-voice-training.yaml \ + -p voice-name="my-voice" \ + -p audio-samples-url="s3://bucket/samples/" +``` + +### document-ingestion + +Ingest documents into Milvus: + +```bash +argo submit document-ingestion.yaml \ + -p source-url="s3://bucket/docs/" \ + -p collection="knowledge_base" \ + -p chunk-size="512" +``` + +## NATS Trigger Format + +Workflows are triggered via NATS `ai.pipeline.trigger`: + +```json +{ + "pipeline": "qlora-training", + "parameters": { + "reference-model": "mistralai/Mistral-7B-Instruct-v0.3", + "output-name": "custom-adapter", + "num-epochs": "5" + } +} +``` + +## GPU Scheduling + +Workflows use node affinity for GPU allocation: + +| Node | GPU | Best For | +|------|-----|----------| +| khelben | AMD Strix Halo 64GB | Large model training, vLLM | +| elminster | NVIDIA RTX 2070 | Whisper, XTTS | +| drizzt | AMD Radeon 680M | Embeddings | +| danilo | Intel Arc | Reranker | + +## Related + +- [homelab-design](https://git.daviestechlabs.io/daviestechlabs/homelab-design) - Architecture docs +- [kuberay-images](https://git.daviestechlabs.io/daviestechlabs/kuberay-images) - Ray worker images +- [handler-base](https://git.daviestechlabs.io/daviestechlabs/handler-base) - Handler library diff --git a/batch-inference.yaml b/batch-inference.yaml new file mode 100644 index 0000000..605033f --- /dev/null +++ b/batch-inference.yaml @@ -0,0 +1,328 @@ +# Batch Inference Workflow +# Runs LLM inference on a batch of inputs +# Triggered via NATS: ai.pipeline.trigger with pipeline="batch-inference" +--- +apiVersion: argoproj.io/v1alpha1 +kind: WorkflowTemplate +metadata: + name: batch-inference + namespace: ai-ml + labels: + app.kubernetes.io/name: batch-inference + app.kubernetes.io/part-of: llm-workflows +spec: + entrypoint: batch-inference + serviceAccountName: argo-workflow + + arguments: + parameters: + - name: input-url + description: "URL to JSON file with inference requests" + - name: output-url + description: "URL to store results (S3 path)" + value: "" + - name: use-rag + value: "true" + description: "Whether to use RAG for context" + - name: max-tokens + value: "500" + description: "Maximum tokens per response" + - name: temperature + value: "0.7" + description: "LLM temperature" + + templates: + - name: batch-inference + dag: + tasks: + - name: fetch-inputs + template: fetch-input-data + arguments: + parameters: + - name: input-url + value: "{{workflow.parameters.input-url}}" + + - name: run-inference + template: inference + dependencies: [fetch-inputs] + arguments: + parameters: + - name: use-rag + value: "{{workflow.parameters.use-rag}}" + - name: max-tokens + value: "{{workflow.parameters.max-tokens}}" + - name: temperature + value: "{{workflow.parameters.temperature}}" + artifacts: + - name: inputs + from: "{{tasks.fetch-inputs.outputs.artifacts.inputs}}" + + - name: upload-results + template: upload-output + dependencies: [run-inference] + when: "{{workflow.parameters.output-url}} != ''" + arguments: + parameters: + - name: output-url + value: "{{workflow.parameters.output-url}}" + artifacts: + - name: results + from: "{{tasks.run-inference.outputs.artifacts.results}}" + + - name: fetch-input-data + inputs: + parameters: + - name: input-url + outputs: + artifacts: + - name: inputs + path: /tmp/inputs + container: + image: python:3.13-slim + command: [python] + args: + - -c + - | + import json + import urllib.request + from pathlib import Path + + input_url = "{{inputs.parameters.input-url}}" + output_dir = Path("/tmp/inputs") + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"Fetching inputs from: {input_url}") + + if input_url.startswith("s3://"): + import subprocess + subprocess.run(["pip", "install", "boto3", "-q"], check=True) + import boto3 + s3 = boto3.client("s3") + bucket, key = input_url[5:].split("/", 1) + s3.download_file(bucket, key, str(output_dir / "inputs.json")) + elif input_url.startswith("http"): + urllib.request.urlretrieve(input_url, output_dir / "inputs.json") + else: + print(f"Unsupported URL scheme: {input_url}") + exit(1) + + # Validate JSON structure + with open(output_dir / "inputs.json") as f: + data = json.load(f) + + if "requests" not in data: + print("Error: JSON must contain 'requests' array") + exit(1) + + print(f"Loaded {len(data['requests'])} inference requests") + resources: + requests: + memory: 256Mi + cpu: 100m + + - name: inference + inputs: + parameters: + - name: use-rag + - name: max-tokens + - name: temperature + artifacts: + - name: inputs + path: /tmp/inputs + outputs: + artifacts: + - name: results + path: /tmp/results + container: + image: python:3.13-slim + command: [python] + args: + - -c + - | + import subprocess + subprocess.run(["pip", "install", "httpx", "pymilvus", "-q"], check=True) + + import json + import httpx + from pathlib import Path + from typing import List, Dict + + # Configuration + VLLM_URL = "http://llm-draft.ai-ml.svc.cluster.local:8000" + EMBEDDINGS_URL = "http://embeddings-predictor.ai-ml.svc.cluster.local" + RERANKER_URL = "http://reranker-predictor.ai-ml.svc.cluster.local" + MILVUS_HOST = "milvus.ai-ml.svc.cluster.local" + LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.3" + + use_rag = "{{inputs.parameters.use-rag}}" == "true" + max_tokens = int("{{inputs.parameters.max-tokens}}") + temperature = float("{{inputs.parameters.temperature}}") + + input_dir = Path("/tmp/inputs") + output_dir = Path("/tmp/results") + output_dir.mkdir(parents=True, exist_ok=True) + + # Load inputs + with open(input_dir / "inputs.json") as f: + data = json.load(f) + requests = data["requests"] + + print(f"Processing {len(requests)} requests (RAG: {use_rag})") + + # Initialize Milvus if using RAG + collection = None + if use_rag: + try: + from pymilvus import connections, Collection, utility + connections.connect(host=MILVUS_HOST, port=19530) + if utility.has_collection("knowledge_base"): + collection = Collection("knowledge_base") + collection.load() + print("Milvus connected") + except Exception as e: + print(f"Milvus connection failed: {e}") + use_rag = False + + def get_embeddings(texts: List[str], client: httpx.Client) -> List[List[float]]: + response = client.post( + f"{EMBEDDINGS_URL}/embeddings", + json={"input": texts, "model": "bge"} + ) + result = response.json() + return [d["embedding"] for d in result.get("data", [])] + + def search_milvus(embedding: List[float]) -> List[Dict]: + results = collection.search( + data=[embedding], + anns_field="embedding", + param={"metric_type": "COSINE", "params": {"ef": 64}}, + limit=5, + output_fields=["text", "source"] + ) + docs = [] + for hits in results: + for hit in hits: + docs.append({ + "text": hit.entity.get("text", ""), + "source": hit.entity.get("source", ""), + "score": hit.score + }) + return docs + + def rerank(query: str, documents: List[str], client: httpx.Client) -> List[Dict]: + response = client.post( + f"{RERANKER_URL}/v1/rerank", + json={"query": query, "documents": documents} + ) + return response.json().get("results", []) + + # Process requests + results = [] + with httpx.Client(timeout=120.0) as client: + for i, req in enumerate(requests): + query = req.get("text", req.get("query", "")) + req_id = req.get("id", str(i)) + + print(f"Processing {i+1}/{len(requests)}: {query[:50]}...") + + context = "" + rag_sources = [] + + if use_rag and collection: + try: + # Get embeddings and search + embeddings = get_embeddings([query], client) + if embeddings: + docs = search_milvus(embeddings[0]) + if docs: + doc_texts = [d["text"] for d in docs] + reranked = rerank(query, doc_texts, client) + sorted_docs = sorted(reranked, key=lambda x: x.get("relevance_score", 0), reverse=True)[:3] + context = "\n\n".join([doc_texts[d["index"]] for d in sorted_docs]) + rag_sources = [docs[d["index"]].get("source", "") for d in sorted_docs] + except Exception as e: + print(f" RAG failed: {e}") + + # Generate response + try: + messages = [{"role": "system", "content": "You are a helpful AI assistant."}] + if context: + messages.append({"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}) + else: + messages.append({"role": "user", "content": query}) + + response = client.post( + f"{VLLM_URL}/v1/chat/completions", + json={ + "model": LLM_MODEL, + "messages": messages, + "max_tokens": max_tokens, + "temperature": temperature + } + ) + result = response.json() + answer = result["choices"][0]["message"]["content"] + except Exception as e: + answer = f"Error: {e}" + + results.append({ + "id": req_id, + "query": query, + "response": answer, + "used_rag": bool(context), + "rag_sources": rag_sources + }) + + # Save results + with open(output_dir / "results.json", "w") as f: + json.dump({"results": results}, f, indent=2) + + print(f"Completed {len(results)} inferences") + + if collection: + from pymilvus import connections + connections.disconnect("default") + envFrom: + - configMapRef: + name: ai-services-config + resources: + requests: + memory: 1Gi + cpu: 500m + + - name: upload-output + inputs: + parameters: + - name: output-url + artifacts: + - name: results + path: /tmp/results + container: + image: python:3.13-slim + command: [python] + args: + - -c + - | + import subprocess + subprocess.run(["pip", "install", "boto3", "-q"], check=True) + + import boto3 + from pathlib import Path + + output_url = "{{inputs.parameters.output-url}}" + results_file = Path("/tmp/results/results.json") + + print(f"Uploading results to: {output_url}") + + if output_url.startswith("s3://"): + s3 = boto3.client("s3") + bucket, key = output_url[5:].split("/", 1) + s3.upload_file(str(results_file), bucket, key) + print("Upload complete") + else: + print(f"Unsupported URL scheme: {output_url}") + exit(1) + resources: + requests: + memory: 256Mi + cpu: 100m diff --git a/coqui-voice-training.yaml b/coqui-voice-training.yaml new file mode 100644 index 0000000..02c85c3 --- /dev/null +++ b/coqui-voice-training.yaml @@ -0,0 +1,969 @@ +# Coqui TTS Voice Training Workflow +# Trains a custom voice model using Coqui TTS from audio samples +# Triggered via NATS: ai.pipeline.trigger with pipeline="coqui-voice-training" +--- +apiVersion: argoproj.io/v1alpha1 +kind: WorkflowTemplate +metadata: + name: coqui-voice-training + namespace: ai-ml + labels: + app.kubernetes.io/name: coqui-voice-training + app.kubernetes.io/part-of: llm-workflows +spec: + entrypoint: train-voice + serviceAccountName: argo-workflow + + arguments: + parameters: + - name: audio-source + description: "URL to audio files (S3 bucket, HTTP, or NFS path with .wav/.mp3 files)" + - name: transcripts-source + description: "URL to transcripts file (CSV with audio_file,transcript columns) - leave empty to auto-transcribe" + value: "" + - name: voice-name + description: "Name for the trained voice model" + value: "custom-voice" + - name: base-model + description: "Base TTS model to fine-tune from" + value: "tts_models/en/ljspeech/vits" + - name: language + description: "Language code (e.g., en, de, fr, es)" + value: "en" + - name: num-epochs + description: "Number of training epochs" + value: "100" + - name: batch-size + description: "Training batch size" + value: "16" + - name: learning-rate + description: "Learning rate for training" + value: "0.0001" + - name: sample-rate + description: "Target sample rate for audio (Hz)" + value: "22050" + - name: output-path + description: "Path to store the trained model (S3 or NFS)" + value: "/models/tts/custom" + + volumeClaimTemplates: + - metadata: + name: training-workspace + spec: + accessModes: ["ReadWriteMany"] + storageClassName: nfs-slow + resources: + requests: + storage: 50Gi + + templates: + - name: train-voice + dag: + tasks: + - name: fetch-audio + template: fetch-audio-files + arguments: + parameters: + - name: audio-source + value: "{{workflow.parameters.audio-source}}" + + - name: fetch-transcripts + template: fetch-transcript-file + arguments: + parameters: + - name: transcripts-source + value: "{{workflow.parameters.transcripts-source}}" + + - name: preprocess-audio + template: preprocess + dependencies: [fetch-audio] + arguments: + parameters: + - name: sample-rate + value: "{{workflow.parameters.sample-rate}}" + artifacts: + - name: raw-audio + from: "{{tasks.fetch-audio.outputs.artifacts.audio-files}}" + + - name: generate-transcripts + template: transcribe-audio + dependencies: [preprocess-audio, fetch-transcripts] + when: "{{workflow.parameters.transcripts-source}} == ''" + arguments: + parameters: + - name: language + value: "{{workflow.parameters.language}}" + artifacts: + - name: audio-files + from: "{{tasks.preprocess-audio.outputs.artifacts.processed-audio}}" + + - name: prepare-dataset + template: prepare-coqui-dataset + dependencies: [preprocess-audio, generate-transcripts, fetch-transcripts] + arguments: + parameters: + - name: voice-name + value: "{{workflow.parameters.voice-name}}" + - name: language + value: "{{workflow.parameters.language}}" + artifacts: + - name: audio-files + from: "{{tasks.preprocess-audio.outputs.artifacts.processed-audio}}" + - name: transcripts + from: "{{=workflow.parameters.transcriptsSource != '' ? tasks.fetch-transcripts.outputs.artifacts.transcripts : tasks.generate-transcripts.outputs.artifacts.transcripts}}" + optional: true + + - name: train-model + template: train-tts + dependencies: [prepare-dataset] + arguments: + parameters: + - name: voice-name + value: "{{workflow.parameters.voice-name}}" + - name: base-model + value: "{{workflow.parameters.base-model}}" + - name: language + value: "{{workflow.parameters.language}}" + - name: num-epochs + value: "{{workflow.parameters.num-epochs}}" + - name: batch-size + value: "{{workflow.parameters.batch-size}}" + - name: learning-rate + value: "{{workflow.parameters.learning-rate}}" + artifacts: + - name: dataset + from: "{{tasks.prepare-dataset.outputs.artifacts.dataset}}" + + - name: export-model + template: export-trained-model + dependencies: [train-model] + arguments: + parameters: + - name: voice-name + value: "{{workflow.parameters.voice-name}}" + - name: output-path + value: "{{workflow.parameters.output-path}}" + artifacts: + - name: trained-model + from: "{{tasks.train-model.outputs.artifacts.model}}" + + # Template: Fetch audio files from source + - name: fetch-audio-files + inputs: + parameters: + - name: audio-source + outputs: + artifacts: + - name: audio-files + path: /tmp/audio + container: + image: python:3.13-slim + command: [python] + args: + - -c + - | + import os + import subprocess + import urllib.request + from pathlib import Path + import shutil + + source_url = "{{inputs.parameters.audio-source}}" + output_dir = Path("/tmp/audio") + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"Fetching audio from: {source_url}") + + if source_url.startswith("s3://"): + subprocess.run(["pip", "install", "boto3", "-q"], check=True) + import boto3 + s3 = boto3.client("s3") + bucket, prefix = source_url[5:].split("/", 1) + response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix) + + audio_extensions = {".wav", ".mp3", ".flac", ".ogg", ".m4a"} + for obj in response.get("Contents", []): + key = obj["Key"] + if Path(key).suffix.lower() in audio_extensions: + local_path = output_dir / Path(key).name + s3.download_file(bucket, key, str(local_path)) + print(f"Downloaded: {key}") + + elif source_url.startswith("http"): + # Handle single file or directory listing + filename = source_url.split("/")[-1] + if any(ext in filename.lower() for ext in [".wav", ".mp3", ".flac", ".zip"]): + local_path = output_dir / filename + urllib.request.urlretrieve(source_url, local_path) + print(f"Downloaded: {filename}") + + # Extract if zip + if filename.endswith(".zip"): + shutil.unpack_archive(local_path, output_dir) + os.remove(local_path) + print("Extracted zip archive") + else: + print(f"URL doesn't appear to be an audio file: {source_url}") + exit(1) + + elif source_url.startswith("/"): + # Local/NFS path + src_path = Path(source_url) + if src_path.is_dir(): + audio_extensions = {".wav", ".mp3", ".flac", ".ogg", ".m4a"} + for f in src_path.iterdir(): + if f.suffix.lower() in audio_extensions: + shutil.copy(f, output_dir / f.name) + print(f"Copied: {f.name}") + elif src_path.is_file(): + shutil.copy(src_path, output_dir / src_path.name) + else: + print(f"Path not found: {source_url}") + exit(1) + else: + print(f"Unsupported source: {source_url}") + exit(1) + + # Count files + audio_files = list(output_dir.glob("*")) + print(f"Total audio files: {len(audio_files)}") + + if len(audio_files) == 0: + print("Error: No audio files found!") + exit(1) + resources: + requests: + memory: 512Mi + cpu: 200m + + # Template: Fetch transcripts file + - name: fetch-transcript-file + inputs: + parameters: + - name: transcripts-source + outputs: + artifacts: + - name: transcripts + path: /tmp/transcripts + optional: true + container: + image: python:3.13-slim + command: [python] + args: + - -c + - | + import os + import subprocess + import urllib.request + from pathlib import Path + import shutil + + source_url = "{{inputs.parameters.transcripts-source}}" + output_dir = Path("/tmp/transcripts") + output_dir.mkdir(parents=True, exist_ok=True) + + if not source_url or source_url.strip() == "": + print("No transcripts source provided - will auto-transcribe") + # Create empty placeholder + (output_dir / "placeholder.txt").write_text("auto-transcribe") + exit(0) + + print(f"Fetching transcripts from: {source_url}") + + if source_url.startswith("s3://"): + subprocess.run(["pip", "install", "boto3", "-q"], check=True) + import boto3 + s3 = boto3.client("s3") + bucket, key = source_url[5:].split("/", 1) + local_path = output_dir / Path(key).name + s3.download_file(bucket, key, str(local_path)) + print(f"Downloaded: {key}") + + elif source_url.startswith("http"): + filename = source_url.split("/")[-1] or "transcripts.csv" + local_path = output_dir / filename + urllib.request.urlretrieve(source_url, local_path) + print(f"Downloaded: {filename}") + + elif source_url.startswith("/"): + src_path = Path(source_url) + if src_path.is_file(): + shutil.copy(src_path, output_dir / src_path.name) + print(f"Copied: {src_path.name}") + else: + print(f"File not found: {source_url}") + exit(1) + else: + print(f"Unsupported source: {source_url}") + exit(1) + resources: + requests: + memory: 256Mi + cpu: 100m + + # Template: Preprocess audio files + - name: preprocess + inputs: + parameters: + - name: sample-rate + artifacts: + - name: raw-audio + path: /tmp/raw-audio + outputs: + artifacts: + - name: processed-audio + path: /tmp/processed-audio + container: + image: python:3.13-slim + command: [bash] + args: + - -c + - | + set -e + + # Install ffmpeg and dependencies + apt-get update && apt-get install -y ffmpeg > /dev/null 2>&1 + pip install -q pydub numpy soundfile + + python3 << 'EOF' + import os + from pathlib import Path + from pydub import AudioSegment + import soundfile as sf + + SAMPLE_RATE = int("{{inputs.parameters.sample-rate}}") + input_dir = Path("/tmp/raw-audio") + output_dir = Path("/tmp/processed-audio") + output_dir.mkdir(parents=True, exist_ok=True) + + audio_extensions = {".wav", ".mp3", ".flac", ".ogg", ".m4a"} + + for audio_file in input_dir.iterdir(): + if audio_file.suffix.lower() not in audio_extensions: + continue + + print(f"Processing: {audio_file.name}") + + try: + # Load audio + audio = AudioSegment.from_file(str(audio_file)) + + # Convert to mono if stereo + if audio.channels > 1: + audio = audio.set_channels(1) + + # Resample to target sample rate + audio = audio.set_frame_rate(SAMPLE_RATE) + + # Normalize audio + audio = audio.normalize() + + # Export as WAV + output_file = output_dir / f"{audio_file.stem}.wav" + audio.export(str(output_file), format="wav") + print(f" -> Saved: {output_file.name}") + + except Exception as e: + print(f" -> Error processing {audio_file.name}: {e}") + continue + + processed_files = list(output_dir.glob("*.wav")) + print(f"\nProcessed {len(processed_files)} audio files") + + if len(processed_files) == 0: + print("Error: No files were successfully processed!") + exit(1) + EOF + resources: + requests: + memory: 2Gi + cpu: "1" + + # Template: Auto-transcribe audio using Coqui STT + - name: transcribe-audio + inputs: + parameters: + - name: language + artifacts: + - name: audio-files + path: /tmp/audio + outputs: + artifacts: + - name: transcripts + path: /tmp/transcripts + container: + image: ghcr.io/coqui-ai/stt:latest + command: [bash] + args: + - -c + - | + set -e + + # Install additional dependencies + pip install -q numpy scipy + + python3 << 'EOF' + import csv + import os + import wave + import numpy as np + from pathlib import Path + from stt import Model + + LANGUAGE = "{{inputs.parameters.language}}" + input_dir = Path("/tmp/audio") + output_dir = Path("/tmp/transcripts") + output_dir.mkdir(parents=True, exist_ok=True) + + # Model paths - Coqui STT models are typically pre-installed in the container + # or can be downloaded from https://coqui.ai/models + MODEL_DIR = Path("/models/stt") + + # Try to find model files + model_file = None + scorer_file = None + + # Check for language-specific models + lang_model_dir = MODEL_DIR / LANGUAGE + if lang_model_dir.exists(): + for f in lang_model_dir.glob("*.tflite"): + model_file = f + for f in lang_model_dir.glob("*.scorer"): + scorer_file = f + + # Fallback to default English model location + if model_file is None: + default_paths = [ + MODEL_DIR / "model.tflite", + Path("/usr/share/stt/model.tflite"), + Path("/opt/stt/model.tflite"), + ] + for p in default_paths: + if p.exists(): + model_file = p + break + + if model_file is None: + # Download model if not found + print("Downloading Coqui STT model...") + import urllib.request + import tarfile + + model_url = "https://github.com/coqui-ai/STT-models/releases/download/english/coqui-stt-1.0.0-lg-vocab.tflite" + scorer_url = "https://github.com/coqui-ai/STT-models/releases/download/english/coqui-stt-1.0.0-lg-vocab.scorer" + + MODEL_DIR.mkdir(parents=True, exist_ok=True) + model_file = MODEL_DIR / "model.tflite" + scorer_file = MODEL_DIR / "model.scorer" + + urllib.request.urlretrieve(model_url, model_file) + urllib.request.urlretrieve(scorer_url, scorer_file) + print("Model downloaded successfully") + + print(f"Loading Coqui STT model: {model_file}") + model = Model(str(model_file)) + + if scorer_file and scorer_file.exists(): + print(f"Loading scorer: {scorer_file}") + model.enableExternalScorer(str(scorer_file)) + + transcripts = [] + + for audio_file in sorted(input_dir.glob("*.wav")): + print(f"Transcribing: {audio_file.name}") + + try: + # Read WAV file + with wave.open(str(audio_file), 'rb') as w: + sample_rate = w.getframerate() + frames = w.getnframes() + audio_data = w.readframes(frames) + + # Convert to int16 array + audio = np.frombuffer(audio_data, dtype=np.int16) + + # Resample if needed (Coqui STT expects 16kHz) + if sample_rate != 16000: + from scipy import signal + audio = signal.resample(audio, int(len(audio) * 16000 / sample_rate)) + audio = audio.astype(np.int16) + + # Run inference + text = model.stt(audio) + + transcripts.append({ + "audio_file": audio_file.name, + "transcript": text + }) + print(f" -> {text[:100] if text else '(empty)'}...") + except Exception as e: + print(f" -> Error: {e}") + continue + + # Write CSV + csv_file = output_dir / "transcripts.csv" + with open(csv_file, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=["audio_file", "transcript"]) + writer.writeheader() + writer.writerows(transcripts) + + print(f"\nTranscribed {len(transcripts)} files") + print(f"Saved to: {csv_file}") + EOF + resources: + requests: + memory: 4Gi + cpu: "2" + limits: + memory: 8Gi + cpu: "4" + + # Template: Prepare dataset in Coqui TTS format + - name: prepare-coqui-dataset + inputs: + parameters: + - name: voice-name + - name: language + artifacts: + - name: audio-files + path: /tmp/audio + - name: transcripts + path: /tmp/transcripts + optional: true + outputs: + artifacts: + - name: dataset + path: /tmp/dataset + container: + image: python:3.13-slim + command: [python] + args: + - -c + - | + import csv + import json + import os + import shutil + from pathlib import Path + + VOICE_NAME = "{{inputs.parameters.voice-name}}" + LANGUAGE = "{{inputs.parameters.language}}" + + audio_dir = Path("/tmp/audio") + transcripts_dir = Path("/tmp/transcripts") + output_dir = Path("/tmp/dataset") + wavs_dir = output_dir / "wavs" + wavs_dir.mkdir(parents=True, exist_ok=True) + + print(f"Preparing Coqui TTS dataset for voice: {VOICE_NAME}") + + # Find transcripts file + transcripts_file = None + for f in transcripts_dir.glob("*.csv"): + transcripts_file = f + break + + if transcripts_file is None: + # Check for .txt files (simple format: filename|text) + for f in transcripts_dir.glob("*.txt"): + if f.name != "placeholder.txt": + transcripts_file = f + break + + if transcripts_file is None: + print("Error: No transcripts file found!") + exit(1) + + print(f"Using transcripts: {transcripts_file}") + + # Parse transcripts + transcripts = {} + + if transcripts_file.suffix == ".csv": + with open(transcripts_file, "r", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + # Handle various column name conventions + audio = row.get("audio_file") or row.get("audio") or row.get("file") or row.get("wav") + text = row.get("transcript") or row.get("text") or row.get("sentence") + if audio and text: + transcripts[audio] = text.strip() + else: + # Simple pipe-separated format: filename|text + with open(transcripts_file, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if "|" in line: + parts = line.split("|", 1) + if len(parts) == 2: + transcripts[parts[0]] = parts[1] + + print(f"Loaded {len(transcripts)} transcripts") + + # Copy audio files and create metadata + metadata_lines = [] + + for audio_file in sorted(audio_dir.glob("*.wav")): + # Try to match transcript + text = None + for key in [audio_file.name, audio_file.stem, audio_file.stem + ".wav"]: + if key in transcripts: + text = transcripts[key] + break + + if text is None: + print(f"Warning: No transcript for {audio_file.name}, skipping") + continue + + # Copy audio file + dest_file = wavs_dir / audio_file.name + shutil.copy(audio_file, dest_file) + + # Add to metadata (LJSpeech format: filename|text|text) + # Coqui uses: audio_file|text|text (normalized text optional) + metadata_lines.append(f"{audio_file.stem}|{text}|{text}") + + # Write metadata.csv + metadata_file = output_dir / "metadata.csv" + with open(metadata_file, "w", encoding="utf-8") as f: + f.write("\n".join(metadata_lines)) + + print(f"Created dataset with {len(metadata_lines)} samples") + + # Create dataset config + config = { + "name": VOICE_NAME, + "language": LANGUAGE, + "num_samples": len(metadata_lines), + "format": "ljspeech" + } + + with open(output_dir / "dataset_config.json", "w") as f: + json.dump(config, f, indent=2) + + print(f"Dataset ready at: {output_dir}") + + if len(metadata_lines) < 10: + print("Warning: Very small dataset! Recommend at least 100+ samples for good results.") + resources: + requests: + memory: 1Gi + cpu: 500m + + # Template: Train Coqui TTS model + - name: train-tts + inputs: + parameters: + - name: voice-name + - name: base-model + - name: language + - name: num-epochs + - name: batch-size + - name: learning-rate + artifacts: + - name: dataset + path: /tmp/dataset + outputs: + artifacts: + - name: model + path: /tmp/output + container: + image: ghcr.io/coqui-ai/tts:latest + command: [bash] + args: + - -c + - | + set -e + + VOICE_NAME="{{inputs.parameters.voice-name}}" + BASE_MODEL="{{inputs.parameters.base-model}}" + LANGUAGE="{{inputs.parameters.language}}" + NUM_EPOCHS="{{inputs.parameters.num-epochs}}" + BATCH_SIZE="{{inputs.parameters.batch-size}}" + LEARNING_RATE="{{inputs.parameters.learning-rate}}" + + DATASET_DIR="/tmp/dataset" + OUTPUT_DIR="/tmp/output" + mkdir -p "$OUTPUT_DIR" + + echo "=== Coqui TTS Voice Training ===" + echo "Voice Name: $VOICE_NAME" + echo "Base Model: $BASE_MODEL" + echo "Language: $LANGUAGE" + echo "Epochs: $NUM_EPOCHS" + echo "Batch Size: $BATCH_SIZE" + echo "Learning Rate: $LEARNING_RATE" + echo "" + + # Download base model if specified for fine-tuning + RESTORE_PATH="" + if [ "$BASE_MODEL" != "" ] && [ "$BASE_MODEL" != "none" ]; then + echo "Downloading base model for fine-tuning: $BASE_MODEL" + # Use tts to download the model and get its path + MODEL_PATH=$(python3 -c " + from TTS.utils.manage import ModelManager + from TTS.utils.synthesizer import Synthesizer + from pathlib import Path + import os + + model_name = '$BASE_MODEL' + manager = ModelManager() + + # Download the model + model_path, config_path, _ = manager.download_model(model_name) + print(model_path) + ") + RESTORE_PATH="$MODEL_PATH" + echo "Base model path: $RESTORE_PATH" + fi + + # Create and run training script following Coqui docs pattern + python3 << EOF + import os + from pathlib import Path + + # Trainer: Where the magic happens + from trainer import Trainer, TrainerArgs + + # Model configs + from TTS.tts.configs.vits_config import VitsConfig + from TTS.tts.configs.shared_configs import BaseDatasetConfig + from TTS.tts.datasets import load_tts_samples + from TTS.tts.models.vits import Vits + from TTS.tts.utils.text.tokenizer import TTSTokenizer + from TTS.utils.audio import AudioProcessor + + # Paths + DATASET_DIR = Path("$DATASET_DIR") + OUTPUT_DIR = Path("$OUTPUT_DIR") + RESTORE_PATH = "$RESTORE_PATH" if "$RESTORE_PATH" else None + + print(f"Dataset: {DATASET_DIR}") + print(f"Output: {OUTPUT_DIR}") + print(f"Restore from: {RESTORE_PATH}") + + # Define dataset config (LJSpeech format) + dataset_config = BaseDatasetConfig( + formatter="ljspeech", + meta_file_train="metadata.csv", + path=str(DATASET_DIR), + language="$LANGUAGE", + ) + + # Initialize training configuration + config = VitsConfig( + run_name="$VOICE_NAME", + output_path=str(OUTPUT_DIR), + datasets=[dataset_config], + batch_size=int("$BATCH_SIZE"), + eval_batch_size=max(1, int("$BATCH_SIZE") // 2), + num_loader_workers=4, + num_eval_loader_workers=2, + run_eval=True, + test_delay_epochs=5, + epochs=int("$NUM_EPOCHS"), + text_cleaner="phoneme_cleaners", + use_phonemes=True, + phoneme_language="$LANGUAGE", + phoneme_cache_path=str(OUTPUT_DIR / "phoneme_cache"), + compute_input_seq_cache=True, + print_step=25, + print_eval=False, + mixed_precision=True, + save_step=500, + save_n_checkpoints=3, + save_best_after=1000, + lr=float("$LEARNING_RATE"), + # Audio settings for typical voice cloning + audio={ + "sample_rate": 22050, + "resample": True, + "do_trim_silence": True, + "trim_db": 45, + }, + ) + + # Initialize the audio processor + # Used for feature extraction and audio I/O + ap = AudioProcessor.init_from_config(config) + + # Initialize the tokenizer + # Converts text to sequences of token IDs + tokenizer, config = TTSTokenizer.init_from_config(config) + + # Load data samples + # Each sample is [text, audio_file_path, speaker_name] + train_samples, eval_samples = load_tts_samples( + dataset_config, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, + ) + + print(f"Training samples: {len(train_samples)}") + print(f"Eval samples: {len(eval_samples)}") + + # Initialize the model + model = Vits(config, ap, tokenizer, speaker_manager=None) + + # Set up trainer arguments + trainer_args = TrainerArgs( + restore_path=RESTORE_PATH, + skip_train_epoch=False, + ) + + # Initialize the trainer + trainer = Trainer( + trainer_args, + config, + output_path=str(OUTPUT_DIR), + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + ) + + # Start training + print("\n" + "=" * 50) + print("Starting training...") + print("=" * 50 + "\n") + + trainer.fit() + + print("\n" + "=" * 50) + print("Training complete!") + print("=" * 50) + EOF + + echo "" + echo "Training complete!" + echo "Model saved to: $OUTPUT_DIR" + ls -la "$OUTPUT_DIR" + resources: + requests: + memory: 16Gi + cpu: "4" + nvidia.com/gpu: "1" + limits: + memory: 32Gi + cpu: "8" + nvidia.com/gpu: "1" + volumeMounts: + - name: training-workspace + mountPath: /tmp/workspace + + # Template: Export trained model + - name: export-trained-model + inputs: + parameters: + - name: voice-name + - name: output-path + artifacts: + - name: trained-model + path: /tmp/trained-model + outputs: + artifacts: + - name: exported-model + path: /tmp/exported + container: + image: python:3.13-slim + command: [bash] + args: + - -c + - | + set -e + + pip install -q boto3 + + python3 << 'EOF' + import json + import os + import shutil + import subprocess + from pathlib import Path + from datetime import datetime + + VOICE_NAME = "{{inputs.parameters.voice-name}}" + OUTPUT_PATH = "{{inputs.parameters.output-path}}" + + model_dir = Path("/tmp/trained-model") + export_dir = Path("/tmp/exported") + export_dir.mkdir(parents=True, exist_ok=True) + + print(f"Exporting trained model: {VOICE_NAME}") + print(f"Target path: {OUTPUT_PATH}") + + # Find best checkpoint + checkpoints = list(model_dir.glob("best_model*.pth")) + list(model_dir.glob("checkpoint_*.pth")) + if not checkpoints: + checkpoints = list(model_dir.glob("*.pth")) + + if not checkpoints: + print("Error: No model checkpoints found!") + exit(1) + + # Sort by modification time and get newest + checkpoints.sort(key=lambda x: x.stat().st_mtime, reverse=True) + best_checkpoint = checkpoints[0] + print(f"Using checkpoint: {best_checkpoint.name}") + + # Create export package + package_dir = export_dir / VOICE_NAME + package_dir.mkdir(parents=True, exist_ok=True) + + # Copy model files + shutil.copy(best_checkpoint, package_dir / "model.pth") + + # Copy config if exists + config_file = model_dir / "config.json" + if config_file.exists(): + shutil.copy(config_file, package_dir / "config.json") + + # Create model info + model_info = { + "name": VOICE_NAME, + "created_at": datetime.now().isoformat(), + "checkpoint": best_checkpoint.name, + "type": "coqui-tts" + } + + with open(package_dir / "model_info.json", "w") as f: + json.dump(model_info, f, indent=2) + + # Create tarball + archive_name = f"{VOICE_NAME}.tar.gz" + shutil.make_archive( + str(export_dir / VOICE_NAME), + "gztar", + export_dir, + VOICE_NAME + ) + + print(f"Created archive: {archive_name}") + + # Upload to destination + if OUTPUT_PATH.startswith("s3://"): + import boto3 + s3 = boto3.client("s3") + bucket, key = OUTPUT_PATH[5:].split("/", 1) + key = f"{key}/{archive_name}" + s3.upload_file(str(export_dir / archive_name), bucket, key) + print(f"Uploaded to: s3://{bucket}/{key}") + + elif OUTPUT_PATH.startswith("/"): + # Local/NFS path + dest_path = Path(OUTPUT_PATH) + dest_path.mkdir(parents=True, exist_ok=True) + shutil.copy(export_dir / archive_name, dest_path / archive_name) + # Also copy uncompressed for easy access + shutil.copytree(package_dir, dest_path / VOICE_NAME, dirs_exist_ok=True) + print(f"Saved to: {dest_path / archive_name}") + + print("\nExport complete!") + print(f"Model package contents:") + for f in package_dir.iterdir(): + print(f" - {f.name}") + EOF + resources: + requests: + memory: 1Gi + cpu: 500m diff --git a/document-ingestion.yaml b/document-ingestion.yaml new file mode 100644 index 0000000..bbda395 --- /dev/null +++ b/document-ingestion.yaml @@ -0,0 +1,369 @@ +# Document Ingestion Workflow +# Ingests documents from a source URL into Milvus vector database +# Triggered via NATS: ai.pipeline.trigger with pipeline="document-ingestion" +--- +apiVersion: argoproj.io/v1alpha1 +kind: WorkflowTemplate +metadata: + name: document-ingestion + namespace: ai-ml + labels: + app.kubernetes.io/name: document-ingestion + app.kubernetes.io/part-of: llm-workflows +spec: + entrypoint: ingest-documents + serviceAccountName: argo-workflow + + arguments: + parameters: + - name: source-url + description: "URL to fetch documents from (S3, HTTP, or local path)" + - name: collection-name + value: "knowledge_base" + description: "Milvus collection name" + - name: chunk-size + value: "512" + description: "Text chunk size in characters" + - name: chunk-overlap + value: "50" + description: "Overlap between chunks" + + templates: + - name: ingest-documents + dag: + tasks: + - name: fetch-documents + template: fetch-docs + arguments: + parameters: + - name: source-url + value: "{{workflow.parameters.source-url}}" + + - name: chunk-documents + template: chunk-docs + dependencies: [fetch-documents] + arguments: + parameters: + - name: chunk-size + value: "{{workflow.parameters.chunk-size}}" + - name: chunk-overlap + value: "{{workflow.parameters.chunk-overlap}}" + artifacts: + - name: documents + from: "{{tasks.fetch-documents.outputs.artifacts.documents}}" + + - name: generate-embeddings + template: embed-docs + dependencies: [chunk-documents] + arguments: + artifacts: + - name: chunks + from: "{{tasks.chunk-documents.outputs.artifacts.chunks}}" + + - name: store-in-milvus + template: store-docs + dependencies: [generate-embeddings] + arguments: + parameters: + - name: collection-name + value: "{{workflow.parameters.collection-name}}" + artifacts: + - name: embeddings + from: "{{tasks.generate-embeddings.outputs.artifacts.embeddings}}" + + - name: fetch-docs + inputs: + parameters: + - name: source-url + outputs: + artifacts: + - name: documents + path: /tmp/documents + container: + image: python:3.13-slim + command: [python] + args: + - -c + - | + import json + import os + import urllib.request + from pathlib import Path + + source_url = "{{inputs.parameters.source-url}}" + output_dir = Path("/tmp/documents") + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"Fetching documents from: {source_url}") + + # Handle different source types + if source_url.startswith("s3://"): + import subprocess + subprocess.run(["pip", "install", "boto3", "-q"], check=True) + import boto3 + s3 = boto3.client("s3") + bucket, prefix = source_url[5:].split("/", 1) + response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix) + for obj in response.get("Contents", []): + key = obj["Key"] + local_path = output_dir / Path(key).name + s3.download_file(bucket, key, str(local_path)) + print(f"Downloaded: {key}") + elif source_url.startswith("http"): + # Single file download + filename = source_url.split("/")[-1] or "document.txt" + local_path = output_dir / filename + urllib.request.urlretrieve(source_url, local_path) + print(f"Downloaded: {filename}") + else: + print(f"Unsupported URL scheme: {source_url}") + exit(1) + + # List downloaded files + files = list(output_dir.glob("*")) + print(f"Downloaded {len(files)} files") + + # Create manifest + manifest = {"files": [str(f) for f in files]} + with open(output_dir / "manifest.json", "w") as f: + json.dump(manifest, f) + resources: + requests: + memory: 256Mi + cpu: 100m + + - name: chunk-docs + inputs: + parameters: + - name: chunk-size + - name: chunk-overlap + artifacts: + - name: documents + path: /tmp/documents + outputs: + artifacts: + - name: chunks + path: /tmp/chunks + container: + image: python:3.13-slim + command: [python] + args: + - -c + - | + import json + from pathlib import Path + + chunk_size = int("{{inputs.parameters.chunk-size}}") + chunk_overlap = int("{{inputs.parameters.chunk-overlap}}") + + input_dir = Path("/tmp/documents") + output_dir = Path("/tmp/chunks") + output_dir.mkdir(parents=True, exist_ok=True) + + # Load manifest + with open(input_dir / "manifest.json") as f: + manifest = json.load(f) + + all_chunks = [] + + for filepath in manifest["files"]: + filepath = Path(filepath) + if not filepath.exists(): + continue + + print(f"Processing: {filepath.name}") + + # Read file content + try: + with open(filepath, "r", encoding="utf-8") as f: + content = f.read() + except Exception as e: + print(f"Error reading {filepath}: {e}") + continue + + # Simple chunking + chunks = [] + start = 0 + while start < len(content): + end = start + chunk_size + chunk = content[start:end] + if chunk.strip(): + chunks.append({ + "text": chunk, + "source": filepath.name, + "chunk_index": len(chunks) + }) + start = end - chunk_overlap + + all_chunks.extend(chunks) + print(f" Created {len(chunks)} chunks") + + # Save chunks + with open(output_dir / "chunks.json", "w") as f: + json.dump({"chunks": all_chunks}, f) + + print(f"Total chunks: {len(all_chunks)}") + resources: + requests: + memory: 512Mi + cpu: 100m + + - name: embed-docs + inputs: + artifacts: + - name: chunks + path: /tmp/chunks + outputs: + artifacts: + - name: embeddings + path: /tmp/embeddings + container: + image: python:3.13-slim + command: [python] + args: + - -c + - | + import subprocess + subprocess.run(["pip", "install", "httpx", "-q"], check=True) + + import json + import httpx + from pathlib import Path + + EMBEDDINGS_URL = "http://embeddings-predictor.ai-ml.svc.cluster.local" + BATCH_SIZE = 32 + + input_dir = Path("/tmp/chunks") + output_dir = Path("/tmp/embeddings") + output_dir.mkdir(parents=True, exist_ok=True) + + # Load chunks + with open(input_dir / "chunks.json") as f: + data = json.load(f) + chunks = data["chunks"] + + print(f"Generating embeddings for {len(chunks)} chunks") + + # Generate embeddings in batches + all_embeddings = [] + with httpx.Client(timeout=120.0) as client: + for i in range(0, len(chunks), BATCH_SIZE): + batch = chunks[i:i+BATCH_SIZE] + texts = [c["text"] for c in batch] + + response = client.post( + f"{EMBEDDINGS_URL}/embeddings", + json={"input": texts, "model": "bge"} + ) + result = response.json() + + for j, emb_data in enumerate(result.get("data", [])): + all_embeddings.append({ + "text": batch[j]["text"], + "source": batch[j]["source"], + "chunk_index": batch[j]["chunk_index"], + "embedding": emb_data["embedding"] + }) + + print(f" Processed batch {i//BATCH_SIZE + 1}/{(len(chunks)-1)//BATCH_SIZE + 1}") + + # Save embeddings + with open(output_dir / "embeddings.json", "w") as f: + json.dump({"embeddings": all_embeddings}, f) + + print(f"Generated {len(all_embeddings)} embeddings") + envFrom: + - configMapRef: + name: ai-services-config + resources: + requests: + memory: 1Gi + cpu: 200m + + - name: store-docs + inputs: + parameters: + - name: collection-name + artifacts: + - name: embeddings + path: /tmp/embeddings + container: + image: python:3.13-slim + command: [python] + args: + - -c + - | + import subprocess + subprocess.run(["pip", "install", "pymilvus", "-q"], check=True) + + import json + from pathlib import Path + from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility + + MILVUS_HOST = "milvus.ai-ml.svc.cluster.local" + MILVUS_PORT = 19530 + COLLECTION_NAME = "{{inputs.parameters.collection-name}}" + EMBEDDING_DIM = 1024 # BGE-large dimension + + input_dir = Path("/tmp/embeddings") + + # Load embeddings + with open(input_dir / "embeddings.json") as f: + data = json.load(f) + embeddings = data["embeddings"] + + print(f"Storing {len(embeddings)} embeddings in Milvus") + + # Connect to Milvus + connections.connect(host=MILVUS_HOST, port=MILVUS_PORT) + print("Connected to Milvus") + + # Create collection if not exists + if not utility.has_collection(COLLECTION_NAME): + fields = [ + FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), + FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), + FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=1024), + FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=EMBEDDING_DIM) + ] + schema = CollectionSchema(fields, description="Knowledge base documents") + collection = Collection(COLLECTION_NAME, schema) + + # Create HNSW index + index_params = { + "metric_type": "COSINE", + "index_type": "HNSW", + "params": {"M": 16, "efConstruction": 256} + } + collection.create_index("embedding", index_params) + print(f"Created collection: {COLLECTION_NAME}") + else: + collection = Collection(COLLECTION_NAME) + print(f"Using existing collection: {COLLECTION_NAME}") + + # Insert data in batches + BATCH_SIZE = 100 + for i in range(0, len(embeddings), BATCH_SIZE): + batch = embeddings[i:i+BATCH_SIZE] + + data = [ + [e["text"] for e in batch], + [e["source"] for e in batch], + [e["embedding"] for e in batch] + ] + + collection.insert(data) + print(f" Inserted batch {i//BATCH_SIZE + 1}/{(len(embeddings)-1)//BATCH_SIZE + 1}") + + # Flush to ensure data is persisted + collection.flush() + print(f"Successfully stored {len(embeddings)} documents") + + connections.disconnect("default") + envFrom: + - configMapRef: + name: ai-services-config + resources: + requests: + memory: 512Mi + cpu: 100m diff --git a/eventsource-kfp.yaml b/eventsource-kfp.yaml new file mode 100644 index 0000000..f02fe26 --- /dev/null +++ b/eventsource-kfp.yaml @@ -0,0 +1,270 @@ +# Argo Events - EventSource for KFP and NATS integration +# Enables bidirectional triggering between Argo Workflows and Kubeflow Pipelines +--- +apiVersion: argoproj.io/v1alpha1 +kind: EventSource +metadata: + name: kfp-events + namespace: ai-ml + labels: + app.kubernetes.io/name: kfp-events + app.kubernetes.io/part-of: llm-workflows +spec: + service: + ports: + - name: webhook + port: 12000 + targetPort: 12000 + # Webhook to receive KFP pipeline completion events + webhook: + kfp-completion: + port: "12000" + endpoint: /kfp/completion + method: POST + kfp-failure: + port: "12000" + endpoint: /kfp/failure + method: POST + # NATS for receiving pipeline trigger requests + nats: + pipeline-trigger: + url: nats://nats.ai-ml.svc.cluster.local:4222 + subject: ai.pipeline.trigger + jsonBody: true + argo-trigger: + url: nats://nats.ai-ml.svc.cluster.local:4222 + subject: ai.argo.trigger + jsonBody: true + kfp-trigger: + url: nats://nats.ai-ml.svc.cluster.local:4222 + subject: ai.kfp.trigger + jsonBody: true +--- +# Sensor for handling KFP completion events +apiVersion: argoproj.io/v1alpha1 +kind: Sensor +metadata: + name: kfp-completion-sensor + namespace: ai-ml + labels: + app.kubernetes.io/name: kfp-completion-sensor + app.kubernetes.io/part-of: llm-workflows +spec: + dependencies: + - name: kfp-success + eventSourceName: kfp-events + eventName: kfp-completion + filters: + data: + - path: body.status + type: string + value: + - "SUCCEEDED" + - name: kfp-failure + eventSourceName: kfp-events + eventName: kfp-failure + triggers: + # On KFP success, publish to NATS + - template: + name: notify-kfp-success + nats: + url: nats://nats.ai-ml.svc.cluster.local:4222 + subject: ai.pipeline.status.completed + payload: + - src: + dependencyName: kfp-success + dataKey: body.run_id + dest: run_id + - src: + dependencyName: kfp-success + dataKey: body.pipeline_name + dest: pipeline_name + - src: + dependencyName: kfp-success + dataKey: body.status + dest: status + retryStrategy: + steps: 3 + # On KFP failure, trigger recovery workflow + - template: + name: kfp-failure-recovery + k8s: + operation: create + source: + resource: + apiVersion: argoproj.io/v1alpha1 + kind: Workflow + metadata: + generateName: kfp-failure-handler- + namespace: ai-ml + spec: + entrypoint: notify-failure + arguments: + parameters: + - name: run-id + - name: pipeline-name + - name: error-message + templates: + - name: notify-failure + inputs: + parameters: + - name: run-id + - name: pipeline-name + - name: error-message + script: + image: python:3.13-slim + command: [python] + source: | + import subprocess + import sys + subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "nats-py"]) + + import asyncio + import json + import nats + + async def notify(): + nc = await nats.connect("nats://nats.ai-ml.svc.cluster.local:4222") + await nc.publish( + "ai.pipeline.status.failed", + json.dumps({ + "run_id": "{{inputs.parameters.run-id}}", + "pipeline_name": "{{inputs.parameters.pipeline-name}}", + "error": "{{inputs.parameters.error-message}}", + "source": "kubeflow" + }).encode() + ) + await nc.close() + + asyncio.run(notify()) + parameters: + - src: + dependencyName: kfp-failure + dataKey: body.run_id + dest: spec.arguments.parameters.0.value + - src: + dependencyName: kfp-failure + dataKey: body.pipeline_name + dest: spec.arguments.parameters.1.value + - src: + dependencyName: kfp-failure + dataKey: body.error + dest: spec.arguments.parameters.2.value + retryStrategy: + steps: 3 +--- +# Sensor for NATS-triggered Argo Workflows +apiVersion: argoproj.io/v1alpha1 +kind: Sensor +metadata: + name: nats-argo-sensor + namespace: ai-ml + labels: + app.kubernetes.io/name: nats-argo-sensor + app.kubernetes.io/part-of: llm-workflows +spec: + dependencies: + - name: argo-trigger + eventSourceName: kfp-events + eventName: argo-trigger + triggers: + - template: + name: trigger-argo-workflow + k8s: + operation: create + source: + resource: + apiVersion: argoproj.io/v1alpha1 + kind: Workflow + metadata: + generateName: nats-triggered- + namespace: ai-ml + spec: + workflowTemplateRef: + name: placeholder + arguments: + parameters: [] + parameters: + - src: + dependencyName: argo-trigger + dataKey: body.template + dest: spec.workflowTemplateRef.name + - src: + dependencyName: argo-trigger + dataKey: body.parameters + dest: spec.arguments.parameters + retryStrategy: + steps: 3 +--- +# Sensor for NATS-triggered KFP Pipelines +apiVersion: argoproj.io/v1alpha1 +kind: Sensor +metadata: + name: nats-kfp-sensor + namespace: ai-ml + labels: + app.kubernetes.io/name: nats-kfp-sensor + app.kubernetes.io/part-of: llm-workflows +spec: + dependencies: + - name: kfp-trigger + eventSourceName: kfp-events + eventName: kfp-trigger + triggers: + # Trigger KFP via Argo Workflow (uses kfp-trigger template) + - template: + name: trigger-kfp-via-argo + k8s: + operation: create + source: + resource: + apiVersion: argoproj.io/v1alpha1 + kind: Workflow + metadata: + generateName: kfp-via-nats- + namespace: ai-ml + spec: + workflowTemplateRef: + name: kfp-trigger + arguments: + parameters: + - name: pipeline-id + value: "" + - name: pipeline-params + value: "{}" + - name: wait-for-completion + value: "true" + parameters: + - src: + dependencyName: kfp-trigger + dataKey: body.pipeline_id + dest: spec.arguments.parameters.0.value + - src: + dependencyName: kfp-trigger + dataKey: body.parameters + dest: spec.arguments.parameters.1.value + operation: "stringify" + - src: + dependencyName: kfp-trigger + dataKey: body.wait + dest: spec.arguments.parameters.2.value + operation: "stringify" + retryStrategy: + steps: 3 +--- +# Service for the EventSource webhook +apiVersion: v1 +kind: Service +metadata: + name: kfp-events-webhook + namespace: ai-ml + labels: + app.kubernetes.io/name: kfp-events + app.kubernetes.io/part-of: llm-workflows +spec: + selector: + eventsource-name: kfp-events + ports: + - name: webhook + port: 12000 + targetPort: 12000 diff --git a/hybrid-ml-training.yaml b/hybrid-ml-training.yaml new file mode 100644 index 0000000..619c5df --- /dev/null +++ b/hybrid-ml-training.yaml @@ -0,0 +1,555 @@ +# Hybrid ML Training Workflow +# Combines Argo Workflows orchestration with Kubeflow Pipeline ML components +# Use case: Train a model using data from Milvus, with checkpointing and evaluation +--- +apiVersion: argoproj.io/v1alpha1 +kind: WorkflowTemplate +metadata: + name: hybrid-ml-training + namespace: ai-ml + labels: + app.kubernetes.io/name: hybrid-ml-training + app.kubernetes.io/part-of: llm-workflows + annotations: + description: | + Demonstrates hybrid Argo+KFP workflow: + - Argo handles orchestration, branching, retry logic + - KFP pipelines handle ML-specific operations (with caching) + - NATS for status updates to frontends +spec: + entrypoint: hybrid-training + serviceAccountName: argo-workflow + + # Artifact repository for model checkpoints + artifactRepositoryRef: + configMap: artifact-repository + key: default + + arguments: + parameters: + - name: collection-name + description: "Milvus collection to pull training data from" + value: "dnd_text_embeddings" + - name: model-name + description: "Base model for fine-tuning" + value: "mistralai/Mistral-7B-v0.3" + - name: lora-rank + description: "LoRA rank (higher = more params)" + value: "16" + - name: epochs + description: "Training epochs" + value: "3" + - name: batch-size + description: "Training batch size" + value: "4" + - name: output-path + description: "S3 path for model output" + value: "s3://models/lora-adapters" + - name: notify-nats + description: "Publish status to NATS" + value: "true" + + # Volumes for GPU caching + volumes: + - name: model-cache + persistentVolumeClaim: + claimName: model-cache-pvc + - name: shm + emptyDir: + medium: Memory + sizeLimit: 16Gi + + templates: + # Main DAG orchestrating the workflow + - name: hybrid-training + dag: + tasks: + - name: notify-start + template: nats-notify + when: "{{workflow.parameters.notify-nats}} == true" + arguments: + parameters: + - name: subject + value: "ai.pipeline.status.{{workflow.name}}" + - name: message + value: '{"status": "started", "pipeline": "hybrid-ml-training"}' + + - name: prepare-data + template: extract-training-data + arguments: + parameters: + - name: collection-name + value: "{{workflow.parameters.collection-name}}" + + - name: validate-data + template: validate-dataset + dependencies: [prepare-data] + arguments: + artifacts: + - name: dataset + from: "{{tasks.prepare-data.outputs.artifacts.dataset}}" + + # KFP Pipeline: Run embedding generation if needed + - name: generate-embeddings + template: trigger-kfp + dependencies: [validate-data] + when: "{{tasks.validate-data.outputs.parameters.needs-embeddings}} == true" + arguments: + parameters: + - name: pipeline-id + value: "embedding-generation" + - name: params + value: '{"input_path": "{{tasks.prepare-data.outputs.parameters.data-path}}"}' + + # Training step (runs on GPU) + - name: train-lora + template: lora-training + dependencies: [validate-data, generate-embeddings] + arguments: + parameters: + - name: model-name + value: "{{workflow.parameters.model-name}}" + - name: lora-rank + value: "{{workflow.parameters.lora-rank}}" + - name: epochs + value: "{{workflow.parameters.epochs}}" + - name: batch-size + value: "{{workflow.parameters.batch-size}}" + artifacts: + - name: dataset + from: "{{tasks.prepare-data.outputs.artifacts.dataset}}" + + # Evaluate model + - name: evaluate + template: evaluate-model + dependencies: [train-lora] + arguments: + artifacts: + - name: adapter + from: "{{tasks.train-lora.outputs.artifacts.adapter}}" + + # Branch based on evaluation results + - name: check-quality + template: quality-gate + dependencies: [evaluate] + arguments: + parameters: + - name: eval-score + value: "{{tasks.evaluate.outputs.parameters.score}}" + - name: threshold + value: "0.7" + + # If quality is good, upload to S3 + - name: upload-model + template: upload-to-s3 + dependencies: [check-quality] + when: "{{tasks.check-quality.outputs.parameters.passed}} == true" + arguments: + parameters: + - name: output-path + value: "{{workflow.parameters.output-path}}" + artifacts: + - name: adapter + from: "{{tasks.train-lora.outputs.artifacts.adapter}}" + + # If quality is poor, trigger retraining with different params + - name: retry-training + template: adjust-and-retry + dependencies: [check-quality] + when: "{{tasks.check-quality.outputs.parameters.passed}} == false" + arguments: + parameters: + - name: current-rank + value: "{{workflow.parameters.lora-rank}}" + - name: current-epochs + value: "{{workflow.parameters.epochs}}" + + - name: notify-complete + template: nats-notify + when: "{{workflow.parameters.notify-nats}} == true" + dependencies: [upload-model] + arguments: + parameters: + - name: subject + value: "ai.pipeline.status.{{workflow.name}}" + - name: message + value: '{"status": "completed", "score": "{{tasks.evaluate.outputs.parameters.score}}"}' + + # Extract training data from Milvus + - name: extract-training-data + inputs: + parameters: + - name: collection-name + outputs: + artifacts: + - name: dataset + path: /tmp/dataset + parameters: + - name: data-path + valueFrom: + path: /tmp/data-path + script: + image: python:3.13-slim + command: [python] + source: | + import subprocess + import sys + subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "pymilvus", "pandas", "pyarrow"]) + + from pymilvus import connections, Collection + import pandas as pd + from pathlib import Path + + connections.connect(host="milvus.ai-ml.svc.cluster.local", port=19530) + + collection = Collection("{{inputs.parameters.collection-name}}") + collection.load() + + # Query all training samples + results = collection.query( + expr="source != ''", + output_fields=["text", "source", "metadata"], + limit=10000 + ) + + # Convert to training format + df = pd.DataFrame(results) + output_dir = Path("/tmp/dataset") + output_dir.mkdir(parents=True, exist_ok=True) + + # Save as parquet for efficient loading + df.to_parquet(output_dir / "train.parquet") + + print(f"Extracted {len(df)} samples") + with open("/tmp/data-path", "w") as f: + f.write(str(output_dir / "train.parquet")) + + # Validate dataset + - name: validate-dataset + inputs: + artifacts: + - name: dataset + path: /tmp/dataset + outputs: + parameters: + - name: needs-embeddings + valueFrom: + path: /tmp/needs-embeddings + - name: sample-count + valueFrom: + path: /tmp/sample-count + script: + image: python:3.13-slim + command: [python] + source: | + import subprocess + import sys + subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "pandas", "pyarrow"]) + + import pandas as pd + from pathlib import Path + + dataset_dir = Path("/tmp/dataset") + parquet_files = list(dataset_dir.glob("*.parquet")) + + if not parquet_files: + raise ValueError("No parquet files found in dataset") + + df = pd.read_parquet(parquet_files[0]) + sample_count = len(df) + print(f"Dataset contains {sample_count} samples") + + # Check if embeddings column exists + needs_embeddings = "embedding" not in df.columns + + with open("/tmp/needs-embeddings", "w") as f: + f.write(str(needs_embeddings).lower()) + + with open("/tmp/sample-count", "w") as f: + f.write(str(sample_count)) + + # Trigger KFP pipeline + - name: trigger-kfp + inputs: + parameters: + - name: pipeline-id + - name: params + outputs: + parameters: + - name: run-id + valueFrom: + path: /tmp/run-id + script: + image: python:3.13-slim + command: [python] + source: | + import subprocess + import sys + import json + subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "kfp==2.12.1"]) + + from kfp import Client + + client = Client(host="http://ml-pipeline.kubeflow.svc.cluster.local:8888") + params = json.loads('''{{inputs.parameters.params}}''') + + run = client.create_run_from_pipeline_func( + pipeline_func=None, + pipeline_id="{{inputs.parameters.pipeline-id}}", + arguments=params + ) + + print(f"Triggered KFP pipeline: {run.run_id}") + with open("/tmp/run-id", "w") as f: + f.write(run.run_id) + + # LoRA training (GPU) + - name: lora-training + inputs: + parameters: + - name: model-name + - name: lora-rank + - name: epochs + - name: batch-size + artifacts: + - name: dataset + path: /data/dataset + outputs: + artifacts: + - name: adapter + path: /output/adapter + - name: logs + path: /output/logs + podSpecPatch: | + containers: + - name: main + resources: + requests: + amd.com/gpu: 1 + limits: + amd.com/gpu: 1 + script: + image: ghcr.io/billy-davies-2/lora-trainer:latest + command: [python] + env: + - name: HF_HOME + value: /cache/huggingface + - name: TRANSFORMERS_CACHE + value: /cache/huggingface + volumeMounts: + - name: model-cache + mountPath: /cache + - name: shm + mountPath: /dev/shm + source: | + import os + import sys + from pathlib import Path + + # Training script + from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments + from peft import LoraConfig, get_peft_model + from datasets import load_dataset + from trl import SFTTrainer + + model_name = "{{inputs.parameters.model-name}}" + lora_rank = int("{{inputs.parameters.lora-rank}}") + epochs = int("{{inputs.parameters.epochs}}") + batch_size = int("{{inputs.parameters.batch-size}}") + + # Load dataset + dataset = load_dataset("parquet", data_files="/data/dataset/*.parquet", split="train") + + # Load model + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype="auto", + device_map="auto" + ) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + # Configure LoRA + lora_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_rank * 2, + target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM" + ) + + model = get_peft_model(model, lora_config) + + # Training + training_args = TrainingArguments( + output_dir="/output/adapter", + num_train_epochs=epochs, + per_device_train_batch_size=batch_size, + gradient_accumulation_steps=4, + learning_rate=2e-4, + logging_dir="/output/logs", + save_strategy="epoch" + ) + + trainer = SFTTrainer( + model=model, + train_dataset=dataset, + tokenizer=tokenizer, + args=training_args, + dataset_text_field="text" + ) + + trainer.train() + trainer.save_model("/output/adapter") + print("Training complete!") + + # Evaluate model + - name: evaluate-model + inputs: + artifacts: + - name: adapter + path: /input/adapter + outputs: + parameters: + - name: score + valueFrom: + path: /tmp/score + script: + image: python:3.13-slim + command: [python] + source: | + import subprocess + import sys + subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "httpx"]) + + import httpx + import json + + # Run evaluation using vLLM with the adapter + test_prompts = [ + "What is the capital of France?", + "Explain machine learning in simple terms.", + "Write a haiku about coding." + ] + + scores = [] + with httpx.Client(timeout=120.0) as client: + for prompt in test_prompts: + response = client.post( + "http://llm-draft.ai-ml.svc.cluster.local:8000/v1/chat/completions", + json={ + "model": "local-adapter", + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 200 + } + ) + # Simple scoring based on response coherence + result = response.json() + content = result["choices"][0]["message"]["content"] + score = min(1.0, len(content) / 100) # Placeholder scoring + scores.append(score) + + avg_score = sum(scores) / len(scores) + print(f"Average evaluation score: {avg_score}") + + with open("/tmp/score", "w") as f: + f.write(str(round(avg_score, 3))) + + # Quality gate + - name: quality-gate + inputs: + parameters: + - name: eval-score + - name: threshold + outputs: + parameters: + - name: passed + valueFrom: + path: /tmp/passed + script: + image: python:3.13-slim + command: [python] + source: | + score = float("{{inputs.parameters.eval-score}}") + threshold = float("{{inputs.parameters.threshold}}") + + passed = score >= threshold + print(f"Score {score} {'passed' if passed else 'failed'} threshold {threshold}") + + with open("/tmp/passed", "w") as f: + f.write(str(passed).lower()) + + # Upload to S3 + - name: upload-to-s3 + inputs: + parameters: + - name: output-path + artifacts: + - name: adapter + path: /input/adapter + script: + image: amazon/aws-cli:latest + command: [bash] + env: + - name: AWS_ACCESS_KEY_ID + valueFrom: + secretKeyRef: + name: s3-credentials + key: access-key + - name: AWS_SECRET_ACCESS_KEY + valueFrom: + secretKeyRef: + name: s3-credentials + key: secret-key + - name: AWS_ENDPOINT_URL + value: "https://quobjects.billy.davies.cloud" + source: | + aws s3 cp --recursive /input/adapter "{{inputs.parameters.output-path}}/$(date +%Y%m%d-%H%M%S)/" + echo "Uploaded adapter to {{inputs.parameters.output-path}}" + + # Adjust parameters and retry + - name: adjust-and-retry + inputs: + parameters: + - name: current-rank + - name: current-epochs + script: + image: python:3.13-slim + command: [python] + source: | + current_rank = int("{{inputs.parameters.current-rank}}") + current_epochs = int("{{inputs.parameters.current-epochs}}") + + # Increase rank and epochs for next attempt + new_rank = min(64, current_rank * 2) + new_epochs = current_epochs + 2 + + print(f"Adjusting parameters: rank {current_rank}->{new_rank}, epochs {current_epochs}->{new_epochs}") + print("TODO: Trigger new workflow with adjusted parameters") + + # NATS notification + - name: nats-notify + inputs: + parameters: + - name: subject + - name: message + script: + image: python:3.13-slim + command: [python] + source: | + import subprocess + import sys + subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "nats-py"]) + + import asyncio + import nats + + async def notify(): + nc = await nats.connect("nats://nats.ai-ml.svc.cluster.local:4222") + await nc.publish( + "{{inputs.parameters.subject}}", + b'''{{inputs.parameters.message}}''' + ) + await nc.close() + + asyncio.run(notify()) diff --git a/kfp-integration.yaml b/kfp-integration.yaml new file mode 100644 index 0000000..d27b5f1 --- /dev/null +++ b/kfp-integration.yaml @@ -0,0 +1,237 @@ +# Argo Workflows + Kubeflow Pipelines Integration +# This template allows Argo Workflows to trigger KFP pipelines and vice versa +--- +apiVersion: argoproj.io/v1alpha1 +kind: WorkflowTemplate +metadata: + name: kfp-trigger + namespace: ai-ml + labels: + app.kubernetes.io/name: kfp-trigger + app.kubernetes.io/part-of: llm-workflows +spec: + entrypoint: trigger-kfp-pipeline + serviceAccountName: argo-workflow + + arguments: + parameters: + - name: pipeline-id + description: "Kubeflow Pipeline ID or name" + - name: pipeline-params + description: "JSON object of pipeline parameters" + value: "{}" + - name: experiment-name + description: "KFP Experiment to use" + value: "Default" + - name: wait-for-completion + description: "Wait for pipeline to complete" + value: "true" + + templates: + - name: trigger-kfp-pipeline + steps: + - - name: submit-run + template: submit-kfp-run + arguments: + parameters: + - name: pipeline-id + value: "{{workflow.parameters.pipeline-id}}" + - name: pipeline-params + value: "{{workflow.parameters.pipeline-params}}" + - name: experiment-name + value: "{{workflow.parameters.experiment-name}}" + + - - name: wait-completion + template: wait-for-kfp + when: "{{workflow.parameters.wait-for-completion}} == true" + arguments: + parameters: + - name: run-id + value: "{{steps.submit-run.outputs.parameters.run-id}}" + + - name: submit-kfp-run + inputs: + parameters: + - name: pipeline-id + - name: pipeline-params + - name: experiment-name + outputs: + parameters: + - name: run-id + valueFrom: + path: /tmp/run-id + script: + image: python:3.13-slim + command: [python] + source: | + import json + import subprocess + import sys + subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "kfp==2.12.1"]) + + from kfp import Client + + KUBEFLOW_HOST = "http://ml-pipeline.kubeflow.svc.cluster.local:8888" + + client = Client(host=KUBEFLOW_HOST) + + pipeline_id = "{{inputs.parameters.pipeline-id}}" + params = json.loads('''{{inputs.parameters.pipeline-params}}''') + experiment_name = "{{inputs.parameters.experiment-name}}" + + # Get or create experiment + try: + experiment = client.get_experiment(experiment_name=experiment_name) + except: + experiment = client.create_experiment(name=experiment_name) + + # Get pipeline by name or ID + try: + pipeline = client.get_pipeline(pipeline_id) + except: + # Try by name + pipelines = client.list_pipelines(filter=f'name="{pipeline_id}"') + if pipelines.pipelines: + pipeline = pipelines.pipelines[0] + else: + raise ValueError(f"Pipeline not found: {pipeline_id}") + + # Create run + run = client.run_pipeline( + experiment_id=experiment.experiment_id, + job_name=f"{pipeline.display_name}-argo-{pipeline_id[:8]}", + pipeline_id=pipeline.pipeline_id, + params=params + ) + + print(f"Submitted KFP run: {run.run_id}") + with open("/tmp/run-id", "w") as f: + f.write(run.run_id) + + - name: wait-for-kfp + inputs: + parameters: + - name: run-id + outputs: + parameters: + - name: status + valueFrom: + path: /tmp/status + script: + image: python:3.13-slim + command: [python] + source: | + import subprocess + import sys + import time + subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "kfp==2.12.1"]) + + from kfp import Client + + KUBEFLOW_HOST = "http://ml-pipeline.kubeflow.svc.cluster.local:8888" + run_id = "{{inputs.parameters.run-id}}" + + client = Client(host=KUBEFLOW_HOST) + + while True: + run = client.get_run(run_id) + state = run.run.status + + print(f"Run {run_id} status: {state}") + + if state in ["SUCCEEDED", "SKIPPED"]: + with open("/tmp/status", "w") as f: + f.write("SUCCEEDED") + break + elif state in ["FAILED", "ERROR", "CANCELLED"]: + with open("/tmp/status", "w") as f: + f.write(state) + raise Exception(f"Pipeline failed with status: {state}") + + time.sleep(30) + +--- +# WorkflowTemplate for running KFP pipeline components as Argo steps +apiVersion: argoproj.io/v1alpha1 +kind: WorkflowTemplate +metadata: + name: kfp-component-runner + namespace: ai-ml + labels: + app.kubernetes.io/name: kfp-component-runner + app.kubernetes.io/part-of: llm-workflows +spec: + entrypoint: run-component + serviceAccountName: argo-workflow + + arguments: + parameters: + - name: component-name + description: "Name of the KFP component to run" + - name: component-params + description: "JSON parameters for the component" + value: "{}" + + templates: + - name: run-component + inputs: + parameters: + - name: component-name + - name: component-params + outputs: + parameters: + - name: result + valueFrom: + path: /tmp/result.json + script: + image: python:3.13-slim + command: [python] + source: | + import json + import subprocess + import sys + subprocess.check_call([ + sys.executable, "-m", "pip", "install", "-q", + "httpx", "pymilvus" + ]) + + import httpx + + component_name = "{{inputs.parameters.component-name}}" + params = json.loads('''{{inputs.parameters.component-params}}''') + + # Component implementations (mirrors KFP components) + COMPONENTS = { + "transcribe_audio": { + "url": "http://whisper-predictor.ai-ml.svc.cluster.local", + "endpoint": "/v1/audio/transcriptions" + }, + "generate_embeddings": { + "url": "http://embeddings-predictor.ai-ml.svc.cluster.local", + "endpoint": "/embeddings" + }, + "generate_response": { + "url": "http://llm-draft.ai-ml.svc.cluster.local:8000", + "endpoint": "/v1/chat/completions" + }, + "synthesize_speech": { + "url": "http://tts-predictor.ai-ml.svc.cluster.local", + "endpoint": "/v1/audio/speech" + } + } + + if component_name not in COMPONENTS: + raise ValueError(f"Unknown component: {component_name}") + + config = COMPONENTS[component_name] + with httpx.Client(timeout=120.0) as client: + response = client.post( + f"{config['url']}{config['endpoint']}", + json=params + ) + result = response.json() + + with open("/tmp/result.json", "w") as f: + json.dump(result, f) + + print(f"Component {component_name} completed") diff --git a/qlora-training.yaml b/qlora-training.yaml new file mode 100644 index 0000000..07b29e9 --- /dev/null +++ b/qlora-training.yaml @@ -0,0 +1,510 @@ +# QLoRA Fine-tuning Workflow +# Trains QLora adapters from a reference model using data from Milvus vector database +# Triggered via NATS: ai.pipeline.trigger with pipeline="qlora-training" +--- +apiVersion: argoproj.io/v1alpha1 +kind: WorkflowTemplate +metadata: + name: qlora-training + namespace: ai-ml + labels: + app.kubernetes.io/name: qlora-training + app.kubernetes.io/part-of: llm-workflows +spec: + entrypoint: train-qlora + serviceAccountName: argo-workflow + + arguments: + parameters: + - name: reference-model + description: "Base model to fine-tune (HuggingFace model ID or path)" + value: "mistralai/Mistral-7B-Instruct-v0.3" + - name: output-name + description: "Name for the output QLora adapter" + value: "qlora-adapter" + - name: milvus-collections + description: "Comma-separated list of Milvus collections to use (empty = all available)" + value: "" + - name: learning-rate + value: "2e-4" + description: "Learning rate for training" + - name: num-epochs + value: "3" + description: "Number of training epochs" + - name: batch-size + value: "4" + description: "Training batch size" + - name: max-seq-length + value: "2048" + description: "Maximum sequence length" + - name: lora-r + value: "64" + description: "LoRA attention dimension" + - name: lora-alpha + value: "16" + description: "LoRA alpha parameter" + - name: lora-dropout + value: "0.05" + description: "LoRA dropout rate" + + volumeClaimTemplates: + - metadata: + name: model-storage + spec: + accessModes: ["ReadWriteMany"] + storageClassName: nfs-slow + resources: + requests: + storage: 50Gi + + templates: + - name: train-qlora + dag: + tasks: + - name: fetch-training-data + template: fetch-data + arguments: + parameters: + - name: milvus-collections + value: "{{workflow.parameters.milvus-collections}}" + + - name: prepare-dataset + template: prepare-data + dependencies: [fetch-training-data] + arguments: + parameters: + - name: max-seq-length + value: "{{workflow.parameters.max-seq-length}}" + artifacts: + - name: raw-data + from: "{{tasks.fetch-training-data.outputs.artifacts.raw-data}}" + + - name: train-model + template: train + dependencies: [prepare-dataset] + arguments: + parameters: + - name: reference-model + value: "{{workflow.parameters.reference-model}}" + - name: output-name + value: "{{workflow.parameters.output-name}}" + - name: learning-rate + value: "{{workflow.parameters.learning-rate}}" + - name: num-epochs + value: "{{workflow.parameters.num-epochs}}" + - name: batch-size + value: "{{workflow.parameters.batch-size}}" + - name: max-seq-length + value: "{{workflow.parameters.max-seq-length}}" + - name: lora-r + value: "{{workflow.parameters.lora-r}}" + - name: lora-alpha + value: "{{workflow.parameters.lora-alpha}}" + - name: lora-dropout + value: "{{workflow.parameters.lora-dropout}}" + artifacts: + - name: training-data + from: "{{tasks.prepare-dataset.outputs.artifacts.training-data}}" + + - name: fetch-data + inputs: + parameters: + - name: milvus-collections + outputs: + artifacts: + - name: raw-data + path: /tmp/raw-data + container: + image: python:3.13-slim + command: [python] + args: + - -c + - | + import subprocess + subprocess.run(["pip", "install", "pymilvus", "-q"], check=True) + + import json + from pathlib import Path + from pymilvus import connections, Collection, utility + + MILVUS_HOST = "milvus.ai-ml.svc.cluster.local" + MILVUS_PORT = 19530 + collections_param = "{{inputs.parameters.milvus-collections}}" + + output_dir = Path("/tmp/raw-data") + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"Connecting to Milvus at {MILVUS_HOST}:{MILVUS_PORT}") + connections.connect(host=MILVUS_HOST, port=MILVUS_PORT) + + # Determine which collections to use + if collections_param and collections_param.strip(): + collection_names = [c.strip() for c in collections_param.split(",")] + print(f"Using specified collections: {collection_names}") + else: + # Get all available collections + collection_names = utility.list_collections() + print(f"Using all available collections: {collection_names}") + + all_training_data = [] + + for collection_name in collection_names: + if not utility.has_collection(collection_name): + print(f"Warning: Collection {collection_name} not found, skipping") + continue + + print(f"Fetching data from collection: {collection_name}") + collection = Collection(collection_name) + collection.load() + + # Query all data from the collection + # Note: Adjust field names based on your schema + try: + # Get collection schema to determine fields + schema = collection.schema + field_names = [field.name for field in schema.fields if field.name != "id"] + + # Query all entities (limited to reasonable batch size) + # For large collections, you may want to implement pagination + results = collection.query( + expr="id >= 0", + output_fields=field_names, + limit=100000 + ) + + print(f" Retrieved {len(results)} records from {collection_name}") + + for result in results: + # Extract text field (adjust based on your schema) + text_content = result.get("text", "") + source = result.get("source", collection_name) + + if text_content: + all_training_data.append({ + "text": text_content, + "source": source, + "collection": collection_name + }) + + except Exception as e: + print(f"Error querying collection {collection_name}: {e}") + continue + + # Save all training data + output_file = output_dir / "training_data.json" + with open(output_file, "w") as f: + json.dump({"data": all_training_data}, f) + + print(f"Total training samples collected: {len(all_training_data)}") + print(f"Saved to {output_file}") + + connections.disconnect("default") + envFrom: + - configMapRef: + name: ai-services-config + resources: + requests: + memory: 1Gi + cpu: 500m + + - name: prepare-data + inputs: + parameters: + - name: max-seq-length + artifacts: + - name: raw-data + path: /tmp/raw-data + outputs: + artifacts: + - name: training-data + path: /tmp/training-data + container: + image: python:3.13-slim + command: [python] + args: + - -c + - | + import json + from pathlib import Path + + max_seq_length = int("{{inputs.parameters.max-seq-length}}") + + input_dir = Path("/tmp/raw-data") + output_dir = Path("/tmp/training-data") + output_dir.mkdir(parents=True, exist_ok=True) + + # Load raw data + with open(input_dir / "training_data.json") as f: + data = json.load(f) + + raw_samples = data["data"] + print(f"Processing {len(raw_samples)} raw samples") + + # Prepare data in instruction format for fine-tuning + # Using Alpaca-style format: instruction + response + training_samples = [] + + for sample in raw_samples: + text = sample["text"] + source = sample.get("source", "") + + # Create instruction-response pairs + # You can customize this based on your use case + training_sample = { + "instruction": f"Based on the following information from {source}, provide a comprehensive response:", + "input": text[:max_seq_length // 2], # Truncate if needed + "output": text[:max_seq_length // 2], + "source": source + } + training_samples.append(training_sample) + + # Split into train/validation (90/10) + split_idx = int(len(training_samples) * 0.9) + train_data = training_samples[:split_idx] + val_data = training_samples[split_idx:] + + # Save prepared datasets + with open(output_dir / "train.json", "w") as f: + json.dump(train_data, f) + + with open(output_dir / "validation.json", "w") as f: + json.dump(val_data, f) + + print(f"Prepared {len(train_data)} training samples") + print(f"Prepared {len(val_data)} validation samples") + print("Data preparation complete") + resources: + requests: + memory: 2Gi + cpu: 500m + + - name: train + inputs: + parameters: + - name: reference-model + - name: output-name + - name: learning-rate + - name: num-epochs + - name: batch-size + - name: max-seq-length + - name: lora-r + - name: lora-alpha + - name: lora-dropout + artifacts: + - name: training-data + path: /tmp/training-data + container: + image: python:3.13-slim + command: [bash] + args: + - -c + - | + set -e + + echo "Installing dependencies..." + pip install -q torch transformers peft datasets accelerate bitsandbytes scipy + + echo "Starting QLoRA training..." + + python << 'EOF' + import json + import os + from pathlib import Path + from datasets import Dataset + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + TrainingArguments, + Trainer, + DataCollatorForLanguageModeling + ) + from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training + import torch + + # Parameters + reference_model = "{{inputs.parameters.reference-model}}" + output_name = "{{inputs.parameters.output-name}}" + learning_rate = float("{{inputs.parameters.learning-rate}}") + num_epochs = int("{{inputs.parameters.num-epochs}}") + batch_size = int("{{inputs.parameters.batch-size}}") + max_seq_length = int("{{inputs.parameters.max-seq-length}}") + lora_r = int("{{inputs.parameters.lora-r}}") + lora_alpha = int("{{inputs.parameters.lora-alpha}}") + lora_dropout = float("{{inputs.parameters.lora-dropout}}") + + data_dir = Path("/tmp/training-data") + output_dir = Path("/mnt/model-storage") / output_name + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"Training configuration:") + print(f" Model: {reference_model}") + print(f" Learning rate: {learning_rate}") + print(f" Epochs: {num_epochs}") + print(f" Batch size: {batch_size}") + print(f" Max sequence length: {max_seq_length}") + print(f" LoRA r: {lora_r}, alpha: {lora_alpha}, dropout: {lora_dropout}") + + # Load datasets + with open(data_dir / "train.json") as f: + train_data = json.load(f) + + with open(data_dir / "validation.json") as f: + val_data = json.load(f) + + print(f"Loaded {len(train_data)} training samples, {len(val_data)} validation samples") + + # Load tokenizer + print(f"Loading tokenizer from {reference_model}...") + tokenizer = AutoTokenizer.from_pretrained(reference_model, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Prepare datasets + def format_sample(sample): + instruction = sample.get("instruction", "") + input_text = sample.get("input", "") + output = sample.get("output", "") + + if input_text: + prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{output}" + else: + prompt = f"### Instruction:\n{instruction}\n\n### Response:\n{output}" + + return {"text": prompt} + + train_dataset = Dataset.from_list([format_sample(s) for s in train_data]) + val_dataset = Dataset.from_list([format_sample(s) for s in val_data]) + + # Tokenize datasets + def tokenize_function(examples): + return tokenizer( + examples["text"], + truncation=True, + max_length=max_seq_length, + padding="max_length" + ) + + print("Tokenizing datasets...") + train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"]) + val_dataset = val_dataset.map(tokenize_function, batched=True, remove_columns=["text"]) + + # Configure quantization for QLoRA + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + ) + + # Load model + print(f"Loading model {reference_model} with 4-bit quantization...") + model = AutoModelForCausalLM.from_pretrained( + reference_model, + quantization_config=bnb_config, + device_map="auto", + trust_remote_code=True + ) + + # Prepare model for training + model = prepare_model_for_kbit_training(model) + + # Configure LoRA + lora_config = LoraConfig( + r=lora_r, + lora_alpha=lora_alpha, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + lora_dropout=lora_dropout, + bias="none", + task_type="CAUSAL_LM" + ) + + # Add LoRA adapters + print("Adding LoRA adapters...") + model = get_peft_model(model, lora_config) + model.print_trainable_parameters() + + # Training arguments + training_args = TrainingArguments( + output_dir=str(output_dir / "checkpoints"), + num_train_epochs=num_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + gradient_accumulation_steps=4, + learning_rate=learning_rate, + fp16=True, + logging_steps=10, + evaluation_strategy="steps", + eval_steps=50, + save_strategy="steps", + save_steps=100, + save_total_limit=3, + load_best_model_at_end=True, + report_to="none", + remove_unused_columns=False, + ) + + # Data collator + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False + ) + + # Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=val_dataset, + data_collator=data_collator, + ) + + # Train + print("Starting training...") + trainer.train() + + # Save final model + print(f"Saving QLora adapter to {output_dir}") + model.save_pretrained(str(output_dir / "final")) + tokenizer.save_pretrained(str(output_dir / "final")) + + # Save training metadata + metadata = { + "reference_model": reference_model, + "output_name": output_name, + "training_params": { + "learning_rate": learning_rate, + "num_epochs": num_epochs, + "batch_size": batch_size, + "max_seq_length": max_seq_length, + "lora_r": lora_r, + "lora_alpha": lora_alpha, + "lora_dropout": lora_dropout + }, + "dataset_info": { + "train_samples": len(train_data), + "val_samples": len(val_data) + } + } + + with open(output_dir / "metadata.json", "w") as f: + json.dump(metadata, f, indent=2) + + print("Training complete!") + print(f"QLora adapter saved to: {output_dir}") + EOF + envFrom: + - configMapRef: + name: ai-services-config + volumeMounts: + - name: model-storage + mountPath: /mnt/model-storage + resources: + requests: + memory: 16Gi + cpu: 4 + nvidia.com/gpu: 1 + limits: + memory: 32Gi + cpu: 8 + nvidia.com/gpu: 1