write custom_sum with set and after (#13045)

This commit is contained in:
chenyu 2025-11-01 10:45:30 -04:00 committed by GitHub
commit bebec73471
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,6 +1,6 @@
import unittest
from tinygrad import Tensor, UOp, Context
from tinygrad.uop.ops import KernelInfo, AxisType, Ops
from tinygrad.uop.ops import KernelInfo, AxisType
# **** kernels ****
@ -33,9 +33,10 @@ def custom_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
return prog.sink(arg=KernelInfo(name=f"custom_gemm_{C.shape[0]}_{C.shape[1]}_{A.shape[1]}", opts_to_apply=()))
def custom_sum(B:UOp, A:UOp) -> UOp:
# TODO: write with set and after?
i = UOp.range(A.shape[0], 0, axis_type=AxisType.REDUCE)
return B[0].store(A[i].reduce(i, arg=Ops.ADD)).sink(arg=KernelInfo(name=f"custom_sum_{A.shape[0]}", opts_to_apply=()))
B = B[0].set(0.0)
B = B[0].set(B.after(i)[0] + A[i], end=i)
return B.sink(arg=KernelInfo(name=f"custom_sum_{A.shape[0]}", opts_to_apply=()))
def flip_contract_kernel(dest:UOp, src:UOp):
assert dest.size%4 == 0