faster wino compile by catting consts across data expand dim (#3293)

* PoC faster wino compile by catting consts across data expand dim

* fix fusions

* faster + golf it

* noqa 501

* implicit broadcast

* Revert "implicit broadcast"

This reverts commit 5915a9083d045ec1e6be84dcb492333325d48666.

* shorter

* shorter

* oops

* 216 upcasts is probably fine

* wino kernel count test

* test winograd number of sts

* specify device for apply_matrix mat elements
This commit is contained in:
David Hou 2024-02-02 00:47:45 -08:00 committed by GitHub
commit aebaab011f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 15 additions and 7 deletions

View file

@ -332,28 +332,32 @@ class TestHandCodedOpts(unittest.TestCase):
def test_masked_upcast_wino_full(self):
with Context(WINO=1):
x,w = Tensor.rand(1,4,9,9, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize()
x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize()
out = Tensor.conv2d(x,w, padding=1)
upcasts = []
wino_schedule = out.lazydata.schedule()
# collect upcasts of tile transform kernels
for i, si in enumerate(out.lazydata.schedule()):
for i, si in enumerate(wino_schedule):
k = Linearizer(si.ast)
k.hand_coded_optimizations()
if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel)
if len(k.bufs) < 100: continue # not a tile transform kernel (there's a permute kernel at the end)
if len(k.bufs) < 36: continue # not a tile transform kernel (there's a permute kernel at the end)
upcasts.append(tuple(k.full_shape[k.shape_len - k.upcasted:k.shape_len]))
assert len(upcasts) == 3 # 3 transformation matrices
# TODO: what did this fix?
assert len(wino_schedule) <= 4 # 4 kernels
# this test case's inputs are too small, so one of the 4-stacks became a local, which is fine i guess
assert upcasts.count((6, 6)) == 2 #and upcasts.count((4, 4)) == 1
out.mean().backward()
for si in x.grad.lazydata.schedule() + w.grad.lazydata.schedule():
backward_schedule = x.grad.lazydata.schedule() + w.grad.lazydata.schedule()
for si in backward_schedule:
k = Linearizer(si.ast)
k.hand_coded_optimizations()
k.linearize()
if len(k.bufs) < 20: continue # not a tile transform kernel
# heuristic number to make sure that at least some upcasts but not too many upcasts are being done
assert 6 <= prod(k.full_shape[k.shape_len - k.upcasted:k.shape_len]) <= 49
assert 6 <= prod(k.full_shape[k.shape_len - k.upcasted:k.shape_len]) <= 216
assert len(backward_schedule) <= 13 # just the current number, but it could be better
def test_masked_upcast_many(self):
layer_1 = Tensor.cat(Tensor.rand(3, 4), Tensor.rand(4, 4))

View file

@ -28,6 +28,7 @@ class TestWinograd(unittest.TestCase):
l = Linearizer(s.ast)
l.hand_coded_optimizations()
l.linearize()
assert len(l.sts) <= 256 # just the current value to prevent regression
if DEBUG >= 2: print(f"{len(l.sts):4d} shapetrackers with max {max(len(x.views) for x in l.sts)} views")
for st in l.sts:
assert len(st.views) <= 2, "too many views in winograd"

View file

@ -653,7 +653,10 @@ class Tensor:
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW)))
# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308
def apply_matrix(mat, t, dim=0): return t if dim == len(HW) else Tensor.stack([apply_matrix(mat, sum(mm*t[j] for j,mm in enumerate(m) if mm), dim=dim+1) for m in mat]) # noqa: E501
def apply_matrix(mat, t, dims=len(HW)):
t_ = t.reshape(t.shape[:dims]+(1,)*dims+t.shape[dims:]).expand(t.shape[:dims]+(len(mat),)*dims+t.shape[dims:])
matcols = [[Tensor.cat(*[Tensor(float(m[k]), device=t.device).reshape((1,) * len(t.shape)).expand(t_.shape[dims:dims+dim]+(1,)+t_.shape[dims+dim+1:]) for m in mat], dim=dim) for k in range(len(mat[0]))] for dim in range(dims)] # noqa: E501
return sum(prod([matcols[dim][mat_is[dim]] for dim in range(dims)]) * t_[mat_is] for mat_is in itertools.product(*[range(len(mat[0])) for _ in range(dims)])) # noqa: E501
HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles
winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]]
winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]]