Compare commits

...

3 commits

Author SHA1 Message Date
George Hotz
4662fb413f works 2026-03-04 13:14:16 +08:00
George Hotz
ccb5dcf3b8 fix 2026-03-04 13:01:35 +08:00
George Hotz
6b82b51759 close 2026-03-04 12:59:17 +08:00
8 changed files with 32 additions and 19 deletions

View file

@ -32,8 +32,8 @@ def apply_after(ctx:AllocCtx, u:UOp):
# CONTIGUOUS and ASSIGN + parents are the only nodes that get updated
add_tags = PatternMatcher([
(UPat(Ops.COPY, name="u"), disk_copy_is_buffer),
# no tag on copies that are assigned
(UPat(Ops.ASSIGN, src=(UPat(), UPat(Ops.COPY, name="c")), name="a"),
# no tag on copies/allreduces that are assigned
(UPat(Ops.ASSIGN, src=(UPat(), UPat((Ops.COPY, Ops.ALLREDUCE), name="c")), name="a"),
lambda a,c: a.replace(src=(a.src[0], c.rtag(())), tag=a.tag+c.tag) if a.tag and c.tag else None),
(UPat(Ops.AFTER, name="u"), apply_after),
(UPat({Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"), tag_uop),

View file

@ -131,8 +131,17 @@ def lower_sink_to_linear(function:UOp) -> UOp|None:
f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}"))
return linear
def soft_allreduce(c:UOp, a:UOp):
from tinygrad.schedule.multi import handle_allreduce
to = c.src[1].param_like(0)
src = c.src[2].param_like(1)
red = UOp(Ops.ALLREDUCE, dtype=a.arg, src=(src, a.src[1]), arg=a.arg)
return to.assign(handle_allreduce(src, red)).sink().call(*c.src[1:])
pm_schedule = PatternMatcher([
(UPat(Ops.SINK, name="function"), lower_sink_to_linear),
# soft handler of allreduce
(UPat(Ops.CALL, src=(UPat(Ops.ALLREDUCE, name="a"),), allow_any_len=True, name="c"), soft_allreduce),
])
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[0]))}")

View file

@ -7,7 +7,7 @@ from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.ALLREDUCE, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL, Ops.ENCDEC}
@ -18,8 +18,8 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
def realize_assign_src(ctx:dict[UOp, None], buf:UOp, x:UOp):
# don't realize COPY/BUFFER_VIEW/ENCDEC when they are the direct source of ASSIGN — the ASSIGN target buffer is the output
if x.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC} and x in ctx \
# don't realize COPY/ALLREDUCE/BUFFER_VIEW/ENCDEC when they are the direct source of ASSIGN — the ASSIGN target buffer is the output
if x.op in {Ops.COPY, Ops.ALLREDUCE, Ops.BUFFER_VIEW, Ops.ENCDEC} and x in ctx \
and not buf.op_in_backward_slice_with_self(Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.PAD):
del ctx[x]
# you don't usually have to do this for assign unless there's a WAR hazard like TestAssign.test_assign_double_diamond_reduce
@ -29,9 +29,9 @@ pm_generate_realize_map = PatternMatcher([
# always realize SINK src
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
# always realize
(UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS, Ops.STORE, Ops.ASSIGN, Ops.ENCDEC}, name="tr"), realize),
(UPat({Ops.COPY, Ops.ALLREDUCE, Ops.BUFFER_VIEW, Ops.CONTIGUOUS, Ops.STORE, Ops.ASSIGN, Ops.ENCDEC}, name="tr"), realize),
# realize srcs of these
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK, Ops.ENCDEC), name="rb"), realize_srcs),
(UPat((Ops.COPY, Ops.ALLREDUCE, Ops.MSELECT, Ops.MSTACK, Ops.ENCDEC), name="rb"), realize_srcs),
# sometimes realize src of assign
(UPat(Ops.ASSIGN, src=(UPat.var("buf"), UPat.var("x"))), realize_assign_src),
])
@ -71,8 +71,8 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
new_src = s.end(*[r for r in closed_ranges if r.op is Ops.RANGE])
del ctx.realize_map[s]
else:
# the Bufferize before a COPY is not removable. there should be a better way to do this
removable = x.op is not Ops.COPY and s.op not in ALWAYS_CONTIGUOUS
# the Bufferize before a COPY/ALLREDUCE is not removable. there should be a better way to do this
removable = x.op not in {Ops.COPY, Ops.ALLREDUCE} and s.op not in ALWAYS_CONTIGUOUS
# None in the device assigns it a number later
opts = BufferizeOpts(device=s.device, removable=removable) if len(ctx.range_map[s][1]) == len(realized_ranges) else \
BufferizeOpts(device=s.device, addrspace=AddrSpace.LOCAL, removable=removable)

View file

@ -71,7 +71,7 @@ def mstack_early_shrink(ms:UOp, shrink:UOp):
return ms.replace(src=tuple(ret))
replace_allreduce = PatternMatcher([
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
#(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
# BROADCAST: explicitly expand broadcast copies and combine with MSTACK
(UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x:
UOp(Ops.MSTACK, c.dtype, tuple(x.copy_to_device(d) for d in c.device)) if isinstance(c.device, tuple) and isinstance(x.device, str) else None),

View file

@ -119,8 +119,8 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
# ** copy rules **
# COPY and source size need to match
(UPat(Ops.COPY, src=(UPat(GroupOp.Movement, name="r"), UPat(name="d")), name="c"),
# COPY/ALLREDUCE and source size need to match
(UPat((Ops.COPY, Ops.ALLREDUCE), src=(UPat(GroupOp.Movement, name="r"), UPat(name="d")), name="c"),
lambda c,r,d: c.replace(src=(r.contiguous(), d)) if r.size != r.base.size else None),
# copy only to different device
@ -153,7 +153,7 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
# *****************
# 3.5 cleanups
ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ASSIGN, Ops.ENCDEC, Ops.NOOP}
ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ALLREDUCE, Ops.ASSIGN, Ops.ENCDEC, Ops.NOOP}
# you don't know in the first pass if axes are going to die, this happens if there's an EXPAND to the left
def cleanup_dead_axes(b:UOp):
@ -263,6 +263,8 @@ pm_const_buffer_folding = pm_mops+PatternMatcher([
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),),), lambda c: c),
# copy on CONST is CONST
(UPat(Ops.COPY, src=(UPat.cvar("x"), UPat()), name="copy"), lambda copy,x: copy.const_like(x.arg)),
# allreduce on CONST is CONST
(UPat(Ops.ALLREDUCE, src=(UPat.cvar("x"), UPat()), name="copy", arg=Ops.ADD), lambda copy,x: copy.const_like(x.arg)*len(x.device)),
# hack if a noop turned to a const
(UPat(Ops.NOOP, src=(UPat.cvar("c"),), name="noop"), lambda c,noop: c),
# mstack on CONST is CONST
@ -490,7 +492,7 @@ def split_store(x:UOp) -> UOp|None:
if ret.op is Ops.STORE: stored = ret.src[1]
elif ret.op is Ops.END and ret.src[0].op is Ops.STORE: stored = ret.src[0].src[1]
else: raise RuntimeError(f"unknown kernel type {ret.op}")
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW}: ret = stored.replace(src=stored.src + ret.ended_ranges)
if stored.op in {Ops.COPY, Ops.ALLREDUCE, Ops.BUFFER_VIEW}: ret = stored.replace(src=stored.src + ret.ended_ranges)
elif stored.op is Ops.ENCDEC: ret = stored
else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts))

View file

@ -26,7 +26,7 @@ axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL:
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.COPY: 2, Ops.BUFFER_VIEW: 1}
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.COPY: 2, Ops.ALLREDUCE: 2, Ops.BUFFER_VIEW: 1}
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:Ops, dt:DType) -> PyConst: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
@ -925,7 +925,7 @@ class CallInfo:
def should_resolve_call(c:UOp) -> bool:
# don't resolve real kernel calls, sink or program
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return False
if c.src[0].op in {Ops.PROGRAM, Ops.LINEAR, Ops.COPY}: return False
if c.src[0].op in {Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.ALLREDUCE}: return False
return True
# ******** ops in python ********
@ -1537,7 +1537,7 @@ def pyrender(ast:UOp) -> str:
cmap = consumer_map_from_toposort(lst)
not_rendered = {Ops.CONST, Ops.VCONST, Ops.DEVICE}
always_rendered = {Ops.PARAM, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.VECTORIZE,
Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.WHERE, Ops.END, Ops.ASSIGN}
Ops.BUFFER, Ops.COPY, Ops.ALLREDUCE, Ops.CALL, Ops.WHERE, Ops.END, Ops.ASSIGN}
to_render: set[UOp] = {ast}
for u in lst:

View file

@ -209,9 +209,11 @@ kernel_spec = PatternMatcher([
# reduce must be on ranges
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype in (dtypes.index, dtypes.int) for y in x.src[1:])),
# COPY/BUFFER_VIEW can have ranges appended
# COPY/ALLREDUCE/BUFFER_VIEW can have ranges appended
(UPat(Ops.COPY, name="x", src=(UPat.var("s"), UPat(Ops.DEVICE)), allow_any_len=True, arg=None),
lambda x,s: x.dtype == s.dtype and all(u.op is Ops.RANGE for u in x.src[2:])),
(UPat(Ops.ALLREDUCE, name="x", src=(UPat.var("s"), UPat(Ops.DEVICE)), allow_any_len=True),
lambda x,s: x.dtype == s.dtype and isinstance(x.arg, Ops) and all(u.op is Ops.RANGE for u in x.src[2:])),
(UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.INDEX, Ops.LOAD)),), allow_any_len=True, name="x"),
lambda x: all(u.op is Ops.RANGE for u in x.src[1:])),
])+movement_ops+shared_codegen_spec+shared_spec

View file

@ -129,7 +129,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
if u._shape is not None:
label += f"\n{shape_to_str(u.shape)}"
if u.op is Ops.CALL:
label += f"\n{u.src[0].key.hex()[:8]}"
label += f"\n{u.src[0].key.hex()[:8]} {u.src[0].op}"
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
if len(u.toposort()) < 30: label += f"\n{u.render()}"
ranges: list[UOp] = []