mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Remote properties is a dataclass (#10283)
Not strictly required for anything but soon there will be like 4 new properties and having it be a huge json just seems like a bad taste. It also seems right to not have a separate endpoint for this, just `GetProperties` request that returns a repr of this similar to how requests are sent in `BatchRequest`. This will also make a switch to anything other than http much simpler if it will be required for any reason, like just a tcp stream of `BatchRequest`s
This commit is contained in:
parent
ba87eca0f1
commit
ddff9857b8
3 changed files with 28 additions and 22 deletions
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
|
|
@ -791,7 +791,7 @@ jobs:
|
|||
- name: Check Device.DEFAULT and print some source
|
||||
run: |
|
||||
python -c "from tinygrad import Device; assert Device.DEFAULT == 'REMOTE', Device.DEFAULT"
|
||||
python -c "from tinygrad import Device; assert Device.default.properties['remotedev'] == 'METAL', Device.default.properties['remotedev']"
|
||||
python -c "from tinygrad import Device; assert Device.default.properties.real_device == 'METAL', Device.default.properties.real_device"
|
||||
DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
|
||||
- name: Run REMOTE=1 Test
|
||||
run: |
|
||||
|
|
|
|||
|
|
@ -68,4 +68,4 @@ def not_support_multi_device():
|
|||
return CI and REAL_DEV in ("GPU", "CUDA")
|
||||
|
||||
# NOTE: This will open REMOTE if it's the default device
|
||||
REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties['remotedev'])
|
||||
REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties.real_device)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from __future__ import annotations
|
|||
from typing import Callable, Optional, Any, cast
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field, replace
|
||||
import multiprocessing, functools, asyncio, http, http.client, hashlib, json, time, os, binascii, struct, ast, contextlib
|
||||
import multiprocessing, functools, asyncio, http, http.client, hashlib, time, os, binascii, struct, ast, contextlib
|
||||
from tinygrad.renderer import Renderer, ProgramSpec
|
||||
from tinygrad.dtype import DTYPES_DICT, dtypes
|
||||
from tinygrad.ops import UOp, Ops, Variable, sint
|
||||
|
|
@ -23,6 +23,15 @@ from tinygrad.runtime.graph.cpu import CPUGraph
|
|||
@dataclass(frozen=True)
|
||||
class RemoteRequest: session: tuple[str, int]|None = field(default=None, kw_only=True)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RemoteProperties:
|
||||
real_device: str
|
||||
renderer: tuple[str, str, tuple[Any, ...]]
|
||||
graph_supported: bool
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GetProperties(RemoteRequest): pass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BufferAlloc(RemoteRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702
|
||||
|
||||
|
|
@ -74,8 +83,8 @@ class GraphExec(RemoteRequest):
|
|||
wait: bool
|
||||
|
||||
# for safe deserialization
|
||||
eval_globals = {x.__name__:x for x in [BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, ProgramExec, GraphComputeItem,
|
||||
GraphAlloc, GraphFree, GraphExec, BufferSpec, UOp, Ops, dtypes]}
|
||||
eval_globals = {x.__name__:x for x in [RemoteProperties, GetProperties, BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree,
|
||||
ProgramExec, GraphComputeItem, GraphAlloc, GraphFree, GraphExec, BufferSpec, UOp, Ops, dtypes]}
|
||||
attribute_whitelist: dict[Any, set[str]] = {dtypes: {*DTYPES_DICT.keys(), 'imagef', 'imageh'}, Ops: {x.name for x in Ops}}
|
||||
eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)),
|
||||
ast.Dict: lambda x: {safe_eval(k):safe_eval(v) for k,v in zip(x.keys, x.values)},
|
||||
|
|
@ -141,6 +150,11 @@ class RemoteHandler:
|
|||
if DEBUG >= 1: print(c)
|
||||
session, dev = self.sessions[unwrap(c.session)], Device[f"{self.base_device}:{unwrap(c.session)[1]}"]
|
||||
match c:
|
||||
case GetProperties():
|
||||
cls, args = dev.renderer.__reduce__()
|
||||
# CPUGraph re-renders kernel from uops specified in CompiledRunner, this is not supported
|
||||
graph_cls = gt if (gt:=graph_class(Device[self.base_device])) is not CPUGraph else None
|
||||
ret = repr(RemoteProperties(dev.device, (cls.__module__, cls.__name__, args), graph_cls is not None)).encode()
|
||||
case BufferAlloc():
|
||||
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
|
||||
session.buffers[c.buffer_num] = Buffer(dev.device, c.size, dtypes.uint8, options=c.options, preallocate=True)
|
||||
|
|
@ -170,11 +184,6 @@ class RemoteHandler:
|
|||
case GraphExec():
|
||||
r = session.graphs[c.graph_num]([session.buffers[buf] for buf in c.bufs], c.var_vals, wait=c.wait)
|
||||
if r is not None: ret = str(r).encode()
|
||||
elif path == "/properties" and method == "GET":
|
||||
cls, args = Device[self.base_device].renderer.__reduce__()
|
||||
# CPUGraph re-renders kernel from uops specified in CompiledRunner, this is not supported
|
||||
graph_cls = gt if (gt:=graph_class(Device[self.base_device])) is not CPUGraph else None
|
||||
ret = json.dumps({'remotedev': self.base_device, 'renderer': (cls.__module__, cls.__name__, args), 'graph': graph_cls is not None}).encode()
|
||||
else: status, ret = http.HTTPStatus.NOT_FOUND, b"Not Found"
|
||||
return status, ret
|
||||
|
||||
|
|
@ -231,18 +240,19 @@ class RemoteDevice(Compiled):
|
|||
while 1:
|
||||
try:
|
||||
self.conn = http.client.HTTPConnection(self.host, timeout=60.0)
|
||||
self.properties = json.loads(self.send("GET", "properties").decode())
|
||||
self.q(GetProperties())
|
||||
self.properties = safe_eval(ast.parse(self.batch_submit(), mode="eval").body)
|
||||
break
|
||||
except Exception as e:
|
||||
print(e)
|
||||
time.sleep(0.1)
|
||||
if DEBUG >= 1: print(f"remote has device {self.properties['remotedev']}")
|
||||
if DEBUG >= 1: print(f"remote has device {self.properties.real_device}")
|
||||
# TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer
|
||||
renderer = self.properties['renderer']
|
||||
renderer = self.properties.renderer
|
||||
if not renderer[0].startswith("tinygrad.renderer.") or not renderer[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {renderer}")
|
||||
renderer_class = fromimport(renderer[0], renderer[1]) # TODO: is this secure?
|
||||
if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {renderer}")
|
||||
graph = fromimport('tinygrad.runtime.graph.remote', 'RemoteGraph') if self.properties['graph'] else None
|
||||
graph = fromimport('tinygrad.runtime.graph.remote', 'RemoteGraph') if self.properties.graph_supported else None
|
||||
super().__init__(device, RemoteAllocator(self), renderer_class(*renderer[2]), Compiler(), functools.partial(RemoteProgram, self), graph)
|
||||
|
||||
def __del__(self):
|
||||
|
|
@ -261,15 +271,11 @@ class RemoteDevice(Compiled):
|
|||
def batch_submit(self):
|
||||
data = self.req.serialize()
|
||||
with Timing(f"*** send {len(self.req._q):-3d} requests {len(self.req._h):-3d} hashes with len {len(data)/1024:.2f} kB in ", enabled=DEBUG>=1):
|
||||
ret = self.send("POST", "batch", data)
|
||||
self.conn.request("POST", "/batch", data)
|
||||
response = self.conn.getresponse()
|
||||
assert response.status == 200, f"POST /batch failed: {response}"
|
||||
ret = response.read()
|
||||
self.req = BatchRequest()
|
||||
return ret
|
||||
|
||||
def send(self, method, path, data:Optional[bytes]=None) -> bytes:
|
||||
# TODO: retry logic
|
||||
self.conn.request(method, "/"+path, data)
|
||||
response = self.conn.getresponse()
|
||||
assert response.status == 200, f"failed on {method} {path}"
|
||||
return response.read()
|
||||
|
||||
if __name__ == "__main__": remote_server(getenv("PORT", 6667))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue