use_tensor_cores bugfix (#9969)

This commit is contained in:
Ignacio Sica 2025-04-21 22:58:17 -03:00 committed by GitHub
commit 0e79aee706
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -302,6 +302,7 @@ class Kernel:
0: will disable any tensor core matching
1: enable tensor cores
2: apply tensor core shape but don't use UOp.WMMA
3: emulate tensor cores with local memory
extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
tc_select -- specifies which tensor core(s) to use for optimization (default -1)
-1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes)
@ -313,7 +314,7 @@ class Kernel:
"""
if tc_select is None: tc_select = TC_SELECT.value
if tc_opt is None: tc_opt = TC_OPT.value
if not self.opts.tensor_cores and use_tensor_cores != 2: return False
if not self.opts.tensor_cores: return False
try: # check TC first and apply hand-coded opts if successful
self.apply_opt(Opt(OptOps.TC, axis, (tc_select, tc_opt)))
@ -343,11 +344,12 @@ class Kernel:
if opt.op is OptOps.TC:
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
check((use_tensor_cores:=USE_TC.value) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
check(len(self.opts.tensor_cores) > 0, "must have tensor cores")
check(opt.axis is not None, "tensor core opts must have an axis")
check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 2, "tensor core opts must have tc_select and tc_opt")
check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select")
check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
check(0 < (use_tensor_cores:=USE_TC.value) <= 3, "use_tensor_cores value is not valid")
check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt), "no tensor core available")
self.applied_opts.append(opt)
return