Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
d39b854a2d unused permute arg on r 2025-07-25 19:31:47 -07:00

View file

@ -256,12 +256,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def valid(self): return UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)
@staticmethod
def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx)
def r(self, op:Ops, axis:tuple[int, ...], permute=True):
def r(self, op:Ops, axis:tuple[int, ...]):
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
if len(axis) == 0: return self
# move any non reduce axis before the first reduce axis
move_early, rest = partition(range(axis[0], len(self.shape)), lambda i: i not in axis and resolve(self.shape[i] != 1))
if move_early and permute:
if move_early:
permaxis = tuple(range(axis[0])) + tuple(move_early) + tuple(rest)
ret = self.permute(permaxis)
new_axis = tuple([x for x in range(axis[0]+len(move_early), len(self.shape)) if resolve(ret.shape[x] != 1)])