#!/usr/bin/env python3 """ Embeddings Demo - Gradio UI for testing BGE embeddings service. Features: - Text input for generating embeddings - Batch embedding support - Similarity comparison between texts - MLflow metrics logging - Visual embedding dimension display """ import os import time import logging import json import gradio as gr import httpx import numpy as np from theme import get_lab_theme, CUSTOM_CSS, create_footer # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("embeddings-demo") # Configuration EMBEDDINGS_URL = os.environ.get( "EMBEDDINGS_URL", # Default: Ray Serve Embeddings endpoint "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/embeddings" ) # ─── MLflow experiment tracking ────────────────────────────────────────── try: import mlflow from mlflow.tracking import MlflowClient MLFLOW_TRACKING_URI = os.environ.get( "MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80", ) mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) _mlflow_client = MlflowClient() _experiment = _mlflow_client.get_experiment_by_name("gradio-embeddings-tuning") if _experiment is None: _experiment_id = _mlflow_client.create_experiment( "gradio-embeddings-tuning", artifact_location="/mlflow/artifacts/gradio-embeddings-tuning", ) else: _experiment_id = _experiment.experiment_id _mlflow_run = mlflow.start_run( experiment_id=_experiment_id, run_name=f"gradio-embeddings-{os.environ.get('HOSTNAME', 'local')}", tags={"service": "gradio-embeddings", "endpoint": EMBEDDINGS_URL}, ) _mlflow_run_id = _mlflow_run.info.run_id _mlflow_step = 0 MLFLOW_ENABLED = True logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id) except Exception as exc: logger.warning("MLflow tracking disabled: %s", exc) _mlflow_client = None _mlflow_run_id = None _mlflow_step = 0 MLFLOW_ENABLED = False def _log_embedding_metrics(latency: float, batch_size: int, embedding_dims: int = 0) -> None: """Log embedding inference metrics to MLflow (non-blocking best-effort).""" global _mlflow_step if not MLFLOW_ENABLED or _mlflow_client is None: return try: _mlflow_step += 1 ts = int(time.time() * 1000) _mlflow_client.log_batch( _mlflow_run_id, metrics=[ mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step), mlflow.entities.Metric("batch_size", batch_size, ts, _mlflow_step), mlflow.entities.Metric("embedding_dims", embedding_dims, ts, _mlflow_step), mlflow.entities.Metric("latency_per_text_ms", (latency * 1000 / batch_size) if batch_size > 0 else 0, ts, _mlflow_step), ], ) except Exception: logger.debug("MLflow log failed", exc_info=True) # HTTP client client = httpx.Client(timeout=60.0) def get_embeddings(texts: list[str]) -> tuple[list[list[float]], float]: """Get embeddings from the embeddings service.""" start_time = time.time() response = client.post( f"{EMBEDDINGS_URL}/embeddings", json={"input": texts, "model": "bge"} ) response.raise_for_status() latency = time.time() - start_time result = response.json() embeddings = [d["embedding"] for d in result.get("data", [])] return embeddings, latency def cosine_similarity(a: list[float], b: list[float]) -> float: """Compute cosine similarity between two vectors.""" a = np.array(a) b = np.array(b) return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) def generate_single_embedding(text: str) -> tuple[str, str, str]: """Generate embedding for a single text.""" if not text.strip(): return "❌ Please enter some text", "", "" try: embeddings, latency = get_embeddings([text]) if not embeddings: return "❌ No embedding returned", "", "" embedding = embeddings[0] dims = len(embedding) # Log to MLflow _log_embedding_metrics(latency, batch_size=1, embedding_dims=dims) # Format output status = f"✅ Generated {dims}-dimensional embedding in {latency*1000:.1f}ms" # Show first/last few dimensions preview = f"Dimensions: {dims}\n\n" preview += "First 10 values:\n" preview += json.dumps(embedding[:10], indent=2) preview += "\n\n...\n\nLast 10 values:\n" preview += json.dumps(embedding[-10:], indent=2) # Stats stats = f""" **Embedding Statistics:** - Dimensions: {dims} - Min value: {min(embedding):.6f} - Max value: {max(embedding):.6f} - Mean: {np.mean(embedding):.6f} - Std: {np.std(embedding):.6f} - L2 Norm: {np.linalg.norm(embedding):.6f} - Latency: {latency*1000:.1f}ms """ return status, preview, stats except Exception as e: logger.exception("Embedding generation failed") return f"❌ Error: {str(e)}", "", "" def compare_texts(text1: str, text2: str) -> tuple[str, str]: """Compare similarity between two texts.""" if not text1.strip() or not text2.strip(): return "❌ Please enter both texts", "" try: embeddings, latency = get_embeddings([text1, text2]) if len(embeddings) != 2: return "❌ Failed to get embeddings for both texts", "" similarity = cosine_similarity(embeddings[0], embeddings[1]) # Log to MLflow _log_embedding_metrics(latency, batch_size=2, embedding_dims=len(embeddings[0])) # Determine similarity level if similarity > 0.9: level = "🟢 Very High" desc = "These texts are semantically very similar" elif similarity > 0.7: level = "🟡 High" desc = "These texts share significant semantic meaning" elif similarity > 0.5: level = "🟠 Moderate" desc = "These texts have some semantic overlap" else: level = "🔴 Low" desc = "These texts are semantically different" result = f""" ## Similarity Score: {similarity:.4f} **Level:** {level} {desc} --- *Computed in {latency*1000:.1f}ms* """ # Create a simple visual bar bar_length = 50 filled = int(similarity * bar_length) bar = "█" * filled + "░" * (bar_length - filled) visual = f"[{bar}] {similarity*100:.1f}%" return result, visual except Exception as e: logger.exception("Comparison failed") return f"❌ Error: {str(e)}", "" def batch_embed(texts_input: str) -> tuple[str, str]: """Generate embeddings for multiple texts (one per line).""" texts = [t.strip() for t in texts_input.strip().split("\n") if t.strip()] if not texts: return "❌ Please enter at least one text (one per line)", "" try: embeddings, latency = get_embeddings(texts) # Log to MLflow _log_embedding_metrics(latency, batch_size=len(embeddings), embedding_dims=len(embeddings[0]) if embeddings else 0) status = f"✅ Generated {len(embeddings)} embeddings in {latency*1000:.1f}ms" status += f" ({latency*1000/len(texts):.1f}ms per text)" # Build similarity matrix n = len(embeddings) matrix = [] for i in range(n): row = [] for j in range(n): sim = cosine_similarity(embeddings[i], embeddings[j]) row.append(f"{sim:.3f}") matrix.append(row) # Format as table header = "| | " + " | ".join([f"Text {i+1}" for i in range(n)]) + " |" separator = "|---" + "|---" * n + "|" rows = [] for i, row in enumerate(matrix): rows.append(f"| **Text {i+1}** | " + " | ".join(row) + " |") table = "\n".join([header, separator] + rows) result = f""" ## Similarity Matrix {table} --- **Texts processed:** """ for i, text in enumerate(texts): result += f"\n{i+1}. {text[:50]}{'...' if len(text) > 50 else ''}" return status, result except Exception as e: logger.exception("Batch embedding failed") return f"❌ Error: {str(e)}", "" def check_service_health() -> str: """Check if the embeddings service is healthy.""" try: response = client.get(f"{EMBEDDINGS_URL}/health", timeout=5.0) if response.status_code == 200: return "🟢 Service is healthy" else: return f"🟡 Service returned status {response.status_code}" except Exception as e: return f"🔴 Service unavailable: {str(e)}" # Build the Gradio app with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="Embeddings Demo") as demo: gr.Markdown(""" # 🔢 Embeddings Demo Test the **BGE Embeddings** service for semantic text encoding. Generate embeddings, compare text similarity, and explore vector representations. """) # Service status with gr.Row(): health_btn = gr.Button("🔄 Check Service", size="sm") health_status = gr.Textbox(label="Service Status", interactive=False) health_btn.click(fn=check_service_health, outputs=health_status) with gr.Tabs(): # Tab 1: Single Embedding with gr.TabItem("📝 Single Text"): with gr.Row(): with gr.Column(): single_input = gr.Textbox( label="Input Text", placeholder="Enter text to generate embeddings...", lines=3 ) single_btn = gr.Button("Generate Embedding", variant="primary") with gr.Column(): single_status = gr.Textbox(label="Status", interactive=False) single_stats = gr.Markdown(label="Statistics") single_preview = gr.Code(label="Embedding Preview", language="json") single_btn.click( fn=generate_single_embedding, inputs=single_input, outputs=[single_status, single_preview, single_stats] ) # Tab 2: Compare Texts with gr.TabItem("⚖️ Compare Texts"): gr.Markdown("Compare the semantic similarity between two texts.") with gr.Row(): compare_text1 = gr.Textbox(label="Text 1", lines=3) compare_text2 = gr.Textbox(label="Text 2", lines=3) compare_btn = gr.Button("Compare Similarity", variant="primary") with gr.Row(): compare_result = gr.Markdown(label="Result") compare_visual = gr.Textbox(label="Similarity Bar", interactive=False) compare_btn.click( fn=compare_texts, inputs=[compare_text1, compare_text2], outputs=[compare_result, compare_visual] ) # Example pairs gr.Examples( examples=[ ["The cat sat on the mat.", "A feline was resting on the rug."], ["Machine learning is a subset of AI.", "Deep learning uses neural networks."], ["I love pizza.", "The stock market crashed today."], ], inputs=[compare_text1, compare_text2], ) # Tab 3: Batch Embeddings with gr.TabItem("📚 Batch Processing"): gr.Markdown("Generate embeddings for multiple texts and see their similarity matrix.") batch_input = gr.Textbox( label="Texts (one per line)", placeholder="Enter multiple texts, one per line...", lines=6 ) batch_btn = gr.Button("Process Batch", variant="primary") batch_status = gr.Textbox(label="Status", interactive=False) batch_result = gr.Markdown(label="Similarity Matrix") batch_btn.click( fn=batch_embed, inputs=batch_input, outputs=[batch_status, batch_result] ) gr.Examples( examples=[ "Python is a programming language.\nJava is also a programming language.\nCoffee is a beverage.", "The quick brown fox jumps over the lazy dog.\nA fast auburn fox leaps above a sleepy canine.\nThe weather is nice today.", ], inputs=batch_input, ) create_footer() if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, show_error=True )