mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix AffineGrid fusion (#16439)
This commit is contained in:
parent
ef50a49693
commit
8ac62b28e5
1 changed files with 1 additions and 1 deletions
|
|
@ -1000,7 +1000,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
|||
if align_corners: return Tensor.linspace(-1, 1, steps, device=theta.device)
|
||||
return Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device)
|
||||
grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims))
|
||||
base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1)
|
||||
base_grid = Tensor.stack(*reversed(grids), grids[0].const_like(1), dim=-1)
|
||||
base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1)
|
||||
return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue