From d1b6d78c66f2013707d4635e1457de871ce4ef96 Mon Sep 17 00:00:00 2001 From: "Billy D." Date: Fri, 6 Feb 2026 19:15:49 -0500 Subject: [PATCH] fix(strixhalo): skip VRAM patch in low-memory init containers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit KubeRay's auto-injected wait-gcs-ready init container has only 256Mi memory limit. The .pth hook was unconditionally importing torch+ROCm which requires >256Mi, causing OOMKill. Now checks cgroup memory limit first — if under 512Mi, skips the expensive torch import entirely. The VRAM patch is only needed by the main Ray worker process, not by health-check init containers. --- amdsmi-shim/strixhalo_vram_fix.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/amdsmi-shim/strixhalo_vram_fix.py b/amdsmi-shim/strixhalo_vram_fix.py index 4002289..e296d7c 100644 --- a/amdsmi-shim/strixhalo_vram_fix.py +++ b/amdsmi-shim/strixhalo_vram_fix.py @@ -55,8 +55,33 @@ def _get_real_vram() -> tuple[int, int] | None: return None +def _should_skip() -> bool: + """Check if we should skip the patch (lightweight/init containers).""" + # Check cgroup memory limit — if under 512Mi, skip the expensive + # torch/ROCm import. KubeRay's wait-gcs-ready init container has + # only 256Mi and importing torch+ROCm would OOMKill it. + for cgroup_mem_path in ( + "/sys/fs/cgroup/memory.max", # cgroup v2 + "/sys/fs/cgroup/memory/memory.limit_in_bytes", # cgroup v1 + ): + try: + with open(cgroup_mem_path) as f: + val = f.read().strip() + if val != "max" and int(val) < 512 * 1024 * 1024: + return True + except (OSError, ValueError): + continue + return False + + def _apply_patch() -> None: """Patch torch.cuda.mem_get_info if we detect unified memory mis-reporting.""" + if _should_skip(): + return + + if _get_real_vram() is None: + return + try: import torch if not hasattr(torch, "cuda") or not torch.cuda.is_available():