mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
reshape mask cleanups [pr] (#8064)
don't need canonicalize_st because we always merge 1 in `_merge_dims`
This commit is contained in:
parent
05dba6e4ee
commit
aefdff4ef5
1 changed files with 7 additions and 9 deletions
|
|
@ -331,14 +331,12 @@ class View:
|
|||
while resolve(acc <= merged_dim) and resolve(acc != merged_dim) and resolve((new_dim := next(r_new_shape, 0)) > 0):
|
||||
strides.append(new_stride)
|
||||
if resolve(new_dim != 1): new_stride *= (new_dim if resolve((acc := acc * new_dim) < real_dim) else 0)
|
||||
if resolve(acc != merged_dim): break
|
||||
else:
|
||||
strides += [0,] * (len(new_shape) - len(strides))
|
||||
new_mask = _reshape_mask(self.mask, self.shape, new_shape)
|
||||
if new_mask is not None:
|
||||
new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask), tuple(reversed(strides)))
|
||||
extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
|
||||
(sum(m[0] * s for m,s in zip(new_mask, new_strides)))
|
||||
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
|
||||
if resolve(acc != merged_dim): return None
|
||||
|
||||
if (new_mask:=_reshape_mask(self.mask, self.shape, new_shape)) is not None:
|
||||
new_strides = (0,) * (len(new_shape) - len(strides)) + tuple(strides[::-1])
|
||||
extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
|
||||
(sum(m[0] * s for m,s in zip(new_mask, new_strides)))
|
||||
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
|
||||
|
||||
return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue