mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
handle float16 overflow in PYTHON (#5022)
* handle float16 overflow in PYTHON use `truncate` when constructing tensor from list to make sure all values are packable (might be slow, but should be correct). add truncate_fp16 to cast overflowed values to inf/-inf. * all valid fmt supports truncate
This commit is contained in:
parent
c0139b05d8
commit
03b367c014
5 changed files with 34 additions and 8 deletions
|
|
@ -7,7 +7,7 @@ from tinygrad.dtype import DType, dtypes, ImageDType
|
|||
from tinygrad.helpers import all_same, getenv, flatten
|
||||
from tinygrad.device import Compiled, Compiler, Allocator
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, truncate
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer
|
||||
|
||||
|
|
@ -110,6 +110,8 @@ class PythonProgram:
|
|||
if dtypes.is_int(dtype):
|
||||
overflow_adjust = 2**(dtype.itemsize*8 - 1) if not dtypes.is_unsigned(dtype) else 0
|
||||
casted = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) for x in casted]
|
||||
elif dtypes.is_float(dtype):
|
||||
casted = [truncate.get(dtype, lambda dt: dt)(x) for x in casted]
|
||||
ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted)))
|
||||
elif uop is UOps.LOAD:
|
||||
if isinstance(dtp[0], ImageDType):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue