refactor to use uop

This commit is contained in:
George Hotz 2026-05-29 19:01:10 -07:00
commit 058259e0ea

View file

@ -41,29 +41,28 @@ def generic_wmma_helper(inp, warp_size, WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_
class PythonProgram:
def __init__(self, name:str, lib:bytes, **kwargs):
self.prg: list[UOp] = pickle.loads(lib)
# the value of SPECIAL comes from local/global_size, not form its source
self.uops: list[tuple[Ops, DType, list[int], Any]] = \
[(u.op, u.dtype, [self.prg.index(v) for v in u.src if u.op is not Ops.SPECIAL], u.arg) for u in self.prg]
self.uops: list[UOp] = pickle.loads(lib)
self.uop_to_index: dict[UOp, int] = {u:i for i,u in enumerate(self.uops)}
self.loop_ends: dict[UOp, int] = {u.src[1]:i for i, u in enumerate(self.uops) if u.op == Ops.END}
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False, **kw):
st = time.perf_counter()
warp = list(itertools.product(*[range(x) for x in local_size[::-1]]))
warp_size = len(warp)
void_ops = {Ops.END, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.GROUP, Ops.STORE}
loop_ends: dict[int, int] = {srcs[1]:i for i, (uop, _, srcs, _) in enumerate(self.uops) if uop == Ops.END}
for idxs in itertools.product(*[range(x) for x in global_size[::-1]]):
values: dict[int, Any] = {}
values: dict[UOp, Any] = {}
pbufs: list[memoryview] = list(bufs)
pvals: list[int] = list(vals)
exec_masks = [[True] * warp_size]
i = 0
while i < len(self.uops):
uop, dtype, srcs, arg = self.uops[i]
src_values = [values[v] for v in srcs if self.uops[v][0] not in void_ops]
src_dtypes = [self.uops[v][1] for v in srcs if self.uops[v][0] not in void_ops]
u = self.uops[i]
uop, dtype, srcs, arg = u.op, u.dtype, (() if u.op is Ops.SPECIAL else u.src), u.arg
src_values = [values[v] for v in srcs if v.op not in void_ops]
src_dtypes = [v.dtype for v in srcs if v.op not in void_ops]
if getenv("TRACE"): print(i, uop, dtype, arg, src_values, src_dtypes)
if uop is Ops.END:
i = srcs[1]
i = self.uop_to_index[srcs[1]]
continue
if uop is Ops.IF:
exec_masks.append([x and y for x,y in zip(exec_masks[-1], src_values[0])])
@ -86,7 +85,7 @@ class PythonProgram:
if g: _store(m, o+j, v, src_dtypes[1].scalar())
i += 1
continue
if uop is Ops.AFTER: values[i] = src_values[0]
if uop is Ops.AFTER: values[u] = src_values[0]
elif uop in {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
assert isinstance(dtype, PtrDType), dtype
storage_fmt = storage_fmt_for_dtype(dtype.base.scalar())
@ -94,16 +93,16 @@ class PythonProgram:
if TYPE_CHECKING or sys.version_info < (3, 12): assert storage_fmt != "e"
if uop is Ops.DEFINE_REG:
# REGs are per thread
values[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
values[u] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
else:
buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is not Ops.PARAM else pbufs.pop(0)
values[i] = [buf.cast(storage_fmt)] * warp_size
values[u] = [buf.cast(storage_fmt)] * warp_size
elif uop is Ops.DEFINE_VAR:
values[i] = [pvals.pop(0)] * warp_size
values[u] = [pvals.pop(0)] * warp_size
elif uop is Ops.SPECIAL:
if arg[0] == 'g': values[i] = [idxs[2-int(arg[-1])]] * warp_size
elif arg[0] == 'l': values[i] = [x[2-int(arg[-1])] for x in warp]
elif uop is Ops.CONST: values[i] = [arg] * warp_size
if arg[0] == 'g': values[u] = [idxs[2-int(arg[-1])]] * warp_size
elif arg[0] == 'l': values[u] = [x[2-int(arg[-1])] for x in warp]
elif uop is Ops.CONST: values[u] = [arg] * warp_size
elif uop is Ops.INDEX:
ret:list = []
if isinstance(src_dtypes[0], ImageDType):
@ -114,31 +113,31 @@ class PythonProgram:
else:
assert len(src_values) == 2, "non-image index must be 2 srcs"
for m,o in zip(*src_values): ret.append((m,o))
values[i] = ret
values[u] = ret
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
values[i] = src_values[0]
values[u] = src_values[0]
elif uop is Ops.RANGE:
if i not in values: values[i] = [0] * warp_size
if u not in values: values[u] = [0] * warp_size
else:
for j in range(len(values[i])):
values[i][j] += 1
if values[i][0] == src_values[0][0]:
del values[i]
i = loop_ends[i] + 1
for j in range(len(values[u])):
values[u][j] += 1
if values[u][0] == src_values[0][0]:
del values[u]
i = self.loop_ends[u] + 1
continue
elif uop is Ops.STACK: values[i] = src_values
elif uop is Ops.BITCAST: values[i] = [bitcast(x, src_dtypes[0], dtype) for x in src_values[0]]
elif uop is Ops.STACK: values[u] = src_values
elif uop is Ops.BITCAST: values[u] = [bitcast(x, src_dtypes[0], dtype) for x in src_values[0]]
elif uop is Ops.CAST:
values[i] = [truncate.get(dtype, lambda dt: dt)(dtype.const(x)) for x in src_values[0]]
values[u] = [truncate.get(dtype, lambda dt: dt)(dtype.const(x)) for x in src_values[0]]
elif uop is Ops.LOAD:
if dtype.count > 1:
values[i] = [load([src_values[i][j] if i != 0 and src_dtypes[i].count > 1 else src_values[i] \
for i in range(len(src_values))], j, dtype.scalar()) for j in range(dtype.count)]
values[u] = [load([src_values[k][j] if k != 0 and src_dtypes[k].count > 1 else src_values[k] \
for k in range(len(src_values))], j, dtype.scalar()) for j in range(dtype.count)]
else:
values[i] = load(src_values, 0, dtype)
elif uop is Ops.GEP: values[i] = src_values[0][get_single_element(arg)]
values[u] = load(src_values, 0, dtype)
elif uop is Ops.GEP: values[u] = src_values[0][get_single_element(arg)]
elif uop is Ops.WMMA:
first_src_dtype = self.uops[srcs[0]][1]
first_src_dtype = srcs[0].dtype
assert isinstance(first_src_dtype, DType) # mypy
dims, dtype_in, device, threads = arg[1], first_src_dtype.scalar(), arg[4], arg[5]
wmma_helper = functools.partial(generic_wmma_helper, src_values, warp_size)
@ -148,17 +147,17 @@ class PythonProgram:
def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
# (i, j), C, D (2 elements on 32 threads): row major same as A/B
def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
values[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
values[u] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
elif device == "AMD" and threads == 64:
def a_elem(x, k, row, goff): return x[k%(dims[2]//4)][goff + (k//(dims[2]//4))*16 + row]
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
def c_map(lane, elem): return (lane%16, (lane//16)*4 + elem)
values[i] = wmma_helper(64, dims[2], len(src_values[0]), len(src_values[1]), len(src_values[2]), a_elem, b_elem, c_map)
values[u] = wmma_helper(64, dims[2], len(src_values[0]), len(src_values[1]), len(src_values[2]), a_elem, b_elem, c_map)
elif device == "AMD" and len(src_values[0]) == 8: # RDNA4
def a_elem(x, k, row, goff): return x[k - [0, 4, 4, 8][k//4]][goff + row + [0, 16, 0, 16][k//4]]
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff)
def c_map(lane, elem): return (lane%16, (lane//16)*8 + elem)
values[i] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map)
values[u] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map)
elif device == "AMD":
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
def a_elem(x, k, row, goff):
@ -167,7 +166,7 @@ class PythonProgram:
# B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
values[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
values[u] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
elif device == "CUDA":
# (col, row) given (lane, elem) for C & D (4 elements on 32 threads); shared by all tc shapes with M=16 N=8
def c_map(lane, elem): return (elem%2 + (lane%4)*2, lane//4 + (elem//2)*8)
@ -175,22 +174,22 @@ class PythonProgram:
if dims == (8,16,16):
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2 + (k//8)*4][goff + (k//2)%4 + (row%8)*4]
def b_elem(x, col, k, goff): return x[k%2 + (k//8)*2][goff + (k//2)%4 + col*4]
values[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
values[u] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
elif dims == (8,16,32):
def a_elem(x, k, row, goff): return x[k%4 + (row//8)*4 + (k//16)*8][goff + (k//4)%4 + (row%8)*4]
def b_elem(x, col, k, goff): return x[k%4 + (k//16)*4][goff + (k//4)%4 + col*4]
values[i] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map)
values[u] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map)
elif dims == (8,16,8) and dtype_in == dtypes.half:
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2][goff + k//2 + (row%8)*4]
def b_elem(x, col, k, goff): return x[k%2][goff + k//2 + col*4]
values[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
values[u] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
elif dims == (8,16,8) and dtype_in == dtypes.float:
def a_elem(x, k, row, goff): return x[(k//4)*2 + row//8][goff + k%4 + (row%8)*4]
def b_elem(x, col, k, goff): return x[k//4][goff + k%4 + col*4]
values[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
values[u] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
elif device == "INTEL":
@ -200,17 +199,17 @@ class PythonProgram:
def b_elem(x, col, k, goff): return x[k][goff+col]
# C, D (8 elements on 8 threads)
def c_map(lane, elem): return (lane, elem)
values[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
values[u] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
elif device == "CPU":
def elem(x, col, row, _): return x[col+row][0] # k is always 0
def c_map(lane, elem): return (elem%16, elem//16)
values[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
values[u] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
elif uop in GroupOp.ALU:
assert all_same([len(x) for x in src_values]), f"{[len(x) for x in src_values]} doesn't match on {uop}"
assert all_same([dtype] + src_dtypes) or uop in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {uop}"
values[i] = [exec_alu(uop, dtype, p) for p in zip(*src_values)]
assert i in values, (uop, dtype, srcs, arg)
values[u] = [exec_alu(uop, dtype, p) for p in zip(*src_values)]
assert u in values, (uop, dtype, srcs, arg)
i += 1
return time.perf_counter() - st