diff --git a/embeddings.py b/embeddings.py index cae7867..7cd82de 100644 --- a/embeddings.py +++ b/embeddings.py @@ -9,6 +9,7 @@ Features: - MLflow metrics logging - Visual embedding dimension display """ + import os import time import logging @@ -26,9 +27,9 @@ logger = logging.getLogger("embeddings-demo") # Configuration EMBEDDINGS_URL = os.environ.get( - "EMBEDDINGS_URL", + "EMBEDDINGS_URL", # Default: Ray Serve Embeddings endpoint - "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/embeddings" + "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/embeddings", ) # ─── MLflow experiment tracking ────────────────────────────────────────── try: @@ -59,7 +60,9 @@ try: _mlflow_run_id = _mlflow_run.info.run_id _mlflow_step = 0 MLFLOW_ENABLED = True - logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id) + logger.info( + "MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id + ) except Exception as exc: logger.warning("MLflow tracking disabled: %s", exc) _mlflow_client = None @@ -68,7 +71,9 @@ except Exception as exc: MLFLOW_ENABLED = False -def _log_embedding_metrics(latency: float, batch_size: int, embedding_dims: int = 0) -> None: +def _log_embedding_metrics( + latency: float, batch_size: int, embedding_dims: int = 0 +) -> None: """Log embedding inference metrics to MLflow (non-blocking best-effort).""" global _mlflow_step if not MLFLOW_ENABLED or _mlflow_client is None: @@ -81,8 +86,15 @@ def _log_embedding_metrics(latency: float, batch_size: int, embedding_dims: int metrics=[ mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step), mlflow.entities.Metric("batch_size", batch_size, ts, _mlflow_step), - mlflow.entities.Metric("embedding_dims", embedding_dims, ts, _mlflow_step), - mlflow.entities.Metric("latency_per_text_ms", (latency * 1000 / batch_size) if batch_size > 0 else 0, ts, _mlflow_step), + mlflow.entities.Metric( + "embedding_dims", embedding_dims, ts, _mlflow_step + ), + mlflow.entities.Metric( + "latency_per_text_ms", + (latency * 1000 / batch_size) if batch_size > 0 else 0, + ts, + _mlflow_step, + ), ], ) except Exception: @@ -96,17 +108,16 @@ client = httpx.Client(timeout=60.0) def get_embeddings(texts: list[str]) -> tuple[list[list[float]], float]: """Get embeddings from the embeddings service.""" start_time = time.time() - + response = client.post( - f"{EMBEDDINGS_URL}/embeddings", - json={"input": texts, "model": "bge"} + f"{EMBEDDINGS_URL}/embeddings", json={"input": texts, "model": "bge"} ) response.raise_for_status() - + latency = time.time() - start_time result = response.json() embeddings = [d["embedding"] for d in result.get("data", [])] - + return embeddings, latency @@ -121,29 +132,29 @@ def generate_single_embedding(text: str) -> tuple[str, str, str]: """Generate embedding for a single text.""" if not text.strip(): return "❌ Please enter some text", "", "" - + try: embeddings, latency = get_embeddings([text]) - + if not embeddings: return "❌ No embedding returned", "", "" - + embedding = embeddings[0] dims = len(embedding) # Log to MLflow _log_embedding_metrics(latency, batch_size=1, embedding_dims=dims) - + # Format output - status = f"✅ Generated {dims}-dimensional embedding in {latency*1000:.1f}ms" - + status = f"✅ Generated {dims}-dimensional embedding in {latency * 1000:.1f}ms" + # Show first/last few dimensions preview = f"Dimensions: {dims}\n\n" preview += "First 10 values:\n" preview += json.dumps(embedding[:10], indent=2) preview += "\n\n...\n\nLast 10 values:\n" preview += json.dumps(embedding[-10:], indent=2) - + # Stats stats = f""" **Embedding Statistics:** @@ -153,11 +164,11 @@ def generate_single_embedding(text: str) -> tuple[str, str, str]: - Mean: {np.mean(embedding):.6f} - Std: {np.std(embedding):.6f} - L2 Norm: {np.linalg.norm(embedding):.6f} -- Latency: {latency*1000:.1f}ms +- Latency: {latency * 1000:.1f}ms """ - + return status, preview, stats - + except Exception as e: logger.exception("Embedding generation failed") return f"❌ Error: {str(e)}", "", "" @@ -167,18 +178,18 @@ def compare_texts(text1: str, text2: str) -> tuple[str, str]: """Compare similarity between two texts.""" if not text1.strip() or not text2.strip(): return "❌ Please enter both texts", "" - + try: embeddings, latency = get_embeddings([text1, text2]) - + if len(embeddings) != 2: return "❌ Failed to get embeddings for both texts", "" - + similarity = cosine_similarity(embeddings[0], embeddings[1]) # Log to MLflow _log_embedding_metrics(latency, batch_size=2, embedding_dims=len(embeddings[0])) - + # Determine similarity level if similarity > 0.9: level = "🟢 Very High" @@ -192,7 +203,7 @@ def compare_texts(text1: str, text2: str) -> tuple[str, str]: else: level = "🔴 Low" desc = "These texts are semantically different" - + result = f""" ## Similarity Score: {similarity:.4f} @@ -201,17 +212,17 @@ def compare_texts(text1: str, text2: str) -> tuple[str, str]: {desc} --- -*Computed in {latency*1000:.1f}ms* +*Computed in {latency * 1000:.1f}ms* """ - + # Create a simple visual bar bar_length = 50 filled = int(similarity * bar_length) bar = "█" * filled + "░" * (bar_length - filled) - visual = f"[{bar}] {similarity*100:.1f}%" - + visual = f"[{bar}] {similarity * 100:.1f}%" + return result, visual - + except Exception as e: logger.exception("Comparison failed") return f"❌ Error: {str(e)}", "" @@ -220,19 +231,23 @@ def compare_texts(text1: str, text2: str) -> tuple[str, str]: def batch_embed(texts_input: str) -> tuple[str, str]: """Generate embeddings for multiple texts (one per line).""" texts = [t.strip() for t in texts_input.strip().split("\n") if t.strip()] - + if not texts: return "❌ Please enter at least one text (one per line)", "" - + try: embeddings, latency = get_embeddings(texts) - - # Log to MLflow - _log_embedding_metrics(latency, batch_size=len(embeddings), embedding_dims=len(embeddings[0]) if embeddings else 0) - status = f"✅ Generated {len(embeddings)} embeddings in {latency*1000:.1f}ms" - status += f" ({latency*1000/len(texts):.1f}ms per text)" - + # Log to MLflow + _log_embedding_metrics( + latency, + batch_size=len(embeddings), + embedding_dims=len(embeddings[0]) if embeddings else 0, + ) + + status = f"✅ Generated {len(embeddings)} embeddings in {latency * 1000:.1f}ms" + status += f" ({latency * 1000 / len(texts):.1f}ms per text)" + # Build similarity matrix n = len(embeddings) matrix = [] @@ -242,16 +257,16 @@ def batch_embed(texts_input: str) -> tuple[str, str]: sim = cosine_similarity(embeddings[i], embeddings[j]) row.append(f"{sim:.3f}") matrix.append(row) - + # Format as table - header = "| | " + " | ".join([f"Text {i+1}" for i in range(n)]) + " |" + header = "| | " + " | ".join([f"Text {i + 1}" for i in range(n)]) + " |" separator = "|---" + "|---" * n + "|" rows = [] for i, row in enumerate(matrix): - rows.append(f"| **Text {i+1}** | " + " | ".join(row) + " |") - + rows.append(f"| **Text {i + 1}** | " + " | ".join(row) + " |") + table = "\n".join([header, separator] + rows) - + result = f""" ## Similarity Matrix @@ -261,10 +276,10 @@ def batch_embed(texts_input: str) -> tuple[str, str]: **Texts processed:** """ for i, text in enumerate(texts): - result += f"\n{i+1}. {text[:50]}{'...' if len(text) > 50 else ''}" - + result += f"\n{i + 1}. {text[:50]}{'...' if len(text) > 50 else ''}" + return status, result - + except Exception as e: logger.exception("Batch embedding failed") return f"❌ Error: {str(e)}", "" @@ -290,14 +305,14 @@ with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="Embeddings Demo") a Test the **BGE Embeddings** service for semantic text encoding. Generate embeddings, compare text similarity, and explore vector representations. """) - + # Service status with gr.Row(): health_btn = gr.Button("🔄 Check Service", size="sm") health_status = gr.Textbox(label="Service Status", interactive=False) - + health_btn.click(fn=check_service_health, outputs=health_status) - + with gr.Tabs(): # Tab 1: Single Embedding with gr.TabItem("📝 Single Text"): @@ -306,71 +321,74 @@ Generate embeddings, compare text similarity, and explore vector representations single_input = gr.Textbox( label="Input Text", placeholder="Enter text to generate embeddings...", - lines=3 + lines=3, ) single_btn = gr.Button("Generate Embedding", variant="primary") - + with gr.Column(): single_status = gr.Textbox(label="Status", interactive=False) single_stats = gr.Markdown(label="Statistics") - + single_preview = gr.Code(label="Embedding Preview", language="json") - + single_btn.click( fn=generate_single_embedding, inputs=single_input, - outputs=[single_status, single_preview, single_stats] + outputs=[single_status, single_preview, single_stats], ) - + # Tab 2: Compare Texts with gr.TabItem("⚖️ Compare Texts"): gr.Markdown("Compare the semantic similarity between two texts.") - + with gr.Row(): compare_text1 = gr.Textbox(label="Text 1", lines=3) compare_text2 = gr.Textbox(label="Text 2", lines=3) - + compare_btn = gr.Button("Compare Similarity", variant="primary") - + with gr.Row(): compare_result = gr.Markdown(label="Result") compare_visual = gr.Textbox(label="Similarity Bar", interactive=False) - + compare_btn.click( fn=compare_texts, inputs=[compare_text1, compare_text2], - outputs=[compare_result, compare_visual] + outputs=[compare_result, compare_visual], ) - + # Example pairs gr.Examples( examples=[ ["The cat sat on the mat.", "A feline was resting on the rug."], - ["Machine learning is a subset of AI.", "Deep learning uses neural networks."], + [ + "Machine learning is a subset of AI.", + "Deep learning uses neural networks.", + ], ["I love pizza.", "The stock market crashed today."], ], inputs=[compare_text1, compare_text2], ) - + # Tab 3: Batch Embeddings with gr.TabItem("📚 Batch Processing"): - gr.Markdown("Generate embeddings for multiple texts and see their similarity matrix.") - + gr.Markdown( + "Generate embeddings for multiple texts and see their similarity matrix." + ) + batch_input = gr.Textbox( label="Texts (one per line)", placeholder="Enter multiple texts, one per line...", - lines=6 + lines=6, ) batch_btn = gr.Button("Process Batch", variant="primary") batch_status = gr.Textbox(label="Status", interactive=False) batch_result = gr.Markdown(label="Similarity Matrix") - + batch_btn.click( - fn=batch_embed, - inputs=batch_input, - outputs=[batch_status, batch_result] + fn=batch_embed, inputs=batch_input, outputs=[batch_status, batch_result] ) - + gr.Examples( examples=[ "Python is a programming language.\nJava is also a programming language.\nCoffee is a beverage.", @@ -378,13 +396,9 @@ Generate embeddings, compare text similarity, and explore vector representations ], inputs=batch_input, ) - + create_footer() if __name__ == "__main__": - demo.launch( - server_name="0.0.0.0", - server_port=7860, - show_error=True - ) + demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True) diff --git a/llm.py b/llm.py index 3d0fe3b..6c210f2 100644 --- a/llm.py +++ b/llm.py @@ -9,10 +9,10 @@ Features: - Token usage and latency metrics - Chat history management """ + import os import time import logging -import json import gradio as gr import httpx @@ -65,7 +65,9 @@ try: _mlflow_run_id = _mlflow_run.info.run_id _mlflow_step = 0 MLFLOW_ENABLED = True - logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id) + logger.info( + "MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id + ) except Exception as exc: logger.warning("MLflow tracking disabled: %s", exc) _mlflow_client = None @@ -95,18 +97,25 @@ def _log_llm_metrics( _mlflow_run_id, metrics=[ mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step), - mlflow.entities.Metric("prompt_tokens", prompt_tokens, ts, _mlflow_step), - mlflow.entities.Metric("completion_tokens", completion_tokens, ts, _mlflow_step), + mlflow.entities.Metric( + "prompt_tokens", prompt_tokens, ts, _mlflow_step + ), + mlflow.entities.Metric( + "completion_tokens", completion_tokens, ts, _mlflow_step + ), mlflow.entities.Metric("total_tokens", total_tokens, ts, _mlflow_step), mlflow.entities.Metric("tokens_per_second", tps, ts, _mlflow_step), mlflow.entities.Metric("temperature", temperature, ts, _mlflow_step), - mlflow.entities.Metric("max_tokens_requested", max_tokens, ts, _mlflow_step), + mlflow.entities.Metric( + "max_tokens_requested", max_tokens, ts, _mlflow_step + ), mlflow.entities.Metric("top_p", top_p, ts, _mlflow_step), ], ) except Exception: logger.debug("MLflow log failed", exc_info=True) + DEFAULT_SYSTEM_PROMPT = ( "You are a helpful AI assistant running on Davies Tech Labs homelab infrastructure. " "You are powered by Llama 3.1 70B served via vLLM on AMD Strix Halo (ROCm). " @@ -273,10 +282,10 @@ def single_prompt( metrics = f""" **Generation Metrics:** - Latency: {latency:.1f}s -- Prompt tokens: {usage.get('prompt_tokens', 'N/A')} -- Completion tokens: {usage.get('completion_tokens', 'N/A')} -- Total tokens: {usage.get('total_tokens', 'N/A')} -- Model: {result.get('model', 'N/A')} +- Prompt tokens: {usage.get("prompt_tokens", "N/A")} +- Completion tokens: {usage.get("completion_tokens", "N/A")} +- Total tokens: {usage.get("total_tokens", "N/A")} +- Model: {result.get("model", "N/A")} """ return text, metrics @@ -360,9 +369,13 @@ Chat with **Llama 3.1 70B** (AWQ INT4) served via vLLM on AMD Strix Halo (ROCm). gr.Examples( examples=[ - ["Summarise the key differences between CUDA and ROCm for ML workloads."], + [ + "Summarise the key differences between CUDA and ROCm for ML workloads." + ], ["Write a haiku about Kubernetes."], - ["Explain Ray Serve in one paragraph for someone new to ML serving."], + [ + "Explain Ray Serve in one paragraph for someone new to ML serving." + ], ["List 5 creative uses for a homelab GPU cluster."], ], inputs=[prompt_input], diff --git a/stt.py b/stt.py index cc555ea..8d223aa 100644 --- a/stt.py +++ b/stt.py @@ -9,11 +9,11 @@ Features: - Translation mode - MLflow metrics logging """ + import os import time import logging import io -import tempfile import gradio as gr import httpx @@ -30,11 +30,10 @@ logger = logging.getLogger("stt-demo") STT_URL = os.environ.get( "STT_URL", # Default: Ray Serve whisper endpoint - "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper" + "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/whisper", ) MLFLOW_TRACKING_URI = os.environ.get( - "MLFLOW_TRACKING_URI", - "http://mlflow.mlflow.svc.cluster.local:80" + "MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80" ) # ─── MLflow experiment tracking ────────────────────────────────────────── @@ -62,7 +61,9 @@ try: _mlflow_run_id = _mlflow_run.info.run_id _mlflow_step = 0 MLFLOW_ENABLED = True - logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id) + logger.info( + "MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id + ) except Exception as exc: logger.warning("MLflow tracking disabled: %s", exc) _mlflow_client = None @@ -72,7 +73,10 @@ except Exception as exc: def _log_stt_metrics( - latency: float, audio_duration: float, word_count: int, task: str, + latency: float, + audio_duration: float, + word_count: int, + task: str, ) -> None: """Log STT inference metrics to MLflow (non-blocking best-effort).""" global _mlflow_step @@ -86,11 +90,15 @@ def _log_stt_metrics( _mlflow_run_id, metrics=[ mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step), - mlflow.entities.Metric("audio_duration_s", audio_duration, ts, _mlflow_step), + mlflow.entities.Metric( + "audio_duration_s", audio_duration, ts, _mlflow_step + ), mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step), mlflow.entities.Metric("word_count", word_count, ts, _mlflow_step), ], - params=[] if _mlflow_step > 1 else [ + params=[] + if _mlflow_step > 1 + else [ mlflow.entities.Param("task", task), ], ) @@ -124,57 +132,55 @@ LANGUAGES = { def transcribe_audio( - audio_input: tuple[int, np.ndarray] | str | None, - language: str, - task: str + audio_input: tuple[int, np.ndarray] | str | None, language: str, task: str ) -> tuple[str, str, str]: """Transcribe audio using the Whisper STT service.""" if audio_input is None: return "❌ Please provide audio input", "", "" - + try: start_time = time.time() - + # Handle different input types if isinstance(audio_input, tuple): # Microphone input: (sample_rate, audio_data) sample_rate, audio_data = audio_input - + # Convert to WAV bytes audio_buffer = io.BytesIO() - sf.write(audio_buffer, audio_data, sample_rate, format='WAV') + sf.write(audio_buffer, audio_data, sample_rate, format="WAV") audio_bytes = audio_buffer.getvalue() audio_duration = len(audio_data) / sample_rate else: # File path - with open(audio_input, 'rb') as f: + with open(audio_input, "rb") as f: audio_bytes = f.read() # Get duration audio_data, sample_rate = sf.read(audio_input) audio_duration = len(audio_data) / sample_rate - + # Prepare request lang_code = LANGUAGES.get(language) - + files = {"file": ("audio.wav", audio_bytes, "audio/wav")} data = {"response_format": "json"} - + if lang_code: data["language"] = lang_code - + # Choose endpoint based on task if task == "Translate to English": endpoint = f"{STT_URL}/v1/audio/translations" else: endpoint = f"{STT_URL}/v1/audio/transcriptions" - + # Send request response = client.post(endpoint, files=files, data=data) response.raise_for_status() - + latency = time.time() - start_time result = response.json() - + text = result.get("text", "") detected_language = result.get("language", "unknown") @@ -185,24 +191,26 @@ def transcribe_audio( word_count=len(text.split()), task=task, ) - + # Status message - status = f"✅ Transcribed {audio_duration:.1f}s of audio in {latency*1000:.0f}ms" - + status = ( + f"✅ Transcribed {audio_duration:.1f}s of audio in {latency * 1000:.0f}ms" + ) + # Metrics metrics = f""" **Transcription Statistics:** - Audio Duration: {audio_duration:.2f} seconds -- Processing Time: {latency*1000:.0f}ms -- Real-time Factor: {latency/audio_duration:.2f}x +- Processing Time: {latency * 1000:.0f}ms +- Real-time Factor: {latency / audio_duration:.2f}x - Detected Language: {detected_language} - Task: {task} - Word Count: {len(text.split())} - Character Count: {len(text)} """ - + return status, text, metrics - + except httpx.HTTPStatusError as e: logger.exception("STT request failed") return f"❌ STT service error: {e.response.status_code}", "", "" @@ -217,12 +225,12 @@ def check_service_health() -> str: response = client.get(f"{STT_URL}/health", timeout=5.0) if response.status_code == 200: return "🟢 Service is healthy" - + # Try v1/models endpoint (OpenAI-compatible) response = client.get(f"{STT_URL}/v1/models", timeout=5.0) if response.status_code == 200: return "🟢 Service is healthy" - + return f"🟡 Service returned status {response.status_code}" except Exception as e: return f"🔴 Service unavailable: {str(e)}" @@ -236,99 +244,89 @@ with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="STT Demo") as demo: Test the **Whisper** speech-to-text service. Transcribe audio from microphone or file upload with support for 100+ languages. """) - + # Service status with gr.Row(): health_btn = gr.Button("🔄 Check Service", size="sm") health_status = gr.Textbox(label="Service Status", interactive=False) - + health_btn.click(fn=check_service_health, outputs=health_status) - + with gr.Tabs(): # Tab 1: Microphone Input with gr.TabItem("🎤 Microphone"): with gr.Row(): with gr.Column(): mic_input = gr.Audio( - label="Record Audio", - sources=["microphone"], - type="numpy" + label="Record Audio", sources=["microphone"], type="numpy" ) - + with gr.Row(): mic_language = gr.Dropdown( choices=list(LANGUAGES.keys()), value="Auto-detect", - label="Language" + label="Language", ) mic_task = gr.Radio( choices=["Transcribe", "Translate to English"], value="Transcribe", - label="Task" + label="Task", ) - + mic_btn = gr.Button("🎯 Transcribe", variant="primary") - + with gr.Column(): mic_status = gr.Textbox(label="Status", interactive=False) mic_metrics = gr.Markdown(label="Metrics") - - mic_output = gr.Textbox( - label="Transcription", - lines=5 - ) - + + mic_output = gr.Textbox(label="Transcription", lines=5) + mic_btn.click( fn=transcribe_audio, inputs=[mic_input, mic_language, mic_task], - outputs=[mic_status, mic_output, mic_metrics] + outputs=[mic_status, mic_output, mic_metrics], ) - + # Tab 2: File Upload with gr.TabItem("📁 File Upload"): with gr.Row(): with gr.Column(): file_input = gr.Audio( - label="Upload Audio File", - sources=["upload"], - type="filepath" + label="Upload Audio File", sources=["upload"], type="filepath" ) - + with gr.Row(): file_language = gr.Dropdown( choices=list(LANGUAGES.keys()), value="Auto-detect", - label="Language" + label="Language", ) file_task = gr.Radio( choices=["Transcribe", "Translate to English"], value="Transcribe", - label="Task" + label="Task", ) - + file_btn = gr.Button("🎯 Transcribe", variant="primary") - + with gr.Column(): file_status = gr.Textbox(label="Status", interactive=False) file_metrics = gr.Markdown(label="Metrics") - - file_output = gr.Textbox( - label="Transcription", - lines=5 - ) - + + file_output = gr.Textbox(label="Transcription", lines=5) + file_btn.click( fn=transcribe_audio, inputs=[file_input, file_language, file_task], - outputs=[file_status, file_output, file_metrics] + outputs=[file_status, file_output, file_metrics], ) - + gr.Markdown(""" **Supported formats:** WAV, MP3, FLAC, OGG, M4A, WEBM *For best results, use clear audio with minimal background noise.* """) - + # Tab 3: Translation with gr.TabItem("🌍 Translation"): gr.Markdown(""" @@ -337,40 +335,33 @@ or file upload with support for 100+ languages. Upload or record audio in any language and get English translation. Whisper will automatically detect the source language. """) - + with gr.Row(): with gr.Column(): trans_input = gr.Audio( label="Audio Input", sources=["microphone", "upload"], - type="numpy" + type="numpy", ) trans_btn = gr.Button("🌍 Translate to English", variant="primary") - + with gr.Column(): trans_status = gr.Textbox(label="Status", interactive=False) trans_metrics = gr.Markdown(label="Metrics") - - trans_output = gr.Textbox( - label="English Translation", - lines=5 - ) - + + trans_output = gr.Textbox(label="English Translation", lines=5) + def translate_audio(audio): return transcribe_audio(audio, "Auto-detect", "Translate to English") - + trans_btn.click( fn=translate_audio, inputs=trans_input, - outputs=[trans_status, trans_output, trans_metrics] + outputs=[trans_status, trans_output, trans_metrics], ) - + create_footer() if __name__ == "__main__": - demo.launch( - server_name="0.0.0.0", - server_port=7860, - show_error=True - ) + demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True) diff --git a/theme.py b/theme.py index ddbdec6..30830c9 100644 --- a/theme.py +++ b/theme.py @@ -3,6 +3,7 @@ Shared Gradio theme for Davies Tech Labs AI demos. Consistent styling across all demo applications. Cyberpunk aesthetic - dark with yellow/gold accents. """ + import gradio as gr @@ -25,7 +26,12 @@ def get_lab_theme() -> gr.Theme: primary_hue=gr.themes.colors.yellow, secondary_hue=gr.themes.colors.amber, neutral_hue=gr.themes.colors.zinc, - font=[gr.themes.GoogleFont("Space Grotesk"), "ui-sans-serif", "system-ui", "sans-serif"], + font=[ + gr.themes.GoogleFont("Space Grotesk"), + "ui-sans-serif", + "system-ui", + "sans-serif", + ], font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "ui-monospace", "monospace"], ).set( # Background colors diff --git a/tts.py b/tts.py index 742d535..8e65aea 100644 --- a/tts.py +++ b/tts.py @@ -9,11 +9,11 @@ Features: - MLflow metrics logging - Multiple TTS backends support (Coqui XTTS, Piper, etc.) """ + import os import time import logging import io -import base64 import gradio as gr import httpx @@ -30,11 +30,10 @@ logger = logging.getLogger("tts-demo") TTS_URL = os.environ.get( "TTS_URL", # Default: Ray Serve TTS endpoint - "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/tts" + "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/tts", ) MLFLOW_TRACKING_URI = os.environ.get( - "MLFLOW_TRACKING_URI", - "http://mlflow.mlflow.svc.cluster.local:80" + "MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80" ) # ─── MLflow experiment tracking ────────────────────────────────────────── @@ -62,7 +61,9 @@ try: _mlflow_run_id = _mlflow_run.info.run_id _mlflow_step = 0 MLFLOW_ENABLED = True - logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id) + logger.info( + "MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id + ) except Exception as exc: logger.warning("MLflow tracking disabled: %s", exc) _mlflow_client = None @@ -72,7 +73,10 @@ except Exception as exc: def _log_tts_metrics( - latency: float, audio_duration: float, text_chars: int, language: str, + latency: float, + audio_duration: float, + text_chars: int, + language: str, ) -> None: """Log TTS inference metrics to MLflow (non-blocking best-effort).""" global _mlflow_step @@ -87,7 +91,9 @@ def _log_tts_metrics( _mlflow_run_id, metrics=[ mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step), - mlflow.entities.Metric("audio_duration_s", audio_duration, ts, _mlflow_step), + mlflow.entities.Metric( + "audio_duration_s", audio_duration, ts, _mlflow_step + ), mlflow.entities.Metric("realtime_factor", rtf, ts, _mlflow_step), mlflow.entities.Metric("chars_per_second", cps, ts, _mlflow_step), mlflow.entities.Metric("text_chars", text_chars, ts, _mlflow_step), @@ -121,38 +127,39 @@ LANGUAGES = { } -def synthesize_speech(text: str, language: str) -> tuple[str, tuple[int, np.ndarray] | None, str]: +def synthesize_speech( + text: str, language: str +) -> tuple[str, tuple[int, np.ndarray] | None, str]: """Synthesize speech from text using the TTS service.""" if not text.strip(): return "❌ Please enter some text", None, "" - + lang_code = LANGUAGES.get(language, "en") - + try: start_time = time.time() - + # Call TTS service (Coqui XTTS API format) response = client.get( - f"{TTS_URL}/api/tts", - params={"text": text, "language_id": lang_code} + f"{TTS_URL}/api/tts", params={"text": text, "language_id": lang_code} ) response.raise_for_status() - + latency = time.time() - start_time audio_bytes = response.content - + # Parse audio data audio_io = io.BytesIO(audio_bytes) audio_data, sample_rate = sf.read(audio_io) - + # Calculate duration if len(audio_data.shape) == 1: duration = len(audio_data) / sample_rate else: duration = len(audio_data) / sample_rate - + # Status message - status = f"✅ Generated {duration:.2f}s of audio in {latency*1000:.0f}ms" + status = f"✅ Generated {duration:.2f}s of audio in {latency * 1000:.0f}ms" # Log to MLflow _log_tts_metrics( @@ -161,22 +168,22 @@ def synthesize_speech(text: str, language: str) -> tuple[str, tuple[int, np.ndar text_chars=len(text), language=lang_code, ) - + # Metrics metrics = f""" **Audio Statistics:** - Duration: {duration:.2f} seconds - Sample Rate: {sample_rate} Hz - Size: {len(audio_bytes) / 1024:.1f} KB -- Generation Time: {latency*1000:.0f}ms -- Real-time Factor: {latency/duration:.2f}x +- Generation Time: {latency * 1000:.0f}ms +- Real-time Factor: {latency / duration:.2f}x - Language: {language} ({lang_code}) - Characters: {len(text)} -- Chars/sec: {len(text)/latency:.1f} +- Chars/sec: {len(text) / latency:.1f} """ - + return status, (sample_rate, audio_data), metrics - + except httpx.HTTPStatusError as e: logger.exception("TTS request failed") return f"❌ TTS service error: {e.response.status_code}", None, "" @@ -192,12 +199,12 @@ def check_service_health() -> str: response = client.get(f"{TTS_URL}/health", timeout=5.0) if response.status_code == 200: return "🟢 Service is healthy" - + # Fall back to root endpoint response = client.get(f"{TTS_URL}/", timeout=5.0) if response.status_code == 200: return "🟢 Service is responding" - + return f"🟡 Service returned status {response.status_code}" except Exception as e: return f"🔴 Service unavailable: {str(e)}" @@ -211,14 +218,14 @@ with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="TTS Demo") as demo: Test the **Coqui XTTS** text-to-speech service. Convert text to natural-sounding speech in multiple languages. """) - + # Service status with gr.Row(): health_btn = gr.Button("🔄 Check Service", size="sm") health_status = gr.Textbox(label="Service Status", interactive=False) - + health_btn.click(fn=check_service_health, outputs=health_status) - + with gr.Tabs(): # Tab 1: Basic TTS with gr.TabItem("🎤 Text to Speech"): @@ -228,114 +235,120 @@ in multiple languages. label="Text to Synthesize", placeholder="Enter text to convert to speech...", lines=5, - max_lines=10 + max_lines=10, ) - + with gr.Row(): language = gr.Dropdown( choices=list(LANGUAGES.keys()), value="English", - label="Language" + label="Language", ) - synthesize_btn = gr.Button("🔊 Synthesize", variant="primary", scale=2) - + synthesize_btn = gr.Button( + "🔊 Synthesize", variant="primary", scale=2 + ) + with gr.Column(scale=1): status_output = gr.Textbox(label="Status", interactive=False) metrics_output = gr.Markdown(label="Metrics") - + audio_output = gr.Audio(label="Generated Audio", type="numpy") - + synthesize_btn.click( fn=synthesize_speech, inputs=[text_input, language], - outputs=[status_output, audio_output, metrics_output] + outputs=[status_output, audio_output, metrics_output], ) - + # Example texts gr.Examples( examples=[ - ["Hello! Welcome to Davies Tech Labs. This is a demonstration of our text-to-speech system.", "English"], - ["The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet.", "English"], - ["Bonjour! Bienvenue au laboratoire technique de Davies.", "French"], + [ + "Hello! Welcome to Davies Tech Labs. This is a demonstration of our text-to-speech system.", + "English", + ], + [ + "The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet.", + "English", + ], + [ + "Bonjour! Bienvenue au laboratoire technique de Davies.", + "French", + ], ["Hola! Bienvenido al laboratorio de tecnología.", "Spanish"], ["Guten Tag! Willkommen im Techniklabor.", "German"], ], inputs=[text_input, language], ) - + # Tab 2: Comparison with gr.TabItem("🔄 Language Comparison"): gr.Markdown("Compare the same text in different languages.") - + compare_text = gr.Textbox( - label="Text to Compare", - value="Hello, how are you today?", - lines=2 + label="Text to Compare", value="Hello, how are you today?", lines=2 ) - + with gr.Row(): - lang1 = gr.Dropdown(choices=list(LANGUAGES.keys()), value="English", label="Language 1") - lang2 = gr.Dropdown(choices=list(LANGUAGES.keys()), value="Spanish", label="Language 2") - + lang1 = gr.Dropdown( + choices=list(LANGUAGES.keys()), value="English", label="Language 1" + ) + lang2 = gr.Dropdown( + choices=list(LANGUAGES.keys()), value="Spanish", label="Language 2" + ) + compare_btn = gr.Button("Compare Languages", variant="primary") - + with gr.Row(): with gr.Column(): gr.Markdown("### Language 1") audio1 = gr.Audio(label="Audio 1", type="numpy") status1 = gr.Textbox(label="Status", interactive=False) - + with gr.Column(): gr.Markdown("### Language 2") audio2 = gr.Audio(label="Audio 2", type="numpy") status2 = gr.Textbox(label="Status", interactive=False) - + def compare_languages(text, l1, l2): s1, a1, _ = synthesize_speech(text, l1) s2, a2, _ = synthesize_speech(text, l2) return s1, a1, s2, a2 - + compare_btn.click( fn=compare_languages, inputs=[compare_text, lang1, lang2], - outputs=[status1, audio1, status2, audio2] + outputs=[status1, audio1, status2, audio2], ) - + # Tab 3: Batch Processing with gr.TabItem("📚 Batch Synthesis"): gr.Markdown("Synthesize multiple texts at once (one per line).") - + batch_input = gr.Textbox( label="Texts (one per line)", placeholder="Enter multiple texts, one per line...", - lines=6 + lines=6, ) batch_lang = gr.Dropdown( - choices=list(LANGUAGES.keys()), - value="English", - label="Language" + choices=list(LANGUAGES.keys()), value="English", label="Language" ) batch_btn = gr.Button("Synthesize All", variant="primary") - + batch_status = gr.Textbox(label="Status", interactive=False) batch_audios = gr.Dataset( - components=[gr.Audio(type="numpy")], - label="Generated Audio Files" + components=[gr.Audio(type="numpy")], label="Generated Audio Files" ) - + # Note: Batch processing would need more complex handling # This is a simplified version gr.Markdown(""" *Note: For batch processing of many texts, consider using the API directly or the Kubeflow pipeline for better throughput.* """) - + create_footer() if __name__ == "__main__": - demo.launch( - server_name="0.0.0.0", - server_port=7860, - show_error=True - ) + demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)