Some checks failed
Build and Push Images / build (Dockerfile.ray-worker-intel, intel) (push) Has been cancelled
Build and Push Images / build (Dockerfile.ray-worker-nvidia, nvidia) (push) Has been cancelled
Build and Push Images / build (Dockerfile.ray-worker-rdna2, rdna2) (push) Has been cancelled
Build and Push Images / build (Dockerfile.ray-worker-strixhalo, strixhalo) (push) Has been cancelled
Build and Push Images / Release (push) Has been cancelled
Build and Push Images / Notify (push) Has been cancelled
Build and Push Images / determine-version (push) Has been cancelled
strixhalo_vram_fix.py: compute effective VRAM as min(GTT_pool, physical_RAM) - 4GB OS reserve instead of raw sysfs VRAM. Prevents OOM when carve-out < model size and prevents kernel OOM when GTT > physical RAM.
271 lines
9.2 KiB
Python
271 lines
9.2 KiB
Python
"""
|
|
Monkey-patch torch.cuda memory reporting for AMD APUs with unified memory.
|
|
|
|
On Strix Halo (and other AMD APUs with dynamic VRAM), the BIOS carve-out
|
|
for VRAM may be smaller than what the GPU can actually use. With the
|
|
``amdgpu.gttsize`` kernel parameter the GPU can dynamically claim host
|
|
RAM through the GTT (Graphics Translation Table) pool, so the *effective*
|
|
VRAM is much larger than the static carve-out reported by sysfs.
|
|
|
|
Example with BIOS at 32 GB, gttsize=131072 (128 GiB):
|
|
|
|
sysfs mem_info_vram_total → 32 GiB (static carve-out)
|
|
torch.cuda.mem_get_info() total → 128 GiB (GTT pool — correct!)
|
|
|
|
However in some configurations PyTorch may report only the static
|
|
carve-out, and vLLM will refuse to load a model that exceeds it. In
|
|
other configurations PyTorch may report the *full* GTT pool, which can
|
|
be larger than actual physical RAM and cause the kernel OOM killer to
|
|
fire during allocation.
|
|
|
|
This module patches ``torch.cuda.mem_get_info()`` and
|
|
``torch.cuda.get_device_properties()`` so that the reported total equals
|
|
the *effective* GPU memory: the smaller of the GTT pool size and the
|
|
total physical RAM, minus a safety reserve for the OS.
|
|
|
|
Installed as a .pth hook so it runs before any user code.
|
|
|
|
IMPORTANT: The re-entry guard (_STRIXHALO_VRAM_FIX_ACTIVE env var) is
|
|
set only during torch import to prevent infinite recursion when torch
|
|
spawns offload-arch. It is CLEARED afterward so child processes (e.g.
|
|
vLLM EngineCore subprocesses) can apply their own patch.
|
|
"""
|
|
|
|
import os
|
|
import glob
|
|
import logging
|
|
|
|
logger = logging.getLogger("strixhalo_vram_fix")
|
|
|
|
|
|
def _read_sysfs_int(path: str) -> int | None:
|
|
try:
|
|
with open(path) as f:
|
|
return int(f.read().strip())
|
|
except (OSError, ValueError):
|
|
return None
|
|
|
|
|
|
def _get_total_physical_ram() -> int | None:
|
|
"""Read total physical RAM from /proc/meminfo (in bytes)."""
|
|
try:
|
|
with open("/proc/meminfo") as f:
|
|
for line in f:
|
|
if line.startswith("MemTotal:"):
|
|
# MemTotal: 98765432 kB
|
|
return int(line.split()[1]) * 1024
|
|
except (OSError, ValueError):
|
|
pass
|
|
return None
|
|
|
|
|
|
def _get_gtt_size() -> int | None:
|
|
"""Read the GTT (Graphics Translation Table) pool size from sysfs."""
|
|
for card_dir in sorted(glob.glob("/sys/class/drm/card[0-9]*/device")):
|
|
gtt_path = os.path.join(card_dir, "mem_info_gtt_total")
|
|
val = _read_sysfs_int(gtt_path)
|
|
if val is not None:
|
|
return val
|
|
return None
|
|
|
|
|
|
def _get_sysfs_vram() -> tuple[int, int] | None:
|
|
"""Read sysfs VRAM total/used for the first AMD GPU."""
|
|
for card_dir in sorted(glob.glob("/sys/class/drm/card[0-9]*/device")):
|
|
vendor_path = os.path.join(card_dir, "vendor")
|
|
if not os.path.exists(vendor_path):
|
|
continue
|
|
try:
|
|
with open(vendor_path) as f:
|
|
vendor = f.read().strip()
|
|
except OSError:
|
|
continue
|
|
if vendor != "0x1002": # AMD
|
|
continue
|
|
|
|
total = _read_sysfs_int(os.path.join(card_dir, "mem_info_vram_total"))
|
|
used = _read_sysfs_int(os.path.join(card_dir, "mem_info_vram_used"))
|
|
if total is not None and used is not None:
|
|
return (total, used)
|
|
return None
|
|
|
|
|
|
# Safety reserve for OS + kubelet + system pods (4 GiB)
|
|
_OS_RESERVE_BYTES = 4 * 1024**3
|
|
|
|
|
|
def _get_effective_vram() -> tuple[int, int] | None:
|
|
"""Compute the effective GPU memory for a unified-memory APU.
|
|
|
|
On APUs with dynamic VRAM (GTT), the effective VRAM is:
|
|
|
|
min(gtt_size, physical_ram) - os_reserve
|
|
|
|
This prevents vLLM from trying to use more memory than physically
|
|
exists (which would trigger the kernel OOM killer), while still
|
|
allowing it to use more than the small BIOS carve-out.
|
|
|
|
Returns (total, used) in bytes, or None if detection fails.
|
|
"""
|
|
sysfs = _get_sysfs_vram()
|
|
if sysfs is None:
|
|
return None
|
|
|
|
sysfs_total, sysfs_used = sysfs
|
|
gtt_total = _get_gtt_size()
|
|
phys_ram = _get_total_physical_ram()
|
|
|
|
if gtt_total is None or phys_ram is None:
|
|
# Can't compute effective VRAM — fall back to sysfs
|
|
return sysfs
|
|
|
|
effective_total = min(gtt_total, phys_ram) - _OS_RESERVE_BYTES
|
|
effective_total = max(effective_total, sysfs_total) # never below carve-out
|
|
|
|
logger.debug(
|
|
"strixhalo_vram_fix: sysfs_total=%d GiB, gtt=%d GiB, "
|
|
"phys_ram=%d GiB, effective=%d GiB",
|
|
sysfs_total // (1024**3),
|
|
gtt_total // (1024**3),
|
|
phys_ram // (1024**3),
|
|
effective_total // (1024**3),
|
|
)
|
|
|
|
return (effective_total, sysfs_used)
|
|
|
|
|
|
_GUARD_ENV = "_STRIXHALO_VRAM_FIX_ACTIVE"
|
|
|
|
|
|
def _should_skip() -> bool:
|
|
"""Check if we should skip the patch (re-entry guard, init containers)."""
|
|
# Re-entry guard: importing torch triggers subprocess calls to
|
|
# offload-arch (a Python script), which re-enters this .pth hook.
|
|
# Without this guard it creates an infinite fork bomb.
|
|
# NOTE: This is only set transiently during _apply_patch() and
|
|
# cleared afterward — child processes will NOT see it.
|
|
if os.environ.get(_GUARD_ENV):
|
|
return True
|
|
|
|
# Check cgroup memory limit — if under 512Mi, skip the expensive
|
|
# torch/ROCm import. KubeRay's wait-gcs-ready init container has
|
|
# only 256Mi and importing torch+ROCm would OOMKill it.
|
|
for cgroup_mem_path in (
|
|
"/sys/fs/cgroup/memory.max", # cgroup v2
|
|
"/sys/fs/cgroup/memory/memory.limit_in_bytes", # cgroup v1
|
|
):
|
|
try:
|
|
with open(cgroup_mem_path) as f:
|
|
val = f.read().strip()
|
|
if val != "max" and int(val) < 512 * 1024 * 1024:
|
|
return True
|
|
except (OSError, ValueError):
|
|
continue
|
|
return False
|
|
|
|
|
|
class _VRAMDeviceProperties:
|
|
"""Proxy that overrides total_memory on torch device properties."""
|
|
|
|
def __init__(self, original, vram_total: int):
|
|
object.__setattr__(self, "_original", original)
|
|
object.__setattr__(self, "_vram_total", vram_total)
|
|
|
|
@property
|
|
def total_memory(self) -> int:
|
|
return object.__getattribute__(self, "_vram_total")
|
|
|
|
def __getattr__(self, name: str):
|
|
return getattr(object.__getattribute__(self, "_original"), name)
|
|
|
|
def __repr__(self) -> str:
|
|
orig = object.__getattribute__(self, "_original")
|
|
vram = object.__getattribute__(self, "_vram_total")
|
|
return repr(orig).replace(
|
|
f"total_memory={orig.total_memory}",
|
|
f"total_memory={vram}",
|
|
)
|
|
|
|
|
|
def _apply_patch() -> None:
|
|
"""Patch torch.cuda memory APIs for correct unified-memory reporting."""
|
|
if _should_skip():
|
|
return
|
|
|
|
if _get_sysfs_vram() is None:
|
|
return
|
|
|
|
# Set guard BEFORE importing torch — torch init spawns offload-arch
|
|
# (a Python script) which would re-enter this .pth hook without it.
|
|
os.environ[_GUARD_ENV] = "1"
|
|
|
|
try:
|
|
import torch
|
|
if not hasattr(torch, "cuda") or not torch.cuda.is_available():
|
|
return
|
|
except (ImportError, OSError):
|
|
return
|
|
finally:
|
|
# CRITICAL: Clear the guard so child processes (vLLM EngineCore
|
|
# subprocesses, Ray actor workers, etc.) can apply their own patch.
|
|
# The guard only needs to live during the torch import above to
|
|
# prevent the offload-arch → .pth → torch import recursion.
|
|
os.environ.pop(_GUARD_ENV, None)
|
|
|
|
vram_info = _get_effective_vram()
|
|
if vram_info is None:
|
|
return
|
|
|
|
effective_total, vram_used = vram_info
|
|
|
|
# Only patch if PyTorch total differs significantly from effective VRAM
|
|
try:
|
|
pt_free, pt_total = torch.cuda.mem_get_info(0)
|
|
except Exception:
|
|
return
|
|
|
|
# If they're within 10% of each other, no patch needed
|
|
if abs(pt_total - effective_total) / max(pt_total, 1) < 0.10:
|
|
return
|
|
|
|
# --- Patch 1: torch.cuda.mem_get_info ---
|
|
original_mem_get_info = torch.cuda.mem_get_info
|
|
|
|
def _patched_mem_get_info(device=None):
|
|
"""Return effective VRAM instead of raw PyTorch numbers."""
|
|
info = _get_effective_vram()
|
|
if info is None:
|
|
return original_mem_get_info(device)
|
|
total, used = info
|
|
# Account for PyTorch's own allocations on top of sysfs baseline
|
|
pt_allocated = torch.cuda.memory_allocated(device or 0)
|
|
free = total - used - pt_allocated
|
|
return (max(free, 0), total)
|
|
|
|
torch.cuda.mem_get_info = _patched_mem_get_info
|
|
|
|
# --- Patch 2: torch.cuda.get_device_properties ---
|
|
# total_memory is a read-only C property, so we wrap the return value
|
|
# in a proxy that overrides it with the effective VRAM total.
|
|
original_get_device_properties = torch.cuda.get_device_properties
|
|
|
|
def _patched_get_device_properties(device=None):
|
|
props = original_get_device_properties(device)
|
|
info = _get_effective_vram()
|
|
if info is None:
|
|
return props
|
|
return _VRAMDeviceProperties(props, info[0])
|
|
|
|
torch.cuda.get_device_properties = _patched_get_device_properties
|
|
|
|
logger.info(
|
|
"strixhalo_vram_fix: patched torch.cuda.mem_get_info and "
|
|
"get_device_properties (PyTorch reported %d GiB total, "
|
|
"effective VRAM is %d GiB)",
|
|
pt_total // (1024**3),
|
|
effective_total // (1024**3),
|
|
)
|
|
|
|
|
|
_apply_patch()
|