mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
compact get_late_rewrite_patterns [pr] (#9116)
This commit is contained in:
parent
2e97022e5e
commit
c1dfe5c00d
1 changed files with 5 additions and 9 deletions
|
|
@ -132,19 +132,15 @@ def get_late_rewrite_patterns(ops, force_transcendental=False):
|
|||
# rewrite SQRT to xpow 0.5
|
||||
if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
|
||||
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
|
||||
if Ops.AND in ops:
|
||||
pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
|
||||
if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
|
||||
# rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
|
||||
if Ops.SHL in ops and Ops.SHR in ops:
|
||||
pat += [
|
||||
(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << powers_of_two[c.arg] if c.arg in powers_of_two else None),
|
||||
(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: x >> powers_of_two[c.arg] if c.arg in powers_of_two and resolve(x>=0,False) else None)
|
||||
]
|
||||
if Ops.SHL in ops: pat += [(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << v if (v:=powers_of_two.get(c.arg, 0)) else None)]
|
||||
if Ops.SHR in ops:
|
||||
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) and resolve(x>=0,False) else None)]
|
||||
if Ops.NEG in ops:
|
||||
pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
|
||||
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
|
||||
if Ops.MULACC in ops:
|
||||
pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))]
|
||||
if Ops.MULACC in ops: pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))]
|
||||
return PatternMatcher(pat)
|
||||
|
||||
# ***** threefry *****
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue