Files
gradio-ui/tts.py
Billy D. 0d1c40725e
All checks were successful
CI / Docker Build & Push (push) Successful in 5m38s
CI / Deploy to Kubernetes (push) Successful in 1m21s
CI / Notify (push) Successful in 1s
CI / Lint (push) Successful in 1m4s
CI / Release (push) Successful in 54s
style: fix ruff lint and formatting issues
- tts.py: rename ambiguous variable 'l' to 'line' (E741)
- tts.py, llm.py: apply ruff formatter
2026-02-22 10:55:00 -05:00

572 lines
19 KiB
Python

#!/usr/bin/env python3
"""
TTS Demo - Gradio UI for testing Text-to-Speech service.
Features:
- Text input with language selection
- Audio playback of synthesized speech
- Sentence-level chunking for better quality
- Speed control
- MLflow metrics logging
"""
import os
import re
import time
import logging
import io
import wave
import gradio as gr
import httpx
import numpy as np
from theme import get_lab_theme, CUSTOM_CSS, create_footer
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("tts-demo")
# Configuration
TTS_URL = os.environ.get(
"TTS_URL",
# Default: Ray Serve TTS endpoint
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/tts",
)
MLFLOW_TRACKING_URI = os.environ.get(
"MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"
)
# ─── MLflow experiment tracking ──────────────────────────────────────────
try:
import mlflow
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
_mlflow_client = MlflowClient()
_experiment = _mlflow_client.get_experiment_by_name("gradio-tts-tuning")
if _experiment is None:
_experiment_id = _mlflow_client.create_experiment(
"gradio-tts-tuning",
artifact_location="/mlflow/artifacts/gradio-tts-tuning",
)
else:
_experiment_id = _experiment.experiment_id
_mlflow_run = mlflow.start_run(
experiment_id=_experiment_id,
run_name=f"gradio-tts-{os.environ.get('HOSTNAME', 'local')}",
tags={"service": "gradio-tts", "endpoint": TTS_URL},
)
_mlflow_run_id = _mlflow_run.info.run_id
_mlflow_step = 0
MLFLOW_ENABLED = True
logger.info(
"MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id
)
except Exception as exc:
logger.warning("MLflow tracking disabled: %s", exc)
_mlflow_client = None
_mlflow_run_id = None
_mlflow_step = 0
MLFLOW_ENABLED = False
def _log_tts_metrics(
latency: float,
audio_duration: float,
text_chars: int,
language: str,
) -> None:
"""Log TTS inference metrics to MLflow (non-blocking best-effort)."""
global _mlflow_step
if not MLFLOW_ENABLED or _mlflow_client is None:
return
try:
_mlflow_step += 1
ts = int(time.time() * 1000)
rtf = latency / audio_duration if audio_duration > 0 else 0
cps = text_chars / latency if latency > 0 else 0
_mlflow_client.log_batch(
_mlflow_run_id,
metrics=[
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
mlflow.entities.Metric(
"audio_duration_s", audio_duration, ts, _mlflow_step
),
mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step),
mlflow.entities.Metric("chars_per_second", cps, ts, _mlflow_step),
mlflow.entities.Metric("text_chars", text_chars, ts, _mlflow_step),
],
)
except Exception:
logger.debug("MLflow log failed", exc_info=True)
# HTTP client with longer timeout for audio generation
client = httpx.Client(timeout=120.0)
# Supported languages for XTTS
LANGUAGES = {
"English": "en",
"Spanish": "es",
"French": "fr",
"German": "de",
"Italian": "it",
"Portuguese": "pt",
"Polish": "pl",
"Turkish": "tr",
"Russian": "ru",
"Dutch": "nl",
"Czech": "cs",
"Arabic": "ar",
"Chinese": "zh-cn",
"Japanese": "ja",
"Korean": "ko",
"Hungarian": "hu",
}
# ─── Text preprocessing ─────────────────────────────────────────────────
_SENTENCE_RE = re.compile(r"(?<=[.!?;])\s+|(?<=\n)\s*", re.MULTILINE)
_DIGIT_WORDS = {
"0": "zero",
"1": "one",
"2": "two",
"3": "three",
"4": "four",
"5": "five",
"6": "six",
"7": "seven",
"8": "eight",
"9": "nine",
}
def _expand_numbers(text: str) -> str:
"""Expand standalone single digits to words for clearer pronunciation."""
return re.sub(
r"\b(\d)\b",
lambda m: _DIGIT_WORDS.get(m.group(0), m.group(0)),
text,
)
def _clean_text(text: str) -> str:
"""Clean and normalise text for TTS input."""
text = re.sub(r"[ \t]+", " ", text)
text = "\n".join(line.strip() for line in text.splitlines())
# Strip markdown / code-fence characters
text = re.sub(r"[*#~`|<>{}[\]\\]", "", text)
# Expand common symbols
text = text.replace("&", " and ")
text = text.replace("@", " at ")
text = text.replace("%", " percent ")
text = text.replace("+", " plus ")
text = text.replace("=", " equals ")
text = _expand_numbers(text)
return text.strip()
def _split_sentences(text: str) -> list[str]:
"""Split text into sentences suitable for TTS.
Keeps sentences short for best quality while preserving natural phrasing.
Very long segments are further split on commas / semicolons.
"""
text = _clean_text(text)
if not text:
return []
raw_parts = _SENTENCE_RE.split(text)
sentences: list[str] = []
for part in raw_parts:
part = part.strip()
if not part:
continue
if len(part) > 200:
for sp in re.split(r"(?<=[,;])\s+", part):
sp = sp.strip()
if sp:
sentences.append(sp)
else:
sentences.append(part)
return sentences
# ─── Audio helpers ───────────────────────────────────────────────────────
def _read_wav_bytes(data: bytes) -> tuple[int, np.ndarray]:
"""Read WAV audio from bytes, handling scipy wavfile and standard WAV.
Returns (sample_rate, float32_audio) with values in [-1, 1].
"""
buf = io.BytesIO(data)
# Try stdlib wave module first — most robust for PCM WAV from scipy
try:
with wave.open(buf, "rb") as wf:
sr = wf.getframerate()
n_frames = wf.getnframes()
n_channels = wf.getnchannels()
sampwidth = wf.getsampwidth()
raw = wf.readframes(n_frames)
if sampwidth == 2:
audio = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
elif sampwidth == 4:
audio = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2147483648.0
elif sampwidth == 1:
audio = (
np.frombuffer(raw, dtype=np.uint8).astype(np.float32) - 128.0
) / 128.0
else:
raise ValueError(f"Unsupported sample width: {sampwidth}")
if n_channels > 1:
audio = audio.reshape(-1, n_channels).mean(axis=1)
return sr, audio
except Exception as exc:
logger.debug("wave module failed (%s), trying soundfile", exc)
# Fallback: soundfile (handles FLAC, OGG, etc.)
buf.seek(0)
try:
import soundfile as sf
audio, sr = sf.read(buf, dtype="float32")
if audio.ndim > 1:
audio = audio.mean(axis=1)
return sr, audio
except Exception as exc:
logger.debug("soundfile failed (%s), attempting raw PCM", exc)
# Last resort: raw 16-bit PCM at 22050 Hz
logger.warning(
"Could not parse WAV header (len=%d, first 4 bytes=%r); raw PCM decode",
len(data),
data[:4],
)
audio = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
return 22050, audio
def _concat_audio(
chunks: list[tuple[int, np.ndarray]], pause_ms: int = 200
) -> tuple[int, np.ndarray]:
"""Concatenate (sample_rate, audio) chunks with silence gaps."""
if not chunks:
return 22050, np.array([], dtype=np.float32)
if len(chunks) == 1:
return chunks[0]
sr = chunks[0][0]
silence = np.zeros(int(sr * pause_ms / 1000), dtype=np.float32)
parts: list[np.ndarray] = []
for sample_rate, audio in chunks:
if sample_rate != sr:
ratio = sr / sample_rate
indices = np.arange(0, len(audio), 1.0 / ratio).astype(int)
indices = indices[indices < len(audio)]
audio = audio[indices]
parts.append(audio)
parts.append(silence)
if parts:
parts.pop() # remove trailing silence
return sr, np.concatenate(parts)
# ─── TTS synthesis ───────────────────────────────────────────────────────
def _synthesize_chunk(text: str, lang_code: str, speed: float = 1.0) -> bytes:
"""Synthesize a single text chunk via the TTS backend.
Uses the JSON POST endpoint (no URL length limits, supports speed).
Falls back to the Coqui-compatible GET endpoint if POST fails.
"""
import base64 as b64
# Try JSON POST first
try:
resp = client.post(
TTS_URL,
json={
"text": text,
"language": lang_code,
"speed": speed,
"return_base64": True,
},
)
resp.raise_for_status()
ct = resp.headers.get("content-type", "")
if "application/json" in ct:
body = resp.json()
if "error" in body:
raise RuntimeError(body["error"])
audio_b64 = body.get("audio", "")
if audio_b64:
return b64.b64decode(audio_b64)
# Non-JSON response — treat as raw audio bytes
return resp.content
except Exception:
logger.debug(
"POST endpoint failed, falling back to GET /api/tts", exc_info=True
)
# Fallback: Coqui-compatible GET (no speed control)
resp = client.get(
f"{TTS_URL}/api/tts",
params={"text": text, "language_id": lang_code},
)
resp.raise_for_status()
return resp.content
def synthesize_speech(
text: str, language: str, speed: float
) -> tuple[str, tuple[int, np.ndarray] | None, str]:
"""Synthesize speech from text using the TTS service.
Long text is split into sentences and synthesized individually
for better quality, then concatenated with natural pauses.
"""
if not text.strip():
return "❌ Please enter some text", None, ""
lang_code = LANGUAGES.get(language, "en")
sentences = _split_sentences(text)
if not sentences:
return "❌ No speakable text found after cleaning", None, ""
try:
start_time = time.time()
audio_chunks: list[tuple[int, np.ndarray]] = []
for sentence in sentences:
raw_audio = _synthesize_chunk(sentence, lang_code, speed)
sr, audio = _read_wav_bytes(raw_audio)
audio_chunks.append((sr, audio))
sample_rate, audio_data = _concat_audio(audio_chunks)
latency = time.time() - start_time
duration = len(audio_data) / sample_rate if sample_rate > 0 else 0
n_chunks = len(sentences)
status = (
f"✅ Generated {duration:.2f}s of audio in {latency * 1000:.0f}ms"
f" ({n_chunks} sentence{'s' if n_chunks != 1 else ''})"
)
_log_tts_metrics(
latency=latency,
audio_duration=duration,
text_chars=len(text),
language=lang_code,
)
metrics = f"""
**Audio Statistics:**
- Duration: {duration:.2f} seconds
- Sample Rate: {sample_rate} Hz
- Size: {len(audio_data) * 2 / 1024:.1f} KB
- Generation Time: {latency * 1000:.0f}ms
- Real-time Factor: {latency / duration:.2f}x
- Language: {language} ({lang_code})
- Speed: {speed:.1f}x
- Sentences: {n_chunks}
- Characters: {len(text)}
- Chars/sec: {len(text) / latency:.1f}
"""
return status, (sample_rate, audio_data), metrics
except httpx.HTTPStatusError as e:
logger.exception("TTS request failed")
return f"❌ TTS service error: {e.response.status_code}", None, ""
except Exception as e:
logger.exception("TTS synthesis failed")
return f"❌ Error: {e}", None, ""
def check_service_health() -> str:
"""Check if the TTS service is healthy."""
try:
response = client.get(f"{TTS_URL}/health", timeout=5.0)
if response.status_code == 200:
return "🟢 Service is healthy"
response = client.get(f"{TTS_URL}/", timeout=5.0)
if response.status_code == 200:
return "🟢 Service is responding"
return f"🟡 Service returned status {response.status_code}"
except Exception as e:
return f"🔴 Service unavailable: {e}"
# ─── Gradio UI ───────────────────────────────────────────────────────────
with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="TTS Demo") as demo:
gr.Markdown("""
# 🔊 Text-to-Speech Demo
Test the **Coqui XTTS** text-to-speech service. Convert text to natural-sounding speech
in multiple languages. Long text is automatically split into sentences for better quality.
""")
with gr.Row():
health_btn = gr.Button("🔄 Check Service", size="sm")
health_status = gr.Textbox(label="Service Status", interactive=False)
health_btn.click(fn=check_service_health, outputs=health_status)
with gr.Tabs():
with gr.TabItem("🎤 Text to Speech"):
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
label="Text to Synthesize",
placeholder="Enter text to convert to speech...",
lines=5,
max_lines=10,
)
with gr.Row():
language = gr.Dropdown(
choices=list(LANGUAGES.keys()),
value="English",
label="Language",
)
speed = gr.Slider(
minimum=0.5,
maximum=2.0,
value=1.0,
step=0.1,
label="Speed",
)
synthesize_btn = gr.Button(
"🔊 Synthesize",
variant="primary",
scale=2,
)
with gr.Column(scale=1):
status_output = gr.Textbox(label="Status", interactive=False)
metrics_output = gr.Markdown(label="Metrics")
audio_output = gr.Audio(label="Generated Audio", type="numpy")
synthesize_btn.click(
fn=synthesize_speech,
inputs=[text_input, language, speed],
outputs=[status_output, audio_output, metrics_output],
)
gr.Examples(
examples=[
[
"Hello! Welcome to Davies Tech Labs. This is a demonstration of our text-to-speech system.",
"English",
1.0,
],
[
"The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet.",
"English",
1.0,
],
[
"Bonjour! Bienvenue au laboratoire technique de Davies.",
"French",
1.0,
],
["Hola! Bienvenido al laboratorio de tecnología.", "Spanish", 1.0],
["Guten Tag! Willkommen im Techniklabor.", "German", 1.0],
],
inputs=[text_input, language, speed],
)
with gr.TabItem("🔄 Language Comparison"):
gr.Markdown("Compare the same text in different languages.")
compare_text = gr.Textbox(
label="Text to Compare", value="Hello, how are you today?", lines=2
)
with gr.Row():
lang1 = gr.Dropdown(
choices=list(LANGUAGES.keys()), value="English", label="Language 1"
)
lang2 = gr.Dropdown(
choices=list(LANGUAGES.keys()), value="Spanish", label="Language 2"
)
compare_speed = gr.Slider(
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed"
)
compare_btn = gr.Button("Compare Languages", variant="primary")
with gr.Row():
with gr.Column():
gr.Markdown("### Language 1")
audio1 = gr.Audio(label="Audio 1", type="numpy")
status1 = gr.Textbox(label="Status", interactive=False)
with gr.Column():
gr.Markdown("### Language 2")
audio2 = gr.Audio(label="Audio 2", type="numpy")
status2 = gr.Textbox(label="Status", interactive=False)
def compare_languages(text, l1, l2, spd):
s1, a1, _ = synthesize_speech(text, l1, spd)
s2, a2, _ = synthesize_speech(text, l2, spd)
return s1, a1, s2, a2
compare_btn.click(
fn=compare_languages,
inputs=[compare_text, lang1, lang2, compare_speed],
outputs=[status1, audio1, status2, audio2],
)
with gr.TabItem("📚 Batch Synthesis"):
gr.Markdown("Synthesize multiple texts at once (one per line).")
batch_input = gr.Textbox(
label="Texts (one per line)",
placeholder="Enter multiple texts, one per line...",
lines=6,
)
batch_lang = gr.Dropdown(
choices=list(LANGUAGES.keys()), value="English", label="Language"
)
batch_speed = gr.Slider(
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed"
)
batch_btn = gr.Button("Synthesize All", variant="primary")
batch_status = gr.Textbox(label="Status", interactive=False)
batch_audio = gr.Audio(label="Combined Audio", type="numpy")
def batch_synthesize(texts_raw: str, lang: str, spd: float):
lines = [
line.strip()
for line in texts_raw.strip().splitlines()
if line.strip()
]
if not lines:
return "❌ Please enter at least one line of text", None
combined = "\n".join(lines)
status, audio, _ = synthesize_speech(combined, lang, spd)
return status, audio
batch_btn.click(
fn=batch_synthesize,
inputs=[batch_input, batch_lang, batch_speed],
outputs=[batch_status, batch_audio],
)
create_footer()
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)