clean up wgsl_matcher [pr] (#8015)

use more UPat syntatic sugar and remove unneeded rules
This commit is contained in:
chenyu 2024-12-03 11:55:03 -05:00 committed by GitHub
commit a5af4e5596
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -31,20 +31,15 @@ def packed_load(root:UOp, bidx:UOp, dtype:DType, var:Optional[UOp]=None):
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
wgsl_matcher = PatternMatcher([
(UPat(Ops.CMPLT, src=(UPat(name="a", dtype=dtypes.bool), UPat(name="b")), name="c"),
lambda a,b,c: UOp(c.op, c.dtype, (a.cast(dtypes.int), b.cast(dtypes.int)))),
(UPat(Ops.XOR, dtype=dtypes.bool, src=(UPat(name="a"), UPat(name="b")), name="c"),
lambda a,b,c: UOp(c.op, dtypes.int, (a.cast(dtypes.int), b.cast(dtypes.int))).cast(dtypes.bool)),
*[(UPat(a, src=(UPat(name="b", dtype=(dtypes.uint, dtypes.int, dtypes.bool))), name="a"),
lambda a,b: UOp(a, dtypes.float, (b.cast(dtypes.float),)).cast(b.dtype)) for a in (Ops.EXP2, Ops.SIN, Ops.LOG2, Ops.SQRT)],
(UPat((Ops.CMPLT, Ops.XOR), src=(UPat(name="a", dtype=dtypes.bool), UPat(name="b")), name="c"),
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
(UPat(Ops.LOAD, name="l", src=(UPat.var('b'),)), lambda l,b: packed_load(l,b,l.dtype) if l.dtype.itemsize < 4 else None),
(UPat(Ops.LOAD, name="l", src=(UPat.var('b'), UPat.var('c'), UPat())),
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(unpack_map[l.dtype])) if l.dtype.itemsize < 4 else None),
(UPat.store(UPat.var("bidx"), UPat.var("var")), lambda bidx,var: packed_store(bidx,var) if var.dtype.itemsize < 4 else None),
(UPat(Ops.MUL, name="m", src=(UPat(name="a"), UPat(Ops.WHERE, src=(UPat.var("g"),
UPat(op=Ops.CONST, name="c1"), UPat(op=Ops.CONST, name="c2"))))),
lambda m,a,g,c1,c2: UOp(Ops.WHERE, dtype=m.dtype, src=(g, UOp.const(dtype=dtypes.float, b=float('nan')), a))
if math.isnan(c1.arg) and c2.arg == 1.0 else None),
# TODO: why is this needed, and only for this MUL order
(UPat(Ops.MUL, src=(UPat.var("a"), UPat.var("g").where(UPat.cvar("c1"), UPat.cvar("c2")))),
lambda a,g,c1,c2: g.where(c1, a) if math.isnan(c1.arg) and c2.arg == 1.0 else None),
]) + extra_pm
type_map = { dtypes.float: "f32", dtypes.uchar: "u32", dtypes.ushort: "u32", dtypes.short: "i32",