mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
bugfixes
This commit is contained in:
parent
b09142a893
commit
9bbd12dc65
3 changed files with 7 additions and 4 deletions
|
|
@ -264,6 +264,8 @@ load_store_indexing = PatternMatcher([
|
|||
(UPat(Ops.AND, name="valid"), simplify_valid),
|
||||
# image load valid idx simplification
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
|
||||
# index True is just Index
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat(Ops.CONST, arg=True))), lambda buf,start_idx: buf.index(start_idx)),
|
||||
# delete_redundant_gates (after expand)
|
||||
(UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
|
||||
UPat.var("val"))), delete_redundant_gates),
|
||||
|
|
@ -278,7 +280,8 @@ pm_render = PatternMatcher([
|
|||
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
|
||||
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
||||
# give any loads that are masked an alt value
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), name="x"), lambda x: x.replace(src=x.src+(x.const_like(0),))),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) if len(x.src) == 1 or x.src[1].op is Ops.CUSTOM else None),
|
||||
# gate any stores that aren't gated with ifs
|
||||
#(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store"),
|
||||
# lambda store,idx: UOp(Ops.STORE, src=store.src+(UOp(Ops.IF, src=(idx.src[2],)),))),
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ base_rewrite = PatternMatcher([
|
|||
# new load/store
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True),
|
||||
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("gate"))).or_casted('bidx'), UPat.var("var"))),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("gate"))).or_casted('bidx'), UPat.var("var")), allow_any_len=True),
|
||||
lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
|
||||
|
|
|
|||
|
|
@ -188,8 +188,8 @@ def add_to_mul(c:UOp, x:UOp):
|
|||
def prefetch_l1(ld:UOp, idx:UOp):
|
||||
if ld.src[-1].op is Ops.CUSTOM: return None
|
||||
ranges = sorted([x for x in ld.src[0].src[0].toposort if x.op is Ops.RANGE], key=lambda x: x.arg)
|
||||
x1 = UOp(Ops.CUSTOM, dtypes.void, src=(idx.index(UOp.const(dtypes.int, ld.dtype.count*2)),), arg="__builtin_HEXAGON_Y2_dcfetch({0});")
|
||||
x2 = UOp(Ops.CUSTOM, dtypes.void, src=(idx.substitute({ranges[-1]: ranges[-1].src[0]}),), arg="__builtin_HEXAGON_Y2_dcfetch({0});")
|
||||
x1 = UOp(Ops.CUSTOM, dtypes.void, src=(idx.src[0], idx.src[1]+UOp.const(dtypes.int, ld.dtype.count*2),), arg="__builtin_HEXAGON_Y2_dcfetch({0}+{1});")
|
||||
x2 = UOp(Ops.CUSTOM, dtypes.void, src=(idx.src[0], idx.src[1].substitute({ranges[-1]: ranges[-1].src[0]}),), arg="__builtin_HEXAGON_Y2_dcfetch({0}+{1});")
|
||||
return ld.replace(src=ld.src+(x1, x2))
|
||||
|
||||
def vectorize_shuffle(vec:UOp):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue