All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 2m9s
298 lines
12 KiB
Python
298 lines
12 KiB
Python
"""
|
|
Ray Serve deployment for vLLM with OpenAI-compatible API.
|
|
Runs on: khelben (Strix Halo 64GB, ROCm)
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import time
|
|
import uuid
|
|
from typing import Any
|
|
|
|
from ray import serve
|
|
from starlette.responses import StreamingResponse
|
|
|
|
|
|
@serve.deployment(name="LLMDeployment", num_replicas=1)
|
|
class LLMDeployment:
|
|
def __init__(self):
|
|
# Workaround: vLLM's rocm.py:verify_quantization unconditionally
|
|
# sets VLLM_USE_TRITON_AWQ=1 (a bug — the os.environ line is
|
|
# outside the if-block). Monkey-patch before engine creation to
|
|
# preserve whatever the operator set in runtime_env env_vars.
|
|
# On gfx1151 Triton AWQ works; the C++ awq_dequantize does NOT.
|
|
os.environ.setdefault("VLLM_USE_TRITON_AWQ", "1")
|
|
try:
|
|
from vllm.platforms.rocm import RocmPlatform
|
|
|
|
_orig_verify = RocmPlatform.verify_quantization.__func__
|
|
|
|
@classmethod # type: ignore[misc]
|
|
def _patched_verify(cls, quant: str) -> None:
|
|
saved = os.environ.get("VLLM_USE_TRITON_AWQ")
|
|
_orig_verify(cls, quant)
|
|
# Restore our value after vLLM clobbers it
|
|
if saved is not None:
|
|
os.environ["VLLM_USE_TRITON_AWQ"] = saved
|
|
|
|
RocmPlatform.verify_quantization = _patched_verify
|
|
triton_val = os.environ.get("VLLM_USE_TRITON_AWQ", "?")
|
|
print(f"Patched RocmPlatform.verify_quantization to preserve VLLM_USE_TRITON_AWQ={triton_val}")
|
|
except Exception as exc:
|
|
print(f"Could not patch RocmPlatform: {exc}")
|
|
|
|
from vllm import SamplingParams
|
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
|
|
self.model_id = os.environ.get("MODEL_ID", "meta-llama/Llama-3.1-70B-Instruct")
|
|
self.max_model_len = int(os.environ.get("MAX_MODEL_LEN", "8192"))
|
|
self.gpu_memory_utilization = float(os.environ.get("GPU_MEMORY_UTILIZATION", "0.9"))
|
|
|
|
print(f"Loading vLLM model: {self.model_id}")
|
|
print(f"Max model length: {self.max_model_len}")
|
|
print(f"GPU memory utilization: {self.gpu_memory_utilization}")
|
|
|
|
self.enable_prefix_caching = os.environ.get("ENABLE_PREFIX_CACHING", "true").lower() == "true"
|
|
self.enable_chunked_prefill = os.environ.get("ENABLE_CHUNKED_PREFILL", "true").lower() == "true"
|
|
self.num_speculative_tokens = int(os.environ.get("NUM_SPECULATIVE_TOKENS", "0"))
|
|
self.ngram_prompt_lookup_max = int(os.environ.get("NGRAM_PROMPT_LOOKUP_MAX", "0"))
|
|
self.enforce_eager = os.environ.get("ENFORCE_EAGER", "false").lower() == "true"
|
|
|
|
engine_kwargs: dict[str, Any] = {
|
|
"model": self.model_id,
|
|
"max_model_len": self.max_model_len,
|
|
"gpu_memory_utilization": self.gpu_memory_utilization,
|
|
"trust_remote_code": True,
|
|
"enable_prefix_caching": self.enable_prefix_caching,
|
|
"enable_chunked_prefill": self.enable_chunked_prefill,
|
|
"enforce_eager": self.enforce_eager,
|
|
}
|
|
|
|
# n-gram speculative decoding (no draft model needed)
|
|
if self.num_speculative_tokens > 0 and self.ngram_prompt_lookup_max > 0:
|
|
engine_kwargs["speculative_config"] = {
|
|
"method": "ngram",
|
|
"num_speculative_tokens": self.num_speculative_tokens,
|
|
"ngram_prompt_lookup_max": self.ngram_prompt_lookup_max,
|
|
}
|
|
print(f"Speculative decoding: ngram n={self.ngram_prompt_lookup_max}, k={self.num_speculative_tokens}")
|
|
|
|
print(f"Prefix caching: {self.enable_prefix_caching}")
|
|
print(f"Chunked prefill: {self.enable_chunked_prefill}")
|
|
print(f"Enforce eager (no torch.compile): {self.enforce_eager}")
|
|
|
|
engine_args = AsyncEngineArgs(**engine_kwargs)
|
|
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
|
self.SamplingParams = SamplingParams
|
|
self.default_max_tokens = int(os.environ.get("DEFAULT_MAX_TOKENS", "256"))
|
|
# Llama 3 stop tokens — safety net in case V1 engine ignores max_tokens
|
|
self._stop_token_ids = [
|
|
128001, # <|end_of_text|>
|
|
128009, # <|eot_id|>
|
|
]
|
|
print(f"Model {self.model_id} async engine created")
|
|
print(f"Default max tokens: {self.default_max_tokens}")
|
|
print(f"Stop token IDs: {self._stop_token_ids}")
|
|
|
|
# MLflow metrics — import locally to avoid cloudpickle
|
|
# serializing a module reference that fails on the worker
|
|
# (strixhalo uses py_executable which bypasses pip runtime_env)
|
|
try:
|
|
from ray_serve.mlflow_logger import InferenceLogger
|
|
|
|
self._mlflow = InferenceLogger(
|
|
experiment_name="ray-serve-llm",
|
|
run_name=f"llm-{self.model_id.split('/')[-1]}",
|
|
tags={"model.name": self.model_id, "model.framework": "vllm", "gpu": "strixhalo"},
|
|
flush_every=5,
|
|
)
|
|
self._mlflow.initialize(
|
|
params={
|
|
"model_id": self.model_id,
|
|
"max_model_len": str(self.max_model_len),
|
|
"gpu_memory_utilization": str(self.gpu_memory_utilization),
|
|
"enable_prefix_caching": str(self.enable_prefix_caching),
|
|
"enable_chunked_prefill": str(self.enable_chunked_prefill),
|
|
"num_speculative_tokens": str(self.num_speculative_tokens),
|
|
"ngram_prompt_lookup_max": str(self.ngram_prompt_lookup_max),
|
|
"enforce_eager": str(self.enforce_eager),
|
|
}
|
|
)
|
|
except ImportError:
|
|
self._mlflow = None
|
|
|
|
async def __call__(self, request) -> dict[str, Any] | StreamingResponse:
|
|
"""
|
|
Handle OpenAI-compatible chat completion requests.
|
|
|
|
Expected request format:
|
|
{
|
|
"model": "model-name",
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
|
"temperature": 0.7,
|
|
"max_tokens": 256,
|
|
"top_p": 1.0,
|
|
"stream": false
|
|
}
|
|
"""
|
|
# Ray Serve passes a Starlette Request for HTTP calls —
|
|
# parse the JSON body so we actually read the user's payload.
|
|
from starlette.requests import Request
|
|
|
|
if isinstance(request, Request):
|
|
body = await request.json()
|
|
else:
|
|
body = request
|
|
|
|
messages = body.get("messages", [])
|
|
temperature = body.get("temperature", 0.7)
|
|
max_tokens = min(
|
|
body.get("max_tokens", self.default_max_tokens),
|
|
self.max_model_len,
|
|
)
|
|
top_p = body.get("top_p", 1.0)
|
|
stop = body.get("stop")
|
|
stream = body.get("stream", False)
|
|
|
|
# Convert messages to prompt
|
|
prompt = self._format_messages(messages)
|
|
|
|
sampling_params = self.SamplingParams(
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
top_p=top_p,
|
|
stop=stop,
|
|
stop_token_ids=self._stop_token_ids,
|
|
)
|
|
|
|
if stream:
|
|
return StreamingResponse(
|
|
self._stream_generate(prompt, sampling_params),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
start_time = time.time()
|
|
request_id = uuid.uuid4().hex
|
|
final_result = None
|
|
async for result in self.engine.generate(prompt, sampling_params, request_id):
|
|
final_result = result
|
|
generated_text = final_result.outputs[0].text
|
|
latency = time.time() - start_time
|
|
|
|
# Use actual output token count from vLLM when available
|
|
vllm_output = final_result.outputs[0]
|
|
completion_tokens = len(vllm_output.token_ids) if hasattr(vllm_output, "token_ids") else len(generated_text.split())
|
|
prompt_tokens = len(final_result.prompt_token_ids) if hasattr(final_result, "prompt_token_ids") else len(prompt.split())
|
|
|
|
# Log to MLflow
|
|
if self._mlflow:
|
|
self._mlflow.log_request(
|
|
latency_s=latency,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=prompt_tokens + completion_tokens,
|
|
tokens_per_second=completion_tokens / latency if latency > 0 else 0,
|
|
temperature=temperature,
|
|
max_tokens_requested=max_tokens,
|
|
)
|
|
|
|
# Return OpenAI-compatible response
|
|
return {
|
|
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
|
"object": "chat.completion",
|
|
"created": int(time.time()),
|
|
"model": self.model_id,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": generated_text,
|
|
},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": prompt_tokens,
|
|
"completion_tokens": completion_tokens,
|
|
"total_tokens": prompt_tokens + completion_tokens,
|
|
},
|
|
}
|
|
|
|
async def _stream_generate(self, prompt: str, sampling_params):
|
|
"""Yield OpenAI-compatible SSE chunks as vLLM generates tokens."""
|
|
request_id = uuid.uuid4().hex
|
|
previous_text = ""
|
|
start_time = time.time()
|
|
completion_tokens = 0
|
|
|
|
async for result in self.engine.generate(prompt, sampling_params, request_id):
|
|
current_text = result.outputs[0].text
|
|
delta = current_text[len(previous_text):]
|
|
if delta:
|
|
completion_tokens += 1
|
|
chunk = {
|
|
"id": f"chatcmpl-{request_id[:8]}",
|
|
"object": "chat.completion.chunk",
|
|
"created": int(time.time()),
|
|
"model": self.model_id,
|
|
"choices": [{
|
|
"index": 0,
|
|
"delta": {"content": delta},
|
|
"finish_reason": None,
|
|
}],
|
|
}
|
|
yield f"data: {json.dumps(chunk)}\n\n"
|
|
previous_text = current_text
|
|
|
|
# Final chunk
|
|
final_chunk = {
|
|
"id": f"chatcmpl-{request_id[:8]}",
|
|
"object": "chat.completion.chunk",
|
|
"created": int(time.time()),
|
|
"model": self.model_id,
|
|
"choices": [{
|
|
"index": 0,
|
|
"delta": {},
|
|
"finish_reason": "stop",
|
|
}],
|
|
}
|
|
yield f"data: {json.dumps(final_chunk)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
# Log metrics
|
|
latency = time.time() - start_time
|
|
if self._mlflow:
|
|
self._mlflow.log_request(
|
|
latency_s=latency,
|
|
prompt_tokens=len(prompt.split()),
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=len(prompt.split()) + completion_tokens,
|
|
tokens_per_second=completion_tokens / latency if latency > 0 else 0,
|
|
temperature=sampling_params.temperature,
|
|
max_tokens_requested=sampling_params.max_tokens,
|
|
)
|
|
|
|
def _format_messages(self, messages: list[dict[str, str]]) -> str:
|
|
"""Format chat messages into a prompt string."""
|
|
formatted = ""
|
|
for msg in messages:
|
|
role = msg.get("role", "user")
|
|
content = msg.get("content", "")
|
|
if role == "system":
|
|
formatted += f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>"
|
|
elif role == "user":
|
|
formatted += f"<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>"
|
|
elif role == "assistant":
|
|
formatted += f"<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>"
|
|
formatted += "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
return formatted
|
|
|
|
|
|
app = LLMDeployment.bind()
|