fixing vllm.
Some checks failed
Build and Push Images / determine-version (push) Successful in 5s
Build and Push Images / build (Dockerfile.ray-worker-nvidia, nvidia) (push) Failing after 39s
Build and Push Images / build (Dockerfile.ray-worker-intel, intel) (push) Failing after 42s
Build and Push Images / build (Dockerfile.ray-worker-strixhalo, strixhalo) (push) Failing after 20s
Build and Push Images / build (Dockerfile.ray-worker-rdna2, rdna2) (push) Failing after 23s
Build and Push Images / Release (push) Has been skipped
Build and Push Images / Notify (push) Successful in 1s
Some checks failed
Build and Push Images / determine-version (push) Successful in 5s
Build and Push Images / build (Dockerfile.ray-worker-nvidia, nvidia) (push) Failing after 39s
Build and Push Images / build (Dockerfile.ray-worker-intel, intel) (push) Failing after 42s
Build and Push Images / build (Dockerfile.ray-worker-strixhalo, strixhalo) (push) Failing after 20s
Build and Push Images / build (Dockerfile.ray-worker-rdna2, rdna2) (push) Failing after 23s
Build and Push Images / Release (push) Has been skipped
Build and Push Images / Notify (push) Successful in 1s
This commit is contained in:
@@ -1,22 +1,28 @@
|
||||
"""
|
||||
Monkey-patch torch.cuda.mem_get_info for AMD APUs with unified memory.
|
||||
Monkey-patch torch.cuda memory reporting for AMD APUs with unified memory.
|
||||
|
||||
On Strix Halo (and other AMD APUs), torch.cuda.mem_get_info() returns
|
||||
GTT (system RAM accessible to GPU) numbers instead of actual VRAM:
|
||||
- Total: 128 GiB (GTT pool, NOT real VRAM)
|
||||
- Free: 29 GiB (GTT free, NOT VRAM free)
|
||||
On Strix Halo (and other AMD APUs), PyTorch reports GTT (system RAM
|
||||
accessible to GPU) instead of actual VRAM:
|
||||
|
||||
torch.cuda.mem_get_info() → (29 GiB free, 128 GiB total) WRONG
|
||||
torch.cuda.get_device_properties().total_memory → 128 GiB WRONG
|
||||
|
||||
Meanwhile sysfs reports the real VRAM:
|
||||
- Total: 96 GiB (actual dedicated VRAM from BIOS)
|
||||
- Free: 95.8 GiB (actual VRAM free)
|
||||
/sys/class/drm/cardN/device/mem_info_vram_total → 96 GiB
|
||||
/sys/class/drm/cardN/device/mem_info_vram_used → ~0.2 GiB
|
||||
|
||||
vLLM uses torch.cuda.mem_get_info() to calculate how much memory to
|
||||
pre-allocate. With wrong numbers it either OOMs or refuses to start.
|
||||
vLLM uses mem_get_info() and get_device_properties() to decide how much
|
||||
memory to pre-allocate. With wrong numbers it either OOMs or refuses to
|
||||
start ("Free memory less than desired GPU memory utilization").
|
||||
|
||||
This module patches mem_get_info to return sysfs VRAM values, which
|
||||
are the physically meaningful numbers for allocation decisions.
|
||||
This module patches both APIs to return sysfs VRAM values instead.
|
||||
|
||||
Installed as a .pth hook so it runs before any user code.
|
||||
|
||||
IMPORTANT: The re-entry guard (_STRIXHALO_VRAM_FIX_ACTIVE env var) is
|
||||
set only during torch import to prevent infinite recursion when torch
|
||||
spawns offload-arch. It is CLEARED afterward so child processes (e.g.
|
||||
vLLM EngineCore subprocesses) can apply their own patch.
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -63,6 +69,8 @@ def _should_skip() -> bool:
|
||||
# Re-entry guard: importing torch triggers subprocess calls to
|
||||
# offload-arch (a Python script), which re-enters this .pth hook.
|
||||
# Without this guard it creates an infinite fork bomb.
|
||||
# NOTE: This is only set transiently during _apply_patch() and
|
||||
# cleared afterward — child processes will NOT see it.
|
||||
if os.environ.get(_GUARD_ENV):
|
||||
return True
|
||||
|
||||
@@ -83,8 +91,31 @@ def _should_skip() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class _VRAMDeviceProperties:
|
||||
"""Proxy that overrides total_memory on torch device properties."""
|
||||
|
||||
def __init__(self, original, vram_total: int):
|
||||
object.__setattr__(self, "_original", original)
|
||||
object.__setattr__(self, "_vram_total", vram_total)
|
||||
|
||||
@property
|
||||
def total_memory(self) -> int:
|
||||
return object.__getattribute__(self, "_vram_total")
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
return getattr(object.__getattribute__(self, "_original"), name)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
orig = object.__getattribute__(self, "_original")
|
||||
vram = object.__getattribute__(self, "_vram_total")
|
||||
return repr(orig).replace(
|
||||
f"total_memory={orig.total_memory}",
|
||||
f"total_memory={vram}",
|
||||
)
|
||||
|
||||
|
||||
def _apply_patch() -> None:
|
||||
"""Patch torch.cuda.mem_get_info if we detect unified memory mis-reporting."""
|
||||
"""Patch torch.cuda memory APIs if we detect unified memory mis-reporting."""
|
||||
if _should_skip():
|
||||
return
|
||||
|
||||
@@ -101,13 +132,18 @@ def _apply_patch() -> None:
|
||||
return
|
||||
except ImportError:
|
||||
return
|
||||
finally:
|
||||
# CRITICAL: Clear the guard so child processes (vLLM EngineCore
|
||||
# subprocesses, Ray actor workers, etc.) can apply their own patch.
|
||||
# The guard only needs to live during the torch import above to
|
||||
# prevent the offload-arch → .pth → torch import recursion.
|
||||
os.environ.pop(_GUARD_ENV, None)
|
||||
|
||||
vram_info = _get_real_vram()
|
||||
if vram_info is None:
|
||||
return
|
||||
|
||||
vram_total, vram_used = vram_info
|
||||
vram_free = vram_total - vram_used
|
||||
|
||||
# Only patch if PyTorch total differs significantly from sysfs VRAM
|
||||
# (i.e. PyTorch is reporting GTT/unified memory, not real VRAM)
|
||||
@@ -120,13 +156,14 @@ def _apply_patch() -> None:
|
||||
if abs(pt_total - vram_total) / max(pt_total, 1) < 0.10:
|
||||
return
|
||||
|
||||
original_fn = torch.cuda.mem_get_info
|
||||
# --- Patch 1: torch.cuda.mem_get_info ---
|
||||
original_mem_get_info = torch.cuda.mem_get_info
|
||||
|
||||
def _patched_mem_get_info(device=None):
|
||||
"""Return real VRAM from sysfs instead of GTT numbers."""
|
||||
real = _get_real_vram()
|
||||
if real is None:
|
||||
return original_fn(device)
|
||||
return original_mem_get_info(device)
|
||||
total, used = real
|
||||
# Account for PyTorch's own allocations on top of sysfs baseline
|
||||
pt_allocated = torch.cuda.memory_allocated(device or 0)
|
||||
@@ -134,9 +171,25 @@ def _apply_patch() -> None:
|
||||
return (max(free, 0), total)
|
||||
|
||||
torch.cuda.mem_get_info = _patched_mem_get_info
|
||||
|
||||
# --- Patch 2: torch.cuda.get_device_properties ---
|
||||
# total_memory is a read-only C property, so we wrap the return value
|
||||
# in a proxy that overrides it with the real VRAM total.
|
||||
original_get_device_properties = torch.cuda.get_device_properties
|
||||
|
||||
def _patched_get_device_properties(device=None):
|
||||
props = original_get_device_properties(device)
|
||||
real = _get_real_vram()
|
||||
if real is None:
|
||||
return props
|
||||
return _VRAMDeviceProperties(props, real[0])
|
||||
|
||||
torch.cuda.get_device_properties = _patched_get_device_properties
|
||||
|
||||
logger.info(
|
||||
"strixhalo_vram_fix: patched torch.cuda.mem_get_info "
|
||||
"(PyTorch reported %d GiB total, sysfs VRAM is %d GiB)",
|
||||
"strixhalo_vram_fix: patched torch.cuda.mem_get_info and "
|
||||
"get_device_properties (PyTorch reported %d GiB total, "
|
||||
"sysfs VRAM is %d GiB)",
|
||||
pt_total // (1024**3),
|
||||
vram_total // (1024**3),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user