Files
gradio-ui/stt.py
Billy D. 1c5dc7f751 feat: add MLflow experiment tracking to all 4 Gradio UIs
Each UI now logs per-request metrics to MLflow:
- llm.py: latency, tokens/sec, prompt/completion tokens (gradio-llm-tuning)
- embeddings.py: latency, text length, batch size (gradio-embeddings-tuning)
- stt.py: latency, audio duration, real-time factor (gradio-stt-tuning)
- tts.py: latency, text length, audio duration (gradio-tts-tuning)

Uses try/except guarded imports so UIs still work if MLflow is
unreachable. Persistent run per Gradio instance, batched metric logging
via MlflowClient.log_batch().
2026-02-13 07:54:06 -05:00

377 lines
12 KiB
Python

#!/usr/bin/env python3
"""
STT Demo - Gradio UI for testing Speech-to-Text (Whisper) service.
Features:
- Microphone recording input
- Audio file upload support
- Multiple language support
- Translation mode
- MLflow metrics logging
"""
import os
import time
import logging
import io
import tempfile
import gradio as gr
import httpx
import soundfile as sf
import numpy as np
from theme import get_lab_theme, CUSTOM_CSS, create_footer
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("stt-demo")
# Configuration
STT_URL = os.environ.get(
"STT_URL",
# Default: Ray Serve whisper endpoint
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper"
)
MLFLOW_TRACKING_URI = os.environ.get(
"MLFLOW_TRACKING_URI",
"http://mlflow.mlflow.svc.cluster.local:80"
)
# ─── MLflow experiment tracking ──────────────────────────────────────────
try:
import mlflow
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
_mlflow_client = MlflowClient()
_experiment = _mlflow_client.get_experiment_by_name("gradio-stt-tuning")
if _experiment is None:
_experiment_id = _mlflow_client.create_experiment(
"gradio-stt-tuning",
artifact_location="/mlflow/artifacts/gradio-stt-tuning",
)
else:
_experiment_id = _experiment.experiment_id
_mlflow_run = mlflow.start_run(
experiment_id=_experiment_id,
run_name=f"gradio-stt-{os.environ.get('HOSTNAME', 'local')}",
tags={"service": "gradio-stt", "endpoint": STT_URL},
)
_mlflow_run_id = _mlflow_run.info.run_id
_mlflow_step = 0
MLFLOW_ENABLED = True
logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id)
except Exception as exc:
logger.warning("MLflow tracking disabled: %s", exc)
_mlflow_client = None
_mlflow_run_id = None
_mlflow_step = 0
MLFLOW_ENABLED = False
def _log_stt_metrics(
latency: float, audio_duration: float, word_count: int, task: str,
) -> None:
"""Log STT inference metrics to MLflow (non-blocking best-effort)."""
global _mlflow_step
if not MLFLOW_ENABLED or _mlflow_client is None:
return
try:
_mlflow_step += 1
ts = int(time.time() * 1000)
rtf = latency / audio_duration if audio_duration > 0 else 0
_mlflow_client.log_batch(
_mlflow_run_id,
metrics=[
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
mlflow.entities.Metric("audio_duration_s", audio_duration, ts, _mlflow_step),
mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step),
mlflow.entities.Metric("word_count", word_count, ts, _mlflow_step),
],
params=[] if _mlflow_step > 1 else [
mlflow.entities.Param("task", task),
],
)
except Exception:
logger.debug("MLflow log failed", exc_info=True)
# HTTP client with longer timeout for transcription
client = httpx.Client(timeout=180.0)
# Whisper supported languages
LANGUAGES = {
"Auto-detect": None,
"English": "en",
"Spanish": "es",
"French": "fr",
"German": "de",
"Italian": "it",
"Portuguese": "pt",
"Dutch": "nl",
"Russian": "ru",
"Chinese": "zh",
"Japanese": "ja",
"Korean": "ko",
"Arabic": "ar",
"Hindi": "hi",
"Turkish": "tr",
"Polish": "pl",
"Ukrainian": "uk",
}
def transcribe_audio(
audio_input: tuple[int, np.ndarray] | str | None,
language: str,
task: str
) -> tuple[str, str, str]:
"""Transcribe audio using the Whisper STT service."""
if audio_input is None:
return "❌ Please provide audio input", "", ""
try:
start_time = time.time()
# Handle different input types
if isinstance(audio_input, tuple):
# Microphone input: (sample_rate, audio_data)
sample_rate, audio_data = audio_input
# Convert to WAV bytes
audio_buffer = io.BytesIO()
sf.write(audio_buffer, audio_data, sample_rate, format='WAV')
audio_bytes = audio_buffer.getvalue()
audio_duration = len(audio_data) / sample_rate
else:
# File path
with open(audio_input, 'rb') as f:
audio_bytes = f.read()
# Get duration
audio_data, sample_rate = sf.read(audio_input)
audio_duration = len(audio_data) / sample_rate
# Prepare request
lang_code = LANGUAGES.get(language)
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
data = {"response_format": "json"}
if lang_code:
data["language"] = lang_code
# Choose endpoint based on task
if task == "Translate to English":
endpoint = f"{STT_URL}/v1/audio/translations"
else:
endpoint = f"{STT_URL}/v1/audio/transcriptions"
# Send request
response = client.post(endpoint, files=files, data=data)
response.raise_for_status()
latency = time.time() - start_time
result = response.json()
text = result.get("text", "")
detected_language = result.get("language", "unknown")
# Log to MLflow
_log_stt_metrics(
latency=latency,
audio_duration=audio_duration,
word_count=len(text.split()),
task=task,
)
# Status message
status = f"✅ Transcribed {audio_duration:.1f}s of audio in {latency*1000:.0f}ms"
# Metrics
metrics = f"""
**Transcription Statistics:**
- Audio Duration: {audio_duration:.2f} seconds
- Processing Time: {latency*1000:.0f}ms
- Real-time Factor: {latency/audio_duration:.2f}x
- Detected Language: {detected_language}
- Task: {task}
- Word Count: {len(text.split())}
- Character Count: {len(text)}
"""
return status, text, metrics
except httpx.HTTPStatusError as e:
logger.exception("STT request failed")
return f"❌ STT service error: {e.response.status_code}", "", ""
except Exception as e:
logger.exception("Transcription failed")
return f"❌ Error: {str(e)}", "", ""
def check_service_health() -> str:
"""Check if the STT service is healthy."""
try:
response = client.get(f"{STT_URL}/health", timeout=5.0)
if response.status_code == 200:
return "🟢 Service is healthy"
# Try v1/models endpoint (OpenAI-compatible)
response = client.get(f"{STT_URL}/v1/models", timeout=5.0)
if response.status_code == 200:
return "🟢 Service is healthy"
return f"🟡 Service returned status {response.status_code}"
except Exception as e:
return f"🔴 Service unavailable: {str(e)}"
# Build the Gradio app
with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="STT Demo") as demo:
gr.Markdown("""
# 🎙️ Speech-to-Text Demo
Test the **Whisper** speech-to-text service. Transcribe audio from microphone
or file upload with support for 100+ languages.
""")
# Service status
with gr.Row():
health_btn = gr.Button("🔄 Check Service", size="sm")
health_status = gr.Textbox(label="Service Status", interactive=False)
health_btn.click(fn=check_service_health, outputs=health_status)
with gr.Tabs():
# Tab 1: Microphone Input
with gr.TabItem("🎤 Microphone"):
with gr.Row():
with gr.Column():
mic_input = gr.Audio(
label="Record Audio",
sources=["microphone"],
type="numpy"
)
with gr.Row():
mic_language = gr.Dropdown(
choices=list(LANGUAGES.keys()),
value="Auto-detect",
label="Language"
)
mic_task = gr.Radio(
choices=["Transcribe", "Translate to English"],
value="Transcribe",
label="Task"
)
mic_btn = gr.Button("🎯 Transcribe", variant="primary")
with gr.Column():
mic_status = gr.Textbox(label="Status", interactive=False)
mic_metrics = gr.Markdown(label="Metrics")
mic_output = gr.Textbox(
label="Transcription",
lines=5
)
mic_btn.click(
fn=transcribe_audio,
inputs=[mic_input, mic_language, mic_task],
outputs=[mic_status, mic_output, mic_metrics]
)
# Tab 2: File Upload
with gr.TabItem("📁 File Upload"):
with gr.Row():
with gr.Column():
file_input = gr.Audio(
label="Upload Audio File",
sources=["upload"],
type="filepath"
)
with gr.Row():
file_language = gr.Dropdown(
choices=list(LANGUAGES.keys()),
value="Auto-detect",
label="Language"
)
file_task = gr.Radio(
choices=["Transcribe", "Translate to English"],
value="Transcribe",
label="Task"
)
file_btn = gr.Button("🎯 Transcribe", variant="primary")
with gr.Column():
file_status = gr.Textbox(label="Status", interactive=False)
file_metrics = gr.Markdown(label="Metrics")
file_output = gr.Textbox(
label="Transcription",
lines=5
)
file_btn.click(
fn=transcribe_audio,
inputs=[file_input, file_language, file_task],
outputs=[file_status, file_output, file_metrics]
)
gr.Markdown("""
**Supported formats:** WAV, MP3, FLAC, OGG, M4A, WEBM
*For best results, use clear audio with minimal background noise.*
""")
# Tab 3: Translation
with gr.TabItem("🌍 Translation"):
gr.Markdown("""
### Speech Translation
Upload or record audio in any language and get English translation.
Whisper will automatically detect the source language.
""")
with gr.Row():
with gr.Column():
trans_input = gr.Audio(
label="Audio Input",
sources=["microphone", "upload"],
type="numpy"
)
trans_btn = gr.Button("🌍 Translate to English", variant="primary")
with gr.Column():
trans_status = gr.Textbox(label="Status", interactive=False)
trans_metrics = gr.Markdown(label="Metrics")
trans_output = gr.Textbox(
label="English Translation",
lines=5
)
def translate_audio(audio):
return transcribe_audio(audio, "Auto-detect", "Translate to English")
trans_btn.click(
fn=translate_audio,
inputs=trans_input,
outputs=[trans_status, trans_output, trans_metrics]
)
create_footer()
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)