mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
1884f021e3
commit
08706c2ea4
1 changed files with 3 additions and 2 deletions
|
|
@ -502,7 +502,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|||
else: had_counter = True
|
||||
|
||||
# if shape has 0, return zero tensor
|
||||
if (num := ceildiv(((num_ := prod(shape)) * dtype.itemsize), 4)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
|
||||
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
|
||||
num = ceildiv(numel * dtype.itemsize, 4)
|
||||
|
||||
# increment rng counter for devices
|
||||
if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()
|
||||
|
|
@ -520,7 +521,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|||
one = Tensor.ones_like(bits, device=bits.device, dtype=dtype).bitcast(uint_dtype)
|
||||
bits = bits.rshift((dtype.itemsize * 8) - nmant).bitwise_or(one)
|
||||
# bitcast back to the original dtype and reshape
|
||||
out = bits.bitcast(dtype)[:num_].sub(1).reshape(shape)
|
||||
out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape)
|
||||
|
||||
# move back to the original device if we were using MOCKGPU
|
||||
if getenv("MOCKGPU") and _device: out = out.to(_device)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue