308 lines
9.6 KiB
Python
308 lines
9.6 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
STT Demo - Gradio UI for testing Speech-to-Text (Whisper) service.
|
|
|
|
Features:
|
|
- Microphone recording input
|
|
- Audio file upload support
|
|
- Multiple language support
|
|
- Translation mode
|
|
- MLflow metrics logging
|
|
"""
|
|
import os
|
|
import time
|
|
import logging
|
|
import io
|
|
import tempfile
|
|
|
|
import gradio as gr
|
|
import httpx
|
|
import soundfile as sf
|
|
import numpy as np
|
|
|
|
from theme import get_lab_theme, CUSTOM_CSS, create_footer
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger("stt-demo")
|
|
|
|
# Configuration
|
|
STT_URL = os.environ.get(
|
|
"STT_URL",
|
|
# Default: Ray Serve whisper endpoint
|
|
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper"
|
|
)
|
|
MLFLOW_TRACKING_URI = os.environ.get(
|
|
"MLFLOW_TRACKING_URI",
|
|
"http://mlflow.mlflow.svc.cluster.local:80"
|
|
)
|
|
|
|
# HTTP client with longer timeout for transcription
|
|
client = httpx.Client(timeout=180.0)
|
|
|
|
# Whisper supported languages
|
|
LANGUAGES = {
|
|
"Auto-detect": None,
|
|
"English": "en",
|
|
"Spanish": "es",
|
|
"French": "fr",
|
|
"German": "de",
|
|
"Italian": "it",
|
|
"Portuguese": "pt",
|
|
"Dutch": "nl",
|
|
"Russian": "ru",
|
|
"Chinese": "zh",
|
|
"Japanese": "ja",
|
|
"Korean": "ko",
|
|
"Arabic": "ar",
|
|
"Hindi": "hi",
|
|
"Turkish": "tr",
|
|
"Polish": "pl",
|
|
"Ukrainian": "uk",
|
|
}
|
|
|
|
|
|
def transcribe_audio(
|
|
audio_input: tuple[int, np.ndarray] | str | None,
|
|
language: str,
|
|
task: str
|
|
) -> tuple[str, str, str]:
|
|
"""Transcribe audio using the Whisper STT service."""
|
|
if audio_input is None:
|
|
return "❌ Please provide audio input", "", ""
|
|
|
|
try:
|
|
start_time = time.time()
|
|
|
|
# Handle different input types
|
|
if isinstance(audio_input, tuple):
|
|
# Microphone input: (sample_rate, audio_data)
|
|
sample_rate, audio_data = audio_input
|
|
|
|
# Convert to WAV bytes
|
|
audio_buffer = io.BytesIO()
|
|
sf.write(audio_buffer, audio_data, sample_rate, format='WAV')
|
|
audio_bytes = audio_buffer.getvalue()
|
|
audio_duration = len(audio_data) / sample_rate
|
|
else:
|
|
# File path
|
|
with open(audio_input, 'rb') as f:
|
|
audio_bytes = f.read()
|
|
# Get duration
|
|
audio_data, sample_rate = sf.read(audio_input)
|
|
audio_duration = len(audio_data) / sample_rate
|
|
|
|
# Prepare request
|
|
lang_code = LANGUAGES.get(language)
|
|
|
|
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
|
|
data = {"response_format": "json"}
|
|
|
|
if lang_code:
|
|
data["language"] = lang_code
|
|
|
|
# Choose endpoint based on task
|
|
if task == "Translate to English":
|
|
endpoint = f"{STT_URL}/v1/audio/translations"
|
|
else:
|
|
endpoint = f"{STT_URL}/v1/audio/transcriptions"
|
|
|
|
# Send request
|
|
response = client.post(endpoint, files=files, data=data)
|
|
response.raise_for_status()
|
|
|
|
latency = time.time() - start_time
|
|
result = response.json()
|
|
|
|
text = result.get("text", "")
|
|
detected_language = result.get("language", "unknown")
|
|
|
|
# Status message
|
|
status = f"✅ Transcribed {audio_duration:.1f}s of audio in {latency*1000:.0f}ms"
|
|
|
|
# Metrics
|
|
metrics = f"""
|
|
**Transcription Statistics:**
|
|
- Audio Duration: {audio_duration:.2f} seconds
|
|
- Processing Time: {latency*1000:.0f}ms
|
|
- Real-time Factor: {latency/audio_duration:.2f}x
|
|
- Detected Language: {detected_language}
|
|
- Task: {task}
|
|
- Word Count: {len(text.split())}
|
|
- Character Count: {len(text)}
|
|
"""
|
|
|
|
return status, text, metrics
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
logger.exception("STT request failed")
|
|
return f"❌ STT service error: {e.response.status_code}", "", ""
|
|
except Exception as e:
|
|
logger.exception("Transcription failed")
|
|
return f"❌ Error: {str(e)}", "", ""
|
|
|
|
|
|
def check_service_health() -> str:
|
|
"""Check if the STT service is healthy."""
|
|
try:
|
|
response = client.get(f"{STT_URL}/health", timeout=5.0)
|
|
if response.status_code == 200:
|
|
return "🟢 Service is healthy"
|
|
|
|
# Try v1/models endpoint (OpenAI-compatible)
|
|
response = client.get(f"{STT_URL}/v1/models", timeout=5.0)
|
|
if response.status_code == 200:
|
|
return "🟢 Service is healthy"
|
|
|
|
return f"🟡 Service returned status {response.status_code}"
|
|
except Exception as e:
|
|
return f"🔴 Service unavailable: {str(e)}"
|
|
|
|
|
|
# Build the Gradio app
|
|
with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="STT Demo") as demo:
|
|
gr.Markdown("""
|
|
# 🎙️ Speech-to-Text Demo
|
|
|
|
Test the **Whisper** speech-to-text service. Transcribe audio from microphone
|
|
or file upload with support for 100+ languages.
|
|
""")
|
|
|
|
# Service status
|
|
with gr.Row():
|
|
health_btn = gr.Button("🔄 Check Service", size="sm")
|
|
health_status = gr.Textbox(label="Service Status", interactive=False)
|
|
|
|
health_btn.click(fn=check_service_health, outputs=health_status)
|
|
|
|
with gr.Tabs():
|
|
# Tab 1: Microphone Input
|
|
with gr.TabItem("🎤 Microphone"):
|
|
with gr.Row():
|
|
with gr.Column():
|
|
mic_input = gr.Audio(
|
|
label="Record Audio",
|
|
sources=["microphone"],
|
|
type="numpy"
|
|
)
|
|
|
|
with gr.Row():
|
|
mic_language = gr.Dropdown(
|
|
choices=list(LANGUAGES.keys()),
|
|
value="Auto-detect",
|
|
label="Language"
|
|
)
|
|
mic_task = gr.Radio(
|
|
choices=["Transcribe", "Translate to English"],
|
|
value="Transcribe",
|
|
label="Task"
|
|
)
|
|
|
|
mic_btn = gr.Button("🎯 Transcribe", variant="primary")
|
|
|
|
with gr.Column():
|
|
mic_status = gr.Textbox(label="Status", interactive=False)
|
|
mic_metrics = gr.Markdown(label="Metrics")
|
|
|
|
mic_output = gr.Textbox(
|
|
label="Transcription",
|
|
lines=5
|
|
)
|
|
|
|
mic_btn.click(
|
|
fn=transcribe_audio,
|
|
inputs=[mic_input, mic_language, mic_task],
|
|
outputs=[mic_status, mic_output, mic_metrics]
|
|
)
|
|
|
|
# Tab 2: File Upload
|
|
with gr.TabItem("📁 File Upload"):
|
|
with gr.Row():
|
|
with gr.Column():
|
|
file_input = gr.Audio(
|
|
label="Upload Audio File",
|
|
sources=["upload"],
|
|
type="filepath"
|
|
)
|
|
|
|
with gr.Row():
|
|
file_language = gr.Dropdown(
|
|
choices=list(LANGUAGES.keys()),
|
|
value="Auto-detect",
|
|
label="Language"
|
|
)
|
|
file_task = gr.Radio(
|
|
choices=["Transcribe", "Translate to English"],
|
|
value="Transcribe",
|
|
label="Task"
|
|
)
|
|
|
|
file_btn = gr.Button("🎯 Transcribe", variant="primary")
|
|
|
|
with gr.Column():
|
|
file_status = gr.Textbox(label="Status", interactive=False)
|
|
file_metrics = gr.Markdown(label="Metrics")
|
|
|
|
file_output = gr.Textbox(
|
|
label="Transcription",
|
|
lines=5
|
|
)
|
|
|
|
file_btn.click(
|
|
fn=transcribe_audio,
|
|
inputs=[file_input, file_language, file_task],
|
|
outputs=[file_status, file_output, file_metrics]
|
|
)
|
|
|
|
gr.Markdown("""
|
|
**Supported formats:** WAV, MP3, FLAC, OGG, M4A, WEBM
|
|
|
|
*For best results, use clear audio with minimal background noise.*
|
|
""")
|
|
|
|
# Tab 3: Translation
|
|
with gr.TabItem("🌍 Translation"):
|
|
gr.Markdown("""
|
|
### Speech Translation
|
|
|
|
Upload or record audio in any language and get English translation.
|
|
Whisper will automatically detect the source language.
|
|
""")
|
|
|
|
with gr.Row():
|
|
with gr.Column():
|
|
trans_input = gr.Audio(
|
|
label="Audio Input",
|
|
sources=["microphone", "upload"],
|
|
type="numpy"
|
|
)
|
|
trans_btn = gr.Button("🌍 Translate to English", variant="primary")
|
|
|
|
with gr.Column():
|
|
trans_status = gr.Textbox(label="Status", interactive=False)
|
|
trans_metrics = gr.Markdown(label="Metrics")
|
|
|
|
trans_output = gr.Textbox(
|
|
label="English Translation",
|
|
lines=5
|
|
)
|
|
|
|
def translate_audio(audio):
|
|
return transcribe_audio(audio, "Auto-detect", "Translate to English")
|
|
|
|
trans_btn.click(
|
|
fn=translate_audio,
|
|
inputs=trans_input,
|
|
outputs=[trans_status, trans_output, trans_metrics]
|
|
)
|
|
|
|
create_footer()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
demo.launch(
|
|
server_name="0.0.0.0",
|
|
server_port=7860,
|
|
show_error=True
|
|
)
|