its coqui-tts not coqui.
This commit is contained in:
@@ -65,11 +65,12 @@ def fetch_pdfs_from_s3(
|
|||||||
|
|
||||||
client = boto3.client(
|
client = boto3.client(
|
||||||
"s3",
|
"s3",
|
||||||
endpoint_url=f"http://{s3_endpoint}",
|
endpoint_url=s3_endpoint,
|
||||||
aws_access_key_id=aws_access_key_id,
|
aws_access_key_id=aws_access_key_id,
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
region_name="us-east-1",
|
region_name="us-east-1",
|
||||||
config=Config(signature_version="s3v4"),
|
config=Config(signature_version="s3v4"),
|
||||||
|
verify=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
paginator = client.get_paginator("list_objects_v2")
|
paginator = client.get_paginator("list_objects_v2")
|
||||||
@@ -226,11 +227,12 @@ def upload_data_to_s3(
|
|||||||
|
|
||||||
client = boto3.client(
|
client = boto3.client(
|
||||||
"s3",
|
"s3",
|
||||||
endpoint_url=f"http://{s3_endpoint}",
|
endpoint_url=s3_endpoint,
|
||||||
aws_access_key_id=aws_access_key_id,
|
aws_access_key_id=aws_access_key_id,
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
region_name="us-east-1",
|
region_name="us-east-1",
|
||||||
config=Config(signature_version="s3v4"),
|
config=Config(signature_version="s3v4"),
|
||||||
|
verify=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
for fname in ["train.json", "val.json"]:
|
for fname in ["train.json", "val.json"]:
|
||||||
@@ -273,11 +275,12 @@ from trl import SFTTrainer
|
|||||||
def _s3_client(cfg):
|
def _s3_client(cfg):
|
||||||
return boto3.client(
|
return boto3.client(
|
||||||
"s3",
|
"s3",
|
||||||
endpoint_url=f"http://{cfg['s3_endpoint']}",
|
endpoint_url=cfg['s3_endpoint'],
|
||||||
aws_access_key_id=cfg["aws_access_key_id"],
|
aws_access_key_id=cfg["aws_access_key_id"],
|
||||||
aws_secret_access_key=cfg["aws_secret_access_key"],
|
aws_secret_access_key=cfg["aws_secret_access_key"],
|
||||||
region_name="us-east-1",
|
region_name="us-east-1",
|
||||||
config=Config(signature_version="s3v4"),
|
config=Config(signature_version="s3v4"),
|
||||||
|
verify=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -693,11 +696,12 @@ def submit_ray_training_job(
|
|||||||
# ── Read results from S3 ────────────────────────────────
|
# ── Read results from S3 ────────────────────────────────
|
||||||
s3 = boto3.client(
|
s3 = boto3.client(
|
||||||
"s3",
|
"s3",
|
||||||
endpoint_url=f"http://{s3_endpoint}",
|
endpoint_url=s3_endpoint,
|
||||||
aws_access_key_id=aws_access_key_id,
|
aws_access_key_id=aws_access_key_id,
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
region_name="us-east-1",
|
region_name="us-east-1",
|
||||||
config=Config(signature_version="s3v4"),
|
config=Config(signature_version="s3v4"),
|
||||||
|
verify=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
results_obj = s3.get_object(
|
results_obj = s3.get_object(
|
||||||
@@ -751,11 +755,12 @@ def download_adapter_from_s3(
|
|||||||
|
|
||||||
client = boto3.client(
|
client = boto3.client(
|
||||||
"s3",
|
"s3",
|
||||||
endpoint_url=f"http://{s3_endpoint}",
|
endpoint_url=s3_endpoint,
|
||||||
aws_access_key_id=aws_access_key_id,
|
aws_access_key_id=aws_access_key_id,
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
region_name="us-east-1",
|
region_name="us-east-1",
|
||||||
config=Config(signature_version="s3v4"),
|
config=Config(signature_version="s3v4"),
|
||||||
|
verify=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
adapter_path = "/tmp/adapter"
|
adapter_path = "/tmp/adapter"
|
||||||
@@ -1019,7 +1024,7 @@ def log_training_metrics(
|
|||||||
)
|
)
|
||||||
def cpu_training_pipeline(
|
def cpu_training_pipeline(
|
||||||
# ── S3 / Quobjects ──
|
# ── S3 / Quobjects ──
|
||||||
s3_endpoint: str = "candlekeep.lab.daviestechlabs.io",
|
s3_endpoint: str = "https://gravenhollow.lab.daviestechlabs.io:30292",
|
||||||
s3_bucket: str = "training-data",
|
s3_bucket: str = "training-data",
|
||||||
s3_prefix: str = "",
|
s3_prefix: str = "",
|
||||||
aws_access_key_id: str = "",
|
aws_access_key_id: str = "",
|
||||||
|
|||||||
@@ -46,11 +46,12 @@ def fetch_pdfs_from_s3(
|
|||||||
|
|
||||||
client = boto3.client(
|
client = boto3.client(
|
||||||
"s3",
|
"s3",
|
||||||
endpoint_url=f"http://{s3_endpoint}",
|
endpoint_url=s3_endpoint,
|
||||||
aws_access_key_id=aws_access_key_id,
|
aws_access_key_id=aws_access_key_id,
|
||||||
aws_secret_access_key=aws_secret_access_key,
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
region_name="us-east-1",
|
region_name="us-east-1",
|
||||||
config=Config(signature_version="s3v4"),
|
config=Config(signature_version="s3v4"),
|
||||||
|
verify=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
paginator = client.get_paginator("list_objects_v2")
|
paginator = client.get_paginator("list_objects_v2")
|
||||||
@@ -597,7 +598,7 @@ def log_training_metrics(
|
|||||||
)
|
)
|
||||||
def qlora_pdf_pipeline(
|
def qlora_pdf_pipeline(
|
||||||
# ── S3 / Quobjects ──
|
# ── S3 / Quobjects ──
|
||||||
s3_endpoint: str = "candlekeep.lab.daviestechlabs.io",
|
s3_endpoint: str = "https://gravenhollow.lab.daviestechlabs.io:30292",
|
||||||
s3_bucket: str = "training-data",
|
s3_bucket: str = "training-data",
|
||||||
s3_prefix: str = "",
|
s3_prefix: str = "",
|
||||||
aws_access_key_id: str = "",
|
aws_access_key_id: str = "",
|
||||||
|
|||||||
@@ -47,10 +47,11 @@ def transcribe_and_diarise(
|
|||||||
|
|
||||||
client = boto3.client(
|
client = boto3.client(
|
||||||
"s3",
|
"s3",
|
||||||
endpoint_url=f"http://{s3_endpoint}",
|
endpoint_url=s3_endpoint,
|
||||||
aws_access_key_id="",
|
aws_access_key_id="",
|
||||||
aws_secret_access_key="",
|
aws_secret_access_key="",
|
||||||
config=boto3.session.Config(signature_version="UNSIGNED"),
|
config=boto3.session.Config(signature_version="UNSIGNED"),
|
||||||
|
verify=False,
|
||||||
)
|
)
|
||||||
print(f"Downloading s3://{s3_bucket}/{s3_key} from {s3_endpoint}")
|
print(f"Downloading s3://{s3_bucket}/{s3_key} from {s3_endpoint}")
|
||||||
client.download_file(s3_bucket, s3_key, audio_path)
|
client.download_file(s3_bucket, s3_key, audio_path)
|
||||||
@@ -258,7 +259,7 @@ def prepare_ljspeech_dataset(
|
|||||||
# 4. Fine-tune Coqui VITS voice model
|
# 4. Fine-tune Coqui VITS voice model
|
||||||
# ──────────────────────────────────────────────────────────────
|
# ──────────────────────────────────────────────────────────────
|
||||||
@dsl.component(
|
@dsl.component(
|
||||||
base_image="ghcr.io/coqui-ai/tts:latest",
|
base_image="ghcr.io/idiap/coqui-tts:latest",
|
||||||
packages_to_install=[],
|
packages_to_install=[],
|
||||||
)
|
)
|
||||||
def train_vits_voice(
|
def train_vits_voice(
|
||||||
@@ -585,7 +586,7 @@ def log_training_metrics(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
def voice_cloning_pipeline(
|
def voice_cloning_pipeline(
|
||||||
s3_endpoint: str = "candlekeep.lab.daviestechlabs.io",
|
s3_endpoint: str = "https://gravenhollow.lab.daviestechlabs.io:30292",
|
||||||
s3_bucket: str = "training-data",
|
s3_bucket: str = "training-data",
|
||||||
s3_key: str = "",
|
s3_key: str = "",
|
||||||
target_speaker: str = "SPEAKER_0",
|
target_speaker: str = "SPEAKER_0",
|
||||||
|
|||||||
@@ -547,7 +547,7 @@ deploymentSpec:
|
|||||||
\ complete. Best checkpoint: {best_checkpoint}\")\n print(f\"Final loss:\
|
\ complete. Best checkpoint: {best_checkpoint}\")\n print(f\"Final loss:\
|
||||||
\ {final_loss:.4f}\")\n\n return out(model_dir=OUTPUT_DIR, best_checkpoint=best_checkpoint,\
|
\ {final_loss:.4f}\")\n\n return out(model_dir=OUTPUT_DIR, best_checkpoint=best_checkpoint,\
|
||||||
\ final_loss=final_loss)\n\n"
|
\ final_loss=final_loss)\n\n"
|
||||||
image: ghcr.io/coqui-ai/tts:latest
|
image: ghcr.io/idiap/coqui-tts:latest
|
||||||
resources:
|
resources:
|
||||||
accelerator:
|
accelerator:
|
||||||
resourceCount: '1'
|
resourceCount: '1'
|
||||||
|
|||||||
Reference in New Issue
Block a user