mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
remove tensor_map
This commit is contained in:
parent
f9a569c0fd
commit
4e3547e5cb
2 changed files with 17 additions and 22 deletions
2
test/external/external_test_onnx_runner.py
vendored
2
test/external/external_test_onnx_runner.py
vendored
|
|
@ -56,10 +56,12 @@ class TestOnnxRunner(unittest.TestCase):
|
|||
output = runner({'inp': Tensor([1, 2, 3, 4])})['output']
|
||||
_check_ast_count(0, output)
|
||||
|
||||
@unittest.skip("const folding is removed")
|
||||
def test_const_fold_from_disk(self):
|
||||
self._test_const_fold_unary_op(True)
|
||||
self._test_const_fold_binary_op(True)
|
||||
|
||||
@unittest.skip("const folding is removed")
|
||||
def test_const_fold_from_memory(self):
|
||||
self._test_const_fold_unary_op(False)
|
||||
# TODO: understand this and fix this, bitcast related
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ pm_post_sched_cache = PatternMatcher([
|
|||
])
|
||||
|
||||
# rewrite all contiguous to assign
|
||||
def apply_buffer_map(ctx:dict[UOp,UOp], x:UOp):
|
||||
def _apply_buffer_map(ctx:dict[UOp,UOp], x:UOp):
|
||||
if x.op is Ops.AFTER:
|
||||
ctx[x] = x.src[0]
|
||||
return None
|
||||
|
|
@ -127,14 +127,9 @@ def apply_buffer_map(ctx:dict[UOp,UOp], x:UOp):
|
|||
buffer = ctx[x].buf_uop.base.reshape(x.max_shard_shape)
|
||||
if isinstance(x.device, tuple) and x.axis is not None: buffer = buffer.multi(x.axis)
|
||||
return buffer.assign(x.src[0] if x.op is Ops.CONTIGUOUS else x.rtag())
|
||||
pm_apply_buffer_map = PatternMatcher([ (UPat(GroupOp.All, name="x"), apply_buffer_map), ])
|
||||
|
||||
schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {}
|
||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
|
||||
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]:
|
||||
# big_sink srcs are all the Tensors
|
||||
st = time.perf_counter()
|
||||
pm_apply_buffer_map = PatternMatcher([ (UPat(GroupOp.All, name="x"), _apply_buffer_map), ])
|
||||
|
||||
def _get_apply_buffer_map(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
|
||||
# precreate all buffers
|
||||
buffer_map: dict[UOp, UOp] = {}
|
||||
dont_realize = {Ops.CONST, Ops.BUFFER, Ops.BIND, Ops.DEFINE_VAR, Ops.AFTER}
|
||||
|
|
@ -153,6 +148,14 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
|||
# apply buffer map, do a few simple rewrites
|
||||
big_sink = graph_rewrite(big_sink, pm_apply_buffer_map, ctx=buffer_map, bottom_up=True, name="apply buffer map")
|
||||
big_sink = graph_rewrite(big_sink, _remove_all_tags)
|
||||
return big_sink, buffer_map
|
||||
|
||||
schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {}
|
||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
|
||||
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]:
|
||||
# big_sink srcs are all the Tensors
|
||||
st = time.perf_counter()
|
||||
big_sink, buffer_map = _get_apply_buffer_map(big_sink)
|
||||
|
||||
# replace BUFFERs with PARAMs, CONSTs UNIQUE with LUNIQUE, strip BIND values for cache key, extract var_vals
|
||||
input_buffers: dict[UOp, UOp] = {}
|
||||
|
|
@ -179,23 +182,15 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
|||
big_sink = big_sink_cache.substitute(tensor_map, name="Apply Kernelize Map")
|
||||
|
||||
pre_schedule, buf_uops_sink = create_schedule(big_sink)
|
||||
|
||||
# save in schedule cache (include AFTERs in tensor_map so we don't need big_sink)
|
||||
after_map = [(u, u.buf_uop) for u in big_sink.toposort() if u.op is Ops.AFTER]
|
||||
tensor_map_sink = UOp.sink(*flatten([(k,v) for k,v in tensor_map.items()]), *flatten(after_map))
|
||||
combined_sink = UOp.sink(tensor_map_sink, buf_uops_sink)
|
||||
if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, combined_sink)
|
||||
if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, buf_uops_sink)
|
||||
else:
|
||||
# schedule cache hit
|
||||
del big_sink_cache
|
||||
pre_schedule, combined_sink = sc_ret
|
||||
pre_schedule, buf_uops_sink = sc_ret
|
||||
del big_sink_cache
|
||||
|
||||
# replace all the PARAMs/LUNIQUEs back (single graph_rewrite for everything)
|
||||
input_buffers_inverse = {v:k for k,v in input_buffers.items()}
|
||||
combined = graph_rewrite(combined_sink, pm_post_sched_cache, ctx=input_buffers_inverse, name="unrewrite combined")
|
||||
tensor_map_sink, buf_uops_sink = combined.src
|
||||
tm_src = tensor_map_sink.src
|
||||
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
|
||||
buf_uops_sink = graph_rewrite(buf_uops_sink, pm_post_sched_cache, ctx=input_buffers_inverse, name="unrewrite combined")
|
||||
|
||||
# add bufs to pre_schedule
|
||||
schedule: list[ExecItem] = []
|
||||
|
|
@ -223,6 +218,4 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
|||
f" | {len(UOpMetaClass.ucache)} uops in cache")
|
||||
|
||||
used_vars = set().union(*[{v.arg[0] for v in si.ast.variables()} for si in schedule])
|
||||
#return tensor_map, schedule, {k:v for k,v in var_vals.items() if k in used_vars}
|
||||
# tensor_map isn't used anymore
|
||||
return buffer_map, schedule, {k:v for k,v in var_vals.items() if k in used_vars}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue