feat(strixhalo): patch torch.cuda.mem_get_info for unified memory APU
Some checks failed
Build and Push Images / determine-version (push) Successful in 4s
Build and Push Images / build (Dockerfile.ray-worker-nvidia, nvidia) (push) Failing after 25s
Build and Push Images / build (Dockerfile.ray-worker-intel, intel) (push) Failing after 28s
Build and Push Images / build (Dockerfile.ray-worker-strixhalo, strixhalo) (push) Failing after 23s
Build and Push Images / build (Dockerfile.ray-worker-rdna2, rdna2) (push) Failing after 26s
Build and Push Images / Release (push) Has been skipped
Build and Push Images / Notify (push) Successful in 1s
Some checks failed
Build and Push Images / determine-version (push) Successful in 4s
Build and Push Images / build (Dockerfile.ray-worker-nvidia, nvidia) (push) Failing after 25s
Build and Push Images / build (Dockerfile.ray-worker-intel, intel) (push) Failing after 28s
Build and Push Images / build (Dockerfile.ray-worker-strixhalo, strixhalo) (push) Failing after 23s
Build and Push Images / build (Dockerfile.ray-worker-rdna2, rdna2) (push) Failing after 26s
Build and Push Images / Release (push) Has been skipped
Build and Push Images / Notify (push) Successful in 1s
On Strix Halo, PyTorch reports GTT pool (128 GiB) as device memory instead of real VRAM (96 GiB from BIOS). vLLM uses mem_get_info() to pre-allocate and refuses to start when free GTT (29 GiB) < requested. The strixhalo_vram_fix.pth hook auto-patches mem_get_info on Python startup to read real VRAM total/used from /sys/class/drm sysfs. Only activates when PyTorch total differs >10% from sysfs VRAM.
This commit is contained in:
@@ -93,6 +93,16 @@ COPY --chown=1000:100 amdsmi-shim /tmp/amdsmi-shim
|
||||
RUN --mount=type=cache,target=/home/ray/.cache/uv,uid=1000,gid=1000 \
|
||||
uv pip install --system /tmp/amdsmi-shim && rm -rf /tmp/amdsmi-shim
|
||||
|
||||
# FIX: Patch torch.cuda.mem_get_info for unified memory APUs.
|
||||
# On Strix Halo, PyTorch reports GTT (128 GiB) instead of real VRAM (96 GiB)
|
||||
# from sysfs. vLLM uses mem_get_info to pre-allocate, so wrong numbers cause
|
||||
# OOM or "insufficient GPU memory" at startup. The .pth file auto-patches
|
||||
# mem_get_info on Python startup to return sysfs VRAM values.
|
||||
COPY --chown=1000:100 amdsmi-shim/strixhalo_vram_fix.py \
|
||||
/home/ray/anaconda3/lib/python3.11/site-packages/strixhalo_vram_fix.py
|
||||
RUN echo "import strixhalo_vram_fix" > \
|
||||
/home/ray/anaconda3/lib/python3.11/site-packages/strixhalo_vram_fix.pth
|
||||
|
||||
# Pre-download common models for faster cold starts (optional, increases image size)
|
||||
# RUN python3 -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('BAAI/bge-large-en-v1.5')"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user