tensor_map cleanups [pr] (#8754)

* tensor_map cleanups [pr]

* update test_schedule too
This commit is contained in:
qazal 2025-01-26 04:41:54 -05:00 committed by GitHub
commit ac70f63d4b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 19 additions and 19 deletions

View file

@ -16,7 +16,7 @@ from tinygrad.shape.view import View
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp
from tinygrad.codegen.kernel import verify_ast
from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
from extra.models.llama import precompute_freqs_cis
@ -67,7 +67,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
@track_rewrites(named=True)
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, ScheduleContext())
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, {})
class TestSchedule(unittest.TestCase):
def test_basic_binop_fusion(self):

View file

@ -81,7 +81,6 @@ class ScheduleContext:
realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule
allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
ops_metadata: dict[UOp, Metadata] = field(default_factory=dict) # this maps fused ops to Metadata
contiguous: dict[UOp, UOp] = field(default_factory=dict) # this maps roots to places they are made contiguous
children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
@ -353,12 +352,12 @@ def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
case _: return None
return reduce.const_like(ret)
def found_contiguous(ctx:ScheduleContext, contig:UOp, src:UOp):
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx.contiguous[src.base] = contig.view(sti)
def replace_contiguous(ctx:ScheduleContext, alu:UOp):
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
new_src = list(alu.src)
for i,s in enumerate(alu.src):
if (replace_src:=ctx.contiguous.get(s, None)) is not None: new_src[i] = replace_src
if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
sym = symbolic_simple+PatternMatcher([
@ -490,11 +489,22 @@ remove_movement_ops = merge_views+PatternMatcher([
@track_rewrites(named=True)
def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec)
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx:=ScheduleContext())
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx={})
# tensors can become an existing buffer or simplify to a const, no ScheduleItem needed
becomes_map: dict[UOp, UOp] = {}
for k,v in tensor_map.items():
# NOOP
if k.base is v.base: continue
# NOTE: only the base tensors get a BUFFER UOp
if v.is_realized and k is k.base: becomes_map[k] = v.view(unwrap(k.st))
# otherwise if it simplified to a CONST the UOp just becomes that CONST
elif v.op is Ops.CONST: becomes_map[k] = v
# we group the rest of UOps into ScheduleItems
rev_tensor_map: dict[UOp, list[UOp]] = {}
for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k)
# add BUFFER uops
sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx, cache={})
sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx:=ScheduleContext(), cache={})
# add realizes
sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
# group realizes into kernels
@ -502,7 +512,6 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu
graph_rewrite(sink, break_sched, ctx)
# create schedule items + map buffers to realized tensors
prescheduled: list[ScheduleItem] = []
becomes_map: dict[UOp, UOp] = {}
for store_uops in store_groups:
small_sink = UOp.sink(*[ctx.realizes[u] for u in store_uops])
if not all(x.op is Ops.STORE for x in small_sink.src): raise RuntimeError(f"expected all realized BUFFERs to get a STORE {sink}")
@ -513,15 +522,6 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu
# increment refcount for this buffer
buf_uop.buffer.ref(1)
# tensors can become an existing buffer or simplify to a const, no ScheduleItem needed
for k,v in tensor_map.items():
# NOOP
if k.base is v.base: continue
# NOTE: only the base tensors get a BUFFER UOp
if v.is_realized and k is k.base: becomes_map[k] = v.view(unwrap(k.st))
# otherwise if it simplified to a CONST the UOp just becomes that CONST
elif v.op is Ops.CONST: becomes_map[k] = v
# add kernel children
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
graph: defaultdict[ScheduleItem, list[ScheduleItem]] = defaultdict(list)