Files
kuberay-images/amdsmi-shim/strixhalo_vram_fix.py
Billy D. a20a5d2ccd
Some checks failed
Build and Push Images / determine-version (push) Successful in 6s
Build and Push Images / Release (push) Has been cancelled
Build and Push Images / Notify (push) Has been cancelled
Build and Push Images / build (Dockerfile.ray-worker-strixhalo, strixhalo) (push) Has been cancelled
Build and Push Images / build (Dockerfile.ray-worker-nvidia, nvidia) (push) Failing after 49s
Build and Push Images / build (Dockerfile.ray-worker-intel, intel) (push) Failing after 1m25s
Build and Push Images / build (Dockerfile.ray-worker-rdna2, rdna2) (push) Has been cancelled
mo fixes.
2026-02-09 11:46:10 -05:00

199 lines
6.8 KiB
Python

"""
Monkey-patch torch.cuda memory reporting for AMD APUs with unified memory.
On Strix Halo (and other AMD APUs), PyTorch reports GTT (system RAM
accessible to GPU) instead of actual VRAM:
torch.cuda.mem_get_info() → (29 GiB free, 128 GiB total) WRONG
torch.cuda.get_device_properties().total_memory → 128 GiB WRONG
Meanwhile sysfs reports the real VRAM:
/sys/class/drm/cardN/device/mem_info_vram_total → 96 GiB
/sys/class/drm/cardN/device/mem_info_vram_used → ~0.2 GiB
vLLM uses mem_get_info() and get_device_properties() to decide how much
memory to pre-allocate. With wrong numbers it either OOMs or refuses to
start ("Free memory less than desired GPU memory utilization").
This module patches both APIs to return sysfs VRAM values instead.
Installed as a .pth hook so it runs before any user code.
IMPORTANT: The re-entry guard (_STRIXHALO_VRAM_FIX_ACTIVE env var) is
set only during torch import to prevent infinite recursion when torch
spawns offload-arch. It is CLEARED afterward so child processes (e.g.
vLLM EngineCore subprocesses) can apply their own patch.
"""
import os
import glob
import logging
logger = logging.getLogger("strixhalo_vram_fix")
def _read_sysfs_int(path: str) -> int | None:
try:
with open(path) as f:
return int(f.read().strip())
except (OSError, ValueError):
return None
def _get_real_vram() -> tuple[int, int] | None:
"""Read real VRAM total/used from sysfs for the first AMD GPU."""
for card_dir in sorted(glob.glob("/sys/class/drm/card[0-9]*/device")):
vendor_path = os.path.join(card_dir, "vendor")
if not os.path.exists(vendor_path):
continue
try:
with open(vendor_path) as f:
vendor = f.read().strip()
except OSError:
continue
if vendor != "0x1002": # AMD
continue
total = _read_sysfs_int(os.path.join(card_dir, "mem_info_vram_total"))
used = _read_sysfs_int(os.path.join(card_dir, "mem_info_vram_used"))
if total is not None and used is not None:
return (total, used)
return None
_GUARD_ENV = "_STRIXHALO_VRAM_FIX_ACTIVE"
def _should_skip() -> bool:
"""Check if we should skip the patch (re-entry guard, init containers)."""
# Re-entry guard: importing torch triggers subprocess calls to
# offload-arch (a Python script), which re-enters this .pth hook.
# Without this guard it creates an infinite fork bomb.
# NOTE: This is only set transiently during _apply_patch() and
# cleared afterward — child processes will NOT see it.
if os.environ.get(_GUARD_ENV):
return True
# Check cgroup memory limit — if under 512Mi, skip the expensive
# torch/ROCm import. KubeRay's wait-gcs-ready init container has
# only 256Mi and importing torch+ROCm would OOMKill it.
for cgroup_mem_path in (
"/sys/fs/cgroup/memory.max", # cgroup v2
"/sys/fs/cgroup/memory/memory.limit_in_bytes", # cgroup v1
):
try:
with open(cgroup_mem_path) as f:
val = f.read().strip()
if val != "max" and int(val) < 512 * 1024 * 1024:
return True
except (OSError, ValueError):
continue
return False
class _VRAMDeviceProperties:
"""Proxy that overrides total_memory on torch device properties."""
def __init__(self, original, vram_total: int):
object.__setattr__(self, "_original", original)
object.__setattr__(self, "_vram_total", vram_total)
@property
def total_memory(self) -> int:
return object.__getattribute__(self, "_vram_total")
def __getattr__(self, name: str):
return getattr(object.__getattribute__(self, "_original"), name)
def __repr__(self) -> str:
orig = object.__getattribute__(self, "_original")
vram = object.__getattribute__(self, "_vram_total")
return repr(orig).replace(
f"total_memory={orig.total_memory}",
f"total_memory={vram}",
)
def _apply_patch() -> None:
"""Patch torch.cuda memory APIs if we detect unified memory mis-reporting."""
if _should_skip():
return
if _get_real_vram() is None:
return
# Set guard BEFORE importing torch — torch init spawns offload-arch
# (a Python script) which would re-enter this .pth hook without it.
os.environ[_GUARD_ENV] = "1"
try:
import torch
if not hasattr(torch, "cuda") or not torch.cuda.is_available():
return
except (ImportError, OSError):
return
finally:
# CRITICAL: Clear the guard so child processes (vLLM EngineCore
# subprocesses, Ray actor workers, etc.) can apply their own patch.
# The guard only needs to live during the torch import above to
# prevent the offload-arch → .pth → torch import recursion.
os.environ.pop(_GUARD_ENV, None)
vram_info = _get_real_vram()
if vram_info is None:
return
vram_total, vram_used = vram_info
# Only patch if PyTorch total differs significantly from sysfs VRAM
# (i.e. PyTorch is reporting GTT/unified memory, not real VRAM)
try:
pt_free, pt_total = torch.cuda.mem_get_info(0)
except Exception:
return
# If they're within 10% of each other, no patch needed
if abs(pt_total - vram_total) / max(pt_total, 1) < 0.10:
return
# --- Patch 1: torch.cuda.mem_get_info ---
original_mem_get_info = torch.cuda.mem_get_info
def _patched_mem_get_info(device=None):
"""Return real VRAM from sysfs instead of GTT numbers."""
real = _get_real_vram()
if real is None:
return original_mem_get_info(device)
total, used = real
# Account for PyTorch's own allocations on top of sysfs baseline
pt_allocated = torch.cuda.memory_allocated(device or 0)
free = total - used - pt_allocated
return (max(free, 0), total)
torch.cuda.mem_get_info = _patched_mem_get_info
# --- Patch 2: torch.cuda.get_device_properties ---
# total_memory is a read-only C property, so we wrap the return value
# in a proxy that overrides it with the real VRAM total.
original_get_device_properties = torch.cuda.get_device_properties
def _patched_get_device_properties(device=None):
props = original_get_device_properties(device)
real = _get_real_vram()
if real is None:
return props
return _VRAMDeviceProperties(props, real[0])
torch.cuda.get_device_properties = _patched_get_device_properties
logger.info(
"strixhalo_vram_fix: patched torch.cuda.mem_get_info and "
"get_device_properties (PyTorch reported %d GiB total, "
"sysfs VRAM is %d GiB)",
pt_total // (1024**3),
vram_total // (1024**3),
)
_apply_patch()