fix caformer onnx run (#16222)

This commit is contained in:
George Hotz 2026-05-15 15:08:36 -07:00 committed by GitHub
commit 2549b14ec2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 5 additions and 2 deletions

View file

@ -21,6 +21,8 @@ def compile(onnx_file):
# TODO this seems dumb
input_types = {k:(dtypes.float32 if v is dtypes.float16 else v) for k,v in input_types.items()}
Tensor.manual_seed(100)
# replace symbolic dimensions (e.g. 'b' for dynamic batch) with 1
input_shapes = {k:tuple(s if isinstance(s, int) else 1 for s in shp) for k,shp in input_shapes.items()}
inputs = {k:Tensor(Tensor.randn(*shp, dtype=input_types[k]).mul(8).realize().numpy(), device='NPY') for k,shp in sorted(input_shapes.items())}
if not getenv("NPY_IMG"):
inputs = {k:Tensor(v.numpy(), device=Device.DEFAULT).realize() if 'img' in k else v for k,v in inputs.items()}

View file

@ -908,12 +908,13 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
return x * scale.reshape(1, -1, *[1] * (x.ndim-2)) + bias.reshape(1, -1, *[1] * (x.ndim-2))
def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05):
return GroupNormalization(x, scale, bias, num_groups=cast(int, x.shape[1]), epsilon=epsilon)
def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1):
def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor|None=None, axis:int=-1, epsilon:float=1e-05, stash_type:int=1):
assert stash_type == 1, "only float32 is supported"
axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim))
mean = (x32:=x.cast(dtypes.float)).mean(axis=axes, keepdim=True)
inv_std_dev = (x32.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt()
return (x32.sub(mean)*inv_std_dev).cast(x.dtype).mul(scale).add(bias), mean, inv_std_dev
ret = (x32.sub(mean)*inv_std_dev).cast(x.dtype).mul(scale)
return (ret.add(bias) if bias is not None else ret), mean, inv_std_dev
def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12):
x = x + skip
if bias is not None: x = x + bias