mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
7 commits
master
...
replace_if
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bb7aa19d67 |
||
|
|
117f37ae5f | ||
|
|
ba23b097c1 | ||
|
|
3341272771 | ||
|
|
975f5ccc99 | ||
|
|
d0de209ad0 | ||
|
|
2d87d89202 |
7 changed files with 28 additions and 37 deletions
|
|
@ -402,7 +402,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
# # check the children's vins
|
||||
# TODO: src ALU are not the same, should it?
|
||||
# assert barrier.src == tuple(local_stores)
|
||||
assert len([u for u in uops if u.op is Ops.IF])
|
||||
#assert len([u for u in uops if u.op is Ops.IF])
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
|
|
|
|||
|
|
@ -272,6 +272,7 @@ class TestConstantFolding(unittest.TestCase):
|
|||
si = t.schedule()
|
||||
assert len(si) == 0
|
||||
|
||||
@unittest.skip("no more if statements")
|
||||
class TestGatedStoreRewrite(unittest.TestCase):
|
||||
def test_tiny_gate_store(self):
|
||||
gmem = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
from typing import cast
|
||||
import itertools
|
||||
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL, SPEC
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype
|
||||
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import panic
|
||||
|
||||
# import all pattern matchers here
|
||||
from tinygrad.codegen.quantize import pm_quant
|
||||
|
|
@ -17,7 +15,7 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in
|
|||
from tinygrad.codegen.opt.postrange import apply_opts
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen
|
||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_prepare_control_flow, pm_add_control_flow, linearize
|
||||
|
||||
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
|
||||
if ren is None: ren = Renderer()
|
||||
|
|
@ -85,35 +83,18 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
|||
|
||||
# final rules for the renderer (without sym)
|
||||
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
|
||||
pm_final_rewrite = pm_decomp+pm_render+extra_matcher+pm_split_ends
|
||||
pm_final_rewrite = pm_decomp+pm_render+extra_matcher
|
||||
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren.device, name="final rewrite")
|
||||
|
||||
# prepare for control flow
|
||||
sink = graph_rewrite(sink, pm_prepare_control_flow, ctx=itertools.count(10000), name="split ends + add if ranges")
|
||||
|
||||
# this was the linearizer
|
||||
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
|
||||
|
||||
# return the rewritten sink
|
||||
return sink
|
||||
|
||||
# inject IF/ENDIF. only needed if device doesn't support gated stores
|
||||
pm_linearize_cleanups = PatternMatcher([
|
||||
# if statements are not allowed in the graph
|
||||
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError("if not allowed in graph"))),
|
||||
# gated INDEX becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
|
||||
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat(name="gate", dtype=dtypes.bool))).or_casted(), UPat()),
|
||||
allow_any_len=True), lambda u, gate: (u, [mif:=UOp(Ops.IF, src=(gate, u.src[0])), u, UOp(Ops.ENDIF, src=(mif,))]))
|
||||
])
|
||||
|
||||
# requires lst be toposorted. like graph rewrite, but for lines
|
||||
def line_rewrite(lst:list[UOp], pm:PatternMatcher) -> list[UOp]:
|
||||
newlst = []
|
||||
replaced: dict[UOp, UOp] = {}
|
||||
for u in lst:
|
||||
nu = u.replace(src=tuple([replaced[x] for x in u.src]))
|
||||
ret: tuple[UOp, list[UOp]] = cast(tuple[UOp, list[UOp]]|None, pm.rewrite(nu)) or (nu, [nu])
|
||||
replaced[u] = ret[0]
|
||||
newlst.extend(ret[1])
|
||||
return newlst
|
||||
|
||||
def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]:
|
||||
"""
|
||||
Function to transform the Kernel UOp graph into a linearized program.
|
||||
|
|
@ -128,6 +109,6 @@ def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]:
|
|||
|
||||
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
|
||||
assert len(full_sink.ranges) == 0, "all ranges must end by the sink"
|
||||
lst = line_rewrite(linearize(full_sink), pm_linearize_cleanups)
|
||||
lst = linearize(full_sink)
|
||||
if SPEC: type_verify(lst, program_spec)
|
||||
return lst
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import heapq
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat, AxisType, GroupOp
|
||||
|
||||
def linearize(u:UOp) -> list[UOp]:
|
||||
# this is a toposort with priority
|
||||
|
|
@ -75,7 +76,11 @@ def do_split_ends(e:UOp):
|
|||
for r in list(UOp.sink(*e.src[1:]).ranges)[::-1]: ret = ret.end(r)
|
||||
return ret
|
||||
|
||||
pm_split_ends = PatternMatcher([
|
||||
pm_prepare_control_flow = PatternMatcher([
|
||||
# split the ends
|
||||
(UPat(Ops.END, name="e"), do_split_ends),
|
||||
# add if ranges
|
||||
(UPat(GroupOp.Defines, name="buf").index(UPat.var("idx"), UPat(name="gate", dtype=dtypes.bool)).or_casted("cast").store(UPat.var("val")),
|
||||
lambda ctx,buf,idx,gate,cast,val:
|
||||
buf.after(r:=UOp.range(gate.cast(dtypes.int), next(ctx), AxisType.IF, dtype=dtypes.int)).index(idx, gate).cast(cast.dtype).store(val).end(r)),
|
||||
])
|
||||
|
|
@ -37,6 +37,8 @@ class Estimates:
|
|||
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort())
|
||||
elif u.op is Ops.IF:
|
||||
dont_count = dont_count.union(u.src[0].toposort())
|
||||
elif u.op is Ops.RANGE:
|
||||
dont_count = dont_count.union(u.src[0].toposort())
|
||||
for u in uops:
|
||||
if u.op in {Ops.LOAD, Ops.STORE}:
|
||||
buf = u
|
||||
|
|
@ -45,18 +47,18 @@ class Estimates:
|
|||
mem[(buf, u.op)] = buf.ptrdtype.size * buf.dtype.itemsize
|
||||
if u.op is Ops.RANGE:
|
||||
mult_stack.append(mults)
|
||||
mults *= cast(sint, u.src[0].ssimplify())
|
||||
mults = cast(sint, (mults*u.src[0]).ssimplify())
|
||||
# SPECIAL are already counted in mults
|
||||
mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults
|
||||
elif u.op is Ops.END: mults = mult_stack.pop(-1)
|
||||
elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
|
||||
elif u.op is Ops.SPECIAL: mults = cast(sint, (mults*u.src[0]).ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
|
||||
elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
|
||||
lds += u.dtype.itemsize * mults
|
||||
elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
|
||||
lds += u.src[1].dtype.itemsize * mults
|
||||
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
|
||||
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
||||
return Estimates(flops, lds, sum(mem.values()))
|
||||
return Estimates(ssimplify(flops), ssimplify(lds), sum(mem.values()))
|
||||
|
||||
@dataclass
|
||||
class ProgramSpec:
|
||||
|
|
|
|||
|
|
@ -157,7 +157,8 @@ class CStyleLanguage(Renderer):
|
|||
|
||||
# mark buffers that we store to writable
|
||||
if u.op is Ops.STORE:
|
||||
for up in u.src[0].toposort():
|
||||
# NOTE: we gate on RANGE to not follow it back
|
||||
for up in u.src[0].toposort(lambda x: x.op is not Ops.RANGE):
|
||||
if up.op is Ops.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True))
|
||||
|
||||
# naming
|
||||
|
|
|
|||
|
|
@ -15,11 +15,12 @@ if TYPE_CHECKING:
|
|||
class AxisType(Enum):
|
||||
def __repr__(self): return str(self)
|
||||
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
|
||||
THREAD = auto()
|
||||
THREAD = auto(); IF = auto() # noqa: E702
|
||||
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
|
||||
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
|
||||
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.IF: "I"}
|
||||
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
|
||||
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
|
||||
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta",
|
||||
AxisType.IF: "green"}
|
||||
|
||||
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue