mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
more crap
This commit is contained in:
parent
ea3dd000d3
commit
cf3f67e0e5
3 changed files with 72 additions and 53 deletions
|
|
@ -315,13 +315,16 @@ class OpenCLRenderer(CStyleLanguage):
|
|||
lambda ctx,x: f"{(struct.unpack('I', struct.pack('f', float_to_bf16(x.arg)))[0] >> 16)}u"),
|
||||
# load/store image (OpenCL)
|
||||
(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')), lambda ctx,buf,idx_y,idx_x: f"IMAGE<{ctx[buf]}, {ctx[idx_y]}, {ctx[idx_x]}>"),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("var"), UPat.var("gate"))),
|
||||
lambda ctx,buf,idx_y,idx_x,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, (int2)({ctx[idx_x]},{ctx[idx_y]})):{ctx[var]})"),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')),)),
|
||||
lambda ctx,buf,idx_y,idx_x: f"read_imagef({ctx[buf]}, smp, (int2)({ctx[idx_x]},{ctx[idx_y]}))"),
|
||||
(UPat(Ops.LOAD,
|
||||
dtype=dtypes.float, src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("var"), UPat.var("gate")), name="x"),
|
||||
lambda ctx,x,buf,idx_y,idx_x,var,gate:
|
||||
f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, (int2)({ctx[idx_x]},{ctx[idx_y]})):{ctx[var]})" if x.max_numel() == 4 else None),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float, src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')),), name="x"),
|
||||
lambda ctx,x,buf,idx_y,idx_x: f"read_imagef({ctx[buf]}, smp, (int2)({ctx[idx_x]},{ctx[idx_y]}))" if x.max_numel() == 4 else None),
|
||||
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')),
|
||||
UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
|
||||
lambda ctx,buf,idx_y,idx_x,var: f"write_imagef({ctx[buf]}, (int2)({ctx[idx_x]},{ctx[idx_y]}), {ctx[var]});"),
|
||||
UPat.var("var", dtypes.float)), allow_any_len=True),
|
||||
lambda ctx,buf,idx_y,idx_x,var:
|
||||
f"write_imagef({ctx[buf]}, (int2)({ctx[idx_x]},{ctx[idx_y]}), {ctx[var]});" if var.max_numel() == 4 else None),
|
||||
]) + base_rewrite
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
|
|
@ -532,9 +535,9 @@ class HIPRenderer(CStyleLanguage):
|
|||
float4 = "make_float4"
|
||||
type_map = {dtypes.bfloat16: "hip_bfloat16", dtypes.fp8e4m3: "hip_fp8", dtypes.fp8e5m2: "hip_bf8"}
|
||||
extra_matcher = create_non_native_float_pats((dtypes.bfloat16, *dtypes.fp8s)) + PatternMatcher([
|
||||
(UPat(Ops.WMMA, name="x", dtype=dtypes.float.vec(4)),
|
||||
lambda x: UOp(Ops.WMMA, x.dtype, (x.src[0].bitcast(dtypes.uint64), x.src[1].bitcast(dtypes.uint64),
|
||||
x.src[2]), (*x.arg,)) if x.src[0].dtype in (dtypes.fp8e4m3.vec(8), dtypes.fp8e5m2.vec(8)) else None),
|
||||
(UPat(Ops.WMMA, name="x", dtype=dtypes.float),
|
||||
lambda x: UOp(Ops.WMMA, x.dtype.scalar(), (x.src[0].bitcast(dtypes.uint64), x.src[1].bitcast(dtypes.uint64),
|
||||
x.src[2]), (*x.arg,)) if x.max_numel() == 4 and x.src[0].dtype.scalar() in dtypes.fp8_ocp and x.src[0].max_numel() == 8 else None),
|
||||
# bfloat16 constant casting
|
||||
(UPat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -155,8 +155,8 @@ extra_matcher = PatternMatcher([
|
|||
(UPat.var("m").where(UPat.var("a", (dtypes.bool,)+dtypes.int8s), UPat.var("b")),
|
||||
lambda m,a,b: m.where(a.cast(dtypes.int16), b.cast(dtypes.int16)).cast(a.dtype) if a.max_numel() == 1 else None),
|
||||
# float16 alus are done in float32
|
||||
(UPat(GroupOp.ALU, dtypes.float16, name="x"), lambda x: UOp(x.op, dtypes.float.vec(x.max_numel()),
|
||||
tuple(s.cast(dtypes.float) if s.dtype != dtypes.bool else s for s in x.src)).cast(x.dtype)),
|
||||
(UPat(GroupOp.ALU, dtypes.float16, name="x"), lambda x: UOp(Ops.CAST, x.dtype.scalar(), (UOp(x.op, dtypes.float,
|
||||
tuple(UOp(Ops.CAST, dtypes.float, (s,)) if s.dtype != dtypes.bool else s for s in x.src)),))),
|
||||
(UPat(GroupOp.Comparison, src=(UPat.var("a", dtypes.float16), UPat.var("b")), name="x"),
|
||||
lambda x,a,b: UOp(x.op, x.dtype, (a.cast(dtypes.float32), b.cast(dtypes.float32))).cast(x.dtype)),
|
||||
# no cmpne for packed ints, y != x => !(y==x)
|
||||
|
|
@ -296,7 +296,7 @@ def vpbroadcast(ctx:IselContext, x:UOp, y:UOp) -> UOp:
|
|||
n = x.ins({1: X86Ops.VPBROADCASTB, 2: X86Ops.VPBROADCASTW, 4: X86Ops.VPBROADCASTD, 8: X86Ops.VPBROADCASTQ}[y.dtype.itemsize], src=(y,))
|
||||
if y.op is Ops.LOAD and len(y.src) == 1 and is_foldable(ctx, n, y): return n
|
||||
# if there isn't a load we can fold we need to move y from gpr to xmm
|
||||
# this is hacky but required because int.vec(1) isn't supported
|
||||
# this is hacky but required because scalar int bitcasts need a float register type
|
||||
y = y if y.dtype.itemsize > 1 else y.cast(dtypes.int16)
|
||||
return n.replace(src=(y.bitcast({2:dtypes.float16, 4:dtypes.float32, 8:dtypes.float64}[y.dtype.itemsize]),))
|
||||
|
||||
|
|
@ -357,12 +357,6 @@ def alloc_vregs(ctx:IselContext, x:UOp) -> UOp|None:
|
|||
# if x.arg in X86GroupOp.WriteFlags: defs.append(ctx.vreg(RFLAGS))
|
||||
return x.replace(tag=tuple(defs))
|
||||
|
||||
dts = dtypes.ints + (dtypes.bool, dtypes.float16, dtypes.float32, dtypes.float64)
|
||||
dt_16bit = tuple(dt.vec(l) for dt in dts for l in [2,1] if l*dt.itemsize == 2 and dt not in dtypes.int16s)
|
||||
dt_32bit = tuple(dt.vec(l) for dt in dts for l in [4,2,1] if l*dt.itemsize == 4 and dt not in dtypes.int32s)
|
||||
dt_64bit = tuple(dt.vec(l) for dt in dts for l in [8,4,2,1] if l*dt.itemsize == 8 and dt not in dtypes.int64s)
|
||||
dt_128bit = tuple(dt.vec(l) for dt in dts for l in [16,8,4,2,1] if l*dt.itemsize == 16)
|
||||
|
||||
isel_matcher = PatternMatcher([
|
||||
# **** Op -> Op ****
|
||||
# cast to pointer is a noop
|
||||
|
|
@ -377,7 +371,7 @@ isel_matcher = PatternMatcher([
|
|||
# add callee saved registers to the RET, these will be scheduled at the top of the kernel and will be saved/restored if they are used in regalloc
|
||||
# so regalloc builds the prologue/epilogue naturally
|
||||
(UPat(Ops.SINK, name="x"), lambda x:
|
||||
x.replace(src=(x.ins(X86Ops.RET, src=x.src + tuple(def_reg(dtypes.uint64 if r in GPR else dtypes.float64.vec(2), r) for r in CALLEE_SAVED)),)) \
|
||||
x.replace(src=(x.ins(X86Ops.RET, src=x.src + tuple(def_reg(dtypes.uint64 if r in GPR else dtypes.float64, r) for r in CALLEE_SAVED)),)) \
|
||||
if not x.src or x.src[0].arg is not X86Ops.RET else None),
|
||||
# function abi constraints
|
||||
(UPat((Ops.PARAM, Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), abi),
|
||||
|
|
@ -446,8 +440,8 @@ isel_matcher = PatternMatcher([
|
|||
# for float16 we route the srcs through gprs unless we can fold them, this is suboptimal for values in xmms, in that case we want vpunpcklwd
|
||||
(UPat(Ops.STACK, dtypes.float16, name="x"), lambda ctx,x:
|
||||
vpins(x.replace(src=tuple(s if s.op is Ops.LOAD and is_foldable(ctx, x, s) else s.bitcast(dtypes.int16) for s in x.src)))),
|
||||
(UPat(Ops.STACK, (dtypes.float32.vec(4), dtypes.float32.vec(8)), name="x"), vshufps),
|
||||
(UPat(Ops.STACK, (dtypes.float64.vec(2), dtypes.float64.vec(4)), name="x"), vshufpd),
|
||||
(UPat(Ops.STACK, dtypes.float32, name="x"), lambda x: vshufps(x) if x.max_numel() in (4, 8) else None),
|
||||
(UPat(Ops.STACK, dtypes.float64, name="x"), lambda x: vshufpd(x) if x.max_numel() in (2, 4) else None),
|
||||
(UPat(Ops.STACK, dtypes.float32, name="x"), vinsertps),
|
||||
(UPat.var("y", dtypes.ints+(dtypes.bool,)).broadcast(name="x"), vpbroadcast),
|
||||
(UPat(Ops.STACK, dtypes.ints+(dtypes.bool,), name="x"), vpins),
|
||||
|
|
@ -555,20 +549,32 @@ isel_matcher = PatternMatcher([
|
|||
# TODO: fuse stores, very few cases -- store cmp becomes setcc, store gep int becomes vpextr, store bitcast to int becomes vmovd/q
|
||||
# copy, load, store
|
||||
# NOTE: copy here violates the spec, it only happens post register allocation when a reg to reg move needs to be inserted
|
||||
(UPat(Ops.COPY, dt_128bit, name="x"), lambda x: x.ins(X86Ops.VMOVUPS)),
|
||||
(UPat(Ops.COPY, dt_64bit, name="x"), lambda x: x.ins(X86Ops.VMOVSD)),
|
||||
(UPat(Ops.COPY, dt_32bit+dt_16bit, name="x"), lambda x: x.ins(X86Ops.VMOVSS)),
|
||||
(UPat(Ops.COPY, name="x"), lambda x: x.ins(X86Ops.VMOVUPS) if x.dtype.itemsize == 16 else None),
|
||||
(UPat(Ops.COPY, name="x"), lambda x: x.ins(X86Ops.VMOVSD) if x.dtype.itemsize == 8 and x.dtype.scalar() not in dtypes.int64s else None),
|
||||
(UPat(Ops.COPY, name="x"), lambda x:
|
||||
x.ins(X86Ops.VMOVSS) if x.dtype.itemsize in (2, 4) and
|
||||
(x.dtype.scalar() in dtypes.floats or (x._shape is not None and x.max_numel() > 1)) else None),
|
||||
(UPat(Ops.COPY, dtypes.ints+(dtypes.bool,), name="x"), lambda x: x.ins(X86Ops.MOV)),
|
||||
(UPat(Ops.LOAD, dt_128bit, src=(UPat(name="a"),), name="x"), lambda x,a: x.ins(X86Ops.VMOVUPS, src=fold_address(a))),
|
||||
(UPat(Ops.LOAD, dt_64bit, src=(UPat(name="a"),), name="x"), lambda x,a: x.ins(X86Ops.VMOVSD, src=fold_address(a))),
|
||||
(UPat(Ops.LOAD, dt_32bit, src=(UPat(name="a"),), name="x"), lambda x,a: x.ins(X86Ops.VMOVSS, src=fold_address(a))),
|
||||
(UPat(Ops.LOAD, dt_16bit, src=(UPat(name="a"),), name="x"), lambda x,a:
|
||||
x.ins(X86Ops.VPINSRW, src=(def_reg(x.dtype, x.tag),) + fold_address(a) + (imm(dtypes.uint8, 0),))),
|
||||
(UPat(Ops.LOAD, src=(UPat(name="a"),), name="x"), lambda x,a: x.ins(X86Ops.VMOVUPS, src=fold_address(a)) if x.dtype.itemsize == 16 else None),
|
||||
(UPat(Ops.LOAD, src=(UPat(name="a"),), name="x"), lambda x,a:
|
||||
x.ins(X86Ops.VMOVSD, src=fold_address(a)) if x.dtype.itemsize == 8 and x.dtype.scalar() not in dtypes.int64s else None),
|
||||
(UPat(Ops.LOAD, src=(UPat(name="a"),), name="x"), lambda x,a:
|
||||
x.ins(X86Ops.VMOVSS, src=fold_address(a))
|
||||
if x.dtype.itemsize == 4 and (x.dtype.scalar() in dtypes.floats or (x._shape is not None and x.max_numel() > 1)) else None),
|
||||
(UPat(Ops.LOAD, src=(UPat(name="a"),), name="x"), lambda x,a:
|
||||
x.ins(X86Ops.VPINSRW, src=(def_reg(x.dtype, x.tag),) + fold_address(a) + (imm(dtypes.uint8, 0),))
|
||||
if x.dtype.itemsize == 2 and (x.dtype.scalar() in dtypes.floats or (x._shape is not None and x.max_numel() > 1)) else None),
|
||||
(UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), src=(UPat(name="a"),), name="x"), lambda x,a: x.ins(X86Ops.MOV, src=fold_address(a))),
|
||||
(UPat.var("a").store(UPat.var("b", dt_128bit), name="x"), lambda a,b,x: x.ins(X86Ops.VMOVUPSm, src=fold_address(a) + (b,))),
|
||||
(UPat.var("a").store(UPat.var("b", dt_64bit), name="x"), lambda a,b,x: x.ins(X86Ops.VMOVSDm, src=fold_address(a) + (b,))),
|
||||
(UPat.var("a").store(UPat.var("b", dt_32bit), name="x"), lambda a,b,x: x.ins(X86Ops.VMOVSSm, src=fold_address(a) + (b,))),
|
||||
(UPat.var("a").store(UPat.var("b", dt_16bit), name="x"), lambda a,b,x: x.ins(X86Ops.VPEXTRW, src=fold_address(a) + (b, imm(dtypes.uint8, 0)))),
|
||||
(UPat.var("a").store(UPat.var("b"), name="x"), lambda a,b,x:
|
||||
x.ins(X86Ops.VMOVUPSm, src=fold_address(a) + (b,)) if b.dtype.itemsize == 16 else None),
|
||||
(UPat.var("a").store(UPat.var("b"), name="x"), lambda a,b,x:
|
||||
x.ins(X86Ops.VMOVSDm, src=fold_address(a) + (b,)) if b.dtype.itemsize == 8 and b.dtype.scalar() not in dtypes.int64s else None),
|
||||
(UPat.var("a").store(UPat.var("b"), name="x"), lambda a,b,x:
|
||||
x.ins(X86Ops.VMOVSSm, src=fold_address(a) + (b,))
|
||||
if b.dtype.itemsize == 4 and (b.dtype.scalar() in dtypes.floats or (b._shape is not None and b.max_numel() > 1)) else None),
|
||||
(UPat.var("a").store(UPat.var("b"), name="x"), lambda a,b,x:
|
||||
x.ins(X86Ops.VPEXTRW, src=fold_address(a) + (b, imm(dtypes.uint8, 0)))
|
||||
if b.dtype.itemsize == 2 and (b.dtype.scalar() in dtypes.floats or (b._shape is not None and b.max_numel() > 1)) else None),
|
||||
(UPat.var("a").store(UPat.var("b", dtypes.ints+(dtypes.bool,)), name="x"), lambda a,b,x:
|
||||
x.ins(X86Ops.MOVm, src=fold_address(a) + (b,)) if (i:=to_imm(b)) is None else x.ins(X86Ops.MOVi, src=fold_address(a) + (i,))),
|
||||
# **** X86Op -> X86Op ****
|
||||
|
|
|
|||
|
|
@ -253,10 +253,11 @@ class AMDLLVMRenderer(LLVMRenderer):
|
|||
f" {ctx[x]} = call float @llvm.amdgcn.cvt.f32.{'bf8' if y.dtype == dtypes.fp8e5m2 else 'fp8'}(i32 {ctx[x.src[0]]}_i32, i32 0)"),
|
||||
]) + base_rewrite
|
||||
extra_matcher = LLVMRenderer.extra_matcher + create_non_native_float_pats(dtypes.fp8s) + PatternMatcher([
|
||||
(UPat(Ops.CAST, dtype=dtypes.half.vec(16), src=UPat.var("y", dtypes.half.vec(8))),
|
||||
lambda y: UOp(Ops.STACK, dtypes.half.vec(16), tuple(y.gep(i // 2) if i % 2 == 0 else UOp.const(dtypes.half, 0.0) for i in range(16)))),
|
||||
(UPat(Ops.CAST, dtype=dtypes.half.vec(8), src=UPat.var("y", dtypes.half.vec(16))),
|
||||
lambda y: UOp(Ops.STACK, dtypes.half.vec(8), tuple(y.gep(i * 2) for i in range(8)))),
|
||||
(UPat(Ops.CAST, dtype=dtypes.half, src=UPat.var("y", dtypes.half), name="x"),
|
||||
lambda x,y: UOp(Ops.STACK, dtypes.half, tuple(y.gep(i // 2) if i % 2 == 0 else UOp.const(dtypes.half, 0.0) for i in range(16)))
|
||||
if x.max_numel() == 16 and y.max_numel() == 8 else None),
|
||||
(UPat(Ops.CAST, dtype=dtypes.half, src=UPat.var("y", dtypes.half), name="x"),
|
||||
lambda x,y: UOp(Ops.STACK, dtypes.half, tuple(y.gep(i * 2) for i in range(8))) if x.max_numel() == 8 and y.max_numel() == 16 else None),
|
||||
# amd llvm intrinsics llvm.log2/llvm.exp2 don't support double
|
||||
(UPat(Ops.LOG2, dtype=dtypes.double, src=(UPat.var("d"),)), xlog2),
|
||||
(UPat(Ops.EXP2, dtype=dtypes.double, src=(UPat.var("d"),)), xexp2),
|
||||
|
|
@ -291,28 +292,37 @@ exit: %packed = phi i32 [%packed_bf8, %do_bf8], [%packed_fp8, %do_fp8]\n %trunc
|
|||
self.string_rewrite += PatternMatcher([(UPat(Ops.WMMA, name="wmma"), lambda ctx, wmma, cdna=self.is_cdna: render_wmma_amd(ctx, wmma, cdna))])
|
||||
if self.is_cdna:
|
||||
self.extra_matcher += PatternMatcher([
|
||||
(UPat(Ops.WMMA, name="x", dtype=dtypes.float.vec(4)),
|
||||
lambda x: UOp(Ops.WMMA, dtypes.float.vec(4), (x.src[0].bitcast(dtypes.uint16.vec(4)), x.src[1].bitcast(dtypes.uint16.vec(4)),
|
||||
x.src[2]), (*x.arg,)) if x.src[0].dtype == dtypes.bfloat16.vec(4) else None),
|
||||
(UPat(Ops.WMMA, name="x", dtype=dtypes.float.vec(4)),
|
||||
lambda x: UOp(Ops.WMMA, dtypes.float.vec(4), (x.src[0].bitcast(dtypes.uint64), x.src[1].bitcast(dtypes.uint64),
|
||||
x.src[2]), (*x.arg,)) if x.src[0].dtype in (dtypes.fp8e4m3.vec(8), dtypes.fp8e5m2.vec(8)) else None),
|
||||
(UPat(Ops.WMMA, name="x", dtype=dtypes.float),
|
||||
lambda x: UOp(Ops.WMMA, dtypes.float, (x.src[0].replace(dtype=x.src[0].dtype.scalar()).bitcast(dtypes.uint16),
|
||||
x.src[1].replace(dtype=x.src[1].dtype.scalar()).bitcast(dtypes.uint16), x.src[2]), (*x.arg,))
|
||||
if x.max_numel() == 4 and x.src[0].dtype.scalar() == dtypes.bfloat16 and x.src[0].max_numel() == 4 else None),
|
||||
(UPat(Ops.WMMA, name="x", dtype=dtypes.float),
|
||||
lambda x: UOp(Ops.WMMA, dtypes.float, (x.src[0].replace(dtype=x.src[0].dtype.scalar()).bitcast(dtypes.uint64),
|
||||
x.src[1].replace(dtype=x.src[1].dtype.scalar()).bitcast(dtypes.uint64), x.src[2]), (*x.arg,))
|
||||
if x.max_numel() == 4 and x.src[0].dtype.scalar() in dtypes.fp8_ocp and x.src[0].max_numel() == 8 else None),
|
||||
])
|
||||
if target.arch in {"gfx1100", "gfx1151"}:
|
||||
self.extra_matcher += PatternMatcher([
|
||||
(UPat(Ops.WMMA, name="x", dtype=dtypes.half.vec(8)),
|
||||
lambda x: UOp(Ops.WMMA, dtypes.half.vec(16), (x.src[0], x.src[1], x.src[2].cast(dtypes.half.vec(16))), (*x.arg,)).cast(dtypes.half.vec(8))),
|
||||
(UPat(Ops.WMMA, name="x"), lambda x: UOp(Ops.WMMA, x.dtype, (x.src[0].bitcast(dtypes.uint16.vec(16)), x.src[1].bitcast(dtypes.uint16.vec(16)),
|
||||
x.src[2]), x.arg) if x.src[0].dtype == dtypes.bfloat16.vec(16) else None),
|
||||
(UPat(Ops.WMMA, name="x", dtype=dtypes.half),
|
||||
lambda x: UOp(Ops.STACK, dtypes.half, tuple(UOp(Ops.WMMA, dtypes.half, (x.src[0], x.src[1],
|
||||
UOp(Ops.STACK, dtypes.half, tuple(x.src[2].gep(i // 2) if i % 2 == 0 else UOp.const(dtypes.half, 0.0) for i in range(16)))),
|
||||
(*x.arg,)).gep(i * 2) for i in range(8))) if x.max_numel() == 8 else None),
|
||||
(UPat(Ops.WMMA, name="x"), lambda x: UOp(Ops.WMMA, x.dtype.scalar(),
|
||||
(x.src[0].replace(dtype=x.src[0].dtype.scalar()).bitcast(dtypes.uint16),
|
||||
x.src[1].replace(dtype=x.src[1].dtype.scalar()).bitcast(dtypes.uint16), x.src[2]), x.arg)
|
||||
if x.src[0].dtype.scalar() == dtypes.bfloat16 and x.src[0].max_numel() == 16 else None),
|
||||
])
|
||||
if target.arch in {"gfx1200", "gfx1201"}:
|
||||
self.extra_matcher += PatternMatcher([
|
||||
(UPat(Ops.WMMA, name="x", dtype=dtypes.bfloat16.vec(8)), lambda x: UOp(Ops.WMMA, dtypes.uint16.vec(8),
|
||||
(x.src[0].bitcast(dtypes.uint16.vec(8)), x.src[1].bitcast(dtypes.uint16.vec(8)), x.src[2].bitcast(dtypes.uint16.vec(8))), (*x.arg,))
|
||||
.bitcast(dtypes.bfloat16.vec(8)) if x.src[0].dtype == dtypes.bfloat16.vec(8) else None),
|
||||
(UPat(Ops.WMMA, name="x", dtype=dtypes.float.vec(8)),
|
||||
lambda x: UOp(Ops.WMMA, dtypes.float.vec(8), (x.src[0].bitcast(dtypes.uint16.vec(8)), x.src[1].bitcast(dtypes.uint16.vec(8)),
|
||||
x.src[2]), (*x.arg,)) if x.src[0].dtype == dtypes.bfloat16.vec(8) else None)
|
||||
(UPat(Ops.WMMA, name="x", dtype=dtypes.bfloat16), lambda x: UOp(Ops.WMMA, dtypes.uint16,
|
||||
(x.src[0].replace(dtype=x.src[0].dtype.scalar()).bitcast(dtypes.uint16),
|
||||
x.src[1].replace(dtype=x.src[1].dtype.scalar()).bitcast(dtypes.uint16),
|
||||
x.src[2].replace(dtype=x.src[2].dtype.scalar()).bitcast(dtypes.uint16)), (*x.arg,))
|
||||
.bitcast(dtypes.bfloat16) if x.max_numel() == 8 and x.src[0].dtype.scalar() == dtypes.bfloat16 and x.src[0].max_numel() == 8 else None),
|
||||
(UPat(Ops.WMMA, name="x", dtype=dtypes.float),
|
||||
lambda x: UOp(Ops.WMMA, dtypes.float, (x.src[0].replace(dtype=x.src[0].dtype.scalar()).bitcast(dtypes.uint16),
|
||||
x.src[1].replace(dtype=x.src[1].dtype.scalar()).bitcast(dtypes.uint16), x.src[2]), (*x.arg,))
|
||||
if x.max_numel() == 8 and x.src[0].dtype.scalar() == dtypes.bfloat16 and x.src[0].max_numel() == 8 else None)
|
||||
])
|
||||
|
||||
def supported_dtypes(self): return {d for d in super().supported_dtypes()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue