mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Remote multi (basic) (#10269)
* Basic remote multi support Simplest thing to be able to use remote with multiple gpus, very slow because no transfers (copyin copyout for cross-device copies) * tests
This commit is contained in:
parent
5f64bbc63d
commit
ba87eca0f1
3 changed files with 27 additions and 23 deletions
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
|
|
@ -473,7 +473,7 @@ jobs:
|
|||
run: CPU=1 PYTHONPATH=. python3 test/test_quantize_onnx.py
|
||||
- name: Run REMOTE=1 Test
|
||||
run: |
|
||||
REMOTEDEV=CPU REMOTE=1 python3 -m pytest test/test_tiny.py test/test_jit.py
|
||||
REMOTEDEV=CPU REMOTE=1 python3 -m pytest test/test_tiny.py test/test_jit.py test/test_multitensor.py
|
||||
REMOTEDEV=GPU REMOTE=1 python3 -m pytest test/test_tiny.py test/test_image_dtype.py test/test_jit.py
|
||||
REMOTEDEV=GPU IMAGE=2 REMOTE=1 python3 -m pytest test/test_tiny.py test/test_image_dtype.py
|
||||
- name: Test Optimization Helpers
|
||||
|
|
@ -795,7 +795,7 @@ jobs:
|
|||
DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
|
||||
- name: Run REMOTE=1 Test
|
||||
run: |
|
||||
python3 -m pytest test/test_tiny.py test/test_jit.py
|
||||
python3 -m pytest test/test_tiny.py test/test_jit.py test/test_multitensor.py
|
||||
|
||||
osxtests:
|
||||
strategy:
|
||||
|
|
|
|||
|
|
@ -64,8 +64,8 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None):
|
|||
return out_buf.cast(uop.dtype.fmt).tolist()[0]
|
||||
|
||||
def not_support_multi_device():
|
||||
# REMOTE doesn't support multi device anywhere, GPU and CUDA don't support multi device if in CI
|
||||
return Device.DEFAULT == "REMOTE" or (CI and Device.DEFAULT in ("GPU", "CUDA"))
|
||||
# GPU and CUDA don't support multi device if in CI
|
||||
return CI and REAL_DEV in ("GPU", "CUDA")
|
||||
|
||||
# NOTE: This will open REMOTE if it's the default device
|
||||
REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties['remotedev'])
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from tinygrad.runtime.graph.cpu import CPUGraph
|
|||
# ***** API *****
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RemoteRequest: session: str|None = field(default=None, kw_only=True)
|
||||
class RemoteRequest: session: tuple[str, int]|None = field(default=None, kw_only=True)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BufferAlloc(RemoteRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702
|
||||
|
|
@ -116,9 +116,9 @@ class RemoteSession:
|
|||
buffers: dict[int, Buffer] = field(default_factory=dict)
|
||||
|
||||
class RemoteHandler:
|
||||
def __init__(self, device: str):
|
||||
self.device = device
|
||||
self.sessions: defaultdict[str, RemoteSession] = defaultdict(RemoteSession)
|
||||
def __init__(self, base_device: str):
|
||||
self.base_device = base_device
|
||||
self.sessions: defaultdict[tuple[str, int], RemoteSession] = defaultdict(RemoteSession)
|
||||
|
||||
async def __call__(self, reader:asyncio.StreamReader, writer:asyncio.StreamWriter):
|
||||
while (req_hdr:=(await reader.readline()).decode().strip()):
|
||||
|
|
@ -139,17 +139,17 @@ class RemoteHandler:
|
|||
# the cmds are always last (currently in datahash)
|
||||
for c in req._q:
|
||||
if DEBUG >= 1: print(c)
|
||||
session = self.sessions[unwrap(c.session)]
|
||||
session, dev = self.sessions[unwrap(c.session)], Device[f"{self.base_device}:{unwrap(c.session)[1]}"]
|
||||
match c:
|
||||
case BufferAlloc():
|
||||
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
|
||||
session.buffers[c.buffer_num] = Buffer(self.device, c.size, dtypes.uint8, options=c.options, preallocate=True)
|
||||
session.buffers[c.buffer_num] = Buffer(dev.device, c.size, dtypes.uint8, options=c.options, preallocate=True)
|
||||
case BufferFree(): del session.buffers[c.buffer_num]
|
||||
case CopyIn(): session.buffers[c.buffer_num].copyin(memoryview(bytearray(req._h[c.datahash])))
|
||||
case CopyOut(): session.buffers[c.buffer_num].copyout(memoryview(ret:=bytearray(session.buffers[c.buffer_num].nbytes)))
|
||||
case ProgramAlloc():
|
||||
lib = Device[self.device].compiler.compile_cached(req._h[c.datahash].decode())
|
||||
session.programs[(c.name, c.datahash)] = Device[self.device].runtime(c.name, lib)
|
||||
lib = dev.compiler.compile_cached(req._h[c.datahash].decode())
|
||||
session.programs[(c.name, c.datahash)] = dev.runtime(c.name, lib)
|
||||
case ProgramFree(): del session.programs[(c.name, c.datahash)]
|
||||
case ProgramExec():
|
||||
bufs = [session.buffers[x]._buf for x in c.bufs]
|
||||
|
|
@ -157,10 +157,10 @@ class RemoteHandler:
|
|||
r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args)
|
||||
if r is not None: ret = str(r).encode()
|
||||
case GraphAlloc():
|
||||
graph_fn: Callable = unwrap(Device[self.device].graph)
|
||||
graph_fn: Callable = unwrap(dev.graph)
|
||||
def _parse_ji(gi: GraphComputeItem):
|
||||
prg = session.programs[(gi.name, gi.datahash)]
|
||||
ps = ProgramSpec(gi.name, '', self.device, UOp(Ops.NOOP), vars=list(gi.vars),
|
||||
ps = ProgramSpec(gi.name, '', dev.device, UOp(Ops.NOOP), vars=list(gi.vars),
|
||||
global_size=list(cast(tuple[int], gi.global_size)) if gi.global_size is not None else None,
|
||||
local_size=list(cast(tuple[int], gi.local_size)) if gi.local_size is not None else None)
|
||||
return ExecItem(CompiledRunner(ps, precompiled=b'', prg=prg), [session.buffers[buf] for buf in gi.bufs])
|
||||
|
|
@ -171,10 +171,10 @@ class RemoteHandler:
|
|||
r = session.graphs[c.graph_num]([session.buffers[buf] for buf in c.bufs], c.var_vals, wait=c.wait)
|
||||
if r is not None: ret = str(r).encode()
|
||||
elif path == "/properties" and method == "GET":
|
||||
cls, args = Device[self.device].renderer.__reduce__()
|
||||
cls, args = Device[self.base_device].renderer.__reduce__()
|
||||
# CPUGraph re-renders kernel from uops specified in CompiledRunner, this is not supported
|
||||
graph_cls = gt if (gt:=graph_class(Device[self.device])) is not CPUGraph else None
|
||||
ret = json.dumps({'remotedev': self.device, 'renderer': (cls.__module__, cls.__name__, args), 'graph': graph_cls is not None}).encode()
|
||||
graph_cls = gt if (gt:=graph_class(Device[self.base_device])) is not CPUGraph else None
|
||||
ret = json.dumps({'remotedev': self.base_device, 'renderer': (cls.__module__, cls.__name__, args), 'graph': graph_cls is not None}).encode()
|
||||
else: status, ret = http.HTTPStatus.NOT_FOUND, b"Not Found"
|
||||
return status, ret
|
||||
|
||||
|
|
@ -219,14 +219,12 @@ class RemoteProgram:
|
|||
|
||||
class RemoteDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
if (host:=getenv("HOST", "")) != "": self.host = host
|
||||
else:
|
||||
multiprocessing.Process(target=remote_server, args=(6667,), name="MainProcess", daemon=True).start()
|
||||
self.host = "127.0.0.1:6667"
|
||||
self.host = getenv("HOST", "") or RemoteDevice.local_server()
|
||||
|
||||
# state for the connection
|
||||
self.session = binascii.hexlify(os.urandom(0x10)).decode()
|
||||
self.buffer_num, self.graph_num = 0, 0
|
||||
self.session = (binascii.hexlify(os.urandom(0x10)).decode(), int(device.split(":")[1]) if ":" in device else 0)
|
||||
self.buffer_num: int = 0
|
||||
self.graph_num: int = 0
|
||||
self.req: BatchRequest = BatchRequest()
|
||||
|
||||
if DEBUG >= 1: print(f"remote with host {self.host}")
|
||||
|
|
@ -254,6 +252,12 @@ class RemoteDevice(Compiled):
|
|||
|
||||
def q(self, x:RemoteRequest): self.req.q(replace(x, session=self.session))
|
||||
|
||||
@functools.cache
|
||||
@staticmethod
|
||||
def local_server():
|
||||
multiprocessing.Process(target=remote_server, args=(6667,), name="MainProcess", daemon=True).start()
|
||||
return "127.0.0.1:6667"
|
||||
|
||||
def batch_submit(self):
|
||||
data = self.req.serialize()
|
||||
with Timing(f"*** send {len(self.req._q):-3d} requests {len(self.req._h):-3d} hashes with len {len(data)/1024:.2f} kB in ", enabled=DEBUG>=1):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue