mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
make remove bufferize fast (#12555)
* add more uop gc test * make remove bufferize fast * substitute is fast too * fix tests
This commit is contained in:
parent
cf8232ec6a
commit
e7aa26ed29
2 changed files with 19 additions and 13 deletions
|
|
@ -124,7 +124,7 @@ def cleanup_dead_axes(b:UOp):
|
|||
# skip for symbolic. TODO: fix this
|
||||
if rng.op is Ops.RANGE and rng.src[0].op is not Ops.CONST: return None
|
||||
# CONSTs are already dead axes
|
||||
if rng.op is Ops.CONST or (rng.op is Ops.RANGE and rng not in b.src[0].backward_slice_with_self):
|
||||
if rng.op is Ops.CONST or (rng.op is Ops.RANGE and rng not in b.src[0].ranges):
|
||||
reshape.append(1)
|
||||
hit = True
|
||||
else:
|
||||
|
|
@ -149,23 +149,29 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
|
|||
# *** here is where we compute the cost ***
|
||||
# if we return None, the bufferize is kept
|
||||
|
||||
accessed_buffers = []
|
||||
accessed_buffers: list[UOp] = []
|
||||
reduces: list[UOp] = []
|
||||
def red_gate(x:UOp):
|
||||
if x.op is Ops.INDEX:
|
||||
accessed_buffers.append(x)
|
||||
return False
|
||||
if x.op is Ops.REDUCE: reduces.append(x)
|
||||
return True
|
||||
ran = src.toposort(gate=red_gate)
|
||||
src.toposort(gate=red_gate)
|
||||
del red_gate
|
||||
|
||||
# if this is generated from multiple buffers, don't remove this buffer
|
||||
if len(dedup([x.src[0] for x in accessed_buffers])) > 2: return None
|
||||
|
||||
# const reduce is okay
|
||||
# TODO: move the reduce folder to before this to prevent the need for this
|
||||
def okay_reduce(x:UOp): return all(y.op not in {Ops.BUFFER, Ops.BUFFERIZE, Ops.COPY} for y in x.backward_slice_with_self)
|
||||
|
||||
# always run this list of ops
|
||||
if any(x.op is Ops.REDUCE and not okay_reduce(x) for x in ran): return None
|
||||
# if any reduces access a buffer, don't remove this buffer
|
||||
buffer_in_reduce = False
|
||||
def buf_gate(x:UOp):
|
||||
nonlocal buffer_in_reduce
|
||||
if x.op in {Ops.BUFFER, Ops.BUFFERIZE}: buffer_in_reduce = True
|
||||
return not buffer_in_reduce
|
||||
UOp.sink(*[x.src[0] for x in reduces]).toposort(gate=buf_gate)
|
||||
del buf_gate
|
||||
if buffer_in_reduce: return None
|
||||
|
||||
# if it makes it here, the bufferize is removed
|
||||
# this is the ranges replaced
|
||||
|
|
@ -465,9 +471,9 @@ def do_sub_recurse(s:UOp):
|
|||
return UOp(Ops.SUBSTITUTE, dtype=x.dtype, src=(x.src[0], sub_k, sub_v))
|
||||
# here we actually do the SUBSTITUTE
|
||||
if x in keys: return values[keys.index(x)]
|
||||
# we filter any keys that aren't in the backward slice. this keeps the algorithm O(output graph size)
|
||||
# NOTE: if k was x, it would trigger above, so self doesn't have to be included in backward_slice
|
||||
new_kv = {k:v for k,v in zip(keys,values) if k in x.backward_slice}
|
||||
# we filter any keys where the ranges don't overlap. this keeps the algorithm O(output graph size)
|
||||
x_ranges = x.ranges
|
||||
new_kv = {k:v for k,v in zip(keys,values) if any(r in x_ranges for r in k.ranges)}
|
||||
# if there's no SUBSTITUTEs left, we can just return x
|
||||
if len(new_kv) == 0: return x
|
||||
# then we add SUBSTITUTE to all parents
|
||||
|
|
|
|||
|
|
@ -236,7 +236,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
def size(self) -> int: return self.arg[0] if self.op is Ops.BUFFER_VIEW else self.arg if self.op is Ops.BUFFER else unwrap(self.st).size
|
||||
|
||||
# determine what ranges this is in
|
||||
@functools.cached_property
|
||||
@recursive_property
|
||||
def _ranges(self) -> dict[UOp, None]:
|
||||
ret: dict[UOp, None] = {}
|
||||
if self.op in range_start.keys():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue