llm streaming outputs, bumped up images.
This commit is contained in:
@@ -8,3 +8,8 @@ resources:
|
||||
- llm.yaml
|
||||
- tts.yaml
|
||||
- stt.yaml
|
||||
|
||||
images:
|
||||
- name: gitea-http.gitea.svc.cluster.local:3000/daviestechlabs/gradio-ui
|
||||
newName: registry.lab.daviestechlabs.io/daviestechlabs/gradio-ui
|
||||
newTag: "0.0.7"
|
||||
|
||||
88
llm.py
88
llm.py
@@ -3,13 +3,14 @@
|
||||
LLM Chat Demo - Gradio UI for testing vLLM inference service.
|
||||
|
||||
Features:
|
||||
- Multi-turn chat with streaming responses
|
||||
- Multi-turn chat with true SSE streaming responses
|
||||
- Configurable temperature, max tokens, top-p
|
||||
- System prompt customisation
|
||||
- Token usage and latency metrics
|
||||
- Chat history management
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
@@ -127,6 +128,27 @@ async_client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=30.0))
|
||||
sync_client = httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0))
|
||||
|
||||
|
||||
def _extract_content(content) -> str:
|
||||
"""Extract plain text from message content.
|
||||
|
||||
Handles both plain strings and Gradio 6.x content-parts format:
|
||||
[{"type": "text", "text": "..."}] or [{"text": "..."}]
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
parts.append(item.get("text", item.get("content", str(item))))
|
||||
elif isinstance(item, str):
|
||||
parts.append(item)
|
||||
else:
|
||||
parts.append(str(item))
|
||||
return "".join(parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
async def chat_stream(
|
||||
message: str,
|
||||
history: list[dict[str, str]],
|
||||
@@ -135,18 +157,21 @@ async def chat_stream(
|
||||
max_tokens: int,
|
||||
top_p: float,
|
||||
):
|
||||
"""Stream chat responses from the vLLM endpoint."""
|
||||
"""Stream chat responses from the vLLM endpoint via SSE."""
|
||||
if not message.strip():
|
||||
yield ""
|
||||
return
|
||||
|
||||
# Build message list from history
|
||||
# Build message list from history, normalising content-parts
|
||||
messages = []
|
||||
if system_prompt.strip():
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
for entry in history:
|
||||
messages.append({"role": entry["role"], "content": entry["content"]})
|
||||
messages.append({
|
||||
"role": entry["role"],
|
||||
"content": _extract_content(entry["content"]),
|
||||
})
|
||||
|
||||
messages.append({"role": "user", "content": message})
|
||||
|
||||
@@ -155,16 +180,58 @@ async def chat_stream(
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"top_p": top_p,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
response = await async_client.post(LLM_URL, json=payload)
|
||||
# Try true SSE streaming first
|
||||
async with async_client.stream("POST", LLM_URL, json=payload) as response:
|
||||
response.raise_for_status()
|
||||
content_type = response.headers.get("content-type", "")
|
||||
|
||||
result = response.json()
|
||||
text = result["choices"][0]["message"]["content"]
|
||||
if "text/event-stream" in content_type:
|
||||
# SSE streaming — accumulate deltas
|
||||
full_text = ""
|
||||
async for line in response.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data = line[6:]
|
||||
if data.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
delta = (
|
||||
chunk.get("choices", [{}])[0]
|
||||
.get("delta", {})
|
||||
.get("content", "")
|
||||
)
|
||||
if delta:
|
||||
full_text += delta
|
||||
yield full_text
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
latency = time.time() - start_time
|
||||
logger.info("LLM streamed response: %d chars in %.1fs", len(full_text), latency)
|
||||
|
||||
# Best-effort metrics from the final SSE payload
|
||||
_log_llm_metrics(
|
||||
latency=latency,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=len(full_text.split()),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
)
|
||||
else:
|
||||
# Non-streaming fallback (endpoint doesn't support stream)
|
||||
body = await response.aread()
|
||||
result = json.loads(body)
|
||||
text = _extract_content(
|
||||
result["choices"][0]["message"]["content"]
|
||||
)
|
||||
latency = time.time() - start_time
|
||||
usage = result.get("usage", {})
|
||||
|
||||
@@ -176,7 +243,6 @@ async def chat_stream(
|
||||
usage.get("completion_tokens", 0),
|
||||
)
|
||||
|
||||
# Log to MLflow
|
||||
_log_llm_metrics(
|
||||
latency=latency,
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
@@ -186,7 +252,7 @@ async def chat_stream(
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
# Yield text progressively for a nicer streaming feel
|
||||
# Yield text progressively for a nicer feel
|
||||
chunk_size = 4
|
||||
words = text.split(" ")
|
||||
partial = ""
|
||||
@@ -266,7 +332,7 @@ def single_prompt(
|
||||
result = response.json()
|
||||
latency = time.time() - start_time
|
||||
|
||||
text = result["choices"][0]["message"]["content"]
|
||||
text = _extract_content(result["choices"][0]["message"]["content"])
|
||||
usage = result.get("usage", {})
|
||||
|
||||
# Log to MLflow
|
||||
@@ -325,7 +391,7 @@ Chat with **Llama 3.1 70B** (AWQ INT4) served via vLLM on AMD Strix Halo (ROCm).
|
||||
)
|
||||
with gr.Row():
|
||||
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
|
||||
max_tokens = gr.Slider(16, 4096, value=512, step=16, label="Max Tokens")
|
||||
max_tokens = gr.Slider(16, 8192, value=2048, step=16, label="Max Tokens")
|
||||
top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-p")
|
||||
|
||||
with gr.Tabs():
|
||||
|
||||
Reference in New Issue
Block a user