mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
faster wino compile by catting consts across data expand dim (#3293)
* PoC faster wino compile by catting consts across data expand dim * fix fusions * faster + golf it * noqa 501 * implicit broadcast * Revert "implicit broadcast" This reverts commit 5915a9083d045ec1e6be84dcb492333325d48666. * shorter * shorter * oops * 216 upcasts is probably fine * wino kernel count test * test winograd number of sts * specify device for apply_matrix mat elements
This commit is contained in:
parent
cf6f478901
commit
aebaab011f
3 changed files with 15 additions and 7 deletions
|
|
@ -332,28 +332,32 @@ class TestHandCodedOpts(unittest.TestCase):
|
|||
|
||||
def test_masked_upcast_wino_full(self):
|
||||
with Context(WINO=1):
|
||||
x,w = Tensor.rand(1,4,9,9, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize()
|
||||
x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize()
|
||||
out = Tensor.conv2d(x,w, padding=1)
|
||||
upcasts = []
|
||||
wino_schedule = out.lazydata.schedule()
|
||||
# collect upcasts of tile transform kernels
|
||||
for i, si in enumerate(out.lazydata.schedule()):
|
||||
for i, si in enumerate(wino_schedule):
|
||||
k = Linearizer(si.ast)
|
||||
k.hand_coded_optimizations()
|
||||
if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel)
|
||||
if len(k.bufs) < 100: continue # not a tile transform kernel (there's a permute kernel at the end)
|
||||
if len(k.bufs) < 36: continue # not a tile transform kernel (there's a permute kernel at the end)
|
||||
upcasts.append(tuple(k.full_shape[k.shape_len - k.upcasted:k.shape_len]))
|
||||
assert len(upcasts) == 3 # 3 transformation matrices
|
||||
# TODO: what did this fix?
|
||||
assert len(wino_schedule) <= 4 # 4 kernels
|
||||
# this test case's inputs are too small, so one of the 4-stacks became a local, which is fine i guess
|
||||
assert upcasts.count((6, 6)) == 2 #and upcasts.count((4, 4)) == 1
|
||||
|
||||
out.mean().backward()
|
||||
for si in x.grad.lazydata.schedule() + w.grad.lazydata.schedule():
|
||||
backward_schedule = x.grad.lazydata.schedule() + w.grad.lazydata.schedule()
|
||||
for si in backward_schedule:
|
||||
k = Linearizer(si.ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
if len(k.bufs) < 20: continue # not a tile transform kernel
|
||||
# heuristic number to make sure that at least some upcasts but not too many upcasts are being done
|
||||
assert 6 <= prod(k.full_shape[k.shape_len - k.upcasted:k.shape_len]) <= 49
|
||||
assert 6 <= prod(k.full_shape[k.shape_len - k.upcasted:k.shape_len]) <= 216
|
||||
assert len(backward_schedule) <= 13 # just the current number, but it could be better
|
||||
|
||||
def test_masked_upcast_many(self):
|
||||
layer_1 = Tensor.cat(Tensor.rand(3, 4), Tensor.rand(4, 4))
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ class TestWinograd(unittest.TestCase):
|
|||
l = Linearizer(s.ast)
|
||||
l.hand_coded_optimizations()
|
||||
l.linearize()
|
||||
assert len(l.sts) <= 256 # just the current value to prevent regression
|
||||
if DEBUG >= 2: print(f"{len(l.sts):4d} shapetrackers with max {max(len(x.views) for x in l.sts)} views")
|
||||
for st in l.sts:
|
||||
assert len(st.views) <= 2, "too many views in winograd"
|
||||
|
|
|
|||
|
|
@ -653,7 +653,10 @@ class Tensor:
|
|||
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
|
||||
|
||||
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
|
||||
def apply_matrix(mat, t, dim=0): return t if dim == len(HW) else Tensor.stack([apply_matrix(mat, sum(mm*t[j] for j,mm in enumerate(m) if mm), dim=dim+1) for m in mat]) # noqa: E501
|
||||
def apply_matrix(mat, t, dims=len(HW)):
|
||||
t_ = t.reshape(t.shape[:dims]+(1,)*dims+t.shape[dims:]).expand(t.shape[:dims]+(len(mat),)*dims+t.shape[dims:])
|
||||
matcols = [[Tensor.cat(*[Tensor(float(m[k]), device=t.device).reshape((1,) * len(t.shape)).expand(t_.shape[dims:dims+dim]+(1,)+t_.shape[dims+dim+1:]) for m in mat], dim=dim) for k in range(len(mat[0]))] for dim in range(dims)] # noqa: E501
|
||||
return sum(prod([matcols[dim][mat_is[dim]] for dim in range(dims)]) * t_[mat_is] for mat_is in itertools.product(*[range(len(mat[0])) for _ in range(dims)])) # noqa: E501
|
||||
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
|
||||
winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
|
||||
winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue