mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
test tiny almost passes
This commit is contained in:
parent
530aed739d
commit
62c6c75657
1 changed files with 5 additions and 4 deletions
|
|
@ -24,7 +24,7 @@ from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, p
|
|||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||
from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite
|
||||
|
||||
from tinygrad.codegen.codegen2 import expander2, pm_move_regs, devectorizer2, unbroadcast
|
||||
from tinygrad.codegen.codegen2 import expander2, pm_move_regs, devectorizer2, unbroadcast, pm_reduce_local
|
||||
|
||||
pm_index_is_shrink = PatternMatcher([
|
||||
# rewrite non-image INDEX to SHRINK
|
||||
|
|
@ -80,7 +80,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
sink = apply_opts(sink, ren, beam=ast.arg.beam)
|
||||
|
||||
# ** expander (expand_rewrite) **
|
||||
sink = graph_rewrite(sink, sym+pm_move_where_on_load, name="postopt symbolic")
|
||||
sink = graph_rewrite(sink, sym+pm_move_where_on_load+pm_flatten_range, name="postopt symbolic")
|
||||
|
||||
# expand
|
||||
#sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander")
|
||||
|
|
@ -91,7 +91,8 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
|
||||
# ** devectorizer (full_graph_rewrite) **
|
||||
# remove reduce
|
||||
sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce")
|
||||
#sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce")
|
||||
sink = graph_rewrite(sink, pm_reduce_local, ctx=ReduceContext(), name="remove_reduce")
|
||||
|
||||
# add gpu dims (late). this works after devectorize, but it's faster here
|
||||
sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims")
|
||||
|
|
@ -110,7 +111,7 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
#sink = graph_rewrite(sink, sym+devectorize_alu+devectorize_buf_and_index+load_store_folding+correct_load_store+load_store_indexing,
|
||||
# ctx=ren, name="devectorize")
|
||||
sink = graph_rewrite(sink, unbroadcast, name="*** unbroadcast")
|
||||
sink = graph_rewrite(sink, devectorizer2, name="devectorize2")
|
||||
sink = graph_rewrite(sink, sym+devectorizer2, name="devectorize2")
|
||||
|
||||
# lower the index dtype to a concrete int
|
||||
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue