fix: apply ruff fixes to ray_serve package
[ray-serve only] - Fix whitespace in docstrings - Add strict=True to zip() calls - Use ternary operators where appropriate - Rename unused loop variables
This commit is contained in:
@@ -6,7 +6,7 @@ Runs on: khelben (Strix Halo 64GB, ROCm)
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from ray import serve
|
||||
|
||||
@@ -15,15 +15,15 @@ from ray import serve
|
||||
class LLMDeployment:
|
||||
def __init__(self):
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
self.model_id = os.environ.get("MODEL_ID", "meta-llama/Llama-3.1-70B-Instruct")
|
||||
self.max_model_len = int(os.environ.get("MAX_MODEL_LEN", "8192"))
|
||||
self.gpu_memory_utilization = float(os.environ.get("GPU_MEMORY_UTILIZATION", "0.9"))
|
||||
|
||||
|
||||
print(f"Loading vLLM model: {self.model_id}")
|
||||
print(f"Max model length: {self.max_model_len}")
|
||||
print(f"GPU memory utilization: {self.gpu_memory_utilization}")
|
||||
|
||||
|
||||
self.llm = LLM(
|
||||
model=self.model_id,
|
||||
max_model_len=self.max_model_len,
|
||||
@@ -33,10 +33,10 @@ class LLMDeployment:
|
||||
self.SamplingParams = SamplingParams
|
||||
print(f"Model {self.model_id} loaded successfully")
|
||||
|
||||
async def __call__(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def __call__(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Handle OpenAI-compatible chat completion requests.
|
||||
|
||||
|
||||
Expected request format:
|
||||
{
|
||||
"model": "model-name",
|
||||
@@ -51,21 +51,21 @@ class LLMDeployment:
|
||||
temperature = request.get("temperature", 0.7)
|
||||
max_tokens = request.get("max_tokens", 256)
|
||||
top_p = request.get("top_p", 1.0)
|
||||
stop = request.get("stop", None)
|
||||
|
||||
stop = request.get("stop")
|
||||
|
||||
# Convert messages to prompt
|
||||
prompt = self._format_messages(messages)
|
||||
|
||||
|
||||
sampling_params = self.SamplingParams(
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
|
||||
outputs = self.llm.generate([prompt], sampling_params)
|
||||
generated_text = outputs[0].outputs[0].text
|
||||
|
||||
|
||||
# Return OpenAI-compatible response
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||
@@ -89,7 +89,7 @@ class LLMDeployment:
|
||||
},
|
||||
}
|
||||
|
||||
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
|
||||
def _format_messages(self, messages: list[dict[str, str]]) -> str:
|
||||
"""Format chat messages into a prompt string."""
|
||||
formatted = ""
|
||||
for msg in messages:
|
||||
|
||||
Reference in New Issue
Block a user