mlx: init runtime (#15612)

* mlx: init

* x

* swap
This commit is contained in:
nimlgen 2026-04-05 22:52:29 +03:00 committed by GitHub
commit e3986a6b74
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 191 additions and 70 deletions

View file

@ -1,4 +1,4 @@
import unittest, ctypes, struct, os, random, numpy as np
import unittest, ctypes, struct, os, random, numpy as np, time
from tinygrad import Device, Tensor, dtypes
from tinygrad.helpers import getenv, mv_address, DEBUG
from test.helpers import slow
@ -618,5 +618,32 @@ class TestHCQ(unittest.TestCase):
_check_copy("AMD", "NV")
_check_copy("NV", "AMD")
def test_speed_cross_device_rdma_copy_bandwidth(self):
try: d1 = Device[f"{Device.DEFAULT}:7"]
except Exception: self.skipTest("no multidevice, test skipped")
if TestHCQ.d0.peer_group == d1.peer_group: self.skipTest("devices in same peer group, no RDMA path")
SZ = 200_000_000
a = Buffer(Device.DEFAULT, SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
b = Buffer(f"{Device.DEFAULT}:7", SZ, dtypes.uint8, options=BufferSpec(nolru=True)).allocate()
# warmup
TestHCQ.d0.allocator._transfer(a._buf, b._buf, SZ, src_dev=d1, dest_dev=TestHCQ.d0)
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value - 1)
d1.timeline_signal.wait(d1.timeline_value - 1)
st = time.perf_counter()
TestHCQ.d0.allocator._transfer(a._buf, b._buf, SZ, src_dev=d1, dest_dev=TestHCQ.d0)
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value - 1)
d1.timeline_signal.wait(d1.timeline_value - 1)
et_ms = (time.perf_counter() - st) * 1e3
gb_s = ((SZ / 1e9) / et_ms) * 1e3
print(f"cross device rdma copy: {et_ms:.2f} ms, {gb_s:.2f} GB/s")
assert 1 <= gb_s <= 100
np.testing.assert_equal(a.numpy(), b.numpy(), "failed")
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,110 @@
from __future__ import annotations
import mmap, struct, functools
from typing import cast
from tinygrad.uop.ops import sint
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocatorBase, HCQAllocator, HWQueue, HCQBuffer, FileIOInterface
from tinygrad.runtime.support.system import System, PCIIfaceBase, PCIAllocationMeta
from tinygrad.runtime.support.memory import VirtMapping, AddrSpace
from tinygrad.runtime.support.mlx.mlxdev import MLXDev, MLXQP, to_be
from tinygrad.helpers import unwrap
class RDMACopyQueue(HWQueue):
def __init__(self, dev:RDMADevice):
self.dev = dev
super().__init__()
def _wqe_data(self, buf:HCQBuffer, sz:int, nic:RDMADevice) -> bytes:
cast(HCQAllocatorBase, nic.allocator).map(buf)
return struct.pack('>IIQ', sz, buf.mappings[nic].meta, buf.mappings[nic].va_addr + (buf.va_addr - buf.base.va_addr))
def encode_ring(self, hwq:HWQueue, dev:HCQCompiled, iface:MLXIface, qp:MLXQP, cq_buf:HCQBuffer,
dbr_off:int, dbr_val:sint, cq_dbr_val:sint, cq_ci:int, uar_db_val:sint|None=None):
for buf in [iface.dbr_buf, cq_buf] + ([iface.uar_buf] if uar_db_val is not None else []): cast(HCQAllocator, dev.allocator).map(buf)
hwq.write(iface.dbr_buf.offset(dbr_off), dbr_val)
if uar_db_val is not None: hwq.write(iface.uar_buf.offset(0x800), uar_db_val, b64=True)
hwq.poll_bit(cq_buf.offset((cq_ci & (qp.cq_size - 1)) * 64 + 60, 4), ((cq_ci >> 7) & 1) << 24, mask=0x01000000)
hwq.write(iface.dbr_buf.offset(qp.cq_dbr), cq_dbr_val)
return self
def copy(self, dest:HCQBuffer, src:HCQBuffer, sz:int):
src_qp, dest_qp, _, _ = self.dev.iface.connect(remote_nic:=unwrap(dest.owner).rdma_dev())
sq_wqe = bytearray(64)
sq_wqe[4:8] = struct.pack('>I', (src_qp.qp_info['qpn'] << 8) | 2)
sq_wqe[11] = 0x08 # CE: signal completion
sq_wqe[16:32] = self._wqe_data(src, sz, self.dev)
self.q(remote_nic, bytes(sq_wqe), self._wqe_data(dest, sz, remote_nic))
return self
def _submit(self, dev:RDMADevice):
for remote_nic, sq_wqe, rq_wqe in zip(self._q[0::3], self._q[1::3], self._q[2::3]):
src_qp, dest_qp, _, _ = dev.iface.connect(remote_nic)
assert src_qp.sq_head + 1 - src_qp.cq_ci <= (1 << src_qp.log_sq_size), "SQ ring full"
assert dest_qp.rq_head + 1 - dest_qp.cq_ci <= (1 << dest_qp.log_rq_size), "RQ ring full"
dest_qp.qp_buf.view((dest_qp.rq_head & ((1 << dest_qp.log_rq_size) - 1)) * 16, 16)[:] = rq_wqe
dest_qp.rq_head += 1
sq_view = src_qp.qp_buf.view(src_qp.sq_offset + (src_qp.sq_head & ((1 << src_qp.log_sq_size) - 1)) * 64, 64)
sq_view[:] = struct.pack('>I', (src_qp.sq_head << 8) | 0x0a) + sq_wqe[4:]
src_qp.sq_head += 1
src_qp.cq_ci += 1
dest_qp.cq_ci += 1
class MLXIface(PCIIfaceBase):
def __init__(self, dev:RDMADevice, dev_id:int):
cl, pcibus = System.list_devices(vendor=0x15b3, devices=((0xffff, (0x101b,)),))[dev_id]
self.dev = dev
self.pci_dev = cl("mlx", pcibus)
self.mlx_dev = MLXDev(self.pci_dev, ip=f"10.0.0.{dev_id}")
self.uar_buf = self._buf([self.mlx_dev.pci_dev.bar_info(0)[0] + self.mlx_dev.uar * 0x1000])
self.dbr_buf = self._buf(self.mlx_dev.dbr_paddrs)
def is_bar_small(self) -> bool: return False
def _buf(self, paddrs:list[int]) -> HCQBuffer:
va = FileIOInterface.anon_mmap(0, size:=len(paddrs) * 0x1000, 0, mmap.MAP_PRIVATE | mmap.MAP_ANONYMOUS, 0)
mapping = VirtMapping(va, size, [(p, 0x1000) for p in paddrs], AddrSpace.SYS, uncached=True, snooped=True)
return HCQBuffer(va, size, meta=PCIAllocationMeta(mapping, has_cpu_mapping=False), owner=self.dev)
@functools.cache
def connect(self, remote_nic:RDMADevice) -> tuple[MLXQP, MLXQP, HCQBuffer, HCQBuffer]:
src_qp, dest_qp = MLXQP(self.mlx_dev, log_sq_size=7, log_rq_size=7), MLXQP(remote_nic.iface.mlx_dev, log_sq_size=7, log_rq_size=7)
src_qp.connect(dest_qp)
dest_qp.connect(src_qp)
return src_qp, dest_qp, self._buf(src_qp.cq_paddrs), remote_nic.iface._buf(dest_qp.cq_paddrs)
class RDMAAllocator(HCQAllocatorBase):
def __init__(self, dev:RDMADevice): super().__init__(dev, batch_cnt=0)
def _map(self, buf:HCQBuffer) -> HCQBuffer:
owner = unwrap(buf.base.owner)
bar, paddrs = owner.iface.pci_dev.bar_info(owner.iface.vram_bar)[0], buf.base.meta.mapping.paddrs # type: ignore[attr-defined]
page_sz = (2 << 20) if min(sz for _, sz in paddrs) >= (2 << 20) else (4 << 10)
pages = [bar + p + off for p, sz in paddrs for off in range(0, sz, page_sz)]
return HCQBuffer(bar + paddrs[0][0], buf.base.size, owner=self.dev,
meta=self.dev.iface.mlx_dev.register_mem(pages, len(pages) * page_sz, page_sz.bit_length() - 1))
def _transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev:HCQCompiled, dest_dev:HCQCompiled):
# sync device
src_q = unwrap(dest_dev.hw_compute_queue_t)().wait(src_dev.timeline_signal, src_dev.timeline_value - 1)
dest_q = unwrap(dest_dev.hw_compute_queue_t)().wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1)
# rdma body + encode doorbell rings
src_qp, dest_qp, src_cq_buf, dest_cq_buf = self.dev.iface.connect(remote_nic:=dest_dev.rdma_dev())
RDMACopyQueue(self.dev).copy(dest, src, sz) \
.encode_ring(src_q, src_dev, self.dev.iface, src_qp, src_cq_buf, src_qp.qp_dbr + 4,
to_be('I', src_qp.sq_head + 1), to_be('I', (src_qp.cq_ci + 1) & 0xFFFFFF), src_qp.cq_ci,
uar_db_val=to_be('Q', ((src_qp.sq_head << 8) | 0x0a) << 32 | ((src_qp.qp_info['qpn'] << 8) | 2))) \
.encode_ring(dest_q, dest_dev, remote_nic.iface, dest_qp, dest_cq_buf, dest_qp.qp_dbr,
to_be('I', dest_qp.rq_head + 1), to_be('I', (dest_qp.cq_ci + 1) & 0xFFFFFF), dest_qp.cq_ci) \
.submit(self.dev)
# signal completion
src_q.signal(src_dev.timeline_signal, src_dev.next_timeline()).submit(src_dev)
dest_q.signal(dest_dev.timeline_signal, dest_dev.next_timeline()).submit(dest_dev)
class RDMADevice(HCQCompiled):
def __init__(self, device:str=""):
self.device_id = int(device.split(":")[1]) if ":" in device else 0
self.iface = MLXIface(self, self.device_id)
super().__init__(device, RDMAAllocator(self), [], None, signal_t=None)

View file

@ -5,7 +5,7 @@ try: import fcntl # windows misses that
except ImportError: fcntl = None #type:ignore[assignment]
from tinygrad.helpers import PROFILE, getenv, to_mv, from_mv, cpu_profile, ProfileRangeEvent, select_first_inited, unwrap, suppress_finalizing
from tinygrad.helpers import TracingKey
from tinygrad.device import BufferSpec, Compiled, LRUAllocator, ProfileDeviceEvent, ProfileProgramEvent
from tinygrad.device import Device, BufferSpec, Compiled, LRUAllocator, ProfileDeviceEvent, ProfileProgramEvent
from tinygrad.uop.ops import sym_infer, sint, UOp
from tinygrad.runtime.autogen import libc
from tinygrad.runtime.support.memory import BumpAllocator
@ -491,6 +491,12 @@ class HCQCompiled(Compiled, Generic[SignalType]):
def _is_cpu(self) -> bool: return hasattr(self, 'device') and self.device.split(":")[0] == "CPU"
def rdma_dev(self):
for i in itertools.count():
if (dev:=next((d for d in HCQCompiled.peer_groups[self.peer_group] if type(d).__name__ == 'RDMADevice'), None)): return dev
try: Device[f'RDMA:{i}']
except IndexError: raise RuntimeError(f"No RDMA found for peer group '{self.peer_group}'")
def finalize(self):
try: self.synchronize() # Try to finalize device in any case.
except RuntimeError as e: print(f"{self.device} synchronization failed before finalizing: {e}")
@ -513,6 +519,9 @@ class HCQBuffer:
assert self.view is not None, "buffer has no cpu_view"
return self.view
@property
def base(self) -> HCQBuffer: return self._base or self
@property
def mappings(self): return self._mappings if self._base is None else self._base._mappings
@ -602,6 +611,8 @@ class HCQAllocator(HCQAllocatorBase, Generic[HCQDeviceType]):
dest.cast('B')[i:i+lsize] = self.b[0].cpu_view().view(size=lsize, fmt='B')[:]
def _transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev:HCQDeviceType, dest_dev:HCQDeviceType):
if src_dev.peer_group != dest_dev.peer_group: return src_dev.rdma_dev().allocator._transfer(dest, src, sz, src_dev, dest_dev)
cast(HCQAllocator, src_dev.allocator).map(dest)
assert src_dev.hw_copy_queue_t is not None

View file

View file

@ -1,13 +1,14 @@
import struct, time, random, json, sys, socket, ctypes, os, functools, itertools
from tinygrad.helpers import getenv, wait_cond, next_power2, ceildiv, DEBUG, hi32, lo32
from __future__ import annotations
import struct, random, socket, ctypes, functools, itertools
from tinygrad.helpers import getenv, wait_cond, round_up, next_power2, ceildiv, DEBUG, hi32, lo32
from tinygrad.runtime.support.memory import BumpAllocator
from tinygrad.runtime.support.system import PCIDevice
from tinygrad.runtime.autogen import mlx5, pci
MLX_DEBUG = getenv("MLX_DEBUG", 0)
MLX5_CMD_STRUCTS = {v: (getattr(mlx5, f"struct_mlx5_ifc_{n[12:].lower()}_in_bits", None), getattr(mlx5, f"struct_mlx5_ifc_{n[12:].lower()}_out_bits", None))
for n, v in mlx5.__dict__.items() if n.startswith("MLX5_CMD_OP_")}
MLX5_CMD_STRUCTS = {v: (getattr(mlx5, f"struct_mlx5_ifc_{n[12:].lower()}_in_bits", None),
getattr(mlx5, f"struct_mlx5_ifc_{n[12:].lower()}_out_bits", None)) for n, v in mlx5.__dict__.items() if n.startswith("MLX5_CMD_OP_")}
MLX5_CMD_STRUCTS[mlx5.MLX5_CMD_OP_ACCESS_REG] = (mlx5.struct_mlx5_ifc_access_register_in_bits, mlx5.struct_mlx5_ifc_access_register_out_bits)
def to_be(fmt, val): return struct.unpack('<'+fmt, struct.pack('>'+fmt, val))[0]
@ -52,8 +53,8 @@ class MLXCmdQueue:
self.log_stride, self.max_reg_cmds = cmd_l & 0xF, (1 << ((cmd_l >> 4) & 0xF)) - 1
stride = next_power2(ctypes.sizeof(mlx5.struct_mlx5_cmd_prot_block))
self.queue, self.queue_paddrs = dev.pci_dev.alloc_sysmem(0x1000 + 128 * stride)
self.mboxes = [(off:=0x1000 + i * stride, self.queue_paddrs[1 + (i * stride) // 0x1000] + (off % 0x1000)) for i in range(128)]
self.queue, self.queue_paddrs = dev.pci_dev.alloc_sysmem(0x1000 + 1024 * stride)
self.mboxes = [(off:=0x1000 + i * stride, self.queue_paddrs[1 + (i * stride) // 0x1000] + (off % 0x1000)) for i in range(1024)]
dev.iseg_w('cmdq_addr_h', hi32(self.queue_paddrs[0]))
dev.iseg_w('cmdq_addr_l_sz', lo32(self.queue_paddrs[0]) | cmd_l)
@ -84,18 +85,19 @@ class MLXCmdQueue:
_in=[int.from_bytes(inp[i:i+4], 'little') for i in range(0, 16, 4)],
out_ptr=to_be('Q', out_ptr), outlen=to_be('I', 16 + out_sz), token=tok, status_own=mlx5.CMD_OWNER_HW)
cmd_bytes = bytearray(bytes(cmd))
cmd_bytes[mlx5.struct_mlx5_cmd_layout.sig.offset] = (~functools.reduce(lambda a, b: a ^ b, cmd_bytes)) & 0xFF
cmd_bytes[mlx5.struct_mlx5_cmd_layout.sig.offset] = (~functools.reduce(lambda a, b: a ^ b, cmd_bytes)) & 0xFF # type: ignore[attr-defined]
# submit and wait for completion
slot_view = self.queue.view(slot << self.log_stride, len(cmd_bytes))
slot_view[:] = cmd_bytes
self.dev.iseg_w('cmd_dbell', 1 << slot)
wait_cond(lambda: slot_view[mlx5.struct_mlx5_cmd_layout.status_own.offset] & mlx5.CMD_OWNER_HW, value=0, msg=f"cmd 0x{opcode:04x}")
wait_cond(lambda: slot_view[mlx5.struct_mlx5_cmd_layout.status_own.offset] & mlx5.CMD_OWNER_HW, value=0, # type: ignore[attr-defined]
msg=f"cmd 0x{opcode:04x}")
# check status and read output
assert slot_view[mlx5.struct_mlx5_cmd_layout.status_own.offset] >> 1 == 0, f"cmd 0x{opcode:04x} delivery error"
assert slot_view[mlx5.struct_mlx5_cmd_layout.status_own.offset] >> 1 == 0, f"cmd 0x{opcode:04x} delivery error" # type: ignore[attr-defined]
out_view = slot_view.view(mlx5.struct_mlx5_cmd_layout.out.offset, 16 + out_sz)
out_view = slot_view.view(mlx5.struct_mlx5_cmd_layout.out.offset, 16 + out_sz) # type: ignore[attr-defined]
status, syndrome = struct.unpack('>I', out_view[0:4])[0] >> 24, struct.unpack('>I', out_view[4:8])[0]
assert status == 0, f"cmd 0x{opcode:04x} failed status=0x{status:x} syn=0x{syndrome:08x}"
@ -145,7 +147,7 @@ class MLXDev:
self.cmd.exec(mlx5.MLX5_CMD_OP_MODIFY_NIC_VPORT_CONTEXT, field_select=dict(roce_en=1), nic_vport_context=dict(roce_en=1))
dbr_mem, self.dbr_paddrs = self.pci_dev.alloc_sysmem(0x1000)
self.dbr = dbr_mem.mv.cast('I')
self.dbr = dbr_mem.view(fmt='I')
self.dbr_alloc = BumpAllocator(0x1000, wrap=False)
self.pd = self.cmd.exec(mlx5.MLX5_CMD_OP_ALLOC_PD)['pd']
@ -159,8 +161,19 @@ class MLXDev:
if DEBUG >= 2: print(f"mlx5 {self.devfmt}: booted mac={self.mac.to_bytes(6,'big').hex(':')} mkey=0x{self.mkey:x}")
def register_mem(self, paddrs:list[int], size:int, log_page_size:int=12) -> int:
n = len(paddrs)
mtt = struct.pack(f'>{round_up(n, 2)}Q', *paddrs, *([0] * (round_up(n, 2) - n)))
if MLX_DEBUG >= 1: print(f"mlx5 {self.devfmt}: register_mem pages={n} page_sz={1 << log_page_size} mtt_bytes={len(mtt)}")
self.provide_pages(mlx5.MLX5_INIT_PAGES)
res = self.cmd.exec(mlx5.MLX5_CMD_OP_CREATE_MKEY, translations_octword_actual_size=ceildiv(n, 2), payload=mtt,
memory_key_mkey_entry=dict(access_mode_1_0=1, lr=1, lw=1, rr=1, rw=1, pd=self.pd, qpn=0xFFFFFF, mkey_7_0=(key_lo:=0x33),
start_addr=paddrs[0], len=size, log_page_size=log_page_size, translations_octword_size=ceildiv(n, 2)))
return (res['mkey_index'] << 8) | key_lo
def provide_pages(self, mode):
if (npages:=self.cmd.exec(mlx5.MLX5_CMD_OP_QUERY_PAGES, op_mod=mode)['num_pages']) <= 0: return
if MLX_DEBUG >= 1: print(f"mlx5 {self.devfmt}: provide_pages mode={mode}, {npages} pages")
mem, paddrs = self.pci_dev.alloc_sysmem(npages * 0x1000)
self.cmd.exec(mlx5.MLX5_CMD_OP_MANAGE_PAGES, op_mod=mlx5.MLX5_PAGES_GIVE, input_num_entries=npages, payload=struct.pack(f'>{npages}Q', *paddrs))
@ -189,20 +202,22 @@ class MLXDev:
class MLXQP:
def __init__(self, dev:MLXDev, log_sq_size=4, log_rq_size=4, log_eq_size=7, log_cq_size=7):
self.dev, self.cq_size, self.log_sq_size = dev, 1 << log_cq_size, log_sq_size
self.dev, self.cq_size, self.log_sq_size, self.log_rq_size = dev, 1 << log_cq_size, log_sq_size, log_rq_size
self.cq_dbr, self.qp_dbr = dev.dbr_alloc.alloc(8, alignment=8), dev.dbr_alloc.alloc(8, alignment=8)
# create EQ, CQ
self.eq_mem, self.eq_info = self.create_queue(mlx5.MLX5_CMD_OP_CREATE_EQ, log_eq_size, entry_sz=64, owner_off=31,
self.eq_mem, self.eq_paddrs, self.eq_info = self.create_queue(mlx5.MLX5_CMD_OP_CREATE_EQ, log_eq_size, entry_sz=64, owner_off=31,
eq_context_entry=dict(log_eq_size=log_eq_size, uar_page=dev.uar, log_page_size=0))
self.cq_mem, self.cq_info = self.create_queue(mlx5.MLX5_CMD_OP_CREATE_CQ, log_cq_size, entry_sz=64, owner_off=63,
cq_context=dict(log_cq_size=log_cq_size, uar_page=dev.uar, c_eqn_or_apu_element=self.eq_info['eq_number'], dbr_addr=dev.dbr_paddrs[0] + self.cq_dbr, log_page_size=0))
self.cq_mem, self.cq_paddrs, self.cq_info = self.create_queue(mlx5.MLX5_CMD_OP_CREATE_CQ, log_cq_size, entry_sz=64, owner_off=63,
cq_context=dict(log_cq_size=log_cq_size, uar_page=dev.uar, c_eqn_or_apu_element=self.eq_info['eq_number'],
dbr_addr=dev.dbr_paddrs[0] + self.cq_dbr, log_page_size=0))
# create QP, buffer is RQ (16B stride) + SQ (64B stride)
self.sq_offset = (1 << log_rq_size) << 4
self.qp_buf, self.qp_info = self.create_queue(mlx5.MLX5_CMD_OP_CREATE_QP, log_sq_size, entry_sz=64, owner_off=0, extra_sz=self.sq_offset,
self.qp_buf, self.qp_paddrs, self.qp_info = self.create_queue(mlx5.MLX5_CMD_OP_CREATE_QP, log_sq_size, entry_sz=64,
owner_off=0, extra_sz=self.sq_offset,
qpc=dict(st=0, pm_state=3, pd=dev.pd, cqn_snd=self.cq_info['cqn'], cqn_rcv=self.cq_info['cqn'], log_msg_max=30, log_rq_size=log_rq_size,
log_rq_stride=0, log_sq_size=log_sq_size, rlky=1, uar_page=dev.uar, log_page_size=0, dbr_addr=dev.dbr_paddrs[0] + self.qp_dbr))
@ -210,67 +225,24 @@ class MLXQP:
self.qp_op(mlx5.MLX5_CMD_OP_RST2INIT_QP, qpc_args=dict(log_ack_req_freq=8), addr_args=dict(pkey_index=0, vhca_port_num=1))
self.cq_ci = self.sq_head = self.rq_head = 0
for i in range(self.cq_size): self.cq_mem[i * 64 + 63] = 0x01 # init owner bits so poll_cq waits for real CQEs
if MLX_DEBUG >= 1: print(f"mlx5: QP 0x{self.qp_info['qpn']:x} (EQ={self.eq_info['eq_number']} CQ=0x{self.cq_info['cqn']:x})")
def create_queue(self, opcode, log_size, entry_sz, owner_off, extra_sz=0, **ctx_kw):
mem, paddrs = self.dev.pci_dev.alloc_sysmem((n := ceildiv((1 << log_size) * entry_sz + extra_sz, 0x1000)) * 0x1000)
return mem, self.dev.cmd.exec(opcode, payload=struct.pack(f'>{n}Q', *paddrs), **ctx_kw)
return mem, paddrs, self.dev.cmd.exec(opcode, payload=struct.pack(f'>{n}Q', *paddrs), **ctx_kw)
def qp_op(self, opcode, qpc_args=None, addr_args=None, **kwargs):
qpc_args = dict(st=0, pm_state=3, pd=self.dev.pd, cqn_snd=self.cq_info['cqn'], cqn_rcv=self.cq_info['cqn'], **(qpc_args or {}))
self.dev.cmd.exec(opcode, qpn=self.qp_info['qpn'], qpc=(qpc_args or {}) | {'primary_address_path': addr_args or {}}, **kwargs)
def connect(self, remote_qpn:int, remote_mac:int, remote_gid:int):
def connect(self, remote:MLXQP):
self.qp_op(mlx5.MLX5_CMD_OP_INIT2RTR_QP, opt_param_mask=0x1A,
qpc_args=dict(mtu=3, log_msg_max=self.dev.caps['log_max_msg'], remote_qpn=remote_qpn, log_ack_req_freq=8, log_rra_max=3, rre=1, rwe=1,
min_rnr_nak=12, next_rcv_psn=0),
addr_args=dict(pkey_index=0, src_addr_index=0, hop_limit=64, udp_sport=udp_sport(self.qp_info['qpn'], remote_qpn), vhca_port_num=1,
rmac_47_32=hi32(remote_mac), rmac_31_0=lo32(remote_mac), rgid_rip=remote_gid))
qpc_args=dict(mtu=5, log_msg_max=self.dev.caps['log_max_msg'], remote_qpn=remote.qp_info['qpn'], log_ack_req_freq=8,
log_rra_max=3, rre=1, rwe=1, min_rnr_nak=1, next_rcv_psn=0),
addr_args=dict(pkey_index=0, src_addr_index=0, hop_limit=64, udp_sport=udp_sport(self.qp_info['qpn'], remote.qp_info['qpn']), vhca_port_num=1,
rmac_47_32=hi32(remote.dev.mac), rmac_31_0=lo32(remote.dev.mac), rgid_rip=int.from_bytes(remote.dev.local_gid, 'big')))
self.qp_op(mlx5.MLX5_CMD_OP_RTR2RTS_QP, qpc_args=dict(log_ack_req_freq=8, next_send_psn=0, log_sra_max=3, retry_count=7, rnr_retry=7),
addr_args=dict(ack_timeout=14, vhca_port_num=1))
if MLX_DEBUG >= 1: print(f"mlx5: QP 0x{self.qp_info['qpn']:x} connected (remote=0x{remote_qpn:x})")
def _post_sq(self, wqe_op, ds_count, data):
wqe = self.qp_buf.view(self.sq_offset + (self.sq_head & 0xF) * 64, 64)
wqe[0:8] = struct.pack('>II', (self.sq_head << 8) | wqe_op, (self.qp_info['qpn'] << 8) | ds_count)
wqe[11] = 0x08 # CE: signal completion
wqe[16:16 + len(data)] = data
self.sq_head += 1
self.dev.dbr[self.qp_dbr // 4 + 1] = to_be('I', self.sq_head)
self.dev.uar_view[0x800 // 8] = to_be('Q', int.from_bytes(wqe[0:8], 'big'))
self.poll_cq()
def poll_cq(self, timeout=5.0):
t = time.monotonic()
while True:
cqe = self.cq_mem.view((self.cq_ci & (self.cq_size - 1)) * 64, 64)
opcode, owner = cqe[63] >> 4, cqe[63] & 1
if opcode != 0x0F and owner == (1 if (self.cq_ci & self.cq_size) else 0):
self.cq_ci += 1
self.dev.dbr[self.cq_dbr // 4] = to_be('I', self.cq_ci & 0xFFFFFF)
assert opcode not in (13, 14), f"CQE error: opcode={opcode} syndrome=0x{cqe[55]:02x} vendor=0x{cqe[54]:02x}"
return opcode
if time.monotonic() - t > timeout: raise TimeoutError("CQ poll timeout")
time.sleep(0.0001)
def rdma_write(self, remote_addr, rkey, local_addr, lkey, length):
self._post_sq(0x08, 3, struct.pack('>QI4xIIQ', remote_addr, rkey, length, lkey, local_addr))
def send(self, addr, lkey, length): self._post_sq(0x0a, 2, struct.pack('>IIQ', length, lkey, addr))
if __name__ == "__main__":
dev = MLXDev(PCIDevice("mlx5", getenv("MLX_PCI", "0000:41:00.0")))
qp = MLXQP(dev)
if "--server" in sys.argv:
print(json.dumps({"qpn": qp.qpn, "mac": dev.mac.to_bytes(6,'big').hex(), "gid": dev.local_gid.hex()}), flush=True)
remote = json.loads(input())
qp.connect(remote["qpn"], int(remote["mac"], 16), int(remote["gid"], 16))
print("connected", flush=True)
target_mem, target_paddrs = dev.pci_dev.alloc_sysmem(0x1000)
print(json.dumps({"target_addr": target_paddrs[0], "rkey": dev.mkey}), flush=True)
input()
data = bytes(target_mem[i] for i in range(64))
as_text = data.rstrip(b'\x00').decode('ascii', errors='replace')
print(f"RECEIVED: {data.hex(' ')}\nAS TEXT: {as_text}", flush=True)
if MLX_DEBUG >= 1: print(f"mlx5: QP 0x{self.qp_info['qpn']:x} connected (remote=0x{remote.qp_info['qpn']:x})")

View file

@ -263,7 +263,8 @@ class PCIIfaceBase:
return HCQBuffer(mapping.va_addr, size, view=barview, meta=PCIAllocationMeta(mapping, cpu_access, hMemory=mapping.paddrs[0][0]), owner=self.dev)
def free(self, b:HCQBuffer):
for dev in b.mapped_devs[1:]: dev.iface.dev_impl.mm.unmap_range(b.va_addr, b.size)
for dev in b.mapped_devs[1:]:
if hasattr(dev.iface, 'dev_impl'): dev.iface.dev_impl.mm.unmap_range(b.va_addr, b.size)
if b.meta.mapping.aspace is AddrSpace.PHYS: self.dev_impl.mm.vfree(b.meta.mapping)
if self.is_local() and b.owner == self.dev and b.meta.has_cpu_mapping: FileIOInterface.munmap(b.va_addr, b.size)