use placeholder in tests (#16672)

This commit is contained in:
George Hotz 2026-06-18 20:51:44 -07:00 committed by GitHub
commit 925c49ce99
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 9 additions and 6 deletions

View file

@ -177,7 +177,8 @@ class TestLocalAccess(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
def test_local_basic(self):
uops = []
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.float32.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
smem = UOp.placeholder((16,), dtypes.float32, slot=0, addrspace=AddrSpace.LOCAL)
uops.append(smem)
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.float32, (), 42.0)))
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True),))
@ -187,7 +188,8 @@ class TestLocalAccess(unittest.TestCase):
@unittest.skipUnless(Device.DEFAULT == "WEBGPU", "Test local access with packed data type")
def test_local_packed(self):
uops = []
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
smem = UOp.placeholder((16,), dtypes.uint8, slot=0, addrspace=AddrSpace.LOCAL)
uops.append(smem)
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42)))
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
sres = smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0))
@ -199,7 +201,7 @@ class TestLocalAccess(unittest.TestCase):
_dtypes = [dtypes.char, dtypes.uchar, dtypes.short, dtypes.ushort, dtypes.half]
size = 16
for dtype in _dtypes:
temp = UOp(Ops.DEFINE_LOCAL, dtype.ptr(size=size, addrspace=AddrSpace.LOCAL), (), 'smem')
temp = UOp.placeholder((size,), dtype, slot=0, addrspace=AddrSpace.LOCAL)
uops = to_uops_list([temp], ren=Device[Device.DEFAULT].renderer)
out = Device[Device.DEFAULT].renderer.render(uops)
# half is supported in wgsl, so it doesn't have to be packed
@ -211,7 +213,8 @@ class TestLocalAccess(unittest.TestCase):
@unittest.skip("tinygrad doesn't support this behavior")
def test_local_indirect(self):
uops = []
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.int32.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
smem = UOp.placeholder((16,), dtypes.int32, slot=0, addrspace=AddrSpace.LOCAL)
uops.append(smem)
st1 = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 1)), uop(uops, Ops.CONST, dtypes.int32, (), 2)))
st2 = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 2)), uop(uops, Ops.CONST, dtypes.int32, (), 42)))
barr = uop(uops, Ops.BARRIER, dtypes.void, (st1,st2))

View file

@ -541,7 +541,7 @@ class TestUOpGraph(unittest.TestCase):
def test_fold_gated_load_local(self):
glbl0 = UOp.param(0, dtypes.int.ptr(16))
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, addrspace=AddrSpace.LOCAL), (), "temp")
smem = UOp.placeholder((18,), dtypes.int, slot=0, addrspace=AddrSpace.LOCAL)
lidx = UOp.special(16, "lidx0", dtypes.int)
st = smem.index(lidx, ptr=True).store(glbl0.index(lidx, ptr=True).load())
barrier = st.barrier()

View file

@ -146,7 +146,7 @@ class TestValidateOOB(unittest.TestCase):
with Context(CHECK_OOB=1):
# Define buffers
gbuf = UOp.param(0, dtypes.uint.ptr(400))
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.uint.ptr(8, addrspace=AddrSpace.LOCAL), (), "temp0")
sbuf = UOp.placeholder((8,), dtypes.uint, slot=0, addrspace=AddrSpace.LOCAL)
# Define indices, valids and barrier
gidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 416),), "gidx0")