diff --git a/amdsmi-shim/strixhalo_vram_fix.py b/amdsmi-shim/strixhalo_vram_fix.py index daf7bb0..946a9a3 100644 --- a/amdsmi-shim/strixhalo_vram_fix.py +++ b/amdsmi-shim/strixhalo_vram_fix.py @@ -1,22 +1,28 @@ """ -Monkey-patch torch.cuda.mem_get_info for AMD APUs with unified memory. +Monkey-patch torch.cuda memory reporting for AMD APUs with unified memory. -On Strix Halo (and other AMD APUs), torch.cuda.mem_get_info() returns -GTT (system RAM accessible to GPU) numbers instead of actual VRAM: - - Total: 128 GiB (GTT pool, NOT real VRAM) - - Free: 29 GiB (GTT free, NOT VRAM free) +On Strix Halo (and other AMD APUs), PyTorch reports GTT (system RAM +accessible to GPU) instead of actual VRAM: + + torch.cuda.mem_get_info() → (29 GiB free, 128 GiB total) WRONG + torch.cuda.get_device_properties().total_memory → 128 GiB WRONG Meanwhile sysfs reports the real VRAM: - - Total: 96 GiB (actual dedicated VRAM from BIOS) - - Free: 95.8 GiB (actual VRAM free) + /sys/class/drm/cardN/device/mem_info_vram_total → 96 GiB + /sys/class/drm/cardN/device/mem_info_vram_used → ~0.2 GiB -vLLM uses torch.cuda.mem_get_info() to calculate how much memory to -pre-allocate. With wrong numbers it either OOMs or refuses to start. +vLLM uses mem_get_info() and get_device_properties() to decide how much +memory to pre-allocate. With wrong numbers it either OOMs or refuses to +start ("Free memory less than desired GPU memory utilization"). -This module patches mem_get_info to return sysfs VRAM values, which -are the physically meaningful numbers for allocation decisions. +This module patches both APIs to return sysfs VRAM values instead. Installed as a .pth hook so it runs before any user code. + +IMPORTANT: The re-entry guard (_STRIXHALO_VRAM_FIX_ACTIVE env var) is +set only during torch import to prevent infinite recursion when torch +spawns offload-arch. It is CLEARED afterward so child processes (e.g. +vLLM EngineCore subprocesses) can apply their own patch. """ import os @@ -63,6 +69,8 @@ def _should_skip() -> bool: # Re-entry guard: importing torch triggers subprocess calls to # offload-arch (a Python script), which re-enters this .pth hook. # Without this guard it creates an infinite fork bomb. + # NOTE: This is only set transiently during _apply_patch() and + # cleared afterward — child processes will NOT see it. if os.environ.get(_GUARD_ENV): return True @@ -83,8 +91,31 @@ def _should_skip() -> bool: return False +class _VRAMDeviceProperties: + """Proxy that overrides total_memory on torch device properties.""" + + def __init__(self, original, vram_total: int): + object.__setattr__(self, "_original", original) + object.__setattr__(self, "_vram_total", vram_total) + + @property + def total_memory(self) -> int: + return object.__getattribute__(self, "_vram_total") + + def __getattr__(self, name: str): + return getattr(object.__getattribute__(self, "_original"), name) + + def __repr__(self) -> str: + orig = object.__getattribute__(self, "_original") + vram = object.__getattribute__(self, "_vram_total") + return repr(orig).replace( + f"total_memory={orig.total_memory}", + f"total_memory={vram}", + ) + + def _apply_patch() -> None: - """Patch torch.cuda.mem_get_info if we detect unified memory mis-reporting.""" + """Patch torch.cuda memory APIs if we detect unified memory mis-reporting.""" if _should_skip(): return @@ -101,13 +132,18 @@ def _apply_patch() -> None: return except ImportError: return + finally: + # CRITICAL: Clear the guard so child processes (vLLM EngineCore + # subprocesses, Ray actor workers, etc.) can apply their own patch. + # The guard only needs to live during the torch import above to + # prevent the offload-arch → .pth → torch import recursion. + os.environ.pop(_GUARD_ENV, None) vram_info = _get_real_vram() if vram_info is None: return vram_total, vram_used = vram_info - vram_free = vram_total - vram_used # Only patch if PyTorch total differs significantly from sysfs VRAM # (i.e. PyTorch is reporting GTT/unified memory, not real VRAM) @@ -120,13 +156,14 @@ def _apply_patch() -> None: if abs(pt_total - vram_total) / max(pt_total, 1) < 0.10: return - original_fn = torch.cuda.mem_get_info + # --- Patch 1: torch.cuda.mem_get_info --- + original_mem_get_info = torch.cuda.mem_get_info def _patched_mem_get_info(device=None): """Return real VRAM from sysfs instead of GTT numbers.""" real = _get_real_vram() if real is None: - return original_fn(device) + return original_mem_get_info(device) total, used = real # Account for PyTorch's own allocations on top of sysfs baseline pt_allocated = torch.cuda.memory_allocated(device or 0) @@ -134,9 +171,25 @@ def _apply_patch() -> None: return (max(free, 0), total) torch.cuda.mem_get_info = _patched_mem_get_info + + # --- Patch 2: torch.cuda.get_device_properties --- + # total_memory is a read-only C property, so we wrap the return value + # in a proxy that overrides it with the real VRAM total. + original_get_device_properties = torch.cuda.get_device_properties + + def _patched_get_device_properties(device=None): + props = original_get_device_properties(device) + real = _get_real_vram() + if real is None: + return props + return _VRAMDeviceProperties(props, real[0]) + + torch.cuda.get_device_properties = _patched_get_device_properties + logger.info( - "strixhalo_vram_fix: patched torch.cuda.mem_get_info " - "(PyTorch reported %d GiB total, sysfs VRAM is %d GiB)", + "strixhalo_vram_fix: patched torch.cuda.mem_get_info and " + "get_device_properties (PyTorch reported %d GiB total, " + "sysfs VRAM is %d GiB)", pt_total // (1024**3), vram_total // (1024**3), )