feat: add SSE streaming support to LLM endpoint
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 2m9s

This commit is contained in:
2026-02-20 16:52:08 -05:00
parent a973768aee
commit 59655e3dcf

View File

@@ -3,12 +3,14 @@ Ray Serve deployment for vLLM with OpenAI-compatible API.
Runs on: khelben (Strix Halo 64GB, ROCm)
"""
import json
import os
import time
import uuid
from typing import Any
from ray import serve
from starlette.responses import StreamingResponse
@serve.deployment(name="LLMDeployment", num_replicas=1)
@@ -120,7 +122,7 @@ class LLMDeployment:
except ImportError:
self._mlflow = None
async def __call__(self, request) -> dict[str, Any]:
async def __call__(self, request) -> dict[str, Any] | StreamingResponse:
"""
Handle OpenAI-compatible chat completion requests.
@@ -151,6 +153,7 @@ class LLMDeployment:
)
top_p = body.get("top_p", 1.0)
stop = body.get("stop")
stream = body.get("stream", False)
# Convert messages to prompt
prompt = self._format_messages(messages)
@@ -163,6 +166,16 @@ class LLMDeployment:
stop_token_ids=self._stop_token_ids,
)
if stream:
return StreamingResponse(
self._stream_generate(prompt, sampling_params),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)
start_time = time.time()
request_id = uuid.uuid4().hex
final_result = None
@@ -211,6 +224,60 @@ class LLMDeployment:
},
}
async def _stream_generate(self, prompt: str, sampling_params):
"""Yield OpenAI-compatible SSE chunks as vLLM generates tokens."""
request_id = uuid.uuid4().hex
previous_text = ""
start_time = time.time()
completion_tokens = 0
async for result in self.engine.generate(prompt, sampling_params, request_id):
current_text = result.outputs[0].text
delta = current_text[len(previous_text):]
if delta:
completion_tokens += 1
chunk = {
"id": f"chatcmpl-{request_id[:8]}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": self.model_id,
"choices": [{
"index": 0,
"delta": {"content": delta},
"finish_reason": None,
}],
}
yield f"data: {json.dumps(chunk)}\n\n"
previous_text = current_text
# Final chunk
final_chunk = {
"id": f"chatcmpl-{request_id[:8]}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": self.model_id,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop",
}],
}
yield f"data: {json.dumps(final_chunk)}\n\n"
yield "data: [DONE]\n\n"
# Log metrics
latency = time.time() - start_time
if self._mlflow:
self._mlflow.log_request(
latency_s=latency,
prompt_tokens=len(prompt.split()),
completion_tokens=completion_tokens,
total_tokens=len(prompt.split()) + completion_tokens,
tokens_per_second=completion_tokens / latency if latency > 0 else 0,
temperature=sampling_params.temperature,
max_tokens_requested=sampling_params.max_tokens,
)
def _format_messages(self, messages: list[dict[str, str]]) -> str:
"""Format chat messages into a prompt string."""
formatted = ""