mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
reintroduce merge views in update benchmark (#3279)
* Reapply "take merge views from corsix branch" (#3278)
This reverts commit d298916232.
* reintroduce merge views
This commit is contained in:
parent
d298916232
commit
09f2952dc3
3 changed files with 85 additions and 6 deletions
|
|
@ -1,6 +1,6 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor, GlobalCounters
|
||||
from tinygrad.helpers import Timing, CI, Profiling, WINO
|
||||
from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
|
||||
|
|
@ -28,6 +28,12 @@ class TestWinograd(unittest.TestCase):
|
|||
l = Linearizer(s.ast)
|
||||
l.hand_coded_optimizations()
|
||||
l.linearize()
|
||||
if DEBUG >= 2: print(f"{len(l.sts):4d} shapetrackers with max {max(len(x.views) for x in l.sts)} views")
|
||||
for st in l.sts:
|
||||
assert len(st.views) <= 2, "too many views in winograd"
|
||||
if DEBUG >= 3:
|
||||
print(f"{len(st.views):3d} views")
|
||||
for v in st.views: print(v)
|
||||
|
||||
def test_profile(self):
|
||||
x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
|
||||
|
|
|
|||
|
|
@ -1,19 +1,86 @@
|
|||
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import functools, math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, Set, cast, Iterable, Union
|
||||
from tinygrad.helpers import merge_dicts, getenv
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, sint
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
|
||||
def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
|
||||
strides = strides_for_shape(shape)
|
||||
result = []
|
||||
for stride in strides:
|
||||
here = offs // stride if stride else 0
|
||||
result.append(here)
|
||||
offs -= here * stride
|
||||
return result
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def merge_views(vm2:View, vm1:View) -> Optional[View]:
|
||||
if vm1.contiguous and vm1.shape == vm2.shape: return vm2
|
||||
if vm2.contiguous: return vm1
|
||||
if vm2.mask or vm1.offset != 0: return None # this isn't supported yet
|
||||
if None in (strides := ShapeTracker((vm2, vm1)).real_strides()): return None
|
||||
return View.create(vm1.shape, cast(Tuple[sint, ...], strides), vm2.offset, vm1.mask)
|
||||
if not vm2.mask and vm1.offset == 0 and None not in (rstrides := ShapeTracker((vm2, vm1)).real_strides()):
|
||||
return View.create(vm1.shape, cast(Tuple[sint, ...], rstrides), vm2.offset, vm1.mask)
|
||||
if vm1.mask:
|
||||
for b,e in vm1.mask:
|
||||
if not (b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
|
||||
return (merged := merge_views(vm2, vm1.shrink(vm1.mask))) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
|
||||
|
||||
# Project vm1's offset and strides on to vm2.
|
||||
origin = un1d(vm2.shape, vm1.offset)
|
||||
terms: List[List[Tuple[int, sint]]] = [[] for _ in origin]
|
||||
strides: List[sint] = [0] * len(vm1.shape)
|
||||
for d1, st in enumerate(vm1.strides):
|
||||
if st == 0: continue
|
||||
for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))):
|
||||
if (s1 := s1 - o) == 0: continue
|
||||
terms[d2].append((d1, s1))
|
||||
strides[d1] += s1 * vm2.strides[d2]
|
||||
|
||||
# Merge dimensions in vm2 if required.
|
||||
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
|
||||
idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
|
||||
merged_size, merged_term = 1, NumNode(0)
|
||||
extents: List[Tuple[sint, Node]] = []
|
||||
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
|
||||
merged_term += Variable.sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
|
||||
merged_size *= s
|
||||
if not (merged_term >= merged_size) and not (merged_term < 0):
|
||||
extents.append((merged_size, merged_term))
|
||||
merged_size, merged_term = 1, NumNode(0)
|
||||
if merged_term: return None
|
||||
if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
|
||||
return (reshaped_vm2 := vm2.reshape(vm2_shape)) and merge_views(reshaped_vm2, vm1)
|
||||
|
||||
if vm2.mask:
|
||||
# Try to project vm2's mask on to vm1.
|
||||
newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
|
||||
for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))):
|
||||
if not (t.min < b or t.max >= e): continue
|
||||
if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int):
|
||||
bad = True
|
||||
continue
|
||||
term = terms[d2]
|
||||
if len(term) != 1:
|
||||
if not term and newe: newe[0] = 0
|
||||
else: bad = True
|
||||
continue
|
||||
d1, s1 = term[0]
|
||||
if not isinstance(s1, int) or not isinstance(newe[d1], int):
|
||||
bad = True
|
||||
continue
|
||||
newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
|
||||
newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
|
||||
|
||||
# If any of vm1 was masked off, try again with that mask in place.
|
||||
for b, e, s in zip(newb, newe, vm1.shape):
|
||||
if b != 0 or e != s:
|
||||
return merge_views(vm2, View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe))))
|
||||
# Otherwise if vm2's mask was violated, then cannot merge.
|
||||
if bad: return None
|
||||
|
||||
return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
|
||||
|
||||
def _expr_view(view:View, idxs:List[Node], valid:Optional[Node]=None) -> Tuple[Node, Node]:
|
||||
assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
|
||||
|
|
|
|||
|
|
@ -83,6 +83,12 @@ class View:
|
|||
def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
|
||||
strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
|
||||
contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
|
||||
# if any dimension has size >1, but is masked such that only one index in the dimension is unmasked
|
||||
# then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset
|
||||
#if mask and any(elim := [isinstance(b, int) and isinstance(e, int) and b+1 >= e for b,e in mask]):
|
||||
# if any(b >= e for b,e in mask): strides, offset, mask = (0,) * len(shape), 0, ((0,0),) * len(shape)
|
||||
# offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim))
|
||||
# strides = tuple(0 if e else st for st,e in zip(strides, elim))
|
||||
return View(shape, strides, offset, mask, contiguous)
|
||||
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue