Revert "modify for size 8"

This reverts commit 3ef0904bd9.
This commit is contained in:
p4sscode 2024-07-21 15:49:24 -03:00
commit 0b00fcfb40
3 changed files with 7 additions and 9 deletions

View file

@ -693,7 +693,7 @@ class Kernel:
fix_st2 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((0,0), (1,1), (1,2), (0,2), (1,0)), ((0,1), (0,3), (1,3)))
elif self.opts.device == "CLANG":
reduce_axes, fix_st1, fix_st2 = [], None, None
upcast_axis = (self.shape_len-self.upcasted+1, self.shape_len-self.upcasted, self.shape_len-self.upcasted+1)
upcast_axis = (self.shape_len-self.upcasted+1, self.shape_len-self.upcasted, self.shape_len-self.upcasted)
elif self.opts.device in {"CUDA", "NV"}:
reduce_axes = [self.shape_len-self.upcasted, self.shape_len-self.upcasted+1]
upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted+2, self.shape_len-self.upcasted+2)

View file

@ -29,7 +29,7 @@ def image_contract_store(buf, ex, idx, idy, ls_allow_any_len, var):
# ***** float4 handling *****
def float4_expand_load(load, buf, ex, idx=UOp.const(dtypes.int, 0), idx2=None):
if len(ex.src) not in [4, 8]: return None
if len(ex.src) != 4: return None
if tuple(x.arg for x in ex.src if x.op is UOps.CONST) != tuple(range(len(ex.src))): return None
if buf.dtype != PtrDType(dtypes.float) and buf.dtype != PtrDType(dtypes.half) and not isinstance(buf.dtype, ImageDType): return None
if idx2 is not None: idx = idx + idx2
@ -40,7 +40,7 @@ def float4_expand_load(load, buf, ex, idx=UOp.const(dtypes.int, 0), idx2=None):
return UOp(UOps.EXPAND, load.dtype, tuple(UOp(UOps.GEP, load.dtype, (vec_load,), i) for i in range(len(ex.src))), ex.arg)
def float4_contract_store(buf, ex, var, store_allow_any_len, idx=UOp.const(dtypes.int, 0), idx2=None, idx3=None):
if len(ex.src) not in [2, 4, 8]: return None
if len(ex.src) not in [2, 4]: return None
if tuple(x.arg for x in ex.src if x.op is UOps.CONST) != tuple(range(len(ex.src))): return None
if buf.dtype != PtrDType(dtypes.float) and buf.dtype != PtrDType(dtypes.half) and not isinstance(buf.dtype, ImageDType): return None
if idx2 is not None: idx = idx + idx2
@ -140,8 +140,8 @@ constant_folder = PatternMatcher([
lambda x: UOp(x.op, dtypes.int32, x.src, x.arg)),
# VECTORIZE/GEP
(UOp(UOps.GEP, src=(UOp(UOps.VECTORIZE).name("cast"),)).name("gep"), lambda gep, cast: cast.src[gep.arg]),
*[(UOp(UOps.VECTORIZE, dtypes.float.vec(i), tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=j) for j in range(i))), lambda x: x) \
for i in [2, 4, 8]],
*[(UOp(UOps.VECTORIZE, dtypes.float.vec(i), tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=j)
for j in range(i))), lambda x: x) for i in [2, 4, 8]],
# tensor core with a 0 input is acc
(UOp(UOps.WMMA, src=(UOp.const(None, 0.0), UOp.var(), UOp.var('acc'))), lambda acc: acc),
(UOp(UOps.WMMA, src=(UOp.var(), UOp.const(None, 0.0), UOp.var('acc'))), lambda acc: acc),
@ -269,8 +269,6 @@ constant_folder = PatternMatcher([
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
# VECTORIZE-PHI-GEP -> PHI-VECTORIZE
(UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(4))).name("root"),
lambda root, val, v0, v1, v2, v3, v4, v5, v6, v7: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3, v4, v5, v6, v7))))),
(UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(4))).name("root"),
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3))))),
(UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(2))).name("root"),
@ -327,7 +325,7 @@ def do_expand(root:UOp):
new_src: List[UOp] = []
for src in root.src:
if src.op is UOps.EXPAND:
lnew_src = [src.src[_expand_arg_to_idx(src.arg, {**rpk, **lrpk})%8] for lrpk in lrpks]
lnew_src = [src.src[_expand_arg_to_idx(src.arg, {**rpk, **lrpk})] for lrpk in lrpks]
if len(dont_expand_args):
# TODO: is this right for UOps.WMMA? all lnew_src should be the same
new_src.append(lnew_src[0] if root.op is UOps.WMMA else UOp(UOps.EXPAND, root.dtype, tuple(lnew_src), dont_expand_args))

View file

@ -187,7 +187,7 @@ class ClangRenderer(CStyleLanguage):
float4 = "make_float4"
has_local = False
global_max = None
tensor_cores = [TensorCore(dims=(8,8,8), threads=[(0,8),(1,8)], thread_local_sizes=[[8],[8],[8,8]], dtype_in=dtypes.float, dtype_out=dtypes.float)]
tensor_cores = [TensorCore(dims=(4,4,4), threads=[(0,4),(1,4)], thread_local_sizes=[[4],[4],[4,4]], dtype_in=dtypes.float, dtype_out=dtypes.float)]
# language options
buffer_suffix = " restrict"