llm streaming outputs, bumped up images.
Some checks failed
CI / Lint (push) Failing after 1m35s
CI / Release (push) Has been skipped
CI / Docker Build & Push (push) Has been skipped
CI / Deploy to Kubernetes (push) Has been skipped
CI / Notify (push) Successful in 1s

This commit is contained in:
2026-02-20 16:53:37 -05:00
parent c050d11ab4
commit f5a2545ac8
2 changed files with 107 additions and 36 deletions

View File

@@ -8,3 +8,8 @@ resources:
- llm.yaml
- tts.yaml
- stt.yaml
images:
- name: gitea-http.gitea.svc.cluster.local:3000/daviestechlabs/gradio-ui
newName: registry.lab.daviestechlabs.io/daviestechlabs/gradio-ui
newTag: "0.0.7"

88
llm.py
View File

@@ -3,13 +3,14 @@
LLM Chat Demo - Gradio UI for testing vLLM inference service.
Features:
- Multi-turn chat with streaming responses
- Multi-turn chat with true SSE streaming responses
- Configurable temperature, max tokens, top-p
- System prompt customisation
- Token usage and latency metrics
- Chat history management
"""
import json
import os
import time
import logging
@@ -127,6 +128,27 @@ async_client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=30.0))
sync_client = httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0))
def _extract_content(content) -> str:
"""Extract plain text from message content.
Handles both plain strings and Gradio 6.x content-parts format:
[{"type": "text", "text": "..."}] or [{"text": "..."}]
"""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
parts.append(item.get("text", item.get("content", str(item))))
elif isinstance(item, str):
parts.append(item)
else:
parts.append(str(item))
return "".join(parts)
return str(content)
async def chat_stream(
message: str,
history: list[dict[str, str]],
@@ -135,18 +157,21 @@ async def chat_stream(
max_tokens: int,
top_p: float,
):
"""Stream chat responses from the vLLM endpoint."""
"""Stream chat responses from the vLLM endpoint via SSE."""
if not message.strip():
yield ""
return
# Build message list from history
# Build message list from history, normalising content-parts
messages = []
if system_prompt.strip():
messages.append({"role": "system", "content": system_prompt})
for entry in history:
messages.append({"role": entry["role"], "content": entry["content"]})
messages.append({
"role": entry["role"],
"content": _extract_content(entry["content"]),
})
messages.append({"role": "user", "content": message})
@@ -155,16 +180,58 @@ async def chat_stream(
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"stream": True,
}
start_time = time.time()
try:
response = await async_client.post(LLM_URL, json=payload)
# Try true SSE streaming first
async with async_client.stream("POST", LLM_URL, json=payload) as response:
response.raise_for_status()
content_type = response.headers.get("content-type", "")
result = response.json()
text = result["choices"][0]["message"]["content"]
if "text/event-stream" in content_type:
# SSE streaming — accumulate deltas
full_text = ""
async for line in response.aiter_lines():
if not line.startswith("data: "):
continue
data = line[6:]
if data.strip() == "[DONE]":
break
try:
chunk = json.loads(data)
delta = (
chunk.get("choices", [{}])[0]
.get("delta", {})
.get("content", "")
)
if delta:
full_text += delta
yield full_text
except json.JSONDecodeError:
continue
latency = time.time() - start_time
logger.info("LLM streamed response: %d chars in %.1fs", len(full_text), latency)
# Best-effort metrics from the final SSE payload
_log_llm_metrics(
latency=latency,
prompt_tokens=0,
completion_tokens=len(full_text.split()),
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)
else:
# Non-streaming fallback (endpoint doesn't support stream)
body = await response.aread()
result = json.loads(body)
text = _extract_content(
result["choices"][0]["message"]["content"]
)
latency = time.time() - start_time
usage = result.get("usage", {})
@@ -176,7 +243,6 @@ async def chat_stream(
usage.get("completion_tokens", 0),
)
# Log to MLflow
_log_llm_metrics(
latency=latency,
prompt_tokens=usage.get("prompt_tokens", 0),
@@ -186,7 +252,7 @@ async def chat_stream(
top_p=top_p,
)
# Yield text progressively for a nicer streaming feel
# Yield text progressively for a nicer feel
chunk_size = 4
words = text.split(" ")
partial = ""
@@ -266,7 +332,7 @@ def single_prompt(
result = response.json()
latency = time.time() - start_time
text = result["choices"][0]["message"]["content"]
text = _extract_content(result["choices"][0]["message"]["content"])
usage = result.get("usage", {})
# Log to MLflow
@@ -325,7 +391,7 @@ Chat with **Llama 3.1 70B** (AWQ INT4) served via vLLM on AMD Strix Halo (ROCm).
)
with gr.Row():
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
max_tokens = gr.Slider(16, 4096, value=512, step=16, label="Max Tokens")
max_tokens = gr.Slider(16, 8192, value=2048, step=16, label="Max Tokens")
top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-p")
with gr.Tabs():