diff --git a/kustomization.yaml b/kustomization.yaml index 2bc42eb..0ee3ec6 100644 --- a/kustomization.yaml +++ b/kustomization.yaml @@ -8,3 +8,8 @@ resources: - llm.yaml - tts.yaml - stt.yaml + +images: + - name: gitea-http.gitea.svc.cluster.local:3000/daviestechlabs/gradio-ui + newName: registry.lab.daviestechlabs.io/daviestechlabs/gradio-ui + newTag: "0.0.7" diff --git a/llm.py b/llm.py index 6c210f2..39d68b9 100644 --- a/llm.py +++ b/llm.py @@ -3,13 +3,14 @@ LLM Chat Demo - Gradio UI for testing vLLM inference service. Features: -- Multi-turn chat with streaming responses +- Multi-turn chat with true SSE streaming responses - Configurable temperature, max tokens, top-p - System prompt customisation - Token usage and latency metrics - Chat history management """ +import json import os import time import logging @@ -127,6 +128,27 @@ async_client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=30.0)) sync_client = httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0)) +def _extract_content(content) -> str: + """Extract plain text from message content. + + Handles both plain strings and Gradio 6.x content-parts format: + [{"type": "text", "text": "..."}] or [{"text": "..."}] + """ + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for item in content: + if isinstance(item, dict): + parts.append(item.get("text", item.get("content", str(item)))) + elif isinstance(item, str): + parts.append(item) + else: + parts.append(str(item)) + return "".join(parts) + return str(content) + + async def chat_stream( message: str, history: list[dict[str, str]], @@ -135,18 +157,21 @@ async def chat_stream( max_tokens: int, top_p: float, ): - """Stream chat responses from the vLLM endpoint.""" + """Stream chat responses from the vLLM endpoint via SSE.""" if not message.strip(): yield "" return - # Build message list from history + # Build message list from history, normalising content-parts messages = [] if system_prompt.strip(): messages.append({"role": "system", "content": system_prompt}) for entry in history: - messages.append({"role": entry["role"], "content": entry["content"]}) + messages.append({ + "role": entry["role"], + "content": _extract_content(entry["content"]), + }) messages.append({"role": "user", "content": message}) @@ -155,45 +180,86 @@ async def chat_stream( "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, + "stream": True, } start_time = time.time() try: - response = await async_client.post(LLM_URL, json=payload) - response.raise_for_status() + # Try true SSE streaming first + async with async_client.stream("POST", LLM_URL, json=payload) as response: + response.raise_for_status() + content_type = response.headers.get("content-type", "") - result = response.json() - text = result["choices"][0]["message"]["content"] - latency = time.time() - start_time - usage = result.get("usage", {}) + if "text/event-stream" in content_type: + # SSE streaming — accumulate deltas + full_text = "" + async for line in response.aiter_lines(): + if not line.startswith("data: "): + continue + data = line[6:] + if data.strip() == "[DONE]": + break + try: + chunk = json.loads(data) + delta = ( + chunk.get("choices", [{}])[0] + .get("delta", {}) + .get("content", "") + ) + if delta: + full_text += delta + yield full_text + except json.JSONDecodeError: + continue - logger.info( - "LLM response: %d tokens in %.1fs (prompt=%d, completion=%d)", - usage.get("total_tokens", 0), - latency, - usage.get("prompt_tokens", 0), - usage.get("completion_tokens", 0), - ) + latency = time.time() - start_time + logger.info("LLM streamed response: %d chars in %.1fs", len(full_text), latency) - # Log to MLflow - _log_llm_metrics( - latency=latency, - prompt_tokens=usage.get("prompt_tokens", 0), - completion_tokens=usage.get("completion_tokens", 0), - temperature=temperature, - max_tokens=max_tokens, - top_p=top_p, - ) + # Best-effort metrics from the final SSE payload + _log_llm_metrics( + latency=latency, + prompt_tokens=0, + completion_tokens=len(full_text.split()), + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + ) + else: + # Non-streaming fallback (endpoint doesn't support stream) + body = await response.aread() + result = json.loads(body) + text = _extract_content( + result["choices"][0]["message"]["content"] + ) + latency = time.time() - start_time + usage = result.get("usage", {}) - # Yield text progressively for a nicer streaming feel - chunk_size = 4 - words = text.split(" ") - partial = "" - for i, word in enumerate(words): - partial += ("" if i == 0 else " ") + word - if i % chunk_size == 0 or i == len(words) - 1: - yield partial + logger.info( + "LLM response: %d tokens in %.1fs (prompt=%d, completion=%d)", + usage.get("total_tokens", 0), + latency, + usage.get("prompt_tokens", 0), + usage.get("completion_tokens", 0), + ) + + _log_llm_metrics( + latency=latency, + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + ) + + # Yield text progressively for a nicer feel + chunk_size = 4 + words = text.split(" ") + partial = "" + for i, word in enumerate(words): + partial += ("" if i == 0 else " ") + word + if i % chunk_size == 0 or i == len(words) - 1: + yield partial except httpx.HTTPStatusError as e: logger.exception("LLM request failed") @@ -266,7 +332,7 @@ def single_prompt( result = response.json() latency = time.time() - start_time - text = result["choices"][0]["message"]["content"] + text = _extract_content(result["choices"][0]["message"]["content"]) usage = result.get("usage", {}) # Log to MLflow @@ -325,7 +391,7 @@ Chat with **Llama 3.1 70B** (AWQ INT4) served via vLLM on AMD Strix Halo (ROCm). ) with gr.Row(): temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature") - max_tokens = gr.Slider(16, 4096, value=512, step=16, label="Max Tokens") + max_tokens = gr.Slider(16, 8192, value=2048, step=16, label="Max Tokens") top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-p") with gr.Tabs():