From e7642b86ddf100ee3edbe1b6ec2611e2d8b04c30 Mon Sep 17 00:00:00 2001 From: "Billy D." Date: Fri, 6 Feb 2026 16:29:46 -0500 Subject: [PATCH] feat(strixhalo): patch torch.cuda.mem_get_info for unified memory APU On Strix Halo, PyTorch reports GTT pool (128 GiB) as device memory instead of real VRAM (96 GiB from BIOS). vLLM uses mem_get_info() to pre-allocate and refuses to start when free GTT (29 GiB) < requested. The strixhalo_vram_fix.pth hook auto-patches mem_get_info on Python startup to read real VRAM total/used from /sys/class/drm sysfs. Only activates when PyTorch total differs >10% from sysfs VRAM. --- amdsmi-shim/strixhalo_vram_fix.py | 107 ++++++++++++++++++++ dockerfiles/Dockerfile.ray-worker-strixhalo | 10 ++ 2 files changed, 117 insertions(+) create mode 100644 amdsmi-shim/strixhalo_vram_fix.py diff --git a/amdsmi-shim/strixhalo_vram_fix.py b/amdsmi-shim/strixhalo_vram_fix.py new file mode 100644 index 0000000..4002289 --- /dev/null +++ b/amdsmi-shim/strixhalo_vram_fix.py @@ -0,0 +1,107 @@ +""" +Monkey-patch torch.cuda.mem_get_info for AMD APUs with unified memory. + +On Strix Halo (and other AMD APUs), torch.cuda.mem_get_info() returns +GTT (system RAM accessible to GPU) numbers instead of actual VRAM: + - Total: 128 GiB (GTT pool, NOT real VRAM) + - Free: 29 GiB (GTT free, NOT VRAM free) + +Meanwhile sysfs reports the real VRAM: + - Total: 96 GiB (actual dedicated VRAM from BIOS) + - Free: 95.8 GiB (actual VRAM free) + +vLLM uses torch.cuda.mem_get_info() to calculate how much memory to +pre-allocate. With wrong numbers it either OOMs or refuses to start. + +This module patches mem_get_info to return sysfs VRAM values, which +are the physically meaningful numbers for allocation decisions. + +Installed as a .pth hook so it runs before any user code. +""" + +import os +import glob +import logging + +logger = logging.getLogger("strixhalo_vram_fix") + + +def _read_sysfs_int(path: str) -> int | None: + try: + with open(path) as f: + return int(f.read().strip()) + except (OSError, ValueError): + return None + + +def _get_real_vram() -> tuple[int, int] | None: + """Read real VRAM total/used from sysfs for the first AMD GPU.""" + for card_dir in sorted(glob.glob("/sys/class/drm/card[0-9]*/device")): + vendor_path = os.path.join(card_dir, "vendor") + if not os.path.exists(vendor_path): + continue + try: + with open(vendor_path) as f: + vendor = f.read().strip() + except OSError: + continue + if vendor != "0x1002": # AMD + continue + + total = _read_sysfs_int(os.path.join(card_dir, "mem_info_vram_total")) + used = _read_sysfs_int(os.path.join(card_dir, "mem_info_vram_used")) + if total is not None and used is not None: + return (total, used) + return None + + +def _apply_patch() -> None: + """Patch torch.cuda.mem_get_info if we detect unified memory mis-reporting.""" + try: + import torch + if not hasattr(torch, "cuda") or not torch.cuda.is_available(): + return + except ImportError: + return + + vram_info = _get_real_vram() + if vram_info is None: + return + + vram_total, vram_used = vram_info + vram_free = vram_total - vram_used + + # Only patch if PyTorch total differs significantly from sysfs VRAM + # (i.e. PyTorch is reporting GTT/unified memory, not real VRAM) + try: + pt_free, pt_total = torch.cuda.mem_get_info(0) + except Exception: + return + + # If they're within 10% of each other, no patch needed + if abs(pt_total - vram_total) / max(pt_total, 1) < 0.10: + return + + original_fn = torch.cuda.mem_get_info + + def _patched_mem_get_info(device=None): + """Return real VRAM from sysfs instead of GTT numbers.""" + real = _get_real_vram() + if real is None: + return original_fn(device) + total, used = real + # Account for PyTorch's own allocations on top of sysfs baseline + pt_allocated = torch.cuda.memory_allocated(device or 0) + free = total - used - pt_allocated + return (max(free, 0), total) + + torch.cuda.mem_get_info = _patched_mem_get_info + logger.info( + "strixhalo_vram_fix: patched torch.cuda.mem_get_info " + "(PyTorch reported %d GiB total, sysfs VRAM is %d GiB)", + pt_total // (1024**3), + vram_total // (1024**3), + ) + + +_apply_patch() diff --git a/dockerfiles/Dockerfile.ray-worker-strixhalo b/dockerfiles/Dockerfile.ray-worker-strixhalo index 3db0bc2..eedb983 100644 --- a/dockerfiles/Dockerfile.ray-worker-strixhalo +++ b/dockerfiles/Dockerfile.ray-worker-strixhalo @@ -93,6 +93,16 @@ COPY --chown=1000:100 amdsmi-shim /tmp/amdsmi-shim RUN --mount=type=cache,target=/home/ray/.cache/uv,uid=1000,gid=1000 \ uv pip install --system /tmp/amdsmi-shim && rm -rf /tmp/amdsmi-shim +# FIX: Patch torch.cuda.mem_get_info for unified memory APUs. +# On Strix Halo, PyTorch reports GTT (128 GiB) instead of real VRAM (96 GiB) +# from sysfs. vLLM uses mem_get_info to pre-allocate, so wrong numbers cause +# OOM or "insufficient GPU memory" at startup. The .pth file auto-patches +# mem_get_info on Python startup to return sysfs VRAM values. +COPY --chown=1000:100 amdsmi-shim/strixhalo_vram_fix.py \ + /home/ray/anaconda3/lib/python3.11/site-packages/strixhalo_vram_fix.py +RUN echo "import strixhalo_vram_fix" > \ + /home/ray/anaconda3/lib/python3.11/site-packages/strixhalo_vram_fix.pth + # Pre-download common models for faster cold starts (optional, increases image size) # RUN python3 -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('BAAI/bge-large-en-v1.5')"