llm streaming outputs, bumped up images.
This commit is contained in:
@@ -8,3 +8,8 @@ resources:
|
|||||||
- llm.yaml
|
- llm.yaml
|
||||||
- tts.yaml
|
- tts.yaml
|
||||||
- stt.yaml
|
- stt.yaml
|
||||||
|
|
||||||
|
images:
|
||||||
|
- name: gitea-http.gitea.svc.cluster.local:3000/daviestechlabs/gradio-ui
|
||||||
|
newName: registry.lab.daviestechlabs.io/daviestechlabs/gradio-ui
|
||||||
|
newTag: "0.0.7"
|
||||||
|
|||||||
138
llm.py
138
llm.py
@@ -3,13 +3,14 @@
|
|||||||
LLM Chat Demo - Gradio UI for testing vLLM inference service.
|
LLM Chat Demo - Gradio UI for testing vLLM inference service.
|
||||||
|
|
||||||
Features:
|
Features:
|
||||||
- Multi-turn chat with streaming responses
|
- Multi-turn chat with true SSE streaming responses
|
||||||
- Configurable temperature, max tokens, top-p
|
- Configurable temperature, max tokens, top-p
|
||||||
- System prompt customisation
|
- System prompt customisation
|
||||||
- Token usage and latency metrics
|
- Token usage and latency metrics
|
||||||
- Chat history management
|
- Chat history management
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
@@ -127,6 +128,27 @@ async_client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=30.0))
|
|||||||
sync_client = httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0))
|
sync_client = httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0))
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_content(content) -> str:
|
||||||
|
"""Extract plain text from message content.
|
||||||
|
|
||||||
|
Handles both plain strings and Gradio 6.x content-parts format:
|
||||||
|
[{"type": "text", "text": "..."}] or [{"text": "..."}]
|
||||||
|
"""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
parts.append(item.get("text", item.get("content", str(item))))
|
||||||
|
elif isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
else:
|
||||||
|
parts.append(str(item))
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
async def chat_stream(
|
async def chat_stream(
|
||||||
message: str,
|
message: str,
|
||||||
history: list[dict[str, str]],
|
history: list[dict[str, str]],
|
||||||
@@ -135,18 +157,21 @@ async def chat_stream(
|
|||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
):
|
):
|
||||||
"""Stream chat responses from the vLLM endpoint."""
|
"""Stream chat responses from the vLLM endpoint via SSE."""
|
||||||
if not message.strip():
|
if not message.strip():
|
||||||
yield ""
|
yield ""
|
||||||
return
|
return
|
||||||
|
|
||||||
# Build message list from history
|
# Build message list from history, normalising content-parts
|
||||||
messages = []
|
messages = []
|
||||||
if system_prompt.strip():
|
if system_prompt.strip():
|
||||||
messages.append({"role": "system", "content": system_prompt})
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
|
||||||
for entry in history:
|
for entry in history:
|
||||||
messages.append({"role": entry["role"], "content": entry["content"]})
|
messages.append({
|
||||||
|
"role": entry["role"],
|
||||||
|
"content": _extract_content(entry["content"]),
|
||||||
|
})
|
||||||
|
|
||||||
messages.append({"role": "user", "content": message})
|
messages.append({"role": "user", "content": message})
|
||||||
|
|
||||||
@@ -155,45 +180,86 @@ async def chat_stream(
|
|||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
|
"stream": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await async_client.post(LLM_URL, json=payload)
|
# Try true SSE streaming first
|
||||||
response.raise_for_status()
|
async with async_client.stream("POST", LLM_URL, json=payload) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
content_type = response.headers.get("content-type", "")
|
||||||
|
|
||||||
result = response.json()
|
if "text/event-stream" in content_type:
|
||||||
text = result["choices"][0]["message"]["content"]
|
# SSE streaming — accumulate deltas
|
||||||
latency = time.time() - start_time
|
full_text = ""
|
||||||
usage = result.get("usage", {})
|
async for line in response.aiter_lines():
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
data = line[6:]
|
||||||
|
if data.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
chunk = json.loads(data)
|
||||||
|
delta = (
|
||||||
|
chunk.get("choices", [{}])[0]
|
||||||
|
.get("delta", {})
|
||||||
|
.get("content", "")
|
||||||
|
)
|
||||||
|
if delta:
|
||||||
|
full_text += delta
|
||||||
|
yield full_text
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
logger.info(
|
latency = time.time() - start_time
|
||||||
"LLM response: %d tokens in %.1fs (prompt=%d, completion=%d)",
|
logger.info("LLM streamed response: %d chars in %.1fs", len(full_text), latency)
|
||||||
usage.get("total_tokens", 0),
|
|
||||||
latency,
|
|
||||||
usage.get("prompt_tokens", 0),
|
|
||||||
usage.get("completion_tokens", 0),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log to MLflow
|
# Best-effort metrics from the final SSE payload
|
||||||
_log_llm_metrics(
|
_log_llm_metrics(
|
||||||
latency=latency,
|
latency=latency,
|
||||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
prompt_tokens=0,
|
||||||
completion_tokens=usage.get("completion_tokens", 0),
|
completion_tokens=len(full_text.split()),
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# Non-streaming fallback (endpoint doesn't support stream)
|
||||||
|
body = await response.aread()
|
||||||
|
result = json.loads(body)
|
||||||
|
text = _extract_content(
|
||||||
|
result["choices"][0]["message"]["content"]
|
||||||
|
)
|
||||||
|
latency = time.time() - start_time
|
||||||
|
usage = result.get("usage", {})
|
||||||
|
|
||||||
# Yield text progressively for a nicer streaming feel
|
logger.info(
|
||||||
chunk_size = 4
|
"LLM response: %d tokens in %.1fs (prompt=%d, completion=%d)",
|
||||||
words = text.split(" ")
|
usage.get("total_tokens", 0),
|
||||||
partial = ""
|
latency,
|
||||||
for i, word in enumerate(words):
|
usage.get("prompt_tokens", 0),
|
||||||
partial += ("" if i == 0 else " ") + word
|
usage.get("completion_tokens", 0),
|
||||||
if i % chunk_size == 0 or i == len(words) - 1:
|
)
|
||||||
yield partial
|
|
||||||
|
_log_llm_metrics(
|
||||||
|
latency=latency,
|
||||||
|
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||||
|
completion_tokens=usage.get("completion_tokens", 0),
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Yield text progressively for a nicer feel
|
||||||
|
chunk_size = 4
|
||||||
|
words = text.split(" ")
|
||||||
|
partial = ""
|
||||||
|
for i, word in enumerate(words):
|
||||||
|
partial += ("" if i == 0 else " ") + word
|
||||||
|
if i % chunk_size == 0 or i == len(words) - 1:
|
||||||
|
yield partial
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
logger.exception("LLM request failed")
|
logger.exception("LLM request failed")
|
||||||
@@ -266,7 +332,7 @@ def single_prompt(
|
|||||||
result = response.json()
|
result = response.json()
|
||||||
latency = time.time() - start_time
|
latency = time.time() - start_time
|
||||||
|
|
||||||
text = result["choices"][0]["message"]["content"]
|
text = _extract_content(result["choices"][0]["message"]["content"])
|
||||||
usage = result.get("usage", {})
|
usage = result.get("usage", {})
|
||||||
|
|
||||||
# Log to MLflow
|
# Log to MLflow
|
||||||
@@ -325,7 +391,7 @@ Chat with **Llama 3.1 70B** (AWQ INT4) served via vLLM on AMD Strix Halo (ROCm).
|
|||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
|
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
|
||||||
max_tokens = gr.Slider(16, 4096, value=512, step=16, label="Max Tokens")
|
max_tokens = gr.Slider(16, 8192, value=2048, step=16, label="Max Tokens")
|
||||||
top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-p")
|
top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-p")
|
||||||
|
|
||||||
with gr.Tabs():
|
with gr.Tabs():
|
||||||
|
|||||||
Reference in New Issue
Block a user