mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
refactor to u.
This commit is contained in:
parent
058259e0ea
commit
5d13dd1123
1 changed files with 45 additions and 46 deletions
|
|
@ -57,27 +57,26 @@ class PythonProgram:
|
|||
i = 0
|
||||
while i < len(self.uops):
|
||||
u = self.uops[i]
|
||||
uop, dtype, srcs, arg = u.op, u.dtype, (() if u.op is Ops.SPECIAL else u.src), u.arg
|
||||
src_values = [values[v] for v in srcs if v.op not in void_ops]
|
||||
src_dtypes = [v.dtype for v in srcs if v.op not in void_ops]
|
||||
if getenv("TRACE"): print(i, uop, dtype, arg, src_values, src_dtypes)
|
||||
if uop is Ops.END:
|
||||
i = self.uop_to_index[srcs[1]]
|
||||
src_values = [values[v] for v in u.src if v.op not in void_ops]
|
||||
src_dtypes = [v.dtype for v in u.src if v.op not in void_ops]
|
||||
if getenv("TRACE"): print(i, u.op, u.dtype, u.arg, src_values, src_dtypes)
|
||||
if u.op is Ops.END:
|
||||
i = self.uop_to_index[u.src[1]]
|
||||
continue
|
||||
if uop is Ops.IF:
|
||||
if u.op is Ops.IF:
|
||||
exec_masks.append([x and y for x,y in zip(exec_masks[-1], src_values[0])])
|
||||
i += 1
|
||||
continue
|
||||
if uop is Ops.ENDIF:
|
||||
if u.op is Ops.ENDIF:
|
||||
exec_masks.pop()
|
||||
i += 1
|
||||
continue
|
||||
if uop in (Ops.BARRIER, Ops.SINK, Ops.NOOP, Ops.GROUP):
|
||||
if u.op in (Ops.BARRIER, Ops.SINK, Ops.NOOP, Ops.GROUP):
|
||||
# in the python emulator, the warp is always in sync
|
||||
i += 1
|
||||
continue
|
||||
assert dtype is not None, f"{uop} is missing a dtype"
|
||||
if uop is Ops.STORE:
|
||||
assert u.dtype is not None, f"{u.op} is missing a dtype"
|
||||
if u.op is Ops.STORE:
|
||||
assert len(src_values) == 2, f"STORE must be lowered to 2 srcs, got {len(src_values)}"
|
||||
store_gate = exec_masks[-1]
|
||||
for j,val in enumerate(src_values[1] if src_dtypes[1].count > 1 else [src_values[1]]):
|
||||
|
|
@ -85,25 +84,25 @@ class PythonProgram:
|
|||
if g: _store(m, o+j, v, src_dtypes[1].scalar())
|
||||
i += 1
|
||||
continue
|
||||
if uop is Ops.AFTER: values[u] = src_values[0]
|
||||
elif uop in {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
|
||||
assert isinstance(dtype, PtrDType), dtype
|
||||
storage_fmt = storage_fmt_for_dtype(dtype.base.scalar())
|
||||
if storage_fmt is None: raise RuntimeError(f"{dtype=} is not supported")
|
||||
if u.op is Ops.AFTER: values[u] = src_values[0]
|
||||
elif u.op in {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
|
||||
assert isinstance(u.dtype, PtrDType), u.dtype
|
||||
storage_fmt = storage_fmt_for_dtype(u.dtype.base.scalar())
|
||||
if storage_fmt is None: raise RuntimeError(f"dtype={u.dtype} is not supported")
|
||||
if TYPE_CHECKING or sys.version_info < (3, 12): assert storage_fmt != "e"
|
||||
if uop is Ops.DEFINE_REG:
|
||||
if u.op is Ops.DEFINE_REG:
|
||||
# REGs are per thread
|
||||
values[u] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
|
||||
values[u] = [memoryview(bytearray(u.dtype.size*u.dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
|
||||
else:
|
||||
buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is not Ops.PARAM else pbufs.pop(0)
|
||||
buf = memoryview(bytearray(u.dtype.size*u.dtype.itemsize)) if u.op is not Ops.PARAM else pbufs.pop(0)
|
||||
values[u] = [buf.cast(storage_fmt)] * warp_size
|
||||
elif uop is Ops.DEFINE_VAR:
|
||||
elif u.op is Ops.DEFINE_VAR:
|
||||
values[u] = [pvals.pop(0)] * warp_size
|
||||
elif uop is Ops.SPECIAL:
|
||||
if arg[0] == 'g': values[u] = [idxs[2-int(arg[-1])]] * warp_size
|
||||
elif arg[0] == 'l': values[u] = [x[2-int(arg[-1])] for x in warp]
|
||||
elif uop is Ops.CONST: values[u] = [arg] * warp_size
|
||||
elif uop is Ops.INDEX:
|
||||
elif u.op is Ops.SPECIAL:
|
||||
if u.arg[0] == 'g': values[u] = [idxs[2-int(u.arg[-1])]] * warp_size
|
||||
elif u.arg[0] == 'l': values[u] = [x[2-int(u.arg[-1])] for x in warp]
|
||||
elif u.op is Ops.CONST: values[u] = [u.arg] * warp_size
|
||||
elif u.op is Ops.INDEX:
|
||||
ret:list = []
|
||||
if isinstance(src_dtypes[0], ImageDType):
|
||||
assert len(src_values) == 3, "image index must be 3 srcs"
|
||||
|
|
@ -114,9 +113,9 @@ class PythonProgram:
|
|||
assert len(src_values) == 2, "non-image index must be 2 srcs"
|
||||
for m,o in zip(*src_values): ret.append((m,o))
|
||||
values[u] = ret
|
||||
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
|
||||
elif u.op is Ops.CAST and isinstance(u.dtype, PtrDType):
|
||||
values[u] = src_values[0]
|
||||
elif uop is Ops.RANGE:
|
||||
elif u.op is Ops.RANGE:
|
||||
if u not in values: values[u] = [0] * warp_size
|
||||
else:
|
||||
for j in range(len(values[u])):
|
||||
|
|
@ -125,21 +124,21 @@ class PythonProgram:
|
|||
del values[u]
|
||||
i = self.loop_ends[u] + 1
|
||||
continue
|
||||
elif uop is Ops.STACK: values[u] = src_values
|
||||
elif uop is Ops.BITCAST: values[u] = [bitcast(x, src_dtypes[0], dtype) for x in src_values[0]]
|
||||
elif uop is Ops.CAST:
|
||||
values[u] = [truncate.get(dtype, lambda dt: dt)(dtype.const(x)) for x in src_values[0]]
|
||||
elif uop is Ops.LOAD:
|
||||
if dtype.count > 1:
|
||||
elif u.op is Ops.STACK: values[u] = src_values
|
||||
elif u.op is Ops.BITCAST: values[u] = [bitcast(x, src_dtypes[0], u.dtype) for x in src_values[0]]
|
||||
elif u.op is Ops.CAST:
|
||||
values[u] = [truncate.get(u.dtype, lambda dt: dt)(u.dtype.const(x)) for x in src_values[0]]
|
||||
elif u.op is Ops.LOAD:
|
||||
if u.dtype.count > 1:
|
||||
values[u] = [load([src_values[k][j] if k != 0 and src_dtypes[k].count > 1 else src_values[k] \
|
||||
for k in range(len(src_values))], j, dtype.scalar()) for j in range(dtype.count)]
|
||||
for k in range(len(src_values))], j, u.dtype.scalar()) for j in range(u.dtype.count)]
|
||||
else:
|
||||
values[u] = load(src_values, 0, dtype)
|
||||
elif uop is Ops.GEP: values[u] = src_values[0][get_single_element(arg)]
|
||||
elif uop is Ops.WMMA:
|
||||
first_src_dtype = srcs[0].dtype
|
||||
values[u] = load(src_values, 0, u.dtype)
|
||||
elif u.op is Ops.GEP: values[u] = src_values[0][get_single_element(u.arg)]
|
||||
elif u.op is Ops.WMMA:
|
||||
first_src_dtype = u.src[0].dtype
|
||||
assert isinstance(first_src_dtype, DType) # mypy
|
||||
dims, dtype_in, device, threads = arg[1], first_src_dtype.scalar(), arg[4], arg[5]
|
||||
dims, dtype_in, device, threads = u.arg[1], first_src_dtype.scalar(), u.arg[4], u.arg[5]
|
||||
wmma_helper = functools.partial(generic_wmma_helper, src_values, warp_size)
|
||||
# TODO: refactor these to a shared TensorCoreLayout
|
||||
if device == "METAL":
|
||||
|
|
@ -191,7 +190,7 @@ class PythonProgram:
|
|||
def b_elem(x, col, k, goff): return x[k//4][goff + k%4 + col*4]
|
||||
values[u] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
|
||||
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {u.arg}")
|
||||
elif device == "INTEL":
|
||||
# A (16 elements on 8 threads)
|
||||
def a_elem(x, k, row, goff): return x[k%2+row*2][goff+k//2]
|
||||
|
|
@ -204,12 +203,12 @@ class PythonProgram:
|
|||
def elem(x, col, row, _): return x[col+row][0] # k is always 0
|
||||
def c_map(lane, elem): return (elem%16, elem//16)
|
||||
values[u] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
||||
elif uop in GroupOp.ALU:
|
||||
assert all_same([len(x) for x in src_values]), f"{[len(x) for x in src_values]} doesn't match on {uop}"
|
||||
assert all_same([dtype] + src_dtypes) or uop in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {uop}"
|
||||
values[u] = [exec_alu(uop, dtype, p) for p in zip(*src_values)]
|
||||
assert u in values, (uop, dtype, srcs, arg)
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {u.arg}")
|
||||
elif u.op in GroupOp.ALU:
|
||||
assert all_same([len(x) for x in src_values]), f"{[len(x) for x in src_values]} doesn't match on {u.op}"
|
||||
assert all_same([u.dtype] + src_dtypes) or u.op in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {u.op}"
|
||||
values[u] = [exec_alu(u.op, u.dtype, p) for p in zip(*src_values)]
|
||||
assert u in values, u
|
||||
i += 1
|
||||
return time.perf_counter() - st
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue