Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
872ca998db simpler reduce that just removes ones 2025-07-25 19:17:38 -07:00

View file

@ -257,18 +257,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@staticmethod @staticmethod
def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx) def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx)
def r(self, op:Ops, axis:tuple[int, ...], permute=True): def r(self, op:Ops, axis:tuple[int, ...], permute=True):
assert permute
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)])) axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
if len(axis) == 0: return self if len(axis) == 0: return self
# move any non reduce axis before the first reduce axis ret = self.permute(tuple([i for i in range(len(self.shape)) if i not in axis])+axis)
move_early, rest = partition(range(axis[0], len(self.shape)), lambda i: i not in axis and resolve(self.shape[i] != 1)) ret = ret.reshape(tuple([s for s in ret.shape if resolve(s != 1)]))
if move_early and permute: ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, tuple(range(len(ret.shape)-len(axis), len(ret.shape)))))
permaxis = tuple(range(axis[0])) + tuple(move_early) + tuple(rest)
ret = self.permute(permaxis)
new_axis = tuple([x for x in range(axis[0]+len(move_early), len(self.shape)) if resolve(ret.shape[x] != 1)])
assert len(axis) == len(new_axis)
else:
ret, new_axis = self, axis
ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, new_axis))
return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)])) return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)]))
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs) def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
def contiguous(self): return self.alu(Ops.CONTIGUOUS) def contiguous(self): return self.alu(Ops.CONTIGUOUS)