mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
more passing
This commit is contained in:
parent
6f1eaa8d46
commit
1ad72dff08
2 changed files with 15 additions and 8 deletions
|
|
@ -3,7 +3,7 @@ from dataclasses import replace
|
|||
import itertools
|
||||
import functools
|
||||
from tinygrad.helpers import DISABLE_FAST_IDIV, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC
|
||||
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic, all_same
|
||||
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic, all_same, flatten
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo, GroupOp
|
||||
from tinygrad.uop.ops import ParamArg, AxisType, _align_left, _broadcast_shape, identity_element
|
||||
from tinygrad.uop.render import pyrender
|
||||
|
|
@ -121,10 +121,14 @@ devectorizer2 = pm_mops+PatternMatcher([
|
|||
def reduce_ranges_to_acc(ctx:ReduceContext, r:UOp):
|
||||
acc = UOp.placeholder_like(r, ctx.acc_num, AddrSpace.REG)
|
||||
ctx.acc_num += 1
|
||||
acc_initted = acc.after(acc.store(identity_element(r.arg[0], r.dtype.scalar())), *r.src[1:])
|
||||
topo = r.src[0].toposort()
|
||||
ended_ranges = flatten([x.ended_ranges for x in topo if x.op is Ops.END])
|
||||
input_ranges = tuple(x for x in topo if x.op is Ops.RANGE and x not in r.src[1:] and x not in ended_ranges)
|
||||
acc_init = acc.after(*input_ranges).store(identity_element(r.arg[0], r.dtype.scalar()))
|
||||
acc_initted = acc.after(acc_init, *r.src[1:])
|
||||
inp = r.src[0].reduce(arg=r.arg) if r.arg[1] else r.src[0]
|
||||
acc_out = acc_initted.store(acc_initted.alu(r.arg[0], inp)).end(*r.src[1:])
|
||||
return acc_initted.after(acc_out)
|
||||
return acc.after(acc_out)
|
||||
|
||||
def expand_horizontal_reduce(r:UOp):
|
||||
axes = r.arg[1]
|
||||
|
|
@ -186,22 +190,25 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
# split ends
|
||||
sink = graph_rewrite(sink, pm_split_ends, name="split ends")
|
||||
|
||||
# this was the linearizer
|
||||
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
|
||||
|
||||
# ***** make it rendererable (outside spec, transform) *****
|
||||
|
||||
# move gates from unrenderable INVALID where
|
||||
sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index")
|
||||
sink = graph_rewrite(sink, pm_move_gates_from_index, name="*** move gates from index")
|
||||
|
||||
# unbroadcast
|
||||
sink = graph_rewrite(sink, unbroadcast, name="unbroadcast")
|
||||
|
||||
# devectorizer
|
||||
sink = graph_rewrite(sink, devectorizer2, name="devectorizer")
|
||||
sink = graph_rewrite(sink, symbolic_simple+devectorizer2, name="devectorizer")
|
||||
|
||||
# remove all weakints
|
||||
sink = graph_rewrite(sink, pm_lower_weakints, name="lower weakints", bottom_up=True)
|
||||
|
||||
# this was the linearizer
|
||||
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
|
||||
# final symbolic
|
||||
sink = graph_rewrite(sink, sym, name="post devectorizer sym")
|
||||
|
||||
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Output AST")
|
||||
if SPEC: type_verify(sink, spec_program)
|
||||
|
|
|
|||
|
|
@ -215,7 +215,7 @@ spec_program = PatternMatcher([
|
|||
lambda x: False if x.dtype.count > 1 and (x.dtype.count,) != x.shape else None),
|
||||
|
||||
# STACK/GEP in program. TODO: this should match Tensor
|
||||
(UPat(Ops.STACK, name="x"), lambda x: len(x.src)>1 or len(x.src) == 0),
|
||||
(UPat(Ops.STACK, name="x"), lambda x: True),
|
||||
|
||||
# if has a <gate, index_for_dedup>
|
||||
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX, Ops.SHRINK)))), lambda: True),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue