#!/usr/bin/env python3 """ LLM Chat Demo - Gradio UI for testing vLLM inference service. Features: - Multi-turn chat with true SSE streaming responses - Configurable temperature, max tokens, top-p - System prompt customisation - Token usage and latency metrics - Chat history management """ import json import os import time import logging import gradio as gr import httpx from theme import get_lab_theme, CUSTOM_CSS, create_footer # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("llm-demo") # Configuration LLM_URL = os.environ.get( "LLM_URL", # Default: Ray Serve LLM endpoint "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/llm", ) # ─── MLflow experiment tracking ────────────────────────────────────────── try: import mlflow from mlflow.tracking import MlflowClient MLFLOW_TRACKING_URI = os.environ.get( "MLFLOW_TRACKING_URI", "http://mlflow.mlflow.svc.cluster.local:80", ) mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) _mlflow_client = MlflowClient() # Ensure experiment exists _experiment = _mlflow_client.get_experiment_by_name("gradio-llm-tuning") if _experiment is None: _experiment_id = _mlflow_client.create_experiment( "gradio-llm-tuning", artifact_location="/mlflow/artifacts/gradio-llm-tuning", ) else: _experiment_id = _experiment.experiment_id # One persistent run per Gradio instance _mlflow_run = mlflow.start_run( experiment_id=_experiment_id, run_name=f"gradio-llm-{os.environ.get('HOSTNAME', 'local')}", tags={ "service": "gradio-llm", "endpoint": LLM_URL, "mlflow.runName": f"gradio-llm-{os.environ.get('HOSTNAME', 'local')}", }, ) _mlflow_run_id = _mlflow_run.info.run_id _mlflow_step = 0 MLFLOW_ENABLED = True logger.info( "MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id ) except Exception as exc: logger.warning("MLflow tracking disabled: %s", exc) _mlflow_client = None _mlflow_run_id = None _mlflow_step = 0 MLFLOW_ENABLED = False def _log_llm_metrics( latency: float, prompt_tokens: int, completion_tokens: int, temperature: float, max_tokens: int, top_p: float, ) -> None: """Log inference metrics to MLflow (non-blocking best-effort).""" global _mlflow_step if not MLFLOW_ENABLED or _mlflow_client is None: return try: _mlflow_step += 1 ts = int(time.time() * 1000) total_tokens = prompt_tokens + completion_tokens tps = completion_tokens / latency if latency > 0 else 0 _mlflow_client.log_batch( _mlflow_run_id, metrics=[ mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step), mlflow.entities.Metric( "prompt_tokens", prompt_tokens, ts, _mlflow_step ), mlflow.entities.Metric( "completion_tokens", completion_tokens, ts, _mlflow_step ), mlflow.entities.Metric("total_tokens", total_tokens, ts, _mlflow_step), mlflow.entities.Metric("tokens_per_second", tps, ts, _mlflow_step), mlflow.entities.Metric("temperature", temperature, ts, _mlflow_step), mlflow.entities.Metric( "max_tokens_requested", max_tokens, ts, _mlflow_step ), mlflow.entities.Metric("top_p", top_p, ts, _mlflow_step), ], ) except Exception: logger.debug("MLflow log failed", exc_info=True) DEFAULT_SYSTEM_PROMPT = ( "You are a helpful AI assistant running on Davies Tech Labs homelab infrastructure. " "You are powered by Llama 3.1 70B served via vLLM on AMD Strix Halo (ROCm). " "Be concise and helpful." ) # Use async client for streaming async_client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=30.0)) sync_client = httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0)) def _extract_content(content) -> str: """Extract plain text from message content. Handles both plain strings and Gradio 6.x content-parts format: [{"type": "text", "text": "..."}] or [{"text": "..."}] """ if isinstance(content, str): return content if isinstance(content, list): parts = [] for item in content: if isinstance(item, dict): parts.append(item.get("text", item.get("content", str(item)))) elif isinstance(item, str): parts.append(item) else: parts.append(str(item)) return "".join(parts) return str(content) async def chat_stream( message: str, history: list[dict[str, str]], system_prompt: str, temperature: float, max_tokens: int, top_p: float, ): """Stream chat responses from the vLLM endpoint via SSE.""" if not message.strip(): yield "" return # Build message list from history, normalising content-parts messages = [] if system_prompt.strip(): messages.append({"role": "system", "content": system_prompt}) for entry in history: messages.append({ "role": entry["role"], "content": _extract_content(entry["content"]), }) messages.append({"role": "user", "content": message}) payload = { "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, "stream": True, } start_time = time.time() try: # Try true SSE streaming first async with async_client.stream("POST", LLM_URL, json=payload) as response: response.raise_for_status() content_type = response.headers.get("content-type", "") if "text/event-stream" in content_type: # SSE streaming — accumulate deltas full_text = "" async for line in response.aiter_lines(): if not line.startswith("data: "): continue data = line[6:] if data.strip() == "[DONE]": break try: chunk = json.loads(data) delta = ( chunk.get("choices", [{}])[0] .get("delta", {}) .get("content", "") ) if delta: full_text += delta yield full_text except json.JSONDecodeError: continue latency = time.time() - start_time logger.info("LLM streamed response: %d chars in %.1fs", len(full_text), latency) # Best-effort metrics from the final SSE payload _log_llm_metrics( latency=latency, prompt_tokens=0, completion_tokens=len(full_text.split()), temperature=temperature, max_tokens=max_tokens, top_p=top_p, ) else: # Non-streaming fallback (endpoint doesn't support stream) body = await response.aread() result = json.loads(body) text = _extract_content( result["choices"][0]["message"]["content"] ) latency = time.time() - start_time usage = result.get("usage", {}) logger.info( "LLM response: %d tokens in %.1fs (prompt=%d, completion=%d)", usage.get("total_tokens", 0), latency, usage.get("prompt_tokens", 0), usage.get("completion_tokens", 0), ) _log_llm_metrics( latency=latency, prompt_tokens=usage.get("prompt_tokens", 0), completion_tokens=usage.get("completion_tokens", 0), temperature=temperature, max_tokens=max_tokens, top_p=top_p, ) # Yield text progressively for a nicer feel chunk_size = 4 words = text.split(" ") partial = "" for i, word in enumerate(words): partial += ("" if i == 0 else " ") + word if i % chunk_size == 0 or i == len(words) - 1: yield partial except httpx.HTTPStatusError as e: logger.exception("LLM request failed") yield f"❌ LLM service error: {e.response.status_code} — {e.response.text[:200]}" except httpx.ConnectError: yield "❌ Cannot connect to LLM service. Is the Ray Serve cluster running?" except Exception as e: logger.exception("LLM chat failed") yield f"❌ Error: {e}" def check_service_health() -> str: """Check if the LLM service is reachable.""" try: # Try a lightweight GET against the Ray Serve base first. # This avoids burning GPU time on a full inference round-trip. base_url = LLM_URL.rsplit("/", 1)[0] # strip /llm path response = sync_client.get(f"{base_url}/-/routes") if response.status_code == 200: return "🟢 LLM service is healthy" # Fall back to a minimal inference probe response = sync_client.post( LLM_URL, json={ "messages": [{"role": "user", "content": "ping"}], "max_tokens": 1, "temperature": 0.0, }, ) if response.status_code == 200: return "🟢 LLM service is healthy" return f"🟡 LLM responded with status {response.status_code}" except httpx.ConnectError: return "🔴 Cannot connect to LLM service" except httpx.TimeoutException: return "🟡 LLM service is reachable but slow to respond" except Exception as e: return f"🔴 Service unavailable: {e}" def single_prompt( prompt: str, system_prompt: str, temperature: float, max_tokens: int, top_p: float, ) -> tuple[str, str]: """Send a single prompt (non-chat mode) and return output + metrics.""" if not prompt.strip(): return "❌ Please enter a prompt", "" messages = [] if system_prompt.strip(): messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) payload = { "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, } start_time = time.time() try: client = httpx.Client(timeout=300.0) response = client.post(LLM_URL, json=payload) response.raise_for_status() result = response.json() latency = time.time() - start_time text = _extract_content(result["choices"][0]["message"]["content"]) usage = result.get("usage", {}) # Log to MLflow _log_llm_metrics( latency=latency, prompt_tokens=usage.get("prompt_tokens", 0), completion_tokens=usage.get("completion_tokens", 0), temperature=temperature, max_tokens=max_tokens, top_p=top_p, ) metrics = f""" **Generation Metrics:** - Latency: {latency:.1f}s - Prompt tokens: {usage.get("prompt_tokens", "N/A")} - Completion tokens: {usage.get("completion_tokens", "N/A")} - Total tokens: {usage.get("total_tokens", "N/A")} - Model: {result.get("model", "N/A")} """ return text, metrics except httpx.HTTPStatusError as e: return f"❌ Error {e.response.status_code}: {e.response.text[:300]}", "" except httpx.ConnectError: return "❌ Cannot connect to LLM service", "" except Exception as e: return f"❌ {e}", "" # ─── Build the Gradio app ──────────────────────────────────────────────── with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="LLM Chat Demo") as demo: gr.Markdown( """ # 🧠 LLM Chat Demo Chat with **Llama 3.1 70B** (AWQ INT4) served via vLLM on AMD Strix Halo (ROCm). """ ) # Service status with gr.Row(): health_btn = gr.Button("🔄 Check Service", size="sm") health_status = gr.Textbox(label="Service Status", interactive=False) health_btn.click(fn=check_service_health, outputs=health_status) # Shared parameters with gr.Accordion("⚙️ Parameters", open=False): system_prompt = gr.Textbox( label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=3, max_lines=6, ) with gr.Row(): temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature") max_tokens = gr.Slider(16, 8192, value=2048, step=16, label="Max Tokens") top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-p") with gr.Tabs(): # Tab 1: Multi-turn Chat with gr.TabItem("💬 Chat"): chatbot = gr.ChatInterface( fn=chat_stream, additional_inputs=[system_prompt, temperature, max_tokens, top_p], examples=[ ["Hello! What can you tell me about yourself?"], ["Explain how a GPU executes a matrix multiplication."], ["Write a Python function to compute the Fibonacci sequence."], ["What are the pros and cons of running LLMs on AMD GPUs?"], ], chatbot=gr.Chatbot( height=520, placeholder="Type a message to start chatting...", ), ) # Tab 2: Single Prompt with gr.TabItem("📝 Single Prompt"): gr.Markdown("Send a one-shot prompt without conversation history.") prompt_input = gr.Textbox( label="Prompt", placeholder="Enter your prompt...", lines=4, max_lines=10, ) generate_btn = gr.Button("🚀 Generate", variant="primary") output_text = gr.Textbox(label="Response", lines=12, interactive=False) output_metrics = gr.Markdown(label="Metrics") generate_btn.click( fn=single_prompt, inputs=[prompt_input, system_prompt, temperature, max_tokens, top_p], outputs=[output_text, output_metrics], ) gr.Examples( examples=[ [ "Summarise the key differences between CUDA and ROCm for ML workloads." ], ["Write a haiku about Kubernetes."], [ "Explain Ray Serve in one paragraph for someone new to ML serving." ], ["List 5 creative uses for a homelab GPU cluster."], ], inputs=[prompt_input], ) create_footer() if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)