update shifts in torch backend (#15622)

This commit is contained in:
chenyu 2026-04-06 14:08:33 -04:00 committed by GitHub
commit 6e30a5f5ea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -511,7 +511,7 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
"aten.fmod.Tensor_out": lambda input,other: input-input.div(other, rounding_mode="trunc")*other,
# TODO: this might result in overflow issues
"aten.round.decimals_out": lambda self,decimals: (self*10**decimals).round()/10**decimals,
# TODO: support this in tinygrad
# TODO: support this in tinygrad. shift by Tensor not supported
"aten.bitwise_left_shift.Tensor_out": lambda x,y: x*(2**y),
"aten.bitwise_right_shift.Tensor_out": lambda x,y: x//(2**y),
# not in tinygrad. are there decomps for these?
@ -555,11 +555,10 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
"aten.remainder.Scalar_Tensor": lambda x,y: x%y,
"aten.floor_divide": lambda x,y: x//y,
"aten.floor_divide_.Tensor": lambda x,y: x//y,
# TODO: use tinygrad methods, but they require x to be unsigned
"aten.__lshift__.Scalar": lambda x,y: x*(2**y),
"aten.__ilshift__.Scalar": lambda x,y: x*(2**y),
"aten.__rshift__.Scalar": lambda x,y: x//(2**y),
"aten.__irshift__.Scalar": lambda x,y: x//(2**y),
"aten.__lshift__.Scalar": lambda x,y: x<<y,
"aten.__ilshift__.Scalar": lambda x,y: x<<y,
"aten.__rshift__.Scalar": lambda x,y: x>>y,
"aten.__irshift__.Scalar": lambda x,y: x>>y,
# inplace ops using replace for fusion
"aten.zero_": lambda x: x.zeros_like(),
"aten.fill_.Scalar": lambda x, y: x.full_like(y),