mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
slice_like
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e3535abe96 |
13 changed files with 42 additions and 26 deletions
|
|
@ -28,8 +28,8 @@ pm_insert_deps = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), insert_deps)]
|
||||||
|
|
||||||
pm_replace_params = PatternMatcher([
|
pm_replace_params = PatternMatcher([
|
||||||
(UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.input_addrs_uop.index(UOp.const(dtypes.int, p.arg))),
|
(UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.input_addrs_uop.index(UOp.const(dtypes.int, p.arg))),
|
||||||
(UPat(Ops.SLICE, src=(UPat(Ops.INDEX, name="addr"), UPat(Ops.CONST, dtype=dtypes.weakint, name="off")), name="bv"),
|
(UPat(Ops.SLICE, src=(UPat(Ops.INDEX, name="addr"), UPat(dtype=dtypes.weakint), UPat(dtype=dtypes.weakint)), name="bv"),
|
||||||
lambda ctx, bv, addr, off: addr.cast(dtypes.uint64) + UOp.const(dtypes.uint64, off.arg * ctx.input_uops[addr.src[1].arg].dtype.itemsize)),
|
lambda ctx, bv, addr: addr.cast(dtypes.uint64) + UOp.const(dtypes.uint64, bv.slice_offset() * ctx.input_uops[addr.src[1].arg].dtype.itemsize)),
|
||||||
])
|
])
|
||||||
|
|
||||||
# **************** graph-only passes ****************
|
# **************** graph-only passes ****************
|
||||||
|
|
|
||||||
|
|
@ -330,7 +330,7 @@ def fold_const_store(buf:UOp, off:UOp, val:UOp) -> UOp:
|
||||||
|
|
||||||
pm_resolve_patches = symbolic + PatternMatcher([
|
pm_resolve_patches = symbolic + PatternMatcher([
|
||||||
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), # getaddr(slice(x)) -> offset+getaddr(x)
|
(UPat(Ops.GETADDR, src=(UPat(Ops.SLICE, name="bv"), UPat(Ops.DEVICE, name="dev"))), # getaddr(slice(x)) -> offset+getaddr(x)
|
||||||
lambda bv, dev: UOp(Ops.GETADDR, dtypes.uint64, src=(bv.src[0], dev)) + UOp.const(dtypes.uint64, bv.src[1].arg * bv.src[0].dtype.itemsize)),
|
lambda bv, dev: UOp(Ops.GETADDR, dtypes.uint64, src=(bv.src[0], dev)) + UOp.const(dtypes.uint64, bv.slice_offset() * bv.src[0].dtype.itemsize)),
|
||||||
(UPat(Ops.GETADDR, src=(UPat(Ops.BUFFER, name="buf"), UPat(Ops.DEVICE)), name="g"),
|
(UPat(Ops.GETADDR, src=(UPat(Ops.BUFFER, name="buf"), UPat(Ops.DEVICE)), name="g"),
|
||||||
lambda buf, g: UOp.const(dtypes.uint64, buf.buffer.get_buf(g.src[1].arg).va_addr)),
|
lambda buf, g: UOp.const(dtypes.uint64, buf.buffer.get_buf(g.src[1].arg).va_addr)),
|
||||||
(UPat(Ops.GETADDR, src=(UPat.cvar("const"), UPat())), lambda const: const),
|
(UPat(Ops.GETADDR, src=(UPat.cvar("const"), UPat())), lambda const: const),
|
||||||
|
|
|
||||||
Binary file not shown.
|
|
@ -74,12 +74,12 @@ A \op{Buffer}'s \textbf{addrspace} is \texttt{GLOBAL}, \texttt{LOCAL}, or \textt
|
||||||
\op{Flip} & $(T,)$ & bools $\mathbf{f}$ & Reverse along flagged axes. \\
|
\op{Flip} & $(T,)$ & bools $\mathbf{f}$ & Reverse along flagged axes. \\
|
||||||
\op{Reshape} & $(T, \mathbf{s'})$ & --- & Reinterpret in row-major order. $\prod s_k = \prod s'_k$. \\
|
\op{Reshape} & $(T, \mathbf{s'})$ & --- & Reinterpret in row-major order. $\prod s_k = \prod s'_k$. \\
|
||||||
\op{Expand} & $(T, \mathbf{s'})$ & --- & Broadcast size-1 axes. $s_k \in \{1, s'_k\}$. \\
|
\op{Expand} & $(T, \mathbf{s'})$ & --- & Broadcast size-1 axes. $s_k \in \{1, s'_k\}$. \\
|
||||||
\op{Pad} & $(T, \mathbf{s'}, \mathbf{o})$ & --- & Place $T$ at offset $o_k$ in a zero-filled output of shape $s'_k$. \\
|
\op{Pad} & $(T, \mathbf{s'}, \mathbf{o})$ & --- & Place $T$ at offset $o_k$ in an output of shape $s'_k$; outside $T$ is invalid. \\
|
||||||
\op{Shrink} & $(T, \mathbf{s'}, \mathbf{o})$ & --- & Keep $s'_k$ elements starting at offset $o_k$ per axis. Inverse of \op{Pad}. \\
|
\op{Shrink} & $(T, \mathbf{s'}, \mathbf{o})$ & --- & Keep $s'_k$ elements starting at offset $o_k$ per axis. Inverse of \op{Pad}. \\
|
||||||
\op{Index} & $(T, i_0, i_1, \ldots)$ & --- & Index from left. $()$-shaped $i$ removes dim; $(k,)$-shaped makes it $k$. \\
|
\op{Index} & $(T, i_0, i_1, \ldots)$ & --- & Index from left. $()$-shaped $i$ removes dim; $(k,)$-shaped makes it $k$. \\
|
||||||
\op{Stack} & $(T_0, T_1, \ldots)$ & --- & Join along a newly created leading axis. All shapes must match. \\
|
\op{Stack} & $(T_0, T_1, \ldots)$ & --- & Join along a newly created leading axis. All shapes must match. \\
|
||||||
\op{Replicated} & $(T,)$ & axes & Mark $T$ as replicated along axes. Collapse axes to $1$. \\
|
\op{Replicated} & $(T,)$ & axes & Mark $T$ as replicated along axes. Collapse axes to $1$. \\
|
||||||
\op{Slice} & $(T, \mathrm{offset})$ & size, dtype & Zero-copy \textit{size} elems of dtype; offset is elems of $T$ dtype. \\
|
\op{Slice} & $(T, \mathbf{s'}, \mathbf{o})$ & dtype & Zero-copy \op{Shrink}; dtype may reinterpret storage. \\
|
||||||
\op{Bitcast} & $(T,)$ & dtype & Reinterpret storage as target dtype; preserve total bytes. \\
|
\op{Bitcast} & $(T,)$ & dtype & Reinterpret storage as target dtype; preserve total bytes. \\
|
||||||
\bottomrule
|
\bottomrule
|
||||||
\end{tabular}
|
\end{tabular}
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ def check_assign(buffer_lists, copies=None):
|
||||||
for orig_si, new_si in zip(linear.src, result.src):
|
for orig_si, new_si in zip(linear.src, result.src):
|
||||||
for orig, new in zip(orig_si.src[1:], new_si.src[1:]):
|
for orig, new in zip(orig_si.src[1:], new_si.src[1:]):
|
||||||
if new.op is Ops.SLICE and id(orig) not in replace_map:
|
if new.op is Ops.SLICE and id(orig) not in replace_map:
|
||||||
replace_map[id(orig)] = (new.src[0], new.src[1].arg * new.src[0].dtype.itemsize, new.arg * new.dtype.itemsize)
|
replace_map[id(orig)] = (new.src[0], new.src[2].arg * new.src[0].dtype.itemsize, new.src[1].arg * new.dtype.itemsize)
|
||||||
|
|
||||||
# verify pinned buffers are not planned
|
# verify pinned buffers are not planned
|
||||||
for buf in held_bufs:
|
for buf in held_bufs:
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ class TestMetalGraph(unittest.TestCase):
|
||||||
buf.op = Ops.SLICE
|
buf.op = Ops.SLICE
|
||||||
src = MagicMock()
|
src = MagicMock()
|
||||||
src.dtype = dtypes.uint8
|
src.dtype = dtypes.uint8
|
||||||
buf.src = (src, UOp.const(dtypes.weakint, offset))
|
buf.src = (src, UOp.const(dtypes.weakint, 1), UOp.const(dtypes.weakint, offset))
|
||||||
buf.dtype = dtypes.uint8
|
buf.dtype = dtypes.uint8
|
||||||
else:
|
else:
|
||||||
buf.op = Ops.BUFFER
|
buf.op = Ops.BUFFER
|
||||||
|
|
|
||||||
|
|
@ -61,11 +61,13 @@ def _make_buffer_view(src:UOp) -> UOp|None:
|
||||||
if (offset := src.contiguous_view_offset()) is None: return None
|
if (offset := src.contiguous_view_offset()) is None: return None
|
||||||
buf = src.base
|
buf = src.base
|
||||||
if buf.op is Ops.SLICE:
|
if buf.op is Ops.SLICE:
|
||||||
byte_offset = buf.src[1].arg * buf.src[0].dtype.itemsize + offset * src.dtype.itemsize
|
slice_offset = buf.slice_offset()
|
||||||
|
assert isinstance(slice_offset, int), "SLICE offset must be concrete for buffer views"
|
||||||
|
byte_offset = slice_offset * buf.src[0].dtype.itemsize + offset * src.dtype.itemsize
|
||||||
buf = buf.src[0]
|
buf = buf.src[0]
|
||||||
if byte_offset % buf.dtype.itemsize != 0: return None
|
if byte_offset % buf.dtype.itemsize != 0: return None
|
||||||
offset = byte_offset // buf.dtype.itemsize
|
offset = byte_offset // buf.dtype.itemsize
|
||||||
return UOp(Ops.SLICE, src.dtype, (buf, UOp.const(dtypes.weakint, offset)), src.numel()).reshape(src.shape)
|
return UOp(Ops.SLICE, src.dtype, (buf, UOp.const(dtypes.weakint, src.numel()), UOp.const(dtypes.weakint, offset))).reshape(src.shape)
|
||||||
|
|
||||||
def contiguous_mops_to_view(c:UOp, src:UOp):
|
def contiguous_mops_to_view(c:UOp, src:UOp):
|
||||||
"""CONTIGUOUS(MOPS(BUFFER)) → CONTIGUOUS(SLICE) when movement ops collapse to a contiguous range."""
|
"""CONTIGUOUS(MOPS(BUFFER)) → CONTIGUOUS(SLICE) when movement ops collapse to a contiguous range."""
|
||||||
|
|
@ -189,7 +191,7 @@ pm_replace_buf = PatternMatcher([
|
||||||
# replace BUFFER with PARAM for cache key normalization
|
# replace BUFFER with PARAM for cache key normalization
|
||||||
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
|
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
|
||||||
# replace SLICE with PARAM. this rewrite is bottom up so BUFFERs we don't need won't be in the input
|
# replace SLICE with PARAM. this rewrite is bottom up so BUFFERs we don't need won't be in the input
|
||||||
(UPat(Ops.SLICE, src=(UPat(Ops.BUFFER), UPat(Ops.CONST, dtype=dtypes.weakint)), name="b"), replace_input_buffer),
|
(UPat(Ops.SLICE, src=(UPat(Ops.BUFFER), UPat(), UPat(Ops.CONST, dtype=dtypes.weakint)), name="b"), replace_input_buffer),
|
||||||
# strip value from BIND for cache key normalization, so different values hit same cache
|
# strip value from BIND for cache key normalization, so different values hit same cache
|
||||||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), replace_input_buffer),
|
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), replace_input_buffer),
|
||||||
])
|
])
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ def get_call_name(call:UOp, bufs:list[Buffer], var_vals:dict[str, int]|None=None
|
||||||
ast, arg_uops = call.src[0], get_call_arg_uops(call)
|
ast, arg_uops = call.src[0], get_call_arg_uops(call)
|
||||||
if ast.op is Ops.PROGRAM: return ast.arg.name
|
if ast.op is Ops.PROGRAM: return ast.arg.name
|
||||||
if ast.op is Ops.SLICE:
|
if ast.op is Ops.SLICE:
|
||||||
offset = ast.src[1].arg * arg_uops[1].dtype.itemsize
|
offset = ast.slice_offset() * arg_uops[1].dtype.itemsize
|
||||||
return colored(f"view {_uop_sz_to_str(arg_uops[0]):>10} @ {offset:<10d}", "yellow")
|
return colored(f"view {_uop_sz_to_str(arg_uops[0]):>10} @ {offset:<10d}", "yellow")
|
||||||
if ast.op is Ops.COPY: return colored(f"copy {_uop_sz_to_str(arg_uops[0]):>10}, {bufs[0].device[:7]:>7s} <- {bufs[1].device[:7]:7s}", "yellow")
|
if ast.op is Ops.COPY: return colored(f"copy {_uop_sz_to_str(arg_uops[0]):>10}, {bufs[0].device[:7]:>7s} <- {bufs[1].device[:7]:7s}", "yellow")
|
||||||
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "encdec": return colored(f"enc/dec {_uop_sz_to_str(arg_uops[0])}", "yellow")
|
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "encdec": return colored(f"enc/dec {_uop_sz_to_str(arg_uops[0])}", "yellow")
|
||||||
|
|
@ -151,7 +151,9 @@ def unwrap_multi(call:UOp, resolved:list[UOp]) -> Iterator[tuple[list[Buffer], d
|
||||||
def exec_view(ctx:ExecContext, call:UOp, ast:UOp) -> float|None:
|
def exec_view(ctx:ExecContext, call:UOp, ast:UOp) -> float|None:
|
||||||
resolved = resolve_params(call, ctx.input_uops)
|
resolved = resolve_params(call, ctx.input_uops)
|
||||||
bufs = [cast(Buffer, b.buffer) for b in resolved]
|
bufs = [cast(Buffer, b.buffer) for b in resolved]
|
||||||
bv = bufs[1].view(resolved[0].arg, ast.dtype, ast.src[1].arg*bufs[1].dtype.itemsize)
|
slice_off = ast.slice_offset()
|
||||||
|
assert isinstance(slice_off, int), "SLICE offset must be concrete for realized buffers"
|
||||||
|
bv = bufs[1].view(prod(ast.max_shape), ast.dtype, slice_off*bufs[1].dtype.itemsize)
|
||||||
with track_stats(ctx, call, bv.device, [bv, bufs[1]], ctx.var_vals): buffers[resolved[0]] = bv
|
with track_stats(ctx, call, bv.device, [bv, bufs[1]], ctx.var_vals): buffers[resolved[0]] = bv
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -109,5 +109,5 @@ class MetalGraph(GraphRunner):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def supports_uop(batch_devs, new_call:UOp) -> bool:
|
def supports_uop(batch_devs, new_call:UOp) -> bool:
|
||||||
# Metal ICB replay encodes offsets as uint32; reject if any Metal buffer offset exceeds 32-bit range.
|
# Metal ICB replay encodes offsets as uint32; reject if any Metal buffer offset exceeds 32-bit range.
|
||||||
if any(b.op is Ops.SLICE and b.src[1].arg * b.src[0].dtype.itemsize > 0xFFFFFFFF for b in new_call.src[1:]): return False
|
if any(b.op is Ops.SLICE and b.slice_offset() * b.src[0].dtype.itemsize > 0xFFFFFFFF for b in new_call.src[1:]): return False
|
||||||
return GraphRunner.supports_uop(batch_devs, new_call)
|
return GraphRunner.supports_uop(batch_devs, new_call)
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,8 @@ def memory_plan_rewrite(linear:UOp, held_bufs:set[UOp]|None=None) -> UOp:
|
||||||
arenas = {key: UOp.new_buffer(key[0], sz, dtypes.int8) for key, sz in arena_sizes.items()}
|
arenas = {key: UOp.new_buffer(key[0], sz, dtypes.int8) for key, sz in arena_sizes.items()}
|
||||||
replace_map:dict[UOp, UOp] = {}
|
replace_map:dict[UOp, UOp] = {}
|
||||||
for buf_uop, offset in offsets.items():
|
for buf_uop, offset in offsets.items():
|
||||||
replace_map[buf_uop] = UOp(Ops.SLICE, buf_uop.dtype, (arenas[_key(buf_uop)], UOp.const(dtypes.weakint, offset)), buf_uop.arg)
|
replace_map[buf_uop] = UOp(Ops.SLICE, buf_uop.dtype, (arenas[_key(buf_uop)], UOp.const(dtypes.weakint, buf_uop.arg),
|
||||||
|
UOp.const(dtypes.weakint, offset)))
|
||||||
|
|
||||||
if DEBUG >= 1 and (omem:=sum(nbytes.values()) / 1e6) != (nmem:=sum(arena_sizes.values()) / 1e6):
|
if DEBUG >= 1 and (omem:=sum(nbytes.values()) / 1e6) != (nmem:=sum(arena_sizes.values()) / 1e6):
|
||||||
print(f"memory reduced from {omem:.2f} MB -> {nmem:.2f} MB, {len(first_appearance)} -> {len(arenas)} bufs")
|
print(f"memory reduced from {omem:.2f} MB -> {nmem:.2f} MB, {len(first_appearance)} -> {len(arenas)} bufs")
|
||||||
|
|
|
||||||
|
|
@ -352,7 +352,7 @@ def late_buffer_view(t:UOp, b:UOp):
|
||||||
if len(shape) == 0: offset = x.src[1].arg
|
if len(shape) == 0: offset = x.src[1].arg
|
||||||
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
|
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
|
||||||
|
|
||||||
return b.replace(src=(UOp(Ops.SLICE, t.dtype, (x.src[0], UOp.const(dtypes.weakint, offset)), size),))
|
return b.replace(src=(UOp(Ops.SLICE, t.dtype, (x.src[0], UOp.const(dtypes.weakint, size), UOp.const(dtypes.weakint, offset))),))
|
||||||
|
|
||||||
to_bufferview = PatternMatcher([
|
to_bufferview = PatternMatcher([
|
||||||
(UPat(Ops.STAGE, src=(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t"), UPat()), name="b"), late_buffer_view),
|
(UPat(Ops.STAGE, src=(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t"), UPat()), name="b"), late_buffer_view),
|
||||||
|
|
@ -568,7 +568,8 @@ def split_store(x:UOp) -> UOp|None:
|
||||||
if ret.op is Ops.STORE: stored = ret.src[1]
|
if ret.op is Ops.STORE: stored = ret.src[1]
|
||||||
elif ret.op is Ops.END and ret.src[0].op is Ops.STORE: stored = ret.src[0].src[1]
|
elif ret.op is Ops.END and ret.src[0].op is Ops.STORE: stored = ret.src[0].src[1]
|
||||||
else: raise RuntimeError(f"unknown kernel type {ret.op}")
|
else: raise RuntimeError(f"unknown kernel type {ret.op}")
|
||||||
if stored.op in {Ops.COPY, Ops.SLICE}: ret = stored.replace(src=stored.src + ret.ended_ranges)
|
if stored.op is Ops.COPY: ret = stored.replace(src=stored.src + ret.ended_ranges)
|
||||||
|
elif stored.op is Ops.SLICE: ret = stored
|
||||||
else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts))
|
else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts))
|
||||||
|
|
||||||
kernel = ret.call(*lctx.map.values(), *lctx.vars.keys())
|
kernel = ret.call(*lctx.map.values(), *lctx.vars.keys())
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from tinygrad.dtype import ConstFloat, PyConst, storage_fmt_for_dtype, to_storag
|
||||||
from tinygrad.device import Buffer, MultiBuffer, canonicalize_device
|
from tinygrad.device import Buffer, MultiBuffer, canonicalize_device
|
||||||
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
|
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
|
||||||
from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, floordiv, floormod, diskcache_put, to_function_name, cpu_profile, TracingKey
|
from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, floordiv, floormod, diskcache_put, to_function_name, cpu_profile, TracingKey
|
||||||
from tinygrad.helpers import VIZ, SPEC, CAPTURE_PROCESS_REPLAY, DISALLOW_BROADCAST
|
from tinygrad.helpers import VIZ, SPEC, CAPTURE_PROCESS_REPLAY, DISALLOW_BROADCAST, strides_for_shape
|
||||||
from tinygrad.helpers import colored, ansilen, printable
|
from tinygrad.helpers import colored, ansilen, printable
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tinygrad.renderer import Estimates
|
from tinygrad.renderer import Estimates
|
||||||
|
|
@ -28,7 +28,7 @@ axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisTy
|
||||||
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
|
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
|
||||||
|
|
||||||
range_start = {Ops.STAGE: 1, Ops.REDUCE: 1, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.FUNCTION: 1,
|
range_start = {Ops.STAGE: 1, Ops.REDUCE: 1, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.FUNCTION: 1,
|
||||||
Ops.COPY: 2, Ops.SLICE: 2, Ops.LINEAR: 0}
|
Ops.COPY: 2, Ops.SLICE: 3, Ops.LINEAR: 0}
|
||||||
|
|
||||||
# https://en.wikipedia.org/wiki/Identity_element
|
# https://en.wikipedia.org/wiki/Identity_element
|
||||||
def identity_element(op:Ops, dt:DType) -> PyConst: return dt.const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dt.min}[op])
|
def identity_element(op:Ops, dt:DType) -> PyConst: return dt.const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dt.min}[op])
|
||||||
|
|
@ -267,9 +267,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
case Ops.BINARY: return (len(self.arg),)
|
case Ops.BINARY: return (len(self.arg),)
|
||||||
case Ops.BUFFER: return (self.arg,)
|
case Ops.BUFFER: return (self.arg,)
|
||||||
case Ops.SLICE:
|
case Ops.SLICE:
|
||||||
# HACK: SLICE is used inside kernels, so we set the shape to () if it's on an INDEX
|
return tuple(self.src[1].sgep(i) for i in range(self.src[1].dtype.count))
|
||||||
if self.src[0].op is Ops.INDEX: return ()
|
|
||||||
return (self.arg,)
|
|
||||||
case Ops.CUSTOM_FUNCTION: return None
|
case Ops.CUSTOM_FUNCTION: return None
|
||||||
case Ops.STAGE:
|
case Ops.STAGE:
|
||||||
# STAGE adds the existing shape to the front, opposite of INDEX
|
# STAGE adds the existing shape to the front, opposite of INDEX
|
||||||
|
|
@ -774,6 +772,13 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
if self.op is Ops.GETTUPLE and self.src[0].op is Ops.TUPLE: return self.src[0].src[self.arg].has_buffer_identity()
|
if self.op is Ops.GETTUPLE and self.src[0].op is Ops.TUPLE: return self.src[0].src[self.arg].has_buffer_identity()
|
||||||
return self.op in {Ops.BUFFER, Ops.SLICE, Ops.PARAM}
|
return self.op in {Ops.BUFFER, Ops.SLICE, Ops.PARAM}
|
||||||
|
|
||||||
|
def slice_offset(self) -> sint:
|
||||||
|
assert self.op is Ops.SLICE
|
||||||
|
offs = tuple(self.src[2].sgep(i) for i in range(self.src[2].dtype.count))
|
||||||
|
if len(offs) == 1: return offs[0]
|
||||||
|
assert len(offs) == len(self.src[0].shape), "multi-axis SLICE offset must match source rank"
|
||||||
|
return UOp.const(dtypes.weakint, 0).usum(*(off*st for off,st in zip(offs, strides_for_shape(self.src[0].shape))))
|
||||||
|
|
||||||
def _base_buffer_is_realized(self) -> bool:
|
def _base_buffer_is_realized(self) -> bool:
|
||||||
"""Walk through AFTER chain to find if the underlying buffer is realized (has allocated memory)."""
|
"""Walk through AFTER chain to find if the underlying buffer is realized (has allocated memory)."""
|
||||||
u = self.base
|
u = self.base
|
||||||
|
|
@ -797,14 +802,16 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
if self.op is Ops.SLICE:
|
if self.op is Ops.SLICE:
|
||||||
if (cret:=buffers.get(self)) is not None: return cret
|
if (cret:=buffers.get(self)) is not None: return cret
|
||||||
buf = self.src[0].buffer
|
buf = self.src[0].buffer
|
||||||
offset = self.src[1].arg
|
slice_off = self.slice_offset()
|
||||||
|
assert isinstance(slice_off, int), "SLICE offset must be concrete for realized buffers"
|
||||||
|
size = prod(self.max_shape)
|
||||||
if isinstance(buf, MultiBuffer):
|
if isinstance(buf, MultiBuffer):
|
||||||
mbuf = MultiBuffer.__new__(MultiBuffer)
|
mbuf = MultiBuffer.__new__(MultiBuffer)
|
||||||
mbuf.bufs = [b.view(self.arg, self.dtype, offset * self.src[0].dtype.itemsize) for b in buf.bufs]
|
mbuf.bufs = [b.view(size, self.dtype, slice_off * self.src[0].dtype.itemsize) for b in buf.bufs]
|
||||||
buffers[self] = mbuf
|
buffers[self] = mbuf
|
||||||
return mbuf
|
return mbuf
|
||||||
assert isinstance(buf, Buffer), "must be a Buffer for SLICE"
|
assert isinstance(buf, Buffer), "must be a Buffer for SLICE"
|
||||||
buffers[self] = bv = buf.view(self.arg, self.dtype, offset * self.src[0].dtype.itemsize)
|
buffers[self] = bv = buf.view(size, self.dtype, slice_off * self.src[0].dtype.itemsize)
|
||||||
return bv
|
return bv
|
||||||
if self.op is Ops.MSELECT:
|
if self.op is Ops.MSELECT:
|
||||||
ret = self.src[0].buffer
|
ret = self.src[0].buffer
|
||||||
|
|
|
||||||
|
|
@ -225,9 +225,12 @@ spec_program = PatternMatcher([
|
||||||
spec_full = PatternMatcher([
|
spec_full = PatternMatcher([
|
||||||
# SLICE on BUFFER is allowed if BUFFER is
|
# SLICE on BUFFER is allowed if BUFFER is
|
||||||
(UPat(Ops.SLICE, src=(UPat(GroupOp.Movement.union({Ops.BUFFER, Ops.PARAM, Ops.STAGE, Ops.AFTER})),
|
(UPat(Ops.SLICE, src=(UPat(GroupOp.Movement.union({Ops.BUFFER, Ops.PARAM, Ops.STAGE, Ops.AFTER})),
|
||||||
UPat(Ops.CONST, dtype=dtypes.weakint)), allow_any_len=True, name="bv"),
|
UPat(dtype=dtypes.weakint), UPat(dtype=dtypes.weakint)), name="bv"),
|
||||||
lambda bv: isinstance(bv.arg, int)),
|
lambda bv: bv.src[1].dtype.count == bv.src[2].dtype.count and all(x >= 0 for x in bv.shape)),
|
||||||
|
|
||||||
|
# TODO: SLICE shouldn't go on INDEX. why is this allowed? remove these both
|
||||||
|
(UPat(Ops.SLICE, src=(UPat((Ops.INDEX,)), UPat(dtype=dtypes.weakint), UPat(dtype=dtypes.weakint)),
|
||||||
|
name="bv"), lambda bv: bv.src[1].dtype.count == bv.src[2].dtype.count and all(x >= 0 for x in bv.shape)),
|
||||||
(UPat(Ops.CALL, src=(UPat((Ops.SLICE,)),), allow_any_len=True), lambda: True),
|
(UPat(Ops.CALL, src=(UPat((Ops.SLICE,)),), allow_any_len=True), lambda: True),
|
||||||
|
|
||||||
# codegen may end ranges after gpudims has replaced RANGE with SPECIAL.
|
# codegen may end ranges after gpudims has replaced RANGE with SPECIAL.
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue