Remote offset (#10311)

For memory savings from memory planner. Also for some reason it makes hlb
cifar on mac noticeably faster.

master:
```
  3  210.12 ms run,    4.34 ms python,  205.78 ms REMOTE, 2075.90 loss, 0.002698 LR, 2.07 GB used,   1558.41 GFLOPS,    327.45 GOPS
  4  210.40 ms run,    4.33 ms python,  206.07 ms REMOTE, 2481.94 loss, 0.002262 LR, 2.07 GB used,   1556.34 GFLOPS,    327.45 GOPS
  5  188.08 ms run,    4.41 ms python,  183.67 ms REMOTE, 1967.49 loss, 0.001827 LR, 2.07 GB used,   1741.00 GFLOPS,    327.45 GOPS
  6  211.19 ms run,    4.26 ms python,  206.93 ms REMOTE, 1511.62 loss, 0.001392 LR, 2.07 GB used,   1550.51 GFLOPS,    327.45 GOPS
```

this:
```
  3  189.05 ms run,    4.50 ms python,  184.55 ms REMOTE, 2075.90 loss, 0.002698 LR, 1.60 GB used,   1732.08 GFLOPS,    327.45 GOPS
  4  187.81 ms run,    4.11 ms python,  183.71 ms REMOTE, 2481.94 loss, 0.002262 LR, 1.60 GB used,   1743.49 GFLOPS,    327.45 GOPS
  5  186.70 ms run,    4.09 ms python,  182.62 ms REMOTE, 1967.49 loss, 0.001827 LR, 1.60 GB used,   1753.89 GFLOPS,    327.45 GOPS
  6  187.18 ms run,    4.06 ms python,  183.12 ms REMOTE, 1511.62 loss, 0.001392 LR, 1.60 GB used,   1749.36 GFLOPS,    327.45 GOPS
```

(`PYTHONPATH=. REMOTE=1 REMOTEDEV=METAL BS=256 STEPS=10 python examples/hlb_cifar10.py`)

Clouldn't reliably reproduce the faster thing on tinybox though.
This commit is contained in:
uuuvn 2025-05-15 23:20:01 +05:00 committed by GitHub
commit c2bf2c6bb0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -33,6 +33,7 @@ class RemoteProperties:
graph_supported: bool
graph_supports_multi: bool
transfer_supported: bool
offset_supported: bool
@dataclass(frozen=True)
class GetProperties(RemoteRequest): pass
@ -40,6 +41,9 @@ class GetProperties(RemoteRequest): pass
@dataclass(frozen=True)
class BufferAlloc(RemoteRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702
@dataclass(frozen=True)
class BufferOffset(RemoteRequest): buffer_num: int; size: int; offset: int; sbuffer_num: int # noqa: E702
@dataclass(frozen=True)
class BufferFree(RemoteRequest): buffer_num: int # noqa: E702
@ -94,8 +98,9 @@ class GraphExec(RemoteRequest):
wait: bool
# for safe deserialization
eval_globals = {x.__name__:x for x in [SessionFree, RemoteProperties, GetProperties, BufferAlloc, BufferFree, CopyIn, CopyOut, Transfer, ProgramAlloc,
ProgramFree, ProgramExec, GraphComputeItem, GraphAlloc, GraphFree, GraphExec, BufferSpec, UOp, Ops, dtypes]}
eval_globals = {x.__name__:x for x in [SessionFree, RemoteProperties, GetProperties, BufferAlloc, BufferOffset, BufferFree, CopyIn, CopyOut, Transfer,
ProgramAlloc, ProgramFree, ProgramExec, GraphComputeItem, GraphAlloc, GraphFree, GraphExec, BufferSpec, UOp,
Ops, dtypes]}
attribute_whitelist: dict[Any, set[str]] = {dtypes: {*DTYPES_DICT.keys(), 'imagef', 'imageh'}, Ops: {x.name for x in Ops}}
eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)),
ast.Dict: lambda x: {safe_eval(k):safe_eval(v) for k,v in zip(x.keys, x.values)},
@ -169,12 +174,15 @@ class RemoteHandler:
rp = RemoteProperties(
real_device=dev.device, renderer=(cls.__module__, cls.__name__, args),
graph_supported=graph_cls is not None, graph_supports_multi=graph_cls is not None and issubclass(graph_cls, MultiGraphRunner),
transfer_supported=hasattr(dev.allocator, '_transfer'),
transfer_supported=hasattr(dev.allocator, '_transfer'), offset_supported=hasattr(dev.allocator, '_offset'),
)
ret = repr(rp).encode()
case BufferAlloc():
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
session.buffers[c.buffer_num] = Buffer(dev.device, c.size, dtypes.uint8, options=c.options, preallocate=True)
case BufferOffset():
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already exists"
session.buffers[c.buffer_num] = session.buffers[c.sbuffer_num].view(c.size, dtypes.uint8, c.offset).allocate()
case BufferFree(): del session.buffers[c.buffer_num]
case CopyIn(): session.buffers[c.buffer_num].copyin(memoryview(bytearray(req._h[c.datahash])))
case CopyOut(): session.buffers[c.buffer_num].copyout(memoryview(ret:=bytearray(session.buffers[c.buffer_num].nbytes)))
@ -227,6 +235,9 @@ def remote_server(port:int):
# ***** frontend *****
class RemoteAllocator(Allocator['RemoteDevice']):
def __init__(self, dev:RemoteDevice):
if dev.properties.offset_supported: self._offset = self._dyn_offset
super().__init__(dev)
# TODO: ideally we shouldn't have to deal with images here
def _alloc(self, size:int, options:BufferSpec) -> int:
self.dev.buffer_num += 1
@ -245,6 +256,10 @@ class RemoteAllocator(Allocator['RemoteDevice']):
else:
src_dev.allocator._copyout(tmp:=memoryview(bytearray(sz)), src)
dest_dev.allocator._copyin(dest, tmp)
def _dyn_offset(self, opaque:int, size:int, offset:int) -> int:
self.dev.buffer_num += 1
self.dev.q(BufferOffset(self.dev.buffer_num, size, offset, opaque))
return self.dev.buffer_num
class RemoteProgram:
def __init__(self, dev:RemoteDevice, name:str, lib:bytes):
@ -291,7 +306,7 @@ class RemoteDevice(Compiled):
self.buffer_num: int = 0
self.graph_num: int = 0
self.properties = safe_eval(ast.parse(self.q(GetProperties(), wait=True), mode="eval").body)
self.properties: RemoteProperties = safe_eval(ast.parse(self.q(GetProperties(), wait=True), mode="eval").body)
if DEBUG >= 1: print(f"remote has device {self.properties.real_device}")
# TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer
renderer = self.properties.renderer