mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
l2
This commit is contained in:
parent
4a49d05a3f
commit
5ce951fb34
3 changed files with 75 additions and 2 deletions
|
|
@ -2,9 +2,11 @@ import numpy as np
|
|||
import unittest
|
||||
from dataclasses import replace
|
||||
from tinygrad import Tensor, Context, Device, dtypes
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.ops import Ops, UOp
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, lower_schedule_item
|
||||
from tinygrad.engine.search import bufs_from_lin
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
|
||||
N = 512
|
||||
|
||||
|
|
@ -234,5 +236,75 @@ class TestQuantizeOnnx(unittest.TestCase):
|
|||
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
|
||||
sexec(out, opts)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "DSP", "only tests for DSP")
|
||||
class TestDSPCache(unittest.TestCase):
|
||||
def test_cache_speed(self):
|
||||
# string becuase this breaks Python language server for syntax highlight for some reason
|
||||
ast = eval("""UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(25088), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 1), strides=(0, 896, 32, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.CAST, dtypes.uchar, arg=None, src=(
|
||||
UOp(Ops.XOR, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.MAX, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.XOR, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.MAX, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (4,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.uchar, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(150528), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 192), strides=(0, 5376, 192, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
UOp(Ops.CONST, dtypes.float, arg=0.012368360534310341, src=(
|
||||
x22:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 192), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.char, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.char.ptr(6144), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 48, 4), strides=(4, 128, 1), offset=0, mask=None, contiguous=False), View(shape=(1, 28, 28, 32, 192), strides=(0, 0, 0, 192, 1), offset=0, mask=None, contiguous=False))), src=()),)),)),)),
|
||||
UOp(Ops.CONST, dtypes.float, arg=0.007441135589033365, src=(
|
||||
x22,)),)),)),)),
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(32), arg=3, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 1), strides=(0, 0, 0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.CONST, dtypes.float, arg=9.203465015161783e-05, src=(
|
||||
x36:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 28, 28, 32, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
UOp(Ops.CONST, dtypes.float, arg=33.812857328652136, src=(
|
||||
x36,)),)),
|
||||
UOp(Ops.CONST, dtypes.float, arg=0.4999999, src=(
|
||||
x36,)),)),
|
||||
UOp(Ops.CONST, dtypes.float, arg=136.0, src=(
|
||||
x36,)),)),)),
|
||||
UOp(Ops.CONST, dtypes.int, arg=0, src=(
|
||||
x36,)),)),
|
||||
x41:=UOp(Ops.CONST, dtypes.int, arg=-1, src=(
|
||||
x36,)),)),
|
||||
UOp(Ops.CONST, dtypes.int, arg=-256, src=(
|
||||
x36,)),)),
|
||||
x41,)),)),)),))""")
|
||||
opts = [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=32), Opt(op=OptOps.UPCAST, axis=0, arg=4)]
|
||||
with Context(DEVECTORIZE=0, QUANTIZE=1):
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
prg = k.to_program()
|
||||
src = prg.src
|
||||
src = src.replace("int32 acc3 = cast0;\n", "int32 acc3 = cast0;\n __builtin_HEXAGON_Y4_l2fetch(data2, 6144);\n")
|
||||
print(src)
|
||||
prg = replace(prg, src=src)
|
||||
rt = CompiledRunner(prg)
|
||||
ei = ExecItem(rt, bufs_from_lin(k))
|
||||
tm = ei.run(wait=True)
|
||||
print(f"final time {tm*1e6:.2f} us")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -668,6 +668,7 @@ class Kernel:
|
|||
for i,(buf,st) in enumerate([(buf,st) for buf,st in zip(self.bufs, self.sts) if buf.op not in {Ops.CONST, Ops.VALID}]):
|
||||
print(f"{i:2d}: {str(st.shape):25s} {str(buf.src[0].dtype).replace('dtypes.',''):20s}", st.real_strides())
|
||||
print(self.applied_opts)
|
||||
if DEBUG >= 5: print(self.ast)
|
||||
# verify AST matches the spec after applying opts
|
||||
if __debug__: type_verify(list(modified_ast.toposort))
|
||||
# TODO: sadly modified_ast doesn't pass the shape spec because of how group_for_reduces constructs UOps, there's probably a way to fix this
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ def add_to_mul(c:UOp, x:UOp):
|
|||
dsp_pm_late = PatternMatcher([
|
||||
# prefetch L1
|
||||
(UPat(Ops.LOAD, dtype=(dtypes.uchar.vec(4), dtypes.uchar.vec(8)), name="ld"),
|
||||
lambda ld: ld.replace(src=ld.src+(UOp(Ops.CUSTOM, dtypes.void, src=(ld.src[0].src[0].index(UOp.const(dtypes.int, 16)),),
|
||||
lambda ld: ld.replace(src=ld.src+(UOp(Ops.CUSTOM, dtypes.void, src=(ld.src[0].src[0].index(UOp.const(dtypes.int, 8)),),
|
||||
arg="__builtin_HEXAGON_Y2_dcfetch({0});"),)) if ld.src[-1].op is not Ops.CUSTOM else None),
|
||||
|
||||
(UPat(Ops.CUSTOMI, dtype=dtypes.int.vec(32), name="c")+UPat.var("x"), add_to_mul),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue