#!/usr/bin/env python3 """ LLM Chat Demo - Gradio UI for testing vLLM inference service. Features: - Multi-turn chat with streaming responses - Configurable temperature, max tokens, top-p - System prompt customisation - Token usage and latency metrics - Chat history management """ import os import time import logging import json import gradio as gr import httpx from theme import get_lab_theme, CUSTOM_CSS, create_footer # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("llm-demo") # Configuration LLM_URL = os.environ.get( "LLM_URL", # Default: Ray Serve LLM endpoint "http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/llm", ) DEFAULT_SYSTEM_PROMPT = ( "You are a helpful AI assistant running on Davies Tech Labs homelab infrastructure. " "You are powered by Llama 3.1 70B served via vLLM on AMD Strix Halo (ROCm). " "Be concise and helpful." ) # Use async client for streaming async_client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=30.0)) sync_client = httpx.Client(timeout=10.0) async def chat_stream( message: str, history: list[dict[str, str]], system_prompt: str, temperature: float, max_tokens: int, top_p: float, ): """Stream chat responses from the vLLM endpoint.""" if not message.strip(): yield "" return # Build message list from history messages = [] if system_prompt.strip(): messages.append({"role": "system", "content": system_prompt}) for entry in history: messages.append({"role": entry["role"], "content": entry["content"]}) messages.append({"role": "user", "content": message}) payload = { "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, } start_time = time.time() try: response = await async_client.post(LLM_URL, json=payload) response.raise_for_status() result = response.json() text = result["choices"][0]["message"]["content"] latency = time.time() - start_time usage = result.get("usage", {}) logger.info( "LLM response: %d tokens in %.1fs (prompt=%d, completion=%d)", usage.get("total_tokens", 0), latency, usage.get("prompt_tokens", 0), usage.get("completion_tokens", 0), ) # Yield text progressively for a nicer streaming feel chunk_size = 4 words = text.split(" ") partial = "" for i, word in enumerate(words): partial += ("" if i == 0 else " ") + word if i % chunk_size == 0 or i == len(words) - 1: yield partial except httpx.HTTPStatusError as e: logger.exception("LLM request failed") yield f"❌ LLM service error: {e.response.status_code} — {e.response.text[:200]}" except httpx.ConnectError: yield "❌ Cannot connect to LLM service. Is the Ray Serve cluster running?" except Exception as e: logger.exception("LLM chat failed") yield f"❌ Error: {e}" def check_service_health() -> str: """Check if the LLM service is reachable.""" try: response = sync_client.post( LLM_URL, json={ "messages": [{"role": "user", "content": "ping"}], "max_tokens": 1, "temperature": 0.0, }, ) if response.status_code == 200: return "🟢 LLM service is healthy" return f"🟡 LLM responded with status {response.status_code}" except httpx.ConnectError: return "🔴 Cannot connect to LLM service" except Exception as e: return f"🔴 Service unavailable: {e}" def single_prompt( prompt: str, system_prompt: str, temperature: float, max_tokens: int, top_p: float, ) -> tuple[str, str]: """Send a single prompt (non-chat mode) and return output + metrics.""" if not prompt.strip(): return "❌ Please enter a prompt", "" messages = [] if system_prompt.strip(): messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) payload = { "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, } start_time = time.time() try: client = httpx.Client(timeout=300.0) response = client.post(LLM_URL, json=payload) response.raise_for_status() result = response.json() latency = time.time() - start_time text = result["choices"][0]["message"]["content"] usage = result.get("usage", {}) metrics = f""" **Generation Metrics:** - Latency: {latency:.1f}s - Prompt tokens: {usage.get('prompt_tokens', 'N/A')} - Completion tokens: {usage.get('completion_tokens', 'N/A')} - Total tokens: {usage.get('total_tokens', 'N/A')} - Model: {result.get('model', 'N/A')} """ return text, metrics except httpx.HTTPStatusError as e: return f"❌ Error {e.response.status_code}: {e.response.text[:300]}", "" except httpx.ConnectError: return "❌ Cannot connect to LLM service", "" except Exception as e: return f"❌ {e}", "" # ─── Build the Gradio app ──────────────────────────────────────────────── with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="LLM Chat Demo") as demo: gr.Markdown( """ # 🧠 LLM Chat Demo Chat with **Llama 3.1 70B** (AWQ INT4) served via vLLM on AMD Strix Halo (ROCm). """ ) # Service status with gr.Row(): health_btn = gr.Button("🔄 Check Service", size="sm") health_status = gr.Textbox(label="Service Status", interactive=False) health_btn.click(fn=check_service_health, outputs=health_status) # Shared parameters with gr.Accordion("⚙️ Parameters", open=False): system_prompt = gr.Textbox( label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=3, max_lines=6, ) with gr.Row(): temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature") max_tokens = gr.Slider(16, 4096, value=512, step=16, label="Max Tokens") top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-p") with gr.Tabs(): # Tab 1: Multi-turn Chat with gr.TabItem("💬 Chat"): chatbot = gr.ChatInterface( fn=chat_stream, type="messages", additional_inputs=[system_prompt, temperature, max_tokens, top_p], examples=[ "Hello! What can you tell me about yourself?", "Explain how a GPU executes a matrix multiplication.", "Write a Python function to compute the Fibonacci sequence.", "What are the pros and cons of running LLMs on AMD GPUs?", ], chatbot=gr.Chatbot( height=520, type="messages", show_copy_button=True, placeholder="Type a message to start chatting...", ), ) # Tab 2: Single Prompt with gr.TabItem("📝 Single Prompt"): gr.Markdown("Send a one-shot prompt without conversation history.") prompt_input = gr.Textbox( label="Prompt", placeholder="Enter your prompt...", lines=4, max_lines=10, ) generate_btn = gr.Button("🚀 Generate", variant="primary") output_text = gr.Textbox(label="Response", lines=12, interactive=False) output_metrics = gr.Markdown(label="Metrics") generate_btn.click( fn=single_prompt, inputs=[prompt_input, system_prompt, temperature, max_tokens, top_p], outputs=[output_text, output_metrics], ) gr.Examples( examples=[ ["Summarise the key differences between CUDA and ROCm for ML workloads."], ["Write a haiku about Kubernetes."], ["Explain Ray Serve in one paragraph for someone new to ML serving."], ["List 5 creative uses for a homelab GPU cluster."], ], inputs=[prompt_input], ) create_footer() if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)