fix bitwise_not for signed int (#8117)

-1 is correct because 2**32-1 is not within int32 range, so in some case clang casts the whole thing into uint32
This commit is contained in:
chenyu 2024-12-09 02:02:51 -05:00 committed by GitHub
commit c814de2dd4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 3 additions and 1 deletions

View file

@ -851,6 +851,8 @@ class TestOps(unittest.TestCase):
helper_test_op([(10,20)], lambda x: x.argmax(0, False).type(torch.int32), lambda x: x.argmax(0, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmax(1, False).type(torch.int32), lambda x: x.argmax(1, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmax(1, True).type(torch.int32), lambda x: x.argmax(1, True), forward_only=True)
# regression test for bitwise_not then argmax
helper_test_op(None, lambda x: (~x).argmax().type(torch.int32), lambda x: (~x).argmax(), forward_only=True, vals=[[2, 2]])
def test_argmin(self):
# check if it returns the first index for multiple occurences

View file

@ -3152,7 +3152,7 @@ class Tensor(SimpleMathTrait):
```
"""
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return self.logical_not() if self.dtype == dtypes.bool else self ^ ((1<<8*self.dtype.itemsize)-1)
return self.logical_not() if self.dtype == dtypes.bool else self ^ -1
def lshift(self, x:int):
"""