flip Ops.END srcs (#12882)

* flip Ops.END srcs

* backward

* late end split
This commit is contained in:
George Hotz 2025-10-23 12:47:50 +08:00 committed by GitHub
commit e85cee0aad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 44 additions and 43 deletions

View file

@ -17,7 +17,7 @@ from tinygrad.renderer.ptx import PTXRenderer
def to_uops_list(u:list[UOp], ren=None) -> list[UOp]:
sink = UOp.group(*u)
for r in sink.ranges: sink = r.end(sink)
for r in sink.ranges: sink = sink.end(r)
# we strip the SINK here for legacy reasons
ret = full_rewrite(sink.sink(arg=KernelInfo(opts_to_apply=())), ren)
assert ret[-1].op is Ops.SINK

View file

@ -14,7 +14,7 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in
from tinygrad.codegen.opt.postrange import apply_opts
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_reduce_simplify, pm_flatten_range, pm_split_ranges
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen
from tinygrad.codegen.late.control_flow import CFGContext, pm_add_ends, pm_add_control_flow, linearize
from tinygrad.codegen.late.control_flow import CFGContext, pm_add_ends, pm_split_ends, pm_add_control_flow, linearize
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
if ren is None: ren = Renderer()
@ -79,6 +79,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren.device, name="final rewrite")
# this was the linearizer
sink = graph_rewrite(sink, pm_split_ends, name="split ends of ranges")
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
# return the rewritten sink

View file

@ -1,4 +1,4 @@
import heapq, functools
import heapq
from typing import cast
from collections import defaultdict
from tinygrad.dtype import dtypes
@ -81,7 +81,7 @@ class CFGContext:
for s in u.src: deps[u] |= deps[s]
if u.op in (Ops.END, Ops.SINK):
nesting |= {x:u for x in deps[u] if x.op is Ops.END and (u.op is Ops.SINK or u.src[0] in deps[x]) and x not in nesting}
nesting |= {x:u for x in deps[u] if x.op is Ops.END and (u.op is Ops.SINK or u.src[1] in deps[x]) and x not in nesting}
if u.op in (Ops.RANGE, Ops.END): deps[u][u] = None
self.edges: dict[UOp, UOp] = {}
@ -90,18 +90,23 @@ class CFGContext:
for k,v in siblings.items():
# range/if that have dependencies on other siblings need to run after them
order = sorted(v, key=lambda x: len([u for u in v if u in deps[x]]))
zipped = zip(order, order[1:]) if k.op is Ops.SINK else zip([k.src[0]] + order, order)
zipped = zip(order, order[1:]) if k.op is Ops.SINK else zip([k.src[1]] + order, order)
for x,y in zipped:
# TODO: is this check correct?
if y.src[0] not in x.backward_slice_with_self:
self.edges[y.src[0]] = x
if y.src[1] not in x.backward_slice_with_self:
self.edges[y.src[1]] = x
pm_add_control_flow = PatternMatcher([
(UPat(Ops.RANGE, name="x"), lambda ctx,x: x.replace(src=x.src+(y,)) if (y:=ctx.edges.get(x)) is not None else None),
])
pm_split_ends = PatternMatcher([
# split the ends
(UPat(Ops.END, name="e"), lambda e: e.src[0].end(e.src[-1]).end(*e.src[1:-1]) if len(e.src) > 2 else None),
])
# NOTE: this can be done whenever
pm_add_ends = PatternMatcher([
# put the end on the store
(UPat(Ops.STORE, name="s"), lambda s:
functools.reduce(lambda x,y: y.end(x), [x for x in s.src[2:] if x.op is Ops.RANGE][::-1], s.replace(src=s.src[:2]))),
])
(UPat(Ops.STORE, name="s"), lambda s: s.replace(src=s.src[:2]).end(*[x for x in s.src[2:] if x.op is Ops.RANGE])),
])

View file

@ -106,15 +106,15 @@ base_rewrite = PatternMatcher([
f" {ctx[x]} = select {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[2].dtype)} {ctx[x.src[2]]}"),
# range
(UPat(Ops.RANGE, name="x"), lambda ctx,x:
f" br label %loop_entry_{range_str(x)}\nloop_entry_{range_str(x)}:\n"
f" br label %loop_body_{range_str(x)}\nloop_body_{range_str(x)}:\n"
f" {ctx[x]} = phi {ldt(x.dtype)} [ 0, %loop_entry_{range_str(x)} ], [ {ctx[x]}phi, %loop_latch_{range_str(x)} ]"),
(UPat(Ops.END, name="x"), lambda ctx,x:
f" br label %loop_latch_{range_str(x.src[0])}\nloop_latch_{range_str(x.src[0])}:\n"
f" {ctx[x.src[0]]}phi = add {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, 1\n"
f" {ctx[x]} = icmp ult {ldt(x.src[0].dtype)} {ctx[x.src[0]]}phi, {ctx[x.src[0].src[0]]}\n"
f" br i1 {ctx[x]}, label %loop_body_{range_str(x.src[0])}, label %loop_exit_{range_str(x.src[0])}\nloop_exit_{range_str(x.src[0])}:"),
(UPat(Ops.RANGE, name="r"), lambda ctx,r:
f" br label %loop_entry_{range_str(r)}\nloop_entry_{range_str(r)}:\n"
f" br label %loop_body_{range_str(r)}\nloop_body_{range_str(r)}:\n"
f" {ctx[r]} = phi {ldt(r.dtype)} [ 0, %loop_entry_{range_str(r)} ], [ {ctx[r]}phi, %loop_latch_{range_str(r)} ]"),
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE, name="r")), name="x"), lambda ctx,x,r:
f" br label %loop_latch_{range_str(r)}\nloop_latch_{range_str(r)}:\n"
f" {ctx[r]}phi = add {ldt(r.dtype)} {ctx[r]}, 1\n"
f" {ctx[x]} = icmp ult {ldt(r.dtype)} {ctx[r]}phi, {ctx[r.src[0]]}\n"
f" br i1 {ctx[x]}, label %loop_body_{range_str(r)}, label %loop_exit_{range_str(r)}\nloop_exit_{range_str(r)}:"),
# if
(UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),

View file

@ -187,8 +187,9 @@ class NIRRenderer(Renderer):
mesa.nir_push_loop(self.b)
self.r[u] = nload(self.b, AddrSpace.REG, i, u.dtype)
elif u.op == Ops.END:
nif(self.b, nalu(self.b, "ilt", x:=nalu(self.b, "iadd", self.r[u.src[0]], nimm(self.b, 1, u.src[0].dtype)), self.r[u.src[0].src[0]]),
functools.partial(nstore, self.b, AddrSpace.REG, ranges.pop(), x, u.src[0].dtype), lambda: njump(self.b, mesa.nir_jump_break))
r = u.src[1]
nif(self.b, nalu(self.b, "ilt", x:=nalu(self.b, "iadd", self.r[r], nimm(self.b, 1, r.dtype)), self.r[r.src[0]]),
functools.partial(nstore, self.b, AddrSpace.REG, ranges.pop(), x, r.dtype), lambda: njump(self.b, mesa.nir_jump_break))
mesa.nir_pop_loop(self.b, None)
else:
if (d:=self.def_rewrite.rewrite(u, ctx=self)) is None: raise RuntimeError(f"failed to render {u.op} srcs {[x.dtype for x in u.src]}")

View file

@ -119,11 +119,11 @@ string_rewrite = PatternMatcher([
if x.dtype.count > 1 else f"ld.{mem_type(buf)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
# simple
(UPat(Ops.DEFINE_REG, src=()), lambda ctx: []),
(UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, 0;", "LOOP_" + f"{ctx.r[x][1:]}:"]),
(UPat(Ops.END, name="x", src=(UPat.var("src0"),), allow_any_len=True), lambda ctx, x, src0: [
ctx.code_for_op[Ops.ADD](ctx.r[src0], ctx.r[src0], "1", dtypes.int, ctx.types[dtypes.int]),
ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[0]], dtypes.int, ctx.types[dtypes.int]),
f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]),
(UPat(Ops.RANGE, name="r"), lambda ctx, r: [f"mov.u32 {ctx.r[r]}, 0;", "LOOP_" + f"{ctx.r[r][1:]}:"]),
(UPat(Ops.END, name="x", src=(UPat(), UPat(Ops.RANGE, name="r"))), lambda ctx, x, r: [
ctx.code_for_op[Ops.ADD](ctx.r[r], ctx.r[r], "1", dtypes.int, ctx.types[dtypes.int]),
ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[r], ctx.r[r.src[0]], dtypes.int, ctx.types[dtypes.int]),
f"@{ctx.r[x]} bra LOOP_{ctx.r[r][1:]};"]),
(UPat(Ops.DEFINE_LOCAL, name="x"),
lambda ctx, x: [f".shared .align 16 .b8 local{x.arg}[{x.dtype.size*x.dtype.itemsize}];", f"mov.u64 {ctx.r[x]}, local{x.arg}[0];"]),
(UPat(Ops.IF, name="x"), lambda ctx, x: f"@!{ctx.r[x.src[0]]} bra IF_{ctx.r[x.src[0]][1:]}_{ctx.uops.index(x)};"),

View file

@ -57,8 +57,8 @@ class PythonProgram:
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
if uop is Ops.END:
loop_ends[idp[0]] = i
i = idp[0]
loop_ends[idp[1]] = i
i = idp[1]
continue
if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP):
# in the python emulator, the warp is always in sync

View file

@ -17,7 +17,7 @@ class AxisType(Enum):
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
THREAD = auto()
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3}
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1}
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
@ -270,13 +270,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@functools.cached_property
def ended_ranges(self):
# copy of range_start
match self.op:
case Ops.REDUCE | Ops.BUFFERIZE: return self.src[1:]
case Ops.STORE: return self.src[2:]
case Ops.WMMA: return self.src[3:]
case Ops.END: return self.src[:1]
case _: return ()
if self.op in range_start: return self.src[range_start[self.op]:]
return ()
# determine what ranges this is in
@recursive_property
@ -359,7 +354,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self,)+src, **kwargs)
def end(self, *src:UOp):
assert self.op is Ops.RANGE, "end only ends ranges"
if len(src) == 0: return self
assert all(x.op is Ops.RANGE for x in src), "end only ends ranges"
return UOp(Ops.END, src=(self,)+src)
def after(self, *src:UOp): return UOp(Ops.AFTER, self.dtype, (self,)+src)
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
@ -1188,10 +1184,8 @@ pm_lower_index_dtype = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
(UPat((Ops.STORE, Ops.LOAD), src=(UPat(), UPat(), UPat().cast(dtypes.index)), allow_any_len=True, name="s"),
lambda s: s.replace(src=s.src[:2]+tuple(u.src[0] for u in s.src[2:]))),
# TODO: this is only triggering if they are all casts, correct?
(UPat((Ops.SINK, Ops.NOOP, Ops.END), src=UPat().cast(dtypes.index), name="n"), lambda n: n.replace(src=tuple(s.src[0] for s in n.src))),
# no CAST on END
(UPat(Ops.END, src=(UPat(Ops.CAST),), allow_any_len=True, name="e"), lambda e: e.replace(src=(e.src[0].src[0],)+e.src[1:])),
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"),
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.index else s for s in n.src))),
])
def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]

View file

@ -121,7 +121,7 @@ program_spec = PatternMatcher([
# RANGE/SPECIAL define loops, END closes them
(UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)),
(UPat(Ops.END, src=(UPat(Ops.RANGE), UPat()), dtype=dtypes.void), lambda: True),
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
# make sure all index dtypes have been lowered
(UPat(GroupOp.All, dtype=dtypes.index), lambda: False),

View file

@ -83,7 +83,7 @@ def uop_to_json(x:UOp, ignore_indexing=False) -> dict[int, dict]:
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
label += f"\n{u.render()}"
if u.op is Ops.END:
label += f"\n{colored(u.src[0].arg[0], axis_colors[u.src[0].arg[-1]])}({u.src[0].vmax+1})"
label += "\n"+' '.join([f"{colored(s.arg[0], axis_colors[s.arg[-1]])}({s.vmax+1})" for s in u.src[1:]])
except Exception:
label += "\n<ISSUE GETTING LABEL>"
if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"