usbgpu: use BOT interface for patch.py (#13644)

* BOT usage

* cleanup

* fix lint

* fix ruff

* fix -7?
This commit is contained in:
Robbe Derks 2026-02-02 04:54:46 +01:00 committed by GitHub
commit d75a1b0d5a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 107 additions and 51 deletions

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3
import sys, os, zlib, struct, hashlib
from tinygrad.helpers import DEBUG, getenv, fetch
import os, zlib, struct, hashlib
from tinygrad.helpers import getenv
from tinygrad.runtime.support.usb import USB3
SUPPORTED_CONTROLLERS = [
@ -50,7 +50,7 @@ patched_fw = patch(file_path, file_hash, patches)
dev = None
for vendor, device in SUPPORTED_CONTROLLERS:
try:
dev = USB3(vendor, device, 0x81, 0x83, 0x02, 0x04)
dev = USB3(vendor, device, 0x81, 0x83, 0x02, 0x04, use_bot=True)
break
except RuntimeError: pass
if dev is None:

View file

@ -5,10 +5,10 @@ from tinygrad.helpers import DEBUG, to_mv, round_up, OSX
from tinygrad.runtime.support.hcq import MMIOInterface
class USB3:
def __init__(self, vendor:int, dev:int, ep_data_in:int, ep_stat_in:int, ep_data_out:int, ep_cmd_out:int, max_streams:int=31):
def __init__(self, vendor:int, dev:int, ep_data_in:int, ep_stat_in:int, ep_data_out:int, ep_cmd_out:int, max_streams:int=31, use_bot=False):
self.vendor, self.dev = vendor, dev
self.ep_data_in, self.ep_stat_in, self.ep_data_out, self.ep_cmd_out = ep_data_in, ep_stat_in, ep_data_out, ep_cmd_out
self.max_streams = max_streams
self.max_streams, self.use_bot = max_streams, use_bot
self.ctx = ctypes.POINTER(libusb.struct_libusb_context)()
if libusb.libusb_init(ctypes.byref(self.ctx)): raise RuntimeError("libusb_init failed")
@ -25,30 +25,34 @@ class USB3:
# Set configuration and claim interface
if libusb.libusb_set_configuration(self.handle, 1): raise RuntimeError("set_configuration failed")
if libusb.libusb_claim_interface(self.handle, 0): raise RuntimeError("claim_interface failed. sudo required?")
if libusb.libusb_set_interface_alt_setting(self.handle, 0, 1): raise RuntimeError("alt_setting failed")
# Clear any stalled endpoints
all_eps = (self.ep_data_out, self.ep_data_in, self.ep_stat_in, self.ep_cmd_out)
for ep in all_eps: libusb.libusb_clear_halt(self.handle, ep)
if use_bot:
self._tag = 0
else:
if libusb.libusb_set_interface_alt_setting(self.handle, 0, 1): raise RuntimeError("alt_setting failed")
# Allocate streams
stream_eps = (ctypes.c_uint8 * 3)(self.ep_data_out, self.ep_data_in, self.ep_stat_in)
if (rc:=libusb.libusb_alloc_streams(self.handle, self.max_streams * len(stream_eps), stream_eps, len(stream_eps))) < 0:
raise RuntimeError(f"alloc_streams failed: {rc}")
# Clear any stalled endpoints
all_eps = (self.ep_data_out, self.ep_data_in, self.ep_stat_in, self.ep_cmd_out)
for ep in all_eps: libusb.libusb_clear_halt(self.handle, ep)
# Base cmd
cmd_template = bytes([0x01, 0x00, 0x00, 0x01, *([0] * 12), 0xE4, 0x24, 0x00, 0xB2, 0x1A, 0x00, 0x00, 0x00, *([0] * 8)])
# Allocate streams
stream_eps = (ctypes.c_uint8 * 3)(self.ep_data_out, self.ep_data_in, self.ep_stat_in)
if (rc:=libusb.libusb_alloc_streams(self.handle, self.max_streams * len(stream_eps), stream_eps, len(stream_eps))) < 0:
raise RuntimeError(f"alloc_streams failed: {rc}")
# Init pools
self.tr = {ep: [libusb.libusb_alloc_transfer(0) for _ in range(self.max_streams)] for ep in all_eps}
# Base cmd
cmd_template = bytes([0x01, 0x00, 0x00, 0x01, *([0] * 12), 0xE4, 0x24, 0x00, 0xB2, 0x1A, 0x00, 0x00, 0x00, *([0] * 8)])
self.buf_cmd = [(ctypes.c_uint8 * len(cmd_template))(*cmd_template) for _ in range(self.max_streams)]
self.buf_stat = [(ctypes.c_uint8 * 64)() for _ in range(self.max_streams)]
self.buf_data_in = [(ctypes.c_uint8 * 0x1000)() for _ in range(self.max_streams)]
self.buf_data_out = [(ctypes.c_uint8 * 0x80000)() for _ in range(self.max_streams)]
self.buf_data_out_mvs = [to_mv(ctypes.addressof(self.buf_data_out[i]), 0x80000) for i in range(self.max_streams)]
# Init pools
self.tr = {ep: [libusb.libusb_alloc_transfer(0) for _ in range(self.max_streams)] for ep in all_eps}
for slot in range(self.max_streams): struct.pack_into(">B", self.buf_cmd[slot], 3, slot + 1)
self.buf_cmd = [(ctypes.c_uint8 * len(cmd_template))(*cmd_template) for _ in range(self.max_streams)]
self.buf_stat = [(ctypes.c_uint8 * 64)() for _ in range(self.max_streams)]
self.buf_data_in = [(ctypes.c_uint8 * 0x1000)() for _ in range(self.max_streams)]
self.buf_data_out = [(ctypes.c_uint8 * 0x80000)() for _ in range(self.max_streams)]
self.buf_data_out_mvs = [to_mv(ctypes.addressof(self.buf_data_out[i]), 0x80000) for i in range(self.max_streams)]
for slot in range(self.max_streams): struct.pack_into(">B", self.buf_cmd[slot], 3, slot + 1)
def _prep_transfer(self, tr, ep, stream_id, buf, length):
tr.contents.dev_handle, tr.contents.endpoint, tr.contents.length, tr.contents.buffer = self.handle, ep, length, buf
@ -68,38 +72,90 @@ class USB3:
if tr.contents.status == libusb.LIBUSB_TRANSFER_COMPLETED: running -= 1
elif tr.contents.status != 0xFF: raise RuntimeError(f"EP 0x{tr.contents.endpoint:02X} error: {tr.contents.status}")
def _bulk_out(self, ep: int, payload: bytes, timeout: int = 1000):
transferred = ctypes.c_int(0)
rc = libusb.libusb_bulk_transfer(
self.handle,
ep,
(ctypes.c_ubyte * len(payload))(*payload),
len(payload),
ctypes.byref(transferred),
timeout,
)
assert rc == 0, f"bulk OUT 0x{ep:02X} failed: {rc}"
assert transferred.value == len(payload), f"bulk OUT short write on 0x{ep:02X}: {transferred.value}/{len(payload)} bytes"
def _bulk_in(self, ep: int, length: int, timeout: int = 1000) -> bytes:
buf, transferred = (ctypes.c_ubyte * length)(), ctypes.c_int(0)
rc = libusb.libusb_bulk_transfer(
self.handle,
ep,
buf,
length,
ctypes.byref(transferred),
timeout,
)
assert rc == 0, f"bulk IN 0x{ep:02X} failed: {rc}"
return bytes(buf[:transferred.value])
def send_batch(self, cdbs:list[bytes], idata:list[int]|None=None, odata:list[bytes|None]|None=None) -> list[bytes|None]:
idata, odata = idata or [0] * len(cdbs), odata or [None] * len(cdbs)
results, tr_window, op_window = [], [], []
results:list[bytes|None] = []
tr_window, op_window = [], []
for idx, (cdb, rlen, send_data) in enumerate(zip(cdbs, idata, odata)):
# allocate slot and stream. stream is 1-based
slot, stream = idx % self.max_streams, (idx % self.max_streams) + 1
if self.use_bot:
dir_in = rlen > 0
data_len = rlen if dir_in else (len(send_data) if send_data is not None else 0)
assert (data_len == 0) if dir_in else (rlen == 0), "BOT mode only supports either read or write per command"
# build cmd packet
self.buf_cmd[slot][16:16+len(cdb)] = list(cdb)
# CBW
self._tag += 1
flags = 0x80 if dir_in else 0x00
cbw = struct.pack("<IIIBBB", 0x43425355, self._tag, data_len, flags, 0, len(cdb)) + cdb + b"\x00" * (16 - len(cdb))
self._bulk_out(self.ep_data_out, cbw)
# cmd + stat transfers
tr_window.append(self._prep_transfer(self.tr[self.ep_cmd_out][slot], self.ep_cmd_out, None, self.buf_cmd[slot], len(self.buf_cmd[slot])))
tr_window.append(self._prep_transfer(self.tr[self.ep_stat_in][slot], self.ep_stat_in, stream, self.buf_stat[slot], 64))
# DAT
if dir_in:
results.append(self._bulk_in(self.ep_data_in, rlen))
else:
if send_data is not None:
self._bulk_out(self.ep_data_out, send_data)
results.append(None)
if rlen:
if rlen > len(self.buf_data_in[slot]): self.buf_data_in[slot] = (ctypes.c_uint8 * round_up(rlen, 0x1000))()
tr_window.append(self._prep_transfer(self.tr[self.ep_data_in][slot], self.ep_data_in, stream, self.buf_data_in[slot], rlen))
# CSW
sig, rtag, residue, status = struct.unpack("<IIIB", self._bulk_in(self.ep_data_in, 13, timeout=2000))
assert sig == 0x53425355, f"Bad CSW signature 0x{sig:08X}, expected 0x53425355"
assert rtag == self._tag, f"CSW tag mismatch: got {rtag}, expected {self._tag}"
assert status == 0, f"SCSI command failed, CSW status=0x{status:02X}, residue={residue}"
else:
# allocate slot and stream. stream is 1-based
slot, stream = idx % self.max_streams, (idx % self.max_streams) + 1
if send_data is not None:
if len(send_data) > len(self.buf_data_out[slot]):
self.buf_data_out[slot] = (ctypes.c_uint8 * len(send_data))()
self.buf_data_out_mvs[slot] = to_mv(ctypes.addressof(self.buf_data_out[slot]), len(send_data))
# build cmd packet
self.buf_cmd[slot][16:16+len(cdb)] = list(cdb)
self.buf_data_out_mvs[slot][:len(send_data)] = bytes(send_data)
tr_window.append(self._prep_transfer(self.tr[self.ep_data_out][slot], self.ep_data_out, stream, self.buf_data_out[slot], len(send_data)))
# cmd + stat transfers
tr_window.append(self._prep_transfer(self.tr[self.ep_cmd_out][slot], self.ep_cmd_out, None, self.buf_cmd[slot], len(self.buf_cmd[slot])))
tr_window.append(self._prep_transfer(self.tr[self.ep_stat_in][slot], self.ep_stat_in, stream, self.buf_stat[slot], 64))
op_window.append((idx, slot, rlen))
if (idx + 1 == len(cdbs)) or len(op_window) >= self.max_streams:
self._submit_and_wait(tr_window)
for idx, slot, rlen in op_window: results.append(bytes(self.buf_data_in[slot][:rlen]) if rlen else None)
tr_window = []
if rlen:
if rlen > len(self.buf_data_in[slot]): self.buf_data_in[slot] = (ctypes.c_uint8 * round_up(rlen, 0x1000))()
tr_window.append(self._prep_transfer(self.tr[self.ep_data_in][slot], self.ep_data_in, stream, self.buf_data_in[slot], rlen))
if send_data is not None:
if len(send_data) > len(self.buf_data_out[slot]):
self.buf_data_out[slot] = (ctypes.c_uint8 * len(send_data))()
self.buf_data_out_mvs[slot] = to_mv(ctypes.addressof(self.buf_data_out[slot]), len(send_data))
self.buf_data_out_mvs[slot][:len(send_data)] = bytes(send_data)
tr_window.append(self._prep_transfer(self.tr[self.ep_data_out][slot], self.ep_data_out, stream, self.buf_data_out[slot], len(send_data)))
op_window.append((idx, slot, rlen))
if (idx + 1 == len(cdbs)) or len(op_window) >= self.max_streams:
self._submit_and_wait(tr_window)
for idx, slot, rlen in op_window: results.append(bytes(self.buf_data_in[slot][:rlen]) if rlen else None)
tr_window = []
return results