Each UI now logs per-request metrics to MLflow: - llm.py: latency, tokens/sec, prompt/completion tokens (gradio-llm-tuning) - embeddings.py: latency, text length, batch size (gradio-embeddings-tuning) - stt.py: latency, audio duration, real-time factor (gradio-stt-tuning) - tts.py: latency, text length, audio duration (gradio-tts-tuning) Uses try/except guarded imports so UIs still work if MLflow is unreachable. Persistent run per Gradio instance, batched metric logging via MlflowClient.log_batch().
367 lines
12 KiB
Python
367 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
LLM Chat Demo - Gradio UI for testing vLLM inference service.
|
|
|
|
Features:
|
|
- Multi-turn chat with streaming responses
|
|
- Configurable temperature, max tokens, top-p
|
|
- System prompt customisation
|
|
- Token usage and latency metrics
|
|
- Chat history management
|
|
"""
|
|
import os
|
|
import time
|
|
import logging
|
|
import json
|
|
|
|
import gradio as gr
|
|
import httpx
|
|
|
|
from theme import get_lab_theme, CUSTOM_CSS, create_footer
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger("llm-demo")
|
|
|
|
# Configuration
|
|
LLM_URL = os.environ.get(
|
|
"LLM_URL",
|
|
# Default: Ray Serve LLM endpoint
|
|
"http://ai-inference-serve-svc.ai-ml.svc.cluster.local:8000/llm",
|
|
)
|
|
|
|
# ─── MLflow experiment tracking ──────────────────────────────────────────
|
|
try:
|
|
import mlflow
|
|
from mlflow.tracking import MlflowClient
|
|
|
|
MLFLOW_TRACKING_URI = os.environ.get(
|
|
"MLFLOW_TRACKING_URI",
|
|
"http://mlflow.mlflow.svc.cluster.local:80",
|
|
)
|
|
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
|
|
_mlflow_client = MlflowClient()
|
|
|
|
# Ensure experiment exists
|
|
_experiment = _mlflow_client.get_experiment_by_name("gradio-llm-tuning")
|
|
if _experiment is None:
|
|
_experiment_id = _mlflow_client.create_experiment(
|
|
"gradio-llm-tuning",
|
|
artifact_location="/mlflow/artifacts/gradio-llm-tuning",
|
|
)
|
|
else:
|
|
_experiment_id = _experiment.experiment_id
|
|
|
|
# One persistent run per Gradio instance
|
|
_mlflow_run = mlflow.start_run(
|
|
experiment_id=_experiment_id,
|
|
run_name=f"gradio-llm-{os.environ.get('HOSTNAME', 'local')}",
|
|
tags={
|
|
"service": "gradio-llm",
|
|
"endpoint": LLM_URL,
|
|
"mlflow.runName": f"gradio-llm-{os.environ.get('HOSTNAME', 'local')}",
|
|
},
|
|
)
|
|
_mlflow_run_id = _mlflow_run.info.run_id
|
|
_mlflow_step = 0
|
|
MLFLOW_ENABLED = True
|
|
logger.info("MLflow tracking enabled: experiment=%s run=%s", _experiment_id, _mlflow_run_id)
|
|
except Exception as exc:
|
|
logger.warning("MLflow tracking disabled: %s", exc)
|
|
_mlflow_client = None
|
|
_mlflow_run_id = None
|
|
_mlflow_step = 0
|
|
MLFLOW_ENABLED = False
|
|
|
|
|
|
def _log_llm_metrics(
|
|
latency: float,
|
|
prompt_tokens: int,
|
|
completion_tokens: int,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
top_p: float,
|
|
) -> None:
|
|
"""Log inference metrics to MLflow (non-blocking best-effort)."""
|
|
global _mlflow_step
|
|
if not MLFLOW_ENABLED or _mlflow_client is None:
|
|
return
|
|
try:
|
|
_mlflow_step += 1
|
|
ts = int(time.time() * 1000)
|
|
total_tokens = prompt_tokens + completion_tokens
|
|
tps = completion_tokens / latency if latency > 0 else 0
|
|
_mlflow_client.log_batch(
|
|
_mlflow_run_id,
|
|
metrics=[
|
|
mlflow.entities.Metric("latency_s", latency, ts, _mlflow_step),
|
|
mlflow.entities.Metric("prompt_tokens", prompt_tokens, ts, _mlflow_step),
|
|
mlflow.entities.Metric("completion_tokens", completion_tokens, ts, _mlflow_step),
|
|
mlflow.entities.Metric("total_tokens", total_tokens, ts, _mlflow_step),
|
|
mlflow.entities.Metric("tokens_per_second", tps, ts, _mlflow_step),
|
|
mlflow.entities.Metric("temperature", temperature, ts, _mlflow_step),
|
|
mlflow.entities.Metric("max_tokens_requested", max_tokens, ts, _mlflow_step),
|
|
mlflow.entities.Metric("top_p", top_p, ts, _mlflow_step),
|
|
],
|
|
)
|
|
except Exception:
|
|
logger.debug("MLflow log failed", exc_info=True)
|
|
|
|
DEFAULT_SYSTEM_PROMPT = (
|
|
"You are a helpful AI assistant running on Davies Tech Labs homelab infrastructure. "
|
|
"You are powered by Llama 3.1 70B served via vLLM on AMD Strix Halo (ROCm). "
|
|
"Be concise and helpful."
|
|
)
|
|
|
|
# Use async client for streaming
|
|
async_client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=30.0))
|
|
sync_client = httpx.Client(timeout=10.0)
|
|
|
|
|
|
async def chat_stream(
|
|
message: str,
|
|
history: list[dict[str, str]],
|
|
system_prompt: str,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
top_p: float,
|
|
):
|
|
"""Stream chat responses from the vLLM endpoint."""
|
|
if not message.strip():
|
|
yield ""
|
|
return
|
|
|
|
# Build message list from history
|
|
messages = []
|
|
if system_prompt.strip():
|
|
messages.append({"role": "system", "content": system_prompt})
|
|
|
|
for entry in history:
|
|
messages.append({"role": entry["role"], "content": entry["content"]})
|
|
|
|
messages.append({"role": "user", "content": message})
|
|
|
|
payload = {
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
"top_p": top_p,
|
|
}
|
|
|
|
start_time = time.time()
|
|
|
|
try:
|
|
response = await async_client.post(LLM_URL, json=payload)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
text = result["choices"][0]["message"]["content"]
|
|
latency = time.time() - start_time
|
|
usage = result.get("usage", {})
|
|
|
|
logger.info(
|
|
"LLM response: %d tokens in %.1fs (prompt=%d, completion=%d)",
|
|
usage.get("total_tokens", 0),
|
|
latency,
|
|
usage.get("prompt_tokens", 0),
|
|
usage.get("completion_tokens", 0),
|
|
)
|
|
|
|
# Log to MLflow
|
|
_log_llm_metrics(
|
|
latency=latency,
|
|
prompt_tokens=usage.get("prompt_tokens", 0),
|
|
completion_tokens=usage.get("completion_tokens", 0),
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
top_p=top_p,
|
|
)
|
|
|
|
# Yield text progressively for a nicer streaming feel
|
|
chunk_size = 4
|
|
words = text.split(" ")
|
|
partial = ""
|
|
for i, word in enumerate(words):
|
|
partial += ("" if i == 0 else " ") + word
|
|
if i % chunk_size == 0 or i == len(words) - 1:
|
|
yield partial
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
logger.exception("LLM request failed")
|
|
yield f"❌ LLM service error: {e.response.status_code} — {e.response.text[:200]}"
|
|
except httpx.ConnectError:
|
|
yield "❌ Cannot connect to LLM service. Is the Ray Serve cluster running?"
|
|
except Exception as e:
|
|
logger.exception("LLM chat failed")
|
|
yield f"❌ Error: {e}"
|
|
|
|
|
|
def check_service_health() -> str:
|
|
"""Check if the LLM service is reachable."""
|
|
try:
|
|
response = sync_client.post(
|
|
LLM_URL,
|
|
json={
|
|
"messages": [{"role": "user", "content": "ping"}],
|
|
"max_tokens": 1,
|
|
"temperature": 0.0,
|
|
},
|
|
)
|
|
if response.status_code == 200:
|
|
return "🟢 LLM service is healthy"
|
|
return f"🟡 LLM responded with status {response.status_code}"
|
|
except httpx.ConnectError:
|
|
return "🔴 Cannot connect to LLM service"
|
|
except Exception as e:
|
|
return f"🔴 Service unavailable: {e}"
|
|
|
|
|
|
def single_prompt(
|
|
prompt: str,
|
|
system_prompt: str,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
top_p: float,
|
|
) -> tuple[str, str]:
|
|
"""Send a single prompt (non-chat mode) and return output + metrics."""
|
|
if not prompt.strip():
|
|
return "❌ Please enter a prompt", ""
|
|
|
|
messages = []
|
|
if system_prompt.strip():
|
|
messages.append({"role": "system", "content": system_prompt})
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
payload = {
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
"top_p": top_p,
|
|
}
|
|
|
|
start_time = time.time()
|
|
|
|
try:
|
|
client = httpx.Client(timeout=300.0)
|
|
response = client.post(LLM_URL, json=payload)
|
|
response.raise_for_status()
|
|
result = response.json()
|
|
latency = time.time() - start_time
|
|
|
|
text = result["choices"][0]["message"]["content"]
|
|
usage = result.get("usage", {})
|
|
|
|
# Log to MLflow
|
|
_log_llm_metrics(
|
|
latency=latency,
|
|
prompt_tokens=usage.get("prompt_tokens", 0),
|
|
completion_tokens=usage.get("completion_tokens", 0),
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
top_p=top_p,
|
|
)
|
|
|
|
metrics = f"""
|
|
**Generation Metrics:**
|
|
- Latency: {latency:.1f}s
|
|
- Prompt tokens: {usage.get('prompt_tokens', 'N/A')}
|
|
- Completion tokens: {usage.get('completion_tokens', 'N/A')}
|
|
- Total tokens: {usage.get('total_tokens', 'N/A')}
|
|
- Model: {result.get('model', 'N/A')}
|
|
"""
|
|
return text, metrics
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
return f"❌ Error {e.response.status_code}: {e.response.text[:300]}", ""
|
|
except httpx.ConnectError:
|
|
return "❌ Cannot connect to LLM service", ""
|
|
except Exception as e:
|
|
return f"❌ {e}", ""
|
|
|
|
|
|
# ─── Build the Gradio app ────────────────────────────────────────────────
|
|
|
|
with gr.Blocks(theme=get_lab_theme(), css=CUSTOM_CSS, title="LLM Chat Demo") as demo:
|
|
gr.Markdown(
|
|
"""
|
|
# 🧠 LLM Chat Demo
|
|
|
|
Chat with **Llama 3.1 70B** (AWQ INT4) served via vLLM on AMD Strix Halo (ROCm).
|
|
"""
|
|
)
|
|
|
|
# Service status
|
|
with gr.Row():
|
|
health_btn = gr.Button("🔄 Check Service", size="sm")
|
|
health_status = gr.Textbox(label="Service Status", interactive=False)
|
|
|
|
health_btn.click(fn=check_service_health, outputs=health_status)
|
|
|
|
# Shared parameters
|
|
with gr.Accordion("⚙️ Parameters", open=False):
|
|
system_prompt = gr.Textbox(
|
|
label="System Prompt",
|
|
value=DEFAULT_SYSTEM_PROMPT,
|
|
lines=3,
|
|
max_lines=6,
|
|
)
|
|
with gr.Row():
|
|
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
|
|
max_tokens = gr.Slider(16, 4096, value=512, step=16, label="Max Tokens")
|
|
top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-p")
|
|
|
|
with gr.Tabs():
|
|
# Tab 1: Multi-turn Chat
|
|
with gr.TabItem("💬 Chat"):
|
|
chatbot = gr.ChatInterface(
|
|
fn=chat_stream,
|
|
additional_inputs=[system_prompt, temperature, max_tokens, top_p],
|
|
examples=[
|
|
["Hello! What can you tell me about yourself?"],
|
|
["Explain how a GPU executes a matrix multiplication."],
|
|
["Write a Python function to compute the Fibonacci sequence."],
|
|
["What are the pros and cons of running LLMs on AMD GPUs?"],
|
|
],
|
|
chatbot=gr.Chatbot(
|
|
height=520,
|
|
placeholder="Type a message to start chatting...",
|
|
),
|
|
)
|
|
|
|
# Tab 2: Single Prompt
|
|
with gr.TabItem("📝 Single Prompt"):
|
|
gr.Markdown("Send a one-shot prompt without conversation history.")
|
|
|
|
prompt_input = gr.Textbox(
|
|
label="Prompt",
|
|
placeholder="Enter your prompt...",
|
|
lines=4,
|
|
max_lines=10,
|
|
)
|
|
generate_btn = gr.Button("🚀 Generate", variant="primary")
|
|
|
|
output_text = gr.Textbox(label="Response", lines=12, interactive=False)
|
|
output_metrics = gr.Markdown(label="Metrics")
|
|
|
|
generate_btn.click(
|
|
fn=single_prompt,
|
|
inputs=[prompt_input, system_prompt, temperature, max_tokens, top_p],
|
|
outputs=[output_text, output_metrics],
|
|
)
|
|
|
|
gr.Examples(
|
|
examples=[
|
|
["Summarise the key differences between CUDA and ROCm for ML workloads."],
|
|
["Write a haiku about Kubernetes."],
|
|
["Explain Ray Serve in one paragraph for someone new to ML serving."],
|
|
["List 5 creative uses for a homelab GPU cluster."],
|
|
],
|
|
inputs=[prompt_input],
|
|
)
|
|
|
|
create_footer()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|