mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
11 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
50be175f88 | ||
|
|
33d80db64c | ||
|
|
28b6cff043 | ||
|
|
23ee2769bd | ||
|
|
5055668d1e | ||
|
|
8305a5804c | ||
|
|
911bb4ff44 |
||
|
|
942d69e139 | ||
|
|
343921f873 | ||
|
|
c88ad9d24f | ||
|
|
60dcc9f4df |
6 changed files with 15 additions and 13 deletions
|
|
@ -106,12 +106,12 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
||||||
post_cat = UOp(Ops.PTRCAT, ptrdtype.base.ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace).vec(vec.dtype.count), tuple(ret))
|
post_cat = UOp(Ops.PTRCAT, ptrdtype.base.ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace).vec(vec.dtype.count), tuple(ret))
|
||||||
return post_cat.gep(tuple(cast(list[int], idxs)))
|
return post_cat.gep(tuple(cast(list[int], idxs)))
|
||||||
|
|
||||||
def cat_after_store(cat:UOp, data:UOp):
|
def cat_after_store(cat:UOp, data:UOp, sto:UOp):
|
||||||
# TODO: this is written in many places
|
# TODO: this is written in many places
|
||||||
offset = 0
|
offset = 0
|
||||||
ret: list[UOp] = []
|
ret: list[UOp] = []
|
||||||
for s in cat.src:
|
for s in cat.src:
|
||||||
ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count)))))
|
ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count))), *sto.src[2:]))
|
||||||
offset += s.dtype.count
|
offset += s.dtype.count
|
||||||
# dtype CAT
|
# dtype CAT
|
||||||
dtypes: list[PtrDType] = [x.dtype for x in ret if isinstance(x.dtype, PtrDType)]
|
dtypes: list[PtrDType] = [x.dtype for x in ret if isinstance(x.dtype, PtrDType)]
|
||||||
|
|
@ -119,13 +119,13 @@ def cat_after_store(cat:UOp, data:UOp):
|
||||||
out_dtype = dtypes[0].base.scalar().vec(sum([x.count for x in dtypes])).ptr(dtypes[0].size, dtypes[0].addrspace)
|
out_dtype = dtypes[0].base.scalar().vec(sum([x.count for x in dtypes])).ptr(dtypes[0].size, dtypes[0].addrspace)
|
||||||
return UOp(Ops.PTRCAT, dtype=out_dtype, src=tuple(ret))
|
return UOp(Ops.PTRCAT, dtype=out_dtype, src=tuple(ret))
|
||||||
|
|
||||||
def gep_on_store(gep:UOp, st:UOp):
|
def gep_on_store(gep:UOp, st:UOp, sto:UOp):
|
||||||
# NOTE: we need to invert the gep here, but it may be an expanding gep
|
# NOTE: we need to invert the gep here, but it may be an expanding gep
|
||||||
# fake argsort. TODO: handle duplicates
|
# fake argsort. TODO: handle duplicates
|
||||||
a = {}
|
a = {}
|
||||||
for i,x in enumerate(gep.arg): a[x] = i
|
for i,x in enumerate(gep.arg): a[x] = i
|
||||||
new_arg = tuple(x[1] for x in sorted(a.items()))
|
new_arg = tuple(x[1] for x in sorted(a.items()))
|
||||||
return gep.src[0].store(st.gep(new_arg))
|
return gep.src[0].store(st.gep(new_arg), *sto.src[2:])
|
||||||
|
|
||||||
load_store_folding = PatternMatcher([
|
load_store_folding = PatternMatcher([
|
||||||
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL), name="buf")), UPat.var("vec"))), expand_index),
|
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL), name="buf")), UPat.var("vec"))), expand_index),
|
||||||
|
|
@ -135,12 +135,12 @@ load_store_folding = PatternMatcher([
|
||||||
(UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True),
|
(UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True),
|
||||||
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
|
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
|
||||||
# GEP on data of STORE
|
# GEP on data of STORE
|
||||||
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st"))), gep_on_store),
|
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st")), allow_any_len=True, name="sto"), gep_on_store),
|
||||||
# put PTRCAT after LOAD
|
# put PTRCAT after LOAD
|
||||||
(UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
|
(UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
|
||||||
lambda cat,ld: UOp(Ops.CAT, ld.dtype, tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
|
lambda cat,ld: UOp(Ops.CAT, ld.dtype, tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
|
||||||
# put PTRCAT after STORE
|
# put PTRCAT after STORE
|
||||||
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data"))), cat_after_store),
|
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data")), allow_any_len=True, name="sto"), cat_after_store),
|
||||||
])
|
])
|
||||||
|
|
||||||
# ***** optional patterns *****
|
# ***** optional patterns *****
|
||||||
|
|
@ -310,8 +310,9 @@ pm_render = PatternMatcher([
|
||||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
|
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
|
||||||
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) if len(x.src) == 1 or x.src[1].op is Ops.CUSTOM else None),
|
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) if len(x.src) == 1 or x.src[1].op is Ops.CUSTOM else None),
|
||||||
# gate any stores that aren't gated with ifs
|
# gate any stores that aren't gated with ifs
|
||||||
(UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store"),
|
(UPat(Ops.STORE, src=(UPat(src=(UPat.var("buf"), UPat.var("idx"), UPat(name="gate", dtype=dtypes.bool))).or_casted(), UPat.var("val")),
|
||||||
lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src+(UOp(Ops.IF, src=(idx.src[2],)),))),
|
name="store", allow_any_len=True),
|
||||||
|
lambda gate,store,buf,idx,val: UOp(Ops.STORE, dtype=store.dtype, src=(buf.index(idx), val, UOp(Ops.IF, src=(gate,)),)+store.src[2:])),
|
||||||
])
|
])
|
||||||
|
|
||||||
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
|
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
|
||||||
|
|
@ -339,7 +340,7 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
||||||
lst = [acc.load()] + lst # put acc as the first element
|
lst = [acc.load()] + lst # put acc as the first element
|
||||||
ctx.acc_num += 1
|
ctx.acc_num += 1
|
||||||
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
||||||
return acc.store(ret).load() if len(reduce_range) != 0 else ret
|
return acc.store(ret, *reduce_range).load() if len(reduce_range) != 0 else ret
|
||||||
|
|
||||||
def no_vectorized_reduce(inp:UOp, red:UOp):
|
def no_vectorized_reduce(inp:UOp, red:UOp):
|
||||||
if inp.dtype != red.dtype:
|
if inp.dtype != red.dtype:
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ def do_expand(root:UOp):
|
||||||
if root.op is Ops.IF:
|
if root.op is Ops.IF:
|
||||||
# for the first arg of IF, just pass them through ignoring UNROLLS
|
# for the first arg of IF, just pass them through ignoring UNROLLS
|
||||||
new_srcs.append(src)
|
new_srcs.append(src)
|
||||||
elif root.op is Ops.REDUCE and src.op is Ops.RANGE:
|
elif root.op in {Ops.REDUCE, Ops.STORE} and src.op is Ops.RANGE:
|
||||||
# for any range args of REDUCE, pass them through
|
# for any range args of REDUCE, pass them through
|
||||||
new_srcs.append(src)
|
new_srcs.append(src)
|
||||||
elif src.dtype.count > 1:
|
elif src.dtype.count > 1:
|
||||||
|
|
|
||||||
|
|
@ -101,6 +101,7 @@ class BlockContext:
|
||||||
idx_context, store_context = ctx.last_ctx(u.src[0]), ctx.last_ctx(u.src[1])
|
idx_context, store_context = ctx.last_ctx(u.src[0]), ctx.last_ctx(u.src[1])
|
||||||
ctx.child_ctxs[u] = tuple([y for y in store_context if y not in idx_context and y.op is Ops.RANGE])
|
ctx.child_ctxs[u] = tuple([y for y in store_context if y not in idx_context and y.op is Ops.RANGE])
|
||||||
else: ctx.child_ctxs[u] = ()
|
else: ctx.child_ctxs[u] = ()
|
||||||
|
#elif u.op is Ops.STORE: ctx.child_ctxs[u] = tuple([x for x in ctx.block_ctxs[u] if x not in u.src])
|
||||||
return ctx
|
return ctx
|
||||||
|
|
||||||
# ***** make blocks *****
|
# ***** make blocks *****
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
|
||||||
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
||||||
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
|
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
|
||||||
if oidx is not ridx: valid = valid * oidx.eq(0)
|
if oidx is not ridx: valid = valid * oidx.eq(0)
|
||||||
return buf.index(idx, valid).store(x.src[1])
|
return buf.index(idx, valid).store(x.src[1], *[x for x in UOp.sink(idx, valid).toposort() if x.op is Ops.RANGE])
|
||||||
|
|
||||||
def lower_const(ctx:IndexContext, view:UOp, c:UOp):
|
def lower_const(ctx:IndexContext, view:UOp, c:UOp):
|
||||||
if all(x.mask is None for x in view.arg.views): return c
|
if all(x.mask is None for x in view.arg.views): return c
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,7 @@ class PythonProgram:
|
||||||
assert dtype is not None, f"{uop} is missing a dtype"
|
assert dtype is not None, f"{uop} is missing a dtype"
|
||||||
dl[i] = dtype
|
dl[i] = dtype
|
||||||
if uop is Ops.STORE:
|
if uop is Ops.STORE:
|
||||||
assert len(inp) == 2, "expected store is ([(memory, offset, gate)], [value])"
|
#assert len(inp) == 2, "expected store is ([(memory, offset, gate)], [value])"
|
||||||
for j,val in enumerate(inp[1] if dtp[1].count > 1 else [inp[1]]):
|
for j,val in enumerate(inp[1] if dtp[1].count > 1 else [inp[1]]):
|
||||||
for (m,o,g),v in zip(inp[0], val):
|
for (m,o,g),v in zip(inp[0], val):
|
||||||
if g: _store(m, o+j, v)
|
if g: _store(m, o+j, v)
|
||||||
|
|
|
||||||
|
|
@ -409,7 +409,7 @@ REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT}
|
||||||
REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT}
|
REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT}
|
||||||
sym = symbolic_flat+PatternMatcher([
|
sym = symbolic_flat+PatternMatcher([
|
||||||
# LOAD/STORE -> NOOP
|
# LOAD/STORE -> NOOP
|
||||||
(UPat.var('x').store(UPat.var('x').load()), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
|
(UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
|
||||||
(UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),
|
(UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),
|
||||||
# VECTORIZE/CONST, VECTORIZE/GEP
|
# VECTORIZE/CONST, VECTORIZE/GEP
|
||||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
|
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue