remove allow_shape_mismatch in Tensor.replace (#14536)

move all logic to torch_backend and not hacking Tensor method
This commit is contained in:
chenyu 2026-02-04 12:38:18 -05:00 committed by GitHub
commit 9052db678f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 92 additions and 4 deletions

View file

@ -29,6 +29,10 @@ def wrap(x: Tensor) -> torch.Tensor:
x._strides = strides_for_shape(x.shape) # always recalculate
if (not hasattr(x, '_storage_offset')) or (not x.uop.is_realized): x._storage_offset = calculate_storage_offset(x)
return mod.wrap(x, _to_torch_dtype(x.dtype), _to_torch_device(x.device).index)
def _update_torch_metadata(tensor: torch.Tensor, tiny: Tensor) -> None:
tiny._strides = strides_for_shape(tiny.shape)
tiny._storage_offset = calculate_storage_offset(tiny)
mod.update_metadata(tensor, tiny.shape, tiny._strides, tiny._storage_offset)
def unwrap(x:torch.Tensor) -> Tensor:
assert isinstance(x, torch.Tensor), f"x isn't {type(x)}"
return mod.unwrap(x)
@ -611,7 +615,10 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
"aten.fill_.Tensor": lambda self, value: Tensor.full(self.shape, value.reshape(()).item(), device=self.device, dtype=self.dtype),
"aten.flip": Tensor.flip,
"aten.scatter_reduce.two": Tensor.scatter_reduce,
"aten.squeeze_.dim": lambda self, dim: self.replace(self.squeeze(dim), allow_shape_mismatch=True), # TODO: inplace view op, here?
"aten.squeeze_.dim": Tensor.squeeze,
"aten.unsqueeze_": Tensor.unsqueeze,
"aten.transpose_": Tensor.transpose,
"aten.t_": Tensor.transpose,
"aten.add.Tensor": lambda input,other,alpha=1: input+alpha*other,
"aten.linspace": lambda start, stop, steps, dtype=None, **kwargs:
Tensor.linspace(start, stop, steps, **({"dtype": _from_torch_dtype(dtype)} if dtype is not None else {})),
@ -655,6 +662,13 @@ inplace_ops = {
"aten.masked_fill_.Tensor",
}
inplace_view_ops = {
"aten.squeeze_.dim",
"aten.unsqueeze_",
"aten.transpose_",
"aten.t_",
}
def wrap_fxn(k,f):
def nf(*args, **kwargs):
if TORCH_DEBUG:
@ -675,8 +689,42 @@ def wrap_inplace(k,f):
return orig
return nf
def wrap_inplace_view_op(k,f):
def nf(*args, **kwargs):
orig = args[0]
args, kwargs = unwrap_args(args, kwargs)
target = args[0]
new_view = f(*args, **kwargs)
if new_view is target or new_view.uop is target.uop:
_update_torch_metadata(orig, target)
return orig
base = canonical_base(target)
op = (f, args[1:], kwargs)
if target is base:
views = derived_views(base)
if views:
old_base = Tensor(base.uop, device=base.device)
old_base.requires_grad = base.requires_grad
old_base._views = getattr(base, "_views", set())
for v in views: v._view_base = old_base
base._views = set()
base._view_base = old_base
base._view_ops = [op]
old_base._views.add(weakref.ref(base))
else:
target._view_base = base
base._views = getattr(base, "_views", set())
base._views.add(weakref.ref(target))
target._view_ops = _get_view_ops(target) + [op]
target.uop = new_view.uop
_update_torch_metadata(orig, target)
return orig
return nf
for k,v in tiny_backend.items():
wrapper = wrap_inplace if k in inplace_ops else wrap_fxn
if k in inplace_view_ops: wrapper = wrap_inplace_view_op
elif k in inplace_ops: wrapper = wrap_inplace
else: wrapper = wrap_fxn
torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrapper(k,v))
@torch.library.impl("aten::equal", "privateuseone")

View file

@ -67,5 +67,37 @@ class TestTorchBackendInplace(unittest.TestCase):
d += torch.arange(4)
np.testing.assert_array_equal(a.cpu(), torch.arange(4).cpu())
def test_inplace_view_metadata(self):
a = torch.arange(6, dtype=torch.float32).reshape(1, 2, 3)
ret = a.squeeze_(0)
self.assertIs(ret, a)
self.assertEqual(a.shape, torch.Size([2, 3]))
ret = a.unsqueeze_(1)
self.assertIs(ret, a)
self.assertEqual(a.shape, torch.Size([2, 1, 3]))
ret = a.transpose_(0, 2)
self.assertIs(ret, a)
self.assertEqual(a.shape, torch.Size([3, 1, 2]))
def test_t_inplace_metadata(self):
a = torch.arange(6, dtype=torch.float32).reshape(2, 3)
ret = a.t_()
self.assertIs(ret, a)
self.assertEqual(a.shape, torch.Size([3, 2]))
expected = torch.arange(6, dtype=torch.float32).reshape(2, 3).t()
np.testing.assert_array_equal(a.cpu().numpy(), expected.cpu().numpy())
def test_squeeze_matmul(self):
# squeeze_ is used internally by PyTorch for vector-matrix matmul (unsqueeze -> mm -> squeeze_)
a = torch.arange(65, dtype=torch.float32)
b = torch.arange(65*45, dtype=torch.float32).reshape(65, 45)
result = a.matmul(b)
self.assertEqual(result.shape, torch.Size([45]))
# verify correctness
a_cpu = torch.arange(65, dtype=torch.float32, device='cpu')
b_cpu = torch.arange(65*45, dtype=torch.float32, device='cpu').reshape(65, 45)
expected = a_cpu.matmul(b_cpu)
np.testing.assert_allclose(result.cpu().numpy(), expected.numpy(), rtol=1e-4, atol=1e-4)
if __name__ == "__main__":
unittest.main()

View file

@ -131,7 +131,15 @@ py::object unwrap_tensor(const at::Tensor &tensor) {
return py::reinterpret_borrow<py::object>(tiny->ptr(getPyInterpreter()));
}
void update_metadata(const at::Tensor &tensor, const std::vector<int64_t> &sizes,
const std::vector<int64_t> &strides, int64_t storage_offset) {
auto* impl = tensor.unsafeGetTensorImpl();
impl->set_allow_tensor_metadata_change(true);
impl->set_sizes_and_strides(sizes, strides, storage_offset);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("wrap", &wrap_tensor);
m.def("unwrap", &unwrap_tensor);
m.def("update_metadata", &update_metadata);
}

View file

@ -276,12 +276,12 @@ class Tensor(OpMixin):
run_schedule(*Tensor.schedule_with_vars(*to_realize), do_update_stats=do_update_stats)
return self
def replace(self, x:Tensor, allow_shape_mismatch=False) -> Tensor:
def replace(self, x:Tensor) -> Tensor:
"""
Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
"""
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
assert self.shape == x.shape or allow_shape_mismatch, f"replace shape mismatch {self.shape} != {x.shape}"
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
self.uop = x.uop
return self