identity store for DEFINE_REG

This commit is contained in:
George Hotz 2025-07-24 16:13:43 -07:00
commit 984a0edbc9

View file

@ -284,7 +284,7 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
assert all(x.dtype == red.dtype for x in lst), f"horizontal reduction mismatch {lst[0].dtype} != {red.dtype}"
# if we have a range
if len(reduce_range) != 0:
input_ranges = tuple([x for x in inp.toposort(gate=lambda x: x.op is Ops.STORE) if x.op is Ops.RANGE and x not in reduce_range])
input_ranges = tuple([x for x in inp.toposort(gate=lambda x: x.op is not Ops.STORE) if x.op is Ops.RANGE and x not in reduce_range])
identity = red.const_like(identity_element(red.arg, red.dtype.scalar()))
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), (identity,), (ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
lst = [acc.store(identity, UOp(Ops.NOOP, src=input_ranges)).load(*reduce_range)] + lst # put acc as the first element