remove tensor_map

This commit is contained in:
George Hotz 2026-02-19 18:23:12 +08:00
commit 4e3547e5cb
2 changed files with 17 additions and 22 deletions

View file

@ -56,10 +56,12 @@ class TestOnnxRunner(unittest.TestCase):
output = runner({'inp': Tensor([1, 2, 3, 4])})['output']
_check_ast_count(0, output)
@unittest.skip("const folding is removed")
def test_const_fold_from_disk(self):
self._test_const_fold_unary_op(True)
self._test_const_fold_binary_op(True)
@unittest.skip("const folding is removed")
def test_const_fold_from_memory(self):
self._test_const_fold_unary_op(False)
# TODO: understand this and fix this, bitcast related

View file

@ -111,7 +111,7 @@ pm_post_sched_cache = PatternMatcher([
])
# rewrite all contiguous to assign
def apply_buffer_map(ctx:dict[UOp,UOp], x:UOp):
def _apply_buffer_map(ctx:dict[UOp,UOp], x:UOp):
if x.op is Ops.AFTER:
ctx[x] = x.src[0]
return None
@ -127,14 +127,9 @@ def apply_buffer_map(ctx:dict[UOp,UOp], x:UOp):
buffer = ctx[x].buf_uop.base.reshape(x.max_shard_shape)
if isinstance(x.device, tuple) and x.axis is not None: buffer = buffer.multi(x.axis)
return buffer.assign(x.src[0] if x.op is Ops.CONTIGUOUS else x.rtag())
pm_apply_buffer_map = PatternMatcher([ (UPat(GroupOp.All, name="x"), apply_buffer_map), ])
schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {}
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]:
# big_sink srcs are all the Tensors
st = time.perf_counter()
pm_apply_buffer_map = PatternMatcher([ (UPat(GroupOp.All, name="x"), _apply_buffer_map), ])
def _get_apply_buffer_map(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
# precreate all buffers
buffer_map: dict[UOp, UOp] = {}
dont_realize = {Ops.CONST, Ops.BUFFER, Ops.BIND, Ops.DEFINE_VAR, Ops.AFTER}
@ -153,6 +148,14 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
# apply buffer map, do a few simple rewrites
big_sink = graph_rewrite(big_sink, pm_apply_buffer_map, ctx=buffer_map, bottom_up=True, name="apply buffer map")
big_sink = graph_rewrite(big_sink, _remove_all_tags)
return big_sink, buffer_map
schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {}
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]:
# big_sink srcs are all the Tensors
st = time.perf_counter()
big_sink, buffer_map = _get_apply_buffer_map(big_sink)
# replace BUFFERs with PARAMs, CONSTs UNIQUE with LUNIQUE, strip BIND values for cache key, extract var_vals
input_buffers: dict[UOp, UOp] = {}
@ -179,23 +182,15 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
big_sink = big_sink_cache.substitute(tensor_map, name="Apply Kernelize Map")
pre_schedule, buf_uops_sink = create_schedule(big_sink)
# save in schedule cache (include AFTERs in tensor_map so we don't need big_sink)
after_map = [(u, u.buf_uop) for u in big_sink.toposort() if u.op is Ops.AFTER]
tensor_map_sink = UOp.sink(*flatten([(k,v) for k,v in tensor_map.items()]), *flatten(after_map))
combined_sink = UOp.sink(tensor_map_sink, buf_uops_sink)
if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, combined_sink)
if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, buf_uops_sink)
else:
# schedule cache hit
del big_sink_cache
pre_schedule, combined_sink = sc_ret
pre_schedule, buf_uops_sink = sc_ret
del big_sink_cache
# replace all the PARAMs/LUNIQUEs back (single graph_rewrite for everything)
input_buffers_inverse = {v:k for k,v in input_buffers.items()}
combined = graph_rewrite(combined_sink, pm_post_sched_cache, ctx=input_buffers_inverse, name="unrewrite combined")
tensor_map_sink, buf_uops_sink = combined.src
tm_src = tensor_map_sink.src
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
buf_uops_sink = graph_rewrite(buf_uops_sink, pm_post_sched_cache, ctx=input_buffers_inverse, name="unrewrite combined")
# add bufs to pre_schedule
schedule: list[ExecItem] = []
@ -223,6 +218,4 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
f" | {len(UOpMetaClass.ucache)} uops in cache")
used_vars = set().union(*[{v.arg[0] for v in si.ast.variables()} for si in schedule])
#return tensor_map, schedule, {k:v for k,v in var_vals.items() if k in used_vars}
# tensor_map isn't used anymore
return buffer_map, schedule, {k:v for k,v in var_vals.items() if k in used_vars}