mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
write custom_sum with set and after (#13045)
This commit is contained in:
parent
e98506735b
commit
bebec73471
1 changed files with 4 additions and 3 deletions
|
|
@ -1,6 +1,6 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor, UOp, Context
|
||||
from tinygrad.uop.ops import KernelInfo, AxisType, Ops
|
||||
from tinygrad.uop.ops import KernelInfo, AxisType
|
||||
|
||||
# **** kernels ****
|
||||
|
||||
|
|
@ -33,9 +33,10 @@ def custom_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
|
|||
return prog.sink(arg=KernelInfo(name=f"custom_gemm_{C.shape[0]}_{C.shape[1]}_{A.shape[1]}", opts_to_apply=()))
|
||||
|
||||
def custom_sum(B:UOp, A:UOp) -> UOp:
|
||||
# TODO: write with set and after?
|
||||
i = UOp.range(A.shape[0], 0, axis_type=AxisType.REDUCE)
|
||||
return B[0].store(A[i].reduce(i, arg=Ops.ADD)).sink(arg=KernelInfo(name=f"custom_sum_{A.shape[0]}", opts_to_apply=()))
|
||||
B = B[0].set(0.0)
|
||||
B = B[0].set(B.after(i)[0] + A[i], end=i)
|
||||
return B.sink(arg=KernelInfo(name=f"custom_sum_{A.shape[0]}", opts_to_apply=()))
|
||||
|
||||
def flip_contract_kernel(dest:UOp, src:UOp):
|
||||
assert dest.size%4 == 0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue