Compare commits

...

7 commits

Author SHA1 Message Date
George Hotz
bb7aa19d67
Merge branch 'master' into replace_if_with_range 2025-10-28 23:02:15 +08:00
George Hotz
117f37ae5f don't remove the gate 2025-10-28 19:28:38 +08:00
George Hotz
ba23b097c1 fix image 2025-10-28 19:22:59 +08:00
George Hotz
3341272771 cleanup patterns 2025-10-28 19:15:01 +08:00
George Hotz
975f5ccc99 tests pass 2025-10-28 19:12:21 +08:00
George Hotz
d0de209ad0 don't brick on that 2025-10-28 18:49:00 +08:00
George Hotz
2d87d89202 replace if with range 2025-10-28 18:30:11 +08:00
7 changed files with 28 additions and 37 deletions

View file

@ -402,7 +402,7 @@ class TestLinearizer(unittest.TestCase):
# # check the children's vins # # check the children's vins
# TODO: src ALU are not the same, should it? # TODO: src ALU are not the same, should it?
# assert barrier.src == tuple(local_stores) # assert barrier.src == tuple(local_stores)
assert len([u for u in uops if u.op is Ops.IF]) #assert len([u for u in uops if u.op is Ops.IF])
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")

View file

@ -272,6 +272,7 @@ class TestConstantFolding(unittest.TestCase):
si = t.schedule() si = t.schedule()
assert len(si) == 0 assert len(si) == 0
@unittest.skip("no more if statements")
class TestGatedStoreRewrite(unittest.TestCase): class TestGatedStoreRewrite(unittest.TestCase):
def test_tiny_gate_store(self): def test_tiny_gate_store(self):
gmem = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) gmem = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)

View file

@ -1,10 +1,8 @@
from typing import cast import itertools
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL, SPEC from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL, SPEC
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
from tinygrad.renderer import Renderer from tinygrad.renderer import Renderer
from tinygrad.dtype import dtypes
from tinygrad.helpers import panic
# import all pattern matchers here # import all pattern matchers here
from tinygrad.codegen.quantize import pm_quant from tinygrad.codegen.quantize import pm_quant
@ -17,7 +15,7 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in
from tinygrad.codegen.opt.postrange import apply_opts from tinygrad.codegen.opt.postrange import apply_opts
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize from tinygrad.codegen.late.linearizer import CFGContext, pm_prepare_control_flow, pm_add_control_flow, linearize
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp: def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
if ren is None: ren = Renderer() if ren is None: ren = Renderer()
@ -85,35 +83,18 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
# final rules for the renderer (without sym) # final rules for the renderer (without sym)
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([]) extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
pm_final_rewrite = pm_decomp+pm_render+extra_matcher+pm_split_ends pm_final_rewrite = pm_decomp+pm_render+extra_matcher
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren.device, name="final rewrite") sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren.device, name="final rewrite")
# prepare for control flow
sink = graph_rewrite(sink, pm_prepare_control_flow, ctx=itertools.count(10000), name="split ends + add if ranges")
# this was the linearizer # this was the linearizer
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True) sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
# return the rewritten sink # return the rewritten sink
return sink return sink
# inject IF/ENDIF. only needed if device doesn't support gated stores
pm_linearize_cleanups = PatternMatcher([
# if statements are not allowed in the graph
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError("if not allowed in graph"))),
# gated INDEX becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat(name="gate", dtype=dtypes.bool))).or_casted(), UPat()),
allow_any_len=True), lambda u, gate: (u, [mif:=UOp(Ops.IF, src=(gate, u.src[0])), u, UOp(Ops.ENDIF, src=(mif,))]))
])
# requires lst be toposorted. like graph rewrite, but for lines
def line_rewrite(lst:list[UOp], pm:PatternMatcher) -> list[UOp]:
newlst = []
replaced: dict[UOp, UOp] = {}
for u in lst:
nu = u.replace(src=tuple([replaced[x] for x in u.src]))
ret: tuple[UOp, list[UOp]] = cast(tuple[UOp, list[UOp]]|None, pm.rewrite(nu)) or (nu, [nu])
replaced[u] = ret[0]
newlst.extend(ret[1])
return newlst
def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]: def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]:
""" """
Function to transform the Kernel UOp graph into a linearized program. Function to transform the Kernel UOp graph into a linearized program.
@ -128,6 +109,6 @@ def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]:
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None) full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
assert len(full_sink.ranges) == 0, "all ranges must end by the sink" assert len(full_sink.ranges) == 0, "all ranges must end by the sink"
lst = line_rewrite(linearize(full_sink), pm_linearize_cleanups) lst = linearize(full_sink)
if SPEC: type_verify(lst, program_spec) if SPEC: type_verify(lst, program_spec)
return lst return lst

View file

@ -1,6 +1,7 @@
import heapq import heapq
from collections import defaultdict from collections import defaultdict
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat from tinygrad.dtype import dtypes
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat, AxisType, GroupOp
def linearize(u:UOp) -> list[UOp]: def linearize(u:UOp) -> list[UOp]:
# this is a toposort with priority # this is a toposort with priority
@ -75,7 +76,11 @@ def do_split_ends(e:UOp):
for r in list(UOp.sink(*e.src[1:]).ranges)[::-1]: ret = ret.end(r) for r in list(UOp.sink(*e.src[1:]).ranges)[::-1]: ret = ret.end(r)
return ret return ret
pm_split_ends = PatternMatcher([ pm_prepare_control_flow = PatternMatcher([
# split the ends # split the ends
(UPat(Ops.END, name="e"), do_split_ends), (UPat(Ops.END, name="e"), do_split_ends),
# add if ranges
(UPat(GroupOp.Defines, name="buf").index(UPat.var("idx"), UPat(name="gate", dtype=dtypes.bool)).or_casted("cast").store(UPat.var("val")),
lambda ctx,buf,idx,gate,cast,val:
buf.after(r:=UOp.range(gate.cast(dtypes.int), next(ctx), AxisType.IF, dtype=dtypes.int)).index(idx, gate).cast(cast.dtype).store(val).end(r)),
]) ])

View file

@ -37,6 +37,8 @@ class Estimates:
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort()) if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort())
elif u.op is Ops.IF: elif u.op is Ops.IF:
dont_count = dont_count.union(u.src[0].toposort()) dont_count = dont_count.union(u.src[0].toposort())
elif u.op is Ops.RANGE:
dont_count = dont_count.union(u.src[0].toposort())
for u in uops: for u in uops:
if u.op in {Ops.LOAD, Ops.STORE}: if u.op in {Ops.LOAD, Ops.STORE}:
buf = u buf = u
@ -45,18 +47,18 @@ class Estimates:
mem[(buf, u.op)] = buf.ptrdtype.size * buf.dtype.itemsize mem[(buf, u.op)] = buf.ptrdtype.size * buf.dtype.itemsize
if u.op is Ops.RANGE: if u.op is Ops.RANGE:
mult_stack.append(mults) mult_stack.append(mults)
mults *= cast(sint, u.src[0].ssimplify()) mults = cast(sint, (mults*u.src[0]).ssimplify())
# SPECIAL are already counted in mults # SPECIAL are already counted in mults
mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults
elif u.op is Ops.END: mults = mult_stack.pop(-1) elif u.op is Ops.END: mults = mult_stack.pop(-1)
elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these elif u.op is Ops.SPECIAL: mults = cast(sint, (mults*u.src[0]).ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG): elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
lds += u.dtype.itemsize * mults lds += u.dtype.itemsize * mults
elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG): elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
lds += u.src[1].dtype.itemsize * mults lds += u.src[1].dtype.itemsize * mults
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
return Estimates(flops, lds, sum(mem.values())) return Estimates(ssimplify(flops), ssimplify(lds), sum(mem.values()))
@dataclass @dataclass
class ProgramSpec: class ProgramSpec:

View file

@ -157,7 +157,8 @@ class CStyleLanguage(Renderer):
# mark buffers that we store to writable # mark buffers that we store to writable
if u.op is Ops.STORE: if u.op is Ops.STORE:
for up in u.src[0].toposort(): # NOTE: we gate on RANGE to not follow it back
for up in u.src[0].toposort(lambda x: x.op is not Ops.RANGE):
if up.op is Ops.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True)) if up.op is Ops.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True))
# naming # naming

View file

@ -15,11 +15,12 @@ if TYPE_CHECKING:
class AxisType(Enum): class AxisType(Enum):
def __repr__(self): return str(self) def __repr__(self): return str(self)
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702 GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
THREAD = auto() THREAD = auto(); IF = auto() # noqa: E702
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u", axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"} AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.IF: "I"}
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE", axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"} AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta",
AxisType.IF: "green"}
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1} range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1}