Some checks failed
Build and Push Images / build-nvidia (push) Failing after 7m25s
Build and Push Images / build-rdna2 (push) Failing after 7m29s
Build and Push Images / build-strixhalo (push) Failing after 6m45s
Build and Push Images / build-intel (push) Failing after 6m22s
Build and Push Images / Release (push) Has been skipped
Build and Push Images / Notify (push) Successful in 1s
Build and Publish ray-serve-apps / lint (push) Failing after 3m9s
Build and Publish ray-serve-apps / publish (push) Has been skipped
- Restructure ray-serve as proper Python package (ray_serve/) - Add pyproject.toml with hatch build system - Add CI workflow to publish to Gitea PyPI - Add py.typed for PEP 561 compliance - Aligns with ADR-0019 handler deployment strategy
109 lines
3.7 KiB
Python
109 lines
3.7 KiB
Python
"""
|
|
Ray Serve deployment for vLLM with OpenAI-compatible API.
|
|
Runs on: khelben (Strix Halo 64GB, ROCm)
|
|
"""
|
|
|
|
import os
|
|
import time
|
|
import uuid
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from ray import serve
|
|
|
|
|
|
@serve.deployment(name="LLMDeployment", num_replicas=1)
|
|
class LLMDeployment:
|
|
def __init__(self):
|
|
from vllm import LLM, SamplingParams
|
|
|
|
self.model_id = os.environ.get("MODEL_ID", "meta-llama/Llama-3.1-70B-Instruct")
|
|
self.max_model_len = int(os.environ.get("MAX_MODEL_LEN", "8192"))
|
|
self.gpu_memory_utilization = float(os.environ.get("GPU_MEMORY_UTILIZATION", "0.9"))
|
|
|
|
print(f"Loading vLLM model: {self.model_id}")
|
|
print(f"Max model length: {self.max_model_len}")
|
|
print(f"GPU memory utilization: {self.gpu_memory_utilization}")
|
|
|
|
self.llm = LLM(
|
|
model=self.model_id,
|
|
max_model_len=self.max_model_len,
|
|
gpu_memory_utilization=self.gpu_memory_utilization,
|
|
trust_remote_code=True,
|
|
)
|
|
self.SamplingParams = SamplingParams
|
|
print(f"Model {self.model_id} loaded successfully")
|
|
|
|
async def __call__(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Handle OpenAI-compatible chat completion requests.
|
|
|
|
Expected request format:
|
|
{
|
|
"model": "model-name",
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
|
"temperature": 0.7,
|
|
"max_tokens": 256,
|
|
"top_p": 1.0,
|
|
"stream": false
|
|
}
|
|
"""
|
|
messages = request.get("messages", [])
|
|
temperature = request.get("temperature", 0.7)
|
|
max_tokens = request.get("max_tokens", 256)
|
|
top_p = request.get("top_p", 1.0)
|
|
stop = request.get("stop", None)
|
|
|
|
# Convert messages to prompt
|
|
prompt = self._format_messages(messages)
|
|
|
|
sampling_params = self.SamplingParams(
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
top_p=top_p,
|
|
stop=stop,
|
|
)
|
|
|
|
outputs = self.llm.generate([prompt], sampling_params)
|
|
generated_text = outputs[0].outputs[0].text
|
|
|
|
# Return OpenAI-compatible response
|
|
return {
|
|
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
|
"object": "chat.completion",
|
|
"created": int(time.time()),
|
|
"model": self.model_id,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": generated_text,
|
|
},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": len(prompt.split()),
|
|
"completion_tokens": len(generated_text.split()),
|
|
"total_tokens": len(prompt.split()) + len(generated_text.split()),
|
|
},
|
|
}
|
|
|
|
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
|
|
"""Format chat messages into a prompt string."""
|
|
formatted = ""
|
|
for msg in messages:
|
|
role = msg.get("role", "user")
|
|
content = msg.get("content", "")
|
|
if role == "system":
|
|
formatted += f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>"
|
|
elif role == "user":
|
|
formatted += f"<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>"
|
|
elif role == "assistant":
|
|
formatted += f"<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>"
|
|
formatted += "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
return formatted
|
|
|
|
|
|
app = LLMDeployment.bind()
|