more cleanups

This commit is contained in:
George Hotz 2025-03-26 17:59:26 +08:00
commit 290ba9ee37
5 changed files with 7 additions and 65 deletions

View file

@ -59,7 +59,7 @@ if __name__ == "__main__":
return {"input": img.numpy()}
quantize_static(model_fp32, fn, ImagenetReader(), quant_format=QuantFormat.QDQ, per_channel=False,
activation_type=QuantType.QUInt8, weight_type=QuantType.QUInt8,
extra_options={"ActivationSymmetric": False, "WeightSymmetric": True})
extra_options={"ActivationSymmetric": False})
run_onnx_jit, input_specs = load_onnx_model(fetch(fn))
t_name, t_spec = list(input_specs.items())[0]

View file

@ -158,11 +158,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
must_divide = True
if ctx is not None and ctx.device == "DSP":
lengths = [128,64,32,16,8,4]
#if ls.dtype.count in [128+64, 128*2+64, 128*4+64]: return None # leave 192 alone
if ls.dtype.count in [192, 288, 160, 96]: return None # leave 192 alone
# we really want stores to be 128 for fast casting
#if ls.op is Ops.LOAD: lengths = [192]+lengths
#if ls.op is Ops.LOAD: lengths = [1536,1024,512,384,256,192,96]+lengths
if ls.dtype.count in [192, 288, 160, 96]: return None # leave these as loads
must_divide = False
elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
pass
@ -304,35 +300,11 @@ pm_reduce = PatternMatcher([
(UPat(Ops.REDUCE, name="x"), reduce_to_acc)
])
def move_load_mask(ld:UOp, idx:UOp):
if len(idx.src) != 3: return None
mask = idx.src[2]
new_idx = idx.replace(src=idx.src[0:2])
return ld.substitute({idx:new_idx}) * mask.broadcast(ld.dtype.count).cast(ld.dtype)
pm_move_load_masks = PatternMatcher([
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, name="idx"),), name="ld"), move_load_mask),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, name="idx").cast(),), name="ld"), move_load_mask),
])
def fix_range(rng:UOp):
if rng.arg in [1,2] and rng.src[0].arg == 0:
return rng.replace(src=(rng.src[0]+1, rng.src[1]))
pm_ranges = PatternMatcher([
(UPat(Ops.RANGE, name="rng"), fix_range)
])
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None, is_conv=False) -> UOp:
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
# we can move the load masks to after the load
#sink = graph_rewrite(sink, pm_move_load_masks, name="move_load_masks")
#sink = graph_rewrite(sink, pm_ranges)
# devectorize is optional
if DEVECTORIZE >= 2: sink = graph_rewrite(sink, sym+load_store_folding+load_store_indexing, ctx=opts)
elif DEVECTORIZE: sink = graph_rewrite(sink, sym+devectorize+load_store_folding+correct_load_store+load_store_indexing, ctx=opts)
@ -346,16 +318,6 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None, is_conv=False) ->
else:
sink = graph_rewrite(sink, opts.pre_matcher)
# late unroll
"""
late_unroll = PatternMatcher([(UPat(Ops.RANGE, arg=2, name="r"),
lambda r: UOp(Ops.UNROLL, dtypes.int, src=(UOp.const(dtypes.int.vec(3), (0,1,2)),), arg=((2,3),)) )])
sink = graph_rewrite(sink, late_unroll, name="late_unroll")
from tinygrad.codegen.expander import expander
sink = graph_rewrite(sink, sym+expander, name="late_expand")
sink = graph_rewrite(sink, sym+load_store_folding+load_store_indexing, ctx=opts)
"""
# remove reduce
sink = graph_rewrite(sink, pm_reduce, ctx=ReduceContext(), name="remove_reduce")

View file

@ -123,18 +123,5 @@ def expand_rewrite(sink:UOp) -> UOp:
# initial symbolic + migrate indexing (remove this)
sink = graph_rewrite(sink, sym+migrate_indexing)
"""
# late pad
def contract_test(x):
ret = UOp(Ops.CONTRACT, dtypes.int.vec(12), src=(x,), arg=((2,3), (3,4)))
#return ret.gep(3) + ret.gep(7) + ret.gep(11) +
return ret.gep((0,1,2,4,5,6,8,9,10))
late_unroll = gep_pushing+PatternMatcher([(UPat(Ops.UNROLL, arg=((3,3),)),
lambda: UOp(Ops.UNROLL, dtypes.int, src=(UOp.const(dtypes.int.vec(4), (0,1,2,3)),), arg=((3,4),))),
(UPat(Ops.CONTRACT, src=(UPat.var("x"),), arg=((2,3), (3,3))), contract_test)])
sink = graph_rewrite(sink, late_unroll, name="late_unroll")
"""
# expand
return graph_rewrite(sink, sym+expander)

View file

@ -387,7 +387,7 @@ class Kernel:
self.group_for_reduces += 1
elif opt.op is OptOps.UNROLL: # purple
check(axis < self.first_upcast, "can't upcasted already upcasted")
#check(amt <= 32, "don't unroll more than 32")
check(amt <= 32, "don't unroll more than 32")
# TODO: fix upcast_count to put purples before yellows. broken because of METAL tensor cores
#upcast_count = sum(x == y for x,y in zip(self.full_shape[-self.upcasted:], self.output_shape[-self.upcasted:])) if self.upcasted else 0
#self.shift_to(axis, amt, insert_before=None if upcast_count == 0 else self.shape_len-upcast_count)

View file

@ -12,12 +12,7 @@ from tinygrad.tensor import Tensor
from tinygrad.engine.realize import CompiledRunner
from tinygrad.renderer import ProgramSpec
actions = []
actions += [(Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0))]
actions += [(Opt(op=OptOps.UNROLL, axis=1, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0))]
actions += [(Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0))]
actions += [(Opt(op=OptOps.PADTO, axis=axis, arg=amt), Opt(op=OptOps.UPCAST, axis=axis, arg=amt)) for amt in [128] for axis in range(7)]
actions += [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7,16,32,64,128] for axis in range(6)]
actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
actions += [Opt(op=OptOps.UNROLL, axis=axis, arg=amt) for amt in [0,4,7] for axis in range(5)]
actions += [Opt(op=OptOps.LOCAL, axis=axis, arg=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)]
actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
@ -114,18 +109,16 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
if TC_SEARCH_OVER_SHAPE and len(lin.applied_opts) == 0: # tensor core opts must be first
for i, action in enumerate(kernel_actions):
if not isinstance(action, tuple) and action.op == OptOps.TC and (tc_arg := cast(tuple, action.arg))[0] == -1:
if action.op == OptOps.TC and (tc_arg := cast(tuple, action.arg))[0] == -1:
# replace every tc_action with default tc with one tc_action for each available tc
kernel_actions[i:i+1] = [Opt(op=OptOps.TC, axis=action.axis, arg=(tc_select, tc_arg[1])) for tc_select,_ in enumerate(lin.opts.tensor_cores)]
for i,a in enumerate(kernel_actions):
if not isinstance(a, tuple) and a.axis is not None and a.op is not OptOps.TC:
if a.axis is not None and a.op is not OptOps.TC:
if ((ax:=lin.real_axis(a)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue
lin2 = lin.copy()
try:
if isinstance(a, tuple):
for aa in a: lin2.apply_opt(aa)
else: lin2.apply_opt(a)
lin2.apply_opt(a)
up, lcl, tc_up = 1, 1, prod(tc.dims)//tc.threads if (tc:=lin2.tensor_core) else 1
for s,c in zip(lin2.full_shape, lin2.colors()):
if c in {"magenta", "yellow"}: up *= s