Compare commits

...

3 commits

Author SHA1 Message Date
George Hotz
dc68b7ef96 in_loop 2025-07-24 14:19:19 -07:00
George Hotz
9997f79c0a this is the right way to do this 2025-07-24 14:09:43 -07:00
George Hotz
386bbf311c Ops.BEFORE finally fixes the DEFINE_REG hacks 2025-07-24 14:02:45 -07:00

View file

@ -284,9 +284,10 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
assert all(x.dtype == red.dtype for x in lst), f"horizontal reduction mismatch {lst[0].dtype} != {red.dtype}"
# if we have a range
if len(reduce_range) != 0:
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG),
(red.const_like(identity_element(red.arg, red.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
lst = [acc.load()] + lst # put acc as the first element
identity = red.const_like(identity_element(red.arg, red.dtype.scalar()))
in_loop = functools.reduce(operator.or_, [x.ne(x.const_like(0)) for x in reduce_range]).broadcast(red.dtype.count)
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), src=(identity,), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
lst = [in_loop.where(acc.load(*reduce_range), identity)] + lst # put acc as the first element
ctx.acc_num += 1
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
return acc.store(ret, *reduce_range).load() if len(reduce_range) != 0 else ret