mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
no dtypes.count in renderer, use shape
This commit is contained in:
parent
6815f28849
commit
ded5cdf2ea
7 changed files with 80 additions and 76 deletions
|
|
@ -55,7 +55,7 @@ class Estimates:
|
|||
lds += u.dtype.itemsize * mults
|
||||
elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
|
||||
lds += u.src[1].dtype.itemsize * mults
|
||||
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
|
||||
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.max_numel()
|
||||
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
||||
return Estimates(flops, lds, sum(mem.values()))
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ base_rewrite = PatternMatcher([
|
|||
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
|
||||
f"{ctx.float4_style[0]}{','.join([ctx[y] for y in x.src])}{ctx.float4_style[1]}"),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x:
|
||||
f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})" if x.dtype.count > 1 and not isinstance(x.dtype, PtrDType) else None),
|
||||
f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})" if x.max_numel() > 1 and not isinstance(x.dtype, PtrDType) else None),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x:
|
||||
f"__builtin_bit_cast({ctx.render_dtype(x.dtype)}, ({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"),
|
||||
|
|
@ -53,7 +53,7 @@ base_rewrite = PatternMatcher([
|
|||
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
|
||||
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR, Ops.OR, Ops.AND} else ctx[v] for v in x.src]), x.dtype)),
|
||||
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
|
||||
(f"[{x.arg[0]}]" if x.src[0].dtype.count > ctx.gep_arr_threshold else f".{'xyzwabcd'[x.arg[0]]}")),
|
||||
(f"[{x.arg[0]}]" if x.src[0].max_numel() > ctx.gep_arr_threshold else f".{'xyzwabcd'[x.arg[0]]}")),
|
||||
# custom passes through with format
|
||||
(UPat((Ops.CUSTOM, Ops.CUSTOMI), name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -153,15 +153,15 @@ extra_matcher = PatternMatcher([
|
|||
# no int8 mul or cmove, cast to int16
|
||||
(UPat.var("a", dtypes.int8s) * UPat.var("b"), lambda a,b: (a.cast(dtypes.int16) * b.cast(dtypes.int16)).cast(a.dtype)),
|
||||
(UPat.var("m").where(UPat.var("a", (dtypes.bool,)+dtypes.int8s), UPat.var("b")),
|
||||
lambda m,a,b: m.where(a.cast(dtypes.int16), b.cast(dtypes.int16)).cast(a.dtype) if a.dtype.count == 1 else None),
|
||||
lambda m,a,b: m.where(a.cast(dtypes.int16), b.cast(dtypes.int16)).cast(a.dtype) if a.max_numel() == 1 else None),
|
||||
# float16 alus are done in float32
|
||||
(UPat(GroupOp.ALU, dtypes.float16, name="x"), lambda x: UOp(x.op, dtypes.float.vec(x.dtype.count),
|
||||
(UPat(GroupOp.ALU, dtypes.float16, name="x"), lambda x: UOp(x.op, dtypes.float.vec(x.max_numel()),
|
||||
tuple(s.cast(dtypes.float) if s.dtype != dtypes.bool else s for s in x.src)).cast(x.dtype)),
|
||||
(UPat(GroupOp.Comparison, src=(UPat.var("a", dtypes.float16), UPat.var("b")), name="x"),
|
||||
lambda x,a,b: UOp(x.op, x.dtype, (a.cast(dtypes.float32), b.cast(dtypes.float32))).cast(x.dtype)),
|
||||
# no cmpne for packed ints, y != x => !(y==x)
|
||||
(UPat(Ops.CMPNE, src=(UPat.var("y", dtypes.ints), UPat.var("x")), name="cmp"),
|
||||
lambda y,x,cmp: UOp(Ops.CMPEQ, cmp.dtype, (y,x))^True if y.dtype.count > 1 else None),
|
||||
lambda y,x,cmp: UOp(Ops.CMPEQ, cmp.dtype, (y,x))^True if y.max_numel() > 1 else None),
|
||||
# float where expects a mask
|
||||
(UPat.var("m", dtypes.bool).where(UPat.var("a", dtypes.floats), UPat.var("b")),
|
||||
lambda m,a,b: m.cast(a.dtype).ne(0).where(a, b) if m.src[0].dtype not in dtypes.floats else None),
|
||||
|
|
@ -174,26 +174,26 @@ extra_matcher = PatternMatcher([
|
|||
# ***** X86 pre instruction selection *****
|
||||
|
||||
def gated_load(ctx, base:UOp, idx:UOp, cast:UOp, alt:UOp, gate:UOp, x:UOp):
|
||||
local = UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(x.dtype.count, AddrSpace.LOCAL), arg=next(ctx))
|
||||
local = UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(x.max_numel(), AddrSpace.LOCAL), arg=next(ctx))
|
||||
local_idx = local.index(UOp.const(dtypes.int32, 0), ptr=True)
|
||||
ptr = gate.where(base.index(idx, ptr=True), local_idx).after((local_idx if x.dtype.count == 1 else local).store(alt))
|
||||
ptr = gate.where(base.index(idx, ptr=True), local_idx).after((local_idx if x.max_numel() == 1 else local).store(alt))
|
||||
return ptr.cast(cast.dtype).load(dtype=x.dtype)
|
||||
|
||||
def gated_store(base:UOp, idx:UOp, cast:UOp, gate:UOp, val:UOp):
|
||||
local = UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(val.dtype.count, AddrSpace.LOCAL), arg=-1)
|
||||
local = UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(val.max_numel(), AddrSpace.LOCAL), arg=-1)
|
||||
ptr = gate.where(base.index(idx, ptr=True), local.index(UOp.const(dtypes.int32, 0), ptr=True))
|
||||
return ptr.cast(cast.dtype).store(val)
|
||||
|
||||
# these must be done in a separate matcher because they violate the spec
|
||||
pre_isel_matcher = PatternMatcher([
|
||||
# zero extending scalar 32bit int is a noop
|
||||
(UPat.var("y", dtypes.uint32).cast(dtypes.int64s, name="x"), lambda y,x: x.replace(op=Ops.NOOP) if y.dtype.count == 1 else None),
|
||||
(UPat.var("y", dtypes.uint32).cast(dtypes.int64s, name="x"), lambda y,x: x.replace(op=Ops.NOOP) if y.max_numel() == 1 else None),
|
||||
# cast between signed and unsigned int is a noop
|
||||
(UPat.var("y", dtypes.ints+(dtypes.bool,)).cast(dtypes.ints, name="x"),
|
||||
lambda y,x: x.replace(op=Ops.NOOP) if x.dtype.itemsize == y.dtype.itemsize else None),
|
||||
# cast to < scalar int is a noop
|
||||
(UPat.var("y", dtypes.ints).cast(dtypes.ints, name="x"),
|
||||
lambda y,x: x.replace(op=Ops.NOOP) if x.dtype.itemsize < y.dtype.itemsize and y.dtype.count == 1 else None),
|
||||
lambda y,x: x.replace(op=Ops.NOOP) if x.dtype.itemsize < y.dtype.itemsize and y.max_numel() == 1 else None),
|
||||
# bitcasts between scalar floats and ints are real, rest are noops
|
||||
(UPat.var("y").bitcast().named("x"), lambda y,x: None if y.dtype in dtypes.floats and x.dtype in dtypes.ints or \
|
||||
y.dtype in dtypes.ints and x.dtype in dtypes.floats else x.replace(op=Ops.NOOP)),
|
||||
|
|
@ -208,7 +208,7 @@ pre_isel_matcher = PatternMatcher([
|
|||
# TODO: remove this once we allow all flag producing ops in cmove
|
||||
# if gate in scalar int cmove is not a comparison need to add one to set the flag
|
||||
(UPat.var("m", dtypes.bool).where(UPat.var("a"), UPat.var("b")),
|
||||
lambda m,a,b: m.ne(0).where(a,b) if m.op not in GroupOp.Comparison and a.dtype.count == 1 else None),
|
||||
lambda m,a,b: m.ne(0).where(a,b) if m.op not in GroupOp.Comparison and a.max_numel() == 1 else None),
|
||||
])
|
||||
|
||||
# ***** X86 registers *****
|
||||
|
|
@ -252,8 +252,8 @@ def cmp(x:UOp) -> UOp:
|
|||
return x.ins(X86Ops.CMP, dtype=dtypes.void) if (i:=to_imm(x.src[1])) is None else x.ins(X86Ops.CMPi, dtype=dtypes.void, src=(x.src[0], i))
|
||||
def vcmp(x:UOp) -> UOp:
|
||||
v = imm(dtypes.uint8, {Ops.CMPLT: 1, Ops.CMPNE: 4, Ops.CMPEQ: 0}[x.op])
|
||||
if x.dtype.scalar() is dtypes.float32: return x.ins(X86Ops.VCMPSS if x.dtype.count == 1 else X86Ops.VCMPPS, src=x.src + (v,))
|
||||
return x.ins(X86Ops.VCMPSD if x.dtype.count == 1 else X86Ops.VCMPPD, src=x.src + (v,))
|
||||
if x.dtype.scalar() is dtypes.float32: return x.ins(X86Ops.VCMPSS if x.max_numel() == 1 else X86Ops.VCMPPS, src=x.src + (v,))
|
||||
return x.ins(X86Ops.VCMPSD if x.max_numel() == 1 else X86Ops.VCMPPD, src=x.src + (v,))
|
||||
|
||||
# vshufps xmm2, xmm0, xmm1, imm
|
||||
# for 128 bit xmm2 selects its lower 2 32 bits from xmm0 and its upper 2 32 bits from xmm1 according to imm
|
||||
|
|
@ -351,7 +351,7 @@ def alloc_vregs(ctx:IselContext, x:UOp) -> UOp|None:
|
|||
defs = []
|
||||
if isinstance(x.tag, tuple): defs = [ctx.vreg(x.tag)]
|
||||
elif x.dtype in dtypes.ints+(dtypes.bool,) or isinstance(x.dtype, PtrDType): defs = [ctx.vreg(WGPR)]
|
||||
elif x.dtype in dtypes.floats or x.dtype.count > 1: defs = [ctx.vreg(XMM)]
|
||||
elif x.dtype.scalar() in dtypes.floats or x.max_numel() > 1: defs = [ctx.vreg(XMM)]
|
||||
# TODO: add this once the scheduler can track register pressure
|
||||
# if x.arg in X86GroupOp.WriteFlags: defs.append(ctx.vreg(RFLAGS))
|
||||
return x.replace(tag=tuple(defs))
|
||||
|
|
@ -390,16 +390,16 @@ isel_matcher = PatternMatcher([
|
|||
UOp.const(dt:=to_int(x.dtype), struct.unpack(dt.fmt, struct.pack(x.dtype.fmt, x.arg))[0]).bitcast(x.dtype) if not x.tag else None),
|
||||
# TODO: these should use a.maximum(b) / a.minimum(b)
|
||||
((UPat.var("a") < UPat.var("b")).where(UPat.var("b", dtypes.float32), UPat.var("a")), lambda a,b:
|
||||
a.ins(X86Ops.VMAXSS if a.dtype.count == 1 else X86Ops.VMAXPS, src=(a, b))),
|
||||
a.ins(X86Ops.VMAXSS if a.max_numel() == 1 else X86Ops.VMAXPS, src=(a, b))),
|
||||
((UPat.var("a") < UPat.var("b")).where(UPat.var("b", dtypes.float64), UPat.var("a")), lambda a,b:
|
||||
a.ins(X86Ops.VMAXSD if a.dtype.count == 1 else X86Ops.VMAXPD, src=(a, b))),
|
||||
a.ins(X86Ops.VMAXSD if a.max_numel() == 1 else X86Ops.VMAXPD, src=(a, b))),
|
||||
((UPat.var("a") < UPat.var("b")).where(UPat.var("a", dtypes.float32), UPat.var("b")), lambda a,b:
|
||||
a.ins(X86Ops.VMINSS if a.dtype.count == 1 else X86Ops.VMINPS, src=(a, b))),
|
||||
a.ins(X86Ops.VMINSS if a.max_numel() == 1 else X86Ops.VMINPS, src=(a, b))),
|
||||
((UPat.var("a") < UPat.var("b")).where(UPat.var("a", dtypes.float64), UPat.var("b")), lambda a,b:
|
||||
a.ins(X86Ops.VMINSD if a.dtype.count == 1 else X86Ops.VMINPD, src=(a, b))),
|
||||
a.ins(X86Ops.VMINSD if a.max_numel() == 1 else X86Ops.VMINPD, src=(a, b))),
|
||||
# conditional moves that use masks NOTE: these currently assume a mask producing cmp exists
|
||||
(UPat.var("m").where(UPat.var("a", dtypes.ints), UPat.var("b")), lambda m,a,b:
|
||||
a.ins(X86Ops.VPBLENDVB, src=(b, a, m.replace(dtype=m.src[0].dtype))) if a.dtype.count > 1 else None),
|
||||
a.ins(X86Ops.VPBLENDVB, src=(b, a, m.replace(dtype=m.src[0].dtype))) if a.max_numel() > 1 else None),
|
||||
(UPat.var("m").where(UPat.var("a", dtypes.float32), UPat.var("b")), lambda m,a,b:
|
||||
a.ins(X86Ops.VBLENDVPS, src=(b, a, m.replace(dtype=m.src[0].dtype)))),
|
||||
(UPat.var("m").where(UPat.var("a", dtypes.float64), UPat.var("b")), lambda m,a,b:
|
||||
|
|
@ -434,12 +434,12 @@ isel_matcher = PatternMatcher([
|
|||
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int32s), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.VPCMPGTD, src=(b, a))),
|
||||
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int64s), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.VPCMPGTQ, src=(b, a))),
|
||||
# float unary
|
||||
(UPat.var("y", dtypes.float32).sqrt().named("x"), lambda y,x: x.ins(X86Ops.VSQRTSS, src=(y, y)) if x.dtype.count == 1 else x.ins(X86Ops.VSQRTPS)),
|
||||
(UPat.var("y", dtypes.float64).sqrt().named("x"), lambda y,x: x.ins(X86Ops.VSQRTSD, src=(y, y)) if x.dtype.count == 1 else x.ins(X86Ops.VSQRTPD)),
|
||||
(UPat.var("y", dtypes.float32).sqrt().named("x"), lambda y,x: x.ins(X86Ops.VSQRTSS, src=(y, y)) if x.max_numel() == 1 else x.ins(X86Ops.VSQRTPS)),
|
||||
(UPat.var("y", dtypes.float64).sqrt().named("x"), lambda y,x: x.ins(X86Ops.VSQRTSD, src=(y, y)) if x.max_numel() == 1 else x.ins(X86Ops.VSQRTPD)),
|
||||
(UPat.var("y", dtypes.float32).trunc().named("x"), lambda y,x:
|
||||
x.ins(X86Ops.VROUNDSS, src=(y, y, imm(dtypes.uint8, 3))) if x.dtype.count == 1 else x.ins(X86Ops.VROUNDPS, src=(y, imm(dtypes.uint8, 3)))),
|
||||
x.ins(X86Ops.VROUNDSS, src=(y, y, imm(dtypes.uint8, 3))) if x.max_numel() == 1 else x.ins(X86Ops.VROUNDPS, src=(y, imm(dtypes.uint8, 3)))),
|
||||
(UPat.var("y", dtypes.float64).trunc().named("x"), lambda y,x:
|
||||
x.ins(X86Ops.VROUNDSD, src=(y, y, imm(dtypes.uint8, 3))) if x.dtype.count == 1 else x.ins(X86Ops.VROUNDPD, src=(y, imm(dtypes.uint8, 3)))),
|
||||
x.ins(X86Ops.VROUNDSD, src=(y, y, imm(dtypes.uint8, 3))) if x.max_numel() == 1 else x.ins(X86Ops.VROUNDPD, src=(y, imm(dtypes.uint8, 3)))),
|
||||
# shufles
|
||||
(UPat.var("y", dtypes.float32).broadcast(name="x"), lambda y,x: x.ins(X86Ops.VBROADCASTSS, src=(y,))),
|
||||
# for float16 we route the srcs through gprs unless we can fold them, this is suboptimal for values in xmms, in that case we want vpunpcklwd
|
||||
|
|
@ -458,29 +458,29 @@ isel_matcher = PatternMatcher([
|
|||
(UPat.var("y", dtypes.floats).gep(name="x"), lambda y,x: x.ins(X86Ops.VPSRLDQ, src=(y, imm(dtypes.uint8, x.arg[0] * x.dtype.itemsize)))),
|
||||
# fused multiply add
|
||||
((UPat(Ops.MUL, dtypes.float32, name="a") + UPat.var("b")).named("c"), lambda ctx,a,b,c:
|
||||
a.ins(X86Ops.VFMADD213SS if a.dtype.count == 1 else X86Ops.VFMADD213PS, src=(*a.src, b)) if is_foldable(ctx, c, a) else None),
|
||||
a.ins(X86Ops.VFMADD213SS if a.max_numel() == 1 else X86Ops.VFMADD213PS, src=(*a.src, b)) if is_foldable(ctx, c, a) else None),
|
||||
((UPat(Ops.MUL, dtypes.float64, name="a") + UPat.var("b")).named("c"), lambda ctx,a,b,c:
|
||||
a.ins(X86Ops.VFMADD213SD if a.dtype.count == 1 else X86Ops.VFMADD213PD, src=(*a.src, b)) if is_foldable(ctx, c, a) else None),
|
||||
a.ins(X86Ops.VFMADD213SD if a.max_numel() == 1 else X86Ops.VFMADD213PD, src=(*a.src, b)) if is_foldable(ctx, c, a) else None),
|
||||
# packed bitwise
|
||||
((UPat() & UPat()).named("x"), lambda x: x.ins(X86Ops.VPAND) if x.dtype.count > 1 else None),
|
||||
((UPat() | UPat()).named("x"), lambda x: x.ins(X86Ops.VPOR) if x.dtype.count > 1 else None),
|
||||
((UPat() ^ UPat()).named("x"), lambda x: x.ins(X86Ops.VPXOR) if x.dtype.count > 1 else None),
|
||||
((UPat() & UPat()).named("x"), lambda x: x.ins(X86Ops.VPAND) if x.max_numel() > 1 else None),
|
||||
((UPat() | UPat()).named("x"), lambda x: x.ins(X86Ops.VPOR) if x.max_numel() > 1 else None),
|
||||
((UPat() ^ UPat()).named("x"), lambda x: x.ins(X86Ops.VPXOR) if x.max_numel() > 1 else None),
|
||||
# packed int binary
|
||||
((UPat(dtype=dtypes.int32s) << UPat()).named("x"), lambda x: x.ins(X86Ops.VPSLLVD) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int64s) << UPat()).named("x"), lambda x: x.ins(X86Ops.VPSLLVQ) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.uint32) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRLVD) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.uint64) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRLVQ) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int32) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRAVD) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int8s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDB) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int16s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDW) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int32s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDD) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int64s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDQ) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.SUB, dtypes.int8s, name="x"), lambda x: x.ins(X86Ops.VPSUBB) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.SUB, dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPSUBW) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.SUB, dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPSUBD) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.SUB, dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPSUBQ) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.MUL, dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPMULLW) if x.dtype.count > 1 else None),
|
||||
(UPat(Ops.MUL, dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMULLD) if x.dtype.count > 1 else None),
|
||||
((UPat(dtype=dtypes.int32s) << UPat()).named("x"), lambda x: x.ins(X86Ops.VPSLLVD) if x.max_numel() > 1 else None),
|
||||
((UPat(dtype=dtypes.int64s) << UPat()).named("x"), lambda x: x.ins(X86Ops.VPSLLVQ) if x.max_numel() > 1 else None),
|
||||
((UPat(dtype=dtypes.uint32) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRLVD) if x.max_numel() > 1 else None),
|
||||
((UPat(dtype=dtypes.uint64) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRLVQ) if x.max_numel() > 1 else None),
|
||||
((UPat(dtype=dtypes.int32) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRAVD) if x.max_numel() > 1 else None),
|
||||
((UPat(dtype=dtypes.int8s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDB) if x.max_numel() > 1 else None),
|
||||
((UPat(dtype=dtypes.int16s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDW) if x.max_numel() > 1 else None),
|
||||
((UPat(dtype=dtypes.int32s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDD) if x.max_numel() > 1 else None),
|
||||
((UPat(dtype=dtypes.int64s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDQ) if x.max_numel() > 1 else None),
|
||||
(UPat(Ops.SUB, dtypes.int8s, name="x"), lambda x: x.ins(X86Ops.VPSUBB) if x.max_numel() > 1 else None),
|
||||
(UPat(Ops.SUB, dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPSUBW) if x.max_numel() > 1 else None),
|
||||
(UPat(Ops.SUB, dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPSUBD) if x.max_numel() > 1 else None),
|
||||
(UPat(Ops.SUB, dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPSUBQ) if x.max_numel() > 1 else None),
|
||||
(UPat(Ops.MUL, dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPMULLW) if x.max_numel() > 1 else None),
|
||||
(UPat(Ops.MUL, dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMULLD) if x.max_numel() > 1 else None),
|
||||
# scalar int binary
|
||||
((UPat(dtype=dtypes.ints).alu(Ops.CDIV, UPat())).named("x"), idiv),
|
||||
# scalar int binary with immediate
|
||||
|
|
@ -504,21 +504,21 @@ isel_matcher = PatternMatcher([
|
|||
(UPat.var("a", dtypes.ints+(dtypes.bool,)) ^ UPat.var("b"), lambda a,b: a.ins(X86Ops.XOR, src=(a, b))),
|
||||
(UPat(Ops.SUB, dtypes.ints, (UPat.var("a"), UPat.var("b"))), lambda a,b: a.ins(X86Ops.SUB, src=(a, b))),
|
||||
# float binary
|
||||
((UPat(dtype=dtypes.float32) + UPat()).named("x"), lambda x: x.ins(X86Ops.VADDSS if x.dtype.count == 1 else X86Ops.VADDPS)),
|
||||
((UPat(dtype=dtypes.float64) + UPat()).named("x"), lambda x: x.ins(X86Ops.VADDSD if x.dtype.count == 1 else X86Ops.VADDPD)),
|
||||
((UPat(dtype=dtypes.float32) * UPat()).named("x"), lambda x: x.ins(X86Ops.VMULSS if x.dtype.count == 1 else X86Ops.VMULPS)),
|
||||
((UPat(dtype=dtypes.float64) * UPat()).named("x"), lambda x: x.ins(X86Ops.VMULSD if x.dtype.count == 1 else X86Ops.VMULPD)),
|
||||
(UPat(Ops.SUB, dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VSUBSS if x.dtype.count == 1 else X86Ops.VSUBPS)),
|
||||
(UPat(Ops.SUB, dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VSUBSD if x.dtype.count == 1 else X86Ops.VSUBPD)),
|
||||
(UPat(Ops.FDIV, dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VDIVSS if x.dtype.count == 1 else X86Ops.VDIVPS)),
|
||||
(UPat(Ops.FDIV, dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VDIVSD if x.dtype.count == 1 else X86Ops.VDIVPD)),
|
||||
((UPat(dtype=dtypes.float32) + UPat()).named("x"), lambda x: x.ins(X86Ops.VADDSS if x.max_numel() == 1 else X86Ops.VADDPS)),
|
||||
((UPat(dtype=dtypes.float64) + UPat()).named("x"), lambda x: x.ins(X86Ops.VADDSD if x.max_numel() == 1 else X86Ops.VADDPD)),
|
||||
((UPat(dtype=dtypes.float32) * UPat()).named("x"), lambda x: x.ins(X86Ops.VMULSS if x.max_numel() == 1 else X86Ops.VMULPS)),
|
||||
((UPat(dtype=dtypes.float64) * UPat()).named("x"), lambda x: x.ins(X86Ops.VMULSD if x.max_numel() == 1 else X86Ops.VMULPD)),
|
||||
(UPat(Ops.SUB, dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VSUBSS if x.max_numel() == 1 else X86Ops.VSUBPS)),
|
||||
(UPat(Ops.SUB, dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VSUBSD if x.max_numel() == 1 else X86Ops.VSUBPD)),
|
||||
(UPat(Ops.FDIV, dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VDIVSS if x.max_numel() == 1 else X86Ops.VDIVPS)),
|
||||
(UPat(Ops.FDIV, dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VDIVSD if x.max_numel() == 1 else X86Ops.VDIVPD)),
|
||||
# casts
|
||||
(UPat(dtype=dtypes.int32).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTDQ2PS) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.int32).cast(dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VCVTDQ2PD) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.float32).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VCVTTPS2DQ) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.float64).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VCVTTPD2DQ) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.float32).cast(dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VCVTPS2PD) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.float64).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTPD2PS) if x.dtype.count > 1 else None),
|
||||
(UPat(dtype=dtypes.int32).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTDQ2PS) if x.max_numel() > 1 else None),
|
||||
(UPat(dtype=dtypes.int32).cast(dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VCVTDQ2PD) if x.max_numel() > 1 else None),
|
||||
(UPat(dtype=dtypes.float32).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VCVTTPS2DQ) if x.max_numel() > 1 else None),
|
||||
(UPat(dtype=dtypes.float64).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VCVTTPD2DQ) if x.max_numel() > 1 else None),
|
||||
(UPat(dtype=dtypes.float32).cast(dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VCVTPS2PD) if x.max_numel() > 1 else None),
|
||||
(UPat(dtype=dtypes.float64).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTPD2PS) if x.max_numel() > 1 else None),
|
||||
(UPat(dtype=dtypes.float32).cast(dtypes.float16, name="x"), lambda x: x.ins(X86Ops.VCVTPS2PH, src=x.src + (imm(dtypes.uint8, 4),))),
|
||||
(UPat(dtype=dtypes.float16).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTPH2PS)),
|
||||
(UPat(dtype=dtypes.float32).cast(dtypes.int32s+dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VCVTTSS2SI)),
|
||||
|
|
@ -527,9 +527,9 @@ isel_matcher = PatternMatcher([
|
|||
(UPat.var("y", dtypes.float64).cast(dtypes.float32, name="x"), lambda y,x: x.ins(X86Ops.VCVTSD2SS, src=(y, y))),
|
||||
(UPat.var("y", (dtypes.int32, dtypes.int64)).cast(dtypes.float32, name="x"), lambda y,x: x.ins(X86Ops.VCVTSI2SS, src=(def_reg(x.dtype), y))),
|
||||
(UPat.var("y", (dtypes.int32, dtypes.int64)).cast(dtypes.float64, name="x"), lambda y,x: x.ins(X86Ops.VCVTSI2SD, src=(def_reg(x.dtype), y))),
|
||||
(UPat(dtype=dtypes.uints+(dtypes.bool,)).cast(dtypes.ints, name="x"), lambda x: x.ins(X86Ops.MOVZX) if x.dtype.count == 1 else None),
|
||||
(UPat(dtype=dtypes.int32).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.MOVSXD) if x.dtype.count == 1 else None),
|
||||
(UPat(dtype=dtypes.sints).cast(dtypes.ints, name="x"), lambda x: x.ins(X86Ops.MOVSX) if x.dtype.count == 1 else None),
|
||||
(UPat(dtype=dtypes.uints+(dtypes.bool,)).cast(dtypes.ints, name="x"), lambda x: x.ins(X86Ops.MOVZX) if x.max_numel() == 1 else None),
|
||||
(UPat(dtype=dtypes.int32).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.MOVSXD) if x.max_numel() == 1 else None),
|
||||
(UPat(dtype=dtypes.sints).cast(dtypes.ints, name="x"), lambda x: x.ins(X86Ops.MOVSX) if x.max_numel() == 1 else None),
|
||||
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXBW)),
|
||||
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXBD)),
|
||||
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXBQ)),
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ base_rewrite = PatternMatcher([
|
|||
(UPat(Ops.GEP, name="x"), lambda ctx,x: f" {ctx[x]} = extractelement {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i32 {x.arg[0]}"),
|
||||
(UPat(Ops.STACK, src=UPat.var('y'), name="x"), lambda ctx,x,y:
|
||||
f" {ctx[x]}_z = insertelement <1 x {ldt(y.dtype)}> poison, {ldt(y.dtype)} {ctx[y]}, i32 0\n"
|
||||
f" {ctx[x]} = shufflevector <1 x {ldt(y.dtype)}> {ctx[x]}_z, <1 x {ldt(y.dtype)}> poison, <{x.dtype.count} x i32> zeroinitializer"),
|
||||
f" {ctx[x]} = shufflevector <1 x {ldt(y.dtype)}> {ctx[x]}_z, <1 x {ldt(y.dtype)}> poison, <{x.max_numel()} x i32> zeroinitializer"),
|
||||
(UPat(Ops.STACK, name="x"), lambda ctx,x: "\n".join([(f" {ctx[x]}_{i}" if i+1 != len(x.src) else f" {ctx[x]}")+
|
||||
f" = insertelement {ldt(x.dtype)} "+(f"{ctx[x]}_{i-1}" if i != 0 else "poison")+
|
||||
f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])),
|
||||
|
|
|
|||
|
|
@ -86,9 +86,10 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if
|
|||
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1, **iointr(space)},
|
||||
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
|
||||
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
|
||||
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.bitsize//dtype.count, num_components=lambda dtype:dtype.count,
|
||||
nload = nir_instr(nc=lambda dtype,num_components:num_components, bs=lambda dtype,num_components:dtype.bitsize//num_components,
|
||||
num_components=lambda dtype,num_components:num_components,
|
||||
intrins=lambda space:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}), **iointr(space)}, srcs=lambda addr: [nsrc(addr)])(
|
||||
lambda b, space, addr, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
|
||||
lambda b, space, addr, dtype, num_components: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
|
||||
|
||||
ngid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_workgroup_id))
|
||||
nlid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_local_invocation_id))
|
||||
|
|
@ -152,10 +153,11 @@ class NIRRenderer(Renderer):
|
|||
lambda ctx,buf,off,val: nstore(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(), UPat.var("alt"), UPat.var("gate")), name="x"),
|
||||
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
|
||||
lambda: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype), lambda: ctx.r[alt])),
|
||||
lambda: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype, x.max_numel()),
|
||||
lambda: ctx.r[alt])),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(),), name="x"),
|
||||
lambda ctx,x,buf,off: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x.dtype)),
|
||||
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.dtype.count}", *[ctx.r[src] for src in x.src])),
|
||||
lambda ctx,x,buf,off: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x.dtype, x.max_numel())),
|
||||
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.max_numel()}", *[ctx.r[src] for src in x.src])),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype.scalar()][x.op], *[ctx.r[src] for src in x.src])),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: ncast(ctx.b, ctx.r[x.src[0]], x.src[0].dtype, x.dtype)),
|
||||
(UPat(Ops.BITCAST, src=(UPat.var("a"),), allow_any_len=True), lambda ctx,a: ctx.r[a]),
|
||||
|
|
@ -203,7 +205,7 @@ class NIRRenderer(Renderer):
|
|||
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents))
|
||||
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype)
|
||||
mesa.nir_push_loop(self.b)
|
||||
self.r[u] = nload(self.b, AddrSpace.REG, i, u.dtype)
|
||||
self.r[u] = nload(self.b, AddrSpace.REG, i, u.dtype, u.max_numel())
|
||||
nif(self.b, nalu(self.b, "ilt", self.r[u], self.r[u.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
|
||||
elif u.op == Ops.END:
|
||||
r = u.src[1]
|
||||
|
|
|
|||
|
|
@ -102,18 +102,18 @@ string_rewrite = PatternMatcher([
|
|||
# store / gated load / load
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))).or_casted(), UPat.var("var")), allow_any_len=True),
|
||||
lambda ctx, loc, var, buf: f"st.{mem_type(buf)}" + \
|
||||
f"{f'.v{cnt}' if ((cnt:=var.dtype.count)>1) else ''}.{ctx.mem_types[var.dtype.scalar()]} " + \
|
||||
f"[{ctx.r[loc]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.dtype.count > 1 else ctx.r[var]};"),
|
||||
f"{f'.v{cnt}' if ((cnt:=var.max_numel())>1) else ''}.{ctx.mem_types[var.dtype.scalar()]} " + \
|
||||
f"[{ctx.r[loc]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.max_numel() > 1 else ctx.r[var]};"),
|
||||
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))).or_casted(), UPat.var("alt"), UPat.var("gate"))),
|
||||
lambda ctx, x, loc, alt, gate, buf: flatten([
|
||||
[f"mov.{ctx.mem_types[x.dtype.scalar()]} {v}, {render_val(0, x.dtype.scalar())};" for v in ctx.r[x]],
|
||||
[f"@{ctx.r[gate]} ld.{mem_type(buf)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"]
|
||||
]) if alt.dtype.count > 1 else [
|
||||
[f"@{ctx.r[gate]} ld.{mem_type(buf)}.v{x.max_numel()}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"]
|
||||
]) if x.max_numel() > 1 else [
|
||||
f"@{ctx.r[gate]} ld.{mem_type(buf)}.{ctx.mem_types[x.dtype.scalar()]} {ctx.r[x]}, [{ctx.r[loc]}+0];",
|
||||
f"@!{ctx.r[gate]} mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[x]}, {ctx.r[alt]};"]),
|
||||
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))).or_casted(),)),
|
||||
lambda ctx, x, loc, buf: f"ld.{mem_type(buf)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
|
||||
if x.dtype.count > 1 else f"ld.{mem_type(buf)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
|
||||
lambda ctx, x, loc, buf: f"ld.{mem_type(buf)}.v{x.max_numel()}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
|
||||
if x.max_numel() > 1 else f"ld.{mem_type(buf)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
|
||||
# simple
|
||||
(UPat(Ops.DEFINE_REG, src=()), lambda ctx: []),
|
||||
(UPat(Ops.RANGE, name="r"), lambda ctx, r: [
|
||||
|
|
@ -218,14 +218,14 @@ class PTXRenderer(Renderer):
|
|||
if u.op is Ops.SPECIAL: r[u] = "%" + u.arg
|
||||
elif u.op is Ops.DEFINE_VAR: bufs.append((u.expr, u.dtype))
|
||||
elif u.op is Ops.LOAD:
|
||||
r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u)
|
||||
r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.max_numel())] if u.max_numel() > 1 else ssa('val', u)
|
||||
elif u.op is Ops.PARAM: bufs.append((f"data{u.arg}", u.dtype))
|
||||
elif u.op is Ops.WMMA:
|
||||
# registers for packing/unpacking input and acc
|
||||
self.wmma_r = [[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[0]]), 4 // u.src[0].dtype.scalar().itemsize)],
|
||||
[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.src[0].dtype.scalar().itemsize)],
|
||||
[ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.dtype.scalar().itemsize)]]
|
||||
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
|
||||
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.max_numel())]
|
||||
prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.END: ("pred", "pred"), Ops.RANGE: ("ridx", None),
|
||||
Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL: ("local", self.types[dtypes.ulong]),
|
||||
Ops.PARAM: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None))
|
||||
|
|
|
|||
|
|
@ -368,6 +368,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
@property
|
||||
def max_shape(self) -> tuple[int, ...]: return to_max_shape(self.shape)
|
||||
|
||||
def max_numel(self) -> int: return prod(self.max_shape)
|
||||
|
||||
@property
|
||||
def shard_shape(self) -> tuple[sint, ...]:
|
||||
if not isinstance(self.device, tuple) or self.axis is None: return self.shape
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue