mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
ConstantOfShape ONNX test fixed. (#890)
* ConstantOfShape ONNX test fixed. * removed redundant if statement * value is optional and should default to a float32 tensor with value of 0 * fixed: default parameters are created at function definition, bad for mutable objects.
This commit is contained in:
parent
5feee9c94b
commit
301f7b54c6
2 changed files with 4 additions and 3 deletions
|
|
@ -38,7 +38,7 @@ def get_run_onnx(onnx_model: ModelProto):
|
|||
elif len(inp.int64_data) > 0:
|
||||
ret = Tensor(np.array(inp.int64_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
|
||||
elif len(inp.int32_data) > 0:
|
||||
ret = Tensor(np.array(inp.int32_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
|
||||
ret = Tensor(np.array(inp.int32_data, dtype=np.int32).reshape(inp.dims), requires_grad=False)
|
||||
else:
|
||||
ret = Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).reshape(inp.dims).astype(np.float32).copy(), requires_grad=False)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -205,9 +205,10 @@ def Or(x:Tensor, y:Tensor): return Where((x==y), x, Tensor.ones(*x.shape)).cast(
|
|||
def Xor(x:Tensor, y:Tensor): return Where((x==y), Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
|
||||
def Not(x:Tensor): return Where((x==1), Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
|
||||
|
||||
def ConstantOfShape(input, value=0.0):
|
||||
def ConstantOfShape(input, value:Tensor=None):
|
||||
if value is None: value=Tensor([0.0])
|
||||
shape = [int(x) for x in safe_numpy(input)]
|
||||
return Tensor.ones(*shape) * value
|
||||
return Tensor.ones(*shape, dtype=value.dtype) * (value if input.shape !=(0,) else 1)
|
||||
|
||||
# this is obviously wrong, but since we don't have types, it's better than nothing
|
||||
def Cast(input, to):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue