fixing vllm.
Some checks failed
Build and Push Images / determine-version (push) Successful in 5s
Build and Push Images / build (Dockerfile.ray-worker-nvidia, nvidia) (push) Failing after 39s
Build and Push Images / build (Dockerfile.ray-worker-intel, intel) (push) Failing after 42s
Build and Push Images / build (Dockerfile.ray-worker-strixhalo, strixhalo) (push) Failing after 20s
Build and Push Images / build (Dockerfile.ray-worker-rdna2, rdna2) (push) Failing after 23s
Build and Push Images / Release (push) Has been skipped
Build and Push Images / Notify (push) Successful in 1s
Some checks failed
Build and Push Images / determine-version (push) Successful in 5s
Build and Push Images / build (Dockerfile.ray-worker-nvidia, nvidia) (push) Failing after 39s
Build and Push Images / build (Dockerfile.ray-worker-intel, intel) (push) Failing after 42s
Build and Push Images / build (Dockerfile.ray-worker-strixhalo, strixhalo) (push) Failing after 20s
Build and Push Images / build (Dockerfile.ray-worker-rdna2, rdna2) (push) Failing after 23s
Build and Push Images / Release (push) Has been skipped
Build and Push Images / Notify (push) Successful in 1s
This commit is contained in:
@@ -1,22 +1,28 @@
|
|||||||
"""
|
"""
|
||||||
Monkey-patch torch.cuda.mem_get_info for AMD APUs with unified memory.
|
Monkey-patch torch.cuda memory reporting for AMD APUs with unified memory.
|
||||||
|
|
||||||
On Strix Halo (and other AMD APUs), torch.cuda.mem_get_info() returns
|
On Strix Halo (and other AMD APUs), PyTorch reports GTT (system RAM
|
||||||
GTT (system RAM accessible to GPU) numbers instead of actual VRAM:
|
accessible to GPU) instead of actual VRAM:
|
||||||
- Total: 128 GiB (GTT pool, NOT real VRAM)
|
|
||||||
- Free: 29 GiB (GTT free, NOT VRAM free)
|
torch.cuda.mem_get_info() → (29 GiB free, 128 GiB total) WRONG
|
||||||
|
torch.cuda.get_device_properties().total_memory → 128 GiB WRONG
|
||||||
|
|
||||||
Meanwhile sysfs reports the real VRAM:
|
Meanwhile sysfs reports the real VRAM:
|
||||||
- Total: 96 GiB (actual dedicated VRAM from BIOS)
|
/sys/class/drm/cardN/device/mem_info_vram_total → 96 GiB
|
||||||
- Free: 95.8 GiB (actual VRAM free)
|
/sys/class/drm/cardN/device/mem_info_vram_used → ~0.2 GiB
|
||||||
|
|
||||||
vLLM uses torch.cuda.mem_get_info() to calculate how much memory to
|
vLLM uses mem_get_info() and get_device_properties() to decide how much
|
||||||
pre-allocate. With wrong numbers it either OOMs or refuses to start.
|
memory to pre-allocate. With wrong numbers it either OOMs or refuses to
|
||||||
|
start ("Free memory less than desired GPU memory utilization").
|
||||||
|
|
||||||
This module patches mem_get_info to return sysfs VRAM values, which
|
This module patches both APIs to return sysfs VRAM values instead.
|
||||||
are the physically meaningful numbers for allocation decisions.
|
|
||||||
|
|
||||||
Installed as a .pth hook so it runs before any user code.
|
Installed as a .pth hook so it runs before any user code.
|
||||||
|
|
||||||
|
IMPORTANT: The re-entry guard (_STRIXHALO_VRAM_FIX_ACTIVE env var) is
|
||||||
|
set only during torch import to prevent infinite recursion when torch
|
||||||
|
spawns offload-arch. It is CLEARED afterward so child processes (e.g.
|
||||||
|
vLLM EngineCore subprocesses) can apply their own patch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@@ -63,6 +69,8 @@ def _should_skip() -> bool:
|
|||||||
# Re-entry guard: importing torch triggers subprocess calls to
|
# Re-entry guard: importing torch triggers subprocess calls to
|
||||||
# offload-arch (a Python script), which re-enters this .pth hook.
|
# offload-arch (a Python script), which re-enters this .pth hook.
|
||||||
# Without this guard it creates an infinite fork bomb.
|
# Without this guard it creates an infinite fork bomb.
|
||||||
|
# NOTE: This is only set transiently during _apply_patch() and
|
||||||
|
# cleared afterward — child processes will NOT see it.
|
||||||
if os.environ.get(_GUARD_ENV):
|
if os.environ.get(_GUARD_ENV):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -83,8 +91,31 @@ def _should_skip() -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class _VRAMDeviceProperties:
|
||||||
|
"""Proxy that overrides total_memory on torch device properties."""
|
||||||
|
|
||||||
|
def __init__(self, original, vram_total: int):
|
||||||
|
object.__setattr__(self, "_original", original)
|
||||||
|
object.__setattr__(self, "_vram_total", vram_total)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_memory(self) -> int:
|
||||||
|
return object.__getattribute__(self, "_vram_total")
|
||||||
|
|
||||||
|
def __getattr__(self, name: str):
|
||||||
|
return getattr(object.__getattribute__(self, "_original"), name)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
orig = object.__getattribute__(self, "_original")
|
||||||
|
vram = object.__getattribute__(self, "_vram_total")
|
||||||
|
return repr(orig).replace(
|
||||||
|
f"total_memory={orig.total_memory}",
|
||||||
|
f"total_memory={vram}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _apply_patch() -> None:
|
def _apply_patch() -> None:
|
||||||
"""Patch torch.cuda.mem_get_info if we detect unified memory mis-reporting."""
|
"""Patch torch.cuda memory APIs if we detect unified memory mis-reporting."""
|
||||||
if _should_skip():
|
if _should_skip():
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -101,13 +132,18 @@ def _apply_patch() -> None:
|
|||||||
return
|
return
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return
|
return
|
||||||
|
finally:
|
||||||
|
# CRITICAL: Clear the guard so child processes (vLLM EngineCore
|
||||||
|
# subprocesses, Ray actor workers, etc.) can apply their own patch.
|
||||||
|
# The guard only needs to live during the torch import above to
|
||||||
|
# prevent the offload-arch → .pth → torch import recursion.
|
||||||
|
os.environ.pop(_GUARD_ENV, None)
|
||||||
|
|
||||||
vram_info = _get_real_vram()
|
vram_info = _get_real_vram()
|
||||||
if vram_info is None:
|
if vram_info is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
vram_total, vram_used = vram_info
|
vram_total, vram_used = vram_info
|
||||||
vram_free = vram_total - vram_used
|
|
||||||
|
|
||||||
# Only patch if PyTorch total differs significantly from sysfs VRAM
|
# Only patch if PyTorch total differs significantly from sysfs VRAM
|
||||||
# (i.e. PyTorch is reporting GTT/unified memory, not real VRAM)
|
# (i.e. PyTorch is reporting GTT/unified memory, not real VRAM)
|
||||||
@@ -120,13 +156,14 @@ def _apply_patch() -> None:
|
|||||||
if abs(pt_total - vram_total) / max(pt_total, 1) < 0.10:
|
if abs(pt_total - vram_total) / max(pt_total, 1) < 0.10:
|
||||||
return
|
return
|
||||||
|
|
||||||
original_fn = torch.cuda.mem_get_info
|
# --- Patch 1: torch.cuda.mem_get_info ---
|
||||||
|
original_mem_get_info = torch.cuda.mem_get_info
|
||||||
|
|
||||||
def _patched_mem_get_info(device=None):
|
def _patched_mem_get_info(device=None):
|
||||||
"""Return real VRAM from sysfs instead of GTT numbers."""
|
"""Return real VRAM from sysfs instead of GTT numbers."""
|
||||||
real = _get_real_vram()
|
real = _get_real_vram()
|
||||||
if real is None:
|
if real is None:
|
||||||
return original_fn(device)
|
return original_mem_get_info(device)
|
||||||
total, used = real
|
total, used = real
|
||||||
# Account for PyTorch's own allocations on top of sysfs baseline
|
# Account for PyTorch's own allocations on top of sysfs baseline
|
||||||
pt_allocated = torch.cuda.memory_allocated(device or 0)
|
pt_allocated = torch.cuda.memory_allocated(device or 0)
|
||||||
@@ -134,9 +171,25 @@ def _apply_patch() -> None:
|
|||||||
return (max(free, 0), total)
|
return (max(free, 0), total)
|
||||||
|
|
||||||
torch.cuda.mem_get_info = _patched_mem_get_info
|
torch.cuda.mem_get_info = _patched_mem_get_info
|
||||||
|
|
||||||
|
# --- Patch 2: torch.cuda.get_device_properties ---
|
||||||
|
# total_memory is a read-only C property, so we wrap the return value
|
||||||
|
# in a proxy that overrides it with the real VRAM total.
|
||||||
|
original_get_device_properties = torch.cuda.get_device_properties
|
||||||
|
|
||||||
|
def _patched_get_device_properties(device=None):
|
||||||
|
props = original_get_device_properties(device)
|
||||||
|
real = _get_real_vram()
|
||||||
|
if real is None:
|
||||||
|
return props
|
||||||
|
return _VRAMDeviceProperties(props, real[0])
|
||||||
|
|
||||||
|
torch.cuda.get_device_properties = _patched_get_device_properties
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"strixhalo_vram_fix: patched torch.cuda.mem_get_info "
|
"strixhalo_vram_fix: patched torch.cuda.mem_get_info and "
|
||||||
"(PyTorch reported %d GiB total, sysfs VRAM is %d GiB)",
|
"get_device_properties (PyTorch reported %d GiB total, "
|
||||||
|
"sysfs VRAM is %d GiB)",
|
||||||
pt_total // (1024**3),
|
pt_total // (1024**3),
|
||||||
vram_total // (1024**3),
|
vram_total // (1024**3),
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user