mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
knum 4
This commit is contained in:
parent
71c7c455a6
commit
dbb50e4a00
2 changed files with 27 additions and 11 deletions
|
|
@ -105,23 +105,34 @@ if __name__ == "__main__":
|
|||
k.apply_opt(Opt(OptOps.UPCAST, 1, 64))
|
||||
#k.apply_opt(Opt(OptOps.UPCAST, 0, 2))
|
||||
"""
|
||||
if False:
|
||||
pass
|
||||
#if knum in [7, 11, 14, 18]:
|
||||
# alignment issue?
|
||||
#pass
|
||||
if knum == 4:
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 96))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
|
||||
elif knum == 37:
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 384))
|
||||
else:
|
||||
full_shape = k.full_shape
|
||||
out_shape = k.sts[0].shape
|
||||
out_strides = k.sts[0].real_strides()
|
||||
if len(out_strides) == 3:
|
||||
if full_shape[2] <= 32: k.apply_opt(Opt(OptOps.UNROLL, 0, 0))
|
||||
else: k.apply_opt(Opt(OptOps.UNROLL, 0, 8))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, out_strides[0]))
|
||||
if out_strides[0] < 128:
|
||||
upcast_0 = 128//out_strides[0]
|
||||
if out_shape[0]%upcast_0 == 0 and upcast_0 != 1: k.apply_opt(Opt(OptOps.UPCAST, 0, upcast_0))
|
||||
if full_shape[1] < 128:
|
||||
if full_shape[2] <= 32: k.apply_opt(Opt(OptOps.UNROLL, 0, 0))
|
||||
else: k.apply_opt(Opt(OptOps.UNROLL, 0, 8))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, full_shape[1]))
|
||||
if out_strides[0] < 128:
|
||||
upcast_0 = 128//out_strides[0]
|
||||
if out_shape[0]%upcast_0 == 0 and upcast_0 != 1: k.apply_opt(Opt(OptOps.UPCAST, 0, upcast_0))
|
||||
elif full_shape[1] % 128 == 0:
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 128))
|
||||
elif len(out_strides) == 1:
|
||||
assert full_shape[0]%128 == 0
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 0, 128))
|
||||
#print("here", out_shape, out_strides, k.name)
|
||||
#if full_shape[0]%128 == 0: k.apply_opt(Opt(OptOps.UPCAST, 0, 128))
|
||||
pass
|
||||
#print("here", out_shape, out_strides, k.name)
|
||||
#k.hand_coded_optimizations()
|
||||
#if knum in [5]: k.apply_opt(Opt(OptOps.UPCAST, 1, 2))
|
||||
p2 = k.to_program()
|
||||
|
|
|
|||
|
|
@ -187,6 +187,11 @@ gep_pushing = PatternMatcher([
|
|||
if not isinstance(x.dtype, PtrDType) else None),
|
||||
# VECTORIZE on same GEP
|
||||
(UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))),
|
||||
# CAST on multi GEP
|
||||
(UPat(Ops.CAST, src=(UPat(Ops.GEP, name="g"),), name="c"),
|
||||
lambda c,g: g.src[0].gep(g.arg[0]).cast(c.dtype.scalar()).broadcast(len(g.arg)) if len(g.arg) > 1 and all_same(g.arg) else None),
|
||||
# VECTORIZE/CONST
|
||||
(UPat(Ops.VECTORIZE, src=UPat.var("x"))+UPat.cvar("c", vec=False), lambda x,c: (x+c.arg).broadcast(c.dtype.count)),
|
||||
])
|
||||
|
||||
symbolic = symbolic_simple+PatternMatcher([
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue