minor: refactoring big changes.
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 12s

This commit is contained in:
2026-02-12 18:47:50 -05:00
parent 297b0d8ebd
commit 6a391147a6

View File

@@ -14,6 +14,30 @@ from ray import serve
@serve.deployment(name="LLMDeployment", num_replicas=1) @serve.deployment(name="LLMDeployment", num_replicas=1)
class LLMDeployment: class LLMDeployment:
def __init__(self): def __init__(self):
# Workaround: vLLM's rocm.py:verify_quantization unconditionally
# sets VLLM_USE_TRITON_AWQ=1 (a bug — the os.environ line is
# outside the if-block). Triton AWQ kernels can't compile for
# gfx1151, so we must keep it at 0. Monkey-patch before engine
# creation to prevent the override.
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
try:
from vllm.platforms.rocm import RocmPlatform
_orig_verify = RocmPlatform.verify_quantization.__func__
@classmethod # type: ignore[misc]
def _patched_verify(cls, quant: str) -> None:
saved = os.environ.get("VLLM_USE_TRITON_AWQ")
_orig_verify(cls, quant)
# Restore our value after vLLM clobbers it
if saved is not None:
os.environ["VLLM_USE_TRITON_AWQ"] = saved
RocmPlatform.verify_quantization = _patched_verify
print("Patched RocmPlatform.verify_quantization to preserve VLLM_USE_TRITON_AWQ=0")
except Exception as exc:
print(f"Could not patch RocmPlatform: {exc}")
from vllm import SamplingParams from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
@@ -26,13 +50,33 @@ class LLMDeployment:
print(f"Max model length: {self.max_model_len}") print(f"Max model length: {self.max_model_len}")
print(f"GPU memory utilization: {self.gpu_memory_utilization}") print(f"GPU memory utilization: {self.gpu_memory_utilization}")
engine_args = AsyncEngineArgs( self.enable_prefix_caching = os.environ.get("ENABLE_PREFIX_CACHING", "true").lower() == "true"
model=self.model_id, self.enable_chunked_prefill = os.environ.get("ENABLE_CHUNKED_PREFILL", "true").lower() == "true"
max_model_len=self.max_model_len, self.num_speculative_tokens = int(os.environ.get("NUM_SPECULATIVE_TOKENS", "0"))
gpu_memory_utilization=self.gpu_memory_utilization, self.ngram_prompt_lookup_max = int(os.environ.get("NGRAM_PROMPT_LOOKUP_MAX", "0"))
trust_remote_code=True,
disable_log_stats=True, engine_kwargs: dict[str, Any] = {
) "model": self.model_id,
"max_model_len": self.max_model_len,
"gpu_memory_utilization": self.gpu_memory_utilization,
"trust_remote_code": True,
"enable_prefix_caching": self.enable_prefix_caching,
"enable_chunked_prefill": self.enable_chunked_prefill,
}
# n-gram speculative decoding (no draft model needed)
if self.num_speculative_tokens > 0 and self.ngram_prompt_lookup_max > 0:
engine_kwargs["speculative_config"] = {
"method": "ngram",
"num_speculative_tokens": self.num_speculative_tokens,
"ngram_prompt_lookup_max": self.ngram_prompt_lookup_max,
}
print(f"Speculative decoding: ngram n={self.ngram_prompt_lookup_max}, k={self.num_speculative_tokens}")
print(f"Prefix caching: {self.enable_prefix_caching}")
print(f"Chunked prefill: {self.enable_chunked_prefill}")
engine_args = AsyncEngineArgs(**engine_kwargs)
self.engine = AsyncLLMEngine.from_engine_args(engine_args) self.engine = AsyncLLMEngine.from_engine_args(engine_args)
self.SamplingParams = SamplingParams self.SamplingParams = SamplingParams
print(f"Model {self.model_id} async engine created") print(f"Model {self.model_id} async engine created")
@@ -54,6 +98,10 @@ class LLMDeployment:
"model_id": self.model_id, "model_id": self.model_id,
"max_model_len": str(self.max_model_len), "max_model_len": str(self.max_model_len),
"gpu_memory_utilization": str(self.gpu_memory_utilization), "gpu_memory_utilization": str(self.gpu_memory_utilization),
"enable_prefix_caching": str(self.enable_prefix_caching),
"enable_chunked_prefill": str(self.enable_chunked_prefill),
"num_speculative_tokens": str(self.num_speculative_tokens),
"ngram_prompt_lookup_max": str(self.ngram_prompt_lookup_max),
} }
) )
except ImportError: except ImportError:
@@ -97,8 +145,10 @@ class LLMDeployment:
generated_text = final_result.outputs[0].text generated_text = final_result.outputs[0].text
latency = time.time() - start_time latency = time.time() - start_time
prompt_tokens = len(prompt.split()) # Use actual output token count from vLLM when available
completion_tokens = len(generated_text.split()) vllm_output = final_result.outputs[0]
completion_tokens = len(vllm_output.token_ids) if hasattr(vllm_output, "token_ids") else len(generated_text.split())
prompt_tokens = len(final_result.prompt_token_ids) if hasattr(final_result, "prompt_token_ids") else len(prompt.split())
# Log to MLflow # Log to MLflow
if self._mlflow: if self._mlflow: