handle float16 overflow in PYTHON (#5022)

* handle float16 overflow in PYTHON

use `truncate` when constructing tensor from list to make sure all values are packable (might be slow, but should be correct). add truncate_fp16 to cast overflowed values to inf/-inf.

* all valid fmt supports truncate
This commit is contained in:
chenyu 2024-06-17 21:12:52 -04:00 committed by GitHub
commit 03b367c014
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 34 additions and 8 deletions

View file

@ -7,7 +7,7 @@ from tinygrad.dtype import DType, dtypes, ImageDType
from tinygrad.helpers import all_same, getenv, flatten
from tinygrad.device import Compiled, Compiler, Allocator
from tinygrad.codegen.uops import UOpGraph, UOps
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, truncate
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer
@ -110,6 +110,8 @@ class PythonProgram:
if dtypes.is_int(dtype):
overflow_adjust = 2**(dtype.itemsize*8 - 1) if not dtypes.is_unsigned(dtype) else 0
casted = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) for x in casted]
elif dtypes.is_float(dtype):
casted = [truncate.get(dtype, lambda dt: dt)(x) for x in casted]
ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted)))
elif uop is UOps.LOAD:
if isinstance(dtp[0], ImageDType):