Scheduler.reduceops helper [pr] (#13162)

This commit is contained in:
chenyu 2025-11-07 18:59:46 -05:00 committed by GitHub
commit 6a509da7f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -219,8 +219,7 @@ class Scheduler:
return ret
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> None|list[UOp]:
reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE]
if not len(reduceops): raise KernelOptError("no reduce ops for TensorCore")
if not (reduceops := self.reduceops): raise KernelOptError("no reduce ops for TensorCore")
reduceop = reduceops[0]
if use_tensor_cores and reduceop is not None and reduceop.arg is Ops.ADD:
mul = reduceop.src[0] if reduceop.src[0].op is not Ops.CAST else reduceop.src[0].src[0]
@ -314,9 +313,10 @@ class Scheduler:
# helpers for hand_coded_optimizations
@property
def reduceops(self) -> list[UOp]: return [x for x in self.ast.backward_slice if x.op is Ops.REDUCE]
@property
def reduceop(self) -> UOp|None:
red = [x for x in self.ast.backward_slice if x.op is Ops.REDUCE]
if not len(red): return None
if not (red := self.reduceops): return None
return UOp(Ops.REDUCE_AXIS, red[0].dtype, red[0].src, (red[0].arg, ()))
@property
def bufs(self) -> list[UOp]: return [x for x in self.ast.toposort() if x.op is Ops.INDEX][::-1]