Compare commits
15 Commits
861f11e22b
...
v0.0.6
| Author | SHA1 | Date | |
|---|---|---|---|
| c050d11ab4 | |||
| 454a1c7cf6 | |||
| 71321e5878 | |||
| 1385736556 | |||
| 9faad8be6b | |||
| faa5dc0d9d | |||
| 0cc03aa145 | |||
| 12bdcab180 | |||
| 4069647495 | |||
| 53afea9352 | |||
| 58319b66ee | |||
| 1c5dc7f751 | |||
| b2d2252342 | |||
| 72681217ef | |||
| af67984737 |
224
.gitea/workflows/ci.yml
Normal file
224
.gitea/workflows/ci.yml
Normal file
@@ -0,0 +1,224 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
env:
|
||||
NTFY_URL: http://ntfy.observability.svc.cluster.local:80
|
||||
REGISTRY: gitea-http.gitea.svc.cluster.local:3000/daviestechlabs
|
||||
REGISTRY_HOST: gitea-http.gitea.svc.cluster.local:3000
|
||||
IMAGE_NAME: gradio-ui
|
||||
KUSTOMIZE_NAMESPACE: ai-ml
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
name: Lint
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up uv
|
||||
run: curl -LsSf https://astral.sh/uv/install.sh | sh && echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Run ruff check
|
||||
run: uvx ruff check .
|
||||
|
||||
- name: Run ruff format check
|
||||
run: uvx ruff format --check .
|
||||
|
||||
release:
|
||||
name: Release
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint]
|
||||
if: gitea.ref == 'refs/heads/main' && gitea.event_name == 'push'
|
||||
outputs:
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Determine version bump
|
||||
id: version
|
||||
run: |
|
||||
LATEST=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
|
||||
VERSION=${LATEST#v}
|
||||
IFS='.' read -r MAJOR MINOR PATCH <<< "$VERSION"
|
||||
|
||||
MSG="${{ gitea.event.head_commit.message }}"
|
||||
if echo "$MSG" | grep -qiE "^major:|BREAKING CHANGE"; then
|
||||
MAJOR=$((MAJOR + 1)); MINOR=0; PATCH=0
|
||||
BUMP="major"
|
||||
elif echo "$MSG" | grep -qiE "^(minor:|feat:)"; then
|
||||
MINOR=$((MINOR + 1)); PATCH=0
|
||||
BUMP="minor"
|
||||
else
|
||||
PATCH=$((PATCH + 1))
|
||||
BUMP="patch"
|
||||
fi
|
||||
|
||||
NEW_VERSION="v${MAJOR}.${MINOR}.${PATCH}"
|
||||
echo "version=$NEW_VERSION" >> $GITHUB_OUTPUT
|
||||
echo "bump=$BUMP" >> $GITHUB_OUTPUT
|
||||
echo "Bumping $LATEST → $NEW_VERSION ($BUMP)"
|
||||
|
||||
- name: Create and push tag
|
||||
run: |
|
||||
git config user.name "gitea-actions[bot]"
|
||||
git config user.email "actions@git.daviestechlabs.io"
|
||||
git tag -a ${{ steps.version.outputs.version }} -m "Release ${{ steps.version.outputs.version }}"
|
||||
git push origin ${{ steps.version.outputs.version }}
|
||||
|
||||
docker:
|
||||
name: Docker Build & Push
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint, release]
|
||||
if: gitea.ref == 'refs/heads/main' && gitea.event_name == 'push'
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Login to Gitea Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY_HOST }}
|
||||
username: ${{ secrets.REGISTRY_USER }}
|
||||
password: ${{ secrets.REGISTRY_TOKEN }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-config-inline: |
|
||||
[registry."gitea-http.gitea.svc.cluster.local:3000"]
|
||||
http = true
|
||||
insecure = true
|
||||
|
||||
- name: Extract metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
tags: |
|
||||
type=semver,pattern={{version}},value=${{ needs.release.outputs.version }}
|
||||
type=semver,pattern={{major}}.{{minor}},value=${{ needs.release.outputs.version }}
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
deploy:
|
||||
name: Deploy to Kubernetes
|
||||
runs-on: ubuntu-latest
|
||||
needs: [docker, release]
|
||||
if: gitea.ref == 'refs/heads/main' && gitea.event_name == 'push'
|
||||
container:
|
||||
image: catthehacker/ubuntu:act-latest
|
||||
volumes:
|
||||
- /secrets/kubeconfig:/secrets/kubeconfig
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install kubectl
|
||||
run: |
|
||||
curl -LO "https://dl.k8s.io/release/$(curl -Ls https://dl.k8s.io/release/stable.txt)/bin/linux/amd64/kubectl"
|
||||
chmod +x kubectl && sudo mv kubectl /usr/local/bin/
|
||||
|
||||
- name: Update image tag in manifests
|
||||
env:
|
||||
KUBECONFIG: /secrets/kubeconfig/config
|
||||
run: |
|
||||
VERSION="${{ needs.release.outputs.version }}"
|
||||
VERSION="${VERSION#v}"
|
||||
for DEPLOY in llm embeddings stt tts; do
|
||||
sed -i "s|image: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:.*|image: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${VERSION}|" "${DEPLOY}.yaml"
|
||||
done
|
||||
|
||||
- name: Apply kustomization
|
||||
env:
|
||||
KUBECONFIG: /secrets/kubeconfig/config
|
||||
run: |
|
||||
kubectl apply -k . --namespace ${{ env.KUSTOMIZE_NAMESPACE }}
|
||||
|
||||
- name: Rollout restart deployments
|
||||
env:
|
||||
KUBECONFIG: /secrets/kubeconfig/config
|
||||
run: |
|
||||
for DEPLOY in llm-ui embeddings-ui stt-ui tts-ui; do
|
||||
kubectl rollout restart deployment/${DEPLOY} -n ${{ env.KUSTOMIZE_NAMESPACE }} 2>/dev/null || true
|
||||
done
|
||||
|
||||
- name: Wait for rollout
|
||||
env:
|
||||
KUBECONFIG: /secrets/kubeconfig/config
|
||||
run: |
|
||||
for DEPLOY in llm-ui embeddings-ui stt-ui tts-ui; do
|
||||
kubectl rollout status deployment/${DEPLOY} -n ${{ env.KUSTOMIZE_NAMESPACE }} --timeout=120s 2>/dev/null || true
|
||||
done
|
||||
|
||||
notify:
|
||||
name: Notify
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint, release, docker, deploy]
|
||||
if: always()
|
||||
steps:
|
||||
- name: Notify on success
|
||||
if: needs.lint.result == 'success' && needs.docker.result == 'success'
|
||||
run: |
|
||||
curl -s \
|
||||
-H "Title: ✅ CI Passed: ${{ gitea.repository }}" \
|
||||
-H "Priority: default" \
|
||||
-H "Tags: white_check_mark,github" \
|
||||
-H "Click: ${{ gitea.server_url }}/${{ gitea.repository }}/actions/runs/${{ gitea.run_id }}" \
|
||||
-d "Branch: ${{ gitea.ref_name }}
|
||||
Commit: ${{ gitea.event.head_commit.message || gitea.sha }}
|
||||
Release: ${{ needs.release.result == 'success' && needs.release.outputs.version || 'skipped' }}
|
||||
Docker: ${{ needs.docker.result }}
|
||||
Deploy: ${{ needs.deploy.result }}" \
|
||||
${{ env.NTFY_URL }}/gitea-ci
|
||||
|
||||
- name: Notify on deploy success
|
||||
if: needs.deploy.result == 'success'
|
||||
run: |
|
||||
curl -s \
|
||||
-H "Title: 🚀 Deployed: ${{ gitea.repository }}" \
|
||||
-H "Priority: default" \
|
||||
-H "Tags: rocket,kubernetes" \
|
||||
-H "Click: ${{ gitea.server_url }}/${{ gitea.repository }}/actions/runs/${{ gitea.run_id }}" \
|
||||
-d "Version: ${{ needs.release.outputs.version }}
|
||||
Namespace: ${{ env.KUSTOMIZE_NAMESPACE }}
|
||||
Apps: llm-ui, embeddings-ui, stt-ui, tts-ui" \
|
||||
${{ env.NTFY_URL }}/gitea-ci
|
||||
|
||||
- name: Notify on failure
|
||||
if: needs.lint.result == 'failure' || needs.docker.result == 'failure' || needs.deploy.result == 'failure'
|
||||
run: |
|
||||
curl -s \
|
||||
-H "Title: ❌ CI Failed: ${{ gitea.repository }}" \
|
||||
-H "Priority: high" \
|
||||
-H "Tags: x,github" \
|
||||
-H "Click: ${{ gitea.server_url }}/${{ gitea.repository }}/actions/runs/${{ gitea.run_id }}" \
|
||||
-d "Branch: ${{ gitea.ref_name }}
|
||||
Commit: ${{ gitea.event.head_commit.message || gitea.sha }}
|
||||
Lint: ${{ needs.lint.result }}
|
||||
Docker: ${{ needs.docker.result }}
|
||||
Deploy: ${{ needs.deploy.result }}" \
|
||||
${{ env.NTFY_URL }}/gitea-ci
|
||||
229
embeddings.py
229
embeddings.py
@@ -9,6 +9,7 @@ Features:
|
||||
- MLflow metrics logging
|
||||
- Visual embedding dimension display
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
@@ -26,14 +27,79 @@ logger = logging.getLogger("embeddings-demo")
|
||||
|
||||
# Configuration
|
||||
EMBEDDINGS_URL = os.environ.get(
|
||||
"EMBEDDINGS_URL",
|
||||
"EMBEDDINGS_URL",
|
||||
# Default: Ray Serve Embeddings endpoint
|
||||
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/embeddings"
|
||||
)
|
||||
MLFLOW_TRACKING_URI = os.environ.get(
|
||||
"MLFLOW_TRACKING_URI",
|
||||
"http://mlflow.mlflow.svc.cluster.local:80"
|
||||
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/embeddings",
|
||||
)
|
||||
# ─── MLflow experiment tracking ──────────────────────────────────────────
|
||||
try:
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
MLFLOW_TRACKING_URI = os.environ.get(
|
||||
"MLFLOW_TRACKING_URI",
|
||||
"http://mlflow.mlflow.svc.cluster.local:80",
|
||||
)
|
||||
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
|
||||
_mlflow_client = MlflowClient()
|
||||
|
||||
_experiment = _mlflow_client.get_experiment_by_name("gradio-embeddings-tuning")
|
||||
if _experiment is None:
|
||||
_experiment_id = _mlflow_client.create_experiment(
|
||||
"gradio-embeddings-tuning",
|
||||
artifact_location="/mlflow/artifacts/gradio-embeddings-tuning",
|
||||
)
|
||||
else:
|
||||
_experiment_id = _experiment.experiment_id
|
||||
|
||||
_mlflow_run = mlflow.start_run(
|
||||
experiment_id=_experiment_id,
|
||||
run_name=f"gradio-embeddings-{os.environ.get('HOSTNAME', 'local')}",
|
||||
tags={"service": "gradio-embeddings", "endpoint": EMBEDDINGS_URL},
|
||||
)
|
||||
_mlflow_run_id = _mlflow_run.info.run_id
|
||||
_mlflow_step = 0
|
||||
MLFLOW_ENABLED = True
|
||||
logger.info(
|
||||
"MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("MLflow tracking disabled: %s", exc)
|
||||
_mlflow_client = None
|
||||
_mlflow_run_id = None
|
||||
_mlflow_step = 0
|
||||
MLFLOW_ENABLED = False
|
||||
|
||||
|
||||
def _log_embedding_metrics(
|
||||
latency: float, batch_size: int, embedding_dims: int = 0
|
||||
) -> None:
|
||||
"""Log embedding inference metrics to MLflow (non-blocking best-effort)."""
|
||||
global _mlflow_step
|
||||
if not MLFLOW_ENABLED or _mlflow_client is None:
|
||||
return
|
||||
try:
|
||||
_mlflow_step += 1
|
||||
ts = int(time.time() * 1000)
|
||||
_mlflow_client.log_batch(
|
||||
_mlflow_run_id,
|
||||
metrics=[
|
||||
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
|
||||
mlflow.entities.Metric("batch_size", batch_size, ts, _mlflow_step),
|
||||
mlflow.entities.Metric(
|
||||
"embedding_dims", embedding_dims, ts, _mlflow_step
|
||||
),
|
||||
mlflow.entities.Metric(
|
||||
"latency_per_text_ms",
|
||||
(latency * 1000 / batch_size) if batch_size > 0 else 0,
|
||||
ts,
|
||||
_mlflow_step,
|
||||
),
|
||||
],
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("MLflow log failed", exc_info=True)
|
||||
|
||||
|
||||
# HTTP client
|
||||
client = httpx.Client(timeout=60.0)
|
||||
@@ -42,17 +108,16 @@ client = httpx.Client(timeout=60.0)
|
||||
def get_embeddings(texts: list[str]) -> tuple[list[list[float]], float]:
|
||||
"""Get embeddings from the embeddings service."""
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
response = client.post(
|
||||
f"{EMBEDDINGS_URL}/embeddings",
|
||||
json={"input": texts, "model": "bge"}
|
||||
f"{EMBEDDINGS_URL}/embeddings", json={"input": texts, "model": "bge"}
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
latency = time.time() - start_time
|
||||
result = response.json()
|
||||
embeddings = [d["embedding"] for d in result.get("data", [])]
|
||||
|
||||
|
||||
return embeddings, latency
|
||||
|
||||
|
||||
@@ -67,26 +132,29 @@ def generate_single_embedding(text: str) -> tuple[str, str, str]:
|
||||
"""Generate embedding for a single text."""
|
||||
if not text.strip():
|
||||
return "❌ Please enter some text", "", ""
|
||||
|
||||
|
||||
try:
|
||||
embeddings, latency = get_embeddings([text])
|
||||
|
||||
|
||||
if not embeddings:
|
||||
return "❌ No embedding returned", "", ""
|
||||
|
||||
|
||||
embedding = embeddings[0]
|
||||
dims = len(embedding)
|
||||
|
||||
|
||||
# Log to MLflow
|
||||
_log_embedding_metrics(latency, batch_size=1, embedding_dims=dims)
|
||||
|
||||
# Format output
|
||||
status = f"✅ Generated {dims}-dimensional embedding in {latency*1000:.1f}ms"
|
||||
|
||||
status = f"✅ Generated {dims}-dimensional embedding in {latency * 1000:.1f}ms"
|
||||
|
||||
# Show first/last few dimensions
|
||||
preview = f"Dimensions: {dims}\n\n"
|
||||
preview += "First 10 values:\n"
|
||||
preview += json.dumps(embedding[:10], indent=2)
|
||||
preview += "\n\n...\n\nLast 10 values:\n"
|
||||
preview += json.dumps(embedding[-10:], indent=2)
|
||||
|
||||
|
||||
# Stats
|
||||
stats = f"""
|
||||
**Embedding Statistics:**
|
||||
@@ -96,11 +164,11 @@ def generate_single_embedding(text: str) -> tuple[str, str, str]:
|
||||
- Mean: {np.mean(embedding):.6f}
|
||||
- Std: {np.std(embedding):.6f}
|
||||
- L2 Norm: {np.linalg.norm(embedding):.6f}
|
||||
- Latency: {latency*1000:.1f}ms
|
||||
- Latency: {latency * 1000:.1f}ms
|
||||
"""
|
||||
|
||||
|
||||
return status, preview, stats
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Embedding generation failed")
|
||||
return f"❌ Error: {str(e)}", "", ""
|
||||
@@ -110,15 +178,18 @@ def compare_texts(text1: str, text2: str) -> tuple[str, str]:
|
||||
"""Compare similarity between two texts."""
|
||||
if not text1.strip() or not text2.strip():
|
||||
return "❌ Please enter both texts", ""
|
||||
|
||||
|
||||
try:
|
||||
embeddings, latency = get_embeddings([text1, text2])
|
||||
|
||||
|
||||
if len(embeddings) != 2:
|
||||
return "❌ Failed to get embeddings for both texts", ""
|
||||
|
||||
|
||||
similarity = cosine_similarity(embeddings[0], embeddings[1])
|
||||
|
||||
|
||||
# Log to MLflow
|
||||
_log_embedding_metrics(latency, batch_size=2, embedding_dims=len(embeddings[0]))
|
||||
|
||||
# Determine similarity level
|
||||
if similarity > 0.9:
|
||||
level = "🟢 Very High"
|
||||
@@ -132,7 +203,7 @@ def compare_texts(text1: str, text2: str) -> tuple[str, str]:
|
||||
else:
|
||||
level = "🔴 Low"
|
||||
desc = "These texts are semantically different"
|
||||
|
||||
|
||||
result = f"""
|
||||
## Similarity Score: {similarity:.4f}
|
||||
|
||||
@@ -141,17 +212,17 @@ def compare_texts(text1: str, text2: str) -> tuple[str, str]:
|
||||
{desc}
|
||||
|
||||
---
|
||||
*Computed in {latency*1000:.1f}ms*
|
||||
*Computed in {latency * 1000:.1f}ms*
|
||||
"""
|
||||
|
||||
|
||||
# Create a simple visual bar
|
||||
bar_length = 50
|
||||
filled = int(similarity * bar_length)
|
||||
bar = "█" * filled + "░" * (bar_length - filled)
|
||||
visual = f"[{bar}] {similarity*100:.1f}%"
|
||||
|
||||
visual = f"[{bar}] {similarity * 100:.1f}%"
|
||||
|
||||
return result, visual
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Comparison failed")
|
||||
return f"❌ Error: {str(e)}", ""
|
||||
@@ -160,16 +231,23 @@ def compare_texts(text1: str, text2: str) -> tuple[str, str]:
|
||||
def batch_embed(texts_input: str) -> tuple[str, str]:
|
||||
"""Generate embeddings for multiple texts (one per line)."""
|
||||
texts = [t.strip() for t in texts_input.strip().split("\n") if t.strip()]
|
||||
|
||||
|
||||
if not texts:
|
||||
return "❌ Please enter at least one text (one per line)", ""
|
||||
|
||||
|
||||
try:
|
||||
embeddings, latency = get_embeddings(texts)
|
||||
|
||||
status = f"✅ Generated {len(embeddings)} embeddings in {latency*1000:.1f}ms"
|
||||
status += f" ({latency*1000/len(texts):.1f}ms per text)"
|
||||
|
||||
|
||||
# Log to MLflow
|
||||
_log_embedding_metrics(
|
||||
latency,
|
||||
batch_size=len(embeddings),
|
||||
embedding_dims=len(embeddings[0]) if embeddings else 0,
|
||||
)
|
||||
|
||||
status = f"✅ Generated {len(embeddings)} embeddings in {latency * 1000:.1f}ms"
|
||||
status += f" ({latency * 1000 / len(texts):.1f}ms per text)"
|
||||
|
||||
# Build similarity matrix
|
||||
n = len(embeddings)
|
||||
matrix = []
|
||||
@@ -179,16 +257,16 @@ def batch_embed(texts_input: str) -> tuple[str, str]:
|
||||
sim = cosine_similarity(embeddings[i], embeddings[j])
|
||||
row.append(f"{sim:.3f}")
|
||||
matrix.append(row)
|
||||
|
||||
|
||||
# Format as table
|
||||
header = "| | " + " | ".join([f"Text {i+1}" for i in range(n)]) + " |"
|
||||
header = "| | " + " | ".join([f"Text {i + 1}" for i in range(n)]) + " |"
|
||||
separator = "|---" + "|---" * n + "|"
|
||||
rows = []
|
||||
for i, row in enumerate(matrix):
|
||||
rows.append(f"| **Text {i+1}** | " + " | ".join(row) + " |")
|
||||
|
||||
rows.append(f"| **Text {i + 1}** | " + " | ".join(row) + " |")
|
||||
|
||||
table = "\n".join([header, separator] + rows)
|
||||
|
||||
|
||||
result = f"""
|
||||
## Similarity Matrix
|
||||
|
||||
@@ -198,10 +276,10 @@ def batch_embed(texts_input: str) -> tuple[str, str]:
|
||||
**Texts processed:**
|
||||
"""
|
||||
for i, text in enumerate(texts):
|
||||
result += f"\n{i+1}. {text[:50]}{'...' if len(text) > 50 else ''}"
|
||||
|
||||
result += f"\n{i + 1}. {text[:50]}{'...' if len(text) > 50 else ''}"
|
||||
|
||||
return status, result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Batch embedding failed")
|
||||
return f"❌ Error: {str(e)}", ""
|
||||
@@ -227,14 +305,14 @@ with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="Embeddings Demo") a
|
||||
Test the **BGE Embeddings** service for semantic text encoding.
|
||||
Generate embeddings, compare text similarity, and explore vector representations.
|
||||
""")
|
||||
|
||||
|
||||
# Service status
|
||||
with gr.Row():
|
||||
health_btn = gr.Button("🔄 Check Service", size="sm")
|
||||
health_status = gr.Textbox(label="Service Status", interactive=False)
|
||||
|
||||
|
||||
health_btn.click(fn=check_service_health, outputs=health_status)
|
||||
|
||||
|
||||
with gr.Tabs():
|
||||
# Tab 1: Single Embedding
|
||||
with gr.TabItem("📝 Single Text"):
|
||||
@@ -243,71 +321,74 @@ Generate embeddings, compare text similarity, and explore vector representations
|
||||
single_input = gr.Textbox(
|
||||
label="Input Text",
|
||||
placeholder="Enter text to generate embeddings...",
|
||||
lines=3
|
||||
lines=3,
|
||||
)
|
||||
single_btn = gr.Button("Generate Embedding", variant="primary")
|
||||
|
||||
|
||||
with gr.Column():
|
||||
single_status = gr.Textbox(label="Status", interactive=False)
|
||||
single_stats = gr.Markdown(label="Statistics")
|
||||
|
||||
|
||||
single_preview = gr.Code(label="Embedding Preview", language="json")
|
||||
|
||||
|
||||
single_btn.click(
|
||||
fn=generate_single_embedding,
|
||||
inputs=single_input,
|
||||
outputs=[single_status, single_preview, single_stats]
|
||||
outputs=[single_status, single_preview, single_stats],
|
||||
)
|
||||
|
||||
|
||||
# Tab 2: Compare Texts
|
||||
with gr.TabItem("⚖️ Compare Texts"):
|
||||
gr.Markdown("Compare the semantic similarity between two texts.")
|
||||
|
||||
|
||||
with gr.Row():
|
||||
compare_text1 = gr.Textbox(label="Text 1", lines=3)
|
||||
compare_text2 = gr.Textbox(label="Text 2", lines=3)
|
||||
|
||||
|
||||
compare_btn = gr.Button("Compare Similarity", variant="primary")
|
||||
|
||||
|
||||
with gr.Row():
|
||||
compare_result = gr.Markdown(label="Result")
|
||||
compare_visual = gr.Textbox(label="Similarity Bar", interactive=False)
|
||||
|
||||
|
||||
compare_btn.click(
|
||||
fn=compare_texts,
|
||||
inputs=[compare_text1, compare_text2],
|
||||
outputs=[compare_result, compare_visual]
|
||||
outputs=[compare_result, compare_visual],
|
||||
)
|
||||
|
||||
|
||||
# Example pairs
|
||||
gr.Examples(
|
||||
examples=[
|
||||
["The cat sat on the mat.", "A feline was resting on the rug."],
|
||||
["Machine learning is a subset of AI.", "Deep learning uses neural networks."],
|
||||
[
|
||||
"Machine learning is a subset of AI.",
|
||||
"Deep learning uses neural networks.",
|
||||
],
|
||||
["I love pizza.", "The stock market crashed today."],
|
||||
],
|
||||
inputs=[compare_text1, compare_text2],
|
||||
)
|
||||
|
||||
|
||||
# Tab 3: Batch Embeddings
|
||||
with gr.TabItem("📚 Batch Processing"):
|
||||
gr.Markdown("Generate embeddings for multiple texts and see their similarity matrix.")
|
||||
|
||||
gr.Markdown(
|
||||
"Generate embeddings for multiple texts and see their similarity matrix."
|
||||
)
|
||||
|
||||
batch_input = gr.Textbox(
|
||||
label="Texts (one per line)",
|
||||
placeholder="Enter multiple texts, one per line...",
|
||||
lines=6
|
||||
lines=6,
|
||||
)
|
||||
batch_btn = gr.Button("Process Batch", variant="primary")
|
||||
batch_status = gr.Textbox(label="Status", interactive=False)
|
||||
batch_result = gr.Markdown(label="Similarity Matrix")
|
||||
|
||||
|
||||
batch_btn.click(
|
||||
fn=batch_embed,
|
||||
inputs=batch_input,
|
||||
outputs=[batch_status, batch_result]
|
||||
fn=batch_embed, inputs=batch_input, outputs=[batch_status, batch_result]
|
||||
)
|
||||
|
||||
|
||||
gr.Examples(
|
||||
examples=[
|
||||
"Python is a programming language.\nJava is also a programming language.\nCoffee is a beverage.",
|
||||
@@ -315,13 +396,9 @@ Generate embeddings, compare text similarity, and explore vector representations
|
||||
],
|
||||
inputs=batch_input,
|
||||
)
|
||||
|
||||
|
||||
create_footer()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch(
|
||||
server_name="0.0.0.0",
|
||||
server_port=7860,
|
||||
show_error=True
|
||||
)
|
||||
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|
||||
|
||||
@@ -20,7 +20,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: gradio
|
||||
image: ghcr.io/billy-davies-2/llm-apps:v2-202601271655
|
||||
image: gitea-http.gitea.svc.cluster.local:3000/daviestechlabs/gradio-ui:latest
|
||||
imagePullPolicy: Always
|
||||
command: ["python", "embeddings.py"]
|
||||
ports:
|
||||
|
||||
@@ -5,5 +5,6 @@ namespace: ai-ml
|
||||
|
||||
resources:
|
||||
- embeddings.yaml
|
||||
- llm.yaml
|
||||
- tts.yaml
|
||||
- stt.yaml
|
||||
|
||||
388
llm.py
Normal file
388
llm.py
Normal file
@@ -0,0 +1,388 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
LLM Chat Demo - Gradio UI for testing vLLM inference service.
|
||||
|
||||
Features:
|
||||
- Multi-turn chat with streaming responses
|
||||
- Configurable temperature, max tokens, top-p
|
||||
- System prompt customisation
|
||||
- Token usage and latency metrics
|
||||
- Chat history management
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
|
||||
import gradio as gr
|
||||
import httpx
|
||||
|
||||
from theme import get_lab_theme, CUSTOM_CSS, create_footer
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("llm-demo")
|
||||
|
||||
# Configuration
|
||||
LLM_URL = os.environ.get(
|
||||
"LLM_URL",
|
||||
# Default: Ray Serve LLM endpoint
|
||||
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/llm",
|
||||
)
|
||||
|
||||
# ─── MLflow experiment tracking ──────────────────────────────────────────
|
||||
try:
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
MLFLOW_TRACKING_URI = os.environ.get(
|
||||
"MLFLOW_TRACKING_URI",
|
||||
"http://mlflow.mlflow.svc.cluster.local:80",
|
||||
)
|
||||
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
|
||||
_mlflow_client = MlflowClient()
|
||||
|
||||
# Ensure experiment exists
|
||||
_experiment = _mlflow_client.get_experiment_by_name("gradio-llm-tuning")
|
||||
if _experiment is None:
|
||||
_experiment_id = _mlflow_client.create_experiment(
|
||||
"gradio-llm-tuning",
|
||||
artifact_location="/mlflow/artifacts/gradio-llm-tuning",
|
||||
)
|
||||
else:
|
||||
_experiment_id = _experiment.experiment_id
|
||||
|
||||
# One persistent run per Gradio instance
|
||||
_mlflow_run = mlflow.start_run(
|
||||
experiment_id=_experiment_id,
|
||||
run_name=f"gradio-llm-{os.environ.get('HOSTNAME', 'local')}",
|
||||
tags={
|
||||
"service": "gradio-llm",
|
||||
"endpoint": LLM_URL,
|
||||
"mlflow.runName": f"gradio-llm-{os.environ.get('HOSTNAME', 'local')}",
|
||||
},
|
||||
)
|
||||
_mlflow_run_id = _mlflow_run.info.run_id
|
||||
_mlflow_step = 0
|
||||
MLFLOW_ENABLED = True
|
||||
logger.info(
|
||||
"MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("MLflow tracking disabled: %s", exc)
|
||||
_mlflow_client = None
|
||||
_mlflow_run_id = None
|
||||
_mlflow_step = 0
|
||||
MLFLOW_ENABLED = False
|
||||
|
||||
|
||||
def _log_llm_metrics(
|
||||
latency: float,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
top_p: float,
|
||||
) -> None:
|
||||
"""Log inference metrics to MLflow (non-blocking best-effort)."""
|
||||
global _mlflow_step
|
||||
if not MLFLOW_ENABLED or _mlflow_client is None:
|
||||
return
|
||||
try:
|
||||
_mlflow_step += 1
|
||||
ts = int(time.time() * 1000)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
tps = completion_tokens / latency if latency > 0 else 0
|
||||
_mlflow_client.log_batch(
|
||||
_mlflow_run_id,
|
||||
metrics=[
|
||||
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
|
||||
mlflow.entities.Metric(
|
||||
"prompt_tokens", prompt_tokens, ts, _mlflow_step
|
||||
),
|
||||
mlflow.entities.Metric(
|
||||
"completion_tokens", completion_tokens, ts, _mlflow_step
|
||||
),
|
||||
mlflow.entities.Metric("total_tokens", total_tokens, ts, _mlflow_step),
|
||||
mlflow.entities.Metric("tokens_per_second", tps, ts, _mlflow_step),
|
||||
mlflow.entities.Metric("temperature", temperature, ts, _mlflow_step),
|
||||
mlflow.entities.Metric(
|
||||
"max_tokens_requested", max_tokens, ts, _mlflow_step
|
||||
),
|
||||
mlflow.entities.Metric("top_p", top_p, ts, _mlflow_step),
|
||||
],
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("MLflow log failed", exc_info=True)
|
||||
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
"You are a helpful AI assistant running on Davies Tech Labs homelab infrastructure. "
|
||||
"You are powered by Llama 3.1 70B served via vLLM on AMD Strix Halo (ROCm). "
|
||||
"Be concise and helpful."
|
||||
)
|
||||
|
||||
# Use async client for streaming
|
||||
async_client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=30.0))
|
||||
sync_client = httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0))
|
||||
|
||||
|
||||
async def chat_stream(
|
||||
message: str,
|
||||
history: list[dict[str, str]],
|
||||
system_prompt: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
top_p: float,
|
||||
):
|
||||
"""Stream chat responses from the vLLM endpoint."""
|
||||
if not message.strip():
|
||||
yield ""
|
||||
return
|
||||
|
||||
# Build message list from history
|
||||
messages = []
|
||||
if system_prompt.strip():
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
for entry in history:
|
||||
messages.append({"role": entry["role"], "content": entry["content"]})
|
||||
|
||||
messages.append({"role": "user", "content": message})
|
||||
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"top_p": top_p,
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
response = await async_client.post(LLM_URL, json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
text = result["choices"][0]["message"]["content"]
|
||||
latency = time.time() - start_time
|
||||
usage = result.get("usage", {})
|
||||
|
||||
logger.info(
|
||||
"LLM response: %d tokens in %.1fs (prompt=%d, completion=%d)",
|
||||
usage.get("total_tokens", 0),
|
||||
latency,
|
||||
usage.get("prompt_tokens", 0),
|
||||
usage.get("completion_tokens", 0),
|
||||
)
|
||||
|
||||
# Log to MLflow
|
||||
_log_llm_metrics(
|
||||
latency=latency,
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
# Yield text progressively for a nicer streaming feel
|
||||
chunk_size = 4
|
||||
words = text.split(" ")
|
||||
partial = ""
|
||||
for i, word in enumerate(words):
|
||||
partial += ("" if i == 0 else " ") + word
|
||||
if i % chunk_size == 0 or i == len(words) - 1:
|
||||
yield partial
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.exception("LLM request failed")
|
||||
yield f"❌ LLM service error: {e.response.status_code} — {e.response.text[:200]}"
|
||||
except httpx.ConnectError:
|
||||
yield "❌ Cannot connect to LLM service. Is the Ray Serve cluster running?"
|
||||
except Exception as e:
|
||||
logger.exception("LLM chat failed")
|
||||
yield f"❌ Error: {e}"
|
||||
|
||||
|
||||
def check_service_health() -> str:
|
||||
"""Check if the LLM service is reachable."""
|
||||
try:
|
||||
# Try a lightweight GET against the Ray Serve base first.
|
||||
# This avoids burning GPU time on a full inference round-trip.
|
||||
base_url = LLM_URL.rsplit("/", 1)[0] # strip /llm path
|
||||
response = sync_client.get(f"{base_url}/-/routes")
|
||||
if response.status_code == 200:
|
||||
return "🟢 LLM service is healthy"
|
||||
# Fall back to a minimal inference probe
|
||||
response = sync_client.post(
|
||||
LLM_URL,
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "ping"}],
|
||||
"max_tokens": 1,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return "🟢 LLM service is healthy"
|
||||
return f"🟡 LLM responded with status {response.status_code}"
|
||||
except httpx.ConnectError:
|
||||
return "🔴 Cannot connect to LLM service"
|
||||
except httpx.TimeoutException:
|
||||
return "🟡 LLM service is reachable but slow to respond"
|
||||
except Exception as e:
|
||||
return f"🔴 Service unavailable: {e}"
|
||||
|
||||
|
||||
def single_prompt(
|
||||
prompt: str,
|
||||
system_prompt: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
top_p: float,
|
||||
) -> tuple[str, str]:
|
||||
"""Send a single prompt (non-chat mode) and return output + metrics."""
|
||||
if not prompt.strip():
|
||||
return "❌ Please enter a prompt", ""
|
||||
|
||||
messages = []
|
||||
if system_prompt.strip():
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"top_p": top_p,
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
client = httpx.Client(timeout=300.0)
|
||||
response = client.post(LLM_URL, json=payload)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
latency = time.time() - start_time
|
||||
|
||||
text = result["choices"][0]["message"]["content"]
|
||||
usage = result.get("usage", {})
|
||||
|
||||
# Log to MLflow
|
||||
_log_llm_metrics(
|
||||
latency=latency,
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
metrics = f"""
|
||||
**Generation Metrics:**
|
||||
- Latency: {latency:.1f}s
|
||||
- Prompt tokens: {usage.get("prompt_tokens", "N/A")}
|
||||
- Completion tokens: {usage.get("completion_tokens", "N/A")}
|
||||
- Total tokens: {usage.get("total_tokens", "N/A")}
|
||||
- Model: {result.get("model", "N/A")}
|
||||
"""
|
||||
return text, metrics
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
return f"❌ Error {e.response.status_code}: {e.response.text[:300]}", ""
|
||||
except httpx.ConnectError:
|
||||
return "❌ Cannot connect to LLM service", ""
|
||||
except Exception as e:
|
||||
return f"❌ {e}", ""
|
||||
|
||||
|
||||
# ─── Build the Gradio app ────────────────────────────────────────────────
|
||||
|
||||
with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="LLM Chat Demo") as demo:
|
||||
gr.Markdown(
|
||||
"""
|
||||
# 🧠 LLM Chat Demo
|
||||
|
||||
Chat with **Llama 3.1 70B** (AWQ INT4) served via vLLM on AMD Strix Halo (ROCm).
|
||||
"""
|
||||
)
|
||||
|
||||
# Service status
|
||||
with gr.Row():
|
||||
health_btn = gr.Button("🔄 Check Service", size="sm")
|
||||
health_status = gr.Textbox(label="Service Status", interactive=False)
|
||||
|
||||
health_btn.click(fn=check_service_health, outputs=health_status)
|
||||
|
||||
# Shared parameters
|
||||
with gr.Accordion("⚙️ Parameters", open=False):
|
||||
system_prompt = gr.Textbox(
|
||||
label="System Prompt",
|
||||
value=DEFAULT_SYSTEM_PROMPT,
|
||||
lines=3,
|
||||
max_lines=6,
|
||||
)
|
||||
with gr.Row():
|
||||
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
|
||||
max_tokens = gr.Slider(16, 4096, value=512, step=16, label="Max Tokens")
|
||||
top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-p")
|
||||
|
||||
with gr.Tabs():
|
||||
# Tab 1: Multi-turn Chat
|
||||
with gr.TabItem("💬 Chat"):
|
||||
chatbot = gr.ChatInterface(
|
||||
fn=chat_stream,
|
||||
additional_inputs=[system_prompt, temperature, max_tokens, top_p],
|
||||
examples=[
|
||||
["Hello! What can you tell me about yourself?"],
|
||||
["Explain how a GPU executes a matrix multiplication."],
|
||||
["Write a Python function to compute the Fibonacci sequence."],
|
||||
["What are the pros and cons of running LLMs on AMD GPUs?"],
|
||||
],
|
||||
chatbot=gr.Chatbot(
|
||||
height=520,
|
||||
placeholder="Type a message to start chatting...",
|
||||
),
|
||||
)
|
||||
|
||||
# Tab 2: Single Prompt
|
||||
with gr.TabItem("📝 Single Prompt"):
|
||||
gr.Markdown("Send a one-shot prompt without conversation history.")
|
||||
|
||||
prompt_input = gr.Textbox(
|
||||
label="Prompt",
|
||||
placeholder="Enter your prompt...",
|
||||
lines=4,
|
||||
max_lines=10,
|
||||
)
|
||||
generate_btn = gr.Button("🚀 Generate", variant="primary")
|
||||
|
||||
output_text = gr.Textbox(label="Response", lines=12, interactive=False)
|
||||
output_metrics = gr.Markdown(label="Metrics")
|
||||
|
||||
generate_btn.click(
|
||||
fn=single_prompt,
|
||||
inputs=[prompt_input, system_prompt, temperature, max_tokens, top_p],
|
||||
outputs=[output_text, output_metrics],
|
||||
)
|
||||
|
||||
gr.Examples(
|
||||
examples=[
|
||||
[
|
||||
"Summarise the key differences between CUDA and ROCm for ML workloads."
|
||||
],
|
||||
["Write a haiku about Kubernetes."],
|
||||
[
|
||||
"Explain Ray Serve in one paragraph for someone new to ML serving."
|
||||
],
|
||||
["List 5 creative uses for a homelab GPU cluster."],
|
||||
],
|
||||
inputs=[prompt_input],
|
||||
)
|
||||
|
||||
create_footer()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|
||||
96
llm.yaml
Normal file
96
llm.yaml
Normal file
@@ -0,0 +1,96 @@
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: llm-ui
|
||||
namespace: ai-ml
|
||||
labels:
|
||||
app: llm
|
||||
component: demo-ui
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: llm
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: llm
|
||||
component: demo-ui
|
||||
spec:
|
||||
containers:
|
||||
- name: gradio
|
||||
image: gitea-http.gitea.svc.cluster.local:3000/daviestechlabs/gradio-ui:latest
|
||||
imagePullPolicy: Always
|
||||
command: ["python", "llm.py"]
|
||||
ports:
|
||||
- containerPort: 7860
|
||||
name: http
|
||||
protocol: TCP
|
||||
env:
|
||||
- name: LLM_URL
|
||||
# Ray Serve endpoint - routes to /llm prefix
|
||||
value: "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/llm"
|
||||
- name: MLFLOW_TRACKING_URI
|
||||
value: "http://mlflow.mlflow.svc.cluster.local:80"
|
||||
resources:
|
||||
requests:
|
||||
cpu: "100m"
|
||||
memory: "256Mi"
|
||||
limits:
|
||||
cpu: "500m"
|
||||
memory: "512Mi"
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /
|
||||
port: 7860
|
||||
initialDelaySeconds: 10
|
||||
periodSeconds: 30
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /
|
||||
port: 7860
|
||||
initialDelaySeconds: 5
|
||||
periodSeconds: 10
|
||||
imagePullSecrets:
|
||||
- name: gitea-registry
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: llm-ui
|
||||
namespace: ai-ml
|
||||
labels:
|
||||
app: llm
|
||||
spec:
|
||||
type: ClusterIP
|
||||
ports:
|
||||
- port: 80
|
||||
targetPort: 7860
|
||||
protocol: TCP
|
||||
name: http
|
||||
selector:
|
||||
app: llm
|
||||
---
|
||||
apiVersion: gateway.networking.k8s.io/v1
|
||||
kind: HTTPRoute
|
||||
metadata:
|
||||
name: llm-ui
|
||||
namespace: ai-ml
|
||||
annotations:
|
||||
external-dns.alpha.kubernetes.io/hostname: llm-ui.lab.daviestechlabs.io
|
||||
spec:
|
||||
parentRefs:
|
||||
- name: envoy-internal
|
||||
namespace: network
|
||||
sectionName: https-lab
|
||||
hostnames:
|
||||
- llm-ui.lab.daviestechlabs.io
|
||||
rules:
|
||||
- matches:
|
||||
- path:
|
||||
type: PathPrefix
|
||||
value: /
|
||||
backendRefs:
|
||||
- name: llm-ui
|
||||
port: 80
|
||||
7
renovate.json
Normal file
7
renovate.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
|
||||
"extends": [
|
||||
"local>daviestechlabs/renovate-config",
|
||||
"local>daviestechlabs/renovate-config:python"
|
||||
]
|
||||
}
|
||||
222
stt.py
222
stt.py
@@ -9,11 +9,11 @@ Features:
|
||||
- Translation mode
|
||||
- MLflow metrics logging
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import io
|
||||
import tempfile
|
||||
|
||||
import gradio as gr
|
||||
import httpx
|
||||
@@ -30,13 +30,82 @@ logger = logging.getLogger("stt-demo")
|
||||
STT_URL = os.environ.get(
|
||||
"STT_URL",
|
||||
# Default: Ray Serve whisper endpoint
|
||||
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper"
|
||||
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper",
|
||||
)
|
||||
MLFLOW_TRACKING_URI = os.environ.get(
|
||||
"MLFLOW_TRACKING_URI",
|
||||
"http://mlflow.mlflow.svc.cluster.local:80"
|
||||
"MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"
|
||||
)
|
||||
|
||||
# ─── MLflow experiment tracking ──────────────────────────────────────────
|
||||
try:
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
|
||||
_mlflow_client = MlflowClient()
|
||||
|
||||
_experiment = _mlflow_client.get_experiment_by_name("gradio-stt-tuning")
|
||||
if _experiment is None:
|
||||
_experiment_id = _mlflow_client.create_experiment(
|
||||
"gradio-stt-tuning",
|
||||
artifact_location="/mlflow/artifacts/gradio-stt-tuning",
|
||||
)
|
||||
else:
|
||||
_experiment_id = _experiment.experiment_id
|
||||
|
||||
_mlflow_run = mlflow.start_run(
|
||||
experiment_id=_experiment_id,
|
||||
run_name=f"gradio-stt-{os.environ.get('HOSTNAME', 'local')}",
|
||||
tags={"service": "gradio-stt", "endpoint": STT_URL},
|
||||
)
|
||||
_mlflow_run_id = _mlflow_run.info.run_id
|
||||
_mlflow_step = 0
|
||||
MLFLOW_ENABLED = True
|
||||
logger.info(
|
||||
"MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("MLflow tracking disabled: %s", exc)
|
||||
_mlflow_client = None
|
||||
_mlflow_run_id = None
|
||||
_mlflow_step = 0
|
||||
MLFLOW_ENABLED = False
|
||||
|
||||
|
||||
def _log_stt_metrics(
|
||||
latency: float,
|
||||
audio_duration: float,
|
||||
word_count: int,
|
||||
task: str,
|
||||
) -> None:
|
||||
"""Log STT inference metrics to MLflow (non-blocking best-effort)."""
|
||||
global _mlflow_step
|
||||
if not MLFLOW_ENABLED or _mlflow_client is None:
|
||||
return
|
||||
try:
|
||||
_mlflow_step += 1
|
||||
ts = int(time.time() * 1000)
|
||||
rtf = latency / audio_duration if audio_duration > 0 else 0
|
||||
_mlflow_client.log_batch(
|
||||
_mlflow_run_id,
|
||||
metrics=[
|
||||
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
|
||||
mlflow.entities.Metric(
|
||||
"audio_duration_s", audio_duration, ts, _mlflow_step
|
||||
),
|
||||
mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step),
|
||||
mlflow.entities.Metric("word_count", word_count, ts, _mlflow_step),
|
||||
],
|
||||
params=[]
|
||||
if _mlflow_step > 1
|
||||
else [
|
||||
mlflow.entities.Param("task", task),
|
||||
],
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("MLflow log failed", exc_info=True)
|
||||
|
||||
|
||||
# HTTP client with longer timeout for transcription
|
||||
client = httpx.Client(timeout=180.0)
|
||||
|
||||
@@ -63,77 +132,85 @@ LANGUAGES = {
|
||||
|
||||
|
||||
def transcribe_audio(
|
||||
audio_input: tuple[int, np.ndarray] | str | None,
|
||||
language: str,
|
||||
task: str
|
||||
audio_input: tuple[int, np.ndarray] | str | None, language: str, task: str
|
||||
) -> tuple[str, str, str]:
|
||||
"""Transcribe audio using the Whisper STT service."""
|
||||
if audio_input is None:
|
||||
return "❌ Please provide audio input", "", ""
|
||||
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# Handle different input types
|
||||
if isinstance(audio_input, tuple):
|
||||
# Microphone input: (sample_rate, audio_data)
|
||||
sample_rate, audio_data = audio_input
|
||||
|
||||
|
||||
# Convert to WAV bytes
|
||||
audio_buffer = io.BytesIO()
|
||||
sf.write(audio_buffer, audio_data, sample_rate, format='WAV')
|
||||
sf.write(audio_buffer, audio_data, sample_rate, format="WAV")
|
||||
audio_bytes = audio_buffer.getvalue()
|
||||
audio_duration = len(audio_data) / sample_rate
|
||||
else:
|
||||
# File path
|
||||
with open(audio_input, 'rb') as f:
|
||||
with open(audio_input, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
# Get duration
|
||||
audio_data, sample_rate = sf.read(audio_input)
|
||||
audio_duration = len(audio_data) / sample_rate
|
||||
|
||||
|
||||
# Prepare request
|
||||
lang_code = LANGUAGES.get(language)
|
||||
|
||||
|
||||
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
|
||||
data = {"response_format": "json"}
|
||||
|
||||
|
||||
if lang_code:
|
||||
data["language"] = lang_code
|
||||
|
||||
|
||||
# Choose endpoint based on task
|
||||
if task == "Translate to English":
|
||||
endpoint = f"{STT_URL}/v1/audio/translations"
|
||||
else:
|
||||
endpoint = f"{STT_URL}/v1/audio/transcriptions"
|
||||
|
||||
|
||||
# Send request
|
||||
response = client.post(endpoint, files=files, data=data)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
latency = time.time() - start_time
|
||||
result = response.json()
|
||||
|
||||
|
||||
text = result.get("text", "")
|
||||
detected_language = result.get("language", "unknown")
|
||||
|
||||
|
||||
# Log to MLflow
|
||||
_log_stt_metrics(
|
||||
latency=latency,
|
||||
audio_duration=audio_duration,
|
||||
word_count=len(text.split()),
|
||||
task=task,
|
||||
)
|
||||
|
||||
# Status message
|
||||
status = f"✅ Transcribed {audio_duration:.1f}s of audio in {latency*1000:.0f}ms"
|
||||
|
||||
status = (
|
||||
f"✅ Transcribed {audio_duration:.1f}s of audio in {latency * 1000:.0f}ms"
|
||||
)
|
||||
|
||||
# Metrics
|
||||
metrics = f"""
|
||||
**Transcription Statistics:**
|
||||
- Audio Duration: {audio_duration:.2f} seconds
|
||||
- Processing Time: {latency*1000:.0f}ms
|
||||
- Real-time Factor: {latency/audio_duration:.2f}x
|
||||
- Processing Time: {latency * 1000:.0f}ms
|
||||
- Real-time Factor: {latency / audio_duration:.2f}x
|
||||
- Detected Language: {detected_language}
|
||||
- Task: {task}
|
||||
- Word Count: {len(text.split())}
|
||||
- Character Count: {len(text)}
|
||||
"""
|
||||
|
||||
|
||||
return status, text, metrics
|
||||
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.exception("STT request failed")
|
||||
return f"❌ STT service error: {e.response.status_code}", "", ""
|
||||
@@ -148,12 +225,12 @@ def check_service_health() -> str:
|
||||
response = client.get(f"{STT_URL}/health", timeout=5.0)
|
||||
if response.status_code == 200:
|
||||
return "🟢 Service is healthy"
|
||||
|
||||
|
||||
# Try v1/models endpoint (OpenAI-compatible)
|
||||
response = client.get(f"{STT_URL}/v1/models", timeout=5.0)
|
||||
if response.status_code == 200:
|
||||
return "🟢 Service is healthy"
|
||||
|
||||
|
||||
return f"🟡 Service returned status {response.status_code}"
|
||||
except Exception as e:
|
||||
return f"🔴 Service unavailable: {str(e)}"
|
||||
@@ -167,99 +244,89 @@ with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="STT Demo") as demo:
|
||||
Test the **Whisper** speech-to-text service. Transcribe audio from microphone
|
||||
or file upload with support for 100+ languages.
|
||||
""")
|
||||
|
||||
|
||||
# Service status
|
||||
with gr.Row():
|
||||
health_btn = gr.Button("🔄 Check Service", size="sm")
|
||||
health_status = gr.Textbox(label="Service Status", interactive=False)
|
||||
|
||||
|
||||
health_btn.click(fn=check_service_health, outputs=health_status)
|
||||
|
||||
|
||||
with gr.Tabs():
|
||||
# Tab 1: Microphone Input
|
||||
with gr.TabItem("🎤 Microphone"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
mic_input = gr.Audio(
|
||||
label="Record Audio",
|
||||
sources=["microphone"],
|
||||
type="numpy"
|
||||
label="Record Audio", sources=["microphone"], type="numpy"
|
||||
)
|
||||
|
||||
|
||||
with gr.Row():
|
||||
mic_language = gr.Dropdown(
|
||||
choices=list(LANGUAGES.keys()),
|
||||
value="Auto-detect",
|
||||
label="Language"
|
||||
label="Language",
|
||||
)
|
||||
mic_task = gr.Radio(
|
||||
choices=["Transcribe", "Translate to English"],
|
||||
value="Transcribe",
|
||||
label="Task"
|
||||
label="Task",
|
||||
)
|
||||
|
||||
|
||||
mic_btn = gr.Button("🎯 Transcribe", variant="primary")
|
||||
|
||||
|
||||
with gr.Column():
|
||||
mic_status = gr.Textbox(label="Status", interactive=False)
|
||||
mic_metrics = gr.Markdown(label="Metrics")
|
||||
|
||||
mic_output = gr.Textbox(
|
||||
label="Transcription",
|
||||
lines=5
|
||||
)
|
||||
|
||||
|
||||
mic_output = gr.Textbox(label="Transcription", lines=5)
|
||||
|
||||
mic_btn.click(
|
||||
fn=transcribe_audio,
|
||||
inputs=[mic_input, mic_language, mic_task],
|
||||
outputs=[mic_status, mic_output, mic_metrics]
|
||||
outputs=[mic_status, mic_output, mic_metrics],
|
||||
)
|
||||
|
||||
|
||||
# Tab 2: File Upload
|
||||
with gr.TabItem("📁 File Upload"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
file_input = gr.Audio(
|
||||
label="Upload Audio File",
|
||||
sources=["upload"],
|
||||
type="filepath"
|
||||
label="Upload Audio File", sources=["upload"], type="filepath"
|
||||
)
|
||||
|
||||
|
||||
with gr.Row():
|
||||
file_language = gr.Dropdown(
|
||||
choices=list(LANGUAGES.keys()),
|
||||
value="Auto-detect",
|
||||
label="Language"
|
||||
label="Language",
|
||||
)
|
||||
file_task = gr.Radio(
|
||||
choices=["Transcribe", "Translate to English"],
|
||||
value="Transcribe",
|
||||
label="Task"
|
||||
label="Task",
|
||||
)
|
||||
|
||||
|
||||
file_btn = gr.Button("🎯 Transcribe", variant="primary")
|
||||
|
||||
|
||||
with gr.Column():
|
||||
file_status = gr.Textbox(label="Status", interactive=False)
|
||||
file_metrics = gr.Markdown(label="Metrics")
|
||||
|
||||
file_output = gr.Textbox(
|
||||
label="Transcription",
|
||||
lines=5
|
||||
)
|
||||
|
||||
|
||||
file_output = gr.Textbox(label="Transcription", lines=5)
|
||||
|
||||
file_btn.click(
|
||||
fn=transcribe_audio,
|
||||
inputs=[file_input, file_language, file_task],
|
||||
outputs=[file_status, file_output, file_metrics]
|
||||
outputs=[file_status, file_output, file_metrics],
|
||||
)
|
||||
|
||||
|
||||
gr.Markdown("""
|
||||
**Supported formats:** WAV, MP3, FLAC, OGG, M4A, WEBM
|
||||
|
||||
*For best results, use clear audio with minimal background noise.*
|
||||
""")
|
||||
|
||||
|
||||
# Tab 3: Translation
|
||||
with gr.TabItem("🌍 Translation"):
|
||||
gr.Markdown("""
|
||||
@@ -268,40 +335,33 @@ or file upload with support for 100+ languages.
|
||||
Upload or record audio in any language and get English translation.
|
||||
Whisper will automatically detect the source language.
|
||||
""")
|
||||
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
trans_input = gr.Audio(
|
||||
label="Audio Input",
|
||||
sources=["microphone", "upload"],
|
||||
type="numpy"
|
||||
type="numpy",
|
||||
)
|
||||
trans_btn = gr.Button("🌍 Translate to English", variant="primary")
|
||||
|
||||
|
||||
with gr.Column():
|
||||
trans_status = gr.Textbox(label="Status", interactive=False)
|
||||
trans_metrics = gr.Markdown(label="Metrics")
|
||||
|
||||
trans_output = gr.Textbox(
|
||||
label="English Translation",
|
||||
lines=5
|
||||
)
|
||||
|
||||
|
||||
trans_output = gr.Textbox(label="English Translation", lines=5)
|
||||
|
||||
def translate_audio(audio):
|
||||
return transcribe_audio(audio, "Auto-detect", "Translate to English")
|
||||
|
||||
|
||||
trans_btn.click(
|
||||
fn=translate_audio,
|
||||
inputs=trans_input,
|
||||
outputs=[trans_status, trans_output, trans_metrics]
|
||||
outputs=[trans_status, trans_output, trans_metrics],
|
||||
)
|
||||
|
||||
|
||||
create_footer()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch(
|
||||
server_name="0.0.0.0",
|
||||
server_port=7860,
|
||||
show_error=True
|
||||
)
|
||||
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|
||||
|
||||
2
stt.yaml
2
stt.yaml
@@ -20,7 +20,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: gradio
|
||||
image: ghcr.io/billy-davies-2/llm-apps:v2-202601271655
|
||||
image: gitea-http.gitea.svc.cluster.local:3000/daviestechlabs/gradio-ui:latest
|
||||
imagePullPolicy: Always
|
||||
command: ["python", "stt.py"]
|
||||
ports:
|
||||
|
||||
184
theme.py
184
theme.py
@@ -3,6 +3,7 @@ Shared Gradio theme for Davies Tech Labs AI demos.
|
||||
Consistent styling across all demo applications.
|
||||
Cyberpunk aesthetic - dark with yellow/gold accents.
|
||||
"""
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
@@ -25,7 +26,12 @@ def get_lab_theme() -> gr.Theme:
|
||||
primary_hue=gr.themes.colors.yellow,
|
||||
secondary_hue=gr.themes.colors.amber,
|
||||
neutral_hue=gr.themes.colors.zinc,
|
||||
font=[gr.themes.GoogleFont("Space Grotesk"), "ui-sans-serif", "system-ui", "sans-serif"],
|
||||
font=[
|
||||
gr.themes.GoogleFont("Space Grotesk"),
|
||||
"ui-sans-serif",
|
||||
"system-ui",
|
||||
"sans-serif",
|
||||
],
|
||||
font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "ui-monospace", "monospace"],
|
||||
).set(
|
||||
# Background colors
|
||||
@@ -75,10 +81,38 @@ def get_lab_theme() -> gr.Theme:
|
||||
block_background_fill_dark=CYBER_GRAY,
|
||||
block_border_color="#2a2a2a",
|
||||
block_border_color_dark="#2a2a2a",
|
||||
block_label_background_fill="#1a1a00",
|
||||
block_label_background_fill_dark="#1a1a00",
|
||||
block_label_text_color=CYBER_YELLOW,
|
||||
block_label_text_color_dark=CYBER_YELLOW,
|
||||
block_label_border_color=CYBER_YELLOW,
|
||||
block_label_border_color_dark=CYBER_YELLOW,
|
||||
block_title_text_color=CYBER_TEXT,
|
||||
block_title_text_color_dark=CYBER_TEXT,
|
||||
# Table / Dataframe
|
||||
table_border_color="#2a2a2a",
|
||||
table_even_background_fill="#111111",
|
||||
table_even_background_fill_dark="#111111",
|
||||
table_odd_background_fill=CYBER_GRAY,
|
||||
table_odd_background_fill_dark=CYBER_GRAY,
|
||||
table_row_focus="#1f1a00",
|
||||
table_row_focus_dark="#1f1a00",
|
||||
# Panel / accordion
|
||||
panel_background_fill=CYBER_DARK,
|
||||
panel_background_fill_dark=CYBER_DARK,
|
||||
panel_border_color="#2a2a2a",
|
||||
panel_border_color_dark="#2a2a2a",
|
||||
# Checkbox / radio
|
||||
checkbox_background_color=CYBER_DARKER,
|
||||
checkbox_background_color_dark=CYBER_DARKER,
|
||||
checkbox_label_background_fill=CYBER_GRAY,
|
||||
checkbox_label_background_fill_dark=CYBER_GRAY,
|
||||
checkbox_label_text_color=CYBER_TEXT,
|
||||
checkbox_label_text_color_dark=CYBER_TEXT,
|
||||
# Colors
|
||||
color_accent=CYBER_YELLOW,
|
||||
color_accent_soft="#1f1a00",
|
||||
color_accent_soft_dark="#1f1a00",
|
||||
)
|
||||
|
||||
|
||||
@@ -338,6 +372,154 @@ label, .gr-label {
|
||||
.glow-text {
|
||||
text-shadow: 0 0 10px var(--cyber-yellow), 0 0 20px var(--cyber-yellow);
|
||||
}
|
||||
|
||||
/* ── Examples table / Dataframe overrides ── */
|
||||
/* Gradio renders Examples as a <table> inside a Dataset component.
|
||||
The default styles inject white / light-gray rows that blow out the
|
||||
cyberpunk palette. Force them dark here. */
|
||||
.gr-samples-table,
|
||||
.gr-sample-textbox,
|
||||
table.table,
|
||||
.gr-examples table,
|
||||
div[class*="dataset"] table {
|
||||
background: var(--cyber-dark) !important;
|
||||
color: var(--cyber-text) !important;
|
||||
}
|
||||
|
||||
.gr-samples-table tr,
|
||||
.gr-examples table tr,
|
||||
div[class*="dataset"] table tr {
|
||||
background: #111111 !important;
|
||||
border-bottom: 1px solid #222 !important;
|
||||
}
|
||||
|
||||
.gr-samples-table tr:nth-child(even),
|
||||
.gr-examples table tr:nth-child(even),
|
||||
div[class*="dataset"] table tr:nth-child(even) {
|
||||
background: #0d0d0d !important;
|
||||
}
|
||||
|
||||
.gr-samples-table tr:hover,
|
||||
.gr-examples table tr:hover,
|
||||
div[class*="dataset"] table tr:hover {
|
||||
background: #1f1a00 !important;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.gr-samples-table th,
|
||||
.gr-examples table th,
|
||||
div[class*="dataset"] table th {
|
||||
background: var(--cyber-gray) !important;
|
||||
color: var(--cyber-yellow) !important;
|
||||
text-transform: uppercase !important;
|
||||
font-size: 0.75rem !important;
|
||||
letter-spacing: 0.1em !important;
|
||||
border-bottom: 2px solid var(--cyber-yellow) !important;
|
||||
padding: 10px 16px !important;
|
||||
}
|
||||
|
||||
.gr-samples-table td,
|
||||
.gr-examples table td,
|
||||
div[class*="dataset"] table td {
|
||||
color: #999 !important;
|
||||
border-bottom: 1px solid #1a1a1a !important;
|
||||
padding: 10px 16px !important;
|
||||
font-family: 'JetBrains Mono', monospace !important;
|
||||
font-size: 0.85rem !important;
|
||||
}
|
||||
|
||||
/* ── Block label pill (e.g. "GENERATED AUDIO", "STATUS") ── */
|
||||
/* These are the small floating labels above each component block */
|
||||
span[class*="label-wrap"],
|
||||
.gr-block-label,
|
||||
.label-wrap {
|
||||
background: #1a1a00 !important;
|
||||
border: 1px solid var(--cyber-yellow) !important;
|
||||
color: var(--cyber-yellow) !important;
|
||||
}
|
||||
|
||||
/* ── Dropdown / select menus ── */
|
||||
.gr-dropdown,
|
||||
select,
|
||||
ul[role="listbox"],
|
||||
div[class*="dropdown"],
|
||||
.secondary-wrap {
|
||||
background: #0a0a0a !important;
|
||||
color: var(--cyber-text) !important;
|
||||
border-color: #333 !important;
|
||||
}
|
||||
|
||||
ul[role="listbox"] li,
|
||||
div[class*="dropdown"] li {
|
||||
background: #0a0a0a !important;
|
||||
color: var(--cyber-text) !important;
|
||||
}
|
||||
|
||||
ul[role="listbox"] li:hover,
|
||||
ul[role="listbox"] li[aria-selected="true"],
|
||||
div[class*="dropdown"] li:hover {
|
||||
background: #1f1a00 !important;
|
||||
color: var(--cyber-yellow) !important;
|
||||
}
|
||||
|
||||
/* ── Audio player ── */
|
||||
.gr-audio,
|
||||
audio {
|
||||
background: var(--cyber-dark) !important;
|
||||
border: 1px solid #2a2a2a !important;
|
||||
}
|
||||
|
||||
/* Audio waveform container */
|
||||
div[data-testid="waveform-container"],
|
||||
div[class*="audio"] {
|
||||
background: #0a0a0a !important;
|
||||
}
|
||||
|
||||
/* ── Markdown inside blocks ── */
|
||||
.gr-markdown,
|
||||
.gr-markdown p,
|
||||
.prose {
|
||||
color: var(--cyber-text) !important;
|
||||
}
|
||||
|
||||
.gr-markdown h3,
|
||||
.gr-markdown h2 {
|
||||
color: var(--cyber-yellow) !important;
|
||||
letter-spacing: 0.05em !important;
|
||||
}
|
||||
|
||||
.gr-markdown strong {
|
||||
color: var(--cyber-gold) !important;
|
||||
}
|
||||
|
||||
/* ── Examples accordion header ("Examples" label) ── */
|
||||
.gr-examples .label-wrap,
|
||||
div[id*="examples"] .label-wrap,
|
||||
span[data-testid="block-label"] {
|
||||
background: #1a1a00 !important;
|
||||
color: var(--cyber-yellow) !important;
|
||||
border: 1px solid var(--cyber-yellow) !important;
|
||||
font-size: 0.7rem !important;
|
||||
text-transform: uppercase !important;
|
||||
letter-spacing: 0.1em !important;
|
||||
}
|
||||
|
||||
/* ── Misc: tooltip, info text ── */
|
||||
.gr-info,
|
||||
.gr-description {
|
||||
color: #666 !important;
|
||||
}
|
||||
|
||||
/* ── Svelte internal: make sure no white backgrounds leak ── */
|
||||
.contain > div,
|
||||
.wrap > div {
|
||||
background: inherit !important;
|
||||
}
|
||||
|
||||
/* ── Tab content panels ── */
|
||||
.tabitem {
|
||||
background: var(--cyber-dark) !important;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
|
||||
217
tts.py
217
tts.py
@@ -9,11 +9,11 @@ Features:
|
||||
- MLflow metrics logging
|
||||
- Multiple TTS backends support (Coqui XTTS, Piper, etc.)
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import io
|
||||
import base64
|
||||
|
||||
import gradio as gr
|
||||
import httpx
|
||||
@@ -30,13 +30,79 @@ logger = logging.getLogger("tts-demo")
|
||||
TTS_URL = os.environ.get(
|
||||
"TTS_URL",
|
||||
# Default: Ray Serve TTS endpoint
|
||||
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/tts"
|
||||
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/tts",
|
||||
)
|
||||
MLFLOW_TRACKING_URI = os.environ.get(
|
||||
"MLFLOW_TRACKING_URI",
|
||||
"http://mlflow.mlflow.svc.cluster.local:80"
|
||||
"MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80"
|
||||
)
|
||||
|
||||
# ─── MLflow experiment tracking ──────────────────────────────────────────
|
||||
try:
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
|
||||
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
|
||||
_mlflow_client = MlflowClient()
|
||||
|
||||
_experiment = _mlflow_client.get_experiment_by_name("gradio-tts-tuning")
|
||||
if _experiment is None:
|
||||
_experiment_id = _mlflow_client.create_experiment(
|
||||
"gradio-tts-tuning",
|
||||
artifact_location="/mlflow/artifacts/gradio-tts-tuning",
|
||||
)
|
||||
else:
|
||||
_experiment_id = _experiment.experiment_id
|
||||
|
||||
_mlflow_run = mlflow.start_run(
|
||||
experiment_id=_experiment_id,
|
||||
run_name=f"gradio-tts-{os.environ.get('HOSTNAME', 'local')}",
|
||||
tags={"service": "gradio-tts", "endpoint": TTS_URL},
|
||||
)
|
||||
_mlflow_run_id = _mlflow_run.info.run_id
|
||||
_mlflow_step = 0
|
||||
MLFLOW_ENABLED = True
|
||||
logger.info(
|
||||
"MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("MLflow tracking disabled: %s", exc)
|
||||
_mlflow_client = None
|
||||
_mlflow_run_id = None
|
||||
_mlflow_step = 0
|
||||
MLFLOW_ENABLED = False
|
||||
|
||||
|
||||
def _log_tts_metrics(
|
||||
latency: float,
|
||||
audio_duration: float,
|
||||
text_chars: int,
|
||||
language: str,
|
||||
) -> None:
|
||||
"""Log TTS inference metrics to MLflow (non-blocking best-effort)."""
|
||||
global _mlflow_step
|
||||
if not MLFLOW_ENABLED or _mlflow_client is None:
|
||||
return
|
||||
try:
|
||||
_mlflow_step += 1
|
||||
ts = int(time.time() * 1000)
|
||||
rtf = latency / audio_duration if audio_duration > 0 else 0
|
||||
cps = text_chars / latency if latency > 0 else 0
|
||||
_mlflow_client.log_batch(
|
||||
_mlflow_run_id,
|
||||
metrics=[
|
||||
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
|
||||
mlflow.entities.Metric(
|
||||
"audio_duration_s", audio_duration, ts, _mlflow_step
|
||||
),
|
||||
mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step),
|
||||
mlflow.entities.Metric("chars_per_second", cps, ts, _mlflow_step),
|
||||
mlflow.entities.Metric("text_chars", text_chars, ts, _mlflow_step),
|
||||
],
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("MLflow log failed", exc_info=True)
|
||||
|
||||
|
||||
# HTTP client with longer timeout for audio generation
|
||||
client = httpx.Client(timeout=120.0)
|
||||
|
||||
@@ -61,54 +127,63 @@ LANGUAGES = {
|
||||
}
|
||||
|
||||
|
||||
def synthesize_speech(text: str, language: str) -> tuple[str, tuple[int, np.ndarray] | None, str]:
|
||||
def synthesize_speech(
|
||||
text: str, language: str
|
||||
) -> tuple[str, tuple[int, np.ndarray] | None, str]:
|
||||
"""Synthesize speech from text using the TTS service."""
|
||||
if not text.strip():
|
||||
return "❌ Please enter some text", None, ""
|
||||
|
||||
|
||||
lang_code = LANGUAGES.get(language, "en")
|
||||
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# Call TTS service (Coqui XTTS API format)
|
||||
response = client.get(
|
||||
f"{TTS_URL}/api/tts",
|
||||
params={"text": text, "language_id": lang_code}
|
||||
f"{TTS_URL}/api/tts", params={"text": text, "language_id": lang_code}
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
latency = time.time() - start_time
|
||||
audio_bytes = response.content
|
||||
|
||||
|
||||
# Parse audio data
|
||||
audio_io = io.BytesIO(audio_bytes)
|
||||
audio_data, sample_rate = sf.read(audio_io)
|
||||
|
||||
|
||||
# Calculate duration
|
||||
if len(audio_data.shape) == 1:
|
||||
duration = len(audio_data) / sample_rate
|
||||
else:
|
||||
duration = len(audio_data) / sample_rate
|
||||
|
||||
|
||||
# Status message
|
||||
status = f"✅ Generated {duration:.2f}s of audio in {latency*1000:.0f}ms"
|
||||
|
||||
status = f"✅ Generated {duration:.2f}s of audio in {latency * 1000:.0f}ms"
|
||||
|
||||
# Log to MLflow
|
||||
_log_tts_metrics(
|
||||
latency=latency,
|
||||
audio_duration=duration,
|
||||
text_chars=len(text),
|
||||
language=lang_code,
|
||||
)
|
||||
|
||||
# Metrics
|
||||
metrics = f"""
|
||||
**Audio Statistics:**
|
||||
- Duration: {duration:.2f} seconds
|
||||
- Sample Rate: {sample_rate} Hz
|
||||
- Size: {len(audio_bytes) / 1024:.1f} KB
|
||||
- Generation Time: {latency*1000:.0f}ms
|
||||
- Real-time Factor: {latency/duration:.2f}x
|
||||
- Generation Time: {latency * 1000:.0f}ms
|
||||
- Real-time Factor: {latency / duration:.2f}x
|
||||
- Language: {language} ({lang_code})
|
||||
- Characters: {len(text)}
|
||||
- Chars/sec: {len(text)/latency:.1f}
|
||||
- Chars/sec: {len(text) / latency:.1f}
|
||||
"""
|
||||
|
||||
|
||||
return status, (sample_rate, audio_data), metrics
|
||||
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.exception("TTS request failed")
|
||||
return f"❌ TTS service error: {e.response.status_code}", None, ""
|
||||
@@ -124,12 +199,12 @@ def check_service_health() -> str:
|
||||
response = client.get(f"{TTS_URL}/health", timeout=5.0)
|
||||
if response.status_code == 200:
|
||||
return "🟢 Service is healthy"
|
||||
|
||||
|
||||
# Fall back to root endpoint
|
||||
response = client.get(f"{TTS_URL}/", timeout=5.0)
|
||||
if response.status_code == 200:
|
||||
return "🟢 Service is responding"
|
||||
|
||||
|
||||
return f"🟡 Service returned status {response.status_code}"
|
||||
except Exception as e:
|
||||
return f"🔴 Service unavailable: {str(e)}"
|
||||
@@ -143,14 +218,14 @@ with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="TTS Demo") as demo:
|
||||
Test the **Coqui XTTS** text-to-speech service. Convert text to natural-sounding speech
|
||||
in multiple languages.
|
||||
""")
|
||||
|
||||
|
||||
# Service status
|
||||
with gr.Row():
|
||||
health_btn = gr.Button("🔄 Check Service", size="sm")
|
||||
health_status = gr.Textbox(label="Service Status", interactive=False)
|
||||
|
||||
|
||||
health_btn.click(fn=check_service_health, outputs=health_status)
|
||||
|
||||
|
||||
with gr.Tabs():
|
||||
# Tab 1: Basic TTS
|
||||
with gr.TabItem("🎤 Text to Speech"):
|
||||
@@ -160,114 +235,120 @@ in multiple languages.
|
||||
label="Text to Synthesize",
|
||||
placeholder="Enter text to convert to speech...",
|
||||
lines=5,
|
||||
max_lines=10
|
||||
max_lines=10,
|
||||
)
|
||||
|
||||
|
||||
with gr.Row():
|
||||
language = gr.Dropdown(
|
||||
choices=list(LANGUAGES.keys()),
|
||||
value="English",
|
||||
label="Language"
|
||||
label="Language",
|
||||
)
|
||||
synthesize_btn = gr.Button("🔊 Synthesize", variant="primary", scale=2)
|
||||
|
||||
synthesize_btn = gr.Button(
|
||||
"🔊 Synthesize", variant="primary", scale=2
|
||||
)
|
||||
|
||||
with gr.Column(scale=1):
|
||||
status_output = gr.Textbox(label="Status", interactive=False)
|
||||
metrics_output = gr.Markdown(label="Metrics")
|
||||
|
||||
|
||||
audio_output = gr.Audio(label="Generated Audio", type="numpy")
|
||||
|
||||
|
||||
synthesize_btn.click(
|
||||
fn=synthesize_speech,
|
||||
inputs=[text_input, language],
|
||||
outputs=[status_output, audio_output, metrics_output]
|
||||
outputs=[status_output, audio_output, metrics_output],
|
||||
)
|
||||
|
||||
|
||||
# Example texts
|
||||
gr.Examples(
|
||||
examples=[
|
||||
["Hello! Welcome to Davies Tech Labs. This is a demonstration of our text-to-speech system.", "English"],
|
||||
["The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet.", "English"],
|
||||
["Bonjour! Bienvenue au laboratoire technique de Davies.", "French"],
|
||||
[
|
||||
"Hello! Welcome to Davies Tech Labs. This is a demonstration of our text-to-speech system.",
|
||||
"English",
|
||||
],
|
||||
[
|
||||
"The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet.",
|
||||
"English",
|
||||
],
|
||||
[
|
||||
"Bonjour! Bienvenue au laboratoire technique de Davies.",
|
||||
"French",
|
||||
],
|
||||
["Hola! Bienvenido al laboratorio de tecnología.", "Spanish"],
|
||||
["Guten Tag! Willkommen im Techniklabor.", "German"],
|
||||
],
|
||||
inputs=[text_input, language],
|
||||
)
|
||||
|
||||
|
||||
# Tab 2: Comparison
|
||||
with gr.TabItem("🔄 Language Comparison"):
|
||||
gr.Markdown("Compare the same text in different languages.")
|
||||
|
||||
|
||||
compare_text = gr.Textbox(
|
||||
label="Text to Compare",
|
||||
value="Hello, how are you today?",
|
||||
lines=2
|
||||
label="Text to Compare", value="Hello, how are you today?", lines=2
|
||||
)
|
||||
|
||||
|
||||
with gr.Row():
|
||||
lang1 = gr.Dropdown(choices=list(LANGUAGES.keys()), value="English", label="Language 1")
|
||||
lang2 = gr.Dropdown(choices=list(LANGUAGES.keys()), value="Spanish", label="Language 2")
|
||||
|
||||
lang1 = gr.Dropdown(
|
||||
choices=list(LANGUAGES.keys()), value="English", label="Language 1"
|
||||
)
|
||||
lang2 = gr.Dropdown(
|
||||
choices=list(LANGUAGES.keys()), value="Spanish", label="Language 2"
|
||||
)
|
||||
|
||||
compare_btn = gr.Button("Compare Languages", variant="primary")
|
||||
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
gr.Markdown("### Language 1")
|
||||
audio1 = gr.Audio(label="Audio 1", type="numpy")
|
||||
status1 = gr.Textbox(label="Status", interactive=False)
|
||||
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("### Language 2")
|
||||
audio2 = gr.Audio(label="Audio 2", type="numpy")
|
||||
status2 = gr.Textbox(label="Status", interactive=False)
|
||||
|
||||
|
||||
def compare_languages(text, l1, l2):
|
||||
s1, a1, _ = synthesize_speech(text, l1)
|
||||
s2, a2, _ = synthesize_speech(text, l2)
|
||||
return s1, a1, s2, a2
|
||||
|
||||
|
||||
compare_btn.click(
|
||||
fn=compare_languages,
|
||||
inputs=[compare_text, lang1, lang2],
|
||||
outputs=[status1, audio1, status2, audio2]
|
||||
outputs=[status1, audio1, status2, audio2],
|
||||
)
|
||||
|
||||
|
||||
# Tab 3: Batch Processing
|
||||
with gr.TabItem("📚 Batch Synthesis"):
|
||||
gr.Markdown("Synthesize multiple texts at once (one per line).")
|
||||
|
||||
|
||||
batch_input = gr.Textbox(
|
||||
label="Texts (one per line)",
|
||||
placeholder="Enter multiple texts, one per line...",
|
||||
lines=6
|
||||
lines=6,
|
||||
)
|
||||
batch_lang = gr.Dropdown(
|
||||
choices=list(LANGUAGES.keys()),
|
||||
value="English",
|
||||
label="Language"
|
||||
choices=list(LANGUAGES.keys()), value="English", label="Language"
|
||||
)
|
||||
batch_btn = gr.Button("Synthesize All", variant="primary")
|
||||
|
||||
|
||||
batch_status = gr.Textbox(label="Status", interactive=False)
|
||||
batch_audios = gr.Dataset(
|
||||
components=[gr.Audio(type="numpy")],
|
||||
label="Generated Audio Files"
|
||||
components=[gr.Audio(type="numpy")], label="Generated Audio Files"
|
||||
)
|
||||
|
||||
|
||||
# Note: Batch processing would need more complex handling
|
||||
# This is a simplified version
|
||||
gr.Markdown("""
|
||||
*Note: For batch processing of many texts, consider using the API directly
|
||||
or the Kubeflow pipeline for better throughput.*
|
||||
""")
|
||||
|
||||
|
||||
create_footer()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch(
|
||||
server_name="0.0.0.0",
|
||||
server_port=7860,
|
||||
show_error=True
|
||||
)
|
||||
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|
||||
|
||||
Reference in New Issue
Block a user