688 lines
25 KiB
Python
688 lines
25 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Voice Cloning Pipeline – Kubeflow Pipelines SDK
|
||
|
||
Takes an audio file and a transcript, extracts a target speaker's
|
||
segments, preprocesses into LJSpeech-format training data, fine-tunes
|
||
a Coqui VITS voice model, pushes the model to Gitea, and logs to MLflow.
|
||
|
||
Usage:
|
||
pip install kfp==2.12.1
|
||
python voice_cloning_pipeline.py
|
||
# Upload voice_cloning_pipeline.yaml to Kubeflow Pipelines UI
|
||
"""
|
||
|
||
from kfp import compiler, dsl
|
||
from typing import NamedTuple
|
||
|
||
|
||
# ──────────────────────────────────────────────────────────────
|
||
# 1. Transcribe + diarise audio via Whisper to identify speakers
|
||
# ──────────────────────────────────────────────────────────────
|
||
@dsl.component(
|
||
base_image="python:3.13-slim",
|
||
packages_to_install=["requests", "boto3"],
|
||
)
|
||
def transcribe_and_diarise(
|
||
s3_endpoint: str,
|
||
s3_bucket: str,
|
||
s3_key: str,
|
||
whisper_url: str = "http://ai-inference-serve-svc.kuberay.svc.cluster.local:8000/whisper",
|
||
) -> NamedTuple("TranscriptOutput", [("transcript_json", str), ("speakers", str), ("audio_path", str)]):
|
||
"""Download audio from Quobjects S3, transcribe via Whisper with timestamps."""
|
||
import json
|
||
import os
|
||
import subprocess
|
||
import tempfile
|
||
import base64
|
||
import boto3
|
||
import requests
|
||
|
||
out = NamedTuple("TranscriptOutput", [("transcript_json", str), ("speakers", str), ("audio_path", str)])
|
||
work = tempfile.mkdtemp()
|
||
|
||
# ── Download audio from Quobjects S3 ─────────────────────
|
||
ext = os.path.splitext(s3_key)[-1] or ".wav"
|
||
audio_path = os.path.join(work, f"audio_raw{ext}")
|
||
|
||
client = boto3.client(
|
||
"s3",
|
||
endpoint_url=s3_endpoint,
|
||
aws_access_key_id="",
|
||
aws_secret_access_key="",
|
||
config=boto3.session.Config(signature_version="UNSIGNED"),
|
||
verify=False,
|
||
)
|
||
print(f"Downloading s3://{s3_bucket}/{s3_key} from {s3_endpoint}")
|
||
client.download_file(s3_bucket, s3_key, audio_path)
|
||
print(f"Downloaded {os.path.getsize(audio_path)} bytes")
|
||
|
||
# ── Normalise to 16 kHz mono WAV ─────────────────────────
|
||
wav_path = os.path.join(work, "audio.wav")
|
||
subprocess.run(
|
||
["apt-get", "update", "-qq"],
|
||
capture_output=True,
|
||
)
|
||
subprocess.run(
|
||
["apt-get", "install", "-y", "-qq", "ffmpeg"],
|
||
capture_output=True, check=True,
|
||
)
|
||
subprocess.run(
|
||
["ffmpeg", "-y", "-i", audio_path, "-ac", "1",
|
||
"-ar", "16000", "-sample_fmt", "s16", wav_path],
|
||
capture_output=True, check=True,
|
||
)
|
||
|
||
# ── Send to Whisper for timestamped transcription ─────────
|
||
with open(wav_path, "rb") as f:
|
||
audio_b64 = base64.b64encode(f.read()).decode()
|
||
|
||
payload = {
|
||
"audio": audio_b64,
|
||
"response_format": "verbose_json",
|
||
"timestamp_granularities": ["segment"],
|
||
}
|
||
resp = requests.post(whisper_url, json=payload, timeout=600)
|
||
resp.raise_for_status()
|
||
result = resp.json()
|
||
|
||
segments = result.get("segments", [])
|
||
print(f"Whisper returned {len(segments)} segments")
|
||
|
||
# ── Group segments by speaker if diarisation is present ───
|
||
# Whisper may not diarise, but we still produce segments with
|
||
# start/end timestamps that the next step can use.
|
||
speakers = set()
|
||
for i, seg in enumerate(segments):
|
||
spk = seg.get("speaker", f"SPEAKER_{i // 10}")
|
||
seg["speaker"] = spk
|
||
speakers.add(spk)
|
||
|
||
speakers_list = sorted(speakers)
|
||
print(f"Detected speakers: {speakers_list}")
|
||
|
||
return out(
|
||
transcript_json=json.dumps(segments),
|
||
speakers=json.dumps(speakers_list),
|
||
audio_path=wav_path,
|
||
)
|
||
|
||
|
||
# ──────────────────────────────────────────────────────────────
|
||
# 2. Extract target speaker's audio segments
|
||
# ──────────────────────────────────────────────────────────────
|
||
@dsl.component(
|
||
base_image="python:3.13-slim",
|
||
packages_to_install=[],
|
||
)
|
||
def extract_speaker_segments(
|
||
transcript_json: str,
|
||
audio_path: str,
|
||
target_speaker: str,
|
||
min_duration_s: float = 1.0,
|
||
max_duration_s: float = 15.0,
|
||
) -> NamedTuple("SpeakerSegments", [("segments_json", str), ("num_segments", int), ("total_duration_s", float)]):
|
||
"""Slice the audio into per-utterance WAV files for the target speaker."""
|
||
import json
|
||
import os
|
||
import subprocess
|
||
import tempfile
|
||
|
||
out = NamedTuple("SpeakerSegments", [("segments_json", str), ("num_segments", int), ("total_duration_s", float)])
|
||
work = tempfile.mkdtemp()
|
||
wavs_dir = os.path.join(work, "wavs")
|
||
os.makedirs(wavs_dir, exist_ok=True)
|
||
|
||
# Install ffmpeg
|
||
subprocess.run(["apt-get", "update", "-qq"], capture_output=True)
|
||
subprocess.run(["apt-get", "install", "-y", "-qq", "ffmpeg"], capture_output=True, check=True)
|
||
|
||
segments = json.loads(transcript_json)
|
||
|
||
# Filter by speaker — fuzzy match (case-insensitive, partial)
|
||
target_lower = target_speaker.lower()
|
||
matched = []
|
||
for seg in segments:
|
||
spk = seg.get("speaker", "").lower()
|
||
if target_lower in spk or spk in target_lower:
|
||
matched.append(seg)
|
||
|
||
# If no speaker labels matched, the user may have given a name
|
||
# that doesn't appear. Fall back to using ALL segments.
|
||
if not matched:
|
||
print(f"WARNING: No segments matched speaker '{target_speaker}'. "
|
||
f"Using all {len(segments)} segments.")
|
||
matched = segments
|
||
|
||
print(f"Matched {len(matched)} segments for speaker '{target_speaker}'")
|
||
|
||
kept = []
|
||
total_dur = 0.0
|
||
for i, seg in enumerate(matched):
|
||
start = float(seg.get("start", 0))
|
||
end = float(seg.get("end", start + 5))
|
||
duration = end - start
|
||
text = seg.get("text", "").strip()
|
||
|
||
if duration < min_duration_s or not text:
|
||
continue
|
||
if duration > max_duration_s:
|
||
end = start + max_duration_s
|
||
duration = max_duration_s
|
||
|
||
wav_name = f"utt_{i:04d}.wav"
|
||
wav_out = os.path.join(wavs_dir, wav_name)
|
||
subprocess.run(
|
||
["ffmpeg", "-y", "-i", audio_path,
|
||
"-ss", str(start), "-to", str(end),
|
||
"-ac", "1", "-ar", "22050", "-sample_fmt", "s16",
|
||
wav_out],
|
||
capture_output=True, check=True,
|
||
)
|
||
|
||
kept.append({
|
||
"wav": wav_name,
|
||
"text": text,
|
||
"start": start,
|
||
"end": end,
|
||
"duration": round(duration, 2),
|
||
})
|
||
total_dur += duration
|
||
|
||
print(f"Extracted {len(kept)} utterances, total {total_dur:.1f}s")
|
||
|
||
return out(
|
||
segments_json=json.dumps({"wavs_dir": wavs_dir, "utterances": kept}),
|
||
num_segments=len(kept),
|
||
total_duration_s=round(total_dur, 2),
|
||
)
|
||
|
||
|
||
# ──────────────────────────────────────────────────────────────
|
||
# 3. Prepare LJSpeech-format dataset for Coqui TTS
|
||
# ──────────────────────────────────────────────────────────────
|
||
@dsl.component(
|
||
base_image="python:3.13-slim",
|
||
packages_to_install=[],
|
||
)
|
||
def prepare_ljspeech_dataset(
|
||
segments_json: str,
|
||
voice_name: str,
|
||
language: str = "en",
|
||
) -> NamedTuple("DatasetOutput", [("dataset_dir", str), ("num_samples", int)]):
|
||
"""Create metadata.csv + wavs/ in LJSpeech format."""
|
||
import json
|
||
import os
|
||
import shutil
|
||
|
||
out = NamedTuple("DatasetOutput", [("dataset_dir", str), ("num_samples", int)])
|
||
|
||
data = json.loads(segments_json)
|
||
wavs_src = data["wavs_dir"]
|
||
utterances = data["utterances"]
|
||
|
||
dataset_dir = os.path.join(os.path.dirname(wavs_src), "dataset")
|
||
wavs_dst = os.path.join(dataset_dir, "wavs")
|
||
os.makedirs(wavs_dst, exist_ok=True)
|
||
|
||
lines = []
|
||
for utt in utterances:
|
||
src = os.path.join(wavs_src, utt["wav"])
|
||
dst = os.path.join(wavs_dst, utt["wav"])
|
||
shutil.copy2(src, dst)
|
||
stem = os.path.splitext(utt["wav"])[0]
|
||
# LJSpeech format: id|text|text
|
||
text = utt["text"].replace("|", " ")
|
||
lines.append(f"{stem}|{text}|{text}")
|
||
|
||
metadata_path = os.path.join(dataset_dir, "metadata.csv")
|
||
with open(metadata_path, "w", encoding="utf-8") as f:
|
||
f.write("\n".join(lines))
|
||
|
||
# Dataset config for reference
|
||
import json as _json
|
||
config = {
|
||
"name": voice_name,
|
||
"language": language,
|
||
"num_samples": len(lines),
|
||
"format": "ljspeech",
|
||
"sample_rate": 22050,
|
||
}
|
||
with open(os.path.join(dataset_dir, "dataset_config.json"), "w") as f:
|
||
_json.dump(config, f, indent=2)
|
||
|
||
print(f"LJSpeech dataset ready: {len(lines)} samples")
|
||
return out(dataset_dir=dataset_dir, num_samples=len(lines))
|
||
|
||
|
||
# ──────────────────────────────────────────────────────────────
|
||
# 4. Fine-tune Coqui VITS voice model
|
||
# ──────────────────────────────────────────────────────────────
|
||
@dsl.component(
|
||
base_image="ghcr.io/idiap/coqui-tts:latest",
|
||
packages_to_install=[],
|
||
)
|
||
def train_vits_voice(
|
||
dataset_dir: str,
|
||
voice_name: str,
|
||
language: str = "en",
|
||
base_model: str = "tts_models/en/ljspeech/vits",
|
||
num_epochs: int = 100,
|
||
batch_size: int = 16,
|
||
learning_rate: float = 0.0001,
|
||
) -> NamedTuple("TrainOutput", [("model_dir", str), ("best_checkpoint", str), ("final_loss", float)]):
|
||
"""Fine-tune a VITS model on the speaker dataset."""
|
||
import os
|
||
import json
|
||
import glob
|
||
|
||
out = NamedTuple("TrainOutput", [("model_dir", str), ("best_checkpoint", str), ("final_loss", float)])
|
||
|
||
OUTPUT_DIR = "/tmp/vits_output"
|
||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||
|
||
print(f"=== Coqui VITS Voice Training ===")
|
||
print(f"Voice name : {voice_name}")
|
||
print(f"Base model : {base_model}")
|
||
print(f"Dataset : {dataset_dir}")
|
||
print(f"Epochs : {num_epochs}")
|
||
print(f"Batch size : {batch_size}")
|
||
print(f"LR : {learning_rate}")
|
||
|
||
# ── Download base model checkpoint ────────────────────────
|
||
restore_path = None
|
||
if base_model and base_model != "none":
|
||
from TTS.utils.manage import ModelManager
|
||
manager = ModelManager()
|
||
model_path, config_path, _ = manager.download_model(base_model)
|
||
restore_path = model_path
|
||
print(f"Base model checkpoint: {restore_path}")
|
||
|
||
# ── Configure and train ───────────────────────────────────
|
||
from trainer import Trainer, TrainerArgs
|
||
from TTS.tts.configs.vits_config import VitsConfig
|
||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||
from TTS.tts.datasets import load_tts_samples
|
||
from TTS.tts.models.vits import Vits
|
||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||
from TTS.utils.audio import AudioProcessor
|
||
|
||
dataset_config = BaseDatasetConfig(
|
||
formatter="ljspeech",
|
||
meta_file_train="metadata.csv",
|
||
path=dataset_dir,
|
||
language=language,
|
||
)
|
||
|
||
config = VitsConfig(
|
||
run_name=voice_name,
|
||
output_path=OUTPUT_DIR,
|
||
datasets=[dataset_config],
|
||
batch_size=batch_size,
|
||
eval_batch_size=max(1, batch_size // 2),
|
||
num_loader_workers=4,
|
||
num_eval_loader_workers=2,
|
||
run_eval=True,
|
||
test_delay_epochs=5,
|
||
epochs=num_epochs,
|
||
text_cleaner="phoneme_cleaners",
|
||
use_phonemes=True,
|
||
phoneme_language=language,
|
||
phoneme_cache_path=os.path.join(OUTPUT_DIR, "phoneme_cache"),
|
||
compute_input_seq_cache=True,
|
||
print_step=25,
|
||
print_eval=False,
|
||
mixed_precision=True,
|
||
save_step=500,
|
||
save_n_checkpoints=3,
|
||
save_best_after=1000,
|
||
lr=learning_rate,
|
||
audio={
|
||
"sample_rate": 22050,
|
||
"resample": True,
|
||
"do_trim_silence": True,
|
||
"trim_db": 45,
|
||
},
|
||
)
|
||
|
||
ap = AudioProcessor.init_from_config(config)
|
||
tokenizer, config = TTSTokenizer.init_from_config(config)
|
||
|
||
train_samples, eval_samples = load_tts_samples(
|
||
dataset_config,
|
||
eval_split=True,
|
||
eval_split_max_size=config.eval_split_max_size,
|
||
eval_split_size=config.eval_split_size,
|
||
)
|
||
print(f"Training samples: {len(train_samples)}")
|
||
print(f"Eval samples: {len(eval_samples)}")
|
||
|
||
model = Vits(config, ap, tokenizer, speaker_manager=None)
|
||
|
||
trainer_args = TrainerArgs(
|
||
restore_path=restore_path,
|
||
skip_train_epoch=False,
|
||
)
|
||
|
||
trainer = Trainer(
|
||
trainer_args,
|
||
config,
|
||
output_path=OUTPUT_DIR,
|
||
model=model,
|
||
train_samples=train_samples,
|
||
eval_samples=eval_samples,
|
||
)
|
||
|
||
trainer.fit()
|
||
|
||
# ── Find best checkpoint ──────────────────────────────────
|
||
best_files = glob.glob(os.path.join(OUTPUT_DIR, "**/best_model*.pth"), recursive=True)
|
||
if not best_files:
|
||
best_files = glob.glob(os.path.join(OUTPUT_DIR, "**/*.pth"), recursive=True)
|
||
best_files.sort(key=os.path.getmtime, reverse=True)
|
||
best_checkpoint = best_files[0] if best_files else ""
|
||
|
||
# Try to read final loss from trainer
|
||
final_loss = 0.0
|
||
try:
|
||
final_loss = float(trainer.keep_avg_train["avg_loss"])
|
||
except Exception:
|
||
pass
|
||
|
||
print(f"Training complete. Best checkpoint: {best_checkpoint}")
|
||
print(f"Final loss: {final_loss:.4f}")
|
||
|
||
return out(model_dir=OUTPUT_DIR, best_checkpoint=best_checkpoint, final_loss=final_loss)
|
||
|
||
|
||
# ──────────────────────────────────────────────────────────────
|
||
# 5. Push trained voice model to Gitea repository
|
||
# ──────────────────────────────────────────────────────────────
|
||
@dsl.component(
|
||
base_image="python:3.13-slim",
|
||
packages_to_install=["requests"],
|
||
)
|
||
def push_model_to_gitea(
|
||
model_dir: str,
|
||
voice_name: str,
|
||
gitea_url: str = "http://gitea-http.gitea.svc.cluster.local:3000",
|
||
gitea_owner: str = "daviestechlabs",
|
||
gitea_repo: str = "voice-models",
|
||
gitea_username: str = "",
|
||
gitea_password: str = "",
|
||
) -> NamedTuple("PushOutput", [("repo_url", str), ("files_pushed", int)]):
|
||
"""Package and push the trained model to a Gitea repository."""
|
||
import base64
|
||
import glob
|
||
import json
|
||
import os
|
||
import requests
|
||
|
||
out = NamedTuple("PushOutput", [("repo_url", str), ("files_pushed", int)])
|
||
session = requests.Session()
|
||
session.auth = (gitea_username, gitea_password) if gitea_username else None
|
||
|
||
api = f"{gitea_url}/api/v1"
|
||
repo_url = f"{gitea_url}/{gitea_owner}/{gitea_repo}"
|
||
|
||
# ── Ensure repo exists ────────────────────────────────────
|
||
r = session.get(f"{api}/repos/{gitea_owner}/{gitea_repo}", timeout=30)
|
||
if r.status_code == 404:
|
||
print(f"Creating repository: {gitea_owner}/{gitea_repo}")
|
||
r = session.post(
|
||
f"{api}/orgs/{gitea_owner}/repos",
|
||
json={
|
||
"name": gitea_repo,
|
||
"description": "Trained voice models from voice cloning pipeline",
|
||
"private": False,
|
||
"auto_init": True,
|
||
},
|
||
timeout=30,
|
||
)
|
||
if r.status_code not in (200, 201):
|
||
r = session.post(
|
||
f"{api}/user/repos",
|
||
json={"name": gitea_repo, "description": "Trained voice models", "auto_init": True},
|
||
timeout=30,
|
||
)
|
||
r.raise_for_status()
|
||
print("Repository created")
|
||
|
||
# ── Collect model files ───────────────────────────────────
|
||
files_to_push = []
|
||
|
||
# Best model checkpoint
|
||
for pattern in ["**/best_model*.pth", "**/*.pth"]:
|
||
found = glob.glob(os.path.join(model_dir, pattern), recursive=True)
|
||
if found:
|
||
found.sort(key=os.path.getmtime, reverse=True)
|
||
files_to_push.append(found[0])
|
||
break
|
||
|
||
# Config
|
||
for pattern in ["**/config.json"]:
|
||
found = glob.glob(os.path.join(model_dir, pattern), recursive=True)
|
||
if found:
|
||
files_to_push.append(found[0])
|
||
|
||
# Model info
|
||
model_info = {
|
||
"name": voice_name,
|
||
"type": "coqui-vits",
|
||
"base_model": "tts_models/en/ljspeech/vits",
|
||
"sample_rate": 22050,
|
||
}
|
||
info_path = os.path.join(model_dir, "model_info.json")
|
||
with open(info_path, "w") as f:
|
||
json.dump(model_info, f, indent=2)
|
||
files_to_push.append(info_path)
|
||
|
||
# ── Push each file ────────────────────────────────────────
|
||
pushed = 0
|
||
for fpath in files_to_push:
|
||
rel = os.path.relpath(fpath, model_dir)
|
||
gitea_path = f"{voice_name}/{rel}"
|
||
print(f"Pushing: {gitea_path} ({os.path.getsize(fpath)} bytes)")
|
||
|
||
with open(fpath, "rb") as f:
|
||
content_b64 = base64.b64encode(f.read()).decode()
|
||
|
||
# Check if file exists
|
||
r = session.get(
|
||
f"{api}/repos/{gitea_owner}/{gitea_repo}/contents/{gitea_path}",
|
||
timeout=30,
|
||
)
|
||
|
||
payload = {
|
||
"content": content_b64,
|
||
"message": f"Upload {voice_name}: {rel}",
|
||
}
|
||
|
||
if r.status_code == 200:
|
||
sha = r.json().get("sha", "")
|
||
payload["sha"] = sha
|
||
r = session.put(
|
||
f"{api}/repos/{gitea_owner}/{gitea_repo}/contents/{gitea_path}",
|
||
json=payload, timeout=120,
|
||
)
|
||
else:
|
||
r = session.post(
|
||
f"{api}/repos/{gitea_owner}/{gitea_repo}/contents/{gitea_path}",
|
||
json=payload, timeout=120,
|
||
)
|
||
|
||
if r.status_code in (200, 201):
|
||
pushed += 1
|
||
print(f" ✓ Pushed")
|
||
else:
|
||
print(f" ✗ Failed ({r.status_code}): {r.text[:200]}")
|
||
|
||
print(f"\nPushed {pushed}/{len(files_to_push)} files to {repo_url}")
|
||
return out(repo_url=repo_url, files_pushed=pushed)
|
||
|
||
|
||
# ──────────────────────────────────────────────────────────────
|
||
# 6. Log metrics to MLflow
|
||
# ──────────────────────────────────────────────────────────────
|
||
@dsl.component(
|
||
base_image="python:3.13-slim",
|
||
packages_to_install=["mlflow>=2.10.0", "requests"],
|
||
)
|
||
def log_training_metrics(
|
||
voice_name: str,
|
||
num_segments: int,
|
||
total_duration_s: float,
|
||
final_loss: float,
|
||
num_epochs: int,
|
||
batch_size: int,
|
||
learning_rate: float,
|
||
repo_url: str,
|
||
files_pushed: int,
|
||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||
experiment_name: str = "voice-cloning",
|
||
) -> NamedTuple("LogOutput", [("run_id", str)]):
|
||
"""Log training run to MLflow."""
|
||
import mlflow
|
||
from datetime import datetime
|
||
|
||
out = NamedTuple("LogOutput", [("run_id", str)])
|
||
|
||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||
mlflow.set_experiment(experiment_name)
|
||
|
||
with mlflow.start_run(run_name=f"voice-clone-{voice_name}-{datetime.now():%Y%m%d-%H%M}") as run:
|
||
mlflow.log_params({
|
||
"voice_name": voice_name,
|
||
"base_model": "tts_models/en/ljspeech/vits",
|
||
"model_type": "coqui-vits",
|
||
"num_epochs": num_epochs,
|
||
"batch_size": batch_size,
|
||
"learning_rate": learning_rate,
|
||
"sample_rate": 22050,
|
||
})
|
||
mlflow.log_metrics({
|
||
"num_training_segments": num_segments,
|
||
"total_audio_duration_s": total_duration_s,
|
||
"final_loss": final_loss,
|
||
"files_pushed": files_pushed,
|
||
})
|
||
mlflow.set_tags({
|
||
"pipeline": "voice-cloning",
|
||
"gitea_repo": repo_url,
|
||
"voice_name": voice_name,
|
||
})
|
||
print(f"Logged to MLflow run: {run.info.run_id}")
|
||
return out(run_id=run.info.run_id)
|
||
|
||
|
||
# ──────────────────────────────────────────────────────────────
|
||
# Pipeline definition
|
||
# ──────────────────────────────────────────────────────────────
|
||
@dsl.pipeline(
|
||
name="Voice Cloning Pipeline",
|
||
description=(
|
||
"Extract a speaker from audio+transcript, fine-tune a Coqui VITS "
|
||
"voice model, push to Gitea, and log metrics to MLflow."
|
||
),
|
||
)
|
||
def voice_cloning_pipeline(
|
||
s3_endpoint: str = "https://gravenhollow.lab.daviestechlabs.io:30292",
|
||
s3_bucket: str = "training-data",
|
||
s3_key: str = "",
|
||
target_speaker: str = "SPEAKER_0",
|
||
voice_name: str = "custom-voice",
|
||
language: str = "en",
|
||
base_model: str = "tts_models/en/ljspeech/vits",
|
||
num_epochs: int = 100,
|
||
batch_size: int = 16,
|
||
learning_rate: float = 0.0001,
|
||
min_segment_duration_s: float = 1.0,
|
||
max_segment_duration_s: float = 15.0,
|
||
# Whisper / inference endpoints
|
||
whisper_url: str = "http://ai-inference-serve-svc.kuberay.svc.cluster.local:8000/whisper",
|
||
# Gitea
|
||
gitea_url: str = "http://gitea-http.gitea.svc.cluster.local:3000",
|
||
gitea_owner: str = "daviestechlabs",
|
||
gitea_repo: str = "voice-models",
|
||
gitea_username: str = "",
|
||
gitea_password: str = "",
|
||
# MLflow
|
||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||
):
|
||
# 1 - Download from Quobjects S3 and transcribe with Whisper
|
||
transcribed = transcribe_and_diarise(
|
||
s3_endpoint=s3_endpoint,
|
||
s3_bucket=s3_bucket,
|
||
s3_key=s3_key,
|
||
whisper_url=whisper_url,
|
||
)
|
||
|
||
# 2 - Extract target speaker's segments
|
||
extracted = extract_speaker_segments(
|
||
transcript_json=transcribed.outputs["transcript_json"],
|
||
audio_path=transcribed.outputs["audio_path"],
|
||
target_speaker=target_speaker,
|
||
min_duration_s=min_segment_duration_s,
|
||
max_duration_s=max_segment_duration_s,
|
||
)
|
||
|
||
# 3 - Build LJSpeech dataset
|
||
dataset = prepare_ljspeech_dataset(
|
||
segments_json=extracted.outputs["segments_json"],
|
||
voice_name=voice_name,
|
||
language=language,
|
||
)
|
||
|
||
# 4 - Train VITS model
|
||
trained = train_vits_voice(
|
||
dataset_dir=dataset.outputs["dataset_dir"],
|
||
voice_name=voice_name,
|
||
language=language,
|
||
base_model=base_model,
|
||
num_epochs=num_epochs,
|
||
batch_size=batch_size,
|
||
learning_rate=learning_rate,
|
||
)
|
||
trained.set_accelerator_type("gpu")
|
||
trained.set_gpu_limit(1)
|
||
trained.set_memory_request("16Gi")
|
||
trained.set_memory_limit("32Gi")
|
||
trained.set_cpu_request("4")
|
||
trained.set_cpu_limit("8")
|
||
|
||
# 5 - Push model to Gitea
|
||
pushed = push_model_to_gitea(
|
||
model_dir=trained.outputs["model_dir"],
|
||
voice_name=voice_name,
|
||
gitea_url=gitea_url,
|
||
gitea_owner=gitea_owner,
|
||
gitea_repo=gitea_repo,
|
||
gitea_username=gitea_username,
|
||
gitea_password=gitea_password,
|
||
)
|
||
|
||
# 6 - Log to MLflow
|
||
log_training_metrics(
|
||
voice_name=voice_name,
|
||
num_segments=extracted.outputs["num_segments"],
|
||
total_duration_s=extracted.outputs["total_duration_s"],
|
||
final_loss=trained.outputs["final_loss"],
|
||
num_epochs=num_epochs,
|
||
batch_size=batch_size,
|
||
learning_rate=learning_rate,
|
||
repo_url=pushed.outputs["repo_url"],
|
||
files_pushed=pushed.outputs["files_pushed"],
|
||
mlflow_tracking_uri=mlflow_tracking_uri,
|
||
)
|
||
|
||
|
||
# ──────────────────────────────────────────────────────────────
|
||
# Compile
|
||
# ──────────────────────────────────────────────────────────────
|
||
if __name__ == "__main__":
|
||
compiler.Compiler().compile(
|
||
pipeline_func=voice_cloning_pipeline,
|
||
package_path="voice_cloning_pipeline.yaml",
|
||
)
|
||
print("Compiled: voice_cloning_pipeline.yaml")
|