fix tinyfs (#14974)

* fix tinyfs

* fix that
This commit is contained in:
chenyu 2026-02-24 08:50:53 -05:00 committed by GitHub
commit 5fd4fc0c6d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 9 additions and 16 deletions

View file

@ -520,9 +520,8 @@ jobs:
run: time BENCHMARK_LOG=cifar AMD=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
- name: Run full CIFAR training steps w 6 GPUS
run: time BENCHMARK_LOG=cifar_6gpu AMD=1 DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
# this needs to be mocked and testable on a local machine
#- name: Test full tinyfs load
# run: TINYFS_ENDPOINT=10.0.52.11:6767 PYTHONPATH=. python extra/tinyfs/fetch_file.py --hash d734f5e3be9f1e9d863bfaa4fc6c1ef2 --len 175866113 --dest mapping.json --check
- name: Test full tinyfs load
run: TINYFS_ENDPOINT=10.0.52.11:6767 PYTHONPATH=. python extra/tinyfs/fetch_file.py --hash d734f5e3be9f1e9d863bfaa4fc6c1ef2 --len 175866113 --dest mapping.json --check
- name: Run process replay tests
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py

View file

@ -45,37 +45,31 @@ class TestTinyFS(unittest.TestCase):
cls._server.shutdown()
cls._server.server_close()
@unittest.expectedFailure
def test_store(self):
h = Tensor([1.0, 2.0, 3.0, 4.0]).fs_store().realize()
self.assertEqual(h.shape, (16,))
self.assertEqual(h.dtype, dtypes.uint8)
@unittest.expectedFailure
def test_store_deterministic(self):
a = Tensor([1.0, 2.0, 3.0, 4.0]).fs_store().realize()
b = Tensor([1.0, 2.0, 3.0, 4.0]).fs_store().realize()
np.testing.assert_array_equal(a.numpy(), b.numpy())
@unittest.expectedFailure
def test_store_different_data(self):
a = Tensor([1.0, 2.0, 3.0, 4.0]).fs_store().realize()
b = Tensor([5.0, 6.0, 7.0, 8.0]).fs_store().realize()
self.assertNotEqual(a.tolist(), b.tolist())
@unittest.expectedFailure
def test_roundtrip_uint8(self):
arr = np.arange(256, dtype=np.uint8)
loaded = Tensor(arr).fs_store().realize().fs_load(len(arr))
loaded = Tensor(arr).fs_store().realize().fs_load(len(arr)).to("CPU")
np.testing.assert_array_equal(loaded.numpy(), arr)
@unittest.expectedFailure
def test_roundtrip_multichunk_uint8(self):
arr = np.random.default_rng(42).integers(0, 256, size=Tensor.CHUNK_SIZE + 1024, dtype=np.uint8)
loaded = Tensor(arr).fs_store().realize().fs_load(len(arr))
loaded = Tensor(arr).fs_store().realize().fs_load(len(arr)).to("CPU")
np.testing.assert_array_equal(loaded.numpy(), arr)
@unittest.expectedFailure
def test_hash_matches_python_impl(self):
arr = np.arange(256, dtype=np.uint8)
h = Tensor(arr).fs_store().realize()

View file

@ -18,10 +18,10 @@ def tag_uop(ctx:AllocCtx, x:UOp):
def disk_copy_is_buffer(ctx:AllocCtx, u:UOp):
# copies to disk are replaced with the disk buffer
to_disk = isinstance(u._device, str) and u._device.startswith("DISK")
to_disk = isinstance(u._device, str) and u._device.startswith(("DISK", "TINYFS"))
if to_disk: ctx.buffer_map[u] = UOp.new_buffer(u.device, u.shard_size, u.dtype).reshape(u.max_shard_shape)
# all copies from disk/numpy are realized into a real buffer
from_creation = isinstance(u.src[0]._device, str) and any(u.src[0]._device.startswith(x) for x in ["NPY", "DISK", "PYTHON"])
from_creation = isinstance(u.src[0]._device, str) and any(u.src[0]._device.startswith(x) for x in ["NPY", "DISK", "PYTHON", "TINYFS"])
if from_creation: return tag_uop(ctx, u)
def apply_after(ctx:AllocCtx, u:UOp):
@ -41,8 +41,8 @@ add_tags = PatternMatcher([
def replace_contig_with_assign(u:UOp):
# if size is 0, remove the contig
if u.size == 0: return u.src[0]
# no real contig for DISK tensors, they are left alone
if isinstance(u._device, str) and u._device.startswith("DISK"): return u.rtag(None)
# no real contig for DISK/TINYFS tensors, they are left alone
if isinstance(u._device, str) and u._device.startswith(("DISK", "TINYFS")): return u.rtag(None)
dtype = u.dtype
if isinstance(dtype, ImageDType):
if prod(dtype.shape) != prod(u.max_shard_shape) or ([x for x in u.max_shard_shape if x != 1] or [1])[-1] % 4 != 0:
@ -113,7 +113,7 @@ def replace_input_buffer(ctx:AllocCtx, b:UOp):
pm_finalize_call = PatternMatcher([
(UPat(Ops.ASSIGN, name="x"), untag_and_append),
(UPat(Ops.AFTER, name="x"), append_after),
(UPat(Ops.COPY, name="x"), lambda ctx,x: append_after(ctx,x) if isinstance(x.device, str) and x.device.startswith("DISK") else None),
(UPat(Ops.COPY, name="x"), lambda ctx,x: append_after(ctx,x) if isinstance(x.device, str) and x.device.startswith(("DISK", "TINYFS")) else None),
# replace UNIQUE with LUNIQUE for CONST cache key normalization
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE, name="d")), name="b"), lambda b,d: b.replace(src=(d,))),
])