refactor to u.

This commit is contained in:
George Hotz 2026-05-29 19:05:20 -07:00
commit 5d13dd1123

View file

@ -57,27 +57,26 @@ class PythonProgram:
i = 0
while i < len(self.uops):
u = self.uops[i]
uop, dtype, srcs, arg = u.op, u.dtype, (() if u.op is Ops.SPECIAL else u.src), u.arg
src_values = [values[v] for v in srcs if v.op not in void_ops]
src_dtypes = [v.dtype for v in srcs if v.op not in void_ops]
if getenv("TRACE"): print(i, uop, dtype, arg, src_values, src_dtypes)
if uop is Ops.END:
i = self.uop_to_index[srcs[1]]
src_values = [values[v] for v in u.src if v.op not in void_ops]
src_dtypes = [v.dtype for v in u.src if v.op not in void_ops]
if getenv("TRACE"): print(i, u.op, u.dtype, u.arg, src_values, src_dtypes)
if u.op is Ops.END:
i = self.uop_to_index[u.src[1]]
continue
if uop is Ops.IF:
if u.op is Ops.IF:
exec_masks.append([x and y for x,y in zip(exec_masks[-1], src_values[0])])
i += 1
continue
if uop is Ops.ENDIF:
if u.op is Ops.ENDIF:
exec_masks.pop()
i += 1
continue
if uop in (Ops.BARRIER, Ops.SINK, Ops.NOOP, Ops.GROUP):
if u.op in (Ops.BARRIER, Ops.SINK, Ops.NOOP, Ops.GROUP):
# in the python emulator, the warp is always in sync
i += 1
continue
assert dtype is not None, f"{uop} is missing a dtype"
if uop is Ops.STORE:
assert u.dtype is not None, f"{u.op} is missing a dtype"
if u.op is Ops.STORE:
assert len(src_values) == 2, f"STORE must be lowered to 2 srcs, got {len(src_values)}"
store_gate = exec_masks[-1]
for j,val in enumerate(src_values[1] if src_dtypes[1].count > 1 else [src_values[1]]):
@ -85,25 +84,25 @@ class PythonProgram:
if g: _store(m, o+j, v, src_dtypes[1].scalar())
i += 1
continue
if uop is Ops.AFTER: values[u] = src_values[0]
elif uop in {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
assert isinstance(dtype, PtrDType), dtype
storage_fmt = storage_fmt_for_dtype(dtype.base.scalar())
if storage_fmt is None: raise RuntimeError(f"{dtype=} is not supported")
if u.op is Ops.AFTER: values[u] = src_values[0]
elif u.op in {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
assert isinstance(u.dtype, PtrDType), u.dtype
storage_fmt = storage_fmt_for_dtype(u.dtype.base.scalar())
if storage_fmt is None: raise RuntimeError(f"dtype={u.dtype} is not supported")
if TYPE_CHECKING or sys.version_info < (3, 12): assert storage_fmt != "e"
if uop is Ops.DEFINE_REG:
if u.op is Ops.DEFINE_REG:
# REGs are per thread
values[u] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
values[u] = [memoryview(bytearray(u.dtype.size*u.dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
else:
buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is not Ops.PARAM else pbufs.pop(0)
buf = memoryview(bytearray(u.dtype.size*u.dtype.itemsize)) if u.op is not Ops.PARAM else pbufs.pop(0)
values[u] = [buf.cast(storage_fmt)] * warp_size
elif uop is Ops.DEFINE_VAR:
elif u.op is Ops.DEFINE_VAR:
values[u] = [pvals.pop(0)] * warp_size
elif uop is Ops.SPECIAL:
if arg[0] == 'g': values[u] = [idxs[2-int(arg[-1])]] * warp_size
elif arg[0] == 'l': values[u] = [x[2-int(arg[-1])] for x in warp]
elif uop is Ops.CONST: values[u] = [arg] * warp_size
elif uop is Ops.INDEX:
elif u.op is Ops.SPECIAL:
if u.arg[0] == 'g': values[u] = [idxs[2-int(u.arg[-1])]] * warp_size
elif u.arg[0] == 'l': values[u] = [x[2-int(u.arg[-1])] for x in warp]
elif u.op is Ops.CONST: values[u] = [u.arg] * warp_size
elif u.op is Ops.INDEX:
ret:list = []
if isinstance(src_dtypes[0], ImageDType):
assert len(src_values) == 3, "image index must be 3 srcs"
@ -114,9 +113,9 @@ class PythonProgram:
assert len(src_values) == 2, "non-image index must be 2 srcs"
for m,o in zip(*src_values): ret.append((m,o))
values[u] = ret
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
elif u.op is Ops.CAST and isinstance(u.dtype, PtrDType):
values[u] = src_values[0]
elif uop is Ops.RANGE:
elif u.op is Ops.RANGE:
if u not in values: values[u] = [0] * warp_size
else:
for j in range(len(values[u])):
@ -125,21 +124,21 @@ class PythonProgram:
del values[u]
i = self.loop_ends[u] + 1
continue
elif uop is Ops.STACK: values[u] = src_values
elif uop is Ops.BITCAST: values[u] = [bitcast(x, src_dtypes[0], dtype) for x in src_values[0]]
elif uop is Ops.CAST:
values[u] = [truncate.get(dtype, lambda dt: dt)(dtype.const(x)) for x in src_values[0]]
elif uop is Ops.LOAD:
if dtype.count > 1:
elif u.op is Ops.STACK: values[u] = src_values
elif u.op is Ops.BITCAST: values[u] = [bitcast(x, src_dtypes[0], u.dtype) for x in src_values[0]]
elif u.op is Ops.CAST:
values[u] = [truncate.get(u.dtype, lambda dt: dt)(u.dtype.const(x)) for x in src_values[0]]
elif u.op is Ops.LOAD:
if u.dtype.count > 1:
values[u] = [load([src_values[k][j] if k != 0 and src_dtypes[k].count > 1 else src_values[k] \
for k in range(len(src_values))], j, dtype.scalar()) for j in range(dtype.count)]
for k in range(len(src_values))], j, u.dtype.scalar()) for j in range(u.dtype.count)]
else:
values[u] = load(src_values, 0, dtype)
elif uop is Ops.GEP: values[u] = src_values[0][get_single_element(arg)]
elif uop is Ops.WMMA:
first_src_dtype = srcs[0].dtype
values[u] = load(src_values, 0, u.dtype)
elif u.op is Ops.GEP: values[u] = src_values[0][get_single_element(u.arg)]
elif u.op is Ops.WMMA:
first_src_dtype = u.src[0].dtype
assert isinstance(first_src_dtype, DType) # mypy
dims, dtype_in, device, threads = arg[1], first_src_dtype.scalar(), arg[4], arg[5]
dims, dtype_in, device, threads = u.arg[1], first_src_dtype.scalar(), u.arg[4], u.arg[5]
wmma_helper = functools.partial(generic_wmma_helper, src_values, warp_size)
# TODO: refactor these to a shared TensorCoreLayout
if device == "METAL":
@ -191,7 +190,7 @@ class PythonProgram:
def b_elem(x, col, k, goff): return x[k//4][goff + k%4 + col*4]
values[u] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
else: raise NotImplementedError(f"unimplemented tensor core {u.arg}")
elif device == "INTEL":
# A (16 elements on 8 threads)
def a_elem(x, k, row, goff): return x[k%2+row*2][goff+k//2]
@ -204,12 +203,12 @@ class PythonProgram:
def elem(x, col, row, _): return x[col+row][0] # k is always 0
def c_map(lane, elem): return (elem%16, elem//16)
values[u] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
elif uop in GroupOp.ALU:
assert all_same([len(x) for x in src_values]), f"{[len(x) for x in src_values]} doesn't match on {uop}"
assert all_same([dtype] + src_dtypes) or uop in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {uop}"
values[u] = [exec_alu(uop, dtype, p) for p in zip(*src_values)]
assert u in values, (uop, dtype, srcs, arg)
else: raise NotImplementedError(f"unimplemented tensor core {u.arg}")
elif u.op in GroupOp.ALU:
assert all_same([len(x) for x in src_values]), f"{[len(x) for x in src_values]} doesn't match on {u.op}"
assert all_same([u.dtype] + src_dtypes) or u.op in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {u.op}"
values[u] = [exec_alu(u.op, u.dtype, p) for p in zip(*src_values)]
assert u in values, u
i += 1
return time.perf_counter() - st