small changes from new codegen (#16681)

* small changes from new codegen

* revert that
This commit is contained in:
George Hotz 2026-06-19 18:29:01 -07:00 committed by GitHub
commit 30830850a9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 3 additions and 4 deletions

View file

@ -1062,9 +1062,9 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
ret = UOp(Ops.BUFFER, dtype.ptr(prod(shape), addrspace), src=(shape_to_shape_arg(buf_shape),), arg=ParamArg(slot, addrspace=addrspace))
if len(shape) > 1: ret = ret.reshape(shape + ((dtype.count,) if addrspace in (AddrSpace.LOCAL, AddrSpace.REG) and dtype.count > 1 else ()))
return ret
def placeholder_like(self, slot:int):
def placeholder_like(self, slot:int, addrspace=AddrSpace.GLOBAL):
assert all_int(self.shape), "no placeholder-like on symbolic shape"
return UOp.placeholder(self.max_shard_shape, self.dtype, slot)
return UOp.placeholder(self.max_shard_shape, self.dtype, slot, addrspace)
# set is store+end+after
def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]|list[UOp]=()) -> UOp:

View file

@ -149,8 +149,7 @@ spec_tensor = PatternMatcher([
# movement ops
(UPat((Ops.RESHAPE, Ops.EXPAND), src=(UPat(), UPat())), lambda: True),
(UPat((Ops.PAD, Ops.SHRINK), src=(UPat(), UPat(), UPat()), name="x"),
lambda x: x.src[1].dtype.count == x.src[2].dtype.count),
(UPat((Ops.PAD, Ops.SHRINK), src=(UPat(), UPat(), UPat()), name="x"), lambda x: x.src[1].shape == x.src[2].shape),
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat(),)), lambda mv: isinstance(mv.arg, tuple)),
# REDUCE has arg=(op, axis_tuple), src[1:] are ranges after lowering