feat(strixhalo): patch torch.cuda.mem_get_info for unified memory APU
Some checks failed
Build and Push Images / determine-version (push) Successful in 4s
Build and Push Images / build (Dockerfile.ray-worker-nvidia, nvidia) (push) Failing after 25s
Build and Push Images / build (Dockerfile.ray-worker-intel, intel) (push) Failing after 28s
Build and Push Images / build (Dockerfile.ray-worker-strixhalo, strixhalo) (push) Failing after 23s
Build and Push Images / build (Dockerfile.ray-worker-rdna2, rdna2) (push) Failing after 26s
Build and Push Images / Release (push) Has been skipped
Build and Push Images / Notify (push) Successful in 1s
Some checks failed
Build and Push Images / determine-version (push) Successful in 4s
Build and Push Images / build (Dockerfile.ray-worker-nvidia, nvidia) (push) Failing after 25s
Build and Push Images / build (Dockerfile.ray-worker-intel, intel) (push) Failing after 28s
Build and Push Images / build (Dockerfile.ray-worker-strixhalo, strixhalo) (push) Failing after 23s
Build and Push Images / build (Dockerfile.ray-worker-rdna2, rdna2) (push) Failing after 26s
Build and Push Images / Release (push) Has been skipped
Build and Push Images / Notify (push) Successful in 1s
On Strix Halo, PyTorch reports GTT pool (128 GiB) as device memory instead of real VRAM (96 GiB from BIOS). vLLM uses mem_get_info() to pre-allocate and refuses to start when free GTT (29 GiB) < requested. The strixhalo_vram_fix.pth hook auto-patches mem_get_info on Python startup to read real VRAM total/used from /sys/class/drm sysfs. Only activates when PyTorch total differs >10% from sysfs VRAM.
This commit is contained in:
107
amdsmi-shim/strixhalo_vram_fix.py
Normal file
107
amdsmi-shim/strixhalo_vram_fix.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Monkey-patch torch.cuda.mem_get_info for AMD APUs with unified memory.
|
||||
|
||||
On Strix Halo (and other AMD APUs), torch.cuda.mem_get_info() returns
|
||||
GTT (system RAM accessible to GPU) numbers instead of actual VRAM:
|
||||
- Total: 128 GiB (GTT pool, NOT real VRAM)
|
||||
- Free: 29 GiB (GTT free, NOT VRAM free)
|
||||
|
||||
Meanwhile sysfs reports the real VRAM:
|
||||
- Total: 96 GiB (actual dedicated VRAM from BIOS)
|
||||
- Free: 95.8 GiB (actual VRAM free)
|
||||
|
||||
vLLM uses torch.cuda.mem_get_info() to calculate how much memory to
|
||||
pre-allocate. With wrong numbers it either OOMs or refuses to start.
|
||||
|
||||
This module patches mem_get_info to return sysfs VRAM values, which
|
||||
are the physically meaningful numbers for allocation decisions.
|
||||
|
||||
Installed as a .pth hook so it runs before any user code.
|
||||
"""
|
||||
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("strixhalo_vram_fix")
|
||||
|
||||
|
||||
def _read_sysfs_int(path: str) -> int | None:
|
||||
try:
|
||||
with open(path) as f:
|
||||
return int(f.read().strip())
|
||||
except (OSError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _get_real_vram() -> tuple[int, int] | None:
|
||||
"""Read real VRAM total/used from sysfs for the first AMD GPU."""
|
||||
for card_dir in sorted(glob.glob("/sys/class/drm/card[0-9]*/device")):
|
||||
vendor_path = os.path.join(card_dir, "vendor")
|
||||
if not os.path.exists(vendor_path):
|
||||
continue
|
||||
try:
|
||||
with open(vendor_path) as f:
|
||||
vendor = f.read().strip()
|
||||
except OSError:
|
||||
continue
|
||||
if vendor != "0x1002": # AMD
|
||||
continue
|
||||
|
||||
total = _read_sysfs_int(os.path.join(card_dir, "mem_info_vram_total"))
|
||||
used = _read_sysfs_int(os.path.join(card_dir, "mem_info_vram_used"))
|
||||
if total is not None and used is not None:
|
||||
return (total, used)
|
||||
return None
|
||||
|
||||
|
||||
def _apply_patch() -> None:
|
||||
"""Patch torch.cuda.mem_get_info if we detect unified memory mis-reporting."""
|
||||
try:
|
||||
import torch
|
||||
if not hasattr(torch, "cuda") or not torch.cuda.is_available():
|
||||
return
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
vram_info = _get_real_vram()
|
||||
if vram_info is None:
|
||||
return
|
||||
|
||||
vram_total, vram_used = vram_info
|
||||
vram_free = vram_total - vram_used
|
||||
|
||||
# Only patch if PyTorch total differs significantly from sysfs VRAM
|
||||
# (i.e. PyTorch is reporting GTT/unified memory, not real VRAM)
|
||||
try:
|
||||
pt_free, pt_total = torch.cuda.mem_get_info(0)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
# If they're within 10% of each other, no patch needed
|
||||
if abs(pt_total - vram_total) / max(pt_total, 1) < 0.10:
|
||||
return
|
||||
|
||||
original_fn = torch.cuda.mem_get_info
|
||||
|
||||
def _patched_mem_get_info(device=None):
|
||||
"""Return real VRAM from sysfs instead of GTT numbers."""
|
||||
real = _get_real_vram()
|
||||
if real is None:
|
||||
return original_fn(device)
|
||||
total, used = real
|
||||
# Account for PyTorch's own allocations on top of sysfs baseline
|
||||
pt_allocated = torch.cuda.memory_allocated(device or 0)
|
||||
free = total - used - pt_allocated
|
||||
return (max(free, 0), total)
|
||||
|
||||
torch.cuda.mem_get_info = _patched_mem_get_info
|
||||
logger.info(
|
||||
"strixhalo_vram_fix: patched torch.cuda.mem_get_info "
|
||||
"(PyTorch reported %d GiB total, sysfs VRAM is %d GiB)",
|
||||
pt_total // (1024**3),
|
||||
vram_total // (1024**3),
|
||||
)
|
||||
|
||||
|
||||
_apply_patch()
|
||||
Reference in New Issue
Block a user