mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
rename BUFFER_VIEW to SLICE (#16391)
* rename BUFFER_VIEW to SLICE * fix comments
This commit is contained in:
parent
3adf7f5d95
commit
156a4438d9
20 changed files with 63 additions and 70 deletions
|
|
@ -28,7 +28,7 @@ pm_insert_deps = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), insert_deps)]
|
|||
|
||||
pm_replace_params = PatternMatcher([
|
||||
(UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.input_addrs_uop.index(UOp.const(dtypes.int, p.arg))),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.INDEX, name="addr"), UPat(Ops.CONST, dtype=dtypes.weakint, name="off")), name="bv"),
|
||||
(UPat(Ops.SLICE, src=(UPat(Ops.INDEX, name="addr"), UPat(Ops.CONST, dtype=dtypes.weakint, name="off")), name="bv"),
|
||||
lambda ctx, bv, addr, off: addr.cast(dtypes.uint64) + UOp.const(dtypes.uint64, off.arg * ctx.input_uops[addr.src[1].arg].dtype.itemsize)),
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -167,7 +167,7 @@ class HCQEncoder:
|
|||
def __init__(self): self.blob, self.patches = b'', []
|
||||
|
||||
def get_dev_addr(self, uop:UOp) -> UOp:
|
||||
return UOp(Ops.GETADDR, dtypes.uint64, src=(uop,)) if unwrap_after(uop).op in (Ops.BUFFER, Ops.BUFFER_VIEW, Ops.BINARY, Ops.MSTACK, Ops.MSELECT) else uop
|
||||
return UOp(Ops.GETADDR, dtypes.uint64, src=(uop,)) if unwrap_after(uop).op in (Ops.BUFFER, Ops.SLICE, Ops.BINARY, Ops.MSTACK, Ops.MSELECT) else uop
|
||||
|
||||
def append(self, *data, dtype=dtypes.uint32):
|
||||
for d in data:
|
||||
|
|
@ -291,7 +291,7 @@ def bufferize_kernargs(ctx:HCQ2LowerCtx, target:UOp, buf_node:UOp) -> UOp:
|
|||
dctx = ctx.dev_ctx[dev]
|
||||
isz = dctx.kernargs_host.dtype.base.itemsize
|
||||
off = dctx.kernargs_allocator.alloc(buf_node.arg, 16)
|
||||
hbufs.append(UOp(Ops.BUFFER_VIEW, dctx.kernargs_host.dtype,
|
||||
hbufs.append(UOp(Ops.SLICE, dctx.kernargs_host.dtype,
|
||||
src=(dctx.kernargs_host, UOp.const(dtypes.weakint, off // isz)), arg=buf_node.arg // isz))
|
||||
addrs.append(dctx.kernargs_gpu + off)
|
||||
return _maybe_mstack(tuple(addrs)).after(*_lower_stores(_maybe_mstack(tuple(hbufs)), buf_node, target.src[1:]))
|
||||
|
|
@ -352,7 +352,7 @@ def fold_blob_store(ctx:HCQ2LowerCtx, buf:UOp, blob:UOp) -> UOp:
|
|||
def resolve_getaddr(ctx:HCQ2LowerCtx, m:UOp) -> UOp:
|
||||
srcs = m.src if m.op is Ops.MSTACK else (m,)
|
||||
for s in srcs:
|
||||
if s.op in (Ops.BUFFER, Ops.BUFFER_VIEW) and s not in ctx.holds: ctx.holds.append(s)
|
||||
if s.op in (Ops.BUFFER, Ops.SLICE) and s not in ctx.holds: ctx.holds.append(s)
|
||||
addrs = [s.arg if s.op is Ops.CONST else s.buffer.get_buf(s.device).va_addr for s in srcs]
|
||||
|
||||
# fast-path: all per-dev VAs equal -> just a const
|
||||
|
|
@ -365,14 +365,14 @@ def resolve_getaddr(ctx:HCQ2LowerCtx, m:UOp) -> UOp:
|
|||
|
||||
pm_resolve_patches = symbolic + PatternMatcher([
|
||||
# resolve getaddrs
|
||||
(UPat(Ops.GETADDR, src=(UPat(Ops.BUFFER_VIEW, name="bv"),)), # getaddr(buffer_view(x)) -> offset+getaddr(x)
|
||||
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"),)), # getaddr(buffer_view(x)) -> offset+getaddr(x)
|
||||
lambda ctx, bv: UOp(Ops.GETADDR, dtypes.uint64, src=(bv.src[0],)) + UOp.const(dtypes.uint64, bv.src[1].arg * bv.src[0].dtype.itemsize)),
|
||||
(UPat(Ops.GETADDR, src=(UPat((Ops.BUFFER, Ops.MSTACK), name="m"),)), resolve_getaddr), # getaddr(buffer|mstack) -> addr_table load|const
|
||||
(UPat(Ops.GETADDR, src=(UPat.cvar("const"),)), lambda ctx, const: const), # getaddr(const) -> const
|
||||
|
||||
# write consts and binaries directly into the buffer (BUFFER or MSTACK of BUFFERs)
|
||||
(UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK), name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
|
||||
(UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK), name="buf").index(UPat.cvar("off")).or_casted()
|
||||
(UPat((Ops.BUFFER, Ops.SLICE, Ops.MSTACK), name="buf").store(UPat(Ops.BINARY, name="blob")), fold_blob_store),
|
||||
(UPat((Ops.BUFFER, Ops.SLICE, Ops.MSTACK), name="buf").index(UPat.cvar("off")).or_casted()
|
||||
.store(UPat.any(UPat.cvar("val"), UPat(Ops.MSTACK, src=UPat.cvar(), name="val"))), fold_const_store),
|
||||
])
|
||||
|
||||
|
|
@ -386,12 +386,12 @@ def parametrize_host_buffer(ctx:HCQ2LowerCtx, buf:UOp) -> UOp:
|
|||
|
||||
pm_parametrize_host_buffers = PatternMatcher([
|
||||
# resolve buffer views to parametrize only root buffers
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.BUFFER_VIEW, name="bv"), UPat.var("idx")), name="bi"),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.SLICE, name="bv"), UPat.var("idx")), name="bi"),
|
||||
lambda bv, idx, bi: bi.replace(src=(bv.src[0], idx * bv.dtype.itemsize // bv.src[0].dtype.itemsize + bv.src[1].arg))),
|
||||
|
||||
# parametrize host buffers
|
||||
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK)),), allow_any_len=True, name="buf"), parametrize_host_buffer),
|
||||
(UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK), name="buf"), parametrize_host_buffer),
|
||||
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.SLICE, Ops.MSTACK)),), allow_any_len=True, name="buf"), parametrize_host_buffer),
|
||||
(UPat((Ops.BUFFER, Ops.SLICE, Ops.MSTACK), name="buf"), parametrize_host_buffer),
|
||||
|
||||
# remove UNIQUE/DEVICE to dedup CONST
|
||||
(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -79,7 +79,7 @@ A \op{Buffer}'s \textbf{addrspace} is \texttt{GLOBAL}, \texttt{LOCAL}, or \textt
|
|||
\op{Index} & $(T, i_0, i_1, \ldots)$ & --- & Index from left. $()$-shaped $i$ removes dim; $(k,)$-shaped makes it $k$. \\
|
||||
\op{Stack} & $(T_0, T_1, \ldots)$ & --- & Join along a newly created leading axis. All shapes must match. \\
|
||||
\op{Replicated} & $(T,)$ & axes & Mark $T$ as replicated along axes. Collapse axes to $1$. \\
|
||||
\op{BufferView} & $(T, \mathrm{offset})$ & size, dtype & Zero-copy \textit{size} elems of dtype; offset is \texttt{weakint} elems of $T$ dtype. \\
|
||||
\op{Slice} & $(T, \mathrm{offset})$ & size, dtype & Zero-copy \textit{size} elems of dtype; offset is elems of $T$ dtype. \\
|
||||
\bottomrule
|
||||
\end{tabular}
|
||||
|
||||
|
|
|
|||
|
|
@ -972,13 +972,6 @@ class TestSchedule(unittest.TestCase):
|
|||
run_linear(*check_schedule(out, 2))
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
@unittest.skip("BUFFER_VIEW no longer supported on non-disk devices")
|
||||
def test_arange_view_op(self):
|
||||
a = Tensor.arange(12).reshape(4, 3).shrink(((1, 2), (1, 3))).contiguous()
|
||||
sched = run_linear(*check_schedule(a, 1))
|
||||
self.assertIs(sched[1].ast.op, Ops.BUFFER_VIEW)
|
||||
np.testing.assert_equal(a.numpy(), [[4, 5]])
|
||||
|
||||
@unittest.skipUnless(dtypes.half in supported_dtypes, "need half")
|
||||
def test_precompute_freqs_cis(self):
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ def _make_linear(buffer_lists, copies=None):
|
|||
def _get_arena(buf, linear, result):
|
||||
for orig_si, new_si in zip(linear.src, result.src):
|
||||
for orig, new in zip(orig_si.src[1:], new_si.src[1:]):
|
||||
if orig is buf and new.op is Ops.BUFFER_VIEW: return new.src[0]
|
||||
if orig is buf and new.op is Ops.SLICE: return new.src[0]
|
||||
return None
|
||||
|
||||
def check_assign(buffer_lists, copies=None):
|
||||
|
|
@ -37,7 +37,7 @@ def check_assign(buffer_lists, copies=None):
|
|||
replace_map: dict[int, tuple[UOp, int, int]] = {}
|
||||
for orig_si, new_si in zip(linear.src, result.src):
|
||||
for orig, new in zip(orig_si.src[1:], new_si.src[1:]):
|
||||
if new.op is Ops.BUFFER_VIEW and id(orig) not in replace_map:
|
||||
if new.op is Ops.SLICE and id(orig) not in replace_map:
|
||||
replace_map[id(orig)] = (new.src[0], new.src[1].arg * new.src[0].dtype.itemsize, new.arg * new.dtype.itemsize)
|
||||
|
||||
# verify pinned buffers are not planned
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ class TestDataset(unittest.TestCase):
|
|||
X_train[0].contiguous().realize()
|
||||
GlobalCounters.reset()
|
||||
X_train[0].contiguous().realize()
|
||||
self.assertLessEqual(GlobalCounters.kernel_count, 1) # 0 if BUFFER_VIEW (zero-copy), 1 otherwise
|
||||
self.assertLessEqual(GlobalCounters.kernel_count, 1) # 0 if SLICE (zero-copy), 1 otherwise
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ class TestContiguous(unittest.TestCase):
|
|||
def test_size_change_buffer_view(self):
|
||||
a = Tensor.empty(4)
|
||||
b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).contiguous()
|
||||
check_schedule(b, 0) # contiguous shrink of a realized buffer is a zero-copy BUFFER_VIEW
|
||||
check_schedule(b, 0) # contiguous shrink of a realized buffer is a zero-copy SLICE
|
||||
|
||||
def test_double_contiguous_realizes_once(self):
|
||||
a = Tensor.empty(4, 1)
|
||||
|
|
@ -1210,10 +1210,10 @@ class TestFusionOp(unittest.TestCase):
|
|||
self.assertEqual(len(linear.src), 1)
|
||||
self.assertLess(time.perf_counter()-st, 2.0)
|
||||
|
||||
# NOTE: the NULL backend supports BUFFER_VIEW
|
||||
# NOTE: the NULL backend supports SLICE
|
||||
class TestBufferView(unittest.TestCase):
|
||||
def test_shrink_contiguous_is_buffer_view(self):
|
||||
# simple 1D shrink of a realized buffer should be BUFFER_VIEW, not a copy kernel
|
||||
# simple 1D shrink of a realized buffer should be SLICE, not a copy kernel
|
||||
a = Tensor.arange(100).clone().realize()
|
||||
b = a.shrink(((10, 50),)).contiguous()
|
||||
run_linear(*check_schedule(b, 0))
|
||||
|
|
@ -1229,7 +1229,7 @@ class TestBufferView(unittest.TestCase):
|
|||
run_linear(*check_schedule(b, 0))
|
||||
|
||||
def test_shrink_non_shard_axis_is_buffer_view_multi(self):
|
||||
# indexing a non-shard axis of a realized sharded tensor should be BUFFER_VIEW on each device, not copy kernels
|
||||
# indexing a non-shard axis of a realized sharded tensor should be SLICE on each device, not copy kernels
|
||||
# this is the flat_llama pattern: weight[layer_idx] where weight is (n_layers, out, dim) sharded on axis=1
|
||||
devices = ("NULL:1", "NULL:2")
|
||||
a = Tensor.arange(8*4*10).reshape(8, 4, 10).clone().shard(devices, axis=1).realize()
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class TestMetalGraph(unittest.TestCase):
|
|||
def metal_buf(self, offset):
|
||||
buf = MagicMock()
|
||||
if offset > 0:
|
||||
buf.op = Ops.BUFFER_VIEW
|
||||
buf.op = Ops.SLICE
|
||||
src = MagicMock()
|
||||
src.dtype = dtypes.uint8
|
||||
buf.src = (src, UOp.const(dtypes.weakint, offset))
|
||||
|
|
@ -36,7 +36,7 @@ class TestMetalGraph(unittest.TestCase):
|
|||
assert self.MetalGraph.supports_uop([self.dev], self.call(self.metal_buf(0), self.metal_buf(0x100000000))) is False
|
||||
|
||||
def test_supports_uop_nonmetal_buf(self):
|
||||
# non-BUFFER_VIEW ops should not be checked for offset
|
||||
# non-SLICE ops should not be checked for offset
|
||||
buf = MagicMock()
|
||||
buf.op = Ops.BUFFER
|
||||
buf.device = Device.DEFAULT
|
||||
|
|
|
|||
|
|
@ -57,21 +57,21 @@ def replace_store_after_with_contig(u:UOp, src:UOp):
|
|||
if assigned_to.op is not Ops.BUFFER: return src.contiguous(tag=u.tag)
|
||||
|
||||
def _make_buffer_view(src:UOp) -> UOp|None:
|
||||
"""If movement ops on src collapse to a contiguous range, return BUFFER_VIEW.reshape(src.shape). Otherwise None."""
|
||||
"""If movement ops on src collapse to a contiguous range, return SLICE.reshape(src.shape). Otherwise None."""
|
||||
if (offset := src.contiguous_view_offset()) is None: return None
|
||||
buf = src.base
|
||||
if buf.op is Ops.BUFFER_VIEW:
|
||||
if buf.op is Ops.SLICE:
|
||||
byte_offset = buf.src[1].arg * buf.src[0].dtype.itemsize + offset * src.dtype.itemsize
|
||||
buf = buf.src[0]
|
||||
if byte_offset % buf.dtype.itemsize != 0: return None
|
||||
offset = byte_offset // buf.dtype.itemsize
|
||||
return UOp(Ops.BUFFER_VIEW, src.dtype, (buf, UOp.const(dtypes.weakint, offset)), src.numel()).reshape(src.shape)
|
||||
return UOp(Ops.SLICE, src.dtype, (buf, UOp.const(dtypes.weakint, offset)), src.numel()).reshape(src.shape)
|
||||
|
||||
def contiguous_mops_to_view(c:UOp, src:UOp):
|
||||
"""CONTIGUOUS(MOPS(BUFFER)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to a contiguous range."""
|
||||
"""CONTIGUOUS(MOPS(BUFFER)) → CONTIGUOUS(SLICE) when movement ops collapse to a contiguous range."""
|
||||
buf = src.base
|
||||
if buf.op not in {Ops.BUFFER, Ops.BUFFER_VIEW}: return None
|
||||
if src.op is Ops.RESHAPE and src.src[0].op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return None
|
||||
if buf.op not in {Ops.BUFFER, Ops.SLICE}: return None
|
||||
if src.op is Ops.RESHAPE and src.src[0].op in {Ops.BUFFER, Ops.SLICE}: return None
|
||||
|
||||
# no symbolic shape
|
||||
if not all_int(c.shape): return None
|
||||
|
|
@ -84,11 +84,11 @@ def contiguous_mops_to_view(c:UOp, src:UOp):
|
|||
|
||||
x = src
|
||||
while x.op in GroupOp.Movement: x = x.src[0]
|
||||
# NOTE: this contiguous is removed because this BUFFER_VIEW/RESHAPE has_buffer_identity
|
||||
# NOTE: this contiguous is removed because this SLICE/RESHAPE has_buffer_identity
|
||||
if x.op is not Ops.MULTI and (view := _make_buffer_view(src)) is not None:
|
||||
return view.contiguous(tag=c.tag)
|
||||
|
||||
# for MULTI tensors, use multi_pm to resolve per-shard movement ops, then create BUFFER_VIEW on the resolved result
|
||||
# for MULTI tensors, use multi_pm to resolve per-shard movement ops, then create SLICE on the resolved result
|
||||
if not isinstance(c.device, str):
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
resolved = graph_rewrite(src, multi_pm, name="multi_buffer_view")
|
||||
|
|
@ -142,7 +142,7 @@ pm_early_transform_tensor_graph = PatternMatcher([
|
|||
# resolve TUPLE+GETTUPLE (for precompiled calls)
|
||||
(UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]),
|
||||
|
||||
# CONTIGUOUS(MOPS(BUFFER/BUFFER_VIEW)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to contiguous range
|
||||
# CONTIGUOUS(MOPS(BUFFER/SLICE)) → CONTIGUOUS(SLICE) when movement ops collapse to contiguous range
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement, name="src"),), name="c"), contiguous_mops_to_view),
|
||||
|
||||
# add CONTIGUOUS to tagged UOps
|
||||
|
|
@ -186,8 +186,8 @@ pm_finalize_call = PatternMatcher([
|
|||
pm_replace_buf = PatternMatcher([
|
||||
# replace BUFFER with PARAM for cache key normalization
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
|
||||
# replace BUFFER_VIEW with PARAM. this rewrite is bottom up so BUFFERs we don't need won't be in the input
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.CONST, dtype=dtypes.weakint)), name="b"), replace_input_buffer),
|
||||
# replace SLICE with PARAM. this rewrite is bottom up so BUFFERs we don't need won't be in the input
|
||||
(UPat(Ops.SLICE, src=(UPat(Ops.BUFFER), UPat(Ops.CONST, dtype=dtypes.weakint)), name="b"), replace_input_buffer),
|
||||
# strip value from BIND for cache key normalization, so different values hit same cache
|
||||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), replace_input_buffer),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ def graph_split_rewrite(linear:UOp, max_batch_size:int=0) -> UOp:
|
|||
current_batch, current_batch_devs = [], []
|
||||
|
||||
for si in linear.src:
|
||||
if si.src[0].op is Ops.BUFFER_VIEW: continue
|
||||
if si.src[0].op is Ops.SLICE: continue
|
||||
|
||||
devs = dedup([Device[x] for b in si.src[1:] if b.op is not Ops.BIND for x in (b.device if isinstance(b.device, tuple) else (b.device,))])
|
||||
graph_t = graph_class(devs[0]) if devs[0].graph is not None else None
|
||||
|
|
@ -193,7 +193,7 @@ class CapturedJit(Generic[ReturnType]):
|
|||
if call.op is not Ops.CALL: continue
|
||||
arg_uops = get_call_arg_uops(call)
|
||||
outs, ins = get_call_outs_ins(call)
|
||||
out |= {arg_uops[k] for k in set(outs) - set(ins) if arg_uops[k].op in (Ops.BUFFER, Ops.BUFFER_VIEW)}
|
||||
out |= {arg_uops[k] for k in set(outs) - set(ins) if arg_uops[k].op in (Ops.BUFFER, Ops.SLICE)}
|
||||
return out
|
||||
|
||||
def __call__(self, input_uops:list[UOp], var_vals:dict[str, int]) -> ReturnType:
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ def get_call_arg_uops(call:UOp) -> tuple[UOp, ...]: return tuple(s for s in call
|
|||
def get_call_outs_ins(call:UOp) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
ast = call.src[0]
|
||||
if ast.op is Ops.PROGRAM: return tuple(ast.arg.outs), tuple(ast.arg.ins)
|
||||
if ast.op in (Ops.COPY, Ops.BUFFER_VIEW): return (0,), (1,)
|
||||
if ast.op in (Ops.COPY, Ops.SLICE): return (0,), (1,)
|
||||
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "encdec": return (0,), tuple(range(1, len(get_call_arg_uops(call))))
|
||||
return (), ()
|
||||
|
||||
|
|
@ -27,7 +27,7 @@ def get_call_name(call:UOp, bufs:list[Buffer], var_vals:dict[str, int]|None=None
|
|||
|
||||
ast, arg_uops = call.src[0], get_call_arg_uops(call)
|
||||
if ast.op is Ops.PROGRAM: return ast.arg.name
|
||||
if ast.op is Ops.BUFFER_VIEW:
|
||||
if ast.op is Ops.SLICE:
|
||||
offset = ast.src[1].arg * arg_uops[1].dtype.itemsize
|
||||
return colored(f"view {_uop_sz_to_str(arg_uops[0]):>10} @ {offset:<10d}", "yellow")
|
||||
if ast.op is Ops.COPY: return colored(f"copy {_uop_sz_to_str(arg_uops[0]):>10}, {bufs[0].device[:7]:>7s} <- {bufs[1].device[:7]:7s}", "yellow")
|
||||
|
|
@ -137,7 +137,7 @@ class ExecContext:
|
|||
cache: bool = True
|
||||
|
||||
def _resolve(b:UOp, inputs:tuple[UOp, ...]) -> UOp:
|
||||
if b.op in (Ops.BUFFER_VIEW, Ops.MSELECT) and b.src[0].op is Ops.PARAM: return b.replace(src=(inputs[b.src[0].arg], *b.src[1:]))
|
||||
if b.op in (Ops.SLICE, Ops.MSELECT) and b.src[0].op is Ops.PARAM: return b.replace(src=(inputs[b.src[0].arg], *b.src[1:]))
|
||||
return inputs[b.arg] if b.op is Ops.PARAM else b
|
||||
def resolve_params(call:UOp, inputs:tuple[UOp, ...]) -> list[UOp]: return [_resolve(b, inputs) for b in get_call_arg_uops(call)]
|
||||
|
||||
|
|
@ -233,7 +233,7 @@ pm_optimize_local_size = PatternMatcher([
|
|||
])
|
||||
|
||||
pm_exec = PatternMatcher([
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.BUFFER_VIEW, name="ast"),), name="call", allow_any_len=True), exec_view),
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.SLICE, name="ast"),), name="call", allow_any_len=True), exec_view),
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.COPY, name="ast"),), name="call", allow_any_len=True), exec_copy),
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.PROGRAM, name="ast"),), name="call", allow_any_len=True), exec_kernel),
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.CUSTOM_FUNCTION, arg="encdec", name="ast"),), name="call", allow_any_len=True), exec_encdec),
|
||||
|
|
|
|||
|
|
@ -109,5 +109,5 @@ class MetalGraph(GraphRunner):
|
|||
@staticmethod
|
||||
def supports_uop(batch_devs, new_call:UOp) -> bool:
|
||||
# Metal ICB replay encodes offsets as uint32; reject if any Metal buffer offset exceeds 32-bit range.
|
||||
if any(b.op is Ops.BUFFER_VIEW and b.src[1].arg * b.src[0].dtype.itemsize > 0xFFFFFFFF for b in new_call.src[1:]): return False
|
||||
if any(b.op is Ops.SLICE and b.src[1].arg * b.src[0].dtype.itemsize > 0xFFFFFFFF for b in new_call.src[1:]): return False
|
||||
return GraphRunner.supports_uop(batch_devs, new_call)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink
|
|||
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
|
||||
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored, Context, SPEC
|
||||
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.SLICE,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
|
||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL, Ops.FUNCTION}
|
||||
|
||||
|
|
@ -18,8 +18,8 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
|
|||
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
|
||||
|
||||
def realize_store_after_src(ctx:dict[UOp, None], dest:UOp, src:UOp):
|
||||
# don't realize COPY/BUFFER_VIEW when they are the direct source of STORE+AFTER — the target buffer is the output
|
||||
if src.op in {Ops.COPY, Ops.BUFFER_VIEW} and src in ctx \
|
||||
# don't realize COPY/SLICE when they are the direct source of STORE+AFTER — the target buffer is the output
|
||||
if src.op in {Ops.COPY, Ops.SLICE} and src in ctx \
|
||||
and not dest.op_in_backward_slice_with_self(Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.PAD):
|
||||
del ctx[src]
|
||||
# you don't usually have to do this for assign unless there's a WAR hazard like TestAssign.test_assign_double_diamond_reduce
|
||||
|
|
@ -58,7 +58,7 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
|||
new_srcs = []
|
||||
for s in x.src:
|
||||
new_src = s
|
||||
if s.op in {Ops.PARAM, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT, Ops.AFTER}:
|
||||
if s.op in {Ops.PARAM, Ops.BUFFER, Ops.SLICE, Ops.MSTACK, Ops.MSELECT, Ops.AFTER}:
|
||||
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
|
||||
elif s in ctx.realize_map:
|
||||
realized_ranges = ctx.realize_map[s]
|
||||
|
|
|
|||
|
|
@ -52,11 +52,11 @@ def memory_plan_rewrite(linear:UOp, held_bufs:set[UOp]|None=None) -> UOp:
|
|||
peaks[_key(buf)] = (max(peaks[_key(buf)][0], offsets[buf] + buf.arg * buf.dtype.itemsize), peaks[_key(buf)][1])
|
||||
arena_sizes = {key: round_up(peak, block_size) for key, (peak, _) in peaks.items()}
|
||||
|
||||
# build replace_map: each buffer becomes a BUFFER_VIEW into a shared per-device-lane arena
|
||||
# build replace_map: each buffer becomes a SLICE into a shared per-device-lane arena
|
||||
arenas = {key: UOp.new_buffer(key[0], sz, dtypes.int8) for key, sz in arena_sizes.items()}
|
||||
replace_map:dict[UOp, UOp] = {}
|
||||
for buf_uop, offset in offsets.items():
|
||||
replace_map[buf_uop] = UOp(Ops.BUFFER_VIEW, buf_uop.dtype, (arenas[_key(buf_uop)], UOp.const(dtypes.weakint, offset)), buf_uop.arg)
|
||||
replace_map[buf_uop] = UOp(Ops.SLICE, buf_uop.dtype, (arenas[_key(buf_uop)], UOp.const(dtypes.weakint, offset)), buf_uop.arg)
|
||||
|
||||
if DEBUG >= 1 and (omem:=sum(nbytes.values()) / 1e6) != (nmem:=sum(arena_sizes.values()) / 1e6):
|
||||
print(f"memory reduced from {omem:.2f} MB -> {nmem:.2f} MB, {len(first_appearance)} -> {len(arenas)} bufs")
|
||||
|
|
|
|||
|
|
@ -307,7 +307,7 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
|
|||
return src.substitute(replaced, extra_pm=pm_gate_substitute)
|
||||
|
||||
def remove_noop_bufferize(idx,b2):
|
||||
if idx.src[1:] != b2.src[1:] or idx.src[0].op is Ops.BUFFER_VIEW: return None
|
||||
if idx.src[1:] != b2.src[1:] or idx.src[0].op is Ops.SLICE: return None
|
||||
return idx.src[0].shrink(tuple((0, s) for s in b2.shape)) if b2.shape else idx.src[0]
|
||||
|
||||
pm_const_buffer_folding = pm_mops+PatternMatcher([
|
||||
|
|
@ -351,7 +351,7 @@ def late_buffer_view(t:UOp, b:UOp):
|
|||
if len(shape) == 0: offset = x.src[1].arg
|
||||
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
|
||||
|
||||
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base, UOp.const(dtypes.weakint, offset)), size), b.src[1]))
|
||||
return b.replace(src=(UOp(Ops.SLICE, t.dtype, (x.base, UOp.const(dtypes.weakint, offset)), size), b.src[1]))
|
||||
|
||||
to_bufferview = PatternMatcher([
|
||||
(UPat(Ops.STAGE, src=(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t"), UPat()), name="b"), late_buffer_view),
|
||||
|
|
@ -559,11 +559,11 @@ def split_store(x:UOp) -> UOp|None:
|
|||
lctx = LocalAddBufferContext()
|
||||
ret = graph_rewrite(x, to_define_global+pm_flatten_range+rangeify_codegen, ctx=lctx, name="kernel split", bottom_up=True)
|
||||
|
||||
# SINK requires all buffers on the same device, but COPY/BUFFER_VIEW are cross-device or special hardware ops
|
||||
# SINK requires all buffers on the same device, but COPY/SLICE are cross-device or special hardware ops
|
||||
if ret.op is Ops.STORE: stored = ret.src[1]
|
||||
elif ret.op is Ops.END and ret.src[0].op is Ops.STORE: stored = ret.src[0].src[1]
|
||||
else: raise RuntimeError(f"unknown kernel type {ret.op}")
|
||||
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW}: ret = stored.replace(src=stored.src + ret.ended_ranges)
|
||||
if stored.op in {Ops.COPY, Ops.SLICE}: ret = stored.replace(src=stored.src + ret.ended_ranges)
|
||||
else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts))
|
||||
|
||||
kernel = ret.call(*lctx.map.values(), *lctx.vars.keys())
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ class Ops(FastEnum):
|
|||
CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto()
|
||||
|
||||
# buffer ops
|
||||
STAGE = auto(); COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto(); CUSTOM_FUNCTION = auto()
|
||||
STAGE = auto(); COPY = auto(); BUFFER = auto(); SLICE = auto(); MSELECT = auto(); MSTACK = auto(); CUSTOM_FUNCTION = auto()
|
||||
|
||||
# the core 6 movement ops! these only exist in the tensor graph
|
||||
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto()
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisTy
|
|||
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
|
||||
|
||||
range_start = {Ops.STAGE: 1, Ops.REDUCE: 1, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.FUNCTION: 1,
|
||||
Ops.COPY: 2, Ops.BUFFER_VIEW: 2, Ops.LINEAR: 0}
|
||||
Ops.COPY: 2, Ops.SLICE: 2, Ops.LINEAR: 0}
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:Ops, dt:DType) -> PyConst: return dt.const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dt.min}[op])
|
||||
|
|
@ -105,7 +105,7 @@ class UOpMetaClass(type):
|
|||
return created
|
||||
|
||||
# some uops map to other stuff
|
||||
buffers:weakref.WeakKeyDictionary[UOp, Buffer|MultiBuffer] = weakref.WeakKeyDictionary() # this maps BUFFER/BUFFER_VIEW uops to their device Buffers
|
||||
buffers:weakref.WeakKeyDictionary[UOp, Buffer|MultiBuffer] = weakref.WeakKeyDictionary() # this maps BUFFER/SLICE uops to their device Buffers
|
||||
all_metadata:weakref.WeakKeyDictionary[UOp, tuple[Metadata, ...]] = weakref.WeakKeyDictionary() # TODO: should this be here?
|
||||
|
||||
# recursive_property replaces functools.cached_property in recursive UOp functions to prevent RecursionError
|
||||
|
|
@ -266,8 +266,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
case Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
|
||||
case Ops.BINARY: return (len(self.arg),)
|
||||
case Ops.BUFFER: return (self.arg,)
|
||||
case Ops.BUFFER_VIEW:
|
||||
# HACK: BUFFER_VIEW is used inside kernels, so we set the shape to () if it's on an INDEX
|
||||
case Ops.SLICE:
|
||||
# HACK: SLICE is used inside kernels, so we set the shape to () if it's on an INDEX
|
||||
if self.src[0].op is Ops.INDEX: return ()
|
||||
return (self.arg,)
|
||||
case Ops.CUSTOM_FUNCTION: return None
|
||||
|
|
@ -763,7 +763,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
"""Check if this UOp has a concrete buffer identity in the graph (RESHAPE/MULTI -> BUFFER chain)."""
|
||||
if self.op in {Ops.RESHAPE, Ops.MULTI}: return self.src[0].has_buffer_identity()
|
||||
if self.op is Ops.GETTUPLE and self.src[0].op is Ops.TUPLE: return self.src[0].src[self.arg].has_buffer_identity()
|
||||
return self.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.PARAM}
|
||||
return self.op in {Ops.BUFFER, Ops.SLICE, Ops.PARAM}
|
||||
|
||||
def _base_buffer_is_realized(self) -> bool:
|
||||
"""Walk through AFTER chain to find if the underlying buffer is realized (has allocated memory)."""
|
||||
|
|
@ -785,7 +785,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
buf = self.src[0].buffer
|
||||
assert isinstance(buf, Buffer), "must be a Buffer for BITCAST"
|
||||
return buf.view(prod(self.max_shape), self.dtype, 0)
|
||||
if self.op is Ops.BUFFER_VIEW:
|
||||
if self.op is Ops.SLICE:
|
||||
if (cret:=buffers.get(self)) is not None: return cret
|
||||
buf = self.src[0].buffer
|
||||
offset = self.src[1].arg
|
||||
|
|
@ -794,7 +794,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
mbuf.bufs = [b.view(self.arg, self.dtype, offset * self.src[0].dtype.itemsize) for b in buf.bufs]
|
||||
buffers[self] = mbuf
|
||||
return mbuf
|
||||
assert isinstance(buf, Buffer), "must be a Buffer for BUFFER_VIEW"
|
||||
assert isinstance(buf, Buffer), "must be a Buffer for SLICE"
|
||||
buffers[self] = bv = buf.view(self.arg, self.dtype, offset * self.src[0].dtype.itemsize)
|
||||
return bv
|
||||
if self.op is Ops.MSELECT:
|
||||
|
|
@ -1012,7 +1012,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return p
|
||||
|
||||
# opaque bodies stay as Ops.CALL; value-producing bodies become Ops.FUNCTION (wrapped in TUPLE)
|
||||
_OPAQUE_CALL_BODIES = {Ops.SINK, Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.BUFFER_VIEW, Ops.CUSTOM_FUNCTION}
|
||||
_OPAQUE_CALL_BODIES = {Ops.SINK, Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.SLICE, Ops.CUSTOM_FUNCTION}
|
||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(),
|
||||
name:str|None=None, precompile:bool=False, precompile_backward:bool=False) -> UOp:
|
||||
assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
|
||||
|
|
|
|||
|
|
@ -221,14 +221,14 @@ spec_program = PatternMatcher([
|
|||
|
||||
# these are intermediate ops. everything should be deleted from here
|
||||
spec_full = PatternMatcher([
|
||||
# BUFFER_VIEW on BUFFER is allowed if BUFFER is
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.BUFFER, Ops.PARAM)), UPat(Ops.CONST, dtype=dtypes.weakint)), allow_any_len=True, name="bv"),
|
||||
# SLICE on BUFFER is allowed if BUFFER is
|
||||
(UPat(Ops.SLICE, src=(UPat((Ops.BUFFER, Ops.PARAM)), UPat(Ops.CONST, dtype=dtypes.weakint)), allow_any_len=True, name="bv"),
|
||||
lambda bv: isinstance(bv.arg, int)),
|
||||
|
||||
# TODO: BUFFER_VIEW shouldn't go on INDEX. why is this allowed? remove these both
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.INDEX,)), UPat(Ops.CONST, dtype=dtypes.weakint)), allow_any_len=True, name="bv"),
|
||||
# TODO: SLICE shouldn't go on INDEX. why is this allowed? remove these both
|
||||
(UPat(Ops.SLICE, src=(UPat((Ops.INDEX,)), UPat(Ops.CONST, dtype=dtypes.weakint)), allow_any_len=True, name="bv"),
|
||||
lambda bv: isinstance(bv.arg, int)),
|
||||
(UPat(Ops.CALL, src=(UPat((Ops.BUFFER_VIEW,)),), allow_any_len=True), lambda: True),
|
||||
(UPat(Ops.CALL, src=(UPat((Ops.SLICE,)),), allow_any_len=True), lambda: True),
|
||||
|
||||
# codegen may end ranges after gpudims has replaced RANGE with SPECIAL.
|
||||
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True), lambda: True),
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
|||
Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.GETADDR: "#9DB1F0", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
|
||||
Ops.SLICE: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.GETADDR: "#9DB1F0", Ops.COPY: "#a040a0", Ops.CUSTOM_FUNCTION: "#bf71b6",
|
||||
Ops.CALL: "#00B7C8", Ops.FUNCTION: "#C07788", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.BINARY: "#404040",
|
||||
Ops.LINEAR: "#7DF4FF",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue