mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cb5d827ed9 |
2 changed files with 5 additions and 2 deletions
|
|
@ -60,7 +60,9 @@ load_store_indexing = PatternMatcher([
|
||||||
def expand_index(buf:UOp, vec:UOp):
|
def expand_index(buf:UOp, vec:UOp):
|
||||||
if getenv("UNSAFE_DISABLE_MASK", 0): vec = vec.get_idx()
|
if getenv("UNSAFE_DISABLE_MASK", 0): vec = vec.get_idx()
|
||||||
# generate the individual indexes
|
# generate the individual indexes
|
||||||
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), ptr=True) for i in range(vec.dtype.count)]),
|
# we use `.buf_target()` here to avoid traversing into the AFTER
|
||||||
|
buf_target = buf.buf_target().rtag() if buf.op is Ops.AFTER else buf
|
||||||
|
midx = graph_rewrite(UOp.sink(*[buf_target.index(vec.gep(i), ptr=True) for i in range(vec.dtype.count)]),
|
||||||
symbolic+load_store_indexing, name=f"index_buf_{buf.arg}")
|
symbolic+load_store_indexing, name=f"index_buf_{buf.arg}")
|
||||||
# extract all the relevant offsets
|
# extract all the relevant offsets
|
||||||
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
|
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
|
||||||
|
|
@ -93,7 +95,7 @@ def expand_index(buf:UOp, vec:UOp):
|
||||||
assert None not in idxs, f"some idxs are missing {idxs}"
|
assert None not in idxs, f"some idxs are missing {idxs}"
|
||||||
# this base thing is for image, we want the CAT to be a normal pointer
|
# this base thing is for image, we want the CAT to be a normal pointer
|
||||||
post_cat = UOp(Ops.PTRCAT, buf.ptrdtype.base.ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace).vec(global_offset), tuple(ret))
|
post_cat = UOp(Ops.PTRCAT, buf.ptrdtype.base.ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace).vec(global_offset), tuple(ret))
|
||||||
return post_cat.gep(tuple(cast(list[int], idxs)))
|
return post_cat.gep(tuple(cast(list[int], idxs))).substitute({buf_target:buf})
|
||||||
|
|
||||||
def cat_after_store(cat:UOp, data:UOp, sto:UOp):
|
def cat_after_store(cat:UOp, data:UOp, sto:UOp):
|
||||||
# TODO: this is written in many places
|
# TODO: this is written in many places
|
||||||
|
|
|
||||||
|
|
@ -615,6 +615,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
|
|
||||||
def buf_target(self) -> UOp:
|
def buf_target(self) -> UOp:
|
||||||
# the buffer that's being loaded from or store to
|
# the buffer that's being loaded from or store to
|
||||||
|
# NOTE: this is the good one to keep
|
||||||
match self.op:
|
match self.op:
|
||||||
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return self
|
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return self
|
||||||
case Ops.AFTER | Ops.INDEX | Ops.STORE | Ops.LOAD: return self.src[0].buf_target()
|
case Ops.AFTER | Ops.INDEX | Ops.STORE | Ops.LOAD: return self.src[0].buf_target()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue