mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
temp fix for symbolic shape view add [pr] (#8337)
something is still wrong with symbolic shape shrink, but it should not recurse forever
This commit is contained in:
parent
791a80a1c7
commit
2bf47b75da
2 changed files with 9 additions and 2 deletions
|
|
@ -31,7 +31,13 @@ class TestSymbolic(unittest.TestCase):
|
|||
def test_merge_view_recursion_err(self):
|
||||
vm2 = View(shape=(Variable('j', 1, 10),), strides=(0,), offset=0, mask=None, contiguous=False)
|
||||
vm1 = View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True)
|
||||
vm2.__add__(vm1)
|
||||
self.assertEqual(vm2+vm1, vm1)
|
||||
|
||||
def test_merge_view_recursion_err2(self):
|
||||
vm2 = View(shape=(Variable('a', 1, 10).bind(4),), strides=(0,), offset=0, mask=None, contiguous=False)
|
||||
vm1 = View(shape=(Variable('a', 1, 10).bind(4),), strides=(1,), offset=0, mask=((0, Variable('a', 1, 10).bind(4)),), contiguous=False)
|
||||
# TODO: this should not be None?
|
||||
self.assertEqual(vm2+vm1, None)
|
||||
|
||||
def test_cat_dim0_strides(self):
|
||||
i = Variable("i", 1, 5).bind(3)
|
||||
|
|
|
|||
|
|
@ -161,7 +161,8 @@ class View:
|
|||
if vm1.contiguous and vm1.shape == vm2.shape: return vm2
|
||||
if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
|
||||
if vm1.mask:
|
||||
if (merged := vm2 + vm1.shrink(vm1.mask)) is None: return None
|
||||
# TODO: why is shrink no changing the view for symbolic shape
|
||||
if (new_vm1 := vm1.shrink(vm1.mask)) == vm1 or (merged := vm2 + new_vm1) is None: return None
|
||||
return merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
|
||||
if not all_int(vm1.shape): return None
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue