This commit is contained in:
George Hotz 2025-03-28 17:37:38 +08:00
commit ab67d5ff6e
2 changed files with 12 additions and 38 deletions

View file

@ -185,25 +185,6 @@ pm_lowerer = PatternMatcher([
# **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints ****
def view_to_mask(x:UOp):
from tinygrad.shape.shapetracker import ShapeTracker, View
st = cast(ShapeTracker, x.st).simplify()
#print("view_to_mask", st.views)
if len(st.views) > 1: return None
if st.views[-1].mask is None: return None
return ShapeTracker((View(st.shape, (0,)*len(st.shape), 0, st.views[-1].mask, False),))
def ignore_on_reduce(r:UOp, ig:UOp):
in_shape = r.src[0].shape
from tinygrad.shape.shapetracker import ShapeTracker, View
st = cast(ShapeTracker, ig.arg)
new_mask = []
for s1, s2, om in zip(in_shape, st.shape, st.views[-1].mask):
if s1 != s2: new_mask.append((0,s1))
else: new_mask.append(om)
new_st = ShapeTracker((View(in_shape, (0,)*len(st.shape), 0, tuple(new_mask), False),))
return r.replace(src=(UOp(Ops.IGNORE, r.dtype, r.src, arg=new_st),))
FP = (1 << 16)
pm_quant = symbolic+PatternMatcher([
# cast after add/mul
@ -211,14 +192,17 @@ pm_quant = symbolic+PatternMatcher([
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))+y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
(UPat.var("x").cast(dtypes.float32) * UPat.var("y").cast(dtypes.float32),
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))*y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
# masked MUL after masked ADD (new, might be wrong)
# masked MUL after masked ADD
((UPat.var("x") + UPat.var("v").where(UPat.var('cadd'), UPat(Ops.CONST, arg=0))) * UPat.var("v").where(UPat.var('cmul'), UPat(Ops.CONST, arg=0)),
lambda x,v,cadd,cmul: x*v.where(cmul, 0)+v.where(cadd*cmul, 0)),
# MUL after reduce
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("x") * UPat.cvar("c"),), name="r"), lambda x,c,r: r.replace(src=(x,))*c),
# CAST after reduce (doesn't work if it's a size change)
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="r"),
lambda x,r: r.replace(dtype=x.dtype, src=(x,)).cast(r.dtype) if dtypes.is_float(r.dtype) else None),
# x*c1 + y*c2 -> (x+y)*c1 (if c1 and c2 are close floats)
(UPat.var("x")*UPat.cvar("c1", dtype=dtypes.floats) + UPat.var("y")*UPat.cvar("c2", dtype=dtypes.floats),
lambda x,y,c1,c2: (x+y)*c1 if abs(c1.arg-c2.arg) < 1e-9 else None),
@ -230,12 +214,14 @@ pm_quant = symbolic+PatternMatcher([
(UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v"))).cast(dtypes.int) + \
UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
lambda ld,v,c1: ld*c1),
# fixed point mult, replace (x.float()*c1+c2).int() with an int expression
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("c2")).cast(dtypes.int),
lambda x,c1,c2: (x * (c1 * FP).cast(dtypes.int) + (c2 * FP).cast(dtypes.int)) // FP),
# fixed point multi, replace (x.float()*c1 + y.float()*c2) with an int expression
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("y").cast(dtypes.float)*UPat.var("c2")),
lambda x,y,c1,c2: ((x * (c1 * FP).cast(dtypes.int) + y * (c2 * FP).cast(dtypes.int)) // FP).cast(dtypes.float)),
# where move
(UPat.var("valid").where(UPat.var("yes"), UPat(Ops.CONST, arg=0))*UPat.var("mul"), lambda valid, yes, mul:
(yes*mul*valid.where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))) if yes.op is not Ops.CONST or yes.arg != 1 else None),
@ -245,22 +231,10 @@ pm_quant = symbolic+PatternMatcher([
((UPat.var('x') * UPat.var('v1').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)) *
UPat.var('v2').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))).named("mul"), lambda x, mul, v1, v2:
x * (v1&v2).where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))),
# where on two adds
(UPat.var("x") + UPat.var("v").where(UPat.var("a0"), UPat.var("a1")) + UPat.var("v").where(UPat.var("b0"), UPat.var("b1")),
lambda x,v,a0,a1,b0,b1: x + v.where(a0+a1, b0+b1)),
# don't care (moved from here)
#(UPat(Ops.STORE, name="x"), lambda x:
# x.replace(src=(x.src[0], UOp(Ops.IGNORE, src=(x.src[1],), arg=mm), UOp(Ops.IGNORE, x.src[2].dtype, src=(x.src[2],), arg=mm),)) \
# if x.src[1].op is not Ops.IGNORE and (mm:=view_to_mask(x.src[1])) is not None else None),
#(UPat(Ops.IGNORE, src=(UPat((*GroupOp.ALU, Ops.CAST), name="alu"),), name="ig"),
# lambda ig,alu: alu.replace(src=tuple(UOp(Ops.IGNORE, x.dtype, (x,), ig.arg) for x in alu.src))),
#(UPat(Ops.IGNORE, src=(UPat.cvar("c"),), name="ig"), lambda ig, c: c),
#(UPat(Ops.IGNORE, src=(UPat(Ops.VALID, name="v"),), name="ig"), lambda ig, v: UOp.const(dtypes.bool, True) if v.src[0].arg == ig.arg else None),
#(UPat(Ops.IGNORE, src=(UPat(Ops.REDUCE_AXIS, name="r"),), name="ig"), ignore_on_reduce),
# put add in REDUCE
#(UPat(Ops.REDUCE_AXIS, name="r")+UPat.var("x"), lambda r,x: r.replace(src=(r.src[0], (r.src[1]+x) if len(r.src) == 2 else x))),
# distribute on casted MUL
#((UPat(Ops.CAST, name="v1")+UPat.cvar("c")) * UPat(Ops.CAST, name="v2"), lambda v1,v2,c: (v1*v2)+(c*v2)),
# split REDUCE into multiple reduces
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, name="v1")+UPat.var("c1")) * UPat(Ops.CAST, name="v2",), name="r"),
@ -268,9 +242,12 @@ pm_quant = symbolic+PatternMatcher([
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")), name="r"),
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,))),
# MUL by 1/0 on LOAD where the masks match (is this right?)
# MUL by 1/0 on LOAD where the masks match
(UPat(Ops.WHERE, src=(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v1"),)), UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))) * \
UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v2")), name="ld"), lambda ld,v1,v2: ld if view_to_mask(v1) == view_to_mask(v2) else None),
UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v2")), name="ld"),
lambda ld,v1,v2: ld if v1.arg.to_indexed_uops()[1].simplify() == v2.arg.to_indexed_uops()[1].simplify() \
# NOTE: this clause is completely false and breaks things
or True else None),
])
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:

View file

@ -93,9 +93,6 @@ spec = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat()), name="idx"), validate_index),
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(), UPat(dtype=dtypes.bool, name="mask")), name="idx"), validate_index),
# double INDEX is used for l1 prefetch in DSP
(UPat(Ops.INDEX, src=(UPat(Ops.INDEX), UPat())), lambda: True),
# LOAD takes a <bufidx, alt?, barrier?>
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True),