fix passing sum_acc_dtype="" to Tensor.sum should fail (#7748)

This commit is contained in:
chenyu 2024-11-17 10:58:41 -05:00 committed by GitHub
commit 55707fd00d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 7 additions and 1 deletions

View file

@ -1011,6 +1011,12 @@ class TestOps(unittest.TestCase):
self.helper_test_exception([()], lambda x: x.sum(1), lambda x: x.sum(1), expected=IndexError)
self.helper_test_exception([()], lambda x: x.sum((1,)), lambda x: x.sum((1,)), expected=IndexError)
def test_sum_acc_dtype(self):
helper_test_op([(45,3)], lambda x: x.sum(), lambda x: x.sum(acc_dtype=dtypes.float32))
if is_dtype_supported(dtypes.float64): helper_test_op([(45,3)], lambda x: x.sum(dtype=torch.float64), lambda x: x.sum(acc_dtype=dtypes.float64))
with self.assertRaises(AttributeError): Tensor([1.0, 2.0]).sum(acc_dtype="")
def test_sum_with_zeros_shape(self):
helper_test_op([(4, 0)], lambda x: x.sum(axis=(0,)))
helper_test_op([(4, 0)], lambda x: x.sum(axis=(1,)))

View file

@ -1526,7 +1526,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
print(t.sum(axis=1).numpy())
```
"""
ret = self.cast(acc_dtype or sum_acc_dtype(self.dtype))._reduce(F.Sum, axis, keepdim)
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(F.Sum, axis, keepdim)
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):