Compare commits

...

2 commits

Author SHA1 Message Date
George Hotz
6c370cd524 stuff 2025-07-30 14:12:06 -07:00
George Hotz
9185e962a3 add local caching to opts 2025-07-30 13:20:11 -07:00

View file

@ -449,7 +449,22 @@ class Kernel:
if op.op in GroupOp.Buffer and op in self.bufs: if op.op in GroupOp.Buffer and op in self.bufs:
st = self.sts[self.bufs.index(op)] st = self.sts[self.bufs.index(op)]
# replace the VIEW source # replace the VIEW source
return ret.replace(src=(ret.src[0].replace(arg=st),)+ret.src[1:]) if op.op is Ops.LOAD:
global_buf = ret.src[0].src[0]
# add locals cache
local_shape = [s if self.axis_types[i] not in (AxisType.GLOBAL, AxisType.REDUCE)
and ss != 0 else 1 for i,(s,ss) in enumerate(zip(st.shape, st.real_strides()))]
# NOTE: this can have any permutation here
lst = lst_store = ShapeTracker.from_shape(tuple(local_shape)).expand(st.shape)
lbuf = UOp(Ops.DEFINE_LOCAL, dtype=global_buf.dtype.base.ptr(prod(local_shape), AddrSpace.LOCAL), arg=1000+global_buf.arg)
# TODO: permute to place any UPCASTs in 0 stride LOCALs
# any permutes of st + lst_store together are fine
print(list(zip(self.shape_str(), st.shape, st.real_strides())))
ret = ret.replace(src=(ret.src[0].replace(arg=st),)+ret.src[1:])
ret = lbuf.view(lst).load(lbuf.view(lst_store).store(ret))
else:
ret = ret.replace(src=(ret.src[0].replace(arg=st),)+ret.src[1:])
return ret
if op.op is Ops.SINK: if op.op is Ops.SINK:
# NOTE: should group_for_reduces be added to the local_dims? # NOTE: should group_for_reduces be added to the local_dims?
kernel_name = ret.arg.name if ret.arg is not None else self.name if name_override is None else name_override kernel_name = ret.arg.name if ret.arg is not None else self.name if name_override is None else name_override