feat: Add ML training and batch inference workflows

- batch-inference: LLM inference with optional RAG
- qlora-training: QLoRA adapter fine-tuning from Milvus
- hybrid-ml-training: Multi-GPU distributed training
- coqui-voice-training: XTTS voice cloning
- document-ingestion: Ingest documents to Milvus
- eventsource-kfp: Argo Events / Kubeflow integration
- kfp-integration: Bridge between Argo and Kubeflow
This commit is contained in:
2026-02-01 20:39:42 -05:00
parent a8fc72dd0b
commit 7104698eee
8 changed files with 3365 additions and 1 deletions

128
README.md
View File

@@ -1,2 +1,128 @@
# argo # Argo Workflows
ML training and batch inference workflows for the DaviesTechLabs AI/ML platform.
## Workflows
| Workflow | Description | Trigger |
|----------|-------------|---------|
| `batch-inference` | Run LLM inference on batch inputs | `ai.pipeline.trigger` (pipeline="batch-inference") |
| `qlora-training` | Train QLoRA adapters from Milvus data | `ai.pipeline.trigger` (pipeline="qlora-training") |
| `hybrid-ml-training` | Multi-GPU distributed training | `ai.pipeline.trigger` (pipeline="hybrid-ml-training") |
| `coqui-voice-training` | XTTS voice cloning/training | `ai.pipeline.trigger` (pipeline="coqui-voice-training") |
| `document-ingestion` | Ingest documents into Milvus | `ai.pipeline.trigger` (pipeline="document-ingestion") |
## Integration
| File | Description |
|------|-------------|
| `eventsource-kfp.yaml` | Argo Events source for Kubeflow Pipelines integration |
| `kfp-integration.yaml` | Bridge workflows between Argo and Kubeflow |
## Architecture
```
NATS (ai.pipeline.trigger)
┌─────────────────┐
│ Argo Events │
│ EventSource │
└─────────────────┘
┌─────────────────┐
│ Argo Sensor │
└─────────────────┘
┌─────────────────┐
│ WorkflowTemplate│
│ (batch-inf, │
│ qlora, etc) │
└─────────────────┘
├──▶ GPU Pods (AMD ROCm / NVIDIA CUDA)
├──▶ Milvus Vector DB
├──▶ vLLM / Ray Serve
└──▶ MLflow Tracking
```
## Workflow Details
### batch-inference
Batch LLM inference with optional RAG:
```bash
argo submit batch-inference.yaml \
-p input-url="s3://bucket/inputs.json" \
-p output-url="s3://bucket/outputs.json" \
-p use-rag="true" \
-p max-tokens="500"
```
### qlora-training
Fine-tune QLoRA adapters from Milvus knowledge:
```bash
argo submit qlora-training.yaml \
-p reference-model="mistralai/Mistral-7B-Instruct-v0.3" \
-p output-name="my-adapter" \
-p milvus-collections="docs,wiki" \
-p num-epochs="3"
```
### coqui-voice-training
Train XTTS voice models:
```bash
argo submit coqui-voice-training.yaml \
-p voice-name="my-voice" \
-p audio-samples-url="s3://bucket/samples/"
```
### document-ingestion
Ingest documents into Milvus:
```bash
argo submit document-ingestion.yaml \
-p source-url="s3://bucket/docs/" \
-p collection="knowledge_base" \
-p chunk-size="512"
```
## NATS Trigger Format
Workflows are triggered via NATS `ai.pipeline.trigger`:
```json
{
"pipeline": "qlora-training",
"parameters": {
"reference-model": "mistralai/Mistral-7B-Instruct-v0.3",
"output-name": "custom-adapter",
"num-epochs": "5"
}
}
```
## GPU Scheduling
Workflows use node affinity for GPU allocation:
| Node | GPU | Best For |
|------|-----|----------|
| khelben | AMD Strix Halo 64GB | Large model training, vLLM |
| elminster | NVIDIA RTX 2070 | Whisper, XTTS |
| drizzt | AMD Radeon 680M | Embeddings |
| danilo | Intel Arc | Reranker |
## Related
- [homelab-design](https://git.daviestechlabs.io/daviestechlabs/homelab-design) - Architecture docs
- [kuberay-images](https://git.daviestechlabs.io/daviestechlabs/kuberay-images) - Ray worker images
- [handler-base](https://git.daviestechlabs.io/daviestechlabs/handler-base) - Handler library

328
batch-inference.yaml Normal file
View File

@@ -0,0 +1,328 @@
# Batch Inference Workflow
# Runs LLM inference on a batch of inputs
# Triggered via NATS: ai.pipeline.trigger with pipeline="batch-inference"
---
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
name: batch-inference
namespace: ai-ml
labels:
app.kubernetes.io/name: batch-inference
app.kubernetes.io/part-of: llm-workflows
spec:
entrypoint: batch-inference
serviceAccountName: argo-workflow
arguments:
parameters:
- name: input-url
description: "URL to JSON file with inference requests"
- name: output-url
description: "URL to store results (S3 path)"
value: ""
- name: use-rag
value: "true"
description: "Whether to use RAG for context"
- name: max-tokens
value: "500"
description: "Maximum tokens per response"
- name: temperature
value: "0.7"
description: "LLM temperature"
templates:
- name: batch-inference
dag:
tasks:
- name: fetch-inputs
template: fetch-input-data
arguments:
parameters:
- name: input-url
value: "{{workflow.parameters.input-url}}"
- name: run-inference
template: inference
dependencies: [fetch-inputs]
arguments:
parameters:
- name: use-rag
value: "{{workflow.parameters.use-rag}}"
- name: max-tokens
value: "{{workflow.parameters.max-tokens}}"
- name: temperature
value: "{{workflow.parameters.temperature}}"
artifacts:
- name: inputs
from: "{{tasks.fetch-inputs.outputs.artifacts.inputs}}"
- name: upload-results
template: upload-output
dependencies: [run-inference]
when: "{{workflow.parameters.output-url}} != ''"
arguments:
parameters:
- name: output-url
value: "{{workflow.parameters.output-url}}"
artifacts:
- name: results
from: "{{tasks.run-inference.outputs.artifacts.results}}"
- name: fetch-input-data
inputs:
parameters:
- name: input-url
outputs:
artifacts:
- name: inputs
path: /tmp/inputs
container:
image: python:3.13-slim
command: [python]
args:
- -c
- |
import json
import urllib.request
from pathlib import Path
input_url = "{{inputs.parameters.input-url}}"
output_dir = Path("/tmp/inputs")
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Fetching inputs from: {input_url}")
if input_url.startswith("s3://"):
import subprocess
subprocess.run(["pip", "install", "boto3", "-q"], check=True)
import boto3
s3 = boto3.client("s3")
bucket, key = input_url[5:].split("/", 1)
s3.download_file(bucket, key, str(output_dir / "inputs.json"))
elif input_url.startswith("http"):
urllib.request.urlretrieve(input_url, output_dir / "inputs.json")
else:
print(f"Unsupported URL scheme: {input_url}")
exit(1)
# Validate JSON structure
with open(output_dir / "inputs.json") as f:
data = json.load(f)
if "requests" not in data:
print("Error: JSON must contain 'requests' array")
exit(1)
print(f"Loaded {len(data['requests'])} inference requests")
resources:
requests:
memory: 256Mi
cpu: 100m
- name: inference
inputs:
parameters:
- name: use-rag
- name: max-tokens
- name: temperature
artifacts:
- name: inputs
path: /tmp/inputs
outputs:
artifacts:
- name: results
path: /tmp/results
container:
image: python:3.13-slim
command: [python]
args:
- -c
- |
import subprocess
subprocess.run(["pip", "install", "httpx", "pymilvus", "-q"], check=True)
import json
import httpx
from pathlib import Path
from typing import List, Dict
# Configuration
VLLM_URL = "http://llm-draft.ai-ml.svc.cluster.local:8000"
EMBEDDINGS_URL = "http://embeddings-predictor.ai-ml.svc.cluster.local"
RERANKER_URL = "http://reranker-predictor.ai-ml.svc.cluster.local"
MILVUS_HOST = "milvus.ai-ml.svc.cluster.local"
LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
use_rag = "{{inputs.parameters.use-rag}}" == "true"
max_tokens = int("{{inputs.parameters.max-tokens}}")
temperature = float("{{inputs.parameters.temperature}}")
input_dir = Path("/tmp/inputs")
output_dir = Path("/tmp/results")
output_dir.mkdir(parents=True, exist_ok=True)
# Load inputs
with open(input_dir / "inputs.json") as f:
data = json.load(f)
requests = data["requests"]
print(f"Processing {len(requests)} requests (RAG: {use_rag})")
# Initialize Milvus if using RAG
collection = None
if use_rag:
try:
from pymilvus import connections, Collection, utility
connections.connect(host=MILVUS_HOST, port=19530)
if utility.has_collection("knowledge_base"):
collection = Collection("knowledge_base")
collection.load()
print("Milvus connected")
except Exception as e:
print(f"Milvus connection failed: {e}")
use_rag = False
def get_embeddings(texts: List[str], client: httpx.Client) -> List[List[float]]:
response = client.post(
f"{EMBEDDINGS_URL}/embeddings",
json={"input": texts, "model": "bge"}
)
result = response.json()
return [d["embedding"] for d in result.get("data", [])]
def search_milvus(embedding: List[float]) -> List[Dict]:
results = collection.search(
data=[embedding],
anns_field="embedding",
param={"metric_type": "COSINE", "params": {"ef": 64}},
limit=5,
output_fields=["text", "source"]
)
docs = []
for hits in results:
for hit in hits:
docs.append({
"text": hit.entity.get("text", ""),
"source": hit.entity.get("source", ""),
"score": hit.score
})
return docs
def rerank(query: str, documents: List[str], client: httpx.Client) -> List[Dict]:
response = client.post(
f"{RERANKER_URL}/v1/rerank",
json={"query": query, "documents": documents}
)
return response.json().get("results", [])
# Process requests
results = []
with httpx.Client(timeout=120.0) as client:
for i, req in enumerate(requests):
query = req.get("text", req.get("query", ""))
req_id = req.get("id", str(i))
print(f"Processing {i+1}/{len(requests)}: {query[:50]}...")
context = ""
rag_sources = []
if use_rag and collection:
try:
# Get embeddings and search
embeddings = get_embeddings([query], client)
if embeddings:
docs = search_milvus(embeddings[0])
if docs:
doc_texts = [d["text"] for d in docs]
reranked = rerank(query, doc_texts, client)
sorted_docs = sorted(reranked, key=lambda x: x.get("relevance_score", 0), reverse=True)[:3]
context = "\n\n".join([doc_texts[d["index"]] for d in sorted_docs])
rag_sources = [docs[d["index"]].get("source", "") for d in sorted_docs]
except Exception as e:
print(f" RAG failed: {e}")
# Generate response
try:
messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
if context:
messages.append({"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"})
else:
messages.append({"role": "user", "content": query})
response = client.post(
f"{VLLM_URL}/v1/chat/completions",
json={
"model": LLM_MODEL,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature
}
)
result = response.json()
answer = result["choices"][0]["message"]["content"]
except Exception as e:
answer = f"Error: {e}"
results.append({
"id": req_id,
"query": query,
"response": answer,
"used_rag": bool(context),
"rag_sources": rag_sources
})
# Save results
with open(output_dir / "results.json", "w") as f:
json.dump({"results": results}, f, indent=2)
print(f"Completed {len(results)} inferences")
if collection:
from pymilvus import connections
connections.disconnect("default")
envFrom:
- configMapRef:
name: ai-services-config
resources:
requests:
memory: 1Gi
cpu: 500m
- name: upload-output
inputs:
parameters:
- name: output-url
artifacts:
- name: results
path: /tmp/results
container:
image: python:3.13-slim
command: [python]
args:
- -c
- |
import subprocess
subprocess.run(["pip", "install", "boto3", "-q"], check=True)
import boto3
from pathlib import Path
output_url = "{{inputs.parameters.output-url}}"
results_file = Path("/tmp/results/results.json")
print(f"Uploading results to: {output_url}")
if output_url.startswith("s3://"):
s3 = boto3.client("s3")
bucket, key = output_url[5:].split("/", 1)
s3.upload_file(str(results_file), bucket, key)
print("Upload complete")
else:
print(f"Unsupported URL scheme: {output_url}")
exit(1)
resources:
requests:
memory: 256Mi
cpu: 100m

969
coqui-voice-training.yaml Normal file
View File

@@ -0,0 +1,969 @@
# Coqui TTS Voice Training Workflow
# Trains a custom voice model using Coqui TTS from audio samples
# Triggered via NATS: ai.pipeline.trigger with pipeline="coqui-voice-training"
---
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
name: coqui-voice-training
namespace: ai-ml
labels:
app.kubernetes.io/name: coqui-voice-training
app.kubernetes.io/part-of: llm-workflows
spec:
entrypoint: train-voice
serviceAccountName: argo-workflow
arguments:
parameters:
- name: audio-source
description: "URL to audio files (S3 bucket, HTTP, or NFS path with .wav/.mp3 files)"
- name: transcripts-source
description: "URL to transcripts file (CSV with audio_file,transcript columns) - leave empty to auto-transcribe"
value: ""
- name: voice-name
description: "Name for the trained voice model"
value: "custom-voice"
- name: base-model
description: "Base TTS model to fine-tune from"
value: "tts_models/en/ljspeech/vits"
- name: language
description: "Language code (e.g., en, de, fr, es)"
value: "en"
- name: num-epochs
description: "Number of training epochs"
value: "100"
- name: batch-size
description: "Training batch size"
value: "16"
- name: learning-rate
description: "Learning rate for training"
value: "0.0001"
- name: sample-rate
description: "Target sample rate for audio (Hz)"
value: "22050"
- name: output-path
description: "Path to store the trained model (S3 or NFS)"
value: "/models/tts/custom"
volumeClaimTemplates:
- metadata:
name: training-workspace
spec:
accessModes: ["ReadWriteMany"]
storageClassName: nfs-slow
resources:
requests:
storage: 50Gi
templates:
- name: train-voice
dag:
tasks:
- name: fetch-audio
template: fetch-audio-files
arguments:
parameters:
- name: audio-source
value: "{{workflow.parameters.audio-source}}"
- name: fetch-transcripts
template: fetch-transcript-file
arguments:
parameters:
- name: transcripts-source
value: "{{workflow.parameters.transcripts-source}}"
- name: preprocess-audio
template: preprocess
dependencies: [fetch-audio]
arguments:
parameters:
- name: sample-rate
value: "{{workflow.parameters.sample-rate}}"
artifacts:
- name: raw-audio
from: "{{tasks.fetch-audio.outputs.artifacts.audio-files}}"
- name: generate-transcripts
template: transcribe-audio
dependencies: [preprocess-audio, fetch-transcripts]
when: "{{workflow.parameters.transcripts-source}} == ''"
arguments:
parameters:
- name: language
value: "{{workflow.parameters.language}}"
artifacts:
- name: audio-files
from: "{{tasks.preprocess-audio.outputs.artifacts.processed-audio}}"
- name: prepare-dataset
template: prepare-coqui-dataset
dependencies: [preprocess-audio, generate-transcripts, fetch-transcripts]
arguments:
parameters:
- name: voice-name
value: "{{workflow.parameters.voice-name}}"
- name: language
value: "{{workflow.parameters.language}}"
artifacts:
- name: audio-files
from: "{{tasks.preprocess-audio.outputs.artifacts.processed-audio}}"
- name: transcripts
from: "{{=workflow.parameters.transcriptsSource != '' ? tasks.fetch-transcripts.outputs.artifacts.transcripts : tasks.generate-transcripts.outputs.artifacts.transcripts}}"
optional: true
- name: train-model
template: train-tts
dependencies: [prepare-dataset]
arguments:
parameters:
- name: voice-name
value: "{{workflow.parameters.voice-name}}"
- name: base-model
value: "{{workflow.parameters.base-model}}"
- name: language
value: "{{workflow.parameters.language}}"
- name: num-epochs
value: "{{workflow.parameters.num-epochs}}"
- name: batch-size
value: "{{workflow.parameters.batch-size}}"
- name: learning-rate
value: "{{workflow.parameters.learning-rate}}"
artifacts:
- name: dataset
from: "{{tasks.prepare-dataset.outputs.artifacts.dataset}}"
- name: export-model
template: export-trained-model
dependencies: [train-model]
arguments:
parameters:
- name: voice-name
value: "{{workflow.parameters.voice-name}}"
- name: output-path
value: "{{workflow.parameters.output-path}}"
artifacts:
- name: trained-model
from: "{{tasks.train-model.outputs.artifacts.model}}"
# Template: Fetch audio files from source
- name: fetch-audio-files
inputs:
parameters:
- name: audio-source
outputs:
artifacts:
- name: audio-files
path: /tmp/audio
container:
image: python:3.13-slim
command: [python]
args:
- -c
- |
import os
import subprocess
import urllib.request
from pathlib import Path
import shutil
source_url = "{{inputs.parameters.audio-source}}"
output_dir = Path("/tmp/audio")
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Fetching audio from: {source_url}")
if source_url.startswith("s3://"):
subprocess.run(["pip", "install", "boto3", "-q"], check=True)
import boto3
s3 = boto3.client("s3")
bucket, prefix = source_url[5:].split("/", 1)
response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
audio_extensions = {".wav", ".mp3", ".flac", ".ogg", ".m4a"}
for obj in response.get("Contents", []):
key = obj["Key"]
if Path(key).suffix.lower() in audio_extensions:
local_path = output_dir / Path(key).name
s3.download_file(bucket, key, str(local_path))
print(f"Downloaded: {key}")
elif source_url.startswith("http"):
# Handle single file or directory listing
filename = source_url.split("/")[-1]
if any(ext in filename.lower() for ext in [".wav", ".mp3", ".flac", ".zip"]):
local_path = output_dir / filename
urllib.request.urlretrieve(source_url, local_path)
print(f"Downloaded: {filename}")
# Extract if zip
if filename.endswith(".zip"):
shutil.unpack_archive(local_path, output_dir)
os.remove(local_path)
print("Extracted zip archive")
else:
print(f"URL doesn't appear to be an audio file: {source_url}")
exit(1)
elif source_url.startswith("/"):
# Local/NFS path
src_path = Path(source_url)
if src_path.is_dir():
audio_extensions = {".wav", ".mp3", ".flac", ".ogg", ".m4a"}
for f in src_path.iterdir():
if f.suffix.lower() in audio_extensions:
shutil.copy(f, output_dir / f.name)
print(f"Copied: {f.name}")
elif src_path.is_file():
shutil.copy(src_path, output_dir / src_path.name)
else:
print(f"Path not found: {source_url}")
exit(1)
else:
print(f"Unsupported source: {source_url}")
exit(1)
# Count files
audio_files = list(output_dir.glob("*"))
print(f"Total audio files: {len(audio_files)}")
if len(audio_files) == 0:
print("Error: No audio files found!")
exit(1)
resources:
requests:
memory: 512Mi
cpu: 200m
# Template: Fetch transcripts file
- name: fetch-transcript-file
inputs:
parameters:
- name: transcripts-source
outputs:
artifacts:
- name: transcripts
path: /tmp/transcripts
optional: true
container:
image: python:3.13-slim
command: [python]
args:
- -c
- |
import os
import subprocess
import urllib.request
from pathlib import Path
import shutil
source_url = "{{inputs.parameters.transcripts-source}}"
output_dir = Path("/tmp/transcripts")
output_dir.mkdir(parents=True, exist_ok=True)
if not source_url or source_url.strip() == "":
print("No transcripts source provided - will auto-transcribe")
# Create empty placeholder
(output_dir / "placeholder.txt").write_text("auto-transcribe")
exit(0)
print(f"Fetching transcripts from: {source_url}")
if source_url.startswith("s3://"):
subprocess.run(["pip", "install", "boto3", "-q"], check=True)
import boto3
s3 = boto3.client("s3")
bucket, key = source_url[5:].split("/", 1)
local_path = output_dir / Path(key).name
s3.download_file(bucket, key, str(local_path))
print(f"Downloaded: {key}")
elif source_url.startswith("http"):
filename = source_url.split("/")[-1] or "transcripts.csv"
local_path = output_dir / filename
urllib.request.urlretrieve(source_url, local_path)
print(f"Downloaded: {filename}")
elif source_url.startswith("/"):
src_path = Path(source_url)
if src_path.is_file():
shutil.copy(src_path, output_dir / src_path.name)
print(f"Copied: {src_path.name}")
else:
print(f"File not found: {source_url}")
exit(1)
else:
print(f"Unsupported source: {source_url}")
exit(1)
resources:
requests:
memory: 256Mi
cpu: 100m
# Template: Preprocess audio files
- name: preprocess
inputs:
parameters:
- name: sample-rate
artifacts:
- name: raw-audio
path: /tmp/raw-audio
outputs:
artifacts:
- name: processed-audio
path: /tmp/processed-audio
container:
image: python:3.13-slim
command: [bash]
args:
- -c
- |
set -e
# Install ffmpeg and dependencies
apt-get update && apt-get install -y ffmpeg > /dev/null 2>&1
pip install -q pydub numpy soundfile
python3 << 'EOF'
import os
from pathlib import Path
from pydub import AudioSegment
import soundfile as sf
SAMPLE_RATE = int("{{inputs.parameters.sample-rate}}")
input_dir = Path("/tmp/raw-audio")
output_dir = Path("/tmp/processed-audio")
output_dir.mkdir(parents=True, exist_ok=True)
audio_extensions = {".wav", ".mp3", ".flac", ".ogg", ".m4a"}
for audio_file in input_dir.iterdir():
if audio_file.suffix.lower() not in audio_extensions:
continue
print(f"Processing: {audio_file.name}")
try:
# Load audio
audio = AudioSegment.from_file(str(audio_file))
# Convert to mono if stereo
if audio.channels > 1:
audio = audio.set_channels(1)
# Resample to target sample rate
audio = audio.set_frame_rate(SAMPLE_RATE)
# Normalize audio
audio = audio.normalize()
# Export as WAV
output_file = output_dir / f"{audio_file.stem}.wav"
audio.export(str(output_file), format="wav")
print(f" -> Saved: {output_file.name}")
except Exception as e:
print(f" -> Error processing {audio_file.name}: {e}")
continue
processed_files = list(output_dir.glob("*.wav"))
print(f"\nProcessed {len(processed_files)} audio files")
if len(processed_files) == 0:
print("Error: No files were successfully processed!")
exit(1)
EOF
resources:
requests:
memory: 2Gi
cpu: "1"
# Template: Auto-transcribe audio using Coqui STT
- name: transcribe-audio
inputs:
parameters:
- name: language
artifacts:
- name: audio-files
path: /tmp/audio
outputs:
artifacts:
- name: transcripts
path: /tmp/transcripts
container:
image: ghcr.io/coqui-ai/stt:latest
command: [bash]
args:
- -c
- |
set -e
# Install additional dependencies
pip install -q numpy scipy
python3 << 'EOF'
import csv
import os
import wave
import numpy as np
from pathlib import Path
from stt import Model
LANGUAGE = "{{inputs.parameters.language}}"
input_dir = Path("/tmp/audio")
output_dir = Path("/tmp/transcripts")
output_dir.mkdir(parents=True, exist_ok=True)
# Model paths - Coqui STT models are typically pre-installed in the container
# or can be downloaded from https://coqui.ai/models
MODEL_DIR = Path("/models/stt")
# Try to find model files
model_file = None
scorer_file = None
# Check for language-specific models
lang_model_dir = MODEL_DIR / LANGUAGE
if lang_model_dir.exists():
for f in lang_model_dir.glob("*.tflite"):
model_file = f
for f in lang_model_dir.glob("*.scorer"):
scorer_file = f
# Fallback to default English model location
if model_file is None:
default_paths = [
MODEL_DIR / "model.tflite",
Path("/usr/share/stt/model.tflite"),
Path("/opt/stt/model.tflite"),
]
for p in default_paths:
if p.exists():
model_file = p
break
if model_file is None:
# Download model if not found
print("Downloading Coqui STT model...")
import urllib.request
import tarfile
model_url = "https://github.com/coqui-ai/STT-models/releases/download/english/coqui-stt-1.0.0-lg-vocab.tflite"
scorer_url = "https://github.com/coqui-ai/STT-models/releases/download/english/coqui-stt-1.0.0-lg-vocab.scorer"
MODEL_DIR.mkdir(parents=True, exist_ok=True)
model_file = MODEL_DIR / "model.tflite"
scorer_file = MODEL_DIR / "model.scorer"
urllib.request.urlretrieve(model_url, model_file)
urllib.request.urlretrieve(scorer_url, scorer_file)
print("Model downloaded successfully")
print(f"Loading Coqui STT model: {model_file}")
model = Model(str(model_file))
if scorer_file and scorer_file.exists():
print(f"Loading scorer: {scorer_file}")
model.enableExternalScorer(str(scorer_file))
transcripts = []
for audio_file in sorted(input_dir.glob("*.wav")):
print(f"Transcribing: {audio_file.name}")
try:
# Read WAV file
with wave.open(str(audio_file), 'rb') as w:
sample_rate = w.getframerate()
frames = w.getnframes()
audio_data = w.readframes(frames)
# Convert to int16 array
audio = np.frombuffer(audio_data, dtype=np.int16)
# Resample if needed (Coqui STT expects 16kHz)
if sample_rate != 16000:
from scipy import signal
audio = signal.resample(audio, int(len(audio) * 16000 / sample_rate))
audio = audio.astype(np.int16)
# Run inference
text = model.stt(audio)
transcripts.append({
"audio_file": audio_file.name,
"transcript": text
})
print(f" -> {text[:100] if text else '(empty)'}...")
except Exception as e:
print(f" -> Error: {e}")
continue
# Write CSV
csv_file = output_dir / "transcripts.csv"
with open(csv_file, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=["audio_file", "transcript"])
writer.writeheader()
writer.writerows(transcripts)
print(f"\nTranscribed {len(transcripts)} files")
print(f"Saved to: {csv_file}")
EOF
resources:
requests:
memory: 4Gi
cpu: "2"
limits:
memory: 8Gi
cpu: "4"
# Template: Prepare dataset in Coqui TTS format
- name: prepare-coqui-dataset
inputs:
parameters:
- name: voice-name
- name: language
artifacts:
- name: audio-files
path: /tmp/audio
- name: transcripts
path: /tmp/transcripts
optional: true
outputs:
artifacts:
- name: dataset
path: /tmp/dataset
container:
image: python:3.13-slim
command: [python]
args:
- -c
- |
import csv
import json
import os
import shutil
from pathlib import Path
VOICE_NAME = "{{inputs.parameters.voice-name}}"
LANGUAGE = "{{inputs.parameters.language}}"
audio_dir = Path("/tmp/audio")
transcripts_dir = Path("/tmp/transcripts")
output_dir = Path("/tmp/dataset")
wavs_dir = output_dir / "wavs"
wavs_dir.mkdir(parents=True, exist_ok=True)
print(f"Preparing Coqui TTS dataset for voice: {VOICE_NAME}")
# Find transcripts file
transcripts_file = None
for f in transcripts_dir.glob("*.csv"):
transcripts_file = f
break
if transcripts_file is None:
# Check for .txt files (simple format: filename|text)
for f in transcripts_dir.glob("*.txt"):
if f.name != "placeholder.txt":
transcripts_file = f
break
if transcripts_file is None:
print("Error: No transcripts file found!")
exit(1)
print(f"Using transcripts: {transcripts_file}")
# Parse transcripts
transcripts = {}
if transcripts_file.suffix == ".csv":
with open(transcripts_file, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
# Handle various column name conventions
audio = row.get("audio_file") or row.get("audio") or row.get("file") or row.get("wav")
text = row.get("transcript") or row.get("text") or row.get("sentence")
if audio and text:
transcripts[audio] = text.strip()
else:
# Simple pipe-separated format: filename|text
with open(transcripts_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if "|" in line:
parts = line.split("|", 1)
if len(parts) == 2:
transcripts[parts[0]] = parts[1]
print(f"Loaded {len(transcripts)} transcripts")
# Copy audio files and create metadata
metadata_lines = []
for audio_file in sorted(audio_dir.glob("*.wav")):
# Try to match transcript
text = None
for key in [audio_file.name, audio_file.stem, audio_file.stem + ".wav"]:
if key in transcripts:
text = transcripts[key]
break
if text is None:
print(f"Warning: No transcript for {audio_file.name}, skipping")
continue
# Copy audio file
dest_file = wavs_dir / audio_file.name
shutil.copy(audio_file, dest_file)
# Add to metadata (LJSpeech format: filename|text|text)
# Coqui uses: audio_file|text|text (normalized text optional)
metadata_lines.append(f"{audio_file.stem}|{text}|{text}")
# Write metadata.csv
metadata_file = output_dir / "metadata.csv"
with open(metadata_file, "w", encoding="utf-8") as f:
f.write("\n".join(metadata_lines))
print(f"Created dataset with {len(metadata_lines)} samples")
# Create dataset config
config = {
"name": VOICE_NAME,
"language": LANGUAGE,
"num_samples": len(metadata_lines),
"format": "ljspeech"
}
with open(output_dir / "dataset_config.json", "w") as f:
json.dump(config, f, indent=2)
print(f"Dataset ready at: {output_dir}")
if len(metadata_lines) < 10:
print("Warning: Very small dataset! Recommend at least 100+ samples for good results.")
resources:
requests:
memory: 1Gi
cpu: 500m
# Template: Train Coqui TTS model
- name: train-tts
inputs:
parameters:
- name: voice-name
- name: base-model
- name: language
- name: num-epochs
- name: batch-size
- name: learning-rate
artifacts:
- name: dataset
path: /tmp/dataset
outputs:
artifacts:
- name: model
path: /tmp/output
container:
image: ghcr.io/coqui-ai/tts:latest
command: [bash]
args:
- -c
- |
set -e
VOICE_NAME="{{inputs.parameters.voice-name}}"
BASE_MODEL="{{inputs.parameters.base-model}}"
LANGUAGE="{{inputs.parameters.language}}"
NUM_EPOCHS="{{inputs.parameters.num-epochs}}"
BATCH_SIZE="{{inputs.parameters.batch-size}}"
LEARNING_RATE="{{inputs.parameters.learning-rate}}"
DATASET_DIR="/tmp/dataset"
OUTPUT_DIR="/tmp/output"
mkdir -p "$OUTPUT_DIR"
echo "=== Coqui TTS Voice Training ==="
echo "Voice Name: $VOICE_NAME"
echo "Base Model: $BASE_MODEL"
echo "Language: $LANGUAGE"
echo "Epochs: $NUM_EPOCHS"
echo "Batch Size: $BATCH_SIZE"
echo "Learning Rate: $LEARNING_RATE"
echo ""
# Download base model if specified for fine-tuning
RESTORE_PATH=""
if [ "$BASE_MODEL" != "" ] && [ "$BASE_MODEL" != "none" ]; then
echo "Downloading base model for fine-tuning: $BASE_MODEL"
# Use tts to download the model and get its path
MODEL_PATH=$(python3 -c "
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer
from pathlib import Path
import os
model_name = '$BASE_MODEL'
manager = ModelManager()
# Download the model
model_path, config_path, _ = manager.download_model(model_name)
print(model_path)
")
RESTORE_PATH="$MODEL_PATH"
echo "Base model path: $RESTORE_PATH"
fi
# Create and run training script following Coqui docs pattern
python3 << EOF
import os
from pathlib import Path
# Trainer: Where the magic happens
from trainer import Trainer, TrainerArgs
# Model configs
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
# Paths
DATASET_DIR = Path("$DATASET_DIR")
OUTPUT_DIR = Path("$OUTPUT_DIR")
RESTORE_PATH = "$RESTORE_PATH" if "$RESTORE_PATH" else None
print(f"Dataset: {DATASET_DIR}")
print(f"Output: {OUTPUT_DIR}")
print(f"Restore from: {RESTORE_PATH}")
# Define dataset config (LJSpeech format)
dataset_config = BaseDatasetConfig(
formatter="ljspeech",
meta_file_train="metadata.csv",
path=str(DATASET_DIR),
language="$LANGUAGE",
)
# Initialize training configuration
config = VitsConfig(
run_name="$VOICE_NAME",
output_path=str(OUTPUT_DIR),
datasets=[dataset_config],
batch_size=int("$BATCH_SIZE"),
eval_batch_size=max(1, int("$BATCH_SIZE") // 2),
num_loader_workers=4,
num_eval_loader_workers=2,
run_eval=True,
test_delay_epochs=5,
epochs=int("$NUM_EPOCHS"),
text_cleaner="phoneme_cleaners",
use_phonemes=True,
phoneme_language="$LANGUAGE",
phoneme_cache_path=str(OUTPUT_DIR / "phoneme_cache"),
compute_input_seq_cache=True,
print_step=25,
print_eval=False,
mixed_precision=True,
save_step=500,
save_n_checkpoints=3,
save_best_after=1000,
lr=float("$LEARNING_RATE"),
# Audio settings for typical voice cloning
audio={
"sample_rate": 22050,
"resample": True,
"do_trim_silence": True,
"trim_db": 45,
},
)
# Initialize the audio processor
# Used for feature extraction and audio I/O
ap = AudioProcessor.init_from_config(config)
# Initialize the tokenizer
# Converts text to sequences of token IDs
tokenizer, config = TTSTokenizer.init_from_config(config)
# Load data samples
# Each sample is [text, audio_file_path, speaker_name]
train_samples, eval_samples = load_tts_samples(
dataset_config,
eval_split=True,
eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
)
print(f"Training samples: {len(train_samples)}")
print(f"Eval samples: {len(eval_samples)}")
# Initialize the model
model = Vits(config, ap, tokenizer, speaker_manager=None)
# Set up trainer arguments
trainer_args = TrainerArgs(
restore_path=RESTORE_PATH,
skip_train_epoch=False,
)
# Initialize the trainer
trainer = Trainer(
trainer_args,
config,
output_path=str(OUTPUT_DIR),
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
)
# Start training
print("\n" + "=" * 50)
print("Starting training...")
print("=" * 50 + "\n")
trainer.fit()
print("\n" + "=" * 50)
print("Training complete!")
print("=" * 50)
EOF
echo ""
echo "Training complete!"
echo "Model saved to: $OUTPUT_DIR"
ls -la "$OUTPUT_DIR"
resources:
requests:
memory: 16Gi
cpu: "4"
nvidia.com/gpu: "1"
limits:
memory: 32Gi
cpu: "8"
nvidia.com/gpu: "1"
volumeMounts:
- name: training-workspace
mountPath: /tmp/workspace
# Template: Export trained model
- name: export-trained-model
inputs:
parameters:
- name: voice-name
- name: output-path
artifacts:
- name: trained-model
path: /tmp/trained-model
outputs:
artifacts:
- name: exported-model
path: /tmp/exported
container:
image: python:3.13-slim
command: [bash]
args:
- -c
- |
set -e
pip install -q boto3
python3 << 'EOF'
import json
import os
import shutil
import subprocess
from pathlib import Path
from datetime import datetime
VOICE_NAME = "{{inputs.parameters.voice-name}}"
OUTPUT_PATH = "{{inputs.parameters.output-path}}"
model_dir = Path("/tmp/trained-model")
export_dir = Path("/tmp/exported")
export_dir.mkdir(parents=True, exist_ok=True)
print(f"Exporting trained model: {VOICE_NAME}")
print(f"Target path: {OUTPUT_PATH}")
# Find best checkpoint
checkpoints = list(model_dir.glob("best_model*.pth")) + list(model_dir.glob("checkpoint_*.pth"))
if not checkpoints:
checkpoints = list(model_dir.glob("*.pth"))
if not checkpoints:
print("Error: No model checkpoints found!")
exit(1)
# Sort by modification time and get newest
checkpoints.sort(key=lambda x: x.stat().st_mtime, reverse=True)
best_checkpoint = checkpoints[0]
print(f"Using checkpoint: {best_checkpoint.name}")
# Create export package
package_dir = export_dir / VOICE_NAME
package_dir.mkdir(parents=True, exist_ok=True)
# Copy model files
shutil.copy(best_checkpoint, package_dir / "model.pth")
# Copy config if exists
config_file = model_dir / "config.json"
if config_file.exists():
shutil.copy(config_file, package_dir / "config.json")
# Create model info
model_info = {
"name": VOICE_NAME,
"created_at": datetime.now().isoformat(),
"checkpoint": best_checkpoint.name,
"type": "coqui-tts"
}
with open(package_dir / "model_info.json", "w") as f:
json.dump(model_info, f, indent=2)
# Create tarball
archive_name = f"{VOICE_NAME}.tar.gz"
shutil.make_archive(
str(export_dir / VOICE_NAME),
"gztar",
export_dir,
VOICE_NAME
)
print(f"Created archive: {archive_name}")
# Upload to destination
if OUTPUT_PATH.startswith("s3://"):
import boto3
s3 = boto3.client("s3")
bucket, key = OUTPUT_PATH[5:].split("/", 1)
key = f"{key}/{archive_name}"
s3.upload_file(str(export_dir / archive_name), bucket, key)
print(f"Uploaded to: s3://{bucket}/{key}")
elif OUTPUT_PATH.startswith("/"):
# Local/NFS path
dest_path = Path(OUTPUT_PATH)
dest_path.mkdir(parents=True, exist_ok=True)
shutil.copy(export_dir / archive_name, dest_path / archive_name)
# Also copy uncompressed for easy access
shutil.copytree(package_dir, dest_path / VOICE_NAME, dirs_exist_ok=True)
print(f"Saved to: {dest_path / archive_name}")
print("\nExport complete!")
print(f"Model package contents:")
for f in package_dir.iterdir():
print(f" - {f.name}")
EOF
resources:
requests:
memory: 1Gi
cpu: 500m

369
document-ingestion.yaml Normal file
View File

@@ -0,0 +1,369 @@
# Document Ingestion Workflow
# Ingests documents from a source URL into Milvus vector database
# Triggered via NATS: ai.pipeline.trigger with pipeline="document-ingestion"
---
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
name: document-ingestion
namespace: ai-ml
labels:
app.kubernetes.io/name: document-ingestion
app.kubernetes.io/part-of: llm-workflows
spec:
entrypoint: ingest-documents
serviceAccountName: argo-workflow
arguments:
parameters:
- name: source-url
description: "URL to fetch documents from (S3, HTTP, or local path)"
- name: collection-name
value: "knowledge_base"
description: "Milvus collection name"
- name: chunk-size
value: "512"
description: "Text chunk size in characters"
- name: chunk-overlap
value: "50"
description: "Overlap between chunks"
templates:
- name: ingest-documents
dag:
tasks:
- name: fetch-documents
template: fetch-docs
arguments:
parameters:
- name: source-url
value: "{{workflow.parameters.source-url}}"
- name: chunk-documents
template: chunk-docs
dependencies: [fetch-documents]
arguments:
parameters:
- name: chunk-size
value: "{{workflow.parameters.chunk-size}}"
- name: chunk-overlap
value: "{{workflow.parameters.chunk-overlap}}"
artifacts:
- name: documents
from: "{{tasks.fetch-documents.outputs.artifacts.documents}}"
- name: generate-embeddings
template: embed-docs
dependencies: [chunk-documents]
arguments:
artifacts:
- name: chunks
from: "{{tasks.chunk-documents.outputs.artifacts.chunks}}"
- name: store-in-milvus
template: store-docs
dependencies: [generate-embeddings]
arguments:
parameters:
- name: collection-name
value: "{{workflow.parameters.collection-name}}"
artifacts:
- name: embeddings
from: "{{tasks.generate-embeddings.outputs.artifacts.embeddings}}"
- name: fetch-docs
inputs:
parameters:
- name: source-url
outputs:
artifacts:
- name: documents
path: /tmp/documents
container:
image: python:3.13-slim
command: [python]
args:
- -c
- |
import json
import os
import urllib.request
from pathlib import Path
source_url = "{{inputs.parameters.source-url}}"
output_dir = Path("/tmp/documents")
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Fetching documents from: {source_url}")
# Handle different source types
if source_url.startswith("s3://"):
import subprocess
subprocess.run(["pip", "install", "boto3", "-q"], check=True)
import boto3
s3 = boto3.client("s3")
bucket, prefix = source_url[5:].split("/", 1)
response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
for obj in response.get("Contents", []):
key = obj["Key"]
local_path = output_dir / Path(key).name
s3.download_file(bucket, key, str(local_path))
print(f"Downloaded: {key}")
elif source_url.startswith("http"):
# Single file download
filename = source_url.split("/")[-1] or "document.txt"
local_path = output_dir / filename
urllib.request.urlretrieve(source_url, local_path)
print(f"Downloaded: {filename}")
else:
print(f"Unsupported URL scheme: {source_url}")
exit(1)
# List downloaded files
files = list(output_dir.glob("*"))
print(f"Downloaded {len(files)} files")
# Create manifest
manifest = {"files": [str(f) for f in files]}
with open(output_dir / "manifest.json", "w") as f:
json.dump(manifest, f)
resources:
requests:
memory: 256Mi
cpu: 100m
- name: chunk-docs
inputs:
parameters:
- name: chunk-size
- name: chunk-overlap
artifacts:
- name: documents
path: /tmp/documents
outputs:
artifacts:
- name: chunks
path: /tmp/chunks
container:
image: python:3.13-slim
command: [python]
args:
- -c
- |
import json
from pathlib import Path
chunk_size = int("{{inputs.parameters.chunk-size}}")
chunk_overlap = int("{{inputs.parameters.chunk-overlap}}")
input_dir = Path("/tmp/documents")
output_dir = Path("/tmp/chunks")
output_dir.mkdir(parents=True, exist_ok=True)
# Load manifest
with open(input_dir / "manifest.json") as f:
manifest = json.load(f)
all_chunks = []
for filepath in manifest["files"]:
filepath = Path(filepath)
if not filepath.exists():
continue
print(f"Processing: {filepath.name}")
# Read file content
try:
with open(filepath, "r", encoding="utf-8") as f:
content = f.read()
except Exception as e:
print(f"Error reading {filepath}: {e}")
continue
# Simple chunking
chunks = []
start = 0
while start < len(content):
end = start + chunk_size
chunk = content[start:end]
if chunk.strip():
chunks.append({
"text": chunk,
"source": filepath.name,
"chunk_index": len(chunks)
})
start = end - chunk_overlap
all_chunks.extend(chunks)
print(f" Created {len(chunks)} chunks")
# Save chunks
with open(output_dir / "chunks.json", "w") as f:
json.dump({"chunks": all_chunks}, f)
print(f"Total chunks: {len(all_chunks)}")
resources:
requests:
memory: 512Mi
cpu: 100m
- name: embed-docs
inputs:
artifacts:
- name: chunks
path: /tmp/chunks
outputs:
artifacts:
- name: embeddings
path: /tmp/embeddings
container:
image: python:3.13-slim
command: [python]
args:
- -c
- |
import subprocess
subprocess.run(["pip", "install", "httpx", "-q"], check=True)
import json
import httpx
from pathlib import Path
EMBEDDINGS_URL = "http://embeddings-predictor.ai-ml.svc.cluster.local"
BATCH_SIZE = 32
input_dir = Path("/tmp/chunks")
output_dir = Path("/tmp/embeddings")
output_dir.mkdir(parents=True, exist_ok=True)
# Load chunks
with open(input_dir / "chunks.json") as f:
data = json.load(f)
chunks = data["chunks"]
print(f"Generating embeddings for {len(chunks)} chunks")
# Generate embeddings in batches
all_embeddings = []
with httpx.Client(timeout=120.0) as client:
for i in range(0, len(chunks), BATCH_SIZE):
batch = chunks[i:i+BATCH_SIZE]
texts = [c["text"] for c in batch]
response = client.post(
f"{EMBEDDINGS_URL}/embeddings",
json={"input": texts, "model": "bge"}
)
result = response.json()
for j, emb_data in enumerate(result.get("data", [])):
all_embeddings.append({
"text": batch[j]["text"],
"source": batch[j]["source"],
"chunk_index": batch[j]["chunk_index"],
"embedding": emb_data["embedding"]
})
print(f" Processed batch {i//BATCH_SIZE + 1}/{(len(chunks)-1)//BATCH_SIZE + 1}")
# Save embeddings
with open(output_dir / "embeddings.json", "w") as f:
json.dump({"embeddings": all_embeddings}, f)
print(f"Generated {len(all_embeddings)} embeddings")
envFrom:
- configMapRef:
name: ai-services-config
resources:
requests:
memory: 1Gi
cpu: 200m
- name: store-docs
inputs:
parameters:
- name: collection-name
artifacts:
- name: embeddings
path: /tmp/embeddings
container:
image: python:3.13-slim
command: [python]
args:
- -c
- |
import subprocess
subprocess.run(["pip", "install", "pymilvus", "-q"], check=True)
import json
from pathlib import Path
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
MILVUS_HOST = "milvus.ai-ml.svc.cluster.local"
MILVUS_PORT = 19530
COLLECTION_NAME = "{{inputs.parameters.collection-name}}"
EMBEDDING_DIM = 1024 # BGE-large dimension
input_dir = Path("/tmp/embeddings")
# Load embeddings
with open(input_dir / "embeddings.json") as f:
data = json.load(f)
embeddings = data["embeddings"]
print(f"Storing {len(embeddings)} embeddings in Milvus")
# Connect to Milvus
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
print("Connected to Milvus")
# Create collection if not exists
if not utility.has_collection(COLLECTION_NAME):
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=1024),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=EMBEDDING_DIM)
]
schema = CollectionSchema(fields, description="Knowledge base documents")
collection = Collection(COLLECTION_NAME, schema)
# Create HNSW index
index_params = {
"metric_type": "COSINE",
"index_type": "HNSW",
"params": {"M": 16, "efConstruction": 256}
}
collection.create_index("embedding", index_params)
print(f"Created collection: {COLLECTION_NAME}")
else:
collection = Collection(COLLECTION_NAME)
print(f"Using existing collection: {COLLECTION_NAME}")
# Insert data in batches
BATCH_SIZE = 100
for i in range(0, len(embeddings), BATCH_SIZE):
batch = embeddings[i:i+BATCH_SIZE]
data = [
[e["text"] for e in batch],
[e["source"] for e in batch],
[e["embedding"] for e in batch]
]
collection.insert(data)
print(f" Inserted batch {i//BATCH_SIZE + 1}/{(len(embeddings)-1)//BATCH_SIZE + 1}")
# Flush to ensure data is persisted
collection.flush()
print(f"Successfully stored {len(embeddings)} documents")
connections.disconnect("default")
envFrom:
- configMapRef:
name: ai-services-config
resources:
requests:
memory: 512Mi
cpu: 100m

270
eventsource-kfp.yaml Normal file
View File

@@ -0,0 +1,270 @@
# Argo Events - EventSource for KFP and NATS integration
# Enables bidirectional triggering between Argo Workflows and Kubeflow Pipelines
---
apiVersion: argoproj.io/v1alpha1
kind: EventSource
metadata:
name: kfp-events
namespace: ai-ml
labels:
app.kubernetes.io/name: kfp-events
app.kubernetes.io/part-of: llm-workflows
spec:
service:
ports:
- name: webhook
port: 12000
targetPort: 12000
# Webhook to receive KFP pipeline completion events
webhook:
kfp-completion:
port: "12000"
endpoint: /kfp/completion
method: POST
kfp-failure:
port: "12000"
endpoint: /kfp/failure
method: POST
# NATS for receiving pipeline trigger requests
nats:
pipeline-trigger:
url: nats://nats.ai-ml.svc.cluster.local:4222
subject: ai.pipeline.trigger
jsonBody: true
argo-trigger:
url: nats://nats.ai-ml.svc.cluster.local:4222
subject: ai.argo.trigger
jsonBody: true
kfp-trigger:
url: nats://nats.ai-ml.svc.cluster.local:4222
subject: ai.kfp.trigger
jsonBody: true
---
# Sensor for handling KFP completion events
apiVersion: argoproj.io/v1alpha1
kind: Sensor
metadata:
name: kfp-completion-sensor
namespace: ai-ml
labels:
app.kubernetes.io/name: kfp-completion-sensor
app.kubernetes.io/part-of: llm-workflows
spec:
dependencies:
- name: kfp-success
eventSourceName: kfp-events
eventName: kfp-completion
filters:
data:
- path: body.status
type: string
value:
- "SUCCEEDED"
- name: kfp-failure
eventSourceName: kfp-events
eventName: kfp-failure
triggers:
# On KFP success, publish to NATS
- template:
name: notify-kfp-success
nats:
url: nats://nats.ai-ml.svc.cluster.local:4222
subject: ai.pipeline.status.completed
payload:
- src:
dependencyName: kfp-success
dataKey: body.run_id
dest: run_id
- src:
dependencyName: kfp-success
dataKey: body.pipeline_name
dest: pipeline_name
- src:
dependencyName: kfp-success
dataKey: body.status
dest: status
retryStrategy:
steps: 3
# On KFP failure, trigger recovery workflow
- template:
name: kfp-failure-recovery
k8s:
operation: create
source:
resource:
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: kfp-failure-handler-
namespace: ai-ml
spec:
entrypoint: notify-failure
arguments:
parameters:
- name: run-id
- name: pipeline-name
- name: error-message
templates:
- name: notify-failure
inputs:
parameters:
- name: run-id
- name: pipeline-name
- name: error-message
script:
image: python:3.13-slim
command: [python]
source: |
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "nats-py"])
import asyncio
import json
import nats
async def notify():
nc = await nats.connect("nats://nats.ai-ml.svc.cluster.local:4222")
await nc.publish(
"ai.pipeline.status.failed",
json.dumps({
"run_id": "{{inputs.parameters.run-id}}",
"pipeline_name": "{{inputs.parameters.pipeline-name}}",
"error": "{{inputs.parameters.error-message}}",
"source": "kubeflow"
}).encode()
)
await nc.close()
asyncio.run(notify())
parameters:
- src:
dependencyName: kfp-failure
dataKey: body.run_id
dest: spec.arguments.parameters.0.value
- src:
dependencyName: kfp-failure
dataKey: body.pipeline_name
dest: spec.arguments.parameters.1.value
- src:
dependencyName: kfp-failure
dataKey: body.error
dest: spec.arguments.parameters.2.value
retryStrategy:
steps: 3
---
# Sensor for NATS-triggered Argo Workflows
apiVersion: argoproj.io/v1alpha1
kind: Sensor
metadata:
name: nats-argo-sensor
namespace: ai-ml
labels:
app.kubernetes.io/name: nats-argo-sensor
app.kubernetes.io/part-of: llm-workflows
spec:
dependencies:
- name: argo-trigger
eventSourceName: kfp-events
eventName: argo-trigger
triggers:
- template:
name: trigger-argo-workflow
k8s:
operation: create
source:
resource:
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: nats-triggered-
namespace: ai-ml
spec:
workflowTemplateRef:
name: placeholder
arguments:
parameters: []
parameters:
- src:
dependencyName: argo-trigger
dataKey: body.template
dest: spec.workflowTemplateRef.name
- src:
dependencyName: argo-trigger
dataKey: body.parameters
dest: spec.arguments.parameters
retryStrategy:
steps: 3
---
# Sensor for NATS-triggered KFP Pipelines
apiVersion: argoproj.io/v1alpha1
kind: Sensor
metadata:
name: nats-kfp-sensor
namespace: ai-ml
labels:
app.kubernetes.io/name: nats-kfp-sensor
app.kubernetes.io/part-of: llm-workflows
spec:
dependencies:
- name: kfp-trigger
eventSourceName: kfp-events
eventName: kfp-trigger
triggers:
# Trigger KFP via Argo Workflow (uses kfp-trigger template)
- template:
name: trigger-kfp-via-argo
k8s:
operation: create
source:
resource:
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: kfp-via-nats-
namespace: ai-ml
spec:
workflowTemplateRef:
name: kfp-trigger
arguments:
parameters:
- name: pipeline-id
value: ""
- name: pipeline-params
value: "{}"
- name: wait-for-completion
value: "true"
parameters:
- src:
dependencyName: kfp-trigger
dataKey: body.pipeline_id
dest: spec.arguments.parameters.0.value
- src:
dependencyName: kfp-trigger
dataKey: body.parameters
dest: spec.arguments.parameters.1.value
operation: "stringify"
- src:
dependencyName: kfp-trigger
dataKey: body.wait
dest: spec.arguments.parameters.2.value
operation: "stringify"
retryStrategy:
steps: 3
---
# Service for the EventSource webhook
apiVersion: v1
kind: Service
metadata:
name: kfp-events-webhook
namespace: ai-ml
labels:
app.kubernetes.io/name: kfp-events
app.kubernetes.io/part-of: llm-workflows
spec:
selector:
eventsource-name: kfp-events
ports:
- name: webhook
port: 12000
targetPort: 12000

555
hybrid-ml-training.yaml Normal file
View File

@@ -0,0 +1,555 @@
# Hybrid ML Training Workflow
# Combines Argo Workflows orchestration with Kubeflow Pipeline ML components
# Use case: Train a model using data from Milvus, with checkpointing and evaluation
---
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
name: hybrid-ml-training
namespace: ai-ml
labels:
app.kubernetes.io/name: hybrid-ml-training
app.kubernetes.io/part-of: llm-workflows
annotations:
description: |
Demonstrates hybrid Argo+KFP workflow:
- Argo handles orchestration, branching, retry logic
- KFP pipelines handle ML-specific operations (with caching)
- NATS for status updates to frontends
spec:
entrypoint: hybrid-training
serviceAccountName: argo-workflow
# Artifact repository for model checkpoints
artifactRepositoryRef:
configMap: artifact-repository
key: default
arguments:
parameters:
- name: collection-name
description: "Milvus collection to pull training data from"
value: "dnd_text_embeddings"
- name: model-name
description: "Base model for fine-tuning"
value: "mistralai/Mistral-7B-v0.3"
- name: lora-rank
description: "LoRA rank (higher = more params)"
value: "16"
- name: epochs
description: "Training epochs"
value: "3"
- name: batch-size
description: "Training batch size"
value: "4"
- name: output-path
description: "S3 path for model output"
value: "s3://models/lora-adapters"
- name: notify-nats
description: "Publish status to NATS"
value: "true"
# Volumes for GPU caching
volumes:
- name: model-cache
persistentVolumeClaim:
claimName: model-cache-pvc
- name: shm
emptyDir:
medium: Memory
sizeLimit: 16Gi
templates:
# Main DAG orchestrating the workflow
- name: hybrid-training
dag:
tasks:
- name: notify-start
template: nats-notify
when: "{{workflow.parameters.notify-nats}} == true"
arguments:
parameters:
- name: subject
value: "ai.pipeline.status.{{workflow.name}}"
- name: message
value: '{"status": "started", "pipeline": "hybrid-ml-training"}'
- name: prepare-data
template: extract-training-data
arguments:
parameters:
- name: collection-name
value: "{{workflow.parameters.collection-name}}"
- name: validate-data
template: validate-dataset
dependencies: [prepare-data]
arguments:
artifacts:
- name: dataset
from: "{{tasks.prepare-data.outputs.artifacts.dataset}}"
# KFP Pipeline: Run embedding generation if needed
- name: generate-embeddings
template: trigger-kfp
dependencies: [validate-data]
when: "{{tasks.validate-data.outputs.parameters.needs-embeddings}} == true"
arguments:
parameters:
- name: pipeline-id
value: "embedding-generation"
- name: params
value: '{"input_path": "{{tasks.prepare-data.outputs.parameters.data-path}}"}'
# Training step (runs on GPU)
- name: train-lora
template: lora-training
dependencies: [validate-data, generate-embeddings]
arguments:
parameters:
- name: model-name
value: "{{workflow.parameters.model-name}}"
- name: lora-rank
value: "{{workflow.parameters.lora-rank}}"
- name: epochs
value: "{{workflow.parameters.epochs}}"
- name: batch-size
value: "{{workflow.parameters.batch-size}}"
artifacts:
- name: dataset
from: "{{tasks.prepare-data.outputs.artifacts.dataset}}"
# Evaluate model
- name: evaluate
template: evaluate-model
dependencies: [train-lora]
arguments:
artifacts:
- name: adapter
from: "{{tasks.train-lora.outputs.artifacts.adapter}}"
# Branch based on evaluation results
- name: check-quality
template: quality-gate
dependencies: [evaluate]
arguments:
parameters:
- name: eval-score
value: "{{tasks.evaluate.outputs.parameters.score}}"
- name: threshold
value: "0.7"
# If quality is good, upload to S3
- name: upload-model
template: upload-to-s3
dependencies: [check-quality]
when: "{{tasks.check-quality.outputs.parameters.passed}} == true"
arguments:
parameters:
- name: output-path
value: "{{workflow.parameters.output-path}}"
artifacts:
- name: adapter
from: "{{tasks.train-lora.outputs.artifacts.adapter}}"
# If quality is poor, trigger retraining with different params
- name: retry-training
template: adjust-and-retry
dependencies: [check-quality]
when: "{{tasks.check-quality.outputs.parameters.passed}} == false"
arguments:
parameters:
- name: current-rank
value: "{{workflow.parameters.lora-rank}}"
- name: current-epochs
value: "{{workflow.parameters.epochs}}"
- name: notify-complete
template: nats-notify
when: "{{workflow.parameters.notify-nats}} == true"
dependencies: [upload-model]
arguments:
parameters:
- name: subject
value: "ai.pipeline.status.{{workflow.name}}"
- name: message
value: '{"status": "completed", "score": "{{tasks.evaluate.outputs.parameters.score}}"}'
# Extract training data from Milvus
- name: extract-training-data
inputs:
parameters:
- name: collection-name
outputs:
artifacts:
- name: dataset
path: /tmp/dataset
parameters:
- name: data-path
valueFrom:
path: /tmp/data-path
script:
image: python:3.13-slim
command: [python]
source: |
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "pymilvus", "pandas", "pyarrow"])
from pymilvus import connections, Collection
import pandas as pd
from pathlib import Path
connections.connect(host="milvus.ai-ml.svc.cluster.local", port=19530)
collection = Collection("{{inputs.parameters.collection-name}}")
collection.load()
# Query all training samples
results = collection.query(
expr="source != ''",
output_fields=["text", "source", "metadata"],
limit=10000
)
# Convert to training format
df = pd.DataFrame(results)
output_dir = Path("/tmp/dataset")
output_dir.mkdir(parents=True, exist_ok=True)
# Save as parquet for efficient loading
df.to_parquet(output_dir / "train.parquet")
print(f"Extracted {len(df)} samples")
with open("/tmp/data-path", "w") as f:
f.write(str(output_dir / "train.parquet"))
# Validate dataset
- name: validate-dataset
inputs:
artifacts:
- name: dataset
path: /tmp/dataset
outputs:
parameters:
- name: needs-embeddings
valueFrom:
path: /tmp/needs-embeddings
- name: sample-count
valueFrom:
path: /tmp/sample-count
script:
image: python:3.13-slim
command: [python]
source: |
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "pandas", "pyarrow"])
import pandas as pd
from pathlib import Path
dataset_dir = Path("/tmp/dataset")
parquet_files = list(dataset_dir.glob("*.parquet"))
if not parquet_files:
raise ValueError("No parquet files found in dataset")
df = pd.read_parquet(parquet_files[0])
sample_count = len(df)
print(f"Dataset contains {sample_count} samples")
# Check if embeddings column exists
needs_embeddings = "embedding" not in df.columns
with open("/tmp/needs-embeddings", "w") as f:
f.write(str(needs_embeddings).lower())
with open("/tmp/sample-count", "w") as f:
f.write(str(sample_count))
# Trigger KFP pipeline
- name: trigger-kfp
inputs:
parameters:
- name: pipeline-id
- name: params
outputs:
parameters:
- name: run-id
valueFrom:
path: /tmp/run-id
script:
image: python:3.13-slim
command: [python]
source: |
import subprocess
import sys
import json
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "kfp==2.12.1"])
from kfp import Client
client = Client(host="http://ml-pipeline.kubeflow.svc.cluster.local:8888")
params = json.loads('''{{inputs.parameters.params}}''')
run = client.create_run_from_pipeline_func(
pipeline_func=None,
pipeline_id="{{inputs.parameters.pipeline-id}}",
arguments=params
)
print(f"Triggered KFP pipeline: {run.run_id}")
with open("/tmp/run-id", "w") as f:
f.write(run.run_id)
# LoRA training (GPU)
- name: lora-training
inputs:
parameters:
- name: model-name
- name: lora-rank
- name: epochs
- name: batch-size
artifacts:
- name: dataset
path: /data/dataset
outputs:
artifacts:
- name: adapter
path: /output/adapter
- name: logs
path: /output/logs
podSpecPatch: |
containers:
- name: main
resources:
requests:
amd.com/gpu: 1
limits:
amd.com/gpu: 1
script:
image: ghcr.io/billy-davies-2/lora-trainer:latest
command: [python]
env:
- name: HF_HOME
value: /cache/huggingface
- name: TRANSFORMERS_CACHE
value: /cache/huggingface
volumeMounts:
- name: model-cache
mountPath: /cache
- name: shm
mountPath: /dev/shm
source: |
import os
import sys
from pathlib import Path
# Training script
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from trl import SFTTrainer
model_name = "{{inputs.parameters.model-name}}"
lora_rank = int("{{inputs.parameters.lora-rank}}")
epochs = int("{{inputs.parameters.epochs}}")
batch_size = int("{{inputs.parameters.batch-size}}")
# Load dataset
dataset = load_dataset("parquet", data_files="/data/dataset/*.parquet", split="train")
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Configure LoRA
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_rank * 2,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
# Training
training_args = TrainingArguments(
output_dir="/output/adapter",
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=4,
learning_rate=2e-4,
logging_dir="/output/logs",
save_strategy="epoch"
)
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
tokenizer=tokenizer,
args=training_args,
dataset_text_field="text"
)
trainer.train()
trainer.save_model("/output/adapter")
print("Training complete!")
# Evaluate model
- name: evaluate-model
inputs:
artifacts:
- name: adapter
path: /input/adapter
outputs:
parameters:
- name: score
valueFrom:
path: /tmp/score
script:
image: python:3.13-slim
command: [python]
source: |
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "httpx"])
import httpx
import json
# Run evaluation using vLLM with the adapter
test_prompts = [
"What is the capital of France?",
"Explain machine learning in simple terms.",
"Write a haiku about coding."
]
scores = []
with httpx.Client(timeout=120.0) as client:
for prompt in test_prompts:
response = client.post(
"http://llm-draft.ai-ml.svc.cluster.local:8000/v1/chat/completions",
json={
"model": "local-adapter",
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 200
}
)
# Simple scoring based on response coherence
result = response.json()
content = result["choices"][0]["message"]["content"]
score = min(1.0, len(content) / 100) # Placeholder scoring
scores.append(score)
avg_score = sum(scores) / len(scores)
print(f"Average evaluation score: {avg_score}")
with open("/tmp/score", "w") as f:
f.write(str(round(avg_score, 3)))
# Quality gate
- name: quality-gate
inputs:
parameters:
- name: eval-score
- name: threshold
outputs:
parameters:
- name: passed
valueFrom:
path: /tmp/passed
script:
image: python:3.13-slim
command: [python]
source: |
score = float("{{inputs.parameters.eval-score}}")
threshold = float("{{inputs.parameters.threshold}}")
passed = score >= threshold
print(f"Score {score} {'passed' if passed else 'failed'} threshold {threshold}")
with open("/tmp/passed", "w") as f:
f.write(str(passed).lower())
# Upload to S3
- name: upload-to-s3
inputs:
parameters:
- name: output-path
artifacts:
- name: adapter
path: /input/adapter
script:
image: amazon/aws-cli:latest
command: [bash]
env:
- name: AWS_ACCESS_KEY_ID
valueFrom:
secretKeyRef:
name: s3-credentials
key: access-key
- name: AWS_SECRET_ACCESS_KEY
valueFrom:
secretKeyRef:
name: s3-credentials
key: secret-key
- name: AWS_ENDPOINT_URL
value: "https://quobjects.billy.davies.cloud"
source: |
aws s3 cp --recursive /input/adapter "{{inputs.parameters.output-path}}/$(date +%Y%m%d-%H%M%S)/"
echo "Uploaded adapter to {{inputs.parameters.output-path}}"
# Adjust parameters and retry
- name: adjust-and-retry
inputs:
parameters:
- name: current-rank
- name: current-epochs
script:
image: python:3.13-slim
command: [python]
source: |
current_rank = int("{{inputs.parameters.current-rank}}")
current_epochs = int("{{inputs.parameters.current-epochs}}")
# Increase rank and epochs for next attempt
new_rank = min(64, current_rank * 2)
new_epochs = current_epochs + 2
print(f"Adjusting parameters: rank {current_rank}->{new_rank}, epochs {current_epochs}->{new_epochs}")
print("TODO: Trigger new workflow with adjusted parameters")
# NATS notification
- name: nats-notify
inputs:
parameters:
- name: subject
- name: message
script:
image: python:3.13-slim
command: [python]
source: |
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "nats-py"])
import asyncio
import nats
async def notify():
nc = await nats.connect("nats://nats.ai-ml.svc.cluster.local:4222")
await nc.publish(
"{{inputs.parameters.subject}}",
b'''{{inputs.parameters.message}}'''
)
await nc.close()
asyncio.run(notify())

237
kfp-integration.yaml Normal file
View File

@@ -0,0 +1,237 @@
# Argo Workflows + Kubeflow Pipelines Integration
# This template allows Argo Workflows to trigger KFP pipelines and vice versa
---
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
name: kfp-trigger
namespace: ai-ml
labels:
app.kubernetes.io/name: kfp-trigger
app.kubernetes.io/part-of: llm-workflows
spec:
entrypoint: trigger-kfp-pipeline
serviceAccountName: argo-workflow
arguments:
parameters:
- name: pipeline-id
description: "Kubeflow Pipeline ID or name"
- name: pipeline-params
description: "JSON object of pipeline parameters"
value: "{}"
- name: experiment-name
description: "KFP Experiment to use"
value: "Default"
- name: wait-for-completion
description: "Wait for pipeline to complete"
value: "true"
templates:
- name: trigger-kfp-pipeline
steps:
- - name: submit-run
template: submit-kfp-run
arguments:
parameters:
- name: pipeline-id
value: "{{workflow.parameters.pipeline-id}}"
- name: pipeline-params
value: "{{workflow.parameters.pipeline-params}}"
- name: experiment-name
value: "{{workflow.parameters.experiment-name}}"
- - name: wait-completion
template: wait-for-kfp
when: "{{workflow.parameters.wait-for-completion}} == true"
arguments:
parameters:
- name: run-id
value: "{{steps.submit-run.outputs.parameters.run-id}}"
- name: submit-kfp-run
inputs:
parameters:
- name: pipeline-id
- name: pipeline-params
- name: experiment-name
outputs:
parameters:
- name: run-id
valueFrom:
path: /tmp/run-id
script:
image: python:3.13-slim
command: [python]
source: |
import json
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "kfp==2.12.1"])
from kfp import Client
KUBEFLOW_HOST = "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
client = Client(host=KUBEFLOW_HOST)
pipeline_id = "{{inputs.parameters.pipeline-id}}"
params = json.loads('''{{inputs.parameters.pipeline-params}}''')
experiment_name = "{{inputs.parameters.experiment-name}}"
# Get or create experiment
try:
experiment = client.get_experiment(experiment_name=experiment_name)
except:
experiment = client.create_experiment(name=experiment_name)
# Get pipeline by name or ID
try:
pipeline = client.get_pipeline(pipeline_id)
except:
# Try by name
pipelines = client.list_pipelines(filter=f'name="{pipeline_id}"')
if pipelines.pipelines:
pipeline = pipelines.pipelines[0]
else:
raise ValueError(f"Pipeline not found: {pipeline_id}")
# Create run
run = client.run_pipeline(
experiment_id=experiment.experiment_id,
job_name=f"{pipeline.display_name}-argo-{pipeline_id[:8]}",
pipeline_id=pipeline.pipeline_id,
params=params
)
print(f"Submitted KFP run: {run.run_id}")
with open("/tmp/run-id", "w") as f:
f.write(run.run_id)
- name: wait-for-kfp
inputs:
parameters:
- name: run-id
outputs:
parameters:
- name: status
valueFrom:
path: /tmp/status
script:
image: python:3.13-slim
command: [python]
source: |
import subprocess
import sys
import time
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "kfp==2.12.1"])
from kfp import Client
KUBEFLOW_HOST = "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
run_id = "{{inputs.parameters.run-id}}"
client = Client(host=KUBEFLOW_HOST)
while True:
run = client.get_run(run_id)
state = run.run.status
print(f"Run {run_id} status: {state}")
if state in ["SUCCEEDED", "SKIPPED"]:
with open("/tmp/status", "w") as f:
f.write("SUCCEEDED")
break
elif state in ["FAILED", "ERROR", "CANCELLED"]:
with open("/tmp/status", "w") as f:
f.write(state)
raise Exception(f"Pipeline failed with status: {state}")
time.sleep(30)
---
# WorkflowTemplate for running KFP pipeline components as Argo steps
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
name: kfp-component-runner
namespace: ai-ml
labels:
app.kubernetes.io/name: kfp-component-runner
app.kubernetes.io/part-of: llm-workflows
spec:
entrypoint: run-component
serviceAccountName: argo-workflow
arguments:
parameters:
- name: component-name
description: "Name of the KFP component to run"
- name: component-params
description: "JSON parameters for the component"
value: "{}"
templates:
- name: run-component
inputs:
parameters:
- name: component-name
- name: component-params
outputs:
parameters:
- name: result
valueFrom:
path: /tmp/result.json
script:
image: python:3.13-slim
command: [python]
source: |
import json
import subprocess
import sys
subprocess.check_call([
sys.executable, "-m", "pip", "install", "-q",
"httpx", "pymilvus"
])
import httpx
component_name = "{{inputs.parameters.component-name}}"
params = json.loads('''{{inputs.parameters.component-params}}''')
# Component implementations (mirrors KFP components)
COMPONENTS = {
"transcribe_audio": {
"url": "http://whisper-predictor.ai-ml.svc.cluster.local",
"endpoint": "/v1/audio/transcriptions"
},
"generate_embeddings": {
"url": "http://embeddings-predictor.ai-ml.svc.cluster.local",
"endpoint": "/embeddings"
},
"generate_response": {
"url": "http://llm-draft.ai-ml.svc.cluster.local:8000",
"endpoint": "/v1/chat/completions"
},
"synthesize_speech": {
"url": "http://tts-predictor.ai-ml.svc.cluster.local",
"endpoint": "/v1/audio/speech"
}
}
if component_name not in COMPONENTS:
raise ValueError(f"Unknown component: {component_name}")
config = COMPONENTS[component_name]
with httpx.Client(timeout=120.0) as client:
response = client.post(
f"{config['url']}{config['endpoint']}",
json=params
)
result = response.json()
with open("/tmp/result.json", "w") as f:
json.dump(result, f)
print(f"Component {component_name} completed")

510
qlora-training.yaml Normal file
View File

@@ -0,0 +1,510 @@
# QLoRA Fine-tuning Workflow
# Trains QLora adapters from a reference model using data from Milvus vector database
# Triggered via NATS: ai.pipeline.trigger with pipeline="qlora-training"
---
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
name: qlora-training
namespace: ai-ml
labels:
app.kubernetes.io/name: qlora-training
app.kubernetes.io/part-of: llm-workflows
spec:
entrypoint: train-qlora
serviceAccountName: argo-workflow
arguments:
parameters:
- name: reference-model
description: "Base model to fine-tune (HuggingFace model ID or path)"
value: "mistralai/Mistral-7B-Instruct-v0.3"
- name: output-name
description: "Name for the output QLora adapter"
value: "qlora-adapter"
- name: milvus-collections
description: "Comma-separated list of Milvus collections to use (empty = all available)"
value: ""
- name: learning-rate
value: "2e-4"
description: "Learning rate for training"
- name: num-epochs
value: "3"
description: "Number of training epochs"
- name: batch-size
value: "4"
description: "Training batch size"
- name: max-seq-length
value: "2048"
description: "Maximum sequence length"
- name: lora-r
value: "64"
description: "LoRA attention dimension"
- name: lora-alpha
value: "16"
description: "LoRA alpha parameter"
- name: lora-dropout
value: "0.05"
description: "LoRA dropout rate"
volumeClaimTemplates:
- metadata:
name: model-storage
spec:
accessModes: ["ReadWriteMany"]
storageClassName: nfs-slow
resources:
requests:
storage: 50Gi
templates:
- name: train-qlora
dag:
tasks:
- name: fetch-training-data
template: fetch-data
arguments:
parameters:
- name: milvus-collections
value: "{{workflow.parameters.milvus-collections}}"
- name: prepare-dataset
template: prepare-data
dependencies: [fetch-training-data]
arguments:
parameters:
- name: max-seq-length
value: "{{workflow.parameters.max-seq-length}}"
artifacts:
- name: raw-data
from: "{{tasks.fetch-training-data.outputs.artifacts.raw-data}}"
- name: train-model
template: train
dependencies: [prepare-dataset]
arguments:
parameters:
- name: reference-model
value: "{{workflow.parameters.reference-model}}"
- name: output-name
value: "{{workflow.parameters.output-name}}"
- name: learning-rate
value: "{{workflow.parameters.learning-rate}}"
- name: num-epochs
value: "{{workflow.parameters.num-epochs}}"
- name: batch-size
value: "{{workflow.parameters.batch-size}}"
- name: max-seq-length
value: "{{workflow.parameters.max-seq-length}}"
- name: lora-r
value: "{{workflow.parameters.lora-r}}"
- name: lora-alpha
value: "{{workflow.parameters.lora-alpha}}"
- name: lora-dropout
value: "{{workflow.parameters.lora-dropout}}"
artifacts:
- name: training-data
from: "{{tasks.prepare-dataset.outputs.artifacts.training-data}}"
- name: fetch-data
inputs:
parameters:
- name: milvus-collections
outputs:
artifacts:
- name: raw-data
path: /tmp/raw-data
container:
image: python:3.13-slim
command: [python]
args:
- -c
- |
import subprocess
subprocess.run(["pip", "install", "pymilvus", "-q"], check=True)
import json
from pathlib import Path
from pymilvus import connections, Collection, utility
MILVUS_HOST = "milvus.ai-ml.svc.cluster.local"
MILVUS_PORT = 19530
collections_param = "{{inputs.parameters.milvus-collections}}"
output_dir = Path("/tmp/raw-data")
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Connecting to Milvus at {MILVUS_HOST}:{MILVUS_PORT}")
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
# Determine which collections to use
if collections_param and collections_param.strip():
collection_names = [c.strip() for c in collections_param.split(",")]
print(f"Using specified collections: {collection_names}")
else:
# Get all available collections
collection_names = utility.list_collections()
print(f"Using all available collections: {collection_names}")
all_training_data = []
for collection_name in collection_names:
if not utility.has_collection(collection_name):
print(f"Warning: Collection {collection_name} not found, skipping")
continue
print(f"Fetching data from collection: {collection_name}")
collection = Collection(collection_name)
collection.load()
# Query all data from the collection
# Note: Adjust field names based on your schema
try:
# Get collection schema to determine fields
schema = collection.schema
field_names = [field.name for field in schema.fields if field.name != "id"]
# Query all entities (limited to reasonable batch size)
# For large collections, you may want to implement pagination
results = collection.query(
expr="id >= 0",
output_fields=field_names,
limit=100000
)
print(f" Retrieved {len(results)} records from {collection_name}")
for result in results:
# Extract text field (adjust based on your schema)
text_content = result.get("text", "")
source = result.get("source", collection_name)
if text_content:
all_training_data.append({
"text": text_content,
"source": source,
"collection": collection_name
})
except Exception as e:
print(f"Error querying collection {collection_name}: {e}")
continue
# Save all training data
output_file = output_dir / "training_data.json"
with open(output_file, "w") as f:
json.dump({"data": all_training_data}, f)
print(f"Total training samples collected: {len(all_training_data)}")
print(f"Saved to {output_file}")
connections.disconnect("default")
envFrom:
- configMapRef:
name: ai-services-config
resources:
requests:
memory: 1Gi
cpu: 500m
- name: prepare-data
inputs:
parameters:
- name: max-seq-length
artifacts:
- name: raw-data
path: /tmp/raw-data
outputs:
artifacts:
- name: training-data
path: /tmp/training-data
container:
image: python:3.13-slim
command: [python]
args:
- -c
- |
import json
from pathlib import Path
max_seq_length = int("{{inputs.parameters.max-seq-length}}")
input_dir = Path("/tmp/raw-data")
output_dir = Path("/tmp/training-data")
output_dir.mkdir(parents=True, exist_ok=True)
# Load raw data
with open(input_dir / "training_data.json") as f:
data = json.load(f)
raw_samples = data["data"]
print(f"Processing {len(raw_samples)} raw samples")
# Prepare data in instruction format for fine-tuning
# Using Alpaca-style format: instruction + response
training_samples = []
for sample in raw_samples:
text = sample["text"]
source = sample.get("source", "")
# Create instruction-response pairs
# You can customize this based on your use case
training_sample = {
"instruction": f"Based on the following information from {source}, provide a comprehensive response:",
"input": text[:max_seq_length // 2], # Truncate if needed
"output": text[:max_seq_length // 2],
"source": source
}
training_samples.append(training_sample)
# Split into train/validation (90/10)
split_idx = int(len(training_samples) * 0.9)
train_data = training_samples[:split_idx]
val_data = training_samples[split_idx:]
# Save prepared datasets
with open(output_dir / "train.json", "w") as f:
json.dump(train_data, f)
with open(output_dir / "validation.json", "w") as f:
json.dump(val_data, f)
print(f"Prepared {len(train_data)} training samples")
print(f"Prepared {len(val_data)} validation samples")
print("Data preparation complete")
resources:
requests:
memory: 2Gi
cpu: 500m
- name: train
inputs:
parameters:
- name: reference-model
- name: output-name
- name: learning-rate
- name: num-epochs
- name: batch-size
- name: max-seq-length
- name: lora-r
- name: lora-alpha
- name: lora-dropout
artifacts:
- name: training-data
path: /tmp/training-data
container:
image: python:3.13-slim
command: [bash]
args:
- -c
- |
set -e
echo "Installing dependencies..."
pip install -q torch transformers peft datasets accelerate bitsandbytes scipy
echo "Starting QLoRA training..."
python << 'EOF'
import json
import os
from pathlib import Path
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch
# Parameters
reference_model = "{{inputs.parameters.reference-model}}"
output_name = "{{inputs.parameters.output-name}}"
learning_rate = float("{{inputs.parameters.learning-rate}}")
num_epochs = int("{{inputs.parameters.num-epochs}}")
batch_size = int("{{inputs.parameters.batch-size}}")
max_seq_length = int("{{inputs.parameters.max-seq-length}}")
lora_r = int("{{inputs.parameters.lora-r}}")
lora_alpha = int("{{inputs.parameters.lora-alpha}}")
lora_dropout = float("{{inputs.parameters.lora-dropout}}")
data_dir = Path("/tmp/training-data")
output_dir = Path("/mnt/model-storage") / output_name
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Training configuration:")
print(f" Model: {reference_model}")
print(f" Learning rate: {learning_rate}")
print(f" Epochs: {num_epochs}")
print(f" Batch size: {batch_size}")
print(f" Max sequence length: {max_seq_length}")
print(f" LoRA r: {lora_r}, alpha: {lora_alpha}, dropout: {lora_dropout}")
# Load datasets
with open(data_dir / "train.json") as f:
train_data = json.load(f)
with open(data_dir / "validation.json") as f:
val_data = json.load(f)
print(f"Loaded {len(train_data)} training samples, {len(val_data)} validation samples")
# Load tokenizer
print(f"Loading tokenizer from {reference_model}...")
tokenizer = AutoTokenizer.from_pretrained(reference_model, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Prepare datasets
def format_sample(sample):
instruction = sample.get("instruction", "")
input_text = sample.get("input", "")
output = sample.get("output", "")
if input_text:
prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{output}"
else:
prompt = f"### Instruction:\n{instruction}\n\n### Response:\n{output}"
return {"text": prompt}
train_dataset = Dataset.from_list([format_sample(s) for s in train_data])
val_dataset = Dataset.from_list([format_sample(s) for s in val_data])
# Tokenize datasets
def tokenize_function(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=max_seq_length,
padding="max_length"
)
print("Tokenizing datasets...")
train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
val_dataset = val_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
# Configure quantization for QLoRA
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
# Load model
print(f"Loading model {reference_model} with 4-bit quantization...")
model = AutoModelForCausalLM.from_pretrained(
reference_model,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
# Prepare model for training
model = prepare_model_for_kbit_training(model)
# Configure LoRA
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM"
)
# Add LoRA adapters
print("Adding LoRA adapters...")
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Training arguments
training_args = TrainingArguments(
output_dir=str(output_dir / "checkpoints"),
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
gradient_accumulation_steps=4,
learning_rate=learning_rate,
fp16=True,
logging_steps=10,
evaluation_strategy="steps",
eval_steps=50,
save_strategy="steps",
save_steps=100,
save_total_limit=3,
load_best_model_at_end=True,
report_to="none",
remove_unused_columns=False,
)
# Data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=data_collator,
)
# Train
print("Starting training...")
trainer.train()
# Save final model
print(f"Saving QLora adapter to {output_dir}")
model.save_pretrained(str(output_dir / "final"))
tokenizer.save_pretrained(str(output_dir / "final"))
# Save training metadata
metadata = {
"reference_model": reference_model,
"output_name": output_name,
"training_params": {
"learning_rate": learning_rate,
"num_epochs": num_epochs,
"batch_size": batch_size,
"max_seq_length": max_seq_length,
"lora_r": lora_r,
"lora_alpha": lora_alpha,
"lora_dropout": lora_dropout
},
"dataset_info": {
"train_samples": len(train_data),
"val_samples": len(val_data)
}
}
with open(output_dir / "metadata.json", "w") as f:
json.dump(metadata, f, indent=2)
print("Training complete!")
print(f"QLora adapter saved to: {output_dir}")
EOF
envFrom:
- configMapRef:
name: ai-services-config
volumeMounts:
- name: model-storage
mountPath: /mnt/model-storage
resources:
requests:
memory: 16Gi
cpu: 4
nvidia.com/gpu: 1
limits:
memory: 32Gi
cpu: 8
nvidia.com/gpu: 1