minor cleanup of View.reshape (#3088)

* minor cleanup of View.reshape

removed some redundant logic

* new_strides

* revert that
This commit is contained in:
chenyu 2024-01-11 13:05:54 -05:00 committed by GitHub
commit f502c9b08f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -181,18 +181,20 @@ class View:
if self.contiguous: return View.create(new_shape)
strides, r_new_shape = [], reversed(new_shape)
for merged_dim, s, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
acc, new_stride = 1, s
for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
acc = 1
# TODO: this <= and != is for symbolic!?
while acc <= merged_dim and acc != merged_dim and (new_dim := next(r_new_shape, None)):
strides.append(new_stride if new_dim != 1 else 0)
if new_dim == 1: continue
new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0)
strides.append(new_stride)
if new_dim != 1: new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0)
if acc != merged_dim: break
else:
strides += [0,] * (len(new_shape) - len(strides))
mask, extra = _reshape_mask(self, new_shape)
cstrides = canonicalize_strides(tuple(e-b for b,e in mask) if mask else new_shape, tuple(reversed(strides)))
extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - (sum(m[0] * s for m,s in zip(mask, cstrides)) if mask else 0) # noqa: E501
if not extra: return View.create(new_shape, cstrides, self.offset + extra_offset, mask)
new_mask, extra = _reshape_mask(self, new_shape)
if not extra:
new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask) if new_mask else new_shape, tuple(reversed(strides)))
extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
(sum(m[0] * s for m,s in zip(new_mask, new_strides)) if new_mask else 0)
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
return None