fixing vllm.
Some checks failed
Build and Push Images / determine-version (push) Successful in 5s
Build and Push Images / build (Dockerfile.ray-worker-nvidia, nvidia) (push) Failing after 39s
Build and Push Images / build (Dockerfile.ray-worker-intel, intel) (push) Failing after 42s
Build and Push Images / build (Dockerfile.ray-worker-strixhalo, strixhalo) (push) Failing after 20s
Build and Push Images / build (Dockerfile.ray-worker-rdna2, rdna2) (push) Failing after 23s
Build and Push Images / Release (push) Has been skipped
Build and Push Images / Notify (push) Successful in 1s

This commit is contained in:
2026-02-08 16:53:16 -05:00
parent 9042460736
commit f297deca9d

View File

@@ -1,22 +1,28 @@
"""
Monkey-patch torch.cuda.mem_get_info for AMD APUs with unified memory.
Monkey-patch torch.cuda memory reporting for AMD APUs with unified memory.
On Strix Halo (and other AMD APUs), torch.cuda.mem_get_info() returns
GTT (system RAM accessible to GPU) numbers instead of actual VRAM:
- Total: 128 GiB (GTT pool, NOT real VRAM)
- Free: 29 GiB (GTT free, NOT VRAM free)
On Strix Halo (and other AMD APUs), PyTorch reports GTT (system RAM
accessible to GPU) instead of actual VRAM:
torch.cuda.mem_get_info() → (29 GiB free, 128 GiB total) WRONG
torch.cuda.get_device_properties().total_memory → 128 GiB WRONG
Meanwhile sysfs reports the real VRAM:
- Total: 96 GiB (actual dedicated VRAM from BIOS)
- Free: 95.8 GiB (actual VRAM free)
/sys/class/drm/cardN/device/mem_info_vram_total → 96 GiB
/sys/class/drm/cardN/device/mem_info_vram_used → ~0.2 GiB
vLLM uses torch.cuda.mem_get_info() to calculate how much memory to
pre-allocate. With wrong numbers it either OOMs or refuses to start.
vLLM uses mem_get_info() and get_device_properties() to decide how much
memory to pre-allocate. With wrong numbers it either OOMs or refuses to
start ("Free memory less than desired GPU memory utilization").
This module patches mem_get_info to return sysfs VRAM values, which
are the physically meaningful numbers for allocation decisions.
This module patches both APIs to return sysfs VRAM values instead.
Installed as a .pth hook so it runs before any user code.
IMPORTANT: The re-entry guard (_STRIXHALO_VRAM_FIX_ACTIVE env var) is
set only during torch import to prevent infinite recursion when torch
spawns offload-arch. It is CLEARED afterward so child processes (e.g.
vLLM EngineCore subprocesses) can apply their own patch.
"""
import os
@@ -63,6 +69,8 @@ def _should_skip() -> bool:
# Re-entry guard: importing torch triggers subprocess calls to
# offload-arch (a Python script), which re-enters this .pth hook.
# Without this guard it creates an infinite fork bomb.
# NOTE: This is only set transiently during _apply_patch() and
# cleared afterward — child processes will NOT see it.
if os.environ.get(_GUARD_ENV):
return True
@@ -83,8 +91,31 @@ def _should_skip() -> bool:
return False
class _VRAMDeviceProperties:
"""Proxy that overrides total_memory on torch device properties."""
def __init__(self, original, vram_total: int):
object.__setattr__(self, "_original", original)
object.__setattr__(self, "_vram_total", vram_total)
@property
def total_memory(self) -> int:
return object.__getattribute__(self, "_vram_total")
def __getattr__(self, name: str):
return getattr(object.__getattribute__(self, "_original"), name)
def __repr__(self) -> str:
orig = object.__getattribute__(self, "_original")
vram = object.__getattribute__(self, "_vram_total")
return repr(orig).replace(
f"total_memory={orig.total_memory}",
f"total_memory={vram}",
)
def _apply_patch() -> None:
"""Patch torch.cuda.mem_get_info if we detect unified memory mis-reporting."""
"""Patch torch.cuda memory APIs if we detect unified memory mis-reporting."""
if _should_skip():
return
@@ -101,13 +132,18 @@ def _apply_patch() -> None:
return
except ImportError:
return
finally:
# CRITICAL: Clear the guard so child processes (vLLM EngineCore
# subprocesses, Ray actor workers, etc.) can apply their own patch.
# The guard only needs to live during the torch import above to
# prevent the offload-arch → .pth → torch import recursion.
os.environ.pop(_GUARD_ENV, None)
vram_info = _get_real_vram()
if vram_info is None:
return
vram_total, vram_used = vram_info
vram_free = vram_total - vram_used
# Only patch if PyTorch total differs significantly from sysfs VRAM
# (i.e. PyTorch is reporting GTT/unified memory, not real VRAM)
@@ -120,13 +156,14 @@ def _apply_patch() -> None:
if abs(pt_total - vram_total) / max(pt_total, 1) < 0.10:
return
original_fn = torch.cuda.mem_get_info
# --- Patch 1: torch.cuda.mem_get_info ---
original_mem_get_info = torch.cuda.mem_get_info
def _patched_mem_get_info(device=None):
"""Return real VRAM from sysfs instead of GTT numbers."""
real = _get_real_vram()
if real is None:
return original_fn(device)
return original_mem_get_info(device)
total, used = real
# Account for PyTorch's own allocations on top of sysfs baseline
pt_allocated = torch.cuda.memory_allocated(device or 0)
@@ -134,9 +171,25 @@ def _apply_patch() -> None:
return (max(free, 0), total)
torch.cuda.mem_get_info = _patched_mem_get_info
# --- Patch 2: torch.cuda.get_device_properties ---
# total_memory is a read-only C property, so we wrap the return value
# in a proxy that overrides it with the real VRAM total.
original_get_device_properties = torch.cuda.get_device_properties
def _patched_get_device_properties(device=None):
props = original_get_device_properties(device)
real = _get_real_vram()
if real is None:
return props
return _VRAMDeviceProperties(props, real[0])
torch.cuda.get_device_properties = _patched_get_device_properties
logger.info(
"strixhalo_vram_fix: patched torch.cuda.mem_get_info "
"(PyTorch reported %d GiB total, sysfs VRAM is %d GiB)",
"strixhalo_vram_fix: patched torch.cuda.mem_get_info and "
"get_device_properties (PyTorch reported %d GiB total, "
"sysfs VRAM is %d GiB)",
pt_total // (1024**3),
vram_total // (1024**3),
)