From 6a391147a610ef40188b0cad3446df0c9700c131 Mon Sep 17 00:00:00 2001 From: "Billy D." Date: Thu, 12 Feb 2026 18:47:50 -0500 Subject: [PATCH] minor: refactoring big changes. --- ray_serve/serve_llm.py | 68 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 59 insertions(+), 9 deletions(-) diff --git a/ray_serve/serve_llm.py b/ray_serve/serve_llm.py index c067860..4ae11a6 100644 --- a/ray_serve/serve_llm.py +++ b/ray_serve/serve_llm.py @@ -14,6 +14,30 @@ from ray import serve @serve.deployment(name="LLMDeployment", num_replicas=1) class LLMDeployment: def __init__(self): + # Workaround: vLLM's rocm.py:verify_quantization unconditionally + # sets VLLM_USE_TRITON_AWQ=1 (a bug — the os.environ line is + # outside the if-block). Triton AWQ kernels can't compile for + # gfx1151, so we must keep it at 0. Monkey-patch before engine + # creation to prevent the override. + os.environ["VLLM_USE_TRITON_AWQ"] = "0" + try: + from vllm.platforms.rocm import RocmPlatform + + _orig_verify = RocmPlatform.verify_quantization.__func__ + + @classmethod # type: ignore[misc] + def _patched_verify(cls, quant: str) -> None: + saved = os.environ.get("VLLM_USE_TRITON_AWQ") + _orig_verify(cls, quant) + # Restore our value after vLLM clobbers it + if saved is not None: + os.environ["VLLM_USE_TRITON_AWQ"] = saved + + RocmPlatform.verify_quantization = _patched_verify + print("Patched RocmPlatform.verify_quantization to preserve VLLM_USE_TRITON_AWQ=0") + except Exception as exc: + print(f"Could not patch RocmPlatform: {exc}") + from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -26,13 +50,33 @@ class LLMDeployment: print(f"Max model length: {self.max_model_len}") print(f"GPU memory utilization: {self.gpu_memory_utilization}") - engine_args = AsyncEngineArgs( - model=self.model_id, - max_model_len=self.max_model_len, - gpu_memory_utilization=self.gpu_memory_utilization, - trust_remote_code=True, - disable_log_stats=True, - ) + self.enable_prefix_caching = os.environ.get("ENABLE_PREFIX_CACHING", "true").lower() == "true" + self.enable_chunked_prefill = os.environ.get("ENABLE_CHUNKED_PREFILL", "true").lower() == "true" + self.num_speculative_tokens = int(os.environ.get("NUM_SPECULATIVE_TOKENS", "0")) + self.ngram_prompt_lookup_max = int(os.environ.get("NGRAM_PROMPT_LOOKUP_MAX", "0")) + + engine_kwargs: dict[str, Any] = { + "model": self.model_id, + "max_model_len": self.max_model_len, + "gpu_memory_utilization": self.gpu_memory_utilization, + "trust_remote_code": True, + "enable_prefix_caching": self.enable_prefix_caching, + "enable_chunked_prefill": self.enable_chunked_prefill, + } + + # n-gram speculative decoding (no draft model needed) + if self.num_speculative_tokens > 0 and self.ngram_prompt_lookup_max > 0: + engine_kwargs["speculative_config"] = { + "method": "ngram", + "num_speculative_tokens": self.num_speculative_tokens, + "ngram_prompt_lookup_max": self.ngram_prompt_lookup_max, + } + print(f"Speculative decoding: ngram n={self.ngram_prompt_lookup_max}, k={self.num_speculative_tokens}") + + print(f"Prefix caching: {self.enable_prefix_caching}") + print(f"Chunked prefill: {self.enable_chunked_prefill}") + + engine_args = AsyncEngineArgs(**engine_kwargs) self.engine = AsyncLLMEngine.from_engine_args(engine_args) self.SamplingParams = SamplingParams print(f"Model {self.model_id} async engine created") @@ -54,6 +98,10 @@ class LLMDeployment: "model_id": self.model_id, "max_model_len": str(self.max_model_len), "gpu_memory_utilization": str(self.gpu_memory_utilization), + "enable_prefix_caching": str(self.enable_prefix_caching), + "enable_chunked_prefill": str(self.enable_chunked_prefill), + "num_speculative_tokens": str(self.num_speculative_tokens), + "ngram_prompt_lookup_max": str(self.ngram_prompt_lookup_max), } ) except ImportError: @@ -97,8 +145,10 @@ class LLMDeployment: generated_text = final_result.outputs[0].text latency = time.time() - start_time - prompt_tokens = len(prompt.split()) - completion_tokens = len(generated_text.split()) + # Use actual output token count from vLLM when available + vllm_output = final_result.outputs[0] + completion_tokens = len(vllm_output.token_ids) if hasattr(vllm_output, "token_ids") else len(generated_text.split()) + prompt_tokens = len(final_result.prompt_token_ids) if hasattr(final_result, "prompt_token_ids") else len(prompt.split()) # Log to MLflow if self._mlflow: