mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
use placeholder in tests (#16672)
This commit is contained in:
parent
05249466ed
commit
925c49ce99
3 changed files with 9 additions and 6 deletions
|
|
@ -177,7 +177,8 @@ class TestLocalAccess(unittest.TestCase):
|
|||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
|
||||
def test_local_basic(self):
|
||||
uops = []
|
||||
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.float32.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
|
||||
smem = UOp.placeholder((16,), dtypes.float32, slot=0, addrspace=AddrSpace.LOCAL)
|
||||
uops.append(smem)
|
||||
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.float32, (), 42.0)))
|
||||
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
|
||||
sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0), ptr=True),))
|
||||
|
|
@ -187,7 +188,8 @@ class TestLocalAccess(unittest.TestCase):
|
|||
@unittest.skipUnless(Device.DEFAULT == "WEBGPU", "Test local access with packed data type")
|
||||
def test_local_packed(self):
|
||||
uops = []
|
||||
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
|
||||
smem = UOp.placeholder((16,), dtypes.uint8, slot=0, addrspace=AddrSpace.LOCAL)
|
||||
uops.append(smem)
|
||||
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42)))
|
||||
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
|
||||
sres = smem.after(barr).index(uop(uops, Ops.CONST, dtypes.int32, (), 0))
|
||||
|
|
@ -199,7 +201,7 @@ class TestLocalAccess(unittest.TestCase):
|
|||
_dtypes = [dtypes.char, dtypes.uchar, dtypes.short, dtypes.ushort, dtypes.half]
|
||||
size = 16
|
||||
for dtype in _dtypes:
|
||||
temp = UOp(Ops.DEFINE_LOCAL, dtype.ptr(size=size, addrspace=AddrSpace.LOCAL), (), 'smem')
|
||||
temp = UOp.placeholder((size,), dtype, slot=0, addrspace=AddrSpace.LOCAL)
|
||||
uops = to_uops_list([temp], ren=Device[Device.DEFAULT].renderer)
|
||||
out = Device[Device.DEFAULT].renderer.render(uops)
|
||||
# half is supported in wgsl, so it doesn't have to be packed
|
||||
|
|
@ -211,7 +213,8 @@ class TestLocalAccess(unittest.TestCase):
|
|||
@unittest.skip("tinygrad doesn't support this behavior")
|
||||
def test_local_indirect(self):
|
||||
uops = []
|
||||
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.int32.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
|
||||
smem = UOp.placeholder((16,), dtypes.int32, slot=0, addrspace=AddrSpace.LOCAL)
|
||||
uops.append(smem)
|
||||
st1 = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 1)), uop(uops, Ops.CONST, dtypes.int32, (), 2)))
|
||||
st2 = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 2)), uop(uops, Ops.CONST, dtypes.int32, (), 42)))
|
||||
barr = uop(uops, Ops.BARRIER, dtypes.void, (st1,st2))
|
||||
|
|
|
|||
|
|
@ -541,7 +541,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
|
||||
def test_fold_gated_load_local(self):
|
||||
glbl0 = UOp.param(0, dtypes.int.ptr(16))
|
||||
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, addrspace=AddrSpace.LOCAL), (), "temp")
|
||||
smem = UOp.placeholder((18,), dtypes.int, slot=0, addrspace=AddrSpace.LOCAL)
|
||||
lidx = UOp.special(16, "lidx0", dtypes.int)
|
||||
st = smem.index(lidx, ptr=True).store(glbl0.index(lidx, ptr=True).load())
|
||||
barrier = st.barrier()
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ class TestValidateOOB(unittest.TestCase):
|
|||
with Context(CHECK_OOB=1):
|
||||
# Define buffers
|
||||
gbuf = UOp.param(0, dtypes.uint.ptr(400))
|
||||
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.uint.ptr(8, addrspace=AddrSpace.LOCAL), (), "temp0")
|
||||
sbuf = UOp.placeholder((8,), dtypes.uint, slot=0, addrspace=AddrSpace.LOCAL)
|
||||
|
||||
# Define indices, valids and barrier
|
||||
gidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 416),), "gidx0")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue