mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
unused_per
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d39b854a2d |
1 changed files with 2 additions and 2 deletions
|
|
@ -256,12 +256,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
def valid(self): return UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)
|
||||
@staticmethod
|
||||
def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx)
|
||||
def r(self, op:Ops, axis:tuple[int, ...], permute=True):
|
||||
def r(self, op:Ops, axis:tuple[int, ...]):
|
||||
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
||||
if len(axis) == 0: return self
|
||||
# move any non reduce axis before the first reduce axis
|
||||
move_early, rest = partition(range(axis[0], len(self.shape)), lambda i: i not in axis and resolve(self.shape[i] != 1))
|
||||
if move_early and permute:
|
||||
if move_early:
|
||||
permaxis = tuple(range(axis[0])) + tuple(move_early) + tuple(rest)
|
||||
ret = self.permute(permaxis)
|
||||
new_axis = tuple([x for x in range(axis[0]+len(move_early), len(self.shape)) if resolve(ret.shape[x] != 1)])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue