feat: Add ML training and batch inference workflows
- batch-inference: LLM inference with optional RAG - qlora-training: QLoRA adapter fine-tuning from Milvus - hybrid-ml-training: Multi-GPU distributed training - coqui-voice-training: XTTS voice cloning - document-ingestion: Ingest documents to Milvus - eventsource-kfp: Argo Events / Kubeflow integration - kfp-integration: Bridge between Argo and Kubeflow
This commit is contained in:
128
README.md
128
README.md
@@ -1,2 +1,128 @@
|
||||
# argo
|
||||
# Argo Workflows
|
||||
|
||||
ML training and batch inference workflows for the DaviesTechLabs AI/ML platform.
|
||||
|
||||
## Workflows
|
||||
|
||||
| Workflow | Description | Trigger |
|
||||
|----------|-------------|---------|
|
||||
| `batch-inference` | Run LLM inference on batch inputs | `ai.pipeline.trigger` (pipeline="batch-inference") |
|
||||
| `qlora-training` | Train QLoRA adapters from Milvus data | `ai.pipeline.trigger` (pipeline="qlora-training") |
|
||||
| `hybrid-ml-training` | Multi-GPU distributed training | `ai.pipeline.trigger` (pipeline="hybrid-ml-training") |
|
||||
| `coqui-voice-training` | XTTS voice cloning/training | `ai.pipeline.trigger` (pipeline="coqui-voice-training") |
|
||||
| `document-ingestion` | Ingest documents into Milvus | `ai.pipeline.trigger` (pipeline="document-ingestion") |
|
||||
|
||||
## Integration
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `eventsource-kfp.yaml` | Argo Events source for Kubeflow Pipelines integration |
|
||||
| `kfp-integration.yaml` | Bridge workflows between Argo and Kubeflow |
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
NATS (ai.pipeline.trigger)
|
||||
│
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ Argo Events │
|
||||
│ EventSource │
|
||||
└─────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ Argo Sensor │
|
||||
└─────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ WorkflowTemplate│
|
||||
│ (batch-inf, │
|
||||
│ qlora, etc) │
|
||||
└─────────────────┘
|
||||
│
|
||||
├──▶ GPU Pods (AMD ROCm / NVIDIA CUDA)
|
||||
├──▶ Milvus Vector DB
|
||||
├──▶ vLLM / Ray Serve
|
||||
└──▶ MLflow Tracking
|
||||
```
|
||||
|
||||
## Workflow Details
|
||||
|
||||
### batch-inference
|
||||
|
||||
Batch LLM inference with optional RAG:
|
||||
|
||||
```bash
|
||||
argo submit batch-inference.yaml \
|
||||
-p input-url="s3://bucket/inputs.json" \
|
||||
-p output-url="s3://bucket/outputs.json" \
|
||||
-p use-rag="true" \
|
||||
-p max-tokens="500"
|
||||
```
|
||||
|
||||
### qlora-training
|
||||
|
||||
Fine-tune QLoRA adapters from Milvus knowledge:
|
||||
|
||||
```bash
|
||||
argo submit qlora-training.yaml \
|
||||
-p reference-model="mistralai/Mistral-7B-Instruct-v0.3" \
|
||||
-p output-name="my-adapter" \
|
||||
-p milvus-collections="docs,wiki" \
|
||||
-p num-epochs="3"
|
||||
```
|
||||
|
||||
### coqui-voice-training
|
||||
|
||||
Train XTTS voice models:
|
||||
|
||||
```bash
|
||||
argo submit coqui-voice-training.yaml \
|
||||
-p voice-name="my-voice" \
|
||||
-p audio-samples-url="s3://bucket/samples/"
|
||||
```
|
||||
|
||||
### document-ingestion
|
||||
|
||||
Ingest documents into Milvus:
|
||||
|
||||
```bash
|
||||
argo submit document-ingestion.yaml \
|
||||
-p source-url="s3://bucket/docs/" \
|
||||
-p collection="knowledge_base" \
|
||||
-p chunk-size="512"
|
||||
```
|
||||
|
||||
## NATS Trigger Format
|
||||
|
||||
Workflows are triggered via NATS `ai.pipeline.trigger`:
|
||||
|
||||
```json
|
||||
{
|
||||
"pipeline": "qlora-training",
|
||||
"parameters": {
|
||||
"reference-model": "mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"output-name": "custom-adapter",
|
||||
"num-epochs": "5"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## GPU Scheduling
|
||||
|
||||
Workflows use node affinity for GPU allocation:
|
||||
|
||||
| Node | GPU | Best For |
|
||||
|------|-----|----------|
|
||||
| khelben | AMD Strix Halo 64GB | Large model training, vLLM |
|
||||
| elminster | NVIDIA RTX 2070 | Whisper, XTTS |
|
||||
| drizzt | AMD Radeon 680M | Embeddings |
|
||||
| danilo | Intel Arc | Reranker |
|
||||
|
||||
## Related
|
||||
|
||||
- [homelab-design](https://git.daviestechlabs.io/daviestechlabs/homelab-design) - Architecture docs
|
||||
- [kuberay-images](https://git.daviestechlabs.io/daviestechlabs/kuberay-images) - Ray worker images
|
||||
- [handler-base](https://git.daviestechlabs.io/daviestechlabs/handler-base) - Handler library
|
||||
|
||||
328
batch-inference.yaml
Normal file
328
batch-inference.yaml
Normal file
@@ -0,0 +1,328 @@
|
||||
# Batch Inference Workflow
|
||||
# Runs LLM inference on a batch of inputs
|
||||
# Triggered via NATS: ai.pipeline.trigger with pipeline="batch-inference"
|
||||
---
|
||||
apiVersion: argoproj.io/v1alpha1
|
||||
kind: WorkflowTemplate
|
||||
metadata:
|
||||
name: batch-inference
|
||||
namespace: ai-ml
|
||||
labels:
|
||||
app.kubernetes.io/name: batch-inference
|
||||
app.kubernetes.io/part-of: llm-workflows
|
||||
spec:
|
||||
entrypoint: batch-inference
|
||||
serviceAccountName: argo-workflow
|
||||
|
||||
arguments:
|
||||
parameters:
|
||||
- name: input-url
|
||||
description: "URL to JSON file with inference requests"
|
||||
- name: output-url
|
||||
description: "URL to store results (S3 path)"
|
||||
value: ""
|
||||
- name: use-rag
|
||||
value: "true"
|
||||
description: "Whether to use RAG for context"
|
||||
- name: max-tokens
|
||||
value: "500"
|
||||
description: "Maximum tokens per response"
|
||||
- name: temperature
|
||||
value: "0.7"
|
||||
description: "LLM temperature"
|
||||
|
||||
templates:
|
||||
- name: batch-inference
|
||||
dag:
|
||||
tasks:
|
||||
- name: fetch-inputs
|
||||
template: fetch-input-data
|
||||
arguments:
|
||||
parameters:
|
||||
- name: input-url
|
||||
value: "{{workflow.parameters.input-url}}"
|
||||
|
||||
- name: run-inference
|
||||
template: inference
|
||||
dependencies: [fetch-inputs]
|
||||
arguments:
|
||||
parameters:
|
||||
- name: use-rag
|
||||
value: "{{workflow.parameters.use-rag}}"
|
||||
- name: max-tokens
|
||||
value: "{{workflow.parameters.max-tokens}}"
|
||||
- name: temperature
|
||||
value: "{{workflow.parameters.temperature}}"
|
||||
artifacts:
|
||||
- name: inputs
|
||||
from: "{{tasks.fetch-inputs.outputs.artifacts.inputs}}"
|
||||
|
||||
- name: upload-results
|
||||
template: upload-output
|
||||
dependencies: [run-inference]
|
||||
when: "{{workflow.parameters.output-url}} != ''"
|
||||
arguments:
|
||||
parameters:
|
||||
- name: output-url
|
||||
value: "{{workflow.parameters.output-url}}"
|
||||
artifacts:
|
||||
- name: results
|
||||
from: "{{tasks.run-inference.outputs.artifacts.results}}"
|
||||
|
||||
- name: fetch-input-data
|
||||
inputs:
|
||||
parameters:
|
||||
- name: input-url
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: inputs
|
||||
path: /tmp/inputs
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
import json
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
|
||||
input_url = "{{inputs.parameters.input-url}}"
|
||||
output_dir = Path("/tmp/inputs")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Fetching inputs from: {input_url}")
|
||||
|
||||
if input_url.startswith("s3://"):
|
||||
import subprocess
|
||||
subprocess.run(["pip", "install", "boto3", "-q"], check=True)
|
||||
import boto3
|
||||
s3 = boto3.client("s3")
|
||||
bucket, key = input_url[5:].split("/", 1)
|
||||
s3.download_file(bucket, key, str(output_dir / "inputs.json"))
|
||||
elif input_url.startswith("http"):
|
||||
urllib.request.urlretrieve(input_url, output_dir / "inputs.json")
|
||||
else:
|
||||
print(f"Unsupported URL scheme: {input_url}")
|
||||
exit(1)
|
||||
|
||||
# Validate JSON structure
|
||||
with open(output_dir / "inputs.json") as f:
|
||||
data = json.load(f)
|
||||
|
||||
if "requests" not in data:
|
||||
print("Error: JSON must contain 'requests' array")
|
||||
exit(1)
|
||||
|
||||
print(f"Loaded {len(data['requests'])} inference requests")
|
||||
resources:
|
||||
requests:
|
||||
memory: 256Mi
|
||||
cpu: 100m
|
||||
|
||||
- name: inference
|
||||
inputs:
|
||||
parameters:
|
||||
- name: use-rag
|
||||
- name: max-tokens
|
||||
- name: temperature
|
||||
artifacts:
|
||||
- name: inputs
|
||||
path: /tmp/inputs
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: results
|
||||
path: /tmp/results
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
import subprocess
|
||||
subprocess.run(["pip", "install", "httpx", "pymilvus", "-q"], check=True)
|
||||
|
||||
import json
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
from typing import List, Dict
|
||||
|
||||
# Configuration
|
||||
VLLM_URL = "http://llm-draft.ai-ml.svc.cluster.local:8000"
|
||||
EMBEDDINGS_URL = "http://embeddings-predictor.ai-ml.svc.cluster.local"
|
||||
RERANKER_URL = "http://reranker-predictor.ai-ml.svc.cluster.local"
|
||||
MILVUS_HOST = "milvus.ai-ml.svc.cluster.local"
|
||||
LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
|
||||
use_rag = "{{inputs.parameters.use-rag}}" == "true"
|
||||
max_tokens = int("{{inputs.parameters.max-tokens}}")
|
||||
temperature = float("{{inputs.parameters.temperature}}")
|
||||
|
||||
input_dir = Path("/tmp/inputs")
|
||||
output_dir = Path("/tmp/results")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load inputs
|
||||
with open(input_dir / "inputs.json") as f:
|
||||
data = json.load(f)
|
||||
requests = data["requests"]
|
||||
|
||||
print(f"Processing {len(requests)} requests (RAG: {use_rag})")
|
||||
|
||||
# Initialize Milvus if using RAG
|
||||
collection = None
|
||||
if use_rag:
|
||||
try:
|
||||
from pymilvus import connections, Collection, utility
|
||||
connections.connect(host=MILVUS_HOST, port=19530)
|
||||
if utility.has_collection("knowledge_base"):
|
||||
collection = Collection("knowledge_base")
|
||||
collection.load()
|
||||
print("Milvus connected")
|
||||
except Exception as e:
|
||||
print(f"Milvus connection failed: {e}")
|
||||
use_rag = False
|
||||
|
||||
def get_embeddings(texts: List[str], client: httpx.Client) -> List[List[float]]:
|
||||
response = client.post(
|
||||
f"{EMBEDDINGS_URL}/embeddings",
|
||||
json={"input": texts, "model": "bge"}
|
||||
)
|
||||
result = response.json()
|
||||
return [d["embedding"] for d in result.get("data", [])]
|
||||
|
||||
def search_milvus(embedding: List[float]) -> List[Dict]:
|
||||
results = collection.search(
|
||||
data=[embedding],
|
||||
anns_field="embedding",
|
||||
param={"metric_type": "COSINE", "params": {"ef": 64}},
|
||||
limit=5,
|
||||
output_fields=["text", "source"]
|
||||
)
|
||||
docs = []
|
||||
for hits in results:
|
||||
for hit in hits:
|
||||
docs.append({
|
||||
"text": hit.entity.get("text", ""),
|
||||
"source": hit.entity.get("source", ""),
|
||||
"score": hit.score
|
||||
})
|
||||
return docs
|
||||
|
||||
def rerank(query: str, documents: List[str], client: httpx.Client) -> List[Dict]:
|
||||
response = client.post(
|
||||
f"{RERANKER_URL}/v1/rerank",
|
||||
json={"query": query, "documents": documents}
|
||||
)
|
||||
return response.json().get("results", [])
|
||||
|
||||
# Process requests
|
||||
results = []
|
||||
with httpx.Client(timeout=120.0) as client:
|
||||
for i, req in enumerate(requests):
|
||||
query = req.get("text", req.get("query", ""))
|
||||
req_id = req.get("id", str(i))
|
||||
|
||||
print(f"Processing {i+1}/{len(requests)}: {query[:50]}...")
|
||||
|
||||
context = ""
|
||||
rag_sources = []
|
||||
|
||||
if use_rag and collection:
|
||||
try:
|
||||
# Get embeddings and search
|
||||
embeddings = get_embeddings([query], client)
|
||||
if embeddings:
|
||||
docs = search_milvus(embeddings[0])
|
||||
if docs:
|
||||
doc_texts = [d["text"] for d in docs]
|
||||
reranked = rerank(query, doc_texts, client)
|
||||
sorted_docs = sorted(reranked, key=lambda x: x.get("relevance_score", 0), reverse=True)[:3]
|
||||
context = "\n\n".join([doc_texts[d["index"]] for d in sorted_docs])
|
||||
rag_sources = [docs[d["index"]].get("source", "") for d in sorted_docs]
|
||||
except Exception as e:
|
||||
print(f" RAG failed: {e}")
|
||||
|
||||
# Generate response
|
||||
try:
|
||||
messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
|
||||
if context:
|
||||
messages.append({"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"})
|
||||
else:
|
||||
messages.append({"role": "user", "content": query})
|
||||
|
||||
response = client.post(
|
||||
f"{VLLM_URL}/v1/chat/completions",
|
||||
json={
|
||||
"model": LLM_MODEL,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature
|
||||
}
|
||||
)
|
||||
result = response.json()
|
||||
answer = result["choices"][0]["message"]["content"]
|
||||
except Exception as e:
|
||||
answer = f"Error: {e}"
|
||||
|
||||
results.append({
|
||||
"id": req_id,
|
||||
"query": query,
|
||||
"response": answer,
|
||||
"used_rag": bool(context),
|
||||
"rag_sources": rag_sources
|
||||
})
|
||||
|
||||
# Save results
|
||||
with open(output_dir / "results.json", "w") as f:
|
||||
json.dump({"results": results}, f, indent=2)
|
||||
|
||||
print(f"Completed {len(results)} inferences")
|
||||
|
||||
if collection:
|
||||
from pymilvus import connections
|
||||
connections.disconnect("default")
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: ai-services-config
|
||||
resources:
|
||||
requests:
|
||||
memory: 1Gi
|
||||
cpu: 500m
|
||||
|
||||
- name: upload-output
|
||||
inputs:
|
||||
parameters:
|
||||
- name: output-url
|
||||
artifacts:
|
||||
- name: results
|
||||
path: /tmp/results
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
import subprocess
|
||||
subprocess.run(["pip", "install", "boto3", "-q"], check=True)
|
||||
|
||||
import boto3
|
||||
from pathlib import Path
|
||||
|
||||
output_url = "{{inputs.parameters.output-url}}"
|
||||
results_file = Path("/tmp/results/results.json")
|
||||
|
||||
print(f"Uploading results to: {output_url}")
|
||||
|
||||
if output_url.startswith("s3://"):
|
||||
s3 = boto3.client("s3")
|
||||
bucket, key = output_url[5:].split("/", 1)
|
||||
s3.upload_file(str(results_file), bucket, key)
|
||||
print("Upload complete")
|
||||
else:
|
||||
print(f"Unsupported URL scheme: {output_url}")
|
||||
exit(1)
|
||||
resources:
|
||||
requests:
|
||||
memory: 256Mi
|
||||
cpu: 100m
|
||||
969
coqui-voice-training.yaml
Normal file
969
coqui-voice-training.yaml
Normal file
@@ -0,0 +1,969 @@
|
||||
# Coqui TTS Voice Training Workflow
|
||||
# Trains a custom voice model using Coqui TTS from audio samples
|
||||
# Triggered via NATS: ai.pipeline.trigger with pipeline="coqui-voice-training"
|
||||
---
|
||||
apiVersion: argoproj.io/v1alpha1
|
||||
kind: WorkflowTemplate
|
||||
metadata:
|
||||
name: coqui-voice-training
|
||||
namespace: ai-ml
|
||||
labels:
|
||||
app.kubernetes.io/name: coqui-voice-training
|
||||
app.kubernetes.io/part-of: llm-workflows
|
||||
spec:
|
||||
entrypoint: train-voice
|
||||
serviceAccountName: argo-workflow
|
||||
|
||||
arguments:
|
||||
parameters:
|
||||
- name: audio-source
|
||||
description: "URL to audio files (S3 bucket, HTTP, or NFS path with .wav/.mp3 files)"
|
||||
- name: transcripts-source
|
||||
description: "URL to transcripts file (CSV with audio_file,transcript columns) - leave empty to auto-transcribe"
|
||||
value: ""
|
||||
- name: voice-name
|
||||
description: "Name for the trained voice model"
|
||||
value: "custom-voice"
|
||||
- name: base-model
|
||||
description: "Base TTS model to fine-tune from"
|
||||
value: "tts_models/en/ljspeech/vits"
|
||||
- name: language
|
||||
description: "Language code (e.g., en, de, fr, es)"
|
||||
value: "en"
|
||||
- name: num-epochs
|
||||
description: "Number of training epochs"
|
||||
value: "100"
|
||||
- name: batch-size
|
||||
description: "Training batch size"
|
||||
value: "16"
|
||||
- name: learning-rate
|
||||
description: "Learning rate for training"
|
||||
value: "0.0001"
|
||||
- name: sample-rate
|
||||
description: "Target sample rate for audio (Hz)"
|
||||
value: "22050"
|
||||
- name: output-path
|
||||
description: "Path to store the trained model (S3 or NFS)"
|
||||
value: "/models/tts/custom"
|
||||
|
||||
volumeClaimTemplates:
|
||||
- metadata:
|
||||
name: training-workspace
|
||||
spec:
|
||||
accessModes: ["ReadWriteMany"]
|
||||
storageClassName: nfs-slow
|
||||
resources:
|
||||
requests:
|
||||
storage: 50Gi
|
||||
|
||||
templates:
|
||||
- name: train-voice
|
||||
dag:
|
||||
tasks:
|
||||
- name: fetch-audio
|
||||
template: fetch-audio-files
|
||||
arguments:
|
||||
parameters:
|
||||
- name: audio-source
|
||||
value: "{{workflow.parameters.audio-source}}"
|
||||
|
||||
- name: fetch-transcripts
|
||||
template: fetch-transcript-file
|
||||
arguments:
|
||||
parameters:
|
||||
- name: transcripts-source
|
||||
value: "{{workflow.parameters.transcripts-source}}"
|
||||
|
||||
- name: preprocess-audio
|
||||
template: preprocess
|
||||
dependencies: [fetch-audio]
|
||||
arguments:
|
||||
parameters:
|
||||
- name: sample-rate
|
||||
value: "{{workflow.parameters.sample-rate}}"
|
||||
artifacts:
|
||||
- name: raw-audio
|
||||
from: "{{tasks.fetch-audio.outputs.artifacts.audio-files}}"
|
||||
|
||||
- name: generate-transcripts
|
||||
template: transcribe-audio
|
||||
dependencies: [preprocess-audio, fetch-transcripts]
|
||||
when: "{{workflow.parameters.transcripts-source}} == ''"
|
||||
arguments:
|
||||
parameters:
|
||||
- name: language
|
||||
value: "{{workflow.parameters.language}}"
|
||||
artifacts:
|
||||
- name: audio-files
|
||||
from: "{{tasks.preprocess-audio.outputs.artifacts.processed-audio}}"
|
||||
|
||||
- name: prepare-dataset
|
||||
template: prepare-coqui-dataset
|
||||
dependencies: [preprocess-audio, generate-transcripts, fetch-transcripts]
|
||||
arguments:
|
||||
parameters:
|
||||
- name: voice-name
|
||||
value: "{{workflow.parameters.voice-name}}"
|
||||
- name: language
|
||||
value: "{{workflow.parameters.language}}"
|
||||
artifacts:
|
||||
- name: audio-files
|
||||
from: "{{tasks.preprocess-audio.outputs.artifacts.processed-audio}}"
|
||||
- name: transcripts
|
||||
from: "{{=workflow.parameters.transcriptsSource != '' ? tasks.fetch-transcripts.outputs.artifacts.transcripts : tasks.generate-transcripts.outputs.artifacts.transcripts}}"
|
||||
optional: true
|
||||
|
||||
- name: train-model
|
||||
template: train-tts
|
||||
dependencies: [prepare-dataset]
|
||||
arguments:
|
||||
parameters:
|
||||
- name: voice-name
|
||||
value: "{{workflow.parameters.voice-name}}"
|
||||
- name: base-model
|
||||
value: "{{workflow.parameters.base-model}}"
|
||||
- name: language
|
||||
value: "{{workflow.parameters.language}}"
|
||||
- name: num-epochs
|
||||
value: "{{workflow.parameters.num-epochs}}"
|
||||
- name: batch-size
|
||||
value: "{{workflow.parameters.batch-size}}"
|
||||
- name: learning-rate
|
||||
value: "{{workflow.parameters.learning-rate}}"
|
||||
artifacts:
|
||||
- name: dataset
|
||||
from: "{{tasks.prepare-dataset.outputs.artifacts.dataset}}"
|
||||
|
||||
- name: export-model
|
||||
template: export-trained-model
|
||||
dependencies: [train-model]
|
||||
arguments:
|
||||
parameters:
|
||||
- name: voice-name
|
||||
value: "{{workflow.parameters.voice-name}}"
|
||||
- name: output-path
|
||||
value: "{{workflow.parameters.output-path}}"
|
||||
artifacts:
|
||||
- name: trained-model
|
||||
from: "{{tasks.train-model.outputs.artifacts.model}}"
|
||||
|
||||
# Template: Fetch audio files from source
|
||||
- name: fetch-audio-files
|
||||
inputs:
|
||||
parameters:
|
||||
- name: audio-source
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: audio-files
|
||||
path: /tmp/audio
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
import os
|
||||
import subprocess
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
source_url = "{{inputs.parameters.audio-source}}"
|
||||
output_dir = Path("/tmp/audio")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Fetching audio from: {source_url}")
|
||||
|
||||
if source_url.startswith("s3://"):
|
||||
subprocess.run(["pip", "install", "boto3", "-q"], check=True)
|
||||
import boto3
|
||||
s3 = boto3.client("s3")
|
||||
bucket, prefix = source_url[5:].split("/", 1)
|
||||
response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
|
||||
|
||||
audio_extensions = {".wav", ".mp3", ".flac", ".ogg", ".m4a"}
|
||||
for obj in response.get("Contents", []):
|
||||
key = obj["Key"]
|
||||
if Path(key).suffix.lower() in audio_extensions:
|
||||
local_path = output_dir / Path(key).name
|
||||
s3.download_file(bucket, key, str(local_path))
|
||||
print(f"Downloaded: {key}")
|
||||
|
||||
elif source_url.startswith("http"):
|
||||
# Handle single file or directory listing
|
||||
filename = source_url.split("/")[-1]
|
||||
if any(ext in filename.lower() for ext in [".wav", ".mp3", ".flac", ".zip"]):
|
||||
local_path = output_dir / filename
|
||||
urllib.request.urlretrieve(source_url, local_path)
|
||||
print(f"Downloaded: {filename}")
|
||||
|
||||
# Extract if zip
|
||||
if filename.endswith(".zip"):
|
||||
shutil.unpack_archive(local_path, output_dir)
|
||||
os.remove(local_path)
|
||||
print("Extracted zip archive")
|
||||
else:
|
||||
print(f"URL doesn't appear to be an audio file: {source_url}")
|
||||
exit(1)
|
||||
|
||||
elif source_url.startswith("/"):
|
||||
# Local/NFS path
|
||||
src_path = Path(source_url)
|
||||
if src_path.is_dir():
|
||||
audio_extensions = {".wav", ".mp3", ".flac", ".ogg", ".m4a"}
|
||||
for f in src_path.iterdir():
|
||||
if f.suffix.lower() in audio_extensions:
|
||||
shutil.copy(f, output_dir / f.name)
|
||||
print(f"Copied: {f.name}")
|
||||
elif src_path.is_file():
|
||||
shutil.copy(src_path, output_dir / src_path.name)
|
||||
else:
|
||||
print(f"Path not found: {source_url}")
|
||||
exit(1)
|
||||
else:
|
||||
print(f"Unsupported source: {source_url}")
|
||||
exit(1)
|
||||
|
||||
# Count files
|
||||
audio_files = list(output_dir.glob("*"))
|
||||
print(f"Total audio files: {len(audio_files)}")
|
||||
|
||||
if len(audio_files) == 0:
|
||||
print("Error: No audio files found!")
|
||||
exit(1)
|
||||
resources:
|
||||
requests:
|
||||
memory: 512Mi
|
||||
cpu: 200m
|
||||
|
||||
# Template: Fetch transcripts file
|
||||
- name: fetch-transcript-file
|
||||
inputs:
|
||||
parameters:
|
||||
- name: transcripts-source
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: transcripts
|
||||
path: /tmp/transcripts
|
||||
optional: true
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
import os
|
||||
import subprocess
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
source_url = "{{inputs.parameters.transcripts-source}}"
|
||||
output_dir = Path("/tmp/transcripts")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not source_url or source_url.strip() == "":
|
||||
print("No transcripts source provided - will auto-transcribe")
|
||||
# Create empty placeholder
|
||||
(output_dir / "placeholder.txt").write_text("auto-transcribe")
|
||||
exit(0)
|
||||
|
||||
print(f"Fetching transcripts from: {source_url}")
|
||||
|
||||
if source_url.startswith("s3://"):
|
||||
subprocess.run(["pip", "install", "boto3", "-q"], check=True)
|
||||
import boto3
|
||||
s3 = boto3.client("s3")
|
||||
bucket, key = source_url[5:].split("/", 1)
|
||||
local_path = output_dir / Path(key).name
|
||||
s3.download_file(bucket, key, str(local_path))
|
||||
print(f"Downloaded: {key}")
|
||||
|
||||
elif source_url.startswith("http"):
|
||||
filename = source_url.split("/")[-1] or "transcripts.csv"
|
||||
local_path = output_dir / filename
|
||||
urllib.request.urlretrieve(source_url, local_path)
|
||||
print(f"Downloaded: {filename}")
|
||||
|
||||
elif source_url.startswith("/"):
|
||||
src_path = Path(source_url)
|
||||
if src_path.is_file():
|
||||
shutil.copy(src_path, output_dir / src_path.name)
|
||||
print(f"Copied: {src_path.name}")
|
||||
else:
|
||||
print(f"File not found: {source_url}")
|
||||
exit(1)
|
||||
else:
|
||||
print(f"Unsupported source: {source_url}")
|
||||
exit(1)
|
||||
resources:
|
||||
requests:
|
||||
memory: 256Mi
|
||||
cpu: 100m
|
||||
|
||||
# Template: Preprocess audio files
|
||||
- name: preprocess
|
||||
inputs:
|
||||
parameters:
|
||||
- name: sample-rate
|
||||
artifacts:
|
||||
- name: raw-audio
|
||||
path: /tmp/raw-audio
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: processed-audio
|
||||
path: /tmp/processed-audio
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [bash]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
set -e
|
||||
|
||||
# Install ffmpeg and dependencies
|
||||
apt-get update && apt-get install -y ffmpeg > /dev/null 2>&1
|
||||
pip install -q pydub numpy soundfile
|
||||
|
||||
python3 << 'EOF'
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pydub import AudioSegment
|
||||
import soundfile as sf
|
||||
|
||||
SAMPLE_RATE = int("{{inputs.parameters.sample-rate}}")
|
||||
input_dir = Path("/tmp/raw-audio")
|
||||
output_dir = Path("/tmp/processed-audio")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
audio_extensions = {".wav", ".mp3", ".flac", ".ogg", ".m4a"}
|
||||
|
||||
for audio_file in input_dir.iterdir():
|
||||
if audio_file.suffix.lower() not in audio_extensions:
|
||||
continue
|
||||
|
||||
print(f"Processing: {audio_file.name}")
|
||||
|
||||
try:
|
||||
# Load audio
|
||||
audio = AudioSegment.from_file(str(audio_file))
|
||||
|
||||
# Convert to mono if stereo
|
||||
if audio.channels > 1:
|
||||
audio = audio.set_channels(1)
|
||||
|
||||
# Resample to target sample rate
|
||||
audio = audio.set_frame_rate(SAMPLE_RATE)
|
||||
|
||||
# Normalize audio
|
||||
audio = audio.normalize()
|
||||
|
||||
# Export as WAV
|
||||
output_file = output_dir / f"{audio_file.stem}.wav"
|
||||
audio.export(str(output_file), format="wav")
|
||||
print(f" -> Saved: {output_file.name}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" -> Error processing {audio_file.name}: {e}")
|
||||
continue
|
||||
|
||||
processed_files = list(output_dir.glob("*.wav"))
|
||||
print(f"\nProcessed {len(processed_files)} audio files")
|
||||
|
||||
if len(processed_files) == 0:
|
||||
print("Error: No files were successfully processed!")
|
||||
exit(1)
|
||||
EOF
|
||||
resources:
|
||||
requests:
|
||||
memory: 2Gi
|
||||
cpu: "1"
|
||||
|
||||
# Template: Auto-transcribe audio using Coqui STT
|
||||
- name: transcribe-audio
|
||||
inputs:
|
||||
parameters:
|
||||
- name: language
|
||||
artifacts:
|
||||
- name: audio-files
|
||||
path: /tmp/audio
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: transcripts
|
||||
path: /tmp/transcripts
|
||||
container:
|
||||
image: ghcr.io/coqui-ai/stt:latest
|
||||
command: [bash]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
set -e
|
||||
|
||||
# Install additional dependencies
|
||||
pip install -q numpy scipy
|
||||
|
||||
python3 << 'EOF'
|
||||
import csv
|
||||
import os
|
||||
import wave
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from stt import Model
|
||||
|
||||
LANGUAGE = "{{inputs.parameters.language}}"
|
||||
input_dir = Path("/tmp/audio")
|
||||
output_dir = Path("/tmp/transcripts")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Model paths - Coqui STT models are typically pre-installed in the container
|
||||
# or can be downloaded from https://coqui.ai/models
|
||||
MODEL_DIR = Path("/models/stt")
|
||||
|
||||
# Try to find model files
|
||||
model_file = None
|
||||
scorer_file = None
|
||||
|
||||
# Check for language-specific models
|
||||
lang_model_dir = MODEL_DIR / LANGUAGE
|
||||
if lang_model_dir.exists():
|
||||
for f in lang_model_dir.glob("*.tflite"):
|
||||
model_file = f
|
||||
for f in lang_model_dir.glob("*.scorer"):
|
||||
scorer_file = f
|
||||
|
||||
# Fallback to default English model location
|
||||
if model_file is None:
|
||||
default_paths = [
|
||||
MODEL_DIR / "model.tflite",
|
||||
Path("/usr/share/stt/model.tflite"),
|
||||
Path("/opt/stt/model.tflite"),
|
||||
]
|
||||
for p in default_paths:
|
||||
if p.exists():
|
||||
model_file = p
|
||||
break
|
||||
|
||||
if model_file is None:
|
||||
# Download model if not found
|
||||
print("Downloading Coqui STT model...")
|
||||
import urllib.request
|
||||
import tarfile
|
||||
|
||||
model_url = "https://github.com/coqui-ai/STT-models/releases/download/english/coqui-stt-1.0.0-lg-vocab.tflite"
|
||||
scorer_url = "https://github.com/coqui-ai/STT-models/releases/download/english/coqui-stt-1.0.0-lg-vocab.scorer"
|
||||
|
||||
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
model_file = MODEL_DIR / "model.tflite"
|
||||
scorer_file = MODEL_DIR / "model.scorer"
|
||||
|
||||
urllib.request.urlretrieve(model_url, model_file)
|
||||
urllib.request.urlretrieve(scorer_url, scorer_file)
|
||||
print("Model downloaded successfully")
|
||||
|
||||
print(f"Loading Coqui STT model: {model_file}")
|
||||
model = Model(str(model_file))
|
||||
|
||||
if scorer_file and scorer_file.exists():
|
||||
print(f"Loading scorer: {scorer_file}")
|
||||
model.enableExternalScorer(str(scorer_file))
|
||||
|
||||
transcripts = []
|
||||
|
||||
for audio_file in sorted(input_dir.glob("*.wav")):
|
||||
print(f"Transcribing: {audio_file.name}")
|
||||
|
||||
try:
|
||||
# Read WAV file
|
||||
with wave.open(str(audio_file), 'rb') as w:
|
||||
sample_rate = w.getframerate()
|
||||
frames = w.getnframes()
|
||||
audio_data = w.readframes(frames)
|
||||
|
||||
# Convert to int16 array
|
||||
audio = np.frombuffer(audio_data, dtype=np.int16)
|
||||
|
||||
# Resample if needed (Coqui STT expects 16kHz)
|
||||
if sample_rate != 16000:
|
||||
from scipy import signal
|
||||
audio = signal.resample(audio, int(len(audio) * 16000 / sample_rate))
|
||||
audio = audio.astype(np.int16)
|
||||
|
||||
# Run inference
|
||||
text = model.stt(audio)
|
||||
|
||||
transcripts.append({
|
||||
"audio_file": audio_file.name,
|
||||
"transcript": text
|
||||
})
|
||||
print(f" -> {text[:100] if text else '(empty)'}...")
|
||||
except Exception as e:
|
||||
print(f" -> Error: {e}")
|
||||
continue
|
||||
|
||||
# Write CSV
|
||||
csv_file = output_dir / "transcripts.csv"
|
||||
with open(csv_file, "w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=["audio_file", "transcript"])
|
||||
writer.writeheader()
|
||||
writer.writerows(transcripts)
|
||||
|
||||
print(f"\nTranscribed {len(transcripts)} files")
|
||||
print(f"Saved to: {csv_file}")
|
||||
EOF
|
||||
resources:
|
||||
requests:
|
||||
memory: 4Gi
|
||||
cpu: "2"
|
||||
limits:
|
||||
memory: 8Gi
|
||||
cpu: "4"
|
||||
|
||||
# Template: Prepare dataset in Coqui TTS format
|
||||
- name: prepare-coqui-dataset
|
||||
inputs:
|
||||
parameters:
|
||||
- name: voice-name
|
||||
- name: language
|
||||
artifacts:
|
||||
- name: audio-files
|
||||
path: /tmp/audio
|
||||
- name: transcripts
|
||||
path: /tmp/transcripts
|
||||
optional: true
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: dataset
|
||||
path: /tmp/dataset
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
VOICE_NAME = "{{inputs.parameters.voice-name}}"
|
||||
LANGUAGE = "{{inputs.parameters.language}}"
|
||||
|
||||
audio_dir = Path("/tmp/audio")
|
||||
transcripts_dir = Path("/tmp/transcripts")
|
||||
output_dir = Path("/tmp/dataset")
|
||||
wavs_dir = output_dir / "wavs"
|
||||
wavs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Preparing Coqui TTS dataset for voice: {VOICE_NAME}")
|
||||
|
||||
# Find transcripts file
|
||||
transcripts_file = None
|
||||
for f in transcripts_dir.glob("*.csv"):
|
||||
transcripts_file = f
|
||||
break
|
||||
|
||||
if transcripts_file is None:
|
||||
# Check for .txt files (simple format: filename|text)
|
||||
for f in transcripts_dir.glob("*.txt"):
|
||||
if f.name != "placeholder.txt":
|
||||
transcripts_file = f
|
||||
break
|
||||
|
||||
if transcripts_file is None:
|
||||
print("Error: No transcripts file found!")
|
||||
exit(1)
|
||||
|
||||
print(f"Using transcripts: {transcripts_file}")
|
||||
|
||||
# Parse transcripts
|
||||
transcripts = {}
|
||||
|
||||
if transcripts_file.suffix == ".csv":
|
||||
with open(transcripts_file, "r", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
# Handle various column name conventions
|
||||
audio = row.get("audio_file") or row.get("audio") or row.get("file") or row.get("wav")
|
||||
text = row.get("transcript") or row.get("text") or row.get("sentence")
|
||||
if audio and text:
|
||||
transcripts[audio] = text.strip()
|
||||
else:
|
||||
# Simple pipe-separated format: filename|text
|
||||
with open(transcripts_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if "|" in line:
|
||||
parts = line.split("|", 1)
|
||||
if len(parts) == 2:
|
||||
transcripts[parts[0]] = parts[1]
|
||||
|
||||
print(f"Loaded {len(transcripts)} transcripts")
|
||||
|
||||
# Copy audio files and create metadata
|
||||
metadata_lines = []
|
||||
|
||||
for audio_file in sorted(audio_dir.glob("*.wav")):
|
||||
# Try to match transcript
|
||||
text = None
|
||||
for key in [audio_file.name, audio_file.stem, audio_file.stem + ".wav"]:
|
||||
if key in transcripts:
|
||||
text = transcripts[key]
|
||||
break
|
||||
|
||||
if text is None:
|
||||
print(f"Warning: No transcript for {audio_file.name}, skipping")
|
||||
continue
|
||||
|
||||
# Copy audio file
|
||||
dest_file = wavs_dir / audio_file.name
|
||||
shutil.copy(audio_file, dest_file)
|
||||
|
||||
# Add to metadata (LJSpeech format: filename|text|text)
|
||||
# Coqui uses: audio_file|text|text (normalized text optional)
|
||||
metadata_lines.append(f"{audio_file.stem}|{text}|{text}")
|
||||
|
||||
# Write metadata.csv
|
||||
metadata_file = output_dir / "metadata.csv"
|
||||
with open(metadata_file, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(metadata_lines))
|
||||
|
||||
print(f"Created dataset with {len(metadata_lines)} samples")
|
||||
|
||||
# Create dataset config
|
||||
config = {
|
||||
"name": VOICE_NAME,
|
||||
"language": LANGUAGE,
|
||||
"num_samples": len(metadata_lines),
|
||||
"format": "ljspeech"
|
||||
}
|
||||
|
||||
with open(output_dir / "dataset_config.json", "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
print(f"Dataset ready at: {output_dir}")
|
||||
|
||||
if len(metadata_lines) < 10:
|
||||
print("Warning: Very small dataset! Recommend at least 100+ samples for good results.")
|
||||
resources:
|
||||
requests:
|
||||
memory: 1Gi
|
||||
cpu: 500m
|
||||
|
||||
# Template: Train Coqui TTS model
|
||||
- name: train-tts
|
||||
inputs:
|
||||
parameters:
|
||||
- name: voice-name
|
||||
- name: base-model
|
||||
- name: language
|
||||
- name: num-epochs
|
||||
- name: batch-size
|
||||
- name: learning-rate
|
||||
artifacts:
|
||||
- name: dataset
|
||||
path: /tmp/dataset
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: model
|
||||
path: /tmp/output
|
||||
container:
|
||||
image: ghcr.io/coqui-ai/tts:latest
|
||||
command: [bash]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
set -e
|
||||
|
||||
VOICE_NAME="{{inputs.parameters.voice-name}}"
|
||||
BASE_MODEL="{{inputs.parameters.base-model}}"
|
||||
LANGUAGE="{{inputs.parameters.language}}"
|
||||
NUM_EPOCHS="{{inputs.parameters.num-epochs}}"
|
||||
BATCH_SIZE="{{inputs.parameters.batch-size}}"
|
||||
LEARNING_RATE="{{inputs.parameters.learning-rate}}"
|
||||
|
||||
DATASET_DIR="/tmp/dataset"
|
||||
OUTPUT_DIR="/tmp/output"
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
echo "=== Coqui TTS Voice Training ==="
|
||||
echo "Voice Name: $VOICE_NAME"
|
||||
echo "Base Model: $BASE_MODEL"
|
||||
echo "Language: $LANGUAGE"
|
||||
echo "Epochs: $NUM_EPOCHS"
|
||||
echo "Batch Size: $BATCH_SIZE"
|
||||
echo "Learning Rate: $LEARNING_RATE"
|
||||
echo ""
|
||||
|
||||
# Download base model if specified for fine-tuning
|
||||
RESTORE_PATH=""
|
||||
if [ "$BASE_MODEL" != "" ] && [ "$BASE_MODEL" != "none" ]; then
|
||||
echo "Downloading base model for fine-tuning: $BASE_MODEL"
|
||||
# Use tts to download the model and get its path
|
||||
MODEL_PATH=$(python3 -c "
|
||||
from TTS.utils.manage import ModelManager
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
model_name = '$BASE_MODEL'
|
||||
manager = ModelManager()
|
||||
|
||||
# Download the model
|
||||
model_path, config_path, _ = manager.download_model(model_name)
|
||||
print(model_path)
|
||||
")
|
||||
RESTORE_PATH="$MODEL_PATH"
|
||||
echo "Base model path: $RESTORE_PATH"
|
||||
fi
|
||||
|
||||
# Create and run training script following Coqui docs pattern
|
||||
python3 << EOF
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Trainer: Where the magic happens
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
# Model configs
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.vits import Vits
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
# Paths
|
||||
DATASET_DIR = Path("$DATASET_DIR")
|
||||
OUTPUT_DIR = Path("$OUTPUT_DIR")
|
||||
RESTORE_PATH = "$RESTORE_PATH" if "$RESTORE_PATH" else None
|
||||
|
||||
print(f"Dataset: {DATASET_DIR}")
|
||||
print(f"Output: {OUTPUT_DIR}")
|
||||
print(f"Restore from: {RESTORE_PATH}")
|
||||
|
||||
# Define dataset config (LJSpeech format)
|
||||
dataset_config = BaseDatasetConfig(
|
||||
formatter="ljspeech",
|
||||
meta_file_train="metadata.csv",
|
||||
path=str(DATASET_DIR),
|
||||
language="$LANGUAGE",
|
||||
)
|
||||
|
||||
# Initialize training configuration
|
||||
config = VitsConfig(
|
||||
run_name="$VOICE_NAME",
|
||||
output_path=str(OUTPUT_DIR),
|
||||
datasets=[dataset_config],
|
||||
batch_size=int("$BATCH_SIZE"),
|
||||
eval_batch_size=max(1, int("$BATCH_SIZE") // 2),
|
||||
num_loader_workers=4,
|
||||
num_eval_loader_workers=2,
|
||||
run_eval=True,
|
||||
test_delay_epochs=5,
|
||||
epochs=int("$NUM_EPOCHS"),
|
||||
text_cleaner="phoneme_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="$LANGUAGE",
|
||||
phoneme_cache_path=str(OUTPUT_DIR / "phoneme_cache"),
|
||||
compute_input_seq_cache=True,
|
||||
print_step=25,
|
||||
print_eval=False,
|
||||
mixed_precision=True,
|
||||
save_step=500,
|
||||
save_n_checkpoints=3,
|
||||
save_best_after=1000,
|
||||
lr=float("$LEARNING_RATE"),
|
||||
# Audio settings for typical voice cloning
|
||||
audio={
|
||||
"sample_rate": 22050,
|
||||
"resample": True,
|
||||
"do_trim_silence": True,
|
||||
"trim_db": 45,
|
||||
},
|
||||
)
|
||||
|
||||
# Initialize the audio processor
|
||||
# Used for feature extraction and audio I/O
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
|
||||
# Initialize the tokenizer
|
||||
# Converts text to sequences of token IDs
|
||||
tokenizer, config = TTSTokenizer.init_from_config(config)
|
||||
|
||||
# Load data samples
|
||||
# Each sample is [text, audio_file_path, speaker_name]
|
||||
train_samples, eval_samples = load_tts_samples(
|
||||
dataset_config,
|
||||
eval_split=True,
|
||||
eval_split_max_size=config.eval_split_max_size,
|
||||
eval_split_size=config.eval_split_size,
|
||||
)
|
||||
|
||||
print(f"Training samples: {len(train_samples)}")
|
||||
print(f"Eval samples: {len(eval_samples)}")
|
||||
|
||||
# Initialize the model
|
||||
model = Vits(config, ap, tokenizer, speaker_manager=None)
|
||||
|
||||
# Set up trainer arguments
|
||||
trainer_args = TrainerArgs(
|
||||
restore_path=RESTORE_PATH,
|
||||
skip_train_epoch=False,
|
||||
)
|
||||
|
||||
# Initialize the trainer
|
||||
trainer = Trainer(
|
||||
trainer_args,
|
||||
config,
|
||||
output_path=str(OUTPUT_DIR),
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
)
|
||||
|
||||
# Start training
|
||||
print("\n" + "=" * 50)
|
||||
print("Starting training...")
|
||||
print("=" * 50 + "\n")
|
||||
|
||||
trainer.fit()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Training complete!")
|
||||
print("=" * 50)
|
||||
EOF
|
||||
|
||||
echo ""
|
||||
echo "Training complete!"
|
||||
echo "Model saved to: $OUTPUT_DIR"
|
||||
ls -la "$OUTPUT_DIR"
|
||||
resources:
|
||||
requests:
|
||||
memory: 16Gi
|
||||
cpu: "4"
|
||||
nvidia.com/gpu: "1"
|
||||
limits:
|
||||
memory: 32Gi
|
||||
cpu: "8"
|
||||
nvidia.com/gpu: "1"
|
||||
volumeMounts:
|
||||
- name: training-workspace
|
||||
mountPath: /tmp/workspace
|
||||
|
||||
# Template: Export trained model
|
||||
- name: export-trained-model
|
||||
inputs:
|
||||
parameters:
|
||||
- name: voice-name
|
||||
- name: output-path
|
||||
artifacts:
|
||||
- name: trained-model
|
||||
path: /tmp/trained-model
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: exported-model
|
||||
path: /tmp/exported
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [bash]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
set -e
|
||||
|
||||
pip install -q boto3
|
||||
|
||||
python3 << 'EOF'
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
VOICE_NAME = "{{inputs.parameters.voice-name}}"
|
||||
OUTPUT_PATH = "{{inputs.parameters.output-path}}"
|
||||
|
||||
model_dir = Path("/tmp/trained-model")
|
||||
export_dir = Path("/tmp/exported")
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Exporting trained model: {VOICE_NAME}")
|
||||
print(f"Target path: {OUTPUT_PATH}")
|
||||
|
||||
# Find best checkpoint
|
||||
checkpoints = list(model_dir.glob("best_model*.pth")) + list(model_dir.glob("checkpoint_*.pth"))
|
||||
if not checkpoints:
|
||||
checkpoints = list(model_dir.glob("*.pth"))
|
||||
|
||||
if not checkpoints:
|
||||
print("Error: No model checkpoints found!")
|
||||
exit(1)
|
||||
|
||||
# Sort by modification time and get newest
|
||||
checkpoints.sort(key=lambda x: x.stat().st_mtime, reverse=True)
|
||||
best_checkpoint = checkpoints[0]
|
||||
print(f"Using checkpoint: {best_checkpoint.name}")
|
||||
|
||||
# Create export package
|
||||
package_dir = export_dir / VOICE_NAME
|
||||
package_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy model files
|
||||
shutil.copy(best_checkpoint, package_dir / "model.pth")
|
||||
|
||||
# Copy config if exists
|
||||
config_file = model_dir / "config.json"
|
||||
if config_file.exists():
|
||||
shutil.copy(config_file, package_dir / "config.json")
|
||||
|
||||
# Create model info
|
||||
model_info = {
|
||||
"name": VOICE_NAME,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"checkpoint": best_checkpoint.name,
|
||||
"type": "coqui-tts"
|
||||
}
|
||||
|
||||
with open(package_dir / "model_info.json", "w") as f:
|
||||
json.dump(model_info, f, indent=2)
|
||||
|
||||
# Create tarball
|
||||
archive_name = f"{VOICE_NAME}.tar.gz"
|
||||
shutil.make_archive(
|
||||
str(export_dir / VOICE_NAME),
|
||||
"gztar",
|
||||
export_dir,
|
||||
VOICE_NAME
|
||||
)
|
||||
|
||||
print(f"Created archive: {archive_name}")
|
||||
|
||||
# Upload to destination
|
||||
if OUTPUT_PATH.startswith("s3://"):
|
||||
import boto3
|
||||
s3 = boto3.client("s3")
|
||||
bucket, key = OUTPUT_PATH[5:].split("/", 1)
|
||||
key = f"{key}/{archive_name}"
|
||||
s3.upload_file(str(export_dir / archive_name), bucket, key)
|
||||
print(f"Uploaded to: s3://{bucket}/{key}")
|
||||
|
||||
elif OUTPUT_PATH.startswith("/"):
|
||||
# Local/NFS path
|
||||
dest_path = Path(OUTPUT_PATH)
|
||||
dest_path.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(export_dir / archive_name, dest_path / archive_name)
|
||||
# Also copy uncompressed for easy access
|
||||
shutil.copytree(package_dir, dest_path / VOICE_NAME, dirs_exist_ok=True)
|
||||
print(f"Saved to: {dest_path / archive_name}")
|
||||
|
||||
print("\nExport complete!")
|
||||
print(f"Model package contents:")
|
||||
for f in package_dir.iterdir():
|
||||
print(f" - {f.name}")
|
||||
EOF
|
||||
resources:
|
||||
requests:
|
||||
memory: 1Gi
|
||||
cpu: 500m
|
||||
369
document-ingestion.yaml
Normal file
369
document-ingestion.yaml
Normal file
@@ -0,0 +1,369 @@
|
||||
# Document Ingestion Workflow
|
||||
# Ingests documents from a source URL into Milvus vector database
|
||||
# Triggered via NATS: ai.pipeline.trigger with pipeline="document-ingestion"
|
||||
---
|
||||
apiVersion: argoproj.io/v1alpha1
|
||||
kind: WorkflowTemplate
|
||||
metadata:
|
||||
name: document-ingestion
|
||||
namespace: ai-ml
|
||||
labels:
|
||||
app.kubernetes.io/name: document-ingestion
|
||||
app.kubernetes.io/part-of: llm-workflows
|
||||
spec:
|
||||
entrypoint: ingest-documents
|
||||
serviceAccountName: argo-workflow
|
||||
|
||||
arguments:
|
||||
parameters:
|
||||
- name: source-url
|
||||
description: "URL to fetch documents from (S3, HTTP, or local path)"
|
||||
- name: collection-name
|
||||
value: "knowledge_base"
|
||||
description: "Milvus collection name"
|
||||
- name: chunk-size
|
||||
value: "512"
|
||||
description: "Text chunk size in characters"
|
||||
- name: chunk-overlap
|
||||
value: "50"
|
||||
description: "Overlap between chunks"
|
||||
|
||||
templates:
|
||||
- name: ingest-documents
|
||||
dag:
|
||||
tasks:
|
||||
- name: fetch-documents
|
||||
template: fetch-docs
|
||||
arguments:
|
||||
parameters:
|
||||
- name: source-url
|
||||
value: "{{workflow.parameters.source-url}}"
|
||||
|
||||
- name: chunk-documents
|
||||
template: chunk-docs
|
||||
dependencies: [fetch-documents]
|
||||
arguments:
|
||||
parameters:
|
||||
- name: chunk-size
|
||||
value: "{{workflow.parameters.chunk-size}}"
|
||||
- name: chunk-overlap
|
||||
value: "{{workflow.parameters.chunk-overlap}}"
|
||||
artifacts:
|
||||
- name: documents
|
||||
from: "{{tasks.fetch-documents.outputs.artifacts.documents}}"
|
||||
|
||||
- name: generate-embeddings
|
||||
template: embed-docs
|
||||
dependencies: [chunk-documents]
|
||||
arguments:
|
||||
artifacts:
|
||||
- name: chunks
|
||||
from: "{{tasks.chunk-documents.outputs.artifacts.chunks}}"
|
||||
|
||||
- name: store-in-milvus
|
||||
template: store-docs
|
||||
dependencies: [generate-embeddings]
|
||||
arguments:
|
||||
parameters:
|
||||
- name: collection-name
|
||||
value: "{{workflow.parameters.collection-name}}"
|
||||
artifacts:
|
||||
- name: embeddings
|
||||
from: "{{tasks.generate-embeddings.outputs.artifacts.embeddings}}"
|
||||
|
||||
- name: fetch-docs
|
||||
inputs:
|
||||
parameters:
|
||||
- name: source-url
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: documents
|
||||
path: /tmp/documents
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
import json
|
||||
import os
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
|
||||
source_url = "{{inputs.parameters.source-url}}"
|
||||
output_dir = Path("/tmp/documents")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Fetching documents from: {source_url}")
|
||||
|
||||
# Handle different source types
|
||||
if source_url.startswith("s3://"):
|
||||
import subprocess
|
||||
subprocess.run(["pip", "install", "boto3", "-q"], check=True)
|
||||
import boto3
|
||||
s3 = boto3.client("s3")
|
||||
bucket, prefix = source_url[5:].split("/", 1)
|
||||
response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
|
||||
for obj in response.get("Contents", []):
|
||||
key = obj["Key"]
|
||||
local_path = output_dir / Path(key).name
|
||||
s3.download_file(bucket, key, str(local_path))
|
||||
print(f"Downloaded: {key}")
|
||||
elif source_url.startswith("http"):
|
||||
# Single file download
|
||||
filename = source_url.split("/")[-1] or "document.txt"
|
||||
local_path = output_dir / filename
|
||||
urllib.request.urlretrieve(source_url, local_path)
|
||||
print(f"Downloaded: {filename}")
|
||||
else:
|
||||
print(f"Unsupported URL scheme: {source_url}")
|
||||
exit(1)
|
||||
|
||||
# List downloaded files
|
||||
files = list(output_dir.glob("*"))
|
||||
print(f"Downloaded {len(files)} files")
|
||||
|
||||
# Create manifest
|
||||
manifest = {"files": [str(f) for f in files]}
|
||||
with open(output_dir / "manifest.json", "w") as f:
|
||||
json.dump(manifest, f)
|
||||
resources:
|
||||
requests:
|
||||
memory: 256Mi
|
||||
cpu: 100m
|
||||
|
||||
- name: chunk-docs
|
||||
inputs:
|
||||
parameters:
|
||||
- name: chunk-size
|
||||
- name: chunk-overlap
|
||||
artifacts:
|
||||
- name: documents
|
||||
path: /tmp/documents
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: chunks
|
||||
path: /tmp/chunks
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
chunk_size = int("{{inputs.parameters.chunk-size}}")
|
||||
chunk_overlap = int("{{inputs.parameters.chunk-overlap}}")
|
||||
|
||||
input_dir = Path("/tmp/documents")
|
||||
output_dir = Path("/tmp/chunks")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load manifest
|
||||
with open(input_dir / "manifest.json") as f:
|
||||
manifest = json.load(f)
|
||||
|
||||
all_chunks = []
|
||||
|
||||
for filepath in manifest["files"]:
|
||||
filepath = Path(filepath)
|
||||
if not filepath.exists():
|
||||
continue
|
||||
|
||||
print(f"Processing: {filepath.name}")
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
except Exception as e:
|
||||
print(f"Error reading {filepath}: {e}")
|
||||
continue
|
||||
|
||||
# Simple chunking
|
||||
chunks = []
|
||||
start = 0
|
||||
while start < len(content):
|
||||
end = start + chunk_size
|
||||
chunk = content[start:end]
|
||||
if chunk.strip():
|
||||
chunks.append({
|
||||
"text": chunk,
|
||||
"source": filepath.name,
|
||||
"chunk_index": len(chunks)
|
||||
})
|
||||
start = end - chunk_overlap
|
||||
|
||||
all_chunks.extend(chunks)
|
||||
print(f" Created {len(chunks)} chunks")
|
||||
|
||||
# Save chunks
|
||||
with open(output_dir / "chunks.json", "w") as f:
|
||||
json.dump({"chunks": all_chunks}, f)
|
||||
|
||||
print(f"Total chunks: {len(all_chunks)}")
|
||||
resources:
|
||||
requests:
|
||||
memory: 512Mi
|
||||
cpu: 100m
|
||||
|
||||
- name: embed-docs
|
||||
inputs:
|
||||
artifacts:
|
||||
- name: chunks
|
||||
path: /tmp/chunks
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: embeddings
|
||||
path: /tmp/embeddings
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
import subprocess
|
||||
subprocess.run(["pip", "install", "httpx", "-q"], check=True)
|
||||
|
||||
import json
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
|
||||
EMBEDDINGS_URL = "http://embeddings-predictor.ai-ml.svc.cluster.local"
|
||||
BATCH_SIZE = 32
|
||||
|
||||
input_dir = Path("/tmp/chunks")
|
||||
output_dir = Path("/tmp/embeddings")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load chunks
|
||||
with open(input_dir / "chunks.json") as f:
|
||||
data = json.load(f)
|
||||
chunks = data["chunks"]
|
||||
|
||||
print(f"Generating embeddings for {len(chunks)} chunks")
|
||||
|
||||
# Generate embeddings in batches
|
||||
all_embeddings = []
|
||||
with httpx.Client(timeout=120.0) as client:
|
||||
for i in range(0, len(chunks), BATCH_SIZE):
|
||||
batch = chunks[i:i+BATCH_SIZE]
|
||||
texts = [c["text"] for c in batch]
|
||||
|
||||
response = client.post(
|
||||
f"{EMBEDDINGS_URL}/embeddings",
|
||||
json={"input": texts, "model": "bge"}
|
||||
)
|
||||
result = response.json()
|
||||
|
||||
for j, emb_data in enumerate(result.get("data", [])):
|
||||
all_embeddings.append({
|
||||
"text": batch[j]["text"],
|
||||
"source": batch[j]["source"],
|
||||
"chunk_index": batch[j]["chunk_index"],
|
||||
"embedding": emb_data["embedding"]
|
||||
})
|
||||
|
||||
print(f" Processed batch {i//BATCH_SIZE + 1}/{(len(chunks)-1)//BATCH_SIZE + 1}")
|
||||
|
||||
# Save embeddings
|
||||
with open(output_dir / "embeddings.json", "w") as f:
|
||||
json.dump({"embeddings": all_embeddings}, f)
|
||||
|
||||
print(f"Generated {len(all_embeddings)} embeddings")
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: ai-services-config
|
||||
resources:
|
||||
requests:
|
||||
memory: 1Gi
|
||||
cpu: 200m
|
||||
|
||||
- name: store-docs
|
||||
inputs:
|
||||
parameters:
|
||||
- name: collection-name
|
||||
artifacts:
|
||||
- name: embeddings
|
||||
path: /tmp/embeddings
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
import subprocess
|
||||
subprocess.run(["pip", "install", "pymilvus", "-q"], check=True)
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
|
||||
|
||||
MILVUS_HOST = "milvus.ai-ml.svc.cluster.local"
|
||||
MILVUS_PORT = 19530
|
||||
COLLECTION_NAME = "{{inputs.parameters.collection-name}}"
|
||||
EMBEDDING_DIM = 1024 # BGE-large dimension
|
||||
|
||||
input_dir = Path("/tmp/embeddings")
|
||||
|
||||
# Load embeddings
|
||||
with open(input_dir / "embeddings.json") as f:
|
||||
data = json.load(f)
|
||||
embeddings = data["embeddings"]
|
||||
|
||||
print(f"Storing {len(embeddings)} embeddings in Milvus")
|
||||
|
||||
# Connect to Milvus
|
||||
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
|
||||
print("Connected to Milvus")
|
||||
|
||||
# Create collection if not exists
|
||||
if not utility.has_collection(COLLECTION_NAME):
|
||||
fields = [
|
||||
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
|
||||
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=1024),
|
||||
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=EMBEDDING_DIM)
|
||||
]
|
||||
schema = CollectionSchema(fields, description="Knowledge base documents")
|
||||
collection = Collection(COLLECTION_NAME, schema)
|
||||
|
||||
# Create HNSW index
|
||||
index_params = {
|
||||
"metric_type": "COSINE",
|
||||
"index_type": "HNSW",
|
||||
"params": {"M": 16, "efConstruction": 256}
|
||||
}
|
||||
collection.create_index("embedding", index_params)
|
||||
print(f"Created collection: {COLLECTION_NAME}")
|
||||
else:
|
||||
collection = Collection(COLLECTION_NAME)
|
||||
print(f"Using existing collection: {COLLECTION_NAME}")
|
||||
|
||||
# Insert data in batches
|
||||
BATCH_SIZE = 100
|
||||
for i in range(0, len(embeddings), BATCH_SIZE):
|
||||
batch = embeddings[i:i+BATCH_SIZE]
|
||||
|
||||
data = [
|
||||
[e["text"] for e in batch],
|
||||
[e["source"] for e in batch],
|
||||
[e["embedding"] for e in batch]
|
||||
]
|
||||
|
||||
collection.insert(data)
|
||||
print(f" Inserted batch {i//BATCH_SIZE + 1}/{(len(embeddings)-1)//BATCH_SIZE + 1}")
|
||||
|
||||
# Flush to ensure data is persisted
|
||||
collection.flush()
|
||||
print(f"Successfully stored {len(embeddings)} documents")
|
||||
|
||||
connections.disconnect("default")
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: ai-services-config
|
||||
resources:
|
||||
requests:
|
||||
memory: 512Mi
|
||||
cpu: 100m
|
||||
270
eventsource-kfp.yaml
Normal file
270
eventsource-kfp.yaml
Normal file
@@ -0,0 +1,270 @@
|
||||
# Argo Events - EventSource for KFP and NATS integration
|
||||
# Enables bidirectional triggering between Argo Workflows and Kubeflow Pipelines
|
||||
---
|
||||
apiVersion: argoproj.io/v1alpha1
|
||||
kind: EventSource
|
||||
metadata:
|
||||
name: kfp-events
|
||||
namespace: ai-ml
|
||||
labels:
|
||||
app.kubernetes.io/name: kfp-events
|
||||
app.kubernetes.io/part-of: llm-workflows
|
||||
spec:
|
||||
service:
|
||||
ports:
|
||||
- name: webhook
|
||||
port: 12000
|
||||
targetPort: 12000
|
||||
# Webhook to receive KFP pipeline completion events
|
||||
webhook:
|
||||
kfp-completion:
|
||||
port: "12000"
|
||||
endpoint: /kfp/completion
|
||||
method: POST
|
||||
kfp-failure:
|
||||
port: "12000"
|
||||
endpoint: /kfp/failure
|
||||
method: POST
|
||||
# NATS for receiving pipeline trigger requests
|
||||
nats:
|
||||
pipeline-trigger:
|
||||
url: nats://nats.ai-ml.svc.cluster.local:4222
|
||||
subject: ai.pipeline.trigger
|
||||
jsonBody: true
|
||||
argo-trigger:
|
||||
url: nats://nats.ai-ml.svc.cluster.local:4222
|
||||
subject: ai.argo.trigger
|
||||
jsonBody: true
|
||||
kfp-trigger:
|
||||
url: nats://nats.ai-ml.svc.cluster.local:4222
|
||||
subject: ai.kfp.trigger
|
||||
jsonBody: true
|
||||
---
|
||||
# Sensor for handling KFP completion events
|
||||
apiVersion: argoproj.io/v1alpha1
|
||||
kind: Sensor
|
||||
metadata:
|
||||
name: kfp-completion-sensor
|
||||
namespace: ai-ml
|
||||
labels:
|
||||
app.kubernetes.io/name: kfp-completion-sensor
|
||||
app.kubernetes.io/part-of: llm-workflows
|
||||
spec:
|
||||
dependencies:
|
||||
- name: kfp-success
|
||||
eventSourceName: kfp-events
|
||||
eventName: kfp-completion
|
||||
filters:
|
||||
data:
|
||||
- path: body.status
|
||||
type: string
|
||||
value:
|
||||
- "SUCCEEDED"
|
||||
- name: kfp-failure
|
||||
eventSourceName: kfp-events
|
||||
eventName: kfp-failure
|
||||
triggers:
|
||||
# On KFP success, publish to NATS
|
||||
- template:
|
||||
name: notify-kfp-success
|
||||
nats:
|
||||
url: nats://nats.ai-ml.svc.cluster.local:4222
|
||||
subject: ai.pipeline.status.completed
|
||||
payload:
|
||||
- src:
|
||||
dependencyName: kfp-success
|
||||
dataKey: body.run_id
|
||||
dest: run_id
|
||||
- src:
|
||||
dependencyName: kfp-success
|
||||
dataKey: body.pipeline_name
|
||||
dest: pipeline_name
|
||||
- src:
|
||||
dependencyName: kfp-success
|
||||
dataKey: body.status
|
||||
dest: status
|
||||
retryStrategy:
|
||||
steps: 3
|
||||
# On KFP failure, trigger recovery workflow
|
||||
- template:
|
||||
name: kfp-failure-recovery
|
||||
k8s:
|
||||
operation: create
|
||||
source:
|
||||
resource:
|
||||
apiVersion: argoproj.io/v1alpha1
|
||||
kind: Workflow
|
||||
metadata:
|
||||
generateName: kfp-failure-handler-
|
||||
namespace: ai-ml
|
||||
spec:
|
||||
entrypoint: notify-failure
|
||||
arguments:
|
||||
parameters:
|
||||
- name: run-id
|
||||
- name: pipeline-name
|
||||
- name: error-message
|
||||
templates:
|
||||
- name: notify-failure
|
||||
inputs:
|
||||
parameters:
|
||||
- name: run-id
|
||||
- name: pipeline-name
|
||||
- name: error-message
|
||||
script:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
source: |
|
||||
import subprocess
|
||||
import sys
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "nats-py"])
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import nats
|
||||
|
||||
async def notify():
|
||||
nc = await nats.connect("nats://nats.ai-ml.svc.cluster.local:4222")
|
||||
await nc.publish(
|
||||
"ai.pipeline.status.failed",
|
||||
json.dumps({
|
||||
"run_id": "{{inputs.parameters.run-id}}",
|
||||
"pipeline_name": "{{inputs.parameters.pipeline-name}}",
|
||||
"error": "{{inputs.parameters.error-message}}",
|
||||
"source": "kubeflow"
|
||||
}).encode()
|
||||
)
|
||||
await nc.close()
|
||||
|
||||
asyncio.run(notify())
|
||||
parameters:
|
||||
- src:
|
||||
dependencyName: kfp-failure
|
||||
dataKey: body.run_id
|
||||
dest: spec.arguments.parameters.0.value
|
||||
- src:
|
||||
dependencyName: kfp-failure
|
||||
dataKey: body.pipeline_name
|
||||
dest: spec.arguments.parameters.1.value
|
||||
- src:
|
||||
dependencyName: kfp-failure
|
||||
dataKey: body.error
|
||||
dest: spec.arguments.parameters.2.value
|
||||
retryStrategy:
|
||||
steps: 3
|
||||
---
|
||||
# Sensor for NATS-triggered Argo Workflows
|
||||
apiVersion: argoproj.io/v1alpha1
|
||||
kind: Sensor
|
||||
metadata:
|
||||
name: nats-argo-sensor
|
||||
namespace: ai-ml
|
||||
labels:
|
||||
app.kubernetes.io/name: nats-argo-sensor
|
||||
app.kubernetes.io/part-of: llm-workflows
|
||||
spec:
|
||||
dependencies:
|
||||
- name: argo-trigger
|
||||
eventSourceName: kfp-events
|
||||
eventName: argo-trigger
|
||||
triggers:
|
||||
- template:
|
||||
name: trigger-argo-workflow
|
||||
k8s:
|
||||
operation: create
|
||||
source:
|
||||
resource:
|
||||
apiVersion: argoproj.io/v1alpha1
|
||||
kind: Workflow
|
||||
metadata:
|
||||
generateName: nats-triggered-
|
||||
namespace: ai-ml
|
||||
spec:
|
||||
workflowTemplateRef:
|
||||
name: placeholder
|
||||
arguments:
|
||||
parameters: []
|
||||
parameters:
|
||||
- src:
|
||||
dependencyName: argo-trigger
|
||||
dataKey: body.template
|
||||
dest: spec.workflowTemplateRef.name
|
||||
- src:
|
||||
dependencyName: argo-trigger
|
||||
dataKey: body.parameters
|
||||
dest: spec.arguments.parameters
|
||||
retryStrategy:
|
||||
steps: 3
|
||||
---
|
||||
# Sensor for NATS-triggered KFP Pipelines
|
||||
apiVersion: argoproj.io/v1alpha1
|
||||
kind: Sensor
|
||||
metadata:
|
||||
name: nats-kfp-sensor
|
||||
namespace: ai-ml
|
||||
labels:
|
||||
app.kubernetes.io/name: nats-kfp-sensor
|
||||
app.kubernetes.io/part-of: llm-workflows
|
||||
spec:
|
||||
dependencies:
|
||||
- name: kfp-trigger
|
||||
eventSourceName: kfp-events
|
||||
eventName: kfp-trigger
|
||||
triggers:
|
||||
# Trigger KFP via Argo Workflow (uses kfp-trigger template)
|
||||
- template:
|
||||
name: trigger-kfp-via-argo
|
||||
k8s:
|
||||
operation: create
|
||||
source:
|
||||
resource:
|
||||
apiVersion: argoproj.io/v1alpha1
|
||||
kind: Workflow
|
||||
metadata:
|
||||
generateName: kfp-via-nats-
|
||||
namespace: ai-ml
|
||||
spec:
|
||||
workflowTemplateRef:
|
||||
name: kfp-trigger
|
||||
arguments:
|
||||
parameters:
|
||||
- name: pipeline-id
|
||||
value: ""
|
||||
- name: pipeline-params
|
||||
value: "{}"
|
||||
- name: wait-for-completion
|
||||
value: "true"
|
||||
parameters:
|
||||
- src:
|
||||
dependencyName: kfp-trigger
|
||||
dataKey: body.pipeline_id
|
||||
dest: spec.arguments.parameters.0.value
|
||||
- src:
|
||||
dependencyName: kfp-trigger
|
||||
dataKey: body.parameters
|
||||
dest: spec.arguments.parameters.1.value
|
||||
operation: "stringify"
|
||||
- src:
|
||||
dependencyName: kfp-trigger
|
||||
dataKey: body.wait
|
||||
dest: spec.arguments.parameters.2.value
|
||||
operation: "stringify"
|
||||
retryStrategy:
|
||||
steps: 3
|
||||
---
|
||||
# Service for the EventSource webhook
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: kfp-events-webhook
|
||||
namespace: ai-ml
|
||||
labels:
|
||||
app.kubernetes.io/name: kfp-events
|
||||
app.kubernetes.io/part-of: llm-workflows
|
||||
spec:
|
||||
selector:
|
||||
eventsource-name: kfp-events
|
||||
ports:
|
||||
- name: webhook
|
||||
port: 12000
|
||||
targetPort: 12000
|
||||
555
hybrid-ml-training.yaml
Normal file
555
hybrid-ml-training.yaml
Normal file
@@ -0,0 +1,555 @@
|
||||
# Hybrid ML Training Workflow
|
||||
# Combines Argo Workflows orchestration with Kubeflow Pipeline ML components
|
||||
# Use case: Train a model using data from Milvus, with checkpointing and evaluation
|
||||
---
|
||||
apiVersion: argoproj.io/v1alpha1
|
||||
kind: WorkflowTemplate
|
||||
metadata:
|
||||
name: hybrid-ml-training
|
||||
namespace: ai-ml
|
||||
labels:
|
||||
app.kubernetes.io/name: hybrid-ml-training
|
||||
app.kubernetes.io/part-of: llm-workflows
|
||||
annotations:
|
||||
description: |
|
||||
Demonstrates hybrid Argo+KFP workflow:
|
||||
- Argo handles orchestration, branching, retry logic
|
||||
- KFP pipelines handle ML-specific operations (with caching)
|
||||
- NATS for status updates to frontends
|
||||
spec:
|
||||
entrypoint: hybrid-training
|
||||
serviceAccountName: argo-workflow
|
||||
|
||||
# Artifact repository for model checkpoints
|
||||
artifactRepositoryRef:
|
||||
configMap: artifact-repository
|
||||
key: default
|
||||
|
||||
arguments:
|
||||
parameters:
|
||||
- name: collection-name
|
||||
description: "Milvus collection to pull training data from"
|
||||
value: "dnd_text_embeddings"
|
||||
- name: model-name
|
||||
description: "Base model for fine-tuning"
|
||||
value: "mistralai/Mistral-7B-v0.3"
|
||||
- name: lora-rank
|
||||
description: "LoRA rank (higher = more params)"
|
||||
value: "16"
|
||||
- name: epochs
|
||||
description: "Training epochs"
|
||||
value: "3"
|
||||
- name: batch-size
|
||||
description: "Training batch size"
|
||||
value: "4"
|
||||
- name: output-path
|
||||
description: "S3 path for model output"
|
||||
value: "s3://models/lora-adapters"
|
||||
- name: notify-nats
|
||||
description: "Publish status to NATS"
|
||||
value: "true"
|
||||
|
||||
# Volumes for GPU caching
|
||||
volumes:
|
||||
- name: model-cache
|
||||
persistentVolumeClaim:
|
||||
claimName: model-cache-pvc
|
||||
- name: shm
|
||||
emptyDir:
|
||||
medium: Memory
|
||||
sizeLimit: 16Gi
|
||||
|
||||
templates:
|
||||
# Main DAG orchestrating the workflow
|
||||
- name: hybrid-training
|
||||
dag:
|
||||
tasks:
|
||||
- name: notify-start
|
||||
template: nats-notify
|
||||
when: "{{workflow.parameters.notify-nats}} == true"
|
||||
arguments:
|
||||
parameters:
|
||||
- name: subject
|
||||
value: "ai.pipeline.status.{{workflow.name}}"
|
||||
- name: message
|
||||
value: '{"status": "started", "pipeline": "hybrid-ml-training"}'
|
||||
|
||||
- name: prepare-data
|
||||
template: extract-training-data
|
||||
arguments:
|
||||
parameters:
|
||||
- name: collection-name
|
||||
value: "{{workflow.parameters.collection-name}}"
|
||||
|
||||
- name: validate-data
|
||||
template: validate-dataset
|
||||
dependencies: [prepare-data]
|
||||
arguments:
|
||||
artifacts:
|
||||
- name: dataset
|
||||
from: "{{tasks.prepare-data.outputs.artifacts.dataset}}"
|
||||
|
||||
# KFP Pipeline: Run embedding generation if needed
|
||||
- name: generate-embeddings
|
||||
template: trigger-kfp
|
||||
dependencies: [validate-data]
|
||||
when: "{{tasks.validate-data.outputs.parameters.needs-embeddings}} == true"
|
||||
arguments:
|
||||
parameters:
|
||||
- name: pipeline-id
|
||||
value: "embedding-generation"
|
||||
- name: params
|
||||
value: '{"input_path": "{{tasks.prepare-data.outputs.parameters.data-path}}"}'
|
||||
|
||||
# Training step (runs on GPU)
|
||||
- name: train-lora
|
||||
template: lora-training
|
||||
dependencies: [validate-data, generate-embeddings]
|
||||
arguments:
|
||||
parameters:
|
||||
- name: model-name
|
||||
value: "{{workflow.parameters.model-name}}"
|
||||
- name: lora-rank
|
||||
value: "{{workflow.parameters.lora-rank}}"
|
||||
- name: epochs
|
||||
value: "{{workflow.parameters.epochs}}"
|
||||
- name: batch-size
|
||||
value: "{{workflow.parameters.batch-size}}"
|
||||
artifacts:
|
||||
- name: dataset
|
||||
from: "{{tasks.prepare-data.outputs.artifacts.dataset}}"
|
||||
|
||||
# Evaluate model
|
||||
- name: evaluate
|
||||
template: evaluate-model
|
||||
dependencies: [train-lora]
|
||||
arguments:
|
||||
artifacts:
|
||||
- name: adapter
|
||||
from: "{{tasks.train-lora.outputs.artifacts.adapter}}"
|
||||
|
||||
# Branch based on evaluation results
|
||||
- name: check-quality
|
||||
template: quality-gate
|
||||
dependencies: [evaluate]
|
||||
arguments:
|
||||
parameters:
|
||||
- name: eval-score
|
||||
value: "{{tasks.evaluate.outputs.parameters.score}}"
|
||||
- name: threshold
|
||||
value: "0.7"
|
||||
|
||||
# If quality is good, upload to S3
|
||||
- name: upload-model
|
||||
template: upload-to-s3
|
||||
dependencies: [check-quality]
|
||||
when: "{{tasks.check-quality.outputs.parameters.passed}} == true"
|
||||
arguments:
|
||||
parameters:
|
||||
- name: output-path
|
||||
value: "{{workflow.parameters.output-path}}"
|
||||
artifacts:
|
||||
- name: adapter
|
||||
from: "{{tasks.train-lora.outputs.artifacts.adapter}}"
|
||||
|
||||
# If quality is poor, trigger retraining with different params
|
||||
- name: retry-training
|
||||
template: adjust-and-retry
|
||||
dependencies: [check-quality]
|
||||
when: "{{tasks.check-quality.outputs.parameters.passed}} == false"
|
||||
arguments:
|
||||
parameters:
|
||||
- name: current-rank
|
||||
value: "{{workflow.parameters.lora-rank}}"
|
||||
- name: current-epochs
|
||||
value: "{{workflow.parameters.epochs}}"
|
||||
|
||||
- name: notify-complete
|
||||
template: nats-notify
|
||||
when: "{{workflow.parameters.notify-nats}} == true"
|
||||
dependencies: [upload-model]
|
||||
arguments:
|
||||
parameters:
|
||||
- name: subject
|
||||
value: "ai.pipeline.status.{{workflow.name}}"
|
||||
- name: message
|
||||
value: '{"status": "completed", "score": "{{tasks.evaluate.outputs.parameters.score}}"}'
|
||||
|
||||
# Extract training data from Milvus
|
||||
- name: extract-training-data
|
||||
inputs:
|
||||
parameters:
|
||||
- name: collection-name
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: dataset
|
||||
path: /tmp/dataset
|
||||
parameters:
|
||||
- name: data-path
|
||||
valueFrom:
|
||||
path: /tmp/data-path
|
||||
script:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
source: |
|
||||
import subprocess
|
||||
import sys
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "pymilvus", "pandas", "pyarrow"])
|
||||
|
||||
from pymilvus import connections, Collection
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
connections.connect(host="milvus.ai-ml.svc.cluster.local", port=19530)
|
||||
|
||||
collection = Collection("{{inputs.parameters.collection-name}}")
|
||||
collection.load()
|
||||
|
||||
# Query all training samples
|
||||
results = collection.query(
|
||||
expr="source != ''",
|
||||
output_fields=["text", "source", "metadata"],
|
||||
limit=10000
|
||||
)
|
||||
|
||||
# Convert to training format
|
||||
df = pd.DataFrame(results)
|
||||
output_dir = Path("/tmp/dataset")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save as parquet for efficient loading
|
||||
df.to_parquet(output_dir / "train.parquet")
|
||||
|
||||
print(f"Extracted {len(df)} samples")
|
||||
with open("/tmp/data-path", "w") as f:
|
||||
f.write(str(output_dir / "train.parquet"))
|
||||
|
||||
# Validate dataset
|
||||
- name: validate-dataset
|
||||
inputs:
|
||||
artifacts:
|
||||
- name: dataset
|
||||
path: /tmp/dataset
|
||||
outputs:
|
||||
parameters:
|
||||
- name: needs-embeddings
|
||||
valueFrom:
|
||||
path: /tmp/needs-embeddings
|
||||
- name: sample-count
|
||||
valueFrom:
|
||||
path: /tmp/sample-count
|
||||
script:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
source: |
|
||||
import subprocess
|
||||
import sys
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "pandas", "pyarrow"])
|
||||
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
dataset_dir = Path("/tmp/dataset")
|
||||
parquet_files = list(dataset_dir.glob("*.parquet"))
|
||||
|
||||
if not parquet_files:
|
||||
raise ValueError("No parquet files found in dataset")
|
||||
|
||||
df = pd.read_parquet(parquet_files[0])
|
||||
sample_count = len(df)
|
||||
print(f"Dataset contains {sample_count} samples")
|
||||
|
||||
# Check if embeddings column exists
|
||||
needs_embeddings = "embedding" not in df.columns
|
||||
|
||||
with open("/tmp/needs-embeddings", "w") as f:
|
||||
f.write(str(needs_embeddings).lower())
|
||||
|
||||
with open("/tmp/sample-count", "w") as f:
|
||||
f.write(str(sample_count))
|
||||
|
||||
# Trigger KFP pipeline
|
||||
- name: trigger-kfp
|
||||
inputs:
|
||||
parameters:
|
||||
- name: pipeline-id
|
||||
- name: params
|
||||
outputs:
|
||||
parameters:
|
||||
- name: run-id
|
||||
valueFrom:
|
||||
path: /tmp/run-id
|
||||
script:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
source: |
|
||||
import subprocess
|
||||
import sys
|
||||
import json
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "kfp==2.12.1"])
|
||||
|
||||
from kfp import Client
|
||||
|
||||
client = Client(host="http://ml-pipeline.kubeflow.svc.cluster.local:8888")
|
||||
params = json.loads('''{{inputs.parameters.params}}''')
|
||||
|
||||
run = client.create_run_from_pipeline_func(
|
||||
pipeline_func=None,
|
||||
pipeline_id="{{inputs.parameters.pipeline-id}}",
|
||||
arguments=params
|
||||
)
|
||||
|
||||
print(f"Triggered KFP pipeline: {run.run_id}")
|
||||
with open("/tmp/run-id", "w") as f:
|
||||
f.write(run.run_id)
|
||||
|
||||
# LoRA training (GPU)
|
||||
- name: lora-training
|
||||
inputs:
|
||||
parameters:
|
||||
- name: model-name
|
||||
- name: lora-rank
|
||||
- name: epochs
|
||||
- name: batch-size
|
||||
artifacts:
|
||||
- name: dataset
|
||||
path: /data/dataset
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: adapter
|
||||
path: /output/adapter
|
||||
- name: logs
|
||||
path: /output/logs
|
||||
podSpecPatch: |
|
||||
containers:
|
||||
- name: main
|
||||
resources:
|
||||
requests:
|
||||
amd.com/gpu: 1
|
||||
limits:
|
||||
amd.com/gpu: 1
|
||||
script:
|
||||
image: ghcr.io/billy-davies-2/lora-trainer:latest
|
||||
command: [python]
|
||||
env:
|
||||
- name: HF_HOME
|
||||
value: /cache/huggingface
|
||||
- name: TRANSFORMERS_CACHE
|
||||
value: /cache/huggingface
|
||||
volumeMounts:
|
||||
- name: model-cache
|
||||
mountPath: /cache
|
||||
- name: shm
|
||||
mountPath: /dev/shm
|
||||
source: |
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Training script
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
|
||||
model_name = "{{inputs.parameters.model-name}}"
|
||||
lora_rank = int("{{inputs.parameters.lora-rank}}")
|
||||
epochs = int("{{inputs.parameters.epochs}}")
|
||||
batch_size = int("{{inputs.parameters.batch-size}}")
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("parquet", data_files="/data/dataset/*.parquet", split="train")
|
||||
|
||||
# Load model
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# Configure LoRA
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_rank * 2,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
# Training
|
||||
training_args = TrainingArguments(
|
||||
output_dir="/output/adapter",
|
||||
num_train_epochs=epochs,
|
||||
per_device_train_batch_size=batch_size,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=2e-4,
|
||||
logging_dir="/output/logs",
|
||||
save_strategy="epoch"
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
args=training_args,
|
||||
dataset_text_field="text"
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model("/output/adapter")
|
||||
print("Training complete!")
|
||||
|
||||
# Evaluate model
|
||||
- name: evaluate-model
|
||||
inputs:
|
||||
artifacts:
|
||||
- name: adapter
|
||||
path: /input/adapter
|
||||
outputs:
|
||||
parameters:
|
||||
- name: score
|
||||
valueFrom:
|
||||
path: /tmp/score
|
||||
script:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
source: |
|
||||
import subprocess
|
||||
import sys
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "httpx"])
|
||||
|
||||
import httpx
|
||||
import json
|
||||
|
||||
# Run evaluation using vLLM with the adapter
|
||||
test_prompts = [
|
||||
"What is the capital of France?",
|
||||
"Explain machine learning in simple terms.",
|
||||
"Write a haiku about coding."
|
||||
]
|
||||
|
||||
scores = []
|
||||
with httpx.Client(timeout=120.0) as client:
|
||||
for prompt in test_prompts:
|
||||
response = client.post(
|
||||
"http://llm-draft.ai-ml.svc.cluster.local:8000/v1/chat/completions",
|
||||
json={
|
||||
"model": "local-adapter",
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 200
|
||||
}
|
||||
)
|
||||
# Simple scoring based on response coherence
|
||||
result = response.json()
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
score = min(1.0, len(content) / 100) # Placeholder scoring
|
||||
scores.append(score)
|
||||
|
||||
avg_score = sum(scores) / len(scores)
|
||||
print(f"Average evaluation score: {avg_score}")
|
||||
|
||||
with open("/tmp/score", "w") as f:
|
||||
f.write(str(round(avg_score, 3)))
|
||||
|
||||
# Quality gate
|
||||
- name: quality-gate
|
||||
inputs:
|
||||
parameters:
|
||||
- name: eval-score
|
||||
- name: threshold
|
||||
outputs:
|
||||
parameters:
|
||||
- name: passed
|
||||
valueFrom:
|
||||
path: /tmp/passed
|
||||
script:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
source: |
|
||||
score = float("{{inputs.parameters.eval-score}}")
|
||||
threshold = float("{{inputs.parameters.threshold}}")
|
||||
|
||||
passed = score >= threshold
|
||||
print(f"Score {score} {'passed' if passed else 'failed'} threshold {threshold}")
|
||||
|
||||
with open("/tmp/passed", "w") as f:
|
||||
f.write(str(passed).lower())
|
||||
|
||||
# Upload to S3
|
||||
- name: upload-to-s3
|
||||
inputs:
|
||||
parameters:
|
||||
- name: output-path
|
||||
artifacts:
|
||||
- name: adapter
|
||||
path: /input/adapter
|
||||
script:
|
||||
image: amazon/aws-cli:latest
|
||||
command: [bash]
|
||||
env:
|
||||
- name: AWS_ACCESS_KEY_ID
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: s3-credentials
|
||||
key: access-key
|
||||
- name: AWS_SECRET_ACCESS_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: s3-credentials
|
||||
key: secret-key
|
||||
- name: AWS_ENDPOINT_URL
|
||||
value: "https://quobjects.billy.davies.cloud"
|
||||
source: |
|
||||
aws s3 cp --recursive /input/adapter "{{inputs.parameters.output-path}}/$(date +%Y%m%d-%H%M%S)/"
|
||||
echo "Uploaded adapter to {{inputs.parameters.output-path}}"
|
||||
|
||||
# Adjust parameters and retry
|
||||
- name: adjust-and-retry
|
||||
inputs:
|
||||
parameters:
|
||||
- name: current-rank
|
||||
- name: current-epochs
|
||||
script:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
source: |
|
||||
current_rank = int("{{inputs.parameters.current-rank}}")
|
||||
current_epochs = int("{{inputs.parameters.current-epochs}}")
|
||||
|
||||
# Increase rank and epochs for next attempt
|
||||
new_rank = min(64, current_rank * 2)
|
||||
new_epochs = current_epochs + 2
|
||||
|
||||
print(f"Adjusting parameters: rank {current_rank}->{new_rank}, epochs {current_epochs}->{new_epochs}")
|
||||
print("TODO: Trigger new workflow with adjusted parameters")
|
||||
|
||||
# NATS notification
|
||||
- name: nats-notify
|
||||
inputs:
|
||||
parameters:
|
||||
- name: subject
|
||||
- name: message
|
||||
script:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
source: |
|
||||
import subprocess
|
||||
import sys
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "nats-py"])
|
||||
|
||||
import asyncio
|
||||
import nats
|
||||
|
||||
async def notify():
|
||||
nc = await nats.connect("nats://nats.ai-ml.svc.cluster.local:4222")
|
||||
await nc.publish(
|
||||
"{{inputs.parameters.subject}}",
|
||||
b'''{{inputs.parameters.message}}'''
|
||||
)
|
||||
await nc.close()
|
||||
|
||||
asyncio.run(notify())
|
||||
237
kfp-integration.yaml
Normal file
237
kfp-integration.yaml
Normal file
@@ -0,0 +1,237 @@
|
||||
# Argo Workflows + Kubeflow Pipelines Integration
|
||||
# This template allows Argo Workflows to trigger KFP pipelines and vice versa
|
||||
---
|
||||
apiVersion: argoproj.io/v1alpha1
|
||||
kind: WorkflowTemplate
|
||||
metadata:
|
||||
name: kfp-trigger
|
||||
namespace: ai-ml
|
||||
labels:
|
||||
app.kubernetes.io/name: kfp-trigger
|
||||
app.kubernetes.io/part-of: llm-workflows
|
||||
spec:
|
||||
entrypoint: trigger-kfp-pipeline
|
||||
serviceAccountName: argo-workflow
|
||||
|
||||
arguments:
|
||||
parameters:
|
||||
- name: pipeline-id
|
||||
description: "Kubeflow Pipeline ID or name"
|
||||
- name: pipeline-params
|
||||
description: "JSON object of pipeline parameters"
|
||||
value: "{}"
|
||||
- name: experiment-name
|
||||
description: "KFP Experiment to use"
|
||||
value: "Default"
|
||||
- name: wait-for-completion
|
||||
description: "Wait for pipeline to complete"
|
||||
value: "true"
|
||||
|
||||
templates:
|
||||
- name: trigger-kfp-pipeline
|
||||
steps:
|
||||
- - name: submit-run
|
||||
template: submit-kfp-run
|
||||
arguments:
|
||||
parameters:
|
||||
- name: pipeline-id
|
||||
value: "{{workflow.parameters.pipeline-id}}"
|
||||
- name: pipeline-params
|
||||
value: "{{workflow.parameters.pipeline-params}}"
|
||||
- name: experiment-name
|
||||
value: "{{workflow.parameters.experiment-name}}"
|
||||
|
||||
- - name: wait-completion
|
||||
template: wait-for-kfp
|
||||
when: "{{workflow.parameters.wait-for-completion}} == true"
|
||||
arguments:
|
||||
parameters:
|
||||
- name: run-id
|
||||
value: "{{steps.submit-run.outputs.parameters.run-id}}"
|
||||
|
||||
- name: submit-kfp-run
|
||||
inputs:
|
||||
parameters:
|
||||
- name: pipeline-id
|
||||
- name: pipeline-params
|
||||
- name: experiment-name
|
||||
outputs:
|
||||
parameters:
|
||||
- name: run-id
|
||||
valueFrom:
|
||||
path: /tmp/run-id
|
||||
script:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
source: |
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "kfp==2.12.1"])
|
||||
|
||||
from kfp import Client
|
||||
|
||||
KUBEFLOW_HOST = "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
|
||||
|
||||
client = Client(host=KUBEFLOW_HOST)
|
||||
|
||||
pipeline_id = "{{inputs.parameters.pipeline-id}}"
|
||||
params = json.loads('''{{inputs.parameters.pipeline-params}}''')
|
||||
experiment_name = "{{inputs.parameters.experiment-name}}"
|
||||
|
||||
# Get or create experiment
|
||||
try:
|
||||
experiment = client.get_experiment(experiment_name=experiment_name)
|
||||
except:
|
||||
experiment = client.create_experiment(name=experiment_name)
|
||||
|
||||
# Get pipeline by name or ID
|
||||
try:
|
||||
pipeline = client.get_pipeline(pipeline_id)
|
||||
except:
|
||||
# Try by name
|
||||
pipelines = client.list_pipelines(filter=f'name="{pipeline_id}"')
|
||||
if pipelines.pipelines:
|
||||
pipeline = pipelines.pipelines[0]
|
||||
else:
|
||||
raise ValueError(f"Pipeline not found: {pipeline_id}")
|
||||
|
||||
# Create run
|
||||
run = client.run_pipeline(
|
||||
experiment_id=experiment.experiment_id,
|
||||
job_name=f"{pipeline.display_name}-argo-{pipeline_id[:8]}",
|
||||
pipeline_id=pipeline.pipeline_id,
|
||||
params=params
|
||||
)
|
||||
|
||||
print(f"Submitted KFP run: {run.run_id}")
|
||||
with open("/tmp/run-id", "w") as f:
|
||||
f.write(run.run_id)
|
||||
|
||||
- name: wait-for-kfp
|
||||
inputs:
|
||||
parameters:
|
||||
- name: run-id
|
||||
outputs:
|
||||
parameters:
|
||||
- name: status
|
||||
valueFrom:
|
||||
path: /tmp/status
|
||||
script:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
source: |
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "kfp==2.12.1"])
|
||||
|
||||
from kfp import Client
|
||||
|
||||
KUBEFLOW_HOST = "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
|
||||
run_id = "{{inputs.parameters.run-id}}"
|
||||
|
||||
client = Client(host=KUBEFLOW_HOST)
|
||||
|
||||
while True:
|
||||
run = client.get_run(run_id)
|
||||
state = run.run.status
|
||||
|
||||
print(f"Run {run_id} status: {state}")
|
||||
|
||||
if state in ["SUCCEEDED", "SKIPPED"]:
|
||||
with open("/tmp/status", "w") as f:
|
||||
f.write("SUCCEEDED")
|
||||
break
|
||||
elif state in ["FAILED", "ERROR", "CANCELLED"]:
|
||||
with open("/tmp/status", "w") as f:
|
||||
f.write(state)
|
||||
raise Exception(f"Pipeline failed with status: {state}")
|
||||
|
||||
time.sleep(30)
|
||||
|
||||
---
|
||||
# WorkflowTemplate for running KFP pipeline components as Argo steps
|
||||
apiVersion: argoproj.io/v1alpha1
|
||||
kind: WorkflowTemplate
|
||||
metadata:
|
||||
name: kfp-component-runner
|
||||
namespace: ai-ml
|
||||
labels:
|
||||
app.kubernetes.io/name: kfp-component-runner
|
||||
app.kubernetes.io/part-of: llm-workflows
|
||||
spec:
|
||||
entrypoint: run-component
|
||||
serviceAccountName: argo-workflow
|
||||
|
||||
arguments:
|
||||
parameters:
|
||||
- name: component-name
|
||||
description: "Name of the KFP component to run"
|
||||
- name: component-params
|
||||
description: "JSON parameters for the component"
|
||||
value: "{}"
|
||||
|
||||
templates:
|
||||
- name: run-component
|
||||
inputs:
|
||||
parameters:
|
||||
- name: component-name
|
||||
- name: component-params
|
||||
outputs:
|
||||
parameters:
|
||||
- name: result
|
||||
valueFrom:
|
||||
path: /tmp/result.json
|
||||
script:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
source: |
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
subprocess.check_call([
|
||||
sys.executable, "-m", "pip", "install", "-q",
|
||||
"httpx", "pymilvus"
|
||||
])
|
||||
|
||||
import httpx
|
||||
|
||||
component_name = "{{inputs.parameters.component-name}}"
|
||||
params = json.loads('''{{inputs.parameters.component-params}}''')
|
||||
|
||||
# Component implementations (mirrors KFP components)
|
||||
COMPONENTS = {
|
||||
"transcribe_audio": {
|
||||
"url": "http://whisper-predictor.ai-ml.svc.cluster.local",
|
||||
"endpoint": "/v1/audio/transcriptions"
|
||||
},
|
||||
"generate_embeddings": {
|
||||
"url": "http://embeddings-predictor.ai-ml.svc.cluster.local",
|
||||
"endpoint": "/embeddings"
|
||||
},
|
||||
"generate_response": {
|
||||
"url": "http://llm-draft.ai-ml.svc.cluster.local:8000",
|
||||
"endpoint": "/v1/chat/completions"
|
||||
},
|
||||
"synthesize_speech": {
|
||||
"url": "http://tts-predictor.ai-ml.svc.cluster.local",
|
||||
"endpoint": "/v1/audio/speech"
|
||||
}
|
||||
}
|
||||
|
||||
if component_name not in COMPONENTS:
|
||||
raise ValueError(f"Unknown component: {component_name}")
|
||||
|
||||
config = COMPONENTS[component_name]
|
||||
with httpx.Client(timeout=120.0) as client:
|
||||
response = client.post(
|
||||
f"{config['url']}{config['endpoint']}",
|
||||
json=params
|
||||
)
|
||||
result = response.json()
|
||||
|
||||
with open("/tmp/result.json", "w") as f:
|
||||
json.dump(result, f)
|
||||
|
||||
print(f"Component {component_name} completed")
|
||||
510
qlora-training.yaml
Normal file
510
qlora-training.yaml
Normal file
@@ -0,0 +1,510 @@
|
||||
# QLoRA Fine-tuning Workflow
|
||||
# Trains QLora adapters from a reference model using data from Milvus vector database
|
||||
# Triggered via NATS: ai.pipeline.trigger with pipeline="qlora-training"
|
||||
---
|
||||
apiVersion: argoproj.io/v1alpha1
|
||||
kind: WorkflowTemplate
|
||||
metadata:
|
||||
name: qlora-training
|
||||
namespace: ai-ml
|
||||
labels:
|
||||
app.kubernetes.io/name: qlora-training
|
||||
app.kubernetes.io/part-of: llm-workflows
|
||||
spec:
|
||||
entrypoint: train-qlora
|
||||
serviceAccountName: argo-workflow
|
||||
|
||||
arguments:
|
||||
parameters:
|
||||
- name: reference-model
|
||||
description: "Base model to fine-tune (HuggingFace model ID or path)"
|
||||
value: "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
- name: output-name
|
||||
description: "Name for the output QLora adapter"
|
||||
value: "qlora-adapter"
|
||||
- name: milvus-collections
|
||||
description: "Comma-separated list of Milvus collections to use (empty = all available)"
|
||||
value: ""
|
||||
- name: learning-rate
|
||||
value: "2e-4"
|
||||
description: "Learning rate for training"
|
||||
- name: num-epochs
|
||||
value: "3"
|
||||
description: "Number of training epochs"
|
||||
- name: batch-size
|
||||
value: "4"
|
||||
description: "Training batch size"
|
||||
- name: max-seq-length
|
||||
value: "2048"
|
||||
description: "Maximum sequence length"
|
||||
- name: lora-r
|
||||
value: "64"
|
||||
description: "LoRA attention dimension"
|
||||
- name: lora-alpha
|
||||
value: "16"
|
||||
description: "LoRA alpha parameter"
|
||||
- name: lora-dropout
|
||||
value: "0.05"
|
||||
description: "LoRA dropout rate"
|
||||
|
||||
volumeClaimTemplates:
|
||||
- metadata:
|
||||
name: model-storage
|
||||
spec:
|
||||
accessModes: ["ReadWriteMany"]
|
||||
storageClassName: nfs-slow
|
||||
resources:
|
||||
requests:
|
||||
storage: 50Gi
|
||||
|
||||
templates:
|
||||
- name: train-qlora
|
||||
dag:
|
||||
tasks:
|
||||
- name: fetch-training-data
|
||||
template: fetch-data
|
||||
arguments:
|
||||
parameters:
|
||||
- name: milvus-collections
|
||||
value: "{{workflow.parameters.milvus-collections}}"
|
||||
|
||||
- name: prepare-dataset
|
||||
template: prepare-data
|
||||
dependencies: [fetch-training-data]
|
||||
arguments:
|
||||
parameters:
|
||||
- name: max-seq-length
|
||||
value: "{{workflow.parameters.max-seq-length}}"
|
||||
artifacts:
|
||||
- name: raw-data
|
||||
from: "{{tasks.fetch-training-data.outputs.artifacts.raw-data}}"
|
||||
|
||||
- name: train-model
|
||||
template: train
|
||||
dependencies: [prepare-dataset]
|
||||
arguments:
|
||||
parameters:
|
||||
- name: reference-model
|
||||
value: "{{workflow.parameters.reference-model}}"
|
||||
- name: output-name
|
||||
value: "{{workflow.parameters.output-name}}"
|
||||
- name: learning-rate
|
||||
value: "{{workflow.parameters.learning-rate}}"
|
||||
- name: num-epochs
|
||||
value: "{{workflow.parameters.num-epochs}}"
|
||||
- name: batch-size
|
||||
value: "{{workflow.parameters.batch-size}}"
|
||||
- name: max-seq-length
|
||||
value: "{{workflow.parameters.max-seq-length}}"
|
||||
- name: lora-r
|
||||
value: "{{workflow.parameters.lora-r}}"
|
||||
- name: lora-alpha
|
||||
value: "{{workflow.parameters.lora-alpha}}"
|
||||
- name: lora-dropout
|
||||
value: "{{workflow.parameters.lora-dropout}}"
|
||||
artifacts:
|
||||
- name: training-data
|
||||
from: "{{tasks.prepare-dataset.outputs.artifacts.training-data}}"
|
||||
|
||||
- name: fetch-data
|
||||
inputs:
|
||||
parameters:
|
||||
- name: milvus-collections
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: raw-data
|
||||
path: /tmp/raw-data
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
import subprocess
|
||||
subprocess.run(["pip", "install", "pymilvus", "-q"], check=True)
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from pymilvus import connections, Collection, utility
|
||||
|
||||
MILVUS_HOST = "milvus.ai-ml.svc.cluster.local"
|
||||
MILVUS_PORT = 19530
|
||||
collections_param = "{{inputs.parameters.milvus-collections}}"
|
||||
|
||||
output_dir = Path("/tmp/raw-data")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Connecting to Milvus at {MILVUS_HOST}:{MILVUS_PORT}")
|
||||
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
|
||||
|
||||
# Determine which collections to use
|
||||
if collections_param and collections_param.strip():
|
||||
collection_names = [c.strip() for c in collections_param.split(",")]
|
||||
print(f"Using specified collections: {collection_names}")
|
||||
else:
|
||||
# Get all available collections
|
||||
collection_names = utility.list_collections()
|
||||
print(f"Using all available collections: {collection_names}")
|
||||
|
||||
all_training_data = []
|
||||
|
||||
for collection_name in collection_names:
|
||||
if not utility.has_collection(collection_name):
|
||||
print(f"Warning: Collection {collection_name} not found, skipping")
|
||||
continue
|
||||
|
||||
print(f"Fetching data from collection: {collection_name}")
|
||||
collection = Collection(collection_name)
|
||||
collection.load()
|
||||
|
||||
# Query all data from the collection
|
||||
# Note: Adjust field names based on your schema
|
||||
try:
|
||||
# Get collection schema to determine fields
|
||||
schema = collection.schema
|
||||
field_names = [field.name for field in schema.fields if field.name != "id"]
|
||||
|
||||
# Query all entities (limited to reasonable batch size)
|
||||
# For large collections, you may want to implement pagination
|
||||
results = collection.query(
|
||||
expr="id >= 0",
|
||||
output_fields=field_names,
|
||||
limit=100000
|
||||
)
|
||||
|
||||
print(f" Retrieved {len(results)} records from {collection_name}")
|
||||
|
||||
for result in results:
|
||||
# Extract text field (adjust based on your schema)
|
||||
text_content = result.get("text", "")
|
||||
source = result.get("source", collection_name)
|
||||
|
||||
if text_content:
|
||||
all_training_data.append({
|
||||
"text": text_content,
|
||||
"source": source,
|
||||
"collection": collection_name
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error querying collection {collection_name}: {e}")
|
||||
continue
|
||||
|
||||
# Save all training data
|
||||
output_file = output_dir / "training_data.json"
|
||||
with open(output_file, "w") as f:
|
||||
json.dump({"data": all_training_data}, f)
|
||||
|
||||
print(f"Total training samples collected: {len(all_training_data)}")
|
||||
print(f"Saved to {output_file}")
|
||||
|
||||
connections.disconnect("default")
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: ai-services-config
|
||||
resources:
|
||||
requests:
|
||||
memory: 1Gi
|
||||
cpu: 500m
|
||||
|
||||
- name: prepare-data
|
||||
inputs:
|
||||
parameters:
|
||||
- name: max-seq-length
|
||||
artifacts:
|
||||
- name: raw-data
|
||||
path: /tmp/raw-data
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: training-data
|
||||
path: /tmp/training-data
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
max_seq_length = int("{{inputs.parameters.max-seq-length}}")
|
||||
|
||||
input_dir = Path("/tmp/raw-data")
|
||||
output_dir = Path("/tmp/training-data")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load raw data
|
||||
with open(input_dir / "training_data.json") as f:
|
||||
data = json.load(f)
|
||||
|
||||
raw_samples = data["data"]
|
||||
print(f"Processing {len(raw_samples)} raw samples")
|
||||
|
||||
# Prepare data in instruction format for fine-tuning
|
||||
# Using Alpaca-style format: instruction + response
|
||||
training_samples = []
|
||||
|
||||
for sample in raw_samples:
|
||||
text = sample["text"]
|
||||
source = sample.get("source", "")
|
||||
|
||||
# Create instruction-response pairs
|
||||
# You can customize this based on your use case
|
||||
training_sample = {
|
||||
"instruction": f"Based on the following information from {source}, provide a comprehensive response:",
|
||||
"input": text[:max_seq_length // 2], # Truncate if needed
|
||||
"output": text[:max_seq_length // 2],
|
||||
"source": source
|
||||
}
|
||||
training_samples.append(training_sample)
|
||||
|
||||
# Split into train/validation (90/10)
|
||||
split_idx = int(len(training_samples) * 0.9)
|
||||
train_data = training_samples[:split_idx]
|
||||
val_data = training_samples[split_idx:]
|
||||
|
||||
# Save prepared datasets
|
||||
with open(output_dir / "train.json", "w") as f:
|
||||
json.dump(train_data, f)
|
||||
|
||||
with open(output_dir / "validation.json", "w") as f:
|
||||
json.dump(val_data, f)
|
||||
|
||||
print(f"Prepared {len(train_data)} training samples")
|
||||
print(f"Prepared {len(val_data)} validation samples")
|
||||
print("Data preparation complete")
|
||||
resources:
|
||||
requests:
|
||||
memory: 2Gi
|
||||
cpu: 500m
|
||||
|
||||
- name: train
|
||||
inputs:
|
||||
parameters:
|
||||
- name: reference-model
|
||||
- name: output-name
|
||||
- name: learning-rate
|
||||
- name: num-epochs
|
||||
- name: batch-size
|
||||
- name: max-seq-length
|
||||
- name: lora-r
|
||||
- name: lora-alpha
|
||||
- name: lora-dropout
|
||||
artifacts:
|
||||
- name: training-data
|
||||
path: /tmp/training-data
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [bash]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
set -e
|
||||
|
||||
echo "Installing dependencies..."
|
||||
pip install -q torch transformers peft datasets accelerate bitsandbytes scipy
|
||||
|
||||
echo "Starting QLoRA training..."
|
||||
|
||||
python << 'EOF'
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datasets import Dataset
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
TrainingArguments,
|
||||
Trainer,
|
||||
DataCollatorForLanguageModeling
|
||||
)
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
import torch
|
||||
|
||||
# Parameters
|
||||
reference_model = "{{inputs.parameters.reference-model}}"
|
||||
output_name = "{{inputs.parameters.output-name}}"
|
||||
learning_rate = float("{{inputs.parameters.learning-rate}}")
|
||||
num_epochs = int("{{inputs.parameters.num-epochs}}")
|
||||
batch_size = int("{{inputs.parameters.batch-size}}")
|
||||
max_seq_length = int("{{inputs.parameters.max-seq-length}}")
|
||||
lora_r = int("{{inputs.parameters.lora-r}}")
|
||||
lora_alpha = int("{{inputs.parameters.lora-alpha}}")
|
||||
lora_dropout = float("{{inputs.parameters.lora-dropout}}")
|
||||
|
||||
data_dir = Path("/tmp/training-data")
|
||||
output_dir = Path("/mnt/model-storage") / output_name
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Training configuration:")
|
||||
print(f" Model: {reference_model}")
|
||||
print(f" Learning rate: {learning_rate}")
|
||||
print(f" Epochs: {num_epochs}")
|
||||
print(f" Batch size: {batch_size}")
|
||||
print(f" Max sequence length: {max_seq_length}")
|
||||
print(f" LoRA r: {lora_r}, alpha: {lora_alpha}, dropout: {lora_dropout}")
|
||||
|
||||
# Load datasets
|
||||
with open(data_dir / "train.json") as f:
|
||||
train_data = json.load(f)
|
||||
|
||||
with open(data_dir / "validation.json") as f:
|
||||
val_data = json.load(f)
|
||||
|
||||
print(f"Loaded {len(train_data)} training samples, {len(val_data)} validation samples")
|
||||
|
||||
# Load tokenizer
|
||||
print(f"Loading tokenizer from {reference_model}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(reference_model, trust_remote_code=True)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Prepare datasets
|
||||
def format_sample(sample):
|
||||
instruction = sample.get("instruction", "")
|
||||
input_text = sample.get("input", "")
|
||||
output = sample.get("output", "")
|
||||
|
||||
if input_text:
|
||||
prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{output}"
|
||||
else:
|
||||
prompt = f"### Instruction:\n{instruction}\n\n### Response:\n{output}"
|
||||
|
||||
return {"text": prompt}
|
||||
|
||||
train_dataset = Dataset.from_list([format_sample(s) for s in train_data])
|
||||
val_dataset = Dataset.from_list([format_sample(s) for s in val_data])
|
||||
|
||||
# Tokenize datasets
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(
|
||||
examples["text"],
|
||||
truncation=True,
|
||||
max_length=max_seq_length,
|
||||
padding="max_length"
|
||||
)
|
||||
|
||||
print("Tokenizing datasets...")
|
||||
train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
|
||||
val_dataset = val_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
|
||||
|
||||
# Configure quantization for QLoRA
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
# Load model
|
||||
print(f"Loading model {reference_model} with 4-bit quantization...")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
reference_model,
|
||||
quantization_config=bnb_config,
|
||||
device_map="auto",
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Prepare model for training
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
# Configure LoRA
|
||||
lora_config = LoraConfig(
|
||||
r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
||||
lora_dropout=lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
# Add LoRA adapters
|
||||
print("Adding LoRA adapters...")
|
||||
model = get_peft_model(model, lora_config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
# Training arguments
|
||||
training_args = TrainingArguments(
|
||||
output_dir=str(output_dir / "checkpoints"),
|
||||
num_train_epochs=num_epochs,
|
||||
per_device_train_batch_size=batch_size,
|
||||
per_device_eval_batch_size=batch_size,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=learning_rate,
|
||||
fp16=True,
|
||||
logging_steps=10,
|
||||
evaluation_strategy="steps",
|
||||
eval_steps=50,
|
||||
save_strategy="steps",
|
||||
save_steps=100,
|
||||
save_total_limit=3,
|
||||
load_best_model_at_end=True,
|
||||
report_to="none",
|
||||
remove_unused_columns=False,
|
||||
)
|
||||
|
||||
# Data collator
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer,
|
||||
mlm=False
|
||||
)
|
||||
|
||||
# Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=val_dataset,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
# Train
|
||||
print("Starting training...")
|
||||
trainer.train()
|
||||
|
||||
# Save final model
|
||||
print(f"Saving QLora adapter to {output_dir}")
|
||||
model.save_pretrained(str(output_dir / "final"))
|
||||
tokenizer.save_pretrained(str(output_dir / "final"))
|
||||
|
||||
# Save training metadata
|
||||
metadata = {
|
||||
"reference_model": reference_model,
|
||||
"output_name": output_name,
|
||||
"training_params": {
|
||||
"learning_rate": learning_rate,
|
||||
"num_epochs": num_epochs,
|
||||
"batch_size": batch_size,
|
||||
"max_seq_length": max_seq_length,
|
||||
"lora_r": lora_r,
|
||||
"lora_alpha": lora_alpha,
|
||||
"lora_dropout": lora_dropout
|
||||
},
|
||||
"dataset_info": {
|
||||
"train_samples": len(train_data),
|
||||
"val_samples": len(val_data)
|
||||
}
|
||||
}
|
||||
|
||||
with open(output_dir / "metadata.json", "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
print("Training complete!")
|
||||
print(f"QLora adapter saved to: {output_dir}")
|
||||
EOF
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: ai-services-config
|
||||
volumeMounts:
|
||||
- name: model-storage
|
||||
mountPath: /mnt/model-storage
|
||||
resources:
|
||||
requests:
|
||||
memory: 16Gi
|
||||
cpu: 4
|
||||
nvidia.com/gpu: 1
|
||||
limits:
|
||||
memory: 32Gi
|
||||
cpu: 8
|
||||
nvidia.com/gpu: 1
|
||||
Reference in New Issue
Block a user