- tts.py: rename ambiguous variable 'l' to 'line' (E741) - tts.py, llm.py: apply ruff formatter
572 lines
19 KiB
Python
572 lines
19 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
TTS Demo - Gradio UI for testing Text-to-Speech service.
|
|
|
|
Features:
|
|
- Text input with language selection
|
|
- Audio playback of synthesized speech
|
|
- Sentence-level chunking for better quality
|
|
- Speed control
|
|
- MLflow metrics logging
|
|
"""
|
|
|
|
import os
|
|
import re
|
|
import time
|
|
import logging
|
|
import io
|
|
import wave
|
|
|
|
import gradio as gr
|
|
import httpx
|
|
import numpy as np
|
|
|
|
from theme import get_lab_theme, CUSTOM_CSS, create_footer
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger("tts-demo")
|
|
|
|
# Configuration
|
|
TTS_URL = os.environ.get(
|
|
"TTS_URL",
|
|
# Default: Ray Serve TTS endpoint
|
|
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/tts",
|
|
)
|
|
MLFLOW_TRACKING_URI = os.environ.get(
|
|
"MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"
|
|
)
|
|
|
|
# ─── MLflow experiment tracking ──────────────────────────────────────────
|
|
try:
|
|
import mlflow
|
|
from mlflow.tracking import MlflowClient
|
|
|
|
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
|
|
_mlflow_client = MlflowClient()
|
|
|
|
_experiment = _mlflow_client.get_experiment_by_name("gradio-tts-tuning")
|
|
if _experiment is None:
|
|
_experiment_id = _mlflow_client.create_experiment(
|
|
"gradio-tts-tuning",
|
|
artifact_location="/mlflow/artifacts/gradio-tts-tuning",
|
|
)
|
|
else:
|
|
_experiment_id = _experiment.experiment_id
|
|
|
|
_mlflow_run = mlflow.start_run(
|
|
experiment_id=_experiment_id,
|
|
run_name=f"gradio-tts-{os.environ.get('HOSTNAME', 'local')}",
|
|
tags={"service": "gradio-tts", "endpoint": TTS_URL},
|
|
)
|
|
_mlflow_run_id = _mlflow_run.info.run_id
|
|
_mlflow_step = 0
|
|
MLFLOW_ENABLED = True
|
|
logger.info(
|
|
"MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("MLflow tracking disabled: %s", exc)
|
|
_mlflow_client = None
|
|
_mlflow_run_id = None
|
|
_mlflow_step = 0
|
|
MLFLOW_ENABLED = False
|
|
|
|
|
|
def _log_tts_metrics(
|
|
latency: float,
|
|
audio_duration: float,
|
|
text_chars: int,
|
|
language: str,
|
|
) -> None:
|
|
"""Log TTS inference metrics to MLflow (non-blocking best-effort)."""
|
|
global _mlflow_step
|
|
if not MLFLOW_ENABLED or _mlflow_client is None:
|
|
return
|
|
try:
|
|
_mlflow_step += 1
|
|
ts = int(time.time() * 1000)
|
|
rtf = latency / audio_duration if audio_duration > 0 else 0
|
|
cps = text_chars / latency if latency > 0 else 0
|
|
_mlflow_client.log_batch(
|
|
_mlflow_run_id,
|
|
metrics=[
|
|
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
|
|
mlflow.entities.Metric(
|
|
"audio_duration_s", audio_duration, ts, _mlflow_step
|
|
),
|
|
mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step),
|
|
mlflow.entities.Metric("chars_per_second", cps, ts, _mlflow_step),
|
|
mlflow.entities.Metric("text_chars", text_chars, ts, _mlflow_step),
|
|
],
|
|
)
|
|
except Exception:
|
|
logger.debug("MLflow log failed", exc_info=True)
|
|
|
|
|
|
# HTTP client with longer timeout for audio generation
|
|
client = httpx.Client(timeout=120.0)
|
|
|
|
# Supported languages for XTTS
|
|
LANGUAGES = {
|
|
"English": "en",
|
|
"Spanish": "es",
|
|
"French": "fr",
|
|
"German": "de",
|
|
"Italian": "it",
|
|
"Portuguese": "pt",
|
|
"Polish": "pl",
|
|
"Turkish": "tr",
|
|
"Russian": "ru",
|
|
"Dutch": "nl",
|
|
"Czech": "cs",
|
|
"Arabic": "ar",
|
|
"Chinese": "zh-cn",
|
|
"Japanese": "ja",
|
|
"Korean": "ko",
|
|
"Hungarian": "hu",
|
|
}
|
|
|
|
# ─── Text preprocessing ─────────────────────────────────────────────────
|
|
|
|
_SENTENCE_RE = re.compile(r"(?<=[.!?;])\s+|(?<=\n)\s*", re.MULTILINE)
|
|
|
|
_DIGIT_WORDS = {
|
|
"0": "zero",
|
|
"1": "one",
|
|
"2": "two",
|
|
"3": "three",
|
|
"4": "four",
|
|
"5": "five",
|
|
"6": "six",
|
|
"7": "seven",
|
|
"8": "eight",
|
|
"9": "nine",
|
|
}
|
|
|
|
|
|
def _expand_numbers(text: str) -> str:
|
|
"""Expand standalone single digits to words for clearer pronunciation."""
|
|
return re.sub(
|
|
r"\b(\d)\b",
|
|
lambda m: _DIGIT_WORDS.get(m.group(0), m.group(0)),
|
|
text,
|
|
)
|
|
|
|
|
|
def _clean_text(text: str) -> str:
|
|
"""Clean and normalise text for TTS input."""
|
|
text = re.sub(r"[ \t]+", " ", text)
|
|
text = "\n".join(line.strip() for line in text.splitlines())
|
|
# Strip markdown / code-fence characters
|
|
text = re.sub(r"[*#~`|<>{}[\]\\]", "", text)
|
|
# Expand common symbols
|
|
text = text.replace("&", " and ")
|
|
text = text.replace("@", " at ")
|
|
text = text.replace("%", " percent ")
|
|
text = text.replace("+", " plus ")
|
|
text = text.replace("=", " equals ")
|
|
text = _expand_numbers(text)
|
|
return text.strip()
|
|
|
|
|
|
def _split_sentences(text: str) -> list[str]:
|
|
"""Split text into sentences suitable for TTS.
|
|
|
|
Keeps sentences short for best quality while preserving natural phrasing.
|
|
Very long segments are further split on commas / semicolons.
|
|
"""
|
|
text = _clean_text(text)
|
|
if not text:
|
|
return []
|
|
|
|
raw_parts = _SENTENCE_RE.split(text)
|
|
sentences: list[str] = []
|
|
for part in raw_parts:
|
|
part = part.strip()
|
|
if not part:
|
|
continue
|
|
if len(part) > 200:
|
|
for sp in re.split(r"(?<=[,;])\s+", part):
|
|
sp = sp.strip()
|
|
if sp:
|
|
sentences.append(sp)
|
|
else:
|
|
sentences.append(part)
|
|
return sentences
|
|
|
|
|
|
# ─── Audio helpers ───────────────────────────────────────────────────────
|
|
|
|
|
|
def _read_wav_bytes(data: bytes) -> tuple[int, np.ndarray]:
|
|
"""Read WAV audio from bytes, handling scipy wavfile and standard WAV.
|
|
|
|
Returns (sample_rate, float32_audio) with values in [-1, 1].
|
|
"""
|
|
buf = io.BytesIO(data)
|
|
|
|
# Try stdlib wave module first — most robust for PCM WAV from scipy
|
|
try:
|
|
with wave.open(buf, "rb") as wf:
|
|
sr = wf.getframerate()
|
|
n_frames = wf.getnframes()
|
|
n_channels = wf.getnchannels()
|
|
sampwidth = wf.getsampwidth()
|
|
raw = wf.readframes(n_frames)
|
|
|
|
if sampwidth == 2:
|
|
audio = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
|
|
elif sampwidth == 4:
|
|
audio = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2147483648.0
|
|
elif sampwidth == 1:
|
|
audio = (
|
|
np.frombuffer(raw, dtype=np.uint8).astype(np.float32) - 128.0
|
|
) / 128.0
|
|
else:
|
|
raise ValueError(f"Unsupported sample width: {sampwidth}")
|
|
|
|
if n_channels > 1:
|
|
audio = audio.reshape(-1, n_channels).mean(axis=1)
|
|
|
|
return sr, audio
|
|
except Exception as exc:
|
|
logger.debug("wave module failed (%s), trying soundfile", exc)
|
|
|
|
# Fallback: soundfile (handles FLAC, OGG, etc.)
|
|
buf.seek(0)
|
|
try:
|
|
import soundfile as sf
|
|
|
|
audio, sr = sf.read(buf, dtype="float32")
|
|
if audio.ndim > 1:
|
|
audio = audio.mean(axis=1)
|
|
return sr, audio
|
|
except Exception as exc:
|
|
logger.debug("soundfile failed (%s), attempting raw PCM", exc)
|
|
|
|
# Last resort: raw 16-bit PCM at 22050 Hz
|
|
logger.warning(
|
|
"Could not parse WAV header (len=%d, first 4 bytes=%r); raw PCM decode",
|
|
len(data),
|
|
data[:4],
|
|
)
|
|
audio = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
|
|
return 22050, audio
|
|
|
|
|
|
def _concat_audio(
|
|
chunks: list[tuple[int, np.ndarray]], pause_ms: int = 200
|
|
) -> tuple[int, np.ndarray]:
|
|
"""Concatenate (sample_rate, audio) chunks with silence gaps."""
|
|
if not chunks:
|
|
return 22050, np.array([], dtype=np.float32)
|
|
if len(chunks) == 1:
|
|
return chunks[0]
|
|
|
|
sr = chunks[0][0]
|
|
silence = np.zeros(int(sr * pause_ms / 1000), dtype=np.float32)
|
|
|
|
parts: list[np.ndarray] = []
|
|
for sample_rate, audio in chunks:
|
|
if sample_rate != sr:
|
|
ratio = sr / sample_rate
|
|
indices = np.arange(0, len(audio), 1.0 / ratio).astype(int)
|
|
indices = indices[indices < len(audio)]
|
|
audio = audio[indices]
|
|
parts.append(audio)
|
|
parts.append(silence)
|
|
|
|
if parts:
|
|
parts.pop() # remove trailing silence
|
|
return sr, np.concatenate(parts)
|
|
|
|
|
|
# ─── TTS synthesis ───────────────────────────────────────────────────────
|
|
|
|
|
|
def _synthesize_chunk(text: str, lang_code: str, speed: float = 1.0) -> bytes:
|
|
"""Synthesize a single text chunk via the TTS backend.
|
|
|
|
Uses the JSON POST endpoint (no URL length limits, supports speed).
|
|
Falls back to the Coqui-compatible GET endpoint if POST fails.
|
|
"""
|
|
import base64 as b64
|
|
|
|
# Try JSON POST first
|
|
try:
|
|
resp = client.post(
|
|
TTS_URL,
|
|
json={
|
|
"text": text,
|
|
"language": lang_code,
|
|
"speed": speed,
|
|
"return_base64": True,
|
|
},
|
|
)
|
|
resp.raise_for_status()
|
|
ct = resp.headers.get("content-type", "")
|
|
if "application/json" in ct:
|
|
body = resp.json()
|
|
if "error" in body:
|
|
raise RuntimeError(body["error"])
|
|
audio_b64 = body.get("audio", "")
|
|
if audio_b64:
|
|
return b64.b64decode(audio_b64)
|
|
# Non-JSON response — treat as raw audio bytes
|
|
return resp.content
|
|
except Exception:
|
|
logger.debug(
|
|
"POST endpoint failed, falling back to GET /api/tts", exc_info=True
|
|
)
|
|
|
|
# Fallback: Coqui-compatible GET (no speed control)
|
|
resp = client.get(
|
|
f"{TTS_URL}/api/tts",
|
|
params={"text": text, "language_id": lang_code},
|
|
)
|
|
resp.raise_for_status()
|
|
return resp.content
|
|
|
|
|
|
def synthesize_speech(
|
|
text: str, language: str, speed: float
|
|
) -> tuple[str, tuple[int, np.ndarray] | None, str]:
|
|
"""Synthesize speech from text using the TTS service.
|
|
|
|
Long text is split into sentences and synthesized individually
|
|
for better quality, then concatenated with natural pauses.
|
|
"""
|
|
if not text.strip():
|
|
return "❌ Please enter some text", None, ""
|
|
|
|
lang_code = LANGUAGES.get(language, "en")
|
|
sentences = _split_sentences(text)
|
|
if not sentences:
|
|
return "❌ No speakable text found after cleaning", None, ""
|
|
|
|
try:
|
|
start_time = time.time()
|
|
audio_chunks: list[tuple[int, np.ndarray]] = []
|
|
|
|
for sentence in sentences:
|
|
raw_audio = _synthesize_chunk(sentence, lang_code, speed)
|
|
sr, audio = _read_wav_bytes(raw_audio)
|
|
audio_chunks.append((sr, audio))
|
|
|
|
sample_rate, audio_data = _concat_audio(audio_chunks)
|
|
latency = time.time() - start_time
|
|
duration = len(audio_data) / sample_rate if sample_rate > 0 else 0
|
|
|
|
n_chunks = len(sentences)
|
|
status = (
|
|
f"✅ Generated {duration:.2f}s of audio in {latency * 1000:.0f}ms"
|
|
f" ({n_chunks} sentence{'s' if n_chunks != 1 else ''})"
|
|
)
|
|
|
|
_log_tts_metrics(
|
|
latency=latency,
|
|
audio_duration=duration,
|
|
text_chars=len(text),
|
|
language=lang_code,
|
|
)
|
|
|
|
metrics = f"""
|
|
**Audio Statistics:**
|
|
- Duration: {duration:.2f} seconds
|
|
- Sample Rate: {sample_rate} Hz
|
|
- Size: {len(audio_data) * 2 / 1024:.1f} KB
|
|
- Generation Time: {latency * 1000:.0f}ms
|
|
- Real-time Factor: {latency / duration:.2f}x
|
|
- Language: {language} ({lang_code})
|
|
- Speed: {speed:.1f}x
|
|
- Sentences: {n_chunks}
|
|
- Characters: {len(text)}
|
|
- Chars/sec: {len(text) / latency:.1f}
|
|
"""
|
|
return status, (sample_rate, audio_data), metrics
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
logger.exception("TTS request failed")
|
|
return f"❌ TTS service error: {e.response.status_code}", None, ""
|
|
except Exception as e:
|
|
logger.exception("TTS synthesis failed")
|
|
return f"❌ Error: {e}", None, ""
|
|
|
|
|
|
def check_service_health() -> str:
|
|
"""Check if the TTS service is healthy."""
|
|
try:
|
|
response = client.get(f"{TTS_URL}/health", timeout=5.0)
|
|
if response.status_code == 200:
|
|
return "🟢 Service is healthy"
|
|
response = client.get(f"{TTS_URL}/", timeout=5.0)
|
|
if response.status_code == 200:
|
|
return "🟢 Service is responding"
|
|
return f"🟡 Service returned status {response.status_code}"
|
|
except Exception as e:
|
|
return f"🔴 Service unavailable: {e}"
|
|
|
|
|
|
# ─── Gradio UI ───────────────────────────────────────────────────────────
|
|
|
|
with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="TTS Demo") as demo:
|
|
gr.Markdown("""
|
|
# 🔊 Text-to-Speech Demo
|
|
|
|
Test the **Coqui XTTS** text-to-speech service. Convert text to natural-sounding speech
|
|
in multiple languages. Long text is automatically split into sentences for better quality.
|
|
""")
|
|
|
|
with gr.Row():
|
|
health_btn = gr.Button("🔄 Check Service", size="sm")
|
|
health_status = gr.Textbox(label="Service Status", interactive=False)
|
|
|
|
health_btn.click(fn=check_service_health, outputs=health_status)
|
|
|
|
with gr.Tabs():
|
|
with gr.TabItem("🎤 Text to Speech"):
|
|
with gr.Row():
|
|
with gr.Column(scale=2):
|
|
text_input = gr.Textbox(
|
|
label="Text to Synthesize",
|
|
placeholder="Enter text to convert to speech...",
|
|
lines=5,
|
|
max_lines=10,
|
|
)
|
|
with gr.Row():
|
|
language = gr.Dropdown(
|
|
choices=list(LANGUAGES.keys()),
|
|
value="English",
|
|
label="Language",
|
|
)
|
|
speed = gr.Slider(
|
|
minimum=0.5,
|
|
maximum=2.0,
|
|
value=1.0,
|
|
step=0.1,
|
|
label="Speed",
|
|
)
|
|
synthesize_btn = gr.Button(
|
|
"🔊 Synthesize",
|
|
variant="primary",
|
|
scale=2,
|
|
)
|
|
with gr.Column(scale=1):
|
|
status_output = gr.Textbox(label="Status", interactive=False)
|
|
metrics_output = gr.Markdown(label="Metrics")
|
|
|
|
audio_output = gr.Audio(label="Generated Audio", type="numpy")
|
|
|
|
synthesize_btn.click(
|
|
fn=synthesize_speech,
|
|
inputs=[text_input, language, speed],
|
|
outputs=[status_output, audio_output, metrics_output],
|
|
)
|
|
|
|
gr.Examples(
|
|
examples=[
|
|
[
|
|
"Hello! Welcome to Davies Tech Labs. This is a demonstration of our text-to-speech system.",
|
|
"English",
|
|
1.0,
|
|
],
|
|
[
|
|
"The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet.",
|
|
"English",
|
|
1.0,
|
|
],
|
|
[
|
|
"Bonjour! Bienvenue au laboratoire technique de Davies.",
|
|
"French",
|
|
1.0,
|
|
],
|
|
["Hola! Bienvenido al laboratorio de tecnología.", "Spanish", 1.0],
|
|
["Guten Tag! Willkommen im Techniklabor.", "German", 1.0],
|
|
],
|
|
inputs=[text_input, language, speed],
|
|
)
|
|
|
|
with gr.TabItem("🔄 Language Comparison"):
|
|
gr.Markdown("Compare the same text in different languages.")
|
|
|
|
compare_text = gr.Textbox(
|
|
label="Text to Compare", value="Hello, how are you today?", lines=2
|
|
)
|
|
with gr.Row():
|
|
lang1 = gr.Dropdown(
|
|
choices=list(LANGUAGES.keys()), value="English", label="Language 1"
|
|
)
|
|
lang2 = gr.Dropdown(
|
|
choices=list(LANGUAGES.keys()), value="Spanish", label="Language 2"
|
|
)
|
|
compare_speed = gr.Slider(
|
|
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed"
|
|
)
|
|
|
|
compare_btn = gr.Button("Compare Languages", variant="primary")
|
|
|
|
with gr.Row():
|
|
with gr.Column():
|
|
gr.Markdown("### Language 1")
|
|
audio1 = gr.Audio(label="Audio 1", type="numpy")
|
|
status1 = gr.Textbox(label="Status", interactive=False)
|
|
with gr.Column():
|
|
gr.Markdown("### Language 2")
|
|
audio2 = gr.Audio(label="Audio 2", type="numpy")
|
|
status2 = gr.Textbox(label="Status", interactive=False)
|
|
|
|
def compare_languages(text, l1, l2, spd):
|
|
s1, a1, _ = synthesize_speech(text, l1, spd)
|
|
s2, a2, _ = synthesize_speech(text, l2, spd)
|
|
return s1, a1, s2, a2
|
|
|
|
compare_btn.click(
|
|
fn=compare_languages,
|
|
inputs=[compare_text, lang1, lang2, compare_speed],
|
|
outputs=[status1, audio1, status2, audio2],
|
|
)
|
|
|
|
with gr.TabItem("📚 Batch Synthesis"):
|
|
gr.Markdown("Synthesize multiple texts at once (one per line).")
|
|
|
|
batch_input = gr.Textbox(
|
|
label="Texts (one per line)",
|
|
placeholder="Enter multiple texts, one per line...",
|
|
lines=6,
|
|
)
|
|
batch_lang = gr.Dropdown(
|
|
choices=list(LANGUAGES.keys()), value="English", label="Language"
|
|
)
|
|
batch_speed = gr.Slider(
|
|
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed"
|
|
)
|
|
batch_btn = gr.Button("Synthesize All", variant="primary")
|
|
|
|
batch_status = gr.Textbox(label="Status", interactive=False)
|
|
batch_audio = gr.Audio(label="Combined Audio", type="numpy")
|
|
|
|
def batch_synthesize(texts_raw: str, lang: str, spd: float):
|
|
lines = [
|
|
line.strip()
|
|
for line in texts_raw.strip().splitlines()
|
|
if line.strip()
|
|
]
|
|
if not lines:
|
|
return "❌ Please enter at least one line of text", None
|
|
combined = "\n".join(lines)
|
|
status, audio, _ = synthesize_speech(combined, lang, spd)
|
|
return status, audio
|
|
|
|
batch_btn.click(
|
|
fn=batch_synthesize,
|
|
inputs=[batch_input, batch_lang, batch_speed],
|
|
outputs=[batch_status, batch_audio],
|
|
)
|
|
|
|
create_footer()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|