fix: remove unused imports and apply ruff formatting

- Remove unused imports: json (llm.py), tempfile (stt.py), base64 (tts.py)
- Apply ruff format to all Python files
This commit is contained in:
2026-02-18 18:36:16 -05:00
parent 0cc03aa145
commit 8552a02a25
5 changed files with 283 additions and 246 deletions

View File

@@ -9,6 +9,7 @@ Features:
- MLflow metrics logging
- Visual embedding dimension display
"""
import os
import time
import logging
@@ -26,9 +27,9 @@ logger = logging.getLogger("embeddings-demo")
# Configuration
EMBEDDINGS_URL = os.environ.get(
"EMBEDDINGS_URL",
"EMBEDDINGS_URL",
# Default: Ray Serve Embeddings endpoint
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/embeddings"
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/embeddings",
)
# ─── MLflow experiment tracking ──────────────────────────────────────────
try:
@@ -59,7 +60,9 @@ try:
_mlflow_run_id = _mlflow_run.info.run_id
_mlflow_step = 0
MLFLOW_ENABLED = True
logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id)
logger.info(
"MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id
)
except Exception as exc:
logger.warning("MLflow tracking disabled: %s", exc)
_mlflow_client = None
@@ -68,7 +71,9 @@ except Exception as exc:
MLFLOW_ENABLED = False
def _log_embedding_metrics(latency: float, batch_size: int, embedding_dims: int = 0) -> None:
def _log_embedding_metrics(
latency: float, batch_size: int, embedding_dims: int = 0
) -> None:
"""Log embedding inference metrics to MLflow (non-blocking best-effort)."""
global _mlflow_step
if not MLFLOW_ENABLED or _mlflow_client is None:
@@ -81,8 +86,15 @@ def _log_embedding_metrics(latency: float, batch_size: int, embedding_dims: int
metrics=[
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
mlflow.entities.Metric("batch_size", batch_size, ts, _mlflow_step),
mlflow.entities.Metric("embedding_dims", embedding_dims, ts, _mlflow_step),
mlflow.entities.Metric("latency_per_text_ms", (latency * 1000 / batch_size) if batch_size > 0 else 0, ts, _mlflow_step),
mlflow.entities.Metric(
"embedding_dims", embedding_dims, ts, _mlflow_step
),
mlflow.entities.Metric(
"latency_per_text_ms",
(latency * 1000 / batch_size) if batch_size > 0 else 0,
ts,
_mlflow_step,
),
],
)
except Exception:
@@ -96,17 +108,16 @@ client = httpx.Client(timeout=60.0)
def get_embeddings(texts: list[str]) -> tuple[list[list[float]], float]:
"""Get embeddings from the embeddings service."""
start_time = time.time()
response = client.post(
f"{EMBEDDINGS_URL}/embeddings",
json={"input": texts, "model": "bge"}
f"{EMBEDDINGS_URL}/embeddings", json={"input": texts, "model": "bge"}
)
response.raise_for_status()
latency = time.time() - start_time
result = response.json()
embeddings = [d["embedding"] for d in result.get("data", [])]
return embeddings, latency
@@ -121,29 +132,29 @@ def generate_single_embedding(text: str) -> tuple[str, str, str]:
"""Generate embedding for a single text."""
if not text.strip():
return "❌ Please enter some text", "", ""
try:
embeddings, latency = get_embeddings([text])
if not embeddings:
return "❌ No embedding returned", "", ""
embedding = embeddings[0]
dims = len(embedding)
# Log to MLflow
_log_embedding_metrics(latency, batch_size=1, embedding_dims=dims)
# Format output
status = f"✅ Generated {dims}-dimensional embedding in {latency*1000:.1f}ms"
status = f"✅ Generated {dims}-dimensional embedding in {latency * 1000:.1f}ms"
# Show first/last few dimensions
preview = f"Dimensions: {dims}\n\n"
preview += "First 10 values:\n"
preview += json.dumps(embedding[:10], indent=2)
preview += "\n\n...\n\nLast 10 values:\n"
preview += json.dumps(embedding[-10:], indent=2)
# Stats
stats = f"""
**Embedding Statistics:**
@@ -153,11 +164,11 @@ def generate_single_embedding(text: str) -> tuple[str, str, str]:
- Mean: {np.mean(embedding):.6f}
- Std: {np.std(embedding):.6f}
- L2 Norm: {np.linalg.norm(embedding):.6f}
- Latency: {latency*1000:.1f}ms
- Latency: {latency * 1000:.1f}ms
"""
return status, preview, stats
except Exception as e:
logger.exception("Embedding generation failed")
return f"❌ Error: {str(e)}", "", ""
@@ -167,18 +178,18 @@ def compare_texts(text1: str, text2: str) -> tuple[str, str]:
"""Compare similarity between two texts."""
if not text1.strip() or not text2.strip():
return "❌ Please enter both texts", ""
try:
embeddings, latency = get_embeddings([text1, text2])
if len(embeddings) != 2:
return "❌ Failed to get embeddings for both texts", ""
similarity = cosine_similarity(embeddings[0], embeddings[1])
# Log to MLflow
_log_embedding_metrics(latency, batch_size=2, embedding_dims=len(embeddings[0]))
# Determine similarity level
if similarity > 0.9:
level = "🟢 Very High"
@@ -192,7 +203,7 @@ def compare_texts(text1: str, text2: str) -> tuple[str, str]:
else:
level = "🔴 Low"
desc = "These texts are semantically different"
result = f"""
## Similarity Score: {similarity:.4f}
@@ -201,17 +212,17 @@ def compare_texts(text1: str, text2: str) -> tuple[str, str]:
{desc}
---
*Computed in {latency*1000:.1f}ms*
*Computed in {latency * 1000:.1f}ms*
"""
# Create a simple visual bar
bar_length = 50
filled = int(similarity * bar_length)
bar = "" * filled + "" * (bar_length - filled)
visual = f"[{bar}] {similarity*100:.1f}%"
visual = f"[{bar}] {similarity * 100:.1f}%"
return result, visual
except Exception as e:
logger.exception("Comparison failed")
return f"❌ Error: {str(e)}", ""
@@ -220,19 +231,23 @@ def compare_texts(text1: str, text2: str) -> tuple[str, str]:
def batch_embed(texts_input: str) -> tuple[str, str]:
"""Generate embeddings for multiple texts (one per line)."""
texts = [t.strip() for t in texts_input.strip().split("\n") if t.strip()]
if not texts:
return "❌ Please enter at least one text (one per line)", ""
try:
embeddings, latency = get_embeddings(texts)
# Log to MLflow
_log_embedding_metrics(latency, batch_size=len(embeddings), embedding_dims=len(embeddings[0]) if embeddings else 0)
status = f"✅ Generated {len(embeddings)} embeddings in {latency*1000:.1f}ms"
status += f" ({latency*1000/len(texts):.1f}ms per text)"
# Log to MLflow
_log_embedding_metrics(
latency,
batch_size=len(embeddings),
embedding_dims=len(embeddings[0]) if embeddings else 0,
)
status = f"✅ Generated {len(embeddings)} embeddings in {latency * 1000:.1f}ms"
status += f" ({latency * 1000 / len(texts):.1f}ms per text)"
# Build similarity matrix
n = len(embeddings)
matrix = []
@@ -242,16 +257,16 @@ def batch_embed(texts_input: str) -> tuple[str, str]:
sim = cosine_similarity(embeddings[i], embeddings[j])
row.append(f"{sim:.3f}")
matrix.append(row)
# Format as table
header = "| | " + " | ".join([f"Text {i+1}" for i in range(n)]) + " |"
header = "| | " + " | ".join([f"Text {i + 1}" for i in range(n)]) + " |"
separator = "|---" + "|---" * n + "|"
rows = []
for i, row in enumerate(matrix):
rows.append(f"| **Text {i+1}** | " + " | ".join(row) + " |")
rows.append(f"| **Text {i + 1}** | " + " | ".join(row) + " |")
table = "\n".join([header, separator] + rows)
result = f"""
## Similarity Matrix
@@ -261,10 +276,10 @@ def batch_embed(texts_input: str) -> tuple[str, str]:
**Texts processed:**
"""
for i, text in enumerate(texts):
result += f"\n{i+1}. {text[:50]}{'...' if len(text) > 50 else ''}"
result += f"\n{i + 1}. {text[:50]}{'...' if len(text) > 50 else ''}"
return status, result
except Exception as e:
logger.exception("Batch embedding failed")
return f"❌ Error: {str(e)}", ""
@@ -290,14 +305,14 @@ with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="Embeddings Demo") a
Test the **BGE Embeddings** service for semantic text encoding.
Generate embeddings, compare text similarity, and explore vector representations.
""")
# Service status
with gr.Row():
health_btn = gr.Button("🔄 Check Service", size="sm")
health_status = gr.Textbox(label="Service Status", interactive=False)
health_btn.click(fn=check_service_health, outputs=health_status)
with gr.Tabs():
# Tab 1: Single Embedding
with gr.TabItem("📝 Single Text"):
@@ -306,71 +321,74 @@ Generate embeddings, compare text similarity, and explore vector representations
single_input = gr.Textbox(
label="Input Text",
placeholder="Enter text to generate embeddings...",
lines=3
lines=3,
)
single_btn = gr.Button("Generate Embedding", variant="primary")
with gr.Column():
single_status = gr.Textbox(label="Status", interactive=False)
single_stats = gr.Markdown(label="Statistics")
single_preview = gr.Code(label="Embedding Preview", language="json")
single_btn.click(
fn=generate_single_embedding,
inputs=single_input,
outputs=[single_status, single_preview, single_stats]
outputs=[single_status, single_preview, single_stats],
)
# Tab 2: Compare Texts
with gr.TabItem("⚖️ Compare Texts"):
gr.Markdown("Compare the semantic similarity between two texts.")
with gr.Row():
compare_text1 = gr.Textbox(label="Text 1", lines=3)
compare_text2 = gr.Textbox(label="Text 2", lines=3)
compare_btn = gr.Button("Compare Similarity", variant="primary")
with gr.Row():
compare_result = gr.Markdown(label="Result")
compare_visual = gr.Textbox(label="Similarity Bar", interactive=False)
compare_btn.click(
fn=compare_texts,
inputs=[compare_text1, compare_text2],
outputs=[compare_result, compare_visual]
outputs=[compare_result, compare_visual],
)
# Example pairs
gr.Examples(
examples=[
["The cat sat on the mat.", "A feline was resting on the rug."],
["Machine learning is a subset of AI.", "Deep learning uses neural networks."],
[
"Machine learning is a subset of AI.",
"Deep learning uses neural networks.",
],
["I love pizza.", "The stock market crashed today."],
],
inputs=[compare_text1, compare_text2],
)
# Tab 3: Batch Embeddings
with gr.TabItem("📚 Batch Processing"):
gr.Markdown("Generate embeddings for multiple texts and see their similarity matrix.")
gr.Markdown(
"Generate embeddings for multiple texts and see their similarity matrix."
)
batch_input = gr.Textbox(
label="Texts (one per line)",
placeholder="Enter multiple texts, one per line...",
lines=6
lines=6,
)
batch_btn = gr.Button("Process Batch", variant="primary")
batch_status = gr.Textbox(label="Status", interactive=False)
batch_result = gr.Markdown(label="Similarity Matrix")
batch_btn.click(
fn=batch_embed,
inputs=batch_input,
outputs=[batch_status, batch_result]
fn=batch_embed, inputs=batch_input, outputs=[batch_status, batch_result]
)
gr.Examples(
examples=[
"Python is a programming language.\nJava is also a programming language.\nCoffee is a beverage.",
@@ -378,13 +396,9 @@ Generate embeddings, compare text similarity, and explore vector representations
],
inputs=batch_input,
)
create_footer()
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)

35
llm.py
View File

@@ -9,10 +9,10 @@ Features:
- Token usage and latency metrics
- Chat history management
"""
import os
import time
import logging
import json
import gradio as gr
import httpx
@@ -65,7 +65,9 @@ try:
_mlflow_run_id = _mlflow_run.info.run_id
_mlflow_step = 0
MLFLOW_ENABLED = True
logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id)
logger.info(
"MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id
)
except Exception as exc:
logger.warning("MLflow tracking disabled: %s", exc)
_mlflow_client = None
@@ -95,18 +97,25 @@ def _log_llm_metrics(
_mlflow_run_id,
metrics=[
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
mlflow.entities.Metric("prompt_tokens", prompt_tokens, ts, _mlflow_step),
mlflow.entities.Metric("completion_tokens", completion_tokens, ts, _mlflow_step),
mlflow.entities.Metric(
"prompt_tokens", prompt_tokens, ts, _mlflow_step
),
mlflow.entities.Metric(
"completion_tokens", completion_tokens, ts, _mlflow_step
),
mlflow.entities.Metric("total_tokens", total_tokens, ts, _mlflow_step),
mlflow.entities.Metric("tokens_per_second", tps, ts, _mlflow_step),
mlflow.entities.Metric("temperature", temperature, ts, _mlflow_step),
mlflow.entities.Metric("max_tokens_requested", max_tokens, ts, _mlflow_step),
mlflow.entities.Metric(
"max_tokens_requested", max_tokens, ts, _mlflow_step
),
mlflow.entities.Metric("top_p", top_p, ts, _mlflow_step),
],
)
except Exception:
logger.debug("MLflow log failed", exc_info=True)
DEFAULT_SYSTEM_PROMPT = (
"You are a helpful AI assistant running on Davies Tech Labs homelab infrastructure. "
"You are powered by Llama 3.1 70B served via vLLM on AMD Strix Halo (ROCm). "
@@ -273,10 +282,10 @@ def single_prompt(
metrics = f"""
**Generation Metrics:**
- Latency: {latency:.1f}s
- Prompt tokens: {usage.get('prompt_tokens', 'N/A')}
- Completion tokens: {usage.get('completion_tokens', 'N/A')}
- Total tokens: {usage.get('total_tokens', 'N/A')}
- Model: {result.get('model', 'N/A')}
- Prompt tokens: {usage.get("prompt_tokens", "N/A")}
- Completion tokens: {usage.get("completion_tokens", "N/A")}
- Total tokens: {usage.get("total_tokens", "N/A")}
- Model: {result.get("model", "N/A")}
"""
return text, metrics
@@ -360,9 +369,13 @@ Chat with **Llama 3.1 70B** (AWQ INT4) served via vLLM on AMD Strix Halo (ROCm).
gr.Examples(
examples=[
["Summarise the key differences between CUDA and ROCm for ML workloads."],
[
"Summarise the key differences between CUDA and ROCm for ML workloads."
],
["Write a haiku about Kubernetes."],
["Explain Ray Serve in one paragraph for someone new to ML serving."],
[
"Explain Ray Serve in one paragraph for someone new to ML serving."
],
["List 5 creative uses for a homelab GPU cluster."],
],
inputs=[prompt_input],

161
stt.py
View File

@@ -9,11 +9,11 @@ Features:
- Translation mode
- MLflow metrics logging
"""
import os
import time
import logging
import io
import tempfile
import gradio as gr
import httpx
@@ -30,11 +30,10 @@ logger = logging.getLogger("stt-demo")
STT_URL = os.environ.get(
"STT_URL",
# Default: Ray Serve whisper endpoint
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper"
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper",
)
MLFLOW_TRACKING_URI = os.environ.get(
"MLFLOW_TRACKING_URI",
"http://mlflow.mlflow.svc.cluster.local:80"
"MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"
)
# ─── MLflow experiment tracking ──────────────────────────────────────────
@@ -62,7 +61,9 @@ try:
_mlflow_run_id = _mlflow_run.info.run_id
_mlflow_step = 0
MLFLOW_ENABLED = True
logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id)
logger.info(
"MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id
)
except Exception as exc:
logger.warning("MLflow tracking disabled: %s", exc)
_mlflow_client = None
@@ -72,7 +73,10 @@ except Exception as exc:
def _log_stt_metrics(
latency: float, audio_duration: float, word_count: int, task: str,
latency: float,
audio_duration: float,
word_count: int,
task: str,
) -> None:
"""Log STT inference metrics to MLflow (non-blocking best-effort)."""
global _mlflow_step
@@ -86,11 +90,15 @@ def _log_stt_metrics(
_mlflow_run_id,
metrics=[
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
mlflow.entities.Metric("audio_duration_s", audio_duration, ts, _mlflow_step),
mlflow.entities.Metric(
"audio_duration_s", audio_duration, ts, _mlflow_step
),
mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step),
mlflow.entities.Metric("word_count", word_count, ts, _mlflow_step),
],
params=[] if _mlflow_step > 1 else [
params=[]
if _mlflow_step > 1
else [
mlflow.entities.Param("task", task),
],
)
@@ -124,57 +132,55 @@ LANGUAGES = {
def transcribe_audio(
audio_input: tuple[int, np.ndarray] | str | None,
language: str,
task: str
audio_input: tuple[int, np.ndarray] | str | None, language: str, task: str
) -> tuple[str, str, str]:
"""Transcribe audio using the Whisper STT service."""
if audio_input is None:
return "❌ Please provide audio input", "", ""
try:
start_time = time.time()
# Handle different input types
if isinstance(audio_input, tuple):
# Microphone input: (sample_rate, audio_data)
sample_rate, audio_data = audio_input
# Convert to WAV bytes
audio_buffer = io.BytesIO()
sf.write(audio_buffer, audio_data, sample_rate, format='WAV')
sf.write(audio_buffer, audio_data, sample_rate, format="WAV")
audio_bytes = audio_buffer.getvalue()
audio_duration = len(audio_data) / sample_rate
else:
# File path
with open(audio_input, 'rb') as f:
with open(audio_input, "rb") as f:
audio_bytes = f.read()
# Get duration
audio_data, sample_rate = sf.read(audio_input)
audio_duration = len(audio_data) / sample_rate
# Prepare request
lang_code = LANGUAGES.get(language)
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
data = {"response_format": "json"}
if lang_code:
data["language"] = lang_code
# Choose endpoint based on task
if task == "Translate to English":
endpoint = f"{STT_URL}/v1/audio/translations"
else:
endpoint = f"{STT_URL}/v1/audio/transcriptions"
# Send request
response = client.post(endpoint, files=files, data=data)
response.raise_for_status()
latency = time.time() - start_time
result = response.json()
text = result.get("text", "")
detected_language = result.get("language", "unknown")
@@ -185,24 +191,26 @@ def transcribe_audio(
word_count=len(text.split()),
task=task,
)
# Status message
status = f"✅ Transcribed {audio_duration:.1f}s of audio in {latency*1000:.0f}ms"
status = (
f"✅ Transcribed {audio_duration:.1f}s of audio in {latency * 1000:.0f}ms"
)
# Metrics
metrics = f"""
**Transcription Statistics:**
- Audio Duration: {audio_duration:.2f} seconds
- Processing Time: {latency*1000:.0f}ms
- Real-time Factor: {latency/audio_duration:.2f}x
- Processing Time: {latency * 1000:.0f}ms
- Real-time Factor: {latency / audio_duration:.2f}x
- Detected Language: {detected_language}
- Task: {task}
- Word Count: {len(text.split())}
- Character Count: {len(text)}
"""
return status, text, metrics
except httpx.HTTPStatusError as e:
logger.exception("STT request failed")
return f"❌ STT service error: {e.response.status_code}", "", ""
@@ -217,12 +225,12 @@ def check_service_health() -> str:
response = client.get(f"{STT_URL}/health", timeout=5.0)
if response.status_code == 200:
return "🟢 Service is healthy"
# Try v1/models endpoint (OpenAI-compatible)
response = client.get(f"{STT_URL}/v1/models", timeout=5.0)
if response.status_code == 200:
return "🟢 Service is healthy"
return f"🟡 Service returned status {response.status_code}"
except Exception as e:
return f"🔴 Service unavailable: {str(e)}"
@@ -236,99 +244,89 @@ with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="STT Demo") as demo:
Test the **Whisper** speech-to-text service. Transcribe audio from microphone
or file upload with support for 100+ languages.
""")
# Service status
with gr.Row():
health_btn = gr.Button("🔄 Check Service", size="sm")
health_status = gr.Textbox(label="Service Status", interactive=False)
health_btn.click(fn=check_service_health, outputs=health_status)
with gr.Tabs():
# Tab 1: Microphone Input
with gr.TabItem("🎤 Microphone"):
with gr.Row():
with gr.Column():
mic_input = gr.Audio(
label="Record Audio",
sources=["microphone"],
type="numpy"
label="Record Audio", sources=["microphone"], type="numpy"
)
with gr.Row():
mic_language = gr.Dropdown(
choices=list(LANGUAGES.keys()),
value="Auto-detect",
label="Language"
label="Language",
)
mic_task = gr.Radio(
choices=["Transcribe", "Translate to English"],
value="Transcribe",
label="Task"
label="Task",
)
mic_btn = gr.Button("🎯 Transcribe", variant="primary")
with gr.Column():
mic_status = gr.Textbox(label="Status", interactive=False)
mic_metrics = gr.Markdown(label="Metrics")
mic_output = gr.Textbox(
label="Transcription",
lines=5
)
mic_output = gr.Textbox(label="Transcription", lines=5)
mic_btn.click(
fn=transcribe_audio,
inputs=[mic_input, mic_language, mic_task],
outputs=[mic_status, mic_output, mic_metrics]
outputs=[mic_status, mic_output, mic_metrics],
)
# Tab 2: File Upload
with gr.TabItem("📁 File Upload"):
with gr.Row():
with gr.Column():
file_input = gr.Audio(
label="Upload Audio File",
sources=["upload"],
type="filepath"
label="Upload Audio File", sources=["upload"], type="filepath"
)
with gr.Row():
file_language = gr.Dropdown(
choices=list(LANGUAGES.keys()),
value="Auto-detect",
label="Language"
label="Language",
)
file_task = gr.Radio(
choices=["Transcribe", "Translate to English"],
value="Transcribe",
label="Task"
label="Task",
)
file_btn = gr.Button("🎯 Transcribe", variant="primary")
with gr.Column():
file_status = gr.Textbox(label="Status", interactive=False)
file_metrics = gr.Markdown(label="Metrics")
file_output = gr.Textbox(
label="Transcription",
lines=5
)
file_output = gr.Textbox(label="Transcription", lines=5)
file_btn.click(
fn=transcribe_audio,
inputs=[file_input, file_language, file_task],
outputs=[file_status, file_output, file_metrics]
outputs=[file_status, file_output, file_metrics],
)
gr.Markdown("""
**Supported formats:** WAV, MP3, FLAC, OGG, M4A, WEBM
*For best results, use clear audio with minimal background noise.*
""")
# Tab 3: Translation
with gr.TabItem("🌍 Translation"):
gr.Markdown("""
@@ -337,40 +335,33 @@ or file upload with support for 100+ languages.
Upload or record audio in any language and get English translation.
Whisper will automatically detect the source language.
""")
with gr.Row():
with gr.Column():
trans_input = gr.Audio(
label="Audio Input",
sources=["microphone", "upload"],
type="numpy"
type="numpy",
)
trans_btn = gr.Button("🌍 Translate to English", variant="primary")
with gr.Column():
trans_status = gr.Textbox(label="Status", interactive=False)
trans_metrics = gr.Markdown(label="Metrics")
trans_output = gr.Textbox(
label="English Translation",
lines=5
)
trans_output = gr.Textbox(label="English Translation", lines=5)
def translate_audio(audio):
return transcribe_audio(audio, "Auto-detect", "Translate to English")
trans_btn.click(
fn=translate_audio,
inputs=trans_input,
outputs=[trans_status, trans_output, trans_metrics]
outputs=[trans_status, trans_output, trans_metrics],
)
create_footer()
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)

View File

@@ -3,6 +3,7 @@ Shared Gradio theme for Davies Tech Labs AI demos.
Consistent styling across all demo applications.
Cyberpunk aesthetic - dark with yellow/gold accents.
"""
import gradio as gr
@@ -25,7 +26,12 @@ def get_lab_theme() -> gr.Theme:
primary_hue=gr.themes.colors.yellow,
secondary_hue=gr.themes.colors.amber,
neutral_hue=gr.themes.colors.zinc,
font=[gr.themes.GoogleFont("Space Grotesk"), "ui-sans-serif", "system-ui", "sans-serif"],
font=[
gr.themes.GoogleFont("Space Grotesk"),
"ui-sans-serif",
"system-ui",
"sans-serif",
],
font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "ui-monospace", "monospace"],
).set(
# Background colors

155
tts.py
View File

@@ -9,11 +9,11 @@ Features:
- MLflow metrics logging
- Multiple TTS backends support (Coqui XTTS, Piper, etc.)
"""
import os
import time
import logging
import io
import base64
import gradio as gr
import httpx
@@ -30,11 +30,10 @@ logger = logging.getLogger("tts-demo")
TTS_URL = os.environ.get(
"TTS_URL",
# Default: Ray Serve TTS endpoint
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/tts"
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/tts",
)
MLFLOW_TRACKING_URI = os.environ.get(
"MLFLOW_TRACKING_URI",
"http://mlflow.mlflow.svc.cluster.local:80"
"MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"
)
# ─── MLflow experiment tracking ──────────────────────────────────────────
@@ -62,7 +61,9 @@ try:
_mlflow_run_id = _mlflow_run.info.run_id
_mlflow_step = 0
MLFLOW_ENABLED = True
logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id)
logger.info(
"MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id
)
except Exception as exc:
logger.warning("MLflow tracking disabled: %s", exc)
_mlflow_client = None
@@ -72,7 +73,10 @@ except Exception as exc:
def _log_tts_metrics(
latency: float, audio_duration: float, text_chars: int, language: str,
latency: float,
audio_duration: float,
text_chars: int,
language: str,
) -> None:
"""Log TTS inference metrics to MLflow (non-blocking best-effort)."""
global _mlflow_step
@@ -87,7 +91,9 @@ def _log_tts_metrics(
_mlflow_run_id,
metrics=[
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
mlflow.entities.Metric("audio_duration_s", audio_duration, ts, _mlflow_step),
mlflow.entities.Metric(
"audio_duration_s", audio_duration, ts, _mlflow_step
),
mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step),
mlflow.entities.Metric("chars_per_second", cps, ts, _mlflow_step),
mlflow.entities.Metric("text_chars", text_chars, ts, _mlflow_step),
@@ -121,38 +127,39 @@ LANGUAGES = {
}
def synthesize_speech(text: str, language: str) -> tuple[str, tuple[int, np.ndarray] | None, str]:
def synthesize_speech(
text: str, language: str
) -> tuple[str, tuple[int, np.ndarray] | None, str]:
"""Synthesize speech from text using the TTS service."""
if not text.strip():
return "❌ Please enter some text", None, ""
lang_code = LANGUAGES.get(language, "en")
try:
start_time = time.time()
# Call TTS service (Coqui XTTS API format)
response = client.get(
f"{TTS_URL}/api/tts",
params={"text": text, "language_id": lang_code}
f"{TTS_URL}/api/tts", params={"text": text, "language_id": lang_code}
)
response.raise_for_status()
latency = time.time() - start_time
audio_bytes = response.content
# Parse audio data
audio_io = io.BytesIO(audio_bytes)
audio_data, sample_rate = sf.read(audio_io)
# Calculate duration
if len(audio_data.shape) == 1:
duration = len(audio_data) / sample_rate
else:
duration = len(audio_data) / sample_rate
# Status message
status = f"✅ Generated {duration:.2f}s of audio in {latency*1000:.0f}ms"
status = f"✅ Generated {duration:.2f}s of audio in {latency * 1000:.0f}ms"
# Log to MLflow
_log_tts_metrics(
@@ -161,22 +168,22 @@ def synthesize_speech(text: str, language: str) -> tuple[str, tuple[int, np.ndar
text_chars=len(text),
language=lang_code,
)
# Metrics
metrics = f"""
**Audio Statistics:**
- Duration: {duration:.2f} seconds
- Sample Rate: {sample_rate} Hz
- Size: {len(audio_bytes) / 1024:.1f} KB
- Generation Time: {latency*1000:.0f}ms
- Real-time Factor: {latency/duration:.2f}x
- Generation Time: {latency * 1000:.0f}ms
- Real-time Factor: {latency / duration:.2f}x
- Language: {language} ({lang_code})
- Characters: {len(text)}
- Chars/sec: {len(text)/latency:.1f}
- Chars/sec: {len(text) / latency:.1f}
"""
return status, (sample_rate, audio_data), metrics
except httpx.HTTPStatusError as e:
logger.exception("TTS request failed")
return f"❌ TTS service error: {e.response.status_code}", None, ""
@@ -192,12 +199,12 @@ def check_service_health() -> str:
response = client.get(f"{TTS_URL}/health", timeout=5.0)
if response.status_code == 200:
return "🟢 Service is healthy"
# Fall back to root endpoint
response = client.get(f"{TTS_URL}/", timeout=5.0)
if response.status_code == 200:
return "🟢 Service is responding"
return f"🟡 Service returned status {response.status_code}"
except Exception as e:
return f"🔴 Service unavailable: {str(e)}"
@@ -211,14 +218,14 @@ with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="TTS Demo") as demo:
Test the **Coqui XTTS** text-to-speech service. Convert text to natural-sounding speech
in multiple languages.
""")
# Service status
with gr.Row():
health_btn = gr.Button("🔄 Check Service", size="sm")
health_status = gr.Textbox(label="Service Status", interactive=False)
health_btn.click(fn=check_service_health, outputs=health_status)
with gr.Tabs():
# Tab 1: Basic TTS
with gr.TabItem("🎤 Text to Speech"):
@@ -228,114 +235,120 @@ in multiple languages.
label="Text to Synthesize",
placeholder="Enter text to convert to speech...",
lines=5,
max_lines=10
max_lines=10,
)
with gr.Row():
language = gr.Dropdown(
choices=list(LANGUAGES.keys()),
value="English",
label="Language"
label="Language",
)
synthesize_btn = gr.Button("🔊 Synthesize", variant="primary", scale=2)
synthesize_btn = gr.Button(
"🔊 Synthesize", variant="primary", scale=2
)
with gr.Column(scale=1):
status_output = gr.Textbox(label="Status", interactive=False)
metrics_output = gr.Markdown(label="Metrics")
audio_output = gr.Audio(label="Generated Audio", type="numpy")
synthesize_btn.click(
fn=synthesize_speech,
inputs=[text_input, language],
outputs=[status_output, audio_output, metrics_output]
outputs=[status_output, audio_output, metrics_output],
)
# Example texts
gr.Examples(
examples=[
["Hello! Welcome to Davies Tech Labs. This is a demonstration of our text-to-speech system.", "English"],
["The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet.", "English"],
["Bonjour! Bienvenue au laboratoire technique de Davies.", "French"],
[
"Hello! Welcome to Davies Tech Labs. This is a demonstration of our text-to-speech system.",
"English",
],
[
"The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet.",
"English",
],
[
"Bonjour! Bienvenue au laboratoire technique de Davies.",
"French",
],
["Hola! Bienvenido al laboratorio de tecnología.", "Spanish"],
["Guten Tag! Willkommen im Techniklabor.", "German"],
],
inputs=[text_input, language],
)
# Tab 2: Comparison
with gr.TabItem("🔄 Language Comparison"):
gr.Markdown("Compare the same text in different languages.")
compare_text = gr.Textbox(
label="Text to Compare",
value="Hello, how are you today?",
lines=2
label="Text to Compare", value="Hello, how are you today?", lines=2
)
with gr.Row():
lang1 = gr.Dropdown(choices=list(LANGUAGES.keys()), value="English", label="Language 1")
lang2 = gr.Dropdown(choices=list(LANGUAGES.keys()), value="Spanish", label="Language 2")
lang1 = gr.Dropdown(
choices=list(LANGUAGES.keys()), value="English", label="Language 1"
)
lang2 = gr.Dropdown(
choices=list(LANGUAGES.keys()), value="Spanish", label="Language 2"
)
compare_btn = gr.Button("Compare Languages", variant="primary")
with gr.Row():
with gr.Column():
gr.Markdown("### Language 1")
audio1 = gr.Audio(label="Audio 1", type="numpy")
status1 = gr.Textbox(label="Status", interactive=False)
with gr.Column():
gr.Markdown("### Language 2")
audio2 = gr.Audio(label="Audio 2", type="numpy")
status2 = gr.Textbox(label="Status", interactive=False)
def compare_languages(text, l1, l2):
s1, a1, _ = synthesize_speech(text, l1)
s2, a2, _ = synthesize_speech(text, l2)
return s1, a1, s2, a2
compare_btn.click(
fn=compare_languages,
inputs=[compare_text, lang1, lang2],
outputs=[status1, audio1, status2, audio2]
outputs=[status1, audio1, status2, audio2],
)
# Tab 3: Batch Processing
with gr.TabItem("📚 Batch Synthesis"):
gr.Markdown("Synthesize multiple texts at once (one per line).")
batch_input = gr.Textbox(
label="Texts (one per line)",
placeholder="Enter multiple texts, one per line...",
lines=6
lines=6,
)
batch_lang = gr.Dropdown(
choices=list(LANGUAGES.keys()),
value="English",
label="Language"
choices=list(LANGUAGES.keys()), value="English", label="Language"
)
batch_btn = gr.Button("Synthesize All", variant="primary")
batch_status = gr.Textbox(label="Status", interactive=False)
batch_audios = gr.Dataset(
components=[gr.Audio(type="numpy")],
label="Generated Audio Files"
components=[gr.Audio(type="numpy")], label="Generated Audio Files"
)
# Note: Batch processing would need more complex handling
# This is a simplified version
gr.Markdown("""
*Note: For batch processing of many texts, consider using the API directly
or the Kubeflow pipeline for better throughput.*
""")
create_footer()
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)