diff --git a/amdsmi-shim/strixhalo_vram_fix.py b/amdsmi-shim/strixhalo_vram_fix.py index f5417b7..52cb2d5 100644 --- a/amdsmi-shim/strixhalo_vram_fix.py +++ b/amdsmi-shim/strixhalo_vram_fix.py @@ -1,21 +1,27 @@ """ Monkey-patch torch.cuda memory reporting for AMD APUs with unified memory. -On Strix Halo (and other AMD APUs), PyTorch reports GTT (system RAM -accessible to GPU) instead of actual VRAM: +On Strix Halo (and other AMD APUs with dynamic VRAM), the BIOS carve-out +for VRAM may be smaller than what the GPU can actually use. With the +``amdgpu.gttsize`` kernel parameter the GPU can dynamically claim host +RAM through the GTT (Graphics Translation Table) pool, so the *effective* +VRAM is much larger than the static carve-out reported by sysfs. - torch.cuda.mem_get_info() → (29 GiB free, 128 GiB total) WRONG - torch.cuda.get_device_properties().total_memory → 128 GiB WRONG +Example with BIOS at 32 GB, gttsize=131072 (128 GiB): -Meanwhile sysfs reports the real VRAM: - /sys/class/drm/cardN/device/mem_info_vram_total → 96 GiB - /sys/class/drm/cardN/device/mem_info_vram_used → ~0.2 GiB + sysfs mem_info_vram_total → 32 GiB (static carve-out) + torch.cuda.mem_get_info() total → 128 GiB (GTT pool — correct!) -vLLM uses mem_get_info() and get_device_properties() to decide how much -memory to pre-allocate. With wrong numbers it either OOMs or refuses to -start ("Free memory less than desired GPU memory utilization"). +However in some configurations PyTorch may report only the static +carve-out, and vLLM will refuse to load a model that exceeds it. In +other configurations PyTorch may report the *full* GTT pool, which can +be larger than actual physical RAM and cause the kernel OOM killer to +fire during allocation. -This module patches both APIs to return sysfs VRAM values instead. +This module patches ``torch.cuda.mem_get_info()`` and +``torch.cuda.get_device_properties()`` so that the reported total equals +the *effective* GPU memory: the smaller of the GTT pool size and the +total physical RAM, minus a safety reserve for the OS. Installed as a .pth hook so it runs before any user code. @@ -40,8 +46,31 @@ def _read_sysfs_int(path: str) -> int | None: return None -def _get_real_vram() -> tuple[int, int] | None: - """Read real VRAM total/used from sysfs for the first AMD GPU.""" +def _get_total_physical_ram() -> int | None: + """Read total physical RAM from /proc/meminfo (in bytes).""" + try: + with open("/proc/meminfo") as f: + for line in f: + if line.startswith("MemTotal:"): + # MemTotal: 98765432 kB + return int(line.split()[1]) * 1024 + except (OSError, ValueError): + pass + return None + + +def _get_gtt_size() -> int | None: + """Read the GTT (Graphics Translation Table) pool size from sysfs.""" + for card_dir in sorted(glob.glob("/sys/class/drm/card[0-9]*/device")): + gtt_path = os.path.join(card_dir, "mem_info_gtt_total") + val = _read_sysfs_int(gtt_path) + if val is not None: + return val + return None + + +def _get_sysfs_vram() -> tuple[int, int] | None: + """Read sysfs VRAM total/used for the first AMD GPU.""" for card_dir in sorted(glob.glob("/sys/class/drm/card[0-9]*/device")): vendor_path = os.path.join(card_dir, "vendor") if not os.path.exists(vendor_path): @@ -61,6 +90,50 @@ def _get_real_vram() -> tuple[int, int] | None: return None +# Safety reserve for OS + kubelet + system pods (4 GiB) +_OS_RESERVE_BYTES = 4 * 1024**3 + + +def _get_effective_vram() -> tuple[int, int] | None: + """Compute the effective GPU memory for a unified-memory APU. + + On APUs with dynamic VRAM (GTT), the effective VRAM is: + + min(gtt_size, physical_ram) - os_reserve + + This prevents vLLM from trying to use more memory than physically + exists (which would trigger the kernel OOM killer), while still + allowing it to use more than the small BIOS carve-out. + + Returns (total, used) in bytes, or None if detection fails. + """ + sysfs = _get_sysfs_vram() + if sysfs is None: + return None + + sysfs_total, sysfs_used = sysfs + gtt_total = _get_gtt_size() + phys_ram = _get_total_physical_ram() + + if gtt_total is None or phys_ram is None: + # Can't compute effective VRAM — fall back to sysfs + return sysfs + + effective_total = min(gtt_total, phys_ram) - _OS_RESERVE_BYTES + effective_total = max(effective_total, sysfs_total) # never below carve-out + + logger.debug( + "strixhalo_vram_fix: sysfs_total=%d GiB, gtt=%d GiB, " + "phys_ram=%d GiB, effective=%d GiB", + sysfs_total // (1024**3), + gtt_total // (1024**3), + phys_ram // (1024**3), + effective_total // (1024**3), + ) + + return (effective_total, sysfs_used) + + _GUARD_ENV = "_STRIXHALO_VRAM_FIX_ACTIVE" @@ -115,11 +188,11 @@ class _VRAMDeviceProperties: def _apply_patch() -> None: - """Patch torch.cuda memory APIs if we detect unified memory mis-reporting.""" + """Patch torch.cuda memory APIs for correct unified-memory reporting.""" if _should_skip(): return - if _get_real_vram() is None: + if _get_sysfs_vram() is None: return # Set guard BEFORE importing torch — torch init spawns offload-arch @@ -139,32 +212,31 @@ def _apply_patch() -> None: # prevent the offload-arch → .pth → torch import recursion. os.environ.pop(_GUARD_ENV, None) - vram_info = _get_real_vram() + vram_info = _get_effective_vram() if vram_info is None: return - vram_total, vram_used = vram_info + effective_total, vram_used = vram_info - # Only patch if PyTorch total differs significantly from sysfs VRAM - # (i.e. PyTorch is reporting GTT/unified memory, not real VRAM) + # Only patch if PyTorch total differs significantly from effective VRAM try: pt_free, pt_total = torch.cuda.mem_get_info(0) except Exception: return # If they're within 10% of each other, no patch needed - if abs(pt_total - vram_total) / max(pt_total, 1) < 0.10: + if abs(pt_total - effective_total) / max(pt_total, 1) < 0.10: return # --- Patch 1: torch.cuda.mem_get_info --- original_mem_get_info = torch.cuda.mem_get_info def _patched_mem_get_info(device=None): - """Return real VRAM from sysfs instead of GTT numbers.""" - real = _get_real_vram() - if real is None: + """Return effective VRAM instead of raw PyTorch numbers.""" + info = _get_effective_vram() + if info is None: return original_mem_get_info(device) - total, used = real + total, used = info # Account for PyTorch's own allocations on top of sysfs baseline pt_allocated = torch.cuda.memory_allocated(device or 0) free = total - used - pt_allocated @@ -174,24 +246,24 @@ def _apply_patch() -> None: # --- Patch 2: torch.cuda.get_device_properties --- # total_memory is a read-only C property, so we wrap the return value - # in a proxy that overrides it with the real VRAM total. + # in a proxy that overrides it with the effective VRAM total. original_get_device_properties = torch.cuda.get_device_properties def _patched_get_device_properties(device=None): props = original_get_device_properties(device) - real = _get_real_vram() - if real is None: + info = _get_effective_vram() + if info is None: return props - return _VRAMDeviceProperties(props, real[0]) + return _VRAMDeviceProperties(props, info[0]) torch.cuda.get_device_properties = _patched_get_device_properties logger.info( "strixhalo_vram_fix: patched torch.cuda.mem_get_info and " "get_device_properties (PyTorch reported %d GiB total, " - "sysfs VRAM is %d GiB)", + "effective VRAM is %d GiB)", pt_total // (1024**3), - vram_total // (1024**3), + effective_total // (1024**3), ) diff --git a/dockerfiles/Dockerfile.ray-worker-strixhalo b/dockerfiles/Dockerfile.ray-worker-strixhalo index 550ff74..d16fe57 100644 --- a/dockerfiles/Dockerfile.ray-worker-strixhalo +++ b/dockerfiles/Dockerfile.ray-worker-strixhalo @@ -2,28 +2,29 @@ # Used for: vLLM (Llama 3.1 70B) # # Build: -# docker build -t registry.lab.daviestechlabs.io/daviestechlabs/ray-worker-strixhalo:v1.0.21 \ +# docker build -t registry.lab.daviestechlabs.io/daviestechlabs/ray-worker-strixhalo:v2.0.0 \ # -f dockerfiles/Dockerfile.ray-worker-strixhalo . # # STRATEGY: Full source build of vLLM on AMD's vendor PyTorch image. # -# The vendor image (rocm/pytorch ROCm 7.0.2 / Ubuntu 24.04 / Python 3.12) -# ships torch 2.9.1 compiled by AMD CI against the exact ROCm libraries in -# the image. Pre-built vLLM torch wheels (wheels.vllm.ai) carry a custom -# torch 2.9.1+git8907517 that segfaults in libhsa-runtime64.so on gfx1151 -# during HSA queue creation. By keeping the vendor torch and compiling vLLM -# from source we guarantee ABI compatibility across the entire stack. +# The vendor image (rocm/pytorch ROCm 7.2 / Ubuntu 24.04 / Python 3.12) +# ships torch 2.9.1+rocm7.2.0 compiled by AMD CI against the exact ROCm +# libraries in the image. ROCm 7.2 includes the hsakmt VGPR count fix +# for gfx1151 (TheRock #2991) — ROCm 7.0.x/7.1.x segfault during HSA +# queue creation due to incorrect VGPR sizing. By keeping the vendor +# torch and compiling vLLM from source we guarantee ABI compatibility +# across the entire stack. # -# gfx1151 is mapped to gfx1100 at runtime via HSA_OVERRIDE_GFX_VERSION=11.0.0, -# so all HIP kernels are compiled for the gfx1100 target. +# ROCm 7.2 supports gfx1151 natively — no HSA_OVERRIDE_GFX_VERSION needed. +# HIP kernels are compiled directly for the gfx1151 target. # # Note: AITER is gfx9-only. On gfx11, vLLM defaults to TRITON_ATTN backend. -FROM docker.io/rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.9.1 +FROM docker.io/rocm/pytorch:rocm7.2_ubuntu24.04_py3.12_pytorch_release_2.9.1 # ── Build arguments ───────────────────────────────────────────────────── ARG VLLM_VERSION=v0.15.1 -ARG PYTORCH_ROCM_ARCH="gfx1100" +ARG PYTORCH_ROCM_ARCH="gfx1151" ARG MAX_JOBS=16 # ── OCI labels ────────────────────────────────────────────────────────── @@ -32,7 +33,7 @@ LABEL org.opencontainers.image.description="Ray Serve worker for AMD Strix Halo LABEL org.opencontainers.image.vendor="DaviesTechLabs" LABEL org.opencontainers.image.source="https://git.daviestechlabs.io/daviestechlabs/kuberay-images" LABEL org.opencontainers.image.licenses="MIT" -LABEL gpu.target="amd-rocm-7.0.2-gfx1151" +LABEL gpu.target="amd-rocm-7.2-gfx1151" LABEL ray.version="2.53.0" LABEL vllm.build="source" @@ -52,17 +53,16 @@ ENV PATH="/opt/venv/bin:/opt/rocm/bin:/opt/rocm/llvm/bin:/home/ray/.local/bin:/u HIP_VISIBLE_DEVICES=0 \ HSA_ENABLE_SDMA=0 \ PYTORCH_ALLOC_CONF="max_split_size_mb:512" \ - HSA_OVERRIDE_GFX_VERSION="11.0.0" \ - ROCM_TARGET_LST="gfx1151,gfx1100" + ROCM_TARGET_LST="gfx1151" # ── System setup ───────────────────────────────────────────────────────── -# The vendor image already ships ALL needed packages: -# cmake 4.0, hipcc 7.0.2, clang++ 20.0 (AMD ROCm LLVM), git, -# libelf, libnuma, libdrm, libopenmpi3, and HIP dev headers/cmake configs. +# The vendor image ships hipcc 7.2, clang++ (AMD ROCm LLVM), git, +# libelf, libnuma, libdrm, libopenmpi3, and HIP dev headers/cmake configs. +# cmake is NOT in the 7.2 image — installed via pip below. # # CRITICAL: Do NOT run apt-get upgrade or install ANY packages from apt. # Even installing ccache triggers a dependency cascade that pulls in -# Ubuntu's hipcc 5.7.1 (which overwrites the vendor hipcc 7.0.2) and +# Ubuntu's hipcc 5.7.1 (which overwrites the vendor hipcc 7.2) and # a broken /usr/bin/hipconfig.pl that makes cmake find_package(hip) # report version 0.0.0 → "Can't find CUDA or HIP installation." # @@ -81,9 +81,9 @@ RUN (groupadd -g 100 -o users 2>/dev/null || true) \ COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv # ── Python build dependencies ────────────────────────────────────────── -# CRITICAL: vLLM requires cmake<4. The vendor image ships cmake 4.0.0 -# which changed find_package(MODULE) behaviour and breaks FindHIP.cmake -# (reports HIP version 0.0.0). Downgrade to 3.x per vLLM's rocm-build.txt. +# CRITICAL: vLLM requires cmake<4. cmake 4.0+ changed find_package(MODULE) +# behaviour and breaks FindHIP.cmake (reports HIP version 0.0.0). +# The ROCm 7.2 image does not ship cmake, so we install 3.x here. RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --python /opt/venv/bin/python3 \ 'cmake>=3.26.1,<4' \ @@ -234,10 +234,11 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # ── Verify vendor torch survived ─────────────────────────────────────── # Fail early if any install step accidentally replaced the vendor torch. +# ROCm 7.2 vendor torch version: 2.9.1+rocm7.2.0.git7e1940d4 RUN python3 -c "\ import torch; \ v = torch.__version__; \ -assert '+git' not in v, f'vLLM torch detected ({v}) — vendor torch was overwritten!'; \ +assert 'rocm7.2' in v, f'Expected ROCm 7.2 vendor torch, got {v}'; \ print(f'torch {v} (vendor) OK')" # ── amdsmi sysfs shim ──────────────────────────────────────────────────