mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
clean up wgsl_matcher [pr] (#8015)
use more UPat syntatic sugar and remove unneeded rules
This commit is contained in:
parent
db330a3110
commit
a5af4e5596
1 changed files with 5 additions and 10 deletions
|
|
@ -31,20 +31,15 @@ def packed_load(root:UOp, bidx:UOp, dtype:DType, var:Optional[UOp]=None):
|
|||
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
|
||||
|
||||
wgsl_matcher = PatternMatcher([
|
||||
(UPat(Ops.CMPLT, src=(UPat(name="a", dtype=dtypes.bool), UPat(name="b")), name="c"),
|
||||
lambda a,b,c: UOp(c.op, c.dtype, (a.cast(dtypes.int), b.cast(dtypes.int)))),
|
||||
(UPat(Ops.XOR, dtype=dtypes.bool, src=(UPat(name="a"), UPat(name="b")), name="c"),
|
||||
lambda a,b,c: UOp(c.op, dtypes.int, (a.cast(dtypes.int), b.cast(dtypes.int))).cast(dtypes.bool)),
|
||||
*[(UPat(a, src=(UPat(name="b", dtype=(dtypes.uint, dtypes.int, dtypes.bool))), name="a"),
|
||||
lambda a,b: UOp(a, dtypes.float, (b.cast(dtypes.float),)).cast(b.dtype)) for a in (Ops.EXP2, Ops.SIN, Ops.LOG2, Ops.SQRT)],
|
||||
(UPat((Ops.CMPLT, Ops.XOR), src=(UPat(name="a", dtype=dtypes.bool), UPat(name="b")), name="c"),
|
||||
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
|
||||
(UPat(Ops.LOAD, name="l", src=(UPat.var('b'),)), lambda l,b: packed_load(l,b,l.dtype) if l.dtype.itemsize < 4 else None),
|
||||
(UPat(Ops.LOAD, name="l", src=(UPat.var('b'), UPat.var('c'), UPat())),
|
||||
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(unpack_map[l.dtype])) if l.dtype.itemsize < 4 else None),
|
||||
(UPat.store(UPat.var("bidx"), UPat.var("var")), lambda bidx,var: packed_store(bidx,var) if var.dtype.itemsize < 4 else None),
|
||||
(UPat(Ops.MUL, name="m", src=(UPat(name="a"), UPat(Ops.WHERE, src=(UPat.var("g"),
|
||||
UPat(op=Ops.CONST, name="c1"), UPat(op=Ops.CONST, name="c2"))))),
|
||||
lambda m,a,g,c1,c2: UOp(Ops.WHERE, dtype=m.dtype, src=(g, UOp.const(dtype=dtypes.float, b=float('nan')), a))
|
||||
if math.isnan(c1.arg) and c2.arg == 1.0 else None),
|
||||
# TODO: why is this needed, and only for this MUL order
|
||||
(UPat(Ops.MUL, src=(UPat.var("a"), UPat.var("g").where(UPat.cvar("c1"), UPat.cvar("c2")))),
|
||||
lambda a,g,c1,c2: g.where(c1, a) if math.isnan(c1.arg) and c2.arg == 1.0 else None),
|
||||
]) + extra_pm
|
||||
|
||||
type_map = { dtypes.float: "f32", dtypes.uchar: "u32", dtypes.ushort: "u32", dtypes.short: "i32",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue