mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
remove allow_shape_mismatch in Tensor.replace (#14536)
move all logic to torch_backend and not hacking Tensor method
This commit is contained in:
parent
ec2b6bbda8
commit
9052db678f
4 changed files with 92 additions and 4 deletions
|
|
@ -29,6 +29,10 @@ def wrap(x: Tensor) -> torch.Tensor:
|
|||
x._strides = strides_for_shape(x.shape) # always recalculate
|
||||
if (not hasattr(x, '_storage_offset')) or (not x.uop.is_realized): x._storage_offset = calculate_storage_offset(x)
|
||||
return mod.wrap(x, _to_torch_dtype(x.dtype), _to_torch_device(x.device).index)
|
||||
def _update_torch_metadata(tensor: torch.Tensor, tiny: Tensor) -> None:
|
||||
tiny._strides = strides_for_shape(tiny.shape)
|
||||
tiny._storage_offset = calculate_storage_offset(tiny)
|
||||
mod.update_metadata(tensor, tiny.shape, tiny._strides, tiny._storage_offset)
|
||||
def unwrap(x:torch.Tensor) -> Tensor:
|
||||
assert isinstance(x, torch.Tensor), f"x isn't {type(x)}"
|
||||
return mod.unwrap(x)
|
||||
|
|
@ -611,7 +615,10 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
|||
"aten.fill_.Tensor": lambda self, value: Tensor.full(self.shape, value.reshape(()).item(), device=self.device, dtype=self.dtype),
|
||||
"aten.flip": Tensor.flip,
|
||||
"aten.scatter_reduce.two": Tensor.scatter_reduce,
|
||||
"aten.squeeze_.dim": lambda self, dim: self.replace(self.squeeze(dim), allow_shape_mismatch=True), # TODO: inplace view op, here?
|
||||
"aten.squeeze_.dim": Tensor.squeeze,
|
||||
"aten.unsqueeze_": Tensor.unsqueeze,
|
||||
"aten.transpose_": Tensor.transpose,
|
||||
"aten.t_": Tensor.transpose,
|
||||
"aten.add.Tensor": lambda input,other,alpha=1: input+alpha*other,
|
||||
"aten.linspace": lambda start, stop, steps, dtype=None, **kwargs:
|
||||
Tensor.linspace(start, stop, steps, **({"dtype": _from_torch_dtype(dtype)} if dtype is not None else {})),
|
||||
|
|
@ -655,6 +662,13 @@ inplace_ops = {
|
|||
"aten.masked_fill_.Tensor",
|
||||
}
|
||||
|
||||
inplace_view_ops = {
|
||||
"aten.squeeze_.dim",
|
||||
"aten.unsqueeze_",
|
||||
"aten.transpose_",
|
||||
"aten.t_",
|
||||
}
|
||||
|
||||
def wrap_fxn(k,f):
|
||||
def nf(*args, **kwargs):
|
||||
if TORCH_DEBUG:
|
||||
|
|
@ -675,8 +689,42 @@ def wrap_inplace(k,f):
|
|||
return orig
|
||||
return nf
|
||||
|
||||
def wrap_inplace_view_op(k,f):
|
||||
def nf(*args, **kwargs):
|
||||
orig = args[0]
|
||||
args, kwargs = unwrap_args(args, kwargs)
|
||||
target = args[0]
|
||||
new_view = f(*args, **kwargs)
|
||||
if new_view is target or new_view.uop is target.uop:
|
||||
_update_torch_metadata(orig, target)
|
||||
return orig
|
||||
base = canonical_base(target)
|
||||
op = (f, args[1:], kwargs)
|
||||
if target is base:
|
||||
views = derived_views(base)
|
||||
if views:
|
||||
old_base = Tensor(base.uop, device=base.device)
|
||||
old_base.requires_grad = base.requires_grad
|
||||
old_base._views = getattr(base, "_views", set())
|
||||
for v in views: v._view_base = old_base
|
||||
base._views = set()
|
||||
base._view_base = old_base
|
||||
base._view_ops = [op]
|
||||
old_base._views.add(weakref.ref(base))
|
||||
else:
|
||||
target._view_base = base
|
||||
base._views = getattr(base, "_views", set())
|
||||
base._views.add(weakref.ref(target))
|
||||
target._view_ops = _get_view_ops(target) + [op]
|
||||
target.uop = new_view.uop
|
||||
_update_torch_metadata(orig, target)
|
||||
return orig
|
||||
return nf
|
||||
|
||||
for k,v in tiny_backend.items():
|
||||
wrapper = wrap_inplace if k in inplace_ops else wrap_fxn
|
||||
if k in inplace_view_ops: wrapper = wrap_inplace_view_op
|
||||
elif k in inplace_ops: wrapper = wrap_inplace
|
||||
else: wrapper = wrap_fxn
|
||||
torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrapper(k,v))
|
||||
|
||||
@torch.library.impl("aten::equal", "privateuseone")
|
||||
|
|
|
|||
|
|
@ -67,5 +67,37 @@ class TestTorchBackendInplace(unittest.TestCase):
|
|||
d += torch.arange(4)
|
||||
np.testing.assert_array_equal(a.cpu(), torch.arange(4).cpu())
|
||||
|
||||
def test_inplace_view_metadata(self):
|
||||
a = torch.arange(6, dtype=torch.float32).reshape(1, 2, 3)
|
||||
ret = a.squeeze_(0)
|
||||
self.assertIs(ret, a)
|
||||
self.assertEqual(a.shape, torch.Size([2, 3]))
|
||||
ret = a.unsqueeze_(1)
|
||||
self.assertIs(ret, a)
|
||||
self.assertEqual(a.shape, torch.Size([2, 1, 3]))
|
||||
ret = a.transpose_(0, 2)
|
||||
self.assertIs(ret, a)
|
||||
self.assertEqual(a.shape, torch.Size([3, 1, 2]))
|
||||
|
||||
def test_t_inplace_metadata(self):
|
||||
a = torch.arange(6, dtype=torch.float32).reshape(2, 3)
|
||||
ret = a.t_()
|
||||
self.assertIs(ret, a)
|
||||
self.assertEqual(a.shape, torch.Size([3, 2]))
|
||||
expected = torch.arange(6, dtype=torch.float32).reshape(2, 3).t()
|
||||
np.testing.assert_array_equal(a.cpu().numpy(), expected.cpu().numpy())
|
||||
|
||||
def test_squeeze_matmul(self):
|
||||
# squeeze_ is used internally by PyTorch for vector-matrix matmul (unsqueeze -> mm -> squeeze_)
|
||||
a = torch.arange(65, dtype=torch.float32)
|
||||
b = torch.arange(65*45, dtype=torch.float32).reshape(65, 45)
|
||||
result = a.matmul(b)
|
||||
self.assertEqual(result.shape, torch.Size([45]))
|
||||
# verify correctness
|
||||
a_cpu = torch.arange(65, dtype=torch.float32, device='cpu')
|
||||
b_cpu = torch.arange(65*45, dtype=torch.float32, device='cpu').reshape(65, 45)
|
||||
expected = a_cpu.matmul(b_cpu)
|
||||
np.testing.assert_allclose(result.cpu().numpy(), expected.numpy(), rtol=1e-4, atol=1e-4)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -131,7 +131,15 @@ py::object unwrap_tensor(const at::Tensor &tensor) {
|
|||
return py::reinterpret_borrow<py::object>(tiny->ptr(getPyInterpreter()));
|
||||
}
|
||||
|
||||
void update_metadata(const at::Tensor &tensor, const std::vector<int64_t> &sizes,
|
||||
const std::vector<int64_t> &strides, int64_t storage_offset) {
|
||||
auto* impl = tensor.unsafeGetTensorImpl();
|
||||
impl->set_allow_tensor_metadata_change(true);
|
||||
impl->set_sizes_and_strides(sizes, strides, storage_offset);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("wrap", &wrap_tensor);
|
||||
m.def("unwrap", &unwrap_tensor);
|
||||
m.def("update_metadata", &update_metadata);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -276,12 +276,12 @@ class Tensor(OpMixin):
|
|||
run_schedule(*Tensor.schedule_with_vars(*to_realize), do_update_stats=do_update_stats)
|
||||
return self
|
||||
|
||||
def replace(self, x:Tensor, allow_shape_mismatch=False) -> Tensor:
|
||||
def replace(self, x:Tensor) -> Tensor:
|
||||
"""
|
||||
Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
|
||||
"""
|
||||
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
|
||||
assert self.shape == x.shape or allow_shape_mismatch, f"replace shape mismatch {self.shape} != {x.shape}"
|
||||
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
|
||||
self.uop = x.uop
|
||||
return self
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue