make remove bufferize fast (#12555)

* add more uop gc test

* make remove bufferize fast

* substitute is fast too

* fix tests
This commit is contained in:
George Hotz 2025-10-09 15:20:02 +08:00 committed by GitHub
commit e7aa26ed29
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 19 additions and 13 deletions

View file

@ -124,7 +124,7 @@ def cleanup_dead_axes(b:UOp):
# skip for symbolic. TODO: fix this
if rng.op is Ops.RANGE and rng.src[0].op is not Ops.CONST: return None
# CONSTs are already dead axes
if rng.op is Ops.CONST or (rng.op is Ops.RANGE and rng not in b.src[0].backward_slice_with_self):
if rng.op is Ops.CONST or (rng.op is Ops.RANGE and rng not in b.src[0].ranges):
reshape.append(1)
hit = True
else:
@ -149,23 +149,29 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
# *** here is where we compute the cost ***
# if we return None, the bufferize is kept
accessed_buffers = []
accessed_buffers: list[UOp] = []
reduces: list[UOp] = []
def red_gate(x:UOp):
if x.op is Ops.INDEX:
accessed_buffers.append(x)
return False
if x.op is Ops.REDUCE: reduces.append(x)
return True
ran = src.toposort(gate=red_gate)
src.toposort(gate=red_gate)
del red_gate
# if this is generated from multiple buffers, don't remove this buffer
if len(dedup([x.src[0] for x in accessed_buffers])) > 2: return None
# const reduce is okay
# TODO: move the reduce folder to before this to prevent the need for this
def okay_reduce(x:UOp): return all(y.op not in {Ops.BUFFER, Ops.BUFFERIZE, Ops.COPY} for y in x.backward_slice_with_self)
# always run this list of ops
if any(x.op is Ops.REDUCE and not okay_reduce(x) for x in ran): return None
# if any reduces access a buffer, don't remove this buffer
buffer_in_reduce = False
def buf_gate(x:UOp):
nonlocal buffer_in_reduce
if x.op in {Ops.BUFFER, Ops.BUFFERIZE}: buffer_in_reduce = True
return not buffer_in_reduce
UOp.sink(*[x.src[0] for x in reduces]).toposort(gate=buf_gate)
del buf_gate
if buffer_in_reduce: return None
# if it makes it here, the bufferize is removed
# this is the ranges replaced
@ -465,9 +471,9 @@ def do_sub_recurse(s:UOp):
return UOp(Ops.SUBSTITUTE, dtype=x.dtype, src=(x.src[0], sub_k, sub_v))
# here we actually do the SUBSTITUTE
if x in keys: return values[keys.index(x)]
# we filter any keys that aren't in the backward slice. this keeps the algorithm O(output graph size)
# NOTE: if k was x, it would trigger above, so self doesn't have to be included in backward_slice
new_kv = {k:v for k,v in zip(keys,values) if k in x.backward_slice}
# we filter any keys where the ranges don't overlap. this keeps the algorithm O(output graph size)
x_ranges = x.ranges
new_kv = {k:v for k,v in zip(keys,values) if any(r in x_ranges for r in k.ranges)}
# if there's no SUBSTITUTEs left, we can just return x
if len(new_kv) == 0: return x
# then we add SUBSTITUTE to all parents

View file

@ -236,7 +236,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def size(self) -> int: return self.arg[0] if self.op is Ops.BUFFER_VIEW else self.arg if self.op is Ops.BUFFER else unwrap(self.st).size
# determine what ranges this is in
@functools.cached_property
@recursive_property
def _ranges(self) -> dict[UOp, None]:
ret: dict[UOp, None] = {}
if self.op in range_start.keys():