mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
more cleanups
This commit is contained in:
parent
e0d63696d7
commit
290ba9ee37
5 changed files with 7 additions and 65 deletions
|
|
@ -59,7 +59,7 @@ if __name__ == "__main__":
|
|||
return {"input": img.numpy()}
|
||||
quantize_static(model_fp32, fn, ImagenetReader(), quant_format=QuantFormat.QDQ, per_channel=False,
|
||||
activation_type=QuantType.QUInt8, weight_type=QuantType.QUInt8,
|
||||
extra_options={"ActivationSymmetric": False, "WeightSymmetric": True})
|
||||
extra_options={"ActivationSymmetric": False})
|
||||
|
||||
run_onnx_jit, input_specs = load_onnx_model(fetch(fn))
|
||||
t_name, t_spec = list(input_specs.items())[0]
|
||||
|
|
|
|||
|
|
@ -158,11 +158,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
|||
must_divide = True
|
||||
if ctx is not None and ctx.device == "DSP":
|
||||
lengths = [128,64,32,16,8,4]
|
||||
#if ls.dtype.count in [128+64, 128*2+64, 128*4+64]: return None # leave 192 alone
|
||||
if ls.dtype.count in [192, 288, 160, 96]: return None # leave 192 alone
|
||||
# we really want stores to be 128 for fast casting
|
||||
#if ls.op is Ops.LOAD: lengths = [192]+lengths
|
||||
#if ls.op is Ops.LOAD: lengths = [1536,1024,512,384,256,192,96]+lengths
|
||||
if ls.dtype.count in [192, 288, 160, 96]: return None # leave these as loads
|
||||
must_divide = False
|
||||
elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
|
||||
pass
|
||||
|
|
@ -304,35 +300,11 @@ pm_reduce = PatternMatcher([
|
|||
(UPat(Ops.REDUCE, name="x"), reduce_to_acc)
|
||||
])
|
||||
|
||||
def move_load_mask(ld:UOp, idx:UOp):
|
||||
if len(idx.src) != 3: return None
|
||||
mask = idx.src[2]
|
||||
new_idx = idx.replace(src=idx.src[0:2])
|
||||
return ld.substitute({idx:new_idx}) * mask.broadcast(ld.dtype.count).cast(ld.dtype)
|
||||
|
||||
pm_move_load_masks = PatternMatcher([
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, name="idx"),), name="ld"), move_load_mask),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, name="idx").cast(),), name="ld"), move_load_mask),
|
||||
])
|
||||
|
||||
def fix_range(rng:UOp):
|
||||
if rng.arg in [1,2] and rng.src[0].arg == 0:
|
||||
return rng.replace(src=(rng.src[0]+1, rng.src[1]))
|
||||
|
||||
pm_ranges = PatternMatcher([
|
||||
(UPat(Ops.RANGE, name="rng"), fix_range)
|
||||
])
|
||||
|
||||
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None, is_conv=False) -> UOp:
|
||||
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
|
||||
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
|
||||
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
|
||||
|
||||
# we can move the load masks to after the load
|
||||
#sink = graph_rewrite(sink, pm_move_load_masks, name="move_load_masks")
|
||||
|
||||
#sink = graph_rewrite(sink, pm_ranges)
|
||||
|
||||
# devectorize is optional
|
||||
if DEVECTORIZE >= 2: sink = graph_rewrite(sink, sym+load_store_folding+load_store_indexing, ctx=opts)
|
||||
elif DEVECTORIZE: sink = graph_rewrite(sink, sym+devectorize+load_store_folding+correct_load_store+load_store_indexing, ctx=opts)
|
||||
|
|
@ -346,16 +318,6 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None, is_conv=False) ->
|
|||
else:
|
||||
sink = graph_rewrite(sink, opts.pre_matcher)
|
||||
|
||||
# late unroll
|
||||
"""
|
||||
late_unroll = PatternMatcher([(UPat(Ops.RANGE, arg=2, name="r"),
|
||||
lambda r: UOp(Ops.UNROLL, dtypes.int, src=(UOp.const(dtypes.int.vec(3), (0,1,2)),), arg=((2,3),)) )])
|
||||
sink = graph_rewrite(sink, late_unroll, name="late_unroll")
|
||||
from tinygrad.codegen.expander import expander
|
||||
sink = graph_rewrite(sink, sym+expander, name="late_expand")
|
||||
sink = graph_rewrite(sink, sym+load_store_folding+load_store_indexing, ctx=opts)
|
||||
"""
|
||||
|
||||
# remove reduce
|
||||
sink = graph_rewrite(sink, pm_reduce, ctx=ReduceContext(), name="remove_reduce")
|
||||
|
||||
|
|
|
|||
|
|
@ -123,18 +123,5 @@ def expand_rewrite(sink:UOp) -> UOp:
|
|||
# initial symbolic + migrate indexing (remove this)
|
||||
sink = graph_rewrite(sink, sym+migrate_indexing)
|
||||
|
||||
"""
|
||||
# late pad
|
||||
def contract_test(x):
|
||||
ret = UOp(Ops.CONTRACT, dtypes.int.vec(12), src=(x,), arg=((2,3), (3,4)))
|
||||
#return ret.gep(3) + ret.gep(7) + ret.gep(11) +
|
||||
return ret.gep((0,1,2,4,5,6,8,9,10))
|
||||
|
||||
late_unroll = gep_pushing+PatternMatcher([(UPat(Ops.UNROLL, arg=((3,3),)),
|
||||
lambda: UOp(Ops.UNROLL, dtypes.int, src=(UOp.const(dtypes.int.vec(4), (0,1,2,3)),), arg=((3,4),))),
|
||||
(UPat(Ops.CONTRACT, src=(UPat.var("x"),), arg=((2,3), (3,3))), contract_test)])
|
||||
sink = graph_rewrite(sink, late_unroll, name="late_unroll")
|
||||
"""
|
||||
|
||||
# expand
|
||||
return graph_rewrite(sink, sym+expander)
|
||||
|
|
|
|||
|
|
@ -387,7 +387,7 @@ class Kernel:
|
|||
self.group_for_reduces += 1
|
||||
elif opt.op is OptOps.UNROLL: # purple
|
||||
check(axis < self.first_upcast, "can't upcasted already upcasted")
|
||||
#check(amt <= 32, "don't unroll more than 32")
|
||||
check(amt <= 32, "don't unroll more than 32")
|
||||
# TODO: fix upcast_count to put purples before yellows. broken because of METAL tensor cores
|
||||
#upcast_count = sum(x == y for x,y in zip(self.full_shape[-self.upcasted:], self.output_shape[-self.upcasted:])) if self.upcasted else 0
|
||||
#self.shift_to(axis, amt, insert_before=None if upcast_count == 0 else self.shape_len-upcast_count)
|
||||
|
|
|
|||
|
|
@ -12,12 +12,7 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
|
||||
actions = []
|
||||
actions += [(Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0))]
|
||||
actions += [(Opt(op=OptOps.UNROLL, axis=1, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0))]
|
||||
actions += [(Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0))]
|
||||
actions += [(Opt(op=OptOps.PADTO, axis=axis, arg=amt), Opt(op=OptOps.UPCAST, axis=axis, arg=amt)) for amt in [128] for axis in range(7)]
|
||||
actions += [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7,16,32,64,128] for axis in range(6)]
|
||||
actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
|
||||
actions += [Opt(op=OptOps.UNROLL, axis=axis, arg=amt) for amt in [0,4,7] for axis in range(5)]
|
||||
actions += [Opt(op=OptOps.LOCAL, axis=axis, arg=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)]
|
||||
actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
|
||||
|
|
@ -114,18 +109,16 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
|
|||
|
||||
if TC_SEARCH_OVER_SHAPE and len(lin.applied_opts) == 0: # tensor core opts must be first
|
||||
for i, action in enumerate(kernel_actions):
|
||||
if not isinstance(action, tuple) and action.op == OptOps.TC and (tc_arg := cast(tuple, action.arg))[0] == -1:
|
||||
if action.op == OptOps.TC and (tc_arg := cast(tuple, action.arg))[0] == -1:
|
||||
# replace every tc_action with default tc with one tc_action for each available tc
|
||||
kernel_actions[i:i+1] = [Opt(op=OptOps.TC, axis=action.axis, arg=(tc_select, tc_arg[1])) for tc_select,_ in enumerate(lin.opts.tensor_cores)]
|
||||
|
||||
for i,a in enumerate(kernel_actions):
|
||||
if not isinstance(a, tuple) and a.axis is not None and a.op is not OptOps.TC:
|
||||
if a.axis is not None and a.op is not OptOps.TC:
|
||||
if ((ax:=lin.real_axis(a)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue
|
||||
lin2 = lin.copy()
|
||||
try:
|
||||
if isinstance(a, tuple):
|
||||
for aa in a: lin2.apply_opt(aa)
|
||||
else: lin2.apply_opt(a)
|
||||
lin2.apply_opt(a)
|
||||
up, lcl, tc_up = 1, 1, prod(tc.dims)//tc.threads if (tc:=lin2.tensor_core) else 1
|
||||
for s,c in zip(lin2.full_shape, lin2.colors()):
|
||||
if c in {"magenta", "yellow"}: up *= s
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue