bufferview offset is units of input dtype (#16378)

This commit is contained in:
George Hotz 2026-05-26 08:49:31 -07:00 committed by GitHub
commit 7f1b02854e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 21 additions and 15 deletions

View file

@ -29,7 +29,7 @@ pm_insert_deps = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), insert_deps)]
pm_replace_params = PatternMatcher([
(UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.input_addrs_uop.index(UOp.const(dtypes.int, p.arg))),
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.INDEX, name="addr"), UPat(Ops.CONST, dtype=dtypes.weakint, name="off")), name="bv"),
lambda ctx, bv, addr, off: addr.cast(dtypes.uint64) + UOp.const(dtypes.uint64, off.arg)),
lambda ctx, bv, addr, off: addr.cast(dtypes.uint64) + UOp.const(dtypes.uint64, off.arg * ctx.input_uops[addr.src[1].arg].dtype.itemsize)),
])
# **************** graph-only passes ****************

View file

@ -292,7 +292,7 @@ def bufferize_kernargs(ctx:HCQ2LowerCtx, target:UOp, buf_node:UOp) -> UOp:
isz = dctx.kernargs_host.dtype.base.itemsize
off = dctx.kernargs_allocator.alloc(buf_node.arg, 16)
hbufs.append(UOp(Ops.BUFFER_VIEW, dctx.kernargs_host.dtype,
src=(dctx.kernargs_host, UOp.const(dtypes.weakint, off)), arg=buf_node.arg // isz))
src=(dctx.kernargs_host, UOp.const(dtypes.weakint, off // isz)), arg=buf_node.arg // isz))
addrs.append(dctx.kernargs_gpu + off)
return _maybe_mstack(tuple(addrs)).after(*_lower_stores(_maybe_mstack(tuple(hbufs)), buf_node, target.src[1:]))
@ -366,7 +366,7 @@ def resolve_getaddr(ctx:HCQ2LowerCtx, m:UOp) -> UOp:
pm_resolve_patches = symbolic + PatternMatcher([
# resolve getaddrs
(UPat(Ops.GETADDR, src=(UPat(Ops.BUFFER_VIEW, name="bv"),)), # getaddr(buffer_view(x)) -> offset+getaddr(x)
lambda ctx, bv: UOp(Ops.GETADDR, dtypes.uint64, src=(bv.src[0],)) + UOp.const(dtypes.uint64, bv.src[1].arg)),
lambda ctx, bv: UOp(Ops.GETADDR, dtypes.uint64, src=(bv.src[0],)) + UOp.const(dtypes.uint64, bv.src[1].arg * bv.src[0].dtype.itemsize)),
(UPat(Ops.GETADDR, src=(UPat((Ops.BUFFER, Ops.MSTACK), name="m"),)), resolve_getaddr), # getaddr(buffer|mstack) -> addr_table load|const
(UPat(Ops.GETADDR, src=(UPat.cvar("const"),)), lambda ctx, const: const), # getaddr(const) -> const
@ -387,7 +387,7 @@ def parametrize_host_buffer(ctx:HCQ2LowerCtx, buf:UOp) -> UOp:
pm_parametrize_host_buffers = PatternMatcher([
# resolve buffer views to parametrize only root buffers
(UPat(Ops.INDEX, src=(UPat(Ops.BUFFER_VIEW, name="bv"), UPat.var("idx")), name="bi"),
lambda bv, idx, bi: bi.replace(src=(bv.src[0], idx + bv.src[1].arg // bv.src[0].dtype.itemsize))),
lambda bv, idx, bi: bi.replace(src=(bv.src[0], idx * bv.dtype.itemsize // bv.src[0].dtype.itemsize + bv.src[1].arg))),
# parametrize host buffers
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK)),), allow_any_len=True, name="buf"), parametrize_host_buffer),

Binary file not shown.

View file

@ -79,7 +79,7 @@ A \op{Buffer}'s \textbf{addrspace} is \texttt{GLOBAL}, \texttt{LOCAL}, or \textt
\op{Index} & $(T, i_0, i_1, \ldots)$ & --- & Index from left. $()$-shaped $i$ removes dim; $(k,)$-shaped makes it $k$. \\
\op{Stack} & $(T_0, T_1, \ldots)$ & --- & Join along a newly created leading axis. All shapes must match. \\
\op{Replicated} & $(T,)$ & axes & Mark $T$ as replicated along axes. Collapse axes to $1$. \\
\op{BufferView} & $(T, \mathrm{offset})$ & size, dtype & Zero-copy \textit{size} elems of dtype; offset is \texttt{weakint} bytes. \\
\op{BufferView} & $(T, \mathrm{offset})$ & size, dtype & Zero-copy \textit{size} elems of dtype; offset is \texttt{weakint} elems of $T$ dtype. \\
\bottomrule
\end{tabular}

View file

@ -38,7 +38,7 @@ def check_assign(buffer_lists, copies=None):
for orig_si, new_si in zip(linear.src, result.src):
for orig, new in zip(orig_si.src[1:], new_si.src[1:]):
if new.op is Ops.BUFFER_VIEW and id(orig) not in replace_map:
replace_map[id(orig)] = (new.src[0], new.src[1].arg, new.arg * new.dtype.itemsize)
replace_map[id(orig)] = (new.src[0], new.src[1].arg * new.src[0].dtype.itemsize, new.arg * new.dtype.itemsize)
# verify pinned buffers are not planned
for buf in held_bufs:

View file

@ -15,7 +15,9 @@ class TestMetalGraph(unittest.TestCase):
buf = MagicMock()
if offset > 0:
buf.op = Ops.BUFFER_VIEW
buf.src = (None, UOp.const(dtypes.weakint, offset))
src = MagicMock()
src.dtype = dtypes.uint8
buf.src = (src, UOp.const(dtypes.weakint, offset))
buf.dtype = dtypes.uint8
else:
buf.op = Ops.BUFFER

View file

@ -60,8 +60,11 @@ def _make_buffer_view(src:UOp) -> UOp|None:
"""If movement ops on src collapse to a contiguous range, return BUFFER_VIEW.reshape(src.shape). Otherwise None."""
if (offset := src.contiguous_view_offset()) is None: return None
buf = src.base
offset *= src.dtype.itemsize
if buf.op is Ops.BUFFER_VIEW: offset, buf = offset + buf.src[1].arg, buf.src[0]
if buf.op is Ops.BUFFER_VIEW:
byte_offset = buf.src[1].arg * buf.src[0].dtype.itemsize + offset * src.dtype.itemsize
buf = buf.src[0]
if byte_offset % buf.dtype.itemsize != 0: return None
offset = byte_offset // buf.dtype.itemsize
return UOp(Ops.BUFFER_VIEW, src.dtype, (buf, UOp.const(dtypes.weakint, offset)), src.numel()).reshape(src.shape)
def contiguous_mops_to_view(c:UOp, src:UOp):

View file

@ -27,7 +27,9 @@ def get_call_name(call:UOp, bufs:list[Buffer], var_vals:dict[str, int]|None=None
ast, arg_uops = call.src[0], get_call_arg_uops(call)
if ast.op is Ops.PROGRAM: return ast.arg.name
if ast.op is Ops.BUFFER_VIEW: return colored(f"view {_uop_sz_to_str(arg_uops[0]):>10} @ {ast.src[1].arg:<10d}", "yellow")
if ast.op is Ops.BUFFER_VIEW:
offset = ast.src[1].arg * arg_uops[1].dtype.itemsize
return colored(f"view {_uop_sz_to_str(arg_uops[0]):>10} @ {offset:<10d}", "yellow")
if ast.op is Ops.COPY: return colored(f"copy {_uop_sz_to_str(arg_uops[0]):>10}, {bufs[0].device[:7]:>7s} <- {bufs[1].device[:7]:7s}", "yellow")
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "encdec": return colored(f"enc/dec {_uop_sz_to_str(arg_uops[0])}", "yellow")
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph": return colored(f"batched {len(ast.src[0].src)}", "cyan")
@ -149,7 +151,7 @@ def unwrap_multi(call:UOp, resolved:list[UOp]) -> Iterator[tuple[list[Buffer], d
def exec_view(ctx:ExecContext, call:UOp, ast:UOp) -> float|None:
resolved = resolve_params(call, ctx.input_uops)
bufs = [cast(Buffer, b.buffer) for b in resolved]
bv = bufs[1].view(resolved[0].arg, ast.dtype, ast.src[1].arg)
bv = bufs[1].view(resolved[0].arg, ast.dtype, ast.src[1].arg*bufs[1].dtype.itemsize)
with track_stats(ctx, call, bv.device, [bv, bufs[1]], ctx.var_vals): buffers[resolved[0]] = bv
return None

View file

@ -109,5 +109,5 @@ class MetalGraph(GraphRunner):
@staticmethod
def supports_uop(batch_devs, new_call:UOp) -> bool:
# Metal ICB replay encodes offsets as uint32; reject if any Metal buffer offset exceeds 32-bit range.
if any(b.op is Ops.BUFFER_VIEW and b.src[1].arg > 0xFFFFFFFF for b in new_call.src[1:]): return False
if any(b.op is Ops.BUFFER_VIEW and b.src[1].arg * b.src[0].dtype.itemsize > 0xFFFFFFFF for b in new_call.src[1:]): return False
return GraphRunner.supports_uop(batch_devs, new_call)

View file

@ -350,7 +350,6 @@ def late_buffer_view(t:UOp, b:UOp):
if len(shape) == 0: offset = x.src[1].arg
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
offset *= x.base.dtype.itemsize
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base, UOp.const(dtypes.weakint, offset)), size), b.src[1]))

View file

@ -799,11 +799,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
offset = self.src[1].arg
if isinstance(buf, MultiBuffer):
mbuf = MultiBuffer.__new__(MultiBuffer)
mbuf.bufs = [b.view(self.arg, self.dtype, offset) for b in buf.bufs]
mbuf.bufs = [b.view(self.arg, self.dtype, offset * self.src[0].dtype.itemsize) for b in buf.bufs]
buffers[self] = mbuf
return mbuf
assert isinstance(buf, Buffer), "must be a Buffer for BUFFER_VIEW"
buffers[self] = bv = buf.view(self.arg, self.dtype, offset)
buffers[self] = bv = buf.view(self.arg, self.dtype, offset * self.src[0].dtype.itemsize)
return bv
if self.op is Ops.MSELECT:
ret = self.src[0].buffer