mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
jit: memplan before compile (#16560)
This commit is contained in:
parent
34481830f1
commit
2c9d2c0d31
1 changed files with 1 additions and 1 deletions
|
|
@ -69,8 +69,8 @@ def jit_lower(linear:UOp, held_bufs:set[UOp], input_uops:list[UOp]) -> UOp:
|
|||
|
||||
# parametrize input buffers: map each input buffer UOp to a PARAM with the correct slot index
|
||||
linear = linear.substitute({u: UOp.param(i, u.dtype, u.shape, u.device) for i,u in enumerate(input_uops)}, walk=True)
|
||||
linear = compile_linear(linear, beam=getenv("JITBEAM", BEAM.value))
|
||||
linear = memory_plan_rewrite(linear, held_bufs)
|
||||
linear = compile_linear(linear, beam=getenv("JITBEAM", BEAM.value))
|
||||
if JIT < 2: linear = graph_split_rewrite(linear, max_batch_size=JIT_BATCH_SIZE.value)
|
||||
if VIZ: graph_rewrite(linear, PatternMatcher([]), name="View graphed linear")
|
||||
return linear
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue