#!/usr/bin/env python3 """ STT Demo - Gradio UI for testing Speech-to-Text (Whisper) service. Features: - Microphone recording input - Audio file upload support - Multiple language support - Translation mode - MLflow metrics logging """ import os import time import logging import io import gradio as gr import httpx import soundfile as sf import numpy as np from theme import get_lab_theme, CUSTOM_CSS, create_footer # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("stt-demo") # Configuration STT_URL = os.environ.get( "STT_URL", # Default: Ray Serve whisper endpoint "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper", ) MLFLOW_TRACKING_URI = os.environ.get( "MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80" ) # ─── MLflow experiment tracking ────────────────────────────────────────── try: import mlflow from mlflow.tracking import MlflowClient mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) _mlflow_client = MlflowClient() _experiment = _mlflow_client.get_experiment_by_name("gradio-stt-tuning") if _experiment is None: _experiment_id = _mlflow_client.create_experiment( "gradio-stt-tuning", artifact_location="/mlflow/artifacts/gradio-stt-tuning", ) else: _experiment_id = _experiment.experiment_id _mlflow_run = mlflow.start_run( experiment_id=_experiment_id, run_name=f"gradio-stt-{os.environ.get('HOSTNAME', 'local')}", tags={"service": "gradio-stt", "endpoint": STT_URL}, ) _mlflow_run_id = _mlflow_run.info.run_id _mlflow_step = 0 MLFLOW_ENABLED = True logger.info( "MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id ) except Exception as exc: logger.warning("MLflow tracking disabled: %s", exc) _mlflow_client = None _mlflow_run_id = None _mlflow_step = 0 MLFLOW_ENABLED = False def _log_stt_metrics( latency: float, audio_duration: float, word_count: int, task: str, ) -> None: """Log STT inference metrics to MLflow (non-blocking best-effort).""" global _mlflow_step if not MLFLOW_ENABLED or _mlflow_client is None: return try: _mlflow_step += 1 ts = int(time.time() * 1000) rtf = latency / audio_duration if audio_duration > 0 else 0 _mlflow_client.log_batch( _mlflow_run_id, metrics=[ mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step), mlflow.entities.Metric( "audio_duration_s", audio_duration, ts, _mlflow_step ), mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step), mlflow.entities.Metric("word_count", word_count, ts, _mlflow_step), ], params=[] if _mlflow_step > 1 else [ mlflow.entities.Param("task", task), ], ) except Exception: logger.debug("MLflow log failed", exc_info=True) # HTTP client with longer timeout for transcription client = httpx.Client(timeout=180.0) # Whisper supported languages LANGUAGES = { "Auto-detect": None, "English": "en", "Spanish": "es", "French": "fr", "German": "de", "Italian": "it", "Portuguese": "pt", "Dutch": "nl", "Russian": "ru", "Chinese": "zh", "Japanese": "ja", "Korean": "ko", "Arabic": "ar", "Hindi": "hi", "Turkish": "tr", "Polish": "pl", "Ukrainian": "uk", } def transcribe_audio( audio_input: tuple[int, np.ndarray] | str | None, language: str, task: str ) -> tuple[str, str, str]: """Transcribe audio using the Whisper STT service.""" if audio_input is None: return "❌ Please provide audio input", "", "" try: start_time = time.time() # Handle different input types if isinstance(audio_input, tuple): # Microphone input: (sample_rate, audio_data) sample_rate, audio_data = audio_input # Convert to WAV bytes audio_buffer = io.BytesIO() sf.write(audio_buffer, audio_data, sample_rate, format="WAV") audio_bytes = audio_buffer.getvalue() audio_duration = len(audio_data) / sample_rate else: # File path with open(audio_input, "rb") as f: audio_bytes = f.read() # Get duration audio_data, sample_rate = sf.read(audio_input) audio_duration = len(audio_data) / sample_rate # Prepare request lang_code = LANGUAGES.get(language) files = {"file": ("audio.wav", audio_bytes, "audio/wav")} data = {"response_format": "json"} if lang_code: data["language"] = lang_code # Choose endpoint based on task if task == "Translate to English": endpoint = f"{STT_URL}/v1/audio/translations" else: endpoint = f"{STT_URL}/v1/audio/transcriptions" # Send request response = client.post(endpoint, files=files, data=data) response.raise_for_status() latency = time.time() - start_time result = response.json() text = result.get("text", "") detected_language = result.get("language", "unknown") # Log to MLflow _log_stt_metrics( latency=latency, audio_duration=audio_duration, word_count=len(text.split()), task=task, ) # Status message status = ( f"✅ Transcribed {audio_duration:.1f}s of audio in {latency * 1000:.0f}ms" ) # Metrics metrics = f""" **Transcription Statistics:** - Audio Duration: {audio_duration:.2f} seconds - Processing Time: {latency * 1000:.0f}ms - Real-time Factor: {latency / audio_duration:.2f}x - Detected Language: {detected_language} - Task: {task} - Word Count: {len(text.split())} - Character Count: {len(text)} """ return status, text, metrics except httpx.HTTPStatusError as e: logger.exception("STT request failed") return f"❌ STT service error: {e.response.status_code}", "", "" except Exception as e: logger.exception("Transcription failed") return f"❌ Error: {str(e)}", "", "" def check_service_health() -> str: """Check if the STT service is healthy.""" try: response = client.get(f"{STT_URL}/health", timeout=5.0) if response.status_code == 200: return "🟢 Service is healthy" # Try v1/models endpoint (OpenAI-compatible) response = client.get(f"{STT_URL}/v1/models", timeout=5.0) if response.status_code == 200: return "🟢 Service is healthy" return f"🟡 Service returned status {response.status_code}" except Exception as e: return f"🔴 Service unavailable: {str(e)}" # Build the Gradio app with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="STT Demo") as demo: gr.Markdown(""" # 🎙️ Speech-to-Text Demo Test the **Whisper** speech-to-text service. Transcribe audio from microphone or file upload with support for 100+ languages. """) # Service status with gr.Row(): health_btn = gr.Button("🔄 Check Service", size="sm") health_status = gr.Textbox(label="Service Status", interactive=False) health_btn.click(fn=check_service_health, outputs=health_status) with gr.Tabs(): # Tab 1: Microphone Input with gr.TabItem("🎤 Microphone"): with gr.Row(): with gr.Column(): mic_input = gr.Audio( label="Record Audio", sources=["microphone"], type="numpy" ) with gr.Row(): mic_language = gr.Dropdown( choices=list(LANGUAGES.keys()), value="Auto-detect", label="Language", ) mic_task = gr.Radio( choices=["Transcribe", "Translate to English"], value="Transcribe", label="Task", ) mic_btn = gr.Button("🎯 Transcribe", variant="primary") with gr.Column(): mic_status = gr.Textbox(label="Status", interactive=False) mic_metrics = gr.Markdown(label="Metrics") mic_output = gr.Textbox(label="Transcription", lines=5) mic_btn.click( fn=transcribe_audio, inputs=[mic_input, mic_language, mic_task], outputs=[mic_status, mic_output, mic_metrics], ) # Tab 2: File Upload with gr.TabItem("📁 File Upload"): with gr.Row(): with gr.Column(): file_input = gr.Audio( label="Upload Audio File", sources=["upload"], type="filepath" ) with gr.Row(): file_language = gr.Dropdown( choices=list(LANGUAGES.keys()), value="Auto-detect", label="Language", ) file_task = gr.Radio( choices=["Transcribe", "Translate to English"], value="Transcribe", label="Task", ) file_btn = gr.Button("🎯 Transcribe", variant="primary") with gr.Column(): file_status = gr.Textbox(label="Status", interactive=False) file_metrics = gr.Markdown(label="Metrics") file_output = gr.Textbox(label="Transcription", lines=5) file_btn.click( fn=transcribe_audio, inputs=[file_input, file_language, file_task], outputs=[file_status, file_output, file_metrics], ) gr.Markdown(""" **Supported formats:** WAV, MP3, FLAC, OGG, M4A, WEBM *For best results, use clear audio with minimal background noise.* """) # Tab 3: Translation with gr.TabItem("🌍 Translation"): gr.Markdown(""" ### Speech Translation Upload or record audio in any language and get English translation. Whisper will automatically detect the source language. """) with gr.Row(): with gr.Column(): trans_input = gr.Audio( label="Audio Input", sources=["microphone", "upload"], type="numpy", ) trans_btn = gr.Button("🌍 Translate to English", variant="primary") with gr.Column(): trans_status = gr.Textbox(label="Status", interactive=False) trans_metrics = gr.Markdown(label="Metrics") trans_output = gr.Textbox(label="English Translation", lines=5) def translate_audio(audio): return transcribe_audio(audio, "Auto-detect", "Translate to English") trans_btn.click( fn=translate_audio, inputs=trans_input, outputs=[trans_status, trans_output, trans_metrics], ) create_footer() if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)