mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
oversized expand for HLOP convs
This commit is contained in:
parent
2edfe64512
commit
f8f026e8bb
3 changed files with 10 additions and 2 deletions
|
|
@ -91,6 +91,7 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid)
|
||||
def test_softplus(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6)
|
||||
@unittest.skip("not supported in older pytorch")
|
||||
def test_gelu(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu)
|
||||
def test_quick_gelu(self):
|
||||
|
|
|
|||
|
|
@ -44,6 +44,10 @@ class TestSymbolic(unittest.TestCase):
|
|||
@unittest.skip("mod max is wrong")
|
||||
def test_mod_factor(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 50, "(((a*100)+(b*50))%100)")
|
||||
|
||||
@unittest.skip("this doesn't work yet")
|
||||
def test_mod_mul(self):
|
||||
self.helper_test_variable((Variable("a", 0, 6)*10)%9, 0, 6, "a")
|
||||
|
||||
def test_sum_0(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)]), 0, 7, "a")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import functools, itertools
|
||||
import math, functools, itertools
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union
|
||||
from tinygrad.helpers import prod, argfix, make_pair, getenv, DEBUG
|
||||
|
|
@ -284,7 +284,10 @@ class Tensor:
|
|||
oy = (iy - dy * (ky-1) - 1)//sy + 1
|
||||
ox = (ix - dx * (kx-1) - 1)//sx + 1
|
||||
# duplicate the inputs for each of the kernels
|
||||
xup = self.reshape(bs, c, 1, iy, 1, ix).expand(bs, c, ky, iy, kx, ix).reshape(bs, c, ky*iy, kx*ix)
|
||||
#xup = self.reshape(bs, c, 1, iy, 1, ix).expand(bs, c, ky, iy, kx, ix).reshape(bs, c, ky*iy, kx*ix)
|
||||
# NOTE: if you oversize this, you can avoid the ZeroView creation. remove when optimizer can fix
|
||||
ey, ex = math.ceil(ky*(iy+dy) / iy), math.ceil(kx*(ix+dx) / ix)
|
||||
xup = self.reshape(bs, c, 1, iy, 1, ix).expand(bs, c, ey, iy, ex, ix).reshape(bs, c, ey*iy, ex*ix)
|
||||
# slide by dilation
|
||||
xup = xup.slice(((0,bs), (0,c), (0,ky*(iy+dy)), (0,kx*(ix+dx))))
|
||||
xup = xup.reshape(bs, c, ky, iy+dy, kx, ix+dx)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue