Files
gradio-ui/stt.py
Billy D. faa5dc0d9d
Some checks failed
CI / Docker Build & Push (push) Failing after 2m37s
CI / Deploy to Kubernetes (push) Has been skipped
CI / Notify (push) Successful in 1s
CI / Lint (push) Successful in 10s
CI / Release (push) Successful in 4s
fix: remove unused imports and apply ruff formatting
- Remove unused imports: json (llm.py), tempfile (stt.py), base64 (tts.py)
- Apply ruff format to all Python files
2026-02-18 18:36:16 -05:00

368 lines
12 KiB
Python

#!/usr/bin/env python3
"""
STT Demo - Gradio UI for testing Speech-to-Text (Whisper) service.
Features:
- Microphone recording input
- Audio file upload support
- Multiple language support
- Translation mode
- MLflow metrics logging
"""
import os
import time
import logging
import io
import gradio as gr
import httpx
import soundfile as sf
import numpy as np
from theme import get_lab_theme, CUSTOM_CSS, create_footer
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("stt-demo")
# Configuration
STT_URL = os.environ.get(
"STT_URL",
# Default: Ray Serve whisper endpoint
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper",
)
MLFLOW_TRACKING_URI = os.environ.get(
"MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"
)
# ─── MLflow experiment tracking ──────────────────────────────────────────
try:
import mlflow
from mlflow.tracking import MlflowClient
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
_mlflow_client = MlflowClient()
_experiment = _mlflow_client.get_experiment_by_name("gradio-stt-tuning")
if _experiment is None:
_experiment_id = _mlflow_client.create_experiment(
"gradio-stt-tuning",
artifact_location="/mlflow/artifacts/gradio-stt-tuning",
)
else:
_experiment_id = _experiment.experiment_id
_mlflow_run = mlflow.start_run(
experiment_id=_experiment_id,
run_name=f"gradio-stt-{os.environ.get('HOSTNAME', 'local')}",
tags={"service": "gradio-stt", "endpoint": STT_URL},
)
_mlflow_run_id = _mlflow_run.info.run_id
_mlflow_step = 0
MLFLOW_ENABLED = True
logger.info(
"MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id
)
except Exception as exc:
logger.warning("MLflow tracking disabled: %s", exc)
_mlflow_client = None
_mlflow_run_id = None
_mlflow_step = 0
MLFLOW_ENABLED = False
def _log_stt_metrics(
latency: float,
audio_duration: float,
word_count: int,
task: str,
) -> None:
"""Log STT inference metrics to MLflow (non-blocking best-effort)."""
global _mlflow_step
if not MLFLOW_ENABLED or _mlflow_client is None:
return
try:
_mlflow_step += 1
ts = int(time.time() * 1000)
rtf = latency / audio_duration if audio_duration > 0 else 0
_mlflow_client.log_batch(
_mlflow_run_id,
metrics=[
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
mlflow.entities.Metric(
"audio_duration_s", audio_duration, ts, _mlflow_step
),
mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step),
mlflow.entities.Metric("word_count", word_count, ts, _mlflow_step),
],
params=[]
if _mlflow_step > 1
else [
mlflow.entities.Param("task", task),
],
)
except Exception:
logger.debug("MLflow log failed", exc_info=True)
# HTTP client with longer timeout for transcription
client = httpx.Client(timeout=180.0)
# Whisper supported languages
LANGUAGES = {
"Auto-detect": None,
"English": "en",
"Spanish": "es",
"French": "fr",
"German": "de",
"Italian": "it",
"Portuguese": "pt",
"Dutch": "nl",
"Russian": "ru",
"Chinese": "zh",
"Japanese": "ja",
"Korean": "ko",
"Arabic": "ar",
"Hindi": "hi",
"Turkish": "tr",
"Polish": "pl",
"Ukrainian": "uk",
}
def transcribe_audio(
audio_input: tuple[int, np.ndarray] | str | None, language: str, task: str
) -> tuple[str, str, str]:
"""Transcribe audio using the Whisper STT service."""
if audio_input is None:
return "❌ Please provide audio input", "", ""
try:
start_time = time.time()
# Handle different input types
if isinstance(audio_input, tuple):
# Microphone input: (sample_rate, audio_data)
sample_rate, audio_data = audio_input
# Convert to WAV bytes
audio_buffer = io.BytesIO()
sf.write(audio_buffer, audio_data, sample_rate, format="WAV")
audio_bytes = audio_buffer.getvalue()
audio_duration = len(audio_data) / sample_rate
else:
# File path
with open(audio_input, "rb") as f:
audio_bytes = f.read()
# Get duration
audio_data, sample_rate = sf.read(audio_input)
audio_duration = len(audio_data) / sample_rate
# Prepare request
lang_code = LANGUAGES.get(language)
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
data = {"response_format": "json"}
if lang_code:
data["language"] = lang_code
# Choose endpoint based on task
if task == "Translate to English":
endpoint = f"{STT_URL}/v1/audio/translations"
else:
endpoint = f"{STT_URL}/v1/audio/transcriptions"
# Send request
response = client.post(endpoint, files=files, data=data)
response.raise_for_status()
latency = time.time() - start_time
result = response.json()
text = result.get("text", "")
detected_language = result.get("language", "unknown")
# Log to MLflow
_log_stt_metrics(
latency=latency,
audio_duration=audio_duration,
word_count=len(text.split()),
task=task,
)
# Status message
status = (
f"✅ Transcribed {audio_duration:.1f}s of audio in {latency * 1000:.0f}ms"
)
# Metrics
metrics = f"""
**Transcription Statistics:**
- Audio Duration: {audio_duration:.2f} seconds
- Processing Time: {latency * 1000:.0f}ms
- Real-time Factor: {latency / audio_duration:.2f}x
- Detected Language: {detected_language}
- Task: {task}
- Word Count: {len(text.split())}
- Character Count: {len(text)}
"""
return status, text, metrics
except httpx.HTTPStatusError as e:
logger.exception("STT request failed")
return f"❌ STT service error: {e.response.status_code}", "", ""
except Exception as e:
logger.exception("Transcription failed")
return f"❌ Error: {str(e)}", "", ""
def check_service_health() -> str:
"""Check if the STT service is healthy."""
try:
response = client.get(f"{STT_URL}/health", timeout=5.0)
if response.status_code == 200:
return "🟢 Service is healthy"
# Try v1/models endpoint (OpenAI-compatible)
response = client.get(f"{STT_URL}/v1/models", timeout=5.0)
if response.status_code == 200:
return "🟢 Service is healthy"
return f"🟡 Service returned status {response.status_code}"
except Exception as e:
return f"🔴 Service unavailable: {str(e)}"
# Build the Gradio app
with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="STT Demo") as demo:
gr.Markdown("""
# 🎙️ Speech-to-Text Demo
Test the **Whisper** speech-to-text service. Transcribe audio from microphone
or file upload with support for 100+ languages.
""")
# Service status
with gr.Row():
health_btn = gr.Button("🔄 Check Service", size="sm")
health_status = gr.Textbox(label="Service Status", interactive=False)
health_btn.click(fn=check_service_health, outputs=health_status)
with gr.Tabs():
# Tab 1: Microphone Input
with gr.TabItem("🎤 Microphone"):
with gr.Row():
with gr.Column():
mic_input = gr.Audio(
label="Record Audio", sources=["microphone"], type="numpy"
)
with gr.Row():
mic_language = gr.Dropdown(
choices=list(LANGUAGES.keys()),
value="Auto-detect",
label="Language",
)
mic_task = gr.Radio(
choices=["Transcribe", "Translate to English"],
value="Transcribe",
label="Task",
)
mic_btn = gr.Button("🎯 Transcribe", variant="primary")
with gr.Column():
mic_status = gr.Textbox(label="Status", interactive=False)
mic_metrics = gr.Markdown(label="Metrics")
mic_output = gr.Textbox(label="Transcription", lines=5)
mic_btn.click(
fn=transcribe_audio,
inputs=[mic_input, mic_language, mic_task],
outputs=[mic_status, mic_output, mic_metrics],
)
# Tab 2: File Upload
with gr.TabItem("📁 File Upload"):
with gr.Row():
with gr.Column():
file_input = gr.Audio(
label="Upload Audio File", sources=["upload"], type="filepath"
)
with gr.Row():
file_language = gr.Dropdown(
choices=list(LANGUAGES.keys()),
value="Auto-detect",
label="Language",
)
file_task = gr.Radio(
choices=["Transcribe", "Translate to English"],
value="Transcribe",
label="Task",
)
file_btn = gr.Button("🎯 Transcribe", variant="primary")
with gr.Column():
file_status = gr.Textbox(label="Status", interactive=False)
file_metrics = gr.Markdown(label="Metrics")
file_output = gr.Textbox(label="Transcription", lines=5)
file_btn.click(
fn=transcribe_audio,
inputs=[file_input, file_language, file_task],
outputs=[file_status, file_output, file_metrics],
)
gr.Markdown("""
**Supported formats:** WAV, MP3, FLAC, OGG, M4A, WEBM
*For best results, use clear audio with minimal background noise.*
""")
# Tab 3: Translation
with gr.TabItem("🌍 Translation"):
gr.Markdown("""
### Speech Translation
Upload or record audio in any language and get English translation.
Whisper will automatically detect the source language.
""")
with gr.Row():
with gr.Column():
trans_input = gr.Audio(
label="Audio Input",
sources=["microphone", "upload"],
type="numpy",
)
trans_btn = gr.Button("🌍 Translate to English", variant="primary")
with gr.Column():
trans_status = gr.Textbox(label="Status", interactive=False)
trans_metrics = gr.Markdown(label="Metrics")
trans_output = gr.Textbox(label="English Translation", lines=5)
def translate_audio(audio):
return transcribe_audio(audio, "Auto-detect", "Translate to English")
trans_btn.click(
fn=translate_audio,
inputs=trans_input,
outputs=[trans_status, trans_output, trans_metrics],
)
create_footer()
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)