Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0d1c40725e | |||
| dfe93ae856 | |||
| f5a2545ac8 | |||
| c050d11ab4 |
@@ -83,26 +83,10 @@ jobs:
|
|||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
with:
|
|
||||||
buildkitd-config-inline: |
|
|
||||||
[registry."gitea-http.gitea.svc.cluster.local:3000"]
|
|
||||||
http = true
|
|
||||||
insecure = true
|
|
||||||
|
|
||||||
- name: Configure Docker for insecure registry
|
|
||||||
run: |
|
|
||||||
sudo mkdir -p /etc/docker
|
|
||||||
echo '{"insecure-registries": ["${{ env.REGISTRY_HOST }}"]}' | sudo tee /etc/docker/daemon.json
|
|
||||||
sudo systemctl restart docker || sudo service docker restart || true
|
|
||||||
sleep 2
|
|
||||||
|
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
if: vars.DOCKERHUB_USERNAME != ''
|
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ vars.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
|
|
||||||
- name: Login to Gitea Registry
|
- name: Login to Gitea Registry
|
||||||
@@ -112,6 +96,14 @@ jobs:
|
|||||||
username: ${{ secrets.REGISTRY_USER }}
|
username: ${{ secrets.REGISTRY_USER }}
|
||||||
password: ${{ secrets.REGISTRY_TOKEN }}
|
password: ${{ secrets.REGISTRY_TOKEN }}
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
with:
|
||||||
|
buildkitd-config-inline: |
|
||||||
|
[registry."gitea-http.gitea.svc.cluster.local:3000"]
|
||||||
|
http = true
|
||||||
|
insecure = true
|
||||||
|
|
||||||
- name: Extract metadata
|
- name: Extract metadata
|
||||||
id: meta
|
id: meta
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
|
|||||||
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
__pycache__/
|
||||||
@@ -8,3 +8,8 @@ resources:
|
|||||||
- llm.yaml
|
- llm.yaml
|
||||||
- tts.yaml
|
- tts.yaml
|
||||||
- stt.yaml
|
- stt.yaml
|
||||||
|
|
||||||
|
images:
|
||||||
|
- name: gitea-http.gitea.svc.cluster.local:3000/daviestechlabs/gradio-ui
|
||||||
|
newName: registry.lab.daviestechlabs.io/daviestechlabs/gradio-ui
|
||||||
|
newTag: "0.0.7"
|
||||||
|
|||||||
140
llm.py
140
llm.py
@@ -3,13 +3,14 @@
|
|||||||
LLM Chat Demo - Gradio UI for testing vLLM inference service.
|
LLM Chat Demo - Gradio UI for testing vLLM inference service.
|
||||||
|
|
||||||
Features:
|
Features:
|
||||||
- Multi-turn chat with streaming responses
|
- Multi-turn chat with true SSE streaming responses
|
||||||
- Configurable temperature, max tokens, top-p
|
- Configurable temperature, max tokens, top-p
|
||||||
- System prompt customisation
|
- System prompt customisation
|
||||||
- Token usage and latency metrics
|
- Token usage and latency metrics
|
||||||
- Chat history management
|
- Chat history management
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
@@ -127,6 +128,27 @@ async_client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=30.0))
|
|||||||
sync_client = httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0))
|
sync_client = httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0))
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_content(content) -> str:
|
||||||
|
"""Extract plain text from message content.
|
||||||
|
|
||||||
|
Handles both plain strings and Gradio 6.x content-parts format:
|
||||||
|
[{"type": "text", "text": "..."}] or [{"text": "..."}]
|
||||||
|
"""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
parts.append(item.get("text", item.get("content", str(item))))
|
||||||
|
elif isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
else:
|
||||||
|
parts.append(str(item))
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
async def chat_stream(
|
async def chat_stream(
|
||||||
message: str,
|
message: str,
|
||||||
history: list[dict[str, str]],
|
history: list[dict[str, str]],
|
||||||
@@ -135,18 +157,23 @@ async def chat_stream(
|
|||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
):
|
):
|
||||||
"""Stream chat responses from the vLLM endpoint."""
|
"""Stream chat responses from the vLLM endpoint via SSE."""
|
||||||
if not message.strip():
|
if not message.strip():
|
||||||
yield ""
|
yield ""
|
||||||
return
|
return
|
||||||
|
|
||||||
# Build message list from history
|
# Build message list from history, normalising content-parts
|
||||||
messages = []
|
messages = []
|
||||||
if system_prompt.strip():
|
if system_prompt.strip():
|
||||||
messages.append({"role": "system", "content": system_prompt})
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
|
||||||
for entry in history:
|
for entry in history:
|
||||||
messages.append({"role": entry["role"], "content": entry["content"]})
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": entry["role"],
|
||||||
|
"content": _extract_content(entry["content"]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
messages.append({"role": "user", "content": message})
|
messages.append({"role": "user", "content": message})
|
||||||
|
|
||||||
@@ -155,45 +182,86 @@ async def chat_stream(
|
|||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
|
"stream": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await async_client.post(LLM_URL, json=payload)
|
# Try true SSE streaming first
|
||||||
response.raise_for_status()
|
async with async_client.stream("POST", LLM_URL, json=payload) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
content_type = response.headers.get("content-type", "")
|
||||||
|
|
||||||
result = response.json()
|
if "text/event-stream" in content_type:
|
||||||
text = result["choices"][0]["message"]["content"]
|
# SSE streaming — accumulate deltas
|
||||||
latency = time.time() - start_time
|
full_text = ""
|
||||||
usage = result.get("usage", {})
|
async for line in response.aiter_lines():
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
data = line[6:]
|
||||||
|
if data.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
chunk = json.loads(data)
|
||||||
|
delta = (
|
||||||
|
chunk.get("choices", [{}])[0]
|
||||||
|
.get("delta", {})
|
||||||
|
.get("content", "")
|
||||||
|
)
|
||||||
|
if delta:
|
||||||
|
full_text += delta
|
||||||
|
yield full_text
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
logger.info(
|
latency = time.time() - start_time
|
||||||
"LLM response: %d tokens in %.1fs (prompt=%d, completion=%d)",
|
logger.info(
|
||||||
usage.get("total_tokens", 0),
|
"LLM streamed response: %d chars in %.1fs", len(full_text), latency
|
||||||
latency,
|
)
|
||||||
usage.get("prompt_tokens", 0),
|
|
||||||
usage.get("completion_tokens", 0),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log to MLflow
|
# Best-effort metrics from the final SSE payload
|
||||||
_log_llm_metrics(
|
_log_llm_metrics(
|
||||||
latency=latency,
|
latency=latency,
|
||||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
prompt_tokens=0,
|
||||||
completion_tokens=usage.get("completion_tokens", 0),
|
completion_tokens=len(full_text.split()),
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# Non-streaming fallback (endpoint doesn't support stream)
|
||||||
|
body = await response.aread()
|
||||||
|
result = json.loads(body)
|
||||||
|
text = _extract_content(result["choices"][0]["message"]["content"])
|
||||||
|
latency = time.time() - start_time
|
||||||
|
usage = result.get("usage", {})
|
||||||
|
|
||||||
# Yield text progressively for a nicer streaming feel
|
logger.info(
|
||||||
chunk_size = 4
|
"LLM response: %d tokens in %.1fs (prompt=%d, completion=%d)",
|
||||||
words = text.split(" ")
|
usage.get("total_tokens", 0),
|
||||||
partial = ""
|
latency,
|
||||||
for i, word in enumerate(words):
|
usage.get("prompt_tokens", 0),
|
||||||
partial += ("" if i == 0 else " ") + word
|
usage.get("completion_tokens", 0),
|
||||||
if i % chunk_size == 0 or i == len(words) - 1:
|
)
|
||||||
yield partial
|
|
||||||
|
_log_llm_metrics(
|
||||||
|
latency=latency,
|
||||||
|
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||||
|
completion_tokens=usage.get("completion_tokens", 0),
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Yield text progressively for a nicer feel
|
||||||
|
chunk_size = 4
|
||||||
|
words = text.split(" ")
|
||||||
|
partial = ""
|
||||||
|
for i, word in enumerate(words):
|
||||||
|
partial += ("" if i == 0 else " ") + word
|
||||||
|
if i % chunk_size == 0 or i == len(words) - 1:
|
||||||
|
yield partial
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
logger.exception("LLM request failed")
|
logger.exception("LLM request failed")
|
||||||
@@ -266,7 +334,7 @@ def single_prompt(
|
|||||||
result = response.json()
|
result = response.json()
|
||||||
latency = time.time() - start_time
|
latency = time.time() - start_time
|
||||||
|
|
||||||
text = result["choices"][0]["message"]["content"]
|
text = _extract_content(result["choices"][0]["message"]["content"])
|
||||||
usage = result.get("usage", {})
|
usage = result.get("usage", {})
|
||||||
|
|
||||||
# Log to MLflow
|
# Log to MLflow
|
||||||
@@ -325,7 +393,7 @@ Chat with **Llama 3.1 70B** (AWQ INT4) served via vLLM on AMD Strix Halo (ROCm).
|
|||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
|
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
|
||||||
max_tokens = gr.Slider(16, 4096, value=512, step=16, label="Max Tokens")
|
max_tokens = gr.Slider(16, 8192, value=2048, step=16, label="Max Tokens")
|
||||||
top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-p")
|
top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-p")
|
||||||
|
|
||||||
with gr.Tabs():
|
with gr.Tabs():
|
||||||
|
|||||||
2
stt.yaml
2
stt.yaml
@@ -28,7 +28,7 @@ spec:
|
|||||||
name: http
|
name: http
|
||||||
protocol: TCP
|
protocol: TCP
|
||||||
env:
|
env:
|
||||||
- name: WHISPER_URL
|
- name: STT_URL
|
||||||
# Ray Serve endpoint - routes to /whisper prefix
|
# Ray Serve endpoint - routes to /whisper prefix
|
||||||
value: "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper"
|
value: "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper"
|
||||||
- name: MLFLOW_TRACKING_URI
|
- name: MLFLOW_TRACKING_URI
|
||||||
|
|||||||
343
tts.py
343
tts.py
@@ -5,19 +5,20 @@ TTS Demo - Gradio UI for testing Text-to-Speech service.
|
|||||||
Features:
|
Features:
|
||||||
- Text input with language selection
|
- Text input with language selection
|
||||||
- Audio playback of synthesized speech
|
- Audio playback of synthesized speech
|
||||||
- Voice/speaker selection (when available)
|
- Sentence-level chunking for better quality
|
||||||
|
- Speed control
|
||||||
- MLflow metrics logging
|
- MLflow metrics logging
|
||||||
- Multiple TTS backends support (Coqui XTTS, Piper, etc.)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import io
|
import io
|
||||||
|
import wave
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import httpx
|
import httpx
|
||||||
import soundfile as sf
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from theme import get_lab_theme, CUSTOM_CSS, create_footer
|
from theme import get_lab_theme, CUSTOM_CSS, create_footer
|
||||||
@@ -126,42 +127,243 @@ LANGUAGES = {
|
|||||||
"Hungarian": "hu",
|
"Hungarian": "hu",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# ─── Text preprocessing ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_SENTENCE_RE = re.compile(r"(?<=[.!?;])\s+|(?<=\n)\s*", re.MULTILINE)
|
||||||
|
|
||||||
|
_DIGIT_WORDS = {
|
||||||
|
"0": "zero",
|
||||||
|
"1": "one",
|
||||||
|
"2": "two",
|
||||||
|
"3": "three",
|
||||||
|
"4": "four",
|
||||||
|
"5": "five",
|
||||||
|
"6": "six",
|
||||||
|
"7": "seven",
|
||||||
|
"8": "eight",
|
||||||
|
"9": "nine",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_numbers(text: str) -> str:
|
||||||
|
"""Expand standalone single digits to words for clearer pronunciation."""
|
||||||
|
return re.sub(
|
||||||
|
r"\b(\d)\b",
|
||||||
|
lambda m: _DIGIT_WORDS.get(m.group(0), m.group(0)),
|
||||||
|
text,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_text(text: str) -> str:
|
||||||
|
"""Clean and normalise text for TTS input."""
|
||||||
|
text = re.sub(r"[ \t]+", " ", text)
|
||||||
|
text = "\n".join(line.strip() for line in text.splitlines())
|
||||||
|
# Strip markdown / code-fence characters
|
||||||
|
text = re.sub(r"[*#~`|<>{}[\]\\]", "", text)
|
||||||
|
# Expand common symbols
|
||||||
|
text = text.replace("&", " and ")
|
||||||
|
text = text.replace("@", " at ")
|
||||||
|
text = text.replace("%", " percent ")
|
||||||
|
text = text.replace("+", " plus ")
|
||||||
|
text = text.replace("=", " equals ")
|
||||||
|
text = _expand_numbers(text)
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _split_sentences(text: str) -> list[str]:
|
||||||
|
"""Split text into sentences suitable for TTS.
|
||||||
|
|
||||||
|
Keeps sentences short for best quality while preserving natural phrasing.
|
||||||
|
Very long segments are further split on commas / semicolons.
|
||||||
|
"""
|
||||||
|
text = _clean_text(text)
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
raw_parts = _SENTENCE_RE.split(text)
|
||||||
|
sentences: list[str] = []
|
||||||
|
for part in raw_parts:
|
||||||
|
part = part.strip()
|
||||||
|
if not part:
|
||||||
|
continue
|
||||||
|
if len(part) > 200:
|
||||||
|
for sp in re.split(r"(?<=[,;])\s+", part):
|
||||||
|
sp = sp.strip()
|
||||||
|
if sp:
|
||||||
|
sentences.append(sp)
|
||||||
|
else:
|
||||||
|
sentences.append(part)
|
||||||
|
return sentences
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Audio helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _read_wav_bytes(data: bytes) -> tuple[int, np.ndarray]:
|
||||||
|
"""Read WAV audio from bytes, handling scipy wavfile and standard WAV.
|
||||||
|
|
||||||
|
Returns (sample_rate, float32_audio) with values in [-1, 1].
|
||||||
|
"""
|
||||||
|
buf = io.BytesIO(data)
|
||||||
|
|
||||||
|
# Try stdlib wave module first — most robust for PCM WAV from scipy
|
||||||
|
try:
|
||||||
|
with wave.open(buf, "rb") as wf:
|
||||||
|
sr = wf.getframerate()
|
||||||
|
n_frames = wf.getnframes()
|
||||||
|
n_channels = wf.getnchannels()
|
||||||
|
sampwidth = wf.getsampwidth()
|
||||||
|
raw = wf.readframes(n_frames)
|
||||||
|
|
||||||
|
if sampwidth == 2:
|
||||||
|
audio = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
elif sampwidth == 4:
|
||||||
|
audio = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2147483648.0
|
||||||
|
elif sampwidth == 1:
|
||||||
|
audio = (
|
||||||
|
np.frombuffer(raw, dtype=np.uint8).astype(np.float32) - 128.0
|
||||||
|
) / 128.0
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported sample width: {sampwidth}")
|
||||||
|
|
||||||
|
if n_channels > 1:
|
||||||
|
audio = audio.reshape(-1, n_channels).mean(axis=1)
|
||||||
|
|
||||||
|
return sr, audio
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("wave module failed (%s), trying soundfile", exc)
|
||||||
|
|
||||||
|
# Fallback: soundfile (handles FLAC, OGG, etc.)
|
||||||
|
buf.seek(0)
|
||||||
|
try:
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
audio, sr = sf.read(buf, dtype="float32")
|
||||||
|
if audio.ndim > 1:
|
||||||
|
audio = audio.mean(axis=1)
|
||||||
|
return sr, audio
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("soundfile failed (%s), attempting raw PCM", exc)
|
||||||
|
|
||||||
|
# Last resort: raw 16-bit PCM at 22050 Hz
|
||||||
|
logger.warning(
|
||||||
|
"Could not parse WAV header (len=%d, first 4 bytes=%r); raw PCM decode",
|
||||||
|
len(data),
|
||||||
|
data[:4],
|
||||||
|
)
|
||||||
|
audio = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
return 22050, audio
|
||||||
|
|
||||||
|
|
||||||
|
def _concat_audio(
|
||||||
|
chunks: list[tuple[int, np.ndarray]], pause_ms: int = 200
|
||||||
|
) -> tuple[int, np.ndarray]:
|
||||||
|
"""Concatenate (sample_rate, audio) chunks with silence gaps."""
|
||||||
|
if not chunks:
|
||||||
|
return 22050, np.array([], dtype=np.float32)
|
||||||
|
if len(chunks) == 1:
|
||||||
|
return chunks[0]
|
||||||
|
|
||||||
|
sr = chunks[0][0]
|
||||||
|
silence = np.zeros(int(sr * pause_ms / 1000), dtype=np.float32)
|
||||||
|
|
||||||
|
parts: list[np.ndarray] = []
|
||||||
|
for sample_rate, audio in chunks:
|
||||||
|
if sample_rate != sr:
|
||||||
|
ratio = sr / sample_rate
|
||||||
|
indices = np.arange(0, len(audio), 1.0 / ratio).astype(int)
|
||||||
|
indices = indices[indices < len(audio)]
|
||||||
|
audio = audio[indices]
|
||||||
|
parts.append(audio)
|
||||||
|
parts.append(silence)
|
||||||
|
|
||||||
|
if parts:
|
||||||
|
parts.pop() # remove trailing silence
|
||||||
|
return sr, np.concatenate(parts)
|
||||||
|
|
||||||
|
|
||||||
|
# ─── TTS synthesis ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _synthesize_chunk(text: str, lang_code: str, speed: float = 1.0) -> bytes:
|
||||||
|
"""Synthesize a single text chunk via the TTS backend.
|
||||||
|
|
||||||
|
Uses the JSON POST endpoint (no URL length limits, supports speed).
|
||||||
|
Falls back to the Coqui-compatible GET endpoint if POST fails.
|
||||||
|
"""
|
||||||
|
import base64 as b64
|
||||||
|
|
||||||
|
# Try JSON POST first
|
||||||
|
try:
|
||||||
|
resp = client.post(
|
||||||
|
TTS_URL,
|
||||||
|
json={
|
||||||
|
"text": text,
|
||||||
|
"language": lang_code,
|
||||||
|
"speed": speed,
|
||||||
|
"return_base64": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
ct = resp.headers.get("content-type", "")
|
||||||
|
if "application/json" in ct:
|
||||||
|
body = resp.json()
|
||||||
|
if "error" in body:
|
||||||
|
raise RuntimeError(body["error"])
|
||||||
|
audio_b64 = body.get("audio", "")
|
||||||
|
if audio_b64:
|
||||||
|
return b64.b64decode(audio_b64)
|
||||||
|
# Non-JSON response — treat as raw audio bytes
|
||||||
|
return resp.content
|
||||||
|
except Exception:
|
||||||
|
logger.debug(
|
||||||
|
"POST endpoint failed, falling back to GET /api/tts", exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback: Coqui-compatible GET (no speed control)
|
||||||
|
resp = client.get(
|
||||||
|
f"{TTS_URL}/api/tts",
|
||||||
|
params={"text": text, "language_id": lang_code},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.content
|
||||||
|
|
||||||
|
|
||||||
def synthesize_speech(
|
def synthesize_speech(
|
||||||
text: str, language: str
|
text: str, language: str, speed: float
|
||||||
) -> tuple[str, tuple[int, np.ndarray] | None, str]:
|
) -> tuple[str, tuple[int, np.ndarray] | None, str]:
|
||||||
"""Synthesize speech from text using the TTS service."""
|
"""Synthesize speech from text using the TTS service.
|
||||||
|
|
||||||
|
Long text is split into sentences and synthesized individually
|
||||||
|
for better quality, then concatenated with natural pauses.
|
||||||
|
"""
|
||||||
if not text.strip():
|
if not text.strip():
|
||||||
return "❌ Please enter some text", None, ""
|
return "❌ Please enter some text", None, ""
|
||||||
|
|
||||||
lang_code = LANGUAGES.get(language, "en")
|
lang_code = LANGUAGES.get(language, "en")
|
||||||
|
sentences = _split_sentences(text)
|
||||||
|
if not sentences:
|
||||||
|
return "❌ No speakable text found after cleaning", None, ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
audio_chunks: list[tuple[int, np.ndarray]] = []
|
||||||
|
|
||||||
# Call TTS service (Coqui XTTS API format)
|
for sentence in sentences:
|
||||||
response = client.get(
|
raw_audio = _synthesize_chunk(sentence, lang_code, speed)
|
||||||
f"{TTS_URL}/api/tts", params={"text": text, "language_id": lang_code}
|
sr, audio = _read_wav_bytes(raw_audio)
|
||||||
)
|
audio_chunks.append((sr, audio))
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
|
sample_rate, audio_data = _concat_audio(audio_chunks)
|
||||||
latency = time.time() - start_time
|
latency = time.time() - start_time
|
||||||
audio_bytes = response.content
|
duration = len(audio_data) / sample_rate if sample_rate > 0 else 0
|
||||||
|
|
||||||
# Parse audio data
|
n_chunks = len(sentences)
|
||||||
audio_io = io.BytesIO(audio_bytes)
|
status = (
|
||||||
audio_data, sample_rate = sf.read(audio_io)
|
f"✅ Generated {duration:.2f}s of audio in {latency * 1000:.0f}ms"
|
||||||
|
f" ({n_chunks} sentence{'s' if n_chunks != 1 else ''})"
|
||||||
|
)
|
||||||
|
|
||||||
# Calculate duration
|
|
||||||
if len(audio_data.shape) == 1:
|
|
||||||
duration = len(audio_data) / sample_rate
|
|
||||||
else:
|
|
||||||
duration = len(audio_data) / sample_rate
|
|
||||||
|
|
||||||
# Status message
|
|
||||||
status = f"✅ Generated {duration:.2f}s of audio in {latency * 1000:.0f}ms"
|
|
||||||
|
|
||||||
# Log to MLflow
|
|
||||||
_log_tts_metrics(
|
_log_tts_metrics(
|
||||||
latency=latency,
|
latency=latency,
|
||||||
audio_duration=duration,
|
audio_duration=duration,
|
||||||
@@ -169,19 +371,19 @@ def synthesize_speech(
|
|||||||
language=lang_code,
|
language=lang_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Metrics
|
|
||||||
metrics = f"""
|
metrics = f"""
|
||||||
**Audio Statistics:**
|
**Audio Statistics:**
|
||||||
- Duration: {duration:.2f} seconds
|
- Duration: {duration:.2f} seconds
|
||||||
- Sample Rate: {sample_rate} Hz
|
- Sample Rate: {sample_rate} Hz
|
||||||
- Size: {len(audio_bytes) / 1024:.1f} KB
|
- Size: {len(audio_data) * 2 / 1024:.1f} KB
|
||||||
- Generation Time: {latency * 1000:.0f}ms
|
- Generation Time: {latency * 1000:.0f}ms
|
||||||
- Real-time Factor: {latency / duration:.2f}x
|
- Real-time Factor: {latency / duration:.2f}x
|
||||||
- Language: {language} ({lang_code})
|
- Language: {language} ({lang_code})
|
||||||
|
- Speed: {speed:.1f}x
|
||||||
|
- Sentences: {n_chunks}
|
||||||
- Characters: {len(text)}
|
- Characters: {len(text)}
|
||||||
- Chars/sec: {len(text) / latency:.1f}
|
- Chars/sec: {len(text) / latency:.1f}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return status, (sample_rate, audio_data), metrics
|
return status, (sample_rate, audio_data), metrics
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
@@ -189,37 +391,33 @@ def synthesize_speech(
|
|||||||
return f"❌ TTS service error: {e.response.status_code}", None, ""
|
return f"❌ TTS service error: {e.response.status_code}", None, ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("TTS synthesis failed")
|
logger.exception("TTS synthesis failed")
|
||||||
return f"❌ Error: {str(e)}", None, ""
|
return f"❌ Error: {e}", None, ""
|
||||||
|
|
||||||
|
|
||||||
def check_service_health() -> str:
|
def check_service_health() -> str:
|
||||||
"""Check if the TTS service is healthy."""
|
"""Check if the TTS service is healthy."""
|
||||||
try:
|
try:
|
||||||
# Try the health endpoint first
|
|
||||||
response = client.get(f"{TTS_URL}/health", timeout=5.0)
|
response = client.get(f"{TTS_URL}/health", timeout=5.0)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return "🟢 Service is healthy"
|
return "🟢 Service is healthy"
|
||||||
|
|
||||||
# Fall back to root endpoint
|
|
||||||
response = client.get(f"{TTS_URL}/", timeout=5.0)
|
response = client.get(f"{TTS_URL}/", timeout=5.0)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return "🟢 Service is responding"
|
return "🟢 Service is responding"
|
||||||
|
|
||||||
return f"🟡 Service returned status {response.status_code}"
|
return f"🟡 Service returned status {response.status_code}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"🔴 Service unavailable: {str(e)}"
|
return f"🔴 Service unavailable: {e}"
|
||||||
|
|
||||||
|
|
||||||
# Build the Gradio app
|
# ─── Gradio UI ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="TTS Demo") as demo:
|
with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="TTS Demo") as demo:
|
||||||
gr.Markdown("""
|
gr.Markdown("""
|
||||||
# 🔊 Text-to-Speech Demo
|
# 🔊 Text-to-Speech Demo
|
||||||
|
|
||||||
Test the **Coqui XTTS** text-to-speech service. Convert text to natural-sounding speech
|
Test the **Coqui XTTS** text-to-speech service. Convert text to natural-sounding speech
|
||||||
in multiple languages.
|
in multiple languages. Long text is automatically split into sentences for better quality.
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Service status
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
health_btn = gr.Button("🔄 Check Service", size="sm")
|
health_btn = gr.Button("🔄 Check Service", size="sm")
|
||||||
health_status = gr.Textbox(label="Service Status", interactive=False)
|
health_status = gr.Textbox(label="Service Status", interactive=False)
|
||||||
@@ -227,7 +425,6 @@ in multiple languages.
|
|||||||
health_btn.click(fn=check_service_health, outputs=health_status)
|
health_btn.click(fn=check_service_health, outputs=health_status)
|
||||||
|
|
||||||
with gr.Tabs():
|
with gr.Tabs():
|
||||||
# Tab 1: Basic TTS
|
|
||||||
with gr.TabItem("🎤 Text to Speech"):
|
with gr.TabItem("🎤 Text to Speech"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=2):
|
with gr.Column(scale=2):
|
||||||
@@ -237,17 +434,24 @@ in multiple languages.
|
|||||||
lines=5,
|
lines=5,
|
||||||
max_lines=10,
|
max_lines=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
language = gr.Dropdown(
|
language = gr.Dropdown(
|
||||||
choices=list(LANGUAGES.keys()),
|
choices=list(LANGUAGES.keys()),
|
||||||
value="English",
|
value="English",
|
||||||
label="Language",
|
label="Language",
|
||||||
)
|
)
|
||||||
synthesize_btn = gr.Button(
|
speed = gr.Slider(
|
||||||
"🔊 Synthesize", variant="primary", scale=2
|
minimum=0.5,
|
||||||
|
maximum=2.0,
|
||||||
|
value=1.0,
|
||||||
|
step=0.1,
|
||||||
|
label="Speed",
|
||||||
|
)
|
||||||
|
synthesize_btn = gr.Button(
|
||||||
|
"🔊 Synthesize",
|
||||||
|
variant="primary",
|
||||||
|
scale=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
status_output = gr.Textbox(label="Status", interactive=False)
|
status_output = gr.Textbox(label="Status", interactive=False)
|
||||||
metrics_output = gr.Markdown(label="Metrics")
|
metrics_output = gr.Markdown(label="Metrics")
|
||||||
@@ -256,39 +460,39 @@ in multiple languages.
|
|||||||
|
|
||||||
synthesize_btn.click(
|
synthesize_btn.click(
|
||||||
fn=synthesize_speech,
|
fn=synthesize_speech,
|
||||||
inputs=[text_input, language],
|
inputs=[text_input, language, speed],
|
||||||
outputs=[status_output, audio_output, metrics_output],
|
outputs=[status_output, audio_output, metrics_output],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Example texts
|
|
||||||
gr.Examples(
|
gr.Examples(
|
||||||
examples=[
|
examples=[
|
||||||
[
|
[
|
||||||
"Hello! Welcome to Davies Tech Labs. This is a demonstration of our text-to-speech system.",
|
"Hello! Welcome to Davies Tech Labs. This is a demonstration of our text-to-speech system.",
|
||||||
"English",
|
"English",
|
||||||
|
1.0,
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
"The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet.",
|
"The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet.",
|
||||||
"English",
|
"English",
|
||||||
|
1.0,
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
"Bonjour! Bienvenue au laboratoire technique de Davies.",
|
"Bonjour! Bienvenue au laboratoire technique de Davies.",
|
||||||
"French",
|
"French",
|
||||||
|
1.0,
|
||||||
],
|
],
|
||||||
["Hola! Bienvenido al laboratorio de tecnología.", "Spanish"],
|
["Hola! Bienvenido al laboratorio de tecnología.", "Spanish", 1.0],
|
||||||
["Guten Tag! Willkommen im Techniklabor.", "German"],
|
["Guten Tag! Willkommen im Techniklabor.", "German", 1.0],
|
||||||
],
|
],
|
||||||
inputs=[text_input, language],
|
inputs=[text_input, language, speed],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tab 2: Comparison
|
|
||||||
with gr.TabItem("🔄 Language Comparison"):
|
with gr.TabItem("🔄 Language Comparison"):
|
||||||
gr.Markdown("Compare the same text in different languages.")
|
gr.Markdown("Compare the same text in different languages.")
|
||||||
|
|
||||||
compare_text = gr.Textbox(
|
compare_text = gr.Textbox(
|
||||||
label="Text to Compare", value="Hello, how are you today?", lines=2
|
label="Text to Compare", value="Hello, how are you today?", lines=2
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
lang1 = gr.Dropdown(
|
lang1 = gr.Dropdown(
|
||||||
choices=list(LANGUAGES.keys()), value="English", label="Language 1"
|
choices=list(LANGUAGES.keys()), value="English", label="Language 1"
|
||||||
@@ -296,6 +500,9 @@ in multiple languages.
|
|||||||
lang2 = gr.Dropdown(
|
lang2 = gr.Dropdown(
|
||||||
choices=list(LANGUAGES.keys()), value="Spanish", label="Language 2"
|
choices=list(LANGUAGES.keys()), value="Spanish", label="Language 2"
|
||||||
)
|
)
|
||||||
|
compare_speed = gr.Slider(
|
||||||
|
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed"
|
||||||
|
)
|
||||||
|
|
||||||
compare_btn = gr.Button("Compare Languages", variant="primary")
|
compare_btn = gr.Button("Compare Languages", variant="primary")
|
||||||
|
|
||||||
@@ -304,24 +511,22 @@ in multiple languages.
|
|||||||
gr.Markdown("### Language 1")
|
gr.Markdown("### Language 1")
|
||||||
audio1 = gr.Audio(label="Audio 1", type="numpy")
|
audio1 = gr.Audio(label="Audio 1", type="numpy")
|
||||||
status1 = gr.Textbox(label="Status", interactive=False)
|
status1 = gr.Textbox(label="Status", interactive=False)
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
gr.Markdown("### Language 2")
|
gr.Markdown("### Language 2")
|
||||||
audio2 = gr.Audio(label="Audio 2", type="numpy")
|
audio2 = gr.Audio(label="Audio 2", type="numpy")
|
||||||
status2 = gr.Textbox(label="Status", interactive=False)
|
status2 = gr.Textbox(label="Status", interactive=False)
|
||||||
|
|
||||||
def compare_languages(text, l1, l2):
|
def compare_languages(text, l1, l2, spd):
|
||||||
s1, a1, _ = synthesize_speech(text, l1)
|
s1, a1, _ = synthesize_speech(text, l1, spd)
|
||||||
s2, a2, _ = synthesize_speech(text, l2)
|
s2, a2, _ = synthesize_speech(text, l2, spd)
|
||||||
return s1, a1, s2, a2
|
return s1, a1, s2, a2
|
||||||
|
|
||||||
compare_btn.click(
|
compare_btn.click(
|
||||||
fn=compare_languages,
|
fn=compare_languages,
|
||||||
inputs=[compare_text, lang1, lang2],
|
inputs=[compare_text, lang1, lang2, compare_speed],
|
||||||
outputs=[status1, audio1, status2, audio2],
|
outputs=[status1, audio1, status2, audio2],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tab 3: Batch Processing
|
|
||||||
with gr.TabItem("📚 Batch Synthesis"):
|
with gr.TabItem("📚 Batch Synthesis"):
|
||||||
gr.Markdown("Synthesize multiple texts at once (one per line).")
|
gr.Markdown("Synthesize multiple texts at once (one per line).")
|
||||||
|
|
||||||
@@ -333,19 +538,31 @@ in multiple languages.
|
|||||||
batch_lang = gr.Dropdown(
|
batch_lang = gr.Dropdown(
|
||||||
choices=list(LANGUAGES.keys()), value="English", label="Language"
|
choices=list(LANGUAGES.keys()), value="English", label="Language"
|
||||||
)
|
)
|
||||||
|
batch_speed = gr.Slider(
|
||||||
|
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed"
|
||||||
|
)
|
||||||
batch_btn = gr.Button("Synthesize All", variant="primary")
|
batch_btn = gr.Button("Synthesize All", variant="primary")
|
||||||
|
|
||||||
batch_status = gr.Textbox(label="Status", interactive=False)
|
batch_status = gr.Textbox(label="Status", interactive=False)
|
||||||
batch_audios = gr.Dataset(
|
batch_audio = gr.Audio(label="Combined Audio", type="numpy")
|
||||||
components=[gr.Audio(type="numpy")], label="Generated Audio Files"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Note: Batch processing would need more complex handling
|
def batch_synthesize(texts_raw: str, lang: str, spd: float):
|
||||||
# This is a simplified version
|
lines = [
|
||||||
gr.Markdown("""
|
line.strip()
|
||||||
*Note: For batch processing of many texts, consider using the API directly
|
for line in texts_raw.strip().splitlines()
|
||||||
or the Kubeflow pipeline for better throughput.*
|
if line.strip()
|
||||||
""")
|
]
|
||||||
|
if not lines:
|
||||||
|
return "❌ Please enter at least one line of text", None
|
||||||
|
combined = "\n".join(lines)
|
||||||
|
status, audio, _ = synthesize_speech(combined, lang, spd)
|
||||||
|
return status, audio
|
||||||
|
|
||||||
|
batch_btn.click(
|
||||||
|
fn=batch_synthesize,
|
||||||
|
inputs=[batch_input, batch_lang, batch_speed],
|
||||||
|
outputs=[batch_status, batch_audio],
|
||||||
|
)
|
||||||
|
|
||||||
create_footer()
|
create_footer()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user