mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
simpler_re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
872ca998db |
1 changed files with 4 additions and 10 deletions
|
|
@ -257,18 +257,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
@staticmethod
|
||||
def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx)
|
||||
def r(self, op:Ops, axis:tuple[int, ...], permute=True):
|
||||
assert permute
|
||||
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
||||
if len(axis) == 0: return self
|
||||
# move any non reduce axis before the first reduce axis
|
||||
move_early, rest = partition(range(axis[0], len(self.shape)), lambda i: i not in axis and resolve(self.shape[i] != 1))
|
||||
if move_early and permute:
|
||||
permaxis = tuple(range(axis[0])) + tuple(move_early) + tuple(rest)
|
||||
ret = self.permute(permaxis)
|
||||
new_axis = tuple([x for x in range(axis[0]+len(move_early), len(self.shape)) if resolve(ret.shape[x] != 1)])
|
||||
assert len(axis) == len(new_axis)
|
||||
else:
|
||||
ret, new_axis = self, axis
|
||||
ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, new_axis))
|
||||
ret = self.permute(tuple([i for i in range(len(self.shape)) if i not in axis])+axis)
|
||||
ret = ret.reshape(tuple([s for s in ret.shape if resolve(s != 1)]))
|
||||
ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, tuple(range(len(ret.shape)-len(axis), len(ret.shape)))))
|
||||
return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)]))
|
||||
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
||||
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue