mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
early simplify
This commit is contained in:
parent
3568f9f5a6
commit
657d9972c2
1 changed files with 7 additions and 1 deletions
|
|
@ -3,7 +3,7 @@ import functools, operator, itertools
|
|||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType
|
||||
from tinygrad.uop.symbolic import sym, symbolic
|
||||
from tinygrad.uop.symbolic import sym, symbolic, symbolic_simple
|
||||
from tinygrad.helpers import argsort, all_same, cpu_profile, TracingKey
|
||||
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
|
|
@ -102,6 +102,11 @@ pm_apply_rangeify = PatternMatcher([
|
|||
(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"), lambda ctx,c: c.replace(src=()) if c in ctx.range_map else None),
|
||||
])
|
||||
|
||||
pm_simplify_rangeify = PatternMatcher([
|
||||
# no buffers for const. NOTE: we rely on the PAD being applied by a WHERE elsewhere
|
||||
(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True).f(Ops.INDEX, allow_any_len=True), lambda c: c),
|
||||
])
|
||||
|
||||
# this is the definition of the movement ops
|
||||
def apply_movement_op(x:UOp, rngs:Sequence[UOp]) -> list[UOp]:
|
||||
match x.op:
|
||||
|
|
@ -223,4 +228,5 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
|||
rctx.range_map[x] = (rngs, out_rngs)
|
||||
|
||||
tsink = graph_rewrite(tsink, pm_apply_rangeify, ctx=rctx, bottom_up=True, name="apply rangeify")
|
||||
tsink = graph_rewrite(tsink, symbolic_simple+pm_simplify_rangeify, name="early rangeify simplify")
|
||||
return tsink, rctx
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue