test_sd_big_conv

This commit is contained in:
George Hotz 2022-10-01 13:26:05 -04:00
commit 7a61dc7ee9
2 changed files with 4 additions and 3 deletions

View file

@ -207,12 +207,12 @@ class TestOps(unittest.TestCase):
arg = (4,3,2,6)
helper_test_op([(4,3,1,6)], lambda x: x.expand(arg), lambda x: x.expand(shape=arg))
@unittest.skip
@unittest.skipUnless(Device.DEFAULT != "GPU", "GPU doesn't work with convs with virtual dimensions > 2**31")
def test_sd_big_conv(self):
# internal shape (1, 1, 512, 62, 62, 512, 3, 3) overflows a int
helper_test_op([(1,512,64,64), (512,512,3,3)],
helper_test_op([(1,256,64,64), (512,256,3,3)],
lambda x,w: torch.nn.functional.conv2d(x, w),
lambda x,w: x.conv2d(w), atol=1e-4)
lambda x,w: x.conv2d(w), atol=1e-2)
def test_large_bs_conv(self):
# large batch size can cause OpenCL image to exceed max image height on macOS

View file

@ -120,6 +120,7 @@ class GPUBuffer:
def contiguous_view_constant_fold(x, name:str) -> Tuple[str, Optional[str], str]:
if x._base_shape == (1,) and x._backing is not None:
# this function doesn't need a memory access
return f"inline float get_{name}(int gid) {{ int valid = 1; int idx = gid; {x.st.expr().replace('//', '/')}; return valid ? {x._backing[0]} : 0.0;}}", None, f"get_{name}(idx);"
else:
return x.contiguous_view(name), f"__global const float *{name}_g", f"get_{name}({name}_g, idx);"