usb: multiple gpus and better error messages (#15900)

This commit is contained in:
Christopher Milan 2026-04-23 22:57:19 -07:00 committed by GitHub
commit cbf4946ea6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 76 additions and 47 deletions

View file

@ -202,6 +202,8 @@ class MockASM24State:
return None
class MockUSB3:
@classmethod
def list_devices(cls, vendor, dev): return [(0, "usb:mock")]
def __init__(self, *args, **kwargs):
self.product, self.is_custom = "", False
def send_batch(self, cdbs:list[bytes], idata:list[int]|None=None, odata:list[bytes|None]|None=None) -> list[bytes|None]:

View file

@ -8,7 +8,7 @@ from tinygrad.runtime.support.hcq import MMIOInterface, BumpAllocator, hcq_filte
from tinygrad.uop.ops import sint
from tinygrad.device import Compiled, BufferSpec
from tinygrad.helpers import getenv, round_up, data64_le, DEBUG, PROFILE, ProfileEvent, lo32, hi32, colored, prod, ContextVar, TracingKey
from tinygrad.helpers import VIZ, ceildiv, unwrap
from tinygrad.helpers import VIZ, ceildiv, unwrap, pluralize
from tinygrad.renderer.cstyle import HIPRenderer, HIPCCRenderer
from tinygrad.renderer.llvmir import AMDLLVMRenderer
from tinygrad.runtime.autogen import kfd, hsa, sqtt, amdgpu_kd, amdgpu_drm
@ -17,6 +17,7 @@ from tinygrad.runtime.support.elf import elf_loader
from tinygrad.runtime.support.am.amdev import AMDev, AMMemoryManager
from tinygrad.runtime.support.amd import AMDReg, AMDIP, import_module, import_soc, import_ip_offsets, import_pmc
from tinygrad.runtime.support.system import System, PCIIfaceBase, PCIAllocationMeta, USBPCIDevice, MAP_FIXED, MAP_NORESERVE
from tinygrad.runtime.support.usb import USB3
from tinygrad.runtime.support.memory import AddrSpace
if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 # pylint: disable=unused-import
@ -644,7 +645,7 @@ class AMDProgram(HCQProgram):
class AMDAllocator(HCQAllocator['AMDDevice']):
def __init__(self, dev:AMDDevice):
super().__init__(dev, copy_bufs=getattr(dev.iface, 'copy_bufs', None), max_copyout_size=0x1000 if dev.is_usb() else None,
supports_copy_from_disk=dev.has_sdma_queue, supports_transfer=dev.has_sdma_queue)
supports_copy_from_disk=dev.has_sdma_queue, supports_transfer=dev.has_sdma_queue and not dev.is_usb())
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer:
return self.dev.iface.alloc(size, host=options.host, uncached=options.uncached, cpu_access=options.cpu_access or not self.dev.has_sdma_queue)
@ -912,10 +913,10 @@ class PCIIface(PCIIfaceBase):
def device_fini(self): self.dev_impl.fini()
class USBIface(PCIIface):
count = 1 # TODO: support multiple usbgpus, see usb.py
def __init__(self, dev, dev_id): # pylint: disable=super-init-not-called
self.dev, self.pci_dev, self.vram_bar = dev, USBPCIDevice(dev.__class__.__name__[:2], f"usb:{dev_id}"), 0
if dev_id >= len(visible:=hcq_filter_visible_devices(USB3.list_devices(0xADD1, 0x0001), "AMD")):
raise RuntimeError(f"AMD:{dev_id} does not exist ({pluralize('device', len(visible))} available)")
self.dev, self.pci_dev, self.vram_bar, self.count = dev, USBPCIDevice("AM", *visible[dev_id]), 0, len(visible)
self.dev_impl = AMDev(self.pci_dev)
self._compute_props()
self.pci_dev.usb._pci_cacheable += [self.pci_dev.bar_info(2)] # doorbell region is cacheable

View file

@ -126,6 +126,7 @@ class DLL(ctypes.CDLL):
def bind(self, restype, *argtypes):
def wrap(fn):
cfunc = None
@functools.wraps(fn)
def wrapper(*args):
nonlocal cfunc
if cfunc is None: (cfunc:=getattr(self, fn.__name__)).argtypes, cfunc.restype = argtypes, restype

View file

@ -214,12 +214,13 @@ class PCIDevice:
except OSError as e: raise RuntimeError(f"Cannot resize BAR {bar_idx}: {e}. Ensure the resizable BAR option is enabled.") from e
class USBPCIDevice(PCIDevice):
def __init__(self, devpref:str, pcibus:str):
def __init__(self, devpref:str, dev, pcibus):
self.pcibus, self.peer_group = pcibus, f"USBPCIDevice_{pcibus}"
self.lock_fd = System.flock_acquire(f"{devpref.lower()}_{pcibus.lower()}.lock")
usb = USB3(0xADD1, 0x0001, 0x81, 0x83, 0x02, 0x04)
if DEBUG >= 1: print(f"am usb: product string: {usb.product!r}")
usb = USB3(dev, 0x81, 0x83, 0x02, 0x04)
if DEBUG >= 1: print(f"am {self.pcibus}: product string: {usb.product!r}")
self.usb: CustomASM24Controller | ASM24Controller = CustomASM24Controller(usb) if usb.is_custom else ASM24Controller(usb)
self.pcibus, self._bar_info = pcibus, System.pci_setup_usb_bars(self.usb, gpu_bus=4, mem_base=0x10000000, pref_mem_base=(32 << 30))
self._bar_info = System.pci_setup_usb_bars(self.usb, gpu_bus=4, mem_base=0x10000000, pref_mem_base=(32 << 30))
self.sram = BumpAllocator(size=0x80000, wrap=False) # asm24 controller sram
def dma_view(self, ctrl_addr, size): return USBMMIOInterface(self.usb, ctrl_addr, size, fmt='B', pcimem=False)

View file

@ -1,59 +1,78 @@
import ctypes, struct, dataclasses, array, itertools, time
import ctypes, struct, dataclasses, array, itertools, time, functools
from typing import Sequence
from tinygrad.runtime.autogen import libusb
from tinygrad.helpers import DEBUG, DEV, to_mv, round_up, OSX, getenv, ceildiv
from tinygrad.runtime.support.hcq import MMIOInterface
from tinygrad.runtime.support import c
def alloc_cbuffer(sz:int) -> tuple[ctypes.Array, memoryview]: return (buf:=(ctypes.c_ubyte * sz)()), to_mv(ctypes.addressof(buf), sz)
def checked(fn, msg=None):
@functools.wraps(fn)
def wrapper(*args):
if (rc:=fn(*args)) < 0: raise RuntimeError(f"{msg or fn.__name__}: {ctypes.string_at(libusb.libusb_strerror(rc)).decode()}")
return rc
return wrapper
class USB3:
def __init__(self, vendor:int, dev:int, ep_data_in:int, ep_stat_in:int, ep_data_out:int, ep_cmd_out:int, max_streams:int=31, use_bot=False):
self.vendor, self.dev = vendor, dev
@staticmethod
@functools.cache
def ctx():
ctx = c.init_c_var(ctypes.POINTER(libusb.struct_libusb_context), checked(libusb.libusb_init))
if DEBUG >= 6: checked(libusb.libusb_set_option)(ctx, libusb.LIBUSB_OPTION_LOG_LEVEL, 4)
return ctx
@classmethod
@functools.cache
def list_devices(cls, vendor:int, dev:int) -> list[tuple[c.POINTER[libusb.struct_libusb_device], str]]:
ret = []
for i in range(checked(libusb.libusb_get_device_list)(cls.ctx(), devs:=ctypes.POINTER(ctypes.POINTER(libusb.struct_libusb_device))())):
desc = c.init_c_var(libusb.struct_libusb_device_descriptor, lambda x: checked(libusb.libusb_get_device_descriptor)(devs[i], x))
if (desc.idVendor, desc.idProduct) == (vendor, dev):
ret.append((libusb.libusb_ref_device(devs[i]), f"usb:{libusb.libusb_get_bus_number(devs[i])}-{libusb.libusb_get_device_address(devs[i])}"))
libusb.libusb_free_device_list(devs, 1)
return ret
def __init__(self, dev:c.POINTER[libusb.struct_libusb_device], ep_data_in:int, ep_stat_in:int, ep_data_out:int, ep_cmd_out:int,
max_streams:int=31, use_bot=False):
self.ep_data_in, self.ep_stat_in, self.ep_data_out, self.ep_cmd_out = ep_data_in, ep_stat_in, ep_data_out, ep_cmd_out
self.max_streams, self.use_bot = max_streams, use_bot
self._transferred = ctypes.c_int(0)
self._bulk_in_buf, self._bulk_in_mv = alloc_cbuffer(4 << 20)
self._bulk_out_buf, self._bulk_out_mv = alloc_cbuffer(4 << 20)
self.ctx = ctypes.POINTER(libusb.struct_libusb_context)()
if libusb.libusb_init(ctypes.byref(self.ctx)): raise RuntimeError("libusb_init failed")
if DEBUG >= 6: libusb.libusb_set_option(self.ctx, libusb.LIBUSB_OPTION_LOG_LEVEL, 4)
self.handle = libusb.libusb_open_device_with_vid_pid(self.ctx, self.vendor, self.dev)
if not self.handle: raise RuntimeError(f"device {self.vendor:04x}:{self.dev:04x} not found. sudo required?")
self.handle = c.init_c_var(c.POINTER[libusb.struct_libusb_device_handle], lambda x: checked(libusb.libusb_open)(dev, x))
# Read product string descriptor
_buf = (ctypes.c_ubyte * 256)()
_desc = libusb.struct_libusb_device_descriptor()
libusb.libusb_get_device_descriptor(libusb.libusb_get_device(self.handle), ctypes.byref(_desc))
_ret = libusb.libusb_get_string_descriptor_ascii(self.handle, _desc.iProduct, _buf, 256)
self.product = bytes(_buf[:max(_ret, 0)]).decode("ascii", errors="replace") if _ret > 0 else ""
checked(libusb.libusb_get_device_descriptor)(libusb.libusb_get_device(self.handle), ctypes.byref(_desc))
_ret = checked(libusb.libusb_get_string_descriptor_ascii)(self.handle, _desc.iProduct, _buf, 256)
self.product = bytes(_buf[:_ret]).decode("ascii", errors="replace")
self.is_custom = self.product.startswith("custom")
if self.is_custom: self.use_bot = use_bot = True
# Detach kernel driver if needed
if libusb.libusb_kernel_driver_active(self.handle, 0):
libusb.libusb_detach_kernel_driver(self.handle, 0)
libusb.libusb_reset_device(self.handle)
if checked(libusb.libusb_kernel_driver_active)(self.handle, 0):
checked(libusb.libusb_detach_kernel_driver)(self.handle, 0)
checked(libusb.libusb_reset_device)(self.handle)
# Set configuration and claim interface
if libusb.libusb_set_configuration(self.handle, 1): raise RuntimeError("set_configuration failed")
if libusb.libusb_claim_interface(self.handle, 0): raise RuntimeError("claim_interface failed. sudo required?")
checked(libusb.libusb_set_configuration)(self.handle, 1)
checked(libusb.libusb_claim_interface)(self.handle, 0)
if use_bot:
libusb.libusb_set_interface_alt_setting(self.handle, 0, 0)
checked(libusb.libusb_set_interface_alt_setting)(self.handle, 0, 0)
self._tag = 0
else:
if libusb.libusb_set_interface_alt_setting(self.handle, 0, 1): raise RuntimeError("alt_setting failed")
checked(libusb.libusb_set_interface_alt_setting)(self.handle, 0, 1)
# Clear any stalled endpoints
all_eps = (self.ep_data_out, self.ep_data_in, self.ep_stat_in, self.ep_cmd_out)
for ep in all_eps: libusb.libusb_clear_halt(self.handle, ep)
for ep in all_eps: checked(libusb.libusb_clear_halt)(self.handle, ep)
# Allocate streams
stream_eps = (ctypes.c_uint8 * 3)(self.ep_data_out, self.ep_data_in, self.ep_stat_in)
if (rc:=libusb.libusb_alloc_streams(self.handle, self.max_streams * len(stream_eps), stream_eps, len(stream_eps))) < 0:
raise RuntimeError(f"alloc_streams failed: {rc}")
checked(libusb.libusb_alloc_streams)(self.handle, self.max_streams * len(stream_eps), stream_eps, len(stream_eps))
# Base cmd
cmd_template = bytes([0x01, 0x00, 0x00, 0x01, *([0] * 12), 0xE4, 0x24, 0x00, 0xB2, 0x1A, 0x00, 0x00, 0x00, *([0] * 8)])
@ -77,11 +96,11 @@ class USB3:
return tr
def _submit_and_wait(self, cmds):
for tr in cmds: libusb.libusb_submit_transfer(tr)
for tr in cmds: checked(libusb.libusb_submit_transfer)(tr)
running = len(cmds)
while running:
libusb.libusb_handle_events(self.ctx)
checked(libusb.libusb_handle_events)(USB3.ctx())
running = len(cmds)
for tr in cmds:
if tr.contents.status == libusb.LIBUSB_TRANSFER_COMPLETED: running -= 1
@ -90,14 +109,12 @@ class USB3:
def _bulk_out(self, ep: int, payload: bytes, timeout: int = 1000):
if len(payload) > len(self._bulk_out_mv): self._bulk_out_buf, self._bulk_out_mv = alloc_cbuffer(len(payload))
self._bulk_out_mv[:len(payload)] = payload
rc = libusb.libusb_bulk_transfer(self.handle, ep, self._bulk_out_buf, len(payload), ctypes.byref(self._transferred), timeout)
assert rc == 0, f"bulk OUT 0x{ep:02X} failed: {rc}"
checked(libusb.libusb_bulk_transfer, f"bulk OUT 0x{ep:02X} failed")(self.handle, ep, self._bulk_out_buf, len(payload), self._transferred, timeout)
assert self._transferred.value == len(payload), f"bulk OUT short write on 0x{ep:02X}: {self._transferred.value}/{len(payload)} bytes"
def _bulk_in(self, ep: int, length: int, timeout: int = 1000) -> memoryview:
if length > len(self._bulk_in_mv): self._bulk_in_buf, self._bulk_in_mv = alloc_cbuffer(length)
rc = libusb.libusb_bulk_transfer(self.handle, ep, self._bulk_in_buf, length, ctypes.byref(self._transferred), timeout)
assert rc == 0, f"bulk IN 0x{ep:02X} failed: {rc}"
checked(libusb.libusb_bulk_transfer, f"bulk IN 0x{ep:02X} failed")(self.handle, ep, self._bulk_in_buf, length, self._transferred, timeout)
return self._bulk_in_mv[:self._transferred.value]
def send_batch(self, cdbs:list[bytes], idata:list[int]|None=None, odata:list[bytes|None]|None=None) -> list[bytes|None]:
@ -172,7 +189,11 @@ class ScsiWriteOp: data:bytes; lba:int=0 # noqa: E702
class CustomASM24Controller:
def __init__(self, usb:USB3|None=None):
self.usb = usb or USB3(0xADD1, 0x0001, 0x81, 0x83, 0x02, 0x04, use_bot=True)
if not usb:
devs = USB3.list_devices(0xADD1, 0x0001)
assert len(devs), "no ASM24 controller found"
self.usb = USB3(devs[0][0], 0x81, 0x83, 0x02, 0x04, use_bot=True)
else: self.usb = usb
self._pci_cacheable: list[tuple[int, int]] = []
self._pci_cache: dict[int, int|None] = {}
@ -186,8 +207,8 @@ class CustomASM24Controller:
if ltssm != 0x78: raise RuntimeError(f"PCIe link not up (LTSSM=0x{ltssm:02X}), custom firmware not ready")
def set_pcie_power(self, enabled:bool, timeout:int=10000):
ret = libusb.libusb_control_transfer(self.usb.handle, 0x40, 0xF3, int(enabled), 0, None, 0, timeout)
assert ret >= 0, f"F3 PCIe power {'on' if enabled else 'off'} failed: {ret}"
checked(libusb.libusb_control_transfer,
f"F3 PCIe power {'on' if enabled else 'off'} failed")(self.usb.handle, 0x40, 0xF3, int(enabled), 0, None, 0, timeout)
# === PCIe TLP via 0xF0 vendor command ===
@ -267,8 +288,8 @@ class CustomASM24Controller:
def write(self, base_addr:int, data:bytes, **kwargs):
"""Write to chip XDATA via vendor control OUT (bRequest=0xE5). wValue=addr, wIndex=val."""
for off, val in enumerate(data):
ret = libusb.libusb_control_transfer(self.usb.handle, 0x40, 0xE5, base_addr + off, val, None, 0, 1000)
assert ret >= 0, f"write(0x{base_addr + off:04X}, 0x{val:02X}) failed: {ret}"
checked(libusb.libusb_control_transfer,
f"write(0x{base_addr + off:04X}, 0x{val:02X}) failed")(self.usb.handle, 0x40, 0xE5, base_addr + off, val, None, 0, 1000)
def scsi_write(self, buf:bytes, lba:int=0):
"""Write to SRAM via 0xF2 vendor command + bulk OUT."""
@ -277,20 +298,23 @@ class CustomASM24Controller:
num_slots = round_up(len(buf_padded), 0x4000) // 0x4000 # 16KB per slot
# 0xF2 OUT: wValue=sectors, wIndex=start_slot|(num_slots<<8)
windex = (num_slots & 0xFF) << 8
ret = libusb.libusb_control_transfer(self.usb.handle, 0x40, 0xF2, sectors, windex, None, 0, 1000)
assert ret >= 0, f"F2 setup failed: {ret}"
checked(libusb.libusb_control_transfer, "F2 setup failed")(self.usb.handle, 0x40, 0xF2, sectors, windex, None, 0, 1000)
self.usb._bulk_out(0x02, buf_padded)
def scsi_read_arm(self, size:int):
windex = (ceildiv(size, 0x4000) & 0xFF) << 8
ret = libusb.libusb_control_transfer(self.usb.handle, 0x40, 0xF2, (ceildiv(size, 512) & 0x7FFF) | 0x8000, windex, None, 0, 1000)
assert ret >= 0, f"F2 read arm failed: {ret}"
checked(libusb.libusb_control_transfer,
"F2 read arm failed")(self.usb.handle, 0x40, 0xF2, (ceildiv(size, 512) & 0x7FFF) | 0x8000, windex, None, 0, 1000)
def scsi_read(self, size:int) -> memoryview: return self.usb._bulk_in(0x81, round_up(size, 512), timeout=10000)[:size]
class ASM24Controller:
def __init__(self, usb:USB3|None=None):
self.usb = usb or USB3(0xADD1, 0x0001, 0x81, 0x83, 0x02, 0x04, use_bot=bool(getenv("USE_BOT", 0)))
if not usb:
devs = USB3.list_devices(0xADD1, 0x0001)
assert len(devs), "no ASM24 controller found"
self.usb = USB3(devs[0][0], 0x81, 0x83, 0x02, 0x04, use_bot=bool(getenv("USE_BOT", 0)))
else: self.usb = usb
self._cache: dict[int, int|None] = {}
self._pci_cacheable: list[tuple[int, int]] = []
self._pci_cache: dict[int, int|None] = {}