diff --git a/ray_serve/serve_llm.py b/ray_serve/serve_llm.py index c14c65f..80fc9a9 100644 --- a/ray_serve/serve_llm.py +++ b/ray_serve/serve_llm.py @@ -14,7 +14,9 @@ from ray import serve @serve.deployment(name="LLMDeployment", num_replicas=1) class LLMDeployment: def __init__(self): - from vllm import LLM, SamplingParams + from vllm import SamplingParams + from vllm.engine.arg_utils import AsyncEngineArgs + from vllm.engine.async_llm_engine import AsyncLLMEngine self.model_id = os.environ.get("MODEL_ID", "meta-llama/Llama-3.1-70B-Instruct") self.max_model_len = int(os.environ.get("MAX_MODEL_LEN", "8192")) @@ -24,14 +26,16 @@ class LLMDeployment: print(f"Max model length: {self.max_model_len}") print(f"GPU memory utilization: {self.gpu_memory_utilization}") - self.llm = LLM( + engine_args = AsyncEngineArgs( model=self.model_id, max_model_len=self.max_model_len, gpu_memory_utilization=self.gpu_memory_utilization, trust_remote_code=True, + disable_log_stats=True, ) + self.engine = AsyncLLMEngine.from_engine_args(engine_args) self.SamplingParams = SamplingParams - print(f"Model {self.model_id} loaded successfully") + print(f"Model {self.model_id} async engine created") async def __call__(self, request: dict[str, Any]) -> dict[str, Any]: """ @@ -63,8 +67,11 @@ class LLMDeployment: stop=stop, ) - outputs = self.llm.generate([prompt], sampling_params) - generated_text = outputs[0].outputs[0].text + request_id = uuid.uuid4().hex + final_result = None + async for result in self.engine.generate(prompt, sampling_params, request_id): + final_result = result + generated_text = final_result.outputs[0].text # Return OpenAI-compatible response return {