v1.0.26: dynamic VRAM via GTT for 32GB carve-out
Some checks failed
Build and Push Images / build (Dockerfile.ray-worker-intel, intel) (push) Has been cancelled
Build and Push Images / build (Dockerfile.ray-worker-nvidia, nvidia) (push) Has been cancelled
Build and Push Images / build (Dockerfile.ray-worker-rdna2, rdna2) (push) Has been cancelled
Build and Push Images / build (Dockerfile.ray-worker-strixhalo, strixhalo) (push) Has been cancelled
Build and Push Images / Release (push) Has been cancelled
Build and Push Images / Notify (push) Has been cancelled
Build and Push Images / determine-version (push) Has been cancelled
Some checks failed
Build and Push Images / build (Dockerfile.ray-worker-intel, intel) (push) Has been cancelled
Build and Push Images / build (Dockerfile.ray-worker-nvidia, nvidia) (push) Has been cancelled
Build and Push Images / build (Dockerfile.ray-worker-rdna2, rdna2) (push) Has been cancelled
Build and Push Images / build (Dockerfile.ray-worker-strixhalo, strixhalo) (push) Has been cancelled
Build and Push Images / Release (push) Has been cancelled
Build and Push Images / Notify (push) Has been cancelled
Build and Push Images / determine-version (push) Has been cancelled
strixhalo_vram_fix.py: compute effective VRAM as min(GTT_pool, physical_RAM) - 4GB OS reserve instead of raw sysfs VRAM. Prevents OOM when carve-out < model size and prevents kernel OOM when GTT > physical RAM.
This commit is contained in:
@@ -1,21 +1,27 @@
|
||||
"""
|
||||
Monkey-patch torch.cuda memory reporting for AMD APUs with unified memory.
|
||||
|
||||
On Strix Halo (and other AMD APUs), PyTorch reports GTT (system RAM
|
||||
accessible to GPU) instead of actual VRAM:
|
||||
On Strix Halo (and other AMD APUs with dynamic VRAM), the BIOS carve-out
|
||||
for VRAM may be smaller than what the GPU can actually use. With the
|
||||
``amdgpu.gttsize`` kernel parameter the GPU can dynamically claim host
|
||||
RAM through the GTT (Graphics Translation Table) pool, so the *effective*
|
||||
VRAM is much larger than the static carve-out reported by sysfs.
|
||||
|
||||
torch.cuda.mem_get_info() → (29 GiB free, 128 GiB total) WRONG
|
||||
torch.cuda.get_device_properties().total_memory → 128 GiB WRONG
|
||||
Example with BIOS at 32 GB, gttsize=131072 (128 GiB):
|
||||
|
||||
Meanwhile sysfs reports the real VRAM:
|
||||
/sys/class/drm/cardN/device/mem_info_vram_total → 96 GiB
|
||||
/sys/class/drm/cardN/device/mem_info_vram_used → ~0.2 GiB
|
||||
sysfs mem_info_vram_total → 32 GiB (static carve-out)
|
||||
torch.cuda.mem_get_info() total → 128 GiB (GTT pool — correct!)
|
||||
|
||||
vLLM uses mem_get_info() and get_device_properties() to decide how much
|
||||
memory to pre-allocate. With wrong numbers it either OOMs or refuses to
|
||||
start ("Free memory less than desired GPU memory utilization").
|
||||
However in some configurations PyTorch may report only the static
|
||||
carve-out, and vLLM will refuse to load a model that exceeds it. In
|
||||
other configurations PyTorch may report the *full* GTT pool, which can
|
||||
be larger than actual physical RAM and cause the kernel OOM killer to
|
||||
fire during allocation.
|
||||
|
||||
This module patches both APIs to return sysfs VRAM values instead.
|
||||
This module patches ``torch.cuda.mem_get_info()`` and
|
||||
``torch.cuda.get_device_properties()`` so that the reported total equals
|
||||
the *effective* GPU memory: the smaller of the GTT pool size and the
|
||||
total physical RAM, minus a safety reserve for the OS.
|
||||
|
||||
Installed as a .pth hook so it runs before any user code.
|
||||
|
||||
@@ -40,8 +46,31 @@ def _read_sysfs_int(path: str) -> int | None:
|
||||
return None
|
||||
|
||||
|
||||
def _get_real_vram() -> tuple[int, int] | None:
|
||||
"""Read real VRAM total/used from sysfs for the first AMD GPU."""
|
||||
def _get_total_physical_ram() -> int | None:
|
||||
"""Read total physical RAM from /proc/meminfo (in bytes)."""
|
||||
try:
|
||||
with open("/proc/meminfo") as f:
|
||||
for line in f:
|
||||
if line.startswith("MemTotal:"):
|
||||
# MemTotal: 98765432 kB
|
||||
return int(line.split()[1]) * 1024
|
||||
except (OSError, ValueError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _get_gtt_size() -> int | None:
|
||||
"""Read the GTT (Graphics Translation Table) pool size from sysfs."""
|
||||
for card_dir in sorted(glob.glob("/sys/class/drm/card[0-9]*/device")):
|
||||
gtt_path = os.path.join(card_dir, "mem_info_gtt_total")
|
||||
val = _read_sysfs_int(gtt_path)
|
||||
if val is not None:
|
||||
return val
|
||||
return None
|
||||
|
||||
|
||||
def _get_sysfs_vram() -> tuple[int, int] | None:
|
||||
"""Read sysfs VRAM total/used for the first AMD GPU."""
|
||||
for card_dir in sorted(glob.glob("/sys/class/drm/card[0-9]*/device")):
|
||||
vendor_path = os.path.join(card_dir, "vendor")
|
||||
if not os.path.exists(vendor_path):
|
||||
@@ -61,6 +90,50 @@ def _get_real_vram() -> tuple[int, int] | None:
|
||||
return None
|
||||
|
||||
|
||||
# Safety reserve for OS + kubelet + system pods (4 GiB)
|
||||
_OS_RESERVE_BYTES = 4 * 1024**3
|
||||
|
||||
|
||||
def _get_effective_vram() -> tuple[int, int] | None:
|
||||
"""Compute the effective GPU memory for a unified-memory APU.
|
||||
|
||||
On APUs with dynamic VRAM (GTT), the effective VRAM is:
|
||||
|
||||
min(gtt_size, physical_ram) - os_reserve
|
||||
|
||||
This prevents vLLM from trying to use more memory than physically
|
||||
exists (which would trigger the kernel OOM killer), while still
|
||||
allowing it to use more than the small BIOS carve-out.
|
||||
|
||||
Returns (total, used) in bytes, or None if detection fails.
|
||||
"""
|
||||
sysfs = _get_sysfs_vram()
|
||||
if sysfs is None:
|
||||
return None
|
||||
|
||||
sysfs_total, sysfs_used = sysfs
|
||||
gtt_total = _get_gtt_size()
|
||||
phys_ram = _get_total_physical_ram()
|
||||
|
||||
if gtt_total is None or phys_ram is None:
|
||||
# Can't compute effective VRAM — fall back to sysfs
|
||||
return sysfs
|
||||
|
||||
effective_total = min(gtt_total, phys_ram) - _OS_RESERVE_BYTES
|
||||
effective_total = max(effective_total, sysfs_total) # never below carve-out
|
||||
|
||||
logger.debug(
|
||||
"strixhalo_vram_fix: sysfs_total=%d GiB, gtt=%d GiB, "
|
||||
"phys_ram=%d GiB, effective=%d GiB",
|
||||
sysfs_total // (1024**3),
|
||||
gtt_total // (1024**3),
|
||||
phys_ram // (1024**3),
|
||||
effective_total // (1024**3),
|
||||
)
|
||||
|
||||
return (effective_total, sysfs_used)
|
||||
|
||||
|
||||
_GUARD_ENV = "_STRIXHALO_VRAM_FIX_ACTIVE"
|
||||
|
||||
|
||||
@@ -115,11 +188,11 @@ class _VRAMDeviceProperties:
|
||||
|
||||
|
||||
def _apply_patch() -> None:
|
||||
"""Patch torch.cuda memory APIs if we detect unified memory mis-reporting."""
|
||||
"""Patch torch.cuda memory APIs for correct unified-memory reporting."""
|
||||
if _should_skip():
|
||||
return
|
||||
|
||||
if _get_real_vram() is None:
|
||||
if _get_sysfs_vram() is None:
|
||||
return
|
||||
|
||||
# Set guard BEFORE importing torch — torch init spawns offload-arch
|
||||
@@ -139,32 +212,31 @@ def _apply_patch() -> None:
|
||||
# prevent the offload-arch → .pth → torch import recursion.
|
||||
os.environ.pop(_GUARD_ENV, None)
|
||||
|
||||
vram_info = _get_real_vram()
|
||||
vram_info = _get_effective_vram()
|
||||
if vram_info is None:
|
||||
return
|
||||
|
||||
vram_total, vram_used = vram_info
|
||||
effective_total, vram_used = vram_info
|
||||
|
||||
# Only patch if PyTorch total differs significantly from sysfs VRAM
|
||||
# (i.e. PyTorch is reporting GTT/unified memory, not real VRAM)
|
||||
# Only patch if PyTorch total differs significantly from effective VRAM
|
||||
try:
|
||||
pt_free, pt_total = torch.cuda.mem_get_info(0)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
# If they're within 10% of each other, no patch needed
|
||||
if abs(pt_total - vram_total) / max(pt_total, 1) < 0.10:
|
||||
if abs(pt_total - effective_total) / max(pt_total, 1) < 0.10:
|
||||
return
|
||||
|
||||
# --- Patch 1: torch.cuda.mem_get_info ---
|
||||
original_mem_get_info = torch.cuda.mem_get_info
|
||||
|
||||
def _patched_mem_get_info(device=None):
|
||||
"""Return real VRAM from sysfs instead of GTT numbers."""
|
||||
real = _get_real_vram()
|
||||
if real is None:
|
||||
"""Return effective VRAM instead of raw PyTorch numbers."""
|
||||
info = _get_effective_vram()
|
||||
if info is None:
|
||||
return original_mem_get_info(device)
|
||||
total, used = real
|
||||
total, used = info
|
||||
# Account for PyTorch's own allocations on top of sysfs baseline
|
||||
pt_allocated = torch.cuda.memory_allocated(device or 0)
|
||||
free = total - used - pt_allocated
|
||||
@@ -174,24 +246,24 @@ def _apply_patch() -> None:
|
||||
|
||||
# --- Patch 2: torch.cuda.get_device_properties ---
|
||||
# total_memory is a read-only C property, so we wrap the return value
|
||||
# in a proxy that overrides it with the real VRAM total.
|
||||
# in a proxy that overrides it with the effective VRAM total.
|
||||
original_get_device_properties = torch.cuda.get_device_properties
|
||||
|
||||
def _patched_get_device_properties(device=None):
|
||||
props = original_get_device_properties(device)
|
||||
real = _get_real_vram()
|
||||
if real is None:
|
||||
info = _get_effective_vram()
|
||||
if info is None:
|
||||
return props
|
||||
return _VRAMDeviceProperties(props, real[0])
|
||||
return _VRAMDeviceProperties(props, info[0])
|
||||
|
||||
torch.cuda.get_device_properties = _patched_get_device_properties
|
||||
|
||||
logger.info(
|
||||
"strixhalo_vram_fix: patched torch.cuda.mem_get_info and "
|
||||
"get_device_properties (PyTorch reported %d GiB total, "
|
||||
"sysfs VRAM is %d GiB)",
|
||||
"effective VRAM is %d GiB)",
|
||||
pt_total // (1024**3),
|
||||
vram_total // (1024**3),
|
||||
effective_total // (1024**3),
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user