test tiny almost passes

This commit is contained in:
George Hotz 2026-06-20 16:46:01 -07:00
commit 62c6c75657

View file

@ -24,7 +24,7 @@ from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, p
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
from tinygrad.codegen.codegen2 import expander2, pm_move_regs, devectorizer2, unbroadcast
from tinygrad.codegen.codegen2 import expander2, pm_move_regs, devectorizer2, unbroadcast, pm_reduce_local
pm_index_is_shrink = PatternMatcher([
# rewrite non-image INDEX to SHRINK
@ -80,7 +80,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
sink = apply_opts(sink, ren, beam=ast.arg.beam)
# ** expander (expand_rewrite) **
sink = graph_rewrite(sink, sym+pm_move_where_on_load, name="postopt symbolic")
sink = graph_rewrite(sink, sym+pm_move_where_on_load+pm_flatten_range, name="postopt symbolic")
# expand
#sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
@ -91,7 +91,8 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# ** devectorizer (full_graph_rewrite) **
# remove reduce
sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce")
#sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce")
sink = graph_rewrite(sink, pm_reduce_local, ctx=ReduceContext(), name="remove_reduce")
# add gpu dims (late). this works after devectorize, but it's faster here
sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims")
@ -110,7 +111,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
#sink = graph_rewrite(sink, sym+devectorize_alu+devectorize_buf_and_index+load_store_folding+correct_load_store+load_store_indexing,
# ctx=ren, name="devectorize")
sink = graph_rewrite(sink, unbroadcast, name="*** unbroadcast")
sink = graph_rewrite(sink, devectorizer2, name="devectorize2")
sink = graph_rewrite(sink, sym+devectorizer2, name="devectorize2")
# lower the index dtype to a concrete int
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")