""" Pure-Python amdsmi shim using sysfs for ROCm platform detection. Replaces the native amdsmi package (which requires libamd_smi.so linked against glibc 2.38) with a lightweight implementation that reads GPU information from sysfs (/sys/class/drm/*). This is needed because the Ray base image (Ubuntu 22.04, glibc 2.35) cannot load the ROCm 7.1 vendor libamd_smi.so (built for Ubuntu 24.04, glibc 2.38). Provides the subset of the amdsmi API used by: - vLLM (platform detection, device info, topology) - PyTorch (device counting, memory/power/temp monitoring) """ from __future__ import annotations import glob import logging import os from enum import IntEnum from pathlib import Path from typing import Any logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Exception # --------------------------------------------------------------------------- class AmdSmiException(Exception): """Shim for amdsmi.AmdSmiException.""" def __init__(self, message: str = "", err_code: int = -1): super().__init__(message) self.err_code = err_code # --------------------------------------------------------------------------- # Enums used by PyTorch monitoring helpers # --------------------------------------------------------------------------- class AmdSmiTemperatureType(IntEnum): EDGE = 0 JUNCTION = 1 VRAM = 2 class AmdSmiTemperatureMetric(IntEnum): CURRENT = 0 MAX = 1 MIN = 2 class AmdSmiClkType(IntEnum): GFX = 0 MEM = 1 # --------------------------------------------------------------------------- # sysfs helpers # --------------------------------------------------------------------------- _DRM_BASE = "/sys/class/drm" def _read_sysfs(path: str, default: str = "") -> str: """Read a sysfs file, returning *default* on any error.""" try: return Path(path).read_text().strip() except (OSError, IOError): return default def _read_sysfs_int(path: str, default: int = 0) -> int: val = _read_sysfs(path) try: if val.startswith("0x"): return int(val, 16) return int(val) except (ValueError, TypeError): return default def _discover_cards() -> list[str]: """Return sorted list of card directory names, e.g. ['card0', 'card1'].""" cards = [] for entry in sorted(glob.glob(os.path.join(_DRM_BASE, "card[0-9]*"))): name = os.path.basename(entry) # Exclude render nodes and connector entries like card0-DP-1 if name.startswith("card") and "-" not in name: vendor = _read_sysfs(os.path.join(entry, "device", "vendor")) # 0x1002 = AMD if vendor == "0x1002": cards.append(name) return cards def _find_hwmon(card: str) -> str | None: """Return the first hwmon directory path for a card, or None.""" hwmon_base = os.path.join(_DRM_BASE, card, "device", "hwmon") entries = sorted(glob.glob(os.path.join(hwmon_base, "hwmon*"))) return entries[0] if entries else None # --------------------------------------------------------------------------- # Handle management # --------------------------------------------------------------------------- # Each "handle" is just a dict with cached paths _handles: list[dict[str, Any]] = [] _initialized = False def amdsmi_init() -> None: """Initialise the shim by discovering AMD GPUs via sysfs.""" global _handles, _initialized if _initialized: return cards = _discover_cards() _handles = [] for card in cards: dev = os.path.join(_DRM_BASE, card, "device") _handles.append( { "card": card, "dev": dev, "hwmon": _find_hwmon(card), "vendor": _read_sysfs(os.path.join(dev, "vendor")), "device_id": _read_sysfs(os.path.join(dev, "device")), } ) _initialized = True logger.debug("amdsmi-shim: discovered %d AMD GPU(s) via sysfs", len(_handles)) def amdsmi_shut_down() -> None: """Reset internal state.""" global _handles, _initialized _handles = [] _initialized = False def amdsmi_get_processor_handles() -> list[dict[str, Any]]: """Return a list of opaque handles (one per GPU).""" if not _initialized: raise AmdSmiException("amdsmi not initialised", err_code=1) return list(_handles) # --------------------------------------------------------------------------- # Device info (used by vLLM get_device_name + PyTorch UUID helpers) # --------------------------------------------------------------------------- def amdsmi_get_gpu_asic_info(handle: dict[str, Any]) -> dict[str, str]: """ Return ASIC information for a GPU handle. Keys returned: device_id, market_name, asic_serial, vendor_id, rev_id """ dev = handle["dev"] device_id_hex = _read_sysfs(os.path.join(dev, "device"), "0x0000") # Read marketing / product name (may not exist on all kernels) market = ( _read_sysfs(os.path.join(dev, "product_name")) or _read_sysfs(os.path.join(dev, "marketing_name")) or f"AMD GPU {device_id_hex}" ) # Unique ID for serial (not always available) serial = _read_sysfs(os.path.join(dev, "unique_id"), "0000000000000000") return { "device_id": device_id_hex, "market_name": market, "asic_serial": serial, "vendor_id": _read_sysfs(os.path.join(dev, "vendor"), "0x1002"), "rev_id": _read_sysfs(os.path.join(dev, "revision"), "0x00"), } # --------------------------------------------------------------------------- # Topology (used by vLLM is_fully_connected) # --------------------------------------------------------------------------- def amdsmi_topo_get_link_type( handle: dict[str, Any], peer_handle: dict[str, Any] ) -> dict[str, int]: """ Return link topology between two GPUs. For single-GPU systems this is essentially a no-op. type: 2 = XGMI, 1 = PCIe; hops = number of hops. """ if handle is peer_handle or handle["card"] == peer_handle["card"]: return {"type": 2, "hops": 0} # same device # Default: not XGMI-connected return {"type": 1, "hops": 2} # --------------------------------------------------------------------------- # Memory (used by PyTorch _get_amdsmi_device_memory_used) # --------------------------------------------------------------------------- def amdsmi_get_gpu_vram_usage(handle: dict[str, Any]) -> dict[str, int]: """ Return VRAM usage in megabytes. Keys: vram_used, vram_total """ dev = handle["dev"] total_bytes = _read_sysfs_int(os.path.join(dev, "mem_info_vram_total")) used_bytes = _read_sysfs_int(os.path.join(dev, "mem_info_vram_used")) return { "vram_used": used_bytes // (1024 * 1024), "vram_total": total_bytes // (1024 * 1024), } # --------------------------------------------------------------------------- # Activity (used by PyTorch _get_amdsmi_utilization / memory_usage) # --------------------------------------------------------------------------- def amdsmi_get_gpu_activity(handle: dict[str, Any]) -> dict[str, int]: """ Return GPU and memory controller activity percentages. Keys: gfx_activity, umc_activity, mm_activity """ dev = handle["dev"] gfx = _read_sysfs_int(os.path.join(dev, "gpu_busy_percent")) # mem_busy_percent is not always available umc = _read_sysfs_int(os.path.join(dev, "mem_busy_percent")) return {"gfx_activity": gfx, "umc_activity": umc, "mm_activity": 0} # --------------------------------------------------------------------------- # Power (used by PyTorch _get_amdsmi_power_draw) # --------------------------------------------------------------------------- def amdsmi_get_power_info(handle: dict[str, Any]) -> dict[str, int]: """ Return power info in watts. Keys: average_socket_power, current_socket_power """ hwmon = handle.get("hwmon") if not hwmon: return {"average_socket_power": 0, "current_socket_power": 0} # sysfs reports in microwatts avg_uw = _read_sysfs_int(os.path.join(hwmon, "power1_average")) cur_uw = _read_sysfs_int(os.path.join(hwmon, "power1_input"), avg_uw) return { "average_socket_power": avg_uw // 1_000_000, "current_socket_power": cur_uw // 1_000_000, } # --------------------------------------------------------------------------- # Temperature (used by PyTorch _get_amdsmi_temperature) # --------------------------------------------------------------------------- def amdsmi_get_temp_metric( handle: dict[str, Any], sensor_type: AmdSmiTemperatureType = AmdSmiTemperatureType.JUNCTION, metric: AmdSmiTemperatureMetric = AmdSmiTemperatureMetric.CURRENT, ) -> int: """ Return temperature in millidegrees Celsius. PyTorch uses JUNCTION / CURRENT and the result is divided by 1000. """ hwmon = handle.get("hwmon") if not hwmon: return 0 # Map sensor type to sysfs file: temp1=edge, temp2=junction, temp3=vram idx = { AmdSmiTemperatureType.EDGE: 1, AmdSmiTemperatureType.JUNCTION: 2, AmdSmiTemperatureType.VRAM: 3, }.get(sensor_type, 1) suffix = { AmdSmiTemperatureMetric.CURRENT: "input", AmdSmiTemperatureMetric.MAX: "max", AmdSmiTemperatureMetric.MIN: "min", }.get(metric, "input") # sysfs already reports millidegrees return _read_sysfs_int(os.path.join(hwmon, f"temp{idx}_{suffix}")) # --------------------------------------------------------------------------- # Clock (used by PyTorch _get_amdsmi_clock_rate via vLLM reference) # --------------------------------------------------------------------------- def amdsmi_get_clock_info( handle: dict[str, Any], clk_type: AmdSmiClkType = AmdSmiClkType.GFX, ) -> dict[str, int]: """ Return clock info in MHz. Keys: clk (current), max_clk, min_clk """ dev = handle["dev"] sysfs_file = { AmdSmiClkType.GFX: "pp_dpm_sclk", AmdSmiClkType.MEM: "pp_dpm_mclk", }.get(clk_type, "pp_dpm_sclk") raw = _read_sysfs(os.path.join(dev, sysfs_file)) current_mhz = 0 min_mhz = 0 max_mhz = 0 for line in raw.splitlines(): # Lines look like: "0: 600Mhz *" parts = line.split() if len(parts) >= 2: try: freq = int(parts[1].replace("Mhz", "").replace("MHz", "")) except ValueError: continue if min_mhz == 0 or freq < min_mhz: min_mhz = freq if freq > max_mhz: max_mhz = freq if "*" in line: current_mhz = freq return {"clk": current_mhz, "max_clk": max_mhz, "min_clk": min_mhz}