Compare commits

...

3 commits

Author SHA1 Message Date
George Hotz
d30e3b3cad fix dim resolution order in split/chunk
Ensure dim_sz is retrieved after dim is resolved, not before.
The previous one-liner evaluated self.shape[dim] with the original
unresolved dim value.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-16 12:37:41 -04:00
George Hotz
41e8a3f2b2 update CLAUDE.md: add pre-commit and no-amend rules
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-16 12:30:00 -04:00
George Hotz
3fafc4d670 support symbolic shapes in split/chunk when split dim is concrete
Previously split() and chunk() required all dimensions to be concrete.
Now they only require the dimension being split to be concrete, allowing
them to work with tensors that have symbolic shapes in other dimensions.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-16 12:28:24 -04:00
3 changed files with 41 additions and 6 deletions

View file

@ -95,6 +95,8 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
## Workflow Rules
- **NEVER commit without explicit user approval** - always show the diff and wait for approval
- **NEVER amend commits** - always create a new commit instead
- Run `pre-commit run --all-files` before committing to catch linting/type errors
- Run tests before proposing commits
- Test with `SPEC=2` when modifying UOp-related code

View file

@ -95,6 +95,37 @@ class TestTensorVariable(unittest.TestCase):
assert t.uop.base.buffer.size == 30
assert t.uop.shape == (3, vb)
def test_symbolic_chunk(self):
# chunk should work when split dimension is concrete, even if other dims are symbolic
vv = Variable("a", 1, 10).bind(4)
t = Tensor.ones(10, 8).contiguous()[:vv, :] # shape (vv, 8)
chunks = t.chunk(2, dim=-1) # split along concrete dim 8
assert len(chunks) == 2
assert chunks[0].shape[1] == 4
assert chunks[1].shape[1] == 4
# verify the values by shrinking to concrete shape first
np.testing.assert_equal(chunks[0].shrink(((0, 4), (0, 4))).numpy(), np.ones((4, 4)))
np.testing.assert_equal(chunks[1].shrink(((0, 4), (0, 4))).numpy(), np.ones((4, 4)))
def test_symbolic_split(self):
# split should work when split dimension is concrete, even if other dims are symbolic
vv = Variable("a", 1, 10).bind(3)
t = Tensor.arange(30).reshape(10, 3).contiguous()[:, :vv] # shape (10, vv)
splits = t.split(5, dim=0) # split along concrete dim 10
assert len(splits) == 2
assert splits[0].shape[0] == 5
assert splits[1].shape[0] == 5
# verify the values by shrinking to concrete shape first
np.testing.assert_equal(splits[0].shrink(((0, 5), (0, 3))).numpy(), np.arange(30).reshape(10, 3)[:5, :3])
np.testing.assert_equal(splits[1].shrink(((0, 5), (0, 3))).numpy(), np.arange(30).reshape(10, 3)[5:, :3])
def test_symbolic_chunk_error_on_symbolic_dim(self):
# chunk should fail when trying to split along a symbolic dimension
vv = Variable("a", 1, 10).bind(4)
t = Tensor.ones(10, 8).contiguous()[:vv, :] # shape (vv, 8)
with self.assertRaises(AssertionError):
t.chunk(2, dim=0) # can't split along symbolic dim
if __name__ == '__main__':
unittest.main()

View file

@ -1334,10 +1334,11 @@ class Tensor(OpMixin):
print("\\n".join([repr(x.numpy()) for x in split]))
```
"""
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
dim = self._resolve_dim(dim)
if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))]
assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
dim_sz = self.shape[dim]
assert isinstance(dim_sz, int), f"does not support symbolic shape in split dimension {dim}: {self.shape}"
if isinstance(sizes, int): sizes = [min(sizes, dim_sz-i) for i in range(0, max(1, dim_sz), max(1, sizes))]
assert sum(sizes) == dim_sz, f"expect sizes to sum exactly to {dim_sz}, but got {sum(sizes)}"
return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
def chunk(self, chunks:int, dim:int=0) -> list[Tensor]:
@ -1359,10 +1360,11 @@ class Tensor(OpMixin):
print("\\n".join([repr(x.numpy()) for x in chunked]))
```
"""
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
dim = self._resolve_dim(dim)
return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim))
dim_sz = self.shape[dim]
assert isinstance(dim_sz, int), f"does not support symbolic shape in split dimension {dim}: {self.shape}"
assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
return list(self.split(ceildiv(dim_sz, chunks) if dim_sz else [0]*chunks, dim=dim))
def unfold(self, dim:int, size:sint, step:int) -> Tensor:
"""