add ConstantOfShape

This commit is contained in:
George Hotz 2023-02-27 10:57:50 -08:00
commit 9d6b63f043
2 changed files with 8 additions and 2 deletions

View file

@ -29,12 +29,14 @@ ONNXLIMIT = getenv("ONNXLIMIT", -1)
def get_run_onnx(onnx_model):
def shape_to_tuple(s): return tuple(x.dim_value for x in s.dim)
def buffer_parse(inp):
if inp.data_type in (1,10,7):
if inp.data_type in (1,10,6,7):
# TODO: this is shared with below
if len(inp.float_data) > 0:
ret = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
elif len(inp.int64_data) > 0:
ret = Tensor(np.array(inp.int64_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
elif len(inp.int32_data) > 0:
ret = Tensor(np.array(inp.int32_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
else:
ret = Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).reshape(inp.dims).astype(np.float32).copy(), requires_grad=False)
else:
@ -42,7 +44,7 @@ def get_run_onnx(onnx_model):
return ret
def attribute_parse(a):
if a.type == 7: return tuple([int(x) for x in a.ints])
if a.type in [6,7]: return tuple([int(x) for x in a.ints])
elif a.type == 4: return buffer_parse(a.t) # TENSOR
elif a.type == 3: return str(a.s)
elif a.type == 2: return int(a.i)

View file

@ -152,6 +152,10 @@ def Tile(input, repeats):
def Range(start, limit, delta): return Tensor.arange(safe_numpy(limit)[0], safe_numpy(start)[0], safe_numpy(delta)[0])
def Where(condition, X, Y): return condition*X + (1-condition)*Y
def ConstantOfShape(input, value=0.0):
shape = [int(x) for x in safe_numpy(input)]
return Tensor.ones(*shape) * value
# NOTE: since we only have one type, this is valid!
def CastLike(input, target_type):
assert isinstance(target_type, Tensor), "can only CastLike Tensor"