mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
9 commits
master
...
simpler_fu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d3c8f09579 | ||
|
|
3969d8574d |
||
|
|
174747efb3 | ||
|
|
aaec715fa8 | ||
|
|
de7b4b10af |
||
|
|
807392be8b |
||
|
|
27c3c67e7c |
||
|
|
0dbaa6293c |
||
|
|
625ecb9fec |
2 changed files with 19 additions and 73 deletions
|
|
@ -3,11 +3,11 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewr
|
||||||
from tinygrad.uop.ops import track_rewrites, _substitute
|
from tinygrad.uop.ops import track_rewrites, _substitute
|
||||||
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
||||||
from tinygrad.uop.symbolic import symbolic_simple
|
from tinygrad.uop.symbolic import symbolic_simple
|
||||||
from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
|
from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, DEBUG, SPLIT_REDUCEOP, flatten
|
||||||
from tinygrad.dtype import ImageDType
|
from tinygrad.dtype import ImageDType
|
||||||
from tinygrad.schedule.multi import multi_pm
|
from tinygrad.schedule.multi import multi_pm
|
||||||
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
|
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
|
||||||
from tinygrad.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop
|
from tinygrad.opt.swizzler import merge_views
|
||||||
|
|
||||||
# creation can recurse a lot
|
# creation can recurse a lot
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -209,75 +209,8 @@ def append_metadata(root:UOp, k:UOp):
|
||||||
|
|
||||||
replace_metadata = PatternMatcher([(UPat(Ops.ASSIGN, src=(UPat(), UPat(Ops.KERNEL, name="k")), name="root", allow_any_len=True), append_metadata),])
|
replace_metadata = PatternMatcher([(UPat(Ops.ASSIGN, src=(UPat(), UPat(Ops.KERNEL, name="k")), name="root", allow_any_len=True), append_metadata),])
|
||||||
|
|
||||||
pm_fuse = PatternMatcher([
|
|
||||||
# FUSE on CONTIGUOUS removes FUSE
|
|
||||||
(UPat(Ops.CONTIGUOUS, name="c").fuse(), lambda c: c),
|
|
||||||
|
|
||||||
# FUSE triggers swizzle on reduceop
|
|
||||||
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").or_casted(),), name="view").fuse(),
|
|
||||||
lambda r,src,view: ret.cast(view.dtype) if (ret:=swizzle_reduceop(r, src, view, fuse=True)) is not None else None),
|
|
||||||
|
|
||||||
# FUSE on reduce (without view) adds fuse marker to grouper
|
|
||||||
(UPat(Ops.REDUCE_AXIS, name="r").fuse(),
|
|
||||||
lambda r: r.replace(src=(r.src[0].fuse(),), arg=r.arg+(True,)) if len(r.arg) == 2 else None),
|
|
||||||
|
|
||||||
# remove FUSE and insert CONTIGUOUS if it's an unsafe pad
|
|
||||||
(UPat(Ops.VIEW, src=(UPat(GroupOp.UnsafePad, name="alu"),), name="view").fuse(),
|
|
||||||
lambda alu, view: alu.contiguous().view(view.st) if any(v.mask is not None for v in view.st.views) else None),
|
|
||||||
|
|
||||||
# FUSE elementwise.
|
|
||||||
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST}, name="alu"),), name="view").fuse(),
|
|
||||||
lambda alu, view: alu.replace(src=tuple(apply_swizzle(x.view(view.arg)).fuse() for x in alu.src))),
|
|
||||||
|
|
||||||
# push FUSE through to srcs
|
|
||||||
(UPat(Ops.FUSE, name="x"), lambda x: x.src[0].replace(src=tuple(y.fuse() for y in x.src[0].src))),
|
|
||||||
])
|
|
||||||
|
|
||||||
def do_fusion(x:UOp):
|
|
||||||
found_contiguous = {}
|
|
||||||
def gate_contiguous(x):
|
|
||||||
if is_contiguous:=(x.op is Ops.CONTIGUOUS): found_contiguous[x] = x.replace(src=(UOp(Ops.VIEW, arg=x.st), UOp.unique()))
|
|
||||||
return not is_contiguous
|
|
||||||
x.toposort(gate=gate_contiguous)
|
|
||||||
del gate_contiguous
|
|
||||||
return graph_rewrite(x.substitute(found_contiguous), pm_fuse, name="local fusion").substitute({v:k for k,v in found_contiguous.items()})
|
|
||||||
|
|
||||||
def fuse_arange(root:UOp):
|
|
||||||
# skip if root is arange
|
|
||||||
if not FUSE_ARANGE or root.src[0].base.op is Ops.CONST: return None
|
|
||||||
# gather all local aranges (including any fused ones)
|
|
||||||
local_arange: list[UOp] = []
|
|
||||||
def gate_reduce(u):
|
|
||||||
if u.op is Ops.REDUCE_AXIS and u.src[0].base.op is Ops.CONST: local_arange.append(u)
|
|
||||||
return u.op not in {*ALWAYS_CONTIGUOUS, Ops.REDUCE_AXIS} or u is root
|
|
||||||
toposort = root.toposort(gate=gate_reduce)
|
|
||||||
if not local_arange: return None
|
|
||||||
# fuse the nearest expand child of arange
|
|
||||||
local_children: dict[UOp, list[UOp]] = {}
|
|
||||||
for u in toposort:
|
|
||||||
for s in u.src: local_children.setdefault(s, []).append(u)
|
|
||||||
fuse_rep: dict[UOp, UOp] = {}
|
|
||||||
for r in local_arange:
|
|
||||||
# skip if already fused
|
|
||||||
if len(r.arg) > 2: continue
|
|
||||||
q = list(local_children[r])
|
|
||||||
while q:
|
|
||||||
u = q.pop()
|
|
||||||
if not (curr_children:=local_children.get(u, [])): continue
|
|
||||||
for child in curr_children:
|
|
||||||
other_paths = {s for s in child.toposort() if s.op in {Ops.REDUCE_AXIS, Ops.BUFFER} and s not in {root, r}}
|
|
||||||
fuse_rep[child] = child.replace(src=tuple(s.fuse() if s is u else s for s in child.src))
|
|
||||||
if other_paths: break
|
|
||||||
else: q.extend(curr_children)
|
|
||||||
return root.substitute(fuse_rep, name="fuse_arange") if fuse_rep else None
|
|
||||||
|
|
||||||
do_fuse = PatternMatcher([
|
|
||||||
(UPat(Ops.FUSE, name="x"), do_fusion),
|
|
||||||
(UPat(Ops.REDUCE_AXIS, name="root"), fuse_arange),
|
|
||||||
])
|
|
||||||
|
|
||||||
add_contiguous = PatternMatcher([(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"),
|
add_contiguous = PatternMatcher([(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"),
|
||||||
lambda ctx,x: x.replace(tag=1).contiguous() if x in ctx and x.tag is None else None)])
|
lambda ctx,x: x.replace(tag=1).contiguous(tag=3 if x in ctx[1] else 2) if x in ctx[0] and x.tag is None else None)])
|
||||||
|
|
||||||
# TODO: get this from the device through GrouperOpts
|
# TODO: get this from the device through GrouperOpts
|
||||||
DEVICE_MAX_BUFS = {"METAL":32, "WEBGPU":8}
|
DEVICE_MAX_BUFS = {"METAL":32, "WEBGPU":8}
|
||||||
|
|
@ -301,6 +234,16 @@ def view_add_srcs(x:UOp):
|
||||||
return x.replace(src=x.src+tuple(avars))
|
return x.replace(src=x.src+tuple(avars))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
new_fusion = PatternMatcher([
|
||||||
|
# FUSE removes CONTIGUOUS tag=2, dies to CONTIGUOUS w/o tag,
|
||||||
|
(UPat(Ops.FUSE, src=(UPat(Ops.CONTIGUOUS, name="c"),)), lambda c: c.src[0].replace(tag=None).fuse() if c.tag == 2 else c),
|
||||||
|
(UPat(Ops.FUSE, src=(UPat(name="s"),)), lambda s: s.replace(src=tuple([y.fuse() for y in s.src]))),
|
||||||
|
# remove CONTIGUOUS if there's no BUFFER upsteam
|
||||||
|
(UPat(Ops.CONTIGUOUS, name="c"),
|
||||||
|
lambda c: None if c.tag != 2 or c.src[0].op is Ops.COPY or
|
||||||
|
any(x.op in GroupOp.UnsafePad.union({Ops.BUFFER}) for x in c.toposort()) else c.src[0].replace(tag=None)),
|
||||||
|
])
|
||||||
|
|
||||||
finalize_contiguous = PatternMatcher([
|
finalize_contiguous = PatternMatcher([
|
||||||
# if an op takes more than one input, check combined LOADs don't exceed device limits
|
# if an op takes more than one input, check combined LOADs don't exceed device limits
|
||||||
(UPat(set.union(GroupOp.Binary, GroupOp.Ternary), name="root"), limit_bufs),
|
(UPat(set.union(GroupOp.Binary, GroupOp.Ternary), name="root"), limit_bufs),
|
||||||
|
|
@ -327,14 +270,17 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# multi + merge_views + simplify
|
# multi + merge_views + simplify
|
||||||
tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views")
|
tensor_map = graph_rewrite_map(sink, multi_pm+merge_views+sym+replace_contiguous, ctx={}, name="merge_views")
|
||||||
|
|
||||||
# display the cleaned up tensor graph
|
# display the cleaned up tensor graph
|
||||||
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Tensor Graph")
|
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Tensor Graph")
|
||||||
|
|
||||||
# insert contiguous in places determined by the realize map
|
# insert contiguous in places determined by the realize map
|
||||||
|
forced_realize = flatten([x.base.src if x.base.op is Ops.MSTACK else [x.base] for x in tensor_map[sink].src])
|
||||||
realize_map = group_realizes(tensor_map[sink])
|
realize_map = group_realizes(tensor_map[sink])
|
||||||
tensor_map = graph_rewrite_map(tensor_map[sink], add_contiguous, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="add_contiguous")
|
tensor_map = graph_rewrite_map(tensor_map[sink], add_contiguous, ctx=(realize_map, forced_realize),
|
||||||
|
bottom_up=True, input_map=tensor_map, name="add_contiguous")
|
||||||
|
tensor_map = graph_rewrite_map(tensor_map[sink], new_fusion, input_map=tensor_map, name="new_fusion")
|
||||||
tensor_map = graph_rewrite_map(tensor_map[sink], finalize_contiguous+remove_tags, input_map=tensor_map, name="finalize_contiguous")
|
tensor_map = graph_rewrite_map(tensor_map[sink], finalize_contiguous+remove_tags, input_map=tensor_map, name="finalize_contiguous")
|
||||||
|
|
||||||
# group into kernels (this is context-free)
|
# group into kernels (this is context-free)
|
||||||
|
|
|
||||||
|
|
@ -275,7 +275,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||||
ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, new_axis))
|
ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, new_axis))
|
||||||
return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)]))
|
return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)]))
|
||||||
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
||||||
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
|
def contiguous(self, **kwargs): return self.alu(Ops.CONTIGUOUS, **kwargs)
|
||||||
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
|
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
|
||||||
def fuse(self): return self.alu(Ops.FUSE)
|
def fuse(self): return self.alu(Ops.FUSE)
|
||||||
def allreduce(self, op, device:str|tuple[str, ...]|UOp):
|
def allreduce(self, op, device:str|tuple[str, ...]|UOp):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue