mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
USB driver for custom ASM firmware (#15597)
* USB driver for custom ASM firmware * timeout * fix mypy * pcie mem read * flip in f/w * one tx * litle endian * autodetect custom * mock bypass * lint * clean
This commit is contained in:
parent
810d7c00cd
commit
2b01ca59dd
3 changed files with 135 additions and 7 deletions
|
|
@ -202,7 +202,8 @@ class MockASM24State:
|
|||
return None
|
||||
|
||||
class MockUSB3:
|
||||
def __init__(self, *args, **kwargs): pass
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.product, self.is_custom = "", False
|
||||
def send_batch(self, cdbs:list[bytes], idata:list[int]|None=None, odata:list[bytes|None]|None=None) -> list[bytes|None]:
|
||||
assert _mock_usb_state is not None
|
||||
idata, odata = idata or [0] * len(cdbs), odata or [None] * len(cdbs)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from tinygrad.helpers import round_up, getenv, OSX, temp, ceildiv, unwrap, fetch
|
|||
from tinygrad.runtime.autogen import libc, pci, vfio, iokit, corefoundation
|
||||
from tinygrad.runtime.support.hcq import FileIOInterface, MMIOInterface, HCQBuffer, hcq_filter_visible_devices
|
||||
from tinygrad.runtime.support.memory import VirtMapping, AddrSpace, BumpAllocator
|
||||
from tinygrad.runtime.support.usb import ASM24Controller, USBMMIOInterface
|
||||
from tinygrad.runtime.support.usb import USB3, CustomASM24Controller, ASM24Controller, USBMMIOInterface
|
||||
|
||||
MAP_FIXED, MAP_FIXED_NOREPLACE = 0x10, 0x100000
|
||||
MAP_LOCKED, MAP_POPULATE, MAP_NORESERVE = 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000), 0x400
|
||||
|
|
@ -82,7 +82,7 @@ class _System:
|
|||
cl, pcibus = hcq_filter_visible_devices(self.list_devices(vendor, devices, base_class))[dev_id]
|
||||
return cl(devpref, pcibus)
|
||||
|
||||
def pci_setup_usb_bars(self, usb:ASM24Controller, gpu_bus:int, mem_base:int, pref_mem_base:int) -> dict[int, tuple[int, int]]:
|
||||
def pci_setup_usb_bars(self, usb:CustomASM24Controller|ASM24Controller, gpu_bus:int, mem_base:int, pref_mem_base:int) -> dict[int, tuple[int, int]]:
|
||||
for bus in range(gpu_bus):
|
||||
# All 3 values must be written at the same time.
|
||||
buses = (0 << 0) | ((bus+1) << 8) | ((gpu_bus) << 16)
|
||||
|
|
@ -213,7 +213,9 @@ class PCIDevice:
|
|||
class USBPCIDevice(PCIDevice):
|
||||
def __init__(self, devpref:str, pcibus:str):
|
||||
self.lock_fd = System.flock_acquire(f"{devpref.lower()}_{pcibus.lower()}.lock")
|
||||
self.usb = ASM24Controller()
|
||||
usb = USB3(0xADD1, 0x0001, 0x81, 0x83, 0x02, 0x04)
|
||||
if DEBUG >= 1: print(f"am usb: product string: {usb.product!r}")
|
||||
self.usb: CustomASM24Controller | ASM24Controller = CustomASM24Controller(usb) if usb.is_custom else ASM24Controller(usb)
|
||||
self.pcibus, self._bar_info = pcibus, System.pci_setup_usb_bars(self.usb, gpu_bus=4, mem_base=0x10000000, pref_mem_base=(32 << 30))
|
||||
self.sram = BumpAllocator(size=0x80000, wrap=False) # asm24 controller sram
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import ctypes, struct, dataclasses, array, itertools
|
||||
import ctypes, struct, dataclasses, array, itertools, time
|
||||
from typing import Sequence
|
||||
from tinygrad.runtime.autogen import libusb
|
||||
from tinygrad.helpers import DEBUG, to_mv, round_up, OSX, getenv
|
||||
|
|
@ -17,6 +17,15 @@ class USB3:
|
|||
self.handle = libusb.libusb_open_device_with_vid_pid(self.ctx, self.vendor, self.dev)
|
||||
if not self.handle: raise RuntimeError(f"device {self.vendor:04x}:{self.dev:04x} not found. sudo required?")
|
||||
|
||||
# Read product string descriptor
|
||||
_buf = (ctypes.c_ubyte * 256)()
|
||||
_desc = libusb.struct_libusb_device_descriptor()
|
||||
libusb.libusb_get_device_descriptor(libusb.libusb_get_device(self.handle), ctypes.byref(_desc))
|
||||
_ret = libusb.libusb_get_string_descriptor_ascii(self.handle, _desc.iProduct, _buf, 256)
|
||||
self.product = bytes(_buf[:max(_ret, 0)]).decode("ascii", errors="replace") if _ret > 0 else ""
|
||||
self.is_custom = self.product.startswith("custom")
|
||||
if self.is_custom: self.use_bot = use_bot = True
|
||||
|
||||
# Detach kernel driver if needed
|
||||
if libusb.libusb_kernel_driver_active(self.handle, 0):
|
||||
libusb.libusb_detach_kernel_driver(self.handle, 0)
|
||||
|
|
@ -168,9 +177,121 @@ class ReadOp: addr:int; size:int # noqa: E702
|
|||
@dataclasses.dataclass(frozen=True)
|
||||
class ScsiWriteOp: data:bytes; lba:int=0 # noqa: E702
|
||||
|
||||
class CustomASM24Controller:
|
||||
def __init__(self, usb:USB3|None=None):
|
||||
self.usb = usb or USB3(0xADD1, 0x0001, 0x81, 0x83, 0x02, 0x04, use_bot=True)
|
||||
self._pci_cacheable: list[tuple[int, int]] = []
|
||||
self._pci_cache: dict[int, int|None] = {}
|
||||
|
||||
# Verify custom firmware is running and PCIe link is up (LTSSM=0x78).
|
||||
ltssm = self.read(0xB450, 1)[0]
|
||||
if ltssm != 0x78: raise RuntimeError(f"PCIe link not up (LTSSM=0x{ltssm:02X}), custom firmware not ready")
|
||||
|
||||
# === PCIe TLP via 0xF0 vendor command ===
|
||||
|
||||
def _f0_out(self, fmt_type:int, byte_en:int, address:int, value:int, mode:int=0):
|
||||
"""Send 0xF0 OUT control transfer: configure TLP engine. 12-byte DATA_OUT = addr_lo[4 LE] + addr_hi[4 LE] + value[4 LE]."""
|
||||
wval = fmt_type | (byte_en << 8)
|
||||
widx = mode & 0x03
|
||||
payload = struct.pack('<III', address & 0xFFFFFFFF, address >> 32, value)
|
||||
buf = (ctypes.c_ubyte * 12)(*payload)
|
||||
ret = libusb.libusb_control_transfer(self.usb.handle, 0x40, 0xF0, wval, widx, buf, 12, 5000)
|
||||
assert ret == 12, f"F0 OUT failed: {ret}"
|
||||
|
||||
def _f0_in(self) -> tuple[int, int, int]:
|
||||
"""Read 0xF0 IN: 8 bytes = data[4 LE] + cpl_hdr[2] + compl_status[1] + ret_status[1]. Returns (data, compl_status, ret_status)."""
|
||||
buf = (ctypes.c_ubyte * 8)()
|
||||
ret = libusb.libusb_control_transfer(self.usb.handle, 0xC0, 0xF0, 0, 0, buf, 8, 5000)
|
||||
assert ret == 8, f"F0 IN failed: {ret}"
|
||||
data = struct.unpack('<I', bytes(buf[0:4]))[0]
|
||||
cpl_status = (buf[4] >> 5) & 0x7 # completion status from CPL_HDR_HI bits [7:5]
|
||||
return data, cpl_status, buf[7]
|
||||
|
||||
def _is_pci_cacheable(self, addr:int) -> bool: return any(x <= addr <= x + sz for x, sz in self._pci_cacheable)
|
||||
|
||||
def pcie_request(self, fmt_type:int, address:int, value:int|None=None, size:int=4, cnt:int=10):
|
||||
if fmt_type == 0x60 and size == 4 and self._is_pci_cacheable(address) and self._pci_cache.get(address) == value: return
|
||||
assert size > 0 and size <= 4, f"Invalid size {size}"
|
||||
if DEBUG >= 5: print("pcie_request", hex(fmt_type), hex(address), value, size)
|
||||
|
||||
offset = address & 0x3
|
||||
byte_en = ((1 << size) - 1) << offset
|
||||
self._pci_cache[address] = value if size == 4 and fmt_type == 0x60 else None
|
||||
|
||||
self._f0_out(fmt_type, byte_en, address & ~0x3, (value << (8 * offset)) if value is not None else 0)
|
||||
|
||||
# Fast path: memory writes and messages don't return completions (same logic as ASM24Controller).
|
||||
if ((fmt_type & 0b11011111) == 0b01000000) or ((fmt_type & 0b10111000) == 0b00110000): return
|
||||
|
||||
# Read TLPs and config writes: read completion via 0xF0 IN. Retry on error/timeout.
|
||||
data, cpl_status, ret_status = self._f0_in()
|
||||
if ret_status != 0:
|
||||
time.sleep(0.001) # TODO: this sleep is very picky
|
||||
if cnt > 0:
|
||||
return self.pcie_request(fmt_type, address, value, size, cnt=cnt-1)
|
||||
raise RuntimeError(f"TLP error after retries: ret_status={ret_status}, address={address:#x}")
|
||||
|
||||
if cpl_status:
|
||||
status_map = {0b001: f"Unsupported Request: {address:#x}", 0b100: "Completer Abort", 0b010: "Config Retry"}
|
||||
raise RuntimeError(f"TLP completion status: {status_map.get(cpl_status, f'Reserved (0b{cpl_status:03b})')}")
|
||||
|
||||
if value is None: return (data >> (8 * offset)) & ((1 << (8 * size)) - 1)
|
||||
|
||||
def pcie_cfg_req(self, byte_addr:int, bus:int=1, dev:int=0, fn:int=0, value:int|None=None, size:int=4):
|
||||
assert byte_addr >> 12 == 0 and bus >> 8 == 0 and dev >> 5 == 0 and fn >> 3 == 0
|
||||
fmt_type = (0x44 if value is not None else 0x4) | int(bus > 0)
|
||||
address = (bus << 24) | (dev << 19) | (fn << 16) | (byte_addr & 0xfff)
|
||||
return self.pcie_request(fmt_type, address, value, size)
|
||||
|
||||
def pcie_mem_req(self, address:int, value:int|None=None, size:int=4):
|
||||
return self.pcie_request(0x60 if value is not None else 0x20, address, value, size)
|
||||
|
||||
def pcie_mem_write(self, address:int, values:list[int], size:int):
|
||||
"""Streaming PCIe memory write via 0xF0 mode 1 + bulk OUT. Data is little-endian dwords on the wire."""
|
||||
if not values: return
|
||||
self._f0_out(0x60, 0x0F, address, len(values), mode=1)
|
||||
self.usb._bulk_out(0x02, struct.pack(f'<{len(values)}I', *values))
|
||||
|
||||
def pcie_mem_read(self, address:int, nbytes:int) -> bytes:
|
||||
"""Streaming PCIe memory read via 0xF0 mode 2 + bulk IN. Returns little-endian bytes."""
|
||||
assert nbytes % 4 == 0, f"pcie_mem_read requires 4-byte aligned size, got {nbytes}"
|
||||
self._f0_out(0x20, 0x0F, address, nbytes // 4, mode=2)
|
||||
return self.usb._bulk_in(0x81, nbytes, timeout=30000)
|
||||
|
||||
# === XDATA read/write (0xE4/0xE5 vendor control transfers) ===
|
||||
|
||||
def read(self, base_addr:int, length:int, **kwargs) -> bytes:
|
||||
"""Read from chip XDATA via vendor control IN (bRequest=0xE4). wValue=addr, wLength=size."""
|
||||
result = b''
|
||||
for off in range(0, length, 0xFF):
|
||||
chunk = min(0xFF, length - off)
|
||||
buf = (ctypes.c_ubyte * chunk)()
|
||||
ret = libusb.libusb_control_transfer(self.usb.handle, 0xC0, 0xE4, base_addr + off, 0, buf, chunk, 1000)
|
||||
assert ret == chunk, f"read(0x{base_addr + off:04X}, {chunk}) failed: {ret}"
|
||||
result += bytes(buf[:ret])
|
||||
return result[:length]
|
||||
|
||||
def write(self, base_addr:int, data:bytes, **kwargs):
|
||||
"""Write to chip XDATA via vendor control OUT (bRequest=0xE5). wValue=addr, wIndex=val."""
|
||||
for off, val in enumerate(data):
|
||||
ret = libusb.libusb_control_transfer(self.usb.handle, 0x40, 0xE5, base_addr + off, val, None, 0, 1000)
|
||||
assert ret >= 0, f"write(0x{base_addr + off:04X}, 0x{val:02X}) failed: {ret}"
|
||||
|
||||
def scsi_write(self, buf:bytes, lba:int=0):
|
||||
"""Write to SRAM via 0xF2 vendor command + bulk OUT."""
|
||||
buf_padded = buf + b'\x00' * (round_up(len(buf), 512) - len(buf))
|
||||
sectors = len(buf_padded) // 512
|
||||
num_slots = round_up(len(buf_padded), 0x4000) // 0x4000 # 16KB per slot
|
||||
# 0xF2 OUT: wValue=sectors, wIndex=start_slot|(num_slots<<8)
|
||||
windex = (num_slots & 0xFF) << 8
|
||||
ret = libusb.libusb_control_transfer(self.usb.handle, 0x40, 0xF2, sectors, windex, None, 0, 1000)
|
||||
assert ret >= 0, f"F2 setup failed: {ret}"
|
||||
self.usb._bulk_out(0x02, buf_padded)
|
||||
|
||||
|
||||
class ASM24Controller:
|
||||
def __init__(self):
|
||||
self.usb = USB3(0xADD1, 0x0001, 0x81, 0x83, 0x02, 0x04, use_bot=bool(getenv("USE_BOT", 0)))
|
||||
def __init__(self, usb:USB3|None=None):
|
||||
self.usb = usb or USB3(0xADD1, 0x0001, 0x81, 0x83, 0x02, 0x04, use_bot=bool(getenv("USE_BOT", 0)))
|
||||
self._cache: dict[int, int|None] = {}
|
||||
self._pci_cacheable: list[tuple[int, int]] = []
|
||||
self._pci_cache: dict[int, int|None] = {}
|
||||
|
|
@ -310,6 +431,10 @@ class USBMMIOInterface(MMIOInterface):
|
|||
if not self.pcimem:
|
||||
return int.from_bytes(self.usb.read(self.addr + off, sz), "little") if sz == self.el_sz else self.usb.read(self.addr + off, sz)
|
||||
|
||||
# Fast path: streaming PCIe read if controller supports it
|
||||
if hasattr(self.usb, 'pcie_mem_read') and sz >= 4 and sz % 4 == 0:
|
||||
return self.usb.pcie_mem_read(self.addr + off, sz)
|
||||
|
||||
acc, acc_size = self._acc_size(sz)
|
||||
return bytes(array.array(acc, [self._acc_one(off + i * acc_size, acc_size) for i in range(sz // acc_size)]))
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue