mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
tensor_map cleanups [pr] (#8754)
* tensor_map cleanups [pr] * update test_schedule too
This commit is contained in:
parent
b53fe7c2fc
commit
ac70f63d4b
2 changed files with 19 additions and 19 deletions
|
|
@ -16,7 +16,7 @@ from tinygrad.shape.view import View
|
|||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp
|
||||
from tinygrad.codegen.kernel import verify_ast
|
||||
from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
||||
|
|
@ -67,7 +67,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
|
|||
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, ScheduleContext())
|
||||
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, {})
|
||||
|
||||
class TestSchedule(unittest.TestCase):
|
||||
def test_basic_binop_fusion(self):
|
||||
|
|
|
|||
|
|
@ -81,7 +81,6 @@ class ScheduleContext:
|
|||
realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule
|
||||
allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
|
||||
ops_metadata: dict[UOp, Metadata] = field(default_factory=dict) # this maps fused ops to Metadata
|
||||
contiguous: dict[UOp, UOp] = field(default_factory=dict) # this maps roots to places they are made contiguous
|
||||
children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
||||
preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
||||
|
||||
|
|
@ -353,12 +352,12 @@ def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
|
|||
case _: return None
|
||||
return reduce.const_like(ret)
|
||||
|
||||
def found_contiguous(ctx:ScheduleContext, contig:UOp, src:UOp):
|
||||
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx.contiguous[src.base] = contig.view(sti)
|
||||
def replace_contiguous(ctx:ScheduleContext, alu:UOp):
|
||||
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
|
||||
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
|
||||
def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
|
||||
new_src = list(alu.src)
|
||||
for i,s in enumerate(alu.src):
|
||||
if (replace_src:=ctx.contiguous.get(s, None)) is not None: new_src[i] = replace_src
|
||||
if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src
|
||||
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
|
||||
|
||||
sym = symbolic_simple+PatternMatcher([
|
||||
|
|
@ -490,11 +489,22 @@ remove_movement_ops = merge_views+PatternMatcher([
|
|||
@track_rewrites(named=True)
|
||||
def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
||||
if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec)
|
||||
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx:=ScheduleContext())
|
||||
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx={})
|
||||
# tensors can become an existing buffer or simplify to a const, no ScheduleItem needed
|
||||
becomes_map: dict[UOp, UOp] = {}
|
||||
for k,v in tensor_map.items():
|
||||
# NOOP
|
||||
if k.base is v.base: continue
|
||||
# NOTE: only the base tensors get a BUFFER UOp
|
||||
if v.is_realized and k is k.base: becomes_map[k] = v.view(unwrap(k.st))
|
||||
# otherwise if it simplified to a CONST the UOp just becomes that CONST
|
||||
elif v.op is Ops.CONST: becomes_map[k] = v
|
||||
|
||||
# we group the rest of UOps into ScheduleItems
|
||||
rev_tensor_map: dict[UOp, list[UOp]] = {}
|
||||
for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k)
|
||||
# add BUFFER uops
|
||||
sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx, cache={})
|
||||
sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx:=ScheduleContext(), cache={})
|
||||
# add realizes
|
||||
sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
|
||||
# group realizes into kernels
|
||||
|
|
@ -502,7 +512,6 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu
|
|||
graph_rewrite(sink, break_sched, ctx)
|
||||
# create schedule items + map buffers to realized tensors
|
||||
prescheduled: list[ScheduleItem] = []
|
||||
becomes_map: dict[UOp, UOp] = {}
|
||||
for store_uops in store_groups:
|
||||
small_sink = UOp.sink(*[ctx.realizes[u] for u in store_uops])
|
||||
if not all(x.op is Ops.STORE for x in small_sink.src): raise RuntimeError(f"expected all realized BUFFERs to get a STORE {sink}")
|
||||
|
|
@ -513,15 +522,6 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu
|
|||
# increment refcount for this buffer
|
||||
buf_uop.buffer.ref(1)
|
||||
|
||||
# tensors can become an existing buffer or simplify to a const, no ScheduleItem needed
|
||||
for k,v in tensor_map.items():
|
||||
# NOOP
|
||||
if k.base is v.base: continue
|
||||
# NOTE: only the base tensors get a BUFFER UOp
|
||||
if v.is_realized and k is k.base: becomes_map[k] = v.view(unwrap(k.st))
|
||||
# otherwise if it simplified to a CONST the UOp just becomes that CONST
|
||||
elif v.op is Ops.CONST: becomes_map[k] = v
|
||||
|
||||
# add kernel children
|
||||
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
|
||||
graph: defaultdict[ScheduleItem, list[ScheduleItem]] = defaultdict(list)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue