mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
c12e11dc32
commit
0b00fcfb40
3 changed files with 7 additions and 9 deletions
|
|
@ -693,7 +693,7 @@ class Kernel:
|
|||
fix_st2 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((0,0), (1,1), (1,2), (0,2), (1,0)), ((0,1), (0,3), (1,3)))
|
||||
elif self.opts.device == "CLANG":
|
||||
reduce_axes, fix_st1, fix_st2 = [], None, None
|
||||
upcast_axis = (self.shape_len-self.upcasted+1, self.shape_len-self.upcasted, self.shape_len-self.upcasted+1)
|
||||
upcast_axis = (self.shape_len-self.upcasted+1, self.shape_len-self.upcasted, self.shape_len-self.upcasted)
|
||||
elif self.opts.device in {"CUDA", "NV"}:
|
||||
reduce_axes = [self.shape_len-self.upcasted, self.shape_len-self.upcasted+1]
|
||||
upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted+2, self.shape_len-self.upcasted+2)
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ def image_contract_store(buf, ex, idx, idy, ls_allow_any_len, var):
|
|||
# ***** float4 handling *****
|
||||
|
||||
def float4_expand_load(load, buf, ex, idx=UOp.const(dtypes.int, 0), idx2=None):
|
||||
if len(ex.src) not in [4, 8]: return None
|
||||
if len(ex.src) != 4: return None
|
||||
if tuple(x.arg for x in ex.src if x.op is UOps.CONST) != tuple(range(len(ex.src))): return None
|
||||
if buf.dtype != PtrDType(dtypes.float) and buf.dtype != PtrDType(dtypes.half) and not isinstance(buf.dtype, ImageDType): return None
|
||||
if idx2 is not None: idx = idx + idx2
|
||||
|
|
@ -40,7 +40,7 @@ def float4_expand_load(load, buf, ex, idx=UOp.const(dtypes.int, 0), idx2=None):
|
|||
return UOp(UOps.EXPAND, load.dtype, tuple(UOp(UOps.GEP, load.dtype, (vec_load,), i) for i in range(len(ex.src))), ex.arg)
|
||||
|
||||
def float4_contract_store(buf, ex, var, store_allow_any_len, idx=UOp.const(dtypes.int, 0), idx2=None, idx3=None):
|
||||
if len(ex.src) not in [2, 4, 8]: return None
|
||||
if len(ex.src) not in [2, 4]: return None
|
||||
if tuple(x.arg for x in ex.src if x.op is UOps.CONST) != tuple(range(len(ex.src))): return None
|
||||
if buf.dtype != PtrDType(dtypes.float) and buf.dtype != PtrDType(dtypes.half) and not isinstance(buf.dtype, ImageDType): return None
|
||||
if idx2 is not None: idx = idx + idx2
|
||||
|
|
@ -140,8 +140,8 @@ constant_folder = PatternMatcher([
|
|||
lambda x: UOp(x.op, dtypes.int32, x.src, x.arg)),
|
||||
# VECTORIZE/GEP
|
||||
(UOp(UOps.GEP, src=(UOp(UOps.VECTORIZE).name("cast"),)).name("gep"), lambda gep, cast: cast.src[gep.arg]),
|
||||
*[(UOp(UOps.VECTORIZE, dtypes.float.vec(i), tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=j) for j in range(i))), lambda x: x) \
|
||||
for i in [2, 4, 8]],
|
||||
*[(UOp(UOps.VECTORIZE, dtypes.float.vec(i), tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=j)
|
||||
for j in range(i))), lambda x: x) for i in [2, 4, 8]],
|
||||
# tensor core with a 0 input is acc
|
||||
(UOp(UOps.WMMA, src=(UOp.const(None, 0.0), UOp.var(), UOp.var('acc'))), lambda acc: acc),
|
||||
(UOp(UOps.WMMA, src=(UOp.var(), UOp.const(None, 0.0), UOp.var('acc'))), lambda acc: acc),
|
||||
|
|
@ -269,8 +269,6 @@ constant_folder = PatternMatcher([
|
|||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
|
||||
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
|
||||
# VECTORIZE-PHI-GEP -> PHI-VECTORIZE
|
||||
(UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(4))).name("root"),
|
||||
lambda root, val, v0, v1, v2, v3, v4, v5, v6, v7: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3, v4, v5, v6, v7))))),
|
||||
(UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(4))).name("root"),
|
||||
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.VECTORIZE, val.dtype, (v0, v1, v2, v3))))),
|
||||
(UOp(UOps.VECTORIZE, src=tuple(UOp(UOps.PHI, src=(UOp(UOps.GEP, src=(UOp.var("val"),), arg=i), UOp.var(f"v{i}"))) for i in range(2))).name("root"),
|
||||
|
|
@ -327,7 +325,7 @@ def do_expand(root:UOp):
|
|||
new_src: List[UOp] = []
|
||||
for src in root.src:
|
||||
if src.op is UOps.EXPAND:
|
||||
lnew_src = [src.src[_expand_arg_to_idx(src.arg, {**rpk, **lrpk})%8] for lrpk in lrpks]
|
||||
lnew_src = [src.src[_expand_arg_to_idx(src.arg, {**rpk, **lrpk})] for lrpk in lrpks]
|
||||
if len(dont_expand_args):
|
||||
# TODO: is this right for UOps.WMMA? all lnew_src should be the same
|
||||
new_src.append(lnew_src[0] if root.op is UOps.WMMA else UOp(UOps.EXPAND, root.dtype, tuple(lnew_src), dont_expand_args))
|
||||
|
|
|
|||
|
|
@ -187,7 +187,7 @@ class ClangRenderer(CStyleLanguage):
|
|||
float4 = "make_float4"
|
||||
has_local = False
|
||||
global_max = None
|
||||
tensor_cores = [TensorCore(dims=(8,8,8), threads=[(0,8),(1,8)], thread_local_sizes=[[8],[8],[8,8]], dtype_in=dtypes.float, dtype_out=dtypes.float)]
|
||||
tensor_cores = [TensorCore(dims=(4,4,4), threads=[(0,4),(1,4)], thread_local_sizes=[[4],[4],[4,4]], dtype_in=dtypes.float, dtype_out=dtypes.float)]
|
||||
|
||||
# language options
|
||||
buffer_suffix = " restrict"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue