mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix div rules (#12567)
* group div rules * merge those pattern matchers * revert
This commit is contained in:
parent
8a1c3dc1bf
commit
840d2bf1ea
2 changed files with 13 additions and 8 deletions
|
|
@ -745,8 +745,8 @@ class UPat(MathTrait):
|
|||
def var(name:str|None=None, dtype:DType|tuple[DType, ...]|None=None): return UPat(dtype=dtype, name=name)
|
||||
@staticmethod
|
||||
@functools.cache
|
||||
def cvar(name:str|None=None, dtype:DType|tuple[DType, ...]|None=None, vec=True):
|
||||
return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name)
|
||||
def cvar(name:str|None=None, dtype:DType|tuple[DType, ...]|None=None, vec=True, arg=None):
|
||||
return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name, arg=arg)
|
||||
@staticmethod
|
||||
def const(dtype:DType|tuple[DType, ...]|None, b:ConstType|InvalidType): return UPat(Ops.CONST, dtype=dtype, arg=b)
|
||||
|
||||
|
|
|
|||
|
|
@ -51,8 +51,6 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
|||
(UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
|
||||
(UPat.var("x") // 1, lambda x: x), # x//1 -> x
|
||||
(UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
|
||||
(UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
|
||||
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
|
||||
((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
|
||||
# 4 variations of (x%c)+(x//c)*c = x TODO: add sorting to remove some variations
|
||||
(UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
|
||||
|
|
@ -76,10 +74,6 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
|||
(UPat.var("x") % UPat.var("x"), lambda x: x.const_like(0)), # x%x -> 0
|
||||
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.index)) != UPat.var("x"),
|
||||
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
|
||||
# x*0 -> 0 or 0*x -> 0
|
||||
# if x is nan or inf it should render the nan value.
|
||||
# NOTE: this can be wrong for loaded NaN
|
||||
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
||||
# ** constant folding **
|
||||
# TODO: add const folding for Ops.THREEFRY
|
||||
(UPat(GroupOp.Unary, src=(UPat((Ops.VCONST, Ops.CONST)),), name="a"), lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg], False))),
|
||||
|
|
@ -91,6 +85,17 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
|||
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
|
||||
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
|
||||
(UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y),
|
||||
# *** div rules ***
|
||||
(UPat.cvar('x', arg=0) / 0, lambda x: x.const_like(float('nan'))), # 0/0 -> nan
|
||||
((UPat.var("x") * 0) / 0, lambda x: x.const_like(float('nan'))), # (x*0)/0 -> nan
|
||||
# can be wrong if x or x2 is 0
|
||||
(UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
|
||||
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
|
||||
# x*0 -> 0 or 0*x -> 0
|
||||
# if x is nan or inf it should render the nan value.
|
||||
# NOTE: this can be wrong for loaded NaN
|
||||
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if x.op is Ops.CONST
|
||||
and isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
||||
# *** cast/bitcast ***
|
||||
(UPat(Ops.CAST, name="root", src=(UPat.cvar("c"),)), lambda root, c: root.const_like(c.arg)),
|
||||
(UPat((Ops.CAST, Ops.BITCAST), name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue