fix: remove unused imports and apply ruff formatting
- Remove unused imports: json (llm.py), tempfile (stt.py), base64 (tts.py) - Apply ruff format to all Python files
This commit is contained in:
161
stt.py
161
stt.py
@@ -9,11 +9,11 @@ Features:
|
||||
- Translation mode
|
||||
- MLflow metrics logging
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import io
|
||||
import tempfile
|
||||
|
||||
import gradio as gr
|
||||
import httpx
|
||||
@@ -30,11 +30,10 @@ logger = logging.getLogger("stt-demo")
|
||||
STT_URL = os.environ.get(
|
||||
"STT_URL",
|
||||
# Default: Ray Serve whisper endpoint
|
||||
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper"
|
||||
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper",
|
||||
)
|
||||
MLFLOW_TRACKING_URI = os.environ.get(
|
||||
"MLFLOW_TRACKING_URI",
|
||||
"http://mlflow.mlflow.svc.cluster.local:80"
|
||||
"MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"
|
||||
)
|
||||
|
||||
# ─── MLflow experiment tracking ──────────────────────────────────────────
|
||||
@@ -62,7 +61,9 @@ try:
|
||||
_mlflow_run_id = _mlflow_run.info.run_id
|
||||
_mlflow_step = 0
|
||||
MLFLOW_ENABLED = True
|
||||
logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id)
|
||||
logger.info(
|
||||
"MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("MLflow tracking disabled: %s", exc)
|
||||
_mlflow_client = None
|
||||
@@ -72,7 +73,10 @@ except Exception as exc:
|
||||
|
||||
|
||||
def _log_stt_metrics(
|
||||
latency: float, audio_duration: float, word_count: int, task: str,
|
||||
latency: float,
|
||||
audio_duration: float,
|
||||
word_count: int,
|
||||
task: str,
|
||||
) -> None:
|
||||
"""Log STT inference metrics to MLflow (non-blocking best-effort)."""
|
||||
global _mlflow_step
|
||||
@@ -86,11 +90,15 @@ def _log_stt_metrics(
|
||||
_mlflow_run_id,
|
||||
metrics=[
|
||||
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
|
||||
mlflow.entities.Metric("audio_duration_s", audio_duration, ts, _mlflow_step),
|
||||
mlflow.entities.Metric(
|
||||
"audio_duration_s", audio_duration, ts, _mlflow_step
|
||||
),
|
||||
mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step),
|
||||
mlflow.entities.Metric("word_count", word_count, ts, _mlflow_step),
|
||||
],
|
||||
params=[] if _mlflow_step > 1 else [
|
||||
params=[]
|
||||
if _mlflow_step > 1
|
||||
else [
|
||||
mlflow.entities.Param("task", task),
|
||||
],
|
||||
)
|
||||
@@ -124,57 +132,55 @@ LANGUAGES = {
|
||||
|
||||
|
||||
def transcribe_audio(
|
||||
audio_input: tuple[int, np.ndarray] | str | None,
|
||||
language: str,
|
||||
task: str
|
||||
audio_input: tuple[int, np.ndarray] | str | None, language: str, task: str
|
||||
) -> tuple[str, str, str]:
|
||||
"""Transcribe audio using the Whisper STT service."""
|
||||
if audio_input is None:
|
||||
return "❌ Please provide audio input", "", ""
|
||||
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# Handle different input types
|
||||
if isinstance(audio_input, tuple):
|
||||
# Microphone input: (sample_rate, audio_data)
|
||||
sample_rate, audio_data = audio_input
|
||||
|
||||
|
||||
# Convert to WAV bytes
|
||||
audio_buffer = io.BytesIO()
|
||||
sf.write(audio_buffer, audio_data, sample_rate, format='WAV')
|
||||
sf.write(audio_buffer, audio_data, sample_rate, format="WAV")
|
||||
audio_bytes = audio_buffer.getvalue()
|
||||
audio_duration = len(audio_data) / sample_rate
|
||||
else:
|
||||
# File path
|
||||
with open(audio_input, 'rb') as f:
|
||||
with open(audio_input, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
# Get duration
|
||||
audio_data, sample_rate = sf.read(audio_input)
|
||||
audio_duration = len(audio_data) / sample_rate
|
||||
|
||||
|
||||
# Prepare request
|
||||
lang_code = LANGUAGES.get(language)
|
||||
|
||||
|
||||
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
|
||||
data = {"response_format": "json"}
|
||||
|
||||
|
||||
if lang_code:
|
||||
data["language"] = lang_code
|
||||
|
||||
|
||||
# Choose endpoint based on task
|
||||
if task == "Translate to English":
|
||||
endpoint = f"{STT_URL}/v1/audio/translations"
|
||||
else:
|
||||
endpoint = f"{STT_URL}/v1/audio/transcriptions"
|
||||
|
||||
|
||||
# Send request
|
||||
response = client.post(endpoint, files=files, data=data)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
latency = time.time() - start_time
|
||||
result = response.json()
|
||||
|
||||
|
||||
text = result.get("text", "")
|
||||
detected_language = result.get("language", "unknown")
|
||||
|
||||
@@ -185,24 +191,26 @@ def transcribe_audio(
|
||||
word_count=len(text.split()),
|
||||
task=task,
|
||||
)
|
||||
|
||||
|
||||
# Status message
|
||||
status = f"✅ Transcribed {audio_duration:.1f}s of audio in {latency*1000:.0f}ms"
|
||||
|
||||
status = (
|
||||
f"✅ Transcribed {audio_duration:.1f}s of audio in {latency * 1000:.0f}ms"
|
||||
)
|
||||
|
||||
# Metrics
|
||||
metrics = f"""
|
||||
**Transcription Statistics:**
|
||||
- Audio Duration: {audio_duration:.2f} seconds
|
||||
- Processing Time: {latency*1000:.0f}ms
|
||||
- Real-time Factor: {latency/audio_duration:.2f}x
|
||||
- Processing Time: {latency * 1000:.0f}ms
|
||||
- Real-time Factor: {latency / audio_duration:.2f}x
|
||||
- Detected Language: {detected_language}
|
||||
- Task: {task}
|
||||
- Word Count: {len(text.split())}
|
||||
- Character Count: {len(text)}
|
||||
"""
|
||||
|
||||
|
||||
return status, text, metrics
|
||||
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.exception("STT request failed")
|
||||
return f"❌ STT service error: {e.response.status_code}", "", ""
|
||||
@@ -217,12 +225,12 @@ def check_service_health() -> str:
|
||||
response = client.get(f"{STT_URL}/health", timeout=5.0)
|
||||
if response.status_code == 200:
|
||||
return "🟢 Service is healthy"
|
||||
|
||||
|
||||
# Try v1/models endpoint (OpenAI-compatible)
|
||||
response = client.get(f"{STT_URL}/v1/models", timeout=5.0)
|
||||
if response.status_code == 200:
|
||||
return "🟢 Service is healthy"
|
||||
|
||||
|
||||
return f"🟡 Service returned status {response.status_code}"
|
||||
except Exception as e:
|
||||
return f"🔴 Service unavailable: {str(e)}"
|
||||
@@ -236,99 +244,89 @@ with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="STT Demo") as demo:
|
||||
Test the **Whisper** speech-to-text service. Transcribe audio from microphone
|
||||
or file upload with support for 100+ languages.
|
||||
""")
|
||||
|
||||
|
||||
# Service status
|
||||
with gr.Row():
|
||||
health_btn = gr.Button("🔄 Check Service", size="sm")
|
||||
health_status = gr.Textbox(label="Service Status", interactive=False)
|
||||
|
||||
|
||||
health_btn.click(fn=check_service_health, outputs=health_status)
|
||||
|
||||
|
||||
with gr.Tabs():
|
||||
# Tab 1: Microphone Input
|
||||
with gr.TabItem("🎤 Microphone"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
mic_input = gr.Audio(
|
||||
label="Record Audio",
|
||||
sources=["microphone"],
|
||||
type="numpy"
|
||||
label="Record Audio", sources=["microphone"], type="numpy"
|
||||
)
|
||||
|
||||
|
||||
with gr.Row():
|
||||
mic_language = gr.Dropdown(
|
||||
choices=list(LANGUAGES.keys()),
|
||||
value="Auto-detect",
|
||||
label="Language"
|
||||
label="Language",
|
||||
)
|
||||
mic_task = gr.Radio(
|
||||
choices=["Transcribe", "Translate to English"],
|
||||
value="Transcribe",
|
||||
label="Task"
|
||||
label="Task",
|
||||
)
|
||||
|
||||
|
||||
mic_btn = gr.Button("🎯 Transcribe", variant="primary")
|
||||
|
||||
|
||||
with gr.Column():
|
||||
mic_status = gr.Textbox(label="Status", interactive=False)
|
||||
mic_metrics = gr.Markdown(label="Metrics")
|
||||
|
||||
mic_output = gr.Textbox(
|
||||
label="Transcription",
|
||||
lines=5
|
||||
)
|
||||
|
||||
|
||||
mic_output = gr.Textbox(label="Transcription", lines=5)
|
||||
|
||||
mic_btn.click(
|
||||
fn=transcribe_audio,
|
||||
inputs=[mic_input, mic_language, mic_task],
|
||||
outputs=[mic_status, mic_output, mic_metrics]
|
||||
outputs=[mic_status, mic_output, mic_metrics],
|
||||
)
|
||||
|
||||
|
||||
# Tab 2: File Upload
|
||||
with gr.TabItem("📁 File Upload"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
file_input = gr.Audio(
|
||||
label="Upload Audio File",
|
||||
sources=["upload"],
|
||||
type="filepath"
|
||||
label="Upload Audio File", sources=["upload"], type="filepath"
|
||||
)
|
||||
|
||||
|
||||
with gr.Row():
|
||||
file_language = gr.Dropdown(
|
||||
choices=list(LANGUAGES.keys()),
|
||||
value="Auto-detect",
|
||||
label="Language"
|
||||
label="Language",
|
||||
)
|
||||
file_task = gr.Radio(
|
||||
choices=["Transcribe", "Translate to English"],
|
||||
value="Transcribe",
|
||||
label="Task"
|
||||
label="Task",
|
||||
)
|
||||
|
||||
|
||||
file_btn = gr.Button("🎯 Transcribe", variant="primary")
|
||||
|
||||
|
||||
with gr.Column():
|
||||
file_status = gr.Textbox(label="Status", interactive=False)
|
||||
file_metrics = gr.Markdown(label="Metrics")
|
||||
|
||||
file_output = gr.Textbox(
|
||||
label="Transcription",
|
||||
lines=5
|
||||
)
|
||||
|
||||
|
||||
file_output = gr.Textbox(label="Transcription", lines=5)
|
||||
|
||||
file_btn.click(
|
||||
fn=transcribe_audio,
|
||||
inputs=[file_input, file_language, file_task],
|
||||
outputs=[file_status, file_output, file_metrics]
|
||||
outputs=[file_status, file_output, file_metrics],
|
||||
)
|
||||
|
||||
|
||||
gr.Markdown("""
|
||||
**Supported formats:** WAV, MP3, FLAC, OGG, M4A, WEBM
|
||||
|
||||
*For best results, use clear audio with minimal background noise.*
|
||||
""")
|
||||
|
||||
|
||||
# Tab 3: Translation
|
||||
with gr.TabItem("🌍 Translation"):
|
||||
gr.Markdown("""
|
||||
@@ -337,40 +335,33 @@ or file upload with support for 100+ languages.
|
||||
Upload or record audio in any language and get English translation.
|
||||
Whisper will automatically detect the source language.
|
||||
""")
|
||||
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
trans_input = gr.Audio(
|
||||
label="Audio Input",
|
||||
sources=["microphone", "upload"],
|
||||
type="numpy"
|
||||
type="numpy",
|
||||
)
|
||||
trans_btn = gr.Button("🌍 Translate to English", variant="primary")
|
||||
|
||||
|
||||
with gr.Column():
|
||||
trans_status = gr.Textbox(label="Status", interactive=False)
|
||||
trans_metrics = gr.Markdown(label="Metrics")
|
||||
|
||||
trans_output = gr.Textbox(
|
||||
label="English Translation",
|
||||
lines=5
|
||||
)
|
||||
|
||||
|
||||
trans_output = gr.Textbox(label="English Translation", lines=5)
|
||||
|
||||
def translate_audio(audio):
|
||||
return transcribe_audio(audio, "Auto-detect", "Translate to English")
|
||||
|
||||
|
||||
trans_btn.click(
|
||||
fn=translate_audio,
|
||||
inputs=trans_input,
|
||||
outputs=[trans_status, trans_output, trans_metrics]
|
||||
outputs=[trans_status, trans_output, trans_metrics],
|
||||
)
|
||||
|
||||
|
||||
create_footer()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch(
|
||||
server_name="0.0.0.0",
|
||||
server_port=7860,
|
||||
show_error=True
|
||||
)
|
||||
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|
||||
|
||||
Reference in New Issue
Block a user