This commit is contained in:
George Hotz 2025-08-13 17:49:54 -07:00
commit 59bfab8a9b

View file

@ -2,7 +2,7 @@ from dataclasses import dataclass, field
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve
from tinygrad.uop.ops import track_rewrites, _substitute
from tinygrad.uop.spec import type_verify, tensor_uop_spec
from tinygrad.uop.symbolic import symbolic_simple
from tinygrad.uop.symbolic import symbolic_simple, sym
from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
from tinygrad.dtype import ImageDType
from tinygrad.schedule.multi import multi_pm
@ -56,7 +56,7 @@ def copy_reorder_view(copy:UOp, view:UOp, base:UOp):
if prod(view.shape) < prod(base.shape): return view.contiguous().copy_to_device(copy.device)
return base.copy_to_device(copy.device).view(view.arg)
sym = symbolic_simple+PatternMatcher([
kernelize_sym = symbolic_simple+PatternMatcher([
# UOp with size 0 is zero
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None),
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
@ -340,13 +340,14 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
"""
# multi + merge_views + simplify
tensor_map = graph_rewrite_map(sink, new_fixups+multi_pm+do_fuse+sym+replace_contiguous, ctx={}, name="merge_views")
tensor_map = graph_rewrite_map(sink, new_fixups+multi_pm+do_fuse+kernelize_sym+replace_contiguous, ctx={}, name="merge_views")
# testing
# NOTE: graph_rewrite_map with bottom_up is broken
rsink = graph_rewrite(tensor_map[sink], rangeify_fixups, bottom_up=True, name="* contiguous")
rsink = graph_rewrite(rsink, pm_children, ctx=ChildrenContext(), bottom_up=True, name="* children")
rsink = graph_rewrite(rsink, pm_rangeify, ctx=RangeifyContext(), bottom_up=True, name="* rangeify")
rsink = graph_rewrite(rsink, sym, name="* symbolic")
rsink = graph_rewrite(rsink, pm_add_buffers, ctx=AddBufferContext(), bottom_up=True, name="* buffer")
from tinygrad.codegen.devectorizer import pm_reduce, ReduceContext