mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
3 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc68b7ef96 | ||
|
|
9997f79c0a | ||
|
|
386bbf311c |
1 changed files with 4 additions and 3 deletions
|
|
@ -284,9 +284,10 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
|||
assert all(x.dtype == red.dtype for x in lst), f"horizontal reduction mismatch {lst[0].dtype} != {red.dtype}"
|
||||
# if we have a range
|
||||
if len(reduce_range) != 0:
|
||||
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG),
|
||||
(red.const_like(identity_element(red.arg, red.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
|
||||
lst = [acc.load()] + lst # put acc as the first element
|
||||
identity = red.const_like(identity_element(red.arg, red.dtype.scalar()))
|
||||
in_loop = functools.reduce(operator.or_, [x.ne(x.const_like(0)) for x in reduce_range]).broadcast(red.dtype.count)
|
||||
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), src=(identity,), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
|
||||
lst = [in_loop.where(acc.load(*reduce_range), identity)] + lst # put acc as the first element
|
||||
ctx.acc_num += 1
|
||||
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
||||
return acc.store(ret, *reduce_range).load() if len(reduce_range) != 0 else ret
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue