Ops.BEFORE finally fixes the DEFINE_REG hacks

This commit is contained in:
George Hotz 2025-07-24 14:02:45 -07:00
commit 386bbf311c
4 changed files with 8 additions and 5 deletions

View file

@ -284,9 +284,11 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
assert all(x.dtype == red.dtype for x in lst), f"horizontal reduction mismatch {lst[0].dtype} != {red.dtype}"
# if we have a range
if len(reduce_range) != 0:
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG),
(red.const_like(identity_element(red.arg, red.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
lst = [acc.load()] + lst # put acc as the first element
input_ranges = tuple([x for x in inp.toposort() if x.op is Ops.RANGE and x not in reduce_range])
identity = red.const_like(identity_element(red.arg, red.dtype.scalar()))
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), src=(identity,), arg=(ctx.acc_num,))
acc = acc.index(UOp.const(dtypes.int, 0)).store(identity, UOp(Ops.BEFORE, src=input_ranges))
lst = [acc.load(*reduce_range)] + lst # put acc as the first element
ctx.acc_num += 1
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
return acc.store(ret, *reduce_range).load() if len(reduce_range) != 0 else ret

View file

@ -135,6 +135,7 @@ class CStyleLanguage(Renderer):
c: defaultdict[str, int] = defaultdict(int)
name = "test"
for u in uops:
if u.op is Ops.BEFORE: continue
if u.op is Ops.SINK:
if u.arg is not None: name = u.arg.function_name
continue

View file

@ -9,7 +9,7 @@ class FastEnum(IntEnum):
# the order of these Ops controls the order of the toposort
class Ops(FastEnum):
# uops that aren't rendered
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto() # noqa: E702
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); BEFORE = auto() # noqa: E702
# buffer ops
COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702

View file

@ -15,7 +15,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D"}
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D", Ops.BEFORE: "#f8a0e0"}
# VIZ API