Files
kuberay-images/amdsmi-shim/strixhalo_vram_fix.py
Billy D. 9042460736
Some checks failed
Build and Push Images / determine-version (push) Successful in 5s
Build and Push Images / build (Dockerfile.ray-worker-nvidia, nvidia) (push) Failing after 25s
Build and Push Images / build (Dockerfile.ray-worker-intel, intel) (push) Failing after 27s
Build and Push Images / build (Dockerfile.ray-worker-strixhalo, strixhalo) (push) Failing after 22s
Build and Push Images / build (Dockerfile.ray-worker-rdna2, rdna2) (push) Failing after 25s
Build and Push Images / Release (push) Has been skipped
Build and Push Images / Notify (push) Successful in 2s
fix(strixhalo): add re-entry guard to prevent offload-arch fork bomb
torch init calls offload-arch (a Python script) which re-enters the
.pth hook, triggering another import torch, creating an infinite fork
storm (1000+ processes). Set _STRIXHALO_VRAM_FIX_ACTIVE env var before
importing torch so child processes skip the patch.
2026-02-07 08:47:06 -05:00

146 lines
4.6 KiB
Python

"""
Monkey-patch torch.cuda.mem_get_info for AMD APUs with unified memory.
On Strix Halo (and other AMD APUs), torch.cuda.mem_get_info() returns
GTT (system RAM accessible to GPU) numbers instead of actual VRAM:
- Total: 128 GiB (GTT pool, NOT real VRAM)
- Free: 29 GiB (GTT free, NOT VRAM free)
Meanwhile sysfs reports the real VRAM:
- Total: 96 GiB (actual dedicated VRAM from BIOS)
- Free: 95.8 GiB (actual VRAM free)
vLLM uses torch.cuda.mem_get_info() to calculate how much memory to
pre-allocate. With wrong numbers it either OOMs or refuses to start.
This module patches mem_get_info to return sysfs VRAM values, which
are the physically meaningful numbers for allocation decisions.
Installed as a .pth hook so it runs before any user code.
"""
import os
import glob
import logging
logger = logging.getLogger("strixhalo_vram_fix")
def _read_sysfs_int(path: str) -> int | None:
try:
with open(path) as f:
return int(f.read().strip())
except (OSError, ValueError):
return None
def _get_real_vram() -> tuple[int, int] | None:
"""Read real VRAM total/used from sysfs for the first AMD GPU."""
for card_dir in sorted(glob.glob("/sys/class/drm/card[0-9]*/device")):
vendor_path = os.path.join(card_dir, "vendor")
if not os.path.exists(vendor_path):
continue
try:
with open(vendor_path) as f:
vendor = f.read().strip()
except OSError:
continue
if vendor != "0x1002": # AMD
continue
total = _read_sysfs_int(os.path.join(card_dir, "mem_info_vram_total"))
used = _read_sysfs_int(os.path.join(card_dir, "mem_info_vram_used"))
if total is not None and used is not None:
return (total, used)
return None
_GUARD_ENV = "_STRIXHALO_VRAM_FIX_ACTIVE"
def _should_skip() -> bool:
"""Check if we should skip the patch (re-entry guard, init containers)."""
# Re-entry guard: importing torch triggers subprocess calls to
# offload-arch (a Python script), which re-enters this .pth hook.
# Without this guard it creates an infinite fork bomb.
if os.environ.get(_GUARD_ENV):
return True
# Check cgroup memory limit — if under 512Mi, skip the expensive
# torch/ROCm import. KubeRay's wait-gcs-ready init container has
# only 256Mi and importing torch+ROCm would OOMKill it.
for cgroup_mem_path in (
"/sys/fs/cgroup/memory.max", # cgroup v2
"/sys/fs/cgroup/memory/memory.limit_in_bytes", # cgroup v1
):
try:
with open(cgroup_mem_path) as f:
val = f.read().strip()
if val != "max" and int(val) < 512 * 1024 * 1024:
return True
except (OSError, ValueError):
continue
return False
def _apply_patch() -> None:
"""Patch torch.cuda.mem_get_info if we detect unified memory mis-reporting."""
if _should_skip():
return
if _get_real_vram() is None:
return
# Set guard BEFORE importing torch — torch init spawns offload-arch
# (a Python script) which would re-enter this .pth hook without it.
os.environ[_GUARD_ENV] = "1"
try:
import torch
if not hasattr(torch, "cuda") or not torch.cuda.is_available():
return
except ImportError:
return
vram_info = _get_real_vram()
if vram_info is None:
return
vram_total, vram_used = vram_info
vram_free = vram_total - vram_used
# Only patch if PyTorch total differs significantly from sysfs VRAM
# (i.e. PyTorch is reporting GTT/unified memory, not real VRAM)
try:
pt_free, pt_total = torch.cuda.mem_get_info(0)
except Exception:
return
# If they're within 10% of each other, no patch needed
if abs(pt_total - vram_total) / max(pt_total, 1) < 0.10:
return
original_fn = torch.cuda.mem_get_info
def _patched_mem_get_info(device=None):
"""Return real VRAM from sysfs instead of GTT numbers."""
real = _get_real_vram()
if real is None:
return original_fn(device)
total, used = real
# Account for PyTorch's own allocations on top of sysfs baseline
pt_allocated = torch.cuda.memory_allocated(device or 0)
free = total - used - pt_allocated
return (max(free, 0), total)
torch.cuda.mem_get_info = _patched_mem_get_info
logger.info(
"strixhalo_vram_fix: patched torch.cuda.mem_get_info "
"(PyTorch reported %d GiB total, sysfs VRAM is %d GiB)",
pt_total // (1024**3),
vram_total // (1024**3),
)
_apply_patch()