mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
rename BUFFERIZE to STAGE (#16125)
This commit is contained in:
parent
39ce780907
commit
daed602569
12 changed files with 38 additions and 38 deletions
2
test/external/external_benchmark_pyrender.py
vendored
2
test/external/external_benchmark_pyrender.py
vendored
|
|
@ -3,7 +3,7 @@ import functools, pickle
|
|||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.helpers import tqdm, temp, time_to_str, cpu_profile
|
||||
|
||||
BENCHMARK_OPS = {Ops.INDEX, Ops.BUFFERIZE}
|
||||
BENCHMARK_OPS = {Ops.INDEX, Ops.STAGE}
|
||||
|
||||
@functools.cache
|
||||
def create_uop(a:int) -> UOp:
|
||||
|
|
|
|||
|
|
@ -799,12 +799,12 @@ class TestConstBufferize(unittest.TestCase):
|
|||
from tinygrad.schedule.rangeify import pm_const_buffer_folding, BufferizeOpts
|
||||
c = UOp.const(dtypes.float, 42.0)
|
||||
r1 = UOp.range(3, 0)
|
||||
bufferize_with_range = UOp(Ops.BUFFERIZE, dtypes.float, (c, r1), arg=BufferizeOpts(device="CPU"))
|
||||
bufferize_with_range = UOp(Ops.STAGE, dtypes.float, (c, r1), arg=BufferizeOpts(device="CPU"))
|
||||
self.assertEqual(len(bufferize_with_range.src), 2) # const + 1 range
|
||||
|
||||
result = graph_rewrite(bufferize_with_range, pm_const_buffer_folding, name='test')
|
||||
# BUFFERIZE should be removed, result is const broadcast to shape
|
||||
self.assertNotEqual(result.op, Ops.BUFFERIZE)
|
||||
self.assertNotEqual(result.op, Ops.STAGE)
|
||||
const_vals = [u.arg for u in result.toposort() if u.op is Ops.CONST and u.dtype == dtypes.float]
|
||||
self.assertIn(42.0, const_vals)
|
||||
|
||||
|
|
@ -814,12 +814,12 @@ class TestConstBufferize(unittest.TestCase):
|
|||
c = UOp.const(dtypes.float, 3.14)
|
||||
r1 = UOp.range(3, 0)
|
||||
r2 = UOp.range(4, 1)
|
||||
bufferize_with_ranges = UOp(Ops.BUFFERIZE, dtypes.float, (c, r1, r2), arg=BufferizeOpts(device="CPU"))
|
||||
bufferize_with_ranges = UOp(Ops.STAGE, dtypes.float, (c, r1, r2), arg=BufferizeOpts(device="CPU"))
|
||||
self.assertEqual(len(bufferize_with_ranges.src), 3) # const + 2 ranges
|
||||
|
||||
result = graph_rewrite(bufferize_with_ranges, pm_const_buffer_folding, name='test')
|
||||
# BUFFERIZE should be removed
|
||||
self.assertNotEqual(result.op, Ops.BUFFERIZE)
|
||||
self.assertNotEqual(result.op, Ops.STAGE)
|
||||
const_vals = [u.arg for u in result.toposort() if u.op is Ops.CONST and u.dtype == dtypes.float]
|
||||
self.assertIn(3.14, const_vals)
|
||||
|
||||
|
|
|
|||
|
|
@ -98,13 +98,13 @@ expander = PatternMatcher([
|
|||
# END on UNROLL ends the UNROLL
|
||||
(UPat(Ops.END, name="u"), end_unrolls),
|
||||
# BUFFERIZE puts UNROLLs for ranges as contract
|
||||
(UPat(Ops.BUFFERIZE, src=(UPat(Ops.UNROLL), UPat(Ops.UNROLL)), name="x"),
|
||||
(UPat(Ops.STAGE, src=(UPat(Ops.UNROLL), UPat(Ops.UNROLL)), name="x"),
|
||||
lambda x: x.replace(src=tuple(UOp(Ops.CONTRACT, dtype=s.dtype.vec(x.src[1].src[0].dtype.count), src=(s,), arg=x.src[1].arg) for s in x.src))),
|
||||
# double expand
|
||||
(UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
|
||||
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
||||
# do expansion
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE,
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.STAGE,
|
||||
Ops.STACK, Ops.REDUCE, Ops.END, Ops.AFTER), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
|
||||
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
||||
# empty UNROLL is NOOP
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ class Scheduler:
|
|||
ret = [r for r in self._output_rngs() if r.arg[-1] == AxisType.LOOP]
|
||||
# exclude any output ranges from global that don't appear in all BUFFERIZE
|
||||
for x in self.ast.toposort():
|
||||
if x.op is Ops.BUFFERIZE:
|
||||
if x.op is Ops.STAGE:
|
||||
ret = [r for r in ret if r in x.ranges]
|
||||
return ret
|
||||
|
||||
|
|
@ -347,6 +347,6 @@ def apply_opts(ast:UOp, ren:Renderer, beam:int=0) -> UOp:
|
|||
elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()):
|
||||
from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
|
||||
# NOTE: hand_coded_optimizations doesn't support multiblock opts yet
|
||||
if not any(u.op is Ops.BUFFERIZE for u in ast.backward_slice):
|
||||
if not any(u.op is Ops.STAGE for u in ast.backward_slice):
|
||||
k = hand_coded_optimizations(k)
|
||||
return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None)
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ class IndexingContext:
|
|||
return UOp.range(s, next(self.range_idx), axistype) if resolve(s!=1) else UOp.const(dtypes.weakint, 0)
|
||||
|
||||
def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
||||
if x.op in {Ops.BUFFERIZE, Ops.INDEX}: return None
|
||||
if x.op in {Ops.STAGE, Ops.INDEX}: return None
|
||||
new_srcs = []
|
||||
for s in x.src:
|
||||
new_src = s
|
||||
|
|
@ -74,7 +74,7 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
|||
# None in the device assigns it a number later
|
||||
opts = BufferizeOpts(device=s.device, removable=removable) if len(ctx.range_map[s][1]) == len(realized_ranges) else \
|
||||
BufferizeOpts(device=s.device, addrspace=AddrSpace.LOCAL, removable=removable)
|
||||
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts)
|
||||
new_src = UOp(Ops.STAGE, s.dtype, src=(new_src,)+closed_ranges, arg=opts)
|
||||
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
|
||||
new_srcs.append(new_src)
|
||||
# NOTE: do we need this?
|
||||
|
|
|
|||
|
|
@ -246,7 +246,7 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
|
|||
indexes: list[UOp] = []
|
||||
reduces: list[UOp] = []
|
||||
def red_gate(x:UOp):
|
||||
if (x.op is Ops.BUFFERIZE and x.arg.addrspace == AddrSpace.GLOBAL) or x.op is Ops.MSTACK:
|
||||
if (x.op is Ops.STAGE and x.arg.addrspace == AddrSpace.GLOBAL) or x.op is Ops.MSTACK:
|
||||
accessed_buffers.append(x)
|
||||
return False
|
||||
if x.op is Ops.STORE:
|
||||
|
|
@ -269,7 +269,7 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
|
|||
buffer_in_reduce = False
|
||||
def buf_gate(x:UOp):
|
||||
nonlocal buffer_in_reduce
|
||||
if x.op in {Ops.PARAM, Ops.BUFFERIZE}: buffer_in_reduce = True
|
||||
if x.op in {Ops.PARAM, Ops.STAGE}: buffer_in_reduce = True
|
||||
return not buffer_in_reduce
|
||||
UOp.sink(*[x.src[0] for x in reduces]).toposort(gate=buf_gate)
|
||||
del buf_gate
|
||||
|
|
@ -278,7 +278,7 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
|
|||
out_in_ratio = (prod(buf.shape)+1) / (sum([x.numel() for x in accessed_buffers])+1)
|
||||
if out_in_ratio < 10: return None
|
||||
# here we have to check the indexes, we might do a partial contig here
|
||||
local_indexes = [x for x in indexes if x.src[0].op is Ops.BUFFERIZE and x.src[0].arg.addrspace == AddrSpace.LOCAL]
|
||||
local_indexes = [x for x in indexes if x.src[0].op is Ops.STAGE and x.src[0].arg.addrspace == AddrSpace.LOCAL]
|
||||
exclude_ranges = UOp.group(*[UOp.group(*x.src[1:]) for x in local_indexes]).ranges
|
||||
subs = [(k,v) for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST]
|
||||
# if it's bufferized or a reduce, it's pcontig
|
||||
|
|
@ -302,11 +302,11 @@ def remove_noop_bufferize(idx,b2):
|
|||
return idx.src[0].shrink(tuple((0, s) for s in b2.shape)) if b2.shape else idx.src[0]
|
||||
|
||||
pm_const_buffer_folding = pm_mops+PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, name="b"), cleanup_dead_axes),
|
||||
(UPat(Ops.STAGE, name="b"), cleanup_dead_axes),
|
||||
# remove noop buffers. if we look at the next index we can remove even more of these
|
||||
(UPat(Ops.INDEX, name="idx").f(Ops.BUFFERIZE, allow_any_len=True, name="b2"), remove_noop_bufferize),
|
||||
(UPat(Ops.INDEX, name="idx").f(Ops.STAGE, allow_any_len=True, name="b2"), remove_noop_bufferize),
|
||||
# no buffers for const (ranges don't matter for const - it's the same value everywhere)
|
||||
(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg)),
|
||||
(UPat(Ops.CONST, name='c').f(Ops.STAGE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg)),
|
||||
# indexing a const is a const
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),),), lambda c: c),
|
||||
# copy on CONST is CONST
|
||||
|
|
@ -320,7 +320,7 @@ pm_const_buffer_folding = pm_mops+PatternMatcher([
|
|||
|
||||
pm_remove_bufferize = PatternMatcher([
|
||||
# remove reindexing with cost function
|
||||
(UPat.var("src").f(Ops.BUFFERIZE, allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize),
|
||||
(UPat.var("src").f(Ops.STAGE, allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize),
|
||||
# STORE to self is NOOP
|
||||
(UPat.var("x").store(UPat.var("x")), lambda x: UOp(Ops.NOOP)),
|
||||
# END on NOOP is NOOP
|
||||
|
|
@ -345,7 +345,7 @@ def late_buffer_view(t:UOp, b:UOp):
|
|||
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset)), b.src[1]))
|
||||
|
||||
to_bufferview = PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, src=(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t"), UPat()), name="b"), late_buffer_view),
|
||||
(UPat(Ops.STAGE, src=(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t"), UPat()), name="b"), late_buffer_view),
|
||||
])
|
||||
|
||||
DEVICE_MAX_BUFS = {"METAL": 31, "WEBGPU": 8} # TODO: get from device?
|
||||
|
|
@ -357,7 +357,7 @@ def limit_bufs(ctx:IndexingContext, root:UOp):
|
|||
bufs: set[UOp] = set()
|
||||
def gate_input(u:UOp):
|
||||
# TODO: add cache to fix n^2
|
||||
if is_load:=(u.op in {Ops.BUFFERIZE, Ops.AFTER, Ops.PARAM, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_VAR}): bufs.add(u)
|
||||
if is_load:=(u.op in {Ops.STAGE, Ops.AFTER, Ops.PARAM, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_VAR}): bufs.add(u)
|
||||
return not is_load
|
||||
root.toposort(gate=gate_input)
|
||||
|
||||
|
|
@ -394,7 +394,7 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
|
|||
ended_stores = []
|
||||
for store in stores:
|
||||
store_target = store.src[0]
|
||||
if store_target.src[0].op is Ops.BUFFERIZE and store_target.src[0].src[0].op is Ops.INDEX:
|
||||
if store_target.src[0].op is Ops.STAGE and store_target.src[0].src[0].op is Ops.INDEX:
|
||||
store_target = store_target.src[0].src[0]
|
||||
if store.src[1] is store_target: continue # skip self-assign
|
||||
end_rngs = sorted(dedup(tuple(store_target.ranges) + tuple(rngs)), key=lambda x: x.arg)
|
||||
|
|
@ -423,10 +423,10 @@ def flatten_bufferize(x:UOp):
|
|||
sym_shape = tuple([r.src[0] if r.op is not Ops.CONST else 1 for r in rngs])
|
||||
ret = ret.shrink(tuple([(0,x) for x in sym_shape]))
|
||||
return ret
|
||||
pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)])
|
||||
pm_flatten_bufferize = PatternMatcher([(UPat(Ops.STAGE, name="x"), flatten_bufferize)])
|
||||
|
||||
pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, src=(UPat(), UPat(name="idx")), name="x"), lambda ctx,x,idx: bufferize_to_store(ctx, x, idx, allow_locals=False)),
|
||||
(UPat(Ops.STAGE, src=(UPat(), UPat(name="idx")), name="x"), lambda ctx,x,idx: bufferize_to_store(ctx, x, idx, allow_locals=False)),
|
||||
|
||||
# move RESHAPEs through MSELECT/MSTACK
|
||||
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
|
||||
|
|
@ -447,7 +447,7 @@ pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
|||
])
|
||||
|
||||
pm_add_buffers_local = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, src=(UPat(), UPat(name="idx")), name="x"), bufferize_to_store),
|
||||
(UPat(Ops.STAGE, src=(UPat(), UPat(name="idx")), name="x"), bufferize_to_store),
|
||||
])
|
||||
|
||||
# *****************
|
||||
|
|
@ -506,7 +506,7 @@ to_define_global = PatternMatcher([
|
|||
(UPat((Ops.MSTACK, Ops.MSELECT, Ops.AFTER), name="after"), handle_after),
|
||||
|
||||
# remove device from local BUFFERIZE
|
||||
(UPat(Ops.BUFFERIZE, name="b"), lambda b: b.replace(arg=replace(b.arg, device=None))),
|
||||
(UPat(Ops.STAGE, name="b"), lambda b: b.replace(arg=replace(b.arg, device=None))),
|
||||
|
||||
# remove UNIQUE/DEVICE to dedup CONST
|
||||
(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
|
||||
|
|
@ -581,7 +581,7 @@ def get_kernel_graph(sink:UOp) -> UOp:
|
|||
|
||||
# bufferize -> store
|
||||
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
|
||||
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store")
|
||||
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="stage to store")
|
||||
tsink = graph_rewrite(tsink, split_kernels, bottom_up=True, name="split kernels")
|
||||
|
||||
# WAR deps: if kernel U reads buffer S, and S is also written by another kernel, S's write must wait for U to finish
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ class Ops(FastEnum):
|
|||
CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto()
|
||||
|
||||
# buffer ops
|
||||
BUFFERIZE = auto(); COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto(); CUSTOM_FUNCTION = auto()
|
||||
STAGE = auto(); COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto(); CUSTOM_FUNCTION = auto()
|
||||
|
||||
# the core 6 movement ops! these only exist in the tensor graph
|
||||
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto()
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL:
|
|||
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
|
||||
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
|
||||
|
||||
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.FUNCTION: 1,
|
||||
range_start = {Ops.STAGE: 1, Ops.REDUCE: 1, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.FUNCTION: 1,
|
||||
Ops.COPY: 2, Ops.BUFFER_VIEW: 1, Ops.LINEAR: 0}
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
|
|
@ -246,7 +246,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
if self.src[0].op is Ops.INDEX: return ()
|
||||
return (self.arg[0],)
|
||||
case Ops.CUSTOM_FUNCTION: return None
|
||||
case Ops.BUFFERIZE: return tuple([int(r.vmax+1) for r in self.src[1:]])
|
||||
case Ops.STAGE: return tuple([int(r.vmax+1) for r in self.src[1:]])
|
||||
case Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
|
||||
case Ops.PARAM:
|
||||
if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size,)
|
||||
|
|
@ -525,7 +525,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
if self.op is Ops.CONTIGUOUS: return self
|
||||
if self.has_buffer_identity(): return self
|
||||
return UOp(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
|
||||
def bufferize(self, *args, **kwargs): return UOp(Ops.BUFFERIZE, dtype=self.dtype, src=(self,)+args, **kwargs)
|
||||
def bufferize(self, *args, **kwargs): return UOp(Ops.STAGE, dtype=self.dtype, src=(self,)+args, **kwargs)
|
||||
def allreduce(self, op, device:str|tuple[str, ...]|UOp):
|
||||
assert isinstance(self.device, tuple), f"allreduce must be on tuple {self.device} isn't"
|
||||
return UOp(Ops.ALLREDUCE, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), op)
|
||||
|
|
@ -674,7 +674,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
@recursive_property
|
||||
def _device(self) -> str|tuple[str, ...]|None:
|
||||
if self.op is Ops.DEVICE: return self.arg
|
||||
if self.op is Ops.BUFFERIZE: return self.arg.device
|
||||
if self.op is Ops.STAGE: return self.arg.device
|
||||
if self.op is Ops.AFTER: return self.src[0]._device
|
||||
if self.op is Ops.MSELECT:
|
||||
assert isinstance(self.src[0].device, tuple), f"mselect must be on tuple device, getting {self.src[0].device}"
|
||||
|
|
@ -691,7 +691,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
if self.op is Ops.MSTACK: return UOp(Ops.MSTACK, self.dtype, src=tuple(x.buf_uop for x in self.src))
|
||||
if self.base.op is Ops.AFTER: return self.base.src[0].buf_uop.base
|
||||
s = self
|
||||
while len(s.src) and s.op not in {Ops.BUFFER, Ops.PARAM, Ops.BUFFERIZE, Ops.MSTACK}: s = s.src[0]
|
||||
while len(s.src) and s.op not in {Ops.BUFFER, Ops.PARAM, Ops.STAGE, Ops.MSTACK}: s = s.src[0]
|
||||
return s
|
||||
|
||||
def contiguous_view_offset(self) -> int|None:
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ renderer = PatternMatcher([
|
|||
(UPat(Ops.CDIV, name="x"), lambda ctx,x: f"cdiv({ctx[x.src[0]]}, {ctx[x.src[1]]})"),
|
||||
(UPat(Ops.CMOD, name="x"), lambda ctx,x: f"cmod({ctx[x.src[0]]}, {ctx[x.src[1]]})"),
|
||||
(UPat(set(syms.keys()), name="x"), lambda ctx,x: strip_binary_parens(x, ctx[x.src[0]], ctx[x.src[1]], lambda a,b: f"({a}{syms[x.op]}{b})")),
|
||||
(UPat((Ops.INDEX, Ops.BUFFERIZE), name="x"), lambda x, ctx: ''.join([f"[{strip_parens(ctx[y])}]" for y in x.src[1:]])),
|
||||
(UPat((Ops.INDEX, Ops.STAGE), name="x"), lambda x, ctx: ''.join([f"[{strip_parens(ctx[y])}]" for y in x.src[1:]])),
|
||||
(UPat(Ops.STACK, name="x"),
|
||||
lambda ctx,x: f"{{{','.join([ctx[y] for y in x.src])}}}" if not x.src or not all_same(x.src) else f"{{{ctx[x.src[0]]}, ...}}"),
|
||||
(UPat(GroupOp.All, name="x"), lambda x: str(x)),
|
||||
|
|
|
|||
|
|
@ -215,7 +215,7 @@ kernel_spec = PatternMatcher([
|
|||
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True), lambda: True),
|
||||
|
||||
# bufferize can be on anything
|
||||
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True), lambda: True),
|
||||
(UPat(Ops.STAGE, src=(UPat(),), allow_any_len=True), lambda: True),
|
||||
|
||||
# REDUCE has arg=(op, axis_tuple), src[1:] are ranges after lowering
|
||||
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"),
|
||||
|
|
|
|||
|
|
@ -290,7 +290,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
|||
((UPat.var("x", dtypes.weakint) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
|
||||
# only RANGE/IF/STORE/KERNEL have side effects
|
||||
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
|
||||
tuple(dedup(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.FUNCTION, Ops.BARRIER, Ops.END, Ops.UNROLL, Ops.LINEAR, Ops.BUFFERIZE}
|
||||
tuple(dedup(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.FUNCTION, Ops.BARRIER, Ops.END, Ops.UNROLL, Ops.LINEAR, Ops.STAGE}
|
||||
else y.src for y in x.src[1:]]))))),
|
||||
# after with 1 src is just src[0]
|
||||
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
|||
Ops.CALL: "#00B7C8", Ops.FUNCTION: "#C07788", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.BINARY: "#404040",
|
||||
Ops.LINEAR: "#7DF4FF",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||
Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}
|
||||
Ops.STAGE: "#AC640D", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}
|
||||
|
||||
# VIZ API
|
||||
|
||||
|
|
@ -125,7 +125,7 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
|
|||
wrap_len = 200 if u.op is Ops.SOURCE else 80
|
||||
label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''), wrap=wrap_len)) if u.arg is not None else ''}"
|
||||
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
|
||||
for idx,x in enumerate(u.src[:1] if u.op in {Ops.BUFFERIZE, Ops.INDEX} else (u.src if u.op is not Ops.END else [])):
|
||||
for idx,x in enumerate(u.src[:1] if u.op in {Ops.STAGE, Ops.INDEX} else (u.src if u.op is not Ops.END else [])):
|
||||
if x in excluded:
|
||||
# walk through excluded movement ops to find the underlying CONST
|
||||
cx = x
|
||||
|
|
@ -139,7 +139,7 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
|
|||
label += f"\n{shape_to_str(u.shape)}"
|
||||
if u.op in {Ops.CALL, Ops.FUNCTION}:
|
||||
label += f"\n{u.src[0].key.hex()[:8]}"
|
||||
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
|
||||
if u.op in {Ops.INDEX, Ops.STAGE}:
|
||||
if len(u.toposort()) < 30: label += f"\n{u.render()}"
|
||||
ranges: list[UOp] = []
|
||||
for us in u.src[1:]: ranges += [s for s in us.toposort() if s.op in {Ops.RANGE, Ops.SPECIAL}]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue