feat: add SSE streaming support to LLM endpoint
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 2m9s
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 2m9s
This commit is contained in:
@@ -3,12 +3,14 @@ Ray Serve deployment for vLLM with OpenAI-compatible API.
|
||||
Runs on: khelben (Strix Halo 64GB, ROCm)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from ray import serve
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
|
||||
@serve.deployment(name="LLMDeployment", num_replicas=1)
|
||||
@@ -120,7 +122,7 @@ class LLMDeployment:
|
||||
except ImportError:
|
||||
self._mlflow = None
|
||||
|
||||
async def __call__(self, request) -> dict[str, Any]:
|
||||
async def __call__(self, request) -> dict[str, Any] | StreamingResponse:
|
||||
"""
|
||||
Handle OpenAI-compatible chat completion requests.
|
||||
|
||||
@@ -151,6 +153,7 @@ class LLMDeployment:
|
||||
)
|
||||
top_p = body.get("top_p", 1.0)
|
||||
stop = body.get("stop")
|
||||
stream = body.get("stream", False)
|
||||
|
||||
# Convert messages to prompt
|
||||
prompt = self._format_messages(messages)
|
||||
@@ -163,6 +166,16 @@ class LLMDeployment:
|
||||
stop_token_ids=self._stop_token_ids,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(
|
||||
self._stream_generate(prompt, sampling_params),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
request_id = uuid.uuid4().hex
|
||||
final_result = None
|
||||
@@ -211,6 +224,60 @@ class LLMDeployment:
|
||||
},
|
||||
}
|
||||
|
||||
async def _stream_generate(self, prompt: str, sampling_params):
|
||||
"""Yield OpenAI-compatible SSE chunks as vLLM generates tokens."""
|
||||
request_id = uuid.uuid4().hex
|
||||
previous_text = ""
|
||||
start_time = time.time()
|
||||
completion_tokens = 0
|
||||
|
||||
async for result in self.engine.generate(prompt, sampling_params, request_id):
|
||||
current_text = result.outputs[0].text
|
||||
delta = current_text[len(previous_text):]
|
||||
if delta:
|
||||
completion_tokens += 1
|
||||
chunk = {
|
||||
"id": f"chatcmpl-{request_id[:8]}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": self.model_id,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"content": delta},
|
||||
"finish_reason": None,
|
||||
}],
|
||||
}
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
previous_text = current_text
|
||||
|
||||
# Final chunk
|
||||
final_chunk = {
|
||||
"id": f"chatcmpl-{request_id[:8]}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": self.model_id,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": "stop",
|
||||
}],
|
||||
}
|
||||
yield f"data: {json.dumps(final_chunk)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
# Log metrics
|
||||
latency = time.time() - start_time
|
||||
if self._mlflow:
|
||||
self._mlflow.log_request(
|
||||
latency_s=latency,
|
||||
prompt_tokens=len(prompt.split()),
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=len(prompt.split()) + completion_tokens,
|
||||
tokens_per_second=completion_tokens / latency if latency > 0 else 0,
|
||||
temperature=sampling_params.temperature,
|
||||
max_tokens_requested=sampling_params.max_tokens,
|
||||
)
|
||||
|
||||
def _format_messages(self, messages: list[dict[str, str]]) -> str:
|
||||
"""Format chat messages into a prompt string."""
|
||||
formatted = ""
|
||||
|
||||
Reference in New Issue
Block a user