universal

This commit is contained in:
George Hotz 2026-05-26 08:55:33 -07:00
commit ff3b0c658d
2 changed files with 5 additions and 6 deletions

View file

@ -285,8 +285,7 @@ pm_render = PatternMatcher([
(UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x),
# rewrite INDEX to SLICE
(UPat(Ops.INDEX, name="x"), lambda x:
UOp(Ops.SLICE, dtype=x.dtype, src=(x.src[0], x.src[1]*x.src[0].dtype.itemsize),
arg=0 if x.dtype.vcount == 1 else x.dtype.vcount)),
UOp(Ops.SLICE, dtype=x.dtype, src=(x.src[0], x.src[1]), arg=0 if x.dtype.vcount == 1 else x.dtype.vcount)),
])
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***

View file

@ -43,14 +43,14 @@ base_rewrite = PatternMatcher([
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, str(x.arg))})"),
# default const render
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
# new load/store
# index + slice is the new index
(UPat.var("buf").index(UPat.var('idx')), lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
(UPat(Ops.SLICE, src=(UPat.var("buf"), UPat.var('idx'))),
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
# new load/store
(UPat(Ops.LOAD, src=(UPat.var('bidx'),)), lambda ctx,bidx: f"(*{ctx[bidx]})"),
(UPat(Ops.LOAD, src=(UPat.var("bidx"), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
# slice is the new index
(UPat(Ops.SLICE, src=(UPat.var("buf"), UPat.var("off")), name="x"),
lambda ctx,buf,off,x: f"(({ctx.render_dtype(x.dtype)})((unsigned char *){ctx[buf]}+{ctx[off]}))"),
# alu/gep
# TODO: look for left-associative
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](