mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
cleanup tensor cores, expose exclude local upcast (#2064)
* expose exclude_local_upcast * convert apply tensor cores to ops * update comment * put LOCAL back to what it was, BEAM is better than way
This commit is contained in:
parent
6e4a12ab68
commit
4124cf1df5
3 changed files with 68 additions and 66 deletions
|
|
@ -8,16 +8,28 @@ if __name__ == "__main__":
|
|||
tactions = set()
|
||||
for ast_str in tqdm(ast_strs):
|
||||
lin = ast_str_to_lin(ast_str)
|
||||
if True or not lin.apply_tensor_cores(): lin.hand_coded_optimizations()
|
||||
if lin.apply_tensor_cores_old():
|
||||
linr = ast_str_to_lin(ast_str)
|
||||
linr.apply_tensor_cores()
|
||||
else:
|
||||
lin.hand_coded_optimizations()
|
||||
continue
|
||||
#lin.hand_coded_optimizations_old()
|
||||
#linr = ast_str_to_lin(ast_str)
|
||||
|
||||
"""
|
||||
if True or not lin.apply_tensor_cores():
|
||||
lin.hand_coded_optimizations()
|
||||
linr = Linearizer(lin.ast)
|
||||
for o in lin.applied_opts:
|
||||
assert o in actions
|
||||
tactions.add(o)
|
||||
linr.apply_opt(o)
|
||||
"""
|
||||
|
||||
assert len(lin.sts) == len(linr.sts)
|
||||
for st1,st2 in zip(lin.sts, linr.sts):
|
||||
assert st1 == st2
|
||||
assert st1 == st2, f"{st1} != {st2}"
|
||||
|
||||
#lin.linearize()
|
||||
#linr.linearize()
|
||||
|
|
|
|||
|
|
@ -131,8 +131,9 @@ class Kernel:
|
|||
@property
|
||||
def global_dims(self) -> int: return self.first_reduce-self.local_dims
|
||||
|
||||
# there's seven chunks of the shape
|
||||
# there's eight chunks of the shape
|
||||
# blue -- global dims
|
||||
# CYAN -- excluded local dims (non-warp)
|
||||
# cyan -- local dims
|
||||
# *** self.first_reduce
|
||||
# green -- reduce-local dims
|
||||
|
|
@ -142,10 +143,12 @@ class Kernel:
|
|||
# purple -- reduce upcasted
|
||||
# yellow -- normal upcasted dimensions
|
||||
def colors(self) -> List[str]:
|
||||
# up to first_reduce, they are all global (blue)
|
||||
# first non local non reduce dims are global (blue)
|
||||
colors = ["blue"] * self.global_dims
|
||||
# some special local_dims are excluded from the local upcast
|
||||
colors += ["CYAN"] * self.exclude_local_upcast
|
||||
# except the local_dims, these are non-reduce locals (cyan)
|
||||
colors += ["cyan"] * self.local_dims
|
||||
colors += ["cyan"] * (self.local_dims - self.exclude_local_upcast)
|
||||
# between first_reduce and first_reduce + group_for_reduce, they are either local (cyan), or late upcasted (green)
|
||||
colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + len(self.group_for_reduce))]
|
||||
# between first_reduce + group_for_reduce and upcasted, they are reduce (red)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from tinygrad.shape.view import View, strides_for_shape
|
|||
from enum import Enum, auto
|
||||
|
||||
class OptOps(Enum):
|
||||
UPCAST = auto(); UNROLL = auto(); LOCAL = auto(); GROUP = auto(); GROUPTOP = auto() # noqa: E702
|
||||
UPCAST = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto() # noqa: E702
|
||||
def __lt__(self, x:OptOps): return self.value < x.value
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
|
|
@ -65,14 +65,15 @@ class OptimizedKernel(Kernel):
|
|||
|
||||
# ******************** complex simplifiers ********************
|
||||
|
||||
def simplify_ones(self):
|
||||
def simplify_ones(self) -> bool:
|
||||
# remove places where the shape is all ones
|
||||
# TODO: this should be factored in to multi shape stride
|
||||
if self.shape_len == 0: return
|
||||
if self.shape_len == 0: return False
|
||||
all_ones = [s==1 for s in self.full_shape]
|
||||
self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
|
||||
self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:])
|
||||
self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
|
||||
return any(all_ones)
|
||||
|
||||
def simplify_merge_adjacent(self):
|
||||
if self.shape_len == 0: return
|
||||
|
|
@ -232,82 +233,68 @@ class OptimizedKernel(Kernel):
|
|||
buf1_strides = self.sts[buf1].real_strides()
|
||||
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides) if s == 0 and self.full_shape[i]%8 == 0 and i < self.first_reduce]
|
||||
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides) if s == 0 and self.full_shape[i]%8 == 0 and i < self.first_reduce]
|
||||
#optim_conv2d = (self.shape_len-self.first_reduce) == 3 and self.full_shape[self.first_reduce+1]%2 == 1 and self.full_shape[self.first_reduce+2]%2 == 1 and max(self.full_shape[self.first_reduce+1:self.first_reduce+3]) < 21
|
||||
# enabling this gives wrong answers!! https://github.com/tinygrad/tinygrad/issues/1967
|
||||
# TODO: WMMA must be a lot better before reenabling things like this
|
||||
optim_conv2d = False
|
||||
if axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%8 == 0 and (self.shape_len-self.first_reduce == 1 or optim_conv2d):
|
||||
if axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%8 == 0 and self.shape_len-self.first_reduce == 1:
|
||||
if DEBUG >= 3: print("METAL TENSOR CORES", axis_buf0, axis_buf1)
|
||||
self.use_tensor_cores = use_tensor_cores == 1 # TC=2 will do the shape ops without the WMMA
|
||||
|
||||
# TODO: select axis in smart way
|
||||
s0, s1 = axis_buf0[-1][0], axis_buf1[-1][0]
|
||||
global_count = self.first_reduce
|
||||
s0_exists, s1_exists = True, True
|
||||
assert s0 != s1 and self.full_shape[s0]%8 == 0 and self.full_shape[s1]%8 == 0
|
||||
def fix(needed, ax):
|
||||
nonlocal s0, s1, s0_exists, s1_exists
|
||||
if not needed: return
|
||||
if s0_exists and ax == s0:
|
||||
if s1_exists and s0 < s1: s1 -= 1
|
||||
s0_exists = False
|
||||
elif s1_exists and ax == s1:
|
||||
if s0_exists and s1 < s0: s0 -= 1
|
||||
s1_exists = False
|
||||
|
||||
# upcast first
|
||||
if self.full_shape[self.first_reduce] > 8: self.shift_to(self.first_reduce, 8)
|
||||
self.upcast()
|
||||
# tensor core (6 ops) -- creates the (2,2,4,2) pattern, an upcasted 2, and a unrolled 8
|
||||
self.apply_opt(Opt(OptOps.UNROLL, 0, 8))
|
||||
self.apply_opt(Opt(OptOps.UPCAST, s0, 2))
|
||||
fix(self.apply_opt(Opt(OptOps.LOCAL, s1, 8)), s1)
|
||||
fix(self.apply_opt(Opt(OptOps.LOCAL, s0, 4)), s0)
|
||||
self.apply_opt(Opt(OptOps.LOCAL, self.global_dims, 4))
|
||||
self.apply_opt(Opt(OptOps.LOCAL, self.global_dims+1, 2))
|
||||
|
||||
# 2 locals
|
||||
self.shift_to(s1, 8, insert_before=self.first_reduce) # axis 2
|
||||
self.shift_to(s0, 8, insert_before=self.first_reduce) # axis 3
|
||||
# final optional global upcast
|
||||
if s1_exists:
|
||||
s1_div = [upc for upc in [4,3,2,1] if self.full_shape[s1]%upc == 0][0]
|
||||
fix(self.apply_opt(Opt(OptOps.UPCAST, s1, s1_div)), s1)
|
||||
if s0_exists:
|
||||
s0_div = [upc for upc in [4,3,2,1] if self.full_shape[s0]%upc == 0][0]
|
||||
fix(self.apply_opt(Opt(OptOps.UPCAST, s0, s0_div)), s0)
|
||||
|
||||
# permuted+upcast for tensor cores
|
||||
self.shift_to(global_count, 4, insert_before=self.first_reduce)
|
||||
self.shift_to(global_count+1, 4, insert_before=self.first_reduce)
|
||||
self.shift_to(self.first_reduce-1, 2)
|
||||
self.upcast()
|
||||
|
||||
# final global upcast
|
||||
if not optim_conv2d:
|
||||
for ax in [s1, s0]:
|
||||
for upc in [4,3,2]:
|
||||
if self.full_shape[ax]%upc == 0:
|
||||
self.shift_to(ax, upc)
|
||||
self.upcast()
|
||||
break
|
||||
# very late (optional) upcast to run group at the same time. only if actually using real tensor cores, otherwise local isn't a simdgroup
|
||||
self.use_tensor_cores = use_tensor_cores == 1 # TC=2 will do the shape ops without the WMMA
|
||||
if self.use_tensor_cores and s0_exists and self.full_shape[s0] % 2 == 0:
|
||||
self.apply_opt(Opt(OptOps.LASTLOCAL, s0, 2))
|
||||
self.exclude_local_upcast += 1
|
||||
|
||||
# alias buffer
|
||||
self.local_dims = self.first_reduce - global_count
|
||||
alias_pattern = [0]*global_count + [2] * self.local_dims + [0] * (self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3] * (self.upcasted-2)
|
||||
alias_pattern = [0]*(self.global_dims+self.exclude_local_upcast) + [2]*(self.local_dims-self.exclude_local_upcast) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2)
|
||||
self.alias_buffer(buf0, alias_pattern)
|
||||
self.alias_buffer(buf1, alias_pattern)
|
||||
|
||||
if self.use_tensor_cores and optim_conv2d:
|
||||
self.upcast()
|
||||
|
||||
if max(self.full_shape[self.first_reduce+1:self.first_reduce+3]) < 5:
|
||||
self.upcast()
|
||||
for upc in range(8, 1, -1):
|
||||
if self.full_shape[global_count-2]%upc == 0:
|
||||
self.shift_to(global_count-2, upc)
|
||||
self.upcast()
|
||||
break
|
||||
else:
|
||||
for upc in range(16, 1, -1):
|
||||
if self.full_shape[global_count-1]%upc == 0:
|
||||
self.shift_to(global_count-1, upc)
|
||||
self.upcast()
|
||||
break
|
||||
|
||||
# very late upcast to run group at the same time. only if actually using real tensor cores, otherwise local isn't a simdgroup
|
||||
if self.use_tensor_cores and self.full_shape[s0] % 2 == 0:
|
||||
self.shift_to(s0, 2, insert_before=self.first_reduce-self.local_dims)
|
||||
self.local_dims += 1
|
||||
self.exclude_local_upcast += 1
|
||||
|
||||
# early exit
|
||||
return True
|
||||
return False
|
||||
|
||||
def apply_opt(self, opt:Opt):
|
||||
assert opt.amt is not None and opt.amt != 0, "amount can't be 0"
|
||||
if opt.amt == 1: return False
|
||||
self.applied_opts.append(opt)
|
||||
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else 0)
|
||||
assert self.full_shape[axis] % opt.amt == 0, "no longer valid shift"
|
||||
if opt.op == OptOps.LOCAL: # cyan
|
||||
assert axis < (self.first_reduce-self.local_dims), "can't local a local or reduce"
|
||||
assert axis < self.first_reduce, "can't local a reduce"
|
||||
self.shift_to(axis, opt.amt, insert_before=self.first_reduce)
|
||||
self.local_dims += 1
|
||||
elif opt.op == OptOps.LASTLOCAL: # cyan
|
||||
assert axis < self.first_reduce, "can't local a reduce"
|
||||
self.shift_to(axis, opt.amt, insert_before=self.first_reduce-self.local_dims)
|
||||
self.local_dims += 1
|
||||
# TOOD: include exclude_local_upcast here
|
||||
elif opt.op == OptOps.GROUP: # green
|
||||
self.shift_to(axis, opt.amt, insert_before=self.first_reduce + len(self.group_for_reduce))
|
||||
self.group_for_reduce.append(opt.amt)
|
||||
|
|
@ -322,7 +309,7 @@ class OptimizedKernel(Kernel):
|
|||
assert axis < self.first_reduce, "upcast is for non-reduce"
|
||||
self.shift_to(axis, opt.amt, insert_before=None)
|
||||
self.upcast()
|
||||
self.simplify_ones()
|
||||
return self.simplify_ones()
|
||||
|
||||
def required_optimizations(self, early_only=False):
|
||||
for buf_index,buf in enumerate(self.bufs):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue