mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
change buffer to not be pointer [pr] (#8302)
This commit is contained in:
parent
4e2d98638d
commit
801e199196
4 changed files with 20 additions and 20 deletions
|
|
@ -1761,7 +1761,7 @@ class TestSwizzle(unittest.TestCase):
|
|||
|
||||
def test_permute_rewrite(self):
|
||||
sink = UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
x1:=UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(1, ('METAL', 16384, dtypes.float)), src=()),
|
||||
x1:=UOp(Ops.BUFFER, dtypes.float, arg=(1, ('METAL', 16384, dtypes.float)), src=()),
|
||||
x2:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.CONTIGUOUS, dtypes.float, arg=None, src=(
|
||||
x1,
|
||||
|
|
@ -1773,15 +1773,15 @@ class TestSwizzle(unittest.TestCase):
|
|||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
x11:=UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 16384, dtypes.float)), src=()),
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(2, ('METAL', 16384, dtypes.float)), src=()),
|
||||
x2,)),)),
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(8, ('METAL', 256, dtypes.float)), src=()),
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(8, ('METAL', 256, dtypes.float)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(10, ('METAL', 16, dtypes.float)), src=()),
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(10, ('METAL', 16, dtypes.float)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
|
||||
x11,)),)),)),)),))
|
||||
|
|
@ -1793,7 +1793,7 @@ class TestSwizzle(unittest.TestCase):
|
|||
# fuse (relu bw, conv2d, conv2d bw, relu)
|
||||
sink = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(10, ('METAL', 128, dtypes.float)), src=()),
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(10, ('METAL', 128, dtypes.float)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 16, 2, 2), strides=(64, 4, 2, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
|
|
@ -1808,16 +1808,16 @@ class TestSwizzle(unittest.TestCase):
|
|||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(9, ('METAL', 96, dtypes.float)), src=()),
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(9, ('METAL', 96, dtypes.float)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 16, 2, 2, 3, 3, 3), strides=(48, 0, 0, 4, 1, 16, 4, 1), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
UOp(Ops.PRELOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(16, ('METAL', 432, dtypes.float)), src=()),
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(16, ('METAL', 432, dtypes.float)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 16, 2, 2, 3, 3, 3), strides=(0, 0, 27, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),
|
||||
x6,)),)),)),
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 16, 2, 2), strides=(64, 4, 2, 1), offset=0, mask=None, contiguous=True),)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (4, 6)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(18, ('METAL', 128, dtypes.float)), src=()),
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(18, ('METAL', 128, dtypes.float)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 16, 2, 3, 2, 3), strides=(64, 4, 2, 0, 1, 0), offset=0, mask=((0, 2), (0, 16), (0, 2), (0, 1), (0, 2), (0, 1)), contiguous=False), View(shape=(1, 2, 1, 16, 3, 2, 3, 2), strides=(0, 576, 0, 36, 12, 6, 2, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),))
|
||||
ret = swizzle_rewrite(sink)
|
||||
self.assertEqual(swizzle_cnt(ret), 0)
|
||||
|
|
@ -1826,7 +1826,7 @@ class TestSwizzle(unittest.TestCase):
|
|||
def test_swizzle_failure_permute(self):
|
||||
sink = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(20, ('METAL', 65, dtypes.float)), src=()),
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(20, ('METAL', 65, dtypes.float)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 65), strides=(0, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
|
||||
|
|
@ -1834,7 +1834,7 @@ class TestSwizzle(unittest.TestCase):
|
|||
x6:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.PRELOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(8, ('METAL', 2925, dtypes.float)), src=()),
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(8, ('METAL', 2925, dtypes.float)), src=()),
|
||||
x10:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(65, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
x12:=UOp(Ops.VALID, dtypes.bool, arg=None, src=(
|
||||
|
|
@ -1854,12 +1854,12 @@ class TestSwizzle(unittest.TestCase):
|
|||
x15,)),
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.PRELOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 2925, dtypes.float)), src=()),
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(2, ('METAL', 2925, dtypes.float)), src=()),
|
||||
x10,)),
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(1, 89), offset=44, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(4, ('METAL', 2925, dtypes.float)), src=()),
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(4, ('METAL', 2925, dtypes.float)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(65, 45, 90), strides=(1, 0, 65), offset=0, mask=((0, 65), (0, 45), (0, 45)), contiguous=False), View(shape=(65, 4094), strides=(4050, 1), offset=0, mask=((0, 65), (0, 4050)), contiguous=False), View(shape=(1, 65, 46, 89), strides=(0, 4094, 89, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)),)),))
|
||||
ret = swizzle_rewrite(sink)
|
||||
self.assertEqual(swizzle_cnt(ret), 0)
|
||||
|
|
@ -1952,7 +1952,7 @@ class TestBigGraph(unittest.TestCase):
|
|||
|
||||
def test_sink_childless_const_alt(self):
|
||||
x = UOp.const(dtypes.int, 0)
|
||||
y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(()))
|
||||
y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int, (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(()))
|
||||
big_graph = big_graph_rewrite(UOp.sink(x, y), ctx:=ScheduleContext())
|
||||
self.assertIs(big_graph, UOp(Ops.NOOP))
|
||||
self.assertEqual(len(ctx.realizes), 0)
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class TestTensorUopRepresentation(unittest.TestCase):
|
|||
# UOp(Ops.EXPAND, dtypes.float, arg=(10, 10), src=(
|
||||
# UOp(Ops.RESHAPE, dtypes.float, arg=(1, 1), src=(
|
||||
# UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),)), src=(
|
||||
# UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(-1, 'METAL', 1), src=()),
|
||||
# UOp(Ops.BUFFER, dtypes.float, arg=(-1, 'METAL', 1), src=()),
|
||||
# UOp(Ops.CONST, dtypes.float, arg=1.0, src=()),)),)),))
|
||||
# expected:
|
||||
# UOp(Ops.EXPAND, dtypes.float, arg=(10, 10), src=(
|
||||
|
|
@ -55,14 +55,14 @@ class TestTensorUopRepresentation(unittest.TestCase):
|
|||
# currently, COPY has an extra BUFFER on the output
|
||||
# current:
|
||||
# UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=(
|
||||
# UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(2, 'TEST', 3), src=()),
|
||||
# UOp(Ops.BUFFER, dtypes.float, arg=(2, 'TEST', 3), src=()),
|
||||
# UOp(Ops.COPY, dtypes.float, arg=('TEST', False), src=(
|
||||
# UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=(
|
||||
# UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(1, 'METAL', 3), src=()),)),)),))
|
||||
# UOp(Ops.BUFFER, dtypes.float, arg=(1, 'METAL', 3), src=()),)),)),))
|
||||
# expected:
|
||||
# UOp(Ops.COPY, dtypes.float, arg=('TEST', False), src=(
|
||||
# UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=(
|
||||
# UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(1, 'METAL', 3), src=()),))
|
||||
# UOp(Ops.BUFFER, dtypes.float, arg=(1, 'METAL', 3), src=()),))
|
||||
@unittest.expectedFailure
|
||||
def test_copyin(self):
|
||||
a = Tensor([1.,2,3]).realize()
|
||||
|
|
|
|||
|
|
@ -151,7 +151,7 @@ def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
|
|||
|
||||
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
|
||||
ctx.bufs.append(x)
|
||||
return UOp(Ops.DEFINE_GLOBAL, x.dtype, (), len(ctx.bufs)-1)
|
||||
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(), (), len(ctx.bufs)-1)
|
||||
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
|
||||
|
||||
def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
|
||||
|
|
|
|||
|
|
@ -438,7 +438,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
if op is Ops.CONST:
|
||||
# NOTE: we embed device on CONST with a fake BUFFER uop
|
||||
fake = UOp(Ops.BUFFER, dtype.ptr(), (UOp(Ops.DEVICE, arg=device),), (-1, 1))
|
||||
fake = UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device),), (-1, 1))
|
||||
# NOTE: BIND stays BIND, UOp.const unbinds here
|
||||
const_uop = arg if isinstance(arg, UOp) else UOp.const(dtype, unwrap(arg))
|
||||
return UOp(Ops.VIEW, dtype, (fake, const_uop), ShapeTracker.from_shape(())).reshape((1,)*len(shape)).expand(shape)
|
||||
|
|
@ -505,7 +505,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
buffer_num = itertools.count(0)
|
||||
@staticmethod
|
||||
def new_buffer(device:str, size:int, dtype:DType) -> UOp:
|
||||
return UOp(Ops.BUFFER, dtype.ptr(), (UOp(Ops.DEVICE, arg=device),), (next(UOp.buffer_num), size))
|
||||
return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device),), (next(UOp.buffer_num), size))
|
||||
@property
|
||||
def device(self) -> str: return unwrap(self._device)
|
||||
@functools.cached_property
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue