mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
minor cleanup of View.reshape (#3088)
* minor cleanup of View.reshape removed some redundant logic * new_strides * revert that
This commit is contained in:
parent
f40299c3fe
commit
f502c9b08f
1 changed files with 11 additions and 9 deletions
|
|
@ -181,18 +181,20 @@ class View:
|
|||
if self.contiguous: return View.create(new_shape)
|
||||
|
||||
strides, r_new_shape = [], reversed(new_shape)
|
||||
for merged_dim, s, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
|
||||
acc, new_stride = 1, s
|
||||
for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
|
||||
acc = 1
|
||||
# TODO: this <= and != is for symbolic!?
|
||||
while acc <= merged_dim and acc != merged_dim and (new_dim := next(r_new_shape, None)):
|
||||
strides.append(new_stride if new_dim != 1 else 0)
|
||||
if new_dim == 1: continue
|
||||
new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0)
|
||||
strides.append(new_stride)
|
||||
if new_dim != 1: new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0)
|
||||
if acc != merged_dim: break
|
||||
else:
|
||||
strides += [0,] * (len(new_shape) - len(strides))
|
||||
mask, extra = _reshape_mask(self, new_shape)
|
||||
cstrides = canonicalize_strides(tuple(e-b for b,e in mask) if mask else new_shape, tuple(reversed(strides)))
|
||||
extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - (sum(m[0] * s for m,s in zip(mask, cstrides)) if mask else 0) # noqa: E501
|
||||
if not extra: return View.create(new_shape, cstrides, self.offset + extra_offset, mask)
|
||||
new_mask, extra = _reshape_mask(self, new_shape)
|
||||
if not extra:
|
||||
new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask) if new_mask else new_shape, tuple(reversed(strides)))
|
||||
extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
|
||||
(sum(m[0] * s for m,s in zip(new_mask, new_strides)) if new_mask else 0)
|
||||
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
|
||||
|
||||
return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue