3 Commits

Author SHA1 Message Date
0d1c40725e style: fix ruff lint and formatting issues
All checks were successful
CI / Docker Build & Push (push) Successful in 5m38s
CI / Deploy to Kubernetes (push) Successful in 1m21s
CI / Notify (push) Successful in 1s
CI / Lint (push) Successful in 1m4s
CI / Release (push) Successful in 54s
- tts.py: rename ambiguous variable 'l' to 'line' (E741)
- tts.py, llm.py: apply ruff formatter
2026-02-22 10:55:00 -05:00
dfe93ae856 fix: stt.yaml env var WHISPER_URL→STT_URL + tts.py improvements
Some checks failed
CI / Lint (push) Failing after 1m38s
CI / Deploy to Kubernetes (push) Has been cancelled
CI / Notify (push) Has been cancelled
CI / Release (push) Has been skipped
CI / Docker Build & Push (push) Has been cancelled
- stt.yaml: rename WHISPER_URL to STT_URL to match what stt.py reads
- tts.py: improve WAV handling (BytesIO fix), sentence splitting, robust
  _read_wav_bytes with wave+soundfile+raw-PCM fallbacks
- Add __pycache__/ to .gitignore
2026-02-22 10:47:10 -05:00
f5a2545ac8 llm streaming outputs, bumped up images.
Some checks failed
CI / Lint (push) Failing after 1m35s
CI / Release (push) Has been skipped
CI / Docker Build & Push (push) Has been skipped
CI / Deploy to Kubernetes (push) Has been skipped
CI / Notify (push) Successful in 1s
2026-02-20 16:53:37 -05:00
5 changed files with 391 additions and 100 deletions

1
.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
__pycache__/

View File

@@ -8,3 +8,8 @@ resources:
- llm.yaml - llm.yaml
- tts.yaml - tts.yaml
- stt.yaml - stt.yaml
images:
- name: gitea-http.gitea.svc.cluster.local:3000/daviestechlabs/gradio-ui
newName: registry.lab.daviestechlabs.io/daviestechlabs/gradio-ui
newTag: "0.0.7"

140
llm.py
View File

@@ -3,13 +3,14 @@
LLM Chat Demo - Gradio UI for testing vLLM inference service. LLM Chat Demo - Gradio UI for testing vLLM inference service.
Features: Features:
- Multi-turn chat with streaming responses - Multi-turn chat with true SSE streaming responses
- Configurable temperature, max tokens, top-p - Configurable temperature, max tokens, top-p
- System prompt customisation - System prompt customisation
- Token usage and latency metrics - Token usage and latency metrics
- Chat history management - Chat history management
""" """
import json
import os import os
import time import time
import logging import logging
@@ -127,6 +128,27 @@ async_client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=30.0))
sync_client = httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0)) sync_client = httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0))
def _extract_content(content) -> str:
"""Extract plain text from message content.
Handles both plain strings and Gradio 6.x content-parts format:
[{"type": "text", "text": "..."}] or [{"text": "..."}]
"""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
parts.append(item.get("text", item.get("content", str(item))))
elif isinstance(item, str):
parts.append(item)
else:
parts.append(str(item))
return "".join(parts)
return str(content)
async def chat_stream( async def chat_stream(
message: str, message: str,
history: list[dict[str, str]], history: list[dict[str, str]],
@@ -135,18 +157,23 @@ async def chat_stream(
max_tokens: int, max_tokens: int,
top_p: float, top_p: float,
): ):
"""Stream chat responses from the vLLM endpoint.""" """Stream chat responses from the vLLM endpoint via SSE."""
if not message.strip(): if not message.strip():
yield "" yield ""
return return
# Build message list from history # Build message list from history, normalising content-parts
messages = [] messages = []
if system_prompt.strip(): if system_prompt.strip():
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
for entry in history: for entry in history:
messages.append({"role": entry["role"], "content": entry["content"]}) messages.append(
{
"role": entry["role"],
"content": _extract_content(entry["content"]),
}
)
messages.append({"role": "user", "content": message}) messages.append({"role": "user", "content": message})
@@ -155,45 +182,86 @@ async def chat_stream(
"temperature": temperature, "temperature": temperature,
"max_tokens": max_tokens, "max_tokens": max_tokens,
"top_p": top_p, "top_p": top_p,
"stream": True,
} }
start_time = time.time() start_time = time.time()
try: try:
response = await async_client.post(LLM_URL, json=payload) # Try true SSE streaming first
response.raise_for_status() async with async_client.stream("POST", LLM_URL, json=payload) as response:
response.raise_for_status()
content_type = response.headers.get("content-type", "")
result = response.json() if "text/event-stream" in content_type:
text = result["choices"][0]["message"]["content"] # SSE streaming — accumulate deltas
latency = time.time() - start_time full_text = ""
usage = result.get("usage", {}) async for line in response.aiter_lines():
if not line.startswith("data: "):
continue
data = line[6:]
if data.strip() == "[DONE]":
break
try:
chunk = json.loads(data)
delta = (
chunk.get("choices", [{}])[0]
.get("delta", {})
.get("content", "")
)
if delta:
full_text += delta
yield full_text
except json.JSONDecodeError:
continue
logger.info( latency = time.time() - start_time
"LLM response: %d tokens in %.1fs (prompt=%d, completion=%d)", logger.info(
usage.get("total_tokens", 0), "LLM streamed response: %d chars in %.1fs", len(full_text), latency
latency, )
usage.get("prompt_tokens", 0),
usage.get("completion_tokens", 0),
)
# Log to MLflow # Best-effort metrics from the final SSE payload
_log_llm_metrics( _log_llm_metrics(
latency=latency, latency=latency,
prompt_tokens=usage.get("prompt_tokens", 0), prompt_tokens=0,
completion_tokens=usage.get("completion_tokens", 0), completion_tokens=len(full_text.split()),
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
top_p=top_p, top_p=top_p,
) )
else:
# Non-streaming fallback (endpoint doesn't support stream)
body = await response.aread()
result = json.loads(body)
text = _extract_content(result["choices"][0]["message"]["content"])
latency = time.time() - start_time
usage = result.get("usage", {})
# Yield text progressively for a nicer streaming feel logger.info(
chunk_size = 4 "LLM response: %d tokens in %.1fs (prompt=%d, completion=%d)",
words = text.split(" ") usage.get("total_tokens", 0),
partial = "" latency,
for i, word in enumerate(words): usage.get("prompt_tokens", 0),
partial += ("" if i == 0 else " ") + word usage.get("completion_tokens", 0),
if i % chunk_size == 0 or i == len(words) - 1: )
yield partial
_log_llm_metrics(
latency=latency,
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)
# Yield text progressively for a nicer feel
chunk_size = 4
words = text.split(" ")
partial = ""
for i, word in enumerate(words):
partial += ("" if i == 0 else " ") + word
if i % chunk_size == 0 or i == len(words) - 1:
yield partial
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.exception("LLM request failed") logger.exception("LLM request failed")
@@ -266,7 +334,7 @@ def single_prompt(
result = response.json() result = response.json()
latency = time.time() - start_time latency = time.time() - start_time
text = result["choices"][0]["message"]["content"] text = _extract_content(result["choices"][0]["message"]["content"])
usage = result.get("usage", {}) usage = result.get("usage", {})
# Log to MLflow # Log to MLflow
@@ -325,7 +393,7 @@ Chat with **Llama 3.1 70B** (AWQ INT4) served via vLLM on AMD Strix Halo (ROCm).
) )
with gr.Row(): with gr.Row():
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature") temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
max_tokens = gr.Slider(16, 4096, value=512, step=16, label="Max Tokens") max_tokens = gr.Slider(16, 8192, value=2048, step=16, label="Max Tokens")
top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-p") top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-p")
with gr.Tabs(): with gr.Tabs():

View File

@@ -28,7 +28,7 @@ spec:
name: http name: http
protocol: TCP protocol: TCP
env: env:
- name: WHISPER_URL - name: STT_URL
# Ray Serve endpoint - routes to /whisper prefix # Ray Serve endpoint - routes to /whisper prefix
value: "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper" value: "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper"
- name: MLFLOW_TRACKING_URI - name: MLFLOW_TRACKING_URI

343
tts.py
View File

@@ -5,19 +5,20 @@ TTS Demo - Gradio UI for testing Text-to-Speech service.
Features: Features:
- Text input with language selection - Text input with language selection
- Audio playback of synthesized speech - Audio playback of synthesized speech
- Voice/speaker selection (when available) - Sentence-level chunking for better quality
- Speed control
- MLflow metrics logging - MLflow metrics logging
- Multiple TTS backends support (Coqui XTTS, Piper, etc.)
""" """
import os import os
import re
import time import time
import logging import logging
import io import io
import wave
import gradio as gr import gradio as gr
import httpx import httpx
import soundfile as sf
import numpy as np import numpy as np
from theme import get_lab_theme, CUSTOM_CSS, create_footer from theme import get_lab_theme, CUSTOM_CSS, create_footer
@@ -126,42 +127,243 @@ LANGUAGES = {
"Hungarian": "hu", "Hungarian": "hu",
} }
# ─── Text preprocessing ─────────────────────────────────────────────────
_SENTENCE_RE = re.compile(r"(?<=[.!?;])\s+|(?<=\n)\s*", re.MULTILINE)
_DIGIT_WORDS = {
"0": "zero",
"1": "one",
"2": "two",
"3": "three",
"4": "four",
"5": "five",
"6": "six",
"7": "seven",
"8": "eight",
"9": "nine",
}
def _expand_numbers(text: str) -> str:
"""Expand standalone single digits to words for clearer pronunciation."""
return re.sub(
r"\b(\d)\b",
lambda m: _DIGIT_WORDS.get(m.group(0), m.group(0)),
text,
)
def _clean_text(text: str) -> str:
"""Clean and normalise text for TTS input."""
text = re.sub(r"[ \t]+", " ", text)
text = "\n".join(line.strip() for line in text.splitlines())
# Strip markdown / code-fence characters
text = re.sub(r"[*#~`|<>{}[\]\\]", "", text)
# Expand common symbols
text = text.replace("&", " and ")
text = text.replace("@", " at ")
text = text.replace("%", " percent ")
text = text.replace("+", " plus ")
text = text.replace("=", " equals ")
text = _expand_numbers(text)
return text.strip()
def _split_sentences(text: str) -> list[str]:
"""Split text into sentences suitable for TTS.
Keeps sentences short for best quality while preserving natural phrasing.
Very long segments are further split on commas / semicolons.
"""
text = _clean_text(text)
if not text:
return []
raw_parts = _SENTENCE_RE.split(text)
sentences: list[str] = []
for part in raw_parts:
part = part.strip()
if not part:
continue
if len(part) > 200:
for sp in re.split(r"(?<=[,;])\s+", part):
sp = sp.strip()
if sp:
sentences.append(sp)
else:
sentences.append(part)
return sentences
# ─── Audio helpers ───────────────────────────────────────────────────────
def _read_wav_bytes(data: bytes) -> tuple[int, np.ndarray]:
"""Read WAV audio from bytes, handling scipy wavfile and standard WAV.
Returns (sample_rate, float32_audio) with values in [-1, 1].
"""
buf = io.BytesIO(data)
# Try stdlib wave module first — most robust for PCM WAV from scipy
try:
with wave.open(buf, "rb") as wf:
sr = wf.getframerate()
n_frames = wf.getnframes()
n_channels = wf.getnchannels()
sampwidth = wf.getsampwidth()
raw = wf.readframes(n_frames)
if sampwidth == 2:
audio = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
elif sampwidth == 4:
audio = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2147483648.0
elif sampwidth == 1:
audio = (
np.frombuffer(raw, dtype=np.uint8).astype(np.float32) - 128.0
) / 128.0
else:
raise ValueError(f"Unsupported sample width: {sampwidth}")
if n_channels > 1:
audio = audio.reshape(-1, n_channels).mean(axis=1)
return sr, audio
except Exception as exc:
logger.debug("wave module failed (%s), trying soundfile", exc)
# Fallback: soundfile (handles FLAC, OGG, etc.)
buf.seek(0)
try:
import soundfile as sf
audio, sr = sf.read(buf, dtype="float32")
if audio.ndim > 1:
audio = audio.mean(axis=1)
return sr, audio
except Exception as exc:
logger.debug("soundfile failed (%s), attempting raw PCM", exc)
# Last resort: raw 16-bit PCM at 22050 Hz
logger.warning(
"Could not parse WAV header (len=%d, first 4 bytes=%r); raw PCM decode",
len(data),
data[:4],
)
audio = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
return 22050, audio
def _concat_audio(
chunks: list[tuple[int, np.ndarray]], pause_ms: int = 200
) -> tuple[int, np.ndarray]:
"""Concatenate (sample_rate, audio) chunks with silence gaps."""
if not chunks:
return 22050, np.array([], dtype=np.float32)
if len(chunks) == 1:
return chunks[0]
sr = chunks[0][0]
silence = np.zeros(int(sr * pause_ms / 1000), dtype=np.float32)
parts: list[np.ndarray] = []
for sample_rate, audio in chunks:
if sample_rate != sr:
ratio = sr / sample_rate
indices = np.arange(0, len(audio), 1.0 / ratio).astype(int)
indices = indices[indices < len(audio)]
audio = audio[indices]
parts.append(audio)
parts.append(silence)
if parts:
parts.pop() # remove trailing silence
return sr, np.concatenate(parts)
# ─── TTS synthesis ───────────────────────────────────────────────────────
def _synthesize_chunk(text: str, lang_code: str, speed: float = 1.0) -> bytes:
"""Synthesize a single text chunk via the TTS backend.
Uses the JSON POST endpoint (no URL length limits, supports speed).
Falls back to the Coqui-compatible GET endpoint if POST fails.
"""
import base64 as b64
# Try JSON POST first
try:
resp = client.post(
TTS_URL,
json={
"text": text,
"language": lang_code,
"speed": speed,
"return_base64": True,
},
)
resp.raise_for_status()
ct = resp.headers.get("content-type", "")
if "application/json" in ct:
body = resp.json()
if "error" in body:
raise RuntimeError(body["error"])
audio_b64 = body.get("audio", "")
if audio_b64:
return b64.b64decode(audio_b64)
# Non-JSON response — treat as raw audio bytes
return resp.content
except Exception:
logger.debug(
"POST endpoint failed, falling back to GET /api/tts", exc_info=True
)
# Fallback: Coqui-compatible GET (no speed control)
resp = client.get(
f"{TTS_URL}/api/tts",
params={"text": text, "language_id": lang_code},
)
resp.raise_for_status()
return resp.content
def synthesize_speech( def synthesize_speech(
text: str, language: str text: str, language: str, speed: float
) -> tuple[str, tuple[int, np.ndarray] | None, str]: ) -> tuple[str, tuple[int, np.ndarray] | None, str]:
"""Synthesize speech from text using the TTS service.""" """Synthesize speech from text using the TTS service.
Long text is split into sentences and synthesized individually
for better quality, then concatenated with natural pauses.
"""
if not text.strip(): if not text.strip():
return "❌ Please enter some text", None, "" return "❌ Please enter some text", None, ""
lang_code = LANGUAGES.get(language, "en") lang_code = LANGUAGES.get(language, "en")
sentences = _split_sentences(text)
if not sentences:
return "❌ No speakable text found after cleaning", None, ""
try: try:
start_time = time.time() start_time = time.time()
audio_chunks: list[tuple[int, np.ndarray]] = []
# Call TTS service (Coqui XTTS API format) for sentence in sentences:
response = client.get( raw_audio = _synthesize_chunk(sentence, lang_code, speed)
f"{TTS_URL}/api/tts", params={"text": text, "language_id": lang_code} sr, audio = _read_wav_bytes(raw_audio)
) audio_chunks.append((sr, audio))
response.raise_for_status()
sample_rate, audio_data = _concat_audio(audio_chunks)
latency = time.time() - start_time latency = time.time() - start_time
audio_bytes = response.content duration = len(audio_data) / sample_rate if sample_rate > 0 else 0
# Parse audio data n_chunks = len(sentences)
audio_io = io.BytesIO(audio_bytes) status = (
audio_data, sample_rate = sf.read(audio_io) f"✅ Generated {duration:.2f}s of audio in {latency * 1000:.0f}ms"
f" ({n_chunks} sentence{'s' if n_chunks != 1 else ''})"
)
# Calculate duration
if len(audio_data.shape) == 1:
duration = len(audio_data) / sample_rate
else:
duration = len(audio_data) / sample_rate
# Status message
status = f"✅ Generated {duration:.2f}s of audio in {latency * 1000:.0f}ms"
# Log to MLflow
_log_tts_metrics( _log_tts_metrics(
latency=latency, latency=latency,
audio_duration=duration, audio_duration=duration,
@@ -169,19 +371,19 @@ def synthesize_speech(
language=lang_code, language=lang_code,
) )
# Metrics
metrics = f""" metrics = f"""
**Audio Statistics:** **Audio Statistics:**
- Duration: {duration:.2f} seconds - Duration: {duration:.2f} seconds
- Sample Rate: {sample_rate} Hz - Sample Rate: {sample_rate} Hz
- Size: {len(audio_bytes) / 1024:.1f} KB - Size: {len(audio_data) * 2 / 1024:.1f} KB
- Generation Time: {latency * 1000:.0f}ms - Generation Time: {latency * 1000:.0f}ms
- Real-time Factor: {latency / duration:.2f}x - Real-time Factor: {latency / duration:.2f}x
- Language: {language} ({lang_code}) - Language: {language} ({lang_code})
- Speed: {speed:.1f}x
- Sentences: {n_chunks}
- Characters: {len(text)} - Characters: {len(text)}
- Chars/sec: {len(text) / latency:.1f} - Chars/sec: {len(text) / latency:.1f}
""" """
return status, (sample_rate, audio_data), metrics return status, (sample_rate, audio_data), metrics
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
@@ -189,37 +391,33 @@ def synthesize_speech(
return f"❌ TTS service error: {e.response.status_code}", None, "" return f"❌ TTS service error: {e.response.status_code}", None, ""
except Exception as e: except Exception as e:
logger.exception("TTS synthesis failed") logger.exception("TTS synthesis failed")
return f"❌ Error: {str(e)}", None, "" return f"❌ Error: {e}", None, ""
def check_service_health() -> str: def check_service_health() -> str:
"""Check if the TTS service is healthy.""" """Check if the TTS service is healthy."""
try: try:
# Try the health endpoint first
response = client.get(f"{TTS_URL}/health", timeout=5.0) response = client.get(f"{TTS_URL}/health", timeout=5.0)
if response.status_code == 200: if response.status_code == 200:
return "🟢 Service is healthy" return "🟢 Service is healthy"
# Fall back to root endpoint
response = client.get(f"{TTS_URL}/", timeout=5.0) response = client.get(f"{TTS_URL}/", timeout=5.0)
if response.status_code == 200: if response.status_code == 200:
return "🟢 Service is responding" return "🟢 Service is responding"
return f"🟡 Service returned status {response.status_code}" return f"🟡 Service returned status {response.status_code}"
except Exception as e: except Exception as e:
return f"🔴 Service unavailable: {str(e)}" return f"🔴 Service unavailable: {e}"
# Build the Gradio app # ─── Gradio UI ───────────────────────────────────────────────────────────
with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="TTS Demo") as demo: with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="TTS Demo") as demo:
gr.Markdown(""" gr.Markdown("""
# 🔊 Text-to-Speech Demo # 🔊 Text-to-Speech Demo
Test the **Coqui XTTS** text-to-speech service. Convert text to natural-sounding speech Test the **Coqui XTTS** text-to-speech service. Convert text to natural-sounding speech
in multiple languages. in multiple languages. Long text is automatically split into sentences for better quality.
""") """)
# Service status
with gr.Row(): with gr.Row():
health_btn = gr.Button("🔄 Check Service", size="sm") health_btn = gr.Button("🔄 Check Service", size="sm")
health_status = gr.Textbox(label="Service Status", interactive=False) health_status = gr.Textbox(label="Service Status", interactive=False)
@@ -227,7 +425,6 @@ in multiple languages.
health_btn.click(fn=check_service_health, outputs=health_status) health_btn.click(fn=check_service_health, outputs=health_status)
with gr.Tabs(): with gr.Tabs():
# Tab 1: Basic TTS
with gr.TabItem("🎤 Text to Speech"): with gr.TabItem("🎤 Text to Speech"):
with gr.Row(): with gr.Row():
with gr.Column(scale=2): with gr.Column(scale=2):
@@ -237,17 +434,24 @@ in multiple languages.
lines=5, lines=5,
max_lines=10, max_lines=10,
) )
with gr.Row(): with gr.Row():
language = gr.Dropdown( language = gr.Dropdown(
choices=list(LANGUAGES.keys()), choices=list(LANGUAGES.keys()),
value="English", value="English",
label="Language", label="Language",
) )
synthesize_btn = gr.Button( speed = gr.Slider(
"🔊 Synthesize", variant="primary", scale=2 minimum=0.5,
maximum=2.0,
value=1.0,
step=0.1,
label="Speed",
)
synthesize_btn = gr.Button(
"🔊 Synthesize",
variant="primary",
scale=2,
) )
with gr.Column(scale=1): with gr.Column(scale=1):
status_output = gr.Textbox(label="Status", interactive=False) status_output = gr.Textbox(label="Status", interactive=False)
metrics_output = gr.Markdown(label="Metrics") metrics_output = gr.Markdown(label="Metrics")
@@ -256,39 +460,39 @@ in multiple languages.
synthesize_btn.click( synthesize_btn.click(
fn=synthesize_speech, fn=synthesize_speech,
inputs=[text_input, language], inputs=[text_input, language, speed],
outputs=[status_output, audio_output, metrics_output], outputs=[status_output, audio_output, metrics_output],
) )
# Example texts
gr.Examples( gr.Examples(
examples=[ examples=[
[ [
"Hello! Welcome to Davies Tech Labs. This is a demonstration of our text-to-speech system.", "Hello! Welcome to Davies Tech Labs. This is a demonstration of our text-to-speech system.",
"English", "English",
1.0,
], ],
[ [
"The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet.", "The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet.",
"English", "English",
1.0,
], ],
[ [
"Bonjour! Bienvenue au laboratoire technique de Davies.", "Bonjour! Bienvenue au laboratoire technique de Davies.",
"French", "French",
1.0,
], ],
["Hola! Bienvenido al laboratorio de tecnología.", "Spanish"], ["Hola! Bienvenido al laboratorio de tecnología.", "Spanish", 1.0],
["Guten Tag! Willkommen im Techniklabor.", "German"], ["Guten Tag! Willkommen im Techniklabor.", "German", 1.0],
], ],
inputs=[text_input, language], inputs=[text_input, language, speed],
) )
# Tab 2: Comparison
with gr.TabItem("🔄 Language Comparison"): with gr.TabItem("🔄 Language Comparison"):
gr.Markdown("Compare the same text in different languages.") gr.Markdown("Compare the same text in different languages.")
compare_text = gr.Textbox( compare_text = gr.Textbox(
label="Text to Compare", value="Hello, how are you today?", lines=2 label="Text to Compare", value="Hello, how are you today?", lines=2
) )
with gr.Row(): with gr.Row():
lang1 = gr.Dropdown( lang1 = gr.Dropdown(
choices=list(LANGUAGES.keys()), value="English", label="Language 1" choices=list(LANGUAGES.keys()), value="English", label="Language 1"
@@ -296,6 +500,9 @@ in multiple languages.
lang2 = gr.Dropdown( lang2 = gr.Dropdown(
choices=list(LANGUAGES.keys()), value="Spanish", label="Language 2" choices=list(LANGUAGES.keys()), value="Spanish", label="Language 2"
) )
compare_speed = gr.Slider(
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed"
)
compare_btn = gr.Button("Compare Languages", variant="primary") compare_btn = gr.Button("Compare Languages", variant="primary")
@@ -304,24 +511,22 @@ in multiple languages.
gr.Markdown("### Language 1") gr.Markdown("### Language 1")
audio1 = gr.Audio(label="Audio 1", type="numpy") audio1 = gr.Audio(label="Audio 1", type="numpy")
status1 = gr.Textbox(label="Status", interactive=False) status1 = gr.Textbox(label="Status", interactive=False)
with gr.Column(): with gr.Column():
gr.Markdown("### Language 2") gr.Markdown("### Language 2")
audio2 = gr.Audio(label="Audio 2", type="numpy") audio2 = gr.Audio(label="Audio 2", type="numpy")
status2 = gr.Textbox(label="Status", interactive=False) status2 = gr.Textbox(label="Status", interactive=False)
def compare_languages(text, l1, l2): def compare_languages(text, l1, l2, spd):
s1, a1, _ = synthesize_speech(text, l1) s1, a1, _ = synthesize_speech(text, l1, spd)
s2, a2, _ = synthesize_speech(text, l2) s2, a2, _ = synthesize_speech(text, l2, spd)
return s1, a1, s2, a2 return s1, a1, s2, a2
compare_btn.click( compare_btn.click(
fn=compare_languages, fn=compare_languages,
inputs=[compare_text, lang1, lang2], inputs=[compare_text, lang1, lang2, compare_speed],
outputs=[status1, audio1, status2, audio2], outputs=[status1, audio1, status2, audio2],
) )
# Tab 3: Batch Processing
with gr.TabItem("📚 Batch Synthesis"): with gr.TabItem("📚 Batch Synthesis"):
gr.Markdown("Synthesize multiple texts at once (one per line).") gr.Markdown("Synthesize multiple texts at once (one per line).")
@@ -333,19 +538,31 @@ in multiple languages.
batch_lang = gr.Dropdown( batch_lang = gr.Dropdown(
choices=list(LANGUAGES.keys()), value="English", label="Language" choices=list(LANGUAGES.keys()), value="English", label="Language"
) )
batch_speed = gr.Slider(
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed"
)
batch_btn = gr.Button("Synthesize All", variant="primary") batch_btn = gr.Button("Synthesize All", variant="primary")
batch_status = gr.Textbox(label="Status", interactive=False) batch_status = gr.Textbox(label="Status", interactive=False)
batch_audios = gr.Dataset( batch_audio = gr.Audio(label="Combined Audio", type="numpy")
components=[gr.Audio(type="numpy")], label="Generated Audio Files"
)
# Note: Batch processing would need more complex handling def batch_synthesize(texts_raw: str, lang: str, spd: float):
# This is a simplified version lines = [
gr.Markdown(""" line.strip()
*Note: For batch processing of many texts, consider using the API directly for line in texts_raw.strip().splitlines()
or the Kubeflow pipeline for better throughput.* if line.strip()
""") ]
if not lines:
return "❌ Please enter at least one line of text", None
combined = "\n".join(lines)
status, audio, _ = synthesize_speech(combined, lang, spd)
return status, audio
batch_btn.click(
fn=batch_synthesize,
inputs=[batch_input, batch_lang, batch_speed],
outputs=[batch_status, batch_audio],
)
create_footer() create_footer()