Compare commits

...

7 commits

Author SHA1 Message Date
George Hotz
bb7aa19d67
Merge branch 'master' into replace_if_with_range 2025-10-28 23:02:15 +08:00
George Hotz
117f37ae5f don't remove the gate 2025-10-28 19:28:38 +08:00
George Hotz
ba23b097c1 fix image 2025-10-28 19:22:59 +08:00
George Hotz
3341272771 cleanup patterns 2025-10-28 19:15:01 +08:00
George Hotz
975f5ccc99 tests pass 2025-10-28 19:12:21 +08:00
George Hotz
d0de209ad0 don't brick on that 2025-10-28 18:49:00 +08:00
George Hotz
2d87d89202 replace if with range 2025-10-28 18:30:11 +08:00
7 changed files with 28 additions and 37 deletions

View file

@ -402,7 +402,7 @@ class TestLinearizer(unittest.TestCase):
# # check the children's vins
# TODO: src ALU are not the same, should it?
# assert barrier.src == tuple(local_stores)
assert len([u for u in uops if u.op is Ops.IF])
#assert len([u for u in uops if u.op is Ops.IF])
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")

View file

@ -272,6 +272,7 @@ class TestConstantFolding(unittest.TestCase):
si = t.schedule()
assert len(si) == 0
@unittest.skip("no more if statements")
class TestGatedStoreRewrite(unittest.TestCase):
def test_tiny_gate_store(self):
gmem = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)

View file

@ -1,10 +1,8 @@
from typing import cast
import itertools
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL, SPEC
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
from tinygrad.renderer import Renderer
from tinygrad.dtype import dtypes
from tinygrad.helpers import panic
# import all pattern matchers here
from tinygrad.codegen.quantize import pm_quant
@ -17,7 +15,7 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in
from tinygrad.codegen.opt.postrange import apply_opts
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
from tinygrad.codegen.late.linearizer import CFGContext, pm_prepare_control_flow, pm_add_control_flow, linearize
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
if ren is None: ren = Renderer()
@ -85,35 +83,18 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
# final rules for the renderer (without sym)
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
pm_final_rewrite = pm_decomp+pm_render+extra_matcher+pm_split_ends
pm_final_rewrite = pm_decomp+pm_render+extra_matcher
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren.device, name="final rewrite")
# prepare for control flow
sink = graph_rewrite(sink, pm_prepare_control_flow, ctx=itertools.count(10000), name="split ends + add if ranges")
# this was the linearizer
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
# return the rewritten sink
return sink
# inject IF/ENDIF. only needed if device doesn't support gated stores
pm_linearize_cleanups = PatternMatcher([
# if statements are not allowed in the graph
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError("if not allowed in graph"))),
# gated INDEX becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat(name="gate", dtype=dtypes.bool))).or_casted(), UPat()),
allow_any_len=True), lambda u, gate: (u, [mif:=UOp(Ops.IF, src=(gate, u.src[0])), u, UOp(Ops.ENDIF, src=(mif,))]))
])
# requires lst be toposorted. like graph rewrite, but for lines
def line_rewrite(lst:list[UOp], pm:PatternMatcher) -> list[UOp]:
newlst = []
replaced: dict[UOp, UOp] = {}
for u in lst:
nu = u.replace(src=tuple([replaced[x] for x in u.src]))
ret: tuple[UOp, list[UOp]] = cast(tuple[UOp, list[UOp]]|None, pm.rewrite(nu)) or (nu, [nu])
replaced[u] = ret[0]
newlst.extend(ret[1])
return newlst
def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]:
"""
Function to transform the Kernel UOp graph into a linearized program.
@ -128,6 +109,6 @@ def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]:
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
assert len(full_sink.ranges) == 0, "all ranges must end by the sink"
lst = line_rewrite(linearize(full_sink), pm_linearize_cleanups)
lst = linearize(full_sink)
if SPEC: type_verify(lst, program_spec)
return lst

View file

@ -1,6 +1,7 @@
import heapq
from collections import defaultdict
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat, AxisType, GroupOp
def linearize(u:UOp) -> list[UOp]:
# this is a toposort with priority
@ -75,7 +76,11 @@ def do_split_ends(e:UOp):
for r in list(UOp.sink(*e.src[1:]).ranges)[::-1]: ret = ret.end(r)
return ret
pm_split_ends = PatternMatcher([
pm_prepare_control_flow = PatternMatcher([
# split the ends
(UPat(Ops.END, name="e"), do_split_ends),
# add if ranges
(UPat(GroupOp.Defines, name="buf").index(UPat.var("idx"), UPat(name="gate", dtype=dtypes.bool)).or_casted("cast").store(UPat.var("val")),
lambda ctx,buf,idx,gate,cast,val:
buf.after(r:=UOp.range(gate.cast(dtypes.int), next(ctx), AxisType.IF, dtype=dtypes.int)).index(idx, gate).cast(cast.dtype).store(val).end(r)),
])

View file

@ -37,6 +37,8 @@ class Estimates:
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort())
elif u.op is Ops.IF:
dont_count = dont_count.union(u.src[0].toposort())
elif u.op is Ops.RANGE:
dont_count = dont_count.union(u.src[0].toposort())
for u in uops:
if u.op in {Ops.LOAD, Ops.STORE}:
buf = u
@ -45,18 +47,18 @@ class Estimates:
mem[(buf, u.op)] = buf.ptrdtype.size * buf.dtype.itemsize
if u.op is Ops.RANGE:
mult_stack.append(mults)
mults *= cast(sint, u.src[0].ssimplify())
mults = cast(sint, (mults*u.src[0]).ssimplify())
# SPECIAL are already counted in mults
mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults
elif u.op is Ops.END: mults = mult_stack.pop(-1)
elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
elif u.op is Ops.SPECIAL: mults = cast(sint, (mults*u.src[0]).ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
lds += u.dtype.itemsize * mults
elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
lds += u.src[1].dtype.itemsize * mults
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
return Estimates(flops, lds, sum(mem.values()))
return Estimates(ssimplify(flops), ssimplify(lds), sum(mem.values()))
@dataclass
class ProgramSpec:

View file

@ -157,7 +157,8 @@ class CStyleLanguage(Renderer):
# mark buffers that we store to writable
if u.op is Ops.STORE:
for up in u.src[0].toposort():
# NOTE: we gate on RANGE to not follow it back
for up in u.src[0].toposort(lambda x: x.op is not Ops.RANGE):
if up.op is Ops.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True))
# naming

View file

@ -15,11 +15,12 @@ if TYPE_CHECKING:
class AxisType(Enum):
def __repr__(self): return str(self)
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
THREAD = auto()
THREAD = auto(); IF = auto() # noqa: E702
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.IF: "I"}
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta",
AxisType.IF: "green"}
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1}