Files
kubeflow/voice_cloning_pipeline.py
Billy D. d4eb54d92b
All checks were successful
Compile and Upload Pipelines / Compile & Upload (push) Successful in 15s
Compile and Upload Pipelines / Notify (push) Successful in 1s
pipelines go to gravenhollow now.
2026-02-18 07:14:12 -05:00

688 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Voice Cloning Pipeline Kubeflow Pipelines SDK
Takes an audio file and a transcript, extracts a target speaker's
segments, preprocesses into LJSpeech-format training data, fine-tunes
a Coqui VITS voice model, pushes the model to Gitea, and logs to MLflow.
Usage:
pip install kfp==2.12.1
python voice_cloning_pipeline.py
# Upload voice_cloning_pipeline.yaml to Kubeflow Pipelines UI
"""
from kfp import compiler, dsl
from typing import NamedTuple
# ──────────────────────────────────────────────────────────────
# 1. Transcribe + diarise audio via Whisper to identify speakers
# ──────────────────────────────────────────────────────────────
@dsl.component(
base_image="python:3.13-slim",
packages_to_install=["requests", "boto3"],
)
def transcribe_and_diarise(
s3_endpoint: str,
s3_bucket: str,
s3_key: str,
whisper_url: str = "http://ai-inference-serve-svc.kuberay.svc.cluster.local:8000/whisper",
) -> NamedTuple("TranscriptOutput", [("transcript_json", str), ("speakers", str), ("audio_path", str)]):
"""Download audio from S3, transcribe via Whisper with timestamps."""
import json
import os
import subprocess
import tempfile
import base64
import boto3
import requests
out = NamedTuple("TranscriptOutput", [("transcript_json", str), ("speakers", str), ("audio_path", str)])
work = tempfile.mkdtemp()
# ── Download audio from S3 ─────────────────────
ext = os.path.splitext(s3_key)[-1] or ".wav"
audio_path = os.path.join(work, f"audio_raw{ext}")
client = boto3.client(
"s3",
endpoint_url=s3_endpoint,
aws_access_key_id="",
aws_secret_access_key="",
config=boto3.session.Config(signature_version="UNSIGNED"),
verify=False,
)
print(f"Downloading s3://{s3_bucket}/{s3_key} from {s3_endpoint}")
client.download_file(s3_bucket, s3_key, audio_path)
print(f"Downloaded {os.path.getsize(audio_path)} bytes")
# ── Normalise to 16 kHz mono WAV ─────────────────────────
wav_path = os.path.join(work, "audio.wav")
subprocess.run(
["apt-get", "update", "-qq"],
capture_output=True,
)
subprocess.run(
["apt-get", "install", "-y", "-qq", "ffmpeg"],
capture_output=True, check=True,
)
subprocess.run(
["ffmpeg", "-y", "-i", audio_path, "-ac", "1",
"-ar", "16000", "-sample_fmt", "s16", wav_path],
capture_output=True, check=True,
)
# ── Send to Whisper for timestamped transcription ─────────
with open(wav_path, "rb") as f:
audio_b64 = base64.b64encode(f.read()).decode()
payload = {
"audio": audio_b64,
"response_format": "verbose_json",
"timestamp_granularities": ["segment"],
}
resp = requests.post(whisper_url, json=payload, timeout=600)
resp.raise_for_status()
result = resp.json()
segments = result.get("segments", [])
print(f"Whisper returned {len(segments)} segments")
# ── Group segments by speaker if diarisation is present ───
# Whisper may not diarise, but we still produce segments with
# start/end timestamps that the next step can use.
speakers = set()
for i, seg in enumerate(segments):
spk = seg.get("speaker", f"SPEAKER_{i // 10}")
seg["speaker"] = spk
speakers.add(spk)
speakers_list = sorted(speakers)
print(f"Detected speakers: {speakers_list}")
return out(
transcript_json=json.dumps(segments),
speakers=json.dumps(speakers_list),
audio_path=wav_path,
)
# ──────────────────────────────────────────────────────────────
# 2. Extract target speaker's audio segments
# ──────────────────────────────────────────────────────────────
@dsl.component(
base_image="python:3.13-slim",
packages_to_install=[],
)
def extract_speaker_segments(
transcript_json: str,
audio_path: str,
target_speaker: str,
min_duration_s: float = 1.0,
max_duration_s: float = 15.0,
) -> NamedTuple("SpeakerSegments", [("segments_json", str), ("num_segments", int), ("total_duration_s", float)]):
"""Slice the audio into per-utterance WAV files for the target speaker."""
import json
import os
import subprocess
import tempfile
out = NamedTuple("SpeakerSegments", [("segments_json", str), ("num_segments", int), ("total_duration_s", float)])
work = tempfile.mkdtemp()
wavs_dir = os.path.join(work, "wavs")
os.makedirs(wavs_dir, exist_ok=True)
# Install ffmpeg
subprocess.run(["apt-get", "update", "-qq"], capture_output=True)
subprocess.run(["apt-get", "install", "-y", "-qq", "ffmpeg"], capture_output=True, check=True)
segments = json.loads(transcript_json)
# Filter by speaker — fuzzy match (case-insensitive, partial)
target_lower = target_speaker.lower()
matched = []
for seg in segments:
spk = seg.get("speaker", "").lower()
if target_lower in spk or spk in target_lower:
matched.append(seg)
# If no speaker labels matched, the user may have given a name
# that doesn't appear. Fall back to using ALL segments.
if not matched:
print(f"WARNING: No segments matched speaker '{target_speaker}'. "
f"Using all {len(segments)} segments.")
matched = segments
print(f"Matched {len(matched)} segments for speaker '{target_speaker}'")
kept = []
total_dur = 0.0
for i, seg in enumerate(matched):
start = float(seg.get("start", 0))
end = float(seg.get("end", start + 5))
duration = end - start
text = seg.get("text", "").strip()
if duration < min_duration_s or not text:
continue
if duration > max_duration_s:
end = start + max_duration_s
duration = max_duration_s
wav_name = f"utt_{i:04d}.wav"
wav_out = os.path.join(wavs_dir, wav_name)
subprocess.run(
["ffmpeg", "-y", "-i", audio_path,
"-ss", str(start), "-to", str(end),
"-ac", "1", "-ar", "22050", "-sample_fmt", "s16",
wav_out],
capture_output=True, check=True,
)
kept.append({
"wav": wav_name,
"text": text,
"start": start,
"end": end,
"duration": round(duration, 2),
})
total_dur += duration
print(f"Extracted {len(kept)} utterances, total {total_dur:.1f}s")
return out(
segments_json=json.dumps({"wavs_dir": wavs_dir, "utterances": kept}),
num_segments=len(kept),
total_duration_s=round(total_dur, 2),
)
# ──────────────────────────────────────────────────────────────
# 3. Prepare LJSpeech-format dataset for Coqui TTS
# ──────────────────────────────────────────────────────────────
@dsl.component(
base_image="python:3.13-slim",
packages_to_install=[],
)
def prepare_ljspeech_dataset(
segments_json: str,
voice_name: str,
language: str = "en",
) -> NamedTuple("DatasetOutput", [("dataset_dir", str), ("num_samples", int)]):
"""Create metadata.csv + wavs/ in LJSpeech format."""
import json
import os
import shutil
out = NamedTuple("DatasetOutput", [("dataset_dir", str), ("num_samples", int)])
data = json.loads(segments_json)
wavs_src = data["wavs_dir"]
utterances = data["utterances"]
dataset_dir = os.path.join(os.path.dirname(wavs_src), "dataset")
wavs_dst = os.path.join(dataset_dir, "wavs")
os.makedirs(wavs_dst, exist_ok=True)
lines = []
for utt in utterances:
src = os.path.join(wavs_src, utt["wav"])
dst = os.path.join(wavs_dst, utt["wav"])
shutil.copy2(src, dst)
stem = os.path.splitext(utt["wav"])[0]
# LJSpeech format: id|text|text
text = utt["text"].replace("|", " ")
lines.append(f"{stem}|{text}|{text}")
metadata_path = os.path.join(dataset_dir, "metadata.csv")
with open(metadata_path, "w", encoding="utf-8") as f:
f.write("\n".join(lines))
# Dataset config for reference
import json as _json
config = {
"name": voice_name,
"language": language,
"num_samples": len(lines),
"format": "ljspeech",
"sample_rate": 22050,
}
with open(os.path.join(dataset_dir, "dataset_config.json"), "w") as f:
_json.dump(config, f, indent=2)
print(f"LJSpeech dataset ready: {len(lines)} samples")
return out(dataset_dir=dataset_dir, num_samples=len(lines))
# ──────────────────────────────────────────────────────────────
# 4. Fine-tune Coqui VITS voice model
# ──────────────────────────────────────────────────────────────
@dsl.component(
base_image="ghcr.io/idiap/coqui-tts:latest",
packages_to_install=[],
)
def train_vits_voice(
dataset_dir: str,
voice_name: str,
language: str = "en",
base_model: str = "tts_models/en/ljspeech/vits",
num_epochs: int = 100,
batch_size: int = 16,
learning_rate: float = 0.0001,
) -> NamedTuple("TrainOutput", [("model_dir", str), ("best_checkpoint", str), ("final_loss", float)]):
"""Fine-tune a VITS model on the speaker dataset."""
import os
import json
import glob
out = NamedTuple("TrainOutput", [("model_dir", str), ("best_checkpoint", str), ("final_loss", float)])
OUTPUT_DIR = "/tmp/vits_output"
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"=== Coqui VITS Voice Training ===")
print(f"Voice name : {voice_name}")
print(f"Base model : {base_model}")
print(f"Dataset : {dataset_dir}")
print(f"Epochs : {num_epochs}")
print(f"Batch size : {batch_size}")
print(f"LR : {learning_rate}")
# ── Download base model checkpoint ────────────────────────
restore_path = None
if base_model and base_model != "none":
from TTS.utils.manage import ModelManager
manager = ModelManager()
model_path, config_path, _ = manager.download_model(base_model)
restore_path = model_path
print(f"Base model checkpoint: {restore_path}")
# ── Configure and train ───────────────────────────────────
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
dataset_config = BaseDatasetConfig(
formatter="ljspeech",
meta_file_train="metadata.csv",
path=dataset_dir,
language=language,
)
config = VitsConfig(
run_name=voice_name,
output_path=OUTPUT_DIR,
datasets=[dataset_config],
batch_size=batch_size,
eval_batch_size=max(1, batch_size // 2),
num_loader_workers=4,
num_eval_loader_workers=2,
run_eval=True,
test_delay_epochs=5,
epochs=num_epochs,
text_cleaner="phoneme_cleaners",
use_phonemes=True,
phoneme_language=language,
phoneme_cache_path=os.path.join(OUTPUT_DIR, "phoneme_cache"),
compute_input_seq_cache=True,
print_step=25,
print_eval=False,
mixed_precision=True,
save_step=500,
save_n_checkpoints=3,
save_best_after=1000,
lr=learning_rate,
audio={
"sample_rate": 22050,
"resample": True,
"do_trim_silence": True,
"trim_db": 45,
},
)
ap = AudioProcessor.init_from_config(config)
tokenizer, config = TTSTokenizer.init_from_config(config)
train_samples, eval_samples = load_tts_samples(
dataset_config,
eval_split=True,
eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
)
print(f"Training samples: {len(train_samples)}")
print(f"Eval samples: {len(eval_samples)}")
model = Vits(config, ap, tokenizer, speaker_manager=None)
trainer_args = TrainerArgs(
restore_path=restore_path,
skip_train_epoch=False,
)
trainer = Trainer(
trainer_args,
config,
output_path=OUTPUT_DIR,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
)
trainer.fit()
# ── Find best checkpoint ──────────────────────────────────
best_files = glob.glob(os.path.join(OUTPUT_DIR, "**/best_model*.pth"), recursive=True)
if not best_files:
best_files = glob.glob(os.path.join(OUTPUT_DIR, "**/*.pth"), recursive=True)
best_files.sort(key=os.path.getmtime, reverse=True)
best_checkpoint = best_files[0] if best_files else ""
# Try to read final loss from trainer
final_loss = 0.0
try:
final_loss = float(trainer.keep_avg_train["avg_loss"])
except Exception:
pass
print(f"Training complete. Best checkpoint: {best_checkpoint}")
print(f"Final loss: {final_loss:.4f}")
return out(model_dir=OUTPUT_DIR, best_checkpoint=best_checkpoint, final_loss=final_loss)
# ──────────────────────────────────────────────────────────────
# 5. Push trained voice model to Gitea repository
# ──────────────────────────────────────────────────────────────
@dsl.component(
base_image="python:3.13-slim",
packages_to_install=["requests"],
)
def push_model_to_gitea(
model_dir: str,
voice_name: str,
gitea_url: str = "http://gitea-http.gitea.svc.cluster.local:3000",
gitea_owner: str = "daviestechlabs",
gitea_repo: str = "voice-models",
gitea_username: str = "",
gitea_password: str = "",
) -> NamedTuple("PushOutput", [("repo_url", str), ("files_pushed", int)]):
"""Package and push the trained model to a Gitea repository."""
import base64
import glob
import json
import os
import requests
out = NamedTuple("PushOutput", [("repo_url", str), ("files_pushed", int)])
session = requests.Session()
session.auth = (gitea_username, gitea_password) if gitea_username else None
api = f"{gitea_url}/api/v1"
repo_url = f"{gitea_url}/{gitea_owner}/{gitea_repo}"
# ── Ensure repo exists ────────────────────────────────────
r = session.get(f"{api}/repos/{gitea_owner}/{gitea_repo}", timeout=30)
if r.status_code == 404:
print(f"Creating repository: {gitea_owner}/{gitea_repo}")
r = session.post(
f"{api}/orgs/{gitea_owner}/repos",
json={
"name": gitea_repo,
"description": "Trained voice models from voice cloning pipeline",
"private": False,
"auto_init": True,
},
timeout=30,
)
if r.status_code not in (200, 201):
r = session.post(
f"{api}/user/repos",
json={"name": gitea_repo, "description": "Trained voice models", "auto_init": True},
timeout=30,
)
r.raise_for_status()
print("Repository created")
# ── Collect model files ───────────────────────────────────
files_to_push = []
# Best model checkpoint
for pattern in ["**/best_model*.pth", "**/*.pth"]:
found = glob.glob(os.path.join(model_dir, pattern), recursive=True)
if found:
found.sort(key=os.path.getmtime, reverse=True)
files_to_push.append(found[0])
break
# Config
for pattern in ["**/config.json"]:
found = glob.glob(os.path.join(model_dir, pattern), recursive=True)
if found:
files_to_push.append(found[0])
# Model info
model_info = {
"name": voice_name,
"type": "coqui-vits",
"base_model": "tts_models/en/ljspeech/vits",
"sample_rate": 22050,
}
info_path = os.path.join(model_dir, "model_info.json")
with open(info_path, "w") as f:
json.dump(model_info, f, indent=2)
files_to_push.append(info_path)
# ── Push each file ────────────────────────────────────────
pushed = 0
for fpath in files_to_push:
rel = os.path.relpath(fpath, model_dir)
gitea_path = f"{voice_name}/{rel}"
print(f"Pushing: {gitea_path} ({os.path.getsize(fpath)} bytes)")
with open(fpath, "rb") as f:
content_b64 = base64.b64encode(f.read()).decode()
# Check if file exists
r = session.get(
f"{api}/repos/{gitea_owner}/{gitea_repo}/contents/{gitea_path}",
timeout=30,
)
payload = {
"content": content_b64,
"message": f"Upload {voice_name}: {rel}",
}
if r.status_code == 200:
sha = r.json().get("sha", "")
payload["sha"] = sha
r = session.put(
f"{api}/repos/{gitea_owner}/{gitea_repo}/contents/{gitea_path}",
json=payload, timeout=120,
)
else:
r = session.post(
f"{api}/repos/{gitea_owner}/{gitea_repo}/contents/{gitea_path}",
json=payload, timeout=120,
)
if r.status_code in (200, 201):
pushed += 1
print(f" ✓ Pushed")
else:
print(f" ✗ Failed ({r.status_code}): {r.text[:200]}")
print(f"\nPushed {pushed}/{len(files_to_push)} files to {repo_url}")
return out(repo_url=repo_url, files_pushed=pushed)
# ──────────────────────────────────────────────────────────────
# 6. Log metrics to MLflow
# ──────────────────────────────────────────────────────────────
@dsl.component(
base_image="python:3.13-slim",
packages_to_install=["mlflow>=2.10.0", "requests"],
)
def log_training_metrics(
voice_name: str,
num_segments: int,
total_duration_s: float,
final_loss: float,
num_epochs: int,
batch_size: int,
learning_rate: float,
repo_url: str,
files_pushed: int,
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
experiment_name: str = "voice-cloning",
) -> NamedTuple("LogOutput", [("run_id", str)]):
"""Log training run to MLflow."""
import mlflow
from datetime import datetime
out = NamedTuple("LogOutput", [("run_id", str)])
mlflow.set_tracking_uri(mlflow_tracking_uri)
mlflow.set_experiment(experiment_name)
with mlflow.start_run(run_name=f"voice-clone-{voice_name}-{datetime.now():%Y%m%d-%H%M}") as run:
mlflow.log_params({
"voice_name": voice_name,
"base_model": "tts_models/en/ljspeech/vits",
"model_type": "coqui-vits",
"num_epochs": num_epochs,
"batch_size": batch_size,
"learning_rate": learning_rate,
"sample_rate": 22050,
})
mlflow.log_metrics({
"num_training_segments": num_segments,
"total_audio_duration_s": total_duration_s,
"final_loss": final_loss,
"files_pushed": files_pushed,
})
mlflow.set_tags({
"pipeline": "voice-cloning",
"gitea_repo": repo_url,
"voice_name": voice_name,
})
print(f"Logged to MLflow run: {run.info.run_id}")
return out(run_id=run.info.run_id)
# ──────────────────────────────────────────────────────────────
# Pipeline definition
# ──────────────────────────────────────────────────────────────
@dsl.pipeline(
name="Voice Cloning Pipeline",
description=(
"Extract a speaker from audio+transcript, fine-tune a Coqui VITS "
"voice model, push to Gitea, and log metrics to MLflow."
),
)
def voice_cloning_pipeline(
s3_endpoint: str = "https://gravenhollow.lab.daviestechlabs.io:30292",
s3_bucket: str = "training-data",
s3_key: str = "",
target_speaker: str = "SPEAKER_0",
voice_name: str = "custom-voice",
language: str = "en",
base_model: str = "tts_models/en/ljspeech/vits",
num_epochs: int = 100,
batch_size: int = 16,
learning_rate: float = 0.0001,
min_segment_duration_s: float = 1.0,
max_segment_duration_s: float = 15.0,
# Whisper / inference endpoints
whisper_url: str = "http://ai-inference-serve-svc.kuberay.svc.cluster.local:8000/whisper",
# Gitea
gitea_url: str = "http://gitea-http.gitea.svc.cluster.local:3000",
gitea_owner: str = "daviestechlabs",
gitea_repo: str = "voice-models",
gitea_username: str = "",
gitea_password: str = "",
# MLflow
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
):
# 1 - Download from S3 and transcribe with Whisper
transcribed = transcribe_and_diarise(
s3_endpoint=s3_endpoint,
s3_bucket=s3_bucket,
s3_key=s3_key,
whisper_url=whisper_url,
)
# 2 - Extract target speaker's segments
extracted = extract_speaker_segments(
transcript_json=transcribed.outputs["transcript_json"],
audio_path=transcribed.outputs["audio_path"],
target_speaker=target_speaker,
min_duration_s=min_segment_duration_s,
max_duration_s=max_segment_duration_s,
)
# 3 - Build LJSpeech dataset
dataset = prepare_ljspeech_dataset(
segments_json=extracted.outputs["segments_json"],
voice_name=voice_name,
language=language,
)
# 4 - Train VITS model
trained = train_vits_voice(
dataset_dir=dataset.outputs["dataset_dir"],
voice_name=voice_name,
language=language,
base_model=base_model,
num_epochs=num_epochs,
batch_size=batch_size,
learning_rate=learning_rate,
)
trained.set_accelerator_type("gpu")
trained.set_gpu_limit(1)
trained.set_memory_request("16Gi")
trained.set_memory_limit("32Gi")
trained.set_cpu_request("4")
trained.set_cpu_limit("8")
# 5 - Push model to Gitea
pushed = push_model_to_gitea(
model_dir=trained.outputs["model_dir"],
voice_name=voice_name,
gitea_url=gitea_url,
gitea_owner=gitea_owner,
gitea_repo=gitea_repo,
gitea_username=gitea_username,
gitea_password=gitea_password,
)
# 6 - Log to MLflow
log_training_metrics(
voice_name=voice_name,
num_segments=extracted.outputs["num_segments"],
total_duration_s=extracted.outputs["total_duration_s"],
final_loss=trained.outputs["final_loss"],
num_epochs=num_epochs,
batch_size=batch_size,
learning_rate=learning_rate,
repo_url=pushed.outputs["repo_url"],
files_pushed=pushed.outputs["files_pushed"],
mlflow_tracking_uri=mlflow_tracking_uri,
)
# ──────────────────────────────────────────────────────────────
# Compile
# ──────────────────────────────────────────────────────────────
if __name__ == "__main__":
compiler.Compiler().compile(
pipeline_func=voice_cloning_pipeline,
package_path="voice_cloning_pipeline.yaml",
)
print("Compiled: voice_cloning_pipeline.yaml")