v1.0.26: dynamic VRAM via GTT for 32GB carve-out
Some checks failed
Build and Push Images / build (Dockerfile.ray-worker-intel, intel) (push) Has been cancelled
Build and Push Images / build (Dockerfile.ray-worker-nvidia, nvidia) (push) Has been cancelled
Build and Push Images / build (Dockerfile.ray-worker-rdna2, rdna2) (push) Has been cancelled
Build and Push Images / build (Dockerfile.ray-worker-strixhalo, strixhalo) (push) Has been cancelled
Build and Push Images / Release (push) Has been cancelled
Build and Push Images / Notify (push) Has been cancelled
Build and Push Images / determine-version (push) Has been cancelled

strixhalo_vram_fix.py: compute effective VRAM as
min(GTT_pool, physical_RAM) - 4GB OS reserve instead of
raw sysfs VRAM. Prevents OOM when carve-out < model size
and prevents kernel OOM when GTT > physical RAM.
This commit is contained in:
2026-02-10 18:38:09 -05:00
parent 1efb94ddd7
commit 8902f6e616
2 changed files with 125 additions and 52 deletions

View File

@@ -1,21 +1,27 @@
""" """
Monkey-patch torch.cuda memory reporting for AMD APUs with unified memory. Monkey-patch torch.cuda memory reporting for AMD APUs with unified memory.
On Strix Halo (and other AMD APUs), PyTorch reports GTT (system RAM On Strix Halo (and other AMD APUs with dynamic VRAM), the BIOS carve-out
accessible to GPU) instead of actual VRAM: for VRAM may be smaller than what the GPU can actually use. With the
``amdgpu.gttsize`` kernel parameter the GPU can dynamically claim host
RAM through the GTT (Graphics Translation Table) pool, so the *effective*
VRAM is much larger than the static carve-out reported by sysfs.
torch.cuda.mem_get_info() → (29 GiB free, 128 GiB total) WRONG Example with BIOS at 32 GB, gttsize=131072 (128 GiB):
torch.cuda.get_device_properties().total_memory → 128 GiB WRONG
Meanwhile sysfs reports the real VRAM: sysfs mem_info_vram_total → 32 GiB (static carve-out)
/sys/class/drm/cardN/device/mem_info_vram_total → 96 GiB torch.cuda.mem_get_info() total → 128 GiB (GTT pool — correct!)
/sys/class/drm/cardN/device/mem_info_vram_used → ~0.2 GiB
vLLM uses mem_get_info() and get_device_properties() to decide how much However in some configurations PyTorch may report only the static
memory to pre-allocate. With wrong numbers it either OOMs or refuses to carve-out, and vLLM will refuse to load a model that exceeds it. In
start ("Free memory less than desired GPU memory utilization"). other configurations PyTorch may report the *full* GTT pool, which can
be larger than actual physical RAM and cause the kernel OOM killer to
fire during allocation.
This module patches both APIs to return sysfs VRAM values instead. This module patches ``torch.cuda.mem_get_info()`` and
``torch.cuda.get_device_properties()`` so that the reported total equals
the *effective* GPU memory: the smaller of the GTT pool size and the
total physical RAM, minus a safety reserve for the OS.
Installed as a .pth hook so it runs before any user code. Installed as a .pth hook so it runs before any user code.
@@ -40,8 +46,31 @@ def _read_sysfs_int(path: str) -> int | None:
return None return None
def _get_real_vram() -> tuple[int, int] | None: def _get_total_physical_ram() -> int | None:
"""Read real VRAM total/used from sysfs for the first AMD GPU.""" """Read total physical RAM from /proc/meminfo (in bytes)."""
try:
with open("/proc/meminfo") as f:
for line in f:
if line.startswith("MemTotal:"):
# MemTotal: 98765432 kB
return int(line.split()[1]) * 1024
except (OSError, ValueError):
pass
return None
def _get_gtt_size() -> int | None:
"""Read the GTT (Graphics Translation Table) pool size from sysfs."""
for card_dir in sorted(glob.glob("/sys/class/drm/card[0-9]*/device")):
gtt_path = os.path.join(card_dir, "mem_info_gtt_total")
val = _read_sysfs_int(gtt_path)
if val is not None:
return val
return None
def _get_sysfs_vram() -> tuple[int, int] | None:
"""Read sysfs VRAM total/used for the first AMD GPU."""
for card_dir in sorted(glob.glob("/sys/class/drm/card[0-9]*/device")): for card_dir in sorted(glob.glob("/sys/class/drm/card[0-9]*/device")):
vendor_path = os.path.join(card_dir, "vendor") vendor_path = os.path.join(card_dir, "vendor")
if not os.path.exists(vendor_path): if not os.path.exists(vendor_path):
@@ -61,6 +90,50 @@ def _get_real_vram() -> tuple[int, int] | None:
return None return None
# Safety reserve for OS + kubelet + system pods (4 GiB)
_OS_RESERVE_BYTES = 4 * 1024**3
def _get_effective_vram() -> tuple[int, int] | None:
"""Compute the effective GPU memory for a unified-memory APU.
On APUs with dynamic VRAM (GTT), the effective VRAM is:
min(gtt_size, physical_ram) - os_reserve
This prevents vLLM from trying to use more memory than physically
exists (which would trigger the kernel OOM killer), while still
allowing it to use more than the small BIOS carve-out.
Returns (total, used) in bytes, or None if detection fails.
"""
sysfs = _get_sysfs_vram()
if sysfs is None:
return None
sysfs_total, sysfs_used = sysfs
gtt_total = _get_gtt_size()
phys_ram = _get_total_physical_ram()
if gtt_total is None or phys_ram is None:
# Can't compute effective VRAM — fall back to sysfs
return sysfs
effective_total = min(gtt_total, phys_ram) - _OS_RESERVE_BYTES
effective_total = max(effective_total, sysfs_total) # never below carve-out
logger.debug(
"strixhalo_vram_fix: sysfs_total=%d GiB, gtt=%d GiB, "
"phys_ram=%d GiB, effective=%d GiB",
sysfs_total // (1024**3),
gtt_total // (1024**3),
phys_ram // (1024**3),
effective_total // (1024**3),
)
return (effective_total, sysfs_used)
_GUARD_ENV = "_STRIXHALO_VRAM_FIX_ACTIVE" _GUARD_ENV = "_STRIXHALO_VRAM_FIX_ACTIVE"
@@ -115,11 +188,11 @@ class _VRAMDeviceProperties:
def _apply_patch() -> None: def _apply_patch() -> None:
"""Patch torch.cuda memory APIs if we detect unified memory mis-reporting.""" """Patch torch.cuda memory APIs for correct unified-memory reporting."""
if _should_skip(): if _should_skip():
return return
if _get_real_vram() is None: if _get_sysfs_vram() is None:
return return
# Set guard BEFORE importing torch — torch init spawns offload-arch # Set guard BEFORE importing torch — torch init spawns offload-arch
@@ -139,32 +212,31 @@ def _apply_patch() -> None:
# prevent the offload-arch → .pth → torch import recursion. # prevent the offload-arch → .pth → torch import recursion.
os.environ.pop(_GUARD_ENV, None) os.environ.pop(_GUARD_ENV, None)
vram_info = _get_real_vram() vram_info = _get_effective_vram()
if vram_info is None: if vram_info is None:
return return
vram_total, vram_used = vram_info effective_total, vram_used = vram_info
# Only patch if PyTorch total differs significantly from sysfs VRAM # Only patch if PyTorch total differs significantly from effective VRAM
# (i.e. PyTorch is reporting GTT/unified memory, not real VRAM)
try: try:
pt_free, pt_total = torch.cuda.mem_get_info(0) pt_free, pt_total = torch.cuda.mem_get_info(0)
except Exception: except Exception:
return return
# If they're within 10% of each other, no patch needed # If they're within 10% of each other, no patch needed
if abs(pt_total - vram_total) / max(pt_total, 1) < 0.10: if abs(pt_total - effective_total) / max(pt_total, 1) < 0.10:
return return
# --- Patch 1: torch.cuda.mem_get_info --- # --- Patch 1: torch.cuda.mem_get_info ---
original_mem_get_info = torch.cuda.mem_get_info original_mem_get_info = torch.cuda.mem_get_info
def _patched_mem_get_info(device=None): def _patched_mem_get_info(device=None):
"""Return real VRAM from sysfs instead of GTT numbers.""" """Return effective VRAM instead of raw PyTorch numbers."""
real = _get_real_vram() info = _get_effective_vram()
if real is None: if info is None:
return original_mem_get_info(device) return original_mem_get_info(device)
total, used = real total, used = info
# Account for PyTorch's own allocations on top of sysfs baseline # Account for PyTorch's own allocations on top of sysfs baseline
pt_allocated = torch.cuda.memory_allocated(device or 0) pt_allocated = torch.cuda.memory_allocated(device or 0)
free = total - used - pt_allocated free = total - used - pt_allocated
@@ -174,24 +246,24 @@ def _apply_patch() -> None:
# --- Patch 2: torch.cuda.get_device_properties --- # --- Patch 2: torch.cuda.get_device_properties ---
# total_memory is a read-only C property, so we wrap the return value # total_memory is a read-only C property, so we wrap the return value
# in a proxy that overrides it with the real VRAM total. # in a proxy that overrides it with the effective VRAM total.
original_get_device_properties = torch.cuda.get_device_properties original_get_device_properties = torch.cuda.get_device_properties
def _patched_get_device_properties(device=None): def _patched_get_device_properties(device=None):
props = original_get_device_properties(device) props = original_get_device_properties(device)
real = _get_real_vram() info = _get_effective_vram()
if real is None: if info is None:
return props return props
return _VRAMDeviceProperties(props, real[0]) return _VRAMDeviceProperties(props, info[0])
torch.cuda.get_device_properties = _patched_get_device_properties torch.cuda.get_device_properties = _patched_get_device_properties
logger.info( logger.info(
"strixhalo_vram_fix: patched torch.cuda.mem_get_info and " "strixhalo_vram_fix: patched torch.cuda.mem_get_info and "
"get_device_properties (PyTorch reported %d GiB total, " "get_device_properties (PyTorch reported %d GiB total, "
"sysfs VRAM is %d GiB)", "effective VRAM is %d GiB)",
pt_total // (1024**3), pt_total // (1024**3),
vram_total // (1024**3), effective_total // (1024**3),
) )

View File

@@ -2,28 +2,29 @@
# Used for: vLLM (Llama 3.1 70B) # Used for: vLLM (Llama 3.1 70B)
# #
# Build: # Build:
# docker build -t registry.lab.daviestechlabs.io/daviestechlabs/ray-worker-strixhalo:v1.0.21 \ # docker build -t registry.lab.daviestechlabs.io/daviestechlabs/ray-worker-strixhalo:v2.0.0 \
# -f dockerfiles/Dockerfile.ray-worker-strixhalo . # -f dockerfiles/Dockerfile.ray-worker-strixhalo .
# #
# STRATEGY: Full source build of vLLM on AMD's vendor PyTorch image. # STRATEGY: Full source build of vLLM on AMD's vendor PyTorch image.
# #
# The vendor image (rocm/pytorch ROCm 7.0.2 / Ubuntu 24.04 / Python 3.12) # The vendor image (rocm/pytorch ROCm 7.2 / Ubuntu 24.04 / Python 3.12)
# ships torch 2.9.1 compiled by AMD CI against the exact ROCm libraries in # ships torch 2.9.1+rocm7.2.0 compiled by AMD CI against the exact ROCm
# the image. Pre-built vLLM torch wheels (wheels.vllm.ai) carry a custom # libraries in the image. ROCm 7.2 includes the hsakmt VGPR count fix
# torch 2.9.1+git8907517 that segfaults in libhsa-runtime64.so on gfx1151 # for gfx1151 (TheRock #2991) — ROCm 7.0.x/7.1.x segfault during HSA
# during HSA queue creation. By keeping the vendor torch and compiling vLLM # queue creation due to incorrect VGPR sizing. By keeping the vendor
# from source we guarantee ABI compatibility across the entire stack. # torch and compiling vLLM from source we guarantee ABI compatibility
# across the entire stack.
# #
# gfx1151 is mapped to gfx1100 at runtime via HSA_OVERRIDE_GFX_VERSION=11.0.0, # ROCm 7.2 supports gfx1151 natively — no HSA_OVERRIDE_GFX_VERSION needed.
# so all HIP kernels are compiled for the gfx1100 target. # HIP kernels are compiled directly for the gfx1151 target.
# #
# Note: AITER is gfx9-only. On gfx11, vLLM defaults to TRITON_ATTN backend. # Note: AITER is gfx9-only. On gfx11, vLLM defaults to TRITON_ATTN backend.
FROM docker.io/rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.9.1 FROM docker.io/rocm/pytorch:rocm7.2_ubuntu24.04_py3.12_pytorch_release_2.9.1
# ── Build arguments ───────────────────────────────────────────────────── # ── Build arguments ─────────────────────────────────────────────────────
ARG VLLM_VERSION=v0.15.1 ARG VLLM_VERSION=v0.15.1
ARG PYTORCH_ROCM_ARCH="gfx1100" ARG PYTORCH_ROCM_ARCH="gfx1151"
ARG MAX_JOBS=16 ARG MAX_JOBS=16
# ── OCI labels ────────────────────────────────────────────────────────── # ── OCI labels ──────────────────────────────────────────────────────────
@@ -32,7 +33,7 @@ LABEL org.opencontainers.image.description="Ray Serve worker for AMD Strix Halo
LABEL org.opencontainers.image.vendor="DaviesTechLabs" LABEL org.opencontainers.image.vendor="DaviesTechLabs"
LABEL org.opencontainers.image.source="https://git.daviestechlabs.io/daviestechlabs/kuberay-images" LABEL org.opencontainers.image.source="https://git.daviestechlabs.io/daviestechlabs/kuberay-images"
LABEL org.opencontainers.image.licenses="MIT" LABEL org.opencontainers.image.licenses="MIT"
LABEL gpu.target="amd-rocm-7.0.2-gfx1151" LABEL gpu.target="amd-rocm-7.2-gfx1151"
LABEL ray.version="2.53.0" LABEL ray.version="2.53.0"
LABEL vllm.build="source" LABEL vllm.build="source"
@@ -52,17 +53,16 @@ ENV PATH="/opt/venv/bin:/opt/rocm/bin:/opt/rocm/llvm/bin:/home/ray/.local/bin:/u
HIP_VISIBLE_DEVICES=0 \ HIP_VISIBLE_DEVICES=0 \
HSA_ENABLE_SDMA=0 \ HSA_ENABLE_SDMA=0 \
PYTORCH_ALLOC_CONF="max_split_size_mb:512" \ PYTORCH_ALLOC_CONF="max_split_size_mb:512" \
HSA_OVERRIDE_GFX_VERSION="11.0.0" \ ROCM_TARGET_LST="gfx1151"
ROCM_TARGET_LST="gfx1151,gfx1100"
# ── System setup ───────────────────────────────────────────────────────── # ── System setup ─────────────────────────────────────────────────────────
# The vendor image already ships ALL needed packages: # The vendor image ships hipcc 7.2, clang++ (AMD ROCm LLVM), git,
# cmake 4.0, hipcc 7.0.2, clang++ 20.0 (AMD ROCm LLVM), git,
# libelf, libnuma, libdrm, libopenmpi3, and HIP dev headers/cmake configs. # libelf, libnuma, libdrm, libopenmpi3, and HIP dev headers/cmake configs.
# cmake is NOT in the 7.2 image — installed via pip below.
# #
# CRITICAL: Do NOT run apt-get upgrade or install ANY packages from apt. # CRITICAL: Do NOT run apt-get upgrade or install ANY packages from apt.
# Even installing ccache triggers a dependency cascade that pulls in # Even installing ccache triggers a dependency cascade that pulls in
# Ubuntu's hipcc 5.7.1 (which overwrites the vendor hipcc 7.0.2) and # Ubuntu's hipcc 5.7.1 (which overwrites the vendor hipcc 7.2) and
# a broken /usr/bin/hipconfig.pl that makes cmake find_package(hip) # a broken /usr/bin/hipconfig.pl that makes cmake find_package(hip)
# report version 0.0.0 → "Can't find CUDA or HIP installation." # report version 0.0.0 → "Can't find CUDA or HIP installation."
# #
@@ -81,9 +81,9 @@ RUN (groupadd -g 100 -o users 2>/dev/null || true) \
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
# ── Python build dependencies ────────────────────────────────────────── # ── Python build dependencies ──────────────────────────────────────────
# CRITICAL: vLLM requires cmake<4. The vendor image ships cmake 4.0.0 # CRITICAL: vLLM requires cmake<4. cmake 4.0+ changed find_package(MODULE)
# which changed find_package(MODULE) behaviour and breaks FindHIP.cmake # behaviour and breaks FindHIP.cmake (reports HIP version 0.0.0).
# (reports HIP version 0.0.0). Downgrade to 3.x per vLLM's rocm-build.txt. # The ROCm 7.2 image does not ship cmake, so we install 3.x here.
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --python /opt/venv/bin/python3 \ uv pip install --python /opt/venv/bin/python3 \
'cmake>=3.26.1,<4' \ 'cmake>=3.26.1,<4' \
@@ -234,10 +234,11 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# ── Verify vendor torch survived ─────────────────────────────────────── # ── Verify vendor torch survived ───────────────────────────────────────
# Fail early if any install step accidentally replaced the vendor torch. # Fail early if any install step accidentally replaced the vendor torch.
# ROCm 7.2 vendor torch version: 2.9.1+rocm7.2.0.git7e1940d4
RUN python3 -c "\ RUN python3 -c "\
import torch; \ import torch; \
v = torch.__version__; \ v = torch.__version__; \
assert '+git' not in v, f'vLLM torch detected ({v}) — vendor torch was overwritten!'; \ assert 'rocm7.2' in v, f'Expected ROCm 7.2 vendor torch, got {v}'; \
print(f'torch {v} (vendor) OK')" print(f'torch {v} (vendor) OK')"
# ── amdsmi sysfs shim ────────────────────────────────────────────────── # ── amdsmi sysfs shim ──────────────────────────────────────────────────