Each UI now logs per-request metrics to MLflow: - llm.py: latency, tokens/sec, prompt/completion tokens (gradio-llm-tuning) - embeddings.py: latency, text length, batch size (gradio-embeddings-tuning) - stt.py: latency, audio duration, real-time factor (gradio-stt-tuning) - tts.py: latency, text length, audio duration (gradio-tts-tuning) Uses try/except guarded imports so UIs still work if MLflow is unreachable. Persistent run per Gradio instance, batched metric logging via MlflowClient.log_batch().
342 lines
12 KiB
Python
342 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
TTS Demo - Gradio UI for testing Text-to-Speech service.
|
|
|
|
Features:
|
|
- Text input with language selection
|
|
- Audio playback of synthesized speech
|
|
- Voice/speaker selection (when available)
|
|
- MLflow metrics logging
|
|
- Multiple TTS backends support (Coqui XTTS, Piper, etc.)
|
|
"""
|
|
import os
|
|
import time
|
|
import logging
|
|
import io
|
|
import base64
|
|
|
|
import gradio as gr
|
|
import httpx
|
|
import soundfile as sf
|
|
import numpy as np
|
|
|
|
from theme import get_lab_theme, CUSTOM_CSS, create_footer
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger("tts-demo")
|
|
|
|
# Configuration
|
|
TTS_URL = os.environ.get(
|
|
"TTS_URL",
|
|
# Default: Ray Serve TTS endpoint
|
|
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/tts"
|
|
)
|
|
MLFLOW_TRACKING_URI = os.environ.get(
|
|
"MLFLOW_TRACKING_URI",
|
|
"http://mlflow.mlflow.svc.cluster.local:80"
|
|
)
|
|
|
|
# ─── MLflow experiment tracking ──────────────────────────────────────────
|
|
try:
|
|
import mlflow
|
|
from mlflow.tracking import MlflowClient
|
|
|
|
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
|
|
_mlflow_client = MlflowClient()
|
|
|
|
_experiment = _mlflow_client.get_experiment_by_name("gradio-tts-tuning")
|
|
if _experiment is None:
|
|
_experiment_id = _mlflow_client.create_experiment(
|
|
"gradio-tts-tuning",
|
|
artifact_location="/mlflow/artifacts/gradio-tts-tuning",
|
|
)
|
|
else:
|
|
_experiment_id = _experiment.experiment_id
|
|
|
|
_mlflow_run = mlflow.start_run(
|
|
experiment_id=_experiment_id,
|
|
run_name=f"gradio-tts-{os.environ.get('HOSTNAME', 'local')}",
|
|
tags={"service": "gradio-tts", "endpoint": TTS_URL},
|
|
)
|
|
_mlflow_run_id = _mlflow_run.info.run_id
|
|
_mlflow_step = 0
|
|
MLFLOW_ENABLED = True
|
|
logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id)
|
|
except Exception as exc:
|
|
logger.warning("MLflow tracking disabled: %s", exc)
|
|
_mlflow_client = None
|
|
_mlflow_run_id = None
|
|
_mlflow_step = 0
|
|
MLFLOW_ENABLED = False
|
|
|
|
|
|
def _log_tts_metrics(
|
|
latency: float, audio_duration: float, text_chars: int, language: str,
|
|
) -> None:
|
|
"""Log TTS inference metrics to MLflow (non-blocking best-effort)."""
|
|
global _mlflow_step
|
|
if not MLFLOW_ENABLED or _mlflow_client is None:
|
|
return
|
|
try:
|
|
_mlflow_step += 1
|
|
ts = int(time.time() * 1000)
|
|
rtf = latency / audio_duration if audio_duration > 0 else 0
|
|
cps = text_chars / latency if latency > 0 else 0
|
|
_mlflow_client.log_batch(
|
|
_mlflow_run_id,
|
|
metrics=[
|
|
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
|
|
mlflow.entities.Metric("audio_duration_s", audio_duration, ts, _mlflow_step),
|
|
mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step),
|
|
mlflow.entities.Metric("chars_per_second", cps, ts, _mlflow_step),
|
|
mlflow.entities.Metric("text_chars", text_chars, ts, _mlflow_step),
|
|
],
|
|
)
|
|
except Exception:
|
|
logger.debug("MLflow log failed", exc_info=True)
|
|
|
|
|
|
# HTTP client with longer timeout for audio generation
|
|
client = httpx.Client(timeout=120.0)
|
|
|
|
# Supported languages for XTTS
|
|
LANGUAGES = {
|
|
"English": "en",
|
|
"Spanish": "es",
|
|
"French": "fr",
|
|
"German": "de",
|
|
"Italian": "it",
|
|
"Portuguese": "pt",
|
|
"Polish": "pl",
|
|
"Turkish": "tr",
|
|
"Russian": "ru",
|
|
"Dutch": "nl",
|
|
"Czech": "cs",
|
|
"Arabic": "ar",
|
|
"Chinese": "zh-cn",
|
|
"Japanese": "ja",
|
|
"Korean": "ko",
|
|
"Hungarian": "hu",
|
|
}
|
|
|
|
|
|
def synthesize_speech(text: str, language: str) -> tuple[str, tuple[int, np.ndarray] | None, str]:
|
|
"""Synthesize speech from text using the TTS service."""
|
|
if not text.strip():
|
|
return "❌ Please enter some text", None, ""
|
|
|
|
lang_code = LANGUAGES.get(language, "en")
|
|
|
|
try:
|
|
start_time = time.time()
|
|
|
|
# Call TTS service (Coqui XTTS API format)
|
|
response = client.get(
|
|
f"{TTS_URL}/api/tts",
|
|
params={"text": text, "language_id": lang_code}
|
|
)
|
|
response.raise_for_status()
|
|
|
|
latency = time.time() - start_time
|
|
audio_bytes = response.content
|
|
|
|
# Parse audio data
|
|
audio_io = io.BytesIO(audio_bytes)
|
|
audio_data, sample_rate = sf.read(audio_io)
|
|
|
|
# Calculate duration
|
|
if len(audio_data.shape) == 1:
|
|
duration = len(audio_data) / sample_rate
|
|
else:
|
|
duration = len(audio_data) / sample_rate
|
|
|
|
# Status message
|
|
status = f"✅ Generated {duration:.2f}s of audio in {latency*1000:.0f}ms"
|
|
|
|
# Log to MLflow
|
|
_log_tts_metrics(
|
|
latency=latency,
|
|
audio_duration=duration,
|
|
text_chars=len(text),
|
|
language=lang_code,
|
|
)
|
|
|
|
# Metrics
|
|
metrics = f"""
|
|
**Audio Statistics:**
|
|
- Duration: {duration:.2f} seconds
|
|
- Sample Rate: {sample_rate} Hz
|
|
- Size: {len(audio_bytes) / 1024:.1f} KB
|
|
- Generation Time: {latency*1000:.0f}ms
|
|
- Real-time Factor: {latency/duration:.2f}x
|
|
- Language: {language} ({lang_code})
|
|
- Characters: {len(text)}
|
|
- Chars/sec: {len(text)/latency:.1f}
|
|
"""
|
|
|
|
return status, (sample_rate, audio_data), metrics
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
logger.exception("TTS request failed")
|
|
return f"❌ TTS service error: {e.response.status_code}", None, ""
|
|
except Exception as e:
|
|
logger.exception("TTS synthesis failed")
|
|
return f"❌ Error: {str(e)}", None, ""
|
|
|
|
|
|
def check_service_health() -> str:
|
|
"""Check if the TTS service is healthy."""
|
|
try:
|
|
# Try the health endpoint first
|
|
response = client.get(f"{TTS_URL}/health", timeout=5.0)
|
|
if response.status_code == 200:
|
|
return "🟢 Service is healthy"
|
|
|
|
# Fall back to root endpoint
|
|
response = client.get(f"{TTS_URL}/", timeout=5.0)
|
|
if response.status_code == 200:
|
|
return "🟢 Service is responding"
|
|
|
|
return f"🟡 Service returned status {response.status_code}"
|
|
except Exception as e:
|
|
return f"🔴 Service unavailable: {str(e)}"
|
|
|
|
|
|
# Build the Gradio app
|
|
with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="TTS Demo") as demo:
|
|
gr.Markdown("""
|
|
# 🔊 Text-to-Speech Demo
|
|
|
|
Test the **Coqui XTTS** text-to-speech service. Convert text to natural-sounding speech
|
|
in multiple languages.
|
|
""")
|
|
|
|
# Service status
|
|
with gr.Row():
|
|
health_btn = gr.Button("🔄 Check Service", size="sm")
|
|
health_status = gr.Textbox(label="Service Status", interactive=False)
|
|
|
|
health_btn.click(fn=check_service_health, outputs=health_status)
|
|
|
|
with gr.Tabs():
|
|
# Tab 1: Basic TTS
|
|
with gr.TabItem("🎤 Text to Speech"):
|
|
with gr.Row():
|
|
with gr.Column(scale=2):
|
|
text_input = gr.Textbox(
|
|
label="Text to Synthesize",
|
|
placeholder="Enter text to convert to speech...",
|
|
lines=5,
|
|
max_lines=10
|
|
)
|
|
|
|
with gr.Row():
|
|
language = gr.Dropdown(
|
|
choices=list(LANGUAGES.keys()),
|
|
value="English",
|
|
label="Language"
|
|
)
|
|
synthesize_btn = gr.Button("🔊 Synthesize", variant="primary", scale=2)
|
|
|
|
with gr.Column(scale=1):
|
|
status_output = gr.Textbox(label="Status", interactive=False)
|
|
metrics_output = gr.Markdown(label="Metrics")
|
|
|
|
audio_output = gr.Audio(label="Generated Audio", type="numpy")
|
|
|
|
synthesize_btn.click(
|
|
fn=synthesize_speech,
|
|
inputs=[text_input, language],
|
|
outputs=[status_output, audio_output, metrics_output]
|
|
)
|
|
|
|
# Example texts
|
|
gr.Examples(
|
|
examples=[
|
|
["Hello! Welcome to Davies Tech Labs. This is a demonstration of our text-to-speech system.", "English"],
|
|
["The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet.", "English"],
|
|
["Bonjour! Bienvenue au laboratoire technique de Davies.", "French"],
|
|
["Hola! Bienvenido al laboratorio de tecnología.", "Spanish"],
|
|
["Guten Tag! Willkommen im Techniklabor.", "German"],
|
|
],
|
|
inputs=[text_input, language],
|
|
)
|
|
|
|
# Tab 2: Comparison
|
|
with gr.TabItem("🔄 Language Comparison"):
|
|
gr.Markdown("Compare the same text in different languages.")
|
|
|
|
compare_text = gr.Textbox(
|
|
label="Text to Compare",
|
|
value="Hello, how are you today?",
|
|
lines=2
|
|
)
|
|
|
|
with gr.Row():
|
|
lang1 = gr.Dropdown(choices=list(LANGUAGES.keys()), value="English", label="Language 1")
|
|
lang2 = gr.Dropdown(choices=list(LANGUAGES.keys()), value="Spanish", label="Language 2")
|
|
|
|
compare_btn = gr.Button("Compare Languages", variant="primary")
|
|
|
|
with gr.Row():
|
|
with gr.Column():
|
|
gr.Markdown("### Language 1")
|
|
audio1 = gr.Audio(label="Audio 1", type="numpy")
|
|
status1 = gr.Textbox(label="Status", interactive=False)
|
|
|
|
with gr.Column():
|
|
gr.Markdown("### Language 2")
|
|
audio2 = gr.Audio(label="Audio 2", type="numpy")
|
|
status2 = gr.Textbox(label="Status", interactive=False)
|
|
|
|
def compare_languages(text, l1, l2):
|
|
s1, a1, _ = synthesize_speech(text, l1)
|
|
s2, a2, _ = synthesize_speech(text, l2)
|
|
return s1, a1, s2, a2
|
|
|
|
compare_btn.click(
|
|
fn=compare_languages,
|
|
inputs=[compare_text, lang1, lang2],
|
|
outputs=[status1, audio1, status2, audio2]
|
|
)
|
|
|
|
# Tab 3: Batch Processing
|
|
with gr.TabItem("📚 Batch Synthesis"):
|
|
gr.Markdown("Synthesize multiple texts at once (one per line).")
|
|
|
|
batch_input = gr.Textbox(
|
|
label="Texts (one per line)",
|
|
placeholder="Enter multiple texts, one per line...",
|
|
lines=6
|
|
)
|
|
batch_lang = gr.Dropdown(
|
|
choices=list(LANGUAGES.keys()),
|
|
value="English",
|
|
label="Language"
|
|
)
|
|
batch_btn = gr.Button("Synthesize All", variant="primary")
|
|
|
|
batch_status = gr.Textbox(label="Status", interactive=False)
|
|
batch_audios = gr.Dataset(
|
|
components=[gr.Audio(type="numpy")],
|
|
label="Generated Audio Files"
|
|
)
|
|
|
|
# Note: Batch processing would need more complex handling
|
|
# This is a simplified version
|
|
gr.Markdown("""
|
|
*Note: For batch processing of many texts, consider using the API directly
|
|
or the Kubeflow pipeline for better throughput.*
|
|
""")
|
|
|
|
create_footer()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
demo.launch(
|
|
server_name="0.0.0.0",
|
|
server_port=7860,
|
|
show_error=True
|
|
)
|