Compare commits
18 Commits
861f11e22b
...
v0.0.7
| Author | SHA1 | Date | |
|---|---|---|---|
| 0d1c40725e | |||
| dfe93ae856 | |||
| f5a2545ac8 | |||
| c050d11ab4 | |||
| 454a1c7cf6 | |||
| 71321e5878 | |||
| 1385736556 | |||
| 9faad8be6b | |||
| faa5dc0d9d | |||
| 0cc03aa145 | |||
| 12bdcab180 | |||
| 4069647495 | |||
| 53afea9352 | |||
| 58319b66ee | |||
| 1c5dc7f751 | |||
| b2d2252342 | |||
| 72681217ef | |||
| af67984737 |
224
.gitea/workflows/ci.yml
Normal file
224
.gitea/workflows/ci.yml
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
env:
|
||||||
|
NTFY_URL: http://ntfy.observability.svc.cluster.local:80
|
||||||
|
REGISTRY: gitea-http.gitea.svc.cluster.local:3000/daviestechlabs
|
||||||
|
REGISTRY_HOST: gitea-http.gitea.svc.cluster.local:3000
|
||||||
|
IMAGE_NAME: gradio-ui
|
||||||
|
KUSTOMIZE_NAMESPACE: ai-ml
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
name: Lint
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up uv
|
||||||
|
run: curl -LsSf https://astral.sh/uv/install.sh | sh && echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||||
|
|
||||||
|
- name: Run ruff check
|
||||||
|
run: uvx ruff check .
|
||||||
|
|
||||||
|
- name: Run ruff format check
|
||||||
|
run: uvx ruff format --check .
|
||||||
|
|
||||||
|
release:
|
||||||
|
name: Release
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [lint]
|
||||||
|
if: gitea.ref == 'refs/heads/main' && gitea.event_name == 'push'
|
||||||
|
outputs:
|
||||||
|
version: ${{ steps.version.outputs.version }}
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Determine version bump
|
||||||
|
id: version
|
||||||
|
run: |
|
||||||
|
LATEST=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
|
||||||
|
VERSION=${LATEST#v}
|
||||||
|
IFS='.' read -r MAJOR MINOR PATCH <<< "$VERSION"
|
||||||
|
|
||||||
|
MSG="${{ gitea.event.head_commit.message }}"
|
||||||
|
if echo "$MSG" | grep -qiE "^major:|BREAKING CHANGE"; then
|
||||||
|
MAJOR=$((MAJOR + 1)); MINOR=0; PATCH=0
|
||||||
|
BUMP="major"
|
||||||
|
elif echo "$MSG" | grep -qiE "^(minor:|feat:)"; then
|
||||||
|
MINOR=$((MINOR + 1)); PATCH=0
|
||||||
|
BUMP="minor"
|
||||||
|
else
|
||||||
|
PATCH=$((PATCH + 1))
|
||||||
|
BUMP="patch"
|
||||||
|
fi
|
||||||
|
|
||||||
|
NEW_VERSION="v${MAJOR}.${MINOR}.${PATCH}"
|
||||||
|
echo "version=$NEW_VERSION" >> $GITHUB_OUTPUT
|
||||||
|
echo "bump=$BUMP" >> $GITHUB_OUTPUT
|
||||||
|
echo "Bumping $LATEST → $NEW_VERSION ($BUMP)"
|
||||||
|
|
||||||
|
- name: Create and push tag
|
||||||
|
run: |
|
||||||
|
git config user.name "gitea-actions[bot]"
|
||||||
|
git config user.email "actions@git.daviestechlabs.io"
|
||||||
|
git tag -a ${{ steps.version.outputs.version }} -m "Release ${{ steps.version.outputs.version }}"
|
||||||
|
git push origin ${{ steps.version.outputs.version }}
|
||||||
|
|
||||||
|
docker:
|
||||||
|
name: Docker Build & Push
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [lint, release]
|
||||||
|
if: gitea.ref == 'refs/heads/main' && gitea.event_name == 'push'
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Login to Docker Hub
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Login to Gitea Registry
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: ${{ env.REGISTRY_HOST }}
|
||||||
|
username: ${{ secrets.REGISTRY_USER }}
|
||||||
|
password: ${{ secrets.REGISTRY_TOKEN }}
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
with:
|
||||||
|
buildkitd-config-inline: |
|
||||||
|
[registry."gitea-http.gitea.svc.cluster.local:3000"]
|
||||||
|
http = true
|
||||||
|
insecure = true
|
||||||
|
|
||||||
|
- name: Extract metadata
|
||||||
|
id: meta
|
||||||
|
uses: docker/metadata-action@v5
|
||||||
|
with:
|
||||||
|
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||||
|
tags: |
|
||||||
|
type=semver,pattern={{version}},value=${{ needs.release.outputs.version }}
|
||||||
|
type=semver,pattern={{major}}.{{minor}},value=${{ needs.release.outputs.version }}
|
||||||
|
type=raw,value=latest,enable={{is_default_branch}}
|
||||||
|
|
||||||
|
- name: Build and push
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
push: true
|
||||||
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
|
cache-from: type=gha
|
||||||
|
cache-to: type=gha,mode=max
|
||||||
|
|
||||||
|
deploy:
|
||||||
|
name: Deploy to Kubernetes
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [docker, release]
|
||||||
|
if: gitea.ref == 'refs/heads/main' && gitea.event_name == 'push'
|
||||||
|
container:
|
||||||
|
image: catthehacker/ubuntu:act-latest
|
||||||
|
volumes:
|
||||||
|
- /secrets/kubeconfig:/secrets/kubeconfig
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install kubectl
|
||||||
|
run: |
|
||||||
|
curl -LO "https://dl.k8s.io/release/$(curl -Ls https://dl.k8s.io/release/stable.txt)/bin/linux/amd64/kubectl"
|
||||||
|
chmod +x kubectl && sudo mv kubectl /usr/local/bin/
|
||||||
|
|
||||||
|
- name: Update image tag in manifests
|
||||||
|
env:
|
||||||
|
KUBECONFIG: /secrets/kubeconfig/config
|
||||||
|
run: |
|
||||||
|
VERSION="${{ needs.release.outputs.version }}"
|
||||||
|
VERSION="${VERSION#v}"
|
||||||
|
for DEPLOY in llm embeddings stt tts; do
|
||||||
|
sed -i "s|image: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:.*|image: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${VERSION}|" "${DEPLOY}.yaml"
|
||||||
|
done
|
||||||
|
|
||||||
|
- name: Apply kustomization
|
||||||
|
env:
|
||||||
|
KUBECONFIG: /secrets/kubeconfig/config
|
||||||
|
run: |
|
||||||
|
kubectl apply -k . --namespace ${{ env.KUSTOMIZE_NAMESPACE }}
|
||||||
|
|
||||||
|
- name: Rollout restart deployments
|
||||||
|
env:
|
||||||
|
KUBECONFIG: /secrets/kubeconfig/config
|
||||||
|
run: |
|
||||||
|
for DEPLOY in llm-ui embeddings-ui stt-ui tts-ui; do
|
||||||
|
kubectl rollout restart deployment/${DEPLOY} -n ${{ env.KUSTOMIZE_NAMESPACE }} 2>/dev/null || true
|
||||||
|
done
|
||||||
|
|
||||||
|
- name: Wait for rollout
|
||||||
|
env:
|
||||||
|
KUBECONFIG: /secrets/kubeconfig/config
|
||||||
|
run: |
|
||||||
|
for DEPLOY in llm-ui embeddings-ui stt-ui tts-ui; do
|
||||||
|
kubectl rollout status deployment/${DEPLOY} -n ${{ env.KUSTOMIZE_NAMESPACE }} --timeout=120s 2>/dev/null || true
|
||||||
|
done
|
||||||
|
|
||||||
|
notify:
|
||||||
|
name: Notify
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [lint, release, docker, deploy]
|
||||||
|
if: always()
|
||||||
|
steps:
|
||||||
|
- name: Notify on success
|
||||||
|
if: needs.lint.result == 'success' && needs.docker.result == 'success'
|
||||||
|
run: |
|
||||||
|
curl -s \
|
||||||
|
-H "Title: ✅ CI Passed: ${{ gitea.repository }}" \
|
||||||
|
-H "Priority: default" \
|
||||||
|
-H "Tags: white_check_mark,github" \
|
||||||
|
-H "Click: ${{ gitea.server_url }}/${{ gitea.repository }}/actions/runs/${{ gitea.run_id }}" \
|
||||||
|
-d "Branch: ${{ gitea.ref_name }}
|
||||||
|
Commit: ${{ gitea.event.head_commit.message || gitea.sha }}
|
||||||
|
Release: ${{ needs.release.result == 'success' && needs.release.outputs.version || 'skipped' }}
|
||||||
|
Docker: ${{ needs.docker.result }}
|
||||||
|
Deploy: ${{ needs.deploy.result }}" \
|
||||||
|
${{ env.NTFY_URL }}/gitea-ci
|
||||||
|
|
||||||
|
- name: Notify on deploy success
|
||||||
|
if: needs.deploy.result == 'success'
|
||||||
|
run: |
|
||||||
|
curl -s \
|
||||||
|
-H "Title: 🚀 Deployed: ${{ gitea.repository }}" \
|
||||||
|
-H "Priority: default" \
|
||||||
|
-H "Tags: rocket,kubernetes" \
|
||||||
|
-H "Click: ${{ gitea.server_url }}/${{ gitea.repository }}/actions/runs/${{ gitea.run_id }}" \
|
||||||
|
-d "Version: ${{ needs.release.outputs.version }}
|
||||||
|
Namespace: ${{ env.KUSTOMIZE_NAMESPACE }}
|
||||||
|
Apps: llm-ui, embeddings-ui, stt-ui, tts-ui" \
|
||||||
|
${{ env.NTFY_URL }}/gitea-ci
|
||||||
|
|
||||||
|
- name: Notify on failure
|
||||||
|
if: needs.lint.result == 'failure' || needs.docker.result == 'failure' || needs.deploy.result == 'failure'
|
||||||
|
run: |
|
||||||
|
curl -s \
|
||||||
|
-H "Title: ❌ CI Failed: ${{ gitea.repository }}" \
|
||||||
|
-H "Priority: high" \
|
||||||
|
-H "Tags: x,github" \
|
||||||
|
-H "Click: ${{ gitea.server_url }}/${{ gitea.repository }}/actions/runs/${{ gitea.run_id }}" \
|
||||||
|
-d "Branch: ${{ gitea.ref_name }}
|
||||||
|
Commit: ${{ gitea.event.head_commit.message || gitea.sha }}
|
||||||
|
Lint: ${{ needs.lint.result }}
|
||||||
|
Docker: ${{ needs.docker.result }}
|
||||||
|
Deploy: ${{ needs.deploy.result }}" \
|
||||||
|
${{ env.NTFY_URL }}/gitea-ci
|
||||||
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
__pycache__/
|
||||||
113
embeddings.py
113
embeddings.py
@@ -9,6 +9,7 @@ Features:
|
|||||||
- MLflow metrics logging
|
- MLflow metrics logging
|
||||||
- Visual embedding dimension display
|
- Visual embedding dimension display
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
@@ -28,12 +29,77 @@ logger = logging.getLogger("embeddings-demo")
|
|||||||
EMBEDDINGS_URL = os.environ.get(
|
EMBEDDINGS_URL = os.environ.get(
|
||||||
"EMBEDDINGS_URL",
|
"EMBEDDINGS_URL",
|
||||||
# Default: Ray Serve Embeddings endpoint
|
# Default: Ray Serve Embeddings endpoint
|
||||||
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/embeddings"
|
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/embeddings",
|
||||||
)
|
)
|
||||||
|
# ─── MLflow experiment tracking ──────────────────────────────────────────
|
||||||
|
try:
|
||||||
|
import mlflow
|
||||||
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
MLFLOW_TRACKING_URI = os.environ.get(
|
MLFLOW_TRACKING_URI = os.environ.get(
|
||||||
"MLFLOW_TRACKING_URI",
|
"MLFLOW_TRACKING_URI",
|
||||||
"http://mlflow.mlflow.svc.cluster.local:80"
|
"http://mlflow.mlflow.svc.cluster.local:80",
|
||||||
)
|
)
|
||||||
|
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
|
||||||
|
_mlflow_client = MlflowClient()
|
||||||
|
|
||||||
|
_experiment = _mlflow_client.get_experiment_by_name("gradio-embeddings-tuning")
|
||||||
|
if _experiment is None:
|
||||||
|
_experiment_id = _mlflow_client.create_experiment(
|
||||||
|
"gradio-embeddings-tuning",
|
||||||
|
artifact_location="/mlflow/artifacts/gradio-embeddings-tuning",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_experiment_id = _experiment.experiment_id
|
||||||
|
|
||||||
|
_mlflow_run = mlflow.start_run(
|
||||||
|
experiment_id=_experiment_id,
|
||||||
|
run_name=f"gradio-embeddings-{os.environ.get('HOSTNAME', 'local')}",
|
||||||
|
tags={"service": "gradio-embeddings", "endpoint": EMBEDDINGS_URL},
|
||||||
|
)
|
||||||
|
_mlflow_run_id = _mlflow_run.info.run_id
|
||||||
|
_mlflow_step = 0
|
||||||
|
MLFLOW_ENABLED = True
|
||||||
|
logger.info(
|
||||||
|
"MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("MLflow tracking disabled: %s", exc)
|
||||||
|
_mlflow_client = None
|
||||||
|
_mlflow_run_id = None
|
||||||
|
_mlflow_step = 0
|
||||||
|
MLFLOW_ENABLED = False
|
||||||
|
|
||||||
|
|
||||||
|
def _log_embedding_metrics(
|
||||||
|
latency: float, batch_size: int, embedding_dims: int = 0
|
||||||
|
) -> None:
|
||||||
|
"""Log embedding inference metrics to MLflow (non-blocking best-effort)."""
|
||||||
|
global _mlflow_step
|
||||||
|
if not MLFLOW_ENABLED or _mlflow_client is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
_mlflow_step += 1
|
||||||
|
ts = int(time.time() * 1000)
|
||||||
|
_mlflow_client.log_batch(
|
||||||
|
_mlflow_run_id,
|
||||||
|
metrics=[
|
||||||
|
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
|
||||||
|
mlflow.entities.Metric("batch_size", batch_size, ts, _mlflow_step),
|
||||||
|
mlflow.entities.Metric(
|
||||||
|
"embedding_dims", embedding_dims, ts, _mlflow_step
|
||||||
|
),
|
||||||
|
mlflow.entities.Metric(
|
||||||
|
"latency_per_text_ms",
|
||||||
|
(latency * 1000 / batch_size) if batch_size > 0 else 0,
|
||||||
|
ts,
|
||||||
|
_mlflow_step,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("MLflow log failed", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
# HTTP client
|
# HTTP client
|
||||||
client = httpx.Client(timeout=60.0)
|
client = httpx.Client(timeout=60.0)
|
||||||
@@ -44,8 +110,7 @@ def get_embeddings(texts: list[str]) -> tuple[list[list[float]], float]:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
f"{EMBEDDINGS_URL}/embeddings",
|
f"{EMBEDDINGS_URL}/embeddings", json={"input": texts, "model": "bge"}
|
||||||
json={"input": texts, "model": "bge"}
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
@@ -77,6 +142,9 @@ def generate_single_embedding(text: str) -> tuple[str, str, str]:
|
|||||||
embedding = embeddings[0]
|
embedding = embeddings[0]
|
||||||
dims = len(embedding)
|
dims = len(embedding)
|
||||||
|
|
||||||
|
# Log to MLflow
|
||||||
|
_log_embedding_metrics(latency, batch_size=1, embedding_dims=dims)
|
||||||
|
|
||||||
# Format output
|
# Format output
|
||||||
status = f"✅ Generated {dims}-dimensional embedding in {latency * 1000:.1f}ms"
|
status = f"✅ Generated {dims}-dimensional embedding in {latency * 1000:.1f}ms"
|
||||||
|
|
||||||
@@ -119,6 +187,9 @@ def compare_texts(text1: str, text2: str) -> tuple[str, str]:
|
|||||||
|
|
||||||
similarity = cosine_similarity(embeddings[0], embeddings[1])
|
similarity = cosine_similarity(embeddings[0], embeddings[1])
|
||||||
|
|
||||||
|
# Log to MLflow
|
||||||
|
_log_embedding_metrics(latency, batch_size=2, embedding_dims=len(embeddings[0]))
|
||||||
|
|
||||||
# Determine similarity level
|
# Determine similarity level
|
||||||
if similarity > 0.9:
|
if similarity > 0.9:
|
||||||
level = "🟢 Very High"
|
level = "🟢 Very High"
|
||||||
@@ -167,6 +238,13 @@ def batch_embed(texts_input: str) -> tuple[str, str]:
|
|||||||
try:
|
try:
|
||||||
embeddings, latency = get_embeddings(texts)
|
embeddings, latency = get_embeddings(texts)
|
||||||
|
|
||||||
|
# Log to MLflow
|
||||||
|
_log_embedding_metrics(
|
||||||
|
latency,
|
||||||
|
batch_size=len(embeddings),
|
||||||
|
embedding_dims=len(embeddings[0]) if embeddings else 0,
|
||||||
|
)
|
||||||
|
|
||||||
status = f"✅ Generated {len(embeddings)} embeddings in {latency * 1000:.1f}ms"
|
status = f"✅ Generated {len(embeddings)} embeddings in {latency * 1000:.1f}ms"
|
||||||
status += f" ({latency * 1000 / len(texts):.1f}ms per text)"
|
status += f" ({latency * 1000 / len(texts):.1f}ms per text)"
|
||||||
|
|
||||||
@@ -243,7 +321,7 @@ Generate embeddings, compare text similarity, and explore vector representations
|
|||||||
single_input = gr.Textbox(
|
single_input = gr.Textbox(
|
||||||
label="Input Text",
|
label="Input Text",
|
||||||
placeholder="Enter text to generate embeddings...",
|
placeholder="Enter text to generate embeddings...",
|
||||||
lines=3
|
lines=3,
|
||||||
)
|
)
|
||||||
single_btn = gr.Button("Generate Embedding", variant="primary")
|
single_btn = gr.Button("Generate Embedding", variant="primary")
|
||||||
|
|
||||||
@@ -256,7 +334,7 @@ Generate embeddings, compare text similarity, and explore vector representations
|
|||||||
single_btn.click(
|
single_btn.click(
|
||||||
fn=generate_single_embedding,
|
fn=generate_single_embedding,
|
||||||
inputs=single_input,
|
inputs=single_input,
|
||||||
outputs=[single_status, single_preview, single_stats]
|
outputs=[single_status, single_preview, single_stats],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tab 2: Compare Texts
|
# Tab 2: Compare Texts
|
||||||
@@ -276,14 +354,17 @@ Generate embeddings, compare text similarity, and explore vector representations
|
|||||||
compare_btn.click(
|
compare_btn.click(
|
||||||
fn=compare_texts,
|
fn=compare_texts,
|
||||||
inputs=[compare_text1, compare_text2],
|
inputs=[compare_text1, compare_text2],
|
||||||
outputs=[compare_result, compare_visual]
|
outputs=[compare_result, compare_visual],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Example pairs
|
# Example pairs
|
||||||
gr.Examples(
|
gr.Examples(
|
||||||
examples=[
|
examples=[
|
||||||
["The cat sat on the mat.", "A feline was resting on the rug."],
|
["The cat sat on the mat.", "A feline was resting on the rug."],
|
||||||
["Machine learning is a subset of AI.", "Deep learning uses neural networks."],
|
[
|
||||||
|
"Machine learning is a subset of AI.",
|
||||||
|
"Deep learning uses neural networks.",
|
||||||
|
],
|
||||||
["I love pizza.", "The stock market crashed today."],
|
["I love pizza.", "The stock market crashed today."],
|
||||||
],
|
],
|
||||||
inputs=[compare_text1, compare_text2],
|
inputs=[compare_text1, compare_text2],
|
||||||
@@ -291,21 +372,21 @@ Generate embeddings, compare text similarity, and explore vector representations
|
|||||||
|
|
||||||
# Tab 3: Batch Embeddings
|
# Tab 3: Batch Embeddings
|
||||||
with gr.TabItem("📚 Batch Processing"):
|
with gr.TabItem("📚 Batch Processing"):
|
||||||
gr.Markdown("Generate embeddings for multiple texts and see their similarity matrix.")
|
gr.Markdown(
|
||||||
|
"Generate embeddings for multiple texts and see their similarity matrix."
|
||||||
|
)
|
||||||
|
|
||||||
batch_input = gr.Textbox(
|
batch_input = gr.Textbox(
|
||||||
label="Texts (one per line)",
|
label="Texts (one per line)",
|
||||||
placeholder="Enter multiple texts, one per line...",
|
placeholder="Enter multiple texts, one per line...",
|
||||||
lines=6
|
lines=6,
|
||||||
)
|
)
|
||||||
batch_btn = gr.Button("Process Batch", variant="primary")
|
batch_btn = gr.Button("Process Batch", variant="primary")
|
||||||
batch_status = gr.Textbox(label="Status", interactive=False)
|
batch_status = gr.Textbox(label="Status", interactive=False)
|
||||||
batch_result = gr.Markdown(label="Similarity Matrix")
|
batch_result = gr.Markdown(label="Similarity Matrix")
|
||||||
|
|
||||||
batch_btn.click(
|
batch_btn.click(
|
||||||
fn=batch_embed,
|
fn=batch_embed, inputs=batch_input, outputs=[batch_status, batch_result]
|
||||||
inputs=batch_input,
|
|
||||||
outputs=[batch_status, batch_result]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
gr.Examples(
|
gr.Examples(
|
||||||
@@ -320,8 +401,4 @@ Generate embeddings, compare text similarity, and explore vector representations
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
demo.launch(
|
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|
||||||
server_name="0.0.0.0",
|
|
||||||
server_port=7860,
|
|
||||||
show_error=True
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ spec:
|
|||||||
spec:
|
spec:
|
||||||
containers:
|
containers:
|
||||||
- name: gradio
|
- name: gradio
|
||||||
image: ghcr.io/billy-davies-2/llm-apps:v2-202601271655
|
image: gitea-http.gitea.svc.cluster.local:3000/daviestechlabs/gradio-ui:latest
|
||||||
imagePullPolicy: Always
|
imagePullPolicy: Always
|
||||||
command: ["python", "embeddings.py"]
|
command: ["python", "embeddings.py"]
|
||||||
ports:
|
ports:
|
||||||
|
|||||||
@@ -5,5 +5,11 @@ namespace: ai-ml
|
|||||||
|
|
||||||
resources:
|
resources:
|
||||||
- embeddings.yaml
|
- embeddings.yaml
|
||||||
|
- llm.yaml
|
||||||
- tts.yaml
|
- tts.yaml
|
||||||
- stt.yaml
|
- stt.yaml
|
||||||
|
|
||||||
|
images:
|
||||||
|
- name: gitea-http.gitea.svc.cluster.local:3000/daviestechlabs/gradio-ui
|
||||||
|
newName: registry.lab.daviestechlabs.io/daviestechlabs/gradio-ui
|
||||||
|
newTag: "0.0.7"
|
||||||
|
|||||||
456
llm.py
Normal file
456
llm.py
Normal file
@@ -0,0 +1,456 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
LLM Chat Demo - Gradio UI for testing vLLM inference service.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Multi-turn chat with true SSE streaming responses
|
||||||
|
- Configurable temperature, max tokens, top-p
|
||||||
|
- System prompt customisation
|
||||||
|
- Token usage and latency metrics
|
||||||
|
- Chat history management
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from theme import get_lab_theme, CUSTOM_CSS, create_footer
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger("llm-demo")
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
LLM_URL = os.environ.get(
|
||||||
|
"LLM_URL",
|
||||||
|
# Default: Ray Serve LLM endpoint
|
||||||
|
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/llm",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ─── MLflow experiment tracking ──────────────────────────────────────────
|
||||||
|
try:
|
||||||
|
import mlflow
|
||||||
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
|
MLFLOW_TRACKING_URI = os.environ.get(
|
||||||
|
"MLFLOW_TRACKING_URI",
|
||||||
|
"http://mlflow.mlflow.svc.cluster.local:80",
|
||||||
|
)
|
||||||
|
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
|
||||||
|
_mlflow_client = MlflowClient()
|
||||||
|
|
||||||
|
# Ensure experiment exists
|
||||||
|
_experiment = _mlflow_client.get_experiment_by_name("gradio-llm-tuning")
|
||||||
|
if _experiment is None:
|
||||||
|
_experiment_id = _mlflow_client.create_experiment(
|
||||||
|
"gradio-llm-tuning",
|
||||||
|
artifact_location="/mlflow/artifacts/gradio-llm-tuning",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_experiment_id = _experiment.experiment_id
|
||||||
|
|
||||||
|
# One persistent run per Gradio instance
|
||||||
|
_mlflow_run = mlflow.start_run(
|
||||||
|
experiment_id=_experiment_id,
|
||||||
|
run_name=f"gradio-llm-{os.environ.get('HOSTNAME', 'local')}",
|
||||||
|
tags={
|
||||||
|
"service": "gradio-llm",
|
||||||
|
"endpoint": LLM_URL,
|
||||||
|
"mlflow.runName": f"gradio-llm-{os.environ.get('HOSTNAME', 'local')}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
_mlflow_run_id = _mlflow_run.info.run_id
|
||||||
|
_mlflow_step = 0
|
||||||
|
MLFLOW_ENABLED = True
|
||||||
|
logger.info(
|
||||||
|
"MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("MLflow tracking disabled: %s", exc)
|
||||||
|
_mlflow_client = None
|
||||||
|
_mlflow_run_id = None
|
||||||
|
_mlflow_step = 0
|
||||||
|
MLFLOW_ENABLED = False
|
||||||
|
|
||||||
|
|
||||||
|
def _log_llm_metrics(
|
||||||
|
latency: float,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
top_p: float,
|
||||||
|
) -> None:
|
||||||
|
"""Log inference metrics to MLflow (non-blocking best-effort)."""
|
||||||
|
global _mlflow_step
|
||||||
|
if not MLFLOW_ENABLED or _mlflow_client is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
_mlflow_step += 1
|
||||||
|
ts = int(time.time() * 1000)
|
||||||
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
|
tps = completion_tokens / latency if latency > 0 else 0
|
||||||
|
_mlflow_client.log_batch(
|
||||||
|
_mlflow_run_id,
|
||||||
|
metrics=[
|
||||||
|
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
|
||||||
|
mlflow.entities.Metric(
|
||||||
|
"prompt_tokens", prompt_tokens, ts, _mlflow_step
|
||||||
|
),
|
||||||
|
mlflow.entities.Metric(
|
||||||
|
"completion_tokens", completion_tokens, ts, _mlflow_step
|
||||||
|
),
|
||||||
|
mlflow.entities.Metric("total_tokens", total_tokens, ts, _mlflow_step),
|
||||||
|
mlflow.entities.Metric("tokens_per_second", tps, ts, _mlflow_step),
|
||||||
|
mlflow.entities.Metric("temperature", temperature, ts, _mlflow_step),
|
||||||
|
mlflow.entities.Metric(
|
||||||
|
"max_tokens_requested", max_tokens, ts, _mlflow_step
|
||||||
|
),
|
||||||
|
mlflow.entities.Metric("top_p", top_p, ts, _mlflow_step),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("MLflow log failed", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_SYSTEM_PROMPT = (
|
||||||
|
"You are a helpful AI assistant running on Davies Tech Labs homelab infrastructure. "
|
||||||
|
"You are powered by Llama 3.1 70B served via vLLM on AMD Strix Halo (ROCm). "
|
||||||
|
"Be concise and helpful."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use async client for streaming
|
||||||
|
async_client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=30.0))
|
||||||
|
sync_client = httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0))
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_content(content) -> str:
|
||||||
|
"""Extract plain text from message content.
|
||||||
|
|
||||||
|
Handles both plain strings and Gradio 6.x content-parts format:
|
||||||
|
[{"type": "text", "text": "..."}] or [{"text": "..."}]
|
||||||
|
"""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
parts.append(item.get("text", item.get("content", str(item))))
|
||||||
|
elif isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
else:
|
||||||
|
parts.append(str(item))
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
message: str,
|
||||||
|
history: list[dict[str, str]],
|
||||||
|
system_prompt: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
top_p: float,
|
||||||
|
):
|
||||||
|
"""Stream chat responses from the vLLM endpoint via SSE."""
|
||||||
|
if not message.strip():
|
||||||
|
yield ""
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build message list from history, normalising content-parts
|
||||||
|
messages = []
|
||||||
|
if system_prompt.strip():
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
|
||||||
|
for entry in history:
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": entry["role"],
|
||||||
|
"content": _extract_content(entry["content"]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
messages.append({"role": "user", "content": message})
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"top_p": top_p,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try true SSE streaming first
|
||||||
|
async with async_client.stream("POST", LLM_URL, json=payload) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
content_type = response.headers.get("content-type", "")
|
||||||
|
|
||||||
|
if "text/event-stream" in content_type:
|
||||||
|
# SSE streaming — accumulate deltas
|
||||||
|
full_text = ""
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
data = line[6:]
|
||||||
|
if data.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
chunk = json.loads(data)
|
||||||
|
delta = (
|
||||||
|
chunk.get("choices", [{}])[0]
|
||||||
|
.get("delta", {})
|
||||||
|
.get("content", "")
|
||||||
|
)
|
||||||
|
if delta:
|
||||||
|
full_text += delta
|
||||||
|
yield full_text
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
latency = time.time() - start_time
|
||||||
|
logger.info(
|
||||||
|
"LLM streamed response: %d chars in %.1fs", len(full_text), latency
|
||||||
|
)
|
||||||
|
|
||||||
|
# Best-effort metrics from the final SSE payload
|
||||||
|
_log_llm_metrics(
|
||||||
|
latency=latency,
|
||||||
|
prompt_tokens=0,
|
||||||
|
completion_tokens=len(full_text.split()),
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Non-streaming fallback (endpoint doesn't support stream)
|
||||||
|
body = await response.aread()
|
||||||
|
result = json.loads(body)
|
||||||
|
text = _extract_content(result["choices"][0]["message"]["content"])
|
||||||
|
latency = time.time() - start_time
|
||||||
|
usage = result.get("usage", {})
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"LLM response: %d tokens in %.1fs (prompt=%d, completion=%d)",
|
||||||
|
usage.get("total_tokens", 0),
|
||||||
|
latency,
|
||||||
|
usage.get("prompt_tokens", 0),
|
||||||
|
usage.get("completion_tokens", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
_log_llm_metrics(
|
||||||
|
latency=latency,
|
||||||
|
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||||
|
completion_tokens=usage.get("completion_tokens", 0),
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Yield text progressively for a nicer feel
|
||||||
|
chunk_size = 4
|
||||||
|
words = text.split(" ")
|
||||||
|
partial = ""
|
||||||
|
for i, word in enumerate(words):
|
||||||
|
partial += ("" if i == 0 else " ") + word
|
||||||
|
if i % chunk_size == 0 or i == len(words) - 1:
|
||||||
|
yield partial
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.exception("LLM request failed")
|
||||||
|
yield f"❌ LLM service error: {e.response.status_code} — {e.response.text[:200]}"
|
||||||
|
except httpx.ConnectError:
|
||||||
|
yield "❌ Cannot connect to LLM service. Is the Ray Serve cluster running?"
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("LLM chat failed")
|
||||||
|
yield f"❌ Error: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def check_service_health() -> str:
|
||||||
|
"""Check if the LLM service is reachable."""
|
||||||
|
try:
|
||||||
|
# Try a lightweight GET against the Ray Serve base first.
|
||||||
|
# This avoids burning GPU time on a full inference round-trip.
|
||||||
|
base_url = LLM_URL.rsplit("/", 1)[0] # strip /llm path
|
||||||
|
response = sync_client.get(f"{base_url}/-/routes")
|
||||||
|
if response.status_code == 200:
|
||||||
|
return "🟢 LLM service is healthy"
|
||||||
|
# Fall back to a minimal inference probe
|
||||||
|
response = sync_client.post(
|
||||||
|
LLM_URL,
|
||||||
|
json={
|
||||||
|
"messages": [{"role": "user", "content": "ping"}],
|
||||||
|
"max_tokens": 1,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
return "🟢 LLM service is healthy"
|
||||||
|
return f"🟡 LLM responded with status {response.status_code}"
|
||||||
|
except httpx.ConnectError:
|
||||||
|
return "🔴 Cannot connect to LLM service"
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
return "🟡 LLM service is reachable but slow to respond"
|
||||||
|
except Exception as e:
|
||||||
|
return f"🔴 Service unavailable: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def single_prompt(
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
top_p: float,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""Send a single prompt (non-chat mode) and return output + metrics."""
|
||||||
|
if not prompt.strip():
|
||||||
|
return "❌ Please enter a prompt", ""
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
if system_prompt.strip():
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"top_p": top_p,
|
||||||
|
}
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
client = httpx.Client(timeout=300.0)
|
||||||
|
response = client.post(LLM_URL, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
latency = time.time() - start_time
|
||||||
|
|
||||||
|
text = _extract_content(result["choices"][0]["message"]["content"])
|
||||||
|
usage = result.get("usage", {})
|
||||||
|
|
||||||
|
# Log to MLflow
|
||||||
|
_log_llm_metrics(
|
||||||
|
latency=latency,
|
||||||
|
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||||
|
completion_tokens=usage.get("completion_tokens", 0),
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = f"""
|
||||||
|
**Generation Metrics:**
|
||||||
|
- Latency: {latency:.1f}s
|
||||||
|
- Prompt tokens: {usage.get("prompt_tokens", "N/A")}
|
||||||
|
- Completion tokens: {usage.get("completion_tokens", "N/A")}
|
||||||
|
- Total tokens: {usage.get("total_tokens", "N/A")}
|
||||||
|
- Model: {result.get("model", "N/A")}
|
||||||
|
"""
|
||||||
|
return text, metrics
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
return f"❌ Error {e.response.status_code}: {e.response.text[:300]}", ""
|
||||||
|
except httpx.ConnectError:
|
||||||
|
return "❌ Cannot connect to LLM service", ""
|
||||||
|
except Exception as e:
|
||||||
|
return f"❌ {e}", ""
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Build the Gradio app ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="LLM Chat Demo") as demo:
|
||||||
|
gr.Markdown(
|
||||||
|
"""
|
||||||
|
# 🧠 LLM Chat Demo
|
||||||
|
|
||||||
|
Chat with **Llama 3.1 70B** (AWQ INT4) served via vLLM on AMD Strix Halo (ROCm).
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Service status
|
||||||
|
with gr.Row():
|
||||||
|
health_btn = gr.Button("🔄 Check Service", size="sm")
|
||||||
|
health_status = gr.Textbox(label="Service Status", interactive=False)
|
||||||
|
|
||||||
|
health_btn.click(fn=check_service_health, outputs=health_status)
|
||||||
|
|
||||||
|
# Shared parameters
|
||||||
|
with gr.Accordion("⚙️ Parameters", open=False):
|
||||||
|
system_prompt = gr.Textbox(
|
||||||
|
label="System Prompt",
|
||||||
|
value=DEFAULT_SYSTEM_PROMPT,
|
||||||
|
lines=3,
|
||||||
|
max_lines=6,
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
|
||||||
|
max_tokens = gr.Slider(16, 8192, value=2048, step=16, label="Max Tokens")
|
||||||
|
top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-p")
|
||||||
|
|
||||||
|
with gr.Tabs():
|
||||||
|
# Tab 1: Multi-turn Chat
|
||||||
|
with gr.TabItem("💬 Chat"):
|
||||||
|
chatbot = gr.ChatInterface(
|
||||||
|
fn=chat_stream,
|
||||||
|
additional_inputs=[system_prompt, temperature, max_tokens, top_p],
|
||||||
|
examples=[
|
||||||
|
["Hello! What can you tell me about yourself?"],
|
||||||
|
["Explain how a GPU executes a matrix multiplication."],
|
||||||
|
["Write a Python function to compute the Fibonacci sequence."],
|
||||||
|
["What are the pros and cons of running LLMs on AMD GPUs?"],
|
||||||
|
],
|
||||||
|
chatbot=gr.Chatbot(
|
||||||
|
height=520,
|
||||||
|
placeholder="Type a message to start chatting...",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tab 2: Single Prompt
|
||||||
|
with gr.TabItem("📝 Single Prompt"):
|
||||||
|
gr.Markdown("Send a one-shot prompt without conversation history.")
|
||||||
|
|
||||||
|
prompt_input = gr.Textbox(
|
||||||
|
label="Prompt",
|
||||||
|
placeholder="Enter your prompt...",
|
||||||
|
lines=4,
|
||||||
|
max_lines=10,
|
||||||
|
)
|
||||||
|
generate_btn = gr.Button("🚀 Generate", variant="primary")
|
||||||
|
|
||||||
|
output_text = gr.Textbox(label="Response", lines=12, interactive=False)
|
||||||
|
output_metrics = gr.Markdown(label="Metrics")
|
||||||
|
|
||||||
|
generate_btn.click(
|
||||||
|
fn=single_prompt,
|
||||||
|
inputs=[prompt_input, system_prompt, temperature, max_tokens, top_p],
|
||||||
|
outputs=[output_text, output_metrics],
|
||||||
|
)
|
||||||
|
|
||||||
|
gr.Examples(
|
||||||
|
examples=[
|
||||||
|
[
|
||||||
|
"Summarise the key differences between CUDA and ROCm for ML workloads."
|
||||||
|
],
|
||||||
|
["Write a haiku about Kubernetes."],
|
||||||
|
[
|
||||||
|
"Explain Ray Serve in one paragraph for someone new to ML serving."
|
||||||
|
],
|
||||||
|
["List 5 creative uses for a homelab GPU cluster."],
|
||||||
|
],
|
||||||
|
inputs=[prompt_input],
|
||||||
|
)
|
||||||
|
|
||||||
|
create_footer()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|
||||||
96
llm.yaml
Normal file
96
llm.yaml
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
---
|
||||||
|
apiVersion: apps/v1
|
||||||
|
kind: Deployment
|
||||||
|
metadata:
|
||||||
|
name: llm-ui
|
||||||
|
namespace: ai-ml
|
||||||
|
labels:
|
||||||
|
app: llm
|
||||||
|
component: demo-ui
|
||||||
|
spec:
|
||||||
|
replicas: 1
|
||||||
|
selector:
|
||||||
|
matchLabels:
|
||||||
|
app: llm
|
||||||
|
template:
|
||||||
|
metadata:
|
||||||
|
labels:
|
||||||
|
app: llm
|
||||||
|
component: demo-ui
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- name: gradio
|
||||||
|
image: gitea-http.gitea.svc.cluster.local:3000/daviestechlabs/gradio-ui:latest
|
||||||
|
imagePullPolicy: Always
|
||||||
|
command: ["python", "llm.py"]
|
||||||
|
ports:
|
||||||
|
- containerPort: 7860
|
||||||
|
name: http
|
||||||
|
protocol: TCP
|
||||||
|
env:
|
||||||
|
- name: LLM_URL
|
||||||
|
# Ray Serve endpoint - routes to /llm prefix
|
||||||
|
value: "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/llm"
|
||||||
|
- name: MLFLOW_TRACKING_URI
|
||||||
|
value: "http://mlflow.mlflow.svc.cluster.local:80"
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
cpu: "100m"
|
||||||
|
memory: "256Mi"
|
||||||
|
limits:
|
||||||
|
cpu: "500m"
|
||||||
|
memory: "512Mi"
|
||||||
|
livenessProbe:
|
||||||
|
httpGet:
|
||||||
|
path: /
|
||||||
|
port: 7860
|
||||||
|
initialDelaySeconds: 10
|
||||||
|
periodSeconds: 30
|
||||||
|
readinessProbe:
|
||||||
|
httpGet:
|
||||||
|
path: /
|
||||||
|
port: 7860
|
||||||
|
initialDelaySeconds: 5
|
||||||
|
periodSeconds: 10
|
||||||
|
imagePullSecrets:
|
||||||
|
- name: gitea-registry
|
||||||
|
---
|
||||||
|
apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
name: llm-ui
|
||||||
|
namespace: ai-ml
|
||||||
|
labels:
|
||||||
|
app: llm
|
||||||
|
spec:
|
||||||
|
type: ClusterIP
|
||||||
|
ports:
|
||||||
|
- port: 80
|
||||||
|
targetPort: 7860
|
||||||
|
protocol: TCP
|
||||||
|
name: http
|
||||||
|
selector:
|
||||||
|
app: llm
|
||||||
|
---
|
||||||
|
apiVersion: gateway.networking.k8s.io/v1
|
||||||
|
kind: HTTPRoute
|
||||||
|
metadata:
|
||||||
|
name: llm-ui
|
||||||
|
namespace: ai-ml
|
||||||
|
annotations:
|
||||||
|
external-dns.alpha.kubernetes.io/hostname: llm-ui.lab.daviestechlabs.io
|
||||||
|
spec:
|
||||||
|
parentRefs:
|
||||||
|
- name: envoy-internal
|
||||||
|
namespace: network
|
||||||
|
sectionName: https-lab
|
||||||
|
hostnames:
|
||||||
|
- llm-ui.lab.daviestechlabs.io
|
||||||
|
rules:
|
||||||
|
- matches:
|
||||||
|
- path:
|
||||||
|
type: PathPrefix
|
||||||
|
value: /
|
||||||
|
backendRefs:
|
||||||
|
- name: llm-ui
|
||||||
|
port: 80
|
||||||
7
renovate.json
Normal file
7
renovate.json
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
|
||||||
|
"extends": [
|
||||||
|
"local>daviestechlabs/renovate-config",
|
||||||
|
"local>daviestechlabs/renovate-config:python"
|
||||||
|
]
|
||||||
|
}
|
||||||
142
stt.py
142
stt.py
@@ -9,11 +9,11 @@ Features:
|
|||||||
- Translation mode
|
- Translation mode
|
||||||
- MLflow metrics logging
|
- MLflow metrics logging
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import io
|
import io
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import httpx
|
import httpx
|
||||||
@@ -30,13 +30,82 @@ logger = logging.getLogger("stt-demo")
|
|||||||
STT_URL = os.environ.get(
|
STT_URL = os.environ.get(
|
||||||
"STT_URL",
|
"STT_URL",
|
||||||
# Default: Ray Serve whisper endpoint
|
# Default: Ray Serve whisper endpoint
|
||||||
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper"
|
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper",
|
||||||
)
|
)
|
||||||
MLFLOW_TRACKING_URI = os.environ.get(
|
MLFLOW_TRACKING_URI = os.environ.get(
|
||||||
"MLFLOW_TRACKING_URI",
|
"MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"
|
||||||
"http://mlflow.mlflow.svc.cluster.local:80"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ─── MLflow experiment tracking ──────────────────────────────────────────
|
||||||
|
try:
|
||||||
|
import mlflow
|
||||||
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
|
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
|
||||||
|
_mlflow_client = MlflowClient()
|
||||||
|
|
||||||
|
_experiment = _mlflow_client.get_experiment_by_name("gradio-stt-tuning")
|
||||||
|
if _experiment is None:
|
||||||
|
_experiment_id = _mlflow_client.create_experiment(
|
||||||
|
"gradio-stt-tuning",
|
||||||
|
artifact_location="/mlflow/artifacts/gradio-stt-tuning",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_experiment_id = _experiment.experiment_id
|
||||||
|
|
||||||
|
_mlflow_run = mlflow.start_run(
|
||||||
|
experiment_id=_experiment_id,
|
||||||
|
run_name=f"gradio-stt-{os.environ.get('HOSTNAME', 'local')}",
|
||||||
|
tags={"service": "gradio-stt", "endpoint": STT_URL},
|
||||||
|
)
|
||||||
|
_mlflow_run_id = _mlflow_run.info.run_id
|
||||||
|
_mlflow_step = 0
|
||||||
|
MLFLOW_ENABLED = True
|
||||||
|
logger.info(
|
||||||
|
"MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("MLflow tracking disabled: %s", exc)
|
||||||
|
_mlflow_client = None
|
||||||
|
_mlflow_run_id = None
|
||||||
|
_mlflow_step = 0
|
||||||
|
MLFLOW_ENABLED = False
|
||||||
|
|
||||||
|
|
||||||
|
def _log_stt_metrics(
|
||||||
|
latency: float,
|
||||||
|
audio_duration: float,
|
||||||
|
word_count: int,
|
||||||
|
task: str,
|
||||||
|
) -> None:
|
||||||
|
"""Log STT inference metrics to MLflow (non-blocking best-effort)."""
|
||||||
|
global _mlflow_step
|
||||||
|
if not MLFLOW_ENABLED or _mlflow_client is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
_mlflow_step += 1
|
||||||
|
ts = int(time.time() * 1000)
|
||||||
|
rtf = latency / audio_duration if audio_duration > 0 else 0
|
||||||
|
_mlflow_client.log_batch(
|
||||||
|
_mlflow_run_id,
|
||||||
|
metrics=[
|
||||||
|
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
|
||||||
|
mlflow.entities.Metric(
|
||||||
|
"audio_duration_s", audio_duration, ts, _mlflow_step
|
||||||
|
),
|
||||||
|
mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step),
|
||||||
|
mlflow.entities.Metric("word_count", word_count, ts, _mlflow_step),
|
||||||
|
],
|
||||||
|
params=[]
|
||||||
|
if _mlflow_step > 1
|
||||||
|
else [
|
||||||
|
mlflow.entities.Param("task", task),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("MLflow log failed", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
# HTTP client with longer timeout for transcription
|
# HTTP client with longer timeout for transcription
|
||||||
client = httpx.Client(timeout=180.0)
|
client = httpx.Client(timeout=180.0)
|
||||||
|
|
||||||
@@ -63,9 +132,7 @@ LANGUAGES = {
|
|||||||
|
|
||||||
|
|
||||||
def transcribe_audio(
|
def transcribe_audio(
|
||||||
audio_input: tuple[int, np.ndarray] | str | None,
|
audio_input: tuple[int, np.ndarray] | str | None, language: str, task: str
|
||||||
language: str,
|
|
||||||
task: str
|
|
||||||
) -> tuple[str, str, str]:
|
) -> tuple[str, str, str]:
|
||||||
"""Transcribe audio using the Whisper STT service."""
|
"""Transcribe audio using the Whisper STT service."""
|
||||||
if audio_input is None:
|
if audio_input is None:
|
||||||
@@ -81,12 +148,12 @@ def transcribe_audio(
|
|||||||
|
|
||||||
# Convert to WAV bytes
|
# Convert to WAV bytes
|
||||||
audio_buffer = io.BytesIO()
|
audio_buffer = io.BytesIO()
|
||||||
sf.write(audio_buffer, audio_data, sample_rate, format='WAV')
|
sf.write(audio_buffer, audio_data, sample_rate, format="WAV")
|
||||||
audio_bytes = audio_buffer.getvalue()
|
audio_bytes = audio_buffer.getvalue()
|
||||||
audio_duration = len(audio_data) / sample_rate
|
audio_duration = len(audio_data) / sample_rate
|
||||||
else:
|
else:
|
||||||
# File path
|
# File path
|
||||||
with open(audio_input, 'rb') as f:
|
with open(audio_input, "rb") as f:
|
||||||
audio_bytes = f.read()
|
audio_bytes = f.read()
|
||||||
# Get duration
|
# Get duration
|
||||||
audio_data, sample_rate = sf.read(audio_input)
|
audio_data, sample_rate = sf.read(audio_input)
|
||||||
@@ -117,8 +184,18 @@ def transcribe_audio(
|
|||||||
text = result.get("text", "")
|
text = result.get("text", "")
|
||||||
detected_language = result.get("language", "unknown")
|
detected_language = result.get("language", "unknown")
|
||||||
|
|
||||||
|
# Log to MLflow
|
||||||
|
_log_stt_metrics(
|
||||||
|
latency=latency,
|
||||||
|
audio_duration=audio_duration,
|
||||||
|
word_count=len(text.split()),
|
||||||
|
task=task,
|
||||||
|
)
|
||||||
|
|
||||||
# Status message
|
# Status message
|
||||||
status = f"✅ Transcribed {audio_duration:.1f}s of audio in {latency*1000:.0f}ms"
|
status = (
|
||||||
|
f"✅ Transcribed {audio_duration:.1f}s of audio in {latency * 1000:.0f}ms"
|
||||||
|
)
|
||||||
|
|
||||||
# Metrics
|
# Metrics
|
||||||
metrics = f"""
|
metrics = f"""
|
||||||
@@ -181,21 +258,19 @@ or file upload with support for 100+ languages.
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
mic_input = gr.Audio(
|
mic_input = gr.Audio(
|
||||||
label="Record Audio",
|
label="Record Audio", sources=["microphone"], type="numpy"
|
||||||
sources=["microphone"],
|
|
||||||
type="numpy"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
mic_language = gr.Dropdown(
|
mic_language = gr.Dropdown(
|
||||||
choices=list(LANGUAGES.keys()),
|
choices=list(LANGUAGES.keys()),
|
||||||
value="Auto-detect",
|
value="Auto-detect",
|
||||||
label="Language"
|
label="Language",
|
||||||
)
|
)
|
||||||
mic_task = gr.Radio(
|
mic_task = gr.Radio(
|
||||||
choices=["Transcribe", "Translate to English"],
|
choices=["Transcribe", "Translate to English"],
|
||||||
value="Transcribe",
|
value="Transcribe",
|
||||||
label="Task"
|
label="Task",
|
||||||
)
|
)
|
||||||
|
|
||||||
mic_btn = gr.Button("🎯 Transcribe", variant="primary")
|
mic_btn = gr.Button("🎯 Transcribe", variant="primary")
|
||||||
@@ -204,15 +279,12 @@ or file upload with support for 100+ languages.
|
|||||||
mic_status = gr.Textbox(label="Status", interactive=False)
|
mic_status = gr.Textbox(label="Status", interactive=False)
|
||||||
mic_metrics = gr.Markdown(label="Metrics")
|
mic_metrics = gr.Markdown(label="Metrics")
|
||||||
|
|
||||||
mic_output = gr.Textbox(
|
mic_output = gr.Textbox(label="Transcription", lines=5)
|
||||||
label="Transcription",
|
|
||||||
lines=5
|
|
||||||
)
|
|
||||||
|
|
||||||
mic_btn.click(
|
mic_btn.click(
|
||||||
fn=transcribe_audio,
|
fn=transcribe_audio,
|
||||||
inputs=[mic_input, mic_language, mic_task],
|
inputs=[mic_input, mic_language, mic_task],
|
||||||
outputs=[mic_status, mic_output, mic_metrics]
|
outputs=[mic_status, mic_output, mic_metrics],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tab 2: File Upload
|
# Tab 2: File Upload
|
||||||
@@ -220,21 +292,19 @@ or file upload with support for 100+ languages.
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
file_input = gr.Audio(
|
file_input = gr.Audio(
|
||||||
label="Upload Audio File",
|
label="Upload Audio File", sources=["upload"], type="filepath"
|
||||||
sources=["upload"],
|
|
||||||
type="filepath"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
file_language = gr.Dropdown(
|
file_language = gr.Dropdown(
|
||||||
choices=list(LANGUAGES.keys()),
|
choices=list(LANGUAGES.keys()),
|
||||||
value="Auto-detect",
|
value="Auto-detect",
|
||||||
label="Language"
|
label="Language",
|
||||||
)
|
)
|
||||||
file_task = gr.Radio(
|
file_task = gr.Radio(
|
||||||
choices=["Transcribe", "Translate to English"],
|
choices=["Transcribe", "Translate to English"],
|
||||||
value="Transcribe",
|
value="Transcribe",
|
||||||
label="Task"
|
label="Task",
|
||||||
)
|
)
|
||||||
|
|
||||||
file_btn = gr.Button("🎯 Transcribe", variant="primary")
|
file_btn = gr.Button("🎯 Transcribe", variant="primary")
|
||||||
@@ -243,15 +313,12 @@ or file upload with support for 100+ languages.
|
|||||||
file_status = gr.Textbox(label="Status", interactive=False)
|
file_status = gr.Textbox(label="Status", interactive=False)
|
||||||
file_metrics = gr.Markdown(label="Metrics")
|
file_metrics = gr.Markdown(label="Metrics")
|
||||||
|
|
||||||
file_output = gr.Textbox(
|
file_output = gr.Textbox(label="Transcription", lines=5)
|
||||||
label="Transcription",
|
|
||||||
lines=5
|
|
||||||
)
|
|
||||||
|
|
||||||
file_btn.click(
|
file_btn.click(
|
||||||
fn=transcribe_audio,
|
fn=transcribe_audio,
|
||||||
inputs=[file_input, file_language, file_task],
|
inputs=[file_input, file_language, file_task],
|
||||||
outputs=[file_status, file_output, file_metrics]
|
outputs=[file_status, file_output, file_metrics],
|
||||||
)
|
)
|
||||||
|
|
||||||
gr.Markdown("""
|
gr.Markdown("""
|
||||||
@@ -274,7 +341,7 @@ Whisper will automatically detect the source language.
|
|||||||
trans_input = gr.Audio(
|
trans_input = gr.Audio(
|
||||||
label="Audio Input",
|
label="Audio Input",
|
||||||
sources=["microphone", "upload"],
|
sources=["microphone", "upload"],
|
||||||
type="numpy"
|
type="numpy",
|
||||||
)
|
)
|
||||||
trans_btn = gr.Button("🌍 Translate to English", variant="primary")
|
trans_btn = gr.Button("🌍 Translate to English", variant="primary")
|
||||||
|
|
||||||
@@ -282,10 +349,7 @@ Whisper will automatically detect the source language.
|
|||||||
trans_status = gr.Textbox(label="Status", interactive=False)
|
trans_status = gr.Textbox(label="Status", interactive=False)
|
||||||
trans_metrics = gr.Markdown(label="Metrics")
|
trans_metrics = gr.Markdown(label="Metrics")
|
||||||
|
|
||||||
trans_output = gr.Textbox(
|
trans_output = gr.Textbox(label="English Translation", lines=5)
|
||||||
label="English Translation",
|
|
||||||
lines=5
|
|
||||||
)
|
|
||||||
|
|
||||||
def translate_audio(audio):
|
def translate_audio(audio):
|
||||||
return transcribe_audio(audio, "Auto-detect", "Translate to English")
|
return transcribe_audio(audio, "Auto-detect", "Translate to English")
|
||||||
@@ -293,15 +357,11 @@ Whisper will automatically detect the source language.
|
|||||||
trans_btn.click(
|
trans_btn.click(
|
||||||
fn=translate_audio,
|
fn=translate_audio,
|
||||||
inputs=trans_input,
|
inputs=trans_input,
|
||||||
outputs=[trans_status, trans_output, trans_metrics]
|
outputs=[trans_status, trans_output, trans_metrics],
|
||||||
)
|
)
|
||||||
|
|
||||||
create_footer()
|
create_footer()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
demo.launch(
|
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|
||||||
server_name="0.0.0.0",
|
|
||||||
server_port=7860,
|
|
||||||
show_error=True
|
|
||||||
)
|
|
||||||
|
|||||||
4
stt.yaml
4
stt.yaml
@@ -20,7 +20,7 @@ spec:
|
|||||||
spec:
|
spec:
|
||||||
containers:
|
containers:
|
||||||
- name: gradio
|
- name: gradio
|
||||||
image: ghcr.io/billy-davies-2/llm-apps:v2-202601271655
|
image: gitea-http.gitea.svc.cluster.local:3000/daviestechlabs/gradio-ui:latest
|
||||||
imagePullPolicy: Always
|
imagePullPolicy: Always
|
||||||
command: ["python", "stt.py"]
|
command: ["python", "stt.py"]
|
||||||
ports:
|
ports:
|
||||||
@@ -28,7 +28,7 @@ spec:
|
|||||||
name: http
|
name: http
|
||||||
protocol: TCP
|
protocol: TCP
|
||||||
env:
|
env:
|
||||||
- name: WHISPER_URL
|
- name: STT_URL
|
||||||
# Ray Serve endpoint - routes to /whisper prefix
|
# Ray Serve endpoint - routes to /whisper prefix
|
||||||
value: "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper"
|
value: "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper"
|
||||||
- name: MLFLOW_TRACKING_URI
|
- name: MLFLOW_TRACKING_URI
|
||||||
|
|||||||
184
theme.py
184
theme.py
@@ -3,6 +3,7 @@ Shared Gradio theme for Davies Tech Labs AI demos.
|
|||||||
Consistent styling across all demo applications.
|
Consistent styling across all demo applications.
|
||||||
Cyberpunk aesthetic - dark with yellow/gold accents.
|
Cyberpunk aesthetic - dark with yellow/gold accents.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
@@ -25,7 +26,12 @@ def get_lab_theme() -> gr.Theme:
|
|||||||
primary_hue=gr.themes.colors.yellow,
|
primary_hue=gr.themes.colors.yellow,
|
||||||
secondary_hue=gr.themes.colors.amber,
|
secondary_hue=gr.themes.colors.amber,
|
||||||
neutral_hue=gr.themes.colors.zinc,
|
neutral_hue=gr.themes.colors.zinc,
|
||||||
font=[gr.themes.GoogleFont("Space Grotesk"), "ui-sans-serif", "system-ui", "sans-serif"],
|
font=[
|
||||||
|
gr.themes.GoogleFont("Space Grotesk"),
|
||||||
|
"ui-sans-serif",
|
||||||
|
"system-ui",
|
||||||
|
"sans-serif",
|
||||||
|
],
|
||||||
font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "ui-monospace", "monospace"],
|
font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "ui-monospace", "monospace"],
|
||||||
).set(
|
).set(
|
||||||
# Background colors
|
# Background colors
|
||||||
@@ -75,10 +81,38 @@ def get_lab_theme() -> gr.Theme:
|
|||||||
block_background_fill_dark=CYBER_GRAY,
|
block_background_fill_dark=CYBER_GRAY,
|
||||||
block_border_color="#2a2a2a",
|
block_border_color="#2a2a2a",
|
||||||
block_border_color_dark="#2a2a2a",
|
block_border_color_dark="#2a2a2a",
|
||||||
|
block_label_background_fill="#1a1a00",
|
||||||
|
block_label_background_fill_dark="#1a1a00",
|
||||||
block_label_text_color=CYBER_YELLOW,
|
block_label_text_color=CYBER_YELLOW,
|
||||||
block_label_text_color_dark=CYBER_YELLOW,
|
block_label_text_color_dark=CYBER_YELLOW,
|
||||||
|
block_label_border_color=CYBER_YELLOW,
|
||||||
|
block_label_border_color_dark=CYBER_YELLOW,
|
||||||
block_title_text_color=CYBER_TEXT,
|
block_title_text_color=CYBER_TEXT,
|
||||||
block_title_text_color_dark=CYBER_TEXT,
|
block_title_text_color_dark=CYBER_TEXT,
|
||||||
|
# Table / Dataframe
|
||||||
|
table_border_color="#2a2a2a",
|
||||||
|
table_even_background_fill="#111111",
|
||||||
|
table_even_background_fill_dark="#111111",
|
||||||
|
table_odd_background_fill=CYBER_GRAY,
|
||||||
|
table_odd_background_fill_dark=CYBER_GRAY,
|
||||||
|
table_row_focus="#1f1a00",
|
||||||
|
table_row_focus_dark="#1f1a00",
|
||||||
|
# Panel / accordion
|
||||||
|
panel_background_fill=CYBER_DARK,
|
||||||
|
panel_background_fill_dark=CYBER_DARK,
|
||||||
|
panel_border_color="#2a2a2a",
|
||||||
|
panel_border_color_dark="#2a2a2a",
|
||||||
|
# Checkbox / radio
|
||||||
|
checkbox_background_color=CYBER_DARKER,
|
||||||
|
checkbox_background_color_dark=CYBER_DARKER,
|
||||||
|
checkbox_label_background_fill=CYBER_GRAY,
|
||||||
|
checkbox_label_background_fill_dark=CYBER_GRAY,
|
||||||
|
checkbox_label_text_color=CYBER_TEXT,
|
||||||
|
checkbox_label_text_color_dark=CYBER_TEXT,
|
||||||
|
# Colors
|
||||||
|
color_accent=CYBER_YELLOW,
|
||||||
|
color_accent_soft="#1f1a00",
|
||||||
|
color_accent_soft_dark="#1f1a00",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -338,6 +372,154 @@ label, .gr-label {
|
|||||||
.glow-text {
|
.glow-text {
|
||||||
text-shadow: 0 0 10px var(--cyber-yellow), 0 0 20px var(--cyber-yellow);
|
text-shadow: 0 0 10px var(--cyber-yellow), 0 0 20px var(--cyber-yellow);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ── Examples table / Dataframe overrides ── */
|
||||||
|
/* Gradio renders Examples as a <table> inside a Dataset component.
|
||||||
|
The default styles inject white / light-gray rows that blow out the
|
||||||
|
cyberpunk palette. Force them dark here. */
|
||||||
|
.gr-samples-table,
|
||||||
|
.gr-sample-textbox,
|
||||||
|
table.table,
|
||||||
|
.gr-examples table,
|
||||||
|
div[class*="dataset"] table {
|
||||||
|
background: var(--cyber-dark) !important;
|
||||||
|
color: var(--cyber-text) !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.gr-samples-table tr,
|
||||||
|
.gr-examples table tr,
|
||||||
|
div[class*="dataset"] table tr {
|
||||||
|
background: #111111 !important;
|
||||||
|
border-bottom: 1px solid #222 !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.gr-samples-table tr:nth-child(even),
|
||||||
|
.gr-examples table tr:nth-child(even),
|
||||||
|
div[class*="dataset"] table tr:nth-child(even) {
|
||||||
|
background: #0d0d0d !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.gr-samples-table tr:hover,
|
||||||
|
.gr-examples table tr:hover,
|
||||||
|
div[class*="dataset"] table tr:hover {
|
||||||
|
background: #1f1a00 !important;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
.gr-samples-table th,
|
||||||
|
.gr-examples table th,
|
||||||
|
div[class*="dataset"] table th {
|
||||||
|
background: var(--cyber-gray) !important;
|
||||||
|
color: var(--cyber-yellow) !important;
|
||||||
|
text-transform: uppercase !important;
|
||||||
|
font-size: 0.75rem !important;
|
||||||
|
letter-spacing: 0.1em !important;
|
||||||
|
border-bottom: 2px solid var(--cyber-yellow) !important;
|
||||||
|
padding: 10px 16px !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.gr-samples-table td,
|
||||||
|
.gr-examples table td,
|
||||||
|
div[class*="dataset"] table td {
|
||||||
|
color: #999 !important;
|
||||||
|
border-bottom: 1px solid #1a1a1a !important;
|
||||||
|
padding: 10px 16px !important;
|
||||||
|
font-family: 'JetBrains Mono', monospace !important;
|
||||||
|
font-size: 0.85rem !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Block label pill (e.g. "GENERATED AUDIO", "STATUS") ── */
|
||||||
|
/* These are the small floating labels above each component block */
|
||||||
|
span[class*="label-wrap"],
|
||||||
|
.gr-block-label,
|
||||||
|
.label-wrap {
|
||||||
|
background: #1a1a00 !important;
|
||||||
|
border: 1px solid var(--cyber-yellow) !important;
|
||||||
|
color: var(--cyber-yellow) !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Dropdown / select menus ── */
|
||||||
|
.gr-dropdown,
|
||||||
|
select,
|
||||||
|
ul[role="listbox"],
|
||||||
|
div[class*="dropdown"],
|
||||||
|
.secondary-wrap {
|
||||||
|
background: #0a0a0a !important;
|
||||||
|
color: var(--cyber-text) !important;
|
||||||
|
border-color: #333 !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
ul[role="listbox"] li,
|
||||||
|
div[class*="dropdown"] li {
|
||||||
|
background: #0a0a0a !important;
|
||||||
|
color: var(--cyber-text) !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
ul[role="listbox"] li:hover,
|
||||||
|
ul[role="listbox"] li[aria-selected="true"],
|
||||||
|
div[class*="dropdown"] li:hover {
|
||||||
|
background: #1f1a00 !important;
|
||||||
|
color: var(--cyber-yellow) !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Audio player ── */
|
||||||
|
.gr-audio,
|
||||||
|
audio {
|
||||||
|
background: var(--cyber-dark) !important;
|
||||||
|
border: 1px solid #2a2a2a !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Audio waveform container */
|
||||||
|
div[data-testid="waveform-container"],
|
||||||
|
div[class*="audio"] {
|
||||||
|
background: #0a0a0a !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Markdown inside blocks ── */
|
||||||
|
.gr-markdown,
|
||||||
|
.gr-markdown p,
|
||||||
|
.prose {
|
||||||
|
color: var(--cyber-text) !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.gr-markdown h3,
|
||||||
|
.gr-markdown h2 {
|
||||||
|
color: var(--cyber-yellow) !important;
|
||||||
|
letter-spacing: 0.05em !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.gr-markdown strong {
|
||||||
|
color: var(--cyber-gold) !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Examples accordion header ("Examples" label) ── */
|
||||||
|
.gr-examples .label-wrap,
|
||||||
|
div[id*="examples"] .label-wrap,
|
||||||
|
span[data-testid="block-label"] {
|
||||||
|
background: #1a1a00 !important;
|
||||||
|
color: var(--cyber-yellow) !important;
|
||||||
|
border: 1px solid var(--cyber-yellow) !important;
|
||||||
|
font-size: 0.7rem !important;
|
||||||
|
text-transform: uppercase !important;
|
||||||
|
letter-spacing: 0.1em !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Misc: tooltip, info text ── */
|
||||||
|
.gr-info,
|
||||||
|
.gr-description {
|
||||||
|
color: #666 !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Svelte internal: make sure no white backgrounds leak ── */
|
||||||
|
.contain > div,
|
||||||
|
.wrap > div {
|
||||||
|
background: inherit !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Tab content panels ── */
|
||||||
|
.tabitem {
|
||||||
|
background: var(--cyber-dark) !important;
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
472
tts.py
472
tts.py
@@ -5,19 +5,20 @@ TTS Demo - Gradio UI for testing Text-to-Speech service.
|
|||||||
Features:
|
Features:
|
||||||
- Text input with language selection
|
- Text input with language selection
|
||||||
- Audio playback of synthesized speech
|
- Audio playback of synthesized speech
|
||||||
- Voice/speaker selection (when available)
|
- Sentence-level chunking for better quality
|
||||||
|
- Speed control
|
||||||
- MLflow metrics logging
|
- MLflow metrics logging
|
||||||
- Multiple TTS backends support (Coqui XTTS, Piper, etc.)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import io
|
import io
|
||||||
import base64
|
import wave
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import httpx
|
import httpx
|
||||||
import soundfile as sf
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from theme import get_lab_theme, CUSTOM_CSS, create_footer
|
from theme import get_lab_theme, CUSTOM_CSS, create_footer
|
||||||
@@ -30,13 +31,79 @@ logger = logging.getLogger("tts-demo")
|
|||||||
TTS_URL = os.environ.get(
|
TTS_URL = os.environ.get(
|
||||||
"TTS_URL",
|
"TTS_URL",
|
||||||
# Default: Ray Serve TTS endpoint
|
# Default: Ray Serve TTS endpoint
|
||||||
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/tts"
|
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/tts",
|
||||||
)
|
)
|
||||||
MLFLOW_TRACKING_URI = os.environ.get(
|
MLFLOW_TRACKING_URI = os.environ.get(
|
||||||
"MLFLOW_TRACKING_URI",
|
"MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"
|
||||||
"http://mlflow.mlflow.svc.cluster.local:80"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ─── MLflow experiment tracking ──────────────────────────────────────────
|
||||||
|
try:
|
||||||
|
import mlflow
|
||||||
|
from mlflow.tracking import MlflowClient
|
||||||
|
|
||||||
|
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
|
||||||
|
_mlflow_client = MlflowClient()
|
||||||
|
|
||||||
|
_experiment = _mlflow_client.get_experiment_by_name("gradio-tts-tuning")
|
||||||
|
if _experiment is None:
|
||||||
|
_experiment_id = _mlflow_client.create_experiment(
|
||||||
|
"gradio-tts-tuning",
|
||||||
|
artifact_location="/mlflow/artifacts/gradio-tts-tuning",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_experiment_id = _experiment.experiment_id
|
||||||
|
|
||||||
|
_mlflow_run = mlflow.start_run(
|
||||||
|
experiment_id=_experiment_id,
|
||||||
|
run_name=f"gradio-tts-{os.environ.get('HOSTNAME', 'local')}",
|
||||||
|
tags={"service": "gradio-tts", "endpoint": TTS_URL},
|
||||||
|
)
|
||||||
|
_mlflow_run_id = _mlflow_run.info.run_id
|
||||||
|
_mlflow_step = 0
|
||||||
|
MLFLOW_ENABLED = True
|
||||||
|
logger.info(
|
||||||
|
"MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("MLflow tracking disabled: %s", exc)
|
||||||
|
_mlflow_client = None
|
||||||
|
_mlflow_run_id = None
|
||||||
|
_mlflow_step = 0
|
||||||
|
MLFLOW_ENABLED = False
|
||||||
|
|
||||||
|
|
||||||
|
def _log_tts_metrics(
|
||||||
|
latency: float,
|
||||||
|
audio_duration: float,
|
||||||
|
text_chars: int,
|
||||||
|
language: str,
|
||||||
|
) -> None:
|
||||||
|
"""Log TTS inference metrics to MLflow (non-blocking best-effort)."""
|
||||||
|
global _mlflow_step
|
||||||
|
if not MLFLOW_ENABLED or _mlflow_client is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
_mlflow_step += 1
|
||||||
|
ts = int(time.time() * 1000)
|
||||||
|
rtf = latency / audio_duration if audio_duration > 0 else 0
|
||||||
|
cps = text_chars / latency if latency > 0 else 0
|
||||||
|
_mlflow_client.log_batch(
|
||||||
|
_mlflow_run_id,
|
||||||
|
metrics=[
|
||||||
|
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
|
||||||
|
mlflow.entities.Metric(
|
||||||
|
"audio_duration_s", audio_duration, ts, _mlflow_step
|
||||||
|
),
|
||||||
|
mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step),
|
||||||
|
mlflow.entities.Metric("chars_per_second", cps, ts, _mlflow_step),
|
||||||
|
mlflow.entities.Metric("text_chars", text_chars, ts, _mlflow_step),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("MLflow log failed", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
# HTTP client with longer timeout for audio generation
|
# HTTP client with longer timeout for audio generation
|
||||||
client = httpx.Client(timeout=120.0)
|
client = httpx.Client(timeout=120.0)
|
||||||
|
|
||||||
@@ -60,53 +127,263 @@ LANGUAGES = {
|
|||||||
"Hungarian": "hu",
|
"Hungarian": "hu",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# ─── Text preprocessing ─────────────────────────────────────────────────
|
||||||
|
|
||||||
def synthesize_speech(text: str, language: str) -> tuple[str, tuple[int, np.ndarray] | None, str]:
|
_SENTENCE_RE = re.compile(r"(?<=[.!?;])\s+|(?<=\n)\s*", re.MULTILINE)
|
||||||
"""Synthesize speech from text using the TTS service."""
|
|
||||||
|
_DIGIT_WORDS = {
|
||||||
|
"0": "zero",
|
||||||
|
"1": "one",
|
||||||
|
"2": "two",
|
||||||
|
"3": "three",
|
||||||
|
"4": "four",
|
||||||
|
"5": "five",
|
||||||
|
"6": "six",
|
||||||
|
"7": "seven",
|
||||||
|
"8": "eight",
|
||||||
|
"9": "nine",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_numbers(text: str) -> str:
|
||||||
|
"""Expand standalone single digits to words for clearer pronunciation."""
|
||||||
|
return re.sub(
|
||||||
|
r"\b(\d)\b",
|
||||||
|
lambda m: _DIGIT_WORDS.get(m.group(0), m.group(0)),
|
||||||
|
text,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_text(text: str) -> str:
|
||||||
|
"""Clean and normalise text for TTS input."""
|
||||||
|
text = re.sub(r"[ \t]+", " ", text)
|
||||||
|
text = "\n".join(line.strip() for line in text.splitlines())
|
||||||
|
# Strip markdown / code-fence characters
|
||||||
|
text = re.sub(r"[*#~`|<>{}[\]\\]", "", text)
|
||||||
|
# Expand common symbols
|
||||||
|
text = text.replace("&", " and ")
|
||||||
|
text = text.replace("@", " at ")
|
||||||
|
text = text.replace("%", " percent ")
|
||||||
|
text = text.replace("+", " plus ")
|
||||||
|
text = text.replace("=", " equals ")
|
||||||
|
text = _expand_numbers(text)
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _split_sentences(text: str) -> list[str]:
|
||||||
|
"""Split text into sentences suitable for TTS.
|
||||||
|
|
||||||
|
Keeps sentences short for best quality while preserving natural phrasing.
|
||||||
|
Very long segments are further split on commas / semicolons.
|
||||||
|
"""
|
||||||
|
text = _clean_text(text)
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
raw_parts = _SENTENCE_RE.split(text)
|
||||||
|
sentences: list[str] = []
|
||||||
|
for part in raw_parts:
|
||||||
|
part = part.strip()
|
||||||
|
if not part:
|
||||||
|
continue
|
||||||
|
if len(part) > 200:
|
||||||
|
for sp in re.split(r"(?<=[,;])\s+", part):
|
||||||
|
sp = sp.strip()
|
||||||
|
if sp:
|
||||||
|
sentences.append(sp)
|
||||||
|
else:
|
||||||
|
sentences.append(part)
|
||||||
|
return sentences
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Audio helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _read_wav_bytes(data: bytes) -> tuple[int, np.ndarray]:
|
||||||
|
"""Read WAV audio from bytes, handling scipy wavfile and standard WAV.
|
||||||
|
|
||||||
|
Returns (sample_rate, float32_audio) with values in [-1, 1].
|
||||||
|
"""
|
||||||
|
buf = io.BytesIO(data)
|
||||||
|
|
||||||
|
# Try stdlib wave module first — most robust for PCM WAV from scipy
|
||||||
|
try:
|
||||||
|
with wave.open(buf, "rb") as wf:
|
||||||
|
sr = wf.getframerate()
|
||||||
|
n_frames = wf.getnframes()
|
||||||
|
n_channels = wf.getnchannels()
|
||||||
|
sampwidth = wf.getsampwidth()
|
||||||
|
raw = wf.readframes(n_frames)
|
||||||
|
|
||||||
|
if sampwidth == 2:
|
||||||
|
audio = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
elif sampwidth == 4:
|
||||||
|
audio = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2147483648.0
|
||||||
|
elif sampwidth == 1:
|
||||||
|
audio = (
|
||||||
|
np.frombuffer(raw, dtype=np.uint8).astype(np.float32) - 128.0
|
||||||
|
) / 128.0
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported sample width: {sampwidth}")
|
||||||
|
|
||||||
|
if n_channels > 1:
|
||||||
|
audio = audio.reshape(-1, n_channels).mean(axis=1)
|
||||||
|
|
||||||
|
return sr, audio
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("wave module failed (%s), trying soundfile", exc)
|
||||||
|
|
||||||
|
# Fallback: soundfile (handles FLAC, OGG, etc.)
|
||||||
|
buf.seek(0)
|
||||||
|
try:
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
audio, sr = sf.read(buf, dtype="float32")
|
||||||
|
if audio.ndim > 1:
|
||||||
|
audio = audio.mean(axis=1)
|
||||||
|
return sr, audio
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("soundfile failed (%s), attempting raw PCM", exc)
|
||||||
|
|
||||||
|
# Last resort: raw 16-bit PCM at 22050 Hz
|
||||||
|
logger.warning(
|
||||||
|
"Could not parse WAV header (len=%d, first 4 bytes=%r); raw PCM decode",
|
||||||
|
len(data),
|
||||||
|
data[:4],
|
||||||
|
)
|
||||||
|
audio = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
return 22050, audio
|
||||||
|
|
||||||
|
|
||||||
|
def _concat_audio(
|
||||||
|
chunks: list[tuple[int, np.ndarray]], pause_ms: int = 200
|
||||||
|
) -> tuple[int, np.ndarray]:
|
||||||
|
"""Concatenate (sample_rate, audio) chunks with silence gaps."""
|
||||||
|
if not chunks:
|
||||||
|
return 22050, np.array([], dtype=np.float32)
|
||||||
|
if len(chunks) == 1:
|
||||||
|
return chunks[0]
|
||||||
|
|
||||||
|
sr = chunks[0][0]
|
||||||
|
silence = np.zeros(int(sr * pause_ms / 1000), dtype=np.float32)
|
||||||
|
|
||||||
|
parts: list[np.ndarray] = []
|
||||||
|
for sample_rate, audio in chunks:
|
||||||
|
if sample_rate != sr:
|
||||||
|
ratio = sr / sample_rate
|
||||||
|
indices = np.arange(0, len(audio), 1.0 / ratio).astype(int)
|
||||||
|
indices = indices[indices < len(audio)]
|
||||||
|
audio = audio[indices]
|
||||||
|
parts.append(audio)
|
||||||
|
parts.append(silence)
|
||||||
|
|
||||||
|
if parts:
|
||||||
|
parts.pop() # remove trailing silence
|
||||||
|
return sr, np.concatenate(parts)
|
||||||
|
|
||||||
|
|
||||||
|
# ─── TTS synthesis ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _synthesize_chunk(text: str, lang_code: str, speed: float = 1.0) -> bytes:
|
||||||
|
"""Synthesize a single text chunk via the TTS backend.
|
||||||
|
|
||||||
|
Uses the JSON POST endpoint (no URL length limits, supports speed).
|
||||||
|
Falls back to the Coqui-compatible GET endpoint if POST fails.
|
||||||
|
"""
|
||||||
|
import base64 as b64
|
||||||
|
|
||||||
|
# Try JSON POST first
|
||||||
|
try:
|
||||||
|
resp = client.post(
|
||||||
|
TTS_URL,
|
||||||
|
json={
|
||||||
|
"text": text,
|
||||||
|
"language": lang_code,
|
||||||
|
"speed": speed,
|
||||||
|
"return_base64": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
ct = resp.headers.get("content-type", "")
|
||||||
|
if "application/json" in ct:
|
||||||
|
body = resp.json()
|
||||||
|
if "error" in body:
|
||||||
|
raise RuntimeError(body["error"])
|
||||||
|
audio_b64 = body.get("audio", "")
|
||||||
|
if audio_b64:
|
||||||
|
return b64.b64decode(audio_b64)
|
||||||
|
# Non-JSON response — treat as raw audio bytes
|
||||||
|
return resp.content
|
||||||
|
except Exception:
|
||||||
|
logger.debug(
|
||||||
|
"POST endpoint failed, falling back to GET /api/tts", exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback: Coqui-compatible GET (no speed control)
|
||||||
|
resp = client.get(
|
||||||
|
f"{TTS_URL}/api/tts",
|
||||||
|
params={"text": text, "language_id": lang_code},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.content
|
||||||
|
|
||||||
|
|
||||||
|
def synthesize_speech(
|
||||||
|
text: str, language: str, speed: float
|
||||||
|
) -> tuple[str, tuple[int, np.ndarray] | None, str]:
|
||||||
|
"""Synthesize speech from text using the TTS service.
|
||||||
|
|
||||||
|
Long text is split into sentences and synthesized individually
|
||||||
|
for better quality, then concatenated with natural pauses.
|
||||||
|
"""
|
||||||
if not text.strip():
|
if not text.strip():
|
||||||
return "❌ Please enter some text", None, ""
|
return "❌ Please enter some text", None, ""
|
||||||
|
|
||||||
lang_code = LANGUAGES.get(language, "en")
|
lang_code = LANGUAGES.get(language, "en")
|
||||||
|
sentences = _split_sentences(text)
|
||||||
|
if not sentences:
|
||||||
|
return "❌ No speakable text found after cleaning", None, ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
audio_chunks: list[tuple[int, np.ndarray]] = []
|
||||||
|
|
||||||
# Call TTS service (Coqui XTTS API format)
|
for sentence in sentences:
|
||||||
response = client.get(
|
raw_audio = _synthesize_chunk(sentence, lang_code, speed)
|
||||||
f"{TTS_URL}/api/tts",
|
sr, audio = _read_wav_bytes(raw_audio)
|
||||||
params={"text": text, "language_id": lang_code}
|
audio_chunks.append((sr, audio))
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
|
sample_rate, audio_data = _concat_audio(audio_chunks)
|
||||||
latency = time.time() - start_time
|
latency = time.time() - start_time
|
||||||
audio_bytes = response.content
|
duration = len(audio_data) / sample_rate if sample_rate > 0 else 0
|
||||||
|
|
||||||
# Parse audio data
|
n_chunks = len(sentences)
|
||||||
audio_io = io.BytesIO(audio_bytes)
|
status = (
|
||||||
audio_data, sample_rate = sf.read(audio_io)
|
f"✅ Generated {duration:.2f}s of audio in {latency * 1000:.0f}ms"
|
||||||
|
f" ({n_chunks} sentence{'s' if n_chunks != 1 else ''})"
|
||||||
|
)
|
||||||
|
|
||||||
# Calculate duration
|
_log_tts_metrics(
|
||||||
if len(audio_data.shape) == 1:
|
latency=latency,
|
||||||
duration = len(audio_data) / sample_rate
|
audio_duration=duration,
|
||||||
else:
|
text_chars=len(text),
|
||||||
duration = len(audio_data) / sample_rate
|
language=lang_code,
|
||||||
|
)
|
||||||
|
|
||||||
# Status message
|
|
||||||
status = f"✅ Generated {duration:.2f}s of audio in {latency*1000:.0f}ms"
|
|
||||||
|
|
||||||
# Metrics
|
|
||||||
metrics = f"""
|
metrics = f"""
|
||||||
**Audio Statistics:**
|
**Audio Statistics:**
|
||||||
- Duration: {duration:.2f} seconds
|
- Duration: {duration:.2f} seconds
|
||||||
- Sample Rate: {sample_rate} Hz
|
- Sample Rate: {sample_rate} Hz
|
||||||
- Size: {len(audio_bytes) / 1024:.1f} KB
|
- Size: {len(audio_data) * 2 / 1024:.1f} KB
|
||||||
- Generation Time: {latency * 1000:.0f}ms
|
- Generation Time: {latency * 1000:.0f}ms
|
||||||
- Real-time Factor: {latency / duration:.2f}x
|
- Real-time Factor: {latency / duration:.2f}x
|
||||||
- Language: {language} ({lang_code})
|
- Language: {language} ({lang_code})
|
||||||
|
- Speed: {speed:.1f}x
|
||||||
|
- Sentences: {n_chunks}
|
||||||
- Characters: {len(text)}
|
- Characters: {len(text)}
|
||||||
- Chars/sec: {len(text) / latency:.1f}
|
- Chars/sec: {len(text) / latency:.1f}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return status, (sample_rate, audio_data), metrics
|
return status, (sample_rate, audio_data), metrics
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
@@ -114,37 +391,33 @@ def synthesize_speech(text: str, language: str) -> tuple[str, tuple[int, np.ndar
|
|||||||
return f"❌ TTS service error: {e.response.status_code}", None, ""
|
return f"❌ TTS service error: {e.response.status_code}", None, ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("TTS synthesis failed")
|
logger.exception("TTS synthesis failed")
|
||||||
return f"❌ Error: {str(e)}", None, ""
|
return f"❌ Error: {e}", None, ""
|
||||||
|
|
||||||
|
|
||||||
def check_service_health() -> str:
|
def check_service_health() -> str:
|
||||||
"""Check if the TTS service is healthy."""
|
"""Check if the TTS service is healthy."""
|
||||||
try:
|
try:
|
||||||
# Try the health endpoint first
|
|
||||||
response = client.get(f"{TTS_URL}/health", timeout=5.0)
|
response = client.get(f"{TTS_URL}/health", timeout=5.0)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return "🟢 Service is healthy"
|
return "🟢 Service is healthy"
|
||||||
|
|
||||||
# Fall back to root endpoint
|
|
||||||
response = client.get(f"{TTS_URL}/", timeout=5.0)
|
response = client.get(f"{TTS_URL}/", timeout=5.0)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return "🟢 Service is responding"
|
return "🟢 Service is responding"
|
||||||
|
|
||||||
return f"🟡 Service returned status {response.status_code}"
|
return f"🟡 Service returned status {response.status_code}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"🔴 Service unavailable: {str(e)}"
|
return f"🔴 Service unavailable: {e}"
|
||||||
|
|
||||||
|
|
||||||
# Build the Gradio app
|
# ─── Gradio UI ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="TTS Demo") as demo:
|
with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="TTS Demo") as demo:
|
||||||
gr.Markdown("""
|
gr.Markdown("""
|
||||||
# 🔊 Text-to-Speech Demo
|
# 🔊 Text-to-Speech Demo
|
||||||
|
|
||||||
Test the **Coqui XTTS** text-to-speech service. Convert text to natural-sounding speech
|
Test the **Coqui XTTS** text-to-speech service. Convert text to natural-sounding speech
|
||||||
in multiple languages.
|
in multiple languages. Long text is automatically split into sentences for better quality.
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Service status
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
health_btn = gr.Button("🔄 Check Service", size="sm")
|
health_btn = gr.Button("🔄 Check Service", size="sm")
|
||||||
health_status = gr.Textbox(label="Service Status", interactive=False)
|
health_status = gr.Textbox(label="Service Status", interactive=False)
|
||||||
@@ -152,7 +425,6 @@ in multiple languages.
|
|||||||
health_btn.click(fn=check_service_health, outputs=health_status)
|
health_btn.click(fn=check_service_health, outputs=health_status)
|
||||||
|
|
||||||
with gr.Tabs():
|
with gr.Tabs():
|
||||||
# Tab 1: Basic TTS
|
|
||||||
with gr.TabItem("🎤 Text to Speech"):
|
with gr.TabItem("🎤 Text to Speech"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=2):
|
with gr.Column(scale=2):
|
||||||
@@ -160,17 +432,26 @@ in multiple languages.
|
|||||||
label="Text to Synthesize",
|
label="Text to Synthesize",
|
||||||
placeholder="Enter text to convert to speech...",
|
placeholder="Enter text to convert to speech...",
|
||||||
lines=5,
|
lines=5,
|
||||||
max_lines=10
|
max_lines=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
language = gr.Dropdown(
|
language = gr.Dropdown(
|
||||||
choices=list(LANGUAGES.keys()),
|
choices=list(LANGUAGES.keys()),
|
||||||
value="English",
|
value="English",
|
||||||
label="Language"
|
label="Language",
|
||||||
|
)
|
||||||
|
speed = gr.Slider(
|
||||||
|
minimum=0.5,
|
||||||
|
maximum=2.0,
|
||||||
|
value=1.0,
|
||||||
|
step=0.1,
|
||||||
|
label="Speed",
|
||||||
|
)
|
||||||
|
synthesize_btn = gr.Button(
|
||||||
|
"🔊 Synthesize",
|
||||||
|
variant="primary",
|
||||||
|
scale=2,
|
||||||
)
|
)
|
||||||
synthesize_btn = gr.Button("🔊 Synthesize", variant="primary", scale=2)
|
|
||||||
|
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
status_output = gr.Textbox(label="Status", interactive=False)
|
status_output = gr.Textbox(label="Status", interactive=False)
|
||||||
metrics_output = gr.Markdown(label="Metrics")
|
metrics_output = gr.Markdown(label="Metrics")
|
||||||
@@ -179,35 +460,49 @@ in multiple languages.
|
|||||||
|
|
||||||
synthesize_btn.click(
|
synthesize_btn.click(
|
||||||
fn=synthesize_speech,
|
fn=synthesize_speech,
|
||||||
inputs=[text_input, language],
|
inputs=[text_input, language, speed],
|
||||||
outputs=[status_output, audio_output, metrics_output]
|
outputs=[status_output, audio_output, metrics_output],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Example texts
|
|
||||||
gr.Examples(
|
gr.Examples(
|
||||||
examples=[
|
examples=[
|
||||||
["Hello! Welcome to Davies Tech Labs. This is a demonstration of our text-to-speech system.", "English"],
|
[
|
||||||
["The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet.", "English"],
|
"Hello! Welcome to Davies Tech Labs. This is a demonstration of our text-to-speech system.",
|
||||||
["Bonjour! Bienvenue au laboratoire technique de Davies.", "French"],
|
"English",
|
||||||
["Hola! Bienvenido al laboratorio de tecnología.", "Spanish"],
|
1.0,
|
||||||
["Guten Tag! Willkommen im Techniklabor.", "German"],
|
|
||||||
],
|
],
|
||||||
inputs=[text_input, language],
|
[
|
||||||
|
"The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet.",
|
||||||
|
"English",
|
||||||
|
1.0,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"Bonjour! Bienvenue au laboratoire technique de Davies.",
|
||||||
|
"French",
|
||||||
|
1.0,
|
||||||
|
],
|
||||||
|
["Hola! Bienvenido al laboratorio de tecnología.", "Spanish", 1.0],
|
||||||
|
["Guten Tag! Willkommen im Techniklabor.", "German", 1.0],
|
||||||
|
],
|
||||||
|
inputs=[text_input, language, speed],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tab 2: Comparison
|
|
||||||
with gr.TabItem("🔄 Language Comparison"):
|
with gr.TabItem("🔄 Language Comparison"):
|
||||||
gr.Markdown("Compare the same text in different languages.")
|
gr.Markdown("Compare the same text in different languages.")
|
||||||
|
|
||||||
compare_text = gr.Textbox(
|
compare_text = gr.Textbox(
|
||||||
label="Text to Compare",
|
label="Text to Compare", value="Hello, how are you today?", lines=2
|
||||||
value="Hello, how are you today?",
|
|
||||||
lines=2
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
lang1 = gr.Dropdown(choices=list(LANGUAGES.keys()), value="English", label="Language 1")
|
lang1 = gr.Dropdown(
|
||||||
lang2 = gr.Dropdown(choices=list(LANGUAGES.keys()), value="Spanish", label="Language 2")
|
choices=list(LANGUAGES.keys()), value="English", label="Language 1"
|
||||||
|
)
|
||||||
|
lang2 = gr.Dropdown(
|
||||||
|
choices=list(LANGUAGES.keys()), value="Spanish", label="Language 2"
|
||||||
|
)
|
||||||
|
compare_speed = gr.Slider(
|
||||||
|
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed"
|
||||||
|
)
|
||||||
|
|
||||||
compare_btn = gr.Button("Compare Languages", variant="primary")
|
compare_btn = gr.Button("Compare Languages", variant="primary")
|
||||||
|
|
||||||
@@ -216,58 +511,61 @@ in multiple languages.
|
|||||||
gr.Markdown("### Language 1")
|
gr.Markdown("### Language 1")
|
||||||
audio1 = gr.Audio(label="Audio 1", type="numpy")
|
audio1 = gr.Audio(label="Audio 1", type="numpy")
|
||||||
status1 = gr.Textbox(label="Status", interactive=False)
|
status1 = gr.Textbox(label="Status", interactive=False)
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
gr.Markdown("### Language 2")
|
gr.Markdown("### Language 2")
|
||||||
audio2 = gr.Audio(label="Audio 2", type="numpy")
|
audio2 = gr.Audio(label="Audio 2", type="numpy")
|
||||||
status2 = gr.Textbox(label="Status", interactive=False)
|
status2 = gr.Textbox(label="Status", interactive=False)
|
||||||
|
|
||||||
def compare_languages(text, l1, l2):
|
def compare_languages(text, l1, l2, spd):
|
||||||
s1, a1, _ = synthesize_speech(text, l1)
|
s1, a1, _ = synthesize_speech(text, l1, spd)
|
||||||
s2, a2, _ = synthesize_speech(text, l2)
|
s2, a2, _ = synthesize_speech(text, l2, spd)
|
||||||
return s1, a1, s2, a2
|
return s1, a1, s2, a2
|
||||||
|
|
||||||
compare_btn.click(
|
compare_btn.click(
|
||||||
fn=compare_languages,
|
fn=compare_languages,
|
||||||
inputs=[compare_text, lang1, lang2],
|
inputs=[compare_text, lang1, lang2, compare_speed],
|
||||||
outputs=[status1, audio1, status2, audio2]
|
outputs=[status1, audio1, status2, audio2],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tab 3: Batch Processing
|
|
||||||
with gr.TabItem("📚 Batch Synthesis"):
|
with gr.TabItem("📚 Batch Synthesis"):
|
||||||
gr.Markdown("Synthesize multiple texts at once (one per line).")
|
gr.Markdown("Synthesize multiple texts at once (one per line).")
|
||||||
|
|
||||||
batch_input = gr.Textbox(
|
batch_input = gr.Textbox(
|
||||||
label="Texts (one per line)",
|
label="Texts (one per line)",
|
||||||
placeholder="Enter multiple texts, one per line...",
|
placeholder="Enter multiple texts, one per line...",
|
||||||
lines=6
|
lines=6,
|
||||||
)
|
)
|
||||||
batch_lang = gr.Dropdown(
|
batch_lang = gr.Dropdown(
|
||||||
choices=list(LANGUAGES.keys()),
|
choices=list(LANGUAGES.keys()), value="English", label="Language"
|
||||||
value="English",
|
)
|
||||||
label="Language"
|
batch_speed = gr.Slider(
|
||||||
|
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed"
|
||||||
)
|
)
|
||||||
batch_btn = gr.Button("Synthesize All", variant="primary")
|
batch_btn = gr.Button("Synthesize All", variant="primary")
|
||||||
|
|
||||||
batch_status = gr.Textbox(label="Status", interactive=False)
|
batch_status = gr.Textbox(label="Status", interactive=False)
|
||||||
batch_audios = gr.Dataset(
|
batch_audio = gr.Audio(label="Combined Audio", type="numpy")
|
||||||
components=[gr.Audio(type="numpy")],
|
|
||||||
label="Generated Audio Files"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Note: Batch processing would need more complex handling
|
def batch_synthesize(texts_raw: str, lang: str, spd: float):
|
||||||
# This is a simplified version
|
lines = [
|
||||||
gr.Markdown("""
|
line.strip()
|
||||||
*Note: For batch processing of many texts, consider using the API directly
|
for line in texts_raw.strip().splitlines()
|
||||||
or the Kubeflow pipeline for better throughput.*
|
if line.strip()
|
||||||
""")
|
]
|
||||||
|
if not lines:
|
||||||
|
return "❌ Please enter at least one line of text", None
|
||||||
|
combined = "\n".join(lines)
|
||||||
|
status, audio, _ = synthesize_speech(combined, lang, spd)
|
||||||
|
return status, audio
|
||||||
|
|
||||||
|
batch_btn.click(
|
||||||
|
fn=batch_synthesize,
|
||||||
|
inputs=[batch_input, batch_lang, batch_speed],
|
||||||
|
outputs=[batch_status, batch_audio],
|
||||||
|
)
|
||||||
|
|
||||||
create_footer()
|
create_footer()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
demo.launch(
|
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|
||||||
server_name="0.0.0.0",
|
|
||||||
server_port=7860,
|
|
||||||
show_error=True
|
|
||||||
)
|
|
||||||
|
|||||||
2
tts.yaml
2
tts.yaml
@@ -20,7 +20,7 @@ spec:
|
|||||||
spec:
|
spec:
|
||||||
containers:
|
containers:
|
||||||
- name: gradio
|
- name: gradio
|
||||||
image: ghcr.io/billy-davies-2/llm-apps:v2-202601271655
|
image: gitea-http.gitea.svc.cluster.local:3000/daviestechlabs/gradio-ui:latest
|
||||||
imagePullPolicy: Always
|
imagePullPolicy: Always
|
||||||
command: ["python", "tts.py"]
|
command: ["python", "tts.py"]
|
||||||
ports:
|
ports:
|
||||||
|
|||||||
Reference in New Issue
Block a user