feat: Add ML training and batch inference workflows
- batch-inference: LLM inference with optional RAG - qlora-training: QLoRA adapter fine-tuning from Milvus - hybrid-ml-training: Multi-GPU distributed training - coqui-voice-training: XTTS voice cloning - document-ingestion: Ingest documents to Milvus - eventsource-kfp: Argo Events / Kubeflow integration - kfp-integration: Bridge between Argo and Kubeflow
This commit is contained in:
128
README.md
128
README.md
@@ -1,2 +1,128 @@
|
|||||||
# argo
|
# Argo Workflows
|
||||||
|
|
||||||
|
ML training and batch inference workflows for the DaviesTechLabs AI/ML platform.
|
||||||
|
|
||||||
|
## Workflows
|
||||||
|
|
||||||
|
| Workflow | Description | Trigger |
|
||||||
|
|----------|-------------|---------|
|
||||||
|
| `batch-inference` | Run LLM inference on batch inputs | `ai.pipeline.trigger` (pipeline="batch-inference") |
|
||||||
|
| `qlora-training` | Train QLoRA adapters from Milvus data | `ai.pipeline.trigger` (pipeline="qlora-training") |
|
||||||
|
| `hybrid-ml-training` | Multi-GPU distributed training | `ai.pipeline.trigger` (pipeline="hybrid-ml-training") |
|
||||||
|
| `coqui-voice-training` | XTTS voice cloning/training | `ai.pipeline.trigger` (pipeline="coqui-voice-training") |
|
||||||
|
| `document-ingestion` | Ingest documents into Milvus | `ai.pipeline.trigger` (pipeline="document-ingestion") |
|
||||||
|
|
||||||
|
## Integration
|
||||||
|
|
||||||
|
| File | Description |
|
||||||
|
|------|-------------|
|
||||||
|
| `eventsource-kfp.yaml` | Argo Events source for Kubeflow Pipelines integration |
|
||||||
|
| `kfp-integration.yaml` | Bridge workflows between Argo and Kubeflow |
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
NATS (ai.pipeline.trigger)
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────────┐
|
||||||
|
│ Argo Events │
|
||||||
|
│ EventSource │
|
||||||
|
└─────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────────┐
|
||||||
|
│ Argo Sensor │
|
||||||
|
└─────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────────┐
|
||||||
|
│ WorkflowTemplate│
|
||||||
|
│ (batch-inf, │
|
||||||
|
│ qlora, etc) │
|
||||||
|
└─────────────────┘
|
||||||
|
│
|
||||||
|
├──▶ GPU Pods (AMD ROCm / NVIDIA CUDA)
|
||||||
|
├──▶ Milvus Vector DB
|
||||||
|
├──▶ vLLM / Ray Serve
|
||||||
|
└──▶ MLflow Tracking
|
||||||
|
```
|
||||||
|
|
||||||
|
## Workflow Details
|
||||||
|
|
||||||
|
### batch-inference
|
||||||
|
|
||||||
|
Batch LLM inference with optional RAG:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
argo submit batch-inference.yaml \
|
||||||
|
-p input-url="s3://bucket/inputs.json" \
|
||||||
|
-p output-url="s3://bucket/outputs.json" \
|
||||||
|
-p use-rag="true" \
|
||||||
|
-p max-tokens="500"
|
||||||
|
```
|
||||||
|
|
||||||
|
### qlora-training
|
||||||
|
|
||||||
|
Fine-tune QLoRA adapters from Milvus knowledge:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
argo submit qlora-training.yaml \
|
||||||
|
-p reference-model="mistralai/Mistral-7B-Instruct-v0.3" \
|
||||||
|
-p output-name="my-adapter" \
|
||||||
|
-p milvus-collections="docs,wiki" \
|
||||||
|
-p num-epochs="3"
|
||||||
|
```
|
||||||
|
|
||||||
|
### coqui-voice-training
|
||||||
|
|
||||||
|
Train XTTS voice models:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
argo submit coqui-voice-training.yaml \
|
||||||
|
-p voice-name="my-voice" \
|
||||||
|
-p audio-samples-url="s3://bucket/samples/"
|
||||||
|
```
|
||||||
|
|
||||||
|
### document-ingestion
|
||||||
|
|
||||||
|
Ingest documents into Milvus:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
argo submit document-ingestion.yaml \
|
||||||
|
-p source-url="s3://bucket/docs/" \
|
||||||
|
-p collection="knowledge_base" \
|
||||||
|
-p chunk-size="512"
|
||||||
|
```
|
||||||
|
|
||||||
|
## NATS Trigger Format
|
||||||
|
|
||||||
|
Workflows are triggered via NATS `ai.pipeline.trigger`:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"pipeline": "qlora-training",
|
||||||
|
"parameters": {
|
||||||
|
"reference-model": "mistralai/Mistral-7B-Instruct-v0.3",
|
||||||
|
"output-name": "custom-adapter",
|
||||||
|
"num-epochs": "5"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## GPU Scheduling
|
||||||
|
|
||||||
|
Workflows use node affinity for GPU allocation:
|
||||||
|
|
||||||
|
| Node | GPU | Best For |
|
||||||
|
|------|-----|----------|
|
||||||
|
| khelben | AMD Strix Halo 64GB | Large model training, vLLM |
|
||||||
|
| elminster | NVIDIA RTX 2070 | Whisper, XTTS |
|
||||||
|
| drizzt | AMD Radeon 680M | Embeddings |
|
||||||
|
| danilo | Intel Arc | Reranker |
|
||||||
|
|
||||||
|
## Related
|
||||||
|
|
||||||
|
- [homelab-design](https://git.daviestechlabs.io/daviestechlabs/homelab-design) - Architecture docs
|
||||||
|
- [kuberay-images](https://git.daviestechlabs.io/daviestechlabs/kuberay-images) - Ray worker images
|
||||||
|
- [handler-base](https://git.daviestechlabs.io/daviestechlabs/handler-base) - Handler library
|
||||||
|
|||||||
328
batch-inference.yaml
Normal file
328
batch-inference.yaml
Normal file
@@ -0,0 +1,328 @@
|
|||||||
|
# Batch Inference Workflow
|
||||||
|
# Runs LLM inference on a batch of inputs
|
||||||
|
# Triggered via NATS: ai.pipeline.trigger with pipeline="batch-inference"
|
||||||
|
---
|
||||||
|
apiVersion: argoproj.io/v1alpha1
|
||||||
|
kind: WorkflowTemplate
|
||||||
|
metadata:
|
||||||
|
name: batch-inference
|
||||||
|
namespace: ai-ml
|
||||||
|
labels:
|
||||||
|
app.kubernetes.io/name: batch-inference
|
||||||
|
app.kubernetes.io/part-of: llm-workflows
|
||||||
|
spec:
|
||||||
|
entrypoint: batch-inference
|
||||||
|
serviceAccountName: argo-workflow
|
||||||
|
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: input-url
|
||||||
|
description: "URL to JSON file with inference requests"
|
||||||
|
- name: output-url
|
||||||
|
description: "URL to store results (S3 path)"
|
||||||
|
value: ""
|
||||||
|
- name: use-rag
|
||||||
|
value: "true"
|
||||||
|
description: "Whether to use RAG for context"
|
||||||
|
- name: max-tokens
|
||||||
|
value: "500"
|
||||||
|
description: "Maximum tokens per response"
|
||||||
|
- name: temperature
|
||||||
|
value: "0.7"
|
||||||
|
description: "LLM temperature"
|
||||||
|
|
||||||
|
templates:
|
||||||
|
- name: batch-inference
|
||||||
|
dag:
|
||||||
|
tasks:
|
||||||
|
- name: fetch-inputs
|
||||||
|
template: fetch-input-data
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: input-url
|
||||||
|
value: "{{workflow.parameters.input-url}}"
|
||||||
|
|
||||||
|
- name: run-inference
|
||||||
|
template: inference
|
||||||
|
dependencies: [fetch-inputs]
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: use-rag
|
||||||
|
value: "{{workflow.parameters.use-rag}}"
|
||||||
|
- name: max-tokens
|
||||||
|
value: "{{workflow.parameters.max-tokens}}"
|
||||||
|
- name: temperature
|
||||||
|
value: "{{workflow.parameters.temperature}}"
|
||||||
|
artifacts:
|
||||||
|
- name: inputs
|
||||||
|
from: "{{tasks.fetch-inputs.outputs.artifacts.inputs}}"
|
||||||
|
|
||||||
|
- name: upload-results
|
||||||
|
template: upload-output
|
||||||
|
dependencies: [run-inference]
|
||||||
|
when: "{{workflow.parameters.output-url}} != ''"
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: output-url
|
||||||
|
value: "{{workflow.parameters.output-url}}"
|
||||||
|
artifacts:
|
||||||
|
- name: results
|
||||||
|
from: "{{tasks.run-inference.outputs.artifacts.results}}"
|
||||||
|
|
||||||
|
- name: fetch-input-data
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: input-url
|
||||||
|
outputs:
|
||||||
|
artifacts:
|
||||||
|
- name: inputs
|
||||||
|
path: /tmp/inputs
|
||||||
|
container:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
import json
|
||||||
|
import urllib.request
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
input_url = "{{inputs.parameters.input-url}}"
|
||||||
|
output_dir = Path("/tmp/inputs")
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"Fetching inputs from: {input_url}")
|
||||||
|
|
||||||
|
if input_url.startswith("s3://"):
|
||||||
|
import subprocess
|
||||||
|
subprocess.run(["pip", "install", "boto3", "-q"], check=True)
|
||||||
|
import boto3
|
||||||
|
s3 = boto3.client("s3")
|
||||||
|
bucket, key = input_url[5:].split("/", 1)
|
||||||
|
s3.download_file(bucket, key, str(output_dir / "inputs.json"))
|
||||||
|
elif input_url.startswith("http"):
|
||||||
|
urllib.request.urlretrieve(input_url, output_dir / "inputs.json")
|
||||||
|
else:
|
||||||
|
print(f"Unsupported URL scheme: {input_url}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# Validate JSON structure
|
||||||
|
with open(output_dir / "inputs.json") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
if "requests" not in data:
|
||||||
|
print("Error: JSON must contain 'requests' array")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
print(f"Loaded {len(data['requests'])} inference requests")
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: 256Mi
|
||||||
|
cpu: 100m
|
||||||
|
|
||||||
|
- name: inference
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: use-rag
|
||||||
|
- name: max-tokens
|
||||||
|
- name: temperature
|
||||||
|
artifacts:
|
||||||
|
- name: inputs
|
||||||
|
path: /tmp/inputs
|
||||||
|
outputs:
|
||||||
|
artifacts:
|
||||||
|
- name: results
|
||||||
|
path: /tmp/results
|
||||||
|
container:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
import subprocess
|
||||||
|
subprocess.run(["pip", "install", "httpx", "pymilvus", "-q"], check=True)
|
||||||
|
|
||||||
|
import json
|
||||||
|
import httpx
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
VLLM_URL = "http://llm-draft.ai-ml.svc.cluster.local:8000"
|
||||||
|
EMBEDDINGS_URL = "http://embeddings-predictor.ai-ml.svc.cluster.local"
|
||||||
|
RERANKER_URL = "http://reranker-predictor.ai-ml.svc.cluster.local"
|
||||||
|
MILVUS_HOST = "milvus.ai-ml.svc.cluster.local"
|
||||||
|
LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||||
|
|
||||||
|
use_rag = "{{inputs.parameters.use-rag}}" == "true"
|
||||||
|
max_tokens = int("{{inputs.parameters.max-tokens}}")
|
||||||
|
temperature = float("{{inputs.parameters.temperature}}")
|
||||||
|
|
||||||
|
input_dir = Path("/tmp/inputs")
|
||||||
|
output_dir = Path("/tmp/results")
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Load inputs
|
||||||
|
with open(input_dir / "inputs.json") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
requests = data["requests"]
|
||||||
|
|
||||||
|
print(f"Processing {len(requests)} requests (RAG: {use_rag})")
|
||||||
|
|
||||||
|
# Initialize Milvus if using RAG
|
||||||
|
collection = None
|
||||||
|
if use_rag:
|
||||||
|
try:
|
||||||
|
from pymilvus import connections, Collection, utility
|
||||||
|
connections.connect(host=MILVUS_HOST, port=19530)
|
||||||
|
if utility.has_collection("knowledge_base"):
|
||||||
|
collection = Collection("knowledge_base")
|
||||||
|
collection.load()
|
||||||
|
print("Milvus connected")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Milvus connection failed: {e}")
|
||||||
|
use_rag = False
|
||||||
|
|
||||||
|
def get_embeddings(texts: List[str], client: httpx.Client) -> List[List[float]]:
|
||||||
|
response = client.post(
|
||||||
|
f"{EMBEDDINGS_URL}/embeddings",
|
||||||
|
json={"input": texts, "model": "bge"}
|
||||||
|
)
|
||||||
|
result = response.json()
|
||||||
|
return [d["embedding"] for d in result.get("data", [])]
|
||||||
|
|
||||||
|
def search_milvus(embedding: List[float]) -> List[Dict]:
|
||||||
|
results = collection.search(
|
||||||
|
data=[embedding],
|
||||||
|
anns_field="embedding",
|
||||||
|
param={"metric_type": "COSINE", "params": {"ef": 64}},
|
||||||
|
limit=5,
|
||||||
|
output_fields=["text", "source"]
|
||||||
|
)
|
||||||
|
docs = []
|
||||||
|
for hits in results:
|
||||||
|
for hit in hits:
|
||||||
|
docs.append({
|
||||||
|
"text": hit.entity.get("text", ""),
|
||||||
|
"source": hit.entity.get("source", ""),
|
||||||
|
"score": hit.score
|
||||||
|
})
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def rerank(query: str, documents: List[str], client: httpx.Client) -> List[Dict]:
|
||||||
|
response = client.post(
|
||||||
|
f"{RERANKER_URL}/v1/rerank",
|
||||||
|
json={"query": query, "documents": documents}
|
||||||
|
)
|
||||||
|
return response.json().get("results", [])
|
||||||
|
|
||||||
|
# Process requests
|
||||||
|
results = []
|
||||||
|
with httpx.Client(timeout=120.0) as client:
|
||||||
|
for i, req in enumerate(requests):
|
||||||
|
query = req.get("text", req.get("query", ""))
|
||||||
|
req_id = req.get("id", str(i))
|
||||||
|
|
||||||
|
print(f"Processing {i+1}/{len(requests)}: {query[:50]}...")
|
||||||
|
|
||||||
|
context = ""
|
||||||
|
rag_sources = []
|
||||||
|
|
||||||
|
if use_rag and collection:
|
||||||
|
try:
|
||||||
|
# Get embeddings and search
|
||||||
|
embeddings = get_embeddings([query], client)
|
||||||
|
if embeddings:
|
||||||
|
docs = search_milvus(embeddings[0])
|
||||||
|
if docs:
|
||||||
|
doc_texts = [d["text"] for d in docs]
|
||||||
|
reranked = rerank(query, doc_texts, client)
|
||||||
|
sorted_docs = sorted(reranked, key=lambda x: x.get("relevance_score", 0), reverse=True)[:3]
|
||||||
|
context = "\n\n".join([doc_texts[d["index"]] for d in sorted_docs])
|
||||||
|
rag_sources = [docs[d["index"]].get("source", "") for d in sorted_docs]
|
||||||
|
except Exception as e:
|
||||||
|
print(f" RAG failed: {e}")
|
||||||
|
|
||||||
|
# Generate response
|
||||||
|
try:
|
||||||
|
messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
|
||||||
|
if context:
|
||||||
|
messages.append({"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"})
|
||||||
|
else:
|
||||||
|
messages.append({"role": "user", "content": query})
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
f"{VLLM_URL}/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": LLM_MODEL,
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"temperature": temperature
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = response.json()
|
||||||
|
answer = result["choices"][0]["message"]["content"]
|
||||||
|
except Exception as e:
|
||||||
|
answer = f"Error: {e}"
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"id": req_id,
|
||||||
|
"query": query,
|
||||||
|
"response": answer,
|
||||||
|
"used_rag": bool(context),
|
||||||
|
"rag_sources": rag_sources
|
||||||
|
})
|
||||||
|
|
||||||
|
# Save results
|
||||||
|
with open(output_dir / "results.json", "w") as f:
|
||||||
|
json.dump({"results": results}, f, indent=2)
|
||||||
|
|
||||||
|
print(f"Completed {len(results)} inferences")
|
||||||
|
|
||||||
|
if collection:
|
||||||
|
from pymilvus import connections
|
||||||
|
connections.disconnect("default")
|
||||||
|
envFrom:
|
||||||
|
- configMapRef:
|
||||||
|
name: ai-services-config
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: 1Gi
|
||||||
|
cpu: 500m
|
||||||
|
|
||||||
|
- name: upload-output
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: output-url
|
||||||
|
artifacts:
|
||||||
|
- name: results
|
||||||
|
path: /tmp/results
|
||||||
|
container:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
import subprocess
|
||||||
|
subprocess.run(["pip", "install", "boto3", "-q"], check=True)
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
output_url = "{{inputs.parameters.output-url}}"
|
||||||
|
results_file = Path("/tmp/results/results.json")
|
||||||
|
|
||||||
|
print(f"Uploading results to: {output_url}")
|
||||||
|
|
||||||
|
if output_url.startswith("s3://"):
|
||||||
|
s3 = boto3.client("s3")
|
||||||
|
bucket, key = output_url[5:].split("/", 1)
|
||||||
|
s3.upload_file(str(results_file), bucket, key)
|
||||||
|
print("Upload complete")
|
||||||
|
else:
|
||||||
|
print(f"Unsupported URL scheme: {output_url}")
|
||||||
|
exit(1)
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: 256Mi
|
||||||
|
cpu: 100m
|
||||||
969
coqui-voice-training.yaml
Normal file
969
coqui-voice-training.yaml
Normal file
@@ -0,0 +1,969 @@
|
|||||||
|
# Coqui TTS Voice Training Workflow
|
||||||
|
# Trains a custom voice model using Coqui TTS from audio samples
|
||||||
|
# Triggered via NATS: ai.pipeline.trigger with pipeline="coqui-voice-training"
|
||||||
|
---
|
||||||
|
apiVersion: argoproj.io/v1alpha1
|
||||||
|
kind: WorkflowTemplate
|
||||||
|
metadata:
|
||||||
|
name: coqui-voice-training
|
||||||
|
namespace: ai-ml
|
||||||
|
labels:
|
||||||
|
app.kubernetes.io/name: coqui-voice-training
|
||||||
|
app.kubernetes.io/part-of: llm-workflows
|
||||||
|
spec:
|
||||||
|
entrypoint: train-voice
|
||||||
|
serviceAccountName: argo-workflow
|
||||||
|
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: audio-source
|
||||||
|
description: "URL to audio files (S3 bucket, HTTP, or NFS path with .wav/.mp3 files)"
|
||||||
|
- name: transcripts-source
|
||||||
|
description: "URL to transcripts file (CSV with audio_file,transcript columns) - leave empty to auto-transcribe"
|
||||||
|
value: ""
|
||||||
|
- name: voice-name
|
||||||
|
description: "Name for the trained voice model"
|
||||||
|
value: "custom-voice"
|
||||||
|
- name: base-model
|
||||||
|
description: "Base TTS model to fine-tune from"
|
||||||
|
value: "tts_models/en/ljspeech/vits"
|
||||||
|
- name: language
|
||||||
|
description: "Language code (e.g., en, de, fr, es)"
|
||||||
|
value: "en"
|
||||||
|
- name: num-epochs
|
||||||
|
description: "Number of training epochs"
|
||||||
|
value: "100"
|
||||||
|
- name: batch-size
|
||||||
|
description: "Training batch size"
|
||||||
|
value: "16"
|
||||||
|
- name: learning-rate
|
||||||
|
description: "Learning rate for training"
|
||||||
|
value: "0.0001"
|
||||||
|
- name: sample-rate
|
||||||
|
description: "Target sample rate for audio (Hz)"
|
||||||
|
value: "22050"
|
||||||
|
- name: output-path
|
||||||
|
description: "Path to store the trained model (S3 or NFS)"
|
||||||
|
value: "/models/tts/custom"
|
||||||
|
|
||||||
|
volumeClaimTemplates:
|
||||||
|
- metadata:
|
||||||
|
name: training-workspace
|
||||||
|
spec:
|
||||||
|
accessModes: ["ReadWriteMany"]
|
||||||
|
storageClassName: nfs-slow
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
storage: 50Gi
|
||||||
|
|
||||||
|
templates:
|
||||||
|
- name: train-voice
|
||||||
|
dag:
|
||||||
|
tasks:
|
||||||
|
- name: fetch-audio
|
||||||
|
template: fetch-audio-files
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: audio-source
|
||||||
|
value: "{{workflow.parameters.audio-source}}"
|
||||||
|
|
||||||
|
- name: fetch-transcripts
|
||||||
|
template: fetch-transcript-file
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: transcripts-source
|
||||||
|
value: "{{workflow.parameters.transcripts-source}}"
|
||||||
|
|
||||||
|
- name: preprocess-audio
|
||||||
|
template: preprocess
|
||||||
|
dependencies: [fetch-audio]
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: sample-rate
|
||||||
|
value: "{{workflow.parameters.sample-rate}}"
|
||||||
|
artifacts:
|
||||||
|
- name: raw-audio
|
||||||
|
from: "{{tasks.fetch-audio.outputs.artifacts.audio-files}}"
|
||||||
|
|
||||||
|
- name: generate-transcripts
|
||||||
|
template: transcribe-audio
|
||||||
|
dependencies: [preprocess-audio, fetch-transcripts]
|
||||||
|
when: "{{workflow.parameters.transcripts-source}} == ''"
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: language
|
||||||
|
value: "{{workflow.parameters.language}}"
|
||||||
|
artifacts:
|
||||||
|
- name: audio-files
|
||||||
|
from: "{{tasks.preprocess-audio.outputs.artifacts.processed-audio}}"
|
||||||
|
|
||||||
|
- name: prepare-dataset
|
||||||
|
template: prepare-coqui-dataset
|
||||||
|
dependencies: [preprocess-audio, generate-transcripts, fetch-transcripts]
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: voice-name
|
||||||
|
value: "{{workflow.parameters.voice-name}}"
|
||||||
|
- name: language
|
||||||
|
value: "{{workflow.parameters.language}}"
|
||||||
|
artifacts:
|
||||||
|
- name: audio-files
|
||||||
|
from: "{{tasks.preprocess-audio.outputs.artifacts.processed-audio}}"
|
||||||
|
- name: transcripts
|
||||||
|
from: "{{=workflow.parameters.transcriptsSource != '' ? tasks.fetch-transcripts.outputs.artifacts.transcripts : tasks.generate-transcripts.outputs.artifacts.transcripts}}"
|
||||||
|
optional: true
|
||||||
|
|
||||||
|
- name: train-model
|
||||||
|
template: train-tts
|
||||||
|
dependencies: [prepare-dataset]
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: voice-name
|
||||||
|
value: "{{workflow.parameters.voice-name}}"
|
||||||
|
- name: base-model
|
||||||
|
value: "{{workflow.parameters.base-model}}"
|
||||||
|
- name: language
|
||||||
|
value: "{{workflow.parameters.language}}"
|
||||||
|
- name: num-epochs
|
||||||
|
value: "{{workflow.parameters.num-epochs}}"
|
||||||
|
- name: batch-size
|
||||||
|
value: "{{workflow.parameters.batch-size}}"
|
||||||
|
- name: learning-rate
|
||||||
|
value: "{{workflow.parameters.learning-rate}}"
|
||||||
|
artifacts:
|
||||||
|
- name: dataset
|
||||||
|
from: "{{tasks.prepare-dataset.outputs.artifacts.dataset}}"
|
||||||
|
|
||||||
|
- name: export-model
|
||||||
|
template: export-trained-model
|
||||||
|
dependencies: [train-model]
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: voice-name
|
||||||
|
value: "{{workflow.parameters.voice-name}}"
|
||||||
|
- name: output-path
|
||||||
|
value: "{{workflow.parameters.output-path}}"
|
||||||
|
artifacts:
|
||||||
|
- name: trained-model
|
||||||
|
from: "{{tasks.train-model.outputs.artifacts.model}}"
|
||||||
|
|
||||||
|
# Template: Fetch audio files from source
|
||||||
|
- name: fetch-audio-files
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: audio-source
|
||||||
|
outputs:
|
||||||
|
artifacts:
|
||||||
|
- name: audio-files
|
||||||
|
path: /tmp/audio
|
||||||
|
container:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import urllib.request
|
||||||
|
from pathlib import Path
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
source_url = "{{inputs.parameters.audio-source}}"
|
||||||
|
output_dir = Path("/tmp/audio")
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"Fetching audio from: {source_url}")
|
||||||
|
|
||||||
|
if source_url.startswith("s3://"):
|
||||||
|
subprocess.run(["pip", "install", "boto3", "-q"], check=True)
|
||||||
|
import boto3
|
||||||
|
s3 = boto3.client("s3")
|
||||||
|
bucket, prefix = source_url[5:].split("/", 1)
|
||||||
|
response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
|
||||||
|
|
||||||
|
audio_extensions = {".wav", ".mp3", ".flac", ".ogg", ".m4a"}
|
||||||
|
for obj in response.get("Contents", []):
|
||||||
|
key = obj["Key"]
|
||||||
|
if Path(key).suffix.lower() in audio_extensions:
|
||||||
|
local_path = output_dir / Path(key).name
|
||||||
|
s3.download_file(bucket, key, str(local_path))
|
||||||
|
print(f"Downloaded: {key}")
|
||||||
|
|
||||||
|
elif source_url.startswith("http"):
|
||||||
|
# Handle single file or directory listing
|
||||||
|
filename = source_url.split("/")[-1]
|
||||||
|
if any(ext in filename.lower() for ext in [".wav", ".mp3", ".flac", ".zip"]):
|
||||||
|
local_path = output_dir / filename
|
||||||
|
urllib.request.urlretrieve(source_url, local_path)
|
||||||
|
print(f"Downloaded: {filename}")
|
||||||
|
|
||||||
|
# Extract if zip
|
||||||
|
if filename.endswith(".zip"):
|
||||||
|
shutil.unpack_archive(local_path, output_dir)
|
||||||
|
os.remove(local_path)
|
||||||
|
print("Extracted zip archive")
|
||||||
|
else:
|
||||||
|
print(f"URL doesn't appear to be an audio file: {source_url}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
elif source_url.startswith("/"):
|
||||||
|
# Local/NFS path
|
||||||
|
src_path = Path(source_url)
|
||||||
|
if src_path.is_dir():
|
||||||
|
audio_extensions = {".wav", ".mp3", ".flac", ".ogg", ".m4a"}
|
||||||
|
for f in src_path.iterdir():
|
||||||
|
if f.suffix.lower() in audio_extensions:
|
||||||
|
shutil.copy(f, output_dir / f.name)
|
||||||
|
print(f"Copied: {f.name}")
|
||||||
|
elif src_path.is_file():
|
||||||
|
shutil.copy(src_path, output_dir / src_path.name)
|
||||||
|
else:
|
||||||
|
print(f"Path not found: {source_url}")
|
||||||
|
exit(1)
|
||||||
|
else:
|
||||||
|
print(f"Unsupported source: {source_url}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# Count files
|
||||||
|
audio_files = list(output_dir.glob("*"))
|
||||||
|
print(f"Total audio files: {len(audio_files)}")
|
||||||
|
|
||||||
|
if len(audio_files) == 0:
|
||||||
|
print("Error: No audio files found!")
|
||||||
|
exit(1)
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: 512Mi
|
||||||
|
cpu: 200m
|
||||||
|
|
||||||
|
# Template: Fetch transcripts file
|
||||||
|
- name: fetch-transcript-file
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: transcripts-source
|
||||||
|
outputs:
|
||||||
|
artifacts:
|
||||||
|
- name: transcripts
|
||||||
|
path: /tmp/transcripts
|
||||||
|
optional: true
|
||||||
|
container:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import urllib.request
|
||||||
|
from pathlib import Path
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
source_url = "{{inputs.parameters.transcripts-source}}"
|
||||||
|
output_dir = Path("/tmp/transcripts")
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if not source_url or source_url.strip() == "":
|
||||||
|
print("No transcripts source provided - will auto-transcribe")
|
||||||
|
# Create empty placeholder
|
||||||
|
(output_dir / "placeholder.txt").write_text("auto-transcribe")
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
print(f"Fetching transcripts from: {source_url}")
|
||||||
|
|
||||||
|
if source_url.startswith("s3://"):
|
||||||
|
subprocess.run(["pip", "install", "boto3", "-q"], check=True)
|
||||||
|
import boto3
|
||||||
|
s3 = boto3.client("s3")
|
||||||
|
bucket, key = source_url[5:].split("/", 1)
|
||||||
|
local_path = output_dir / Path(key).name
|
||||||
|
s3.download_file(bucket, key, str(local_path))
|
||||||
|
print(f"Downloaded: {key}")
|
||||||
|
|
||||||
|
elif source_url.startswith("http"):
|
||||||
|
filename = source_url.split("/")[-1] or "transcripts.csv"
|
||||||
|
local_path = output_dir / filename
|
||||||
|
urllib.request.urlretrieve(source_url, local_path)
|
||||||
|
print(f"Downloaded: {filename}")
|
||||||
|
|
||||||
|
elif source_url.startswith("/"):
|
||||||
|
src_path = Path(source_url)
|
||||||
|
if src_path.is_file():
|
||||||
|
shutil.copy(src_path, output_dir / src_path.name)
|
||||||
|
print(f"Copied: {src_path.name}")
|
||||||
|
else:
|
||||||
|
print(f"File not found: {source_url}")
|
||||||
|
exit(1)
|
||||||
|
else:
|
||||||
|
print(f"Unsupported source: {source_url}")
|
||||||
|
exit(1)
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: 256Mi
|
||||||
|
cpu: 100m
|
||||||
|
|
||||||
|
# Template: Preprocess audio files
|
||||||
|
- name: preprocess
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: sample-rate
|
||||||
|
artifacts:
|
||||||
|
- name: raw-audio
|
||||||
|
path: /tmp/raw-audio
|
||||||
|
outputs:
|
||||||
|
artifacts:
|
||||||
|
- name: processed-audio
|
||||||
|
path: /tmp/processed-audio
|
||||||
|
container:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [bash]
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Install ffmpeg and dependencies
|
||||||
|
apt-get update && apt-get install -y ffmpeg > /dev/null 2>&1
|
||||||
|
pip install -q pydub numpy soundfile
|
||||||
|
|
||||||
|
python3 << 'EOF'
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from pydub import AudioSegment
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
SAMPLE_RATE = int("{{inputs.parameters.sample-rate}}")
|
||||||
|
input_dir = Path("/tmp/raw-audio")
|
||||||
|
output_dir = Path("/tmp/processed-audio")
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
audio_extensions = {".wav", ".mp3", ".flac", ".ogg", ".m4a"}
|
||||||
|
|
||||||
|
for audio_file in input_dir.iterdir():
|
||||||
|
if audio_file.suffix.lower() not in audio_extensions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Processing: {audio_file.name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load audio
|
||||||
|
audio = AudioSegment.from_file(str(audio_file))
|
||||||
|
|
||||||
|
# Convert to mono if stereo
|
||||||
|
if audio.channels > 1:
|
||||||
|
audio = audio.set_channels(1)
|
||||||
|
|
||||||
|
# Resample to target sample rate
|
||||||
|
audio = audio.set_frame_rate(SAMPLE_RATE)
|
||||||
|
|
||||||
|
# Normalize audio
|
||||||
|
audio = audio.normalize()
|
||||||
|
|
||||||
|
# Export as WAV
|
||||||
|
output_file = output_dir / f"{audio_file.stem}.wav"
|
||||||
|
audio.export(str(output_file), format="wav")
|
||||||
|
print(f" -> Saved: {output_file.name}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" -> Error processing {audio_file.name}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
processed_files = list(output_dir.glob("*.wav"))
|
||||||
|
print(f"\nProcessed {len(processed_files)} audio files")
|
||||||
|
|
||||||
|
if len(processed_files) == 0:
|
||||||
|
print("Error: No files were successfully processed!")
|
||||||
|
exit(1)
|
||||||
|
EOF
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: 2Gi
|
||||||
|
cpu: "1"
|
||||||
|
|
||||||
|
# Template: Auto-transcribe audio using Coqui STT
|
||||||
|
- name: transcribe-audio
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: language
|
||||||
|
artifacts:
|
||||||
|
- name: audio-files
|
||||||
|
path: /tmp/audio
|
||||||
|
outputs:
|
||||||
|
artifacts:
|
||||||
|
- name: transcripts
|
||||||
|
path: /tmp/transcripts
|
||||||
|
container:
|
||||||
|
image: ghcr.io/coqui-ai/stt:latest
|
||||||
|
command: [bash]
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Install additional dependencies
|
||||||
|
pip install -q numpy scipy
|
||||||
|
|
||||||
|
python3 << 'EOF'
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
import wave
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from stt import Model
|
||||||
|
|
||||||
|
LANGUAGE = "{{inputs.parameters.language}}"
|
||||||
|
input_dir = Path("/tmp/audio")
|
||||||
|
output_dir = Path("/tmp/transcripts")
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Model paths - Coqui STT models are typically pre-installed in the container
|
||||||
|
# or can be downloaded from https://coqui.ai/models
|
||||||
|
MODEL_DIR = Path("/models/stt")
|
||||||
|
|
||||||
|
# Try to find model files
|
||||||
|
model_file = None
|
||||||
|
scorer_file = None
|
||||||
|
|
||||||
|
# Check for language-specific models
|
||||||
|
lang_model_dir = MODEL_DIR / LANGUAGE
|
||||||
|
if lang_model_dir.exists():
|
||||||
|
for f in lang_model_dir.glob("*.tflite"):
|
||||||
|
model_file = f
|
||||||
|
for f in lang_model_dir.glob("*.scorer"):
|
||||||
|
scorer_file = f
|
||||||
|
|
||||||
|
# Fallback to default English model location
|
||||||
|
if model_file is None:
|
||||||
|
default_paths = [
|
||||||
|
MODEL_DIR / "model.tflite",
|
||||||
|
Path("/usr/share/stt/model.tflite"),
|
||||||
|
Path("/opt/stt/model.tflite"),
|
||||||
|
]
|
||||||
|
for p in default_paths:
|
||||||
|
if p.exists():
|
||||||
|
model_file = p
|
||||||
|
break
|
||||||
|
|
||||||
|
if model_file is None:
|
||||||
|
# Download model if not found
|
||||||
|
print("Downloading Coqui STT model...")
|
||||||
|
import urllib.request
|
||||||
|
import tarfile
|
||||||
|
|
||||||
|
model_url = "https://github.com/coqui-ai/STT-models/releases/download/english/coqui-stt-1.0.0-lg-vocab.tflite"
|
||||||
|
scorer_url = "https://github.com/coqui-ai/STT-models/releases/download/english/coqui-stt-1.0.0-lg-vocab.scorer"
|
||||||
|
|
||||||
|
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
model_file = MODEL_DIR / "model.tflite"
|
||||||
|
scorer_file = MODEL_DIR / "model.scorer"
|
||||||
|
|
||||||
|
urllib.request.urlretrieve(model_url, model_file)
|
||||||
|
urllib.request.urlretrieve(scorer_url, scorer_file)
|
||||||
|
print("Model downloaded successfully")
|
||||||
|
|
||||||
|
print(f"Loading Coqui STT model: {model_file}")
|
||||||
|
model = Model(str(model_file))
|
||||||
|
|
||||||
|
if scorer_file and scorer_file.exists():
|
||||||
|
print(f"Loading scorer: {scorer_file}")
|
||||||
|
model.enableExternalScorer(str(scorer_file))
|
||||||
|
|
||||||
|
transcripts = []
|
||||||
|
|
||||||
|
for audio_file in sorted(input_dir.glob("*.wav")):
|
||||||
|
print(f"Transcribing: {audio_file.name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Read WAV file
|
||||||
|
with wave.open(str(audio_file), 'rb') as w:
|
||||||
|
sample_rate = w.getframerate()
|
||||||
|
frames = w.getnframes()
|
||||||
|
audio_data = w.readframes(frames)
|
||||||
|
|
||||||
|
# Convert to int16 array
|
||||||
|
audio = np.frombuffer(audio_data, dtype=np.int16)
|
||||||
|
|
||||||
|
# Resample if needed (Coqui STT expects 16kHz)
|
||||||
|
if sample_rate != 16000:
|
||||||
|
from scipy import signal
|
||||||
|
audio = signal.resample(audio, int(len(audio) * 16000 / sample_rate))
|
||||||
|
audio = audio.astype(np.int16)
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
text = model.stt(audio)
|
||||||
|
|
||||||
|
transcripts.append({
|
||||||
|
"audio_file": audio_file.name,
|
||||||
|
"transcript": text
|
||||||
|
})
|
||||||
|
print(f" -> {text[:100] if text else '(empty)'}...")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" -> Error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Write CSV
|
||||||
|
csv_file = output_dir / "transcripts.csv"
|
||||||
|
with open(csv_file, "w", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=["audio_file", "transcript"])
|
||||||
|
writer.writeheader()
|
||||||
|
writer.writerows(transcripts)
|
||||||
|
|
||||||
|
print(f"\nTranscribed {len(transcripts)} files")
|
||||||
|
print(f"Saved to: {csv_file}")
|
||||||
|
EOF
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: 4Gi
|
||||||
|
cpu: "2"
|
||||||
|
limits:
|
||||||
|
memory: 8Gi
|
||||||
|
cpu: "4"
|
||||||
|
|
||||||
|
# Template: Prepare dataset in Coqui TTS format
|
||||||
|
- name: prepare-coqui-dataset
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: voice-name
|
||||||
|
- name: language
|
||||||
|
artifacts:
|
||||||
|
- name: audio-files
|
||||||
|
path: /tmp/audio
|
||||||
|
- name: transcripts
|
||||||
|
path: /tmp/transcripts
|
||||||
|
optional: true
|
||||||
|
outputs:
|
||||||
|
artifacts:
|
||||||
|
- name: dataset
|
||||||
|
path: /tmp/dataset
|
||||||
|
container:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
VOICE_NAME = "{{inputs.parameters.voice-name}}"
|
||||||
|
LANGUAGE = "{{inputs.parameters.language}}"
|
||||||
|
|
||||||
|
audio_dir = Path("/tmp/audio")
|
||||||
|
transcripts_dir = Path("/tmp/transcripts")
|
||||||
|
output_dir = Path("/tmp/dataset")
|
||||||
|
wavs_dir = output_dir / "wavs"
|
||||||
|
wavs_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"Preparing Coqui TTS dataset for voice: {VOICE_NAME}")
|
||||||
|
|
||||||
|
# Find transcripts file
|
||||||
|
transcripts_file = None
|
||||||
|
for f in transcripts_dir.glob("*.csv"):
|
||||||
|
transcripts_file = f
|
||||||
|
break
|
||||||
|
|
||||||
|
if transcripts_file is None:
|
||||||
|
# Check for .txt files (simple format: filename|text)
|
||||||
|
for f in transcripts_dir.glob("*.txt"):
|
||||||
|
if f.name != "placeholder.txt":
|
||||||
|
transcripts_file = f
|
||||||
|
break
|
||||||
|
|
||||||
|
if transcripts_file is None:
|
||||||
|
print("Error: No transcripts file found!")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
print(f"Using transcripts: {transcripts_file}")
|
||||||
|
|
||||||
|
# Parse transcripts
|
||||||
|
transcripts = {}
|
||||||
|
|
||||||
|
if transcripts_file.suffix == ".csv":
|
||||||
|
with open(transcripts_file, "r", encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
# Handle various column name conventions
|
||||||
|
audio = row.get("audio_file") or row.get("audio") or row.get("file") or row.get("wav")
|
||||||
|
text = row.get("transcript") or row.get("text") or row.get("sentence")
|
||||||
|
if audio and text:
|
||||||
|
transcripts[audio] = text.strip()
|
||||||
|
else:
|
||||||
|
# Simple pipe-separated format: filename|text
|
||||||
|
with open(transcripts_file, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if "|" in line:
|
||||||
|
parts = line.split("|", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
transcripts[parts[0]] = parts[1]
|
||||||
|
|
||||||
|
print(f"Loaded {len(transcripts)} transcripts")
|
||||||
|
|
||||||
|
# Copy audio files and create metadata
|
||||||
|
metadata_lines = []
|
||||||
|
|
||||||
|
for audio_file in sorted(audio_dir.glob("*.wav")):
|
||||||
|
# Try to match transcript
|
||||||
|
text = None
|
||||||
|
for key in [audio_file.name, audio_file.stem, audio_file.stem + ".wav"]:
|
||||||
|
if key in transcripts:
|
||||||
|
text = transcripts[key]
|
||||||
|
break
|
||||||
|
|
||||||
|
if text is None:
|
||||||
|
print(f"Warning: No transcript for {audio_file.name}, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Copy audio file
|
||||||
|
dest_file = wavs_dir / audio_file.name
|
||||||
|
shutil.copy(audio_file, dest_file)
|
||||||
|
|
||||||
|
# Add to metadata (LJSpeech format: filename|text|text)
|
||||||
|
# Coqui uses: audio_file|text|text (normalized text optional)
|
||||||
|
metadata_lines.append(f"{audio_file.stem}|{text}|{text}")
|
||||||
|
|
||||||
|
# Write metadata.csv
|
||||||
|
metadata_file = output_dir / "metadata.csv"
|
||||||
|
with open(metadata_file, "w", encoding="utf-8") as f:
|
||||||
|
f.write("\n".join(metadata_lines))
|
||||||
|
|
||||||
|
print(f"Created dataset with {len(metadata_lines)} samples")
|
||||||
|
|
||||||
|
# Create dataset config
|
||||||
|
config = {
|
||||||
|
"name": VOICE_NAME,
|
||||||
|
"language": LANGUAGE,
|
||||||
|
"num_samples": len(metadata_lines),
|
||||||
|
"format": "ljspeech"
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(output_dir / "dataset_config.json", "w") as f:
|
||||||
|
json.dump(config, f, indent=2)
|
||||||
|
|
||||||
|
print(f"Dataset ready at: {output_dir}")
|
||||||
|
|
||||||
|
if len(metadata_lines) < 10:
|
||||||
|
print("Warning: Very small dataset! Recommend at least 100+ samples for good results.")
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: 1Gi
|
||||||
|
cpu: 500m
|
||||||
|
|
||||||
|
# Template: Train Coqui TTS model
|
||||||
|
- name: train-tts
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: voice-name
|
||||||
|
- name: base-model
|
||||||
|
- name: language
|
||||||
|
- name: num-epochs
|
||||||
|
- name: batch-size
|
||||||
|
- name: learning-rate
|
||||||
|
artifacts:
|
||||||
|
- name: dataset
|
||||||
|
path: /tmp/dataset
|
||||||
|
outputs:
|
||||||
|
artifacts:
|
||||||
|
- name: model
|
||||||
|
path: /tmp/output
|
||||||
|
container:
|
||||||
|
image: ghcr.io/coqui-ai/tts:latest
|
||||||
|
command: [bash]
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
set -e
|
||||||
|
|
||||||
|
VOICE_NAME="{{inputs.parameters.voice-name}}"
|
||||||
|
BASE_MODEL="{{inputs.parameters.base-model}}"
|
||||||
|
LANGUAGE="{{inputs.parameters.language}}"
|
||||||
|
NUM_EPOCHS="{{inputs.parameters.num-epochs}}"
|
||||||
|
BATCH_SIZE="{{inputs.parameters.batch-size}}"
|
||||||
|
LEARNING_RATE="{{inputs.parameters.learning-rate}}"
|
||||||
|
|
||||||
|
DATASET_DIR="/tmp/dataset"
|
||||||
|
OUTPUT_DIR="/tmp/output"
|
||||||
|
mkdir -p "$OUTPUT_DIR"
|
||||||
|
|
||||||
|
echo "=== Coqui TTS Voice Training ==="
|
||||||
|
echo "Voice Name: $VOICE_NAME"
|
||||||
|
echo "Base Model: $BASE_MODEL"
|
||||||
|
echo "Language: $LANGUAGE"
|
||||||
|
echo "Epochs: $NUM_EPOCHS"
|
||||||
|
echo "Batch Size: $BATCH_SIZE"
|
||||||
|
echo "Learning Rate: $LEARNING_RATE"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Download base model if specified for fine-tuning
|
||||||
|
RESTORE_PATH=""
|
||||||
|
if [ "$BASE_MODEL" != "" ] && [ "$BASE_MODEL" != "none" ]; then
|
||||||
|
echo "Downloading base model for fine-tuning: $BASE_MODEL"
|
||||||
|
# Use tts to download the model and get its path
|
||||||
|
MODEL_PATH=$(python3 -c "
|
||||||
|
from TTS.utils.manage import ModelManager
|
||||||
|
from TTS.utils.synthesizer import Synthesizer
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
|
||||||
|
model_name = '$BASE_MODEL'
|
||||||
|
manager = ModelManager()
|
||||||
|
|
||||||
|
# Download the model
|
||||||
|
model_path, config_path, _ = manager.download_model(model_name)
|
||||||
|
print(model_path)
|
||||||
|
")
|
||||||
|
RESTORE_PATH="$MODEL_PATH"
|
||||||
|
echo "Base model path: $RESTORE_PATH"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Create and run training script following Coqui docs pattern
|
||||||
|
python3 << EOF
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Trainer: Where the magic happens
|
||||||
|
from trainer import Trainer, TrainerArgs
|
||||||
|
|
||||||
|
# Model configs
|
||||||
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
|
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||||
|
from TTS.tts.datasets import load_tts_samples
|
||||||
|
from TTS.tts.models.vits import Vits
|
||||||
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
# Paths
|
||||||
|
DATASET_DIR = Path("$DATASET_DIR")
|
||||||
|
OUTPUT_DIR = Path("$OUTPUT_DIR")
|
||||||
|
RESTORE_PATH = "$RESTORE_PATH" if "$RESTORE_PATH" else None
|
||||||
|
|
||||||
|
print(f"Dataset: {DATASET_DIR}")
|
||||||
|
print(f"Output: {OUTPUT_DIR}")
|
||||||
|
print(f"Restore from: {RESTORE_PATH}")
|
||||||
|
|
||||||
|
# Define dataset config (LJSpeech format)
|
||||||
|
dataset_config = BaseDatasetConfig(
|
||||||
|
formatter="ljspeech",
|
||||||
|
meta_file_train="metadata.csv",
|
||||||
|
path=str(DATASET_DIR),
|
||||||
|
language="$LANGUAGE",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize training configuration
|
||||||
|
config = VitsConfig(
|
||||||
|
run_name="$VOICE_NAME",
|
||||||
|
output_path=str(OUTPUT_DIR),
|
||||||
|
datasets=[dataset_config],
|
||||||
|
batch_size=int("$BATCH_SIZE"),
|
||||||
|
eval_batch_size=max(1, int("$BATCH_SIZE") // 2),
|
||||||
|
num_loader_workers=4,
|
||||||
|
num_eval_loader_workers=2,
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=5,
|
||||||
|
epochs=int("$NUM_EPOCHS"),
|
||||||
|
text_cleaner="phoneme_cleaners",
|
||||||
|
use_phonemes=True,
|
||||||
|
phoneme_language="$LANGUAGE",
|
||||||
|
phoneme_cache_path=str(OUTPUT_DIR / "phoneme_cache"),
|
||||||
|
compute_input_seq_cache=True,
|
||||||
|
print_step=25,
|
||||||
|
print_eval=False,
|
||||||
|
mixed_precision=True,
|
||||||
|
save_step=500,
|
||||||
|
save_n_checkpoints=3,
|
||||||
|
save_best_after=1000,
|
||||||
|
lr=float("$LEARNING_RATE"),
|
||||||
|
# Audio settings for typical voice cloning
|
||||||
|
audio={
|
||||||
|
"sample_rate": 22050,
|
||||||
|
"resample": True,
|
||||||
|
"do_trim_silence": True,
|
||||||
|
"trim_db": 45,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the audio processor
|
||||||
|
# Used for feature extraction and audio I/O
|
||||||
|
ap = AudioProcessor.init_from_config(config)
|
||||||
|
|
||||||
|
# Initialize the tokenizer
|
||||||
|
# Converts text to sequences of token IDs
|
||||||
|
tokenizer, config = TTSTokenizer.init_from_config(config)
|
||||||
|
|
||||||
|
# Load data samples
|
||||||
|
# Each sample is [text, audio_file_path, speaker_name]
|
||||||
|
train_samples, eval_samples = load_tts_samples(
|
||||||
|
dataset_config,
|
||||||
|
eval_split=True,
|
||||||
|
eval_split_max_size=config.eval_split_max_size,
|
||||||
|
eval_split_size=config.eval_split_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Training samples: {len(train_samples)}")
|
||||||
|
print(f"Eval samples: {len(eval_samples)}")
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
model = Vits(config, ap, tokenizer, speaker_manager=None)
|
||||||
|
|
||||||
|
# Set up trainer arguments
|
||||||
|
trainer_args = TrainerArgs(
|
||||||
|
restore_path=RESTORE_PATH,
|
||||||
|
skip_train_epoch=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
trainer = Trainer(
|
||||||
|
trainer_args,
|
||||||
|
config,
|
||||||
|
output_path=str(OUTPUT_DIR),
|
||||||
|
model=model,
|
||||||
|
train_samples=train_samples,
|
||||||
|
eval_samples=eval_samples,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start training
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("Starting training...")
|
||||||
|
print("=" * 50 + "\n")
|
||||||
|
|
||||||
|
trainer.fit()
|
||||||
|
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("Training complete!")
|
||||||
|
print("=" * 50)
|
||||||
|
EOF
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Training complete!"
|
||||||
|
echo "Model saved to: $OUTPUT_DIR"
|
||||||
|
ls -la "$OUTPUT_DIR"
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: 16Gi
|
||||||
|
cpu: "4"
|
||||||
|
nvidia.com/gpu: "1"
|
||||||
|
limits:
|
||||||
|
memory: 32Gi
|
||||||
|
cpu: "8"
|
||||||
|
nvidia.com/gpu: "1"
|
||||||
|
volumeMounts:
|
||||||
|
- name: training-workspace
|
||||||
|
mountPath: /tmp/workspace
|
||||||
|
|
||||||
|
# Template: Export trained model
|
||||||
|
- name: export-trained-model
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: voice-name
|
||||||
|
- name: output-path
|
||||||
|
artifacts:
|
||||||
|
- name: trained-model
|
||||||
|
path: /tmp/trained-model
|
||||||
|
outputs:
|
||||||
|
artifacts:
|
||||||
|
- name: exported-model
|
||||||
|
path: /tmp/exported
|
||||||
|
container:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [bash]
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
set -e
|
||||||
|
|
||||||
|
pip install -q boto3
|
||||||
|
|
||||||
|
python3 << 'EOF'
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
VOICE_NAME = "{{inputs.parameters.voice-name}}"
|
||||||
|
OUTPUT_PATH = "{{inputs.parameters.output-path}}"
|
||||||
|
|
||||||
|
model_dir = Path("/tmp/trained-model")
|
||||||
|
export_dir = Path("/tmp/exported")
|
||||||
|
export_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"Exporting trained model: {VOICE_NAME}")
|
||||||
|
print(f"Target path: {OUTPUT_PATH}")
|
||||||
|
|
||||||
|
# Find best checkpoint
|
||||||
|
checkpoints = list(model_dir.glob("best_model*.pth")) + list(model_dir.glob("checkpoint_*.pth"))
|
||||||
|
if not checkpoints:
|
||||||
|
checkpoints = list(model_dir.glob("*.pth"))
|
||||||
|
|
||||||
|
if not checkpoints:
|
||||||
|
print("Error: No model checkpoints found!")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# Sort by modification time and get newest
|
||||||
|
checkpoints.sort(key=lambda x: x.stat().st_mtime, reverse=True)
|
||||||
|
best_checkpoint = checkpoints[0]
|
||||||
|
print(f"Using checkpoint: {best_checkpoint.name}")
|
||||||
|
|
||||||
|
# Create export package
|
||||||
|
package_dir = export_dir / VOICE_NAME
|
||||||
|
package_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Copy model files
|
||||||
|
shutil.copy(best_checkpoint, package_dir / "model.pth")
|
||||||
|
|
||||||
|
# Copy config if exists
|
||||||
|
config_file = model_dir / "config.json"
|
||||||
|
if config_file.exists():
|
||||||
|
shutil.copy(config_file, package_dir / "config.json")
|
||||||
|
|
||||||
|
# Create model info
|
||||||
|
model_info = {
|
||||||
|
"name": VOICE_NAME,
|
||||||
|
"created_at": datetime.now().isoformat(),
|
||||||
|
"checkpoint": best_checkpoint.name,
|
||||||
|
"type": "coqui-tts"
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(package_dir / "model_info.json", "w") as f:
|
||||||
|
json.dump(model_info, f, indent=2)
|
||||||
|
|
||||||
|
# Create tarball
|
||||||
|
archive_name = f"{VOICE_NAME}.tar.gz"
|
||||||
|
shutil.make_archive(
|
||||||
|
str(export_dir / VOICE_NAME),
|
||||||
|
"gztar",
|
||||||
|
export_dir,
|
||||||
|
VOICE_NAME
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Created archive: {archive_name}")
|
||||||
|
|
||||||
|
# Upload to destination
|
||||||
|
if OUTPUT_PATH.startswith("s3://"):
|
||||||
|
import boto3
|
||||||
|
s3 = boto3.client("s3")
|
||||||
|
bucket, key = OUTPUT_PATH[5:].split("/", 1)
|
||||||
|
key = f"{key}/{archive_name}"
|
||||||
|
s3.upload_file(str(export_dir / archive_name), bucket, key)
|
||||||
|
print(f"Uploaded to: s3://{bucket}/{key}")
|
||||||
|
|
||||||
|
elif OUTPUT_PATH.startswith("/"):
|
||||||
|
# Local/NFS path
|
||||||
|
dest_path = Path(OUTPUT_PATH)
|
||||||
|
dest_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy(export_dir / archive_name, dest_path / archive_name)
|
||||||
|
# Also copy uncompressed for easy access
|
||||||
|
shutil.copytree(package_dir, dest_path / VOICE_NAME, dirs_exist_ok=True)
|
||||||
|
print(f"Saved to: {dest_path / archive_name}")
|
||||||
|
|
||||||
|
print("\nExport complete!")
|
||||||
|
print(f"Model package contents:")
|
||||||
|
for f in package_dir.iterdir():
|
||||||
|
print(f" - {f.name}")
|
||||||
|
EOF
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: 1Gi
|
||||||
|
cpu: 500m
|
||||||
369
document-ingestion.yaml
Normal file
369
document-ingestion.yaml
Normal file
@@ -0,0 +1,369 @@
|
|||||||
|
# Document Ingestion Workflow
|
||||||
|
# Ingests documents from a source URL into Milvus vector database
|
||||||
|
# Triggered via NATS: ai.pipeline.trigger with pipeline="document-ingestion"
|
||||||
|
---
|
||||||
|
apiVersion: argoproj.io/v1alpha1
|
||||||
|
kind: WorkflowTemplate
|
||||||
|
metadata:
|
||||||
|
name: document-ingestion
|
||||||
|
namespace: ai-ml
|
||||||
|
labels:
|
||||||
|
app.kubernetes.io/name: document-ingestion
|
||||||
|
app.kubernetes.io/part-of: llm-workflows
|
||||||
|
spec:
|
||||||
|
entrypoint: ingest-documents
|
||||||
|
serviceAccountName: argo-workflow
|
||||||
|
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: source-url
|
||||||
|
description: "URL to fetch documents from (S3, HTTP, or local path)"
|
||||||
|
- name: collection-name
|
||||||
|
value: "knowledge_base"
|
||||||
|
description: "Milvus collection name"
|
||||||
|
- name: chunk-size
|
||||||
|
value: "512"
|
||||||
|
description: "Text chunk size in characters"
|
||||||
|
- name: chunk-overlap
|
||||||
|
value: "50"
|
||||||
|
description: "Overlap between chunks"
|
||||||
|
|
||||||
|
templates:
|
||||||
|
- name: ingest-documents
|
||||||
|
dag:
|
||||||
|
tasks:
|
||||||
|
- name: fetch-documents
|
||||||
|
template: fetch-docs
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: source-url
|
||||||
|
value: "{{workflow.parameters.source-url}}"
|
||||||
|
|
||||||
|
- name: chunk-documents
|
||||||
|
template: chunk-docs
|
||||||
|
dependencies: [fetch-documents]
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: chunk-size
|
||||||
|
value: "{{workflow.parameters.chunk-size}}"
|
||||||
|
- name: chunk-overlap
|
||||||
|
value: "{{workflow.parameters.chunk-overlap}}"
|
||||||
|
artifacts:
|
||||||
|
- name: documents
|
||||||
|
from: "{{tasks.fetch-documents.outputs.artifacts.documents}}"
|
||||||
|
|
||||||
|
- name: generate-embeddings
|
||||||
|
template: embed-docs
|
||||||
|
dependencies: [chunk-documents]
|
||||||
|
arguments:
|
||||||
|
artifacts:
|
||||||
|
- name: chunks
|
||||||
|
from: "{{tasks.chunk-documents.outputs.artifacts.chunks}}"
|
||||||
|
|
||||||
|
- name: store-in-milvus
|
||||||
|
template: store-docs
|
||||||
|
dependencies: [generate-embeddings]
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: collection-name
|
||||||
|
value: "{{workflow.parameters.collection-name}}"
|
||||||
|
artifacts:
|
||||||
|
- name: embeddings
|
||||||
|
from: "{{tasks.generate-embeddings.outputs.artifacts.embeddings}}"
|
||||||
|
|
||||||
|
- name: fetch-docs
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: source-url
|
||||||
|
outputs:
|
||||||
|
artifacts:
|
||||||
|
- name: documents
|
||||||
|
path: /tmp/documents
|
||||||
|
container:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import urllib.request
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
source_url = "{{inputs.parameters.source-url}}"
|
||||||
|
output_dir = Path("/tmp/documents")
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"Fetching documents from: {source_url}")
|
||||||
|
|
||||||
|
# Handle different source types
|
||||||
|
if source_url.startswith("s3://"):
|
||||||
|
import subprocess
|
||||||
|
subprocess.run(["pip", "install", "boto3", "-q"], check=True)
|
||||||
|
import boto3
|
||||||
|
s3 = boto3.client("s3")
|
||||||
|
bucket, prefix = source_url[5:].split("/", 1)
|
||||||
|
response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
|
||||||
|
for obj in response.get("Contents", []):
|
||||||
|
key = obj["Key"]
|
||||||
|
local_path = output_dir / Path(key).name
|
||||||
|
s3.download_file(bucket, key, str(local_path))
|
||||||
|
print(f"Downloaded: {key}")
|
||||||
|
elif source_url.startswith("http"):
|
||||||
|
# Single file download
|
||||||
|
filename = source_url.split("/")[-1] or "document.txt"
|
||||||
|
local_path = output_dir / filename
|
||||||
|
urllib.request.urlretrieve(source_url, local_path)
|
||||||
|
print(f"Downloaded: {filename}")
|
||||||
|
else:
|
||||||
|
print(f"Unsupported URL scheme: {source_url}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# List downloaded files
|
||||||
|
files = list(output_dir.glob("*"))
|
||||||
|
print(f"Downloaded {len(files)} files")
|
||||||
|
|
||||||
|
# Create manifest
|
||||||
|
manifest = {"files": [str(f) for f in files]}
|
||||||
|
with open(output_dir / "manifest.json", "w") as f:
|
||||||
|
json.dump(manifest, f)
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: 256Mi
|
||||||
|
cpu: 100m
|
||||||
|
|
||||||
|
- name: chunk-docs
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: chunk-size
|
||||||
|
- name: chunk-overlap
|
||||||
|
artifacts:
|
||||||
|
- name: documents
|
||||||
|
path: /tmp/documents
|
||||||
|
outputs:
|
||||||
|
artifacts:
|
||||||
|
- name: chunks
|
||||||
|
path: /tmp/chunks
|
||||||
|
container:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
chunk_size = int("{{inputs.parameters.chunk-size}}")
|
||||||
|
chunk_overlap = int("{{inputs.parameters.chunk-overlap}}")
|
||||||
|
|
||||||
|
input_dir = Path("/tmp/documents")
|
||||||
|
output_dir = Path("/tmp/chunks")
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Load manifest
|
||||||
|
with open(input_dir / "manifest.json") as f:
|
||||||
|
manifest = json.load(f)
|
||||||
|
|
||||||
|
all_chunks = []
|
||||||
|
|
||||||
|
for filepath in manifest["files"]:
|
||||||
|
filepath = Path(filepath)
|
||||||
|
if not filepath.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Processing: {filepath.name}")
|
||||||
|
|
||||||
|
# Read file content
|
||||||
|
try:
|
||||||
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Simple chunking
|
||||||
|
chunks = []
|
||||||
|
start = 0
|
||||||
|
while start < len(content):
|
||||||
|
end = start + chunk_size
|
||||||
|
chunk = content[start:end]
|
||||||
|
if chunk.strip():
|
||||||
|
chunks.append({
|
||||||
|
"text": chunk,
|
||||||
|
"source": filepath.name,
|
||||||
|
"chunk_index": len(chunks)
|
||||||
|
})
|
||||||
|
start = end - chunk_overlap
|
||||||
|
|
||||||
|
all_chunks.extend(chunks)
|
||||||
|
print(f" Created {len(chunks)} chunks")
|
||||||
|
|
||||||
|
# Save chunks
|
||||||
|
with open(output_dir / "chunks.json", "w") as f:
|
||||||
|
json.dump({"chunks": all_chunks}, f)
|
||||||
|
|
||||||
|
print(f"Total chunks: {len(all_chunks)}")
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: 512Mi
|
||||||
|
cpu: 100m
|
||||||
|
|
||||||
|
- name: embed-docs
|
||||||
|
inputs:
|
||||||
|
artifacts:
|
||||||
|
- name: chunks
|
||||||
|
path: /tmp/chunks
|
||||||
|
outputs:
|
||||||
|
artifacts:
|
||||||
|
- name: embeddings
|
||||||
|
path: /tmp/embeddings
|
||||||
|
container:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
import subprocess
|
||||||
|
subprocess.run(["pip", "install", "httpx", "-q"], check=True)
|
||||||
|
|
||||||
|
import json
|
||||||
|
import httpx
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
EMBEDDINGS_URL = "http://embeddings-predictor.ai-ml.svc.cluster.local"
|
||||||
|
BATCH_SIZE = 32
|
||||||
|
|
||||||
|
input_dir = Path("/tmp/chunks")
|
||||||
|
output_dir = Path("/tmp/embeddings")
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Load chunks
|
||||||
|
with open(input_dir / "chunks.json") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
chunks = data["chunks"]
|
||||||
|
|
||||||
|
print(f"Generating embeddings for {len(chunks)} chunks")
|
||||||
|
|
||||||
|
# Generate embeddings in batches
|
||||||
|
all_embeddings = []
|
||||||
|
with httpx.Client(timeout=120.0) as client:
|
||||||
|
for i in range(0, len(chunks), BATCH_SIZE):
|
||||||
|
batch = chunks[i:i+BATCH_SIZE]
|
||||||
|
texts = [c["text"] for c in batch]
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
f"{EMBEDDINGS_URL}/embeddings",
|
||||||
|
json={"input": texts, "model": "bge"}
|
||||||
|
)
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
for j, emb_data in enumerate(result.get("data", [])):
|
||||||
|
all_embeddings.append({
|
||||||
|
"text": batch[j]["text"],
|
||||||
|
"source": batch[j]["source"],
|
||||||
|
"chunk_index": batch[j]["chunk_index"],
|
||||||
|
"embedding": emb_data["embedding"]
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f" Processed batch {i//BATCH_SIZE + 1}/{(len(chunks)-1)//BATCH_SIZE + 1}")
|
||||||
|
|
||||||
|
# Save embeddings
|
||||||
|
with open(output_dir / "embeddings.json", "w") as f:
|
||||||
|
json.dump({"embeddings": all_embeddings}, f)
|
||||||
|
|
||||||
|
print(f"Generated {len(all_embeddings)} embeddings")
|
||||||
|
envFrom:
|
||||||
|
- configMapRef:
|
||||||
|
name: ai-services-config
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: 1Gi
|
||||||
|
cpu: 200m
|
||||||
|
|
||||||
|
- name: store-docs
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: collection-name
|
||||||
|
artifacts:
|
||||||
|
- name: embeddings
|
||||||
|
path: /tmp/embeddings
|
||||||
|
container:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
import subprocess
|
||||||
|
subprocess.run(["pip", "install", "pymilvus", "-q"], check=True)
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
|
||||||
|
|
||||||
|
MILVUS_HOST = "milvus.ai-ml.svc.cluster.local"
|
||||||
|
MILVUS_PORT = 19530
|
||||||
|
COLLECTION_NAME = "{{inputs.parameters.collection-name}}"
|
||||||
|
EMBEDDING_DIM = 1024 # BGE-large dimension
|
||||||
|
|
||||||
|
input_dir = Path("/tmp/embeddings")
|
||||||
|
|
||||||
|
# Load embeddings
|
||||||
|
with open(input_dir / "embeddings.json") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
embeddings = data["embeddings"]
|
||||||
|
|
||||||
|
print(f"Storing {len(embeddings)} embeddings in Milvus")
|
||||||
|
|
||||||
|
# Connect to Milvus
|
||||||
|
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
|
||||||
|
print("Connected to Milvus")
|
||||||
|
|
||||||
|
# Create collection if not exists
|
||||||
|
if not utility.has_collection(COLLECTION_NAME):
|
||||||
|
fields = [
|
||||||
|
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
|
||||||
|
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
|
||||||
|
FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=1024),
|
||||||
|
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=EMBEDDING_DIM)
|
||||||
|
]
|
||||||
|
schema = CollectionSchema(fields, description="Knowledge base documents")
|
||||||
|
collection = Collection(COLLECTION_NAME, schema)
|
||||||
|
|
||||||
|
# Create HNSW index
|
||||||
|
index_params = {
|
||||||
|
"metric_type": "COSINE",
|
||||||
|
"index_type": "HNSW",
|
||||||
|
"params": {"M": 16, "efConstruction": 256}
|
||||||
|
}
|
||||||
|
collection.create_index("embedding", index_params)
|
||||||
|
print(f"Created collection: {COLLECTION_NAME}")
|
||||||
|
else:
|
||||||
|
collection = Collection(COLLECTION_NAME)
|
||||||
|
print(f"Using existing collection: {COLLECTION_NAME}")
|
||||||
|
|
||||||
|
# Insert data in batches
|
||||||
|
BATCH_SIZE = 100
|
||||||
|
for i in range(0, len(embeddings), BATCH_SIZE):
|
||||||
|
batch = embeddings[i:i+BATCH_SIZE]
|
||||||
|
|
||||||
|
data = [
|
||||||
|
[e["text"] for e in batch],
|
||||||
|
[e["source"] for e in batch],
|
||||||
|
[e["embedding"] for e in batch]
|
||||||
|
]
|
||||||
|
|
||||||
|
collection.insert(data)
|
||||||
|
print(f" Inserted batch {i//BATCH_SIZE + 1}/{(len(embeddings)-1)//BATCH_SIZE + 1}")
|
||||||
|
|
||||||
|
# Flush to ensure data is persisted
|
||||||
|
collection.flush()
|
||||||
|
print(f"Successfully stored {len(embeddings)} documents")
|
||||||
|
|
||||||
|
connections.disconnect("default")
|
||||||
|
envFrom:
|
||||||
|
- configMapRef:
|
||||||
|
name: ai-services-config
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: 512Mi
|
||||||
|
cpu: 100m
|
||||||
270
eventsource-kfp.yaml
Normal file
270
eventsource-kfp.yaml
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
# Argo Events - EventSource for KFP and NATS integration
|
||||||
|
# Enables bidirectional triggering between Argo Workflows and Kubeflow Pipelines
|
||||||
|
---
|
||||||
|
apiVersion: argoproj.io/v1alpha1
|
||||||
|
kind: EventSource
|
||||||
|
metadata:
|
||||||
|
name: kfp-events
|
||||||
|
namespace: ai-ml
|
||||||
|
labels:
|
||||||
|
app.kubernetes.io/name: kfp-events
|
||||||
|
app.kubernetes.io/part-of: llm-workflows
|
||||||
|
spec:
|
||||||
|
service:
|
||||||
|
ports:
|
||||||
|
- name: webhook
|
||||||
|
port: 12000
|
||||||
|
targetPort: 12000
|
||||||
|
# Webhook to receive KFP pipeline completion events
|
||||||
|
webhook:
|
||||||
|
kfp-completion:
|
||||||
|
port: "12000"
|
||||||
|
endpoint: /kfp/completion
|
||||||
|
method: POST
|
||||||
|
kfp-failure:
|
||||||
|
port: "12000"
|
||||||
|
endpoint: /kfp/failure
|
||||||
|
method: POST
|
||||||
|
# NATS for receiving pipeline trigger requests
|
||||||
|
nats:
|
||||||
|
pipeline-trigger:
|
||||||
|
url: nats://nats.ai-ml.svc.cluster.local:4222
|
||||||
|
subject: ai.pipeline.trigger
|
||||||
|
jsonBody: true
|
||||||
|
argo-trigger:
|
||||||
|
url: nats://nats.ai-ml.svc.cluster.local:4222
|
||||||
|
subject: ai.argo.trigger
|
||||||
|
jsonBody: true
|
||||||
|
kfp-trigger:
|
||||||
|
url: nats://nats.ai-ml.svc.cluster.local:4222
|
||||||
|
subject: ai.kfp.trigger
|
||||||
|
jsonBody: true
|
||||||
|
---
|
||||||
|
# Sensor for handling KFP completion events
|
||||||
|
apiVersion: argoproj.io/v1alpha1
|
||||||
|
kind: Sensor
|
||||||
|
metadata:
|
||||||
|
name: kfp-completion-sensor
|
||||||
|
namespace: ai-ml
|
||||||
|
labels:
|
||||||
|
app.kubernetes.io/name: kfp-completion-sensor
|
||||||
|
app.kubernetes.io/part-of: llm-workflows
|
||||||
|
spec:
|
||||||
|
dependencies:
|
||||||
|
- name: kfp-success
|
||||||
|
eventSourceName: kfp-events
|
||||||
|
eventName: kfp-completion
|
||||||
|
filters:
|
||||||
|
data:
|
||||||
|
- path: body.status
|
||||||
|
type: string
|
||||||
|
value:
|
||||||
|
- "SUCCEEDED"
|
||||||
|
- name: kfp-failure
|
||||||
|
eventSourceName: kfp-events
|
||||||
|
eventName: kfp-failure
|
||||||
|
triggers:
|
||||||
|
# On KFP success, publish to NATS
|
||||||
|
- template:
|
||||||
|
name: notify-kfp-success
|
||||||
|
nats:
|
||||||
|
url: nats://nats.ai-ml.svc.cluster.local:4222
|
||||||
|
subject: ai.pipeline.status.completed
|
||||||
|
payload:
|
||||||
|
- src:
|
||||||
|
dependencyName: kfp-success
|
||||||
|
dataKey: body.run_id
|
||||||
|
dest: run_id
|
||||||
|
- src:
|
||||||
|
dependencyName: kfp-success
|
||||||
|
dataKey: body.pipeline_name
|
||||||
|
dest: pipeline_name
|
||||||
|
- src:
|
||||||
|
dependencyName: kfp-success
|
||||||
|
dataKey: body.status
|
||||||
|
dest: status
|
||||||
|
retryStrategy:
|
||||||
|
steps: 3
|
||||||
|
# On KFP failure, trigger recovery workflow
|
||||||
|
- template:
|
||||||
|
name: kfp-failure-recovery
|
||||||
|
k8s:
|
||||||
|
operation: create
|
||||||
|
source:
|
||||||
|
resource:
|
||||||
|
apiVersion: argoproj.io/v1alpha1
|
||||||
|
kind: Workflow
|
||||||
|
metadata:
|
||||||
|
generateName: kfp-failure-handler-
|
||||||
|
namespace: ai-ml
|
||||||
|
spec:
|
||||||
|
entrypoint: notify-failure
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: run-id
|
||||||
|
- name: pipeline-name
|
||||||
|
- name: error-message
|
||||||
|
templates:
|
||||||
|
- name: notify-failure
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: run-id
|
||||||
|
- name: pipeline-name
|
||||||
|
- name: error-message
|
||||||
|
script:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
source: |
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "nats-py"])
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import nats
|
||||||
|
|
||||||
|
async def notify():
|
||||||
|
nc = await nats.connect("nats://nats.ai-ml.svc.cluster.local:4222")
|
||||||
|
await nc.publish(
|
||||||
|
"ai.pipeline.status.failed",
|
||||||
|
json.dumps({
|
||||||
|
"run_id": "{{inputs.parameters.run-id}}",
|
||||||
|
"pipeline_name": "{{inputs.parameters.pipeline-name}}",
|
||||||
|
"error": "{{inputs.parameters.error-message}}",
|
||||||
|
"source": "kubeflow"
|
||||||
|
}).encode()
|
||||||
|
)
|
||||||
|
await nc.close()
|
||||||
|
|
||||||
|
asyncio.run(notify())
|
||||||
|
parameters:
|
||||||
|
- src:
|
||||||
|
dependencyName: kfp-failure
|
||||||
|
dataKey: body.run_id
|
||||||
|
dest: spec.arguments.parameters.0.value
|
||||||
|
- src:
|
||||||
|
dependencyName: kfp-failure
|
||||||
|
dataKey: body.pipeline_name
|
||||||
|
dest: spec.arguments.parameters.1.value
|
||||||
|
- src:
|
||||||
|
dependencyName: kfp-failure
|
||||||
|
dataKey: body.error
|
||||||
|
dest: spec.arguments.parameters.2.value
|
||||||
|
retryStrategy:
|
||||||
|
steps: 3
|
||||||
|
---
|
||||||
|
# Sensor for NATS-triggered Argo Workflows
|
||||||
|
apiVersion: argoproj.io/v1alpha1
|
||||||
|
kind: Sensor
|
||||||
|
metadata:
|
||||||
|
name: nats-argo-sensor
|
||||||
|
namespace: ai-ml
|
||||||
|
labels:
|
||||||
|
app.kubernetes.io/name: nats-argo-sensor
|
||||||
|
app.kubernetes.io/part-of: llm-workflows
|
||||||
|
spec:
|
||||||
|
dependencies:
|
||||||
|
- name: argo-trigger
|
||||||
|
eventSourceName: kfp-events
|
||||||
|
eventName: argo-trigger
|
||||||
|
triggers:
|
||||||
|
- template:
|
||||||
|
name: trigger-argo-workflow
|
||||||
|
k8s:
|
||||||
|
operation: create
|
||||||
|
source:
|
||||||
|
resource:
|
||||||
|
apiVersion: argoproj.io/v1alpha1
|
||||||
|
kind: Workflow
|
||||||
|
metadata:
|
||||||
|
generateName: nats-triggered-
|
||||||
|
namespace: ai-ml
|
||||||
|
spec:
|
||||||
|
workflowTemplateRef:
|
||||||
|
name: placeholder
|
||||||
|
arguments:
|
||||||
|
parameters: []
|
||||||
|
parameters:
|
||||||
|
- src:
|
||||||
|
dependencyName: argo-trigger
|
||||||
|
dataKey: body.template
|
||||||
|
dest: spec.workflowTemplateRef.name
|
||||||
|
- src:
|
||||||
|
dependencyName: argo-trigger
|
||||||
|
dataKey: body.parameters
|
||||||
|
dest: spec.arguments.parameters
|
||||||
|
retryStrategy:
|
||||||
|
steps: 3
|
||||||
|
---
|
||||||
|
# Sensor for NATS-triggered KFP Pipelines
|
||||||
|
apiVersion: argoproj.io/v1alpha1
|
||||||
|
kind: Sensor
|
||||||
|
metadata:
|
||||||
|
name: nats-kfp-sensor
|
||||||
|
namespace: ai-ml
|
||||||
|
labels:
|
||||||
|
app.kubernetes.io/name: nats-kfp-sensor
|
||||||
|
app.kubernetes.io/part-of: llm-workflows
|
||||||
|
spec:
|
||||||
|
dependencies:
|
||||||
|
- name: kfp-trigger
|
||||||
|
eventSourceName: kfp-events
|
||||||
|
eventName: kfp-trigger
|
||||||
|
triggers:
|
||||||
|
# Trigger KFP via Argo Workflow (uses kfp-trigger template)
|
||||||
|
- template:
|
||||||
|
name: trigger-kfp-via-argo
|
||||||
|
k8s:
|
||||||
|
operation: create
|
||||||
|
source:
|
||||||
|
resource:
|
||||||
|
apiVersion: argoproj.io/v1alpha1
|
||||||
|
kind: Workflow
|
||||||
|
metadata:
|
||||||
|
generateName: kfp-via-nats-
|
||||||
|
namespace: ai-ml
|
||||||
|
spec:
|
||||||
|
workflowTemplateRef:
|
||||||
|
name: kfp-trigger
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: pipeline-id
|
||||||
|
value: ""
|
||||||
|
- name: pipeline-params
|
||||||
|
value: "{}"
|
||||||
|
- name: wait-for-completion
|
||||||
|
value: "true"
|
||||||
|
parameters:
|
||||||
|
- src:
|
||||||
|
dependencyName: kfp-trigger
|
||||||
|
dataKey: body.pipeline_id
|
||||||
|
dest: spec.arguments.parameters.0.value
|
||||||
|
- src:
|
||||||
|
dependencyName: kfp-trigger
|
||||||
|
dataKey: body.parameters
|
||||||
|
dest: spec.arguments.parameters.1.value
|
||||||
|
operation: "stringify"
|
||||||
|
- src:
|
||||||
|
dependencyName: kfp-trigger
|
||||||
|
dataKey: body.wait
|
||||||
|
dest: spec.arguments.parameters.2.value
|
||||||
|
operation: "stringify"
|
||||||
|
retryStrategy:
|
||||||
|
steps: 3
|
||||||
|
---
|
||||||
|
# Service for the EventSource webhook
|
||||||
|
apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
name: kfp-events-webhook
|
||||||
|
namespace: ai-ml
|
||||||
|
labels:
|
||||||
|
app.kubernetes.io/name: kfp-events
|
||||||
|
app.kubernetes.io/part-of: llm-workflows
|
||||||
|
spec:
|
||||||
|
selector:
|
||||||
|
eventsource-name: kfp-events
|
||||||
|
ports:
|
||||||
|
- name: webhook
|
||||||
|
port: 12000
|
||||||
|
targetPort: 12000
|
||||||
555
hybrid-ml-training.yaml
Normal file
555
hybrid-ml-training.yaml
Normal file
@@ -0,0 +1,555 @@
|
|||||||
|
# Hybrid ML Training Workflow
|
||||||
|
# Combines Argo Workflows orchestration with Kubeflow Pipeline ML components
|
||||||
|
# Use case: Train a model using data from Milvus, with checkpointing and evaluation
|
||||||
|
---
|
||||||
|
apiVersion: argoproj.io/v1alpha1
|
||||||
|
kind: WorkflowTemplate
|
||||||
|
metadata:
|
||||||
|
name: hybrid-ml-training
|
||||||
|
namespace: ai-ml
|
||||||
|
labels:
|
||||||
|
app.kubernetes.io/name: hybrid-ml-training
|
||||||
|
app.kubernetes.io/part-of: llm-workflows
|
||||||
|
annotations:
|
||||||
|
description: |
|
||||||
|
Demonstrates hybrid Argo+KFP workflow:
|
||||||
|
- Argo handles orchestration, branching, retry logic
|
||||||
|
- KFP pipelines handle ML-specific operations (with caching)
|
||||||
|
- NATS for status updates to frontends
|
||||||
|
spec:
|
||||||
|
entrypoint: hybrid-training
|
||||||
|
serviceAccountName: argo-workflow
|
||||||
|
|
||||||
|
# Artifact repository for model checkpoints
|
||||||
|
artifactRepositoryRef:
|
||||||
|
configMap: artifact-repository
|
||||||
|
key: default
|
||||||
|
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: collection-name
|
||||||
|
description: "Milvus collection to pull training data from"
|
||||||
|
value: "dnd_text_embeddings"
|
||||||
|
- name: model-name
|
||||||
|
description: "Base model for fine-tuning"
|
||||||
|
value: "mistralai/Mistral-7B-v0.3"
|
||||||
|
- name: lora-rank
|
||||||
|
description: "LoRA rank (higher = more params)"
|
||||||
|
value: "16"
|
||||||
|
- name: epochs
|
||||||
|
description: "Training epochs"
|
||||||
|
value: "3"
|
||||||
|
- name: batch-size
|
||||||
|
description: "Training batch size"
|
||||||
|
value: "4"
|
||||||
|
- name: output-path
|
||||||
|
description: "S3 path for model output"
|
||||||
|
value: "s3://models/lora-adapters"
|
||||||
|
- name: notify-nats
|
||||||
|
description: "Publish status to NATS"
|
||||||
|
value: "true"
|
||||||
|
|
||||||
|
# Volumes for GPU caching
|
||||||
|
volumes:
|
||||||
|
- name: model-cache
|
||||||
|
persistentVolumeClaim:
|
||||||
|
claimName: model-cache-pvc
|
||||||
|
- name: shm
|
||||||
|
emptyDir:
|
||||||
|
medium: Memory
|
||||||
|
sizeLimit: 16Gi
|
||||||
|
|
||||||
|
templates:
|
||||||
|
# Main DAG orchestrating the workflow
|
||||||
|
- name: hybrid-training
|
||||||
|
dag:
|
||||||
|
tasks:
|
||||||
|
- name: notify-start
|
||||||
|
template: nats-notify
|
||||||
|
when: "{{workflow.parameters.notify-nats}} == true"
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: subject
|
||||||
|
value: "ai.pipeline.status.{{workflow.name}}"
|
||||||
|
- name: message
|
||||||
|
value: '{"status": "started", "pipeline": "hybrid-ml-training"}'
|
||||||
|
|
||||||
|
- name: prepare-data
|
||||||
|
template: extract-training-data
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: collection-name
|
||||||
|
value: "{{workflow.parameters.collection-name}}"
|
||||||
|
|
||||||
|
- name: validate-data
|
||||||
|
template: validate-dataset
|
||||||
|
dependencies: [prepare-data]
|
||||||
|
arguments:
|
||||||
|
artifacts:
|
||||||
|
- name: dataset
|
||||||
|
from: "{{tasks.prepare-data.outputs.artifacts.dataset}}"
|
||||||
|
|
||||||
|
# KFP Pipeline: Run embedding generation if needed
|
||||||
|
- name: generate-embeddings
|
||||||
|
template: trigger-kfp
|
||||||
|
dependencies: [validate-data]
|
||||||
|
when: "{{tasks.validate-data.outputs.parameters.needs-embeddings}} == true"
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: pipeline-id
|
||||||
|
value: "embedding-generation"
|
||||||
|
- name: params
|
||||||
|
value: '{"input_path": "{{tasks.prepare-data.outputs.parameters.data-path}}"}'
|
||||||
|
|
||||||
|
# Training step (runs on GPU)
|
||||||
|
- name: train-lora
|
||||||
|
template: lora-training
|
||||||
|
dependencies: [validate-data, generate-embeddings]
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: model-name
|
||||||
|
value: "{{workflow.parameters.model-name}}"
|
||||||
|
- name: lora-rank
|
||||||
|
value: "{{workflow.parameters.lora-rank}}"
|
||||||
|
- name: epochs
|
||||||
|
value: "{{workflow.parameters.epochs}}"
|
||||||
|
- name: batch-size
|
||||||
|
value: "{{workflow.parameters.batch-size}}"
|
||||||
|
artifacts:
|
||||||
|
- name: dataset
|
||||||
|
from: "{{tasks.prepare-data.outputs.artifacts.dataset}}"
|
||||||
|
|
||||||
|
# Evaluate model
|
||||||
|
- name: evaluate
|
||||||
|
template: evaluate-model
|
||||||
|
dependencies: [train-lora]
|
||||||
|
arguments:
|
||||||
|
artifacts:
|
||||||
|
- name: adapter
|
||||||
|
from: "{{tasks.train-lora.outputs.artifacts.adapter}}"
|
||||||
|
|
||||||
|
# Branch based on evaluation results
|
||||||
|
- name: check-quality
|
||||||
|
template: quality-gate
|
||||||
|
dependencies: [evaluate]
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: eval-score
|
||||||
|
value: "{{tasks.evaluate.outputs.parameters.score}}"
|
||||||
|
- name: threshold
|
||||||
|
value: "0.7"
|
||||||
|
|
||||||
|
# If quality is good, upload to S3
|
||||||
|
- name: upload-model
|
||||||
|
template: upload-to-s3
|
||||||
|
dependencies: [check-quality]
|
||||||
|
when: "{{tasks.check-quality.outputs.parameters.passed}} == true"
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: output-path
|
||||||
|
value: "{{workflow.parameters.output-path}}"
|
||||||
|
artifacts:
|
||||||
|
- name: adapter
|
||||||
|
from: "{{tasks.train-lora.outputs.artifacts.adapter}}"
|
||||||
|
|
||||||
|
# If quality is poor, trigger retraining with different params
|
||||||
|
- name: retry-training
|
||||||
|
template: adjust-and-retry
|
||||||
|
dependencies: [check-quality]
|
||||||
|
when: "{{tasks.check-quality.outputs.parameters.passed}} == false"
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: current-rank
|
||||||
|
value: "{{workflow.parameters.lora-rank}}"
|
||||||
|
- name: current-epochs
|
||||||
|
value: "{{workflow.parameters.epochs}}"
|
||||||
|
|
||||||
|
- name: notify-complete
|
||||||
|
template: nats-notify
|
||||||
|
when: "{{workflow.parameters.notify-nats}} == true"
|
||||||
|
dependencies: [upload-model]
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: subject
|
||||||
|
value: "ai.pipeline.status.{{workflow.name}}"
|
||||||
|
- name: message
|
||||||
|
value: '{"status": "completed", "score": "{{tasks.evaluate.outputs.parameters.score}}"}'
|
||||||
|
|
||||||
|
# Extract training data from Milvus
|
||||||
|
- name: extract-training-data
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: collection-name
|
||||||
|
outputs:
|
||||||
|
artifacts:
|
||||||
|
- name: dataset
|
||||||
|
path: /tmp/dataset
|
||||||
|
parameters:
|
||||||
|
- name: data-path
|
||||||
|
valueFrom:
|
||||||
|
path: /tmp/data-path
|
||||||
|
script:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
source: |
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "pymilvus", "pandas", "pyarrow"])
|
||||||
|
|
||||||
|
from pymilvus import connections, Collection
|
||||||
|
import pandas as pd
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
connections.connect(host="milvus.ai-ml.svc.cluster.local", port=19530)
|
||||||
|
|
||||||
|
collection = Collection("{{inputs.parameters.collection-name}}")
|
||||||
|
collection.load()
|
||||||
|
|
||||||
|
# Query all training samples
|
||||||
|
results = collection.query(
|
||||||
|
expr="source != ''",
|
||||||
|
output_fields=["text", "source", "metadata"],
|
||||||
|
limit=10000
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to training format
|
||||||
|
df = pd.DataFrame(results)
|
||||||
|
output_dir = Path("/tmp/dataset")
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save as parquet for efficient loading
|
||||||
|
df.to_parquet(output_dir / "train.parquet")
|
||||||
|
|
||||||
|
print(f"Extracted {len(df)} samples")
|
||||||
|
with open("/tmp/data-path", "w") as f:
|
||||||
|
f.write(str(output_dir / "train.parquet"))
|
||||||
|
|
||||||
|
# Validate dataset
|
||||||
|
- name: validate-dataset
|
||||||
|
inputs:
|
||||||
|
artifacts:
|
||||||
|
- name: dataset
|
||||||
|
path: /tmp/dataset
|
||||||
|
outputs:
|
||||||
|
parameters:
|
||||||
|
- name: needs-embeddings
|
||||||
|
valueFrom:
|
||||||
|
path: /tmp/needs-embeddings
|
||||||
|
- name: sample-count
|
||||||
|
valueFrom:
|
||||||
|
path: /tmp/sample-count
|
||||||
|
script:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
source: |
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "pandas", "pyarrow"])
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
dataset_dir = Path("/tmp/dataset")
|
||||||
|
parquet_files = list(dataset_dir.glob("*.parquet"))
|
||||||
|
|
||||||
|
if not parquet_files:
|
||||||
|
raise ValueError("No parquet files found in dataset")
|
||||||
|
|
||||||
|
df = pd.read_parquet(parquet_files[0])
|
||||||
|
sample_count = len(df)
|
||||||
|
print(f"Dataset contains {sample_count} samples")
|
||||||
|
|
||||||
|
# Check if embeddings column exists
|
||||||
|
needs_embeddings = "embedding" not in df.columns
|
||||||
|
|
||||||
|
with open("/tmp/needs-embeddings", "w") as f:
|
||||||
|
f.write(str(needs_embeddings).lower())
|
||||||
|
|
||||||
|
with open("/tmp/sample-count", "w") as f:
|
||||||
|
f.write(str(sample_count))
|
||||||
|
|
||||||
|
# Trigger KFP pipeline
|
||||||
|
- name: trigger-kfp
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: pipeline-id
|
||||||
|
- name: params
|
||||||
|
outputs:
|
||||||
|
parameters:
|
||||||
|
- name: run-id
|
||||||
|
valueFrom:
|
||||||
|
path: /tmp/run-id
|
||||||
|
script:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
source: |
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "kfp==2.12.1"])
|
||||||
|
|
||||||
|
from kfp import Client
|
||||||
|
|
||||||
|
client = Client(host="http://ml-pipeline.kubeflow.svc.cluster.local:8888")
|
||||||
|
params = json.loads('''{{inputs.parameters.params}}''')
|
||||||
|
|
||||||
|
run = client.create_run_from_pipeline_func(
|
||||||
|
pipeline_func=None,
|
||||||
|
pipeline_id="{{inputs.parameters.pipeline-id}}",
|
||||||
|
arguments=params
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Triggered KFP pipeline: {run.run_id}")
|
||||||
|
with open("/tmp/run-id", "w") as f:
|
||||||
|
f.write(run.run_id)
|
||||||
|
|
||||||
|
# LoRA training (GPU)
|
||||||
|
- name: lora-training
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: model-name
|
||||||
|
- name: lora-rank
|
||||||
|
- name: epochs
|
||||||
|
- name: batch-size
|
||||||
|
artifacts:
|
||||||
|
- name: dataset
|
||||||
|
path: /data/dataset
|
||||||
|
outputs:
|
||||||
|
artifacts:
|
||||||
|
- name: adapter
|
||||||
|
path: /output/adapter
|
||||||
|
- name: logs
|
||||||
|
path: /output/logs
|
||||||
|
podSpecPatch: |
|
||||||
|
containers:
|
||||||
|
- name: main
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
amd.com/gpu: 1
|
||||||
|
limits:
|
||||||
|
amd.com/gpu: 1
|
||||||
|
script:
|
||||||
|
image: ghcr.io/billy-davies-2/lora-trainer:latest
|
||||||
|
command: [python]
|
||||||
|
env:
|
||||||
|
- name: HF_HOME
|
||||||
|
value: /cache/huggingface
|
||||||
|
- name: TRANSFORMERS_CACHE
|
||||||
|
value: /cache/huggingface
|
||||||
|
volumeMounts:
|
||||||
|
- name: model-cache
|
||||||
|
mountPath: /cache
|
||||||
|
- name: shm
|
||||||
|
mountPath: /dev/shm
|
||||||
|
source: |
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Training script
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
|
||||||
|
from peft import LoraConfig, get_peft_model
|
||||||
|
from datasets import load_dataset
|
||||||
|
from trl import SFTTrainer
|
||||||
|
|
||||||
|
model_name = "{{inputs.parameters.model-name}}"
|
||||||
|
lora_rank = int("{{inputs.parameters.lora-rank}}")
|
||||||
|
epochs = int("{{inputs.parameters.epochs}}")
|
||||||
|
batch_size = int("{{inputs.parameters.batch-size}}")
|
||||||
|
|
||||||
|
# Load dataset
|
||||||
|
dataset = load_dataset("parquet", data_files="/data/dataset/*.parquet", split="train")
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype="auto",
|
||||||
|
device_map="auto"
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
|
# Configure LoRA
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
r=lora_rank,
|
||||||
|
lora_alpha=lora_rank * 2,
|
||||||
|
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||||
|
lora_dropout=0.05,
|
||||||
|
bias="none",
|
||||||
|
task_type="CAUSAL_LM"
|
||||||
|
)
|
||||||
|
|
||||||
|
model = get_peft_model(model, lora_config)
|
||||||
|
|
||||||
|
# Training
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir="/output/adapter",
|
||||||
|
num_train_epochs=epochs,
|
||||||
|
per_device_train_batch_size=batch_size,
|
||||||
|
gradient_accumulation_steps=4,
|
||||||
|
learning_rate=2e-4,
|
||||||
|
logging_dir="/output/logs",
|
||||||
|
save_strategy="epoch"
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
model=model,
|
||||||
|
train_dataset=dataset,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
args=training_args,
|
||||||
|
dataset_text_field="text"
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
trainer.save_model("/output/adapter")
|
||||||
|
print("Training complete!")
|
||||||
|
|
||||||
|
# Evaluate model
|
||||||
|
- name: evaluate-model
|
||||||
|
inputs:
|
||||||
|
artifacts:
|
||||||
|
- name: adapter
|
||||||
|
path: /input/adapter
|
||||||
|
outputs:
|
||||||
|
parameters:
|
||||||
|
- name: score
|
||||||
|
valueFrom:
|
||||||
|
path: /tmp/score
|
||||||
|
script:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
source: |
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "httpx"])
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Run evaluation using vLLM with the adapter
|
||||||
|
test_prompts = [
|
||||||
|
"What is the capital of France?",
|
||||||
|
"Explain machine learning in simple terms.",
|
||||||
|
"Write a haiku about coding."
|
||||||
|
]
|
||||||
|
|
||||||
|
scores = []
|
||||||
|
with httpx.Client(timeout=120.0) as client:
|
||||||
|
for prompt in test_prompts:
|
||||||
|
response = client.post(
|
||||||
|
"http://llm-draft.ai-ml.svc.cluster.local:8000/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "local-adapter",
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"max_tokens": 200
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Simple scoring based on response coherence
|
||||||
|
result = response.json()
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
score = min(1.0, len(content) / 100) # Placeholder scoring
|
||||||
|
scores.append(score)
|
||||||
|
|
||||||
|
avg_score = sum(scores) / len(scores)
|
||||||
|
print(f"Average evaluation score: {avg_score}")
|
||||||
|
|
||||||
|
with open("/tmp/score", "w") as f:
|
||||||
|
f.write(str(round(avg_score, 3)))
|
||||||
|
|
||||||
|
# Quality gate
|
||||||
|
- name: quality-gate
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: eval-score
|
||||||
|
- name: threshold
|
||||||
|
outputs:
|
||||||
|
parameters:
|
||||||
|
- name: passed
|
||||||
|
valueFrom:
|
||||||
|
path: /tmp/passed
|
||||||
|
script:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
source: |
|
||||||
|
score = float("{{inputs.parameters.eval-score}}")
|
||||||
|
threshold = float("{{inputs.parameters.threshold}}")
|
||||||
|
|
||||||
|
passed = score >= threshold
|
||||||
|
print(f"Score {score} {'passed' if passed else 'failed'} threshold {threshold}")
|
||||||
|
|
||||||
|
with open("/tmp/passed", "w") as f:
|
||||||
|
f.write(str(passed).lower())
|
||||||
|
|
||||||
|
# Upload to S3
|
||||||
|
- name: upload-to-s3
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: output-path
|
||||||
|
artifacts:
|
||||||
|
- name: adapter
|
||||||
|
path: /input/adapter
|
||||||
|
script:
|
||||||
|
image: amazon/aws-cli:latest
|
||||||
|
command: [bash]
|
||||||
|
env:
|
||||||
|
- name: AWS_ACCESS_KEY_ID
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: s3-credentials
|
||||||
|
key: access-key
|
||||||
|
- name: AWS_SECRET_ACCESS_KEY
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: s3-credentials
|
||||||
|
key: secret-key
|
||||||
|
- name: AWS_ENDPOINT_URL
|
||||||
|
value: "https://quobjects.billy.davies.cloud"
|
||||||
|
source: |
|
||||||
|
aws s3 cp --recursive /input/adapter "{{inputs.parameters.output-path}}/$(date +%Y%m%d-%H%M%S)/"
|
||||||
|
echo "Uploaded adapter to {{inputs.parameters.output-path}}"
|
||||||
|
|
||||||
|
# Adjust parameters and retry
|
||||||
|
- name: adjust-and-retry
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: current-rank
|
||||||
|
- name: current-epochs
|
||||||
|
script:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
source: |
|
||||||
|
current_rank = int("{{inputs.parameters.current-rank}}")
|
||||||
|
current_epochs = int("{{inputs.parameters.current-epochs}}")
|
||||||
|
|
||||||
|
# Increase rank and epochs for next attempt
|
||||||
|
new_rank = min(64, current_rank * 2)
|
||||||
|
new_epochs = current_epochs + 2
|
||||||
|
|
||||||
|
print(f"Adjusting parameters: rank {current_rank}->{new_rank}, epochs {current_epochs}->{new_epochs}")
|
||||||
|
print("TODO: Trigger new workflow with adjusted parameters")
|
||||||
|
|
||||||
|
# NATS notification
|
||||||
|
- name: nats-notify
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: subject
|
||||||
|
- name: message
|
||||||
|
script:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
source: |
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "nats-py"])
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import nats
|
||||||
|
|
||||||
|
async def notify():
|
||||||
|
nc = await nats.connect("nats://nats.ai-ml.svc.cluster.local:4222")
|
||||||
|
await nc.publish(
|
||||||
|
"{{inputs.parameters.subject}}",
|
||||||
|
b'''{{inputs.parameters.message}}'''
|
||||||
|
)
|
||||||
|
await nc.close()
|
||||||
|
|
||||||
|
asyncio.run(notify())
|
||||||
237
kfp-integration.yaml
Normal file
237
kfp-integration.yaml
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
# Argo Workflows + Kubeflow Pipelines Integration
|
||||||
|
# This template allows Argo Workflows to trigger KFP pipelines and vice versa
|
||||||
|
---
|
||||||
|
apiVersion: argoproj.io/v1alpha1
|
||||||
|
kind: WorkflowTemplate
|
||||||
|
metadata:
|
||||||
|
name: kfp-trigger
|
||||||
|
namespace: ai-ml
|
||||||
|
labels:
|
||||||
|
app.kubernetes.io/name: kfp-trigger
|
||||||
|
app.kubernetes.io/part-of: llm-workflows
|
||||||
|
spec:
|
||||||
|
entrypoint: trigger-kfp-pipeline
|
||||||
|
serviceAccountName: argo-workflow
|
||||||
|
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: pipeline-id
|
||||||
|
description: "Kubeflow Pipeline ID or name"
|
||||||
|
- name: pipeline-params
|
||||||
|
description: "JSON object of pipeline parameters"
|
||||||
|
value: "{}"
|
||||||
|
- name: experiment-name
|
||||||
|
description: "KFP Experiment to use"
|
||||||
|
value: "Default"
|
||||||
|
- name: wait-for-completion
|
||||||
|
description: "Wait for pipeline to complete"
|
||||||
|
value: "true"
|
||||||
|
|
||||||
|
templates:
|
||||||
|
- name: trigger-kfp-pipeline
|
||||||
|
steps:
|
||||||
|
- - name: submit-run
|
||||||
|
template: submit-kfp-run
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: pipeline-id
|
||||||
|
value: "{{workflow.parameters.pipeline-id}}"
|
||||||
|
- name: pipeline-params
|
||||||
|
value: "{{workflow.parameters.pipeline-params}}"
|
||||||
|
- name: experiment-name
|
||||||
|
value: "{{workflow.parameters.experiment-name}}"
|
||||||
|
|
||||||
|
- - name: wait-completion
|
||||||
|
template: wait-for-kfp
|
||||||
|
when: "{{workflow.parameters.wait-for-completion}} == true"
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: run-id
|
||||||
|
value: "{{steps.submit-run.outputs.parameters.run-id}}"
|
||||||
|
|
||||||
|
- name: submit-kfp-run
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: pipeline-id
|
||||||
|
- name: pipeline-params
|
||||||
|
- name: experiment-name
|
||||||
|
outputs:
|
||||||
|
parameters:
|
||||||
|
- name: run-id
|
||||||
|
valueFrom:
|
||||||
|
path: /tmp/run-id
|
||||||
|
script:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
source: |
|
||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "kfp==2.12.1"])
|
||||||
|
|
||||||
|
from kfp import Client
|
||||||
|
|
||||||
|
KUBEFLOW_HOST = "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
|
||||||
|
|
||||||
|
client = Client(host=KUBEFLOW_HOST)
|
||||||
|
|
||||||
|
pipeline_id = "{{inputs.parameters.pipeline-id}}"
|
||||||
|
params = json.loads('''{{inputs.parameters.pipeline-params}}''')
|
||||||
|
experiment_name = "{{inputs.parameters.experiment-name}}"
|
||||||
|
|
||||||
|
# Get or create experiment
|
||||||
|
try:
|
||||||
|
experiment = client.get_experiment(experiment_name=experiment_name)
|
||||||
|
except:
|
||||||
|
experiment = client.create_experiment(name=experiment_name)
|
||||||
|
|
||||||
|
# Get pipeline by name or ID
|
||||||
|
try:
|
||||||
|
pipeline = client.get_pipeline(pipeline_id)
|
||||||
|
except:
|
||||||
|
# Try by name
|
||||||
|
pipelines = client.list_pipelines(filter=f'name="{pipeline_id}"')
|
||||||
|
if pipelines.pipelines:
|
||||||
|
pipeline = pipelines.pipelines[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Pipeline not found: {pipeline_id}")
|
||||||
|
|
||||||
|
# Create run
|
||||||
|
run = client.run_pipeline(
|
||||||
|
experiment_id=experiment.experiment_id,
|
||||||
|
job_name=f"{pipeline.display_name}-argo-{pipeline_id[:8]}",
|
||||||
|
pipeline_id=pipeline.pipeline_id,
|
||||||
|
params=params
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Submitted KFP run: {run.run_id}")
|
||||||
|
with open("/tmp/run-id", "w") as f:
|
||||||
|
f.write(run.run_id)
|
||||||
|
|
||||||
|
- name: wait-for-kfp
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: run-id
|
||||||
|
outputs:
|
||||||
|
parameters:
|
||||||
|
- name: status
|
||||||
|
valueFrom:
|
||||||
|
path: /tmp/status
|
||||||
|
script:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
source: |
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "kfp==2.12.1"])
|
||||||
|
|
||||||
|
from kfp import Client
|
||||||
|
|
||||||
|
KUBEFLOW_HOST = "http://ml-pipeline.kubeflow.svc.cluster.local:8888"
|
||||||
|
run_id = "{{inputs.parameters.run-id}}"
|
||||||
|
|
||||||
|
client = Client(host=KUBEFLOW_HOST)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
run = client.get_run(run_id)
|
||||||
|
state = run.run.status
|
||||||
|
|
||||||
|
print(f"Run {run_id} status: {state}")
|
||||||
|
|
||||||
|
if state in ["SUCCEEDED", "SKIPPED"]:
|
||||||
|
with open("/tmp/status", "w") as f:
|
||||||
|
f.write("SUCCEEDED")
|
||||||
|
break
|
||||||
|
elif state in ["FAILED", "ERROR", "CANCELLED"]:
|
||||||
|
with open("/tmp/status", "w") as f:
|
||||||
|
f.write(state)
|
||||||
|
raise Exception(f"Pipeline failed with status: {state}")
|
||||||
|
|
||||||
|
time.sleep(30)
|
||||||
|
|
||||||
|
---
|
||||||
|
# WorkflowTemplate for running KFP pipeline components as Argo steps
|
||||||
|
apiVersion: argoproj.io/v1alpha1
|
||||||
|
kind: WorkflowTemplate
|
||||||
|
metadata:
|
||||||
|
name: kfp-component-runner
|
||||||
|
namespace: ai-ml
|
||||||
|
labels:
|
||||||
|
app.kubernetes.io/name: kfp-component-runner
|
||||||
|
app.kubernetes.io/part-of: llm-workflows
|
||||||
|
spec:
|
||||||
|
entrypoint: run-component
|
||||||
|
serviceAccountName: argo-workflow
|
||||||
|
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: component-name
|
||||||
|
description: "Name of the KFP component to run"
|
||||||
|
- name: component-params
|
||||||
|
description: "JSON parameters for the component"
|
||||||
|
value: "{}"
|
||||||
|
|
||||||
|
templates:
|
||||||
|
- name: run-component
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: component-name
|
||||||
|
- name: component-params
|
||||||
|
outputs:
|
||||||
|
parameters:
|
||||||
|
- name: result
|
||||||
|
valueFrom:
|
||||||
|
path: /tmp/result.json
|
||||||
|
script:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
source: |
|
||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
subprocess.check_call([
|
||||||
|
sys.executable, "-m", "pip", "install", "-q",
|
||||||
|
"httpx", "pymilvus"
|
||||||
|
])
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
component_name = "{{inputs.parameters.component-name}}"
|
||||||
|
params = json.loads('''{{inputs.parameters.component-params}}''')
|
||||||
|
|
||||||
|
# Component implementations (mirrors KFP components)
|
||||||
|
COMPONENTS = {
|
||||||
|
"transcribe_audio": {
|
||||||
|
"url": "http://whisper-predictor.ai-ml.svc.cluster.local",
|
||||||
|
"endpoint": "/v1/audio/transcriptions"
|
||||||
|
},
|
||||||
|
"generate_embeddings": {
|
||||||
|
"url": "http://embeddings-predictor.ai-ml.svc.cluster.local",
|
||||||
|
"endpoint": "/embeddings"
|
||||||
|
},
|
||||||
|
"generate_response": {
|
||||||
|
"url": "http://llm-draft.ai-ml.svc.cluster.local:8000",
|
||||||
|
"endpoint": "/v1/chat/completions"
|
||||||
|
},
|
||||||
|
"synthesize_speech": {
|
||||||
|
"url": "http://tts-predictor.ai-ml.svc.cluster.local",
|
||||||
|
"endpoint": "/v1/audio/speech"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if component_name not in COMPONENTS:
|
||||||
|
raise ValueError(f"Unknown component: {component_name}")
|
||||||
|
|
||||||
|
config = COMPONENTS[component_name]
|
||||||
|
with httpx.Client(timeout=120.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{config['url']}{config['endpoint']}",
|
||||||
|
json=params
|
||||||
|
)
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
with open("/tmp/result.json", "w") as f:
|
||||||
|
json.dump(result, f)
|
||||||
|
|
||||||
|
print(f"Component {component_name} completed")
|
||||||
510
qlora-training.yaml
Normal file
510
qlora-training.yaml
Normal file
@@ -0,0 +1,510 @@
|
|||||||
|
# QLoRA Fine-tuning Workflow
|
||||||
|
# Trains QLora adapters from a reference model using data from Milvus vector database
|
||||||
|
# Triggered via NATS: ai.pipeline.trigger with pipeline="qlora-training"
|
||||||
|
---
|
||||||
|
apiVersion: argoproj.io/v1alpha1
|
||||||
|
kind: WorkflowTemplate
|
||||||
|
metadata:
|
||||||
|
name: qlora-training
|
||||||
|
namespace: ai-ml
|
||||||
|
labels:
|
||||||
|
app.kubernetes.io/name: qlora-training
|
||||||
|
app.kubernetes.io/part-of: llm-workflows
|
||||||
|
spec:
|
||||||
|
entrypoint: train-qlora
|
||||||
|
serviceAccountName: argo-workflow
|
||||||
|
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: reference-model
|
||||||
|
description: "Base model to fine-tune (HuggingFace model ID or path)"
|
||||||
|
value: "mistralai/Mistral-7B-Instruct-v0.3"
|
||||||
|
- name: output-name
|
||||||
|
description: "Name for the output QLora adapter"
|
||||||
|
value: "qlora-adapter"
|
||||||
|
- name: milvus-collections
|
||||||
|
description: "Comma-separated list of Milvus collections to use (empty = all available)"
|
||||||
|
value: ""
|
||||||
|
- name: learning-rate
|
||||||
|
value: "2e-4"
|
||||||
|
description: "Learning rate for training"
|
||||||
|
- name: num-epochs
|
||||||
|
value: "3"
|
||||||
|
description: "Number of training epochs"
|
||||||
|
- name: batch-size
|
||||||
|
value: "4"
|
||||||
|
description: "Training batch size"
|
||||||
|
- name: max-seq-length
|
||||||
|
value: "2048"
|
||||||
|
description: "Maximum sequence length"
|
||||||
|
- name: lora-r
|
||||||
|
value: "64"
|
||||||
|
description: "LoRA attention dimension"
|
||||||
|
- name: lora-alpha
|
||||||
|
value: "16"
|
||||||
|
description: "LoRA alpha parameter"
|
||||||
|
- name: lora-dropout
|
||||||
|
value: "0.05"
|
||||||
|
description: "LoRA dropout rate"
|
||||||
|
|
||||||
|
volumeClaimTemplates:
|
||||||
|
- metadata:
|
||||||
|
name: model-storage
|
||||||
|
spec:
|
||||||
|
accessModes: ["ReadWriteMany"]
|
||||||
|
storageClassName: nfs-slow
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
storage: 50Gi
|
||||||
|
|
||||||
|
templates:
|
||||||
|
- name: train-qlora
|
||||||
|
dag:
|
||||||
|
tasks:
|
||||||
|
- name: fetch-training-data
|
||||||
|
template: fetch-data
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: milvus-collections
|
||||||
|
value: "{{workflow.parameters.milvus-collections}}"
|
||||||
|
|
||||||
|
- name: prepare-dataset
|
||||||
|
template: prepare-data
|
||||||
|
dependencies: [fetch-training-data]
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: max-seq-length
|
||||||
|
value: "{{workflow.parameters.max-seq-length}}"
|
||||||
|
artifacts:
|
||||||
|
- name: raw-data
|
||||||
|
from: "{{tasks.fetch-training-data.outputs.artifacts.raw-data}}"
|
||||||
|
|
||||||
|
- name: train-model
|
||||||
|
template: train
|
||||||
|
dependencies: [prepare-dataset]
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: reference-model
|
||||||
|
value: "{{workflow.parameters.reference-model}}"
|
||||||
|
- name: output-name
|
||||||
|
value: "{{workflow.parameters.output-name}}"
|
||||||
|
- name: learning-rate
|
||||||
|
value: "{{workflow.parameters.learning-rate}}"
|
||||||
|
- name: num-epochs
|
||||||
|
value: "{{workflow.parameters.num-epochs}}"
|
||||||
|
- name: batch-size
|
||||||
|
value: "{{workflow.parameters.batch-size}}"
|
||||||
|
- name: max-seq-length
|
||||||
|
value: "{{workflow.parameters.max-seq-length}}"
|
||||||
|
- name: lora-r
|
||||||
|
value: "{{workflow.parameters.lora-r}}"
|
||||||
|
- name: lora-alpha
|
||||||
|
value: "{{workflow.parameters.lora-alpha}}"
|
||||||
|
- name: lora-dropout
|
||||||
|
value: "{{workflow.parameters.lora-dropout}}"
|
||||||
|
artifacts:
|
||||||
|
- name: training-data
|
||||||
|
from: "{{tasks.prepare-dataset.outputs.artifacts.training-data}}"
|
||||||
|
|
||||||
|
- name: fetch-data
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: milvus-collections
|
||||||
|
outputs:
|
||||||
|
artifacts:
|
||||||
|
- name: raw-data
|
||||||
|
path: /tmp/raw-data
|
||||||
|
container:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
import subprocess
|
||||||
|
subprocess.run(["pip", "install", "pymilvus", "-q"], check=True)
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from pymilvus import connections, Collection, utility
|
||||||
|
|
||||||
|
MILVUS_HOST = "milvus.ai-ml.svc.cluster.local"
|
||||||
|
MILVUS_PORT = 19530
|
||||||
|
collections_param = "{{inputs.parameters.milvus-collections}}"
|
||||||
|
|
||||||
|
output_dir = Path("/tmp/raw-data")
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"Connecting to Milvus at {MILVUS_HOST}:{MILVUS_PORT}")
|
||||||
|
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
|
||||||
|
|
||||||
|
# Determine which collections to use
|
||||||
|
if collections_param and collections_param.strip():
|
||||||
|
collection_names = [c.strip() for c in collections_param.split(",")]
|
||||||
|
print(f"Using specified collections: {collection_names}")
|
||||||
|
else:
|
||||||
|
# Get all available collections
|
||||||
|
collection_names = utility.list_collections()
|
||||||
|
print(f"Using all available collections: {collection_names}")
|
||||||
|
|
||||||
|
all_training_data = []
|
||||||
|
|
||||||
|
for collection_name in collection_names:
|
||||||
|
if not utility.has_collection(collection_name):
|
||||||
|
print(f"Warning: Collection {collection_name} not found, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Fetching data from collection: {collection_name}")
|
||||||
|
collection = Collection(collection_name)
|
||||||
|
collection.load()
|
||||||
|
|
||||||
|
# Query all data from the collection
|
||||||
|
# Note: Adjust field names based on your schema
|
||||||
|
try:
|
||||||
|
# Get collection schema to determine fields
|
||||||
|
schema = collection.schema
|
||||||
|
field_names = [field.name for field in schema.fields if field.name != "id"]
|
||||||
|
|
||||||
|
# Query all entities (limited to reasonable batch size)
|
||||||
|
# For large collections, you may want to implement pagination
|
||||||
|
results = collection.query(
|
||||||
|
expr="id >= 0",
|
||||||
|
output_fields=field_names,
|
||||||
|
limit=100000
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" Retrieved {len(results)} records from {collection_name}")
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
# Extract text field (adjust based on your schema)
|
||||||
|
text_content = result.get("text", "")
|
||||||
|
source = result.get("source", collection_name)
|
||||||
|
|
||||||
|
if text_content:
|
||||||
|
all_training_data.append({
|
||||||
|
"text": text_content,
|
||||||
|
"source": source,
|
||||||
|
"collection": collection_name
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error querying collection {collection_name}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Save all training data
|
||||||
|
output_file = output_dir / "training_data.json"
|
||||||
|
with open(output_file, "w") as f:
|
||||||
|
json.dump({"data": all_training_data}, f)
|
||||||
|
|
||||||
|
print(f"Total training samples collected: {len(all_training_data)}")
|
||||||
|
print(f"Saved to {output_file}")
|
||||||
|
|
||||||
|
connections.disconnect("default")
|
||||||
|
envFrom:
|
||||||
|
- configMapRef:
|
||||||
|
name: ai-services-config
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: 1Gi
|
||||||
|
cpu: 500m
|
||||||
|
|
||||||
|
- name: prepare-data
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: max-seq-length
|
||||||
|
artifacts:
|
||||||
|
- name: raw-data
|
||||||
|
path: /tmp/raw-data
|
||||||
|
outputs:
|
||||||
|
artifacts:
|
||||||
|
- name: training-data
|
||||||
|
path: /tmp/training-data
|
||||||
|
container:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [python]
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
max_seq_length = int("{{inputs.parameters.max-seq-length}}")
|
||||||
|
|
||||||
|
input_dir = Path("/tmp/raw-data")
|
||||||
|
output_dir = Path("/tmp/training-data")
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Load raw data
|
||||||
|
with open(input_dir / "training_data.json") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
raw_samples = data["data"]
|
||||||
|
print(f"Processing {len(raw_samples)} raw samples")
|
||||||
|
|
||||||
|
# Prepare data in instruction format for fine-tuning
|
||||||
|
# Using Alpaca-style format: instruction + response
|
||||||
|
training_samples = []
|
||||||
|
|
||||||
|
for sample in raw_samples:
|
||||||
|
text = sample["text"]
|
||||||
|
source = sample.get("source", "")
|
||||||
|
|
||||||
|
# Create instruction-response pairs
|
||||||
|
# You can customize this based on your use case
|
||||||
|
training_sample = {
|
||||||
|
"instruction": f"Based on the following information from {source}, provide a comprehensive response:",
|
||||||
|
"input": text[:max_seq_length // 2], # Truncate if needed
|
||||||
|
"output": text[:max_seq_length // 2],
|
||||||
|
"source": source
|
||||||
|
}
|
||||||
|
training_samples.append(training_sample)
|
||||||
|
|
||||||
|
# Split into train/validation (90/10)
|
||||||
|
split_idx = int(len(training_samples) * 0.9)
|
||||||
|
train_data = training_samples[:split_idx]
|
||||||
|
val_data = training_samples[split_idx:]
|
||||||
|
|
||||||
|
# Save prepared datasets
|
||||||
|
with open(output_dir / "train.json", "w") as f:
|
||||||
|
json.dump(train_data, f)
|
||||||
|
|
||||||
|
with open(output_dir / "validation.json", "w") as f:
|
||||||
|
json.dump(val_data, f)
|
||||||
|
|
||||||
|
print(f"Prepared {len(train_data)} training samples")
|
||||||
|
print(f"Prepared {len(val_data)} validation samples")
|
||||||
|
print("Data preparation complete")
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: 2Gi
|
||||||
|
cpu: 500m
|
||||||
|
|
||||||
|
- name: train
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: reference-model
|
||||||
|
- name: output-name
|
||||||
|
- name: learning-rate
|
||||||
|
- name: num-epochs
|
||||||
|
- name: batch-size
|
||||||
|
- name: max-seq-length
|
||||||
|
- name: lora-r
|
||||||
|
- name: lora-alpha
|
||||||
|
- name: lora-dropout
|
||||||
|
artifacts:
|
||||||
|
- name: training-data
|
||||||
|
path: /tmp/training-data
|
||||||
|
container:
|
||||||
|
image: python:3.13-slim
|
||||||
|
command: [bash]
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
set -e
|
||||||
|
|
||||||
|
echo "Installing dependencies..."
|
||||||
|
pip install -q torch transformers peft datasets accelerate bitsandbytes scipy
|
||||||
|
|
||||||
|
echo "Starting QLoRA training..."
|
||||||
|
|
||||||
|
python << 'EOF'
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from datasets import Dataset
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
BitsAndBytesConfig,
|
||||||
|
TrainingArguments,
|
||||||
|
Trainer,
|
||||||
|
DataCollatorForLanguageModeling
|
||||||
|
)
|
||||||
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
reference_model = "{{inputs.parameters.reference-model}}"
|
||||||
|
output_name = "{{inputs.parameters.output-name}}"
|
||||||
|
learning_rate = float("{{inputs.parameters.learning-rate}}")
|
||||||
|
num_epochs = int("{{inputs.parameters.num-epochs}}")
|
||||||
|
batch_size = int("{{inputs.parameters.batch-size}}")
|
||||||
|
max_seq_length = int("{{inputs.parameters.max-seq-length}}")
|
||||||
|
lora_r = int("{{inputs.parameters.lora-r}}")
|
||||||
|
lora_alpha = int("{{inputs.parameters.lora-alpha}}")
|
||||||
|
lora_dropout = float("{{inputs.parameters.lora-dropout}}")
|
||||||
|
|
||||||
|
data_dir = Path("/tmp/training-data")
|
||||||
|
output_dir = Path("/mnt/model-storage") / output_name
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"Training configuration:")
|
||||||
|
print(f" Model: {reference_model}")
|
||||||
|
print(f" Learning rate: {learning_rate}")
|
||||||
|
print(f" Epochs: {num_epochs}")
|
||||||
|
print(f" Batch size: {batch_size}")
|
||||||
|
print(f" Max sequence length: {max_seq_length}")
|
||||||
|
print(f" LoRA r: {lora_r}, alpha: {lora_alpha}, dropout: {lora_dropout}")
|
||||||
|
|
||||||
|
# Load datasets
|
||||||
|
with open(data_dir / "train.json") as f:
|
||||||
|
train_data = json.load(f)
|
||||||
|
|
||||||
|
with open(data_dir / "validation.json") as f:
|
||||||
|
val_data = json.load(f)
|
||||||
|
|
||||||
|
print(f"Loaded {len(train_data)} training samples, {len(val_data)} validation samples")
|
||||||
|
|
||||||
|
# Load tokenizer
|
||||||
|
print(f"Loading tokenizer from {reference_model}...")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(reference_model, trust_remote_code=True)
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
# Prepare datasets
|
||||||
|
def format_sample(sample):
|
||||||
|
instruction = sample.get("instruction", "")
|
||||||
|
input_text = sample.get("input", "")
|
||||||
|
output = sample.get("output", "")
|
||||||
|
|
||||||
|
if input_text:
|
||||||
|
prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{output}"
|
||||||
|
else:
|
||||||
|
prompt = f"### Instruction:\n{instruction}\n\n### Response:\n{output}"
|
||||||
|
|
||||||
|
return {"text": prompt}
|
||||||
|
|
||||||
|
train_dataset = Dataset.from_list([format_sample(s) for s in train_data])
|
||||||
|
val_dataset = Dataset.from_list([format_sample(s) for s in val_data])
|
||||||
|
|
||||||
|
# Tokenize datasets
|
||||||
|
def tokenize_function(examples):
|
||||||
|
return tokenizer(
|
||||||
|
examples["text"],
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_seq_length,
|
||||||
|
padding="max_length"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Tokenizing datasets...")
|
||||||
|
train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
|
||||||
|
val_dataset = val_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
|
||||||
|
|
||||||
|
# Configure quantization for QLoRA
|
||||||
|
bnb_config = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_compute_dtype=torch.float16,
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
print(f"Loading model {reference_model} with 4-bit quantization...")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
reference_model,
|
||||||
|
quantization_config=bnb_config,
|
||||||
|
device_map="auto",
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare model for training
|
||||||
|
model = prepare_model_for_kbit_training(model)
|
||||||
|
|
||||||
|
# Configure LoRA
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
r=lora_r,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
bias="none",
|
||||||
|
task_type="CAUSAL_LM"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add LoRA adapters
|
||||||
|
print("Adding LoRA adapters...")
|
||||||
|
model = get_peft_model(model, lora_config)
|
||||||
|
model.print_trainable_parameters()
|
||||||
|
|
||||||
|
# Training arguments
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir=str(output_dir / "checkpoints"),
|
||||||
|
num_train_epochs=num_epochs,
|
||||||
|
per_device_train_batch_size=batch_size,
|
||||||
|
per_device_eval_batch_size=batch_size,
|
||||||
|
gradient_accumulation_steps=4,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
fp16=True,
|
||||||
|
logging_steps=10,
|
||||||
|
evaluation_strategy="steps",
|
||||||
|
eval_steps=50,
|
||||||
|
save_strategy="steps",
|
||||||
|
save_steps=100,
|
||||||
|
save_total_limit=3,
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
report_to="none",
|
||||||
|
remove_unused_columns=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Data collator
|
||||||
|
data_collator = DataCollatorForLanguageModeling(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
mlm=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trainer
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=val_dataset,
|
||||||
|
data_collator=data_collator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train
|
||||||
|
print("Starting training...")
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Save final model
|
||||||
|
print(f"Saving QLora adapter to {output_dir}")
|
||||||
|
model.save_pretrained(str(output_dir / "final"))
|
||||||
|
tokenizer.save_pretrained(str(output_dir / "final"))
|
||||||
|
|
||||||
|
# Save training metadata
|
||||||
|
metadata = {
|
||||||
|
"reference_model": reference_model,
|
||||||
|
"output_name": output_name,
|
||||||
|
"training_params": {
|
||||||
|
"learning_rate": learning_rate,
|
||||||
|
"num_epochs": num_epochs,
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"max_seq_length": max_seq_length,
|
||||||
|
"lora_r": lora_r,
|
||||||
|
"lora_alpha": lora_alpha,
|
||||||
|
"lora_dropout": lora_dropout
|
||||||
|
},
|
||||||
|
"dataset_info": {
|
||||||
|
"train_samples": len(train_data),
|
||||||
|
"val_samples": len(val_data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(output_dir / "metadata.json", "w") as f:
|
||||||
|
json.dump(metadata, f, indent=2)
|
||||||
|
|
||||||
|
print("Training complete!")
|
||||||
|
print(f"QLora adapter saved to: {output_dir}")
|
||||||
|
EOF
|
||||||
|
envFrom:
|
||||||
|
- configMapRef:
|
||||||
|
name: ai-services-config
|
||||||
|
volumeMounts:
|
||||||
|
- name: model-storage
|
||||||
|
mountPath: /mnt/model-storage
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: 16Gi
|
||||||
|
cpu: 4
|
||||||
|
nvidia.com/gpu: 1
|
||||||
|
limits:
|
||||||
|
memory: 32Gi
|
||||||
|
cpu: 8
|
||||||
|
nvidia.com/gpu: 1
|
||||||
Reference in New Issue
Block a user