mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
3 commits
master
...
moveleftri
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ae02dea19 |
||
|
|
0e91b6fd30 | ||
|
|
823dfbde70 |
5 changed files with 124 additions and 111 deletions
|
|
@ -15,7 +15,8 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
|||
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewrite, track_rewrites
|
||||
from tinygrad.uop.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
|
||||
from tinygrad.schedule.kernelize import merge_views, get_kernelize_map, Kernel
|
||||
from tinygrad.schedule.kernelize import get_kernelize_map, Kernel
|
||||
from tinygrad.opt.swizzler import merge_views
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from tinygrad.dtype import ImageDType, AddrSpace
|
|||
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import strides_for_shape, get_contraction
|
||||
from tinygrad.schedule.kernelize import view_left
|
||||
from tinygrad.opt.swizzler import view_left, view_right
|
||||
|
||||
class OptOps(Enum):
|
||||
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
||||
|
|
@ -52,6 +52,8 @@ class TensorCoreOptions:
|
|||
class Kernel:
|
||||
def __init__(self, ast:UOp, opts:Renderer|None=None):
|
||||
assert ast.op is Ops.SINK, ast.op
|
||||
ast = graph_rewrite(ast, view_left, name="Main View Left")
|
||||
ast = graph_rewrite(ast, view_right, name="Main View Right")
|
||||
self.ast = ast
|
||||
|
||||
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
|
||||
|
|
@ -73,7 +75,7 @@ class Kernel:
|
|||
self.sts.append(unwrap(x.src[0].st))
|
||||
|
||||
# add a shapetracker to the end to track the full shape, with 0 strides so it can merge
|
||||
full_shape = ast.full_shape
|
||||
full_shape = self.ast.full_shape
|
||||
self.sts.append(ShapeTracker.from_shape(full_shape, (0,)*len(full_shape)))
|
||||
|
||||
# parameters for optimization
|
||||
|
|
|
|||
108
tinygrad/opt/swizzler.py
Normal file
108
tinygrad/opt/swizzler.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, resolve, sint
|
||||
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
|
||||
from tinygrad.helpers import unwrap, prod, all_same
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.schedule.grouper import ALWAYS_CONTIGUOUS
|
||||
|
||||
# **** swizzler
|
||||
|
||||
merge_views = PatternMatcher([
|
||||
# merge adjacent views
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="v1"),), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)),
|
||||
# replace MovementOps with VIEW
|
||||
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.base.view(mop.st)),
|
||||
# remove NOOP views
|
||||
(UPat.var("x").view(name="view"),
|
||||
lambda x,view: x if x.st is not None and x.op not in GroupOp.Defines and view.st.contiguous and view.shape == x.shape else None),
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}).view(name="view"),
|
||||
lambda view: view.const_like(0) if (mask:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
|
||||
# only unmaksed VIEW on CONST replaces the ShapeTracker
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"),
|
||||
lambda x,view: x.replace(src=(x.src[0].replace(arg=x.st+view.st),)) if all(v.mask is None for v in (x.st+view.st).views) else None),
|
||||
# VIEW on SINK is SINK
|
||||
(UPat(Ops.VIEW, name="v").sink(), lambda v: v.src[0].sink()),
|
||||
])
|
||||
|
||||
def reduce_push_add_ones(src:UOp, r:UOp, view:UOp):
|
||||
# contiguous, expand, and the same with ones removed
|
||||
if unwrap(view.st).contiguous and len(r.shape) < len(view.shape) and \
|
||||
tuple(x for x in r.shape if resolve(x != 1)) == tuple(x for x in view.shape if resolve(x != 1)):
|
||||
new_shape: list[sint] = []
|
||||
new_reduce_axis = []
|
||||
if (contraction:=get_contraction_with_reduce(view.shape, r.shape, r.arg[1])) is None: return None
|
||||
for i,pairs in enumerate(contraction):
|
||||
new_shape_chunk = [view.shape[p] for p in pairs]
|
||||
if i in r.arg[1]:
|
||||
# if this is a reduce axis, we need a 1 in the view here to put it
|
||||
assert len(new_shape_chunk) > 0
|
||||
new_shape += [1]*(len(pairs)-1) + [src.shape[i]]
|
||||
new_reduce_axis.append(len(new_shape)-1)
|
||||
else:
|
||||
# otherwise, pass through the new_shape_chunk
|
||||
new_shape += new_shape_chunk
|
||||
ret = r.replace(src=(src.reshape(tuple(new_shape)),), arg=(r.arg[0], tuple(new_reduce_axis))+r.arg[2:])
|
||||
assert ret.shape == view.shape, f"shape mismatch on reduce_push_add_ones, {ret.shape} != {view.shape}"
|
||||
return ret
|
||||
return None
|
||||
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# view before elementwise and buffer ops
|
||||
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.LOAD, Ops.STORE, Ops.VALID, Ops.SINK}, name="e"),), name="view"),
|
||||
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
|
||||
# if there's ones added after reduce, put this before the reduce
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), reduce_push_add_ones),
|
||||
])
|
||||
|
||||
def apply_swizzle(u:UOp) -> UOp: return graph_rewrite(u, view_left, name="Sub View Left")
|
||||
|
||||
# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape.
|
||||
def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False):
|
||||
# contiguous and same size can push to children
|
||||
# if there's a reduce child, shapes match with ones removed
|
||||
if unwrap(view.st).contiguous and view.size == r.size and \
|
||||
(not (len(r.arg) == 3 and r.arg[2]) or # arg[2] = True is fuse marker
|
||||
tuple((i,x) for i,x in enumerate(r.shape) if resolve(x != 1)) == tuple((i,x) for i,x in enumerate(view.shape) if resolve(x != 1))):
|
||||
return None
|
||||
# swizzle the input
|
||||
input_st = ShapeTracker.from_shape(src.shape)
|
||||
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
|
||||
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
|
||||
strides = strides_for_shape(rshape)
|
||||
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
|
||||
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in unwrap(view.st).views]
|
||||
new_view = tmp + ShapeTracker(tuple(nv))
|
||||
swizzled_input = apply_swizzle(src.view(new_view))
|
||||
# create a new reduceop
|
||||
new_axis = tuple(range(len(view.shape), len(view.shape) + len(r.axis_arg)))
|
||||
if fuse: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input.fuse(),), (r.arg[0], new_axis, True))
|
||||
else: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input,), (r.arg[0], new_axis))
|
||||
return red.reshape(view.shape)
|
||||
|
||||
def reduceop_view_right(src:UOp, v:UOp, r:UOp):
|
||||
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
|
||||
new_axis = [i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u]
|
||||
return src.r(r.arg[0], tuple(new_axis)).reshape(r.shape)
|
||||
|
||||
def elementwise_view_right(root:UOp):
|
||||
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in ALWAYS_CONTIGUOUS]): return None
|
||||
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
|
||||
# place view after applying the elementwise op
|
||||
new_st = ShapeTracker.from_shape(swizzles[0].base.shape)
|
||||
new_src = [x.base if x.base.shape==new_st.shape else apply_swizzle(x.view(new_st)) for x in root.src]
|
||||
# reshape to match downstream shapes
|
||||
return root.replace(src=tuple(new_src)).reshape(root.shape)
|
||||
|
||||
# push VIEW to children
|
||||
view_right = merge_views+PatternMatcher([
|
||||
# push a non contiguous ShapeTracker through reduceop
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
|
||||
# apply view after reduceops
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="src"),), name="v"),), name="r"), reduceop_view_right),
|
||||
# apply view after elementwise ops
|
||||
(UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS}, name="root"), elementwise_view_right),
|
||||
# merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
|
||||
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
|
||||
# add VIEW to any DEFINE_GLOBAL that somehow lost its view
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.DEFINE_GLOBAL, name="d"),), name="x", allow_any_len=True), lambda d,x: x.replace(src=(d.view(d.st),)+x.src[1:])),
|
||||
])
|
||||
|
|
@ -3,7 +3,7 @@ from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND,
|
|||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK}
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL}
|
||||
|
||||
# **** Grouper decides which of the UOps realize
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup,
|
|||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
|
||||
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
|
||||
|
||||
# creation can recurse a lot
|
||||
|
|
@ -148,105 +147,6 @@ create_kernels = PatternMatcher([
|
|||
lambda ms: UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).reshape(ms.src[0].arg)),
|
||||
])
|
||||
|
||||
# **** swizzler
|
||||
|
||||
merge_views = PatternMatcher([
|
||||
# merge adjacent views
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="v1"),), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)),
|
||||
# replace MovementOps with VIEW
|
||||
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.base.view(mop.st)),
|
||||
# remove NOOP views
|
||||
(UPat.var("x").view(name="view"),
|
||||
lambda x,view: x if x.st is not None and x.op not in GroupOp.Defines and view.st.contiguous and view.shape == x.shape else None),
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}).view(name="view"),
|
||||
lambda view: view.const_like(0) if (mask:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
|
||||
# only unmaksed VIEW on CONST replaces the ShapeTracker
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"),
|
||||
lambda x,view: x.replace(src=(x.src[0].replace(arg=x.st+view.st),)) if all(v.mask is None for v in (x.st+view.st).views) else None),
|
||||
])
|
||||
|
||||
def reduce_push_add_ones(src:UOp, r:UOp, view:UOp):
|
||||
# contiguous, expand, and the same with ones removed
|
||||
if unwrap(view.st).contiguous and len(r.shape) < len(view.shape) and \
|
||||
tuple(x for x in r.shape if resolve(x != 1)) == tuple(x for x in view.shape if resolve(x != 1)):
|
||||
new_shape: list[sint] = []
|
||||
new_reduce_axis = []
|
||||
if (contraction:=get_contraction_with_reduce(view.shape, r.shape, r.arg[1])) is None: return None
|
||||
for i,pairs in enumerate(contraction):
|
||||
new_shape_chunk = [view.shape[p] for p in pairs]
|
||||
if i in r.arg[1]:
|
||||
# if this is a reduce axis, we need a 1 in the view here to put it
|
||||
assert len(new_shape_chunk) > 0
|
||||
new_shape += [1]*(len(pairs)-1) + [src.shape[i]]
|
||||
new_reduce_axis.append(len(new_shape)-1)
|
||||
else:
|
||||
# otherwise, pass through the new_shape_chunk
|
||||
new_shape += new_shape_chunk
|
||||
ret = r.replace(src=(src.reshape(tuple(new_shape)),), arg=(r.arg[0], tuple(new_reduce_axis))+r.arg[2:])
|
||||
assert ret.shape == view.shape, f"shape mismatch on reduce_push_add_ones, {ret.shape} != {view.shape}"
|
||||
return ret
|
||||
return None
|
||||
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# view before elementwise and buffer ops
|
||||
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.LOAD, Ops.STORE, Ops.VALID, Ops.SINK}, name="e"),), name="view"),
|
||||
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
|
||||
# if there's ones added after reduce, put this before the reduce
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), reduce_push_add_ones),
|
||||
])
|
||||
|
||||
def apply_swizzle(u:UOp) -> UOp: return graph_rewrite(u, view_left, name="Sub View Left")
|
||||
|
||||
# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape.
|
||||
def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False):
|
||||
# contiguous and same size can push to children
|
||||
# if there's a reduce child, shapes match with ones removed
|
||||
if unwrap(view.st).contiguous and view.size == r.size and \
|
||||
(not (len(r.arg) == 3 and r.arg[2]) or # arg[2] = True is fuse marker
|
||||
tuple((i,x) for i,x in enumerate(r.shape) if resolve(x != 1)) == tuple((i,x) for i,x in enumerate(view.shape) if resolve(x != 1))):
|
||||
return None
|
||||
# swizzle the input
|
||||
input_st = ShapeTracker.from_shape(src.shape)
|
||||
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
|
||||
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
|
||||
strides = strides_for_shape(rshape)
|
||||
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
|
||||
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in unwrap(view.st).views]
|
||||
new_view = tmp + ShapeTracker(tuple(nv))
|
||||
swizzled_input = apply_swizzle(src.view(new_view))
|
||||
# create a new reduceop
|
||||
new_axis = tuple(range(len(view.shape), len(view.shape) + len(r.axis_arg)))
|
||||
if fuse: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input.fuse(),), (r.arg[0], new_axis, True))
|
||||
else: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input,), (r.arg[0], new_axis))
|
||||
return red.reshape(view.shape)
|
||||
|
||||
def reduceop_view_right(src:UOp, v:UOp, r:UOp):
|
||||
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
|
||||
new_axis = [i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u]
|
||||
return src.r(r.arg[0], tuple(new_axis)).reshape(r.shape)
|
||||
|
||||
def elementwise_view_right(root:UOp):
|
||||
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in ALWAYS_CONTIGUOUS]): return None
|
||||
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
|
||||
# place view after applying the elementwise op
|
||||
new_st = ShapeTracker.from_shape(swizzles[0].base.shape)
|
||||
new_src = [x.base if x.base.shape==new_st.shape else apply_swizzle(x.view(new_st)) for x in root.src]
|
||||
# reshape to match downstream shapes
|
||||
return root.replace(src=tuple(new_src)).reshape(root.shape)
|
||||
|
||||
# push VIEW to children
|
||||
view_right = merge_views+PatternMatcher([
|
||||
# push a non contiguous ShapeTracker through reduceop
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
|
||||
# apply view after reduceops
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="src"),), name="v"),), name="r"), reduceop_view_right),
|
||||
# apply view after elementwise ops
|
||||
(UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS}, name="root"), elementwise_view_right),
|
||||
# merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
|
||||
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
|
||||
])
|
||||
|
||||
# **** fix kernel AST
|
||||
|
||||
add_buffer_ops = PatternMatcher([
|
||||
|
|
@ -294,8 +194,6 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
|
|||
if k.arg.ast.op in GroupOp.Meta or all(s.op is Ops.STORE for s in k.arg.ast.src): return None
|
||||
# replace global memory ops with the BUFFER they write to
|
||||
ast = graph_rewrite(k.arg.ast, replace_globals, bottom_up=True, name="replace globals")
|
||||
# push views to edges
|
||||
ast = graph_rewrite(graph_rewrite(ast, view_left, name="Main View Left"), view_right, name="Main View Right")
|
||||
# replace buffer with define_global + add load/store last
|
||||
bufs = []
|
||||
for s in k.src:
|
||||
|
|
@ -303,7 +201,7 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
|
|||
# traverse back through MSELECT and MSTACK. HACK: 0 branch of MSTACK only
|
||||
while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
|
||||
bufs.append(s)
|
||||
ast = graph_rewrite(ast, view_left+add_buffer_ops+fix_kernel_ops, bufs, bottom_up=True, name="replace buffer")
|
||||
ast = graph_rewrite(ast, add_buffer_ops+fix_kernel_ops, bufs, bottom_up=True, name="replace buffer")
|
||||
if ast.op is Ops.SINK and not all_same([x.device for x in k.src]):
|
||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
|
||||
return k.replace(arg=Kernel(ast, k.arg.metadata))
|
||||
|
|
@ -383,7 +281,7 @@ def fuse_arange(root:UOp):
|
|||
return root.substitute(fuse_rep, name="fuse_arange") if fuse_rep else None
|
||||
|
||||
do_fuse = PatternMatcher([
|
||||
(UPat(Ops.FUSE, name="x"), do_fusion),
|
||||
#(UPat(Ops.FUSE, name="x"), do_fusion),
|
||||
(UPat(Ops.REDUCE_AXIS, name="root"), fuse_arange),
|
||||
])
|
||||
|
||||
|
|
@ -418,6 +316,12 @@ finalize_contiguous = PatternMatcher([
|
|||
|
||||
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
|
||||
new_fixups = PatternMatcher([
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.RESHAPE, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).reshape(r.arg)),
|
||||
# TODO: this should be BUFFER_VIEW
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.SHRINK, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).shrink(r.arg)),
|
||||
])
|
||||
|
||||
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}")
|
||||
def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
"""
|
||||
|
|
@ -431,7 +335,7 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
|||
"""
|
||||
|
||||
# multi + merge_views + simplify
|
||||
tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views")
|
||||
tensor_map = graph_rewrite_map(sink, new_fixups+multi_pm+do_fuse+sym+replace_contiguous, ctx={}, name="merge_views")
|
||||
|
||||
# display the cleaned up tensor graph
|
||||
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Tensor Graph")
|
||||
|
|
@ -441,8 +345,6 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
|||
tensor_map = graph_rewrite_map(tensor_map[sink], add_contiguous, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="add_contiguous")
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], finalize_contiguous+remove_tags, input_map=tensor_map, name="finalize_contiguous")
|
||||
|
||||
# TODO: move view_left/view_right here
|
||||
|
||||
# group into kernels (this is context-free)
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], create_kernels, input_map=tensor_map, name="create_kernels")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue