#!/usr/bin/env python3 """ TTS Demo - Gradio UI for testing Text-to-Speech service. Features: - Text input with language selection - Audio playback of synthesized speech - Sentence-level chunking for better quality - Speed control - MLflow metrics logging """ import os import re import time import logging import io import wave import gradio as gr import httpx import numpy as np from theme import get_lab_theme, CUSTOM_CSS, create_footer # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("tts-demo") # Configuration TTS_URL = os.environ.get( "TTS_URL", # Default: Ray Serve TTS endpoint "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/tts", ) MLFLOW_TRACKING_URI = os.environ.get( "MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80" ) # ─── MLflow experiment tracking ────────────────────────────────────────── try: import mlflow from mlflow.tracking import MlflowClient mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) _mlflow_client = MlflowClient() _experiment = _mlflow_client.get_experiment_by_name("gradio-tts-tuning") if _experiment is None: _experiment_id = _mlflow_client.create_experiment( "gradio-tts-tuning", artifact_location="/mlflow/artifacts/gradio-tts-tuning", ) else: _experiment_id = _experiment.experiment_id _mlflow_run = mlflow.start_run( experiment_id=_experiment_id, run_name=f"gradio-tts-{os.environ.get('HOSTNAME', 'local')}", tags={"service": "gradio-tts", "endpoint": TTS_URL}, ) _mlflow_run_id = _mlflow_run.info.run_id _mlflow_step = 0 MLFLOW_ENABLED = True logger.info( "MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id ) except Exception as exc: logger.warning("MLflow tracking disabled: %s", exc) _mlflow_client = None _mlflow_run_id = None _mlflow_step = 0 MLFLOW_ENABLED = False def _log_tts_metrics( latency: float, audio_duration: float, text_chars: int, language: str, ) -> None: """Log TTS inference metrics to MLflow (non-blocking best-effort).""" global _mlflow_step if not MLFLOW_ENABLED or _mlflow_client is None: return try: _mlflow_step += 1 ts = int(time.time() * 1000) rtf = latency / audio_duration if audio_duration > 0 else 0 cps = text_chars / latency if latency > 0 else 0 _mlflow_client.log_batch( _mlflow_run_id, metrics=[ mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step), mlflow.entities.Metric( "audio_duration_s", audio_duration, ts, _mlflow_step ), mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step), mlflow.entities.Metric("chars_per_second", cps, ts, _mlflow_step), mlflow.entities.Metric("text_chars", text_chars, ts, _mlflow_step), ], ) except Exception: logger.debug("MLflow log failed", exc_info=True) # HTTP client with longer timeout for audio generation client = httpx.Client(timeout=120.0) # Supported languages for XTTS LANGUAGES = { "English": "en", "Spanish": "es", "French": "fr", "German": "de", "Italian": "it", "Portuguese": "pt", "Polish": "pl", "Turkish": "tr", "Russian": "ru", "Dutch": "nl", "Czech": "cs", "Arabic": "ar", "Chinese": "zh-cn", "Japanese": "ja", "Korean": "ko", "Hungarian": "hu", } # ─── Text preprocessing ───────────────────────────────────────────────── _SENTENCE_RE = re.compile(r"(?<=[.!?;])\s+|(?<=\n)\s*", re.MULTILINE) _DIGIT_WORDS = { "0": "zero", "1": "one", "2": "two", "3": "three", "4": "four", "5": "five", "6": "six", "7": "seven", "8": "eight", "9": "nine", } def _expand_numbers(text: str) -> str: """Expand standalone single digits to words for clearer pronunciation.""" return re.sub( r"\b(\d)\b", lambda m: _DIGIT_WORDS.get(m.group(0), m.group(0)), text, ) def _clean_text(text: str) -> str: """Clean and normalise text for TTS input.""" text = re.sub(r"[ \t]+", " ", text) text = "\n".join(line.strip() for line in text.splitlines()) # Strip markdown / code-fence characters text = re.sub(r"[*#~`|<>{}[\]\\]", "", text) # Expand common symbols text = text.replace("&", " and ") text = text.replace("@", " at ") text = text.replace("%", " percent ") text = text.replace("+", " plus ") text = text.replace("=", " equals ") text = _expand_numbers(text) return text.strip() def _split_sentences(text: str) -> list[str]: """Split text into sentences suitable for TTS. Keeps sentences short for best quality while preserving natural phrasing. Very long segments are further split on commas / semicolons. """ text = _clean_text(text) if not text: return [] raw_parts = _SENTENCE_RE.split(text) sentences: list[str] = [] for part in raw_parts: part = part.strip() if not part: continue if len(part) > 200: for sp in re.split(r"(?<=[,;])\s+", part): sp = sp.strip() if sp: sentences.append(sp) else: sentences.append(part) return sentences # ─── Audio helpers ─────────────────────────────────────────────────────── def _read_wav_bytes(data: bytes) -> tuple[int, np.ndarray]: """Read WAV audio from bytes, handling scipy wavfile and standard WAV. Returns (sample_rate, float32_audio) with values in [-1, 1]. """ buf = io.BytesIO(data) # Try stdlib wave module first — most robust for PCM WAV from scipy try: with wave.open(buf, "rb") as wf: sr = wf.getframerate() n_frames = wf.getnframes() n_channels = wf.getnchannels() sampwidth = wf.getsampwidth() raw = wf.readframes(n_frames) if sampwidth == 2: audio = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0 elif sampwidth == 4: audio = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2147483648.0 elif sampwidth == 1: audio = ( np.frombuffer(raw, dtype=np.uint8).astype(np.float32) - 128.0 ) / 128.0 else: raise ValueError(f"Unsupported sample width: {sampwidth}") if n_channels > 1: audio = audio.reshape(-1, n_channels).mean(axis=1) return sr, audio except Exception as exc: logger.debug("wave module failed (%s), trying soundfile", exc) # Fallback: soundfile (handles FLAC, OGG, etc.) buf.seek(0) try: import soundfile as sf audio, sr = sf.read(buf, dtype="float32") if audio.ndim > 1: audio = audio.mean(axis=1) return sr, audio except Exception as exc: logger.debug("soundfile failed (%s), attempting raw PCM", exc) # Last resort: raw 16-bit PCM at 22050 Hz logger.warning( "Could not parse WAV header (len=%d, first 4 bytes=%r); raw PCM decode", len(data), data[:4], ) audio = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0 return 22050, audio def _concat_audio( chunks: list[tuple[int, np.ndarray]], pause_ms: int = 200 ) -> tuple[int, np.ndarray]: """Concatenate (sample_rate, audio) chunks with silence gaps.""" if not chunks: return 22050, np.array([], dtype=np.float32) if len(chunks) == 1: return chunks[0] sr = chunks[0][0] silence = np.zeros(int(sr * pause_ms / 1000), dtype=np.float32) parts: list[np.ndarray] = [] for sample_rate, audio in chunks: if sample_rate != sr: ratio = sr / sample_rate indices = np.arange(0, len(audio), 1.0 / ratio).astype(int) indices = indices[indices < len(audio)] audio = audio[indices] parts.append(audio) parts.append(silence) if parts: parts.pop() # remove trailing silence return sr, np.concatenate(parts) # ─── TTS synthesis ─────────────────────────────────────────────────────── def _synthesize_chunk(text: str, lang_code: str, speed: float = 1.0) -> bytes: """Synthesize a single text chunk via the TTS backend. Uses the JSON POST endpoint (no URL length limits, supports speed). Falls back to the Coqui-compatible GET endpoint if POST fails. """ import base64 as b64 # Try JSON POST first try: resp = client.post( TTS_URL, json={ "text": text, "language": lang_code, "speed": speed, "return_base64": True, }, ) resp.raise_for_status() ct = resp.headers.get("content-type", "") if "application/json" in ct: body = resp.json() if "error" in body: raise RuntimeError(body["error"]) audio_b64 = body.get("audio", "") if audio_b64: return b64.b64decode(audio_b64) # Non-JSON response — treat as raw audio bytes return resp.content except Exception: logger.debug( "POST endpoint failed, falling back to GET /api/tts", exc_info=True ) # Fallback: Coqui-compatible GET (no speed control) resp = client.get( f"{TTS_URL}/api/tts", params={"text": text, "language_id": lang_code}, ) resp.raise_for_status() return resp.content def synthesize_speech( text: str, language: str, speed: float ) -> tuple[str, tuple[int, np.ndarray] | None, str]: """Synthesize speech from text using the TTS service. Long text is split into sentences and synthesized individually for better quality, then concatenated with natural pauses. """ if not text.strip(): return "❌ Please enter some text", None, "" lang_code = LANGUAGES.get(language, "en") sentences = _split_sentences(text) if not sentences: return "❌ No speakable text found after cleaning", None, "" try: start_time = time.time() audio_chunks: list[tuple[int, np.ndarray]] = [] for sentence in sentences: raw_audio = _synthesize_chunk(sentence, lang_code, speed) sr, audio = _read_wav_bytes(raw_audio) audio_chunks.append((sr, audio)) sample_rate, audio_data = _concat_audio(audio_chunks) latency = time.time() - start_time duration = len(audio_data) / sample_rate if sample_rate > 0 else 0 n_chunks = len(sentences) status = ( f"✅ Generated {duration:.2f}s of audio in {latency * 1000:.0f}ms" f" ({n_chunks} sentence{'s' if n_chunks != 1 else ''})" ) _log_tts_metrics( latency=latency, audio_duration=duration, text_chars=len(text), language=lang_code, ) metrics = f""" **Audio Statistics:** - Duration: {duration:.2f} seconds - Sample Rate: {sample_rate} Hz - Size: {len(audio_data) * 2 / 1024:.1f} KB - Generation Time: {latency * 1000:.0f}ms - Real-time Factor: {latency / duration:.2f}x - Language: {language} ({lang_code}) - Speed: {speed:.1f}x - Sentences: {n_chunks} - Characters: {len(text)} - Chars/sec: {len(text) / latency:.1f} """ return status, (sample_rate, audio_data), metrics except httpx.HTTPStatusError as e: logger.exception("TTS request failed") return f"❌ TTS service error: {e.response.status_code}", None, "" except Exception as e: logger.exception("TTS synthesis failed") return f"❌ Error: {e}", None, "" def check_service_health() -> str: """Check if the TTS service is healthy.""" try: response = client.get(f"{TTS_URL}/health", timeout=5.0) if response.status_code == 200: return "🟢 Service is healthy" response = client.get(f"{TTS_URL}/", timeout=5.0) if response.status_code == 200: return "🟢 Service is responding" return f"🟡 Service returned status {response.status_code}" except Exception as e: return f"🔴 Service unavailable: {e}" # ─── Gradio UI ─────────────────────────────────────────────────────────── with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="TTS Demo") as demo: gr.Markdown(""" # 🔊 Text-to-Speech Demo Test the **Coqui XTTS** text-to-speech service. Convert text to natural-sounding speech in multiple languages. Long text is automatically split into sentences for better quality. """) with gr.Row(): health_btn = gr.Button("🔄 Check Service", size="sm") health_status = gr.Textbox(label="Service Status", interactive=False) health_btn.click(fn=check_service_health, outputs=health_status) with gr.Tabs(): with gr.TabItem("🎤 Text to Speech"): with gr.Row(): with gr.Column(scale=2): text_input = gr.Textbox( label="Text to Synthesize", placeholder="Enter text to convert to speech...", lines=5, max_lines=10, ) with gr.Row(): language = gr.Dropdown( choices=list(LANGUAGES.keys()), value="English", label="Language", ) speed = gr.Slider( minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed", ) synthesize_btn = gr.Button( "🔊 Synthesize", variant="primary", scale=2, ) with gr.Column(scale=1): status_output = gr.Textbox(label="Status", interactive=False) metrics_output = gr.Markdown(label="Metrics") audio_output = gr.Audio(label="Generated Audio", type="numpy") synthesize_btn.click( fn=synthesize_speech, inputs=[text_input, language, speed], outputs=[status_output, audio_output, metrics_output], ) gr.Examples( examples=[ [ "Hello! Welcome to Davies Tech Labs. This is a demonstration of our text-to-speech system.", "English", 1.0, ], [ "The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet.", "English", 1.0, ], [ "Bonjour! Bienvenue au laboratoire technique de Davies.", "French", 1.0, ], ["Hola! Bienvenido al laboratorio de tecnología.", "Spanish", 1.0], ["Guten Tag! Willkommen im Techniklabor.", "German", 1.0], ], inputs=[text_input, language, speed], ) with gr.TabItem("🔄 Language Comparison"): gr.Markdown("Compare the same text in different languages.") compare_text = gr.Textbox( label="Text to Compare", value="Hello, how are you today?", lines=2 ) with gr.Row(): lang1 = gr.Dropdown( choices=list(LANGUAGES.keys()), value="English", label="Language 1" ) lang2 = gr.Dropdown( choices=list(LANGUAGES.keys()), value="Spanish", label="Language 2" ) compare_speed = gr.Slider( minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed" ) compare_btn = gr.Button("Compare Languages", variant="primary") with gr.Row(): with gr.Column(): gr.Markdown("### Language 1") audio1 = gr.Audio(label="Audio 1", type="numpy") status1 = gr.Textbox(label="Status", interactive=False) with gr.Column(): gr.Markdown("### Language 2") audio2 = gr.Audio(label="Audio 2", type="numpy") status2 = gr.Textbox(label="Status", interactive=False) def compare_languages(text, l1, l2, spd): s1, a1, _ = synthesize_speech(text, l1, spd) s2, a2, _ = synthesize_speech(text, l2, spd) return s1, a1, s2, a2 compare_btn.click( fn=compare_languages, inputs=[compare_text, lang1, lang2, compare_speed], outputs=[status1, audio1, status2, audio2], ) with gr.TabItem("📚 Batch Synthesis"): gr.Markdown("Synthesize multiple texts at once (one per line).") batch_input = gr.Textbox( label="Texts (one per line)", placeholder="Enter multiple texts, one per line...", lines=6, ) batch_lang = gr.Dropdown( choices=list(LANGUAGES.keys()), value="English", label="Language" ) batch_speed = gr.Slider( minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed" ) batch_btn = gr.Button("Synthesize All", variant="primary") batch_status = gr.Textbox(label="Status", interactive=False) batch_audio = gr.Audio(label="Combined Audio", type="numpy") def batch_synthesize(texts_raw: str, lang: str, spd: float): lines = [ line.strip() for line in texts_raw.strip().splitlines() if line.strip() ] if not lines: return "❌ Please enter at least one line of text", None combined = "\n".join(lines) status, audio, _ = synthesize_speech(combined, lang, spd) return status, audio batch_btn.click( fn=batch_synthesize, inputs=[batch_input, batch_lang, batch_speed], outputs=[batch_status, batch_audio], ) create_footer() if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)