Remote multi (basic) (#10269)

* Basic remote multi support

Simplest thing to be able to use remote with multiple gpus, very slow
because no transfers (copyin copyout for cross-device copies)

* tests
This commit is contained in:
uuuvn 2025-05-13 21:52:47 +05:00 committed by GitHub
commit ba87eca0f1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 27 additions and 23 deletions

View file

@ -473,7 +473,7 @@ jobs:
run: CPU=1 PYTHONPATH=. python3 test/test_quantize_onnx.py
- name: Run REMOTE=1 Test
run: |
REMOTEDEV=CPU REMOTE=1 python3 -m pytest test/test_tiny.py test/test_jit.py
REMOTEDEV=CPU REMOTE=1 python3 -m pytest test/test_tiny.py test/test_jit.py test/test_multitensor.py
REMOTEDEV=GPU REMOTE=1 python3 -m pytest test/test_tiny.py test/test_image_dtype.py test/test_jit.py
REMOTEDEV=GPU IMAGE=2 REMOTE=1 python3 -m pytest test/test_tiny.py test/test_image_dtype.py
- name: Test Optimization Helpers
@ -795,7 +795,7 @@ jobs:
DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
- name: Run REMOTE=1 Test
run: |
python3 -m pytest test/test_tiny.py test/test_jit.py
python3 -m pytest test/test_tiny.py test/test_jit.py test/test_multitensor.py
osxtests:
strategy:

View file

@ -64,8 +64,8 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None):
return out_buf.cast(uop.dtype.fmt).tolist()[0]
def not_support_multi_device():
# REMOTE doesn't support multi device anywhere, GPU and CUDA don't support multi device if in CI
return Device.DEFAULT == "REMOTE" or (CI and Device.DEFAULT in ("GPU", "CUDA"))
# GPU and CUDA don't support multi device if in CI
return CI and REAL_DEV in ("GPU", "CUDA")
# NOTE: This will open REMOTE if it's the default device
REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties['remotedev'])

View file

@ -21,7 +21,7 @@ from tinygrad.runtime.graph.cpu import CPUGraph
# ***** API *****
@dataclass(frozen=True)
class RemoteRequest: session: str|None = field(default=None, kw_only=True)
class RemoteRequest: session: tuple[str, int]|None = field(default=None, kw_only=True)
@dataclass(frozen=True)
class BufferAlloc(RemoteRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702
@ -116,9 +116,9 @@ class RemoteSession:
buffers: dict[int, Buffer] = field(default_factory=dict)
class RemoteHandler:
def __init__(self, device: str):
self.device = device
self.sessions: defaultdict[str, RemoteSession] = defaultdict(RemoteSession)
def __init__(self, base_device: str):
self.base_device = base_device
self.sessions: defaultdict[tuple[str, int], RemoteSession] = defaultdict(RemoteSession)
async def __call__(self, reader:asyncio.StreamReader, writer:asyncio.StreamWriter):
while (req_hdr:=(await reader.readline()).decode().strip()):
@ -139,17 +139,17 @@ class RemoteHandler:
# the cmds are always last (currently in datahash)
for c in req._q:
if DEBUG >= 1: print(c)
session = self.sessions[unwrap(c.session)]
session, dev = self.sessions[unwrap(c.session)], Device[f"{self.base_device}:{unwrap(c.session)[1]}"]
match c:
case BufferAlloc():
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
session.buffers[c.buffer_num] = Buffer(self.device, c.size, dtypes.uint8, options=c.options, preallocate=True)
session.buffers[c.buffer_num] = Buffer(dev.device, c.size, dtypes.uint8, options=c.options, preallocate=True)
case BufferFree(): del session.buffers[c.buffer_num]
case CopyIn(): session.buffers[c.buffer_num].copyin(memoryview(bytearray(req._h[c.datahash])))
case CopyOut(): session.buffers[c.buffer_num].copyout(memoryview(ret:=bytearray(session.buffers[c.buffer_num].nbytes)))
case ProgramAlloc():
lib = Device[self.device].compiler.compile_cached(req._h[c.datahash].decode())
session.programs[(c.name, c.datahash)] = Device[self.device].runtime(c.name, lib)
lib = dev.compiler.compile_cached(req._h[c.datahash].decode())
session.programs[(c.name, c.datahash)] = dev.runtime(c.name, lib)
case ProgramFree(): del session.programs[(c.name, c.datahash)]
case ProgramExec():
bufs = [session.buffers[x]._buf for x in c.bufs]
@ -157,10 +157,10 @@ class RemoteHandler:
r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args)
if r is not None: ret = str(r).encode()
case GraphAlloc():
graph_fn: Callable = unwrap(Device[self.device].graph)
graph_fn: Callable = unwrap(dev.graph)
def _parse_ji(gi: GraphComputeItem):
prg = session.programs[(gi.name, gi.datahash)]
ps = ProgramSpec(gi.name, '', self.device, UOp(Ops.NOOP), vars=list(gi.vars),
ps = ProgramSpec(gi.name, '', dev.device, UOp(Ops.NOOP), vars=list(gi.vars),
global_size=list(cast(tuple[int], gi.global_size)) if gi.global_size is not None else None,
local_size=list(cast(tuple[int], gi.local_size)) if gi.local_size is not None else None)
return ExecItem(CompiledRunner(ps, precompiled=b'', prg=prg), [session.buffers[buf] for buf in gi.bufs])
@ -171,10 +171,10 @@ class RemoteHandler:
r = session.graphs[c.graph_num]([session.buffers[buf] for buf in c.bufs], c.var_vals, wait=c.wait)
if r is not None: ret = str(r).encode()
elif path == "/properties" and method == "GET":
cls, args = Device[self.device].renderer.__reduce__()
cls, args = Device[self.base_device].renderer.__reduce__()
# CPUGraph re-renders kernel from uops specified in CompiledRunner, this is not supported
graph_cls = gt if (gt:=graph_class(Device[self.device])) is not CPUGraph else None
ret = json.dumps({'remotedev': self.device, 'renderer': (cls.__module__, cls.__name__, args), 'graph': graph_cls is not None}).encode()
graph_cls = gt if (gt:=graph_class(Device[self.base_device])) is not CPUGraph else None
ret = json.dumps({'remotedev': self.base_device, 'renderer': (cls.__module__, cls.__name__, args), 'graph': graph_cls is not None}).encode()
else: status, ret = http.HTTPStatus.NOT_FOUND, b"Not Found"
return status, ret
@ -219,14 +219,12 @@ class RemoteProgram:
class RemoteDevice(Compiled):
def __init__(self, device:str):
if (host:=getenv("HOST", "")) != "": self.host = host
else:
multiprocessing.Process(target=remote_server, args=(6667,), name="MainProcess", daemon=True).start()
self.host = "127.0.0.1:6667"
self.host = getenv("HOST", "") or RemoteDevice.local_server()
# state for the connection
self.session = binascii.hexlify(os.urandom(0x10)).decode()
self.buffer_num, self.graph_num = 0, 0
self.session = (binascii.hexlify(os.urandom(0x10)).decode(), int(device.split(":")[1]) if ":" in device else 0)
self.buffer_num: int = 0
self.graph_num: int = 0
self.req: BatchRequest = BatchRequest()
if DEBUG >= 1: print(f"remote with host {self.host}")
@ -254,6 +252,12 @@ class RemoteDevice(Compiled):
def q(self, x:RemoteRequest): self.req.q(replace(x, session=self.session))
@functools.cache
@staticmethod
def local_server():
multiprocessing.Process(target=remote_server, args=(6667,), name="MainProcess", daemon=True).start()
return "127.0.0.1:6667"
def batch_submit(self):
data = self.req.serialize()
with Timing(f"*** send {len(self.req._q):-3d} requests {len(self.req._h):-3d} hashes with len {len(data)/1024:.2f} kB in ", enabled=DEBUG>=1):