fix AffineGrid fusion (#16439)

This commit is contained in:
chenyu 2026-05-29 17:59:47 -04:00 committed by GitHub
commit 8ac62b28e5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1000,7 +1000,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
if align_corners: return Tensor.linspace(-1, 1, steps, device=theta.device)
return Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device)
grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims))
base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1)
base_grid = Tensor.stack(*reversed(grids), grids[0].const_like(1), dim=-1)
base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1)
return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1)