feat: add SSE streaming support to LLM endpoint
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 2m9s
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 2m9s
This commit is contained in:
@@ -3,12 +3,14 @@ Ray Serve deployment for vLLM with OpenAI-compatible API.
|
|||||||
Runs on: khelben (Strix Halo 64GB, ROCm)
|
Runs on: khelben (Strix Halo 64GB, ROCm)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from ray import serve
|
from ray import serve
|
||||||
|
from starlette.responses import StreamingResponse
|
||||||
|
|
||||||
|
|
||||||
@serve.deployment(name="LLMDeployment", num_replicas=1)
|
@serve.deployment(name="LLMDeployment", num_replicas=1)
|
||||||
@@ -120,7 +122,7 @@ class LLMDeployment:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
self._mlflow = None
|
self._mlflow = None
|
||||||
|
|
||||||
async def __call__(self, request) -> dict[str, Any]:
|
async def __call__(self, request) -> dict[str, Any] | StreamingResponse:
|
||||||
"""
|
"""
|
||||||
Handle OpenAI-compatible chat completion requests.
|
Handle OpenAI-compatible chat completion requests.
|
||||||
|
|
||||||
@@ -151,6 +153,7 @@ class LLMDeployment:
|
|||||||
)
|
)
|
||||||
top_p = body.get("top_p", 1.0)
|
top_p = body.get("top_p", 1.0)
|
||||||
stop = body.get("stop")
|
stop = body.get("stop")
|
||||||
|
stream = body.get("stream", False)
|
||||||
|
|
||||||
# Convert messages to prompt
|
# Convert messages to prompt
|
||||||
prompt = self._format_messages(messages)
|
prompt = self._format_messages(messages)
|
||||||
@@ -163,6 +166,16 @@ class LLMDeployment:
|
|||||||
stop_token_ids=self._stop_token_ids,
|
stop_token_ids=self._stop_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return StreamingResponse(
|
||||||
|
self._stream_generate(prompt, sampling_params),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
request_id = uuid.uuid4().hex
|
request_id = uuid.uuid4().hex
|
||||||
final_result = None
|
final_result = None
|
||||||
@@ -211,6 +224,60 @@ class LLMDeployment:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def _stream_generate(self, prompt: str, sampling_params):
|
||||||
|
"""Yield OpenAI-compatible SSE chunks as vLLM generates tokens."""
|
||||||
|
request_id = uuid.uuid4().hex
|
||||||
|
previous_text = ""
|
||||||
|
start_time = time.time()
|
||||||
|
completion_tokens = 0
|
||||||
|
|
||||||
|
async for result in self.engine.generate(prompt, sampling_params, request_id):
|
||||||
|
current_text = result.outputs[0].text
|
||||||
|
delta = current_text[len(previous_text):]
|
||||||
|
if delta:
|
||||||
|
completion_tokens += 1
|
||||||
|
chunk = {
|
||||||
|
"id": f"chatcmpl-{request_id[:8]}",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": self.model_id,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"content": delta},
|
||||||
|
"finish_reason": None,
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
previous_text = current_text
|
||||||
|
|
||||||
|
# Final chunk
|
||||||
|
final_chunk = {
|
||||||
|
"id": f"chatcmpl-{request_id[:8]}",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": self.model_id,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(final_chunk)}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
# Log metrics
|
||||||
|
latency = time.time() - start_time
|
||||||
|
if self._mlflow:
|
||||||
|
self._mlflow.log_request(
|
||||||
|
latency_s=latency,
|
||||||
|
prompt_tokens=len(prompt.split()),
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=len(prompt.split()) + completion_tokens,
|
||||||
|
tokens_per_second=completion_tokens / latency if latency > 0 else 0,
|
||||||
|
temperature=sampling_params.temperature,
|
||||||
|
max_tokens_requested=sampling_params.max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
def _format_messages(self, messages: list[dict[str, str]]) -> str:
|
def _format_messages(self, messages: list[dict[str, str]]) -> str:
|
||||||
"""Format chat messages into a prompt string."""
|
"""Format chat messages into a prompt string."""
|
||||||
formatted = ""
|
formatted = ""
|
||||||
|
|||||||
Reference in New Issue
Block a user