Remote properties is a dataclass (#10283)

Not strictly required for anything but soon there will be like 4 new
properties and having it be a huge json just seems like a bad taste.

It also seems right to not have a separate endpoint for this, just
`GetProperties` request that returns a repr of this similar to how
requests are sent in `BatchRequest`.

This will also make a switch to anything other than http much simpler
if it will be required for any reason, like just a tcp stream of
`BatchRequest`s
This commit is contained in:
uuuvn 2025-05-13 23:56:58 +05:00 committed by GitHub
commit ddff9857b8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 28 additions and 22 deletions

View file

@ -791,7 +791,7 @@ jobs:
- name: Check Device.DEFAULT and print some source
run: |
python -c "from tinygrad import Device; assert Device.DEFAULT == 'REMOTE', Device.DEFAULT"
python -c "from tinygrad import Device; assert Device.default.properties['remotedev'] == 'METAL', Device.default.properties['remotedev']"
python -c "from tinygrad import Device; assert Device.default.properties.real_device == 'METAL', Device.default.properties.real_device"
DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
- name: Run REMOTE=1 Test
run: |

View file

@ -68,4 +68,4 @@ def not_support_multi_device():
return CI and REAL_DEV in ("GPU", "CUDA")
# NOTE: This will open REMOTE if it's the default device
REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties['remotedev'])
REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties.real_device)

View file

@ -8,7 +8,7 @@ from __future__ import annotations
from typing import Callable, Optional, Any, cast
from collections import defaultdict
from dataclasses import dataclass, field, replace
import multiprocessing, functools, asyncio, http, http.client, hashlib, json, time, os, binascii, struct, ast, contextlib
import multiprocessing, functools, asyncio, http, http.client, hashlib, time, os, binascii, struct, ast, contextlib
from tinygrad.renderer import Renderer, ProgramSpec
from tinygrad.dtype import DTYPES_DICT, dtypes
from tinygrad.ops import UOp, Ops, Variable, sint
@ -23,6 +23,15 @@ from tinygrad.runtime.graph.cpu import CPUGraph
@dataclass(frozen=True)
class RemoteRequest: session: tuple[str, int]|None = field(default=None, kw_only=True)
@dataclass(frozen=True)
class RemoteProperties:
real_device: str
renderer: tuple[str, str, tuple[Any, ...]]
graph_supported: bool
@dataclass(frozen=True)
class GetProperties(RemoteRequest): pass
@dataclass(frozen=True)
class BufferAlloc(RemoteRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702
@ -74,8 +83,8 @@ class GraphExec(RemoteRequest):
wait: bool
# for safe deserialization
eval_globals = {x.__name__:x for x in [BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, ProgramExec, GraphComputeItem,
GraphAlloc, GraphFree, GraphExec, BufferSpec, UOp, Ops, dtypes]}
eval_globals = {x.__name__:x for x in [RemoteProperties, GetProperties, BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree,
ProgramExec, GraphComputeItem, GraphAlloc, GraphFree, GraphExec, BufferSpec, UOp, Ops, dtypes]}
attribute_whitelist: dict[Any, set[str]] = {dtypes: {*DTYPES_DICT.keys(), 'imagef', 'imageh'}, Ops: {x.name for x in Ops}}
eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)),
ast.Dict: lambda x: {safe_eval(k):safe_eval(v) for k,v in zip(x.keys, x.values)},
@ -141,6 +150,11 @@ class RemoteHandler:
if DEBUG >= 1: print(c)
session, dev = self.sessions[unwrap(c.session)], Device[f"{self.base_device}:{unwrap(c.session)[1]}"]
match c:
case GetProperties():
cls, args = dev.renderer.__reduce__()
# CPUGraph re-renders kernel from uops specified in CompiledRunner, this is not supported
graph_cls = gt if (gt:=graph_class(Device[self.base_device])) is not CPUGraph else None
ret = repr(RemoteProperties(dev.device, (cls.__module__, cls.__name__, args), graph_cls is not None)).encode()
case BufferAlloc():
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
session.buffers[c.buffer_num] = Buffer(dev.device, c.size, dtypes.uint8, options=c.options, preallocate=True)
@ -170,11 +184,6 @@ class RemoteHandler:
case GraphExec():
r = session.graphs[c.graph_num]([session.buffers[buf] for buf in c.bufs], c.var_vals, wait=c.wait)
if r is not None: ret = str(r).encode()
elif path == "/properties" and method == "GET":
cls, args = Device[self.base_device].renderer.__reduce__()
# CPUGraph re-renders kernel from uops specified in CompiledRunner, this is not supported
graph_cls = gt if (gt:=graph_class(Device[self.base_device])) is not CPUGraph else None
ret = json.dumps({'remotedev': self.base_device, 'renderer': (cls.__module__, cls.__name__, args), 'graph': graph_cls is not None}).encode()
else: status, ret = http.HTTPStatus.NOT_FOUND, b"Not Found"
return status, ret
@ -231,18 +240,19 @@ class RemoteDevice(Compiled):
while 1:
try:
self.conn = http.client.HTTPConnection(self.host, timeout=60.0)
self.properties = json.loads(self.send("GET", "properties").decode())
self.q(GetProperties())
self.properties = safe_eval(ast.parse(self.batch_submit(), mode="eval").body)
break
except Exception as e:
print(e)
time.sleep(0.1)
if DEBUG >= 1: print(f"remote has device {self.properties['remotedev']}")
if DEBUG >= 1: print(f"remote has device {self.properties.real_device}")
# TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer
renderer = self.properties['renderer']
renderer = self.properties.renderer
if not renderer[0].startswith("tinygrad.renderer.") or not renderer[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {renderer}")
renderer_class = fromimport(renderer[0], renderer[1]) # TODO: is this secure?
if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {renderer}")
graph = fromimport('tinygrad.runtime.graph.remote', 'RemoteGraph') if self.properties['graph'] else None
graph = fromimport('tinygrad.runtime.graph.remote', 'RemoteGraph') if self.properties.graph_supported else None
super().__init__(device, RemoteAllocator(self), renderer_class(*renderer[2]), Compiler(), functools.partial(RemoteProgram, self), graph)
def __del__(self):
@ -261,15 +271,11 @@ class RemoteDevice(Compiled):
def batch_submit(self):
data = self.req.serialize()
with Timing(f"*** send {len(self.req._q):-3d} requests {len(self.req._h):-3d} hashes with len {len(data)/1024:.2f} kB in ", enabled=DEBUG>=1):
ret = self.send("POST", "batch", data)
self.conn.request("POST", "/batch", data)
response = self.conn.getresponse()
assert response.status == 200, f"POST /batch failed: {response}"
ret = response.read()
self.req = BatchRequest()
return ret
def send(self, method, path, data:Optional[bytes]=None) -> bytes:
# TODO: retry logic
self.conn.request(method, "/"+path, data)
response = self.conn.getresponse()
assert response.status == 200, f"failed on {method} {path}"
return response.read()
if __name__ == "__main__": remote_server(getenv("PORT", 6667))