fix: remove unused imports and apply ruff formatting
- Remove unused imports: json (llm.py), tempfile (stt.py), base64 (tts.py) - Apply ruff format to all Python files
This commit is contained in:
@@ -9,6 +9,7 @@ Features:
|
|||||||
- MLflow metrics logging
|
- MLflow metrics logging
|
||||||
- Visual embedding dimension display
|
- Visual embedding dimension display
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
@@ -28,7 +29,7 @@ logger = logging.getLogger("embeddings-demo")
|
|||||||
EMBEDDINGS_URL = os.environ.get(
|
EMBEDDINGS_URL = os.environ.get(
|
||||||
"EMBEDDINGS_URL",
|
"EMBEDDINGS_URL",
|
||||||
# Default: Ray Serve Embeddings endpoint
|
# Default: Ray Serve Embeddings endpoint
|
||||||
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/embeddings"
|
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/embeddings",
|
||||||
)
|
)
|
||||||
# ─── MLflow experiment tracking ──────────────────────────────────────────
|
# ─── MLflow experiment tracking ──────────────────────────────────────────
|
||||||
try:
|
try:
|
||||||
@@ -59,7 +60,9 @@ try:
|
|||||||
_mlflow_run_id = _mlflow_run.info.run_id
|
_mlflow_run_id = _mlflow_run.info.run_id
|
||||||
_mlflow_step = 0
|
_mlflow_step = 0
|
||||||
MLFLOW_ENABLED = True
|
MLFLOW_ENABLED = True
|
||||||
logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id)
|
logger.info(
|
||||||
|
"MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("MLflow tracking disabled: %s", exc)
|
logger.warning("MLflow tracking disabled: %s", exc)
|
||||||
_mlflow_client = None
|
_mlflow_client = None
|
||||||
@@ -68,7 +71,9 @@ except Exception as exc:
|
|||||||
MLFLOW_ENABLED = False
|
MLFLOW_ENABLED = False
|
||||||
|
|
||||||
|
|
||||||
def _log_embedding_metrics(latency: float, batch_size: int, embedding_dims: int = 0) -> None:
|
def _log_embedding_metrics(
|
||||||
|
latency: float, batch_size: int, embedding_dims: int = 0
|
||||||
|
) -> None:
|
||||||
"""Log embedding inference metrics to MLflow (non-blocking best-effort)."""
|
"""Log embedding inference metrics to MLflow (non-blocking best-effort)."""
|
||||||
global _mlflow_step
|
global _mlflow_step
|
||||||
if not MLFLOW_ENABLED or _mlflow_client is None:
|
if not MLFLOW_ENABLED or _mlflow_client is None:
|
||||||
@@ -81,8 +86,15 @@ def _log_embedding_metrics(latency: float, batch_size: int, embedding_dims: int
|
|||||||
metrics=[
|
metrics=[
|
||||||
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
|
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
|
||||||
mlflow.entities.Metric("batch_size", batch_size, ts, _mlflow_step),
|
mlflow.entities.Metric("batch_size", batch_size, ts, _mlflow_step),
|
||||||
mlflow.entities.Metric("embedding_dims", embedding_dims, ts, _mlflow_step),
|
mlflow.entities.Metric(
|
||||||
mlflow.entities.Metric("latency_per_text_ms", (latency * 1000 / batch_size) if batch_size > 0 else 0, ts, _mlflow_step),
|
"embedding_dims", embedding_dims, ts, _mlflow_step
|
||||||
|
),
|
||||||
|
mlflow.entities.Metric(
|
||||||
|
"latency_per_text_ms",
|
||||||
|
(latency * 1000 / batch_size) if batch_size > 0 else 0,
|
||||||
|
ts,
|
||||||
|
_mlflow_step,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -98,8 +110,7 @@ def get_embeddings(texts: list[str]) -> tuple[list[list[float]], float]:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
f"{EMBEDDINGS_URL}/embeddings",
|
f"{EMBEDDINGS_URL}/embeddings", json={"input": texts, "model": "bge"}
|
||||||
json={"input": texts, "model": "bge"}
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
@@ -228,7 +239,11 @@ def batch_embed(texts_input: str) -> tuple[str, str]:
|
|||||||
embeddings, latency = get_embeddings(texts)
|
embeddings, latency = get_embeddings(texts)
|
||||||
|
|
||||||
# Log to MLflow
|
# Log to MLflow
|
||||||
_log_embedding_metrics(latency, batch_size=len(embeddings), embedding_dims=len(embeddings[0]) if embeddings else 0)
|
_log_embedding_metrics(
|
||||||
|
latency,
|
||||||
|
batch_size=len(embeddings),
|
||||||
|
embedding_dims=len(embeddings[0]) if embeddings else 0,
|
||||||
|
)
|
||||||
|
|
||||||
status = f"✅ Generated {len(embeddings)} embeddings in {latency * 1000:.1f}ms"
|
status = f"✅ Generated {len(embeddings)} embeddings in {latency * 1000:.1f}ms"
|
||||||
status += f" ({latency * 1000 / len(texts):.1f}ms per text)"
|
status += f" ({latency * 1000 / len(texts):.1f}ms per text)"
|
||||||
@@ -306,7 +321,7 @@ Generate embeddings, compare text similarity, and explore vector representations
|
|||||||
single_input = gr.Textbox(
|
single_input = gr.Textbox(
|
||||||
label="Input Text",
|
label="Input Text",
|
||||||
placeholder="Enter text to generate embeddings...",
|
placeholder="Enter text to generate embeddings...",
|
||||||
lines=3
|
lines=3,
|
||||||
)
|
)
|
||||||
single_btn = gr.Button("Generate Embedding", variant="primary")
|
single_btn = gr.Button("Generate Embedding", variant="primary")
|
||||||
|
|
||||||
@@ -319,7 +334,7 @@ Generate embeddings, compare text similarity, and explore vector representations
|
|||||||
single_btn.click(
|
single_btn.click(
|
||||||
fn=generate_single_embedding,
|
fn=generate_single_embedding,
|
||||||
inputs=single_input,
|
inputs=single_input,
|
||||||
outputs=[single_status, single_preview, single_stats]
|
outputs=[single_status, single_preview, single_stats],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tab 2: Compare Texts
|
# Tab 2: Compare Texts
|
||||||
@@ -339,14 +354,17 @@ Generate embeddings, compare text similarity, and explore vector representations
|
|||||||
compare_btn.click(
|
compare_btn.click(
|
||||||
fn=compare_texts,
|
fn=compare_texts,
|
||||||
inputs=[compare_text1, compare_text2],
|
inputs=[compare_text1, compare_text2],
|
||||||
outputs=[compare_result, compare_visual]
|
outputs=[compare_result, compare_visual],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Example pairs
|
# Example pairs
|
||||||
gr.Examples(
|
gr.Examples(
|
||||||
examples=[
|
examples=[
|
||||||
["The cat sat on the mat.", "A feline was resting on the rug."],
|
["The cat sat on the mat.", "A feline was resting on the rug."],
|
||||||
["Machine learning is a subset of AI.", "Deep learning uses neural networks."],
|
[
|
||||||
|
"Machine learning is a subset of AI.",
|
||||||
|
"Deep learning uses neural networks.",
|
||||||
|
],
|
||||||
["I love pizza.", "The stock market crashed today."],
|
["I love pizza.", "The stock market crashed today."],
|
||||||
],
|
],
|
||||||
inputs=[compare_text1, compare_text2],
|
inputs=[compare_text1, compare_text2],
|
||||||
@@ -354,21 +372,21 @@ Generate embeddings, compare text similarity, and explore vector representations
|
|||||||
|
|
||||||
# Tab 3: Batch Embeddings
|
# Tab 3: Batch Embeddings
|
||||||
with gr.TabItem("📚 Batch Processing"):
|
with gr.TabItem("📚 Batch Processing"):
|
||||||
gr.Markdown("Generate embeddings for multiple texts and see their similarity matrix.")
|
gr.Markdown(
|
||||||
|
"Generate embeddings for multiple texts and see their similarity matrix."
|
||||||
|
)
|
||||||
|
|
||||||
batch_input = gr.Textbox(
|
batch_input = gr.Textbox(
|
||||||
label="Texts (one per line)",
|
label="Texts (one per line)",
|
||||||
placeholder="Enter multiple texts, one per line...",
|
placeholder="Enter multiple texts, one per line...",
|
||||||
lines=6
|
lines=6,
|
||||||
)
|
)
|
||||||
batch_btn = gr.Button("Process Batch", variant="primary")
|
batch_btn = gr.Button("Process Batch", variant="primary")
|
||||||
batch_status = gr.Textbox(label="Status", interactive=False)
|
batch_status = gr.Textbox(label="Status", interactive=False)
|
||||||
batch_result = gr.Markdown(label="Similarity Matrix")
|
batch_result = gr.Markdown(label="Similarity Matrix")
|
||||||
|
|
||||||
batch_btn.click(
|
batch_btn.click(
|
||||||
fn=batch_embed,
|
fn=batch_embed, inputs=batch_input, outputs=[batch_status, batch_result]
|
||||||
inputs=batch_input,
|
|
||||||
outputs=[batch_status, batch_result]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
gr.Examples(
|
gr.Examples(
|
||||||
@@ -383,8 +401,4 @@ Generate embeddings, compare text similarity, and explore vector representations
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
demo.launch(
|
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|
||||||
server_name="0.0.0.0",
|
|
||||||
server_port=7860,
|
|
||||||
show_error=True
|
|
||||||
)
|
|
||||||
|
|||||||
35
llm.py
35
llm.py
@@ -9,10 +9,10 @@ Features:
|
|||||||
- Token usage and latency metrics
|
- Token usage and latency metrics
|
||||||
- Chat history management
|
- Chat history management
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import json
|
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import httpx
|
import httpx
|
||||||
@@ -65,7 +65,9 @@ try:
|
|||||||
_mlflow_run_id = _mlflow_run.info.run_id
|
_mlflow_run_id = _mlflow_run.info.run_id
|
||||||
_mlflow_step = 0
|
_mlflow_step = 0
|
||||||
MLFLOW_ENABLED = True
|
MLFLOW_ENABLED = True
|
||||||
logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id)
|
logger.info(
|
||||||
|
"MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("MLflow tracking disabled: %s", exc)
|
logger.warning("MLflow tracking disabled: %s", exc)
|
||||||
_mlflow_client = None
|
_mlflow_client = None
|
||||||
@@ -95,18 +97,25 @@ def _log_llm_metrics(
|
|||||||
_mlflow_run_id,
|
_mlflow_run_id,
|
||||||
metrics=[
|
metrics=[
|
||||||
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
|
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
|
||||||
mlflow.entities.Metric("prompt_tokens", prompt_tokens, ts, _mlflow_step),
|
mlflow.entities.Metric(
|
||||||
mlflow.entities.Metric("completion_tokens", completion_tokens, ts, _mlflow_step),
|
"prompt_tokens", prompt_tokens, ts, _mlflow_step
|
||||||
|
),
|
||||||
|
mlflow.entities.Metric(
|
||||||
|
"completion_tokens", completion_tokens, ts, _mlflow_step
|
||||||
|
),
|
||||||
mlflow.entities.Metric("total_tokens", total_tokens, ts, _mlflow_step),
|
mlflow.entities.Metric("total_tokens", total_tokens, ts, _mlflow_step),
|
||||||
mlflow.entities.Metric("tokens_per_second", tps, ts, _mlflow_step),
|
mlflow.entities.Metric("tokens_per_second", tps, ts, _mlflow_step),
|
||||||
mlflow.entities.Metric("temperature", temperature, ts, _mlflow_step),
|
mlflow.entities.Metric("temperature", temperature, ts, _mlflow_step),
|
||||||
mlflow.entities.Metric("max_tokens_requested", max_tokens, ts, _mlflow_step),
|
mlflow.entities.Metric(
|
||||||
|
"max_tokens_requested", max_tokens, ts, _mlflow_step
|
||||||
|
),
|
||||||
mlflow.entities.Metric("top_p", top_p, ts, _mlflow_step),
|
mlflow.entities.Metric("top_p", top_p, ts, _mlflow_step),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("MLflow log failed", exc_info=True)
|
logger.debug("MLflow log failed", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_SYSTEM_PROMPT = (
|
DEFAULT_SYSTEM_PROMPT = (
|
||||||
"You are a helpful AI assistant running on Davies Tech Labs homelab infrastructure. "
|
"You are a helpful AI assistant running on Davies Tech Labs homelab infrastructure. "
|
||||||
"You are powered by Llama 3.1 70B served via vLLM on AMD Strix Halo (ROCm). "
|
"You are powered by Llama 3.1 70B served via vLLM on AMD Strix Halo (ROCm). "
|
||||||
@@ -273,10 +282,10 @@ def single_prompt(
|
|||||||
metrics = f"""
|
metrics = f"""
|
||||||
**Generation Metrics:**
|
**Generation Metrics:**
|
||||||
- Latency: {latency:.1f}s
|
- Latency: {latency:.1f}s
|
||||||
- Prompt tokens: {usage.get('prompt_tokens', 'N/A')}
|
- Prompt tokens: {usage.get("prompt_tokens", "N/A")}
|
||||||
- Completion tokens: {usage.get('completion_tokens', 'N/A')}
|
- Completion tokens: {usage.get("completion_tokens", "N/A")}
|
||||||
- Total tokens: {usage.get('total_tokens', 'N/A')}
|
- Total tokens: {usage.get("total_tokens", "N/A")}
|
||||||
- Model: {result.get('model', 'N/A')}
|
- Model: {result.get("model", "N/A")}
|
||||||
"""
|
"""
|
||||||
return text, metrics
|
return text, metrics
|
||||||
|
|
||||||
@@ -360,9 +369,13 @@ Chat with **Llama 3.1 70B** (AWQ INT4) served via vLLM on AMD Strix Halo (ROCm).
|
|||||||
|
|
||||||
gr.Examples(
|
gr.Examples(
|
||||||
examples=[
|
examples=[
|
||||||
["Summarise the key differences between CUDA and ROCm for ML workloads."],
|
[
|
||||||
|
"Summarise the key differences between CUDA and ROCm for ML workloads."
|
||||||
|
],
|
||||||
["Write a haiku about Kubernetes."],
|
["Write a haiku about Kubernetes."],
|
||||||
["Explain Ray Serve in one paragraph for someone new to ML serving."],
|
[
|
||||||
|
"Explain Ray Serve in one paragraph for someone new to ML serving."
|
||||||
|
],
|
||||||
["List 5 creative uses for a homelab GPU cluster."],
|
["List 5 creative uses for a homelab GPU cluster."],
|
||||||
],
|
],
|
||||||
inputs=[prompt_input],
|
inputs=[prompt_input],
|
||||||
|
|||||||
81
stt.py
81
stt.py
@@ -9,11 +9,11 @@ Features:
|
|||||||
- Translation mode
|
- Translation mode
|
||||||
- MLflow metrics logging
|
- MLflow metrics logging
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import io
|
import io
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import httpx
|
import httpx
|
||||||
@@ -30,11 +30,10 @@ logger = logging.getLogger("stt-demo")
|
|||||||
STT_URL = os.environ.get(
|
STT_URL = os.environ.get(
|
||||||
"STT_URL",
|
"STT_URL",
|
||||||
# Default: Ray Serve whisper endpoint
|
# Default: Ray Serve whisper endpoint
|
||||||
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper"
|
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper",
|
||||||
)
|
)
|
||||||
MLFLOW_TRACKING_URI = os.environ.get(
|
MLFLOW_TRACKING_URI = os.environ.get(
|
||||||
"MLFLOW_TRACKING_URI",
|
"MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"
|
||||||
"http://mlflow.mlflow.svc.cluster.local:80"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# ─── MLflow experiment tracking ──────────────────────────────────────────
|
# ─── MLflow experiment tracking ──────────────────────────────────────────
|
||||||
@@ -62,7 +61,9 @@ try:
|
|||||||
_mlflow_run_id = _mlflow_run.info.run_id
|
_mlflow_run_id = _mlflow_run.info.run_id
|
||||||
_mlflow_step = 0
|
_mlflow_step = 0
|
||||||
MLFLOW_ENABLED = True
|
MLFLOW_ENABLED = True
|
||||||
logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id)
|
logger.info(
|
||||||
|
"MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("MLflow tracking disabled: %s", exc)
|
logger.warning("MLflow tracking disabled: %s", exc)
|
||||||
_mlflow_client = None
|
_mlflow_client = None
|
||||||
@@ -72,7 +73,10 @@ except Exception as exc:
|
|||||||
|
|
||||||
|
|
||||||
def _log_stt_metrics(
|
def _log_stt_metrics(
|
||||||
latency: float, audio_duration: float, word_count: int, task: str,
|
latency: float,
|
||||||
|
audio_duration: float,
|
||||||
|
word_count: int,
|
||||||
|
task: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Log STT inference metrics to MLflow (non-blocking best-effort)."""
|
"""Log STT inference metrics to MLflow (non-blocking best-effort)."""
|
||||||
global _mlflow_step
|
global _mlflow_step
|
||||||
@@ -86,11 +90,15 @@ def _log_stt_metrics(
|
|||||||
_mlflow_run_id,
|
_mlflow_run_id,
|
||||||
metrics=[
|
metrics=[
|
||||||
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
|
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
|
||||||
mlflow.entities.Metric("audio_duration_s", audio_duration, ts, _mlflow_step),
|
mlflow.entities.Metric(
|
||||||
|
"audio_duration_s", audio_duration, ts, _mlflow_step
|
||||||
|
),
|
||||||
mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step),
|
mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step),
|
||||||
mlflow.entities.Metric("word_count", word_count, ts, _mlflow_step),
|
mlflow.entities.Metric("word_count", word_count, ts, _mlflow_step),
|
||||||
],
|
],
|
||||||
params=[] if _mlflow_step > 1 else [
|
params=[]
|
||||||
|
if _mlflow_step > 1
|
||||||
|
else [
|
||||||
mlflow.entities.Param("task", task),
|
mlflow.entities.Param("task", task),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@@ -124,9 +132,7 @@ LANGUAGES = {
|
|||||||
|
|
||||||
|
|
||||||
def transcribe_audio(
|
def transcribe_audio(
|
||||||
audio_input: tuple[int, np.ndarray] | str | None,
|
audio_input: tuple[int, np.ndarray] | str | None, language: str, task: str
|
||||||
language: str,
|
|
||||||
task: str
|
|
||||||
) -> tuple[str, str, str]:
|
) -> tuple[str, str, str]:
|
||||||
"""Transcribe audio using the Whisper STT service."""
|
"""Transcribe audio using the Whisper STT service."""
|
||||||
if audio_input is None:
|
if audio_input is None:
|
||||||
@@ -142,12 +148,12 @@ def transcribe_audio(
|
|||||||
|
|
||||||
# Convert to WAV bytes
|
# Convert to WAV bytes
|
||||||
audio_buffer = io.BytesIO()
|
audio_buffer = io.BytesIO()
|
||||||
sf.write(audio_buffer, audio_data, sample_rate, format='WAV')
|
sf.write(audio_buffer, audio_data, sample_rate, format="WAV")
|
||||||
audio_bytes = audio_buffer.getvalue()
|
audio_bytes = audio_buffer.getvalue()
|
||||||
audio_duration = len(audio_data) / sample_rate
|
audio_duration = len(audio_data) / sample_rate
|
||||||
else:
|
else:
|
||||||
# File path
|
# File path
|
||||||
with open(audio_input, 'rb') as f:
|
with open(audio_input, "rb") as f:
|
||||||
audio_bytes = f.read()
|
audio_bytes = f.read()
|
||||||
# Get duration
|
# Get duration
|
||||||
audio_data, sample_rate = sf.read(audio_input)
|
audio_data, sample_rate = sf.read(audio_input)
|
||||||
@@ -187,7 +193,9 @@ def transcribe_audio(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Status message
|
# Status message
|
||||||
status = f"✅ Transcribed {audio_duration:.1f}s of audio in {latency*1000:.0f}ms"
|
status = (
|
||||||
|
f"✅ Transcribed {audio_duration:.1f}s of audio in {latency * 1000:.0f}ms"
|
||||||
|
)
|
||||||
|
|
||||||
# Metrics
|
# Metrics
|
||||||
metrics = f"""
|
metrics = f"""
|
||||||
@@ -250,21 +258,19 @@ or file upload with support for 100+ languages.
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
mic_input = gr.Audio(
|
mic_input = gr.Audio(
|
||||||
label="Record Audio",
|
label="Record Audio", sources=["microphone"], type="numpy"
|
||||||
sources=["microphone"],
|
|
||||||
type="numpy"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
mic_language = gr.Dropdown(
|
mic_language = gr.Dropdown(
|
||||||
choices=list(LANGUAGES.keys()),
|
choices=list(LANGUAGES.keys()),
|
||||||
value="Auto-detect",
|
value="Auto-detect",
|
||||||
label="Language"
|
label="Language",
|
||||||
)
|
)
|
||||||
mic_task = gr.Radio(
|
mic_task = gr.Radio(
|
||||||
choices=["Transcribe", "Translate to English"],
|
choices=["Transcribe", "Translate to English"],
|
||||||
value="Transcribe",
|
value="Transcribe",
|
||||||
label="Task"
|
label="Task",
|
||||||
)
|
)
|
||||||
|
|
||||||
mic_btn = gr.Button("🎯 Transcribe", variant="primary")
|
mic_btn = gr.Button("🎯 Transcribe", variant="primary")
|
||||||
@@ -273,15 +279,12 @@ or file upload with support for 100+ languages.
|
|||||||
mic_status = gr.Textbox(label="Status", interactive=False)
|
mic_status = gr.Textbox(label="Status", interactive=False)
|
||||||
mic_metrics = gr.Markdown(label="Metrics")
|
mic_metrics = gr.Markdown(label="Metrics")
|
||||||
|
|
||||||
mic_output = gr.Textbox(
|
mic_output = gr.Textbox(label="Transcription", lines=5)
|
||||||
label="Transcription",
|
|
||||||
lines=5
|
|
||||||
)
|
|
||||||
|
|
||||||
mic_btn.click(
|
mic_btn.click(
|
||||||
fn=transcribe_audio,
|
fn=transcribe_audio,
|
||||||
inputs=[mic_input, mic_language, mic_task],
|
inputs=[mic_input, mic_language, mic_task],
|
||||||
outputs=[mic_status, mic_output, mic_metrics]
|
outputs=[mic_status, mic_output, mic_metrics],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tab 2: File Upload
|
# Tab 2: File Upload
|
||||||
@@ -289,21 +292,19 @@ or file upload with support for 100+ languages.
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
file_input = gr.Audio(
|
file_input = gr.Audio(
|
||||||
label="Upload Audio File",
|
label="Upload Audio File", sources=["upload"], type="filepath"
|
||||||
sources=["upload"],
|
|
||||||
type="filepath"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
file_language = gr.Dropdown(
|
file_language = gr.Dropdown(
|
||||||
choices=list(LANGUAGES.keys()),
|
choices=list(LANGUAGES.keys()),
|
||||||
value="Auto-detect",
|
value="Auto-detect",
|
||||||
label="Language"
|
label="Language",
|
||||||
)
|
)
|
||||||
file_task = gr.Radio(
|
file_task = gr.Radio(
|
||||||
choices=["Transcribe", "Translate to English"],
|
choices=["Transcribe", "Translate to English"],
|
||||||
value="Transcribe",
|
value="Transcribe",
|
||||||
label="Task"
|
label="Task",
|
||||||
)
|
)
|
||||||
|
|
||||||
file_btn = gr.Button("🎯 Transcribe", variant="primary")
|
file_btn = gr.Button("🎯 Transcribe", variant="primary")
|
||||||
@@ -312,15 +313,12 @@ or file upload with support for 100+ languages.
|
|||||||
file_status = gr.Textbox(label="Status", interactive=False)
|
file_status = gr.Textbox(label="Status", interactive=False)
|
||||||
file_metrics = gr.Markdown(label="Metrics")
|
file_metrics = gr.Markdown(label="Metrics")
|
||||||
|
|
||||||
file_output = gr.Textbox(
|
file_output = gr.Textbox(label="Transcription", lines=5)
|
||||||
label="Transcription",
|
|
||||||
lines=5
|
|
||||||
)
|
|
||||||
|
|
||||||
file_btn.click(
|
file_btn.click(
|
||||||
fn=transcribe_audio,
|
fn=transcribe_audio,
|
||||||
inputs=[file_input, file_language, file_task],
|
inputs=[file_input, file_language, file_task],
|
||||||
outputs=[file_status, file_output, file_metrics]
|
outputs=[file_status, file_output, file_metrics],
|
||||||
)
|
)
|
||||||
|
|
||||||
gr.Markdown("""
|
gr.Markdown("""
|
||||||
@@ -343,7 +341,7 @@ Whisper will automatically detect the source language.
|
|||||||
trans_input = gr.Audio(
|
trans_input = gr.Audio(
|
||||||
label="Audio Input",
|
label="Audio Input",
|
||||||
sources=["microphone", "upload"],
|
sources=["microphone", "upload"],
|
||||||
type="numpy"
|
type="numpy",
|
||||||
)
|
)
|
||||||
trans_btn = gr.Button("🌍 Translate to English", variant="primary")
|
trans_btn = gr.Button("🌍 Translate to English", variant="primary")
|
||||||
|
|
||||||
@@ -351,10 +349,7 @@ Whisper will automatically detect the source language.
|
|||||||
trans_status = gr.Textbox(label="Status", interactive=False)
|
trans_status = gr.Textbox(label="Status", interactive=False)
|
||||||
trans_metrics = gr.Markdown(label="Metrics")
|
trans_metrics = gr.Markdown(label="Metrics")
|
||||||
|
|
||||||
trans_output = gr.Textbox(
|
trans_output = gr.Textbox(label="English Translation", lines=5)
|
||||||
label="English Translation",
|
|
||||||
lines=5
|
|
||||||
)
|
|
||||||
|
|
||||||
def translate_audio(audio):
|
def translate_audio(audio):
|
||||||
return transcribe_audio(audio, "Auto-detect", "Translate to English")
|
return transcribe_audio(audio, "Auto-detect", "Translate to English")
|
||||||
@@ -362,15 +357,11 @@ Whisper will automatically detect the source language.
|
|||||||
trans_btn.click(
|
trans_btn.click(
|
||||||
fn=translate_audio,
|
fn=translate_audio,
|
||||||
inputs=trans_input,
|
inputs=trans_input,
|
||||||
outputs=[trans_status, trans_output, trans_metrics]
|
outputs=[trans_status, trans_output, trans_metrics],
|
||||||
)
|
)
|
||||||
|
|
||||||
create_footer()
|
create_footer()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
demo.launch(
|
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|
||||||
server_name="0.0.0.0",
|
|
||||||
server_port=7860,
|
|
||||||
show_error=True
|
|
||||||
)
|
|
||||||
|
|||||||
8
theme.py
8
theme.py
@@ -3,6 +3,7 @@ Shared Gradio theme for Davies Tech Labs AI demos.
|
|||||||
Consistent styling across all demo applications.
|
Consistent styling across all demo applications.
|
||||||
Cyberpunk aesthetic - dark with yellow/gold accents.
|
Cyberpunk aesthetic - dark with yellow/gold accents.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
@@ -25,7 +26,12 @@ def get_lab_theme() -> gr.Theme:
|
|||||||
primary_hue=gr.themes.colors.yellow,
|
primary_hue=gr.themes.colors.yellow,
|
||||||
secondary_hue=gr.themes.colors.amber,
|
secondary_hue=gr.themes.colors.amber,
|
||||||
neutral_hue=gr.themes.colors.zinc,
|
neutral_hue=gr.themes.colors.zinc,
|
||||||
font=[gr.themes.GoogleFont("Space Grotesk"), "ui-sans-serif", "system-ui", "sans-serif"],
|
font=[
|
||||||
|
gr.themes.GoogleFont("Space Grotesk"),
|
||||||
|
"ui-sans-serif",
|
||||||
|
"system-ui",
|
||||||
|
"sans-serif",
|
||||||
|
],
|
||||||
font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "ui-monospace", "monospace"],
|
font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "ui-monospace", "monospace"],
|
||||||
).set(
|
).set(
|
||||||
# Background colors
|
# Background colors
|
||||||
|
|||||||
81
tts.py
81
tts.py
@@ -9,11 +9,11 @@ Features:
|
|||||||
- MLflow metrics logging
|
- MLflow metrics logging
|
||||||
- Multiple TTS backends support (Coqui XTTS, Piper, etc.)
|
- Multiple TTS backends support (Coqui XTTS, Piper, etc.)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import io
|
import io
|
||||||
import base64
|
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import httpx
|
import httpx
|
||||||
@@ -30,11 +30,10 @@ logger = logging.getLogger("tts-demo")
|
|||||||
TTS_URL = os.environ.get(
|
TTS_URL = os.environ.get(
|
||||||
"TTS_URL",
|
"TTS_URL",
|
||||||
# Default: Ray Serve TTS endpoint
|
# Default: Ray Serve TTS endpoint
|
||||||
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/tts"
|
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/tts",
|
||||||
)
|
)
|
||||||
MLFLOW_TRACKING_URI = os.environ.get(
|
MLFLOW_TRACKING_URI = os.environ.get(
|
||||||
"MLFLOW_TRACKING_URI",
|
"MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"
|
||||||
"http://mlflow.mlflow.svc.cluster.local:80"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# ─── MLflow experiment tracking ──────────────────────────────────────────
|
# ─── MLflow experiment tracking ──────────────────────────────────────────
|
||||||
@@ -62,7 +61,9 @@ try:
|
|||||||
_mlflow_run_id = _mlflow_run.info.run_id
|
_mlflow_run_id = _mlflow_run.info.run_id
|
||||||
_mlflow_step = 0
|
_mlflow_step = 0
|
||||||
MLFLOW_ENABLED = True
|
MLFLOW_ENABLED = True
|
||||||
logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id)
|
logger.info(
|
||||||
|
"MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("MLflow tracking disabled: %s", exc)
|
logger.warning("MLflow tracking disabled: %s", exc)
|
||||||
_mlflow_client = None
|
_mlflow_client = None
|
||||||
@@ -72,7 +73,10 @@ except Exception as exc:
|
|||||||
|
|
||||||
|
|
||||||
def _log_tts_metrics(
|
def _log_tts_metrics(
|
||||||
latency: float, audio_duration: float, text_chars: int, language: str,
|
latency: float,
|
||||||
|
audio_duration: float,
|
||||||
|
text_chars: int,
|
||||||
|
language: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Log TTS inference metrics to MLflow (non-blocking best-effort)."""
|
"""Log TTS inference metrics to MLflow (non-blocking best-effort)."""
|
||||||
global _mlflow_step
|
global _mlflow_step
|
||||||
@@ -87,7 +91,9 @@ def _log_tts_metrics(
|
|||||||
_mlflow_run_id,
|
_mlflow_run_id,
|
||||||
metrics=[
|
metrics=[
|
||||||
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
|
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
|
||||||
mlflow.entities.Metric("audio_duration_s", audio_duration, ts, _mlflow_step),
|
mlflow.entities.Metric(
|
||||||
|
"audio_duration_s", audio_duration, ts, _mlflow_step
|
||||||
|
),
|
||||||
mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step),
|
mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step),
|
||||||
mlflow.entities.Metric("chars_per_second", cps, ts, _mlflow_step),
|
mlflow.entities.Metric("chars_per_second", cps, ts, _mlflow_step),
|
||||||
mlflow.entities.Metric("text_chars", text_chars, ts, _mlflow_step),
|
mlflow.entities.Metric("text_chars", text_chars, ts, _mlflow_step),
|
||||||
@@ -121,7 +127,9 @@ LANGUAGES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def synthesize_speech(text: str, language: str) -> tuple[str, tuple[int, np.ndarray] | None, str]:
|
def synthesize_speech(
|
||||||
|
text: str, language: str
|
||||||
|
) -> tuple[str, tuple[int, np.ndarray] | None, str]:
|
||||||
"""Synthesize speech from text using the TTS service."""
|
"""Synthesize speech from text using the TTS service."""
|
||||||
if not text.strip():
|
if not text.strip():
|
||||||
return "❌ Please enter some text", None, ""
|
return "❌ Please enter some text", None, ""
|
||||||
@@ -133,8 +141,7 @@ def synthesize_speech(text: str, language: str) -> tuple[str, tuple[int, np.ndar
|
|||||||
|
|
||||||
# Call TTS service (Coqui XTTS API format)
|
# Call TTS service (Coqui XTTS API format)
|
||||||
response = client.get(
|
response = client.get(
|
||||||
f"{TTS_URL}/api/tts",
|
f"{TTS_URL}/api/tts", params={"text": text, "language_id": lang_code}
|
||||||
params={"text": text, "language_id": lang_code}
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
@@ -228,16 +235,18 @@ in multiple languages.
|
|||||||
label="Text to Synthesize",
|
label="Text to Synthesize",
|
||||||
placeholder="Enter text to convert to speech...",
|
placeholder="Enter text to convert to speech...",
|
||||||
lines=5,
|
lines=5,
|
||||||
max_lines=10
|
max_lines=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
language = gr.Dropdown(
|
language = gr.Dropdown(
|
||||||
choices=list(LANGUAGES.keys()),
|
choices=list(LANGUAGES.keys()),
|
||||||
value="English",
|
value="English",
|
||||||
label="Language"
|
label="Language",
|
||||||
|
)
|
||||||
|
synthesize_btn = gr.Button(
|
||||||
|
"🔊 Synthesize", variant="primary", scale=2
|
||||||
)
|
)
|
||||||
synthesize_btn = gr.Button("🔊 Synthesize", variant="primary", scale=2)
|
|
||||||
|
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
status_output = gr.Textbox(label="Status", interactive=False)
|
status_output = gr.Textbox(label="Status", interactive=False)
|
||||||
@@ -248,15 +257,24 @@ in multiple languages.
|
|||||||
synthesize_btn.click(
|
synthesize_btn.click(
|
||||||
fn=synthesize_speech,
|
fn=synthesize_speech,
|
||||||
inputs=[text_input, language],
|
inputs=[text_input, language],
|
||||||
outputs=[status_output, audio_output, metrics_output]
|
outputs=[status_output, audio_output, metrics_output],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Example texts
|
# Example texts
|
||||||
gr.Examples(
|
gr.Examples(
|
||||||
examples=[
|
examples=[
|
||||||
["Hello! Welcome to Davies Tech Labs. This is a demonstration of our text-to-speech system.", "English"],
|
[
|
||||||
["The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet.", "English"],
|
"Hello! Welcome to Davies Tech Labs. This is a demonstration of our text-to-speech system.",
|
||||||
["Bonjour! Bienvenue au laboratoire technique de Davies.", "French"],
|
"English",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet.",
|
||||||
|
"English",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"Bonjour! Bienvenue au laboratoire technique de Davies.",
|
||||||
|
"French",
|
||||||
|
],
|
||||||
["Hola! Bienvenido al laboratorio de tecnología.", "Spanish"],
|
["Hola! Bienvenido al laboratorio de tecnología.", "Spanish"],
|
||||||
["Guten Tag! Willkommen im Techniklabor.", "German"],
|
["Guten Tag! Willkommen im Techniklabor.", "German"],
|
||||||
],
|
],
|
||||||
@@ -268,14 +286,16 @@ in multiple languages.
|
|||||||
gr.Markdown("Compare the same text in different languages.")
|
gr.Markdown("Compare the same text in different languages.")
|
||||||
|
|
||||||
compare_text = gr.Textbox(
|
compare_text = gr.Textbox(
|
||||||
label="Text to Compare",
|
label="Text to Compare", value="Hello, how are you today?", lines=2
|
||||||
value="Hello, how are you today?",
|
|
||||||
lines=2
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
lang1 = gr.Dropdown(choices=list(LANGUAGES.keys()), value="English", label="Language 1")
|
lang1 = gr.Dropdown(
|
||||||
lang2 = gr.Dropdown(choices=list(LANGUAGES.keys()), value="Spanish", label="Language 2")
|
choices=list(LANGUAGES.keys()), value="English", label="Language 1"
|
||||||
|
)
|
||||||
|
lang2 = gr.Dropdown(
|
||||||
|
choices=list(LANGUAGES.keys()), value="Spanish", label="Language 2"
|
||||||
|
)
|
||||||
|
|
||||||
compare_btn = gr.Button("Compare Languages", variant="primary")
|
compare_btn = gr.Button("Compare Languages", variant="primary")
|
||||||
|
|
||||||
@@ -298,7 +318,7 @@ in multiple languages.
|
|||||||
compare_btn.click(
|
compare_btn.click(
|
||||||
fn=compare_languages,
|
fn=compare_languages,
|
||||||
inputs=[compare_text, lang1, lang2],
|
inputs=[compare_text, lang1, lang2],
|
||||||
outputs=[status1, audio1, status2, audio2]
|
outputs=[status1, audio1, status2, audio2],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tab 3: Batch Processing
|
# Tab 3: Batch Processing
|
||||||
@@ -308,19 +328,16 @@ in multiple languages.
|
|||||||
batch_input = gr.Textbox(
|
batch_input = gr.Textbox(
|
||||||
label="Texts (one per line)",
|
label="Texts (one per line)",
|
||||||
placeholder="Enter multiple texts, one per line...",
|
placeholder="Enter multiple texts, one per line...",
|
||||||
lines=6
|
lines=6,
|
||||||
)
|
)
|
||||||
batch_lang = gr.Dropdown(
|
batch_lang = gr.Dropdown(
|
||||||
choices=list(LANGUAGES.keys()),
|
choices=list(LANGUAGES.keys()), value="English", label="Language"
|
||||||
value="English",
|
|
||||||
label="Language"
|
|
||||||
)
|
)
|
||||||
batch_btn = gr.Button("Synthesize All", variant="primary")
|
batch_btn = gr.Button("Synthesize All", variant="primary")
|
||||||
|
|
||||||
batch_status = gr.Textbox(label="Status", interactive=False)
|
batch_status = gr.Textbox(label="Status", interactive=False)
|
||||||
batch_audios = gr.Dataset(
|
batch_audios = gr.Dataset(
|
||||||
components=[gr.Audio(type="numpy")],
|
components=[gr.Audio(type="numpy")], label="Generated Audio Files"
|
||||||
label="Generated Audio Files"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: Batch processing would need more complex handling
|
# Note: Batch processing would need more complex handling
|
||||||
@@ -334,8 +351,4 @@ or the Kubeflow pipeline for better throughput.*
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
demo.launch(
|
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|
||||||
server_name="0.0.0.0",
|
|
||||||
server_port=7860,
|
|
||||||
show_error=True
|
|
||||||
)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user