mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
d8f86be613
commit
689ab6a49f
14 changed files with 33 additions and 28 deletions
|
|
@ -28,8 +28,8 @@ pm_insert_deps = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), insert_deps)]
|
|||
|
||||
pm_replace_params = PatternMatcher([
|
||||
(UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.input_addrs_uop.index(UOp.const(dtypes.int, p.arg))),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.INDEX, name="addr"),), name="bv"),
|
||||
lambda ctx, bv, addr: addr.cast(dtypes.uint64) + UOp.const(dtypes.uint64, bv.arg[1] * bv.dtype.itemsize)),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.INDEX, name="addr"), UPat(Ops.CONST, dtype=dtypes.weakint, name="off")), name="bv"),
|
||||
lambda ctx, bv, addr, off: addr.cast(dtypes.uint64) + UOp.const(dtypes.uint64, off.arg)),
|
||||
])
|
||||
|
||||
# **************** graph-only passes ****************
|
||||
|
|
|
|||
|
|
@ -291,7 +291,8 @@ def bufferize_kernargs(ctx:HCQ2LowerCtx, target:UOp, buf_node:UOp) -> UOp:
|
|||
dctx = ctx.dev_ctx[dev]
|
||||
isz = dctx.kernargs_host.dtype.base.itemsize
|
||||
off = dctx.kernargs_allocator.alloc(buf_node.arg, 16)
|
||||
hbufs.append(UOp(Ops.BUFFER_VIEW, dctx.kernargs_host.dtype, src=(dctx.kernargs_host,), arg=(buf_node.arg // isz, off // isz)))
|
||||
hbufs.append(UOp(Ops.BUFFER_VIEW, dctx.kernargs_host.dtype,
|
||||
src=(dctx.kernargs_host, UOp.const(dtypes.weakint, off)), arg=buf_node.arg // isz))
|
||||
addrs.append(dctx.kernargs_gpu + off)
|
||||
return _maybe_mstack(tuple(addrs)).after(*_lower_stores(_maybe_mstack(tuple(hbufs)), buf_node, target.src[1:]))
|
||||
|
||||
|
|
@ -370,7 +371,7 @@ def resolve_getaddr(ctx:HCQ2LowerCtx, m:UOp) -> UOp:
|
|||
pm_resolve_patches = symbolic + PatternMatcher([
|
||||
# resolve getaddrs
|
||||
(UPat(Ops.GETADDR, src=(UPat(Ops.BUFFER_VIEW, name="bv"),)), # getaddr(buffer_view(x)) -> offset+getaddr(x)
|
||||
lambda ctx, bv: UOp(Ops.GETADDR, dtypes.uint64, src=(bv.src[0],)) + UOp.const(dtypes.uint64, bv.arg[1] * bv.dtype.itemsize)),
|
||||
lambda ctx, bv: UOp(Ops.GETADDR, dtypes.uint64, src=(bv.src[0],)) + UOp.const(dtypes.uint64, bv.src[1].arg)),
|
||||
(UPat(Ops.GETADDR, src=(UPat((Ops.BUFFER, Ops.MSTACK), name="m"),)), resolve_getaddr), # getaddr(buffer|mstack) -> addr_table load|const
|
||||
(UPat(Ops.GETADDR, src=(UPat.cvar("const"),)), lambda ctx, const: const), # getaddr(const) -> const
|
||||
|
||||
|
|
@ -391,7 +392,7 @@ def parametrize_host_buffer(ctx:HCQ2LowerCtx, buf:UOp) -> UOp:
|
|||
pm_parametrize_host_buffers = PatternMatcher([
|
||||
# resolve buffer views to parametrize only root buffers
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.BUFFER_VIEW, name="bv"), UPat.var("idx")), name="bi"),
|
||||
lambda bv, idx, bi: bi.replace(src=(bv.src[0], idx + bv.arg[1]))),
|
||||
lambda bv, idx, bi: bi.replace(src=(bv.src[0], idx + bv.src[1].arg // bv.src[0].dtype.itemsize))),
|
||||
|
||||
# parametrize host buffers
|
||||
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK)),), allow_any_len=True, name="buf"), parametrize_host_buffer),
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -53,8 +53,6 @@ All nodes in the tinygrad graph are \textbf{UOps}. A UOp is a tuple $(\mathrm{op
|
|||
Placeholder with shape $\mathbf{s}$. Substituted in \op{Function}. \\[4pt]
|
||||
\op{Buffer} & () & size, dtype, device, addrspace &
|
||||
Shape $(n \cdot \textit{size},)$ if device is $n$-tuple, else $(\textit{size},)$. \\
|
||||
\op{BufferView} & (buf,) & size, dtype, offset &
|
||||
Typed access into a buffer. Zero-copy $(\textit{size},)$ slice at offset; inherits addrspace. \\
|
||||
\op{Const} & () & value, dtype &
|
||||
A scalar constant with shape $(\ )$. \\
|
||||
& & & Form vector consts with \op{Stack} \\
|
||||
|
|
@ -66,7 +64,7 @@ All nodes in the tinygrad graph are \textbf{UOps}. A UOp is a tuple $(\mathrm{op
|
|||
A \op{Buffer}'s \textbf{addrspace} is \texttt{GLOBAL}, \texttt{LOCAL}, or \texttt{REG}.
|
||||
|
||||
%% ============================================================
|
||||
\subsection*{{\color{movgreen}Movement Ops} \normalfont\small--- no arithmetic, shapes are $(k,)$-shaped UOps with dtype \texttt{index} in src}
|
||||
\subsection*{{\color{movgreen}Movement Ops} \normalfont\small--- no arithmetic; view, indexing, and reinterpretation only}
|
||||
|
||||
\begin{tabular}{@{}l l l l@{}}
|
||||
\toprule
|
||||
|
|
@ -81,6 +79,7 @@ A \op{Buffer}'s \textbf{addrspace} is \texttt{GLOBAL}, \texttt{LOCAL}, or \textt
|
|||
\op{Index} & $(T, i_0, i_1, \ldots)$ & --- & Index from left. $()$-shaped $i$ removes dim; $(k,)$-shaped makes it $k$. \\
|
||||
\op{Stack} & $(T_0, T_1, \ldots)$ & --- & Join along a newly created leading axis. All shapes must match. \\
|
||||
\op{Replicated} & $(T,)$ & axes & Mark $T$ as replicated along axes. Collapse axes to $1$. \\
|
||||
\op{BufferView} & $(T, \mathrm{offset})$ & size, dtype & Zero-copy \textit{size} elems of dtype; offset is \texttt{weakint} bytes. \\
|
||||
\bottomrule
|
||||
\end{tabular}
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ def check_assign(buffer_lists, copies=None):
|
|||
for orig_si, new_si in zip(linear.src, result.src):
|
||||
for orig, new in zip(orig_si.src[1:], new_si.src[1:]):
|
||||
if new.op is Ops.BUFFER_VIEW and id(orig) not in replace_map:
|
||||
replace_map[id(orig)] = (new.src[0], new.arg[1] * new.dtype.itemsize, new.arg[0] * new.dtype.itemsize)
|
||||
replace_map[id(orig)] = (new.src[0], new.src[1].arg, new.arg * new.dtype.itemsize)
|
||||
|
||||
# verify pinned buffers are not planned
|
||||
for buf in held_bufs:
|
||||
|
|
|
|||
|
|
@ -166,8 +166,8 @@ class TestMultiOutputGradient(unittest.TestCase):
|
|||
c, d, e, _, _ = Tensor.custom_kernel(Tensor.empty(4, 4), Tensor.empty(4, 4), Tensor.empty(4, 4), a, b,
|
||||
fxn=addmulsub_kernel, grad_fxn=backward_addmulsub)
|
||||
(c.sum() + d.sum() + e.sum()).backward()
|
||||
np.testing.assert_allclose(a.grad.numpy(), a_ref.grad.numpy(), rtol=1e-5)
|
||||
np.testing.assert_allclose(b.grad.numpy(), b_ref.grad.numpy(), rtol=1e-5)
|
||||
np.testing.assert_allclose(a.grad.numpy(), a_ref.grad.numpy(), atol=1e-6, rtol=1e-5)
|
||||
np.testing.assert_allclose(b.grad.numpy(), b_ref.grad.numpy(), atol=1e-6, rtol=1e-5)
|
||||
|
||||
class TestViewGradient(unittest.TestCase):
|
||||
def test_expand(self):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
from tinygrad import Device
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "METAL", "Metal device required to run")
|
||||
|
|
@ -15,7 +15,7 @@ class TestMetalGraph(unittest.TestCase):
|
|||
buf = MagicMock()
|
||||
if offset > 0:
|
||||
buf.op = Ops.BUFFER_VIEW
|
||||
buf.arg = (None, offset)
|
||||
buf.src = (None, UOp.const(dtypes.weakint, offset))
|
||||
buf.dtype = dtypes.uint8
|
||||
else:
|
||||
buf.op = Ops.BUFFER
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, track_rewrites
|
||||
from tinygrad.helpers import VIZ, pluralize, all_int
|
||||
|
||||
|
|
@ -59,8 +60,9 @@ def _make_buffer_view(src:UOp) -> UOp|None:
|
|||
"""If movement ops on src collapse to a contiguous range, return BUFFER_VIEW.reshape(src.shape). Otherwise None."""
|
||||
if (offset := src.contiguous_view_offset()) is None: return None
|
||||
buf = src.base
|
||||
if buf.op is Ops.BUFFER_VIEW: offset, buf = offset + buf.arg[1], buf.src[0]
|
||||
return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (src.numel(), offset)).reshape(src.shape)
|
||||
offset *= src.dtype.itemsize
|
||||
if buf.op is Ops.BUFFER_VIEW: offset, buf = offset + buf.src[1].arg, buf.src[0]
|
||||
return UOp(Ops.BUFFER_VIEW, src.dtype, (buf, UOp.const(dtypes.weakint, offset)), src.numel()).reshape(src.shape)
|
||||
|
||||
def contiguous_mops_to_view(c:UOp, src:UOp):
|
||||
"""CONTIGUOUS(MOPS(BUFFER)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to a contiguous range."""
|
||||
|
|
@ -184,7 +186,7 @@ pm_replace_buf = PatternMatcher([
|
|||
# replace BUFFER with PARAM for cache key normalization
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
|
||||
# replace BUFFER_VIEW with PARAM. this rewrite is bottom up so BUFFERs we don't need won't be in the input
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="b"), replace_input_buffer),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.CONST, dtype=dtypes.weakint)), name="b"), replace_input_buffer),
|
||||
# strip value from BIND for cache key normalization, so different values hit same cache
|
||||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), replace_input_buffer),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ def get_call_name(call:UOp, bufs:list[Buffer], var_vals:dict[str, int]|None=None
|
|||
|
||||
ast, arg_uops = call.src[0], get_call_arg_uops(call)
|
||||
if ast.op is Ops.PROGRAM: return ast.arg.name
|
||||
if ast.op is Ops.BUFFER_VIEW: return colored(f"view {_uop_sz_to_str(arg_uops[0]):>10} @ {ast.arg[1] * arg_uops[1].dtype.itemsize:<10d}", "yellow")
|
||||
if ast.op is Ops.BUFFER_VIEW: return colored(f"view {_uop_sz_to_str(arg_uops[0]):>10} @ {ast.src[1].arg:<10d}", "yellow")
|
||||
if ast.op is Ops.COPY: return colored(f"copy {_uop_sz_to_str(arg_uops[0]):>10}, {bufs[0].device[:7]:>7s} <- {bufs[1].device[:7]:7s}", "yellow")
|
||||
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "encdec": return colored(f"enc/dec {_uop_sz_to_str(arg_uops[0])}", "yellow")
|
||||
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph": return colored(f"batched {len(ast.src[0].src)}", "cyan")
|
||||
|
|
@ -149,7 +149,7 @@ def unwrap_multi(call:UOp, resolved:list[UOp]) -> Iterator[tuple[list[Buffer], d
|
|||
def exec_view(ctx:ExecContext, call:UOp, ast:UOp) -> float|None:
|
||||
resolved = resolve_params(call, ctx.input_uops)
|
||||
bufs = [cast(Buffer, b.buffer) for b in resolved]
|
||||
bv = bufs[1].view(resolved[0].arg, ast.dtype, ast.arg[1]*bufs[1].dtype.itemsize)
|
||||
bv = bufs[1].view(resolved[0].arg, ast.dtype, ast.src[1].arg)
|
||||
with track_stats(ctx, call, bv.device, [bv, bufs[1]], ctx.var_vals): buffers[resolved[0]] = bv
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -109,5 +109,5 @@ class MetalGraph(GraphRunner):
|
|||
@staticmethod
|
||||
def supports_uop(batch_devs, new_call:UOp) -> bool:
|
||||
# Metal ICB replay encodes offsets as uint32; reject if any Metal buffer offset exceeds 32-bit range.
|
||||
if any(b.op is Ops.BUFFER_VIEW and b.arg[1] * b.dtype.itemsize > 0xFFFFFFFF for b in new_call.src[1:]): return False
|
||||
if any(b.op is Ops.BUFFER_VIEW and b.src[1].arg > 0xFFFFFFFF for b in new_call.src[1:]): return False
|
||||
return GraphRunner.supports_uop(batch_devs, new_call)
|
||||
|
|
|
|||
|
|
@ -56,8 +56,7 @@ def memory_plan_rewrite(linear:UOp, held_bufs:set[UOp]|None=None) -> UOp:
|
|||
arenas = {key: UOp.new_buffer(key[0], sz, dtypes.int8) for key, sz in arena_sizes.items()}
|
||||
replace_map:dict[UOp, UOp] = {}
|
||||
for buf_uop, offset in offsets.items():
|
||||
assert offset % buf_uop.dtype.itemsize == 0, f"offset {offset} not aligned to {buf_uop.dtype.itemsize}"
|
||||
replace_map[buf_uop] = UOp(Ops.BUFFER_VIEW, buf_uop.dtype, (arenas[_key(buf_uop)],), (buf_uop.arg, offset // buf_uop.dtype.itemsize))
|
||||
replace_map[buf_uop] = UOp(Ops.BUFFER_VIEW, buf_uop.dtype, (arenas[_key(buf_uop)], UOp.const(dtypes.weakint, offset)), buf_uop.arg)
|
||||
|
||||
if DEBUG >= 1 and (omem:=sum(nbytes.values()) / 1e6) != (nmem:=sum(arena_sizes.values()) / 1e6):
|
||||
print(f"memory reduced from {omem:.2f} MB -> {nmem:.2f} MB, {len(first_appearance)} -> {len(arenas)} bufs")
|
||||
|
|
|
|||
|
|
@ -350,8 +350,9 @@ def late_buffer_view(t:UOp, b:UOp):
|
|||
|
||||
if len(shape) == 0: offset = x.src[1].arg
|
||||
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
|
||||
offset *= x.base.dtype.itemsize
|
||||
|
||||
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset)), b.src[1]))
|
||||
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base, UOp.const(dtypes.weakint, offset)), size), b.src[1]))
|
||||
|
||||
to_bufferview = PatternMatcher([
|
||||
(UPat(Ops.STAGE, src=(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t"), UPat()), name="b"), late_buffer_view),
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisTy
|
|||
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
|
||||
|
||||
range_start = {Ops.STAGE: 1, Ops.REDUCE: 1, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.FUNCTION: 1,
|
||||
Ops.COPY: 2, Ops.BUFFER_VIEW: 1, Ops.LINEAR: 0}
|
||||
Ops.COPY: 2, Ops.BUFFER_VIEW: 2, Ops.LINEAR: 0}
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:Ops, dt:DType) -> PyConst: return dt.const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dt.min}[op])
|
||||
|
|
@ -269,7 +269,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
case Ops.BUFFER_VIEW:
|
||||
# HACK: BUFFER_VIEW is used inside kernels, so we set the shape to () if it's on an INDEX
|
||||
if self.src[0].op is Ops.INDEX: return ()
|
||||
return (self.arg[0],)
|
||||
return (self.arg,)
|
||||
case Ops.CUSTOM_FUNCTION: return None
|
||||
case Ops.STAGE:
|
||||
# STAGE adds the existing shape to the front, opposite of INDEX
|
||||
|
|
@ -796,13 +796,14 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
if self.op is Ops.BUFFER_VIEW:
|
||||
if (cret:=buffers.get(self)) is not None: return cret
|
||||
buf = self.src[0].buffer
|
||||
offset = self.src[1].arg
|
||||
if isinstance(buf, MultiBuffer):
|
||||
mbuf = MultiBuffer.__new__(MultiBuffer)
|
||||
mbuf.bufs = [b.view(self.arg[0], self.dtype, self.arg[1] * self.dtype.itemsize) for b in buf.bufs]
|
||||
mbuf.bufs = [b.view(self.arg, self.dtype, offset) for b in buf.bufs]
|
||||
buffers[self] = mbuf
|
||||
return mbuf
|
||||
assert isinstance(buf, Buffer), "must be a Buffer for BUFFER_VIEW"
|
||||
buffers[self] = bv = buf.view(self.arg[0], self.dtype, self.arg[1] * self.dtype.itemsize)
|
||||
buffers[self] = bv = buf.view(self.arg, self.dtype, offset)
|
||||
return bv
|
||||
if self.op is Ops.MSELECT:
|
||||
ret = self.src[0].buffer
|
||||
|
|
|
|||
|
|
@ -223,10 +223,12 @@ spec_program = PatternMatcher([
|
|||
# these are intermediate ops. everything should be deleted from here
|
||||
spec_full = PatternMatcher([
|
||||
# BUFFER_VIEW on BUFFER is allowed if BUFFER is
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.BUFFER, Ops.PARAM)),)), lambda: True),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.BUFFER, Ops.PARAM)), UPat(Ops.CONST, dtype=dtypes.weakint)), allow_any_len=True, name="bv"),
|
||||
lambda bv: isinstance(bv.arg, int)),
|
||||
|
||||
# TODO: BUFFER_VIEW shouldn't go on INDEX. why is this allowed? remove these both
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.INDEX,)),), allow_any_len=True), lambda: True),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.INDEX,)), UPat(Ops.CONST, dtype=dtypes.weakint)), allow_any_len=True, name="bv"),
|
||||
lambda bv: isinstance(bv.arg, int)),
|
||||
(UPat(Ops.CALL, src=(UPat((Ops.BUFFER_VIEW,)),), allow_any_len=True), lambda: True),
|
||||
|
||||
# codegen may end ranges after gpudims has replaced RANGE with SPECIAL.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue