mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Fix exponential complexity in _is_padding_okay [pr] (#7008)
* preliminary test * missed Optional * don't check for cache during recursion * match style from st_fixup... may be marginally faster? * pathological test case: strongly connected DAG * move to test_schedule as this isn't really a fusion * oops this shouldn't be edited * Revert "oops this shouldn't be edited" This reverts commit487cb027dc. * Revert "move to test_schedule as this isn't really a fusion" This reverts commit48d8c550ce. * move to test_schedule as this isn't really a fusion * ok no more merge error funny business
This commit is contained in:
parent
bd8ecf7fd6
commit
2ac5aec66b
2 changed files with 15 additions and 3 deletions
|
|
@ -1599,6 +1599,16 @@ class TestIndexing(unittest.TestCase):
|
|||
self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1)))
|
||||
self.assertLess(et, 1e3)
|
||||
|
||||
def test_strongly_connected_DAG(self):
|
||||
val = 1.0
|
||||
a = Tensor(val).realize()
|
||||
def f(a):
|
||||
for _ in range(24): a = Tensor.stack(a, a)[0]
|
||||
return a.item()
|
||||
r, et = timeit(f, a)
|
||||
self.assertEqual(r, val)
|
||||
self.assertLess(et, 1e3)
|
||||
|
||||
def test_no_rewrite_elementwise(self):
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(3)]
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
|
|
|
|||
|
|
@ -238,11 +238,13 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
|
|||
if x.base.realized is None: children[x.base][buf] = None
|
||||
_recurse_lb(x, realizes, allbufs, simple_pads, children, assign_targets, double_reduces)
|
||||
|
||||
def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
|
||||
def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], cache:Dict[LazyBuffer, bool]) -> bool:
|
||||
if (n:=cache.get(buf)) is not None: return n
|
||||
if buf in realizes: return True
|
||||
# NOTE: this broke to_image_idx and coder with JIT
|
||||
if buf.op in UNSAFE_PAD_OPS: return False
|
||||
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
|
||||
cache[buf] = ret = all(_is_padding_okay(x.base, realizes, cache) for x in buf.srcs)
|
||||
return ret
|
||||
|
||||
def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
|
||||
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, None],
|
||||
|
|
@ -292,7 +294,7 @@ def _get_output_groups(outs:List[LazyBuffer]) -> \
|
|||
|
||||
# check if we have to realize pads
|
||||
for p in simple_pads:
|
||||
if not _is_padding_okay(p, realizes):
|
||||
if not _is_padding_okay(p, realizes, {}):
|
||||
realizes[p] = None
|
||||
|
||||
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue