""" Monkey-patch torch.cuda memory reporting for AMD APUs with unified memory. On Strix Halo (and other AMD APUs with dynamic VRAM), the BIOS carve-out for VRAM may be smaller than what the GPU can actually use. With the ``amdgpu.gttsize`` kernel parameter the GPU can dynamically claim host RAM through the GTT (Graphics Translation Table) pool, so the *effective* VRAM is much larger than the static carve-out reported by sysfs. Example with BIOS at 32 GB, gttsize=131072 (128 GiB): sysfs mem_info_vram_total → 32 GiB (static carve-out) torch.cuda.mem_get_info() total → 128 GiB (GTT pool — correct!) However in some configurations PyTorch may report only the static carve-out, and vLLM will refuse to load a model that exceeds it. In other configurations PyTorch may report the *full* GTT pool, which can be larger than actual physical RAM and cause the kernel OOM killer to fire during allocation. This module patches ``torch.cuda.mem_get_info()`` and ``torch.cuda.get_device_properties()`` so that the reported total equals the *effective* GPU memory: the smaller of the GTT pool size and the total physical RAM, minus a safety reserve for the OS. Installed as a .pth hook so it runs before any user code. IMPORTANT: The re-entry guard (_STRIXHALO_VRAM_FIX_ACTIVE env var) is set only during torch import to prevent infinite recursion when torch spawns offload-arch. It is CLEARED afterward so child processes (e.g. vLLM EngineCore subprocesses) can apply their own patch. """ import os import glob import logging logger = logging.getLogger("strixhalo_vram_fix") def _read_sysfs_int(path: str) -> int | None: try: with open(path) as f: return int(f.read().strip()) except (OSError, ValueError): return None def _get_total_physical_ram() -> int | None: """Read total physical RAM from /proc/meminfo (in bytes).""" try: with open("/proc/meminfo") as f: for line in f: if line.startswith("MemTotal:"): # MemTotal: 98765432 kB return int(line.split()[1]) * 1024 except (OSError, ValueError): pass return None def _get_gtt_size() -> int | None: """Read the GTT (Graphics Translation Table) pool size from sysfs.""" for card_dir in sorted(glob.glob("/sys/class/drm/card[0-9]*/device")): gtt_path = os.path.join(card_dir, "mem_info_gtt_total") val = _read_sysfs_int(gtt_path) if val is not None: return val return None def _get_sysfs_vram() -> tuple[int, int] | None: """Read sysfs VRAM total/used for the first AMD GPU.""" for card_dir in sorted(glob.glob("/sys/class/drm/card[0-9]*/device")): vendor_path = os.path.join(card_dir, "vendor") if not os.path.exists(vendor_path): continue try: with open(vendor_path) as f: vendor = f.read().strip() except OSError: continue if vendor != "0x1002": # AMD continue total = _read_sysfs_int(os.path.join(card_dir, "mem_info_vram_total")) used = _read_sysfs_int(os.path.join(card_dir, "mem_info_vram_used")) if total is not None and used is not None: return (total, used) return None # Safety reserve for OS + kubelet + system pods (4 GiB) _OS_RESERVE_BYTES = 4 * 1024**3 def _get_effective_vram() -> tuple[int, int] | None: """Compute the effective GPU memory for a unified-memory APU. On APUs with dynamic VRAM (GTT), the effective VRAM is: min(gtt_size, physical_ram) - os_reserve This prevents vLLM from trying to use more memory than physically exists (which would trigger the kernel OOM killer), while still allowing it to use more than the small BIOS carve-out. Returns (total, used) in bytes, or None if detection fails. """ sysfs = _get_sysfs_vram() if sysfs is None: return None sysfs_total, sysfs_used = sysfs gtt_total = _get_gtt_size() phys_ram = _get_total_physical_ram() if gtt_total is None or phys_ram is None: # Can't compute effective VRAM — fall back to sysfs return sysfs effective_total = min(gtt_total, phys_ram) - _OS_RESERVE_BYTES effective_total = max(effective_total, sysfs_total) # never below carve-out logger.debug( "strixhalo_vram_fix: sysfs_total=%d GiB, gtt=%d GiB, " "phys_ram=%d GiB, effective=%d GiB", sysfs_total // (1024**3), gtt_total // (1024**3), phys_ram // (1024**3), effective_total // (1024**3), ) return (effective_total, sysfs_used) _GUARD_ENV = "_STRIXHALO_VRAM_FIX_ACTIVE" def _should_skip() -> bool: """Check if we should skip the patch (re-entry guard, init containers).""" # Re-entry guard: importing torch triggers subprocess calls to # offload-arch (a Python script), which re-enters this .pth hook. # Without this guard it creates an infinite fork bomb. # NOTE: This is only set transiently during _apply_patch() and # cleared afterward — child processes will NOT see it. if os.environ.get(_GUARD_ENV): return True # Check cgroup memory limit — if under 512Mi, skip the expensive # torch/ROCm import. KubeRay's wait-gcs-ready init container has # only 256Mi and importing torch+ROCm would OOMKill it. for cgroup_mem_path in ( "/sys/fs/cgroup/memory.max", # cgroup v2 "/sys/fs/cgroup/memory/memory.limit_in_bytes", # cgroup v1 ): try: with open(cgroup_mem_path) as f: val = f.read().strip() if val != "max" and int(val) < 512 * 1024 * 1024: return True except (OSError, ValueError): continue return False class _VRAMDeviceProperties: """Proxy that overrides total_memory on torch device properties.""" def __init__(self, original, vram_total: int): object.__setattr__(self, "_original", original) object.__setattr__(self, "_vram_total", vram_total) @property def total_memory(self) -> int: return object.__getattribute__(self, "_vram_total") def __getattr__(self, name: str): return getattr(object.__getattribute__(self, "_original"), name) def __repr__(self) -> str: orig = object.__getattribute__(self, "_original") vram = object.__getattribute__(self, "_vram_total") return repr(orig).replace( f"total_memory={orig.total_memory}", f"total_memory={vram}", ) def _apply_patch() -> None: """Patch torch.cuda memory APIs for correct unified-memory reporting.""" if _should_skip(): return if _get_sysfs_vram() is None: return # Set guard BEFORE importing torch — torch init spawns offload-arch # (a Python script) which would re-enter this .pth hook without it. os.environ[_GUARD_ENV] = "1" try: import torch if not hasattr(torch, "cuda") or not torch.cuda.is_available(): return except (ImportError, OSError): return finally: # CRITICAL: Clear the guard so child processes (vLLM EngineCore # subprocesses, Ray actor workers, etc.) can apply their own patch. # The guard only needs to live during the torch import above to # prevent the offload-arch → .pth → torch import recursion. os.environ.pop(_GUARD_ENV, None) vram_info = _get_effective_vram() if vram_info is None: return effective_total, vram_used = vram_info # Only patch if PyTorch total differs significantly from effective VRAM try: pt_free, pt_total = torch.cuda.mem_get_info(0) except Exception: return # If they're within 10% of each other, no patch needed if abs(pt_total - effective_total) / max(pt_total, 1) < 0.10: return # --- Patch 1: torch.cuda.mem_get_info --- original_mem_get_info = torch.cuda.mem_get_info def _patched_mem_get_info(device=None): """Return effective VRAM instead of raw PyTorch numbers.""" info = _get_effective_vram() if info is None: return original_mem_get_info(device) total, used = info # Account for PyTorch's own allocations on top of sysfs baseline pt_allocated = torch.cuda.memory_allocated(device or 0) free = total - used - pt_allocated return (max(free, 0), total) torch.cuda.mem_get_info = _patched_mem_get_info # --- Patch 2: torch.cuda.get_device_properties --- # total_memory is a read-only C property, so we wrap the return value # in a proxy that overrides it with the effective VRAM total. original_get_device_properties = torch.cuda.get_device_properties def _patched_get_device_properties(device=None): props = original_get_device_properties(device) info = _get_effective_vram() if info is None: return props return _VRAMDeviceProperties(props, info[0]) torch.cuda.get_device_properties = _patched_get_device_properties logger.info( "strixhalo_vram_fix: patched torch.cuda.mem_get_info and " "get_device_properties (PyTorch reported %d GiB total, " "effective VRAM is %d GiB)", pt_total // (1024**3), effective_total // (1024**3), ) _apply_patch()