mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
11 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
50be175f88 | ||
|
|
33d80db64c | ||
|
|
28b6cff043 | ||
|
|
23ee2769bd | ||
|
|
5055668d1e | ||
|
|
8305a5804c | ||
|
|
911bb4ff44 |
||
|
|
942d69e139 | ||
|
|
343921f873 | ||
|
|
c88ad9d24f | ||
|
|
60dcc9f4df |
6 changed files with 15 additions and 13 deletions
|
|
@ -106,12 +106,12 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
|||
post_cat = UOp(Ops.PTRCAT, ptrdtype.base.ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace).vec(vec.dtype.count), tuple(ret))
|
||||
return post_cat.gep(tuple(cast(list[int], idxs)))
|
||||
|
||||
def cat_after_store(cat:UOp, data:UOp):
|
||||
def cat_after_store(cat:UOp, data:UOp, sto:UOp):
|
||||
# TODO: this is written in many places
|
||||
offset = 0
|
||||
ret: list[UOp] = []
|
||||
for s in cat.src:
|
||||
ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count)))))
|
||||
ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count))), *sto.src[2:]))
|
||||
offset += s.dtype.count
|
||||
# dtype CAT
|
||||
dtypes: list[PtrDType] = [x.dtype for x in ret if isinstance(x.dtype, PtrDType)]
|
||||
|
|
@ -119,13 +119,13 @@ def cat_after_store(cat:UOp, data:UOp):
|
|||
out_dtype = dtypes[0].base.scalar().vec(sum([x.count for x in dtypes])).ptr(dtypes[0].size, dtypes[0].addrspace)
|
||||
return UOp(Ops.PTRCAT, dtype=out_dtype, src=tuple(ret))
|
||||
|
||||
def gep_on_store(gep:UOp, st:UOp):
|
||||
def gep_on_store(gep:UOp, st:UOp, sto:UOp):
|
||||
# NOTE: we need to invert the gep here, but it may be an expanding gep
|
||||
# fake argsort. TODO: handle duplicates
|
||||
a = {}
|
||||
for i,x in enumerate(gep.arg): a[x] = i
|
||||
new_arg = tuple(x[1] for x in sorted(a.items()))
|
||||
return gep.src[0].store(st.gep(new_arg))
|
||||
return gep.src[0].store(st.gep(new_arg), *sto.src[2:])
|
||||
|
||||
load_store_folding = PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL), name="buf")), UPat.var("vec"))), expand_index),
|
||||
|
|
@ -135,12 +135,12 @@ load_store_folding = PatternMatcher([
|
|||
(UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True),
|
||||
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
|
||||
# GEP on data of STORE
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st"))), gep_on_store),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st")), allow_any_len=True, name="sto"), gep_on_store),
|
||||
# put PTRCAT after LOAD
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
|
||||
lambda cat,ld: UOp(Ops.CAT, ld.dtype, tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
|
||||
# put PTRCAT after STORE
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data"))), cat_after_store),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data")), allow_any_len=True, name="sto"), cat_after_store),
|
||||
])
|
||||
|
||||
# ***** optional patterns *****
|
||||
|
|
@ -310,8 +310,9 @@ pm_render = PatternMatcher([
|
|||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) if len(x.src) == 1 or x.src[1].op is Ops.CUSTOM else None),
|
||||
# gate any stores that aren't gated with ifs
|
||||
(UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store"),
|
||||
lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src+(UOp(Ops.IF, src=(idx.src[2],)),))),
|
||||
(UPat(Ops.STORE, src=(UPat(src=(UPat.var("buf"), UPat.var("idx"), UPat(name="gate", dtype=dtypes.bool))).or_casted(), UPat.var("val")),
|
||||
name="store", allow_any_len=True),
|
||||
lambda gate,store,buf,idx,val: UOp(Ops.STORE, dtype=store.dtype, src=(buf.index(idx), val, UOp(Ops.IF, src=(gate,)),)+store.src[2:])),
|
||||
])
|
||||
|
||||
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
|
||||
|
|
@ -339,7 +340,7 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
|||
lst = [acc.load()] + lst # put acc as the first element
|
||||
ctx.acc_num += 1
|
||||
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
||||
return acc.store(ret).load() if len(reduce_range) != 0 else ret
|
||||
return acc.store(ret, *reduce_range).load() if len(reduce_range) != 0 else ret
|
||||
|
||||
def no_vectorized_reduce(inp:UOp, red:UOp):
|
||||
if inp.dtype != red.dtype:
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ def do_expand(root:UOp):
|
|||
if root.op is Ops.IF:
|
||||
# for the first arg of IF, just pass them through ignoring UNROLLS
|
||||
new_srcs.append(src)
|
||||
elif root.op is Ops.REDUCE and src.op is Ops.RANGE:
|
||||
elif root.op in {Ops.REDUCE, Ops.STORE} and src.op is Ops.RANGE:
|
||||
# for any range args of REDUCE, pass them through
|
||||
new_srcs.append(src)
|
||||
elif src.dtype.count > 1:
|
||||
|
|
|
|||
|
|
@ -101,6 +101,7 @@ class BlockContext:
|
|||
idx_context, store_context = ctx.last_ctx(u.src[0]), ctx.last_ctx(u.src[1])
|
||||
ctx.child_ctxs[u] = tuple([y for y in store_context if y not in idx_context and y.op is Ops.RANGE])
|
||||
else: ctx.child_ctxs[u] = ()
|
||||
#elif u.op is Ops.STORE: ctx.child_ctxs[u] = tuple([x for x in ctx.block_ctxs[u] if x not in u.src])
|
||||
return ctx
|
||||
|
||||
# ***** make blocks *****
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
|
|||
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
||||
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
|
||||
if oidx is not ridx: valid = valid * oidx.eq(0)
|
||||
return buf.index(idx, valid).store(x.src[1])
|
||||
return buf.index(idx, valid).store(x.src[1], *[x for x in UOp.sink(idx, valid).toposort() if x.op is Ops.RANGE])
|
||||
|
||||
def lower_const(ctx:IndexContext, view:UOp, c:UOp):
|
||||
if all(x.mask is None for x in view.arg.views): return c
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ class PythonProgram:
|
|||
assert dtype is not None, f"{uop} is missing a dtype"
|
||||
dl[i] = dtype
|
||||
if uop is Ops.STORE:
|
||||
assert len(inp) == 2, "expected store is ([(memory, offset, gate)], [value])"
|
||||
#assert len(inp) == 2, "expected store is ([(memory, offset, gate)], [value])"
|
||||
for j,val in enumerate(inp[1] if dtp[1].count > 1 else [inp[1]]):
|
||||
for (m,o,g),v in zip(inp[0], val):
|
||||
if g: _store(m, o+j, v)
|
||||
|
|
|
|||
|
|
@ -409,7 +409,7 @@ REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT}
|
|||
REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT}
|
||||
sym = symbolic_flat+PatternMatcher([
|
||||
# LOAD/STORE -> NOOP
|
||||
(UPat.var('x').store(UPat.var('x').load()), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
|
||||
(UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
|
||||
(UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),
|
||||
# VECTORIZE/CONST, VECTORIZE/GEP
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue