This commit is contained in:
George Hotz 2025-03-21 11:14:12 +08:00
commit 5ce951fb34
3 changed files with 75 additions and 2 deletions

View file

@ -2,9 +2,11 @@ import numpy as np
import unittest
from dataclasses import replace
from tinygrad import Tensor, Context, Device, dtypes
from tinygrad.ops import Ops
from tinygrad.ops import Ops, UOp
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
from tinygrad.engine.realize import CompiledRunner, ExecItem, lower_schedule_item
from tinygrad.engine.search import bufs_from_lin
from tinygrad.shape.shapetracker import ShapeTracker, View
N = 512
@ -234,5 +236,75 @@ class TestQuantizeOnnx(unittest.TestCase):
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
sexec(out, opts)
@unittest.skipIf(Device.DEFAULT != "DSP", "only tests for DSP")
class TestDSPCache(unittest.TestCase):
def test_cache_speed(self):
# string becuase this breaks Python language server for syntax highlight for some reason
ast = eval("""UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(25088), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 1), strides=(0, 896, 32, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.uchar, arg=None, src=(
UOp(Ops.XOR, dtypes.int, arg=None, src=(
UOp(Ops.MAX, dtypes.int, arg=None, src=(
UOp(Ops.XOR, dtypes.int, arg=None, src=(
UOp(Ops.MAX, dtypes.int, arg=None, src=(
UOp(Ops.CAST, dtypes.int, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (4,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.int, arg=None, src=(
UOp(Ops.LOAD, dtypes.uchar, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(150528), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 192), strides=(0, 5376, 192, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
UOp(Ops.CONST, dtypes.float, arg=0.012368360534310341, src=(
x22:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 192), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.int, arg=None, src=(
UOp(Ops.LOAD, dtypes.char, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.char.ptr(6144), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 48, 4), strides=(4, 128, 1), offset=0, mask=None, contiguous=False), View(shape=(1, 28, 28, 32, 192), strides=(0, 0, 0, 192, 1), offset=0, mask=None, contiguous=False))), src=()),)),)),)),
UOp(Ops.CONST, dtypes.float, arg=0.007441135589033365, src=(
x22,)),)),)),)),
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(32), arg=3, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 1), strides=(0, 0, 0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(Ops.CONST, dtypes.float, arg=9.203465015161783e-05, src=(
x36:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
UOp(Ops.CONST, dtypes.float, arg=33.812857328652136, src=(
x36,)),)),
UOp(Ops.CONST, dtypes.float, arg=0.4999999, src=(
x36,)),)),
UOp(Ops.CONST, dtypes.float, arg=136.0, src=(
x36,)),)),)),
UOp(Ops.CONST, dtypes.int, arg=0, src=(
x36,)),)),
x41:=UOp(Ops.CONST, dtypes.int, arg=-1, src=(
x36,)),)),
UOp(Ops.CONST, dtypes.int, arg=-256, src=(
x36,)),)),
x41,)),)),)),))""")
opts = [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=32), Opt(op=OptOps.UPCAST, axis=0, arg=4)]
with Context(DEVECTORIZE=0, QUANTIZE=1):
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
for opt in opts: k.apply_opt(opt)
prg = k.to_program()
src = prg.src
src = src.replace("int32 acc3 = cast0;\n", "int32 acc3 = cast0;\n __builtin_HEXAGON_Y4_l2fetch(data2, 6144);\n")
print(src)
prg = replace(prg, src=src)
rt = CompiledRunner(prg)
ei = ExecItem(rt, bufs_from_lin(k))
tm = ei.run(wait=True)
print(f"final time {tm*1e6:.2f} us")
if __name__ == "__main__":
unittest.main()

View file

@ -668,6 +668,7 @@ class Kernel:
for i,(buf,st) in enumerate([(buf,st) for buf,st in zip(self.bufs, self.sts) if buf.op not in {Ops.CONST, Ops.VALID}]):
print(f"{i:2d}: {str(st.shape):25s} {str(buf.src[0].dtype).replace('dtypes.',''):20s}", st.real_strides())
print(self.applied_opts)
if DEBUG >= 5: print(self.ast)
# verify AST matches the spec after applying opts
if __debug__: type_verify(list(modified_ast.toposort))
# TODO: sadly modified_ast doesn't pass the shape spec because of how group_for_reduces constructs UOps, there's probably a way to fix this

View file

@ -82,7 +82,7 @@ def add_to_mul(c:UOp, x:UOp):
dsp_pm_late = PatternMatcher([
# prefetch L1
(UPat(Ops.LOAD, dtype=(dtypes.uchar.vec(4), dtypes.uchar.vec(8)), name="ld"),
lambda ld: ld.replace(src=ld.src+(UOp(Ops.CUSTOM, dtypes.void, src=(ld.src[0].src[0].index(UOp.const(dtypes.int, 16)),),
lambda ld: ld.replace(src=ld.src+(UOp(Ops.CUSTOM, dtypes.void, src=(ld.src[0].src[0].index(UOp.const(dtypes.int, 8)),),
arg="__builtin_HEXAGON_Y2_dcfetch({0});"),)) if ld.src[-1].op is not Ops.CUSTOM else None),
(UPat(Ops.CUSTOMI, dtype=dtypes.int.vec(32), name="c")+UPat.var("x"), add_to_mul),