mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
sym
This commit is contained in:
parent
b5d7d339f4
commit
59bfab8a9b
1 changed files with 4 additions and 3 deletions
|
|
@ -2,7 +2,7 @@ from dataclasses import dataclass, field
|
|||
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve
|
||||
from tinygrad.uop.ops import track_rewrites, _substitute
|
||||
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
||||
from tinygrad.uop.symbolic import symbolic_simple
|
||||
from tinygrad.uop.symbolic import symbolic_simple, sym
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
|
|
@ -56,7 +56,7 @@ def copy_reorder_view(copy:UOp, view:UOp, base:UOp):
|
|||
if prod(view.shape) < prod(base.shape): return view.contiguous().copy_to_device(copy.device)
|
||||
return base.copy_to_device(copy.device).view(view.arg)
|
||||
|
||||
sym = symbolic_simple+PatternMatcher([
|
||||
kernelize_sym = symbolic_simple+PatternMatcher([
|
||||
# UOp with size 0 is zero
|
||||
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None),
|
||||
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
|
||||
|
|
@ -340,13 +340,14 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
|||
"""
|
||||
|
||||
# multi + merge_views + simplify
|
||||
tensor_map = graph_rewrite_map(sink, new_fixups+multi_pm+do_fuse+sym+replace_contiguous, ctx={}, name="merge_views")
|
||||
tensor_map = graph_rewrite_map(sink, new_fixups+multi_pm+do_fuse+kernelize_sym+replace_contiguous, ctx={}, name="merge_views")
|
||||
|
||||
# testing
|
||||
# NOTE: graph_rewrite_map with bottom_up is broken
|
||||
rsink = graph_rewrite(tensor_map[sink], rangeify_fixups, bottom_up=True, name="* contiguous")
|
||||
rsink = graph_rewrite(rsink, pm_children, ctx=ChildrenContext(), bottom_up=True, name="* children")
|
||||
rsink = graph_rewrite(rsink, pm_rangeify, ctx=RangeifyContext(), bottom_up=True, name="* rangeify")
|
||||
rsink = graph_rewrite(rsink, sym, name="* symbolic")
|
||||
rsink = graph_rewrite(rsink, pm_add_buffers, ctx=AddBufferContext(), bottom_up=True, name="* buffer")
|
||||
|
||||
from tinygrad.codegen.devectorizer import pm_reduce, ReduceContext
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue