mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
2 commits
master
...
local_cach
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c370cd524 | ||
|
|
9185e962a3 |
1 changed files with 16 additions and 1 deletions
|
|
@ -449,7 +449,22 @@ class Kernel:
|
|||
if op.op in GroupOp.Buffer and op in self.bufs:
|
||||
st = self.sts[self.bufs.index(op)]
|
||||
# replace the VIEW source
|
||||
return ret.replace(src=(ret.src[0].replace(arg=st),)+ret.src[1:])
|
||||
if op.op is Ops.LOAD:
|
||||
global_buf = ret.src[0].src[0]
|
||||
# add locals cache
|
||||
local_shape = [s if self.axis_types[i] not in (AxisType.GLOBAL, AxisType.REDUCE)
|
||||
and ss != 0 else 1 for i,(s,ss) in enumerate(zip(st.shape, st.real_strides()))]
|
||||
# NOTE: this can have any permutation here
|
||||
lst = lst_store = ShapeTracker.from_shape(tuple(local_shape)).expand(st.shape)
|
||||
lbuf = UOp(Ops.DEFINE_LOCAL, dtype=global_buf.dtype.base.ptr(prod(local_shape), AddrSpace.LOCAL), arg=1000+global_buf.arg)
|
||||
# TODO: permute to place any UPCASTs in 0 stride LOCALs
|
||||
# any permutes of st + lst_store together are fine
|
||||
print(list(zip(self.shape_str(), st.shape, st.real_strides())))
|
||||
ret = ret.replace(src=(ret.src[0].replace(arg=st),)+ret.src[1:])
|
||||
ret = lbuf.view(lst).load(lbuf.view(lst_store).store(ret))
|
||||
else:
|
||||
ret = ret.replace(src=(ret.src[0].replace(arg=st),)+ret.src[1:])
|
||||
return ret
|
||||
if op.op is Ops.SINK:
|
||||
# NOTE: should group_for_reduces be added to the local_dims?
|
||||
kernel_name = ret.arg.name if ret.arg is not None else self.name if name_override is None else name_override
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue