mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
simpler_re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
872ca998db |
1 changed files with 4 additions and 10 deletions
|
|
@ -257,18 +257,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx)
|
def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx)
|
||||||
def r(self, op:Ops, axis:tuple[int, ...], permute=True):
|
def r(self, op:Ops, axis:tuple[int, ...], permute=True):
|
||||||
|
assert permute
|
||||||
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
||||||
if len(axis) == 0: return self
|
if len(axis) == 0: return self
|
||||||
# move any non reduce axis before the first reduce axis
|
ret = self.permute(tuple([i for i in range(len(self.shape)) if i not in axis])+axis)
|
||||||
move_early, rest = partition(range(axis[0], len(self.shape)), lambda i: i not in axis and resolve(self.shape[i] != 1))
|
ret = ret.reshape(tuple([s for s in ret.shape if resolve(s != 1)]))
|
||||||
if move_early and permute:
|
ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, tuple(range(len(ret.shape)-len(axis), len(ret.shape)))))
|
||||||
permaxis = tuple(range(axis[0])) + tuple(move_early) + tuple(rest)
|
|
||||||
ret = self.permute(permaxis)
|
|
||||||
new_axis = tuple([x for x in range(axis[0]+len(move_early), len(self.shape)) if resolve(ret.shape[x] != 1)])
|
|
||||||
assert len(axis) == len(new_axis)
|
|
||||||
else:
|
|
||||||
ret, new_axis = self, axis
|
|
||||||
ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, new_axis))
|
|
||||||
return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)]))
|
return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)]))
|
||||||
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
||||||
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
|
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue