mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
unused
This commit is contained in:
parent
cbe23e13c2
commit
ab67d5ff6e
2 changed files with 12 additions and 38 deletions
|
|
@ -185,25 +185,6 @@ pm_lowerer = PatternMatcher([
|
|||
|
||||
# **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints ****
|
||||
|
||||
def view_to_mask(x:UOp):
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
st = cast(ShapeTracker, x.st).simplify()
|
||||
#print("view_to_mask", st.views)
|
||||
if len(st.views) > 1: return None
|
||||
if st.views[-1].mask is None: return None
|
||||
return ShapeTracker((View(st.shape, (0,)*len(st.shape), 0, st.views[-1].mask, False),))
|
||||
|
||||
def ignore_on_reduce(r:UOp, ig:UOp):
|
||||
in_shape = r.src[0].shape
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
st = cast(ShapeTracker, ig.arg)
|
||||
new_mask = []
|
||||
for s1, s2, om in zip(in_shape, st.shape, st.views[-1].mask):
|
||||
if s1 != s2: new_mask.append((0,s1))
|
||||
else: new_mask.append(om)
|
||||
new_st = ShapeTracker((View(in_shape, (0,)*len(st.shape), 0, tuple(new_mask), False),))
|
||||
return r.replace(src=(UOp(Ops.IGNORE, r.dtype, r.src, arg=new_st),))
|
||||
|
||||
FP = (1 << 16)
|
||||
pm_quant = symbolic+PatternMatcher([
|
||||
# cast after add/mul
|
||||
|
|
@ -211,14 +192,17 @@ pm_quant = symbolic+PatternMatcher([
|
|||
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))+y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
|
||||
(UPat.var("x").cast(dtypes.float32) * UPat.var("y").cast(dtypes.float32),
|
||||
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))*y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
|
||||
# masked MUL after masked ADD (new, might be wrong)
|
||||
|
||||
# masked MUL after masked ADD
|
||||
((UPat.var("x") + UPat.var("v").where(UPat.var('cadd'), UPat(Ops.CONST, arg=0))) * UPat.var("v").where(UPat.var('cmul'), UPat(Ops.CONST, arg=0)),
|
||||
lambda x,v,cadd,cmul: x*v.where(cmul, 0)+v.where(cadd*cmul, 0)),
|
||||
|
||||
# MUL after reduce
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("x") * UPat.cvar("c"),), name="r"), lambda x,c,r: r.replace(src=(x,))*c),
|
||||
# CAST after reduce (doesn't work if it's a size change)
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="r"),
|
||||
lambda x,r: r.replace(dtype=x.dtype, src=(x,)).cast(r.dtype) if dtypes.is_float(r.dtype) else None),
|
||||
|
||||
# x*c1 + y*c2 -> (x+y)*c1 (if c1 and c2 are close floats)
|
||||
(UPat.var("x")*UPat.cvar("c1", dtype=dtypes.floats) + UPat.var("y")*UPat.cvar("c2", dtype=dtypes.floats),
|
||||
lambda x,y,c1,c2: (x+y)*c1 if abs(c1.arg-c2.arg) < 1e-9 else None),
|
||||
|
|
@ -230,12 +214,14 @@ pm_quant = symbolic+PatternMatcher([
|
|||
(UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v"))).cast(dtypes.int) + \
|
||||
UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
|
||||
lambda ld,v,c1: ld*c1),
|
||||
|
||||
# fixed point mult, replace (x.float()*c1+c2).int() with an int expression
|
||||
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("c2")).cast(dtypes.int),
|
||||
lambda x,c1,c2: (x * (c1 * FP).cast(dtypes.int) + (c2 * FP).cast(dtypes.int)) // FP),
|
||||
# fixed point multi, replace (x.float()*c1 + y.float()*c2) with an int expression
|
||||
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("y").cast(dtypes.float)*UPat.var("c2")),
|
||||
lambda x,y,c1,c2: ((x * (c1 * FP).cast(dtypes.int) + y * (c2 * FP).cast(dtypes.int)) // FP).cast(dtypes.float)),
|
||||
|
||||
# where move
|
||||
(UPat.var("valid").where(UPat.var("yes"), UPat(Ops.CONST, arg=0))*UPat.var("mul"), lambda valid, yes, mul:
|
||||
(yes*mul*valid.where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))) if yes.op is not Ops.CONST or yes.arg != 1 else None),
|
||||
|
|
@ -245,22 +231,10 @@ pm_quant = symbolic+PatternMatcher([
|
|||
((UPat.var('x') * UPat.var('v1').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)) *
|
||||
UPat.var('v2').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))).named("mul"), lambda x, mul, v1, v2:
|
||||
x * (v1&v2).where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))),
|
||||
|
||||
# where on two adds
|
||||
(UPat.var("x") + UPat.var("v").where(UPat.var("a0"), UPat.var("a1")) + UPat.var("v").where(UPat.var("b0"), UPat.var("b1")),
|
||||
lambda x,v,a0,a1,b0,b1: x + v.where(a0+a1, b0+b1)),
|
||||
# don't care (moved from here)
|
||||
#(UPat(Ops.STORE, name="x"), lambda x:
|
||||
# x.replace(src=(x.src[0], UOp(Ops.IGNORE, src=(x.src[1],), arg=mm), UOp(Ops.IGNORE, x.src[2].dtype, src=(x.src[2],), arg=mm),)) \
|
||||
# if x.src[1].op is not Ops.IGNORE and (mm:=view_to_mask(x.src[1])) is not None else None),
|
||||
#(UPat(Ops.IGNORE, src=(UPat((*GroupOp.ALU, Ops.CAST), name="alu"),), name="ig"),
|
||||
# lambda ig,alu: alu.replace(src=tuple(UOp(Ops.IGNORE, x.dtype, (x,), ig.arg) for x in alu.src))),
|
||||
#(UPat(Ops.IGNORE, src=(UPat.cvar("c"),), name="ig"), lambda ig, c: c),
|
||||
#(UPat(Ops.IGNORE, src=(UPat(Ops.VALID, name="v"),), name="ig"), lambda ig, v: UOp.const(dtypes.bool, True) if v.src[0].arg == ig.arg else None),
|
||||
#(UPat(Ops.IGNORE, src=(UPat(Ops.REDUCE_AXIS, name="r"),), name="ig"), ignore_on_reduce),
|
||||
# put add in REDUCE
|
||||
#(UPat(Ops.REDUCE_AXIS, name="r")+UPat.var("x"), lambda r,x: r.replace(src=(r.src[0], (r.src[1]+x) if len(r.src) == 2 else x))),
|
||||
# distribute on casted MUL
|
||||
#((UPat(Ops.CAST, name="v1")+UPat.cvar("c")) * UPat(Ops.CAST, name="v2"), lambda v1,v2,c: (v1*v2)+(c*v2)),
|
||||
|
||||
# split REDUCE into multiple reduces
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, name="v1")+UPat.var("c1")) * UPat(Ops.CAST, name="v2",), name="r"),
|
||||
|
|
@ -268,9 +242,12 @@ pm_quant = symbolic+PatternMatcher([
|
|||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")), name="r"),
|
||||
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,))),
|
||||
|
||||
# MUL by 1/0 on LOAD where the masks match (is this right?)
|
||||
# MUL by 1/0 on LOAD where the masks match
|
||||
(UPat(Ops.WHERE, src=(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v1"),)), UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))) * \
|
||||
UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v2")), name="ld"), lambda ld,v1,v2: ld if view_to_mask(v1) == view_to_mask(v2) else None),
|
||||
UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v2")), name="ld"),
|
||||
lambda ld,v1,v2: ld if v1.arg.to_indexed_uops()[1].simplify() == v2.arg.to_indexed_uops()[1].simplify() \
|
||||
# NOTE: this clause is completely false and breaks things
|
||||
or True else None),
|
||||
])
|
||||
|
||||
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
|
||||
|
|
|
|||
|
|
@ -93,9 +93,6 @@ spec = PatternMatcher([
|
|||
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat()), name="idx"), validate_index),
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(), UPat(dtype=dtypes.bool, name="mask")), name="idx"), validate_index),
|
||||
|
||||
# double INDEX is used for l1 prefetch in DSP
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.INDEX), UPat())), lambda: True),
|
||||
|
||||
# LOAD takes a <bufidx, alt?, barrier?>
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue