mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
e0988dbae5
commit
e3986a6b74
6 changed files with 191 additions and 70 deletions
|
|
@ -1,4 +1,4 @@
|
|||
import unittest, ctypes, struct, os, random, numpy as np
|
||||
import unittest, ctypes, struct, os, random, numpy as np, time
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad.helpers import getenv, mv_address, DEBUG
|
||||
from test.helpers import slow
|
||||
|
|
@ -618,5 +618,32 @@ class TestHCQ(unittest.TestCase):
|
|||
_check_copy("AMD", "NV")
|
||||
_check_copy("NV", "AMD")
|
||||
|
||||
def test_speed_cross_device_rdma_copy_bandwidth(self):
|
||||
try: d1 = Device[f"{Device.DEFAULT}:7"]
|
||||
except Exception: self.skipTest("no multidevice, test skipped")
|
||||
|
||||
if TestHCQ.d0.peer_group == d1.peer_group: self.skipTest("devices in same peer group, no RDMA path")
|
||||
|
||||
SZ = 200_000_000
|
||||
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
|
||||
b = Buffer(f"{Device.DEFAULT}:7", SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
|
||||
|
||||
# warmup
|
||||
TestHCQ.d0.allocator._transfer(a._buf, b._buf, SZ, src_dev=d1, dest_dev=TestHCQ.d0)
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value - 1)
|
||||
d1.timeline_signal.wait(d1.timeline_value - 1)
|
||||
|
||||
st = time.perf_counter()
|
||||
TestHCQ.d0.allocator._transfer(a._buf, b._buf, SZ, src_dev=d1, dest_dev=TestHCQ.d0)
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value - 1)
|
||||
d1.timeline_signal.wait(d1.timeline_value - 1)
|
||||
et_ms = (time.perf_counter() - st) * 1e3
|
||||
|
||||
gb_s = ((SZ / 1e9) / et_ms) * 1e3
|
||||
print(f"cross device rdma copy: {et_ms:.2f} ms, {gb_s:.2f} GB/s")
|
||||
assert 1 <= gb_s <= 100
|
||||
|
||||
np.testing.assert_equal(a.numpy(), b.numpy(), "failed")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
110
tinygrad/runtime/ops_rdma.py
Normal file
110
tinygrad/runtime/ops_rdma.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
from __future__ import annotations
|
||||
import mmap, struct, functools
|
||||
from typing import cast
|
||||
from tinygrad.uop.ops import sint
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocatorBase, HCQAllocator, HWQueue, HCQBuffer, FileIOInterface
|
||||
from tinygrad.runtime.support.system import System, PCIIfaceBase, PCIAllocationMeta
|
||||
from tinygrad.runtime.support.memory import VirtMapping, AddrSpace
|
||||
from tinygrad.runtime.support.mlx.mlxdev import MLXDev, MLXQP, to_be
|
||||
from tinygrad.helpers import unwrap
|
||||
|
||||
class RDMACopyQueue(HWQueue):
|
||||
def __init__(self, dev:RDMADevice):
|
||||
self.dev = dev
|
||||
super().__init__()
|
||||
|
||||
def _wqe_data(self, buf:HCQBuffer, sz:int, nic:RDMADevice) -> bytes:
|
||||
cast(HCQAllocatorBase, nic.allocator).map(buf)
|
||||
return struct.pack('>IIQ', sz, buf.mappings[nic].meta, buf.mappings[nic].va_addr + (buf.va_addr - buf.base.va_addr))
|
||||
|
||||
def encode_ring(self, hwq:HWQueue, dev:HCQCompiled, iface:MLXIface, qp:MLXQP, cq_buf:HCQBuffer,
|
||||
dbr_off:int, dbr_val:sint, cq_dbr_val:sint, cq_ci:int, uar_db_val:sint|None=None):
|
||||
for buf in [iface.dbr_buf, cq_buf] + ([iface.uar_buf] if uar_db_val is not None else []): cast(HCQAllocator, dev.allocator).map(buf)
|
||||
hwq.write(iface.dbr_buf.offset(dbr_off), dbr_val)
|
||||
if uar_db_val is not None: hwq.write(iface.uar_buf.offset(0x800), uar_db_val, b64=True)
|
||||
hwq.poll_bit(cq_buf.offset((cq_ci & (qp.cq_size - 1)) * 64 + 60, 4), ((cq_ci >> 7) & 1) << 24, mask=0x01000000)
|
||||
hwq.write(iface.dbr_buf.offset(qp.cq_dbr), cq_dbr_val)
|
||||
return self
|
||||
|
||||
def copy(self, dest:HCQBuffer, src:HCQBuffer, sz:int):
|
||||
src_qp, dest_qp, _, _ = self.dev.iface.connect(remote_nic:=unwrap(dest.owner).rdma_dev())
|
||||
|
||||
sq_wqe = bytearray(64)
|
||||
sq_wqe[4:8] = struct.pack('>I', (src_qp.qp_info['qpn'] << 8) | 2)
|
||||
sq_wqe[11] = 0x08 # CE: signal completion
|
||||
sq_wqe[16:32] = self._wqe_data(src, sz, self.dev)
|
||||
|
||||
self.q(remote_nic, bytes(sq_wqe), self._wqe_data(dest, sz, remote_nic))
|
||||
return self
|
||||
|
||||
def _submit(self, dev:RDMADevice):
|
||||
for remote_nic, sq_wqe, rq_wqe in zip(self._q[0::3], self._q[1::3], self._q[2::3]):
|
||||
src_qp, dest_qp, _, _ = dev.iface.connect(remote_nic)
|
||||
assert src_qp.sq_head + 1 - src_qp.cq_ci <= (1 << src_qp.log_sq_size), "SQ ring full"
|
||||
assert dest_qp.rq_head + 1 - dest_qp.cq_ci <= (1 << dest_qp.log_rq_size), "RQ ring full"
|
||||
dest_qp.qp_buf.view((dest_qp.rq_head & ((1 << dest_qp.log_rq_size) - 1)) * 16, 16)[:] = rq_wqe
|
||||
dest_qp.rq_head += 1
|
||||
sq_view = src_qp.qp_buf.view(src_qp.sq_offset + (src_qp.sq_head & ((1 << src_qp.log_sq_size) - 1)) * 64, 64)
|
||||
sq_view[:] = struct.pack('>I', (src_qp.sq_head << 8) | 0x0a) + sq_wqe[4:]
|
||||
src_qp.sq_head += 1
|
||||
src_qp.cq_ci += 1
|
||||
dest_qp.cq_ci += 1
|
||||
|
||||
class MLXIface(PCIIfaceBase):
|
||||
def __init__(self, dev:RDMADevice, dev_id:int):
|
||||
cl, pcibus = System.list_devices(vendor=0x15b3, devices=((0xffff, (0x101b,)),))[dev_id]
|
||||
self.dev = dev
|
||||
self.pci_dev = cl("mlx", pcibus)
|
||||
self.mlx_dev = MLXDev(self.pci_dev, ip=f"10.0.0.{dev_id}")
|
||||
self.uar_buf = self._buf([self.mlx_dev.pci_dev.bar_info(0)[0] + self.mlx_dev.uar * 0x1000])
|
||||
self.dbr_buf = self._buf(self.mlx_dev.dbr_paddrs)
|
||||
|
||||
def is_bar_small(self) -> bool: return False
|
||||
|
||||
def _buf(self, paddrs:list[int]) -> HCQBuffer:
|
||||
va = FileIOInterface.anon_mmap(0, size:=len(paddrs) * 0x1000, 0, mmap.MAP_PRIVATE | mmap.MAP_ANONYMOUS, 0)
|
||||
mapping = VirtMapping(va, size, [(p, 0x1000) for p in paddrs], AddrSpace.SYS, uncached=True, snooped=True)
|
||||
return HCQBuffer(va, size, meta=PCIAllocationMeta(mapping, has_cpu_mapping=False), owner=self.dev)
|
||||
|
||||
@functools.cache
|
||||
def connect(self, remote_nic:RDMADevice) -> tuple[MLXQP, MLXQP, HCQBuffer, HCQBuffer]:
|
||||
src_qp, dest_qp = MLXQP(self.mlx_dev, log_sq_size=7, log_rq_size=7), MLXQP(remote_nic.iface.mlx_dev, log_sq_size=7, log_rq_size=7)
|
||||
src_qp.connect(dest_qp)
|
||||
dest_qp.connect(src_qp)
|
||||
return src_qp, dest_qp, self._buf(src_qp.cq_paddrs), remote_nic.iface._buf(dest_qp.cq_paddrs)
|
||||
|
||||
class RDMAAllocator(HCQAllocatorBase):
|
||||
def __init__(self, dev:RDMADevice): super().__init__(dev, batch_cnt=0)
|
||||
|
||||
def _map(self, buf:HCQBuffer) -> HCQBuffer:
|
||||
owner = unwrap(buf.base.owner)
|
||||
bar, paddrs = owner.iface.pci_dev.bar_info(owner.iface.vram_bar)[0], buf.base.meta.mapping.paddrs # type: ignore[attr-defined]
|
||||
page_sz = (2 << 20) if min(sz for _, sz in paddrs) >= (2 << 20) else (4 << 10)
|
||||
pages = [bar + p + off for p, sz in paddrs for off in range(0, sz, page_sz)]
|
||||
return HCQBuffer(bar + paddrs[0][0], buf.base.size, owner=self.dev,
|
||||
meta=self.dev.iface.mlx_dev.register_mem(pages, len(pages) * page_sz, page_sz.bit_length() - 1))
|
||||
|
||||
def _transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev:HCQCompiled, dest_dev:HCQCompiled):
|
||||
# sync device
|
||||
src_q = unwrap(dest_dev.hw_compute_queue_t)().wait(src_dev.timeline_signal, src_dev.timeline_value - 1)
|
||||
dest_q = unwrap(dest_dev.hw_compute_queue_t)().wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1)
|
||||
|
||||
# rdma body + encode doorbell rings
|
||||
src_qp, dest_qp, src_cq_buf, dest_cq_buf = self.dev.iface.connect(remote_nic:=dest_dev.rdma_dev())
|
||||
RDMACopyQueue(self.dev).copy(dest, src, sz) \
|
||||
.encode_ring(src_q, src_dev, self.dev.iface, src_qp, src_cq_buf, src_qp.qp_dbr + 4,
|
||||
to_be('I', src_qp.sq_head + 1), to_be('I', (src_qp.cq_ci + 1) & 0xFFFFFF), src_qp.cq_ci,
|
||||
uar_db_val=to_be('Q', ((src_qp.sq_head << 8) | 0x0a) << 32 | ((src_qp.qp_info['qpn'] << 8) | 2))) \
|
||||
.encode_ring(dest_q, dest_dev, remote_nic.iface, dest_qp, dest_cq_buf, dest_qp.qp_dbr,
|
||||
to_be('I', dest_qp.rq_head + 1), to_be('I', (dest_qp.cq_ci + 1) & 0xFFFFFF), dest_qp.cq_ci) \
|
||||
.submit(self.dev)
|
||||
|
||||
# signal completion
|
||||
src_q.signal(src_dev.timeline_signal, src_dev.next_timeline()).submit(src_dev)
|
||||
dest_q.signal(dest_dev.timeline_signal, dest_dev.next_timeline()).submit(dest_dev)
|
||||
|
||||
class RDMADevice(HCQCompiled):
|
||||
def __init__(self, device:str=""):
|
||||
self.device_id = int(device.split(":")[1]) if ":" in device else 0
|
||||
self.iface = MLXIface(self, self.device_id)
|
||||
super().__init__(device, RDMAAllocator(self), [], None, signal_t=None)
|
||||
|
|
@ -5,7 +5,7 @@ try: import fcntl # windows misses that
|
|||
except ImportError: fcntl = None #type:ignore[assignment]
|
||||
from tinygrad.helpers import PROFILE, getenv, to_mv, from_mv, cpu_profile, ProfileRangeEvent, select_first_inited, unwrap, suppress_finalizing
|
||||
from tinygrad.helpers import TracingKey
|
||||
from tinygrad.device import BufferSpec, Compiled, LRUAllocator, ProfileDeviceEvent, ProfileProgramEvent
|
||||
from tinygrad.device import Device, BufferSpec, Compiled, LRUAllocator, ProfileDeviceEvent, ProfileProgramEvent
|
||||
from tinygrad.uop.ops import sym_infer, sint, UOp
|
||||
from tinygrad.runtime.autogen import libc
|
||||
from tinygrad.runtime.support.memory import BumpAllocator
|
||||
|
|
@ -491,6 +491,12 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
|||
|
||||
def _is_cpu(self) -> bool: return hasattr(self, 'device') and self.device.split(":")[0] == "CPU"
|
||||
|
||||
def rdma_dev(self):
|
||||
for i in itertools.count():
|
||||
if (dev:=next((d for d in HCQCompiled.peer_groups[self.peer_group] if type(d).__name__ == 'RDMADevice'), None)): return dev
|
||||
try: Device[f'RDMA:{i}']
|
||||
except IndexError: raise RuntimeError(f"No RDMA found for peer group '{self.peer_group}'")
|
||||
|
||||
def finalize(self):
|
||||
try: self.synchronize() # Try to finalize device in any case.
|
||||
except RuntimeError as e: print(f"{self.device} synchronization failed before finalizing: {e}")
|
||||
|
|
@ -513,6 +519,9 @@ class HCQBuffer:
|
|||
assert self.view is not None, "buffer has no cpu_view"
|
||||
return self.view
|
||||
|
||||
@property
|
||||
def base(self) -> HCQBuffer: return self._base or self
|
||||
|
||||
@property
|
||||
def mappings(self): return self._mappings if self._base is None else self._base._mappings
|
||||
|
||||
|
|
@ -602,6 +611,8 @@ class HCQAllocator(HCQAllocatorBase, Generic[HCQDeviceType]):
|
|||
dest.cast('B')[i:i+lsize] = self.b[0].cpu_view().view(size=lsize, fmt='B')[:]
|
||||
|
||||
def _transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev:HCQDeviceType, dest_dev:HCQDeviceType):
|
||||
if src_dev.peer_group != dest_dev.peer_group: return src_dev.rdma_dev().allocator._transfer(dest, src, sz, src_dev, dest_dev)
|
||||
|
||||
cast(HCQAllocator, src_dev.allocator).map(dest)
|
||||
|
||||
assert src_dev.hw_copy_queue_t is not None
|
||||
|
|
|
|||
0
tinygrad/runtime/support/mlx/__init__.py
Normal file
0
tinygrad/runtime/support/mlx/__init__.py
Normal file
|
|
@ -1,13 +1,14 @@
|
|||
import struct, time, random, json, sys, socket, ctypes, os, functools, itertools
|
||||
from tinygrad.helpers import getenv, wait_cond, next_power2, ceildiv, DEBUG, hi32, lo32
|
||||
from __future__ import annotations
|
||||
import struct, random, socket, ctypes, functools, itertools
|
||||
from tinygrad.helpers import getenv, wait_cond, round_up, next_power2, ceildiv, DEBUG, hi32, lo32
|
||||
from tinygrad.runtime.support.memory import BumpAllocator
|
||||
from tinygrad.runtime.support.system import PCIDevice
|
||||
from tinygrad.runtime.autogen import mlx5, pci
|
||||
|
||||
MLX_DEBUG = getenv("MLX_DEBUG", 0)
|
||||
|
||||
MLX5_CMD_STRUCTS = {v: (getattr(mlx5, f"struct_mlx5_ifc_{n[12:].lower()}_in_bits", None), getattr(mlx5, f"struct_mlx5_ifc_{n[12:].lower()}_out_bits", None))
|
||||
for n, v in mlx5.__dict__.items() if n.startswith("MLX5_CMD_OP_")}
|
||||
MLX5_CMD_STRUCTS = {v: (getattr(mlx5, f"struct_mlx5_ifc_{n[12:].lower()}_in_bits", None),
|
||||
getattr(mlx5, f"struct_mlx5_ifc_{n[12:].lower()}_out_bits", None)) for n, v in mlx5.__dict__.items() if n.startswith("MLX5_CMD_OP_")}
|
||||
MLX5_CMD_STRUCTS[mlx5.MLX5_CMD_OP_ACCESS_REG] = (mlx5.struct_mlx5_ifc_access_register_in_bits, mlx5.struct_mlx5_ifc_access_register_out_bits)
|
||||
|
||||
def to_be(fmt, val): return struct.unpack('<'+fmt, struct.pack('>'+fmt, val))[0]
|
||||
|
|
@ -52,8 +53,8 @@ class MLXCmdQueue:
|
|||
self.log_stride, self.max_reg_cmds = cmd_l & 0xF, (1 << ((cmd_l >> 4) & 0xF)) - 1
|
||||
|
||||
stride = next_power2(ctypes.sizeof(mlx5.struct_mlx5_cmd_prot_block))
|
||||
self.queue, self.queue_paddrs = dev.pci_dev.alloc_sysmem(0x1000 + 128 * stride)
|
||||
self.mboxes = [(off:=0x1000 + i * stride, self.queue_paddrs[1 + (i * stride) // 0x1000] + (off % 0x1000)) for i in range(128)]
|
||||
self.queue, self.queue_paddrs = dev.pci_dev.alloc_sysmem(0x1000 + 1024 * stride)
|
||||
self.mboxes = [(off:=0x1000 + i * stride, self.queue_paddrs[1 + (i * stride) // 0x1000] + (off % 0x1000)) for i in range(1024)]
|
||||
|
||||
dev.iseg_w('cmdq_addr_h', hi32(self.queue_paddrs[0]))
|
||||
dev.iseg_w('cmdq_addr_l_sz', lo32(self.queue_paddrs[0]) | cmd_l)
|
||||
|
|
@ -84,18 +85,19 @@ class MLXCmdQueue:
|
|||
_in=[int.from_bytes(inp[i:i+4], 'little') for i in range(0, 16, 4)],
|
||||
out_ptr=to_be('Q', out_ptr), outlen=to_be('I', 16 + out_sz), token=tok, status_own=mlx5.CMD_OWNER_HW)
|
||||
cmd_bytes = bytearray(bytes(cmd))
|
||||
cmd_bytes[mlx5.struct_mlx5_cmd_layout.sig.offset] = (~functools.reduce(lambda a, b: a ^ b, cmd_bytes)) & 0xFF
|
||||
cmd_bytes[mlx5.struct_mlx5_cmd_layout.sig.offset] = (~functools.reduce(lambda a, b: a ^ b, cmd_bytes)) & 0xFF # type: ignore[attr-defined]
|
||||
|
||||
# submit and wait for completion
|
||||
slot_view = self.queue.view(slot << self.log_stride, len(cmd_bytes))
|
||||
slot_view[:] = cmd_bytes
|
||||
self.dev.iseg_w('cmd_dbell', 1 << slot)
|
||||
wait_cond(lambda: slot_view[mlx5.struct_mlx5_cmd_layout.status_own.offset] & mlx5.CMD_OWNER_HW, value=0, msg=f"cmd 0x{opcode:04x}")
|
||||
wait_cond(lambda: slot_view[mlx5.struct_mlx5_cmd_layout.status_own.offset] & mlx5.CMD_OWNER_HW, value=0, # type: ignore[attr-defined]
|
||||
msg=f"cmd 0x{opcode:04x}")
|
||||
|
||||
# check status and read output
|
||||
assert slot_view[mlx5.struct_mlx5_cmd_layout.status_own.offset] >> 1 == 0, f"cmd 0x{opcode:04x} delivery error"
|
||||
assert slot_view[mlx5.struct_mlx5_cmd_layout.status_own.offset] >> 1 == 0, f"cmd 0x{opcode:04x} delivery error" # type: ignore[attr-defined]
|
||||
|
||||
out_view = slot_view.view(mlx5.struct_mlx5_cmd_layout.out.offset, 16 + out_sz)
|
||||
out_view = slot_view.view(mlx5.struct_mlx5_cmd_layout.out.offset, 16 + out_sz) # type: ignore[attr-defined]
|
||||
status, syndrome = struct.unpack('>I', out_view[0:4])[0] >> 24, struct.unpack('>I', out_view[4:8])[0]
|
||||
assert status == 0, f"cmd 0x{opcode:04x} failed status=0x{status:x} syn=0x{syndrome:08x}"
|
||||
|
||||
|
|
@ -145,7 +147,7 @@ class MLXDev:
|
|||
self.cmd.exec(mlx5.MLX5_CMD_OP_MODIFY_NIC_VPORT_CONTEXT, field_select=dict(roce_en=1), nic_vport_context=dict(roce_en=1))
|
||||
|
||||
dbr_mem, self.dbr_paddrs = self.pci_dev.alloc_sysmem(0x1000)
|
||||
self.dbr = dbr_mem.mv.cast('I')
|
||||
self.dbr = dbr_mem.view(fmt='I')
|
||||
self.dbr_alloc = BumpAllocator(0x1000, wrap=False)
|
||||
|
||||
self.pd = self.cmd.exec(mlx5.MLX5_CMD_OP_ALLOC_PD)['pd']
|
||||
|
|
@ -159,8 +161,19 @@ class MLXDev:
|
|||
|
||||
if DEBUG >= 2: print(f"mlx5 {self.devfmt}: booted mac={self.mac.to_bytes(6,'big').hex(':')} mkey=0x{self.mkey:x}")
|
||||
|
||||
def register_mem(self, paddrs:list[int], size:int, log_page_size:int=12) -> int:
|
||||
n = len(paddrs)
|
||||
mtt = struct.pack(f'>{round_up(n, 2)}Q', *paddrs, *([0] * (round_up(n, 2) - n)))
|
||||
if MLX_DEBUG >= 1: print(f"mlx5 {self.devfmt}: register_mem pages={n} page_sz={1 << log_page_size} mtt_bytes={len(mtt)}")
|
||||
self.provide_pages(mlx5.MLX5_INIT_PAGES)
|
||||
res = self.cmd.exec(mlx5.MLX5_CMD_OP_CREATE_MKEY, translations_octword_actual_size=ceildiv(n, 2), payload=mtt,
|
||||
memory_key_mkey_entry=dict(access_mode_1_0=1, lr=1, lw=1, rr=1, rw=1, pd=self.pd, qpn=0xFFFFFF, mkey_7_0=(key_lo:=0x33),
|
||||
start_addr=paddrs[0], len=size, log_page_size=log_page_size, translations_octword_size=ceildiv(n, 2)))
|
||||
return (res['mkey_index'] << 8) | key_lo
|
||||
|
||||
def provide_pages(self, mode):
|
||||
if (npages:=self.cmd.exec(mlx5.MLX5_CMD_OP_QUERY_PAGES, op_mod=mode)['num_pages']) <= 0: return
|
||||
if MLX_DEBUG >= 1: print(f"mlx5 {self.devfmt}: provide_pages mode={mode}, {npages} pages")
|
||||
mem, paddrs = self.pci_dev.alloc_sysmem(npages * 0x1000)
|
||||
self.cmd.exec(mlx5.MLX5_CMD_OP_MANAGE_PAGES, op_mod=mlx5.MLX5_PAGES_GIVE, input_num_entries=npages, payload=struct.pack(f'>{npages}Q', *paddrs))
|
||||
|
||||
|
|
@ -189,20 +202,22 @@ class MLXDev:
|
|||
|
||||
class MLXQP:
|
||||
def __init__(self, dev:MLXDev, log_sq_size=4, log_rq_size=4, log_eq_size=7, log_cq_size=7):
|
||||
self.dev, self.cq_size, self.log_sq_size = dev, 1 << log_cq_size, log_sq_size
|
||||
self.dev, self.cq_size, self.log_sq_size, self.log_rq_size = dev, 1 << log_cq_size, log_sq_size, log_rq_size
|
||||
|
||||
self.cq_dbr, self.qp_dbr = dev.dbr_alloc.alloc(8, alignment=8), dev.dbr_alloc.alloc(8, alignment=8)
|
||||
|
||||
# create EQ, CQ
|
||||
self.eq_mem, self.eq_info = self.create_queue(mlx5.MLX5_CMD_OP_CREATE_EQ, log_eq_size, entry_sz=64, owner_off=31,
|
||||
self.eq_mem, self.eq_paddrs, self.eq_info = self.create_queue(mlx5.MLX5_CMD_OP_CREATE_EQ, log_eq_size, entry_sz=64, owner_off=31,
|
||||
eq_context_entry=dict(log_eq_size=log_eq_size, uar_page=dev.uar, log_page_size=0))
|
||||
|
||||
self.cq_mem, self.cq_info = self.create_queue(mlx5.MLX5_CMD_OP_CREATE_CQ, log_cq_size, entry_sz=64, owner_off=63,
|
||||
cq_context=dict(log_cq_size=log_cq_size, uar_page=dev.uar, c_eqn_or_apu_element=self.eq_info['eq_number'], dbr_addr=dev.dbr_paddrs[0] + self.cq_dbr, log_page_size=0))
|
||||
self.cq_mem, self.cq_paddrs, self.cq_info = self.create_queue(mlx5.MLX5_CMD_OP_CREATE_CQ, log_cq_size, entry_sz=64, owner_off=63,
|
||||
cq_context=dict(log_cq_size=log_cq_size, uar_page=dev.uar, c_eqn_or_apu_element=self.eq_info['eq_number'],
|
||||
dbr_addr=dev.dbr_paddrs[0] + self.cq_dbr, log_page_size=0))
|
||||
|
||||
# create QP, buffer is RQ (16B stride) + SQ (64B stride)
|
||||
self.sq_offset = (1 << log_rq_size) << 4
|
||||
self.qp_buf, self.qp_info = self.create_queue(mlx5.MLX5_CMD_OP_CREATE_QP, log_sq_size, entry_sz=64, owner_off=0, extra_sz=self.sq_offset,
|
||||
self.qp_buf, self.qp_paddrs, self.qp_info = self.create_queue(mlx5.MLX5_CMD_OP_CREATE_QP, log_sq_size, entry_sz=64,
|
||||
owner_off=0, extra_sz=self.sq_offset,
|
||||
qpc=dict(st=0, pm_state=3, pd=dev.pd, cqn_snd=self.cq_info['cqn'], cqn_rcv=self.cq_info['cqn'], log_msg_max=30, log_rq_size=log_rq_size,
|
||||
log_rq_stride=0, log_sq_size=log_sq_size, rlky=1, uar_page=dev.uar, log_page_size=0, dbr_addr=dev.dbr_paddrs[0] + self.qp_dbr))
|
||||
|
||||
|
|
@ -210,67 +225,24 @@ class MLXQP:
|
|||
self.qp_op(mlx5.MLX5_CMD_OP_RST2INIT_QP, qpc_args=dict(log_ack_req_freq=8), addr_args=dict(pkey_index=0, vhca_port_num=1))
|
||||
|
||||
self.cq_ci = self.sq_head = self.rq_head = 0
|
||||
for i in range(self.cq_size): self.cq_mem[i * 64 + 63] = 0x01 # init owner bits so poll_cq waits for real CQEs
|
||||
if MLX_DEBUG >= 1: print(f"mlx5: QP 0x{self.qp_info['qpn']:x} (EQ={self.eq_info['eq_number']} CQ=0x{self.cq_info['cqn']:x})")
|
||||
|
||||
def create_queue(self, opcode, log_size, entry_sz, owner_off, extra_sz=0, **ctx_kw):
|
||||
mem, paddrs = self.dev.pci_dev.alloc_sysmem((n := ceildiv((1 << log_size) * entry_sz + extra_sz, 0x1000)) * 0x1000)
|
||||
return mem, self.dev.cmd.exec(opcode, payload=struct.pack(f'>{n}Q', *paddrs), **ctx_kw)
|
||||
return mem, paddrs, self.dev.cmd.exec(opcode, payload=struct.pack(f'>{n}Q', *paddrs), **ctx_kw)
|
||||
|
||||
def qp_op(self, opcode, qpc_args=None, addr_args=None, **kwargs):
|
||||
qpc_args = dict(st=0, pm_state=3, pd=self.dev.pd, cqn_snd=self.cq_info['cqn'], cqn_rcv=self.cq_info['cqn'], **(qpc_args or {}))
|
||||
self.dev.cmd.exec(opcode, qpn=self.qp_info['qpn'], qpc=(qpc_args or {}) | {'primary_address_path': addr_args or {}}, **kwargs)
|
||||
|
||||
def connect(self, remote_qpn:int, remote_mac:int, remote_gid:int):
|
||||
def connect(self, remote:MLXQP):
|
||||
self.qp_op(mlx5.MLX5_CMD_OP_INIT2RTR_QP, opt_param_mask=0x1A,
|
||||
qpc_args=dict(mtu=3, log_msg_max=self.dev.caps['log_max_msg'], remote_qpn=remote_qpn, log_ack_req_freq=8, log_rra_max=3, rre=1, rwe=1,
|
||||
min_rnr_nak=12, next_rcv_psn=0),
|
||||
addr_args=dict(pkey_index=0, src_addr_index=0, hop_limit=64, udp_sport=udp_sport(self.qp_info['qpn'], remote_qpn), vhca_port_num=1,
|
||||
rmac_47_32=hi32(remote_mac), rmac_31_0=lo32(remote_mac), rgid_rip=remote_gid))
|
||||
qpc_args=dict(mtu=5, log_msg_max=self.dev.caps['log_max_msg'], remote_qpn=remote.qp_info['qpn'], log_ack_req_freq=8,
|
||||
log_rra_max=3, rre=1, rwe=1, min_rnr_nak=1, next_rcv_psn=0),
|
||||
addr_args=dict(pkey_index=0, src_addr_index=0, hop_limit=64, udp_sport=udp_sport(self.qp_info['qpn'], remote.qp_info['qpn']), vhca_port_num=1,
|
||||
rmac_47_32=hi32(remote.dev.mac), rmac_31_0=lo32(remote.dev.mac), rgid_rip=int.from_bytes(remote.dev.local_gid, 'big')))
|
||||
self.qp_op(mlx5.MLX5_CMD_OP_RTR2RTS_QP, qpc_args=dict(log_ack_req_freq=8, next_send_psn=0, log_sra_max=3, retry_count=7, rnr_retry=7),
|
||||
addr_args=dict(ack_timeout=14, vhca_port_num=1))
|
||||
|
||||
if MLX_DEBUG >= 1: print(f"mlx5: QP 0x{self.qp_info['qpn']:x} connected (remote=0x{remote_qpn:x})")
|
||||
|
||||
def _post_sq(self, wqe_op, ds_count, data):
|
||||
wqe = self.qp_buf.view(self.sq_offset + (self.sq_head & 0xF) * 64, 64)
|
||||
wqe[0:8] = struct.pack('>II', (self.sq_head << 8) | wqe_op, (self.qp_info['qpn'] << 8) | ds_count)
|
||||
wqe[11] = 0x08 # CE: signal completion
|
||||
wqe[16:16 + len(data)] = data
|
||||
|
||||
self.sq_head += 1
|
||||
self.dev.dbr[self.qp_dbr // 4 + 1] = to_be('I', self.sq_head)
|
||||
self.dev.uar_view[0x800 // 8] = to_be('Q', int.from_bytes(wqe[0:8], 'big'))
|
||||
self.poll_cq()
|
||||
|
||||
def poll_cq(self, timeout=5.0):
|
||||
t = time.monotonic()
|
||||
while True:
|
||||
cqe = self.cq_mem.view((self.cq_ci & (self.cq_size - 1)) * 64, 64)
|
||||
opcode, owner = cqe[63] >> 4, cqe[63] & 1
|
||||
if opcode != 0x0F and owner == (1 if (self.cq_ci & self.cq_size) else 0):
|
||||
self.cq_ci += 1
|
||||
self.dev.dbr[self.cq_dbr // 4] = to_be('I', self.cq_ci & 0xFFFFFF)
|
||||
assert opcode not in (13, 14), f"CQE error: opcode={opcode} syndrome=0x{cqe[55]:02x} vendor=0x{cqe[54]:02x}"
|
||||
return opcode
|
||||
if time.monotonic() - t > timeout: raise TimeoutError("CQ poll timeout")
|
||||
time.sleep(0.0001)
|
||||
|
||||
def rdma_write(self, remote_addr, rkey, local_addr, lkey, length):
|
||||
self._post_sq(0x08, 3, struct.pack('>QI4xIIQ', remote_addr, rkey, length, lkey, local_addr))
|
||||
def send(self, addr, lkey, length): self._post_sq(0x0a, 2, struct.pack('>IIQ', length, lkey, addr))
|
||||
|
||||
if __name__ == "__main__":
|
||||
dev = MLXDev(PCIDevice("mlx5", getenv("MLX_PCI", "0000:41:00.0")))
|
||||
qp = MLXQP(dev)
|
||||
|
||||
if "--server" in sys.argv:
|
||||
print(json.dumps({"qpn": qp.qpn, "mac": dev.mac.to_bytes(6,'big').hex(), "gid": dev.local_gid.hex()}), flush=True)
|
||||
remote = json.loads(input())
|
||||
qp.connect(remote["qpn"], int(remote["mac"], 16), int(remote["gid"], 16))
|
||||
print("connected", flush=True)
|
||||
target_mem, target_paddrs = dev.pci_dev.alloc_sysmem(0x1000)
|
||||
print(json.dumps({"target_addr": target_paddrs[0], "rkey": dev.mkey}), flush=True)
|
||||
input()
|
||||
data = bytes(target_mem[i] for i in range(64))
|
||||
as_text = data.rstrip(b'\x00').decode('ascii', errors='replace')
|
||||
print(f"RECEIVED: {data.hex(' ')}\nAS TEXT: {as_text}", flush=True)
|
||||
if MLX_DEBUG >= 1: print(f"mlx5: QP 0x{self.qp_info['qpn']:x} connected (remote=0x{remote.qp_info['qpn']:x})")
|
||||
|
|
@ -263,7 +263,8 @@ class PCIIfaceBase:
|
|||
return HCQBuffer(mapping.va_addr, size, view=barview, meta=PCIAllocationMeta(mapping, cpu_access, hMemory=mapping.paddrs[0][0]), owner=self.dev)
|
||||
|
||||
def free(self, b:HCQBuffer):
|
||||
for dev in b.mapped_devs[1:]: dev.iface.dev_impl.mm.unmap_range(b.va_addr, b.size)
|
||||
for dev in b.mapped_devs[1:]:
|
||||
if hasattr(dev.iface, 'dev_impl'): dev.iface.dev_impl.mm.unmap_range(b.va_addr, b.size)
|
||||
if b.meta.mapping.aspace is AddrSpace.PHYS: self.dev_impl.mm.vfree(b.meta.mapping)
|
||||
if self.is_local() and b.owner == self.dev and b.meta.has_cpu_mapping: FileIOInterface.munmap(b.va_addr, b.size)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue