mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
usbgpu: use BOT interface for patch.py (#13644)
* BOT usage * cleanup * fix lint * fix ruff * fix -7?
This commit is contained in:
parent
2931b52875
commit
d75a1b0d5a
2 changed files with 107 additions and 51 deletions
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
import sys, os, zlib, struct, hashlib
|
||||
from tinygrad.helpers import DEBUG, getenv, fetch
|
||||
import os, zlib, struct, hashlib
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.runtime.support.usb import USB3
|
||||
|
||||
SUPPORTED_CONTROLLERS = [
|
||||
|
|
@ -50,7 +50,7 @@ patched_fw = patch(file_path, file_hash, patches)
|
|||
dev = None
|
||||
for vendor, device in SUPPORTED_CONTROLLERS:
|
||||
try:
|
||||
dev = USB3(vendor, device, 0x81, 0x83, 0x02, 0x04)
|
||||
dev = USB3(vendor, device, 0x81, 0x83, 0x02, 0x04, use_bot=True)
|
||||
break
|
||||
except RuntimeError: pass
|
||||
if dev is None:
|
||||
|
|
|
|||
|
|
@ -5,10 +5,10 @@ from tinygrad.helpers import DEBUG, to_mv, round_up, OSX
|
|||
from tinygrad.runtime.support.hcq import MMIOInterface
|
||||
|
||||
class USB3:
|
||||
def __init__(self, vendor:int, dev:int, ep_data_in:int, ep_stat_in:int, ep_data_out:int, ep_cmd_out:int, max_streams:int=31):
|
||||
def __init__(self, vendor:int, dev:int, ep_data_in:int, ep_stat_in:int, ep_data_out:int, ep_cmd_out:int, max_streams:int=31, use_bot=False):
|
||||
self.vendor, self.dev = vendor, dev
|
||||
self.ep_data_in, self.ep_stat_in, self.ep_data_out, self.ep_cmd_out = ep_data_in, ep_stat_in, ep_data_out, ep_cmd_out
|
||||
self.max_streams = max_streams
|
||||
self.max_streams, self.use_bot = max_streams, use_bot
|
||||
self.ctx = ctypes.POINTER(libusb.struct_libusb_context)()
|
||||
|
||||
if libusb.libusb_init(ctypes.byref(self.ctx)): raise RuntimeError("libusb_init failed")
|
||||
|
|
@ -25,30 +25,34 @@ class USB3:
|
|||
# Set configuration and claim interface
|
||||
if libusb.libusb_set_configuration(self.handle, 1): raise RuntimeError("set_configuration failed")
|
||||
if libusb.libusb_claim_interface(self.handle, 0): raise RuntimeError("claim_interface failed. sudo required?")
|
||||
if libusb.libusb_set_interface_alt_setting(self.handle, 0, 1): raise RuntimeError("alt_setting failed")
|
||||
|
||||
# Clear any stalled endpoints
|
||||
all_eps = (self.ep_data_out, self.ep_data_in, self.ep_stat_in, self.ep_cmd_out)
|
||||
for ep in all_eps: libusb.libusb_clear_halt(self.handle, ep)
|
||||
if use_bot:
|
||||
self._tag = 0
|
||||
else:
|
||||
if libusb.libusb_set_interface_alt_setting(self.handle, 0, 1): raise RuntimeError("alt_setting failed")
|
||||
|
||||
# Allocate streams
|
||||
stream_eps = (ctypes.c_uint8 * 3)(self.ep_data_out, self.ep_data_in, self.ep_stat_in)
|
||||
if (rc:=libusb.libusb_alloc_streams(self.handle, self.max_streams * len(stream_eps), stream_eps, len(stream_eps))) < 0:
|
||||
raise RuntimeError(f"alloc_streams failed: {rc}")
|
||||
# Clear any stalled endpoints
|
||||
all_eps = (self.ep_data_out, self.ep_data_in, self.ep_stat_in, self.ep_cmd_out)
|
||||
for ep in all_eps: libusb.libusb_clear_halt(self.handle, ep)
|
||||
|
||||
# Base cmd
|
||||
cmd_template = bytes([0x01, 0x00, 0x00, 0x01, *([0] * 12), 0xE4, 0x24, 0x00, 0xB2, 0x1A, 0x00, 0x00, 0x00, *([0] * 8)])
|
||||
# Allocate streams
|
||||
stream_eps = (ctypes.c_uint8 * 3)(self.ep_data_out, self.ep_data_in, self.ep_stat_in)
|
||||
if (rc:=libusb.libusb_alloc_streams(self.handle, self.max_streams * len(stream_eps), stream_eps, len(stream_eps))) < 0:
|
||||
raise RuntimeError(f"alloc_streams failed: {rc}")
|
||||
|
||||
# Init pools
|
||||
self.tr = {ep: [libusb.libusb_alloc_transfer(0) for _ in range(self.max_streams)] for ep in all_eps}
|
||||
# Base cmd
|
||||
cmd_template = bytes([0x01, 0x00, 0x00, 0x01, *([0] * 12), 0xE4, 0x24, 0x00, 0xB2, 0x1A, 0x00, 0x00, 0x00, *([0] * 8)])
|
||||
|
||||
self.buf_cmd = [(ctypes.c_uint8 * len(cmd_template))(*cmd_template) for _ in range(self.max_streams)]
|
||||
self.buf_stat = [(ctypes.c_uint8 * 64)() for _ in range(self.max_streams)]
|
||||
self.buf_data_in = [(ctypes.c_uint8 * 0x1000)() for _ in range(self.max_streams)]
|
||||
self.buf_data_out = [(ctypes.c_uint8 * 0x80000)() for _ in range(self.max_streams)]
|
||||
self.buf_data_out_mvs = [to_mv(ctypes.addressof(self.buf_data_out[i]), 0x80000) for i in range(self.max_streams)]
|
||||
# Init pools
|
||||
self.tr = {ep: [libusb.libusb_alloc_transfer(0) for _ in range(self.max_streams)] for ep in all_eps}
|
||||
|
||||
for slot in range(self.max_streams): struct.pack_into(">B", self.buf_cmd[slot], 3, slot + 1)
|
||||
self.buf_cmd = [(ctypes.c_uint8 * len(cmd_template))(*cmd_template) for _ in range(self.max_streams)]
|
||||
self.buf_stat = [(ctypes.c_uint8 * 64)() for _ in range(self.max_streams)]
|
||||
self.buf_data_in = [(ctypes.c_uint8 * 0x1000)() for _ in range(self.max_streams)]
|
||||
self.buf_data_out = [(ctypes.c_uint8 * 0x80000)() for _ in range(self.max_streams)]
|
||||
self.buf_data_out_mvs = [to_mv(ctypes.addressof(self.buf_data_out[i]), 0x80000) for i in range(self.max_streams)]
|
||||
|
||||
for slot in range(self.max_streams): struct.pack_into(">B", self.buf_cmd[slot], 3, slot + 1)
|
||||
|
||||
def _prep_transfer(self, tr, ep, stream_id, buf, length):
|
||||
tr.contents.dev_handle, tr.contents.endpoint, tr.contents.length, tr.contents.buffer = self.handle, ep, length, buf
|
||||
|
|
@ -68,38 +72,90 @@ class USB3:
|
|||
if tr.contents.status == libusb.LIBUSB_TRANSFER_COMPLETED: running -= 1
|
||||
elif tr.contents.status != 0xFF: raise RuntimeError(f"EP 0x{tr.contents.endpoint:02X} error: {tr.contents.status}")
|
||||
|
||||
def _bulk_out(self, ep: int, payload: bytes, timeout: int = 1000):
|
||||
transferred = ctypes.c_int(0)
|
||||
rc = libusb.libusb_bulk_transfer(
|
||||
self.handle,
|
||||
ep,
|
||||
(ctypes.c_ubyte * len(payload))(*payload),
|
||||
len(payload),
|
||||
ctypes.byref(transferred),
|
||||
timeout,
|
||||
)
|
||||
assert rc == 0, f"bulk OUT 0x{ep:02X} failed: {rc}"
|
||||
assert transferred.value == len(payload), f"bulk OUT short write on 0x{ep:02X}: {transferred.value}/{len(payload)} bytes"
|
||||
|
||||
def _bulk_in(self, ep: int, length: int, timeout: int = 1000) -> bytes:
|
||||
buf, transferred = (ctypes.c_ubyte * length)(), ctypes.c_int(0)
|
||||
rc = libusb.libusb_bulk_transfer(
|
||||
self.handle,
|
||||
ep,
|
||||
buf,
|
||||
length,
|
||||
ctypes.byref(transferred),
|
||||
timeout,
|
||||
)
|
||||
assert rc == 0, f"bulk IN 0x{ep:02X} failed: {rc}"
|
||||
return bytes(buf[:transferred.value])
|
||||
|
||||
def send_batch(self, cdbs:list[bytes], idata:list[int]|None=None, odata:list[bytes|None]|None=None) -> list[bytes|None]:
|
||||
idata, odata = idata or [0] * len(cdbs), odata or [None] * len(cdbs)
|
||||
results, tr_window, op_window = [], [], []
|
||||
results:list[bytes|None] = []
|
||||
tr_window, op_window = [], []
|
||||
|
||||
for idx, (cdb, rlen, send_data) in enumerate(zip(cdbs, idata, odata)):
|
||||
# allocate slot and stream. stream is 1-based
|
||||
slot, stream = idx % self.max_streams, (idx % self.max_streams) + 1
|
||||
if self.use_bot:
|
||||
dir_in = rlen > 0
|
||||
data_len = rlen if dir_in else (len(send_data) if send_data is not None else 0)
|
||||
assert (data_len == 0) if dir_in else (rlen == 0), "BOT mode only supports either read or write per command"
|
||||
|
||||
# build cmd packet
|
||||
self.buf_cmd[slot][16:16+len(cdb)] = list(cdb)
|
||||
# CBW
|
||||
self._tag += 1
|
||||
flags = 0x80 if dir_in else 0x00
|
||||
cbw = struct.pack("<IIIBBB", 0x43425355, self._tag, data_len, flags, 0, len(cdb)) + cdb + b"\x00" * (16 - len(cdb))
|
||||
self._bulk_out(self.ep_data_out, cbw)
|
||||
|
||||
# cmd + stat transfers
|
||||
tr_window.append(self._prep_transfer(self.tr[self.ep_cmd_out][slot], self.ep_cmd_out, None, self.buf_cmd[slot], len(self.buf_cmd[slot])))
|
||||
tr_window.append(self._prep_transfer(self.tr[self.ep_stat_in][slot], self.ep_stat_in, stream, self.buf_stat[slot], 64))
|
||||
# DAT
|
||||
if dir_in:
|
||||
results.append(self._bulk_in(self.ep_data_in, rlen))
|
||||
else:
|
||||
if send_data is not None:
|
||||
self._bulk_out(self.ep_data_out, send_data)
|
||||
results.append(None)
|
||||
|
||||
if rlen:
|
||||
if rlen > len(self.buf_data_in[slot]): self.buf_data_in[slot] = (ctypes.c_uint8 * round_up(rlen, 0x1000))()
|
||||
tr_window.append(self._prep_transfer(self.tr[self.ep_data_in][slot], self.ep_data_in, stream, self.buf_data_in[slot], rlen))
|
||||
# CSW
|
||||
sig, rtag, residue, status = struct.unpack("<IIIB", self._bulk_in(self.ep_data_in, 13, timeout=2000))
|
||||
assert sig == 0x53425355, f"Bad CSW signature 0x{sig:08X}, expected 0x53425355"
|
||||
assert rtag == self._tag, f"CSW tag mismatch: got {rtag}, expected {self._tag}"
|
||||
assert status == 0, f"SCSI command failed, CSW status=0x{status:02X}, residue={residue}"
|
||||
else:
|
||||
# allocate slot and stream. stream is 1-based
|
||||
slot, stream = idx % self.max_streams, (idx % self.max_streams) + 1
|
||||
|
||||
if send_data is not None:
|
||||
if len(send_data) > len(self.buf_data_out[slot]):
|
||||
self.buf_data_out[slot] = (ctypes.c_uint8 * len(send_data))()
|
||||
self.buf_data_out_mvs[slot] = to_mv(ctypes.addressof(self.buf_data_out[slot]), len(send_data))
|
||||
# build cmd packet
|
||||
self.buf_cmd[slot][16:16+len(cdb)] = list(cdb)
|
||||
|
||||
self.buf_data_out_mvs[slot][:len(send_data)] = bytes(send_data)
|
||||
tr_window.append(self._prep_transfer(self.tr[self.ep_data_out][slot], self.ep_data_out, stream, self.buf_data_out[slot], len(send_data)))
|
||||
# cmd + stat transfers
|
||||
tr_window.append(self._prep_transfer(self.tr[self.ep_cmd_out][slot], self.ep_cmd_out, None, self.buf_cmd[slot], len(self.buf_cmd[slot])))
|
||||
tr_window.append(self._prep_transfer(self.tr[self.ep_stat_in][slot], self.ep_stat_in, stream, self.buf_stat[slot], 64))
|
||||
|
||||
op_window.append((idx, slot, rlen))
|
||||
if (idx + 1 == len(cdbs)) or len(op_window) >= self.max_streams:
|
||||
self._submit_and_wait(tr_window)
|
||||
for idx, slot, rlen in op_window: results.append(bytes(self.buf_data_in[slot][:rlen]) if rlen else None)
|
||||
tr_window = []
|
||||
if rlen:
|
||||
if rlen > len(self.buf_data_in[slot]): self.buf_data_in[slot] = (ctypes.c_uint8 * round_up(rlen, 0x1000))()
|
||||
tr_window.append(self._prep_transfer(self.tr[self.ep_data_in][slot], self.ep_data_in, stream, self.buf_data_in[slot], rlen))
|
||||
|
||||
if send_data is not None:
|
||||
if len(send_data) > len(self.buf_data_out[slot]):
|
||||
self.buf_data_out[slot] = (ctypes.c_uint8 * len(send_data))()
|
||||
self.buf_data_out_mvs[slot] = to_mv(ctypes.addressof(self.buf_data_out[slot]), len(send_data))
|
||||
|
||||
self.buf_data_out_mvs[slot][:len(send_data)] = bytes(send_data)
|
||||
tr_window.append(self._prep_transfer(self.tr[self.ep_data_out][slot], self.ep_data_out, stream, self.buf_data_out[slot], len(send_data)))
|
||||
|
||||
op_window.append((idx, slot, rlen))
|
||||
if (idx + 1 == len(cdbs)) or len(op_window) >= self.max_streams:
|
||||
self._submit_and_wait(tr_window)
|
||||
for idx, slot, rlen in op_window: results.append(bytes(self.buf_data_in[slot][:rlen]) if rlen else None)
|
||||
tr_window = []
|
||||
|
||||
return results
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue