mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
add ConstantOfShape
This commit is contained in:
parent
082134952b
commit
9d6b63f043
2 changed files with 8 additions and 2 deletions
|
|
@ -29,12 +29,14 @@ ONNXLIMIT = getenv("ONNXLIMIT", -1)
|
|||
def get_run_onnx(onnx_model):
|
||||
def shape_to_tuple(s): return tuple(x.dim_value for x in s.dim)
|
||||
def buffer_parse(inp):
|
||||
if inp.data_type in (1,10,7):
|
||||
if inp.data_type in (1,10,6,7):
|
||||
# TODO: this is shared with below
|
||||
if len(inp.float_data) > 0:
|
||||
ret = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
|
||||
elif len(inp.int64_data) > 0:
|
||||
ret = Tensor(np.array(inp.int64_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
|
||||
elif len(inp.int32_data) > 0:
|
||||
ret = Tensor(np.array(inp.int32_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
|
||||
else:
|
||||
ret = Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).reshape(inp.dims).astype(np.float32).copy(), requires_grad=False)
|
||||
else:
|
||||
|
|
@ -42,7 +44,7 @@ def get_run_onnx(onnx_model):
|
|||
return ret
|
||||
|
||||
def attribute_parse(a):
|
||||
if a.type == 7: return tuple([int(x) for x in a.ints])
|
||||
if a.type in [6,7]: return tuple([int(x) for x in a.ints])
|
||||
elif a.type == 4: return buffer_parse(a.t) # TENSOR
|
||||
elif a.type == 3: return str(a.s)
|
||||
elif a.type == 2: return int(a.i)
|
||||
|
|
|
|||
|
|
@ -152,6 +152,10 @@ def Tile(input, repeats):
|
|||
def Range(start, limit, delta): return Tensor.arange(safe_numpy(limit)[0], safe_numpy(start)[0], safe_numpy(delta)[0])
|
||||
def Where(condition, X, Y): return condition*X + (1-condition)*Y
|
||||
|
||||
def ConstantOfShape(input, value=0.0):
|
||||
shape = [int(x) for x in safe_numpy(input)]
|
||||
return Tensor.ones(*shape) * value
|
||||
|
||||
# NOTE: since we only have one type, this is valid!
|
||||
def CastLike(input, target_type):
|
||||
assert isinstance(target_type, Tensor), "can only CastLike Tensor"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue