early simplify

This commit is contained in:
George Hotz 2025-10-09 12:20:25 +08:00
commit 657d9972c2

View file

@ -3,7 +3,7 @@ import functools, operator, itertools
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType
from tinygrad.uop.symbolic import sym, symbolic
from tinygrad.uop.symbolic import sym, symbolic, symbolic_simple
from tinygrad.helpers import argsort, all_same, cpu_profile, TracingKey
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
@ -102,6 +102,11 @@ pm_apply_rangeify = PatternMatcher([
(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"), lambda ctx,c: c.replace(src=()) if c in ctx.range_map else None),
])
pm_simplify_rangeify = PatternMatcher([
# no buffers for const. NOTE: we rely on the PAD being applied by a WHERE elsewhere
(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True).f(Ops.INDEX, allow_any_len=True), lambda c: c),
])
# this is the definition of the movement ops
def apply_movement_op(x:UOp, rngs:Sequence[UOp]) -> list[UOp]:
match x.op:
@ -223,4 +228,5 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
rctx.range_map[x] = (rngs, out_rngs)
tsink = graph_rewrite(tsink, pm_apply_rangeify, ctx=rctx, bottom_up=True, name="apply rangeify")
tsink = graph_rewrite(tsink, symbolic_simple+pm_simplify_rangeify, name="early rangeify simplify")
return tsink, rctx