fix: respect VLLM_USE_TRITON_AWQ from runtime_env instead of hardcoding 0
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 14s
All checks were successful
Build and Publish ray-serve-apps / build-and-publish (push) Successful in 14s
The previous code unconditionally set VLLM_USE_TRITON_AWQ=0, overriding
the value from the RayService runtime_env env_vars. On gfx1151:
- Triton AWQ kernels work (TRITON_AWQ=1)
- C++ awq_dequantize op does NOT exist (TRITON_AWQ=0 → crash)
Changed to os.environ.setdefault('VLLM_USE_TRITON_AWQ', '1') so the
operator-configured value is preserved, defaulting to Triton AWQ.
This commit is contained in:
@@ -16,10 +16,10 @@ class LLMDeployment:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Workaround: vLLM's rocm.py:verify_quantization unconditionally
|
# Workaround: vLLM's rocm.py:verify_quantization unconditionally
|
||||||
# sets VLLM_USE_TRITON_AWQ=1 (a bug — the os.environ line is
|
# sets VLLM_USE_TRITON_AWQ=1 (a bug — the os.environ line is
|
||||||
# outside the if-block). Triton AWQ kernels can't compile for
|
# outside the if-block). Monkey-patch before engine creation to
|
||||||
# gfx1151, so we must keep it at 0. Monkey-patch before engine
|
# preserve whatever the operator set in runtime_env env_vars.
|
||||||
# creation to prevent the override.
|
# On gfx1151 Triton AWQ works; the C++ awq_dequantize does NOT.
|
||||||
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
|
os.environ.setdefault("VLLM_USE_TRITON_AWQ", "1")
|
||||||
try:
|
try:
|
||||||
from vllm.platforms.rocm import RocmPlatform
|
from vllm.platforms.rocm import RocmPlatform
|
||||||
|
|
||||||
@@ -34,7 +34,8 @@ class LLMDeployment:
|
|||||||
os.environ["VLLM_USE_TRITON_AWQ"] = saved
|
os.environ["VLLM_USE_TRITON_AWQ"] = saved
|
||||||
|
|
||||||
RocmPlatform.verify_quantization = _patched_verify
|
RocmPlatform.verify_quantization = _patched_verify
|
||||||
print("Patched RocmPlatform.verify_quantization to preserve VLLM_USE_TRITON_AWQ=0")
|
triton_val = os.environ.get("VLLM_USE_TRITON_AWQ", "?")
|
||||||
|
print(f"Patched RocmPlatform.verify_quantization to preserve VLLM_USE_TRITON_AWQ={triton_val}")
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
print(f"Could not patch RocmPlatform: {exc}")
|
print(f"Could not patch RocmPlatform: {exc}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user