type annotation for layernorm (#6883)

This commit is contained in:
chenyu 2024-10-04 09:03:56 -04:00 committed by GitHub
commit 4c3895744e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -3042,7 +3042,7 @@ class Tensor:
"""
return functools.reduce(lambda x,f: f(x), ll, self)
def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor:
def layernorm(self, axis:Union[int,Tuple[int,...]]=-1, eps:float=1e-5) -> Tensor:
"""
Applies Layer Normalization over a mini-batch of inputs.