more passing

This commit is contained in:
George Hotz 2026-06-16 12:54:54 -07:00
commit 1ad72dff08
2 changed files with 15 additions and 8 deletions

View file

@ -3,7 +3,7 @@ from dataclasses import replace
import itertools
import functools
from tinygrad.helpers import DISABLE_FAST_IDIV, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic, all_same
from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic, all_same, flatten
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo, GroupOp
from tinygrad.uop.ops import ParamArg, AxisType, _align_left, _broadcast_shape, identity_element
from tinygrad.uop.render import pyrender
@ -121,10 +121,14 @@ devectorizer2 = pm_mops+PatternMatcher([
def reduce_ranges_to_acc(ctx:ReduceContext, r:UOp):
acc = UOp.placeholder_like(r, ctx.acc_num, AddrSpace.REG)
ctx.acc_num += 1
acc_initted = acc.after(acc.store(identity_element(r.arg[0], r.dtype.scalar())), *r.src[1:])
topo = r.src[0].toposort()
ended_ranges = flatten([x.ended_ranges for x in topo if x.op is Ops.END])
input_ranges = tuple(x for x in topo if x.op is Ops.RANGE and x not in r.src[1:] and x not in ended_ranges)
acc_init = acc.after(*input_ranges).store(identity_element(r.arg[0], r.dtype.scalar()))
acc_initted = acc.after(acc_init, *r.src[1:])
inp = r.src[0].reduce(arg=r.arg) if r.arg[1] else r.src[0]
acc_out = acc_initted.store(acc_initted.alu(r.arg[0], inp)).end(*r.src[1:])
return acc_initted.after(acc_out)
return acc.after(acc_out)
def expand_horizontal_reduce(r:UOp):
axes = r.arg[1]
@ -186,22 +190,25 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# split ends
sink = graph_rewrite(sink, pm_split_ends, name="split ends")
# this was the linearizer
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
# ***** make it rendererable (outside spec, transform) *****
# move gates from unrenderable INVALID where
sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index")
sink = graph_rewrite(sink, pm_move_gates_from_index, name="*** move gates from index")
# unbroadcast
sink = graph_rewrite(sink, unbroadcast, name="unbroadcast")
# devectorizer
sink = graph_rewrite(sink, devectorizer2, name="devectorizer")
sink = graph_rewrite(sink, symbolic_simple+devectorizer2, name="devectorizer")
# remove all weakints
sink = graph_rewrite(sink, pm_lower_weakints, name="lower weakints", bottom_up=True)
# this was the linearizer
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
# final symbolic
sink = graph_rewrite(sink, sym, name="post devectorizer sym")
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Output AST")
if SPEC: type_verify(sink, spec_program)

View file

@ -215,7 +215,7 @@ spec_program = PatternMatcher([
lambda x: False if x.dtype.count > 1 and (x.dtype.count,) != x.shape else None),
# STACK/GEP in program. TODO: this should match Tensor
(UPat(Ops.STACK, name="x"), lambda x: len(x.src)>1 or len(x.src) == 0),
(UPat(Ops.STACK, name="x"), lambda x: True),
# if has a <gate, index_for_dedup>
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX, Ops.SHRINK)))), lambda: True),