mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
use_tensor_cores bugfix (#9969)
This commit is contained in:
parent
5294c32279
commit
0e79aee706
1 changed files with 4 additions and 2 deletions
|
|
@ -302,6 +302,7 @@ class Kernel:
|
|||
0: will disable any tensor core matching
|
||||
1: enable tensor cores
|
||||
2: apply tensor core shape but don't use UOp.WMMA
|
||||
3: emulate tensor cores with local memory
|
||||
extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
|
||||
tc_select -- specifies which tensor core(s) to use for optimization (default -1)
|
||||
-1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes)
|
||||
|
|
@ -313,7 +314,7 @@ class Kernel:
|
|||
"""
|
||||
if tc_select is None: tc_select = TC_SELECT.value
|
||||
if tc_opt is None: tc_opt = TC_OPT.value
|
||||
if not self.opts.tensor_cores and use_tensor_cores != 2: return False
|
||||
if not self.opts.tensor_cores: return False
|
||||
try: # check TC first and apply hand-coded opts if successful
|
||||
self.apply_opt(Opt(OptOps.TC, axis, (tc_select, tc_opt)))
|
||||
|
||||
|
|
@ -343,11 +344,12 @@ class Kernel:
|
|||
|
||||
if opt.op is OptOps.TC:
|
||||
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
|
||||
check((use_tensor_cores:=USE_TC.value) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
|
||||
check(len(self.opts.tensor_cores) > 0, "must have tensor cores")
|
||||
check(opt.axis is not None, "tensor core opts must have an axis")
|
||||
check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 2, "tensor core opts must have tc_select and tc_opt")
|
||||
check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select")
|
||||
check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
|
||||
check(0 < (use_tensor_cores:=USE_TC.value) <= 3, "use_tensor_cores value is not valid")
|
||||
check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt), "no tensor core available")
|
||||
self.applied_opts.append(opt)
|
||||
return
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue