mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Remote offset (#10311)
For memory savings from memory planner. Also for some reason it makes hlb cifar on mac noticeably faster. master: ``` 3 210.12 ms run, 4.34 ms python, 205.78 ms REMOTE, 2075.90 loss, 0.002698 LR, 2.07 GB used, 1558.41 GFLOPS, 327.45 GOPS 4 210.40 ms run, 4.33 ms python, 206.07 ms REMOTE, 2481.94 loss, 0.002262 LR, 2.07 GB used, 1556.34 GFLOPS, 327.45 GOPS 5 188.08 ms run, 4.41 ms python, 183.67 ms REMOTE, 1967.49 loss, 0.001827 LR, 2.07 GB used, 1741.00 GFLOPS, 327.45 GOPS 6 211.19 ms run, 4.26 ms python, 206.93 ms REMOTE, 1511.62 loss, 0.001392 LR, 2.07 GB used, 1550.51 GFLOPS, 327.45 GOPS ``` this: ``` 3 189.05 ms run, 4.50 ms python, 184.55 ms REMOTE, 2075.90 loss, 0.002698 LR, 1.60 GB used, 1732.08 GFLOPS, 327.45 GOPS 4 187.81 ms run, 4.11 ms python, 183.71 ms REMOTE, 2481.94 loss, 0.002262 LR, 1.60 GB used, 1743.49 GFLOPS, 327.45 GOPS 5 186.70 ms run, 4.09 ms python, 182.62 ms REMOTE, 1967.49 loss, 0.001827 LR, 1.60 GB used, 1753.89 GFLOPS, 327.45 GOPS 6 187.18 ms run, 4.06 ms python, 183.12 ms REMOTE, 1511.62 loss, 0.001392 LR, 1.60 GB used, 1749.36 GFLOPS, 327.45 GOPS ``` (`PYTHONPATH=. REMOTE=1 REMOTEDEV=METAL BS=256 STEPS=10 python examples/hlb_cifar10.py`) Clouldn't reliably reproduce the faster thing on tinybox though.
This commit is contained in:
parent
3c453e96a9
commit
c2bf2c6bb0
1 changed files with 19 additions and 4 deletions
|
|
@ -33,6 +33,7 @@ class RemoteProperties:
|
|||
graph_supported: bool
|
||||
graph_supports_multi: bool
|
||||
transfer_supported: bool
|
||||
offset_supported: bool
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GetProperties(RemoteRequest): pass
|
||||
|
|
@ -40,6 +41,9 @@ class GetProperties(RemoteRequest): pass
|
|||
@dataclass(frozen=True)
|
||||
class BufferAlloc(RemoteRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BufferOffset(RemoteRequest): buffer_num: int; size: int; offset: int; sbuffer_num: int # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BufferFree(RemoteRequest): buffer_num: int # noqa: E702
|
||||
|
||||
|
|
@ -94,8 +98,9 @@ class GraphExec(RemoteRequest):
|
|||
wait: bool
|
||||
|
||||
# for safe deserialization
|
||||
eval_globals = {x.__name__:x for x in [SessionFree, RemoteProperties, GetProperties, BufferAlloc, BufferFree, CopyIn, CopyOut, Transfer, ProgramAlloc,
|
||||
ProgramFree, ProgramExec, GraphComputeItem, GraphAlloc, GraphFree, GraphExec, BufferSpec, UOp, Ops, dtypes]}
|
||||
eval_globals = {x.__name__:x for x in [SessionFree, RemoteProperties, GetProperties, BufferAlloc, BufferOffset, BufferFree, CopyIn, CopyOut, Transfer,
|
||||
ProgramAlloc, ProgramFree, ProgramExec, GraphComputeItem, GraphAlloc, GraphFree, GraphExec, BufferSpec, UOp,
|
||||
Ops, dtypes]}
|
||||
attribute_whitelist: dict[Any, set[str]] = {dtypes: {*DTYPES_DICT.keys(), 'imagef', 'imageh'}, Ops: {x.name for x in Ops}}
|
||||
eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)),
|
||||
ast.Dict: lambda x: {safe_eval(k):safe_eval(v) for k,v in zip(x.keys, x.values)},
|
||||
|
|
@ -169,12 +174,15 @@ class RemoteHandler:
|
|||
rp = RemoteProperties(
|
||||
real_device=dev.device, renderer=(cls.__module__, cls.__name__, args),
|
||||
graph_supported=graph_cls is not None, graph_supports_multi=graph_cls is not None and issubclass(graph_cls, MultiGraphRunner),
|
||||
transfer_supported=hasattr(dev.allocator, '_transfer'),
|
||||
transfer_supported=hasattr(dev.allocator, '_transfer'), offset_supported=hasattr(dev.allocator, '_offset'),
|
||||
)
|
||||
ret = repr(rp).encode()
|
||||
case BufferAlloc():
|
||||
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
|
||||
session.buffers[c.buffer_num] = Buffer(dev.device, c.size, dtypes.uint8, options=c.options, preallocate=True)
|
||||
case BufferOffset():
|
||||
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already exists"
|
||||
session.buffers[c.buffer_num] = session.buffers[c.sbuffer_num].view(c.size, dtypes.uint8, c.offset).allocate()
|
||||
case BufferFree(): del session.buffers[c.buffer_num]
|
||||
case CopyIn(): session.buffers[c.buffer_num].copyin(memoryview(bytearray(req._h[c.datahash])))
|
||||
case CopyOut(): session.buffers[c.buffer_num].copyout(memoryview(ret:=bytearray(session.buffers[c.buffer_num].nbytes)))
|
||||
|
|
@ -227,6 +235,9 @@ def remote_server(port:int):
|
|||
# ***** frontend *****
|
||||
|
||||
class RemoteAllocator(Allocator['RemoteDevice']):
|
||||
def __init__(self, dev:RemoteDevice):
|
||||
if dev.properties.offset_supported: self._offset = self._dyn_offset
|
||||
super().__init__(dev)
|
||||
# TODO: ideally we shouldn't have to deal with images here
|
||||
def _alloc(self, size:int, options:BufferSpec) -> int:
|
||||
self.dev.buffer_num += 1
|
||||
|
|
@ -245,6 +256,10 @@ class RemoteAllocator(Allocator['RemoteDevice']):
|
|||
else:
|
||||
src_dev.allocator._copyout(tmp:=memoryview(bytearray(sz)), src)
|
||||
dest_dev.allocator._copyin(dest, tmp)
|
||||
def _dyn_offset(self, opaque:int, size:int, offset:int) -> int:
|
||||
self.dev.buffer_num += 1
|
||||
self.dev.q(BufferOffset(self.dev.buffer_num, size, offset, opaque))
|
||||
return self.dev.buffer_num
|
||||
|
||||
class RemoteProgram:
|
||||
def __init__(self, dev:RemoteDevice, name:str, lib:bytes):
|
||||
|
|
@ -291,7 +306,7 @@ class RemoteDevice(Compiled):
|
|||
self.buffer_num: int = 0
|
||||
self.graph_num: int = 0
|
||||
|
||||
self.properties = safe_eval(ast.parse(self.q(GetProperties(), wait=True), mode="eval").body)
|
||||
self.properties: RemoteProperties = safe_eval(ast.parse(self.q(GetProperties(), wait=True), mode="eval").body)
|
||||
if DEBUG >= 1: print(f"remote has device {self.properties.real_device}")
|
||||
# TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer
|
||||
renderer = self.properties.renderer
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue