mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix DiskDevice reuse (#11039)
* fix DiskDevice reuse * fix mypy and DiskDevice.count * mypy * add test --------- Co-authored-by: b1tg <b1tg@users.noreply.github.com>
This commit is contained in:
parent
5628e2054c
commit
fcbefde8f5
2 changed files with 14 additions and 4 deletions
|
|
@ -1,4 +1,4 @@
|
|||
import pathlib, tempfile, unittest
|
||||
import os, pathlib, tempfile, unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.dtype import DType
|
||||
|
|
@ -410,5 +410,13 @@ class TestPathTensor(unittest.TestCase):
|
|||
self.assertEqual(t_cpu.device, "CPU")
|
||||
np.testing.assert_array_equal(t_cpu.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
|
||||
|
||||
def test_path_tensor_disk_device_bug(self):
|
||||
test_file = pathlib.Path(self.temp_dir.name) / "disk_device_bug"
|
||||
with open(test_file, "wb") as f: f.write(bytes(range(10)))
|
||||
os.chmod(test_file, 0o000)
|
||||
with self.assertRaises(PermissionError):
|
||||
Tensor(pathlib.Path(test_file)).tolist()
|
||||
os.chmod(test_file, 0o644)
|
||||
assert Tensor(pathlib.Path(test_file)).tolist(), list(range(10))
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -16,10 +16,11 @@ class DiskDevice(Compiled):
|
|||
self.fd: Optional[int] = None
|
||||
self.count = 0
|
||||
super().__init__(device, DiskAllocator(self), None, None, None)
|
||||
def _might_open(self, size):
|
||||
self.count += 1
|
||||
def _might_open(self, size:int):
|
||||
assert self.size is None or size <= self.size, f"can't reopen Disk tensor with larger size, opened with {self.size}, tried to open with {size}"
|
||||
if self.size is not None: return
|
||||
if self.size is not None and hasattr(self.device, "mem"):
|
||||
self.count += 1
|
||||
return
|
||||
filename = self.device[len("disk:"):]
|
||||
self.size = size
|
||||
|
||||
|
|
@ -34,6 +35,7 @@ class DiskDevice(Compiled):
|
|||
self.mem = mmap.mmap(self.fd, self.size)
|
||||
if hasattr(self.mem, 'madvise') and (hp := getattr(mmap, "MADV_HUGEPAGE", None)) is not None:
|
||||
with contextlib.suppress(OSError): self.mem.madvise(hp) # some systems have transparent_hugepage disabled
|
||||
self.count += 1
|
||||
def _might_close(self):
|
||||
self.count -= 1
|
||||
if self.count == 0:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue