remote: support several hosts (#15585)

* remote: support several hossts

* f
This commit is contained in:
nimlgen 2026-04-03 11:22:15 +03:00 committed by GitHub
commit 237084b276
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 16 additions and 15 deletions

View file

@ -13,8 +13,9 @@ if __name__ == "__main__":
devs = RemotePCIDevice.remote_list(0x1002, ((0, (0,)),), 0) or RemotePCIDevice.remote_list(0x10de, ((0, (0,)),), 0x03)
if not devs: raise RuntimeError("no GPU found on remote")
pci = RemotePCIDevice("BN", devs[0])
print(f"connected to {os.environ['REMOTE']}, device: {devs[0]}\n")
sock, name = devs[0]
pci = RemotePCIDevice("BN", name, sock=sock)
print(f"connected to {os.environ['REMOTE']}, device: {name}\n")
# ping (minimal server round-trip, no device I/O)
from tinygrad.runtime.support.system import RemoteCmd

View file

@ -1,6 +1,6 @@
from __future__ import annotations
import os, mmap, array, functools, ctypes, select, contextlib, dataclasses, sys, itertools, struct, socket, subprocess, time, enum, atexit
from tinygrad.helpers import round_up, getenv, OSX, temp, ceildiv, unwrap, fetch, system, _ensure_downloads_dir, DEBUG
from tinygrad.helpers import round_up, getenv, OSX, temp, ceildiv, unwrap, fetch, system, _ensure_downloads_dir, DEBUG, flatten
from tinygrad.runtime.autogen import libc, pci, vfio, iokit, corefoundation
from tinygrad.runtime.support.hcq import FileIOInterface, MMIOInterface, HCQBuffer, hcq_filter_visible_devices
from tinygrad.runtime.support.memory import VirtMapping, AddrSpace, BumpAllocator
@ -74,8 +74,8 @@ class _System:
return sorted([val for vndr, device, val in all_devs if vndr == vendor and any((device & mask) in devlist for mask, devlist in devices)])
@functools.cache
def list_devices(self, vendor:int, devices:tuple[tuple[int, tuple[int, ...]], ...], base_class:int|None=None) -> list[tuple[type, str]]:
if getenv("REMOTE", ""): return [(RemotePCIDevice, x) for x in RemotePCIDevice.remote_list(vendor, devices, base_class)]
def list_devices(self, vendor:int, devices:tuple[tuple[int, tuple[int, ...]], ...], base_class:int|None=None):
if getenv("REMOTE", ""): return [(functools.partial(RemotePCIDevice,sock=s), x) for s,x in RemotePCIDevice.remote_list(vendor,devices,base_class)]
return [(APLRemotePCIDevice if OSX else PCIDevice, x) for x in System.pci_scan_bus(vendor, devices, base_class)]
def pci_probe_device(self, devpref:str, dev_id:int, vendor:int, devices:tuple[tuple[int, tuple[int, ...]], ...], base_class:int|None=None):
@ -314,15 +314,13 @@ class RemotePCIDevice(PCIDevice):
@staticmethod
@functools.cache
def remote_sock() -> socket.socket:
host_port = getenv("REMOTE", "127.0.0.1:6667").split(":")
host, port = host_port[0], int(host_port[1]) if len(host_port) > 1 else 6667
def remote_sock(host:str, port:int) -> socket.socket:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
sock.settimeout(getenv("REMOTE_TIMEOUT", 3))
sock.connect((host, port))
sock.settimeout(None)
if DEBUG >= 1:
if DEBUG >= 1 and RemotePCIDevice._start_time == 0.0:
RemotePCIDevice._start_time = time.perf_counter()
def _print_stats():
dt = time.perf_counter() - RemotePCIDevice._start_time
@ -334,11 +332,13 @@ class RemotePCIDevice(PCIDevice):
@staticmethod
@functools.cache
def remote_list(vendor:int, devices:tuple[tuple[int, tuple[int, ...]], ...], base_class:int|None) -> list[str]:
host, port = (sock:=RemotePCIDevice.remote_sock()).getpeername()
def remote_list(vendor:int, devices:tuple[tuple[int, tuple[int, ...]], ...], base_class:int|None) -> list[tuple[socket.socket, str]]:
payload = array.array('I', itertools.chain.from_iterable((m, d) for m, ds in devices for d in ds)).tobytes()
data_len, _, _, _ = RemotePCIDevice._rpc(sock, 0, RemoteCmd.PROBE, base_class or 0, len(payload), vendor, payload=payload)
return [f"remote:{host}:{port}:{d}" for d in RemotePCIDevice._recvall(sock, data_len).decode().split('\n')] if data_len else []
def q(r:str) -> list[tuple[socket.socket, str]]:
sock = RemotePCIDevice.remote_sock((host:=r.strip().split(":")[0]), (port:=int(r.strip().split(":")[1]) if ":" in r else 6667))
data_len, _, _, _ = RemotePCIDevice._rpc(sock, 0, RemoteCmd.PROBE, base_class or 0, len(payload), vendor, payload=payload)
return [(sock, f"remote:{host}:{port}:{d}") for d in RemotePCIDevice._recvall(sock, data_len).decode().split('\n')]
return flatten([q(r) for r in getenv("REMOTE", "").split(",") if r.strip()])
@staticmethod
def _recvall(sock:socket.socket, n:int) -> bytes:
@ -359,8 +359,8 @@ class RemotePCIDevice(PCIDevice):
RemotePCIDevice._rpc_count += 1
return (resp[1], resp[2]) + ((RemotePCIDevice._recvall(sock, readout_size) if readout_size > 0 else None),) + (fd,)
def __init__(self, devpref:str, pcibus:str, sock:socket.socket|None=None):
self.sock, self.pcibus, self.dev_id = sock or self.remote_sock(), pcibus, int(pcibus.split(':')[-1]) if ':' in pcibus else 0
def __init__(self, devpref:str, pcibus:str, sock:socket.socket):
self.sock, self.pcibus, self.dev_id = sock, pcibus, int(pcibus.split(':')[-1]) if ':' in pcibus else 0
for buft in [socket.SO_SNDBUF, socket.SO_RCVBUF]: self.sock.setsockopt(socket.SOL_SOCKET, buft, 64 << 20)
self.lock_fd = System.flock_acquire(f"{devpref.lower()}_{pcibus.lower()}.lock")