diff --git a/amdsmi-shim/strixhalo_vram_fix.py b/amdsmi-shim/strixhalo_vram_fix.py index e296d7c..daf7bb0 100644 --- a/amdsmi-shim/strixhalo_vram_fix.py +++ b/amdsmi-shim/strixhalo_vram_fix.py @@ -55,8 +55,17 @@ def _get_real_vram() -> tuple[int, int] | None: return None +_GUARD_ENV = "_STRIXHALO_VRAM_FIX_ACTIVE" + + def _should_skip() -> bool: - """Check if we should skip the patch (lightweight/init containers).""" + """Check if we should skip the patch (re-entry guard, init containers).""" + # Re-entry guard: importing torch triggers subprocess calls to + # offload-arch (a Python script), which re-enters this .pth hook. + # Without this guard it creates an infinite fork bomb. + if os.environ.get(_GUARD_ENV): + return True + # Check cgroup memory limit — if under 512Mi, skip the expensive # torch/ROCm import. KubeRay's wait-gcs-ready init container has # only 256Mi and importing torch+ROCm would OOMKill it. @@ -82,6 +91,10 @@ def _apply_patch() -> None: if _get_real_vram() is None: return + # Set guard BEFORE importing torch — torch init spawns offload-arch + # (a Python script) which would re-enter this .pth hook without it. + os.environ[_GUARD_ENV] = "1" + try: import torch if not hasattr(torch, "cuda") or not torch.cuda.is_available():