From 59655e3dcfa243fd4956905ddb69b36bb9cd95a4 Mon Sep 17 00:00:00 2001 From: "Billy D." Date: Fri, 20 Feb 2026 16:52:08 -0500 Subject: [PATCH] feat: add SSE streaming support to LLM endpoint --- ray_serve/serve_llm.py | 69 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 68 insertions(+), 1 deletion(-) diff --git a/ray_serve/serve_llm.py b/ray_serve/serve_llm.py index bc27a20..b0b105b 100644 --- a/ray_serve/serve_llm.py +++ b/ray_serve/serve_llm.py @@ -3,12 +3,14 @@ Ray Serve deployment for vLLM with OpenAI-compatible API. Runs on: khelben (Strix Halo 64GB, ROCm) """ +import json import os import time import uuid from typing import Any from ray import serve +from starlette.responses import StreamingResponse @serve.deployment(name="LLMDeployment", num_replicas=1) @@ -120,7 +122,7 @@ class LLMDeployment: except ImportError: self._mlflow = None - async def __call__(self, request) -> dict[str, Any]: + async def __call__(self, request) -> dict[str, Any] | StreamingResponse: """ Handle OpenAI-compatible chat completion requests. @@ -151,6 +153,7 @@ class LLMDeployment: ) top_p = body.get("top_p", 1.0) stop = body.get("stop") + stream = body.get("stream", False) # Convert messages to prompt prompt = self._format_messages(messages) @@ -163,6 +166,16 @@ class LLMDeployment: stop_token_ids=self._stop_token_ids, ) + if stream: + return StreamingResponse( + self._stream_generate(prompt, sampling_params), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + }, + ) + start_time = time.time() request_id = uuid.uuid4().hex final_result = None @@ -211,6 +224,60 @@ class LLMDeployment: }, } + async def _stream_generate(self, prompt: str, sampling_params): + """Yield OpenAI-compatible SSE chunks as vLLM generates tokens.""" + request_id = uuid.uuid4().hex + previous_text = "" + start_time = time.time() + completion_tokens = 0 + + async for result in self.engine.generate(prompt, sampling_params, request_id): + current_text = result.outputs[0].text + delta = current_text[len(previous_text):] + if delta: + completion_tokens += 1 + chunk = { + "id": f"chatcmpl-{request_id[:8]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": self.model_id, + "choices": [{ + "index": 0, + "delta": {"content": delta}, + "finish_reason": None, + }], + } + yield f"data: {json.dumps(chunk)}\n\n" + previous_text = current_text + + # Final chunk + final_chunk = { + "id": f"chatcmpl-{request_id[:8]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": self.model_id, + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": "stop", + }], + } + yield f"data: {json.dumps(final_chunk)}\n\n" + yield "data: [DONE]\n\n" + + # Log metrics + latency = time.time() - start_time + if self._mlflow: + self._mlflow.log_request( + latency_s=latency, + prompt_tokens=len(prompt.split()), + completion_tokens=completion_tokens, + total_tokens=len(prompt.split()) + completion_tokens, + tokens_per_second=completion_tokens / latency if latency > 0 else 0, + temperature=sampling_params.temperature, + max_tokens_requested=sampling_params.max_tokens, + ) + def _format_messages(self, messages: list[dict[str, str]]) -> str: """Format chat messages into a prompt string.""" formatted = ""