""" Monkey-patch torch.cuda.mem_get_info for AMD APUs with unified memory. On Strix Halo (and other AMD APUs), torch.cuda.mem_get_info() returns GTT (system RAM accessible to GPU) numbers instead of actual VRAM: - Total: 128 GiB (GTT pool, NOT real VRAM) - Free: 29 GiB (GTT free, NOT VRAM free) Meanwhile sysfs reports the real VRAM: - Total: 96 GiB (actual dedicated VRAM from BIOS) - Free: 95.8 GiB (actual VRAM free) vLLM uses torch.cuda.mem_get_info() to calculate how much memory to pre-allocate. With wrong numbers it either OOMs or refuses to start. This module patches mem_get_info to return sysfs VRAM values, which are the physically meaningful numbers for allocation decisions. Installed as a .pth hook so it runs before any user code. """ import os import glob import logging logger = logging.getLogger("strixhalo_vram_fix") def _read_sysfs_int(path: str) -> int | None: try: with open(path) as f: return int(f.read().strip()) except (OSError, ValueError): return None def _get_real_vram() -> tuple[int, int] | None: """Read real VRAM total/used from sysfs for the first AMD GPU.""" for card_dir in sorted(glob.glob("/sys/class/drm/card[0-9]*/device")): vendor_path = os.path.join(card_dir, "vendor") if not os.path.exists(vendor_path): continue try: with open(vendor_path) as f: vendor = f.read().strip() except OSError: continue if vendor != "0x1002": # AMD continue total = _read_sysfs_int(os.path.join(card_dir, "mem_info_vram_total")) used = _read_sysfs_int(os.path.join(card_dir, "mem_info_vram_used")) if total is not None and used is not None: return (total, used) return None def _should_skip() -> bool: """Check if we should skip the patch (lightweight/init containers).""" # Check cgroup memory limit — if under 512Mi, skip the expensive # torch/ROCm import. KubeRay's wait-gcs-ready init container has # only 256Mi and importing torch+ROCm would OOMKill it. for cgroup_mem_path in ( "/sys/fs/cgroup/memory.max", # cgroup v2 "/sys/fs/cgroup/memory/memory.limit_in_bytes", # cgroup v1 ): try: with open(cgroup_mem_path) as f: val = f.read().strip() if val != "max" and int(val) < 512 * 1024 * 1024: return True except (OSError, ValueError): continue return False def _apply_patch() -> None: """Patch torch.cuda.mem_get_info if we detect unified memory mis-reporting.""" if _should_skip(): return if _get_real_vram() is None: return try: import torch if not hasattr(torch, "cuda") or not torch.cuda.is_available(): return except ImportError: return vram_info = _get_real_vram() if vram_info is None: return vram_total, vram_used = vram_info vram_free = vram_total - vram_used # Only patch if PyTorch total differs significantly from sysfs VRAM # (i.e. PyTorch is reporting GTT/unified memory, not real VRAM) try: pt_free, pt_total = torch.cuda.mem_get_info(0) except Exception: return # If they're within 10% of each other, no patch needed if abs(pt_total - vram_total) / max(pt_total, 1) < 0.10: return original_fn = torch.cuda.mem_get_info def _patched_mem_get_info(device=None): """Return real VRAM from sysfs instead of GTT numbers.""" real = _get_real_vram() if real is None: return original_fn(device) total, used = real # Account for PyTorch's own allocations on top of sysfs baseline pt_allocated = torch.cuda.memory_allocated(device or 0) free = total - used - pt_allocated return (max(free, 0), total) torch.cuda.mem_get_info = _patched_mem_get_info logger.info( "strixhalo_vram_fix: patched torch.cuda.mem_get_info " "(PyTorch reported %d GiB total, sysfs VRAM is %d GiB)", pt_total // (1024**3), vram_total // (1024**3), ) _apply_patch()