late remove ones

This commit is contained in:
George Hotz 2025-07-25 15:51:56 -07:00
commit dfb3e99b09

View file

@ -451,6 +451,8 @@ class Kernel:
ret = op.replace(src=tuple(fixup_ast(x) for x in op.src)) # noqa: F821
if op.op in GroupOp.Buffer and op in self.bufs:
st = self.sts[self.bufs.index(op)]
# late remove all ones
st = st.reshape(tuple([x for x in st.shape if resolve(x != 1)]))
# NOTE: if CONST got masked after applying opts, we create a new VALID
if op.op is Ops.CONST and any(v.mask is not None for v in st.views): return op.view(st).valid()
# otherwise we just replace the VIEW source