mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
remote: support several hosts (#15585)
* remote: support several hossts * f
This commit is contained in:
parent
0ed8d9271d
commit
237084b276
2 changed files with 16 additions and 15 deletions
|
|
@ -13,8 +13,9 @@ if __name__ == "__main__":
|
|||
devs = RemotePCIDevice.remote_list(0x1002, ((0, (0,)),), 0) or RemotePCIDevice.remote_list(0x10de, ((0, (0,)),), 0x03)
|
||||
if not devs: raise RuntimeError("no GPU found on remote")
|
||||
|
||||
pci = RemotePCIDevice("BN", devs[0])
|
||||
print(f"connected to {os.environ['REMOTE']}, device: {devs[0]}\n")
|
||||
sock, name = devs[0]
|
||||
pci = RemotePCIDevice("BN", name, sock=sock)
|
||||
print(f"connected to {os.environ['REMOTE']}, device: {name}\n")
|
||||
|
||||
# ping (minimal server round-trip, no device I/O)
|
||||
from tinygrad.runtime.support.system import RemoteCmd
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
import os, mmap, array, functools, ctypes, select, contextlib, dataclasses, sys, itertools, struct, socket, subprocess, time, enum, atexit
|
||||
from tinygrad.helpers import round_up, getenv, OSX, temp, ceildiv, unwrap, fetch, system, _ensure_downloads_dir, DEBUG
|
||||
from tinygrad.helpers import round_up, getenv, OSX, temp, ceildiv, unwrap, fetch, system, _ensure_downloads_dir, DEBUG, flatten
|
||||
from tinygrad.runtime.autogen import libc, pci, vfio, iokit, corefoundation
|
||||
from tinygrad.runtime.support.hcq import FileIOInterface, MMIOInterface, HCQBuffer, hcq_filter_visible_devices
|
||||
from tinygrad.runtime.support.memory import VirtMapping, AddrSpace, BumpAllocator
|
||||
|
|
@ -74,8 +74,8 @@ class _System:
|
|||
return sorted([val for vndr, device, val in all_devs if vndr == vendor and any((device & mask) in devlist for mask, devlist in devices)])
|
||||
|
||||
@functools.cache
|
||||
def list_devices(self, vendor:int, devices:tuple[tuple[int, tuple[int, ...]], ...], base_class:int|None=None) -> list[tuple[type, str]]:
|
||||
if getenv("REMOTE", ""): return [(RemotePCIDevice, x) for x in RemotePCIDevice.remote_list(vendor, devices, base_class)]
|
||||
def list_devices(self, vendor:int, devices:tuple[tuple[int, tuple[int, ...]], ...], base_class:int|None=None):
|
||||
if getenv("REMOTE", ""): return [(functools.partial(RemotePCIDevice,sock=s), x) for s,x in RemotePCIDevice.remote_list(vendor,devices,base_class)]
|
||||
return [(APLRemotePCIDevice if OSX else PCIDevice, x) for x in System.pci_scan_bus(vendor, devices, base_class)]
|
||||
|
||||
def pci_probe_device(self, devpref:str, dev_id:int, vendor:int, devices:tuple[tuple[int, tuple[int, ...]], ...], base_class:int|None=None):
|
||||
|
|
@ -314,15 +314,13 @@ class RemotePCIDevice(PCIDevice):
|
|||
|
||||
@staticmethod
|
||||
@functools.cache
|
||||
def remote_sock() -> socket.socket:
|
||||
host_port = getenv("REMOTE", "127.0.0.1:6667").split(":")
|
||||
host, port = host_port[0], int(host_port[1]) if len(host_port) > 1 else 6667
|
||||
def remote_sock(host:str, port:int) -> socket.socket:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
sock.settimeout(getenv("REMOTE_TIMEOUT", 3))
|
||||
sock.connect((host, port))
|
||||
sock.settimeout(None)
|
||||
if DEBUG >= 1:
|
||||
if DEBUG >= 1 and RemotePCIDevice._start_time == 0.0:
|
||||
RemotePCIDevice._start_time = time.perf_counter()
|
||||
def _print_stats():
|
||||
dt = time.perf_counter() - RemotePCIDevice._start_time
|
||||
|
|
@ -334,11 +332,13 @@ class RemotePCIDevice(PCIDevice):
|
|||
|
||||
@staticmethod
|
||||
@functools.cache
|
||||
def remote_list(vendor:int, devices:tuple[tuple[int, tuple[int, ...]], ...], base_class:int|None) -> list[str]:
|
||||
host, port = (sock:=RemotePCIDevice.remote_sock()).getpeername()
|
||||
def remote_list(vendor:int, devices:tuple[tuple[int, tuple[int, ...]], ...], base_class:int|None) -> list[tuple[socket.socket, str]]:
|
||||
payload = array.array('I', itertools.chain.from_iterable((m, d) for m, ds in devices for d in ds)).tobytes()
|
||||
data_len, _, _, _ = RemotePCIDevice._rpc(sock, 0, RemoteCmd.PROBE, base_class or 0, len(payload), vendor, payload=payload)
|
||||
return [f"remote:{host}:{port}:{d}" for d in RemotePCIDevice._recvall(sock, data_len).decode().split('\n')] if data_len else []
|
||||
def q(r:str) -> list[tuple[socket.socket, str]]:
|
||||
sock = RemotePCIDevice.remote_sock((host:=r.strip().split(":")[0]), (port:=int(r.strip().split(":")[1]) if ":" in r else 6667))
|
||||
data_len, _, _, _ = RemotePCIDevice._rpc(sock, 0, RemoteCmd.PROBE, base_class or 0, len(payload), vendor, payload=payload)
|
||||
return [(sock, f"remote:{host}:{port}:{d}") for d in RemotePCIDevice._recvall(sock, data_len).decode().split('\n')]
|
||||
return flatten([q(r) for r in getenv("REMOTE", "").split(",") if r.strip()])
|
||||
|
||||
@staticmethod
|
||||
def _recvall(sock:socket.socket, n:int) -> bytes:
|
||||
|
|
@ -359,8 +359,8 @@ class RemotePCIDevice(PCIDevice):
|
|||
RemotePCIDevice._rpc_count += 1
|
||||
return (resp[1], resp[2]) + ((RemotePCIDevice._recvall(sock, readout_size) if readout_size > 0 else None),) + (fd,)
|
||||
|
||||
def __init__(self, devpref:str, pcibus:str, sock:socket.socket|None=None):
|
||||
self.sock, self.pcibus, self.dev_id = sock or self.remote_sock(), pcibus, int(pcibus.split(':')[-1]) if ':' in pcibus else 0
|
||||
def __init__(self, devpref:str, pcibus:str, sock:socket.socket):
|
||||
self.sock, self.pcibus, self.dev_id = sock, pcibus, int(pcibus.split(':')[-1]) if ':' in pcibus else 0
|
||||
for buft in [socket.SO_SNDBUF, socket.SO_RCVBUF]: self.sock.setsockopt(socket.SOL_SOCKET, buft, 64 << 20)
|
||||
|
||||
self.lock_fd = System.flock_acquire(f"{devpref.lower()}_{pcibus.lower()}.lock")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue