Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0d1c40725e | |||
| dfe93ae856 | |||
| f5a2545ac8 |
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
__pycache__/
|
||||
@@ -8,3 +8,8 @@ resources:
|
||||
- llm.yaml
|
||||
- tts.yaml
|
||||
- stt.yaml
|
||||
|
||||
images:
|
||||
- name: gitea-http.gitea.svc.cluster.local:3000/daviestechlabs/gradio-ui
|
||||
newName: registry.lab.daviestechlabs.io/daviestechlabs/gradio-ui
|
||||
newTag: "0.0.7"
|
||||
|
||||
140
llm.py
140
llm.py
@@ -3,13 +3,14 @@
|
||||
LLM Chat Demo - Gradio UI for testing vLLM inference service.
|
||||
|
||||
Features:
|
||||
- Multi-turn chat with streaming responses
|
||||
- Multi-turn chat with true SSE streaming responses
|
||||
- Configurable temperature, max tokens, top-p
|
||||
- System prompt customisation
|
||||
- Token usage and latency metrics
|
||||
- Chat history management
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
@@ -127,6 +128,27 @@ async_client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=30.0))
|
||||
sync_client = httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0))
|
||||
|
||||
|
||||
def _extract_content(content) -> str:
|
||||
"""Extract plain text from message content.
|
||||
|
||||
Handles both plain strings and Gradio 6.x content-parts format:
|
||||
[{"type": "text", "text": "..."}] or [{"text": "..."}]
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
parts.append(item.get("text", item.get("content", str(item))))
|
||||
elif isinstance(item, str):
|
||||
parts.append(item)
|
||||
else:
|
||||
parts.append(str(item))
|
||||
return "".join(parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
async def chat_stream(
|
||||
message: str,
|
||||
history: list[dict[str, str]],
|
||||
@@ -135,18 +157,23 @@ async def chat_stream(
|
||||
max_tokens: int,
|
||||
top_p: float,
|
||||
):
|
||||
"""Stream chat responses from the vLLM endpoint."""
|
||||
"""Stream chat responses from the vLLM endpoint via SSE."""
|
||||
if not message.strip():
|
||||
yield ""
|
||||
return
|
||||
|
||||
# Build message list from history
|
||||
# Build message list from history, normalising content-parts
|
||||
messages = []
|
||||
if system_prompt.strip():
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
for entry in history:
|
||||
messages.append({"role": entry["role"], "content": entry["content"]})
|
||||
messages.append(
|
||||
{
|
||||
"role": entry["role"],
|
||||
"content": _extract_content(entry["content"]),
|
||||
}
|
||||
)
|
||||
|
||||
messages.append({"role": "user", "content": message})
|
||||
|
||||
@@ -155,45 +182,86 @@ async def chat_stream(
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"top_p": top_p,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
response = await async_client.post(LLM_URL, json=payload)
|
||||
response.raise_for_status()
|
||||
# Try true SSE streaming first
|
||||
async with async_client.stream("POST", LLM_URL, json=payload) as response:
|
||||
response.raise_for_status()
|
||||
content_type = response.headers.get("content-type", "")
|
||||
|
||||
result = response.json()
|
||||
text = result["choices"][0]["message"]["content"]
|
||||
latency = time.time() - start_time
|
||||
usage = result.get("usage", {})
|
||||
if "text/event-stream" in content_type:
|
||||
# SSE streaming — accumulate deltas
|
||||
full_text = ""
|
||||
async for line in response.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data = line[6:]
|
||||
if data.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
delta = (
|
||||
chunk.get("choices", [{}])[0]
|
||||
.get("delta", {})
|
||||
.get("content", "")
|
||||
)
|
||||
if delta:
|
||||
full_text += delta
|
||||
yield full_text
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"LLM response: %d tokens in %.1fs (prompt=%d, completion=%d)",
|
||||
usage.get("total_tokens", 0),
|
||||
latency,
|
||||
usage.get("prompt_tokens", 0),
|
||||
usage.get("completion_tokens", 0),
|
||||
)
|
||||
latency = time.time() - start_time
|
||||
logger.info(
|
||||
"LLM streamed response: %d chars in %.1fs", len(full_text), latency
|
||||
)
|
||||
|
||||
# Log to MLflow
|
||||
_log_llm_metrics(
|
||||
latency=latency,
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
)
|
||||
# Best-effort metrics from the final SSE payload
|
||||
_log_llm_metrics(
|
||||
latency=latency,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=len(full_text.split()),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
)
|
||||
else:
|
||||
# Non-streaming fallback (endpoint doesn't support stream)
|
||||
body = await response.aread()
|
||||
result = json.loads(body)
|
||||
text = _extract_content(result["choices"][0]["message"]["content"])
|
||||
latency = time.time() - start_time
|
||||
usage = result.get("usage", {})
|
||||
|
||||
# Yield text progressively for a nicer streaming feel
|
||||
chunk_size = 4
|
||||
words = text.split(" ")
|
||||
partial = ""
|
||||
for i, word in enumerate(words):
|
||||
partial += ("" if i == 0 else " ") + word
|
||||
if i % chunk_size == 0 or i == len(words) - 1:
|
||||
yield partial
|
||||
logger.info(
|
||||
"LLM response: %d tokens in %.1fs (prompt=%d, completion=%d)",
|
||||
usage.get("total_tokens", 0),
|
||||
latency,
|
||||
usage.get("prompt_tokens", 0),
|
||||
usage.get("completion_tokens", 0),
|
||||
)
|
||||
|
||||
_log_llm_metrics(
|
||||
latency=latency,
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
# Yield text progressively for a nicer feel
|
||||
chunk_size = 4
|
||||
words = text.split(" ")
|
||||
partial = ""
|
||||
for i, word in enumerate(words):
|
||||
partial += ("" if i == 0 else " ") + word
|
||||
if i % chunk_size == 0 or i == len(words) - 1:
|
||||
yield partial
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.exception("LLM request failed")
|
||||
@@ -266,7 +334,7 @@ def single_prompt(
|
||||
result = response.json()
|
||||
latency = time.time() - start_time
|
||||
|
||||
text = result["choices"][0]["message"]["content"]
|
||||
text = _extract_content(result["choices"][0]["message"]["content"])
|
||||
usage = result.get("usage", {})
|
||||
|
||||
# Log to MLflow
|
||||
@@ -325,7 +393,7 @@ Chat with **Llama 3.1 70B** (AWQ INT4) served via vLLM on AMD Strix Halo (ROCm).
|
||||
)
|
||||
with gr.Row():
|
||||
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
|
||||
max_tokens = gr.Slider(16, 4096, value=512, step=16, label="Max Tokens")
|
||||
max_tokens = gr.Slider(16, 8192, value=2048, step=16, label="Max Tokens")
|
||||
top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-p")
|
||||
|
||||
with gr.Tabs():
|
||||
|
||||
2
stt.yaml
2
stt.yaml
@@ -28,7 +28,7 @@ spec:
|
||||
name: http
|
||||
protocol: TCP
|
||||
env:
|
||||
- name: WHISPER_URL
|
||||
- name: STT_URL
|
||||
# Ray Serve endpoint - routes to /whisper prefix
|
||||
value: "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper"
|
||||
- name: MLFLOW_TRACKING_URI
|
||||
|
||||
343
tts.py
343
tts.py
@@ -5,19 +5,20 @@ TTS Demo - Gradio UI for testing Text-to-Speech service.
|
||||
Features:
|
||||
- Text input with language selection
|
||||
- Audio playback of synthesized speech
|
||||
- Voice/speaker selection (when available)
|
||||
- Sentence-level chunking for better quality
|
||||
- Speed control
|
||||
- MLflow metrics logging
|
||||
- Multiple TTS backends support (Coqui XTTS, Piper, etc.)
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import logging
|
||||
import io
|
||||
import wave
|
||||
|
||||
import gradio as gr
|
||||
import httpx
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
|
||||
from theme import get_lab_theme, CUSTOM_CSS, create_footer
|
||||
@@ -126,42 +127,243 @@ LANGUAGES = {
|
||||
"Hungarian": "hu",
|
||||
}
|
||||
|
||||
# ─── Text preprocessing ─────────────────────────────────────────────────
|
||||
|
||||
_SENTENCE_RE = re.compile(r"(?<=[.!?;])\s+|(?<=\n)\s*", re.MULTILINE)
|
||||
|
||||
_DIGIT_WORDS = {
|
||||
"0": "zero",
|
||||
"1": "one",
|
||||
"2": "two",
|
||||
"3": "three",
|
||||
"4": "four",
|
||||
"5": "five",
|
||||
"6": "six",
|
||||
"7": "seven",
|
||||
"8": "eight",
|
||||
"9": "nine",
|
||||
}
|
||||
|
||||
|
||||
def _expand_numbers(text: str) -> str:
|
||||
"""Expand standalone single digits to words for clearer pronunciation."""
|
||||
return re.sub(
|
||||
r"\b(\d)\b",
|
||||
lambda m: _DIGIT_WORDS.get(m.group(0), m.group(0)),
|
||||
text,
|
||||
)
|
||||
|
||||
|
||||
def _clean_text(text: str) -> str:
|
||||
"""Clean and normalise text for TTS input."""
|
||||
text = re.sub(r"[ \t]+", " ", text)
|
||||
text = "\n".join(line.strip() for line in text.splitlines())
|
||||
# Strip markdown / code-fence characters
|
||||
text = re.sub(r"[*#~`|<>{}[\]\\]", "", text)
|
||||
# Expand common symbols
|
||||
text = text.replace("&", " and ")
|
||||
text = text.replace("@", " at ")
|
||||
text = text.replace("%", " percent ")
|
||||
text = text.replace("+", " plus ")
|
||||
text = text.replace("=", " equals ")
|
||||
text = _expand_numbers(text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _split_sentences(text: str) -> list[str]:
|
||||
"""Split text into sentences suitable for TTS.
|
||||
|
||||
Keeps sentences short for best quality while preserving natural phrasing.
|
||||
Very long segments are further split on commas / semicolons.
|
||||
"""
|
||||
text = _clean_text(text)
|
||||
if not text:
|
||||
return []
|
||||
|
||||
raw_parts = _SENTENCE_RE.split(text)
|
||||
sentences: list[str] = []
|
||||
for part in raw_parts:
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
if len(part) > 200:
|
||||
for sp in re.split(r"(?<=[,;])\s+", part):
|
||||
sp = sp.strip()
|
||||
if sp:
|
||||
sentences.append(sp)
|
||||
else:
|
||||
sentences.append(part)
|
||||
return sentences
|
||||
|
||||
|
||||
# ─── Audio helpers ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _read_wav_bytes(data: bytes) -> tuple[int, np.ndarray]:
|
||||
"""Read WAV audio from bytes, handling scipy wavfile and standard WAV.
|
||||
|
||||
Returns (sample_rate, float32_audio) with values in [-1, 1].
|
||||
"""
|
||||
buf = io.BytesIO(data)
|
||||
|
||||
# Try stdlib wave module first — most robust for PCM WAV from scipy
|
||||
try:
|
||||
with wave.open(buf, "rb") as wf:
|
||||
sr = wf.getframerate()
|
||||
n_frames = wf.getnframes()
|
||||
n_channels = wf.getnchannels()
|
||||
sampwidth = wf.getsampwidth()
|
||||
raw = wf.readframes(n_frames)
|
||||
|
||||
if sampwidth == 2:
|
||||
audio = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
elif sampwidth == 4:
|
||||
audio = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2147483648.0
|
||||
elif sampwidth == 1:
|
||||
audio = (
|
||||
np.frombuffer(raw, dtype=np.uint8).astype(np.float32) - 128.0
|
||||
) / 128.0
|
||||
else:
|
||||
raise ValueError(f"Unsupported sample width: {sampwidth}")
|
||||
|
||||
if n_channels > 1:
|
||||
audio = audio.reshape(-1, n_channels).mean(axis=1)
|
||||
|
||||
return sr, audio
|
||||
except Exception as exc:
|
||||
logger.debug("wave module failed (%s), trying soundfile", exc)
|
||||
|
||||
# Fallback: soundfile (handles FLAC, OGG, etc.)
|
||||
buf.seek(0)
|
||||
try:
|
||||
import soundfile as sf
|
||||
|
||||
audio, sr = sf.read(buf, dtype="float32")
|
||||
if audio.ndim > 1:
|
||||
audio = audio.mean(axis=1)
|
||||
return sr, audio
|
||||
except Exception as exc:
|
||||
logger.debug("soundfile failed (%s), attempting raw PCM", exc)
|
||||
|
||||
# Last resort: raw 16-bit PCM at 22050 Hz
|
||||
logger.warning(
|
||||
"Could not parse WAV header (len=%d, first 4 bytes=%r); raw PCM decode",
|
||||
len(data),
|
||||
data[:4],
|
||||
)
|
||||
audio = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
return 22050, audio
|
||||
|
||||
|
||||
def _concat_audio(
|
||||
chunks: list[tuple[int, np.ndarray]], pause_ms: int = 200
|
||||
) -> tuple[int, np.ndarray]:
|
||||
"""Concatenate (sample_rate, audio) chunks with silence gaps."""
|
||||
if not chunks:
|
||||
return 22050, np.array([], dtype=np.float32)
|
||||
if len(chunks) == 1:
|
||||
return chunks[0]
|
||||
|
||||
sr = chunks[0][0]
|
||||
silence = np.zeros(int(sr * pause_ms / 1000), dtype=np.float32)
|
||||
|
||||
parts: list[np.ndarray] = []
|
||||
for sample_rate, audio in chunks:
|
||||
if sample_rate != sr:
|
||||
ratio = sr / sample_rate
|
||||
indices = np.arange(0, len(audio), 1.0 / ratio).astype(int)
|
||||
indices = indices[indices < len(audio)]
|
||||
audio = audio[indices]
|
||||
parts.append(audio)
|
||||
parts.append(silence)
|
||||
|
||||
if parts:
|
||||
parts.pop() # remove trailing silence
|
||||
return sr, np.concatenate(parts)
|
||||
|
||||
|
||||
# ─── TTS synthesis ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _synthesize_chunk(text: str, lang_code: str, speed: float = 1.0) -> bytes:
|
||||
"""Synthesize a single text chunk via the TTS backend.
|
||||
|
||||
Uses the JSON POST endpoint (no URL length limits, supports speed).
|
||||
Falls back to the Coqui-compatible GET endpoint if POST fails.
|
||||
"""
|
||||
import base64 as b64
|
||||
|
||||
# Try JSON POST first
|
||||
try:
|
||||
resp = client.post(
|
||||
TTS_URL,
|
||||
json={
|
||||
"text": text,
|
||||
"language": lang_code,
|
||||
"speed": speed,
|
||||
"return_base64": True,
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
ct = resp.headers.get("content-type", "")
|
||||
if "application/json" in ct:
|
||||
body = resp.json()
|
||||
if "error" in body:
|
||||
raise RuntimeError(body["error"])
|
||||
audio_b64 = body.get("audio", "")
|
||||
if audio_b64:
|
||||
return b64.b64decode(audio_b64)
|
||||
# Non-JSON response — treat as raw audio bytes
|
||||
return resp.content
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"POST endpoint failed, falling back to GET /api/tts", exc_info=True
|
||||
)
|
||||
|
||||
# Fallback: Coqui-compatible GET (no speed control)
|
||||
resp = client.get(
|
||||
f"{TTS_URL}/api/tts",
|
||||
params={"text": text, "language_id": lang_code},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
|
||||
def synthesize_speech(
|
||||
text: str, language: str
|
||||
text: str, language: str, speed: float
|
||||
) -> tuple[str, tuple[int, np.ndarray] | None, str]:
|
||||
"""Synthesize speech from text using the TTS service."""
|
||||
"""Synthesize speech from text using the TTS service.
|
||||
|
||||
Long text is split into sentences and synthesized individually
|
||||
for better quality, then concatenated with natural pauses.
|
||||
"""
|
||||
if not text.strip():
|
||||
return "❌ Please enter some text", None, ""
|
||||
|
||||
lang_code = LANGUAGES.get(language, "en")
|
||||
sentences = _split_sentences(text)
|
||||
if not sentences:
|
||||
return "❌ No speakable text found after cleaning", None, ""
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
audio_chunks: list[tuple[int, np.ndarray]] = []
|
||||
|
||||
# Call TTS service (Coqui XTTS API format)
|
||||
response = client.get(
|
||||
f"{TTS_URL}/api/tts", params={"text": text, "language_id": lang_code}
|
||||
)
|
||||
response.raise_for_status()
|
||||
for sentence in sentences:
|
||||
raw_audio = _synthesize_chunk(sentence, lang_code, speed)
|
||||
sr, audio = _read_wav_bytes(raw_audio)
|
||||
audio_chunks.append((sr, audio))
|
||||
|
||||
sample_rate, audio_data = _concat_audio(audio_chunks)
|
||||
latency = time.time() - start_time
|
||||
audio_bytes = response.content
|
||||
duration = len(audio_data) / sample_rate if sample_rate > 0 else 0
|
||||
|
||||
# Parse audio data
|
||||
audio_io = io.BytesIO(audio_bytes)
|
||||
audio_data, sample_rate = sf.read(audio_io)
|
||||
n_chunks = len(sentences)
|
||||
status = (
|
||||
f"✅ Generated {duration:.2f}s of audio in {latency * 1000:.0f}ms"
|
||||
f" ({n_chunks} sentence{'s' if n_chunks != 1 else ''})"
|
||||
)
|
||||
|
||||
# Calculate duration
|
||||
if len(audio_data.shape) == 1:
|
||||
duration = len(audio_data) / sample_rate
|
||||
else:
|
||||
duration = len(audio_data) / sample_rate
|
||||
|
||||
# Status message
|
||||
status = f"✅ Generated {duration:.2f}s of audio in {latency * 1000:.0f}ms"
|
||||
|
||||
# Log to MLflow
|
||||
_log_tts_metrics(
|
||||
latency=latency,
|
||||
audio_duration=duration,
|
||||
@@ -169,19 +371,19 @@ def synthesize_speech(
|
||||
language=lang_code,
|
||||
)
|
||||
|
||||
# Metrics
|
||||
metrics = f"""
|
||||
**Audio Statistics:**
|
||||
- Duration: {duration:.2f} seconds
|
||||
- Sample Rate: {sample_rate} Hz
|
||||
- Size: {len(audio_bytes) / 1024:.1f} KB
|
||||
- Size: {len(audio_data) * 2 / 1024:.1f} KB
|
||||
- Generation Time: {latency * 1000:.0f}ms
|
||||
- Real-time Factor: {latency / duration:.2f}x
|
||||
- Language: {language} ({lang_code})
|
||||
- Speed: {speed:.1f}x
|
||||
- Sentences: {n_chunks}
|
||||
- Characters: {len(text)}
|
||||
- Chars/sec: {len(text) / latency:.1f}
|
||||
"""
|
||||
|
||||
return status, (sample_rate, audio_data), metrics
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
@@ -189,37 +391,33 @@ def synthesize_speech(
|
||||
return f"❌ TTS service error: {e.response.status_code}", None, ""
|
||||
except Exception as e:
|
||||
logger.exception("TTS synthesis failed")
|
||||
return f"❌ Error: {str(e)}", None, ""
|
||||
return f"❌ Error: {e}", None, ""
|
||||
|
||||
|
||||
def check_service_health() -> str:
|
||||
"""Check if the TTS service is healthy."""
|
||||
try:
|
||||
# Try the health endpoint first
|
||||
response = client.get(f"{TTS_URL}/health", timeout=5.0)
|
||||
if response.status_code == 200:
|
||||
return "🟢 Service is healthy"
|
||||
|
||||
# Fall back to root endpoint
|
||||
response = client.get(f"{TTS_URL}/", timeout=5.0)
|
||||
if response.status_code == 200:
|
||||
return "🟢 Service is responding"
|
||||
|
||||
return f"🟡 Service returned status {response.status_code}"
|
||||
except Exception as e:
|
||||
return f"🔴 Service unavailable: {str(e)}"
|
||||
return f"🔴 Service unavailable: {e}"
|
||||
|
||||
|
||||
# Build the Gradio app
|
||||
# ─── Gradio UI ───────────────────────────────────────────────────────────
|
||||
|
||||
with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="TTS Demo") as demo:
|
||||
gr.Markdown("""
|
||||
# 🔊 Text-to-Speech Demo
|
||||
|
||||
Test the **Coqui XTTS** text-to-speech service. Convert text to natural-sounding speech
|
||||
in multiple languages.
|
||||
in multiple languages. Long text is automatically split into sentences for better quality.
|
||||
""")
|
||||
|
||||
# Service status
|
||||
with gr.Row():
|
||||
health_btn = gr.Button("🔄 Check Service", size="sm")
|
||||
health_status = gr.Textbox(label="Service Status", interactive=False)
|
||||
@@ -227,7 +425,6 @@ in multiple languages.
|
||||
health_btn.click(fn=check_service_health, outputs=health_status)
|
||||
|
||||
with gr.Tabs():
|
||||
# Tab 1: Basic TTS
|
||||
with gr.TabItem("🎤 Text to Speech"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
@@ -237,17 +434,24 @@ in multiple languages.
|
||||
lines=5,
|
||||
max_lines=10,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
language = gr.Dropdown(
|
||||
choices=list(LANGUAGES.keys()),
|
||||
value="English",
|
||||
label="Language",
|
||||
)
|
||||
synthesize_btn = gr.Button(
|
||||
"🔊 Synthesize", variant="primary", scale=2
|
||||
speed = gr.Slider(
|
||||
minimum=0.5,
|
||||
maximum=2.0,
|
||||
value=1.0,
|
||||
step=0.1,
|
||||
label="Speed",
|
||||
)
|
||||
synthesize_btn = gr.Button(
|
||||
"🔊 Synthesize",
|
||||
variant="primary",
|
||||
scale=2,
|
||||
)
|
||||
|
||||
with gr.Column(scale=1):
|
||||
status_output = gr.Textbox(label="Status", interactive=False)
|
||||
metrics_output = gr.Markdown(label="Metrics")
|
||||
@@ -256,39 +460,39 @@ in multiple languages.
|
||||
|
||||
synthesize_btn.click(
|
||||
fn=synthesize_speech,
|
||||
inputs=[text_input, language],
|
||||
inputs=[text_input, language, speed],
|
||||
outputs=[status_output, audio_output, metrics_output],
|
||||
)
|
||||
|
||||
# Example texts
|
||||
gr.Examples(
|
||||
examples=[
|
||||
[
|
||||
"Hello! Welcome to Davies Tech Labs. This is a demonstration of our text-to-speech system.",
|
||||
"English",
|
||||
1.0,
|
||||
],
|
||||
[
|
||||
"The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet.",
|
||||
"English",
|
||||
1.0,
|
||||
],
|
||||
[
|
||||
"Bonjour! Bienvenue au laboratoire technique de Davies.",
|
||||
"French",
|
||||
1.0,
|
||||
],
|
||||
["Hola! Bienvenido al laboratorio de tecnología.", "Spanish"],
|
||||
["Guten Tag! Willkommen im Techniklabor.", "German"],
|
||||
["Hola! Bienvenido al laboratorio de tecnología.", "Spanish", 1.0],
|
||||
["Guten Tag! Willkommen im Techniklabor.", "German", 1.0],
|
||||
],
|
||||
inputs=[text_input, language],
|
||||
inputs=[text_input, language, speed],
|
||||
)
|
||||
|
||||
# Tab 2: Comparison
|
||||
with gr.TabItem("🔄 Language Comparison"):
|
||||
gr.Markdown("Compare the same text in different languages.")
|
||||
|
||||
compare_text = gr.Textbox(
|
||||
label="Text to Compare", value="Hello, how are you today?", lines=2
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
lang1 = gr.Dropdown(
|
||||
choices=list(LANGUAGES.keys()), value="English", label="Language 1"
|
||||
@@ -296,6 +500,9 @@ in multiple languages.
|
||||
lang2 = gr.Dropdown(
|
||||
choices=list(LANGUAGES.keys()), value="Spanish", label="Language 2"
|
||||
)
|
||||
compare_speed = gr.Slider(
|
||||
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed"
|
||||
)
|
||||
|
||||
compare_btn = gr.Button("Compare Languages", variant="primary")
|
||||
|
||||
@@ -304,24 +511,22 @@ in multiple languages.
|
||||
gr.Markdown("### Language 1")
|
||||
audio1 = gr.Audio(label="Audio 1", type="numpy")
|
||||
status1 = gr.Textbox(label="Status", interactive=False)
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("### Language 2")
|
||||
audio2 = gr.Audio(label="Audio 2", type="numpy")
|
||||
status2 = gr.Textbox(label="Status", interactive=False)
|
||||
|
||||
def compare_languages(text, l1, l2):
|
||||
s1, a1, _ = synthesize_speech(text, l1)
|
||||
s2, a2, _ = synthesize_speech(text, l2)
|
||||
def compare_languages(text, l1, l2, spd):
|
||||
s1, a1, _ = synthesize_speech(text, l1, spd)
|
||||
s2, a2, _ = synthesize_speech(text, l2, spd)
|
||||
return s1, a1, s2, a2
|
||||
|
||||
compare_btn.click(
|
||||
fn=compare_languages,
|
||||
inputs=[compare_text, lang1, lang2],
|
||||
inputs=[compare_text, lang1, lang2, compare_speed],
|
||||
outputs=[status1, audio1, status2, audio2],
|
||||
)
|
||||
|
||||
# Tab 3: Batch Processing
|
||||
with gr.TabItem("📚 Batch Synthesis"):
|
||||
gr.Markdown("Synthesize multiple texts at once (one per line).")
|
||||
|
||||
@@ -333,19 +538,31 @@ in multiple languages.
|
||||
batch_lang = gr.Dropdown(
|
||||
choices=list(LANGUAGES.keys()), value="English", label="Language"
|
||||
)
|
||||
batch_speed = gr.Slider(
|
||||
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed"
|
||||
)
|
||||
batch_btn = gr.Button("Synthesize All", variant="primary")
|
||||
|
||||
batch_status = gr.Textbox(label="Status", interactive=False)
|
||||
batch_audios = gr.Dataset(
|
||||
components=[gr.Audio(type="numpy")], label="Generated Audio Files"
|
||||
)
|
||||
batch_audio = gr.Audio(label="Combined Audio", type="numpy")
|
||||
|
||||
# Note: Batch processing would need more complex handling
|
||||
# This is a simplified version
|
||||
gr.Markdown("""
|
||||
*Note: For batch processing of many texts, consider using the API directly
|
||||
or the Kubeflow pipeline for better throughput.*
|
||||
""")
|
||||
def batch_synthesize(texts_raw: str, lang: str, spd: float):
|
||||
lines = [
|
||||
line.strip()
|
||||
for line in texts_raw.strip().splitlines()
|
||||
if line.strip()
|
||||
]
|
||||
if not lines:
|
||||
return "❌ Please enter at least one line of text", None
|
||||
combined = "\n".join(lines)
|
||||
status, audio, _ = synthesize_speech(combined, lang, spd)
|
||||
return status, audio
|
||||
|
||||
batch_btn.click(
|
||||
fn=batch_synthesize,
|
||||
inputs=[batch_input, batch_lang, batch_speed],
|
||||
outputs=[batch_status, batch_audio],
|
||||
)
|
||||
|
||||
create_footer()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user