v1.0.26: dynamic VRAM via GTT for 32GB carve-out
Some checks failed
Build and Push Images / build (Dockerfile.ray-worker-intel, intel) (push) Has been cancelled
Build and Push Images / build (Dockerfile.ray-worker-nvidia, nvidia) (push) Has been cancelled
Build and Push Images / build (Dockerfile.ray-worker-rdna2, rdna2) (push) Has been cancelled
Build and Push Images / build (Dockerfile.ray-worker-strixhalo, strixhalo) (push) Has been cancelled
Build and Push Images / Release (push) Has been cancelled
Build and Push Images / Notify (push) Has been cancelled
Build and Push Images / determine-version (push) Has been cancelled
Some checks failed
Build and Push Images / build (Dockerfile.ray-worker-intel, intel) (push) Has been cancelled
Build and Push Images / build (Dockerfile.ray-worker-nvidia, nvidia) (push) Has been cancelled
Build and Push Images / build (Dockerfile.ray-worker-rdna2, rdna2) (push) Has been cancelled
Build and Push Images / build (Dockerfile.ray-worker-strixhalo, strixhalo) (push) Has been cancelled
Build and Push Images / Release (push) Has been cancelled
Build and Push Images / Notify (push) Has been cancelled
Build and Push Images / determine-version (push) Has been cancelled
strixhalo_vram_fix.py: compute effective VRAM as min(GTT_pool, physical_RAM) - 4GB OS reserve instead of raw sysfs VRAM. Prevents OOM when carve-out < model size and prevents kernel OOM when GTT > physical RAM.
This commit is contained in:
@@ -1,21 +1,27 @@
|
|||||||
"""
|
"""
|
||||||
Monkey-patch torch.cuda memory reporting for AMD APUs with unified memory.
|
Monkey-patch torch.cuda memory reporting for AMD APUs with unified memory.
|
||||||
|
|
||||||
On Strix Halo (and other AMD APUs), PyTorch reports GTT (system RAM
|
On Strix Halo (and other AMD APUs with dynamic VRAM), the BIOS carve-out
|
||||||
accessible to GPU) instead of actual VRAM:
|
for VRAM may be smaller than what the GPU can actually use. With the
|
||||||
|
``amdgpu.gttsize`` kernel parameter the GPU can dynamically claim host
|
||||||
|
RAM through the GTT (Graphics Translation Table) pool, so the *effective*
|
||||||
|
VRAM is much larger than the static carve-out reported by sysfs.
|
||||||
|
|
||||||
torch.cuda.mem_get_info() → (29 GiB free, 128 GiB total) WRONG
|
Example with BIOS at 32 GB, gttsize=131072 (128 GiB):
|
||||||
torch.cuda.get_device_properties().total_memory → 128 GiB WRONG
|
|
||||||
|
|
||||||
Meanwhile sysfs reports the real VRAM:
|
sysfs mem_info_vram_total → 32 GiB (static carve-out)
|
||||||
/sys/class/drm/cardN/device/mem_info_vram_total → 96 GiB
|
torch.cuda.mem_get_info() total → 128 GiB (GTT pool — correct!)
|
||||||
/sys/class/drm/cardN/device/mem_info_vram_used → ~0.2 GiB
|
|
||||||
|
|
||||||
vLLM uses mem_get_info() and get_device_properties() to decide how much
|
However in some configurations PyTorch may report only the static
|
||||||
memory to pre-allocate. With wrong numbers it either OOMs or refuses to
|
carve-out, and vLLM will refuse to load a model that exceeds it. In
|
||||||
start ("Free memory less than desired GPU memory utilization").
|
other configurations PyTorch may report the *full* GTT pool, which can
|
||||||
|
be larger than actual physical RAM and cause the kernel OOM killer to
|
||||||
|
fire during allocation.
|
||||||
|
|
||||||
This module patches both APIs to return sysfs VRAM values instead.
|
This module patches ``torch.cuda.mem_get_info()`` and
|
||||||
|
``torch.cuda.get_device_properties()`` so that the reported total equals
|
||||||
|
the *effective* GPU memory: the smaller of the GTT pool size and the
|
||||||
|
total physical RAM, minus a safety reserve for the OS.
|
||||||
|
|
||||||
Installed as a .pth hook so it runs before any user code.
|
Installed as a .pth hook so it runs before any user code.
|
||||||
|
|
||||||
@@ -40,8 +46,31 @@ def _read_sysfs_int(path: str) -> int | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _get_real_vram() -> tuple[int, int] | None:
|
def _get_total_physical_ram() -> int | None:
|
||||||
"""Read real VRAM total/used from sysfs for the first AMD GPU."""
|
"""Read total physical RAM from /proc/meminfo (in bytes)."""
|
||||||
|
try:
|
||||||
|
with open("/proc/meminfo") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.startswith("MemTotal:"):
|
||||||
|
# MemTotal: 98765432 kB
|
||||||
|
return int(line.split()[1]) * 1024
|
||||||
|
except (OSError, ValueError):
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_gtt_size() -> int | None:
|
||||||
|
"""Read the GTT (Graphics Translation Table) pool size from sysfs."""
|
||||||
|
for card_dir in sorted(glob.glob("/sys/class/drm/card[0-9]*/device")):
|
||||||
|
gtt_path = os.path.join(card_dir, "mem_info_gtt_total")
|
||||||
|
val = _read_sysfs_int(gtt_path)
|
||||||
|
if val is not None:
|
||||||
|
return val
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_sysfs_vram() -> tuple[int, int] | None:
|
||||||
|
"""Read sysfs VRAM total/used for the first AMD GPU."""
|
||||||
for card_dir in sorted(glob.glob("/sys/class/drm/card[0-9]*/device")):
|
for card_dir in sorted(glob.glob("/sys/class/drm/card[0-9]*/device")):
|
||||||
vendor_path = os.path.join(card_dir, "vendor")
|
vendor_path = os.path.join(card_dir, "vendor")
|
||||||
if not os.path.exists(vendor_path):
|
if not os.path.exists(vendor_path):
|
||||||
@@ -61,6 +90,50 @@ def _get_real_vram() -> tuple[int, int] | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Safety reserve for OS + kubelet + system pods (4 GiB)
|
||||||
|
_OS_RESERVE_BYTES = 4 * 1024**3
|
||||||
|
|
||||||
|
|
||||||
|
def _get_effective_vram() -> tuple[int, int] | None:
|
||||||
|
"""Compute the effective GPU memory for a unified-memory APU.
|
||||||
|
|
||||||
|
On APUs with dynamic VRAM (GTT), the effective VRAM is:
|
||||||
|
|
||||||
|
min(gtt_size, physical_ram) - os_reserve
|
||||||
|
|
||||||
|
This prevents vLLM from trying to use more memory than physically
|
||||||
|
exists (which would trigger the kernel OOM killer), while still
|
||||||
|
allowing it to use more than the small BIOS carve-out.
|
||||||
|
|
||||||
|
Returns (total, used) in bytes, or None if detection fails.
|
||||||
|
"""
|
||||||
|
sysfs = _get_sysfs_vram()
|
||||||
|
if sysfs is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
sysfs_total, sysfs_used = sysfs
|
||||||
|
gtt_total = _get_gtt_size()
|
||||||
|
phys_ram = _get_total_physical_ram()
|
||||||
|
|
||||||
|
if gtt_total is None or phys_ram is None:
|
||||||
|
# Can't compute effective VRAM — fall back to sysfs
|
||||||
|
return sysfs
|
||||||
|
|
||||||
|
effective_total = min(gtt_total, phys_ram) - _OS_RESERVE_BYTES
|
||||||
|
effective_total = max(effective_total, sysfs_total) # never below carve-out
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"strixhalo_vram_fix: sysfs_total=%d GiB, gtt=%d GiB, "
|
||||||
|
"phys_ram=%d GiB, effective=%d GiB",
|
||||||
|
sysfs_total // (1024**3),
|
||||||
|
gtt_total // (1024**3),
|
||||||
|
phys_ram // (1024**3),
|
||||||
|
effective_total // (1024**3),
|
||||||
|
)
|
||||||
|
|
||||||
|
return (effective_total, sysfs_used)
|
||||||
|
|
||||||
|
|
||||||
_GUARD_ENV = "_STRIXHALO_VRAM_FIX_ACTIVE"
|
_GUARD_ENV = "_STRIXHALO_VRAM_FIX_ACTIVE"
|
||||||
|
|
||||||
|
|
||||||
@@ -115,11 +188,11 @@ class _VRAMDeviceProperties:
|
|||||||
|
|
||||||
|
|
||||||
def _apply_patch() -> None:
|
def _apply_patch() -> None:
|
||||||
"""Patch torch.cuda memory APIs if we detect unified memory mis-reporting."""
|
"""Patch torch.cuda memory APIs for correct unified-memory reporting."""
|
||||||
if _should_skip():
|
if _should_skip():
|
||||||
return
|
return
|
||||||
|
|
||||||
if _get_real_vram() is None:
|
if _get_sysfs_vram() is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Set guard BEFORE importing torch — torch init spawns offload-arch
|
# Set guard BEFORE importing torch — torch init spawns offload-arch
|
||||||
@@ -139,32 +212,31 @@ def _apply_patch() -> None:
|
|||||||
# prevent the offload-arch → .pth → torch import recursion.
|
# prevent the offload-arch → .pth → torch import recursion.
|
||||||
os.environ.pop(_GUARD_ENV, None)
|
os.environ.pop(_GUARD_ENV, None)
|
||||||
|
|
||||||
vram_info = _get_real_vram()
|
vram_info = _get_effective_vram()
|
||||||
if vram_info is None:
|
if vram_info is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
vram_total, vram_used = vram_info
|
effective_total, vram_used = vram_info
|
||||||
|
|
||||||
# Only patch if PyTorch total differs significantly from sysfs VRAM
|
# Only patch if PyTorch total differs significantly from effective VRAM
|
||||||
# (i.e. PyTorch is reporting GTT/unified memory, not real VRAM)
|
|
||||||
try:
|
try:
|
||||||
pt_free, pt_total = torch.cuda.mem_get_info(0)
|
pt_free, pt_total = torch.cuda.mem_get_info(0)
|
||||||
except Exception:
|
except Exception:
|
||||||
return
|
return
|
||||||
|
|
||||||
# If they're within 10% of each other, no patch needed
|
# If they're within 10% of each other, no patch needed
|
||||||
if abs(pt_total - vram_total) / max(pt_total, 1) < 0.10:
|
if abs(pt_total - effective_total) / max(pt_total, 1) < 0.10:
|
||||||
return
|
return
|
||||||
|
|
||||||
# --- Patch 1: torch.cuda.mem_get_info ---
|
# --- Patch 1: torch.cuda.mem_get_info ---
|
||||||
original_mem_get_info = torch.cuda.mem_get_info
|
original_mem_get_info = torch.cuda.mem_get_info
|
||||||
|
|
||||||
def _patched_mem_get_info(device=None):
|
def _patched_mem_get_info(device=None):
|
||||||
"""Return real VRAM from sysfs instead of GTT numbers."""
|
"""Return effective VRAM instead of raw PyTorch numbers."""
|
||||||
real = _get_real_vram()
|
info = _get_effective_vram()
|
||||||
if real is None:
|
if info is None:
|
||||||
return original_mem_get_info(device)
|
return original_mem_get_info(device)
|
||||||
total, used = real
|
total, used = info
|
||||||
# Account for PyTorch's own allocations on top of sysfs baseline
|
# Account for PyTorch's own allocations on top of sysfs baseline
|
||||||
pt_allocated = torch.cuda.memory_allocated(device or 0)
|
pt_allocated = torch.cuda.memory_allocated(device or 0)
|
||||||
free = total - used - pt_allocated
|
free = total - used - pt_allocated
|
||||||
@@ -174,24 +246,24 @@ def _apply_patch() -> None:
|
|||||||
|
|
||||||
# --- Patch 2: torch.cuda.get_device_properties ---
|
# --- Patch 2: torch.cuda.get_device_properties ---
|
||||||
# total_memory is a read-only C property, so we wrap the return value
|
# total_memory is a read-only C property, so we wrap the return value
|
||||||
# in a proxy that overrides it with the real VRAM total.
|
# in a proxy that overrides it with the effective VRAM total.
|
||||||
original_get_device_properties = torch.cuda.get_device_properties
|
original_get_device_properties = torch.cuda.get_device_properties
|
||||||
|
|
||||||
def _patched_get_device_properties(device=None):
|
def _patched_get_device_properties(device=None):
|
||||||
props = original_get_device_properties(device)
|
props = original_get_device_properties(device)
|
||||||
real = _get_real_vram()
|
info = _get_effective_vram()
|
||||||
if real is None:
|
if info is None:
|
||||||
return props
|
return props
|
||||||
return _VRAMDeviceProperties(props, real[0])
|
return _VRAMDeviceProperties(props, info[0])
|
||||||
|
|
||||||
torch.cuda.get_device_properties = _patched_get_device_properties
|
torch.cuda.get_device_properties = _patched_get_device_properties
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"strixhalo_vram_fix: patched torch.cuda.mem_get_info and "
|
"strixhalo_vram_fix: patched torch.cuda.mem_get_info and "
|
||||||
"get_device_properties (PyTorch reported %d GiB total, "
|
"get_device_properties (PyTorch reported %d GiB total, "
|
||||||
"sysfs VRAM is %d GiB)",
|
"effective VRAM is %d GiB)",
|
||||||
pt_total // (1024**3),
|
pt_total // (1024**3),
|
||||||
vram_total // (1024**3),
|
effective_total // (1024**3),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,28 +2,29 @@
|
|||||||
# Used for: vLLM (Llama 3.1 70B)
|
# Used for: vLLM (Llama 3.1 70B)
|
||||||
#
|
#
|
||||||
# Build:
|
# Build:
|
||||||
# docker build -t registry.lab.daviestechlabs.io/daviestechlabs/ray-worker-strixhalo:v1.0.21 \
|
# docker build -t registry.lab.daviestechlabs.io/daviestechlabs/ray-worker-strixhalo:v2.0.0 \
|
||||||
# -f dockerfiles/Dockerfile.ray-worker-strixhalo .
|
# -f dockerfiles/Dockerfile.ray-worker-strixhalo .
|
||||||
#
|
#
|
||||||
# STRATEGY: Full source build of vLLM on AMD's vendor PyTorch image.
|
# STRATEGY: Full source build of vLLM on AMD's vendor PyTorch image.
|
||||||
#
|
#
|
||||||
# The vendor image (rocm/pytorch ROCm 7.0.2 / Ubuntu 24.04 / Python 3.12)
|
# The vendor image (rocm/pytorch ROCm 7.2 / Ubuntu 24.04 / Python 3.12)
|
||||||
# ships torch 2.9.1 compiled by AMD CI against the exact ROCm libraries in
|
# ships torch 2.9.1+rocm7.2.0 compiled by AMD CI against the exact ROCm
|
||||||
# the image. Pre-built vLLM torch wheels (wheels.vllm.ai) carry a custom
|
# libraries in the image. ROCm 7.2 includes the hsakmt VGPR count fix
|
||||||
# torch 2.9.1+git8907517 that segfaults in libhsa-runtime64.so on gfx1151
|
# for gfx1151 (TheRock #2991) — ROCm 7.0.x/7.1.x segfault during HSA
|
||||||
# during HSA queue creation. By keeping the vendor torch and compiling vLLM
|
# queue creation due to incorrect VGPR sizing. By keeping the vendor
|
||||||
# from source we guarantee ABI compatibility across the entire stack.
|
# torch and compiling vLLM from source we guarantee ABI compatibility
|
||||||
|
# across the entire stack.
|
||||||
#
|
#
|
||||||
# gfx1151 is mapped to gfx1100 at runtime via HSA_OVERRIDE_GFX_VERSION=11.0.0,
|
# ROCm 7.2 supports gfx1151 natively — no HSA_OVERRIDE_GFX_VERSION needed.
|
||||||
# so all HIP kernels are compiled for the gfx1100 target.
|
# HIP kernels are compiled directly for the gfx1151 target.
|
||||||
#
|
#
|
||||||
# Note: AITER is gfx9-only. On gfx11, vLLM defaults to TRITON_ATTN backend.
|
# Note: AITER is gfx9-only. On gfx11, vLLM defaults to TRITON_ATTN backend.
|
||||||
|
|
||||||
FROM docker.io/rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.9.1
|
FROM docker.io/rocm/pytorch:rocm7.2_ubuntu24.04_py3.12_pytorch_release_2.9.1
|
||||||
|
|
||||||
# ── Build arguments ─────────────────────────────────────────────────────
|
# ── Build arguments ─────────────────────────────────────────────────────
|
||||||
ARG VLLM_VERSION=v0.15.1
|
ARG VLLM_VERSION=v0.15.1
|
||||||
ARG PYTORCH_ROCM_ARCH="gfx1100"
|
ARG PYTORCH_ROCM_ARCH="gfx1151"
|
||||||
ARG MAX_JOBS=16
|
ARG MAX_JOBS=16
|
||||||
|
|
||||||
# ── OCI labels ──────────────────────────────────────────────────────────
|
# ── OCI labels ──────────────────────────────────────────────────────────
|
||||||
@@ -32,7 +33,7 @@ LABEL org.opencontainers.image.description="Ray Serve worker for AMD Strix Halo
|
|||||||
LABEL org.opencontainers.image.vendor="DaviesTechLabs"
|
LABEL org.opencontainers.image.vendor="DaviesTechLabs"
|
||||||
LABEL org.opencontainers.image.source="https://git.daviestechlabs.io/daviestechlabs/kuberay-images"
|
LABEL org.opencontainers.image.source="https://git.daviestechlabs.io/daviestechlabs/kuberay-images"
|
||||||
LABEL org.opencontainers.image.licenses="MIT"
|
LABEL org.opencontainers.image.licenses="MIT"
|
||||||
LABEL gpu.target="amd-rocm-7.0.2-gfx1151"
|
LABEL gpu.target="amd-rocm-7.2-gfx1151"
|
||||||
LABEL ray.version="2.53.0"
|
LABEL ray.version="2.53.0"
|
||||||
LABEL vllm.build="source"
|
LABEL vllm.build="source"
|
||||||
|
|
||||||
@@ -52,17 +53,16 @@ ENV PATH="/opt/venv/bin:/opt/rocm/bin:/opt/rocm/llvm/bin:/home/ray/.local/bin:/u
|
|||||||
HIP_VISIBLE_DEVICES=0 \
|
HIP_VISIBLE_DEVICES=0 \
|
||||||
HSA_ENABLE_SDMA=0 \
|
HSA_ENABLE_SDMA=0 \
|
||||||
PYTORCH_ALLOC_CONF="max_split_size_mb:512" \
|
PYTORCH_ALLOC_CONF="max_split_size_mb:512" \
|
||||||
HSA_OVERRIDE_GFX_VERSION="11.0.0" \
|
ROCM_TARGET_LST="gfx1151"
|
||||||
ROCM_TARGET_LST="gfx1151,gfx1100"
|
|
||||||
|
|
||||||
# ── System setup ─────────────────────────────────────────────────────────
|
# ── System setup ─────────────────────────────────────────────────────────
|
||||||
# The vendor image already ships ALL needed packages:
|
# The vendor image ships hipcc 7.2, clang++ (AMD ROCm LLVM), git,
|
||||||
# cmake 4.0, hipcc 7.0.2, clang++ 20.0 (AMD ROCm LLVM), git,
|
|
||||||
# libelf, libnuma, libdrm, libopenmpi3, and HIP dev headers/cmake configs.
|
# libelf, libnuma, libdrm, libopenmpi3, and HIP dev headers/cmake configs.
|
||||||
|
# cmake is NOT in the 7.2 image — installed via pip below.
|
||||||
#
|
#
|
||||||
# CRITICAL: Do NOT run apt-get upgrade or install ANY packages from apt.
|
# CRITICAL: Do NOT run apt-get upgrade or install ANY packages from apt.
|
||||||
# Even installing ccache triggers a dependency cascade that pulls in
|
# Even installing ccache triggers a dependency cascade that pulls in
|
||||||
# Ubuntu's hipcc 5.7.1 (which overwrites the vendor hipcc 7.0.2) and
|
# Ubuntu's hipcc 5.7.1 (which overwrites the vendor hipcc 7.2) and
|
||||||
# a broken /usr/bin/hipconfig.pl that makes cmake find_package(hip)
|
# a broken /usr/bin/hipconfig.pl that makes cmake find_package(hip)
|
||||||
# report version 0.0.0 → "Can't find CUDA or HIP installation."
|
# report version 0.0.0 → "Can't find CUDA or HIP installation."
|
||||||
#
|
#
|
||||||
@@ -81,9 +81,9 @@ RUN (groupadd -g 100 -o users 2>/dev/null || true) \
|
|||||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
|
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
|
||||||
|
|
||||||
# ── Python build dependencies ──────────────────────────────────────────
|
# ── Python build dependencies ──────────────────────────────────────────
|
||||||
# CRITICAL: vLLM requires cmake<4. The vendor image ships cmake 4.0.0
|
# CRITICAL: vLLM requires cmake<4. cmake 4.0+ changed find_package(MODULE)
|
||||||
# which changed find_package(MODULE) behaviour and breaks FindHIP.cmake
|
# behaviour and breaks FindHIP.cmake (reports HIP version 0.0.0).
|
||||||
# (reports HIP version 0.0.0). Downgrade to 3.x per vLLM's rocm-build.txt.
|
# The ROCm 7.2 image does not ship cmake, so we install 3.x here.
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
uv pip install --python /opt/venv/bin/python3 \
|
uv pip install --python /opt/venv/bin/python3 \
|
||||||
'cmake>=3.26.1,<4' \
|
'cmake>=3.26.1,<4' \
|
||||||
@@ -234,10 +234,11 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||||||
|
|
||||||
# ── Verify vendor torch survived ───────────────────────────────────────
|
# ── Verify vendor torch survived ───────────────────────────────────────
|
||||||
# Fail early if any install step accidentally replaced the vendor torch.
|
# Fail early if any install step accidentally replaced the vendor torch.
|
||||||
|
# ROCm 7.2 vendor torch version: 2.9.1+rocm7.2.0.git7e1940d4
|
||||||
RUN python3 -c "\
|
RUN python3 -c "\
|
||||||
import torch; \
|
import torch; \
|
||||||
v = torch.__version__; \
|
v = torch.__version__; \
|
||||||
assert '+git' not in v, f'vLLM torch detected ({v}) — vendor torch was overwritten!'; \
|
assert 'rocm7.2' in v, f'Expected ROCm 7.2 vendor torch, got {v}'; \
|
||||||
print(f'torch {v} (vendor) OK')"
|
print(f'torch {v} (vendor) OK')"
|
||||||
|
|
||||||
# ── amdsmi sysfs shim ──────────────────────────────────────────────────
|
# ── amdsmi sysfs shim ──────────────────────────────────────────────────
|
||||||
|
|||||||
Reference in New Issue
Block a user