feat: Add ML training and batch inference workflows
- batch-inference: LLM inference with optional RAG - qlora-training: QLoRA adapter fine-tuning from Milvus - hybrid-ml-training: Multi-GPU distributed training - coqui-voice-training: XTTS voice cloning - document-ingestion: Ingest documents to Milvus - eventsource-kfp: Argo Events / Kubeflow integration - kfp-integration: Bridge between Argo and Kubeflow
This commit is contained in:
969
coqui-voice-training.yaml
Normal file
969
coqui-voice-training.yaml
Normal file
@@ -0,0 +1,969 @@
|
||||
# Coqui TTS Voice Training Workflow
|
||||
# Trains a custom voice model using Coqui TTS from audio samples
|
||||
# Triggered via NATS: ai.pipeline.trigger with pipeline="coqui-voice-training"
|
||||
---
|
||||
apiVersion: argoproj.io/v1alpha1
|
||||
kind: WorkflowTemplate
|
||||
metadata:
|
||||
name: coqui-voice-training
|
||||
namespace: ai-ml
|
||||
labels:
|
||||
app.kubernetes.io/name: coqui-voice-training
|
||||
app.kubernetes.io/part-of: llm-workflows
|
||||
spec:
|
||||
entrypoint: train-voice
|
||||
serviceAccountName: argo-workflow
|
||||
|
||||
arguments:
|
||||
parameters:
|
||||
- name: audio-source
|
||||
description: "URL to audio files (S3 bucket, HTTP, or NFS path with .wav/.mp3 files)"
|
||||
- name: transcripts-source
|
||||
description: "URL to transcripts file (CSV with audio_file,transcript columns) - leave empty to auto-transcribe"
|
||||
value: ""
|
||||
- name: voice-name
|
||||
description: "Name for the trained voice model"
|
||||
value: "custom-voice"
|
||||
- name: base-model
|
||||
description: "Base TTS model to fine-tune from"
|
||||
value: "tts_models/en/ljspeech/vits"
|
||||
- name: language
|
||||
description: "Language code (e.g., en, de, fr, es)"
|
||||
value: "en"
|
||||
- name: num-epochs
|
||||
description: "Number of training epochs"
|
||||
value: "100"
|
||||
- name: batch-size
|
||||
description: "Training batch size"
|
||||
value: "16"
|
||||
- name: learning-rate
|
||||
description: "Learning rate for training"
|
||||
value: "0.0001"
|
||||
- name: sample-rate
|
||||
description: "Target sample rate for audio (Hz)"
|
||||
value: "22050"
|
||||
- name: output-path
|
||||
description: "Path to store the trained model (S3 or NFS)"
|
||||
value: "/models/tts/custom"
|
||||
|
||||
volumeClaimTemplates:
|
||||
- metadata:
|
||||
name: training-workspace
|
||||
spec:
|
||||
accessModes: ["ReadWriteMany"]
|
||||
storageClassName: nfs-slow
|
||||
resources:
|
||||
requests:
|
||||
storage: 50Gi
|
||||
|
||||
templates:
|
||||
- name: train-voice
|
||||
dag:
|
||||
tasks:
|
||||
- name: fetch-audio
|
||||
template: fetch-audio-files
|
||||
arguments:
|
||||
parameters:
|
||||
- name: audio-source
|
||||
value: "{{workflow.parameters.audio-source}}"
|
||||
|
||||
- name: fetch-transcripts
|
||||
template: fetch-transcript-file
|
||||
arguments:
|
||||
parameters:
|
||||
- name: transcripts-source
|
||||
value: "{{workflow.parameters.transcripts-source}}"
|
||||
|
||||
- name: preprocess-audio
|
||||
template: preprocess
|
||||
dependencies: [fetch-audio]
|
||||
arguments:
|
||||
parameters:
|
||||
- name: sample-rate
|
||||
value: "{{workflow.parameters.sample-rate}}"
|
||||
artifacts:
|
||||
- name: raw-audio
|
||||
from: "{{tasks.fetch-audio.outputs.artifacts.audio-files}}"
|
||||
|
||||
- name: generate-transcripts
|
||||
template: transcribe-audio
|
||||
dependencies: [preprocess-audio, fetch-transcripts]
|
||||
when: "{{workflow.parameters.transcripts-source}} == ''"
|
||||
arguments:
|
||||
parameters:
|
||||
- name: language
|
||||
value: "{{workflow.parameters.language}}"
|
||||
artifacts:
|
||||
- name: audio-files
|
||||
from: "{{tasks.preprocess-audio.outputs.artifacts.processed-audio}}"
|
||||
|
||||
- name: prepare-dataset
|
||||
template: prepare-coqui-dataset
|
||||
dependencies: [preprocess-audio, generate-transcripts, fetch-transcripts]
|
||||
arguments:
|
||||
parameters:
|
||||
- name: voice-name
|
||||
value: "{{workflow.parameters.voice-name}}"
|
||||
- name: language
|
||||
value: "{{workflow.parameters.language}}"
|
||||
artifacts:
|
||||
- name: audio-files
|
||||
from: "{{tasks.preprocess-audio.outputs.artifacts.processed-audio}}"
|
||||
- name: transcripts
|
||||
from: "{{=workflow.parameters.transcriptsSource != '' ? tasks.fetch-transcripts.outputs.artifacts.transcripts : tasks.generate-transcripts.outputs.artifacts.transcripts}}"
|
||||
optional: true
|
||||
|
||||
- name: train-model
|
||||
template: train-tts
|
||||
dependencies: [prepare-dataset]
|
||||
arguments:
|
||||
parameters:
|
||||
- name: voice-name
|
||||
value: "{{workflow.parameters.voice-name}}"
|
||||
- name: base-model
|
||||
value: "{{workflow.parameters.base-model}}"
|
||||
- name: language
|
||||
value: "{{workflow.parameters.language}}"
|
||||
- name: num-epochs
|
||||
value: "{{workflow.parameters.num-epochs}}"
|
||||
- name: batch-size
|
||||
value: "{{workflow.parameters.batch-size}}"
|
||||
- name: learning-rate
|
||||
value: "{{workflow.parameters.learning-rate}}"
|
||||
artifacts:
|
||||
- name: dataset
|
||||
from: "{{tasks.prepare-dataset.outputs.artifacts.dataset}}"
|
||||
|
||||
- name: export-model
|
||||
template: export-trained-model
|
||||
dependencies: [train-model]
|
||||
arguments:
|
||||
parameters:
|
||||
- name: voice-name
|
||||
value: "{{workflow.parameters.voice-name}}"
|
||||
- name: output-path
|
||||
value: "{{workflow.parameters.output-path}}"
|
||||
artifacts:
|
||||
- name: trained-model
|
||||
from: "{{tasks.train-model.outputs.artifacts.model}}"
|
||||
|
||||
# Template: Fetch audio files from source
|
||||
- name: fetch-audio-files
|
||||
inputs:
|
||||
parameters:
|
||||
- name: audio-source
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: audio-files
|
||||
path: /tmp/audio
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
import os
|
||||
import subprocess
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
source_url = "{{inputs.parameters.audio-source}}"
|
||||
output_dir = Path("/tmp/audio")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Fetching audio from: {source_url}")
|
||||
|
||||
if source_url.startswith("s3://"):
|
||||
subprocess.run(["pip", "install", "boto3", "-q"], check=True)
|
||||
import boto3
|
||||
s3 = boto3.client("s3")
|
||||
bucket, prefix = source_url[5:].split("/", 1)
|
||||
response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
|
||||
|
||||
audio_extensions = {".wav", ".mp3", ".flac", ".ogg", ".m4a"}
|
||||
for obj in response.get("Contents", []):
|
||||
key = obj["Key"]
|
||||
if Path(key).suffix.lower() in audio_extensions:
|
||||
local_path = output_dir / Path(key).name
|
||||
s3.download_file(bucket, key, str(local_path))
|
||||
print(f"Downloaded: {key}")
|
||||
|
||||
elif source_url.startswith("http"):
|
||||
# Handle single file or directory listing
|
||||
filename = source_url.split("/")[-1]
|
||||
if any(ext in filename.lower() for ext in [".wav", ".mp3", ".flac", ".zip"]):
|
||||
local_path = output_dir / filename
|
||||
urllib.request.urlretrieve(source_url, local_path)
|
||||
print(f"Downloaded: {filename}")
|
||||
|
||||
# Extract if zip
|
||||
if filename.endswith(".zip"):
|
||||
shutil.unpack_archive(local_path, output_dir)
|
||||
os.remove(local_path)
|
||||
print("Extracted zip archive")
|
||||
else:
|
||||
print(f"URL doesn't appear to be an audio file: {source_url}")
|
||||
exit(1)
|
||||
|
||||
elif source_url.startswith("/"):
|
||||
# Local/NFS path
|
||||
src_path = Path(source_url)
|
||||
if src_path.is_dir():
|
||||
audio_extensions = {".wav", ".mp3", ".flac", ".ogg", ".m4a"}
|
||||
for f in src_path.iterdir():
|
||||
if f.suffix.lower() in audio_extensions:
|
||||
shutil.copy(f, output_dir / f.name)
|
||||
print(f"Copied: {f.name}")
|
||||
elif src_path.is_file():
|
||||
shutil.copy(src_path, output_dir / src_path.name)
|
||||
else:
|
||||
print(f"Path not found: {source_url}")
|
||||
exit(1)
|
||||
else:
|
||||
print(f"Unsupported source: {source_url}")
|
||||
exit(1)
|
||||
|
||||
# Count files
|
||||
audio_files = list(output_dir.glob("*"))
|
||||
print(f"Total audio files: {len(audio_files)}")
|
||||
|
||||
if len(audio_files) == 0:
|
||||
print("Error: No audio files found!")
|
||||
exit(1)
|
||||
resources:
|
||||
requests:
|
||||
memory: 512Mi
|
||||
cpu: 200m
|
||||
|
||||
# Template: Fetch transcripts file
|
||||
- name: fetch-transcript-file
|
||||
inputs:
|
||||
parameters:
|
||||
- name: transcripts-source
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: transcripts
|
||||
path: /tmp/transcripts
|
||||
optional: true
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
import os
|
||||
import subprocess
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
source_url = "{{inputs.parameters.transcripts-source}}"
|
||||
output_dir = Path("/tmp/transcripts")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not source_url or source_url.strip() == "":
|
||||
print("No transcripts source provided - will auto-transcribe")
|
||||
# Create empty placeholder
|
||||
(output_dir / "placeholder.txt").write_text("auto-transcribe")
|
||||
exit(0)
|
||||
|
||||
print(f"Fetching transcripts from: {source_url}")
|
||||
|
||||
if source_url.startswith("s3://"):
|
||||
subprocess.run(["pip", "install", "boto3", "-q"], check=True)
|
||||
import boto3
|
||||
s3 = boto3.client("s3")
|
||||
bucket, key = source_url[5:].split("/", 1)
|
||||
local_path = output_dir / Path(key).name
|
||||
s3.download_file(bucket, key, str(local_path))
|
||||
print(f"Downloaded: {key}")
|
||||
|
||||
elif source_url.startswith("http"):
|
||||
filename = source_url.split("/")[-1] or "transcripts.csv"
|
||||
local_path = output_dir / filename
|
||||
urllib.request.urlretrieve(source_url, local_path)
|
||||
print(f"Downloaded: {filename}")
|
||||
|
||||
elif source_url.startswith("/"):
|
||||
src_path = Path(source_url)
|
||||
if src_path.is_file():
|
||||
shutil.copy(src_path, output_dir / src_path.name)
|
||||
print(f"Copied: {src_path.name}")
|
||||
else:
|
||||
print(f"File not found: {source_url}")
|
||||
exit(1)
|
||||
else:
|
||||
print(f"Unsupported source: {source_url}")
|
||||
exit(1)
|
||||
resources:
|
||||
requests:
|
||||
memory: 256Mi
|
||||
cpu: 100m
|
||||
|
||||
# Template: Preprocess audio files
|
||||
- name: preprocess
|
||||
inputs:
|
||||
parameters:
|
||||
- name: sample-rate
|
||||
artifacts:
|
||||
- name: raw-audio
|
||||
path: /tmp/raw-audio
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: processed-audio
|
||||
path: /tmp/processed-audio
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [bash]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
set -e
|
||||
|
||||
# Install ffmpeg and dependencies
|
||||
apt-get update && apt-get install -y ffmpeg > /dev/null 2>&1
|
||||
pip install -q pydub numpy soundfile
|
||||
|
||||
python3 << 'EOF'
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pydub import AudioSegment
|
||||
import soundfile as sf
|
||||
|
||||
SAMPLE_RATE = int("{{inputs.parameters.sample-rate}}")
|
||||
input_dir = Path("/tmp/raw-audio")
|
||||
output_dir = Path("/tmp/processed-audio")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
audio_extensions = {".wav", ".mp3", ".flac", ".ogg", ".m4a"}
|
||||
|
||||
for audio_file in input_dir.iterdir():
|
||||
if audio_file.suffix.lower() not in audio_extensions:
|
||||
continue
|
||||
|
||||
print(f"Processing: {audio_file.name}")
|
||||
|
||||
try:
|
||||
# Load audio
|
||||
audio = AudioSegment.from_file(str(audio_file))
|
||||
|
||||
# Convert to mono if stereo
|
||||
if audio.channels > 1:
|
||||
audio = audio.set_channels(1)
|
||||
|
||||
# Resample to target sample rate
|
||||
audio = audio.set_frame_rate(SAMPLE_RATE)
|
||||
|
||||
# Normalize audio
|
||||
audio = audio.normalize()
|
||||
|
||||
# Export as WAV
|
||||
output_file = output_dir / f"{audio_file.stem}.wav"
|
||||
audio.export(str(output_file), format="wav")
|
||||
print(f" -> Saved: {output_file.name}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" -> Error processing {audio_file.name}: {e}")
|
||||
continue
|
||||
|
||||
processed_files = list(output_dir.glob("*.wav"))
|
||||
print(f"\nProcessed {len(processed_files)} audio files")
|
||||
|
||||
if len(processed_files) == 0:
|
||||
print("Error: No files were successfully processed!")
|
||||
exit(1)
|
||||
EOF
|
||||
resources:
|
||||
requests:
|
||||
memory: 2Gi
|
||||
cpu: "1"
|
||||
|
||||
# Template: Auto-transcribe audio using Coqui STT
|
||||
- name: transcribe-audio
|
||||
inputs:
|
||||
parameters:
|
||||
- name: language
|
||||
artifacts:
|
||||
- name: audio-files
|
||||
path: /tmp/audio
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: transcripts
|
||||
path: /tmp/transcripts
|
||||
container:
|
||||
image: ghcr.io/coqui-ai/stt:latest
|
||||
command: [bash]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
set -e
|
||||
|
||||
# Install additional dependencies
|
||||
pip install -q numpy scipy
|
||||
|
||||
python3 << 'EOF'
|
||||
import csv
|
||||
import os
|
||||
import wave
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from stt import Model
|
||||
|
||||
LANGUAGE = "{{inputs.parameters.language}}"
|
||||
input_dir = Path("/tmp/audio")
|
||||
output_dir = Path("/tmp/transcripts")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Model paths - Coqui STT models are typically pre-installed in the container
|
||||
# or can be downloaded from https://coqui.ai/models
|
||||
MODEL_DIR = Path("/models/stt")
|
||||
|
||||
# Try to find model files
|
||||
model_file = None
|
||||
scorer_file = None
|
||||
|
||||
# Check for language-specific models
|
||||
lang_model_dir = MODEL_DIR / LANGUAGE
|
||||
if lang_model_dir.exists():
|
||||
for f in lang_model_dir.glob("*.tflite"):
|
||||
model_file = f
|
||||
for f in lang_model_dir.glob("*.scorer"):
|
||||
scorer_file = f
|
||||
|
||||
# Fallback to default English model location
|
||||
if model_file is None:
|
||||
default_paths = [
|
||||
MODEL_DIR / "model.tflite",
|
||||
Path("/usr/share/stt/model.tflite"),
|
||||
Path("/opt/stt/model.tflite"),
|
||||
]
|
||||
for p in default_paths:
|
||||
if p.exists():
|
||||
model_file = p
|
||||
break
|
||||
|
||||
if model_file is None:
|
||||
# Download model if not found
|
||||
print("Downloading Coqui STT model...")
|
||||
import urllib.request
|
||||
import tarfile
|
||||
|
||||
model_url = "https://github.com/coqui-ai/STT-models/releases/download/english/coqui-stt-1.0.0-lg-vocab.tflite"
|
||||
scorer_url = "https://github.com/coqui-ai/STT-models/releases/download/english/coqui-stt-1.0.0-lg-vocab.scorer"
|
||||
|
||||
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
model_file = MODEL_DIR / "model.tflite"
|
||||
scorer_file = MODEL_DIR / "model.scorer"
|
||||
|
||||
urllib.request.urlretrieve(model_url, model_file)
|
||||
urllib.request.urlretrieve(scorer_url, scorer_file)
|
||||
print("Model downloaded successfully")
|
||||
|
||||
print(f"Loading Coqui STT model: {model_file}")
|
||||
model = Model(str(model_file))
|
||||
|
||||
if scorer_file and scorer_file.exists():
|
||||
print(f"Loading scorer: {scorer_file}")
|
||||
model.enableExternalScorer(str(scorer_file))
|
||||
|
||||
transcripts = []
|
||||
|
||||
for audio_file in sorted(input_dir.glob("*.wav")):
|
||||
print(f"Transcribing: {audio_file.name}")
|
||||
|
||||
try:
|
||||
# Read WAV file
|
||||
with wave.open(str(audio_file), 'rb') as w:
|
||||
sample_rate = w.getframerate()
|
||||
frames = w.getnframes()
|
||||
audio_data = w.readframes(frames)
|
||||
|
||||
# Convert to int16 array
|
||||
audio = np.frombuffer(audio_data, dtype=np.int16)
|
||||
|
||||
# Resample if needed (Coqui STT expects 16kHz)
|
||||
if sample_rate != 16000:
|
||||
from scipy import signal
|
||||
audio = signal.resample(audio, int(len(audio) * 16000 / sample_rate))
|
||||
audio = audio.astype(np.int16)
|
||||
|
||||
# Run inference
|
||||
text = model.stt(audio)
|
||||
|
||||
transcripts.append({
|
||||
"audio_file": audio_file.name,
|
||||
"transcript": text
|
||||
})
|
||||
print(f" -> {text[:100] if text else '(empty)'}...")
|
||||
except Exception as e:
|
||||
print(f" -> Error: {e}")
|
||||
continue
|
||||
|
||||
# Write CSV
|
||||
csv_file = output_dir / "transcripts.csv"
|
||||
with open(csv_file, "w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=["audio_file", "transcript"])
|
||||
writer.writeheader()
|
||||
writer.writerows(transcripts)
|
||||
|
||||
print(f"\nTranscribed {len(transcripts)} files")
|
||||
print(f"Saved to: {csv_file}")
|
||||
EOF
|
||||
resources:
|
||||
requests:
|
||||
memory: 4Gi
|
||||
cpu: "2"
|
||||
limits:
|
||||
memory: 8Gi
|
||||
cpu: "4"
|
||||
|
||||
# Template: Prepare dataset in Coqui TTS format
|
||||
- name: prepare-coqui-dataset
|
||||
inputs:
|
||||
parameters:
|
||||
- name: voice-name
|
||||
- name: language
|
||||
artifacts:
|
||||
- name: audio-files
|
||||
path: /tmp/audio
|
||||
- name: transcripts
|
||||
path: /tmp/transcripts
|
||||
optional: true
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: dataset
|
||||
path: /tmp/dataset
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [python]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
VOICE_NAME = "{{inputs.parameters.voice-name}}"
|
||||
LANGUAGE = "{{inputs.parameters.language}}"
|
||||
|
||||
audio_dir = Path("/tmp/audio")
|
||||
transcripts_dir = Path("/tmp/transcripts")
|
||||
output_dir = Path("/tmp/dataset")
|
||||
wavs_dir = output_dir / "wavs"
|
||||
wavs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Preparing Coqui TTS dataset for voice: {VOICE_NAME}")
|
||||
|
||||
# Find transcripts file
|
||||
transcripts_file = None
|
||||
for f in transcripts_dir.glob("*.csv"):
|
||||
transcripts_file = f
|
||||
break
|
||||
|
||||
if transcripts_file is None:
|
||||
# Check for .txt files (simple format: filename|text)
|
||||
for f in transcripts_dir.glob("*.txt"):
|
||||
if f.name != "placeholder.txt":
|
||||
transcripts_file = f
|
||||
break
|
||||
|
||||
if transcripts_file is None:
|
||||
print("Error: No transcripts file found!")
|
||||
exit(1)
|
||||
|
||||
print(f"Using transcripts: {transcripts_file}")
|
||||
|
||||
# Parse transcripts
|
||||
transcripts = {}
|
||||
|
||||
if transcripts_file.suffix == ".csv":
|
||||
with open(transcripts_file, "r", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
# Handle various column name conventions
|
||||
audio = row.get("audio_file") or row.get("audio") or row.get("file") or row.get("wav")
|
||||
text = row.get("transcript") or row.get("text") or row.get("sentence")
|
||||
if audio and text:
|
||||
transcripts[audio] = text.strip()
|
||||
else:
|
||||
# Simple pipe-separated format: filename|text
|
||||
with open(transcripts_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if "|" in line:
|
||||
parts = line.split("|", 1)
|
||||
if len(parts) == 2:
|
||||
transcripts[parts[0]] = parts[1]
|
||||
|
||||
print(f"Loaded {len(transcripts)} transcripts")
|
||||
|
||||
# Copy audio files and create metadata
|
||||
metadata_lines = []
|
||||
|
||||
for audio_file in sorted(audio_dir.glob("*.wav")):
|
||||
# Try to match transcript
|
||||
text = None
|
||||
for key in [audio_file.name, audio_file.stem, audio_file.stem + ".wav"]:
|
||||
if key in transcripts:
|
||||
text = transcripts[key]
|
||||
break
|
||||
|
||||
if text is None:
|
||||
print(f"Warning: No transcript for {audio_file.name}, skipping")
|
||||
continue
|
||||
|
||||
# Copy audio file
|
||||
dest_file = wavs_dir / audio_file.name
|
||||
shutil.copy(audio_file, dest_file)
|
||||
|
||||
# Add to metadata (LJSpeech format: filename|text|text)
|
||||
# Coqui uses: audio_file|text|text (normalized text optional)
|
||||
metadata_lines.append(f"{audio_file.stem}|{text}|{text}")
|
||||
|
||||
# Write metadata.csv
|
||||
metadata_file = output_dir / "metadata.csv"
|
||||
with open(metadata_file, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(metadata_lines))
|
||||
|
||||
print(f"Created dataset with {len(metadata_lines)} samples")
|
||||
|
||||
# Create dataset config
|
||||
config = {
|
||||
"name": VOICE_NAME,
|
||||
"language": LANGUAGE,
|
||||
"num_samples": len(metadata_lines),
|
||||
"format": "ljspeech"
|
||||
}
|
||||
|
||||
with open(output_dir / "dataset_config.json", "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
print(f"Dataset ready at: {output_dir}")
|
||||
|
||||
if len(metadata_lines) < 10:
|
||||
print("Warning: Very small dataset! Recommend at least 100+ samples for good results.")
|
||||
resources:
|
||||
requests:
|
||||
memory: 1Gi
|
||||
cpu: 500m
|
||||
|
||||
# Template: Train Coqui TTS model
|
||||
- name: train-tts
|
||||
inputs:
|
||||
parameters:
|
||||
- name: voice-name
|
||||
- name: base-model
|
||||
- name: language
|
||||
- name: num-epochs
|
||||
- name: batch-size
|
||||
- name: learning-rate
|
||||
artifacts:
|
||||
- name: dataset
|
||||
path: /tmp/dataset
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: model
|
||||
path: /tmp/output
|
||||
container:
|
||||
image: ghcr.io/coqui-ai/tts:latest
|
||||
command: [bash]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
set -e
|
||||
|
||||
VOICE_NAME="{{inputs.parameters.voice-name}}"
|
||||
BASE_MODEL="{{inputs.parameters.base-model}}"
|
||||
LANGUAGE="{{inputs.parameters.language}}"
|
||||
NUM_EPOCHS="{{inputs.parameters.num-epochs}}"
|
||||
BATCH_SIZE="{{inputs.parameters.batch-size}}"
|
||||
LEARNING_RATE="{{inputs.parameters.learning-rate}}"
|
||||
|
||||
DATASET_DIR="/tmp/dataset"
|
||||
OUTPUT_DIR="/tmp/output"
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
echo "=== Coqui TTS Voice Training ==="
|
||||
echo "Voice Name: $VOICE_NAME"
|
||||
echo "Base Model: $BASE_MODEL"
|
||||
echo "Language: $LANGUAGE"
|
||||
echo "Epochs: $NUM_EPOCHS"
|
||||
echo "Batch Size: $BATCH_SIZE"
|
||||
echo "Learning Rate: $LEARNING_RATE"
|
||||
echo ""
|
||||
|
||||
# Download base model if specified for fine-tuning
|
||||
RESTORE_PATH=""
|
||||
if [ "$BASE_MODEL" != "" ] && [ "$BASE_MODEL" != "none" ]; then
|
||||
echo "Downloading base model for fine-tuning: $BASE_MODEL"
|
||||
# Use tts to download the model and get its path
|
||||
MODEL_PATH=$(python3 -c "
|
||||
from TTS.utils.manage import ModelManager
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
model_name = '$BASE_MODEL'
|
||||
manager = ModelManager()
|
||||
|
||||
# Download the model
|
||||
model_path, config_path, _ = manager.download_model(model_name)
|
||||
print(model_path)
|
||||
")
|
||||
RESTORE_PATH="$MODEL_PATH"
|
||||
echo "Base model path: $RESTORE_PATH"
|
||||
fi
|
||||
|
||||
# Create and run training script following Coqui docs pattern
|
||||
python3 << EOF
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Trainer: Where the magic happens
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
# Model configs
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.vits import Vits
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
# Paths
|
||||
DATASET_DIR = Path("$DATASET_DIR")
|
||||
OUTPUT_DIR = Path("$OUTPUT_DIR")
|
||||
RESTORE_PATH = "$RESTORE_PATH" if "$RESTORE_PATH" else None
|
||||
|
||||
print(f"Dataset: {DATASET_DIR}")
|
||||
print(f"Output: {OUTPUT_DIR}")
|
||||
print(f"Restore from: {RESTORE_PATH}")
|
||||
|
||||
# Define dataset config (LJSpeech format)
|
||||
dataset_config = BaseDatasetConfig(
|
||||
formatter="ljspeech",
|
||||
meta_file_train="metadata.csv",
|
||||
path=str(DATASET_DIR),
|
||||
language="$LANGUAGE",
|
||||
)
|
||||
|
||||
# Initialize training configuration
|
||||
config = VitsConfig(
|
||||
run_name="$VOICE_NAME",
|
||||
output_path=str(OUTPUT_DIR),
|
||||
datasets=[dataset_config],
|
||||
batch_size=int("$BATCH_SIZE"),
|
||||
eval_batch_size=max(1, int("$BATCH_SIZE") // 2),
|
||||
num_loader_workers=4,
|
||||
num_eval_loader_workers=2,
|
||||
run_eval=True,
|
||||
test_delay_epochs=5,
|
||||
epochs=int("$NUM_EPOCHS"),
|
||||
text_cleaner="phoneme_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="$LANGUAGE",
|
||||
phoneme_cache_path=str(OUTPUT_DIR / "phoneme_cache"),
|
||||
compute_input_seq_cache=True,
|
||||
print_step=25,
|
||||
print_eval=False,
|
||||
mixed_precision=True,
|
||||
save_step=500,
|
||||
save_n_checkpoints=3,
|
||||
save_best_after=1000,
|
||||
lr=float("$LEARNING_RATE"),
|
||||
# Audio settings for typical voice cloning
|
||||
audio={
|
||||
"sample_rate": 22050,
|
||||
"resample": True,
|
||||
"do_trim_silence": True,
|
||||
"trim_db": 45,
|
||||
},
|
||||
)
|
||||
|
||||
# Initialize the audio processor
|
||||
# Used for feature extraction and audio I/O
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
|
||||
# Initialize the tokenizer
|
||||
# Converts text to sequences of token IDs
|
||||
tokenizer, config = TTSTokenizer.init_from_config(config)
|
||||
|
||||
# Load data samples
|
||||
# Each sample is [text, audio_file_path, speaker_name]
|
||||
train_samples, eval_samples = load_tts_samples(
|
||||
dataset_config,
|
||||
eval_split=True,
|
||||
eval_split_max_size=config.eval_split_max_size,
|
||||
eval_split_size=config.eval_split_size,
|
||||
)
|
||||
|
||||
print(f"Training samples: {len(train_samples)}")
|
||||
print(f"Eval samples: {len(eval_samples)}")
|
||||
|
||||
# Initialize the model
|
||||
model = Vits(config, ap, tokenizer, speaker_manager=None)
|
||||
|
||||
# Set up trainer arguments
|
||||
trainer_args = TrainerArgs(
|
||||
restore_path=RESTORE_PATH,
|
||||
skip_train_epoch=False,
|
||||
)
|
||||
|
||||
# Initialize the trainer
|
||||
trainer = Trainer(
|
||||
trainer_args,
|
||||
config,
|
||||
output_path=str(OUTPUT_DIR),
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
)
|
||||
|
||||
# Start training
|
||||
print("\n" + "=" * 50)
|
||||
print("Starting training...")
|
||||
print("=" * 50 + "\n")
|
||||
|
||||
trainer.fit()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Training complete!")
|
||||
print("=" * 50)
|
||||
EOF
|
||||
|
||||
echo ""
|
||||
echo "Training complete!"
|
||||
echo "Model saved to: $OUTPUT_DIR"
|
||||
ls -la "$OUTPUT_DIR"
|
||||
resources:
|
||||
requests:
|
||||
memory: 16Gi
|
||||
cpu: "4"
|
||||
nvidia.com/gpu: "1"
|
||||
limits:
|
||||
memory: 32Gi
|
||||
cpu: "8"
|
||||
nvidia.com/gpu: "1"
|
||||
volumeMounts:
|
||||
- name: training-workspace
|
||||
mountPath: /tmp/workspace
|
||||
|
||||
# Template: Export trained model
|
||||
- name: export-trained-model
|
||||
inputs:
|
||||
parameters:
|
||||
- name: voice-name
|
||||
- name: output-path
|
||||
artifacts:
|
||||
- name: trained-model
|
||||
path: /tmp/trained-model
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: exported-model
|
||||
path: /tmp/exported
|
||||
container:
|
||||
image: python:3.13-slim
|
||||
command: [bash]
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
set -e
|
||||
|
||||
pip install -q boto3
|
||||
|
||||
python3 << 'EOF'
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
VOICE_NAME = "{{inputs.parameters.voice-name}}"
|
||||
OUTPUT_PATH = "{{inputs.parameters.output-path}}"
|
||||
|
||||
model_dir = Path("/tmp/trained-model")
|
||||
export_dir = Path("/tmp/exported")
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Exporting trained model: {VOICE_NAME}")
|
||||
print(f"Target path: {OUTPUT_PATH}")
|
||||
|
||||
# Find best checkpoint
|
||||
checkpoints = list(model_dir.glob("best_model*.pth")) + list(model_dir.glob("checkpoint_*.pth"))
|
||||
if not checkpoints:
|
||||
checkpoints = list(model_dir.glob("*.pth"))
|
||||
|
||||
if not checkpoints:
|
||||
print("Error: No model checkpoints found!")
|
||||
exit(1)
|
||||
|
||||
# Sort by modification time and get newest
|
||||
checkpoints.sort(key=lambda x: x.stat().st_mtime, reverse=True)
|
||||
best_checkpoint = checkpoints[0]
|
||||
print(f"Using checkpoint: {best_checkpoint.name}")
|
||||
|
||||
# Create export package
|
||||
package_dir = export_dir / VOICE_NAME
|
||||
package_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy model files
|
||||
shutil.copy(best_checkpoint, package_dir / "model.pth")
|
||||
|
||||
# Copy config if exists
|
||||
config_file = model_dir / "config.json"
|
||||
if config_file.exists():
|
||||
shutil.copy(config_file, package_dir / "config.json")
|
||||
|
||||
# Create model info
|
||||
model_info = {
|
||||
"name": VOICE_NAME,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"checkpoint": best_checkpoint.name,
|
||||
"type": "coqui-tts"
|
||||
}
|
||||
|
||||
with open(package_dir / "model_info.json", "w") as f:
|
||||
json.dump(model_info, f, indent=2)
|
||||
|
||||
# Create tarball
|
||||
archive_name = f"{VOICE_NAME}.tar.gz"
|
||||
shutil.make_archive(
|
||||
str(export_dir / VOICE_NAME),
|
||||
"gztar",
|
||||
export_dir,
|
||||
VOICE_NAME
|
||||
)
|
||||
|
||||
print(f"Created archive: {archive_name}")
|
||||
|
||||
# Upload to destination
|
||||
if OUTPUT_PATH.startswith("s3://"):
|
||||
import boto3
|
||||
s3 = boto3.client("s3")
|
||||
bucket, key = OUTPUT_PATH[5:].split("/", 1)
|
||||
key = f"{key}/{archive_name}"
|
||||
s3.upload_file(str(export_dir / archive_name), bucket, key)
|
||||
print(f"Uploaded to: s3://{bucket}/{key}")
|
||||
|
||||
elif OUTPUT_PATH.startswith("/"):
|
||||
# Local/NFS path
|
||||
dest_path = Path(OUTPUT_PATH)
|
||||
dest_path.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(export_dir / archive_name, dest_path / archive_name)
|
||||
# Also copy uncompressed for easy access
|
||||
shutil.copytree(package_dir, dest_path / VOICE_NAME, dirs_exist_ok=True)
|
||||
print(f"Saved to: {dest_path / archive_name}")
|
||||
|
||||
print("\nExport complete!")
|
||||
print(f"Model package contents:")
|
||||
for f in package_dir.iterdir():
|
||||
print(f" - {f.name}")
|
||||
EOF
|
||||
resources:
|
||||
requests:
|
||||
memory: 1Gi
|
||||
cpu: 500m
|
||||
Reference in New Issue
Block a user