mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
test_sd_big_conv
This commit is contained in:
parent
178ba50c03
commit
7a61dc7ee9
2 changed files with 4 additions and 3 deletions
|
|
@ -207,12 +207,12 @@ class TestOps(unittest.TestCase):
|
|||
arg = (4,3,2,6)
|
||||
helper_test_op([(4,3,1,6)], lambda x: x.expand(arg), lambda x: x.expand(shape=arg))
|
||||
|
||||
@unittest.skip
|
||||
@unittest.skipUnless(Device.DEFAULT != "GPU", "GPU doesn't work with convs with virtual dimensions > 2**31")
|
||||
def test_sd_big_conv(self):
|
||||
# internal shape (1, 1, 512, 62, 62, 512, 3, 3) overflows a int
|
||||
helper_test_op([(1,512,64,64), (512,512,3,3)],
|
||||
helper_test_op([(1,256,64,64), (512,256,3,3)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x, w),
|
||||
lambda x,w: x.conv2d(w), atol=1e-4)
|
||||
lambda x,w: x.conv2d(w), atol=1e-2)
|
||||
|
||||
def test_large_bs_conv(self):
|
||||
# large batch size can cause OpenCL image to exceed max image height on macOS
|
||||
|
|
|
|||
|
|
@ -120,6 +120,7 @@ class GPUBuffer:
|
|||
|
||||
def contiguous_view_constant_fold(x, name:str) -> Tuple[str, Optional[str], str]:
|
||||
if x._base_shape == (1,) and x._backing is not None:
|
||||
# this function doesn't need a memory access
|
||||
return f"inline float get_{name}(int gid) {{ int valid = 1; int idx = gid; {x.st.expr().replace('//', '/')}; return valid ? {x._backing[0]} : 0.0;}}", None, f"get_{name}(idx);"
|
||||
else:
|
||||
return x.contiguous_view(name), f"__global const float *{name}_g", f"get_{name}({name}_g, idx);"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue