mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix bitwise_not for signed int (#8117)
-1 is correct because 2**32-1 is not within int32 range, so in some case clang casts the whole thing into uint32
This commit is contained in:
parent
e22d7b6fb0
commit
c814de2dd4
2 changed files with 3 additions and 1 deletions
|
|
@ -851,6 +851,8 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(10,20)], lambda x: x.argmax(0, False).type(torch.int32), lambda x: x.argmax(0, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(1, False).type(torch.int32), lambda x: x.argmax(1, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(1, True).type(torch.int32), lambda x: x.argmax(1, True), forward_only=True)
|
||||
# regression test for bitwise_not then argmax
|
||||
helper_test_op(None, lambda x: (~x).argmax().type(torch.int32), lambda x: (~x).argmax(), forward_only=True, vals=[[2, 2]])
|
||||
|
||||
def test_argmin(self):
|
||||
# check if it returns the first index for multiple occurences
|
||||
|
|
|
|||
|
|
@ -3152,7 +3152,7 @@ class Tensor(SimpleMathTrait):
|
|||
```
|
||||
"""
|
||||
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
||||
return self.logical_not() if self.dtype == dtypes.bool else self ^ ((1<<8*self.dtype.itemsize)-1)
|
||||
return self.logical_not() if self.dtype == dtypes.bool else self ^ -1
|
||||
|
||||
def lshift(self, x:int):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue