Remove the toCPU copy (#2445)

* Remove the rawbuffer copy in runtime/lib.py on line 44

* remove buffer view

* added metadata back, oops

* delayed cpu testcase

* whitespace

* whitespace

* buffer behavior as is

* Update test_jit.py
This commit is contained in:
qtkite 2023-11-28 15:37:13 +11:00 committed by GitHub
commit cb507a9389
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 18 additions and 6 deletions

View file

@ -240,5 +240,18 @@ class TestJit(unittest.TestCase):
assert len(cache.good_jitted.jit_cache) == 1
assert len(cache.bad_jitted.jit_cache) == 1
def test_jit_buffer_behavior(self):
@TinyJit
def foo(x) -> Tensor: return x.sum().realize()
result_1 = foo(Tensor([1] * 2))
result_2 = foo(Tensor([2] * 2))
result_3 = foo(Tensor([3] * 2))
# expect the buffer to share underlying buffer
np.testing.assert_allclose(result_1.numpy(), [2], atol=1e-4, rtol=1e-5)
np.testing.assert_allclose(result_2.numpy(), [6], atol=1e-4, rtol=1e-5)
np.testing.assert_allclose(result_3.numpy(), [6], atol=1e-4, rtol=1e-5)
if __name__ == '__main__':
unittest.main()
unittest.main()

View file

@ -40,9 +40,8 @@ class RawBufferCopyIn(RawBuffer):
class RawBufferMapped(RawBufferCopyIn):
def _buffer(self) -> memoryview: raise NotImplementedError("must be implemented")
# NOTE: this metadata prevents the backing buffer from being freed. hack can be removed with PEP688
def buffer_view(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self}), count=self.size) # type: ignore
def toCPU(self) -> np.ndarray: return self.buffer_view().astype(self.dtype.np, copy=True) # Need a copy (for now), since jit will write to the same buffer.
def _copyin(self, x:np.ndarray) -> None: np.copyto(self.buffer_view(), x.reshape(-1))
def toCPU(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self}), count=self.size)
def _copyin(self, x:np.ndarray) -> None: np.copyto(self.toCPU(), x.reshape(-1))
@classmethod
def fromBuffer(cls, src, shape, dtype, **kwargs):

View file

@ -123,7 +123,7 @@ class MetalGraph:
icb_command.setBarrier()
self.read_resources, self.write_resources = dedup(read_resources), dedup(write_resources)
self.command_buffer: Any = None
self.int_buf_view = self.int_buf.buffer_view() # TODO: this is metal syncing when it doesn't need to
self.int_buf_view = self.int_buf.toCPU() # TODO: this is metal syncing when it doesn't need to
def __call__(self, input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
# NOTE: you at least can't update the ints if this is running

View file

@ -123,7 +123,7 @@ class Tensor:
def numpy(self) -> np.ndarray:
assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}"
assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}"
return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().reshape(self.shape)
return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().astype(self.dtype.np, copy=True).reshape(self.shape)
def item(self) -> Union[float, int]: return self.numpy().item()
def to(self, device:Optional[str]) -> Tensor: