mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
revert layernorm to have axis param
This commit is contained in:
parent
dc80bf6f85
commit
dec5334da9
1 changed files with 1 additions and 2 deletions
|
|
@ -312,8 +312,7 @@ class Tensor:
|
|||
|
||||
def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return functools.reduce(lambda x,f: f(x), ll, self)
|
||||
|
||||
def layernorm(self, eps=1e-5):
|
||||
axis = range(1, len(self.shape))
|
||||
def layernorm(self, axis=-1, eps=1e-5):
|
||||
y = (self - self.mean(axis=axis, keepdim=True))
|
||||
return y.div((y*y).mean(axis=axis, keepdim=True).add(eps).sqrt())
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue