oversized expand for HLOP convs

This commit is contained in:
George Hotz 2023-02-24 21:48:47 -08:00
commit f8f026e8bb
3 changed files with 10 additions and 2 deletions

View file

@ -91,6 +91,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid)
def test_softplus(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6)
@unittest.skip("not supported in older pytorch")
def test_gelu(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu)
def test_quick_gelu(self):

View file

@ -44,6 +44,10 @@ class TestSymbolic(unittest.TestCase):
@unittest.skip("mod max is wrong")
def test_mod_factor(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 50, "(((a*100)+(b*50))%100)")
@unittest.skip("this doesn't work yet")
def test_mod_mul(self):
self.helper_test_variable((Variable("a", 0, 6)*10)%9, 0, 6, "a")
def test_sum_0(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)]), 0, 7, "a")

View file

@ -1,6 +1,6 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import functools, itertools
import math, functools, itertools
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union
from tinygrad.helpers import prod, argfix, make_pair, getenv, DEBUG
@ -284,7 +284,10 @@ class Tensor:
oy = (iy - dy * (ky-1) - 1)//sy + 1
ox = (ix - dx * (kx-1) - 1)//sx + 1
# duplicate the inputs for each of the kernels
xup = self.reshape(bs, c, 1, iy, 1, ix).expand(bs, c, ky, iy, kx, ix).reshape(bs, c, ky*iy, kx*ix)
#xup = self.reshape(bs, c, 1, iy, 1, ix).expand(bs, c, ky, iy, kx, ix).reshape(bs, c, ky*iy, kx*ix)
# NOTE: if you oversize this, you can avoid the ZeroView creation. remove when optimizer can fix
ey, ex = math.ceil(ky*(iy+dy) / iy), math.ceil(kx*(ix+dx) / ix)
xup = self.reshape(bs, c, 1, iy, 1, ix).expand(bs, c, ey, iy, ex, ix).reshape(bs, c, ey*iy, ex*ix)
# slide by dilation
xup = xup.slice(((0,bs), (0,c), (0,ky*(iy+dy)), (0,kx*(ix+dx))))
xup = xup.reshape(bs, c, ky, iy+dy, kx, ix+dx)