mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
3 commits
master
...
late_allre
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4662fb413f | ||
|
|
ccb5dcf3b8 | ||
|
|
6b82b51759 |
8 changed files with 32 additions and 19 deletions
|
|
@ -32,8 +32,8 @@ def apply_after(ctx:AllocCtx, u:UOp):
|
|||
# CONTIGUOUS and ASSIGN + parents are the only nodes that get updated
|
||||
add_tags = PatternMatcher([
|
||||
(UPat(Ops.COPY, name="u"), disk_copy_is_buffer),
|
||||
# no tag on copies that are assigned
|
||||
(UPat(Ops.ASSIGN, src=(UPat(), UPat(Ops.COPY, name="c")), name="a"),
|
||||
# no tag on copies/allreduces that are assigned
|
||||
(UPat(Ops.ASSIGN, src=(UPat(), UPat((Ops.COPY, Ops.ALLREDUCE), name="c")), name="a"),
|
||||
lambda a,c: a.replace(src=(a.src[0], c.rtag(())), tag=a.tag+c.tag) if a.tag and c.tag else None),
|
||||
(UPat(Ops.AFTER, name="u"), apply_after),
|
||||
(UPat({Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"), tag_uop),
|
||||
|
|
|
|||
|
|
@ -131,8 +131,17 @@ def lower_sink_to_linear(function:UOp) -> UOp|None:
|
|||
f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}"))
|
||||
return linear
|
||||
|
||||
def soft_allreduce(c:UOp, a:UOp):
|
||||
from tinygrad.schedule.multi import handle_allreduce
|
||||
to = c.src[1].param_like(0)
|
||||
src = c.src[2].param_like(1)
|
||||
red = UOp(Ops.ALLREDUCE, dtype=a.arg, src=(src, a.src[1]), arg=a.arg)
|
||||
return to.assign(handle_allreduce(src, red)).sink().call(*c.src[1:])
|
||||
|
||||
pm_schedule = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="function"), lower_sink_to_linear),
|
||||
# soft handler of allreduce
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.ALLREDUCE, name="a"),), allow_any_len=True, name="c"), soft_allreduce),
|
||||
])
|
||||
|
||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[0]))}")
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink
|
|||
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
|
||||
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
|
||||
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.ALLREDUCE, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
|
||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL, Ops.ENCDEC}
|
||||
|
||||
|
|
@ -18,8 +18,8 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
|
|||
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
|
||||
|
||||
def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp):
|
||||
# don't realize COPY/BUFFER_VIEW/ENCDEC when they are the direct source of ASSIGN — the ASSIGN target buffer is the output
|
||||
if x.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC} and x in ctx \
|
||||
# don't realize COPY/ALLREDUCE/BUFFER_VIEW/ENCDEC when they are the direct source of ASSIGN — the ASSIGN target buffer is the output
|
||||
if x.op in {Ops.COPY, Ops.ALLREDUCE, Ops.BUFFER_VIEW, Ops.ENCDEC} and x in ctx \
|
||||
and not buf.op_in_backward_slice_with_self(Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.PAD):
|
||||
del ctx[x]
|
||||
# you don't usually have to do this for assign unless there's a WAR hazard like TestAssign.test_assign_double_diamond_reduce
|
||||
|
|
@ -29,9 +29,9 @@ pm_generate_realize_map = PatternMatcher([
|
|||
# always realize SINK src
|
||||
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
|
||||
# always realize
|
||||
(UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS, Ops.STORE, Ops.ASSIGN, Ops.ENCDEC}, name="tr"), realize),
|
||||
(UPat({Ops.COPY, Ops.ALLREDUCE, Ops.BUFFER_VIEW, Ops.CONTIGUOUS, Ops.STORE, Ops.ASSIGN, Ops.ENCDEC}, name="tr"), realize),
|
||||
# realize srcs of these
|
||||
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK, Ops.ENCDEC), name="rb"), realize_srcs),
|
||||
(UPat((Ops.COPY, Ops.ALLREDUCE, Ops.MSELECT, Ops.MSTACK, Ops.ENCDEC), name="rb"), realize_srcs),
|
||||
# sometimes realize src of assign
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var("buf"), UPat.var("x"))), realize_assign_src),
|
||||
])
|
||||
|
|
@ -71,8 +71,8 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
|||
new_src = s.end(*[r for r in closed_ranges if r.op is Ops.RANGE])
|
||||
del ctx.realize_map[s]
|
||||
else:
|
||||
# the Bufferize before a COPY is not removable. there should be a better way to do this
|
||||
removable = x.op is not Ops.COPY and s.op not in ALWAYS_CONTIGUOUS
|
||||
# the Bufferize before a COPY/ALLREDUCE is not removable. there should be a better way to do this
|
||||
removable = x.op not in {Ops.COPY, Ops.ALLREDUCE} and s.op not in ALWAYS_CONTIGUOUS
|
||||
# None in the device assigns it a number later
|
||||
opts = BufferizeOpts(device=s.device, removable=removable) if len(ctx.range_map[s][1]) == len(realized_ranges) else \
|
||||
BufferizeOpts(device=s.device, addrspace=AddrSpace.LOCAL, removable=removable)
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ def mstack_early_shrink(ms:UOp, shrink:UOp):
|
|||
return ms.replace(src=tuple(ret))
|
||||
|
||||
replace_allreduce = PatternMatcher([
|
||||
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
|
||||
#(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
|
||||
# BROADCAST: explicitly expand broadcast copies and combine with MSTACK
|
||||
(UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x:
|
||||
UOp(Ops.MSTACK, c.dtype, tuple(x.copy_to_device(d) for d in c.device)) if isinstance(c.device, tuple) and isinstance(x.device, str) else None),
|
||||
|
|
|
|||
|
|
@ -119,8 +119,8 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
|||
|
||||
# ** copy rules **
|
||||
|
||||
# COPY and source size need to match
|
||||
(UPat(Ops.COPY, src=(UPat(GroupOp.Movement, name="r"), UPat(name="d")), name="c"),
|
||||
# COPY/ALLREDUCE and source size need to match
|
||||
(UPat((Ops.COPY, Ops.ALLREDUCE), src=(UPat(GroupOp.Movement, name="r"), UPat(name="d")), name="c"),
|
||||
lambda c,r,d: c.replace(src=(r.contiguous(), d)) if r.size != r.base.size else None),
|
||||
|
||||
# copy only to different device
|
||||
|
|
@ -153,7 +153,7 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
|||
# *****************
|
||||
# 3.5 cleanups
|
||||
|
||||
ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ASSIGN, Ops.ENCDEC, Ops.NOOP}
|
||||
ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ALLREDUCE, Ops.ASSIGN, Ops.ENCDEC, Ops.NOOP}
|
||||
|
||||
# you don't know in the first pass if axes are going to die, this happens if there's an EXPAND to the left
|
||||
def cleanup_dead_axes(b:UOp):
|
||||
|
|
@ -263,6 +263,8 @@ pm_const_buffer_folding = pm_mops+PatternMatcher([
|
|||
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),),), lambda c: c),
|
||||
# copy on CONST is CONST
|
||||
(UPat(Ops.COPY, src=(UPat.cvar("x"), UPat()), name="copy"), lambda copy,x: copy.const_like(x.arg)),
|
||||
# allreduce on CONST is CONST
|
||||
(UPat(Ops.ALLREDUCE, src=(UPat.cvar("x"), UPat()), name="copy", arg=Ops.ADD), lambda copy,x: copy.const_like(x.arg)*len(x.device)),
|
||||
# hack if a noop turned to a const
|
||||
(UPat(Ops.NOOP, src=(UPat.cvar("c"),), name="noop"), lambda c,noop: c),
|
||||
# mstack on CONST is CONST
|
||||
|
|
@ -490,7 +492,7 @@ def split_store(x:UOp) -> UOp|None:
|
|||
if ret.op is Ops.STORE: stored = ret.src[1]
|
||||
elif ret.op is Ops.END and ret.src[0].op is Ops.STORE: stored = ret.src[0].src[1]
|
||||
else: raise RuntimeError(f"unknown kernel type {ret.op}")
|
||||
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW}: ret = stored.replace(src=stored.src + ret.ended_ranges)
|
||||
if stored.op in {Ops.COPY, Ops.ALLREDUCE, Ops.BUFFER_VIEW}: ret = stored.replace(src=stored.src + ret.ended_ranges)
|
||||
elif stored.op is Ops.ENCDEC: ret = stored
|
||||
else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts))
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL:
|
|||
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
|
||||
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
|
||||
|
||||
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.COPY: 2, Ops.BUFFER_VIEW: 1}
|
||||
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.COPY: 2, Ops.ALLREDUCE: 2, Ops.BUFFER_VIEW: 1}
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:Ops, dt:DType) -> PyConst: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
||||
|
|
@ -925,7 +925,7 @@ class CallInfo:
|
|||
def should_resolve_call(c:UOp) -> bool:
|
||||
# don't resolve real kernel calls, sink or program
|
||||
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return False
|
||||
if c.src[0].op in {Ops.PROGRAM, Ops.LINEAR, Ops.COPY}: return False
|
||||
if c.src[0].op in {Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.ALLREDUCE}: return False
|
||||
return True
|
||||
|
||||
# ******** ops in python ********
|
||||
|
|
@ -1537,7 +1537,7 @@ def pyrender(ast:UOp) -> str:
|
|||
cmap = consumer_map_from_toposort(lst)
|
||||
not_rendered = {Ops.CONST, Ops.VCONST, Ops.DEVICE}
|
||||
always_rendered = {Ops.PARAM, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.VECTORIZE,
|
||||
Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.WHERE, Ops.END, Ops.ASSIGN}
|
||||
Ops.BUFFER, Ops.COPY, Ops.ALLREDUCE, Ops.CALL, Ops.WHERE, Ops.END, Ops.ASSIGN}
|
||||
|
||||
to_render: set[UOp] = {ast}
|
||||
for u in lst:
|
||||
|
|
|
|||
|
|
@ -209,9 +209,11 @@ kernel_spec = PatternMatcher([
|
|||
# reduce must be on ranges
|
||||
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype in (dtypes.index, dtypes.int) for y in x.src[1:])),
|
||||
|
||||
# COPY/BUFFER_VIEW can have ranges appended
|
||||
# COPY/ALLREDUCE/BUFFER_VIEW can have ranges appended
|
||||
(UPat(Ops.COPY, name="x", src=(UPat.var("s"), UPat(Ops.DEVICE)), allow_any_len=True, arg=None),
|
||||
lambda x,s: x.dtype == s.dtype and all(u.op is Ops.RANGE for u in x.src[2:])),
|
||||
(UPat(Ops.ALLREDUCE, name="x", src=(UPat.var("s"), UPat(Ops.DEVICE)), allow_any_len=True),
|
||||
lambda x,s: x.dtype == s.dtype and isinstance(x.arg, Ops) and all(u.op is Ops.RANGE for u in x.src[2:])),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.INDEX, Ops.LOAD)),), allow_any_len=True, name="x"),
|
||||
lambda x: all(u.op is Ops.RANGE for u in x.src[1:])),
|
||||
])+movement_ops+shared_codegen_spec+shared_spec
|
||||
|
|
|
|||
|
|
@ -129,7 +129,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
|||
if u._shape is not None:
|
||||
label += f"\n{shape_to_str(u.shape)}"
|
||||
if u.op is Ops.CALL:
|
||||
label += f"\n{u.src[0].key.hex()[:8]}"
|
||||
label += f"\n{u.src[0].key.hex()[:8]} {u.src[0].op}"
|
||||
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
|
||||
if len(u.toposort()) < 30: label += f"\n{u.render()}"
|
||||
ranges: list[UOp] = []
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue