Compare commits

...

11 commits

Author SHA1 Message Date
George Hotz
50be175f88 store 2025-07-22 19:28:39 -07:00
George Hotz
33d80db64c this can happen later 2025-07-22 19:12:37 -07:00
George Hotz
28b6cff043 broken 2025-07-22 19:05:11 -07:00
George Hotz
23ee2769bd gate store 2025-07-22 19:03:06 -07:00
George Hotz
5055668d1e oops, forgot that 2025-07-22 18:54:14 -07:00
George Hotz
8305a5804c just the range thing 2025-07-22 18:52:26 -07:00
George Hotz
911bb4ff44
Merge branch 'master' into endrange 2025-07-22 18:43:22 -07:00
George Hotz
942d69e139 store is endrange 2025-07-22 15:30:10 -07:00
George Hotz
343921f873 Revert "end the ranges in the stores"
This reverts commit c88ad9d24f.
2025-07-22 15:12:21 -07:00
George Hotz
c88ad9d24f end the ranges in the stores 2025-07-22 15:04:10 -07:00
George Hotz
60dcc9f4df insert endrange 2025-07-22 14:38:49 -07:00
6 changed files with 15 additions and 13 deletions

View file

@ -106,12 +106,12 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
post_cat = UOp(Ops.PTRCAT, ptrdtype.base.ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace).vec(vec.dtype.count), tuple(ret))
return post_cat.gep(tuple(cast(list[int], idxs)))
def cat_after_store(cat:UOp, data:UOp):
def cat_after_store(cat:UOp, data:UOp, sto:UOp):
# TODO: this is written in many places
offset = 0
ret: list[UOp] = []
for s in cat.src:
ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count)))))
ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count))), *sto.src[2:]))
offset += s.dtype.count
# dtype CAT
dtypes: list[PtrDType] = [x.dtype for x in ret if isinstance(x.dtype, PtrDType)]
@ -119,13 +119,13 @@ def cat_after_store(cat:UOp, data:UOp):
out_dtype = dtypes[0].base.scalar().vec(sum([x.count for x in dtypes])).ptr(dtypes[0].size, dtypes[0].addrspace)
return UOp(Ops.PTRCAT, dtype=out_dtype, src=tuple(ret))
def gep_on_store(gep:UOp, st:UOp):
def gep_on_store(gep:UOp, st:UOp, sto:UOp):
# NOTE: we need to invert the gep here, but it may be an expanding gep
# fake argsort. TODO: handle duplicates
a = {}
for i,x in enumerate(gep.arg): a[x] = i
new_arg = tuple(x[1] for x in sorted(a.items()))
return gep.src[0].store(st.gep(new_arg))
return gep.src[0].store(st.gep(new_arg), *sto.src[2:])
load_store_folding = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL), name="buf")), UPat.var("vec"))), expand_index),
@ -135,12 +135,12 @@ load_store_folding = PatternMatcher([
(UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True),
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
# GEP on data of STORE
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st"))), gep_on_store),
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st")), allow_any_len=True, name="sto"), gep_on_store),
# put PTRCAT after LOAD
(UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
lambda cat,ld: UOp(Ops.CAT, ld.dtype, tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
# put PTRCAT after STORE
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data"))), cat_after_store),
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data")), allow_any_len=True, name="sto"), cat_after_store),
])
# ***** optional patterns *****
@ -310,8 +310,9 @@ pm_render = PatternMatcher([
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) if len(x.src) == 1 or x.src[1].op is Ops.CUSTOM else None),
# gate any stores that aren't gated with ifs
(UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store"),
lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src+(UOp(Ops.IF, src=(idx.src[2],)),))),
(UPat(Ops.STORE, src=(UPat(src=(UPat.var("buf"), UPat.var("idx"), UPat(name="gate", dtype=dtypes.bool))).or_casted(), UPat.var("val")),
name="store", allow_any_len=True),
lambda gate,store,buf,idx,val: UOp(Ops.STORE, dtype=store.dtype, src=(buf.index(idx), val, UOp(Ops.IF, src=(gate,)),)+store.src[2:])),
])
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
@ -339,7 +340,7 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
lst = [acc.load()] + lst # put acc as the first element
ctx.acc_num += 1
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
return acc.store(ret).load() if len(reduce_range) != 0 else ret
return acc.store(ret, *reduce_range).load() if len(reduce_range) != 0 else ret
def no_vectorized_reduce(inp:UOp, red:UOp):
if inp.dtype != red.dtype:

View file

@ -49,7 +49,7 @@ def do_expand(root:UOp):
if root.op is Ops.IF:
# for the first arg of IF, just pass them through ignoring UNROLLS
new_srcs.append(src)
elif root.op is Ops.REDUCE and src.op is Ops.RANGE:
elif root.op in {Ops.REDUCE, Ops.STORE} and src.op is Ops.RANGE:
# for any range args of REDUCE, pass them through
new_srcs.append(src)
elif src.dtype.count > 1:

View file

@ -101,6 +101,7 @@ class BlockContext:
idx_context, store_context = ctx.last_ctx(u.src[0]), ctx.last_ctx(u.src[1])
ctx.child_ctxs[u] = tuple([y for y in store_context if y not in idx_context and y.op is Ops.RANGE])
else: ctx.child_ctxs[u] = ()
#elif u.op is Ops.STORE: ctx.child_ctxs[u] = tuple([x for x in ctx.block_ctxs[u] if x not in u.src])
return ctx
# ***** make blocks *****

View file

@ -57,7 +57,7 @@ def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
if oidx is not ridx: valid = valid * oidx.eq(0)
return buf.index(idx, valid).store(x.src[1])
return buf.index(idx, valid).store(x.src[1], *[x for x in UOp.sink(idx, valid).toposort() if x.op is Ops.RANGE])
def lower_const(ctx:IndexContext, view:UOp, c:UOp):
if all(x.mask is None for x in view.arg.views): return c

View file

@ -56,7 +56,7 @@ class PythonProgram:
assert dtype is not None, f"{uop} is missing a dtype"
dl[i] = dtype
if uop is Ops.STORE:
assert len(inp) == 2, "expected store is ([(memory, offset, gate)], [value])"
#assert len(inp) == 2, "expected store is ([(memory, offset, gate)], [value])"
for j,val in enumerate(inp[1] if dtp[1].count > 1 else [inp[1]]):
for (m,o,g),v in zip(inp[0], val):
if g: _store(m, o+j, v)

View file

@ -409,7 +409,7 @@ REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT}
REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT}
sym = symbolic_flat+PatternMatcher([
# LOAD/STORE -> NOOP
(UPat.var('x').store(UPat.var('x').load()), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
(UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
(UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),
# VECTORIZE/CONST, VECTORIZE/GEP
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),