mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
ba1d3baae8
commit
2d8b802958
2 changed files with 13 additions and 22 deletions
|
|
@ -13,35 +13,26 @@ class TestWinograd(unittest.TestCase):
|
|||
def test_forward_kernels(self):
|
||||
x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
|
||||
out = Tensor.conv2d(x,w)
|
||||
self.assertEqual(len(out.schedule_linear().src), 2)
|
||||
self.assertEqual(len(out.schedule_linear().src), 4)
|
||||
|
||||
def test_backward_kernels(self):
|
||||
x,w = Tensor.empty(1,4,9,9).realize(), Tensor.empty(4,4,3,3).realize()
|
||||
out = Tensor.conv2d(x,w, padding=1)
|
||||
out.mean().backward()
|
||||
backward_schedule = x.grad.schedule_linear(w.grad)
|
||||
self.assertEqual(len(backward_schedule.src), 2)
|
||||
self.assertEqual(len(backward_schedule.src), 4)
|
||||
|
||||
@unittest.skip("this requires optimizations")
|
||||
def test_counters(self):
|
||||
IC, OC, X, Y = 4,4,9,9
|
||||
x,w = Tensor.rand(1,IC,Y,X).realize(), Tensor.rand(OC,IC,3,3).realize()
|
||||
IC, OC, H = 64, 64, 28
|
||||
x,w = Tensor.empty(1,IC,H,H,device="NULL").realize(), Tensor.empty(OC,IC,3,3,device="NULL").realize()
|
||||
GlobalCounters.reset()
|
||||
with Context(WINO=1):
|
||||
Tensor.conv2d(x,w).realize()
|
||||
ops_wino, mem_wino = GlobalCounters.global_ops, GlobalCounters.global_mem
|
||||
with Context(NOOPT=0, WINO=1): Tensor.conv2d(x,w).realize()
|
||||
ops_wino = GlobalCounters.global_ops
|
||||
GlobalCounters.reset()
|
||||
with Context(WINO=0):
|
||||
Tensor.conv2d(x,w).realize()
|
||||
ops_normal, mem_normal = GlobalCounters.global_ops, GlobalCounters.global_mem
|
||||
|
||||
ops_ratio, mem_ratio = ops_wino/ops_normal, mem_wino/mem_normal
|
||||
print(f"ops: normal {ops_normal:9d} wino {ops_wino:9d} ratio {ops_ratio:.2f}")
|
||||
print(f"mem: normal {mem_normal:9d} wino {mem_wino:9d} ratio {mem_ratio:.2f}")
|
||||
|
||||
# TODO: what's optimal on this?
|
||||
self.assertLess(ops_ratio, 4.3)
|
||||
self.assertLess(mem_ratio, 4)
|
||||
with Context(NOOPT=0, WINO=0): Tensor.conv2d(x,w).realize()
|
||||
ops_normal = GlobalCounters.global_ops
|
||||
print(f"ops: normal {ops_normal} wino {ops_wino} ratio {ops_wino/ops_normal:.2f}")
|
||||
self.assertLess(ops_wino/ops_normal, 0.6)
|
||||
|
||||
def test_dtype(self):
|
||||
IC, OC, X, Y = 4,4,9,9
|
||||
|
|
|
|||
|
|
@ -686,11 +686,11 @@ class Tensor(RandMixin):
|
|||
|
||||
g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front
|
||||
|
||||
# compute 6x6 winograd tiles: GgGt, BtdB
|
||||
# compute 6x6 winograd tiles: GgGt, BtdB. contiguous so the transforms are materialized once
|
||||
# (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1))
|
||||
gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
|
||||
gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).contiguous().reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx)))
|
||||
# (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx)
|
||||
dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx)
|
||||
dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).contiguous().reshape(*HWI, bs, groups, 1, cin, *tyx)
|
||||
|
||||
# matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx)
|
||||
ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), dtype=dtype), len(HW))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue