reshape mask cleanups [pr] (#8064)

don't need canonicalize_st because we always merge 1 in `_merge_dims`
This commit is contained in:
chenyu 2024-12-05 20:20:43 -05:00 committed by GitHub
commit aefdff4ef5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -331,14 +331,12 @@ class View:
while resolve(acc <= merged_dim) and resolve(acc != merged_dim) and resolve((new_dim := next(r_new_shape, 0)) > 0):
strides.append(new_stride)
if resolve(new_dim != 1): new_stride *= (new_dim if resolve((acc := acc * new_dim) < real_dim) else 0)
if resolve(acc != merged_dim): break
else:
strides += [0,] * (len(new_shape) - len(strides))
new_mask = _reshape_mask(self.mask, self.shape, new_shape)
if new_mask is not None:
new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask), tuple(reversed(strides)))
extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
(sum(m[0] * s for m,s in zip(new_mask, new_strides)))
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
if resolve(acc != merged_dim): return None
if (new_mask:=_reshape_mask(self.mask, self.shape, new_shape)) is not None:
new_strides = (0,) * (len(new_shape) - len(strides)) + tuple(strides[::-1])
extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
(sum(m[0] * s for m,s in zip(new_mask, new_strides)))
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
return None