reintroduce merge views in update benchmark (#3279)

* Reapply "take merge views from corsix branch" (#3278)

This reverts commit d298916232.

* reintroduce merge views
This commit is contained in:
George Hotz 2024-01-30 09:47:20 -08:00 committed by GitHub
commit 09f2952dc3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 85 additions and 6 deletions

View file

@ -1,6 +1,6 @@
import unittest
from tinygrad import Tensor, GlobalCounters
from tinygrad.helpers import Timing, CI, Profiling, WINO
from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG
from tinygrad.ops import LoadOps
from tinygrad.codegen.linearizer import Linearizer
@ -28,6 +28,12 @@ class TestWinograd(unittest.TestCase):
l = Linearizer(s.ast)
l.hand_coded_optimizations()
l.linearize()
if DEBUG >= 2: print(f"{len(l.sts):4d} shapetrackers with max {max(len(x.views) for x in l.sts)} views")
for st in l.sts:
assert len(st.views) <= 2, "too many views in winograd"
if DEBUG >= 3:
print(f"{len(st.views):3d} views")
for v in st.views: print(v)
def test_profile(self):
x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()

View file

@ -1,19 +1,86 @@
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
from __future__ import annotations
import functools
import functools, math
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, Set, cast, Iterable, Union
from tinygrad.helpers import merge_dicts, getenv
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, sint
from tinygrad.shape.view import View
from tinygrad.shape.view import View, strides_for_shape
def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
strides = strides_for_shape(shape)
result = []
for stride in strides:
here = offs // stride if stride else 0
result.append(here)
offs -= here * stride
return result
@functools.lru_cache(maxsize=None)
def merge_views(vm2:View, vm1:View) -> Optional[View]:
if vm1.contiguous and vm1.shape == vm2.shape: return vm2
if vm2.contiguous: return vm1
if vm2.mask or vm1.offset != 0: return None # this isn't supported yet
if None in (strides := ShapeTracker((vm2, vm1)).real_strides()): return None
return View.create(vm1.shape, cast(Tuple[sint, ...], strides), vm2.offset, vm1.mask)
if not vm2.mask and vm1.offset == 0 and None not in (rstrides := ShapeTracker((vm2, vm1)).real_strides()):
return View.create(vm1.shape, cast(Tuple[sint, ...], rstrides), vm2.offset, vm1.mask)
if vm1.mask:
for b,e in vm1.mask:
if not (b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
return (merged := merge_views(vm2, vm1.shrink(vm1.mask))) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
# Project vm1's offset and strides on to vm2.
origin = un1d(vm2.shape, vm1.offset)
terms: List[List[Tuple[int, sint]]] = [[] for _ in origin]
strides: List[sint] = [0] * len(vm1.shape)
for d1, st in enumerate(vm1.strides):
if st == 0: continue
for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))):
if (s1 := s1 - o) == 0: continue
terms[d2].append((d1, s1))
strides[d1] += s1 * vm2.strides[d2]
# Merge dimensions in vm2 if required.
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
merged_size, merged_term = 1, NumNode(0)
extents: List[Tuple[sint, Node]] = []
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
merged_term += Variable.sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
merged_size *= s
if not (merged_term >= merged_size) and not (merged_term < 0):
extents.append((merged_size, merged_term))
merged_size, merged_term = 1, NumNode(0)
if merged_term: return None
if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
return (reshaped_vm2 := vm2.reshape(vm2_shape)) and merge_views(reshaped_vm2, vm1)
if vm2.mask:
# Try to project vm2's mask on to vm1.
newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))):
if not (t.min < b or t.max >= e): continue
if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int):
bad = True
continue
term = terms[d2]
if len(term) != 1:
if not term and newe: newe[0] = 0
else: bad = True
continue
d1, s1 = term[0]
if not isinstance(s1, int) or not isinstance(newe[d1], int):
bad = True
continue
newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
# If any of vm1 was masked off, try again with that mask in place.
for b, e, s in zip(newb, newe, vm1.shape):
if b != 0 or e != s:
return merge_views(vm2, View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe))))
# Otherwise if vm2's mask was violated, then cannot merge.
if bad: return None
return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
def _expr_view(view:View, idxs:List[Node], valid:Optional[Node]=None) -> Tuple[Node, Node]:
assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"

View file

@ -83,6 +83,12 @@ class View:
def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
# if any dimension has size >1, but is masked such that only one index in the dimension is unmasked
# then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset
#if mask and any(elim := [isinstance(b, int) and isinstance(e, int) and b+1 >= e for b,e in mask]):
# if any(b >= e for b,e in mask): strides, offset, mask = (0,) * len(shape), 0, ((0,0),) * len(shape)
# offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim))
# strides = tuple(0 if e else st for st,e in zip(strides, elim))
return View(shape, strides, offset, mask, contiguous)
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none