cleanup tensor cores, expose exclude local upcast (#2064)

* expose exclude_local_upcast

* convert apply tensor cores to ops

* update comment

* put LOCAL back to what it was, BEAM is better than way
This commit is contained in:
George Hotz 2023-10-14 09:21:03 -07:00 committed by GitHub
commit 4124cf1df5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 68 additions and 66 deletions

View file

@ -8,16 +8,28 @@ if __name__ == "__main__":
tactions = set()
for ast_str in tqdm(ast_strs):
lin = ast_str_to_lin(ast_str)
if True or not lin.apply_tensor_cores(): lin.hand_coded_optimizations()
if lin.apply_tensor_cores_old():
linr = ast_str_to_lin(ast_str)
linr.apply_tensor_cores()
else:
lin.hand_coded_optimizations()
continue
#lin.hand_coded_optimizations_old()
#linr = ast_str_to_lin(ast_str)
"""
if True or not lin.apply_tensor_cores():
lin.hand_coded_optimizations()
linr = Linearizer(lin.ast)
for o in lin.applied_opts:
assert o in actions
tactions.add(o)
linr.apply_opt(o)
"""
assert len(lin.sts) == len(linr.sts)
for st1,st2 in zip(lin.sts, linr.sts):
assert st1 == st2
assert st1 == st2, f"{st1} != {st2}"
#lin.linearize()
#linr.linearize()

View file

@ -131,8 +131,9 @@ class Kernel:
@property
def global_dims(self) -> int: return self.first_reduce-self.local_dims
# there's seven chunks of the shape
# there's eight chunks of the shape
# blue -- global dims
# CYAN -- excluded local dims (non-warp)
# cyan -- local dims
# *** self.first_reduce
# green -- reduce-local dims
@ -142,10 +143,12 @@ class Kernel:
# purple -- reduce upcasted
# yellow -- normal upcasted dimensions
def colors(self) -> List[str]:
# up to first_reduce, they are all global (blue)
# first non local non reduce dims are global (blue)
colors = ["blue"] * self.global_dims
# some special local_dims are excluded from the local upcast
colors += ["CYAN"] * self.exclude_local_upcast
# except the local_dims, these are non-reduce locals (cyan)
colors += ["cyan"] * self.local_dims
colors += ["cyan"] * (self.local_dims - self.exclude_local_upcast)
# between first_reduce and first_reduce + group_for_reduce, they are either local (cyan), or late upcasted (green)
colors += ["white" if i in self.upcast_in_mid_reduce_axes else "green" for i in range(self.first_reduce, self.first_reduce + len(self.group_for_reduce))]
# between first_reduce + group_for_reduce and upcasted, they are reduce (red)

View file

@ -10,7 +10,7 @@ from tinygrad.shape.view import View, strides_for_shape
from enum import Enum, auto
class OptOps(Enum):
UPCAST = auto(); UNROLL = auto(); LOCAL = auto(); GROUP = auto(); GROUPTOP = auto() # noqa: E702
UPCAST = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto(); GROUP = auto(); GROUPTOP = auto() # noqa: E702
def __lt__(self, x:OptOps): return self.value < x.value
@dataclass(frozen=True, order=True)
@ -65,14 +65,15 @@ class OptimizedKernel(Kernel):
# ******************** complex simplifiers ********************
def simplify_ones(self):
def simplify_ones(self) -> bool:
# remove places where the shape is all ones
# TODO: this should be factored in to multi shape stride
if self.shape_len == 0: return
if self.shape_len == 0: return False
all_ones = [s==1 for s in self.full_shape]
self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce])
self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:])
self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
return any(all_ones)
def simplify_merge_adjacent(self):
if self.shape_len == 0: return
@ -232,82 +233,68 @@ class OptimizedKernel(Kernel):
buf1_strides = self.sts[buf1].real_strides()
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides) if s == 0 and self.full_shape[i]%8 == 0 and i < self.first_reduce]
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides) if s == 0 and self.full_shape[i]%8 == 0 and i < self.first_reduce]
#optim_conv2d = (self.shape_len-self.first_reduce) == 3 and self.full_shape[self.first_reduce+1]%2 == 1 and self.full_shape[self.first_reduce+2]%2 == 1 and max(self.full_shape[self.first_reduce+1:self.first_reduce+3]) < 21
# enabling this gives wrong answers!! https://github.com/tinygrad/tinygrad/issues/1967
# TODO: WMMA must be a lot better before reenabling things like this
optim_conv2d = False
if axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%8 == 0 and (self.shape_len-self.first_reduce == 1 or optim_conv2d):
if axis_buf0 and axis_buf1 and self.full_shape[self.first_reduce]%8 == 0 and self.shape_len-self.first_reduce == 1:
if DEBUG >= 3: print("METAL TENSOR CORES", axis_buf0, axis_buf1)
self.use_tensor_cores = use_tensor_cores == 1 # TC=2 will do the shape ops without the WMMA
# TODO: select axis in smart way
s0, s1 = axis_buf0[-1][0], axis_buf1[-1][0]
global_count = self.first_reduce
s0_exists, s1_exists = True, True
assert s0 != s1 and self.full_shape[s0]%8 == 0 and self.full_shape[s1]%8 == 0
def fix(needed, ax):
nonlocal s0, s1, s0_exists, s1_exists
if not needed: return
if s0_exists and ax == s0:
if s1_exists and s0 < s1: s1 -= 1
s0_exists = False
elif s1_exists and ax == s1:
if s0_exists and s1 < s0: s0 -= 1
s1_exists = False
# upcast first
if self.full_shape[self.first_reduce] > 8: self.shift_to(self.first_reduce, 8)
self.upcast()
# tensor core (6 ops) -- creates the (2,2,4,2) pattern, an upcasted 2, and a unrolled 8
self.apply_opt(Opt(OptOps.UNROLL, 0, 8))
self.apply_opt(Opt(OptOps.UPCAST, s0, 2))
fix(self.apply_opt(Opt(OptOps.LOCAL, s1, 8)), s1)
fix(self.apply_opt(Opt(OptOps.LOCAL, s0, 4)), s0)
self.apply_opt(Opt(OptOps.LOCAL, self.global_dims, 4))
self.apply_opt(Opt(OptOps.LOCAL, self.global_dims+1, 2))
# 2 locals
self.shift_to(s1, 8, insert_before=self.first_reduce) # axis 2
self.shift_to(s0, 8, insert_before=self.first_reduce) # axis 3
# final optional global upcast
if s1_exists:
s1_div = [upc for upc in [4,3,2,1] if self.full_shape[s1]%upc == 0][0]
fix(self.apply_opt(Opt(OptOps.UPCAST, s1, s1_div)), s1)
if s0_exists:
s0_div = [upc for upc in [4,3,2,1] if self.full_shape[s0]%upc == 0][0]
fix(self.apply_opt(Opt(OptOps.UPCAST, s0, s0_div)), s0)
# permuted+upcast for tensor cores
self.shift_to(global_count, 4, insert_before=self.first_reduce)
self.shift_to(global_count+1, 4, insert_before=self.first_reduce)
self.shift_to(self.first_reduce-1, 2)
self.upcast()
# final global upcast
if not optim_conv2d:
for ax in [s1, s0]:
for upc in [4,3,2]:
if self.full_shape[ax]%upc == 0:
self.shift_to(ax, upc)
self.upcast()
break
# very late (optional) upcast to run group at the same time. only if actually using real tensor cores, otherwise local isn't a simdgroup
self.use_tensor_cores = use_tensor_cores == 1 # TC=2 will do the shape ops without the WMMA
if self.use_tensor_cores and s0_exists and self.full_shape[s0] % 2 == 0:
self.apply_opt(Opt(OptOps.LASTLOCAL, s0, 2))
self.exclude_local_upcast += 1
# alias buffer
self.local_dims = self.first_reduce - global_count
alias_pattern = [0]*global_count + [2] * self.local_dims + [0] * (self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3] * (self.upcasted-2)
alias_pattern = [0]*(self.global_dims+self.exclude_local_upcast) + [2]*(self.local_dims-self.exclude_local_upcast) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2)
self.alias_buffer(buf0, alias_pattern)
self.alias_buffer(buf1, alias_pattern)
if self.use_tensor_cores and optim_conv2d:
self.upcast()
if max(self.full_shape[self.first_reduce+1:self.first_reduce+3]) < 5:
self.upcast()
for upc in range(8, 1, -1):
if self.full_shape[global_count-2]%upc == 0:
self.shift_to(global_count-2, upc)
self.upcast()
break
else:
for upc in range(16, 1, -1):
if self.full_shape[global_count-1]%upc == 0:
self.shift_to(global_count-1, upc)
self.upcast()
break
# very late upcast to run group at the same time. only if actually using real tensor cores, otherwise local isn't a simdgroup
if self.use_tensor_cores and self.full_shape[s0] % 2 == 0:
self.shift_to(s0, 2, insert_before=self.first_reduce-self.local_dims)
self.local_dims += 1
self.exclude_local_upcast += 1
# early exit
return True
return False
def apply_opt(self, opt:Opt):
assert opt.amt is not None and opt.amt != 0, "amount can't be 0"
if opt.amt == 1: return False
self.applied_opts.append(opt)
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else 0)
assert self.full_shape[axis] % opt.amt == 0, "no longer valid shift"
if opt.op == OptOps.LOCAL: # cyan
assert axis < (self.first_reduce-self.local_dims), "can't local a local or reduce"
assert axis < self.first_reduce, "can't local a reduce"
self.shift_to(axis, opt.amt, insert_before=self.first_reduce)
self.local_dims += 1
elif opt.op == OptOps.LASTLOCAL: # cyan
assert axis < self.first_reduce, "can't local a reduce"
self.shift_to(axis, opt.amt, insert_before=self.first_reduce-self.local_dims)
self.local_dims += 1
# TOOD: include exclude_local_upcast here
elif opt.op == OptOps.GROUP: # green
self.shift_to(axis, opt.amt, insert_before=self.first_reduce + len(self.group_for_reduce))
self.group_for_reduce.append(opt.amt)
@ -322,7 +309,7 @@ class OptimizedKernel(Kernel):
assert axis < self.first_reduce, "upcast is for non-reduce"
self.shift_to(axis, opt.amt, insert_before=None)
self.upcast()
self.simplify_ones()
return self.simplify_ones()
def required_optimizations(self, early_only=False):
for buf_index,buf in enumerate(self.bufs):