lazy srcs shape mistmatch assert + fix ASSIGN [pr] (#8014)

* lazy srcs shape mistmatch assert [pr]

* duplicate assert

* base it later

* keep the assert
This commit is contained in:
qazal 2024-12-03 15:40:37 -05:00 committed by GitHub
commit 099364ed32
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 3 additions and 2 deletions

View file

@ -32,6 +32,7 @@ class LazyBuffer(MathTrait):
# properties on base
self.op, self.arg, self.srcs = op, arg, srcs # this is a UOp, except the src is LazyBuffers and not UOps
assert self.op is not Ops.ASSIGN or srcs[0].base.realized is not None, "assign target must be realized"
assert all_same([x.st.shape for x in self.srcs]), f"src shape mismatch! {self.srcs}"
if self.op is Ops.BUFFER_VIEW:
# some LazyBuffers can be processed with only a view, no AST required
@ -81,7 +82,7 @@ class LazyBuffer(MathTrait):
assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
assert self.is_realized, f"assign target must be realized {self}"
return LazyBuffer.metaop(Ops.ASSIGN, self.shape, self.dtype, self.device, arg=None if self.st.contiguous else self.st,
src=(self.base, x), enable_cache=True)
src=(self, x), enable_cache=True) # NOTE: assign to VIEW is fine
def can_view(self):
return (self.st.consecutive and not self.is_unrealized_const() and not isinstance(self.dtype, ImageDType) and

View file

@ -70,7 +70,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache
op = None
elif buf.op is Ops.ASSIGN:
target, new_val = [to_uop(x, ctx, buffers, cache) for x in buf.srcs]
ctx.assigns.add(ubuf:=target.buf_uop)
ctx.assigns.add(ubuf:=target.base.buf_uop)
op = UOp(Ops.ASSIGN, dtype.base, (ubuf, new_val), buf.arg)
else:
ubuf = UOp.new_buffer(buf.device, buf.size, dtype)