mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix passing sum_acc_dtype="" to Tensor.sum should fail (#7748)
This commit is contained in:
parent
f18296e23c
commit
55707fd00d
2 changed files with 7 additions and 1 deletions
|
|
@ -1011,6 +1011,12 @@ class TestOps(unittest.TestCase):
|
|||
self.helper_test_exception([()], lambda x: x.sum(1), lambda x: x.sum(1), expected=IndexError)
|
||||
self.helper_test_exception([()], lambda x: x.sum((1,)), lambda x: x.sum((1,)), expected=IndexError)
|
||||
|
||||
def test_sum_acc_dtype(self):
|
||||
helper_test_op([(45,3)], lambda x: x.sum(), lambda x: x.sum(acc_dtype=dtypes.float32))
|
||||
if is_dtype_supported(dtypes.float64): helper_test_op([(45,3)], lambda x: x.sum(dtype=torch.float64), lambda x: x.sum(acc_dtype=dtypes.float64))
|
||||
|
||||
with self.assertRaises(AttributeError): Tensor([1.0, 2.0]).sum(acc_dtype="")
|
||||
|
||||
def test_sum_with_zeros_shape(self):
|
||||
helper_test_op([(4, 0)], lambda x: x.sum(axis=(0,)))
|
||||
helper_test_op([(4, 0)], lambda x: x.sum(axis=(1,)))
|
||||
|
|
|
|||
|
|
@ -1526,7 +1526,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|||
print(t.sum(axis=1).numpy())
|
||||
```
|
||||
"""
|
||||
ret = self.cast(acc_dtype or sum_acc_dtype(self.dtype))._reduce(F.Sum, axis, keepdim)
|
||||
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(F.Sum, axis, keepdim)
|
||||
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
|
||||
|
||||
def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue