mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
use truncate in onnx read_int64 [pr] (#11720)
This commit is contained in:
parent
50e789e290
commit
b67345caa3
1 changed files with 2 additions and 4 deletions
|
|
@ -5,7 +5,7 @@ from io import BufferedReader
|
|||
from tinygrad.nn.state import TensorIO
|
||||
from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr
|
||||
from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple, argsort, is_numpy_ndarray, get_single_element, polyN
|
||||
from tinygrad.dtype import DType, ConstType, dtypes, _from_np_dtype
|
||||
from tinygrad.dtype import DType, ConstType, dtypes, _from_np_dtype, truncate
|
||||
from tinygrad.device import is_dtype_supported, Device
|
||||
|
||||
# ***** protobuf definitions ******
|
||||
|
|
@ -105,9 +105,7 @@ class PBBufferedReader(BufferedReader):
|
|||
def read_bytes(self) -> Tensor: return self.read_delimited(use_tensor=True)
|
||||
def read_float(self) -> float: return struct.unpack("<f", self.read(4))[0]
|
||||
def read_packed_floats(self) -> Tensor: return self.read_delimited(use_tensor=True)
|
||||
def read_int64(self) -> int:
|
||||
val = self.decode_varint()
|
||||
return val - 2**64 if val & (1 << 63) else val
|
||||
def read_int64(self) -> int: return truncate[dtypes.int64](self.decode_varint())
|
||||
def read_packed_int64s(self) -> list[int]:
|
||||
total_bytes_len = self.decode_varint()
|
||||
old_pos = self.tell()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue