Files
gradio-ui/llm.py
2026-02-12 05:36:15 -05:00

270 lines
8.4 KiB
Python

#!/usr/bin/env python3
"""
LLM Chat Demo - Gradio UI for testing vLLM inference service.
Features:
- Multi-turn chat with streaming responses
- Configurable temperature, max tokens, top-p
- System prompt customisation
- Token usage and latency metrics
- Chat history management
"""
import os
import time
import logging
import json
import gradio as gr
import httpx
from theme import get_lab_theme, CUSTOM_CSS, create_footer
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("llm-demo")
# Configuration
LLM_URL = os.environ.get(
"LLM_URL",
# Default: Ray Serve LLM endpoint
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/llm",
)
DEFAULT_SYSTEM_PROMPT = (
"You are a helpful AI assistant running on Davies Tech Labs homelab infrastructure. "
"You are powered by Llama 3.1 70B served via vLLM on AMD Strix Halo (ROCm). "
"Be concise and helpful."
)
# Use async client for streaming
async_client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=30.0))
sync_client = httpx.Client(timeout=10.0)
async def chat_stream(
message: str,
history: list[dict[str, str]],
system_prompt: str,
temperature: float,
max_tokens: int,
top_p: float,
):
"""Stream chat responses from the vLLM endpoint."""
if not message.strip():
yield ""
return
# Build message list from history
messages = []
if system_prompt.strip():
messages.append({"role": "system", "content": system_prompt})
for entry in history:
messages.append({"role": entry["role"], "content": entry["content"]})
messages.append({"role": "user", "content": message})
payload = {
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
}
start_time = time.time()
try:
response = await async_client.post(LLM_URL, json=payload)
response.raise_for_status()
result = response.json()
text = result["choices"][0]["message"]["content"]
latency = time.time() - start_time
usage = result.get("usage", {})
logger.info(
"LLM response: %d tokens in %.1fs (prompt=%d, completion=%d)",
usage.get("total_tokens", 0),
latency,
usage.get("prompt_tokens", 0),
usage.get("completion_tokens", 0),
)
# Yield text progressively for a nicer streaming feel
chunk_size = 4
words = text.split(" ")
partial = ""
for i, word in enumerate(words):
partial += ("" if i == 0 else " ") + word
if i % chunk_size == 0 or i == len(words) - 1:
yield partial
except httpx.HTTPStatusError as e:
logger.exception("LLM request failed")
yield f"❌ LLM service error: {e.response.status_code}{e.response.text[:200]}"
except httpx.ConnectError:
yield "❌ Cannot connect to LLM service. Is the Ray Serve cluster running?"
except Exception as e:
logger.exception("LLM chat failed")
yield f"❌ Error: {e}"
def check_service_health() -> str:
"""Check if the LLM service is reachable."""
try:
response = sync_client.post(
LLM_URL,
json={
"messages": [{"role": "user", "content": "ping"}],
"max_tokens": 1,
"temperature": 0.0,
},
)
if response.status_code == 200:
return "🟢 LLM service is healthy"
return f"🟡 LLM responded with status {response.status_code}"
except httpx.ConnectError:
return "🔴 Cannot connect to LLM service"
except Exception as e:
return f"🔴 Service unavailable: {e}"
def single_prompt(
prompt: str,
system_prompt: str,
temperature: float,
max_tokens: int,
top_p: float,
) -> tuple[str, str]:
"""Send a single prompt (non-chat mode) and return output + metrics."""
if not prompt.strip():
return "❌ Please enter a prompt", ""
messages = []
if system_prompt.strip():
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
payload = {
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
}
start_time = time.time()
try:
client = httpx.Client(timeout=300.0)
response = client.post(LLM_URL, json=payload)
response.raise_for_status()
result = response.json()
latency = time.time() - start_time
text = result["choices"][0]["message"]["content"]
usage = result.get("usage", {})
metrics = f"""
**Generation Metrics:**
- Latency: {latency:.1f}s
- Prompt tokens: {usage.get('prompt_tokens', 'N/A')}
- Completion tokens: {usage.get('completion_tokens', 'N/A')}
- Total tokens: {usage.get('total_tokens', 'N/A')}
- Model: {result.get('model', 'N/A')}
"""
return text, metrics
except httpx.HTTPStatusError as e:
return f"❌ Error {e.response.status_code}: {e.response.text[:300]}", ""
except httpx.ConnectError:
return "❌ Cannot connect to LLM service", ""
except Exception as e:
return f"{e}", ""
# ─── Build the Gradio app ────────────────────────────────────────────────
with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="LLM Chat Demo") as demo:
gr.Markdown(
"""
# 🧠 LLM Chat Demo
Chat with **Llama 3.1 70B** (AWQ INT4) served via vLLM on AMD Strix Halo (ROCm).
"""
)
# Service status
with gr.Row():
health_btn = gr.Button("🔄 Check Service", size="sm")
health_status = gr.Textbox(label="Service Status", interactive=False)
health_btn.click(fn=check_service_health, outputs=health_status)
# Shared parameters
with gr.Accordion("⚙️ Parameters", open=False):
system_prompt = gr.Textbox(
label="System Prompt",
value=DEFAULT_SYSTEM_PROMPT,
lines=3,
max_lines=6,
)
with gr.Row():
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
max_tokens = gr.Slider(16, 4096, value=512, step=16, label="Max Tokens")
top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-p")
with gr.Tabs():
# Tab 1: Multi-turn Chat
with gr.TabItem("💬 Chat"):
chatbot = gr.ChatInterface(
fn=chat_stream,
additional_inputs=[system_prompt, temperature, max_tokens, top_p],
examples=[
["Hello! What can you tell me about yourself?"],
["Explain how a GPU executes a matrix multiplication."],
["Write a Python function to compute the Fibonacci sequence."],
["What are the pros and cons of running LLMs on AMD GPUs?"],
],
chatbot=gr.Chatbot(
height=520,
placeholder="Type a message to start chatting...",
),
)
# Tab 2: Single Prompt
with gr.TabItem("📝 Single Prompt"):
gr.Markdown("Send a one-shot prompt without conversation history.")
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt...",
lines=4,
max_lines=10,
)
generate_btn = gr.Button("🚀 Generate", variant="primary")
output_text = gr.Textbox(label="Response", lines=12, interactive=False)
output_metrics = gr.Markdown(label="Metrics")
generate_btn.click(
fn=single_prompt,
inputs=[prompt_input, system_prompt, temperature, max_tokens, top_p],
outputs=[output_text, output_metrics],
)
gr.Examples(
examples=[
["Summarise the key differences between CUDA and ROCm for ML workloads."],
["Write a haiku about Kubernetes."],
["Explain Ray Serve in one paragraph for someone new to ML serving."],
["List 5 creative uses for a homelab GPU cluster."],
],
inputs=[prompt_input],
)
create_footer()
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)