mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
2 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fec7fbb824 | ||
|
|
6e39f041b7 |
6 changed files with 20 additions and 7 deletions
|
|
@ -2,7 +2,7 @@ from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes
|
|||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, graph_rewrite, AxisType, PatternMatcher, UPat
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.schedule.kernelize import merge_views, view_left
|
||||
from tinygrad.opt.swizzler import merge_views, view_left
|
||||
from tinygrad.helpers import getenv, colored, prod, unwrap
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
|
|
|
|||
|
|
@ -160,6 +160,7 @@ class TestOps(unittest.TestCase):
|
|||
b = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.int32)
|
||||
helper_test_op([], lambda: torch.zeros_like(b), lambda: Tensor.zeros_like(a), forward_only=True)
|
||||
|
||||
@unittest.skip("undefined behavior")
|
||||
def test_empty_0(self):
|
||||
helper_test_op([], lambda: torch.empty(45,65)*0/0, lambda: Tensor.empty(45,65)*0/0, forward_only=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
|||
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewrite, track_rewrites
|
||||
from tinygrad.uop.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
|
||||
from tinygrad.schedule.kernelize import merge_views, get_kernelize_map, Kernel
|
||||
from tinygrad.schedule.kernelize import get_kernelize_map, Kernel
|
||||
from tinygrad.opt.swizzler import merge_views
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.uop.ops import PatternMatcher, Ops, UPat, graph_rewrite, RewriteContext, UOp
|
||||
from tinygrad.schedule.kernelize import sym, merge_views
|
||||
from tinygrad.schedule.kernelize import sym
|
||||
from tinygrad.opt.swizzler import merge_views
|
||||
|
||||
class TestRewriteTrackedChildren(unittest.TestCase):
|
||||
@unittest.skip("track_children no longer supported")
|
||||
|
|
|
|||
|
|
@ -28,7 +28,11 @@ do_realize = PatternMatcher([
|
|||
# always realize ASSIGN/CONTIGUOUS/GroupOp.Meta
|
||||
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}, name="tr"), realize),
|
||||
# realize before expand or unsafe pad ops
|
||||
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), name="view"), realize_before_view),
|
||||
(UPat(Ops.EXPAND, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),)), lambda ctx,tr:
|
||||
realize(ctx,tr) if not DONT_REALIZE_EXPAND and tr.base.op not in ALWAYS_CONTIGUOUS else None),
|
||||
(UPat(Ops.PAD, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),)), lambda ctx,tr:
|
||||
realize(ctx,tr) if not can_pad(tr, ctx) and tr.base.op not in ALWAYS_CONTIGUOUS else None),
|
||||
#(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), name="view"), realize_before_view),
|
||||
# realize parents of COPY, MSELECT, MSTACK
|
||||
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_parents),
|
||||
])
|
||||
|
|
@ -60,7 +64,7 @@ def group_realizes(sink:UOp) -> dict[UOp, None]:
|
|||
children: dict[UOp, dict[UOp, None]] = {}
|
||||
assigns: dict[UOp, None] = {}
|
||||
for u in (toposort:=sink.toposort()):
|
||||
if u.op in {Ops.VIEW, Ops.SINK}: continue
|
||||
if u.op in GroupOp.Movement.union({Ops.VIEW, Ops.SINK}): continue
|
||||
if u.op is Ops.ASSIGN: assigns[u.buf_uop] = None
|
||||
for s in u.src: children.setdefault(s.base, {})[u] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from tinygrad.dtype import ImageDType, dtypes
|
|||
from tinygrad.schedule.multi import multi_pm
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
|
||||
from tinygrad.opt.swizzler import merge_views, view_left, view_right, apply_swizzle, swizzle_reduceop
|
||||
from tinygrad.opt.swizzler import view_left, view_right, apply_swizzle, swizzle_reduceop
|
||||
|
||||
# creation can recurse a lot
|
||||
import sys
|
||||
|
|
@ -319,6 +319,12 @@ finalize_contiguous = PatternMatcher([
|
|||
|
||||
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
|
||||
new_fixups = PatternMatcher([
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.RESHAPE, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).reshape(r.arg)),
|
||||
# TODO: this should be BUFFER_VIEW
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.SHRINK, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).shrink(r.arg)),
|
||||
])
|
||||
|
||||
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}")
|
||||
def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
"""
|
||||
|
|
@ -332,7 +338,7 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
|||
"""
|
||||
|
||||
# multi + merge_views + simplify
|
||||
tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views")
|
||||
tensor_map = graph_rewrite_map(sink, new_fixups+multi_pm+do_fuse+sym+replace_contiguous, ctx={}, name="merge_views")
|
||||
|
||||
# display the cleaned up tensor graph
|
||||
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Tensor Graph")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue