Compare commits

...

2 commits

Author SHA1 Message Date
George Hotz
fec7fbb824 fixup realize 2025-08-05 11:59:53 -07:00
George Hotz
6e39f041b7 don't merge views until we are in kernel land 2025-08-05 11:42:18 -07:00
6 changed files with 20 additions and 7 deletions

View file

@ -2,7 +2,7 @@ from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo, graph_rewrite, AxisType, PatternMatcher, UPat
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
from tinygrad.dtype import AddrSpace
from tinygrad.schedule.kernelize import merge_views, view_left
from tinygrad.opt.swizzler import merge_views, view_left
from tinygrad.helpers import getenv, colored, prod, unwrap
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.shape.view import strides_for_shape

View file

@ -160,6 +160,7 @@ class TestOps(unittest.TestCase):
b = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.int32)
helper_test_op([], lambda: torch.zeros_like(b), lambda: Tensor.zeros_like(a), forward_only=True)
@unittest.skip("undefined behavior")
def test_empty_0(self):
helper_test_op([], lambda: torch.empty(45,65)*0/0, lambda: Tensor.empty(45,65)*0/0, forward_only=True)

View file

@ -15,7 +15,8 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewrite, track_rewrites
from tinygrad.uop.symbolic import symbolic_simple
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
from tinygrad.schedule.kernelize import merge_views, get_kernelize_map, Kernel
from tinygrad.schedule.kernelize import get_kernelize_map, Kernel
from tinygrad.opt.swizzler import merge_views
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule

View file

@ -1,7 +1,8 @@
import unittest
from tinygrad import Tensor
from tinygrad.uop.ops import PatternMatcher, Ops, UPat, graph_rewrite, RewriteContext, UOp
from tinygrad.schedule.kernelize import sym, merge_views
from tinygrad.schedule.kernelize import sym
from tinygrad.opt.swizzler import merge_views
class TestRewriteTrackedChildren(unittest.TestCase):
@unittest.skip("track_children no longer supported")

View file

@ -28,7 +28,11 @@ do_realize = PatternMatcher([
# always realize ASSIGN/CONTIGUOUS/GroupOp.Meta
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}, name="tr"), realize),
# realize before expand or unsafe pad ops
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), name="view"), realize_before_view),
(UPat(Ops.EXPAND, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),)), lambda ctx,tr:
realize(ctx,tr) if not DONT_REALIZE_EXPAND and tr.base.op not in ALWAYS_CONTIGUOUS else None),
(UPat(Ops.PAD, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),)), lambda ctx,tr:
realize(ctx,tr) if not can_pad(tr, ctx) and tr.base.op not in ALWAYS_CONTIGUOUS else None),
#(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), name="view"), realize_before_view),
# realize parents of COPY, MSELECT, MSTACK
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_parents),
])
@ -60,7 +64,7 @@ def group_realizes(sink:UOp) -> dict[UOp, None]:
children: dict[UOp, dict[UOp, None]] = {}
assigns: dict[UOp, None] = {}
for u in (toposort:=sink.toposort()):
if u.op in {Ops.VIEW, Ops.SINK}: continue
if u.op in GroupOp.Movement.union({Ops.VIEW, Ops.SINK}): continue
if u.op is Ops.ASSIGN: assigns[u.buf_uop] = None
for s in u.src: children.setdefault(s.base, {})[u] = None

View file

@ -8,7 +8,7 @@ from tinygrad.dtype import ImageDType, dtypes
from tinygrad.schedule.multi import multi_pm
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
from tinygrad.opt.swizzler import merge_views, view_left, view_right, apply_swizzle, swizzle_reduceop
from tinygrad.opt.swizzler import view_left, view_right, apply_swizzle, swizzle_reduceop
# creation can recurse a lot
import sys
@ -319,6 +319,12 @@ finalize_contiguous = PatternMatcher([
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
new_fixups = PatternMatcher([
(UPat(Ops.COPY, src=(UPat(Ops.RESHAPE, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).reshape(r.arg)),
# TODO: this should be BUFFER_VIEW
(UPat(Ops.COPY, src=(UPat(Ops.SHRINK, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).shrink(r.arg)),
])
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}")
def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
"""
@ -332,7 +338,7 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
"""
# multi + merge_views + simplify
tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views")
tensor_map = graph_rewrite_map(sink, new_fixups+multi_pm+do_fuse+sym+replace_contiguous, ctx={}, name="merge_views")
# display the cleaned up tensor graph
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Tensor Graph")