move simplify into views_to_indexed_uops (#9999)

* move simplify into views_to_indexed_uops

* cache that
This commit is contained in:
George Hotz 2025-04-23 13:50:27 +01:00 committed by GitHub
commit cc1087d2ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 13 additions and 16 deletions

View file

@ -1,6 +1,6 @@
import gc
from tinygrad import Tensor, UOp, Device
from tinygrad.shape.shapetracker import views_to_indexed_uops, folded_upcast
from tinygrad.shape.shapetracker import views_to_indexed_uops, upcast
from tinygrad.engine.realize import method_cache, get_kernel
def uops_allocated(): return sum([isinstance(x, UOp) for x in gc.get_objects()])
@ -63,7 +63,7 @@ if __name__ == "__main__":
# these caches will keep uops alive
method_cache.clear()
views_to_indexed_uops.cache_clear()
folded_upcast.cache_clear()
upcast.cache_clear()
new_uops = uops_allocated()
gc.collect()

View file

@ -7,12 +7,13 @@ from tinygrad.helpers import merge_dicts, getenv
from tinygrad.shape.view import View, strides_for_shape, unravel
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context
from tinygrad.codegen.symbolic import sym, split_uop, symbolic_flat, uop_given_valid, simplify_valid
from tinygrad.codegen.symbolic import split_uop, symbolic_flat, uop_given_valid, simplify_valid
def overflow(u: UOp): return u.vmax > dtypes.max(dtypes.int) or u.vmin < dtypes.min(dtypes.int)
# If a node overflow, its srcs need to be checked to see if this overflow is the result of an ALU operation,
# or that the node simply inherits the dtype from srcs. Upcast is either `Ops.CAST`+`replace` or just `replace`.
@functools.cache
def upcast(u: UOp) -> UOp:
srcs = tuple(upcast(_src) for _src in u.src)
if u.dtype.scalar() is dtypes.int:
@ -24,29 +25,26 @@ def upcast(u: UOp) -> UOp:
if any((overflow(src) for src in u.src)): return upcasted.cast(u.dtype)
return u.replace(src=tuple(srcs))
# pooling op may overflow before folding causing unnecessary upcast
@functools.cache
def folded_upcast(u: UOp) -> UOp:
with Context(TRACK_MATCH_STATS=0):
return upcast(graph_rewrite(u, sym, {}))
@functools.cache
def views_to_indexed_uops(views: tuple[View, ...], _idxs:Optional[tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
idx, valid = views[-1].to_indexed_uops(_idxs)
for view in reversed(views[0:-1]):
view = view.minify()
idx, valid = view.to_indexed_uops([sint_to_uop(i) for i in unravel(view.shape, idx)], valid)
return idx, valid
# symbolic
idx, valid = graph_rewrite(UOp.sink(idx, valid), symbolic_flat).src
# simplify
if (newvalid:=simplify_valid(valid)) is not None: valid = graph_rewrite(newvalid, symbolic_flat)
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
# upcast if needed
return upcast(idx), upcast(valid)
@functools.cache
def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[Optional[sint], ...]:
# NOTE: if a stride is not always valid, it will be None
if len(views) == 1 and views[-1].mask is None: return views[-1].strides
ret: list[Optional[sint]] = [None] * len(views[-1].shape)
idx, valid = (graph_rewrite(u, symbolic_flat) for u in views_to_indexed_uops(views))
# TODO: always apply these in to_indexed_uops?
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
idx, valid = views_to_indexed_uops(views)
for c in split_uop(idx, Ops.ADD):
if c.op is Ops.RANGE: ret[c.arg] = 1
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg
@ -92,8 +90,7 @@ class ShapeTracker:
def to_uop(self) -> UOp: return UOp(Ops.VIEW, dtypes.void, (), self)
def to_indexed_uops(self, _idxs:Optional[list[UOp]|tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
idx, valid = views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None)
return folded_upcast(idx), folded_upcast(valid)
return views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None)
# upper bound on buffer size required to fit this shapetracker
def real_size(self) -> int: