feat: add SSE streaming support to LLM endpoint
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 2m9s

This commit is contained in:
2026-02-20 16:52:08 -05:00
parent a973768aee
commit 59655e3dcf

View File

@@ -3,12 +3,14 @@ Ray Serve deployment for vLLM with OpenAI-compatible API.
Runs on: khelben (Strix Halo 64GB, ROCm) Runs on: khelben (Strix Halo 64GB, ROCm)
""" """
import json
import os import os
import time import time
import uuid import uuid
from typing import Any from typing import Any
from ray import serve from ray import serve
from starlette.responses import StreamingResponse
@serve.deployment(name="LLMDeployment", num_replicas=1) @serve.deployment(name="LLMDeployment", num_replicas=1)
@@ -120,7 +122,7 @@ class LLMDeployment:
except ImportError: except ImportError:
self._mlflow = None self._mlflow = None
async def __call__(self, request) -> dict[str, Any]: async def __call__(self, request) -> dict[str, Any] | StreamingResponse:
""" """
Handle OpenAI-compatible chat completion requests. Handle OpenAI-compatible chat completion requests.
@@ -151,6 +153,7 @@ class LLMDeployment:
) )
top_p = body.get("top_p", 1.0) top_p = body.get("top_p", 1.0)
stop = body.get("stop") stop = body.get("stop")
stream = body.get("stream", False)
# Convert messages to prompt # Convert messages to prompt
prompt = self._format_messages(messages) prompt = self._format_messages(messages)
@@ -163,6 +166,16 @@ class LLMDeployment:
stop_token_ids=self._stop_token_ids, stop_token_ids=self._stop_token_ids,
) )
if stream:
return StreamingResponse(
self._stream_generate(prompt, sampling_params),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)
start_time = time.time() start_time = time.time()
request_id = uuid.uuid4().hex request_id = uuid.uuid4().hex
final_result = None final_result = None
@@ -211,6 +224,60 @@ class LLMDeployment:
}, },
} }
async def _stream_generate(self, prompt: str, sampling_params):
"""Yield OpenAI-compatible SSE chunks as vLLM generates tokens."""
request_id = uuid.uuid4().hex
previous_text = ""
start_time = time.time()
completion_tokens = 0
async for result in self.engine.generate(prompt, sampling_params, request_id):
current_text = result.outputs[0].text
delta = current_text[len(previous_text):]
if delta:
completion_tokens += 1
chunk = {
"id": f"chatcmpl-{request_id[:8]}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": self.model_id,
"choices": [{
"index": 0,
"delta": {"content": delta},
"finish_reason": None,
}],
}
yield f"data: {json.dumps(chunk)}\n\n"
previous_text = current_text
# Final chunk
final_chunk = {
"id": f"chatcmpl-{request_id[:8]}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": self.model_id,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop",
}],
}
yield f"data: {json.dumps(final_chunk)}\n\n"
yield "data: [DONE]\n\n"
# Log metrics
latency = time.time() - start_time
if self._mlflow:
self._mlflow.log_request(
latency_s=latency,
prompt_tokens=len(prompt.split()),
completion_tokens=completion_tokens,
total_tokens=len(prompt.split()) + completion_tokens,
tokens_per_second=completion_tokens / latency if latency > 0 else 0,
temperature=sampling_params.temperature,
max_tokens_requested=sampling_params.max_tokens,
)
def _format_messages(self, messages: list[dict[str, str]]) -> str: def _format_messages(self, messages: list[dict[str, str]]) -> str:
"""Format chat messages into a prompt string.""" """Format chat messages into a prompt string."""
formatted = "" formatted = ""