mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
3 commits
master
...
symbolic-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d30e3b3cad | ||
|
|
41e8a3f2b2 | ||
|
|
3fafc4d670 |
3 changed files with 41 additions and 6 deletions
|
|
@ -95,6 +95,8 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
|
|||
## Workflow Rules
|
||||
|
||||
- **NEVER commit without explicit user approval** - always show the diff and wait for approval
|
||||
- **NEVER amend commits** - always create a new commit instead
|
||||
- Run `pre-commit run --all-files` before committing to catch linting/type errors
|
||||
- Run tests before proposing commits
|
||||
- Test with `SPEC=2` when modifying UOp-related code
|
||||
|
||||
|
|
|
|||
|
|
@ -95,6 +95,37 @@ class TestTensorVariable(unittest.TestCase):
|
|||
assert t.uop.base.buffer.size == 30
|
||||
assert t.uop.shape == (3, vb)
|
||||
|
||||
def test_symbolic_chunk(self):
|
||||
# chunk should work when split dimension is concrete, even if other dims are symbolic
|
||||
vv = Variable("a", 1, 10).bind(4)
|
||||
t = Tensor.ones(10, 8).contiguous()[:vv, :] # shape (vv, 8)
|
||||
chunks = t.chunk(2, dim=-1) # split along concrete dim 8
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0].shape[1] == 4
|
||||
assert chunks[1].shape[1] == 4
|
||||
# verify the values by shrinking to concrete shape first
|
||||
np.testing.assert_equal(chunks[0].shrink(((0, 4), (0, 4))).numpy(), np.ones((4, 4)))
|
||||
np.testing.assert_equal(chunks[1].shrink(((0, 4), (0, 4))).numpy(), np.ones((4, 4)))
|
||||
|
||||
def test_symbolic_split(self):
|
||||
# split should work when split dimension is concrete, even if other dims are symbolic
|
||||
vv = Variable("a", 1, 10).bind(3)
|
||||
t = Tensor.arange(30).reshape(10, 3).contiguous()[:, :vv] # shape (10, vv)
|
||||
splits = t.split(5, dim=0) # split along concrete dim 10
|
||||
assert len(splits) == 2
|
||||
assert splits[0].shape[0] == 5
|
||||
assert splits[1].shape[0] == 5
|
||||
# verify the values by shrinking to concrete shape first
|
||||
np.testing.assert_equal(splits[0].shrink(((0, 5), (0, 3))).numpy(), np.arange(30).reshape(10, 3)[:5, :3])
|
||||
np.testing.assert_equal(splits[1].shrink(((0, 5), (0, 3))).numpy(), np.arange(30).reshape(10, 3)[5:, :3])
|
||||
|
||||
def test_symbolic_chunk_error_on_symbolic_dim(self):
|
||||
# chunk should fail when trying to split along a symbolic dimension
|
||||
vv = Variable("a", 1, 10).bind(4)
|
||||
t = Tensor.ones(10, 8).contiguous()[:vv, :] # shape (vv, 8)
|
||||
with self.assertRaises(AssertionError):
|
||||
t.chunk(2, dim=0) # can't split along symbolic dim
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1334,10 +1334,11 @@ class Tensor(OpMixin):
|
|||
print("\\n".join([repr(x.numpy()) for x in split]))
|
||||
```
|
||||
"""
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
dim = self._resolve_dim(dim)
|
||||
if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))]
|
||||
assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
|
||||
dim_sz = self.shape[dim]
|
||||
assert isinstance(dim_sz, int), f"does not support symbolic shape in split dimension {dim}: {self.shape}"
|
||||
if isinstance(sizes, int): sizes = [min(sizes, dim_sz-i) for i in range(0, max(1, dim_sz), max(1, sizes))]
|
||||
assert sum(sizes) == dim_sz, f"expect sizes to sum exactly to {dim_sz}, but got {sum(sizes)}"
|
||||
return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
|
||||
|
||||
def chunk(self, chunks:int, dim:int=0) -> list[Tensor]:
|
||||
|
|
@ -1359,10 +1360,11 @@ class Tensor(OpMixin):
|
|||
print("\\n".join([repr(x.numpy()) for x in chunked]))
|
||||
```
|
||||
"""
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
|
||||
dim = self._resolve_dim(dim)
|
||||
return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim))
|
||||
dim_sz = self.shape[dim]
|
||||
assert isinstance(dim_sz, int), f"does not support symbolic shape in split dimension {dim}: {self.shape}"
|
||||
assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
|
||||
return list(self.split(ceildiv(dim_sz, chunks) if dim_sz else [0]*chunks, dim=dim))
|
||||
|
||||
def unfold(self, dim:int, size:sint, step:int) -> Tensor:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue