Fix exponential complexity in _is_padding_okay [pr] (#7008)

* preliminary test

* missed Optional

* don't check for cache during recursion

* match style from st_fixup... may be marginally faster?

* pathological test case: strongly connected DAG

* move to test_schedule as this isn't really a fusion

* oops this shouldn't be edited

* Revert "oops this shouldn't be edited"

This reverts commit 487cb027dc.

* Revert "move to test_schedule as this isn't really a fusion"

This reverts commit 48d8c550ce.

* move to test_schedule as this isn't really a fusion

* ok no more merge error funny business
This commit is contained in:
Louis Novy 2024-10-13 16:34:47 -07:00 committed by GitHub
commit 2ac5aec66b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 15 additions and 3 deletions

View file

@ -1599,6 +1599,16 @@ class TestIndexing(unittest.TestCase):
self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1)))
self.assertLess(et, 1e3)
def test_strongly_connected_DAG(self):
val = 1.0
a = Tensor(val).realize()
def f(a):
for _ in range(24): a = Tensor.stack(a, a)[0]
return a.item()
r, et = timeit(f, a)
self.assertEqual(r, val)
self.assertLess(et, 1e3)
def test_no_rewrite_elementwise(self):
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(3)]
ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))

View file

@ -238,11 +238,13 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
if x.base.realized is None: children[x.base][buf] = None
_recurse_lb(x, realizes, allbufs, simple_pads, children, assign_targets, double_reduces)
def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], cache:Dict[LazyBuffer, bool]) -> bool:
if (n:=cache.get(buf)) is not None: return n
if buf in realizes: return True
# NOTE: this broke to_image_idx and coder with JIT
if buf.op in UNSAFE_PAD_OPS: return False
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
cache[buf] = ret = all(_is_padding_okay(x.base, realizes, cache) for x in buf.srcs)
return ret
def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, None],
@ -292,7 +294,7 @@ def _get_output_groups(outs:List[LazyBuffer]) -> \
# check if we have to realize pads
for p in simple_pads:
if not _is_padding_okay(p, realizes):
if not _is_padding_okay(p, realizes, {}):
realizes[p] = None
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)