Buffer.as_buffer -> Buffer.as_memoryview [pr] (#14535)

it casts to memoryview. also inline the as_typed_buffer checks to Tensor._data
This commit is contained in:
chenyu 2026-02-04 11:31:11 -05:00 committed by GitHub
commit d57d24c7d4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 90 additions and 92 deletions

View file

@ -72,7 +72,7 @@ def loader_process(q_in, q_out, X:Tensor, seed):
#storage_tensor._copyin(img_tensor.numpy())
# faster
X[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
X[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = img.tobytes()
# ideal
#X[idx].assign(img.tobytes()) # NOTE: this is slow!
@ -264,8 +264,8 @@ def load_unet3d_data(preprocessed_dataset_dir, seed, queue_in, queue_out, X:Tens
x = random_brightness_augmentation(x)
x = gaussian_noise(x)
X[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = x.tobytes()
Y[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = y.tobytes()
X[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = x.tobytes()
Y[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = y.tobytes()
queue_out.put(idx)
queue_out.put(None)
@ -379,12 +379,12 @@ def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue
clipped_match_idxs = np.clip(match_idxs, 0, None)
clipped_boxes, clipped_labels = tgt["boxes"][clipped_match_idxs], tgt["labels"][clipped_match_idxs]
boxes[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = clipped_boxes.tobytes()
labels[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = clipped_labels.tobytes()
matches[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = match_idxs.tobytes()
anchors[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = anchor.tobytes()
boxes[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = clipped_boxes.tobytes()
labels[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = clipped_labels.tobytes()
matches[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = match_idxs.tobytes()
anchors[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = anchor.tobytes()
imgs[idx].contiguous().realize().uop.base.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
imgs[idx].contiguous().realize().uop.base.realized.as_memoryview(force_zero_copy=True)[:] = img.tobytes()
queue_out.put(idx)
queue_out.put(None)

View file

@ -48,7 +48,7 @@ def prepare_browser_chunks(model):
weight_metadata = metadata.get(name, default)
weight_metadata["parts"][part_num] = {"file": i, "file_start_pos": cursor, "size": size}
metadata[name] = weight_metadata
data = bytes(state_dict[name].uop.base.realized.as_buffer())
data = bytes(state_dict[name].uop.base.realized.as_memoryview())
data = data if not offsets else data[offsets[0]:offsets[1]]
writer.write(data)
cursor += size

View file

@ -1120,8 +1120,8 @@ class WaveState:
self.n_lanes = n_lanes
self.vgpr_buf = Buffer('CPU', VGPR_SIZE, dtypes.uint32).ensure_allocated()
self.sgpr_buf = Buffer('CPU', SGPR_COUNT, dtypes.uint32).ensure_allocated()
self._vgpr_mv = self.vgpr_buf.as_buffer(force_zero_copy=True).cast('I')
self._sgpr_mv = self.sgpr_buf.as_buffer(force_zero_copy=True).cast('I')
self._vgpr_mv = self.vgpr_buf.as_memoryview(force_zero_copy=True).cast('I')
self._sgpr_mv = self.sgpr_buf.as_memoryview(force_zero_copy=True).cast('I')
# Zero memory using ctypes memset (much faster than Python loops)
ctypes.memset(self.vgpr_buf._buf.va_addr, 0, VGPR_SIZE * 4)
ctypes.memset(self.sgpr_buf._buf.va_addr, 0, SGPR_COUNT * 4)

View file

@ -37,7 +37,7 @@ b.copyin(row.data)
c.copyin(mat.data)
ret = prog(a._buf, b._buf, c._buf, global_size=[1,1,1], local_size=[8,1,1], wait=True)
print(ret)
out = np.frombuffer(a.as_buffer(), np.float32)
out = np.frombuffer(a.as_memoryview(), np.float32)
real = row.astype(np.float32)@mat.T.astype(np.float32)
print("out:", out)
print("real", real)

View file

@ -98,10 +98,10 @@ if __name__ == "__main__":
# check correctness
if getenv("VERIFY"):
from tinygrad.engine.realize import run_schedule
triton_buf = np.frombuffer(si.bufs[0].as_buffer(), np.float16).reshape(M,N)
triton_buf = np.frombuffer(si.bufs[0].as_memoryview(), np.float16).reshape(M,N)
print(triton_buf)
run_schedule(sched)
tinygrad_buf = np.frombuffer(si.bufs[0].as_buffer(), np.float16).reshape(M,N)
tinygrad_buf = np.frombuffer(si.bufs[0].as_memoryview(), np.float16).reshape(M,N)
print(tinygrad_buf)
np.testing.assert_allclose(triton_buf, tinygrad_buf)
print("correct!")

View file

@ -18,7 +18,7 @@ prg = dev.runtime("write_ones", mbin)
prg(buf0._buf, global_size=(1,65537,1), local_size=(1,1,1), wait=True)
import numpy as np
def to_np(buf): return np.frombuffer(buf.as_buffer().cast(buf.dtype.base.fmt), dtype=_to_np_dtype(buf.dtype.base))
def to_np(buf): return np.frombuffer(buf.as_memoryview().cast(buf.dtype.base.fmt), dtype=_to_np_dtype(buf.dtype.base))
big = to_np(buf0)
print(big)

View file

@ -119,7 +119,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
val = TestHCQ.b.uop.buffer.as_memoryview().cast("f")[0]
assert val == 1.0, f"got val {val}"
def test_exec_2_kernels_100_times(self):
@ -135,7 +135,7 @@ class TestHCQ(unittest.TestCase):
q.submit(TestHCQ.d0, {virt_val.expr: TestHCQ.d0.timeline_value})
TestHCQ.d0.timeline_value += 1
val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0]
val = TestHCQ.a.uop.buffer.as_memoryview().cast("f")[0]
assert val == 200.0, f"got val {val}"
@unittest.skipIf(Device.DEFAULT in {"CPU"}, "No globals/locals on LLVM/CPU")
@ -151,9 +151,9 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
val = TestHCQ.b.uop.buffer.as_memoryview().cast("f")[0]
assert val == 1.0, f"got val {val}"
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
val = TestHCQ.b.uop.buffer.as_memoryview().cast("f")[1]
assert val == 0.0, f"got val {val}, should not be updated"
@unittest.skipIf(Device.DEFAULT in {"CPU"}, "No globals/locals on LLVM/CPU")
@ -186,7 +186,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
res_sum = sum(x for x in zt.as_buffer().cast("I"))
res_sum = sum(x for x in zt.as_memoryview().cast("I"))
assert x * y * z == res_sum, f"want {x * y * z}, got {res_sum}"
# Test copy
@ -200,7 +200,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
val = TestHCQ.b.uop.buffer.as_memoryview().cast("f")[1]
assert val == 1.0, f"got val {val}"
def test_copy_long(self):
@ -218,7 +218,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
mv_buf1 = buf1.as_buffer().cast('Q')
mv_buf1 = buf1.as_memoryview().cast('Q')
assert libc.memcmp(mv_address(mv_buf1), buf2._buf.va_addr, sz) == 0
@slow
@ -242,7 +242,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
mv_buf1 = buf1.as_buffer()
mv_buf1 = buf1.as_memoryview()
assert libc.memcmp(mv_address(mv_buf1), buf2._buf.va_addr, sz) == 0
def test_update_copy(self):
@ -260,7 +260,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
val = TestHCQ.b.uop.buffer.as_memoryview().cast("f")[1]
assert val == 1.0, f"got val {val}"
def test_update_copy_long(self):
@ -283,7 +283,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
mv_buf1 = buf1.as_buffer().cast('Q')
mv_buf1 = buf1.as_memoryview().cast('Q')
for i in range(sz//8): assert mv_buf1[i] == 0x0101010101010101, f"offset {i*8} differs, not all copied, got {hex(mv_buf1[i])}"
# Test bind api
@ -421,7 +421,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert buf1.as_buffer()[0] == i
assert buf1.as_memoryview()[0] == i
def test_small_copies_from_host_buf_intercopy(self):
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
@ -440,7 +440,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert buf2.as_buffer()[0] == i
assert buf2.as_memoryview()[0] == i
def test_small_copies_from_host_buf_transfer(self):
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
@ -463,7 +463,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert buf2.as_buffer()[0] == i
assert buf2.as_memoryview()[0] == i
def test_memory_barrier(self):
a = Tensor([0, 1], device=Device.DEFAULT, dtype=dtypes.int8).realize()
@ -486,7 +486,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert buf1.as_buffer()[0] == (i + 1), f"has {buf1.as_buffer()[0]}, need {i + 1}"
assert buf1.as_memoryview()[0] == (i + 1), f"has {buf1.as_memoryview()[0]}, need {i + 1}"
def test_memory_barrier_before_copy(self):
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
@ -511,7 +511,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert buf2.as_buffer()[0] == i
assert buf2.as_memoryview()[0] == i
def test_map_cpu_buffer_to_device(self):
if Device[Device.DEFAULT].hw_copy_queue_t is None: self.skipTest("skip device without copy queue")

View file

@ -20,7 +20,7 @@ class TestAMD(unittest.TestCase):
global_size=TestAMD.d0_runner.global_size, local_size=TestAMD.d0_runner.local_size)
TestAMD.d0_runner.clprg(TestAMD.a.uop.buffer._buf, TestAMD.b.uop.buffer._buf,
global_size=TestAMD.d0_runner.global_size, local_size=TestAMD.d0_runner.local_size)
val = TestAMD.a.uop.buffer.as_buffer().cast("f")[0]
val = TestAMD.a.uop.buffer.as_memoryview().cast("f")[0]
assert val == 4000.0, f"got val {val}"
if __name__ == "__main__":

View file

@ -65,7 +65,7 @@ class TestHCQ(unittest.TestCase):
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0]
val = TestHCQ.a.uop.buffer.as_memoryview().cast("f")[0]
assert val == 2000.0, f"got val {val}"
def test_run_1000_times(self):
@ -81,7 +81,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.compute_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0]
val = TestHCQ.a.uop.buffer.as_memoryview().cast("f")[0]
assert val == 2000.0, f"got val {val}"
def test_run_to_3(self):
@ -95,7 +95,7 @@ class TestHCQ(unittest.TestCase):
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
val = TestHCQ.b.uop.buffer.as_memoryview().cast("f")[0]
assert val == 3.0, f"got val {val}"
def test_update_exec(self):
@ -106,9 +106,9 @@ class TestHCQ(unittest.TestCase):
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
val = TestHCQ.b.uop.buffer.as_memoryview().cast("f")[0]
assert val == 1.0, f"got val {val}"
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
val = TestHCQ.b.uop.buffer.as_memoryview().cast("f")[1]
assert val == 0.0, f"got val {val}, should not be updated"
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
@ -126,7 +126,7 @@ class TestHCQ(unittest.TestCase):
TestHCQ.compute_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0]
val = TestHCQ.a.uop.buffer.as_memoryview().cast("f")[0]
assert val == 2000.0, f"got val {val}"
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
@ -141,9 +141,9 @@ class TestHCQ(unittest.TestCase):
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
val = TestHCQ.b.uop.buffer.as_memoryview().cast("f")[0]
assert val == 1.0, f"got val {val}"
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
val = TestHCQ.b.uop.buffer.as_memoryview().cast("f")[1]
assert val == 0.0, f"got val {val}, should not be updated"
@unittest.skipIf(CI, "Can't handle async update on CPU")
@ -174,7 +174,7 @@ class TestHCQ(unittest.TestCase):
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
val = TestHCQ.b.uop.buffer.as_memoryview().cast("f")[0]
assert val == 1.0, f"got val {val}"
def test_submit_empty_queues(self):
@ -206,7 +206,7 @@ class TestHCQ(unittest.TestCase):
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
val = TestHCQ.b.uop.buffer.as_memoryview().cast("f")[0]
assert val == 1.0, f"got val {val}"
def test_copy_1000_times(self):
@ -221,7 +221,7 @@ class TestHCQ(unittest.TestCase):
# confirm the signal didn't exceed the put value
with self.assertRaises(RuntimeError):
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value + 1, timeout=50)
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
val = TestHCQ.b.uop.buffer.as_memoryview().cast("f")[1]
assert val == 0.0, f"got val {val}"
def test_copy(self):
@ -231,7 +231,7 @@ class TestHCQ(unittest.TestCase):
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
val = TestHCQ.b.uop.buffer.as_memoryview().cast("f")[1]
assert val == 1.0, f"got val {val}"
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
@ -248,7 +248,7 @@ class TestHCQ(unittest.TestCase):
# confirm the signal didn't exceed the put value
with self.assertRaises(RuntimeError):
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value + 1, timeout=50)
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
val = TestHCQ.b.uop.buffer.as_memoryview().cast("f")[1]
assert val == 0.0, f"got val {val}"
def test_copy_bandwidth(self):
@ -288,7 +288,7 @@ class TestHCQ(unittest.TestCase):
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0]
val = TestHCQ.a.uop.buffer.as_memoryview().cast("f")[0]
assert val == 1.0, f"got val {val}"
def test_cross_device_signal(self):
@ -319,7 +319,7 @@ class TestHCQ(unittest.TestCase):
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[0]
val = TestHCQ.b.uop.buffer.as_memoryview().cast("f")[0]
assert val == 1.0, f"got val {val}"
if __name__ == "__main__":

View file

@ -30,7 +30,7 @@ def alloc_rawbuffer(device, fill=False):
if fill:
with Context(DEBUG=0):
data = np.random.randint(-10000, 10000, size=rawbuf.size, dtype=_to_np_dtype(rawbuf.dtype))
rawbuf.copyin(Tensor(data).realize().uop.base.realized.as_buffer())
rawbuf.copyin(Tensor(data).realize().uop.base.realized.as_memoryview())
return rawbuf
def gen_kernel_ji(device, deps):
@ -93,7 +93,7 @@ def run_jit(jis, all_buffers, input_buffers, var_vals):
with Context(DEBUG=0):
res_buffers = []
for rawbuf in all_buffers: res_buffers.append(rawbuf.as_buffer())
for rawbuf in all_buffers: res_buffers.append(rawbuf.as_memoryview())
return res_buffers
def fuzz_graph(jis, all_buffers, input_buffers):

View file

@ -149,10 +149,10 @@ class TestTensorCores(unittest.TestCase):
if _to_np_dtype(real_bufs[0].dtype) is None: continue
real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=_to_np_dtype(real_bufs[0].dtype)).data) # Zero to check that all values are filled
prg.exec(real_bufs)
result = np.frombuffer(real_bufs[0].as_buffer(), _to_np_dtype(real_bufs[0].dtype))
result = np.frombuffer(real_bufs[0].as_memoryview(), _to_np_dtype(real_bufs[0].dtype))
# ensure the results for each choice of axis matches
if golden_result is None: golden_result = np.frombuffer(real_bufs[0].as_buffer(), _to_np_dtype(real_bufs[0].dtype))
if golden_result is None: golden_result = np.frombuffer(real_bufs[0].as_memoryview(), _to_np_dtype(real_bufs[0].dtype))
np.testing.assert_allclose(result, golden_result, atol=0.1, rtol=0.2)
@Context(ALLOW_TF32=1)

View file

@ -40,7 +40,7 @@ def helper_alloc_rawbuffer(device, fill=False):
if fill:
with Context(DEBUG=0):
data = np.random.randint(-10000, 10000, size=rawbuf.size, dtype=_to_np_dtype(rawbuf.dtype))
rawbuf.copyin(Tensor(data).realize().uop.base.realized.as_buffer())
rawbuf.copyin(Tensor(data).realize().uop.base.realized.as_memoryview())
return rawbuf
def helper_create_offset_rawbuffer(base, offset=0):
@ -54,7 +54,7 @@ def helper_run_jit(jis, bufs, out_buffers):
rawbuf.copyin(mv)
for ei in jis: ei.run({}, jit=True)
return [rawbuf.as_buffer() for rawbuf in bufs]
return [rawbuf.as_memoryview() for rawbuf in bufs]
def helper_test_graphs(graph_impl, graphs, runs=RUN_CNT):
reg_ji = []

View file

@ -12,7 +12,7 @@ class TestImageCopy(unittest.TestCase):
def test_image_copyout_1x8(self, img_type=dtypes.imagef):
it = Tensor.arange(32).cast(img_type((1,8,4))).realize()
buf = it.uop.buffer
out = buf.as_buffer()
out = buf.as_memoryview()
np.testing.assert_equal(out.cast(it.dtype.fmt).tolist(), np.arange(32))
@unittest.skipUnless(is_dtype_supported(dtypes.half, device="PYTHON"), "need half")
@ -26,14 +26,14 @@ class TestImageCopy(unittest.TestCase):
def test_image_copyout_2x4(self):
it = Tensor.arange(2*4*4).cast(dtypes.imagef((2,4,4))).realize()
buf = it.uop.buffer
out = buf.as_buffer()
out = buf.as_memoryview()
np.testing.assert_equal(out.cast('f').tolist(), np.arange(2*4*4))
def test_image_roundtrip(self):
sz = (4,2,4)
it = Tensor.rand(prod(sz)).cast(dtypes.imagef(sz)).realize()
buf = it.uop.buffer
out = buf.as_buffer()
out = buf.as_memoryview()
it2 = Tensor.rand(prod(sz)).cast(dtypes.imagef(sz)).realize()
buf2 = it2.uop.buffer
@ -190,7 +190,7 @@ class TestImageDType(unittest.TestCase):
for s in sched:
s.run()
if s.bufs[0].dtype == dtypes.float:
lst = s.bufs[0].as_buffer().cast("f").tolist()
lst = s.bufs[0].as_memoryview().cast("f").tolist()
print(lst)
assert not np.any(np.isnan(lst))
# NOTE: the w1 grad must realize to a separate kernel

View file

@ -504,7 +504,7 @@ def helper_linearizer_opt(r:Tensor|list[Tensor], *args, **kwargs):
return realized_ast
def copyout_outputs(outbufs:list[Buffer]) -> list[np.ndarray]:
return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs]
return [np.frombuffer(x.as_memoryview(), _to_np_dtype(x.dtype)) for x in outbufs]
def reset_bufs(bufs:list[Buffer]):
for buf in bufs: buf.copyin(np.zeros((buf.size*buf.dtype.itemsize,), dtype=np.uint8).data)

View file

@ -76,7 +76,7 @@ class TestPickle(unittest.TestCase):
del a
del buffer
a2:UOp = pickle.loads(s)
self.assertListEqual(a2.base.realized.as_buffer().cast("I").tolist(), [0, 1, 2, 3])
self.assertListEqual(a2.base.realized.as_memoryview().cast("I").tolist(), [0, 1, 2, 3])
def test_pickle_unrealized_tensor(self):
t = Tensor.ones(10, 10)

View file

@ -22,7 +22,7 @@ def _test_uop_result(inputs:list[Tensor], prg, local_size=None):
if local_size is not None: prg = replace(prg, local_size=local_size)
ei = CompiledRunner(prg)
ei.exec(outbufs+inbufs)
return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs]
return [np.frombuffer(x.as_memoryview(), _to_np_dtype(x.dtype)) for x in outbufs]
def _setup_and_test_alu(alu_op:Ops, input_val:ConstType, *alu_src_uops:UOp):
dtype = alu_src_uops[0].dtype

View file

@ -2124,14 +2124,14 @@ class TestCopyFolding(unittest.TestCase):
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
def test_permute_on_disk(self):
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer())
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_memoryview())
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}")
b = a.reshape(2, 2).permute(1, 0).to("CPU")
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
def test_permute_on_disk_contiguous(self):
with open(temp('dt_arange_4_permute_contig'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_buffer())
with open(temp('dt_arange_4_permute_contig'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_memoryview())
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute_contig')}")
b = a.reshape(2, 2).permute(1, 0).contiguous().to("CPU")
b.realize()
@ -2145,7 +2145,7 @@ class TestCopyFolding(unittest.TestCase):
# NOTE: disk permute must come after COPY
def test_permute_after_shrink_on_disk(self):
with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().uop.base.buffer.as_buffer())
with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().uop.base.buffer.as_memoryview())
a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}")
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CPU")
b.realize()

View file

@ -13,26 +13,26 @@ class TestSubBuffer(unittest.TestCase):
def test_subbuffer(self):
vbuf = self.buf.view(2, dtypes.uint8, offset=3).ensure_allocated()
tst = vbuf.as_buffer().tolist()
tst = vbuf.as_memoryview().tolist()
assert tst == [3, 4]
def test_subbuffer_cast(self):
# NOTE: bitcast depends on endianness
vbuf = self.buf.view(2, dtypes.uint16, offset=3).ensure_allocated()
tst = vbuf.as_buffer().cast("H").tolist()
tst = vbuf.as_memoryview().cast("H").tolist()
assert tst == [3|(4<<8), 5|(6<<8)]
def test_subbuffer_double(self):
vbuf = self.buf.view(4, dtypes.uint8, offset=3).ensure_allocated()
vvbuf = vbuf.view(2, dtypes.uint8, offset=1).ensure_allocated()
tst = vvbuf.as_buffer().tolist()
tst = vvbuf.as_memoryview().tolist()
assert tst == [4, 5]
def test_subbuffer_len(self):
vbuf = self.buf.view(5, dtypes.uint8, 2).ensure_allocated()
mv = vbuf.as_buffer()
mv = vbuf.as_memoryview()
assert len(mv) == 5
mv = vbuf.as_buffer(allow_zero_copy=True)
mv = vbuf.as_memoryview(allow_zero_copy=True)
assert len(mv) == 5
def test_subbuffer_used(self):
@ -63,7 +63,7 @@ class TestSubBuffer(unittest.TestCase):
vbuf.ensure_allocated()
tst = vbuf.as_buffer().tolist()
tst = vbuf.as_memoryview().tolist()
assert tst == [13, 14]
def test_subbuffer_is_allocated(self):
@ -112,8 +112,8 @@ class TestSubBuffer(unittest.TestCase):
sub_buf.copyout(memoryview(data_out_sub))
assert data_out_sub == bytearray(range(3, 6))
sub_buf.copyin(memoryview(bytearray(range(3))))
assert sub_buf.as_buffer().tolist() == list(range(3))
assert self.buf.as_buffer().tolist()[3:6] == list(range(3))
assert sub_buf.as_memoryview().tolist() == list(range(3))
assert self.buf.as_memoryview().tolist()[3:6] == list(range(3))
sub_buf.copyout(memoryview(data_out_sub))
assert data_out_sub == bytearray(range(3))
data_out_base = bytearray([0]*10)
@ -145,17 +145,17 @@ class TestSubBuffer(unittest.TestCase):
sub_buf = self.buf.view(4, dtypes.int8, offset=3)
sub_buf.allocate()
sub_buf.copyin(memoryview(bytearray(range(10, 14))))
assert self.buf.as_buffer().tolist()[3:7] == sub_buf.as_buffer().tolist()
assert self.buf.as_memoryview().tolist()[3:7] == sub_buf.as_memoryview().tolist()
sub_buf = self.buf_unalloc.view(4, dtypes.int8, offset=3)
sub_buf.allocate()
sub_buf.copyin(memoryview(bytearray(range(10, 14))))
assert self.buf_unalloc.as_buffer().tolist()[3:7] == sub_buf.as_buffer().tolist()
assert self.buf_unalloc.as_memoryview().tolist()[3:7] == sub_buf.as_memoryview().tolist()
def test_subbuffer_dealloc(self):
sub_buf = self.buf.view(4, dtypes.int8, offset=3).ensure_allocated()
sub_buf.deallocate()
assert self.buf.as_buffer().tolist() == list(range(10))
assert self.buf.as_memoryview().tolist() == list(range(10))
def test_subbuffer_double_dealloc(self):
sub_buf = self.buf.view(3, dtypes.uint8, offset=4).ensure_allocated()
@ -168,17 +168,17 @@ class TestSubBuffer(unittest.TestCase):
def test_subbuffer_uaf(self):
sub_buf = self.buf.view(4, dtypes.int8, offset=3).ensure_allocated()
assert self.buf.as_buffer().tolist(), list(range(10))
assert self.buf.as_memoryview().tolist(), list(range(10))
sub_buf.deallocate()
with self.assertRaises(AssertionError):
sub_buf.as_buffer().tolist()
assert self.buf.as_buffer().tolist(), list(range(10))
sub_buf.as_memoryview().tolist()
assert self.buf.as_memoryview().tolist(), list(range(10))
sub_buf = self.buf.view(4, dtypes.int8, offset=3).ensure_allocated()
assert sub_buf.as_buffer().tolist(), list(range(3, 7))
assert sub_buf.as_memoryview().tolist(), list(range(3, 7))
self.buf.deallocate()
with self.assertRaises(AssertionError):
sub_buf.as_buffer().tolist()
sub_buf.as_memoryview().tolist()
if __name__ == '__main__':
unittest.main()

View file

@ -6,7 +6,7 @@ def time_tensor_numpy(out:Tensor):
times = []
for _ in range(5):
st = time.perf_counter()
out.uop.base.realized.as_buffer(allow_zero_copy=True)
out.uop.base.realized.as_memoryview(allow_zero_copy=True)
et = time.perf_counter() - st
times.append(et)
return min(times)

View file

@ -177,20 +177,16 @@ class Buffer:
def as_dmaref(self) -> DMARef:
assert hasattr(self.allocator, "_as_dmaref"), f"Device {self.device} doesn't support DMA"
return self.allocator._as_dmaref(self._buf)
def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
# zero copy with as_buffer (disabled by default due to use after free)
def as_memoryview(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
# zero copy with as_memoryview (disabled by default due to use after free)
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, '_as_buffer') and (self.options is None or self.options.image is None):
return self.allocator._as_buffer(self._buf)
assert not force_zero_copy, "force zero copy was passed, but copy is required"
return self.copyout(memoryview(bytearray(self.nbytes)))
def as_typed_buffer(self, shape=None, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
assert self.dtype.base.fmt != "e" or sys.version_info >= (3, 12)
return self.as_buffer(allow_zero_copy, force_zero_copy).cast(self.dtype.base.fmt, shape if shape is not None else (self.size,))
def numpy(self) -> 'np.ndarray': # type: ignore [name-defined] # noqa: F821
import numpy as np
assert _to_np_dtype(self.dtype.base) is not None, f"no np dtype for {self.dtype.base}"
return np.frombuffer(self.as_buffer(), dtype=_to_np_dtype(self.dtype.base))
return np.frombuffer(self.as_memoryview(), dtype=_to_np_dtype(self.dtype.base))
def copyin(self, mv:memoryview):
mv = flat_mv(mv)
assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"

View file

@ -207,7 +207,7 @@ class CapturedJit(Generic[ReturnType]):
asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True)
self.jit_cache = [replace(item, bufs=[asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache]
for old, new in asgn.items():
if old.is_allocated(): new.ensure_allocated().copyin(old.as_buffer())
if old.is_allocated(): new.ensure_allocated().copyin(old.as_memoryview())
self.__post_init__()
# jit exec
@ -219,7 +219,7 @@ class CapturedJit(Generic[ReturnType]):
# copy aliased inputs to prevent read-after-write hazard
for i, ib in enumerate(input_buffers):
if (writer := self._output_to_writer.get(ib)) is not None and self._input_to_max_reader.get(i, -1) > writer:
input_buffers[i] = Buffer(ib.device, ib.size, ib.dtype).ensure_allocated().copyin(ib.as_buffer())
input_buffers[i] = Buffer(ib.device, ib.size, ib.dtype).ensure_allocated().copyin(ib.as_memoryview())
for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]

View file

@ -79,7 +79,7 @@ class BufferCopy(Runner):
# fast(ish) path, uses readinto in diskbuffers
src.allocator._copyout(dest.allocator._as_buffer(dest._buf), src._buf)
else:
dest.copyin(src.as_buffer(allow_zero_copy=True)) # may allocate a CPU buffer depending on allow_zero_copy
dest.copyin(src.as_memoryview(allow_zero_copy=True)) # may allocate a CPU buffer depending on allow_zero_copy
def __call__(self, rawbufs:list[Buffer], var_vals:dict[str, int], wait=False):
dest, src = rawbufs[0:2]
assert dest.size == src.size and dest.dtype == src.dtype, f"buffer copy mismatch, {dest.size} != {src.size}, {dest.dtype} != {src.dtype}"
@ -199,7 +199,7 @@ def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_
bufs = [b for b in ei.bufs if b is not None]
nb: list[Buffer|None] = [Buffer("CPU", b.size, b.dtype) for b in bufs]
for cpu_b, gpu_b in zip(nb, bufs):
if cpu_b is not None and gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer())
if cpu_b is not None and gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_memoryview())
# run on GPU
ei.run(var_vals, do_update_stats=do_update_stats)

View file

@ -316,7 +316,7 @@ class Tensor(OpMixin):
x = self.cast(self.dtype.base).contiguous()
if isinstance(self.device, tuple): x = x.to("CPU")
return cast(Buffer, x.realize().uop.base.buffer).ensure_allocated()
def _data(self) -> memoryview: return self._buffer().as_buffer()
def _data(self) -> memoryview: return self._buffer().as_memoryview()
def data(self) -> memoryview:
"""
@ -329,7 +329,9 @@ class Tensor(OpMixin):
"""
if 0 in self.shape: return memoryview(bytearray(0)).cast(self.dtype.base.fmt)
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
return self._buffer().as_typed_buffer(self.shape)
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
assert self.dtype.base.fmt != "e" or sys.version_info >= (3, 12)
return self._buffer().as_memoryview().cast(self.dtype.base.fmt, self.shape)
def item(self) -> PyConst:
"""