all BUFFER_VIEW, no INDEX

This commit is contained in:
George Hotz 2026-05-25 19:17:51 -07:00
commit 38bd43981b
4 changed files with 13 additions and 2 deletions

View file

@ -283,6 +283,10 @@ pm_render = PatternMatcher([
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.STACK, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
(UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x),
# rewrite INDEX to SLICE
(UPat(Ops.INDEX, name="x"), lambda x:
UOp(Ops.SLICE, dtype=x.dtype, src=(x.src[0], x.src[1]*x.src[0].dtype.itemsize),
arg=0 if x.dtype.vcount == 1 else x.dtype.vcount)),
])
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***

View file

@ -47,7 +47,10 @@ base_rewrite = PatternMatcher([
(UPat.var("buf").index(UPat.var('idx')), lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
(UPat(Ops.LOAD, src=(UPat.var('bidx'),)), lambda ctx,bidx: f"(*{ctx[bidx]})"),
(UPat(Ops.LOAD, src=(UPat.var("bidx"), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var"))), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
# slice is the new index
(UPat(Ops.SLICE, src=(UPat.var("buf"), UPat.var("off")), name="x"),
lambda ctx,buf,off,x: f"(({ctx.render_dtype(x.dtype)})((unsigned char *){ctx[buf]}+{ctx[off]}))"),
# alu/gep
# TODO: look for left-associative
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
@ -197,7 +200,7 @@ class CStyleLanguage(Renderer):
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
if u.op in {Ops.ENDIF, Ops.END}: depth -= 1
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI, Ops.SLICE} or \
(u.op is Ops.LOAD and u.src[0].ptrdtype.addrspace == AddrSpace.REG) or \
(u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \
(u.op in {Ops.STACK, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):

View file

@ -267,6 +267,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
case Ops.BINARY: return (len(self.arg),)
case Ops.BUFFER: return (self.arg,)
case Ops.SLICE:
if self.arg == 0: return ()
# HACK: SLICE is used inside kernels, so we set the shape to () if it's on an INDEX
if self.src[0].op is Ops.INDEX: return ()
return (self.arg,)

View file

@ -200,6 +200,9 @@ spec_program = PatternMatcher([
# weakint is not allowed in programs
(UPat(GroupOp.All, dtypes.weakint), lambda: False),
# slice in program
(UPat(Ops.SLICE), lambda: True),
# movement ops are not allowed in programs
(UPat(GroupOp.Movement), lambda: False),