mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
move simplify into views_to_indexed_uops (#9999)
* move simplify into views_to_indexed_uops * cache that
This commit is contained in:
parent
c39128133c
commit
cc1087d2ec
2 changed files with 13 additions and 16 deletions
4
test/external/external_uop_gc.py
vendored
4
test/external/external_uop_gc.py
vendored
|
|
@ -1,6 +1,6 @@
|
|||
import gc
|
||||
from tinygrad import Tensor, UOp, Device
|
||||
from tinygrad.shape.shapetracker import views_to_indexed_uops, folded_upcast
|
||||
from tinygrad.shape.shapetracker import views_to_indexed_uops, upcast
|
||||
from tinygrad.engine.realize import method_cache, get_kernel
|
||||
|
||||
def uops_allocated(): return sum([isinstance(x, UOp) for x in gc.get_objects()])
|
||||
|
|
@ -63,7 +63,7 @@ if __name__ == "__main__":
|
|||
# these caches will keep uops alive
|
||||
method_cache.clear()
|
||||
views_to_indexed_uops.cache_clear()
|
||||
folded_upcast.cache_clear()
|
||||
upcast.cache_clear()
|
||||
|
||||
new_uops = uops_allocated()
|
||||
gc.collect()
|
||||
|
|
|
|||
|
|
@ -7,12 +7,13 @@ from tinygrad.helpers import merge_dicts, getenv
|
|||
from tinygrad.shape.view import View, strides_for_shape, unravel
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context
|
||||
from tinygrad.codegen.symbolic import sym, split_uop, symbolic_flat, uop_given_valid, simplify_valid
|
||||
from tinygrad.codegen.symbolic import split_uop, symbolic_flat, uop_given_valid, simplify_valid
|
||||
|
||||
def overflow(u: UOp): return u.vmax > dtypes.max(dtypes.int) or u.vmin < dtypes.min(dtypes.int)
|
||||
|
||||
# If a node overflow, its srcs need to be checked to see if this overflow is the result of an ALU operation,
|
||||
# or that the node simply inherits the dtype from srcs. Upcast is either `Ops.CAST`+`replace` or just `replace`.
|
||||
@functools.cache
|
||||
def upcast(u: UOp) -> UOp:
|
||||
srcs = tuple(upcast(_src) for _src in u.src)
|
||||
if u.dtype.scalar() is dtypes.int:
|
||||
|
|
@ -24,29 +25,26 @@ def upcast(u: UOp) -> UOp:
|
|||
if any((overflow(src) for src in u.src)): return upcasted.cast(u.dtype)
|
||||
return u.replace(src=tuple(srcs))
|
||||
|
||||
# pooling op may overflow before folding causing unnecessary upcast
|
||||
@functools.cache
|
||||
def folded_upcast(u: UOp) -> UOp:
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
return upcast(graph_rewrite(u, sym, {}))
|
||||
|
||||
@functools.cache
|
||||
def views_to_indexed_uops(views: tuple[View, ...], _idxs:Optional[tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
|
||||
idx, valid = views[-1].to_indexed_uops(_idxs)
|
||||
for view in reversed(views[0:-1]):
|
||||
view = view.minify()
|
||||
idx, valid = view.to_indexed_uops([sint_to_uop(i) for i in unravel(view.shape, idx)], valid)
|
||||
return idx, valid
|
||||
# symbolic
|
||||
idx, valid = graph_rewrite(UOp.sink(idx, valid), symbolic_flat).src
|
||||
# simplify
|
||||
if (newvalid:=simplify_valid(valid)) is not None: valid = graph_rewrite(newvalid, symbolic_flat)
|
||||
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
|
||||
# upcast if needed
|
||||
return upcast(idx), upcast(valid)
|
||||
|
||||
@functools.cache
|
||||
def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[Optional[sint], ...]:
|
||||
# NOTE: if a stride is not always valid, it will be None
|
||||
if len(views) == 1 and views[-1].mask is None: return views[-1].strides
|
||||
ret: list[Optional[sint]] = [None] * len(views[-1].shape)
|
||||
idx, valid = (graph_rewrite(u, symbolic_flat) for u in views_to_indexed_uops(views))
|
||||
# TODO: always apply these in to_indexed_uops?
|
||||
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
|
||||
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
|
||||
idx, valid = views_to_indexed_uops(views)
|
||||
for c in split_uop(idx, Ops.ADD):
|
||||
if c.op is Ops.RANGE: ret[c.arg] = 1
|
||||
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg
|
||||
|
|
@ -92,8 +90,7 @@ class ShapeTracker:
|
|||
|
||||
def to_uop(self) -> UOp: return UOp(Ops.VIEW, dtypes.void, (), self)
|
||||
def to_indexed_uops(self, _idxs:Optional[list[UOp]|tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
|
||||
idx, valid = views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None)
|
||||
return folded_upcast(idx), folded_upcast(valid)
|
||||
return views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None)
|
||||
|
||||
# upper bound on buffer size required to fit this shapetracker
|
||||
def real_size(self) -> int:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue