Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
cb5d827ed9 use buf_target in expand_index 2025-11-19 15:37:19 -08:00
2 changed files with 5 additions and 2 deletions

View file

@ -60,7 +60,9 @@ load_store_indexing = PatternMatcher([
def expand_index(buf:UOp, vec:UOp):
if getenv("UNSAFE_DISABLE_MASK", 0): vec = vec.get_idx()
# generate the individual indexes
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), ptr=True) for i in range(vec.dtype.count)]),
# we use `.buf_target()` here to avoid traversing into the AFTER
buf_target = buf.buf_target().rtag() if buf.op is Ops.AFTER else buf
midx = graph_rewrite(UOp.sink(*[buf_target.index(vec.gep(i), ptr=True) for i in range(vec.dtype.count)]),
symbolic+load_store_indexing, name=f"index_buf_{buf.arg}")
# extract all the relevant offsets
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
@ -93,7 +95,7 @@ def expand_index(buf:UOp, vec:UOp):
assert None not in idxs, f"some idxs are missing {idxs}"
# this base thing is for image, we want the CAT to be a normal pointer
post_cat = UOp(Ops.PTRCAT, buf.ptrdtype.base.ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace).vec(global_offset), tuple(ret))
return post_cat.gep(tuple(cast(list[int], idxs)))
return post_cat.gep(tuple(cast(list[int], idxs))).substitute({buf_target:buf})
def cat_after_store(cat:UOp, data:UOp, sto:UOp):
# TODO: this is written in many places

View file

@ -615,6 +615,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def buf_target(self) -> UOp:
# the buffer that's being loaded from or store to
# NOTE: this is the good one to keep
match self.op:
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return self
case Ops.AFTER | Ops.INDEX | Ops.STORE | Ops.LOAD: return self.src[0].buf_target()