mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Merge branch 'master' into new_x86_backend
This commit is contained in:
commit
e9f2e89f8f
57 changed files with 1210 additions and 702 deletions
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
|
|
@ -704,6 +704,8 @@ jobs:
|
|||
# TODO: run all once emulator is faster
|
||||
- name: Run RDNA3 ops tests
|
||||
run: SKIP_SLOW_TEST=1 AMD_LLVM=0 pytest -n=auto test/test_ops.py -k "test_sparse_categorical_crossentropy or test_tril or test_nonzero or test_softmax_argmax" --durations 20
|
||||
- name: Run RDNA4 emulator tests
|
||||
run: MOCKGPU_ARCH=rdna4 python -m pytest test/test_tiny.py -v --durations 20
|
||||
|
||||
testnvidia:
|
||||
strategy:
|
||||
|
|
|
|||
|
|
@ -49,10 +49,11 @@ from tinygrad.helpers import Context, DEBUG, colored
|
|||
from tinygrad.engine.realize import get_runner
|
||||
|
||||
from extra.assembly.amd import decode_inst
|
||||
from extra.assembly.amd.autogen.rdna3.str_pcode import PCODE
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP1_SDST, VOP2, VOP3, VOP3_SDST, VOP3SD, VOP3P, VOPC,
|
||||
DS, FLAT, GLOBAL, SCRATCH, VOPD, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOPDOp)
|
||||
from extra.assembly.amd.dsl import VCC_LO, EXEC_LO, SCC
|
||||
from extra.assembly.amd.autogen.rdna3.str_pcode import PCODE as PCODE_RDNA3
|
||||
from extra.assembly.amd.autogen.rdna4.str_pcode import PCODE as PCODE_RDNA4
|
||||
from extra.assembly.amd.autogen.rdna3 import ins as ir3
|
||||
from extra.assembly.amd.autogen.rdna4 import ins as ir4
|
||||
from extra.assembly.amd.dsl import VCC_LO, EXEC_LO, SCC, ttmp
|
||||
from extra.assembly.amd.autogen.common import Fmt, OpType
|
||||
from extra.assembly.amd.pcode import parse_block, _FUNCS
|
||||
|
||||
|
|
@ -79,15 +80,23 @@ def _apply_src_mods(val: UOp, mod_bit: int, abs_bits: int, neg_bits: int, bits:
|
|||
if neg_bits & (1 << mod_bit): fv = fv.neg()
|
||||
return fv.bitcast(ut).cast(dtypes.uint32) if bits == 16 else fv.bitcast(ut)
|
||||
|
||||
# Map VOPD ops to VOP2 ops for pcode lookup
|
||||
# Map VOPD ops to VOP2 ops for pcode lookup (both RDNA3 and RDNA4)
|
||||
VOPD_TO_VOP2 = {
|
||||
VOPDOp.V_DUAL_FMAC_F32: VOP2Op.V_FMAC_F32_E32, VOPDOp.V_DUAL_MUL_F32: VOP2Op.V_MUL_F32_E32,
|
||||
VOPDOp.V_DUAL_ADD_F32: VOP2Op.V_ADD_F32_E32, VOPDOp.V_DUAL_SUB_F32: VOP2Op.V_SUB_F32_E32,
|
||||
VOPDOp.V_DUAL_SUBREV_F32: VOP2Op.V_SUBREV_F32_E32, VOPDOp.V_DUAL_MAX_F32: VOP2Op.V_MAX_F32_E32,
|
||||
VOPDOp.V_DUAL_MIN_F32: VOP2Op.V_MIN_F32_E32, VOPDOp.V_DUAL_ADD_NC_U32: VOP2Op.V_ADD_NC_U32_E32,
|
||||
VOPDOp.V_DUAL_LSHLREV_B32: VOP2Op.V_LSHLREV_B32_E32, VOPDOp.V_DUAL_AND_B32: VOP2Op.V_AND_B32_E32,
|
||||
VOPDOp.V_DUAL_MOV_B32: VOP1Op.V_MOV_B32_E32, VOPDOp.V_DUAL_CNDMASK_B32: VOP2Op.V_CNDMASK_B32_E32,
|
||||
VOPDOp.V_DUAL_FMAAK_F32: VOP2Op.V_FMAAK_F32_E32, VOPDOp.V_DUAL_FMAMK_F32: VOP2Op.V_FMAMK_F32_E32,
|
||||
ir3.VOPDOp.V_DUAL_FMAC_F32: ir3.VOP2Op.V_FMAC_F32_E32, ir3.VOPDOp.V_DUAL_MUL_F32: ir3.VOP2Op.V_MUL_F32_E32,
|
||||
ir3.VOPDOp.V_DUAL_ADD_F32: ir3.VOP2Op.V_ADD_F32_E32, ir3.VOPDOp.V_DUAL_SUB_F32: ir3.VOP2Op.V_SUB_F32_E32,
|
||||
ir3.VOPDOp.V_DUAL_SUBREV_F32: ir3.VOP2Op.V_SUBREV_F32_E32, ir3.VOPDOp.V_DUAL_MAX_F32: ir3.VOP2Op.V_MAX_F32_E32,
|
||||
ir3.VOPDOp.V_DUAL_MIN_F32: ir3.VOP2Op.V_MIN_F32_E32, ir3.VOPDOp.V_DUAL_ADD_NC_U32: ir3.VOP2Op.V_ADD_NC_U32_E32,
|
||||
ir3.VOPDOp.V_DUAL_LSHLREV_B32: ir3.VOP2Op.V_LSHLREV_B32_E32, ir3.VOPDOp.V_DUAL_AND_B32: ir3.VOP2Op.V_AND_B32_E32,
|
||||
ir3.VOPDOp.V_DUAL_MOV_B32: ir3.VOP1Op.V_MOV_B32_E32, ir3.VOPDOp.V_DUAL_CNDMASK_B32: ir3.VOP2Op.V_CNDMASK_B32_E32,
|
||||
ir3.VOPDOp.V_DUAL_FMAAK_F32: ir3.VOP2Op.V_FMAAK_F32_E32, ir3.VOPDOp.V_DUAL_FMAMK_F32: ir3.VOP2Op.V_FMAMK_F32_E32,
|
||||
# RDNA4 mappings (same VOP1/VOP2 targets, RDNA4 uses _NUM_ suffix for min/max)
|
||||
ir4.VOPDOp.V_DUAL_FMAC_F32: ir3.VOP2Op.V_FMAC_F32_E32, ir4.VOPDOp.V_DUAL_MUL_F32: ir3.VOP2Op.V_MUL_F32_E32,
|
||||
ir4.VOPDOp.V_DUAL_ADD_F32: ir3.VOP2Op.V_ADD_F32_E32, ir4.VOPDOp.V_DUAL_SUB_F32: ir3.VOP2Op.V_SUB_F32_E32,
|
||||
ir4.VOPDOp.V_DUAL_SUBREV_F32: ir3.VOP2Op.V_SUBREV_F32_E32, ir4.VOPDOp.V_DUAL_MAX_NUM_F32: ir3.VOP2Op.V_MAX_F32_E32,
|
||||
ir4.VOPDOp.V_DUAL_MIN_NUM_F32: ir3.VOP2Op.V_MIN_F32_E32, ir4.VOPDOp.V_DUAL_ADD_NC_U32: ir3.VOP2Op.V_ADD_NC_U32_E32,
|
||||
ir4.VOPDOp.V_DUAL_LSHLREV_B32: ir3.VOP2Op.V_LSHLREV_B32_E32, ir4.VOPDOp.V_DUAL_AND_B32: ir3.VOP2Op.V_AND_B32_E32,
|
||||
ir4.VOPDOp.V_DUAL_MOV_B32: ir3.VOP1Op.V_MOV_B32_E32, ir4.VOPDOp.V_DUAL_CNDMASK_B32: ir3.VOP2Op.V_CNDMASK_B32_E32,
|
||||
ir4.VOPDOp.V_DUAL_FMAAK_F32: ir3.VOP2Op.V_FMAAK_F32_E32, ir4.VOPDOp.V_DUAL_FMAMK_F32: ir3.VOP2Op.V_FMAMK_F32_E32,
|
||||
}
|
||||
WAVE_SIZE = 32
|
||||
# Special registers stored after inline constants (256-259)
|
||||
|
|
@ -146,11 +155,15 @@ _pcode_fixes = {
|
|||
'V_TRIG_PREOP_F64': ("result = 64'F((1201'B(2.0 / PI)[1200 : 0] << shift.u32) & 1201'0x1fffffffffffff)", "result = trig_preop_result(shift)"),
|
||||
}
|
||||
|
||||
def _get_pcode_dict(op) -> dict:
|
||||
"""Return the PCODE dictionary for the given opcode based on its architecture."""
|
||||
return PCODE_RDNA4 if 'rdna4' in type(op).__module__ else PCODE_RDNA3
|
||||
|
||||
# Pcode parser
|
||||
@functools.cache
|
||||
def get_pcode(op) -> str:
|
||||
op_name = op.name
|
||||
pcode = PCODE[op]
|
||||
pcode = _get_pcode_dict(op)[op]
|
||||
if op_name in _pcode_fixes: pcode = pcode.replace(*_pcode_fixes[op_name])
|
||||
if 'V_DIV_SCALE' in op_name:
|
||||
dt, exp_lim, ldexp_val = ('f32', '23', '64') if 'F32' in op_name else ('f64', '52', '128')
|
||||
|
|
@ -174,7 +187,12 @@ def get_pcode(op) -> str:
|
|||
def parse_pcode(pcode: str, srcs: dict[str, UOp] | None = None) -> tuple[dict, list[tuple[str, UOp]]]:
|
||||
vars: dict = srcs.copy() if srcs else {}
|
||||
assigns: list[tuple[str, UOp]] = []
|
||||
lines = [l.strip().rstrip(';') for l in pcode.split('\n') if l.strip() and not l.strip().startswith('//')]
|
||||
raw_lines = [l.strip().rstrip(';') for l in pcode.split('\n') if l.strip() and not l.strip().startswith('//')]
|
||||
# TODO: pcode.py should tokenize full pcode string instead of line-by-line, then this hack can be removed
|
||||
lines: list[str] = []
|
||||
for l in raw_lines:
|
||||
if lines and lines[-1].endswith('&&'): lines[-1] = lines[-1] + ' ' + l
|
||||
else: lines.append(l)
|
||||
_, final, _ = parse_block(lines, 0, vars, assigns=assigns)
|
||||
sliced = set(d.split('[')[0] for d, _ in assigns if '[' in d)
|
||||
for var, val in final.items():
|
||||
|
|
@ -317,9 +335,9 @@ class _Ctx:
|
|||
return base, mask, size
|
||||
|
||||
# Dynamic register access (takes UOp index instead of int)
|
||||
def rsgpr_dyn(self, reg: UOp) -> UOp:
|
||||
def rsgpr_dyn(self, reg: UOp, valid: UOp | None = None) -> UOp:
|
||||
"""Read SGPR with dynamic register index."""
|
||||
return self.sgpr.index(reg.cast(dtypes.int), ptr=True).load()
|
||||
return self.sgpr.index(reg.cast(dtypes.int), valid, ptr=True).load() if valid is not None else self.sgpr.index(reg.cast(dtypes.int), ptr=True).load()
|
||||
|
||||
def wsgpr_dyn(self, reg: UOp, val: UOp) -> UOp:
|
||||
"""Write SGPR with dynamic register index. Writes to NULL (124) are discarded."""
|
||||
|
|
@ -341,15 +359,18 @@ class _Ctx:
|
|||
If lane is None, only scalar access is supported (off must be < 256).
|
||||
is_f64: True for F64 operations where 64-bit literals go in high 32 bits."""
|
||||
is_float_const = (off >= _c(240)) & (off <= _c(248))
|
||||
sgpr_lo = self.rsgpr_dyn(off)
|
||||
is_vgpr = off >= _c(256)
|
||||
is_sgpr = is_vgpr.ne(True)
|
||||
sgpr_lo = self.rsgpr_dyn(off, is_sgpr)
|
||||
|
||||
if lane is not None:
|
||||
is_vgpr, vgpr_reg = off >= _c(256), off - _c(256)
|
||||
vgpr_reg = off - _c(256)
|
||||
vgpr_lo = self.rvgpr_dyn(vgpr_reg, lane, is_vgpr)
|
||||
vgpr_val = _u64(vgpr_lo, self.rvgpr_dyn(vgpr_reg + _c(1), lane, is_vgpr)) if bits == 64 else vgpr_lo
|
||||
|
||||
if bits == 64:
|
||||
sgpr_val = _u64(sgpr_lo, self.rsgpr_dyn(off + _c(1)))
|
||||
sgpr_hi = self.rsgpr_dyn(off + _c(1), is_sgpr)
|
||||
sgpr_val = _u64(sgpr_lo, sgpr_hi)
|
||||
# Integer inline constants: sign-extend 32-bit value from buffer to 64-bit
|
||||
# Float constants: cast F32 to F64
|
||||
int_inline = sgpr_lo.cast(dtypes.int32).cast(dtypes.int64)
|
||||
|
|
@ -402,17 +423,19 @@ class _Ctx:
|
|||
return UOp.sink(*self.scalar_stores(assigns, sdst_reg, sdst_size), *self.inc_pc())
|
||||
|
||||
def compile_lane_pcode(self, op, inst) -> UOp:
|
||||
"""Compile READLANE/READFIRSTLANE/WRITELANE using pcode parser."""
|
||||
"""Compile cross-lane ops (READLANE/WRITELANE/PERMLANE) using pcode parser."""
|
||||
pcode = get_pcode(op)
|
||||
op_name = op.name if hasattr(op, 'name') else str(op)
|
||||
src0_off, vdst_off = self.inst_field(type(inst).src0), self.inst_field(type(inst).vdst)
|
||||
src0_reg = (src0_off >= _c(256)).where(src0_off - _c(256), _c(0)) # VGPR index or 0
|
||||
src1_off = self.inst_field(type(inst).src1) if hasattr(type(inst), 'src1') else None
|
||||
src2_off = self.inst_field(type(inst).src2) if hasattr(type(inst), 'src2') else None
|
||||
exec_lo = self.rsgpr_dyn(_c(EXEC_LO.offset))
|
||||
srcs = {
|
||||
'SRC0': src0_reg, 'VDST': vdst_off, 'EXEC_LO': exec_lo, 'EXEC': exec_lo.cast(dtypes.uint64), '_vgpr': self.vgpr,
|
||||
'S0': self.rsrc_dyn(src0_off, _c(0, dtypes.int)) if 'WRITELANE' in op_name else src0_reg,
|
||||
'S1': self.rsrc_dyn(src1_off, _c(0, dtypes.int)) if src1_off is not None else _c(0),
|
||||
'S2': self.rsrc_dyn(src2_off, _c(0, dtypes.int)) if src2_off is not None else _c(0),
|
||||
}
|
||||
_, assigns = parse_pcode(pcode, srcs)
|
||||
stores = []
|
||||
|
|
@ -480,13 +503,14 @@ class _Ctx:
|
|||
# INSTRUCTION HANDLERS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _compile_sopp(inst: SOPP, ctx: _Ctx) -> UOp:
|
||||
simm16 = ctx.inst_field_signed(SOPP.simm16).cast(dtypes.int16)
|
||||
if inst.op == SOPPOp.S_ENDPGM:
|
||||
def _compile_sopp(inst: ir3.SOPP | ir4.SOPP, ctx: _Ctx) -> UOp:
|
||||
simm16 = ctx.inst_field_signed(type(inst).simm16).cast(dtypes.int16)
|
||||
if inst.op in (ir3.SOPPOp.S_ENDPGM, ir4.SOPPOp.S_ENDPGM):
|
||||
return UOp.sink(ctx.wsgpr_dyn(_c(PC_LO_IDX), UOp.const(dtypes.uint32, 0xFFFFFFFF)),
|
||||
ctx.wsgpr_dyn(_c(PC_HI_IDX), UOp.const(dtypes.uint32, 0xFFFFFFFF)))
|
||||
if inst.op in (ir3.SOPPOp.S_NOP, ir4.SOPPOp.S_NOP): return UOp.sink(*ctx.inc_pc()) # S_NOP is a no-op
|
||||
# NOTE: we ignore SOPPs without PCODE
|
||||
if inst.op in PCODE:
|
||||
if inst.op in _get_pcode_dict(inst.op):
|
||||
pcode = get_pcode(inst.op)
|
||||
pc_bytes = ctx.rpc() # PC is already 64-bit byte address
|
||||
vcc, exec_lo = ctx.rsgpr_dyn(_c(VCC_LO.offset)), ctx.rsgpr_dyn(_c(EXEC_LO.offset))
|
||||
|
|
@ -498,50 +522,57 @@ def _compile_sopp(inst: SOPP, ctx: _Ctx) -> UOp:
|
|||
return UOp.sink(ctx.wsgpr_dyn(_c(PC_LO_IDX), lo), ctx.wsgpr_dyn(_c(PC_HI_IDX), hi))
|
||||
return UOp.sink(*ctx.inc_pc())
|
||||
|
||||
def _compile_smem(inst: SMEM, ctx: _Ctx) -> UOp:
|
||||
def _compile_smem(inst: ir3.SMEM | ir4.SMEM, ctx: _Ctx) -> UOp:
|
||||
# Cache invalidation instructions are no-ops in the emulator (we don't model caches)
|
||||
if inst.op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV):
|
||||
cache_inv_ops = [ir3.SMEMOp.S_GL1_INV, ir3.SMEMOp.S_DCACHE_INV, ir4.SMEMOp.S_DCACHE_INV]
|
||||
if hasattr(ir4.SMEMOp, 'S_GL1_INV'): cache_inv_ops.append(ir4.SMEMOp.S_GL1_INV)
|
||||
if inst.op in cache_inv_ops:
|
||||
return UOp.sink(*ctx.inc_pc())
|
||||
# Dynamic sbase field (bits 5:0) - SGPR pair, field value * 2 = register offset
|
||||
sbase = ctx.inst_field(SMEM.sbase) * _c(2)
|
||||
sbase = ctx.inst_field(type(inst).sbase) * _c(2)
|
||||
# Dynamic sdata field (bits 12:6) - destination SGPR
|
||||
sdata_reg = ctx.inst_field(SMEM.sdata)
|
||||
offset = ctx.inst_field_signed(SMEM.offset) # 21-bit signed immediate
|
||||
# Dynamic soffset field (bits 63:57) - SGPR for additional offset (NULL=124 reads as 0)
|
||||
soffset = ctx.inst_field(SMEM.soffset)
|
||||
sdata_reg = ctx.inst_field(type(inst).sdata)
|
||||
# RDNA4 uses 'ioffset', RDNA3 uses 'offset' - use type(inst) to get correct field
|
||||
offset_field = type(inst).ioffset if hasattr(type(inst), 'ioffset') else type(inst).offset
|
||||
offset = ctx.inst_field_signed(offset_field) # signed immediate
|
||||
# Dynamic soffset field - SGPR for additional offset (NULL=124 reads as 0)
|
||||
soffset = ctx.inst_field(type(inst).soffset)
|
||||
addr = _u64(ctx.rsgpr_dyn(sbase), ctx.rsgpr_dyn(sbase + _c(1))) + offset.cast(dtypes.uint64) + ctx.rsgpr_dyn(soffset).cast(dtypes.uint64)
|
||||
ndwords = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}.get(inst.op, 1)
|
||||
_SMEM_NDWORDS = {ir3.SMEMOp.S_LOAD_B32: 1, ir3.SMEMOp.S_LOAD_B64: 2, ir3.SMEMOp.S_LOAD_B128: 4,
|
||||
ir3.SMEMOp.S_LOAD_B256: 8, ir3.SMEMOp.S_LOAD_B512: 16, ir4.SMEMOp.S_LOAD_B32: 1, ir4.SMEMOp.S_LOAD_B64: 2,
|
||||
ir4.SMEMOp.S_LOAD_B96: 3, ir4.SMEMOp.S_LOAD_B128: 4, ir4.SMEMOp.S_LOAD_B256: 8, ir4.SMEMOp.S_LOAD_B512: 16}
|
||||
ndwords = _SMEM_NDWORDS[inst.op]
|
||||
stores = [ctx.wsgpr_dyn(sdata_reg + _c(i), ctx.vmem.index((addr + UOp.const(dtypes.uint64, i * 4) >> UOp.const(dtypes.uint64, 2)).cast(dtypes.int)))
|
||||
for i in range(ndwords)]
|
||||
return UOp.sink(*stores, *ctx.inc_pc())
|
||||
|
||||
def _compile_sop(inst: SOP1 | SOP2 | SOPC | SOPK, ctx: _Ctx) -> UOp:
|
||||
def _compile_sop(inst: ir3.SOP1 | ir3.SOP2 | ir3.SOPC | ir3.SOPK | ir4.SOP1 | ir4.SOP2 | ir4.SOPC | ir4.SOPK, ctx: _Ctx) -> UOp:
|
||||
bits = inst.canonical_op_bits
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
|
||||
|
||||
if isinstance(inst, SOPK):
|
||||
sdst_off = ctx.inst_field(SOPK.sdst)
|
||||
simm16 = ctx.inst_field(SOPK.simm16)
|
||||
if isinstance(inst, (ir3.SOPK, ir4.SOPK)):
|
||||
sdst_off = ctx.inst_field(type(inst).sdst)
|
||||
simm16 = ctx.inst_field(type(inst).simm16)
|
||||
# Sign-extend simm16
|
||||
simm16_sext = simm16.cast(dtypes.int16).cast(dtypes.int32)
|
||||
srcs = {'S0': ctx.rsgpr_dyn(sdst_off), 'SIMM16': simm16_sext, 'D0': ctx.rsgpr_dyn(sdst_off)}
|
||||
dst_off, dst_size = sdst_off, 1
|
||||
elif isinstance(inst, SOP1):
|
||||
sdst_off = ctx.inst_field(SOP1.sdst)
|
||||
ssrc0_off = ctx.inst_field(SOP1.ssrc0)
|
||||
elif isinstance(inst, (ir3.SOP1, ir4.SOP1)):
|
||||
sdst_off = ctx.inst_field(type(inst).sdst)
|
||||
ssrc0_off = ctx.inst_field(type(inst).ssrc0)
|
||||
srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal)}
|
||||
dst_off, dst_size = sdst_off, bits['d'] // 32
|
||||
elif isinstance(inst, SOP2):
|
||||
sdst_off = ctx.inst_field(SOP2.sdst)
|
||||
ssrc0_off = ctx.inst_field(SOP2.ssrc0)
|
||||
ssrc1_off = ctx.inst_field(SOP2.ssrc1)
|
||||
elif isinstance(inst, (ir3.SOP2, ir4.SOP2)):
|
||||
sdst_off = ctx.inst_field(type(inst).sdst)
|
||||
ssrc0_off = ctx.inst_field(type(inst).ssrc0)
|
||||
ssrc1_off = ctx.inst_field(type(inst).ssrc1)
|
||||
srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal),
|
||||
'S1': ctx.rsrc_dyn(ssrc1_off, None, bits['s1'], literal)}
|
||||
if literal is not None: srcs['SIMM32'] = literal
|
||||
dst_off, dst_size = sdst_off, bits['d'] // 32
|
||||
elif isinstance(inst, SOPC):
|
||||
ssrc0_off = ctx.inst_field(SOPC.ssrc0)
|
||||
ssrc1_off = ctx.inst_field(SOPC.ssrc1)
|
||||
elif isinstance(inst, (ir3.SOPC, ir4.SOPC)):
|
||||
ssrc0_off = ctx.inst_field(type(inst).ssrc0)
|
||||
ssrc1_off = ctx.inst_field(type(inst).ssrc1)
|
||||
srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal),
|
||||
'S1': ctx.rsrc_dyn(ssrc1_off, None, bits['s1'], literal)}
|
||||
dst_off, dst_size = _c(0), 0 # SOPC writes to SCC, not sdst
|
||||
|
|
@ -550,18 +581,18 @@ def _compile_sop(inst: SOP1 | SOP2 | SOPC | SOPK, ctx: _Ctx) -> UOp:
|
|||
|
||||
return ctx.compile_sop_pcode(inst.op, srcs, dst_off, dst_size)
|
||||
|
||||
def _compile_vop12(inst: VOP1 | VOP1_SDST | VOP2, ctx: _Ctx) -> UOp:
|
||||
def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VOP1_SDST | ir4.VOP2, ctx: _Ctx) -> UOp:
|
||||
op_name = _op_name(inst)
|
||||
if op_name == 'V_READFIRSTLANE_B32_E32': return ctx.compile_lane_pcode(inst.op, inst)
|
||||
if op_name in ('V_READFIRSTLANE_B32_E32', 'V_PERMLANE64_B32_E32'): return ctx.compile_lane_pcode(inst.op, inst)
|
||||
lane, exec_mask, bits = ctx.range(), ctx.rsgpr_dyn(_c(EXEC_LO.offset)), inst.canonical_op_bits
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
|
||||
vdst_reg = ctx.inst_field(VOP1.vdst)
|
||||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||||
write_hi_half = bits['d'] == 16 and (vdst_reg >= _c(128))
|
||||
if isinstance(write_hi_half, UOp): vdst_reg = write_hi_half.where(vdst_reg - _c(128), vdst_reg)
|
||||
elif write_hi_half: vdst_reg -= 128
|
||||
if isinstance(inst, VOP1):
|
||||
if isinstance(inst, (ir3.VOP1, ir4.VOP1)):
|
||||
# Handle VOP1 hi-half source operand (src0 >= v[128] for 16-bit ops)
|
||||
src0_off = ctx.inst_field(VOP1.src0)
|
||||
src0_off = ctx.inst_field(type(inst).src0)
|
||||
s0 = ctx.rsrc_dyn(src0_off, lane, bits['s0'], literal)
|
||||
if bits['s0'] == 16:
|
||||
src0_hi = src0_off >= _c(384)
|
||||
|
|
@ -570,13 +601,13 @@ def _compile_vop12(inst: VOP1 | VOP1_SDST | VOP2, ctx: _Ctx) -> UOp:
|
|||
s0 = src0_hi.where(_hi16(ctx.rvgpr_dyn(src0_reg, lane)), s0)
|
||||
srcs = {'S0': s0}
|
||||
else:
|
||||
vsrc1_reg = ctx.inst_field(VOP2.vsrc1)
|
||||
vsrc1_reg = ctx.inst_field(type(inst).vsrc1)
|
||||
vsrc1_hi = bits['s0'] == 16 and (vsrc1_reg >= _c(128))
|
||||
vsrc1_actual = _cond(vsrc1_hi, vsrc1_reg - _c(128), vsrc1_reg)
|
||||
s1 = _cond_hi16(vsrc1_hi, ctx.rvgpr_dyn(vsrc1_actual, lane))
|
||||
d0 = _cond_hi16(write_hi_half, ctx.rvgpr_dyn(vdst_reg, lane)) # FMAC/FMAMK hi-half dest needs hi-half accumulator
|
||||
# Handle VOP2 hi-half src0 operand (src0 >= v[128] for 16-bit ops)
|
||||
src0_off = ctx.inst_field(VOP2.src0)
|
||||
src0_off = ctx.inst_field(type(inst).src0)
|
||||
s0 = ctx.rsrc_dyn(src0_off, lane, bits['s0'], literal)
|
||||
if bits['s0'] == 16:
|
||||
src0_hi = src0_off >= _c(384)
|
||||
|
|
@ -584,19 +615,20 @@ def _compile_vop12(inst: VOP1 | VOP1_SDST | VOP2, ctx: _Ctx) -> UOp:
|
|||
src0_reg = src0_hi.where(src0_off - _c(384), _c(0))
|
||||
s0 = src0_hi.where(_hi16(ctx.rvgpr_dyn(src0_reg, lane)), s0)
|
||||
srcs = {'S0': s0, 'S1': s1, 'D0': d0}
|
||||
if inst.op in (VOP2Op.V_FMAAK_F32_E32, VOP2Op.V_FMAMK_F32_E32, VOP2Op.V_FMAAK_F16_E32, VOP2Op.V_FMAMK_F16_E32):
|
||||
if inst.op in (ir3.VOP2Op.V_FMAAK_F32_E32, ir3.VOP2Op.V_FMAMK_F32_E32, ir3.VOP2Op.V_FMAAK_F16_E32,
|
||||
ir3.VOP2Op.V_FMAMK_F16_E32):
|
||||
assert literal is not None
|
||||
srcs['SIMM32'] = literal
|
||||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=write_hi_half)
|
||||
|
||||
def _compile_vopc(inst: VOPC | VOP3, ctx: _Ctx, opsel: int = 0, abs_bits: int = 0, neg_bits: int = 0) -> UOp:
|
||||
def _compile_vopc(inst: ir3.VOPC | ir3.VOP3 | ir4.VOPC | ir4.VOP3, ctx: _Ctx, opsel: int = 0, abs_bits: int = 0, neg_bits: int = 0) -> UOp:
|
||||
exec_mask, op_name, bits = ctx.rsgpr_dyn(_c(EXEC_LO.offset)), _op_name(inst), inst.canonical_op_bits
|
||||
is_cmpx, is_vopc = 'CMPX' in op_name, hasattr(inst, 'vsrc1') # is_vopc: e32 vs e64
|
||||
|
||||
# Handle both VOPC (vsrc1) and VOP3 (src1) instruction formats - read operands dynamically
|
||||
if is_vopc:
|
||||
src0_off = ctx.inst_field(VOPC.src0)
|
||||
vsrc1_off = ctx.inst_field(VOPC.vsrc1)
|
||||
src0_off = ctx.inst_field(type(inst).src0)
|
||||
vsrc1_off = ctx.inst_field(type(inst).vsrc1)
|
||||
# For 16-bit ops, vsrc1 >= 128 means hi-half of v[vsrc1-128]
|
||||
if bits['s0'] == 16:
|
||||
vsrc1_hi = vsrc1_off >= _c(128)
|
||||
|
|
@ -605,9 +637,9 @@ def _compile_vopc(inst: VOPC | VOP3, ctx: _Ctx, opsel: int = 0, abs_bits: int =
|
|||
vsrc1_hi = False
|
||||
src1_off = _c(256) + vsrc1_off
|
||||
else:
|
||||
src0_off = ctx.inst_field(VOP3.src0)
|
||||
src1_off = ctx.inst_field(VOP3.src1)
|
||||
dst_off = ctx.inst_field(VOP3.vdst)
|
||||
src0_off = ctx.inst_field(type(inst).src0)
|
||||
src1_off = ctx.inst_field(type(inst).src1)
|
||||
dst_off = ctx.inst_field(type(inst).vdst)
|
||||
vsrc1_hi = False
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
|
||||
|
||||
|
|
@ -636,7 +668,7 @@ def _compile_vopc(inst: VOPC | VOP3, ctx: _Ctx, opsel: int = 0, abs_bits: int =
|
|||
stores = [ctx.wsgpr_dyn(dst_off, new_result)] if not is_vopc else [ctx.wsgpr_dyn(_c(VCC_LO.offset), new_result)]
|
||||
return UOp.sink(*stores, *ctx.inc_pc())
|
||||
|
||||
def _compile_vop3(inst: VOP3, ctx: _Ctx) -> UOp:
|
||||
def _compile_vop3(inst: ir3.VOP3 | ir4.VOP3, ctx: _Ctx) -> UOp:
|
||||
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
|
||||
bits = inst.canonical_op_bits
|
||||
opsel, op_name = getattr(inst, 'opsel', 0) or 0, _op_name(inst)
|
||||
|
|
@ -645,18 +677,22 @@ def _compile_vop3(inst: VOP3, ctx: _Ctx) -> UOp:
|
|||
if op_name in ('V_READLANE_B32', 'V_READFIRSTLANE_B32', 'V_READFIRSTLANE_B32_E64', 'V_WRITELANE_B32'):
|
||||
return ctx.compile_lane_pcode(inst.op, inst)
|
||||
|
||||
# V_PERMLANE16_B32 / V_PERMLANEX16_B32: cross-lane swizzle via pcode
|
||||
if 'PERMLANE16' in op_name or 'PERMLANEX16' in op_name:
|
||||
return ctx.compile_lane_pcode(inst.op, inst)
|
||||
|
||||
# VOP3 VOPC (v_cmp_*_e64) - delegate to unified VOPC handler
|
||||
if 'V_CMP' in op_name or 'V_CMPX' in op_name:
|
||||
return _compile_vopc(inst, ctx, opsel=opsel, abs_bits=getattr(inst, 'abs', 0) or 0, neg_bits=getattr(inst, 'neg', 0) or 0)
|
||||
|
||||
# Regular VOP3 - read operands dynamically
|
||||
lane = ctx.range()
|
||||
vdst_reg = ctx.inst_field(VOP3.vdst)
|
||||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
|
||||
ops = inst.canonical_operands
|
||||
src0 = ctx.rsrc_dyn(ctx.inst_field(VOP3.src0), lane, bits['s0'], literal, 's0' in ops and ops['s0'][0] == Fmt.FMT_NUM_F64)
|
||||
src1 = ctx.rsrc_dyn(ctx.inst_field(VOP3.src1), lane, bits['s1'], literal, 's1' in ops and ops['s1'][0] == Fmt.FMT_NUM_F64)
|
||||
src2 = ctx.rsrc_dyn(ctx.inst_field(VOP3.src2), lane, bits['s2'], literal, 's2' in ops and ops['s2'][0] == Fmt.FMT_NUM_F64)
|
||||
src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, bits['s0'], literal, 's0' in ops and ops['s0'][0] == Fmt.FMT_NUM_F64)
|
||||
src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, bits['s1'], literal, 's1' in ops and ops['s1'][0] == Fmt.FMT_NUM_F64)
|
||||
src2 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src2), lane, bits['s2'], literal, 's2' in ops and ops['s2'][0] == Fmt.FMT_NUM_F64)
|
||||
if bits['s0'] == 16:
|
||||
src0 = _apply_opsel(src0, 0, opsel)
|
||||
src1 = _apply_opsel(src1, 1, opsel)
|
||||
|
|
@ -666,19 +702,19 @@ def _compile_vop3(inst: VOP3, ctx: _Ctx) -> UOp:
|
|||
src1 = _apply_src_mods(src1, 1, abs_bits, neg_bits, bits['s1'])
|
||||
src2 = _apply_src_mods(src2, 2, abs_bits, neg_bits, bits['s2'])
|
||||
srcs = {'S0': src0, 'S1': src1, 'S2': src2}
|
||||
if inst.op in (VOP3Op.V_CNDMASK_B32_E64, VOP3Op.V_CNDMASK_B16) and src2 is not None: srcs['VCC'] = src2
|
||||
if inst.op in (ir3.VOP3Op.V_CNDMASK_B32_E64, ir3.VOP3Op.V_CNDMASK_B16) and src2 is not None: srcs['VCC'] = src2
|
||||
# FMAC instructions need D0 (accumulator) from destination register
|
||||
if 'FMAC' in op_name: srcs['D0'] = ctx.rvgpr_dyn(vdst_reg, lane)
|
||||
opsel_dst_hi = bool(opsel & 0b1000) and bits['d'] == 16
|
||||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=opsel_dst_hi, clmp=getattr(inst, 'clmp', 0))
|
||||
|
||||
def _compile_vop3sd(inst: VOP3SD, ctx: _Ctx) -> UOp:
|
||||
def _compile_vop3sd(inst: ir3.VOP3SD | ir4.VOP3SD, ctx: _Ctx) -> UOp:
|
||||
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
|
||||
bits, pcode, ops = inst.canonical_op_bits, get_pcode(inst.op), inst.canonical_operands
|
||||
|
||||
# Read operands dynamically from instruction encoding
|
||||
vdst_reg, sdst_off = ctx.inst_field(VOP3SD.vdst), ctx.inst_field(VOP3SD.sdst)
|
||||
src0_off, src1_off, src2_off = ctx.inst_field(VOP3SD.src0), ctx.inst_field(VOP3SD.src1), ctx.inst_field(VOP3SD.src2)
|
||||
vdst_reg, sdst_off = ctx.inst_field(type(inst).vdst), ctx.inst_field(type(inst).sdst)
|
||||
src0_off, src1_off, src2_off = ctx.inst_field(type(inst).src0), ctx.inst_field(type(inst).src1), ctx.inst_field(type(inst).src2)
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
|
||||
|
||||
has_carry_in = 's2' in ops and ops['s2'][2] == OpType.OPR_SREG
|
||||
|
|
@ -724,13 +760,13 @@ def _compile_vop3sd(inst: VOP3SD, ctx: _Ctx) -> UOp:
|
|||
else:
|
||||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, sdst_reg=inst.sdst.offset)
|
||||
|
||||
def _compile_wmma(inst: VOP3P, ctx: _Ctx) -> UOp:
|
||||
def _compile_wmma(inst: ir3.VOP3P | ir4.VOP3P, ctx: _Ctx) -> UOp:
|
||||
op_name = _op_name(inst)
|
||||
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
|
||||
vdst_reg = ctx.inst_field(VOP3P.vdst)
|
||||
src0_r = ctx.inst_field(VOP3P.src0) - _c(256)
|
||||
src1_r = ctx.inst_field(VOP3P.src1) - _c(256)
|
||||
src2_r = ctx.inst_field(VOP3P.src2) - _c(256)
|
||||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||||
src0_r = ctx.inst_field(type(inst).src0) - _c(256)
|
||||
src1_r = ctx.inst_field(type(inst).src1) - _c(256)
|
||||
src2_r = ctx.inst_field(type(inst).src2) - _c(256)
|
||||
is_f16_output = 'F16_16X16X16_F16' in op_name or 'BF16_16X16X16_BF16' in op_name # F16/BF16 output vs F32 output
|
||||
is_bf16 = 'BF16' in op_name
|
||||
cvt = _FUNCS['bf16_to_f32'] if is_bf16 else _FUNCS['f16_to_f32']
|
||||
|
|
@ -757,16 +793,16 @@ def _compile_wmma(inst: VOP3P, ctx: _Ctx) -> UOp:
|
|||
stores = [ctx.wvgpr_dyn(vdst_reg + _c(i // 32), UOp.const(dtypes.int, i % 32), mat_d[i].bitcast(dtypes.uint32), exec_mask) for i in range(256)]
|
||||
return UOp.sink(*stores, *ctx.inc_pc())
|
||||
|
||||
def _compile_vop3p(inst: VOP3P, ctx: _Ctx) -> UOp:
|
||||
def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P, ctx: _Ctx) -> UOp:
|
||||
op_name = _op_name(inst)
|
||||
if 'WMMA' in op_name and ('16X16X16_F16' in op_name or '16X16X16_BF16' in op_name): return _compile_wmma(inst, ctx)
|
||||
|
||||
lane = ctx.range()
|
||||
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
|
||||
vdst_reg = ctx.inst_field(VOP3P.vdst)
|
||||
src0 = ctx.rsrc_dyn(ctx.inst_field(VOP3P.src0), lane, 16)
|
||||
src1 = ctx.rsrc_dyn(ctx.inst_field(VOP3P.src1), lane, 16)
|
||||
src2 = ctx.rsrc_dyn(ctx.inst_field(VOP3P.src2), lane, 16)
|
||||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||||
src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, 16)
|
||||
src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, 16)
|
||||
src2 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src2), lane, 16)
|
||||
opsel, opsel_hi = getattr(inst, 'opsel', 0) or 0, getattr(inst, 'opsel_hi', 3) if getattr(inst, 'opsel_hi', 3) is not None else 3
|
||||
opsel_hi2 = getattr(inst, 'opsel_hi2', 1) if getattr(inst, 'opsel_hi2', 1) is not None else 1
|
||||
neg, neg_hi = getattr(inst, 'neg', 0) or 0, getattr(inst, 'neg_hi', 0) or 0
|
||||
|
|
@ -788,7 +824,7 @@ def _compile_vop3p(inst: VOP3P, ctx: _Ctx) -> UOp:
|
|||
s0_mod = apply_neg_mix(apply_abs(src0, 1, 1, 1), 1, 1, 1)
|
||||
s1_mod = apply_neg_mix(apply_abs(src1, 2, 2, 2), 2, 2, 2)
|
||||
s2_mod = apply_neg_mix(apply_abs(src2, 4, 4, 4), 4, 4, 4)
|
||||
srcs = {'S0': s0_mod, 'S1': s1_mod, 'S2': s2_mod,
|
||||
srcs = {'S@0': s0_mod, 'S@1': s1_mod, 'S@2': s2_mod,
|
||||
'OPSEL_HI': UOp.const(dtypes.uint32, combined_opsel_hi), 'OPSEL': UOp.const(dtypes.uint32, opsel)}
|
||||
else:
|
||||
def get_half_bits(val: UOp, use_hi: bool, apply_neg: bool = False) -> UOp:
|
||||
|
|
@ -806,18 +842,19 @@ def _compile_vop3p(inst: VOP3P, ctx: _Ctx) -> UOp:
|
|||
if is_dot_iu: srcs['NEG'] = UOp.const(dtypes.uint32, neg)
|
||||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask)
|
||||
|
||||
def _compile_vopd(inst: VOPD, ctx: _Ctx) -> UOp:
|
||||
def _compile_vopd(inst: ir3.VOPD | ir4.VOPD, ctx: _Ctx) -> UOp:
|
||||
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
|
||||
# Read operands dynamically
|
||||
vdstx_reg = ctx.inst_field(VOPD.vdstx)
|
||||
# Read operands dynamically - use type(inst) to get correct field descriptors
|
||||
inst_type = type(inst)
|
||||
vdstx_reg = ctx.inst_field(inst_type.vdstx)
|
||||
# vdsty has complex encoding: actual = (raw << 1) | ((vdstx & 1) ^ 1)
|
||||
vdsty_raw = ctx.inst_field(VOPD.vdsty)
|
||||
vdsty_raw = ctx.inst_field(inst_type.vdsty)
|
||||
vdsty_reg = (vdsty_raw << _c(1)) | ((vdstx_reg & _c(1)) ^ _c(1))
|
||||
srcx0_off = ctx.inst_field(VOPD.srcx0)
|
||||
srcy0_off = ctx.inst_field(VOPD.srcy0)
|
||||
vsrcx1_reg = ctx.inst_field(VOPD.vsrcx1)
|
||||
vsrcy1_reg = ctx.inst_field(VOPD.vsrcy1)
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
|
||||
srcx0_off = ctx.inst_field(inst_type.srcx0)
|
||||
srcy0_off = ctx.inst_field(inst_type.srcy0)
|
||||
vsrcx1_reg = ctx.inst_field(inst_type.vsrcx1)
|
||||
vsrcy1_reg = ctx.inst_field(inst_type.vsrcy1)
|
||||
literal = ctx.inst_field(inst_type.literal) if hasattr(inst_type, 'literal') else None
|
||||
|
||||
lane = ctx.range()
|
||||
srcy0, srcy1 = ctx.rsrc_dyn(srcy0_off, lane, literal=literal), ctx.rvgpr_dyn(vsrcy1_reg, lane)
|
||||
|
|
@ -828,49 +865,64 @@ def _compile_vopd(inst: VOPD, ctx: _Ctx) -> UOp:
|
|||
assert vop is not None, f"no VOP mapping for VOPD {label}: {op}"
|
||||
if label == 'Y': srcs = {'S0': srcy0, 'S1': srcy1, 'D0': ctx.rvgpr_dyn(vdst_reg, lane)}
|
||||
else: srcs = {'S0': ctx.rsrc_dyn(src0_off, lane, literal=literal), 'S1': ctx.rvgpr_dyn(vsrc1_reg, lane), 'D0': ctx.rvgpr_dyn(vdst_reg, lane)}
|
||||
if op in (VOPDOp.V_DUAL_FMAAK_F32, VOPDOp.V_DUAL_FMAMK_F32):
|
||||
if op in (ir3.VOPDOp.V_DUAL_FMAAK_F32, ir3.VOPDOp.V_DUAL_FMAMK_F32, ir4.VOPDOp.V_DUAL_FMAAK_F32, ir4.VOPDOp.V_DUAL_FMAMK_F32):
|
||||
assert literal is not None
|
||||
srcs['SIMM32'] = literal
|
||||
if op == VOPDOp.V_DUAL_CNDMASK_B32: srcs['VCC'] = ctx.rsgpr_dyn(_c(VCC_LO.offset))
|
||||
if op in (ir3.VOPDOp.V_DUAL_CNDMASK_B32, ir4.VOPDOp.V_DUAL_CNDMASK_B32): srcs['VCC'] = ctx.rsgpr_dyn(_c(VCC_LO.offset))
|
||||
pcode = get_pcode(vop)
|
||||
srcs.update({'VCC': ctx.rsgpr_dyn(_c(VCC_LO.offset)), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane})
|
||||
for dest, val in parse_pcode(pcode, srcs)[1]:
|
||||
if dest.startswith('D0'): all_stores.append(ctx.wvgpr_dyn(vdst_reg, lane, _val_to_u32(val), exec_mask, after=srcy1))
|
||||
return UOp.sink(UOp.group(*all_stores).end(lane), *ctx.inc_pc())
|
||||
|
||||
def _compile_mem_op(inst: DS | FLAT | GLOBAL | SCRATCH, ctx: _Ctx) -> UOp:
|
||||
def _compile_mem_op(inst: ir3.DS | ir3.FLAT | ir3.GLOBAL | ir3.SCRATCH | ir4.DS | ir4.VFLAT | ir4.VGLOBAL | ir4.VSCRATCH, ctx: _Ctx) -> UOp:
|
||||
"""Unified memory operation compiler for DS, FLAT, GLOBAL, SCRATCH."""
|
||||
exec_mask, op_name = ctx.rsgpr_dyn(_c(EXEC_LO.offset)), _op_name(inst)
|
||||
pcode = get_pcode(inst.op)
|
||||
|
||||
is_lds = isinstance(inst, DS)
|
||||
is_scratch = isinstance(inst, SCRATCH)
|
||||
is_lds = isinstance(inst, (ir3.DS, ir4.DS))
|
||||
is_scratch = isinstance(inst, (ir3.SCRATCH, ir4.VSCRATCH))
|
||||
mem = ctx.lds if is_lds else ctx.scratch if is_scratch else ctx.vmem
|
||||
addr_shift = UOp.const(dtypes.uint32 if is_lds else dtypes.uint64, 2)
|
||||
|
||||
# Extract register info - all dynamic for deduplication
|
||||
if is_lds:
|
||||
addr_reg = ctx.inst_field(DS.addr)
|
||||
vdata_reg = ctx.inst_field(DS.data0)
|
||||
vdst_reg = ctx.inst_field(DS.vdst)
|
||||
offset0 = ctx.inst_field(DS.offset0)
|
||||
offset1 = ctx.inst_field(DS.offset1)
|
||||
addr_reg = ctx.inst_field(type(inst).addr)
|
||||
vdata_reg = ctx.inst_field(type(inst).data0)
|
||||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||||
offset0 = ctx.inst_field(type(inst).offset0)
|
||||
offset1 = ctx.inst_field(type(inst).offset1)
|
||||
offset = offset0 # DS uses offset0 as primary offset
|
||||
saddr_reg = None
|
||||
else:
|
||||
elif isinstance(inst, (ir4.VGLOBAL, ir4.VSCRATCH, ir4.VFLAT)): # RDNA4: vaddr, vsrc, ioffset
|
||||
addr_reg = ctx.inst_field(type(inst).vaddr)
|
||||
vdata_reg = ctx.inst_field(type(inst).vsrc)
|
||||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||||
offset = ctx.inst_field_signed(type(inst).ioffset)
|
||||
offset0, offset1 = _c(0), _c(0)
|
||||
saddr_reg = ctx.inst_field(type(inst).saddr) if hasattr(type(inst), 'saddr') else None
|
||||
else: # RDNA3: addr, data, offset
|
||||
addr_reg = ctx.inst_field(type(inst).addr)
|
||||
vdata_reg = ctx.inst_field(type(inst).data)
|
||||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||||
offset = ctx.inst_field_signed(type(inst).offset)
|
||||
offset0, offset1 = _c(0), _c(0)
|
||||
# Dynamic saddr - read field, NULL (124) or >= 128 means no saddr
|
||||
saddr_reg = ctx.inst_field(type(inst).saddr) if hasattr(inst, 'saddr') else None
|
||||
saddr_reg = ctx.inst_field(type(inst).saddr) if hasattr(type(inst), 'saddr') else None
|
||||
|
||||
# Data width from canonical_op_bits (32/64/96/128), default to 32 for untyped ops
|
||||
data_bits_mem = inst.canonical_op_bits.get('data', 32)
|
||||
is_atomic, glc = 'ATOMIC' in op_name, getattr(inst, 'glc', 0)
|
||||
has_data1 = is_lds and hasattr(inst, 'data1') and inst.data1 is not None
|
||||
data1_reg = ctx.inst_field(DS.data1) if is_lds else _c(0)
|
||||
data1_reg = ctx.inst_field(type(inst).data1) if is_lds else _c(0)
|
||||
|
||||
# DS_PERMUTE/DS_BPERMUTE: cross-lane VGPR access via pcode
|
||||
if is_lds and 'PERMUTE' in op_name:
|
||||
pcode = get_pcode(inst.op)
|
||||
srcs = {'ADDR': addr_reg, 'DATA0': vdata_reg, 'VDST': vdst_reg, 'OFFSET': offset,
|
||||
'EXEC': exec_mask.cast(dtypes.uint64), '_vgpr': ctx.vgpr}
|
||||
_, assigns = parse_pcode(pcode, srcs)
|
||||
stores = [ctx.vgpr.index(val[0].cast(dtypes.int)).store(val[1].cast(dtypes.uint32)) for dest, val in assigns if dest.startswith('VGPR[')]
|
||||
return UOp.sink(*stores, *ctx.inc_pc())
|
||||
|
||||
def make_addr(lane: UOp) -> UOp:
|
||||
if is_lds: return ctx.rvgpr_dyn(addr_reg, lane)
|
||||
|
|
@ -912,14 +964,26 @@ def _compile_mem_op(inst: DS | FLAT | GLOBAL | SCRATCH, ctx: _Ctx) -> UOp:
|
|||
else:
|
||||
data = {'DATA': _u64(ctx.rvgpr_dyn(vdata_reg, lane), ctx.rvgpr_dyn(vdata_reg + _c(1), lane)),
|
||||
'DATA2': _u64(ctx.rvgpr_dyn(data1_reg, lane), ctx.rvgpr_dyn(data1_reg + _c(1), lane)) if has_data1 else UOp.const(dtypes.uint64, 0)}
|
||||
return {'ADDR': addr, 'ADDR_BASE': addr, 'OFFSET': offset, 'OFFSET0': offset0, 'OFFSET1': offset1, '_lds': mem, 'laneId': lane, **data}
|
||||
# RDNA3 uses ADDR/OFFSET, RDNA4 uses vgpr_a/offset (lowercase) + CalcDsAddr function
|
||||
return {'ADDR': addr, 'ADDR_BASE': addr, 'OFFSET': offset, 'OFFSET0': offset0, 'OFFSET1': offset1, '_lds': mem, 'laneId': lane,
|
||||
'vgpr_a': ctx.rvgpr_dyn(addr_reg, lane), 'offset': offset, **data}
|
||||
active = _lane_active(exec_mask, lane)
|
||||
# saddr < 124 means valid SGPR pair, otherwise use 0 (NULL means no saddr contribution)
|
||||
use_saddr = (saddr_reg < _c(124)) if saddr_reg is not None else UOp.const(dtypes.bool, False)
|
||||
saddr_raw = _u64(ctx.rsgpr_dyn(saddr_reg), ctx.rsgpr_dyn(saddr_reg + _c(1))) if saddr_reg is not None else UOp.const(dtypes.uint64, 0)
|
||||
saddr_base = use_saddr.where(saddr_raw, UOp.const(dtypes.uint64, 0))
|
||||
# Sign-extend offset to 64-bit for the final address calculation
|
||||
ioffset64 = offset.cast(dtypes.int64).cast(dtypes.uint64)
|
||||
# v_addr for CalcGlobalAddr: when saddr valid, use low 32 bits as offset; otherwise full 64-bit address. Include ioffset.
|
||||
vaddr_full = _u64(ctx.rvgpr_dyn(addr_reg, lane), ctx.rvgpr_dyn(addr_reg + _c(1), lane))
|
||||
vaddr_lo = ctx.rvgpr_dyn(addr_reg, lane).cast(dtypes.uint64)
|
||||
vaddr_base = use_saddr.where(vaddr_lo + ioffset64, vaddr_full + ioffset64)
|
||||
if is_atomic:
|
||||
return {'ADDR': addr, 'DATA': _u64(ctx.rvgpr_dyn(vdata_reg, lane), ctx.rvgpr_dyn(vdata_reg + _c(1), lane)) if data_bits_mem == 64 else ctx.rvgpr_dyn(vdata_reg, lane),
|
||||
'_vmem': mem, '_active': active, 'laneId': lane}
|
||||
'_vmem': mem, '_active': active, 'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base}
|
||||
vdata = ctx.rvgpr_dyn(vdata_reg, lane).cast(dtypes.uint64) if 'STORE' in op_name else ctx.rvgpr_dyn(vdst_reg, lane) if 'D16' in op_name else UOp.const(dtypes.uint32, 0)
|
||||
if 'STORE' in op_name and data_bits_mem >= 64: vdata = vdata | (ctx.rvgpr_dyn(vdata_reg + _c(1), lane).cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32))
|
||||
srcs = {'ADDR': addr, 'VDATA': vdata, '_vmem': mem, '_active': active, 'laneId': lane}
|
||||
srcs = {'ADDR': addr, 'VDATA': vdata, '_vmem': mem, '_active': active, 'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base}
|
||||
for i in range(data_bits_mem // 32): srcs[f'VDATA{i}'] = ctx.rvgpr_dyn(vdata_reg + _c(i), lane) if 'STORE' in op_name else UOp.const(dtypes.uint32, 0)
|
||||
return srcs
|
||||
|
||||
|
|
@ -970,10 +1034,15 @@ def _compile_mem_op(inst: DS | FLAT | GLOBAL | SCRATCH, ctx: _Ctx) -> UOp:
|
|||
|
||||
# Dispatch table: instruction type -> handler function
|
||||
_INST_HANDLERS: dict[type, Callable[..., UOp]] = {
|
||||
SOPP: _compile_sopp, SMEM: _compile_smem, SOP1: _compile_sop, SOP2: _compile_sop, SOPC: _compile_sop, SOPK: _compile_sop,
|
||||
VOP1: _compile_vop12, VOP1_SDST: _compile_vop12, VOP2: _compile_vop12, VOPC: _compile_vopc, VOP3: _compile_vop3, VOP3_SDST: _compile_vop3,
|
||||
VOP3SD: _compile_vop3sd, VOP3P: _compile_vop3p, VOPD: _compile_vopd,
|
||||
DS: _compile_mem_op, FLAT: _compile_mem_op, GLOBAL: _compile_mem_op, SCRATCH: _compile_mem_op,
|
||||
ir3.SOPP: _compile_sopp, ir3.SMEM: _compile_smem, ir3.SOP1: _compile_sop, ir3.SOP2: _compile_sop, ir3.SOPC: _compile_sop, ir3.SOPK: _compile_sop,
|
||||
ir3.VOP1: _compile_vop12, ir3.VOP1_SDST: _compile_vop12, ir3.VOP2: _compile_vop12, ir3.VOPC: _compile_vopc, ir3.VOP3: _compile_vop3,
|
||||
ir3.VOP3_SDST: _compile_vop3, ir3.VOP3SD: _compile_vop3sd, ir3.VOP3P: _compile_vop3p, ir3.VOPD: _compile_vopd,
|
||||
ir3.DS: _compile_mem_op, ir3.FLAT: _compile_mem_op, ir3.GLOBAL: _compile_mem_op, ir3.SCRATCH: _compile_mem_op,
|
||||
# RDNA4 instruction classes
|
||||
ir4.SOPP: _compile_sopp, ir4.SMEM: _compile_smem, ir4.SOP1: _compile_sop, ir4.SOP2: _compile_sop, ir4.SOPC: _compile_sop, ir4.SOPK: _compile_sop,
|
||||
ir4.VOP1: _compile_vop12, ir4.VOP1_SDST: _compile_vop12, ir4.VOP2: _compile_vop12, ir4.VOPC: _compile_vopc, ir4.VOP3: _compile_vop3,
|
||||
ir4.VOP3_SDST: _compile_vop3, ir4.VOP3SD: _compile_vop3sd, ir4.VOP3P: _compile_vop3p, ir4.VOPD: _compile_vopd,
|
||||
ir4.DS: _compile_mem_op, ir4.VFLAT: _compile_mem_op, ir4.VGLOBAL: _compile_mem_op, ir4.VSCRATCH: _compile_mem_op,
|
||||
}
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
|
@ -983,9 +1052,9 @@ _INST_HANDLERS: dict[type, Callable[..., UOp]] = {
|
|||
_canonical_runner_cache: list[tuple[int, int, int, object]] = [] # [(base, mask, size, runner), ...]
|
||||
|
||||
@functools.cache
|
||||
def _get_runner(inst_bytes: bytes):
|
||||
def _get_runner(inst_bytes: bytes, arch: str = "rdna3"):
|
||||
"""Build and compile instruction to CompiledRunner. Cached by instruction bytes, with canonical dedup."""
|
||||
inst = decode_inst(inst_bytes)
|
||||
inst = decode_inst(inst_bytes, arch)
|
||||
inst_size = inst.size()
|
||||
inst_int = int.from_bytes(inst_bytes[:inst_size], 'little')
|
||||
|
||||
|
|
@ -1014,15 +1083,15 @@ def _get_runner(inst_bytes: bytes):
|
|||
return runner, True
|
||||
|
||||
@functools.cache
|
||||
def decode_program(data: bytes) -> dict[int, tuple[str, Callable, list[int], Any]]:
|
||||
def decode_program(data: bytes, arch: str = "rdna3") -> dict[int, tuple[str, Callable, list[int], Any]]:
|
||||
"""Decode program to {pc: (name, fxn, globals, runner)}."""
|
||||
result: dict[int, tuple[str, Callable, list[int], Any]] = {}
|
||||
i = 0
|
||||
while i < len(data):
|
||||
inst = decode_inst(data[i:])
|
||||
if isinstance(inst, SOPP) and inst.op == SOPPOp.S_CODE_END: break
|
||||
inst = decode_inst(data[i:], arch)
|
||||
if hasattr(inst, 'op') and inst.op in (ir3.SOPPOp.S_CODE_END, ir4.SOPPOp.S_CODE_END): break
|
||||
try:
|
||||
runner, is_new = _get_runner(bytes(data[i:i + inst.size() + 4]))
|
||||
runner, is_new = _get_runner(bytes(data[i:i + inst.size() + 4]), arch)
|
||||
if DEBUG >= 3:
|
||||
try: inst_str = repr(inst)
|
||||
except Exception: inst_str = f"<{type(inst).__name__} at PC={i}>"
|
||||
|
|
@ -1081,9 +1150,9 @@ class WaveState:
|
|||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int = 0x19c,
|
||||
scratch_size: int = 0) -> int:
|
||||
scratch_size: int = 0, arch: str = "rdna3") -> int:
|
||||
"""Execute AMD assembly program. scratch_size is private_segment_fixed_size from kernel descriptor (per-lane)."""
|
||||
program_raw = decode_program(bytes((ctypes.c_char * lib_sz).from_address(lib).raw))
|
||||
program_raw = decode_program(bytes((ctypes.c_char * lib_sz).from_address(lib).raw), arch)
|
||||
program = {lib + offset: val for offset, val in program_raw.items()} # Remap to actual addresses
|
||||
lds_size = ((rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE_SHIFT) * 512
|
||||
total_threads = lx * ly * lz
|
||||
|
|
@ -1111,6 +1180,12 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int,
|
|||
(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Z, gidz)]:
|
||||
if rsrc2 & enabled: st._write_sgpr(sgpr_idx, gid); sgpr_idx += 1
|
||||
|
||||
# RDNA4 uses TTMP registers for workgroup IDs: ttmp[9]=gidx, ttmp[10]=gidy, ttmp[11]=gidz
|
||||
if arch == "rdna4":
|
||||
st._write_sgpr(ttmp[9].offset, gidx)
|
||||
st._write_sgpr(ttmp[10].offset, gidy)
|
||||
st._write_sgpr(ttmp[11].offset, gidz)
|
||||
|
||||
# v0 = packed workitem IDs, scratch stride in secret SGPR
|
||||
for lane in range(n_lanes):
|
||||
tid = wave_start + lane
|
||||
|
|
@ -1127,7 +1202,7 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int,
|
|||
assert fxn is not None, f"[emu] No fxn for {name} at PC={pc}"
|
||||
assert 4 not in globals_list or scratch_buf, f"SCRATCH instruction {name} but scratch_size=0"
|
||||
if DEBUG >= 6:
|
||||
inst = decode_inst(bytes((ctypes.c_char * 12).from_address(pc).raw))
|
||||
inst = decode_inst(bytes((ctypes.c_char * 12).from_address(pc).raw), arch)
|
||||
print(f"[emu] exec PC={pc:X}: {inst!r}")
|
||||
fxn(*[c_bufs[g] for g in globals_list])
|
||||
else: raise RuntimeError("exceeded 1M instructions, likely infinite loop")
|
||||
|
|
|
|||
|
|
@ -200,6 +200,17 @@ def _ff1(val: UOp, bits: int) -> UOp:
|
|||
result = cond.where(_const(dtypes.int, i), result)
|
||||
return result
|
||||
|
||||
def _sad_u8(a: UOp, b: UOp, acc: UOp, masked: bool = False) -> UOp:
|
||||
"""Sum of absolute differences of 4 unsigned bytes + accumulator. If masked, skips bytes where a == 0."""
|
||||
a, b, acc = a.cast(dtypes.uint32), b.cast(dtypes.uint32), acc.cast(dtypes.uint32)
|
||||
result = acc
|
||||
for i in range(4):
|
||||
a_byte = (a >> _u32(i * 8)) & _u32(0xFF)
|
||||
b_byte = (b >> _u32(i * 8)) & _u32(0xFF)
|
||||
diff = (a_byte > b_byte).where(a_byte - b_byte, b_byte - a_byte)
|
||||
result = result + (a_byte.ne(_u32(0)).where(diff, _u32(0)) if masked else diff)
|
||||
return result
|
||||
|
||||
_FUNCS: dict[str, Callable[..., UOp]] = {
|
||||
'sqrt': lambda a: UOp(Ops.SQRT, a.dtype, (a,)), 'trunc': lambda a: UOp(Ops.TRUNC, a.dtype, (a,)),
|
||||
'log2': lambda a: UOp(Ops.LOG2, a.dtype, (a,)), 'sin': lambda a: _trig_reduce(a),
|
||||
|
|
@ -254,6 +265,15 @@ _FUNCS: dict[str, Callable[..., UOp]] = {
|
|||
# Float to int16 conversions
|
||||
'v_cvt_i16_f32': lambda a: UOp(Ops.TRUNC, dtypes.float32, (a.bitcast(dtypes.float32),)).cast(dtypes.int16),
|
||||
'v_cvt_u16_f32': lambda a: _f_to_u(a.bitcast(dtypes.float32), dtypes.uint16),
|
||||
# SAD (Sum of Absolute Differences) - sum |a_i - b_i| for 4 bytes + accumulator
|
||||
'v_sad_u8': lambda a, b, c: _sad_u8(a, b, c),
|
||||
'v_msad_u8': lambda a, b, c: _sad_u8(a, b, c, masked=True),
|
||||
# System NOPs - these are scheduling hints, no effect on emulation
|
||||
'MIN': lambda a, b: (a < b).where(a, b),
|
||||
's_nop': lambda a: _u32(0),
|
||||
# Address calculation for memory operations
|
||||
'CalcDsAddr': lambda a, o, *r: a.cast(dtypes.uint32) + o.cast(dtypes.uint32),
|
||||
'CalcGlobalAddr': lambda v, s, *r: v.cast(dtypes.uint64) + s.cast(dtypes.uint64),
|
||||
}
|
||||
for is_max, name in [(False, 'min'), (True, 'max')]:
|
||||
for dt, sfx in [(dtypes.float32, 'f32'), (dtypes.int, 'i32'), (dtypes.uint32, 'u32'), (dtypes.int16, 'i16'), (dtypes.uint16, 'u16')]:
|
||||
|
|
@ -435,7 +455,7 @@ class Parser:
|
|||
self.eat('DOT')
|
||||
dt_name = self.eat('IDENT').val
|
||||
return self._handle_mem_load(addr, DTYPES.get(dt_name, dtypes.uint32))
|
||||
if name == 'VGPR':
|
||||
if name == 'VGPR' and self.at('LBRACKET'):
|
||||
self.eat('LBRACKET')
|
||||
lane = self.parse()
|
||||
self.eat('RBRACKET')
|
||||
|
|
@ -462,7 +482,21 @@ class Parser:
|
|||
if self.try_eat('LBRACE'):
|
||||
idx = self.eat('NUM').val
|
||||
self.eat('RBRACE')
|
||||
elem = self.vars.get(f'{name}{idx}', _u32(0))
|
||||
# Handle VGPR{lane}[reg] - 2D array access after loop unrolling
|
||||
if name == 'VGPR' and self.at('LBRACKET'):
|
||||
self.eat('LBRACKET')
|
||||
reg = self.parse()
|
||||
self.eat('RBRACKET')
|
||||
vgpr = self.vars.get('_vgpr')
|
||||
if vgpr is None: return _u32(0)
|
||||
return vgpr.index(_to_u32(reg) * _u32(32) + _u32(int(idx)), ptr=True).load()
|
||||
elem = self.vars.get(f'{name}@{idx}', self.vars.get(f'{name}{idx}'))
|
||||
if elem is None:
|
||||
# Extract bit idx from base variable (like var[idx])
|
||||
base = self.vars.get(name)
|
||||
assert isinstance(base, UOp), f"unknown variable: {name}{idx}"
|
||||
dt = dtypes.uint64 if base.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32
|
||||
elem = (base.cast(dt) >> _const(dt, int(idx))) & _const(dt, 1)
|
||||
if self.try_eat('DOT'):
|
||||
dt_name = self.eat('IDENT').val
|
||||
return _cast_to(elem, DTYPES.get(dt_name, dtypes.uint32))
|
||||
|
|
@ -475,15 +509,13 @@ class Parser:
|
|||
return self._handle_bracket_rest(first, _u32(0), name)
|
||||
if name in self.vars:
|
||||
v = self.vars[name]
|
||||
return v if isinstance(v, UOp) else _u32(0) if isinstance(v, dict) else _u32(0)
|
||||
assert isinstance(v, UOp), f"expected UOp for {name}, got {type(v)}"
|
||||
return v
|
||||
raise RuntimeError(f"unknown variable: {name}")
|
||||
raise RuntimeError(f"unexpected token in primary: {self.peek()}")
|
||||
|
||||
def _handle_dot(self, base, field: str) -> UOp:
|
||||
if isinstance(base, str): return _u32(0)
|
||||
if not isinstance(base, UOp):
|
||||
if isinstance(base, dict): return base.get(field, _u32(0))
|
||||
return _u32(0)
|
||||
assert isinstance(base, UOp), f"expected UOp for dot access, got {type(base)}"
|
||||
if field == 'u64' and self.at('LBRACKET') and self.peek(1).type == 'IDENT' and self.peek(1).val == 'laneId':
|
||||
self.eat('LBRACKET')
|
||||
self.eat_val('laneId', 'IDENT')
|
||||
|
|
@ -541,9 +573,11 @@ class Parser:
|
|||
var_name = self._find_var_name(base)
|
||||
if first.op == Ops.CONST:
|
||||
idx = int(first.arg)
|
||||
# Check for array element (var@idx)
|
||||
if var_name and f'{var_name}@{idx}' in self.vars:
|
||||
v = self.vars[f'{var_name}@{idx}']
|
||||
return _cast_to(v, dt_suffix) if dt_suffix else v
|
||||
# Bit extraction
|
||||
dt = dtypes.uint64 if base.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32
|
||||
base_cast = base.cast(dt) if base.dtype != dt else base
|
||||
result = ((base_cast >> _const(dt, idx)) & _const(dt, 1))
|
||||
|
|
@ -631,13 +665,14 @@ class Parser:
|
|||
raise RuntimeError(f"unexpected token after {bits}': {self.peek()}")
|
||||
|
||||
def _parse_number(self, num: str) -> UOp:
|
||||
if num.startswith('0x') or num.startswith('0X'): return _const(dtypes.uint64, int(num.rstrip('ULul'), 16))
|
||||
suffix, num = _strip_suffix(num)
|
||||
if '.' in num or suffix in ('F', 'f'):
|
||||
return _const(dtypes.float32 if suffix in ('F', 'f') else dtypes.float64, float(num))
|
||||
val = int(num)
|
||||
if 'ULL' in suffix: return _const(dtypes.uint64, val)
|
||||
if 'LL' in suffix or 'L' in suffix: return _const(dtypes.uint64, val)
|
||||
if num.startswith('0x') or num.startswith('0X'):
|
||||
is_u64 = num.upper().endswith('ULL') or num.upper().endswith('LL') or num.upper().endswith('UL')
|
||||
return _const(dtypes.uint64 if is_u64 else dtypes.uint32, int(num.rstrip('ULul'), 16))
|
||||
suffix, num_str = _strip_suffix(num)
|
||||
if '.' in num_str or suffix in ('F', 'f'):
|
||||
return _const(dtypes.float32 if suffix in ('F', 'f') else dtypes.float64, float(num_str))
|
||||
val = int(num_str)
|
||||
if 'ULL' in suffix or 'LL' in suffix or 'L' in suffix: return _const(dtypes.uint64, val)
|
||||
if 'U' in suffix: return _const(dtypes.uint32, val)
|
||||
return _const(dtypes.int if val < 0 else dtypes.uint32, val)
|
||||
|
||||
|
|
@ -655,7 +690,8 @@ class Parser:
|
|||
if ';' in body or '\n' in body or 'return' in body.lower():
|
||||
lines = [l.strip() for l in body.replace(';', '\n').split('\n') if l.strip() and not l.strip().startswith('//')]
|
||||
_, _, result = parse_block(lines, 0, lv, self.funcs)
|
||||
return result if result is not None else _u32(0)
|
||||
assert result is not None, f"lambda {name} must return a value"
|
||||
return result
|
||||
return parse_expr(body, lv, self.funcs)
|
||||
if name in self.funcs:
|
||||
return self.funcs[name](*args)
|
||||
|
|
@ -663,7 +699,7 @@ class Parser:
|
|||
|
||||
def _handle_mem_load(self, addr: UOp, dt) -> UOp:
|
||||
mem = self.vars.get('_vmem') if '_vmem' in self.vars else self.vars.get('_lds')
|
||||
if mem is None: return _const(dt, 0)
|
||||
assert mem is not None, "memory load requires _vmem or _lds"
|
||||
adt = dtypes.uint64 if addr.dtype == dtypes.uint64 else dtypes.uint32
|
||||
active = self.vars.get('_active')
|
||||
gate = (active,) if active is not None else ()
|
||||
|
|
@ -725,29 +761,9 @@ def parse_tokens(toks: list[Token], vars: dict[str, VarVal], funcs: dict | None
|
|||
|
||||
# Unified block parser for pcode
|
||||
def _subst_loop_var(line: str, loop_var: str, val: int) -> str:
|
||||
"""Substitute loop variable and evaluate bracket expressions.
|
||||
Converts var[loop_var] to var{val} for array element access (like the old regex parser)."""
|
||||
"""Substitute loop variable with its value."""
|
||||
toks = tokenize(line)
|
||||
# First pass: convert var[loop_var] to var{loop_var} to mark for array element assignment
|
||||
result_toks: list[Token] = []
|
||||
j = 0
|
||||
while j < len(toks):
|
||||
t = toks[j]
|
||||
# Check for pattern: IDENT[loop_var] where it's not preceded by a dot (not .type[...])
|
||||
if t.type == 'IDENT' and j+3 < len(toks) and toks[j+1].type == 'LBRACKET' and toks[j+2].type == 'IDENT' and toks[j+2].val == loop_var and toks[j+3].type == 'RBRACKET':
|
||||
# Check that it's not .type[loop_var]
|
||||
if not result_toks or result_toks[-1].type != 'DOT':
|
||||
result_toks.append(t)
|
||||
result_toks.append(Token('LBRACE', '{'))
|
||||
result_toks.append(Token('NUM', str(val)))
|
||||
result_toks.append(Token('RBRACE', '}'))
|
||||
j += 4
|
||||
continue
|
||||
result_toks.append(t)
|
||||
j += 1
|
||||
# Second pass: substitute loop variable in remaining positions
|
||||
subst_parts = [str(val) if t.type == 'IDENT' and t.val == loop_var else t.val for t in result_toks if t.type != 'EOF']
|
||||
return ' '.join(subst_parts)
|
||||
return ' '.join(str(val) if t.type == 'IDENT' and t.val == loop_var else t.val for t in toks if t.type != 'EOF')
|
||||
|
||||
def _set_bits(old: UOp, val: UOp, width: int, offset: int) -> UOp:
|
||||
"""Set bits [offset:offset+width) in old to val, masking and shifting appropriately."""
|
||||
|
|
@ -797,8 +813,9 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
|||
def parse_bound():
|
||||
if p.at('NUM') and p.peek(1).type == 'QUOTE': p.eat('NUM'); p.eat('QUOTE')
|
||||
if p.at('NUM'): return int(p.eat('NUM').val.rstrip('UuLl'))
|
||||
expr = p.parse()
|
||||
return int(expr.arg) if expr.op == Ops.CONST else 0
|
||||
expr = p.parse().simplify()
|
||||
assert expr.op == Ops.CONST, f"loop bound must be constant, got {expr}"
|
||||
return int(expr.arg)
|
||||
start_val = parse_bound()
|
||||
p.eat('COLON')
|
||||
end_val = parse_bound()
|
||||
|
|
@ -844,7 +861,9 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
|||
|
||||
# declare
|
||||
if first == 'declare':
|
||||
if '[' not in line and len(toks) >= 2 and toks[1].type == 'IDENT': vars[toks[1].val] = _u32(0)
|
||||
# Initialize scalar declarations (skip arrays and vars already passed as srcs)
|
||||
if '[' not in line and len(toks) >= 2 and toks[1].type == 'IDENT':
|
||||
vars.setdefault(toks[1].val, _u32(0))
|
||||
i += 1; continue
|
||||
|
||||
# lambda definition
|
||||
|
|
@ -902,6 +921,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
|||
j, lane_toks = _match_bracket(toks, 1)
|
||||
if j < len(toks) and toks[j].type == 'LBRACKET':
|
||||
j, reg_toks = _match_bracket(toks, j)
|
||||
if j < len(toks) and toks[j].type == 'DOT': j += 2 # skip .type suffix
|
||||
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
|
||||
ln, rg, val = parse_tokens(lane_toks, vars, funcs), parse_tokens(reg_toks, vars, funcs), parse_tokens(toks[j:], vars, funcs)
|
||||
if assigns is not None: assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}]', (_to_u32(rg) * _u32(32) + _to_u32(ln), val)))
|
||||
|
|
@ -965,19 +985,32 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
|||
block_assigns[var] = vars[var] = _set_bit(existing, _to_u32(parse_tokens(bit_toks, vars, funcs)), parse_tokens(toks[j+1:], vars, funcs))
|
||||
i += 1; continue
|
||||
|
||||
# Array element: var{idx} = value
|
||||
if len(toks) >= 5 and toks[0].type == 'IDENT' and toks[1].type == 'LBRACE' and toks[2].type == 'NUM':
|
||||
var, idx = toks[0].val, int(toks[2].val)
|
||||
j = 4
|
||||
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
|
||||
if j < len(toks):
|
||||
val = parse_tokens(toks[j+1:], vars, funcs)
|
||||
existing = block_assigns.get(var, vars.get(var))
|
||||
if existing is not None and isinstance(existing, UOp):
|
||||
block_assigns[var] = vars[var] = _set_bit(existing, _u32(idx), val)
|
||||
else:
|
||||
block_assigns[f'{var}@{idx}'] = vars[f'{var}@{idx}'] = val
|
||||
i += 1; continue
|
||||
# Array element: var[idx] = value (static index) or var[expr] = value (dynamic)
|
||||
if len(toks) >= 4 and toks[0].type == 'IDENT' and toks[1].type == 'LBRACKET':
|
||||
var = toks[0].val
|
||||
j, idx_toks = _match_bracket(toks, 1)
|
||||
if j < len(toks) and toks[j].type == 'EQUALS':
|
||||
# Static index: var[NUM] = value
|
||||
if len(idx_toks) == 1 and idx_toks[0].type == 'NUM':
|
||||
idx = int(idx_toks[0].val.rstrip('UuLl'))
|
||||
val = parse_tokens(toks[j+1:], vars, funcs)
|
||||
existing = block_assigns.get(var, vars.get(var))
|
||||
if existing is not None and isinstance(existing, UOp):
|
||||
block_assigns[var] = vars[var] = _set_bit(existing, _u32(idx), val)
|
||||
else:
|
||||
block_assigns[f'{var}@{idx}'] = vars[f'{var}@{idx}'] = val
|
||||
i += 1; continue
|
||||
# Dynamic index: var[expr] = value where var has @-elements
|
||||
elems = [(k.split('@')[1], v) for k, v in {**vars, **block_assigns}.items() if k.startswith(f'{var}@') and isinstance(v, UOp)]
|
||||
if elems:
|
||||
idx_expr = parse_tokens(idx_toks, vars, funcs)
|
||||
val = parse_tokens(toks[j+1:], vars, funcs)
|
||||
for elem_idx_str, old_elem in elems:
|
||||
elem_idx = int(elem_idx_str)
|
||||
cond = _to_u32(idx_expr).eq(_u32(elem_idx))
|
||||
new_val = cond.where(val.cast(old_elem.dtype) if val.dtype != old_elem.dtype else val, old_elem)
|
||||
block_assigns[f'{var}@{elem_idx}'] = vars[f'{var}@{elem_idx}'] = new_val
|
||||
i += 1; continue
|
||||
|
||||
# Compound assignment: var += or var -=
|
||||
assign_op = next((j for j, t in enumerate(toks) if t.type == 'ASSIGN_OP'), None)
|
||||
|
|
@ -1024,13 +1057,14 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
|||
def parse_cond(s, kw):
|
||||
ll = s.lower()
|
||||
return _to_bool(parse_expr(s[ll.find(kw) + len(kw):ll.rfind('then')].strip(), vars, funcs))
|
||||
def not_static_false(c): return c.op != Ops.CONST or c.arg is not False
|
||||
def is_const(c, v): return c.op == Ops.CONST and c.arg is v
|
||||
cond = parse_cond(line, 'if')
|
||||
conditions: list[tuple[UOp, UOp | dict[str, VarVal] | None]] = [(cond, None)] if not_static_false(cond) else []
|
||||
conditions: list[tuple[UOp, UOp | dict[str, VarVal] | None]] = [(cond, None)] if not is_const(cond, False) else []
|
||||
else_branch: tuple[UOp | None, dict[str, VarVal]] = (None, {})
|
||||
vars_snap = dict(vars)
|
||||
static_true = is_const(cond, True) # track if any condition is statically true
|
||||
i += 1
|
||||
i, branch, ret = parse_block(lines, i, vars, funcs, assigns)
|
||||
i, branch, ret = parse_block(lines, i, vars, funcs, assigns if not is_const(cond, False) else None)
|
||||
if conditions: conditions[0] = (cond, ret if ret is not None else branch)
|
||||
vars.clear(); vars.update(vars_snap)
|
||||
while i < len(lines):
|
||||
|
|
@ -1039,12 +1073,16 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
|||
lf = ltoks[0].val.lower()
|
||||
if lf == 'elsif':
|
||||
c = parse_cond(lines[i], 'elsif')
|
||||
i += 1; i, branch, ret = parse_block(lines, i, vars, funcs, assigns)
|
||||
if not_static_false(c): conditions.append((c, ret if ret is not None else branch))
|
||||
take = not static_true and not is_const(c, False)
|
||||
i += 1; i, branch, ret = parse_block(lines, i, vars, funcs, assigns if take else None)
|
||||
if take:
|
||||
conditions.append((c, ret if ret is not None else branch))
|
||||
if is_const(c, True): static_true = True
|
||||
vars.clear(); vars.update(vars_snap)
|
||||
elif lf == 'else':
|
||||
i += 1; i, branch, ret = parse_block(lines, i, vars, funcs, assigns)
|
||||
else_branch = (ret, branch)
|
||||
i += 1
|
||||
i, branch, ret = parse_block(lines, i, vars, funcs, assigns if not static_true else None)
|
||||
if not static_true: else_branch = (ret, branch)
|
||||
vars.clear(); vars.update(vars_snap)
|
||||
elif lf == 'endif': i += 1; break
|
||||
else: break
|
||||
|
|
@ -1056,17 +1094,21 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
|||
if rv.dtype != result.dtype and rv.dtype.itemsize == result.dtype.itemsize: result = result.cast(rv.dtype)
|
||||
result = c.where(rv, result)
|
||||
return i, block_assigns, result
|
||||
# Main style: merge variable assignments with WHERE
|
||||
else_assigns = else_branch[1]
|
||||
all_vars = set().union(*[ba.keys() for _, ba in conditions if isinstance(ba, dict)], else_assigns.keys())
|
||||
for var in all_vars:
|
||||
res: Any = else_assigns.get(var, block_assigns.get(var, vars.get(var, _u32(0))))
|
||||
for cond, ba in reversed(conditions):
|
||||
if isinstance(ba, dict) and var in ba:
|
||||
tv = ba[var]
|
||||
if isinstance(tv, UOp) and isinstance(res, UOp):
|
||||
res = cond.where(tv, res.cast(tv.dtype) if tv.dtype != res.dtype and tv.dtype.itemsize == res.dtype.itemsize else res)
|
||||
block_assigns[var] = vars[var] = res
|
||||
# If statically true, use that branch directly; otherwise merge with WHERE
|
||||
if static_true:
|
||||
ba = next((b for c, b in conditions if is_const(c, True) and isinstance(b, dict)), {})
|
||||
block_assigns.update(ba); vars.update(ba)
|
||||
else:
|
||||
else_assigns = else_branch[1]
|
||||
all_vars = set().union(*[ba.keys() for _, ba in conditions if isinstance(ba, dict)], else_assigns.keys())
|
||||
for var in all_vars:
|
||||
res: Any = else_assigns.get(var, block_assigns.get(var, vars.get(var, _u32(0))))
|
||||
for cond, ba in reversed(conditions):
|
||||
if isinstance(ba, dict) and var in ba:
|
||||
tv = ba[var]
|
||||
if isinstance(tv, UOp) and isinstance(res, UOp):
|
||||
res = cond.where(tv, res.cast(tv.dtype) if tv.dtype != res.dtype and tv.dtype.itemsize == res.dtype.itemsize else res)
|
||||
block_assigns[var] = vars[var] = res
|
||||
continue
|
||||
|
||||
# Regular assignment: var = value
|
||||
|
|
|
|||
|
|
@ -2,11 +2,8 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Iterator
|
||||
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
|
||||
from extra.assembly.amd.sqtt import decode, print_packets, INST, VALUINST, IMMEDIATE, WAVESTART, WAVEEND, InstOp, PacketType, IMMEDIATE_MASK
|
||||
from extra.assembly.amd.dsl import Inst
|
||||
from extra.assembly.amd import decode_inst
|
||||
from extra.assembly.amd.autogen.rdna3.ins import SOPP, s_endpgm
|
||||
from extra.assembly.amd.autogen.rdna3.enum import SOPPOp
|
||||
|
||||
|
|
@ -16,19 +13,11 @@ class InstructionInfo:
|
|||
wave: int
|
||||
inst: Inst
|
||||
|
||||
def map_insts(data:bytes, lib:bytes) -> Iterator[tuple[PacketType, InstructionInfo|None]]:
|
||||
def map_insts(data:bytes, lib:bytes, target:int) -> Iterator[tuple[PacketType, InstructionInfo|None]]:
|
||||
"""maps SQTT packets to instructions, yields (packet, instruction_info or None)"""
|
||||
# map pcs to insts
|
||||
pc_map:dict[int, Inst] = {}
|
||||
image, sections, _ = elf_loader(lib)
|
||||
text = next((sh for sh in sections if sh.name == ".text"), None)
|
||||
assert text is not None, "no .text section found"
|
||||
text_off, text_size = text.header.sh_addr, text.header.sh_size
|
||||
offset = text_off
|
||||
while offset < text_off + text_size:
|
||||
inst = decode_inst(image[offset:])
|
||||
pc_map[offset-text_off] = inst
|
||||
offset += inst.size()
|
||||
from tinygrad.viz.serve import amd_decode
|
||||
pc_map = amd_decode(lib, target)
|
||||
|
||||
wave_pc:dict[int, int] = {}
|
||||
# only processing packets on one [CU, SIMD] unit
|
||||
|
|
@ -37,7 +26,7 @@ def map_insts(data:bytes, lib:bytes) -> Iterator[tuple[PacketType, InstructionIn
|
|||
if not simd_select(p): continue
|
||||
if isinstance(p, WAVESTART):
|
||||
assert p.wave not in wave_pc, "only one inflight wave per unit"
|
||||
wave_pc[p.wave] = 0
|
||||
wave_pc[p.wave] = next(iter(pc_map))
|
||||
continue
|
||||
if isinstance(p, WAVEEND):
|
||||
pc = wave_pc.pop(p.wave)
|
||||
|
|
@ -80,22 +69,22 @@ def map_insts(data:bytes, lib:bytes) -> Iterator[tuple[PacketType, InstructionIn
|
|||
# test to compare every packet with the rocprof decoder
|
||||
|
||||
def test_rocprof_inst_traces_match(sqtt, prg, target):
|
||||
from tinygrad.viz.serve import llvm_disasm
|
||||
from tinygrad.viz.serve import amd_decode
|
||||
from extra.sqtt.roc import decode as roc_decode, InstExec
|
||||
disasm = {addr+prg.base:inst_disasm for addr, inst_disasm in llvm_disasm(target, prg.lib).items()}
|
||||
rctx = roc_decode([sqtt], {prg.name:disasm})
|
||||
rwaves = rctx.inst_execs[(sqtt.kern, sqtt.exec_tag)]
|
||||
addr_table = amd_decode(prg.lib, target)
|
||||
disasm = {addr+prg.base:(inst.disasm(), inst.size()) for addr,inst in addr_table.items()}
|
||||
rctx = roc_decode([sqtt], {prg.tag:disasm})
|
||||
rwaves = rctx.inst_execs.get((sqtt.kern, sqtt.exec_tag), [])
|
||||
rwaves_iter:dict[int, list[Iterator[InstExec]]] = {} # wave unit (0-15) -> list of inst trace iterators for all executions on that unit
|
||||
for w in rwaves: rwaves_iter.setdefault(w.wave_id, []).append(w.unpack_insts())
|
||||
rwaves_base = next(iter(disasm)) # base program counter
|
||||
|
||||
passed_insts = 0
|
||||
for pkt, info in map_insts(sqtt.blob, prg.lib):
|
||||
for pkt, info in map_insts(sqtt.blob, prg.lib, target):
|
||||
if DEBUG >= 2: print_packets([pkt])
|
||||
if info is None: continue
|
||||
if DEBUG >= 2: print(f"{' '*29}{info.inst.disasm()}")
|
||||
rocprof_inst = next(rwaves_iter[info.wave][0])
|
||||
ref_pc = rocprof_inst.pc-rwaves_base
|
||||
ref_pc = rocprof_inst.pc-prg.base
|
||||
# always check pc matches
|
||||
assert ref_pc == info.pc, f"pc mismatch {ref_pc}:{disasm[rocprof_inst.pc][0]} != {info.pc}:{info.inst.disasm()}"
|
||||
# special handling for s_endpgm, it marks the wave completion.
|
||||
|
|
@ -110,7 +99,8 @@ def test_rocprof_inst_traces_match(sqtt, prg, target):
|
|||
for k,v in rwaves_iter.items():
|
||||
assert len(v) == 0, f"incomplete wave {k}"
|
||||
|
||||
print(f"passed for {passed_insts} instructions across {len(rwaves)} waves scheduled on {len(rwaves_iter)} wave units")
|
||||
if len(rwaves):
|
||||
print(f"passed for {passed_insts} instructions across {len(rwaves)} waves scheduled on {len(rwaves_iter)} wave units")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse, pickle, pathlib
|
||||
|
|
@ -123,7 +113,7 @@ if __name__ == "__main__":
|
|||
with open(args.profile, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
sqtt_events = [e for e in data if type(e).__name__ == "ProfileSQTTEvent"]
|
||||
kern_events = {e.name:e for e in data if type(e).__name__ == "ProfileProgramEvent"}
|
||||
kern_events = {e.tag:e for e in data if type(e).__name__ == "ProfileProgramEvent"}
|
||||
target = next((e for e in data if type(e).__name__ == "ProfileDeviceEvent" and e.device.startswith("AMD"))).props["gfx_target_version"]
|
||||
for e in sqtt_events:
|
||||
if args.kernel is not None and args.kernel != e.kern: continue
|
||||
|
|
|
|||
|
|
@ -43,6 +43,23 @@ VCC = VCC_LO # For VOP3SD sdst field (VCC_LO is exported from dsl)
|
|||
USE_HW = os.environ.get("USE_HW", "0") == "1"
|
||||
FLOAT_TOLERANCE = 1e-5
|
||||
|
||||
def get_gpu_target() -> tuple[int, int, int]:
|
||||
"""Get the GPU target as (major, minor, stepping) tuple."""
|
||||
if not USE_HW: return (0, 0, 0)
|
||||
from tinygrad.device import Device
|
||||
return Device["AMD"].target
|
||||
|
||||
def skip_unless_gfx(min_major: int, min_minor: int = 0, reason: str = ""):
|
||||
"""Skip test if GPU target is below the minimum required version."""
|
||||
import unittest
|
||||
def decorator(test_func):
|
||||
if not USE_HW: return test_func
|
||||
target = get_gpu_target()
|
||||
if target[0] < min_major or (target[0] == min_major and target[1] < min_minor):
|
||||
return unittest.skip(reason or f"requires gfx{min_major}{min_minor}0+")(test_func)
|
||||
return test_func
|
||||
return decorator
|
||||
|
||||
# Output buffer layout: vgpr[16][32], sgpr[16], vcc, scc, exec
|
||||
N_VGPRS, N_SGPRS, WAVE_SIZE = 16, 16, 32
|
||||
VGPR_BYTES = N_VGPRS * WAVE_SIZE * 4 # 16 regs * 32 lanes * 4 bytes = 2048
|
||||
|
|
@ -212,8 +229,12 @@ amdhsa.kernels:
|
|||
|
||||
return parse_output(bytes(out_buf), n_lanes)
|
||||
|
||||
def compare_wave_states(emu_st: WaveState, hw_st: WaveState, n_lanes: int, n_vgprs: int = N_VGPRS) -> list[str]:
|
||||
"""Compare two WaveStates and return list of differences."""
|
||||
def compare_wave_states(emu_st: WaveState, hw_st: WaveState, n_lanes: int, n_vgprs: int = N_VGPRS, ulp_tolerance: int = 0) -> list[str]:
|
||||
"""Compare two WaveStates and return list of differences.
|
||||
|
||||
Args:
|
||||
ulp_tolerance: Allow up to this many ULPs difference for float comparisons (0 = exact match required)
|
||||
"""
|
||||
import math
|
||||
diffs = []
|
||||
for i in range(n_vgprs):
|
||||
|
|
@ -224,6 +245,11 @@ def compare_wave_states(emu_st: WaveState, hw_st: WaveState, n_lanes: int, n_vgp
|
|||
emu_f, hw_f = _f32(emu_val), _f32(hw_val)
|
||||
if math.isnan(emu_f) and math.isnan(hw_f):
|
||||
continue
|
||||
# Check ULP difference for floats (only for same-sign values)
|
||||
if ulp_tolerance > 0 and (emu_val < 0x80000000) == (hw_val < 0x80000000):
|
||||
ulp_diff = abs(int(emu_val) - int(hw_val))
|
||||
if ulp_diff <= ulp_tolerance:
|
||||
continue
|
||||
diffs.append(f"v[{i}] lane {lane}: emu=0x{emu_val:08x} ({emu_f:.6g}) hw=0x{hw_val:08x} ({hw_f:.6g})")
|
||||
for i in range(N_SGPRS):
|
||||
emu_val = emu_st.sgpr[i]
|
||||
|
|
@ -236,16 +262,19 @@ def compare_wave_states(emu_st: WaveState, hw_st: WaveState, n_lanes: int, n_vgp
|
|||
diffs.append(f"scc: emu={emu_st.scc} hw={hw_st.scc}")
|
||||
return diffs
|
||||
|
||||
def run_program(instructions: list, n_lanes: int = 1) -> WaveState:
|
||||
def run_program(instructions: list, n_lanes: int = 1, ulp_tolerance: int = 0) -> WaveState:
|
||||
"""Run instructions and return WaveState.
|
||||
|
||||
If USE_HW=1, runs on both emulator and hardware, compares results, and raises if they differ.
|
||||
Otherwise, runs only on emulator.
|
||||
|
||||
Args:
|
||||
ulp_tolerance: Allow up to this many ULPs difference for float comparisons (0 = exact match required)
|
||||
"""
|
||||
emu_st = run_program_emu(instructions, n_lanes)
|
||||
if USE_HW:
|
||||
hw_st = run_program_hw(instructions, n_lanes)
|
||||
diffs = compare_wave_states(emu_st, hw_st, n_lanes)
|
||||
diffs = compare_wave_states(emu_st, hw_st, n_lanes, ulp_tolerance=ulp_tolerance)
|
||||
if diffs:
|
||||
raise AssertionError(f"Emulator vs Hardware mismatch:\n" + "\n".join(diffs))
|
||||
return hw_st
|
||||
|
|
|
|||
|
|
@ -719,5 +719,47 @@ class TestAtomicOrdering(unittest.TestCase):
|
|||
self.assertEqual(st.vgpr[0][4], 150, "Final value should be 150")
|
||||
|
||||
|
||||
class TestDsPermute(unittest.TestCase):
|
||||
"""Tests for DS_PERMUTE_B32 and DS_BPERMUTE_B32 instructions."""
|
||||
|
||||
def test_ds_permute_b32_identity(self):
|
||||
"""DS_PERMUTE_B32 with identity permutation (lane 0 sends to lane 0)."""
|
||||
# For simplicity, test with single lane
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0), # addr = 0 (lane 0)
|
||||
v_mov_b32_e32(v[1], 0xDEADBEEF), # data
|
||||
ds_permute_b32(v[2], v[0], v[1]),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
# Lane 0 sends to lane 0, so lane 0 gets 0xDEADBEEF
|
||||
self.assertEqual(st.vgpr[0][2], 0xDEADBEEF)
|
||||
|
||||
def test_ds_bpermute_b32_identity(self):
|
||||
"""DS_BPERMUTE_B32 with identity permutation (each lane reads from itself)."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0), # addr = 0 (read from lane 0)
|
||||
v_mov_b32_e32(v[1], 0xCAFEBABE), # data in lane 0
|
||||
ds_bpermute_b32(v[2], v[0], v[1]),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
# Lane 0 reads from lane 0's v[1]
|
||||
self.assertEqual(st.vgpr[0][2], 0xCAFEBABE)
|
||||
|
||||
def test_ds_permute_b32_broadcast(self):
|
||||
"""DS_PERMUTE_B32 broadcast - all lanes send to lane 0."""
|
||||
# With 4 lanes, all sending to lane 0, highest lane wins
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0), # All lanes send to addr 0 (lane 0)
|
||||
v_mov_b32_e32(v[1], 0x11111111), # All lanes send same data
|
||||
ds_permute_b32(v[2], v[0], v[1]),
|
||||
s_waitcnt(lgkmcnt=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=4)
|
||||
# Lane 0 receives data (highest numbered active lane wins)
|
||||
self.assertEqual(st.vgpr[0][2], 0x11111111)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ class TestBasicScalar(unittest.TestCase):
|
|||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.sgpr[1], 0x80000000)
|
||||
|
||||
@skip_unless_gfx(11, 5, "SALU FP ops require gfx1150+")
|
||||
def test_s_fmamk_f32(self):
|
||||
"""S_FMAMK_F32: D = S0 * literal + S1."""
|
||||
# 2.0 * 3.0 + 1.0 = 7.0
|
||||
|
|
@ -73,6 +74,7 @@ class TestBasicScalar(unittest.TestCase):
|
|||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.sgpr[2], f2i(7.0))
|
||||
|
||||
@skip_unless_gfx(11, 5, "SALU FP ops require gfx1150+")
|
||||
def test_s_fmamk_f32_negative(self):
|
||||
"""S_FMAMK_F32 with negative values."""
|
||||
# -2.0 * 4.0 + 10.0 = 2.0
|
||||
|
|
|
|||
|
|
@ -1561,5 +1561,24 @@ class TestCvtNormF16(unittest.TestCase):
|
|||
self.assertAlmostEqual(result, 32768, delta=1)
|
||||
|
||||
|
||||
class TestPermlane64(unittest.TestCase):
|
||||
"""Tests for V_PERMLANE64_B32 instruction (wave64 cross-half swap)."""
|
||||
|
||||
def test_v_permlane64_b32_is_nop_in_wave32(self):
|
||||
"""V_PERMLANE64_B32 is a NOP in wave32 mode.
|
||||
|
||||
Per AMD pcode: "if WAVE32 then s_nop(...) else ... endif"
|
||||
The emulator runs in wave32 mode, so this instruction should not modify registers.
|
||||
"""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0xCAFEBABE), # source
|
||||
v_mov_b32_e32(v[1], 0x12345678), # dest (should be preserved)
|
||||
v_permlane64_b32_e32(v[1], v[0]), # NOP in wave32
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
# Dest register should be unchanged (NOP behavior in wave32)
|
||||
self.assertEqual(st.vgpr[0][1], 0x12345678)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -3262,5 +3262,81 @@ class TestMinMaxF16Vop3(unittest.TestCase):
|
|||
self.assertAlmostEqual(result, 4.0, delta=0.01)
|
||||
|
||||
|
||||
class TestSadHi(unittest.TestCase):
|
||||
"""Tests for V_SAD_HI_U8 instruction."""
|
||||
|
||||
def test_v_sad_hi_u8_basic(self):
|
||||
"""V_SAD_HI_U8: (sad << 16) + acc."""
|
||||
# |1-5| + |2-6| + |3-7| + |4-8| = 16, << 16 = 0x100000, + 100 = 0x100064
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0x04030201),
|
||||
v_mov_b32_e32(v[1], 0x08070605),
|
||||
v_mov_b32_e32(v[2], 100),
|
||||
v_sad_hi_u8(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][3], (16 << 16) + 100)
|
||||
|
||||
def test_v_sad_hi_u8_zero_diff(self):
|
||||
"""V_SAD_HI_U8: identical inputs gives acc only."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0x12345678),
|
||||
v_mov_b32_e32(v[2], 50),
|
||||
v_sad_hi_u8(v[3], v[0], v[0], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][3], 50)
|
||||
|
||||
|
||||
class TestPermlane(unittest.TestCase):
|
||||
"""Tests for V_PERMLANE16_B32 and V_PERMLANEX16_B32 instructions."""
|
||||
|
||||
def test_v_permlane16_b32_identity(self):
|
||||
"""V_PERMLANE16_B32 with identity permutation (lane i reads from lane i within row)."""
|
||||
# lanesel encodes 4 bits per position: position i gets lanesel[i*4+3:i*4]
|
||||
# Identity: position 0->0, 1->1, ..., 15->15
|
||||
# lanesel = 0xFEDCBA9876543210 (positions 15-0 in nibbles)
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0xDEADBEEF), # source data
|
||||
s_mov_b32(s[0], 0x76543210), # lanesel low (positions 0-7)
|
||||
s_mov_b32(s[1], 0xFEDCBA98), # lanesel high (positions 8-15)
|
||||
v_permlane16_b32(v[1], v[0], s[0], s[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
# Lane 0 reads from lane 0 (position 0 -> lanesel[3:0] = 0)
|
||||
self.assertEqual(st.vgpr[0][1], 0xDEADBEEF)
|
||||
|
||||
def test_v_permlane16_b32_broadcast(self):
|
||||
"""V_PERMLANE16_B32 broadcast lane 0 to all lanes in row."""
|
||||
# lanesel = all zeros -> all positions read from lane 0 within row
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0xCAFEBABE), # source data
|
||||
s_mov_b32(s[0], 0), # lanesel low = 0 (all read lane 0)
|
||||
s_mov_b32(s[1], 0), # lanesel high = 0
|
||||
v_permlane16_b32(v[1], v[0], s[0], s[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=4)
|
||||
# All lanes read from lane 0 of their row
|
||||
for lane in range(4):
|
||||
self.assertEqual(st.vgpr[lane][1], 0xCAFEBABE)
|
||||
|
||||
def test_v_permlanex16_b32_identity(self):
|
||||
"""V_PERMLANEX16_B32 cross-row read with identity selection."""
|
||||
# In wave32: row 0 (lanes 0-15) reads from row 1 (lanes 16-31) and vice versa
|
||||
# With single lane in row 0, it reads from lane 0 of row 1 (lane 16)
|
||||
# But lane 16 doesn't exist in 1-lane test, so use 32 lanes
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0x11111111), # All lanes have this initially
|
||||
s_mov_b32(s[0], 0x76543210), # lanesel low
|
||||
s_mov_b32(s[1], 0xFEDCBA98), # lanesel high
|
||||
v_permlanex16_b32(v[1], v[0], s[0], s[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=32)
|
||||
# Lane 0 in row 0 reads from lane 0 of row 1 (lane 16)
|
||||
self.assertEqual(st.vgpr[0][1], 0x11111111)
|
||||
# Lane 16 in row 1 reads from lane 0 of row 0 (lane 0)
|
||||
self.assertEqual(st.vgpr[16][1], 0x11111111)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -767,6 +767,7 @@ class TestDot2F32F16(unittest.TestCase):
|
|||
"""V_DOT2_F32_F16 with negative f16 values."""
|
||||
# src0 = {hi=-2.0, lo=3.0}, src1 = {hi=1.0, lo=2.0}
|
||||
# result = 3.0*2.0 + (-2.0)*1.0 + 0 = 6 - 2 = 4.0
|
||||
# NOTE: Hardware DOT2 may have up to 1 ULP difference due to internal implementation
|
||||
src0 = (f32_to_f16(-2.0) << 16) | f32_to_f16(3.0)
|
||||
src1 = (f32_to_f16(1.0) << 16) | f32_to_f16(2.0)
|
||||
instructions = [
|
||||
|
|
@ -777,7 +778,7 @@ class TestDot2F32F16(unittest.TestCase):
|
|||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot2_f32_f16(v[3], v[0], v[1], v[2], opsel_hi=3, opsel_hi2=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
st = run_program(instructions, n_lanes=1, ulp_tolerance=1)
|
||||
result = i2f(st.vgpr[0][3])
|
||||
self.assertAlmostEqual(result, 4.0, places=2)
|
||||
|
||||
|
|
|
|||
98
extra/assembly/amd/test/test_rdna4_emu.py
Normal file
98
extra/assembly/amd/test/test_rdna4_emu.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
import unittest, ctypes
|
||||
from extra.assembly.amd.autogen.rdna4 import ins as ir4
|
||||
from extra.assembly.amd.dsl import v, s
|
||||
from extra.assembly.amd.emu import WaveState, decode_program
|
||||
from tinygrad.device import Buffer, BufferSpec
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
class TestRDNA4Emu(unittest.TestCase):
|
||||
def _run(self, insts: list, sgprs: dict[int, int] = None, vgprs: dict[tuple[int, int], int] = None) -> WaveState:
|
||||
"""Run instructions and return final WaveState."""
|
||||
# Add S_ENDPGM if not present
|
||||
if not any(isinstance(i, ir4.SOPP) and i.op == ir4.SOPPOp.S_ENDPGM for i in insts):
|
||||
insts = list(insts) + [ir4.SOPP(ir4.SOPPOp.S_ENDPGM, simm=0)]
|
||||
|
||||
# Assemble and decode
|
||||
code = b''.join(i.to_bytes() for i in insts)
|
||||
code_buf = (ctypes.c_uint8 * len(code)).from_buffer_copy(code)
|
||||
code_addr = ctypes.addressof(code_buf)
|
||||
program_raw = decode_program(code, "rdna4")
|
||||
program = {code_addr + offset: val for offset, val in program_raw.items()}
|
||||
|
||||
# Setup wave state
|
||||
st = WaveState(n_lanes=1)
|
||||
st.pc = code_addr
|
||||
if sgprs:
|
||||
for idx, val in sgprs.items(): st._write_sgpr(idx, val)
|
||||
if vgprs:
|
||||
for (reg, lane), val in vgprs.items(): st._write_vgpr(reg, lane, val)
|
||||
|
||||
# Setup vmem buffer with external_ptr=0 (maps to address 0, allows any pointer access)
|
||||
vmem_buf = Buffer('CPU', 1 << 40, dtypes.uint32, options=BufferSpec(external_ptr=0)).ensure_allocated()
|
||||
|
||||
# Execute
|
||||
c_bufs = [ctypes.c_uint64(st.sgpr_buf._buf.va_addr), ctypes.c_uint64(st.vgpr_buf._buf.va_addr),
|
||||
ctypes.c_uint64(vmem_buf._buf.va_addr), ctypes.c_uint64(0), ctypes.c_uint64(0)]
|
||||
for _ in range(100):
|
||||
if (pc := st.pc) == 0xFFFFFFFFFFFFFFFF or pc not in program: break
|
||||
_, fxn, globals_list, _ = program[pc]
|
||||
fxn(*[c_bufs[g] for g in globals_list])
|
||||
return st
|
||||
|
||||
def test_vopd_dual_mov(self):
|
||||
"""Test VOPD with two V_DUAL_MOV_B32 operations: v[1]=s[1], v[2]=s[2]."""
|
||||
insts = [ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
|
||||
vdstx=v[1], vdsty=v[2], srcx0=s[1], srcy0=s[2], vsrcx1=v[0], vsrcy1=v[0])]
|
||||
st = self._run(insts, sgprs={1: 0x40e00000, 2: 0x41100000}) # 7.0f, 9.0f
|
||||
self.assertEqual(st._read_vgpr(1, 0), 0x40e00000) # v[1] = 7.0
|
||||
self.assertEqual(st._read_vgpr(2, 0), 0x41100000) # v[2] = 9.0
|
||||
|
||||
def test_vopd_dual_mov_after_other_vopd(self):
|
||||
"""Test VOPD reuse: first VOPD(v[3]=0, v[0]=?), then VOPD(v[1]=s[1], v[2]=s[2])."""
|
||||
# This matches the BEAM kernel sequence that fails
|
||||
insts = [
|
||||
ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
|
||||
vdstx=v[3], vdsty=v[0], srcx0=0, srcy0=s[0], vsrcx1=v[0], vsrcy1=v[0]), # v[3]=0, v[0]=s[0]
|
||||
ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
|
||||
vdstx=v[1], vdsty=v[2], srcx0=s[1], srcy0=s[2], vsrcx1=v[0], vsrcy1=v[0]), # v[1]=s[1], v[2]=s[2]
|
||||
]
|
||||
st = self._run(insts, sgprs={0: 0x40a00000, 1: 0x40e00000, 2: 0x41100000}) # 5.0f, 7.0f, 9.0f
|
||||
self.assertEqual(st._read_vgpr(1, 0), 0x40e00000) # v[1] = 7.0
|
||||
self.assertEqual(st._read_vgpr(2, 0), 0x41100000) # v[2] = 9.0
|
||||
|
||||
def test_vopd_with_s_add_f32_sequence(self):
|
||||
"""Test full BEAM kernel sequence: s_add_f32 then VOPD."""
|
||||
# This is the exact sequence from the failing BEAM kernel
|
||||
insts = [
|
||||
ir4.SOP2(ir4.SOP2Op.S_ADD_F32, sdst=s[0], ssrc0=s[0], ssrc1=s[8]), # s[0] = s[0] + s[8]
|
||||
ir4.SOP2(ir4.SOP2Op.S_ADD_F32, sdst=s[1], ssrc0=s[1], ssrc1=s[9]), # s[1] = s[1] + s[9]
|
||||
ir4.SOP2(ir4.SOP2Op.S_ADD_F32, sdst=s[2], ssrc0=s[2], ssrc1=s[10]), # s[2] = s[2] + s[10]
|
||||
ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
|
||||
vdstx=v[3], vdsty=v[0], srcx0=0, srcy0=s[0], vsrcx1=v[0], vsrcy1=v[0]),
|
||||
ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
|
||||
vdstx=v[1], vdsty=v[2], srcx0=s[1], srcy0=s[2], vsrcx1=v[0], vsrcy1=v[0]),
|
||||
]
|
||||
# Input: s[0:2] = [1,2,3], s[8:10] = [4,5,6]
|
||||
# After s_add_f32: s[0:2] = [5,7,9]
|
||||
st = self._run(insts, sgprs={0: 0x3f800000, 1: 0x40000000, 2: 0x40400000, # 1.0, 2.0, 3.0
|
||||
8: 0x40800000, 9: 0x40a00000, 10: 0x40c00000}) # 4.0, 5.0, 6.0
|
||||
self.assertEqual(st._read_vgpr(1, 0), 0x40e00000) # v[1] = 7.0
|
||||
self.assertEqual(st._read_vgpr(2, 0), 0x41100000) # v[2] = 9.0
|
||||
|
||||
def test_s_mov_b32_then_vopd(self):
|
||||
"""Test s_mov_b32 followed by VOPD - simulates BEAM kernel sequence."""
|
||||
# Use s_mov_b32 with SGPR source (copy from pre-initialized SGPRs)
|
||||
# s[10:12] will have values set by test harness, copy to s[0:2], then VOPD to VGPRs
|
||||
insts = [
|
||||
ir4.SOP1(ir4.SOP1Op.S_MOV_B32, sdst=s[0], ssrc0=s[10]), # s[0] = s[10]
|
||||
ir4.SOP1(ir4.SOP1Op.S_MOV_B32, sdst=s[1], ssrc0=s[11]), # s[1] = s[11]
|
||||
ir4.SOP1(ir4.SOP1Op.S_MOV_B32, sdst=s[2], ssrc0=s[12]), # s[2] = s[12]
|
||||
ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
|
||||
vdstx=v[1], vdsty=v[2], srcx0=s[1], srcy0=s[2], vsrcx1=v[0], vsrcy1=v[0]),
|
||||
]
|
||||
st = self._run(insts, sgprs={10: 0x40a00000, 11: 0x40e00000, 12: 0x41100000}) # 5.0, 7.0, 9.0
|
||||
self.assertEqual(st._read_vgpr(1, 0), 0x40e00000) # v[1] = 7.0
|
||||
self.assertEqual(st._read_vgpr(2, 0), 0x41100000) # v[2] = 9.0
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -203,12 +203,12 @@ class SQTTExamplesTestBase(unittest.TestCase):
|
|||
class TestSQTTExamplesRDNA3(SQTTExamplesTestBase):
|
||||
target = "gfx1100"
|
||||
expected = {
|
||||
"profile_empty_run_0": [1803, 1908, 1928, 1979, 2006, 1912],
|
||||
"profile_empty_run_1": [1803, 1908, 1928, 1979, 2006, 1912],
|
||||
"profile_gemm_run_0": [2531, 1844, 1864, 1915, 1942, 1848, 3074, 1919, 1939, 1990, 2017, 1923, 19026, 1919, 1939, 1990, 2017, 1929],
|
||||
"profile_gemm_run_1": [2554, 1844, 1864, 1915, 1942, 1848, 3084, 1919, 1939, 1990, 2017, 1923, 19010, 1919, 1939, 1990, 2017, 1923],
|
||||
"profile_plus_run_0": [1900, 1908, 1928, 1979, 2006, 1912],
|
||||
"profile_plus_run_1": [1856, 1908, 1928, 1979, 2006, 1912],
|
||||
"profile_empty_run_0": [1844, 1885, 1905, 1956, 1983, 1889],
|
||||
"profile_empty_run_1": [1780, 1885, 1905, 1956, 1983, 1889],
|
||||
"profile_gemm_run_0": [2656, 2025, 2045, 2096, 2123, 2029, 3183, 2019, 2039, 2090, 2117, 2023, 19119, 2013, 2033, 2084, 2111, 2017],
|
||||
"profile_gemm_run_1": [2662, 2025, 2045, 2096, 2123, 2029, 3179, 2019, 2039, 2090, 2117, 2023, 19113, 2071, 2091, 2142, 2169, 2075],
|
||||
"profile_plus_run_0": [1886, 2013, 2033, 2084, 2111, 2017],
|
||||
"profile_plus_run_1": [1988, 2071, 2091, 2142, 2169, 2075],
|
||||
}
|
||||
|
||||
class TestSQTTExamplesRDNA4(SQTTExamplesTestBase): target = "gfx1200"
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
import sys, os, zlib, struct, hashlib
|
||||
from tinygrad.helpers import DEBUG, getenv, fetch
|
||||
import os, zlib, struct, hashlib
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.runtime.support.usb import USB3
|
||||
|
||||
SUPPORTED_CONTROLLERS = [
|
||||
|
|
@ -50,7 +50,7 @@ patched_fw = patch(file_path, file_hash, patches)
|
|||
dev = None
|
||||
for vendor, device in SUPPORTED_CONTROLLERS:
|
||||
try:
|
||||
dev = USB3(vendor, device, 0x81, 0x83, 0x02, 0x04)
|
||||
dev = USB3(vendor, device, 0x81, 0x83, 0x02, 0x04, use_bot=True)
|
||||
break
|
||||
except RuntimeError: pass
|
||||
if dev is None:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import tinygrad.runtime.autogen.am.am as am
|
|||
import tinygrad.runtime.autogen.amdgpu_drm as amdgpu_drm
|
||||
from tinygrad.helpers import from_mv
|
||||
from test.mockgpu.driver import VirtDriver, VirtFileDesc, TextFileDesc, DirFileDesc, VirtFile
|
||||
from test.mockgpu.amd.amdgpu import AMDGPU, gpu_props
|
||||
from test.mockgpu.amd.amdgpu import AMDGPU, gpu_props, GFX_TARGET_VERSION, MOCKGPU_ARCH
|
||||
|
||||
libc = ctypes.CDLL(ctypes.util.find_library("c"))
|
||||
libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_long]
|
||||
|
|
@ -90,35 +90,30 @@ class AMDDriver(VirtDriver):
|
|||
def _prepare_gpu(self, gpu_id):
|
||||
self.doorbells[gpu_id] = memoryview(bytearray(0x2000))
|
||||
self.gpus[gpu_id] = AMDGPU(gpu_id)
|
||||
# IP versions: rdna3 = GC 11.0.0, NBIF 4.3.0; rdna4 = GC 12.0.0, NBIF 6.3.1
|
||||
ip_versions = {"rdna3": {"gc": (11, 0, 0), "sdma": (6, 0, 0), "nbif": (4, 3, 0)},
|
||||
"rdna4": {"gc": (12, 0, 0), "sdma": (6, 0, 0), "nbif": (6, 3, 1)}}[MOCKGPU_ARCH]
|
||||
def ip_discovery_files(hwid, ver, base_addr):
|
||||
p = f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{hwid}/0'
|
||||
return [VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{hwid}', functools.partial(DirFileDesc, child_names=['0'])),
|
||||
VirtFile(f'{p}/major', functools.partial(TextFileDesc, text=str(ver[0]))),
|
||||
VirtFile(f'{p}/minor', functools.partial(TextFileDesc, text=str(ver[1]))),
|
||||
VirtFile(f'{p}/revision', functools.partial(TextFileDesc, text=str(ver[2]))),
|
||||
VirtFile(f'{p}/base_addr', functools.partial(TextFileDesc, text=base_addr))]
|
||||
self.tracked_files += [
|
||||
VirtFile('/sys/module/amdgpu', functools.partial(TextFileDesc, text="1")),
|
||||
VirtFile('/sys/module/amdgpu/parameters/ppfeaturemask', functools.partial(TextFileDesc, text="0xffff3fff")),
|
||||
VirtFile(f'/sys/devices/virtual/kfd/kfd/topology/nodes/{gpu_id}', functools.partial(DirFileDesc, child_names=['gpu_id', 'properties'])),
|
||||
VirtFile(f'/sys/devices/virtual/kfd/kfd/topology/nodes/{gpu_id}/gpu_id', functools.partial(TextFileDesc, text=f"{gpu_id}")),
|
||||
VirtFile(f'/sys/devices/virtual/kfd/kfd/topology/nodes/{gpu_id}/properties',
|
||||
functools.partial(TextFileDesc, text=gpu_props.format(drm_render_minor=gpu_id))),
|
||||
functools.partial(TextFileDesc, text=gpu_props.format(drm_render_minor=gpu_id, gfx_target_version=GFX_TARGET_VERSION))),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/power_dpm_force_performance_level',
|
||||
functools.partial(TextFileDesc, text='profile_standard\n')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0',
|
||||
functools.partial(DirFileDesc, child_names=[str(am.GC_HWID), str(am.SDMA0_HWID), str(am.NBIF_HWID)])),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.GC_HWID}', functools.partial(DirFileDesc, child_names=['0'])),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.GC_HWID}/0/major', functools.partial(TextFileDesc, text='11')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.GC_HWID}/0/minor', functools.partial(TextFileDesc, text='0')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.GC_HWID}/0/revision', functools.partial(TextFileDesc, text='0')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.GC_HWID}/0/base_addr',
|
||||
functools.partial(TextFileDesc, text='0x00001260\n0x0000A000\n0x0001C000\n0x02402C00')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.SDMA0_HWID}', functools.partial(DirFileDesc, child_names=['0'])),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.SDMA0_HWID}/0/major', functools.partial(TextFileDesc, text='6')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.SDMA0_HWID}/0/minor', functools.partial(TextFileDesc, text='0')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.SDMA0_HWID}/0/revision', functools.partial(TextFileDesc, text='0')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.SDMA0_HWID}/0/base_addr',
|
||||
functools.partial(TextFileDesc, text='0x00001260\n0x0000A000\n0x0001C000\n0x02402C00')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.NBIF_HWID}', functools.partial(DirFileDesc, child_names=['0'])),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.NBIF_HWID}/0/major', functools.partial(TextFileDesc, text='4')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.NBIF_HWID}/0/minor', functools.partial(TextFileDesc, text='3')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.NBIF_HWID}/0/revision', functools.partial(TextFileDesc, text='0')),
|
||||
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.NBIF_HWID}/0/base_addr',
|
||||
functools.partial(TextFileDesc, text='0x00000000\n0x00000014\n0x00000D20\n0x00010400\n0x0241B000\n0x04040000')),
|
||||
*ip_discovery_files(am.GC_HWID, ip_versions["gc"], '0x00001260\n0x0000A000\n0x0001C000\n0x02402C00'),
|
||||
*ip_discovery_files(am.SDMA0_HWID, ip_versions["sdma"], '0x00001260\n0x0000A000\n0x0001C000\n0x02402C00'),
|
||||
*ip_discovery_files(am.NBIF_HWID, ip_versions["nbif"], '0x00000000\n0x00000014\n0x00000D20\n0x00010400\n0x0241B000\n0x04040000'),
|
||||
VirtFile(f'/dev/dri/renderD{gpu_id}', functools.partial(DRMFileDesc, driver=self, gpu=f"{self.gpus[gpu_id]}")),
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import ctypes, time
|
||||
from test.mockgpu.gpu import VirtGPU
|
||||
from test.mockgpu.helpers import _try_dlopen_remu
|
||||
from tinygrad.helpers import getbits, to_mv
|
||||
from tinygrad.helpers import getbits, to_mv, getenv
|
||||
from tinygrad.runtime.support import c
|
||||
|
||||
MOCKGPU_ARCH = getenv("MOCKGPU_ARCH", "rdna3")
|
||||
GFX_TARGET_VERSION = {"rdna3": 110000, "rdna4": 120000}[MOCKGPU_ARCH]
|
||||
import tinygrad.runtime.autogen.amd_gpu as amd_gpu, tinygrad.runtime.autogen.am.pm4_nv as pm4
|
||||
|
||||
SDMA_MAX_COPY_SIZE = 0x400000
|
||||
|
|
@ -194,10 +197,11 @@ class PM4Executor(AMDQueue):
|
|||
scratch_size = wavesize * 4 # This gives the scratch size per thread (lane)
|
||||
|
||||
assert prg_sz > 0, "Invalid prg ptr (not found in mapped ranges)"
|
||||
# Pass valid memory ranges, rsrc2, and scratch_size to Python emulator
|
||||
# Pass valid memory ranges, rsrc2, scratch_size and arch to Python emulator
|
||||
if hasattr(remu, 'valid_mem_ranges'): remu.valid_mem_ranges = self.gpu.mapped_ranges
|
||||
if hasattr(remu, 'rsrc2'): remu.rsrc2 = rsrc2
|
||||
if hasattr(remu, 'scratch_size'): remu.scratch_size = scratch_size
|
||||
if hasattr(remu, 'arch'): remu.arch = self.gpu.arch
|
||||
err = remu.run_asm(prg_addr, prg_sz, *gl, *lc, args_addr)
|
||||
if err != 0: raise RuntimeError("remu does not support the new instruction introduced in this kernel")
|
||||
|
||||
|
|
@ -314,6 +318,7 @@ class AMDGPU(VirtGPU):
|
|||
self.regs = AMDGPURegisters()
|
||||
self.mapped_ranges = set()
|
||||
self.queues = []
|
||||
self.arch = MOCKGPU_ARCH
|
||||
|
||||
def map_range(self, vaddr, size): self.mapped_ranges.add((vaddr, size))
|
||||
def unmap_range(self, vaddr, size): self.mapped_ranges.remove((vaddr, size))
|
||||
|
|
@ -342,7 +347,7 @@ simd_arrays_per_engine 2
|
|||
cu_per_simd_array 8
|
||||
simd_per_cu 2
|
||||
max_slots_scratch_cu 32
|
||||
gfx_target_version 110000
|
||||
gfx_target_version {gfx_target_version}
|
||||
vendor_id 4098
|
||||
device_id 29772
|
||||
location_id 34304
|
||||
|
|
|
|||
|
|
@ -16,14 +16,15 @@ def _try_dlopen_gpuocelot():
|
|||
return None
|
||||
|
||||
class PythonRemu:
|
||||
"""Python RDNA3 emulator wrapper that matches the libremu.so interface."""
|
||||
"""Python RDNA3/RDNA4 emulator wrapper that matches the libremu.so interface."""
|
||||
valid_mem_ranges: set[tuple[int, int]] = set()
|
||||
rsrc2: int = 0x19c # Default: USER_SGPR_COUNT=14, enable X and Y workgroup IDs
|
||||
scratch_size: int = 0 # private_segment_fixed_size from kernel descriptor
|
||||
arch: str = "rdna3" # Architecture: rdna3 or rdna4
|
||||
|
||||
def run_asm(self, lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int) -> int:
|
||||
from extra.assembly.amd.emu import run_asm
|
||||
return run_asm(lib, lib_sz, gx, gy, gz, lx, ly, lz, args_ptr, self.rsrc2, self.scratch_size)
|
||||
return run_asm(lib, lib_sz, gx, gy, gz, lx, ly, lz, args_ptr, self.rsrc2, self.scratch_size, self.arch)
|
||||
|
||||
def _try_dlopen_remu():
|
||||
# Use Python emulator only if PYTHON_REMU=1
|
||||
|
|
|
|||
|
|
@ -195,8 +195,10 @@ class TestAssignIssues(unittest.TestCase):
|
|||
t.shrink(((1, 3), (1, 3))).assign(Tensor.ones(2, 2))
|
||||
np.testing.assert_allclose(t.numpy(), torch_tensor.numpy())
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_assign_broadcast(self):
|
||||
# broadcasting during assign should behave like PyTorch
|
||||
# NOTE: we don't want implicit dtype casting (int64 -> float32 loses precision), so this fails
|
||||
torch_tensor = torch.zeros(3, 5)
|
||||
torch_tensor[:] = torch.arange(5)
|
||||
t = Tensor.zeros(3, 5)
|
||||
|
|
|
|||
|
|
@ -256,6 +256,18 @@ class TestMultiTensor(unittest.TestCase):
|
|||
a,b = _test_allreduce(Tensor.rand(256, 256))
|
||||
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
|
||||
|
||||
def test_multiple_to_single_device_naive(self):
|
||||
with Context(RING=0):
|
||||
t = Tensor.arange(32).shard(devices_4, 0).to(Device.DEFAULT).realize()
|
||||
self.assertEqual(t.device, Device.DEFAULT)
|
||||
np.testing.assert_equal(t.numpy(), np.arange(32))
|
||||
|
||||
def test_multiple_to_single_device_ring(self):
|
||||
with Context(RING=2):
|
||||
t = Tensor.arange(32).shard(devices_4, 0).to(Device.DEFAULT).realize()
|
||||
self.assertEqual(t.device, Device.DEFAULT)
|
||||
np.testing.assert_equal(t.numpy(), np.arange(32))
|
||||
|
||||
def test_allreduce_all2all(self):
|
||||
with Context(ALL2ALL=2):
|
||||
a,b = _test_allreduce(Tensor.rand(256, 256))
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# allow define from star imports
|
||||
|
||||
import unittest
|
||||
import textwrap, functools
|
||||
import functools
|
||||
|
||||
from tinygrad import Device, Tensor
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
|
|
@ -11,69 +11,27 @@ from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
|||
from tinygrad.viz.serve import amdgpu_cfg
|
||||
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
from extra.assembly.amd.dsl import Inst
|
||||
from extra.assembly.amd.dsl import s
|
||||
|
||||
template = """.text
|
||||
.globl fn_name
|
||||
.p2align 8
|
||||
.type fn_name,@function
|
||||
fn_name:
|
||||
INSTRUCTION
|
||||
|
||||
.rodata
|
||||
.p2align 6
|
||||
.amdhsa_kernel fn_name
|
||||
.amdhsa_kernarg_size 8
|
||||
.amdhsa_user_sgpr_kernarg_segment_ptr 1
|
||||
.amdhsa_next_free_vgpr .amdgcn.next_free_vgpr
|
||||
.amdhsa_next_free_sgpr .amdgcn.next_free_sgpr
|
||||
.amdhsa_wavefront_size32 1
|
||||
.end_amdhsa_kernel
|
||||
|
||||
.amdgpu_metadata
|
||||
---
|
||||
amdhsa.version:
|
||||
- 1
|
||||
- 0
|
||||
amdhsa.kernels:
|
||||
- .name: fn_name
|
||||
.symbol: fn_name.kd
|
||||
.group_segment_fixed_size: 0
|
||||
.private_segment_fixed_size: 0
|
||||
.wavefront_size: 32
|
||||
.sgpr_count: 8
|
||||
.vgpr_count: 8
|
||||
.max_flat_workgroup_size: 1024
|
||||
.kernarg_segment_align: 8
|
||||
.kernarg_segment_size: 8
|
||||
.args:
|
||||
- .address_space: global
|
||||
.name: a
|
||||
.offset: 0
|
||||
.size: 8
|
||||
.type_name: 'float*'
|
||||
.value_kind: global_buffer
|
||||
...
|
||||
.end_amdgpu_metadata
|
||||
"""
|
||||
# TODO: this belongs to the dsl infrastructure
|
||||
from extra.gemm.amd_asm_matmul import Kernel
|
||||
|
||||
# TODO: shouldn't need compiler once we can output ELF
|
||||
# outputs a text disassembly for humans and a machine readable binary
|
||||
def assemble(name:str, insts:list[str|Inst], compiler:Compiler) -> tuple[str, bytes]:
|
||||
asm = "\n".join([inst if isinstance(inst, str) else inst.disasm() for inst in insts])
|
||||
src = template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(asm))
|
||||
def assemble(name:str, k:Kernel, compiler:Compiler) -> tuple[str, bytes]:
|
||||
src = k.to_asm()
|
||||
return (src, compiler.compile(src))
|
||||
|
||||
def asm_kernel(out:UOp, insts:list[str|Inst], name:str, device:str, compiler:Compiler, n_threads:int=1, n_workgroups:int=1) -> UOp:
|
||||
def asm_kernel(out:UOp, k:Kernel, name:str, device:str, compiler:Compiler, n_threads:int=1, n_workgroups:int=1) -> UOp:
|
||||
lidx = UOp.special(n_threads, "lidx0")
|
||||
gidx = UOp.special(n_workgroups, "gidx0")
|
||||
sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name=name))
|
||||
src, lib = assemble(name, insts, compiler)
|
||||
src, lib = assemble(name, k, compiler)
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)),
|
||||
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
def run_asm(name:str, insts:list) -> None:
|
||||
fxn = functools.partial(asm_kernel, insts=insts, name=name, device=Device.DEFAULT, compiler=HIPCompiler(Device[Device.DEFAULT].renderer.arch))
|
||||
def run_asm(name:str, k:Kernel) -> None:
|
||||
fxn = functools.partial(asm_kernel, k=k, name=name, device=Device.DEFAULT, compiler=HIPCompiler(Device[Device.DEFAULT].renderer.arch))
|
||||
out = Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0]
|
||||
out.realize()
|
||||
|
||||
|
|
@ -85,32 +43,32 @@ class TestCfg(unittest.TestCase):
|
|||
self.skipTest(f"tests written for RDNA, got arch {arch}")
|
||||
|
||||
def test_simple(self):
|
||||
run_asm("simple", [
|
||||
"entry:",
|
||||
"s_branch bb1",
|
||||
"bb1:",
|
||||
s_endpgm(),
|
||||
s_code_end(),
|
||||
])
|
||||
k = Kernel(arch=Device["AMD"].arch)
|
||||
k.label("entry")
|
||||
k.emit(s_branch(), target="bb1")
|
||||
k.label("bb1")
|
||||
k.emit(s_endpgm())
|
||||
k.emit(s_code_end())
|
||||
run_asm("simple", k)
|
||||
|
||||
def test_diamond(self):
|
||||
run_asm("diamond", insts:=[
|
||||
"entry:",
|
||||
s_mov_b32(s[0], 0),
|
||||
s_mov_b32(s[1], 0),
|
||||
s_cmp_eq_u64(s[0:1], 0),
|
||||
"s_cbranch_scc1 if",
|
||||
"s_branch else",
|
||||
"if:",
|
||||
s_nop(1),
|
||||
"s_branch end",
|
||||
"else:",
|
||||
s_nop(0),
|
||||
"end:",
|
||||
s_endpgm(),
|
||||
s_code_end(),
|
||||
])
|
||||
_, lib = assemble("diamond", insts, HIPCompiler(Device[Device.DEFAULT].arch))
|
||||
k = Kernel(arch=Device["AMD"].arch)
|
||||
k.label("entry")
|
||||
k.emit(s_mov_b32(s[0], 0))
|
||||
k.emit(s_mov_b32(s[1], 0))
|
||||
k.emit(s_cmp_eq_u64(s[0:1], 0))
|
||||
k.emit(s_cbranch_scc1(), target="if")
|
||||
k.emit(s_branch(), target="else")
|
||||
k.label("if")
|
||||
k.emit(s_nop(1))
|
||||
k.emit(s_branch(), target="end")
|
||||
k.label("else")
|
||||
k.emit(s_nop(0))
|
||||
k.label("end")
|
||||
k.emit(s_endpgm())
|
||||
k.emit(s_code_end())
|
||||
run_asm("diamond", k)
|
||||
_, lib = assemble("diamond", k, HIPCompiler(Device[Device.DEFAULT].arch))
|
||||
cfg = amdgpu_cfg(lib, Device[Device.DEFAULT].device_props()["gfx_target_version"])["data"]
|
||||
self.assertEqual(len(cfg["blocks"]), 5)
|
||||
edge_count = sum(len(v) for v in cfg["paths"].values())
|
||||
|
|
@ -124,133 +82,138 @@ class TestCfg(unittest.TestCase):
|
|||
self.assertEqual(insts, ['s_mov_b32', 's_cmp_eq_u64'])
|
||||
|
||||
def test_loop(self):
|
||||
run_asm("simple_loop", [
|
||||
"entry:",
|
||||
s_mov_b32(s[1], 4),
|
||||
"loop:",
|
||||
s_add_u32(s[1], s[1], -1),
|
||||
s_cmp_eq_i32(s[1], 0),
|
||||
"s_cbranch_scc0 loop",
|
||||
s_endpgm(),
|
||||
s_code_end(),
|
||||
])
|
||||
k = Kernel(arch=Device["AMD"].arch)
|
||||
k.label("entry")
|
||||
k.emit(s_mov_b32(s[1], 4))
|
||||
k.label("loop")
|
||||
k.emit(s_add_u32(s[1], s[1], -1))
|
||||
k.emit(s_cmp_eq_i32(s[1], 0))
|
||||
k.emit(s_cbranch_scc0(), target="loop")
|
||||
k.emit(s_endpgm())
|
||||
k.emit(s_code_end())
|
||||
run_asm("simple_loop", k)
|
||||
|
||||
def test_loop_branch(self):
|
||||
run_asm("loop_if", [
|
||||
"entry:",
|
||||
s_mov_b32(s[1], 4),
|
||||
"loop:",
|
||||
s_add_u32(s[1], s[1], -1),
|
||||
s_cmp_eq_i32(s[1], 2),
|
||||
"s_cbranch_scc1 cond",
|
||||
"s_branch cont",
|
||||
"cond:",
|
||||
s_add_u32(s[1], s[1], -2),
|
||||
"cont:",
|
||||
s_cmp_eq_i32(s[1], 0),
|
||||
"s_cbranch_scc0 loop",
|
||||
s_endpgm(),
|
||||
s_code_end(),
|
||||
])
|
||||
k = Kernel(arch=Device["AMD"].arch)
|
||||
k.label("entry")
|
||||
k.emit(s_mov_b32(s[1], 4))
|
||||
k.label("loop")
|
||||
k.emit(s_add_u32(s[1], s[1], -1))
|
||||
k.emit(s_cmp_eq_i32(s[1], 2))
|
||||
k.emit(s_cbranch_scc1(), target="cond")
|
||||
k.emit(s_branch(), target="cont")
|
||||
k.label("cond")
|
||||
k.emit(s_add_u32(s[1], s[1], -2))
|
||||
k.label("cont")
|
||||
k.emit(s_cmp_eq_i32(s[1], 0))
|
||||
k.emit(s_cbranch_scc0(), target="loop")
|
||||
k.emit(s_endpgm())
|
||||
k.emit(s_code_end())
|
||||
run_asm("loop_if", k)
|
||||
|
||||
def test_loop_break(self):
|
||||
run_asm("loop_break", [
|
||||
"entry:",
|
||||
s_mov_b32(s[1], 8),
|
||||
"loop:",
|
||||
s_add_u32(s[1], s[1], -1),
|
||||
s_cmp_eq_i32(s[1], 5),
|
||||
"s_cbranch_scc1 break",
|
||||
s_cmp_eq_i32(s[1], 0),
|
||||
"s_cbranch_scc0 loop",
|
||||
"break:",
|
||||
s_endpgm(),
|
||||
s_code_end(),
|
||||
])
|
||||
k = Kernel(arch=Device["AMD"].arch)
|
||||
k.label("entry")
|
||||
k.emit(s_mov_b32(s[1], 8))
|
||||
k.label("loop")
|
||||
k.emit(s_add_u32(s[1], s[1], -1))
|
||||
k.emit(s_cmp_eq_i32(s[1], 5))
|
||||
k.emit(s_cbranch_scc1(), target="break")
|
||||
k.emit(s_cmp_eq_i32(s[1], 0))
|
||||
k.emit(s_cbranch_scc0(), target="loop")
|
||||
k.label("break")
|
||||
k.emit(s_endpgm())
|
||||
k.emit(s_code_end())
|
||||
run_asm("loop_break", k)
|
||||
|
||||
def test_switch(self):
|
||||
run_asm("switch_case", [
|
||||
"entry:",
|
||||
s_cmp_eq_i32(s[0], 0),
|
||||
"s_cbranch_scc1 case0",
|
||||
s_cmp_eq_i32(s[0], 1),
|
||||
"s_cbranch_scc1 case1",
|
||||
"s_branch case2",
|
||||
"case0:",
|
||||
s_nop(0),
|
||||
"s_branch join",
|
||||
"case1:",
|
||||
s_nop(1),
|
||||
"s_branch join",
|
||||
"case2:",
|
||||
s_nop(2),
|
||||
"s_branch join",
|
||||
"join:",
|
||||
s_endpgm(),
|
||||
s_code_end(),
|
||||
])
|
||||
k = Kernel(arch=Device["AMD"].arch)
|
||||
k.label("entry")
|
||||
k.emit(s_cmp_eq_i32(s[0], 0))
|
||||
k.emit(s_cbranch_scc1(), target="case0")
|
||||
k.emit(s_cmp_eq_i32(s[0], 1))
|
||||
k.emit(s_cbranch_scc1(), target="case1")
|
||||
k.emit(s_branch(), target="case2")
|
||||
k.label("case0")
|
||||
k.emit(s_nop(0))
|
||||
k.emit(s_branch(), target="join")
|
||||
k.label("case1")
|
||||
k.emit(s_nop(1))
|
||||
k.emit(s_branch(), target="join")
|
||||
k.label("case2")
|
||||
k.emit(s_nop(2))
|
||||
k.emit(s_branch(), target="join")
|
||||
k.label("join")
|
||||
k.emit(s_endpgm())
|
||||
k.emit(s_code_end())
|
||||
run_asm("switch_case", k)
|
||||
|
||||
def test_ping_pong(self):
|
||||
run_asm("ping_pong", [
|
||||
"entry:",
|
||||
s_cmp_eq_i32(s[0], 0),
|
||||
"s_cbranch_scc1 ping",
|
||||
"s_branch pong",
|
||||
"ping:",
|
||||
s_cmp_eq_i32(s[1], 0),
|
||||
"s_cbranch_scc1 pong",
|
||||
"s_branch end",
|
||||
"pong:",
|
||||
s_cmp_eq_i32(s[2], 0),
|
||||
"s_cbranch_scc1 ping",
|
||||
"end:",
|
||||
s_endpgm(),
|
||||
s_code_end(),
|
||||
])
|
||||
k = Kernel(arch=Device["AMD"].arch)
|
||||
k.label("entry")
|
||||
k.emit(s_cmp_eq_i32(s[0], 0))
|
||||
k.emit(s_cbranch_scc1(), target="ping")
|
||||
k.emit(s_branch(), target="pong")
|
||||
k.label("ping")
|
||||
k.emit(s_cmp_eq_i32(s[1], 0))
|
||||
k.emit(s_cbranch_scc1(), target="pong")
|
||||
k.emit(s_branch(), target="end")
|
||||
k.label("pong")
|
||||
k.emit(s_cmp_eq_i32(s[2], 0))
|
||||
k.emit(s_cbranch_scc1(), target="ping")
|
||||
k.label("end")
|
||||
k.emit(s_endpgm())
|
||||
k.emit(s_code_end())
|
||||
run_asm("ping_pong", k)
|
||||
|
||||
def test_colored_blocks(self):
|
||||
N = 10
|
||||
asm = ["entry:", "s_branch init0"]
|
||||
k = Kernel(arch=Device["AMD"].arch)
|
||||
k.label("entry")
|
||||
k.emit(s_branch(), target="init0")
|
||||
for i in range(N):
|
||||
asm += [f"init{i}:", s_mov_b32(s[1], i + 1), f"s_branch {(loop:=f'loop{i}')}"]
|
||||
asm += [
|
||||
f"{loop}:",
|
||||
s_nop(i & 7),
|
||||
s_add_u32(s[1], s[1], -1),
|
||||
s_cmp_eq_i32(s[1], 0),
|
||||
f"s_cbranch_scc0 {loop}",
|
||||
f"s_branch {'init' + str(i+1) if i + 1 < N else 'end'}",
|
||||
]
|
||||
asm += ["end:", s_endpgm(), s_code_end()]
|
||||
run_asm("test_colored_blocks", asm)
|
||||
loop = f"loop{i}"
|
||||
k.label(f"init{i}")
|
||||
k.emit(s_mov_b32(s[1], i + 1))
|
||||
k.emit(s_branch(), target=loop)
|
||||
k.label(loop)
|
||||
k.emit(s_nop(i & 7))
|
||||
k.emit(s_add_u32(s[1], s[1], -1))
|
||||
k.emit(s_cmp_eq_i32(s[1], 0))
|
||||
k.emit(s_cbranch_scc0(), target=loop)
|
||||
k.emit(s_branch(), target=f"init{i+1}" if i + 1 < N else "end")
|
||||
k.label("end")
|
||||
k.emit(s_endpgm())
|
||||
k.emit(s_code_end())
|
||||
run_asm("test_colored_blocks", k)
|
||||
|
||||
def test_jump_back_to_end(self):
|
||||
run_asm("jump_back_to_end", [
|
||||
"entry:",
|
||||
s_mov_b32(s[1], 2),
|
||||
"s_cbranch_execz loop",
|
||||
"end:",
|
||||
s_endpgm(),
|
||||
"loop:",
|
||||
s_add_u32(s[1], s[1], -1),
|
||||
s_cmp_eq_i32(s[1], 0),
|
||||
"s_branch end",
|
||||
s_code_end(),
|
||||
])
|
||||
k = Kernel(arch=Device["AMD"].arch)
|
||||
k.label("entry")
|
||||
k.emit(s_mov_b32(s[1], 2))
|
||||
k.emit(s_cbranch_execz(), target="loop")
|
||||
k.label("end")
|
||||
k.emit(s_endpgm())
|
||||
k.label("loop")
|
||||
k.emit(s_add_u32(s[1], s[1], -1))
|
||||
k.emit(s_cmp_eq_i32(s[1], 0))
|
||||
k.emit(s_branch(), target="end")
|
||||
k.emit(s_code_end())
|
||||
run_asm("jump_back_to_end", k)
|
||||
|
||||
def test_hit_count(self):
|
||||
run_asm("test_hit_count", [
|
||||
"entry:",
|
||||
s_mov_b32(s[1], 1),
|
||||
"s_branch alt",
|
||||
"continue:",
|
||||
s_mov_b32(s[2], 2),
|
||||
s_add_u32(s[1], s[1], s[2]),
|
||||
"alt:",
|
||||
s_add_u32(s[1], s[1], -1),
|
||||
s_endpgm(),
|
||||
s_code_end(),
|
||||
])
|
||||
k = Kernel(arch=Device["AMD"].arch)
|
||||
k.label("entry")
|
||||
k.emit(s_mov_b32(s[1], 1))
|
||||
k.emit(s_branch(), target="alt")
|
||||
k.label("continue")
|
||||
k.emit(s_mov_b32(s[2], 2))
|
||||
k.emit(s_add_u32(s[1], s[1], s[2]))
|
||||
k.label("alt")
|
||||
k.emit(s_add_u32(s[1], s[1], -1))
|
||||
k.emit(s_endpgm())
|
||||
k.emit(s_code_end())
|
||||
run_asm("test_hit_count", k)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -462,17 +462,25 @@ class TestAssign(unittest.TestCase):
|
|||
np.testing.assert_allclose(a.numpy(), [0., 0., 1., 2., 3., 0., 0., 0.])
|
||||
|
||||
def test_assign_bitcast(self):
|
||||
# assign to a bitcast view should modify the underlying buffer (only works on DISK currently)
|
||||
# assign to a bitcast view should modify the underlying buffer
|
||||
a = Tensor([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32).realize()
|
||||
# IEEE 754: 1.0f = 0x3f800000, 2.0f = 0x40000000, 3.0f = 0x40400000, 4.0f = 0x40800000
|
||||
a.bitcast(dtypes.uint32).assign(Tensor([0x40800000, 0x40400000, 0x40000000, 0x3f800000], dtype=dtypes.uint32)).realize()
|
||||
np.testing.assert_allclose(a.numpy(), [1.0, 2.0, 3.0, 4.0]) # TODO: should be [4.0, 3.0, 2.0, 1.0]
|
||||
np.testing.assert_allclose(a.numpy(), [4.0, 3.0, 2.0, 1.0])
|
||||
# double bitcast
|
||||
b = Tensor([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32).realize()
|
||||
b.bitcast(dtypes.uint32).bitcast(dtypes.int32).assign(Tensor([0x40800000, 0x40400000, 0x40000000, 0x3f800000], dtype=dtypes.int32)).realize()
|
||||
np.testing.assert_allclose(b.numpy(), [4.0, 3.0, 2.0, 1.0])
|
||||
# shrink then bitcast
|
||||
c = Tensor([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32).realize()
|
||||
c[0:2].bitcast(dtypes.uint32).assign(Tensor([0x40800000, 0x40400000], dtype=dtypes.uint32)).realize()
|
||||
np.testing.assert_allclose(c.numpy(), [4.0, 3.0, 3.0, 4.0])
|
||||
|
||||
def test_assign_bitcast_different_size(self):
|
||||
# assign to a shape-changing bitcast view (only works on DISK currently)
|
||||
# different-size bitcast creates a new tensor, not a view, so assign doesn't modify the original
|
||||
a = Tensor([0]*8, dtype=dtypes.uint8).realize()
|
||||
a.bitcast(dtypes.int64).assign(Tensor([12345], dtype=dtypes.int64)).realize()
|
||||
np.testing.assert_equal(a.numpy(), [0]*8) # TODO: should be [57, 48, 0, 0, 0, 0, 0, 0] (little-endian 12345)
|
||||
np.testing.assert_equal(a.numpy(), [0]*8)
|
||||
|
||||
@unittest.skip("don't use output buffer, and mismatch dtype no longer supported")
|
||||
def test_cast_assignment(self):
|
||||
|
|
@ -485,6 +493,38 @@ class TestAssign(unittest.TestCase):
|
|||
assert oba1 is None and oba2 is None
|
||||
np.testing.assert_allclose(a.numpy(), np.arange(N*N,dtype=np.int32).reshape((N,N)))
|
||||
|
||||
def test_assign_dtype_mismatch(self):
|
||||
# assign should not implicitly cast dtypes - this can lose precision
|
||||
a = Tensor.zeros(4, dtype=dtypes.float32).contiguous().realize()
|
||||
b = Tensor([1, 2, 3, 4], dtype=dtypes.int32)
|
||||
with self.assertRaisesRegex(RuntimeError, "assign dtype mismatch"):
|
||||
a.assign(b)
|
||||
|
||||
def test_assign_dtype_mismatch_int64_to_float32(self):
|
||||
# int64 -> float32 loses precision for large values, should not be implicit
|
||||
a = Tensor.zeros(1, dtype=dtypes.float32).contiguous().realize()
|
||||
b = Tensor([16777217], dtype=dtypes.int64) # 2^24 + 1, not exactly representable in float32
|
||||
with self.assertRaisesRegex(RuntimeError, "assign dtype mismatch"):
|
||||
a.assign(b)
|
||||
|
||||
def test_assign_shape_broadcast(self):
|
||||
# shape broadcasting should work when dtypes match
|
||||
a = Tensor.zeros(3, 5, dtype=dtypes.float32).contiguous().realize()
|
||||
b = Tensor([1., 2., 3., 4., 5.], dtype=dtypes.float32)
|
||||
a.assign(b)
|
||||
a.realize()
|
||||
expected = np.array([[1., 2., 3., 4., 5.]] * 3)
|
||||
np.testing.assert_allclose(a.numpy(), expected)
|
||||
|
||||
def test_assign_shape_broadcast_2d(self):
|
||||
# broadcast (1, 5) to (3, 5)
|
||||
a = Tensor.zeros(3, 5, dtype=dtypes.float32).contiguous().realize()
|
||||
b = Tensor([[1., 2., 3., 4., 5.]], dtype=dtypes.float32)
|
||||
a.assign(b)
|
||||
a.realize()
|
||||
expected = np.array([[1., 2., 3., 4., 5.]] * 3)
|
||||
np.testing.assert_allclose(a.numpy(), expected)
|
||||
|
||||
def test_disk_assignment(self):
|
||||
a = Tensor.empty(5, device=f"disk:{temp('disk_assignment')}").assign(Tensor.ones(5)).numpy()
|
||||
np.testing.assert_equal(a, np.ones(5))
|
||||
|
|
@ -579,12 +619,12 @@ class TestAssignOrdering(unittest.TestCase):
|
|||
def test_slice_write_then_full_read(self):
|
||||
"""Write to slice, then read full buffer."""
|
||||
# without .realize(): orphan slice assign not triggered by .numpy()
|
||||
buf = Tensor.zeros(4).contiguous().realize()
|
||||
buf = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
|
||||
buf[1:3].assign(Tensor([5, 6]))
|
||||
np.testing.assert_equal(buf.numpy(), [0, 0, 0, 0]) # TODO: wrong! should be [0, 5, 6, 0]
|
||||
|
||||
# with .realize(): assign executes
|
||||
buf = Tensor.zeros(4).contiguous().realize()
|
||||
buf = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
|
||||
buf[1:3].assign(Tensor([5, 6])).realize()
|
||||
np.testing.assert_equal(buf.numpy(), [0, 5, 6, 0])
|
||||
|
||||
|
|
@ -666,7 +706,7 @@ class TestAssignOrdering(unittest.TestCase):
|
|||
|
||||
def test_three_buffer_chain(self):
|
||||
"""Chain: A depends on B, B depends on C - ordering matters."""
|
||||
a = Tensor.zeros(4).contiguous().realize()
|
||||
a = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
|
||||
b = Tensor([1, 2, 3, 4]).contiguous().realize()
|
||||
c = Tensor([10, 10, 10, 10]).contiguous().realize()
|
||||
# b reads from c, a reads from b
|
||||
|
|
@ -678,8 +718,8 @@ class TestAssignOrdering(unittest.TestCase):
|
|||
|
||||
def test_interleaved_assign_read_patterns(self):
|
||||
"""Complex interleaved pattern: write A, read A into B, write B, read B."""
|
||||
a = Tensor.zeros(4).contiguous().realize()
|
||||
b = Tensor.zeros(4).contiguous().realize()
|
||||
a = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
|
||||
b = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
|
||||
|
||||
a.assign(Tensor([1, 2, 3, 4]))
|
||||
b.assign(a.contiguous()) # b should get [1,2,3,4]
|
||||
|
|
|
|||
|
|
@ -5,24 +5,13 @@ from tinygrad.runtime.support.c import DLL, record, init_records
|
|||
from tinygrad.runtime.support import c
|
||||
from tinygrad.runtime.support.autogen import gen
|
||||
|
||||
class TestAutogen(unittest.TestCase):
|
||||
@unittest.skipIf(WIN, "doesn't compile on windows")
|
||||
class TestC(unittest.TestCase):
|
||||
def compile(self, src):
|
||||
with tempfile.NamedTemporaryFile(suffix=".so") as f:
|
||||
subprocess.check_output(('clang', '-x', 'c', '-fPIC', '-shared', '-', '-o', f.name), input=src.encode())
|
||||
return DLL("test", f.name)
|
||||
|
||||
def run_gen(self, contents):
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.h') as f:
|
||||
f.write(contents)
|
||||
f.flush()
|
||||
|
||||
generated_code = gen(name="test_header", dll=None, files=[f.name])
|
||||
|
||||
namespace = {}
|
||||
exec(generated_code, namespace)
|
||||
return namespace
|
||||
|
||||
@unittest.skipIf(WIN, "doesn't compile on windows")
|
||||
def test_packed_struct(self):
|
||||
@record
|
||||
class Baz:
|
||||
|
|
@ -45,7 +34,6 @@ class TestAutogen(unittest.TestCase):
|
|||
assert b.c == 1
|
||||
assert b.d == 0
|
||||
|
||||
@unittest.skipIf(WIN, "doesn't compile on windows")
|
||||
def test_packed_struct_interop(self):
|
||||
@record
|
||||
class Baz:
|
||||
|
|
@ -75,7 +63,6 @@ class TestAutogen(unittest.TestCase):
|
|||
self.assertEqual(test(b), b.a + b.b + b.c + b.d)
|
||||
|
||||
# https://github.com/python/cpython/issues/90914
|
||||
@unittest.skipIf(WIN, "doesn't compile on windows")
|
||||
def test_bitfield_interop(self):
|
||||
@record
|
||||
class Baz:
|
||||
|
|
@ -103,7 +90,6 @@ class TestAutogen(unittest.TestCase):
|
|||
def test(x:Baz) -> ctypes.c_int: ...
|
||||
for i in range(8): self.assertEqual(test(Baz(*(j==i for j in range(8)))), i==2)
|
||||
|
||||
@unittest.skipIf(WIN, "doesn't compile on windows")
|
||||
def test_struct_interop(self):
|
||||
@record
|
||||
class Baz:
|
||||
|
|
@ -131,7 +117,6 @@ class TestAutogen(unittest.TestCase):
|
|||
def test(x:Baz) -> Baz: ...
|
||||
self.assertEqual(bytes(test(Baz(*range(8)))), struct.pack("8i", *range(7, -1, -1)))
|
||||
|
||||
@unittest.skipIf(WIN, "doesn't compile on windows")
|
||||
def test_aos_interop(self):
|
||||
@record
|
||||
class Item:
|
||||
|
|
@ -151,7 +136,6 @@ class TestAutogen(unittest.TestCase):
|
|||
def test(arr:(Item * 3)) -> ctypes.c_int: ...
|
||||
self.assertEqual(test((Item * 3)(Item(10), Item(20), Item(30))), 60)
|
||||
|
||||
@unittest.skipIf(WIN, "doesn't compile on windows")
|
||||
def test_soa_interop(self):
|
||||
@record
|
||||
class Row:
|
||||
|
|
@ -173,7 +157,6 @@ class TestAutogen(unittest.TestCase):
|
|||
self.assertEqual(r.data[1], 20)
|
||||
self.assertEqual(r.data[2], 10)
|
||||
|
||||
@unittest.skipIf(WIN, "doesn't compile on windows")
|
||||
def test_soa_ptr_interop(self):
|
||||
@record
|
||||
class Row:
|
||||
|
|
@ -191,7 +174,6 @@ class TestAutogen(unittest.TestCase):
|
|||
def test(x:Row) -> ctypes.c_int: ...
|
||||
assert test(Row((ctypes.c_int * 3)(10, 20, 30))) == 60
|
||||
|
||||
@unittest.skipIf(WIN, "doesn't compile on windows")
|
||||
def test_nested_struct_interop(self):
|
||||
@record
|
||||
class Inner:
|
||||
|
|
@ -217,7 +199,6 @@ class TestAutogen(unittest.TestCase):
|
|||
self.assertEqual(o.inner.a, 20)
|
||||
self.assertEqual(o.b, 10)
|
||||
|
||||
@unittest.skipIf(WIN, "doesn't compile on windows")
|
||||
def test_struct_pointer_interop(self):
|
||||
@record
|
||||
class Foo:
|
||||
|
|
@ -242,7 +223,88 @@ class TestAutogen(unittest.TestCase):
|
|||
self.assertEqual(out.contents.a, 20)
|
||||
self.assertEqual(out.contents.b, 10)
|
||||
|
||||
@unittest.skipIf(WIN, "doesn't compile on windows")
|
||||
def test_pointer_field_roundtrip(self):
|
||||
# This tests storing a pointer in a record struct field and passing it to C
|
||||
# Mimics how mesa.struct_lp_build_tgsi_params.mask is used
|
||||
from tinygrad.runtime.support.c import POINTER
|
||||
@record
|
||||
class Inner:
|
||||
SIZE = 8
|
||||
value: Annotated[ctypes.c_int, 0]
|
||||
flag: Annotated[ctypes.c_int, 4]
|
||||
@record
|
||||
class Outer:
|
||||
SIZE = 16
|
||||
x: Annotated[ctypes.c_int, 0]
|
||||
inner_ptr: Annotated[POINTER[Inner], 8]
|
||||
init_records()
|
||||
|
||||
src = """
|
||||
struct inner { int value; int flag; };
|
||||
struct outer { int x; struct inner *inner_ptr; };
|
||||
int test(struct inner *p) {
|
||||
return p->value + p->flag;
|
||||
}
|
||||
"""
|
||||
dll = self.compile(src)
|
||||
@dll.bind
|
||||
def test(p:POINTER[Inner]) -> ctypes.c_int: ...
|
||||
|
||||
inner = Inner(value=42, flag=10)
|
||||
outer = Outer(x=1, inner_ptr=ctypes.pointer(inner))
|
||||
# Retrieve pointer from struct field and pass to C
|
||||
self.assertEqual(test(outer.inner_ptr), 52)
|
||||
|
||||
def test_pointer_field_loses_reference(self):
|
||||
# BUG: When a pointer is stored in a record struct field, only the address bytes are saved.
|
||||
# The pointer's _objects dict (which prevents GC of the pointed-to object) is lost.
|
||||
# This causes the pointed-to object to be garbage collected, leading to use-after-free.
|
||||
from tinygrad.runtime.support.c import POINTER
|
||||
@record
|
||||
class MaskContext:
|
||||
SIZE = 16
|
||||
value: Annotated[ctypes.c_int, 0]
|
||||
initialized: Annotated[ctypes.c_int, 4]
|
||||
ptr: Annotated[ctypes.c_void_p, 8]
|
||||
@record
|
||||
class Params:
|
||||
SIZE = 16
|
||||
x: Annotated[ctypes.c_int, 0]
|
||||
mask: Annotated[POINTER[MaskContext], 8]
|
||||
init_records()
|
||||
|
||||
src = """
|
||||
struct mask_ctx { int value; int initialized; void *ptr; };
|
||||
void mask_begin(struct mask_ctx *m, int val) { m->value = val; m->initialized = 1; }
|
||||
int mask_end(struct mask_ctx *m) { return m->value + m->initialized; }
|
||||
"""
|
||||
dll = self.compile(src)
|
||||
@dll.bind
|
||||
def mask_begin(m:POINTER[MaskContext], val:ctypes.c_int) -> None: ...
|
||||
@dll.bind
|
||||
def mask_end(m:POINTER[MaskContext]) -> ctypes.c_int: ...
|
||||
|
||||
# When MaskContext() is created inline, it gets garbage collected after the pointer
|
||||
# is stored because only the address bytes are saved, not the _objects reference.
|
||||
params = Params(x=1, mask=ctypes.pointer(MaskContext()))
|
||||
mask_begin(params.mask, 42)
|
||||
result = mask_end(params.mask)
|
||||
self.assertEqual(result, 43) # 42 + 1
|
||||
|
||||
@unittest.skipIf(OSX and ('MTLCompiler' in DLL._loaded_ or 'llvm' in DLL._loaded_), "libclang can't be loaded after MTLCompiler or llvm on OSX")
|
||||
@unittest.skipIf(WIN, "doesn't compile on windows")
|
||||
class TestAutogen(unittest.TestCase):
|
||||
def run_gen(self, contents):
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.h') as f:
|
||||
f.write(contents)
|
||||
f.flush()
|
||||
|
||||
generated_code = gen(name="test_header", dll=None, files=[f.name])
|
||||
|
||||
namespace = {}
|
||||
exec(generated_code, namespace)
|
||||
return namespace
|
||||
|
||||
def test_packed_structs(self):
|
||||
ns = self.run_gen("""
|
||||
typedef unsigned NvU32;
|
||||
|
|
@ -292,47 +354,6 @@ typedef struct
|
|||
assert frts_cmd.readVbiosDesc.__class__ is FWSECLIC_READ_VBIOS_DESC
|
||||
assert frts_cmd.frtsRegionDesc.__class__ is FWSECLIC_FRTS_REGION_DESC
|
||||
|
||||
@unittest.skipIf(WIN, "doesn't compile on windows")
|
||||
@unittest.skipIf(OSX, "can't find stdint?")
|
||||
def test_packed_fields(self):
|
||||
ns = self.run_gen("""#include <stdint.h>
|
||||
typedef struct die_info
|
||||
{
|
||||
uint16_t die_id;
|
||||
uint16_t die_offset; /* Points to the corresponding die_header structure */
|
||||
} die_info;
|
||||
|
||||
typedef struct ip_discovery_header
|
||||
{
|
||||
uint32_t signature; /* Table Signature */
|
||||
uint16_t version; /* Table Version */
|
||||
uint16_t size; /* Table Size */
|
||||
uint32_t id; /* Table ID */
|
||||
uint16_t num_dies; /* Number of Dies */
|
||||
die_info die_info[16]; /* list die information for up to 16 dies */
|
||||
union {
|
||||
uint16_t padding[1]; /* version <= 3 */
|
||||
struct { /* version == 4 */
|
||||
uint8_t base_addr_64_bit : 1; /* ip structures are using 64 bit base address */
|
||||
uint8_t reserved : 7;
|
||||
uint8_t reserved2;
|
||||
};
|
||||
};
|
||||
} ip_discovery_header;
|
||||
""")
|
||||
|
||||
ip_discovery_header = ns['ip_discovery_header']
|
||||
|
||||
hdr = b'IPDS\x04\x00|\x1d\x80\x1a\xffd\x01\x00\x00\x00\x8c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00' # noqa: E501
|
||||
ihdr = ip_discovery_header.from_buffer_copy(hdr)
|
||||
|
||||
assert ctypes.sizeof(ihdr) == 80
|
||||
assert ihdr.signature == 0x53445049
|
||||
assert ihdr.version == 0x0004
|
||||
assert ihdr.num_dies == 1
|
||||
assert ihdr.base_addr_64_bit == 1
|
||||
|
||||
@unittest.skipIf(WIN, "doesn't compile on windows")
|
||||
def test_gen_from_header(self):
|
||||
namespace = self.run_gen("""
|
||||
typedef struct {
|
||||
|
|
@ -378,7 +399,6 @@ typedef struct ip_discovery_header
|
|||
self.assertTrue(hasattr(rect, 'height'))
|
||||
self.assertTrue(hasattr(rect, 'color'))
|
||||
|
||||
@unittest.skipIf(WIN, "doesn't compile on windows")
|
||||
def test_struct_ordering(self):
|
||||
namespace = self.run_gen("""
|
||||
struct A;
|
||||
|
|
@ -408,77 +428,6 @@ typedef struct ip_discovery_header
|
|||
self.assertTrue(hasattr(b, 'c_ptr'))
|
||||
self.assertTrue(hasattr(c, 'a_ptr'))
|
||||
|
||||
@unittest.skipIf(WIN, "doesn't compile on windows")
|
||||
def test_pointer_field_roundtrip(self):
|
||||
# This tests storing a pointer in a record struct field and passing it to C
|
||||
# Mimics how mesa.struct_lp_build_tgsi_params.mask is used
|
||||
from tinygrad.runtime.support.c import POINTER
|
||||
@record
|
||||
class Inner:
|
||||
SIZE = 8
|
||||
value: Annotated[ctypes.c_int, 0]
|
||||
flag: Annotated[ctypes.c_int, 4]
|
||||
@record
|
||||
class Outer:
|
||||
SIZE = 16
|
||||
x: Annotated[ctypes.c_int, 0]
|
||||
inner_ptr: Annotated[POINTER[Inner], 8]
|
||||
init_records()
|
||||
|
||||
src = """
|
||||
struct inner { int value; int flag; };
|
||||
struct outer { int x; struct inner *inner_ptr; };
|
||||
int test(struct inner *p) {
|
||||
return p->value + p->flag;
|
||||
}
|
||||
"""
|
||||
dll = self.compile(src)
|
||||
@dll.bind
|
||||
def test(p:POINTER[Inner]) -> ctypes.c_int: ...
|
||||
|
||||
inner = Inner(value=42, flag=10)
|
||||
outer = Outer(x=1, inner_ptr=ctypes.pointer(inner))
|
||||
# Retrieve pointer from struct field and pass to C
|
||||
self.assertEqual(test(outer.inner_ptr), 52)
|
||||
|
||||
@unittest.skipIf(WIN, "doesn't compile on windows")
|
||||
def test_pointer_field_loses_reference(self):
|
||||
# BUG: When a pointer is stored in a record struct field, only the address bytes are saved.
|
||||
# The pointer's _objects dict (which prevents GC of the pointed-to object) is lost.
|
||||
# This causes the pointed-to object to be garbage collected, leading to use-after-free.
|
||||
from tinygrad.runtime.support.c import POINTER
|
||||
@record
|
||||
class MaskContext:
|
||||
SIZE = 16
|
||||
value: Annotated[ctypes.c_int, 0]
|
||||
initialized: Annotated[ctypes.c_int, 4]
|
||||
ptr: Annotated[ctypes.c_void_p, 8]
|
||||
@record
|
||||
class Params:
|
||||
SIZE = 16
|
||||
x: Annotated[ctypes.c_int, 0]
|
||||
mask: Annotated[POINTER[MaskContext], 8]
|
||||
init_records()
|
||||
|
||||
src = """
|
||||
struct mask_ctx { int value; int initialized; void *ptr; };
|
||||
void mask_begin(struct mask_ctx *m, int val) { m->value = val; m->initialized = 1; }
|
||||
int mask_end(struct mask_ctx *m) { return m->value + m->initialized; }
|
||||
"""
|
||||
dll = self.compile(src)
|
||||
@dll.bind
|
||||
def mask_begin(m:POINTER[MaskContext], val:ctypes.c_int) -> None: ...
|
||||
@dll.bind
|
||||
def mask_end(m:POINTER[MaskContext]) -> ctypes.c_int: ...
|
||||
|
||||
# When MaskContext() is created inline, it gets garbage collected after the pointer
|
||||
# is stored because only the address bytes are saved, not the _objects reference.
|
||||
params = Params(x=1, mask=ctypes.pointer(MaskContext()))
|
||||
mask_begin(params.mask, 42)
|
||||
result = mask_end(params.mask)
|
||||
self.assertEqual(result, 43) # 42 + 1
|
||||
|
||||
@unittest.skipIf(WIN, "doesn't compile on windows")
|
||||
def test_anonymous_children(self):
|
||||
namespace = self.run_gen("""
|
||||
struct foo {
|
||||
|
|
@ -491,7 +440,6 @@ typedef struct ip_discovery_header
|
|||
self.assertIn('struct_foo', namespace)
|
||||
self.assertIn('struct_foo_bar', namespace)
|
||||
|
||||
@unittest.skipIf(WIN, "doesn't compile on windows")
|
||||
def test_enums(self):
|
||||
namespace = self.run_gen("""
|
||||
enum Foo { A, B, C };
|
||||
|
|
@ -511,4 +459,43 @@ typedef struct ip_discovery_header
|
|||
assert namespace["enum_Bar"].get(1) == "Y"
|
||||
assert namespace["enum_Bar"].get(2) == "Z"
|
||||
|
||||
@unittest.skipIf(OSX, "can't find stdint?")
|
||||
def test_packed_fields(self):
|
||||
ns = self.run_gen("""#include <stdint.h>
|
||||
typedef struct die_info
|
||||
{
|
||||
uint16_t die_id;
|
||||
uint16_t die_offset; /* Points to the corresponding die_header structure */
|
||||
} die_info;
|
||||
|
||||
typedef struct ip_discovery_header
|
||||
{
|
||||
uint32_t signature; /* Table Signature */
|
||||
uint16_t version; /* Table Version */
|
||||
uint16_t size; /* Table Size */
|
||||
uint32_t id; /* Table ID */
|
||||
uint16_t num_dies; /* Number of Dies */
|
||||
die_info die_info[16]; /* list die information for up to 16 dies */
|
||||
union {
|
||||
uint16_t padding[1]; /* version <= 3 */
|
||||
struct { /* version == 4 */
|
||||
uint8_t base_addr_64_bit : 1; /* ip structures are using 64 bit base address */
|
||||
uint8_t reserved : 7;
|
||||
uint8_t reserved2;
|
||||
};
|
||||
};
|
||||
} ip_discovery_header;
|
||||
""")
|
||||
|
||||
ip_discovery_header = ns['ip_discovery_header']
|
||||
|
||||
hdr = b'IPDS\x04\x00|\x1d\x80\x1a\xffd\x01\x00\x00\x00\x8c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00' # noqa: E501
|
||||
ihdr = ip_discovery_header.from_buffer_copy(hdr)
|
||||
|
||||
assert ctypes.sizeof(ihdr) == 80
|
||||
assert ihdr.signature == 0x53445049
|
||||
assert ihdr.version == 0x0004
|
||||
assert ihdr.num_dies == 1
|
||||
assert ihdr.base_addr_64_bit == 1
|
||||
|
||||
if __name__ == "__main__": unittest.main()
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class TestCall(unittest.TestCase):
|
|||
b = Tensor.randn(K, N)
|
||||
Tensor.realize(a, b)
|
||||
c = Tensor.call(a, b, fxn=a.as_param(0) @ b.as_param(1))
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5)
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6)
|
||||
|
||||
@unittest.skip("needs GEMM on mixins")
|
||||
def test_call_gemm_uop(self):
|
||||
|
|
@ -56,7 +56,7 @@ class TestCall(unittest.TestCase):
|
|||
y = UOp.param(1, dtypes.float, shape=(K, N))
|
||||
c = Tensor.call(a, b, fxn=x@y)
|
||||
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5)
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -59,13 +59,13 @@ class TestRawDiskBuffer(unittest.TestCase):
|
|||
_test_bitcasted(t, dtypes.float32, 0.0)
|
||||
_test_bitcasted(t, dtypes.uint32, 0)
|
||||
# pi in float16 stored via int16
|
||||
t.bitcast(dtypes.uint16).assign(Tensor.full((128, 64), 0x4248, dtype=dtypes.uint16)).realize()
|
||||
t.assign(Tensor.full((128, 64), 0x4248, dtype=dtypes.uint16).bitcast(dtypes.uint8)).realize()
|
||||
_test_bitcasted(t, dtypes.float16, 3.140625)
|
||||
_test_bitcasted(t, dtypes.float32, 50.064727)
|
||||
_test_bitcasted(t, dtypes.uint16, 0x4248)
|
||||
_test_bitcasted(t, dtypes.uint32, 0x42484248)
|
||||
# pi in float32 stored via float32
|
||||
t.bitcast(dtypes.float32).assign(Tensor.full((128, 32), 3.1415927, dtype=dtypes.float32)).realize()
|
||||
t.assign(Tensor.full((128, 32), 3.1415927, dtype=dtypes.float32).bitcast(dtypes.uint8)).realize()
|
||||
_test_bitcasted(t, dtypes.float32, 3.1415927)
|
||||
_test_bitcasted(t, dtypes.uint32, 0x40490FDB)
|
||||
# doesn't suport normal cast
|
||||
|
|
@ -348,13 +348,25 @@ class TestDiskTensor(unittest.TestCase):
|
|||
|
||||
def test_assign_with_bitcast(self):
|
||||
# bitcast assign is used in safe_save for writing header length
|
||||
# this tests the synchronous disk assign hack handles bitcast correctly
|
||||
# bitcast on source side works, bitcast on target side raises
|
||||
pathlib.Path(temp(fn:="dt_assign_bitcast")).unlink(missing_ok=True)
|
||||
t = Tensor.empty(16, device=f"disk:{temp(fn)}", dtype=dtypes.uint8)
|
||||
t[0:8].bitcast(dtypes.int64).assign([12345])
|
||||
# verify the data was written correctly
|
||||
# correct way: bitcast the source to match target dtype
|
||||
t[0:8].assign(Tensor([12345], dtype=dtypes.int64, device="CPU").bitcast(dtypes.uint8))
|
||||
val = int.from_bytes(t[0:8].data(), 'little')
|
||||
self.assertEqual(val, 12345)
|
||||
# bitcast on target with non-broadcastable dtype raises
|
||||
with self.assertRaises(RuntimeError):
|
||||
t[0:4].bitcast(dtypes.int32).assign(Tensor([12345], dtype=dtypes.int64))
|
||||
|
||||
def test_assign_to_bitcast_view(self):
|
||||
# assign float values to a float32 view of a uint8 disk buffer (used by safe_save)
|
||||
pathlib.Path(temp(fn:="dt_bitcast_view_assign")).unlink(missing_ok=True)
|
||||
t = Tensor.empty(32, device=f"disk:{temp(fn)}", dtype=dtypes.uint8)
|
||||
# create float32 view of bytes 8-24 (4 floats)
|
||||
float_view = t[8:24].bitcast(dtypes.float32)
|
||||
float_view.assign(Tensor([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32, device="CPU"))
|
||||
np.testing.assert_array_equal(float_view.numpy(), [1.0, 2.0, 3.0, 4.0])
|
||||
|
||||
def test_assign_cross_device(self):
|
||||
# disk assign allows cross-device (source on GPU/CPU, target on disk)
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@ class TestScheduleCache(unittest.TestCase):
|
|||
_, var_vals = t.schedule_with_vars()
|
||||
self.assertEqual(var_vals, {'pos': 42})
|
||||
|
||||
@Context(SPEC=0)
|
||||
def test_custom_kernel(self):
|
||||
for i in range(4):
|
||||
a = Tensor.empty(1)
|
||||
|
|
@ -43,7 +42,6 @@ class TestScheduleCache(unittest.TestCase):
|
|||
a.realize()
|
||||
self.assertEqual(a.item(), i)
|
||||
|
||||
@Context(SPEC=0)
|
||||
def test_same_custom_function_reuses_cache(self):
|
||||
schedule_cache.clear()
|
||||
fxn = functools.partial(custom_set0_kernel, num=10)
|
||||
|
|
|
|||
|
|
@ -470,5 +470,19 @@ class TestUnfoldableImageChannelSelection(unittest.TestCase):
|
|||
load = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((10, 10, 4)), arg=0).index(x, ptr=True), UOp.const(dtypes.float, 0)))
|
||||
self.assertEqual(self._count_nans(load), 1)
|
||||
|
||||
class TestDropTrueGate(unittest.TestCase):
|
||||
def test_drop_true_gate_on_index(self):
|
||||
# test that INDEX with a constant True gate gets simplified to drop the gate
|
||||
from tinygrad.codegen.late.devectorizer import load_store_indexing
|
||||
from tinygrad.uop.ops import graph_rewrite
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
|
||||
idx = UOp.const(dtypes.index, 0)
|
||||
true_gate = UOp.const(dtypes.bool, True)
|
||||
index_with_gate = UOp(Ops.INDEX, dtypes.int.ptr(), (buf, idx, true_gate))
|
||||
# apply the optimization
|
||||
result = graph_rewrite(index_with_gate, load_store_indexing)
|
||||
# the True gate should be dropped (INDEX should only have 2 sources)
|
||||
self.assertEqual(len(result.src), 2, "True gate should be dropped from INDEX")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -479,6 +479,26 @@ class TestUOpGraph(unittest.TestCase):
|
|||
for u in uops:
|
||||
self.assertNotEqual(u.dtype, dtypes.long)
|
||||
|
||||
def test_load_idx_no_math_on_loaded(self):
|
||||
# test the (x+y)<c pattern where x has loads - we shouldn't do math on loaded indices
|
||||
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(128000), arg=0, src=())
|
||||
c1 = UOp.range(UOp.const(dtypes.index, 512), 1, AxisType.LOOP)
|
||||
c2 = UOp.range(UOp.const(dtypes.index, 250), 2, AxisType.LOOP)
|
||||
c3 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=1, src=())
|
||||
c4 = c3.index(c1) # c4 is a load
|
||||
c5 = UOp.range(UOp.const(dtypes.index, 240), 0, AxisType.REDUCE)
|
||||
c6 = ((c2*UOp.const(dtypes.index, 240))+c5)
|
||||
c7 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(60000), arg=2, src=())
|
||||
c8 = c7.index(c6)
|
||||
# (loaded + range) < const pattern - loaded value shouldn't be promoted to long
|
||||
loaded_idx = c4.cast(dtypes.index)
|
||||
comparison = (loaded_idx + c5) < UOp.const(dtypes.index, 60000)
|
||||
c9 = comparison.where(c8.cast(dtypes.uint).cast(dtypes.uchar), 0).reduce(c5, arg=Ops.ADD)
|
||||
c10 = c0.index(((c1*UOp.const(dtypes.index, 250))+c2)).store(c9).end(c1, c2)
|
||||
uops = to_uops_list([c10])
|
||||
for u in uops:
|
||||
self.assertNotEqual(u.dtype, dtypes.long)
|
||||
|
||||
def test_fold_gated_load(self):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
|
||||
|
|
|
|||
|
|
@ -1074,6 +1074,24 @@ class TestGatedUopGivenValid(unittest.TestCase):
|
|||
expected_vec = UOp(Ops.VECTORIZE, dtypes.index.vec(2), (uconst(0), r0))
|
||||
self.assertEqual(idx, (r0 < 3).where(expected_vec, UOp.invalid()))
|
||||
|
||||
class TestRangeSplitting(unittest.TestCase):
|
||||
def test_range_split_on_mod(self):
|
||||
# test that mark_range_mod splits RANGE(8) into RANGE(4)*2 + RANGE(2) when used with %2
|
||||
from tinygrad.codegen.simplify import pm_split_ranges, pm_flatten_range
|
||||
r0 = UOp.range(uconst(8), 0)
|
||||
# create a simple expression using the range with mod: store range%2 to a buffer
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
|
||||
val = (r0 % uconst(2)).cast(dtypes.int)
|
||||
store = UOp(Ops.STORE, dtypes.void, (buf.index(uconst(0)), val))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.END, dtypes.void, (store, r0)),))
|
||||
# count RANGEs before
|
||||
ranges_before = len([u for u in sink.toposort() if u.op is Ops.RANGE])
|
||||
# apply the range splitting optimization
|
||||
sink_after = graph_rewrite(sink, pm_split_ranges+pm_flatten_range, ctx={}, name="test split ranges")
|
||||
# count RANGEs after - should have more due to splitting
|
||||
ranges_after = len([u for u in sink_after.toposort() if u.op is Ops.RANGE])
|
||||
self.assertGreater(ranges_after, ranges_before, "RANGE should be split when used with mod of divisible constant")
|
||||
|
||||
class TestBounds(unittest.TestCase):
|
||||
def test_unrolled_arange(self):
|
||||
# #include <metal_stdlib>
|
||||
|
|
|
|||
|
|
@ -322,8 +322,7 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
|||
input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in ended_ranges])
|
||||
identity = red.const(red.dtype, identity_element(red.arg, red.dtype.scalar()))
|
||||
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=ctx.acc_num)
|
||||
acc_init = acc.after(*input_ranges).index(UOp.const(dtypes.int, 0)).store(identity) if len(input_ranges) else \
|
||||
acc.index(UOp.const(dtypes.int, 0)).store(identity)
|
||||
acc_init = acc.after(*input_ranges).index(UOp.const(dtypes.int, 0)).store(identity)
|
||||
lst = [acc.after(acc_init, *reduce_range).index(UOp.const(dtypes.int, 0))] + lst # put acc as the first element
|
||||
ctx.acc_num += 1
|
||||
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:dict[str, Any]|None=No
|
|||
j += "\x20"*(round_up(len(j),8)-len(j))
|
||||
pathlib.Path(fn).unlink(missing_ok=True)
|
||||
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
t[0:8].bitcast(dtypes.int64).assign([len(j)])
|
||||
t[0:8].assign(Tensor([len(j)], dtype=dtypes.int64, device="CPU").bitcast(dtypes.uint8))
|
||||
t[8:8+len(j)].assign(list(j.encode('utf-8')))
|
||||
for k,v in safe_load(t).items(): v.assign(tensors[k])
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,13 @@ import tinygrad.runtime.support.objc as objc
|
|||
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator, ProfileDeviceEvent, CompilerSet, CompilerPair
|
||||
from tinygrad.renderer.cstyle import MetalRenderer
|
||||
from tinygrad.runtime.autogen import metal
|
||||
from tinygrad.runtime.support.c import DLL
|
||||
|
||||
# 13 is requestType that metal uses to compile source code into MTLB, there aren't any docs or symbols.
|
||||
REQUEST_TYPE_COMPILE = 13
|
||||
|
||||
# Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc
|
||||
ctypes.CDLL("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics")
|
||||
DLL("CoreGraphics", "CoreGraphics")
|
||||
|
||||
# FIXME: these need autogen to support objc categories
|
||||
# https://developer.apple.com/library/archive/documentation/Cocoa/Conceptual/ObjectiveC/Chapters/ocCategories.html
|
||||
|
|
@ -67,7 +68,7 @@ class MetalCompiler(Compiler):
|
|||
# doesn't seem to be anything we can do.
|
||||
with contextlib.suppress(FileNotFoundError, ModuleNotFoundError):
|
||||
import tinygrad.runtime.autogen.llvm # noqa: F401
|
||||
support = ctypes.CDLL("/System/Library/PrivateFrameworks/MTLCompiler.framework/MTLCompiler")
|
||||
support = DLL("MTLCompiler", "MTLCompiler")
|
||||
support.MTLCodeGenServiceCreate.restype = ctypes.c_void_p
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -131,13 +131,15 @@ def init_c_struct_t(sz:int, fields: tuple[tuple, ...]):
|
|||
def init_c_var(ty, creat_cb): return (creat_cb(v:=del_an(ty)()), v)[1]
|
||||
|
||||
class DLL(ctypes.CDLL):
|
||||
_loaded_: set[str] = set()
|
||||
|
||||
@staticmethod
|
||||
def findlib(nm:str, paths:list[str], extra_paths=[]):
|
||||
if nm == 'libc' and OSX: return '/usr/lib/libc.dylib'
|
||||
if pathlib.Path(path:=getenv(nm.replace('-', '_').upper()+"_PATH", '')).is_file(): return path
|
||||
for p in paths:
|
||||
libpaths = {"posix": ["/usr/lib64", "/usr/lib", "/usr/local/lib"], "nt": os.environ['PATH'].split(os.pathsep),
|
||||
"darwin": ["/opt/homebrew/lib", f"/System/Library/Frameworks/{p}.framework"],
|
||||
"darwin": ["/opt/homebrew/lib", f"/System/Library/Frameworks/{p}.framework", f"/System/Library/PrivateFrameworks/{p}.framework"],
|
||||
'linux': ['/lib', '/lib64', f"/lib/{sysconfig.get_config_var('MULTIARCH')}", "/usr/lib/wsl/lib/"]}
|
||||
if (pth:=pathlib.Path(p)).is_absolute():
|
||||
if pth.is_file(): return p
|
||||
|
|
@ -154,12 +156,12 @@ class DLL(ctypes.CDLL):
|
|||
if f.read(4) == b'\x7FELF': return str(l)
|
||||
|
||||
def __init__(self, nm:str, paths:str|list[str], extra_paths=[], emsg="", **kwargs):
|
||||
self.nm, self.emsg, self.loaded = nm, emsg, False
|
||||
self.nm, self.emsg = nm, emsg
|
||||
if (path:= DLL.findlib(nm, paths if isinstance(paths, list) else [paths], extra_paths if isinstance(extra_paths, list) else [extra_paths])):
|
||||
if DEBUG >= 3: print(f"loading {nm} from {path}")
|
||||
try:
|
||||
super().__init__(path, **kwargs)
|
||||
self.loaded = True
|
||||
self._loaded_.add(self.nm)
|
||||
except OSError as e:
|
||||
self.emsg = str(e)
|
||||
if DEBUG >= 3: print(f"loading {nm} failed: {e}")
|
||||
|
|
@ -175,5 +177,6 @@ class DLL(ctypes.CDLL):
|
|||
return wrapper
|
||||
|
||||
def __getattr__(self, nm):
|
||||
if not self.loaded: raise AttributeError(f"failed to load library {self.nm}: " + (self.emsg or f"try setting {self.nm.upper()+'_PATH'}?"))
|
||||
if self.nm not in self._loaded_:
|
||||
raise AttributeError(f"failed to load library {self.nm}: " + (self.emsg or f"try setting {self.nm.upper()+'_PATH'}?"))
|
||||
return super().__getattr__(nm)
|
||||
|
|
|
|||
|
|
@ -5,10 +5,10 @@ from tinygrad.helpers import DEBUG, to_mv, round_up, OSX
|
|||
from tinygrad.runtime.support.hcq import MMIOInterface
|
||||
|
||||
class USB3:
|
||||
def __init__(self, vendor:int, dev:int, ep_data_in:int, ep_stat_in:int, ep_data_out:int, ep_cmd_out:int, max_streams:int=31):
|
||||
def __init__(self, vendor:int, dev:int, ep_data_in:int, ep_stat_in:int, ep_data_out:int, ep_cmd_out:int, max_streams:int=31, use_bot=False):
|
||||
self.vendor, self.dev = vendor, dev
|
||||
self.ep_data_in, self.ep_stat_in, self.ep_data_out, self.ep_cmd_out = ep_data_in, ep_stat_in, ep_data_out, ep_cmd_out
|
||||
self.max_streams = max_streams
|
||||
self.max_streams, self.use_bot = max_streams, use_bot
|
||||
self.ctx = ctypes.POINTER(libusb.struct_libusb_context)()
|
||||
|
||||
if libusb.libusb_init(ctypes.byref(self.ctx)): raise RuntimeError("libusb_init failed")
|
||||
|
|
@ -25,30 +25,34 @@ class USB3:
|
|||
# Set configuration and claim interface
|
||||
if libusb.libusb_set_configuration(self.handle, 1): raise RuntimeError("set_configuration failed")
|
||||
if libusb.libusb_claim_interface(self.handle, 0): raise RuntimeError("claim_interface failed. sudo required?")
|
||||
if libusb.libusb_set_interface_alt_setting(self.handle, 0, 1): raise RuntimeError("alt_setting failed")
|
||||
|
||||
# Clear any stalled endpoints
|
||||
all_eps = (self.ep_data_out, self.ep_data_in, self.ep_stat_in, self.ep_cmd_out)
|
||||
for ep in all_eps: libusb.libusb_clear_halt(self.handle, ep)
|
||||
if use_bot:
|
||||
self._tag = 0
|
||||
else:
|
||||
if libusb.libusb_set_interface_alt_setting(self.handle, 0, 1): raise RuntimeError("alt_setting failed")
|
||||
|
||||
# Allocate streams
|
||||
stream_eps = (ctypes.c_uint8 * 3)(self.ep_data_out, self.ep_data_in, self.ep_stat_in)
|
||||
if (rc:=libusb.libusb_alloc_streams(self.handle, self.max_streams * len(stream_eps), stream_eps, len(stream_eps))) < 0:
|
||||
raise RuntimeError(f"alloc_streams failed: {rc}")
|
||||
# Clear any stalled endpoints
|
||||
all_eps = (self.ep_data_out, self.ep_data_in, self.ep_stat_in, self.ep_cmd_out)
|
||||
for ep in all_eps: libusb.libusb_clear_halt(self.handle, ep)
|
||||
|
||||
# Base cmd
|
||||
cmd_template = bytes([0x01, 0x00, 0x00, 0x01, *([0] * 12), 0xE4, 0x24, 0x00, 0xB2, 0x1A, 0x00, 0x00, 0x00, *([0] * 8)])
|
||||
# Allocate streams
|
||||
stream_eps = (ctypes.c_uint8 * 3)(self.ep_data_out, self.ep_data_in, self.ep_stat_in)
|
||||
if (rc:=libusb.libusb_alloc_streams(self.handle, self.max_streams * len(stream_eps), stream_eps, len(stream_eps))) < 0:
|
||||
raise RuntimeError(f"alloc_streams failed: {rc}")
|
||||
|
||||
# Init pools
|
||||
self.tr = {ep: [libusb.libusb_alloc_transfer(0) for _ in range(self.max_streams)] for ep in all_eps}
|
||||
# Base cmd
|
||||
cmd_template = bytes([0x01, 0x00, 0x00, 0x01, *([0] * 12), 0xE4, 0x24, 0x00, 0xB2, 0x1A, 0x00, 0x00, 0x00, *([0] * 8)])
|
||||
|
||||
self.buf_cmd = [(ctypes.c_uint8 * len(cmd_template))(*cmd_template) for _ in range(self.max_streams)]
|
||||
self.buf_stat = [(ctypes.c_uint8 * 64)() for _ in range(self.max_streams)]
|
||||
self.buf_data_in = [(ctypes.c_uint8 * 0x1000)() for _ in range(self.max_streams)]
|
||||
self.buf_data_out = [(ctypes.c_uint8 * 0x80000)() for _ in range(self.max_streams)]
|
||||
self.buf_data_out_mvs = [to_mv(ctypes.addressof(self.buf_data_out[i]), 0x80000) for i in range(self.max_streams)]
|
||||
# Init pools
|
||||
self.tr = {ep: [libusb.libusb_alloc_transfer(0) for _ in range(self.max_streams)] for ep in all_eps}
|
||||
|
||||
for slot in range(self.max_streams): struct.pack_into(">B", self.buf_cmd[slot], 3, slot + 1)
|
||||
self.buf_cmd = [(ctypes.c_uint8 * len(cmd_template))(*cmd_template) for _ in range(self.max_streams)]
|
||||
self.buf_stat = [(ctypes.c_uint8 * 64)() for _ in range(self.max_streams)]
|
||||
self.buf_data_in = [(ctypes.c_uint8 * 0x1000)() for _ in range(self.max_streams)]
|
||||
self.buf_data_out = [(ctypes.c_uint8 * 0x80000)() for _ in range(self.max_streams)]
|
||||
self.buf_data_out_mvs = [to_mv(ctypes.addressof(self.buf_data_out[i]), 0x80000) for i in range(self.max_streams)]
|
||||
|
||||
for slot in range(self.max_streams): struct.pack_into(">B", self.buf_cmd[slot], 3, slot + 1)
|
||||
|
||||
def _prep_transfer(self, tr, ep, stream_id, buf, length):
|
||||
tr.contents.dev_handle, tr.contents.endpoint, tr.contents.length, tr.contents.buffer = self.handle, ep, length, buf
|
||||
|
|
@ -68,38 +72,90 @@ class USB3:
|
|||
if tr.contents.status == libusb.LIBUSB_TRANSFER_COMPLETED: running -= 1
|
||||
elif tr.contents.status != 0xFF: raise RuntimeError(f"EP 0x{tr.contents.endpoint:02X} error: {tr.contents.status}")
|
||||
|
||||
def _bulk_out(self, ep: int, payload: bytes, timeout: int = 1000):
|
||||
transferred = ctypes.c_int(0)
|
||||
rc = libusb.libusb_bulk_transfer(
|
||||
self.handle,
|
||||
ep,
|
||||
(ctypes.c_ubyte * len(payload))(*payload),
|
||||
len(payload),
|
||||
ctypes.byref(transferred),
|
||||
timeout,
|
||||
)
|
||||
assert rc == 0, f"bulk OUT 0x{ep:02X} failed: {rc}"
|
||||
assert transferred.value == len(payload), f"bulk OUT short write on 0x{ep:02X}: {transferred.value}/{len(payload)} bytes"
|
||||
|
||||
def _bulk_in(self, ep: int, length: int, timeout: int = 1000) -> bytes:
|
||||
buf, transferred = (ctypes.c_ubyte * length)(), ctypes.c_int(0)
|
||||
rc = libusb.libusb_bulk_transfer(
|
||||
self.handle,
|
||||
ep,
|
||||
buf,
|
||||
length,
|
||||
ctypes.byref(transferred),
|
||||
timeout,
|
||||
)
|
||||
assert rc == 0, f"bulk IN 0x{ep:02X} failed: {rc}"
|
||||
return bytes(buf[:transferred.value])
|
||||
|
||||
def send_batch(self, cdbs:list[bytes], idata:list[int]|None=None, odata:list[bytes|None]|None=None) -> list[bytes|None]:
|
||||
idata, odata = idata or [0] * len(cdbs), odata or [None] * len(cdbs)
|
||||
results, tr_window, op_window = [], [], []
|
||||
results:list[bytes|None] = []
|
||||
tr_window, op_window = [], []
|
||||
|
||||
for idx, (cdb, rlen, send_data) in enumerate(zip(cdbs, idata, odata)):
|
||||
# allocate slot and stream. stream is 1-based
|
||||
slot, stream = idx % self.max_streams, (idx % self.max_streams) + 1
|
||||
if self.use_bot:
|
||||
dir_in = rlen > 0
|
||||
data_len = rlen if dir_in else (len(send_data) if send_data is not None else 0)
|
||||
assert (data_len == 0) if dir_in else (rlen == 0), "BOT mode only supports either read or write per command"
|
||||
|
||||
# build cmd packet
|
||||
self.buf_cmd[slot][16:16+len(cdb)] = list(cdb)
|
||||
# CBW
|
||||
self._tag += 1
|
||||
flags = 0x80 if dir_in else 0x00
|
||||
cbw = struct.pack("<IIIBBB", 0x43425355, self._tag, data_len, flags, 0, len(cdb)) + cdb + b"\x00" * (16 - len(cdb))
|
||||
self._bulk_out(self.ep_data_out, cbw)
|
||||
|
||||
# cmd + stat transfers
|
||||
tr_window.append(self._prep_transfer(self.tr[self.ep_cmd_out][slot], self.ep_cmd_out, None, self.buf_cmd[slot], len(self.buf_cmd[slot])))
|
||||
tr_window.append(self._prep_transfer(self.tr[self.ep_stat_in][slot], self.ep_stat_in, stream, self.buf_stat[slot], 64))
|
||||
# DAT
|
||||
if dir_in:
|
||||
results.append(self._bulk_in(self.ep_data_in, rlen))
|
||||
else:
|
||||
if send_data is not None:
|
||||
self._bulk_out(self.ep_data_out, send_data)
|
||||
results.append(None)
|
||||
|
||||
if rlen:
|
||||
if rlen > len(self.buf_data_in[slot]): self.buf_data_in[slot] = (ctypes.c_uint8 * round_up(rlen, 0x1000))()
|
||||
tr_window.append(self._prep_transfer(self.tr[self.ep_data_in][slot], self.ep_data_in, stream, self.buf_data_in[slot], rlen))
|
||||
# CSW
|
||||
sig, rtag, residue, status = struct.unpack("<IIIB", self._bulk_in(self.ep_data_in, 13, timeout=2000))
|
||||
assert sig == 0x53425355, f"Bad CSW signature 0x{sig:08X}, expected 0x53425355"
|
||||
assert rtag == self._tag, f"CSW tag mismatch: got {rtag}, expected {self._tag}"
|
||||
assert status == 0, f"SCSI command failed, CSW status=0x{status:02X}, residue={residue}"
|
||||
else:
|
||||
# allocate slot and stream. stream is 1-based
|
||||
slot, stream = idx % self.max_streams, (idx % self.max_streams) + 1
|
||||
|
||||
if send_data is not None:
|
||||
if len(send_data) > len(self.buf_data_out[slot]):
|
||||
self.buf_data_out[slot] = (ctypes.c_uint8 * len(send_data))()
|
||||
self.buf_data_out_mvs[slot] = to_mv(ctypes.addressof(self.buf_data_out[slot]), len(send_data))
|
||||
# build cmd packet
|
||||
self.buf_cmd[slot][16:16+len(cdb)] = list(cdb)
|
||||
|
||||
self.buf_data_out_mvs[slot][:len(send_data)] = bytes(send_data)
|
||||
tr_window.append(self._prep_transfer(self.tr[self.ep_data_out][slot], self.ep_data_out, stream, self.buf_data_out[slot], len(send_data)))
|
||||
# cmd + stat transfers
|
||||
tr_window.append(self._prep_transfer(self.tr[self.ep_cmd_out][slot], self.ep_cmd_out, None, self.buf_cmd[slot], len(self.buf_cmd[slot])))
|
||||
tr_window.append(self._prep_transfer(self.tr[self.ep_stat_in][slot], self.ep_stat_in, stream, self.buf_stat[slot], 64))
|
||||
|
||||
op_window.append((idx, slot, rlen))
|
||||
if (idx + 1 == len(cdbs)) or len(op_window) >= self.max_streams:
|
||||
self._submit_and_wait(tr_window)
|
||||
for idx, slot, rlen in op_window: results.append(bytes(self.buf_data_in[slot][:rlen]) if rlen else None)
|
||||
tr_window = []
|
||||
if rlen:
|
||||
if rlen > len(self.buf_data_in[slot]): self.buf_data_in[slot] = (ctypes.c_uint8 * round_up(rlen, 0x1000))()
|
||||
tr_window.append(self._prep_transfer(self.tr[self.ep_data_in][slot], self.ep_data_in, stream, self.buf_data_in[slot], rlen))
|
||||
|
||||
if send_data is not None:
|
||||
if len(send_data) > len(self.buf_data_out[slot]):
|
||||
self.buf_data_out[slot] = (ctypes.c_uint8 * len(send_data))()
|
||||
self.buf_data_out_mvs[slot] = to_mv(ctypes.addressof(self.buf_data_out[slot]), len(send_data))
|
||||
|
||||
self.buf_data_out_mvs[slot][:len(send_data)] = bytes(send_data)
|
||||
tr_window.append(self._prep_transfer(self.tr[self.ep_data_out][slot], self.ep_data_out, stream, self.buf_data_out[slot], len(send_data)))
|
||||
|
||||
op_window.append((idx, slot, rlen))
|
||||
if (idx + 1 == len(cdbs)) or len(op_window) >= self.max_streams:
|
||||
self._submit_and_wait(tr_window)
|
||||
for idx, slot, rlen in op_window: results.append(bytes(self.buf_data_in[slot][:rlen]) if rlen else None)
|
||||
tr_window = []
|
||||
|
||||
return results
|
||||
|
||||
|
|
|
|||
|
|
@ -71,7 +71,8 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
|||
# allgather
|
||||
copied_chunks = []
|
||||
for i,rc in enumerate(reduced_chunks):
|
||||
if use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(n_lbs))))
|
||||
if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg))
|
||||
elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(n_lbs))))
|
||||
else:
|
||||
this_chunk: list[UOp|None] = [None] * n_lbs
|
||||
this_chunk[(i+n_lbs-1)%n_lbs] = rc
|
||||
|
|
@ -93,12 +94,7 @@ def mstack_early_shrink(ms:UOp, shrink:UOp):
|
|||
return s.shrink(tuple(new_arg))
|
||||
for i, x in enumerate(ms.src):
|
||||
if x.op is Ops.COPY:
|
||||
# if src device doesn't have a renderer, we have to view after the copy
|
||||
# TODO: a way to understand this
|
||||
if x.src[0].device in {"DISK", "NPY"}:
|
||||
ret.append(apply_shrink(x, i))
|
||||
else:
|
||||
ret.append(apply_shrink(x.src[0], i).copy_to_device(x.device))
|
||||
ret.append(apply_shrink(x.src[0], i).copy_to_device(x.device))
|
||||
else:
|
||||
ret.append(apply_shrink(x, i).contiguous())
|
||||
return ms.replace(src=tuple(ret))
|
||||
|
|
|
|||
|
|
@ -126,6 +126,10 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
|||
|
||||
# ** assign rules **
|
||||
|
||||
# move bitcast from assign target to source: a.bitcast(X).assign(src) -> a.assign(src.bitcast(a.dtype))
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src")), name="assign"),
|
||||
lambda target, src, assign: target.assign(src.bitcast(target.dtype)).replace(tag=assign.tag)),
|
||||
|
||||
# assign only to buffer, otherwise make it a CONTIGUOUS
|
||||
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.BUFFER}, name="target"), UPat(name="x")), name="assign"),
|
||||
lambda x,target,assign: x.f(Ops.CONTIGUOUS, tag=assign.tag) if ((t:=target.base).op is not Ops.BUFFER and \
|
||||
|
|
@ -271,20 +275,20 @@ def late_buffer_view(t:UOp, b:UOp):
|
|||
size = prod(shape)
|
||||
|
||||
# walk up for the INDEX
|
||||
# NOTE: even though we allow RESHAPE and SHRINK, they can combine to form non-contiguous access patterns (e.g. t[::2])
|
||||
x = t
|
||||
while not any(u.op is Ops.INDEX for u in x.src):
|
||||
assert x.op not in GroupOp.Elementwise, "can't buffer view elementwise"
|
||||
while x.op is not Ops.INDEX:
|
||||
assert x.op in {Ops.BITCAST, Ops.CONTIGUOUS, Ops.SHRINK, Ops.RESHAPE}, f"unexpected op {x.op} in buffer view walk"
|
||||
x = x.src[0]
|
||||
x = next(u for u in x.src if u.op is Ops.INDEX)
|
||||
|
||||
if len(shape) == 0: offset = x.src[1].arg
|
||||
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
|
||||
else: offset = sum(idx.vmin for idx in x.src[1:])
|
||||
if offset < 0: raise RuntimeError(f"negative offset {offset} in buffer view")
|
||||
|
||||
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset), tag=t.tag), b.src[1]))
|
||||
|
||||
to_bufferview = PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, src=(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t"), UPat()), name="b"), late_buffer_view),
|
||||
(UPat((Ops.BITCAST, Ops.CONTIGUOUS)).f(Ops.BUFFER_VIEW, name="b"), lambda b: b.replace(src=b.src[0].src)),
|
||||
])
|
||||
|
||||
DEVICE_MAX_BUFS = {"METAL": 31, "WEBGPU": 8} # TODO: get from device?
|
||||
|
|
@ -451,9 +455,6 @@ to_define_global = PatternMatcher([
|
|||
# this is only needed if you are using symbolic
|
||||
(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
|
||||
|
||||
# remove RANGE with 0 size
|
||||
(UPat(Ops.RANGE, name="r"), lambda r: UOp.const(dtypes.index, 0) if r.vmax == 0 else None),
|
||||
|
||||
# renumber the ranges starting with 0 so that kernel deduping works
|
||||
(UPat(Ops.RANGE, name="r"), renumber_range),
|
||||
])
|
||||
|
|
@ -469,9 +470,6 @@ rangeify_codegen = PatternMatcher([
|
|||
# TODO: this can be moved into codegen?
|
||||
(UPat(Ops.NOOP, name="x"), lambda x: x.src[0]),
|
||||
|
||||
# strip the arg from store
|
||||
(UPat(Ops.STORE, name="x"), lambda x: x.replace(arg=None) if x.arg is not None else None),
|
||||
|
||||
# add loads to non ptr indexes
|
||||
# TODO: this can be moved into codegen?
|
||||
#(UPat.any(UPat(Ops.DEFINE_GLOBAL, name="dg"), UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True, name="dg"))
|
||||
|
|
|
|||
|
|
@ -286,20 +286,20 @@ class Tensor(OpMixin):
|
|||
return self
|
||||
|
||||
def assign(self, x:Tensor|PyConst|list|tuple) -> Tensor:
|
||||
is_disk = isinstance(self.device, str) and self.device.startswith("DISK")
|
||||
if not isinstance(x, Tensor): x = Tensor(x, device="CPU" if is_disk else self.device, dtype=self.dtype)
|
||||
if self.uop is x.uop: return self # a self assign is a NOOP
|
||||
# broadcast x (shape only, dtype must match)
|
||||
if self.shape != x.shape: x = x._broadcast_to(self.shape)
|
||||
if self.shape != x.shape: raise RuntimeError(f"assign shape mismatch {self.shape} != {x.shape}")
|
||||
if not is_disk and self.device != x.device: raise RuntimeError(f"assign device mismatch {self.device} != {x.device}")
|
||||
if self.dtype != x.dtype: raise RuntimeError(f"assign dtype mismatch {self.dtype} != {x.dtype}")
|
||||
if isinstance(self.device, tuple) and self.uop.axis != x.uop.axis: raise RuntimeError(f"multi axis mismatch {self.uop.axis} != {x.uop.axis}")
|
||||
|
||||
# TODO: this is a hack for writing to DISK. remove with working assign
|
||||
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
||||
if not isinstance(x, Tensor): x = Tensor(x, device="CPU", dtype=self.dtype)
|
||||
if is_disk:
|
||||
self._buffer().copyin(x._data())
|
||||
return self
|
||||
if not isinstance(x, Tensor): x = Tensor(x, device=self.device, dtype=self.dtype)
|
||||
if self.uop is x.uop: return self # a self assign is a NOOP
|
||||
# NOTE: we allow cross device assign
|
||||
# broadcast x
|
||||
if least_upper_dtype(self.dtype, x.dtype) == self.dtype: x = x._broadcast_to(self.shape).cast(self.dtype)
|
||||
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
||||
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
|
||||
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
||||
assert not isinstance(self.device, tuple) or self.uop.axis == x.uop.axis, f"multi assign axis mismatch {self.uop.axis} != {x.uop.axis}"
|
||||
return self.replace(self._apply_uop(UOp.assign, x))
|
||||
|
||||
def detach(self) -> Tensor:
|
||||
|
|
@ -3868,7 +3868,10 @@ class Tensor(OpMixin):
|
|||
|
||||
def bitcast(self, dtype:DTypeLike) -> Tensor:
|
||||
"""
|
||||
Bitcasts `self` to the given `dtype` of the same itemsize.
|
||||
Bitcasts `self` to the given `dtype`.
|
||||
|
||||
When the target dtype has the same itemsize, this is a view of the same memory.
|
||||
When itemsizes differ, the last dimension is adjusted and a new Tensor is created.
|
||||
|
||||
`self` must not require a gradient.
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class Ops(FastEnum):
|
|||
# ** 1 -- defines/special **
|
||||
|
||||
# define GLOBAL/VAR are ptrs to outside the Kernel
|
||||
DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); BIND = auto()
|
||||
DEFINE_VAR = auto(); BIND = auto()
|
||||
|
||||
# this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly
|
||||
SPECIAL = auto()
|
||||
|
|
@ -28,6 +28,9 @@ class Ops(FastEnum):
|
|||
NOOP = auto(); REWRITE_ERROR = auto()
|
||||
PARAM = auto(); CALL = auto()
|
||||
|
||||
# TODO: remove this alias, DEFINE_GLOBAL is PARAM now
|
||||
DEFINE_GLOBAL = PARAM
|
||||
|
||||
# renderer
|
||||
# LINEAR is a list of UOps, SOURCE has a str arg that's human readable, BINARY has bytes arg that's compiled
|
||||
PROGRAM = auto(); LINEAR = auto(); SOURCE = auto(); BINARY = auto()
|
||||
|
|
|
|||
|
|
@ -228,8 +228,9 @@ class UOp(OpMixin, Generic[OpT], metaclass=UOpMetaClass):
|
|||
case Ops.BUFFER_VIEW: return (self.arg[0],)
|
||||
case Ops.ENCDEC: return self.arg[0]
|
||||
case Ops.BUFFERIZE: return tuple([int(r.vmax+1) for r in self.src[1:]])
|
||||
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
|
||||
case Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
|
||||
case Ops.PARAM:
|
||||
if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size,)
|
||||
# NOTE: copied from marg
|
||||
if len(self.src) >= 1: return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count))
|
||||
return None
|
||||
|
|
@ -420,10 +421,8 @@ class UOp(OpMixin, Generic[OpT], metaclass=UOpMetaClass):
|
|||
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
|
||||
def store(self, src:UOp|ConstType, **kwargs):
|
||||
return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self, UOp.const(self.dtype, src) if not isinstance(src, UOp) else src), **kwargs)
|
||||
def end(self, *src:UOp):
|
||||
if len(src) == 0: return self
|
||||
return UOp(Ops.END, src=(self,)+src)
|
||||
def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs)
|
||||
def end(self, *src:UOp): return UOp(Ops.END, src=(self,)+src) if len(src) else self
|
||||
def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs) if len(src) else self
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
|
||||
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
|
||||
def contract(self, *rngs:UOp):
|
||||
|
|
@ -1445,6 +1444,7 @@ def pyrender(ast:UOp) -> str:
|
|||
for s in u.src: to_render.add(s)
|
||||
if u.op is Ops.STORE: to_render.add(u.src[1])
|
||||
if u.op in {Ops.REDUCE, Ops.REDUCE_AXIS}: to_render.add(u.src[0])
|
||||
if u.op in {Ops.CUSTOM_KERNEL, Ops.CALL}: raise NotImplementedError("custom_kernel / call can't be pyrendered")
|
||||
if u.op in not_rendered: continue
|
||||
# checking the consumers is not enough, you have to make sure it's not used twice by the one consumer
|
||||
if len(cmap[u]) == 1 and len([x for x in list(cmap[u].keys())[0].src if x is u]) == 1 and u.op not in always_rendered: continue
|
||||
|
|
@ -1453,7 +1453,13 @@ def pyrender(ast:UOp) -> str:
|
|||
kernels: dict[UOp, tuple[str, str]] = {}
|
||||
r: dict[UOp, str] = {}
|
||||
ret: dict[str, str] = {}
|
||||
depth: dict[UOp, int] = {}
|
||||
for i,u in enumerate(lst):
|
||||
# limit inline depth to avoid "too many nested parentheses" in Python parser
|
||||
op_depth = 1 + max([depth[s] for s in u.src], default=0)
|
||||
if op_depth > 100: to_render.add(u)
|
||||
depth[u] = 0 if u in to_render else op_depth
|
||||
# do the rendering
|
||||
if u.op is Ops.KERNEL:
|
||||
if u.arg.ast not in kernels:
|
||||
kernels[u.arg.ast] = (f"k{len(kernels)}", f"def k{len(kernels)}():\n " + pyrender(u.arg.ast).replace('\n', '\n ') + "\n return ast\n\n")
|
||||
|
|
|
|||
|
|
@ -317,7 +317,8 @@ def eval_pyrender(code:str) -> UOp:
|
|||
return lcls['ast']
|
||||
|
||||
def test_pyrender(test_ast:UOp, assert_parents=True):
|
||||
code = pyrender(test_ast)
|
||||
try: code = pyrender(test_ast)
|
||||
except NotImplementedError: return None # this is okay, not all ops can be pyrendered
|
||||
ast:UOp = eval_pyrender(code)
|
||||
if ast is not test_ast:
|
||||
if assert_parents:
|
||||
|
|
|
|||
|
|
@ -306,11 +306,12 @@ def load_counters(profile:list[ProfileEvent]) -> None:
|
|||
# to decode a SQTT trace, we need the raw stream, program binary and device properties
|
||||
if (sqtt:=v.get(ProfileSQTTEvent)):
|
||||
for e in sqtt:
|
||||
if e.itrace: steps.append(create_step(f"PKTS SE:{e.se}", (f"/prg-pkts-{e.se}", len(ctxs), len(steps)), data=(e.blob, prg_events[k].lib)))
|
||||
if e.itrace: steps.append(create_step(f"PKTS SE:{e.se}", (f"/prg-pkts-{e.se}", len(ctxs), len(steps)),
|
||||
data=(e.blob, prg_events[k].lib, device_props[e.device]["gfx_target_version"])))
|
||||
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k])))
|
||||
ctxs.append({"name":f"Exec {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
|
||||
|
||||
def sqtt_timeline(data:bytes, lib:bytes) -> list[ProfileEvent]:
|
||||
def sqtt_timeline(data:bytes, lib:bytes, target:int) -> list[ProfileEvent]:
|
||||
from extra.assembly.amd.sqttmap import map_insts, InstructionInfo
|
||||
from extra.assembly.amd.sqtt import PacketType, INST, InstOp, VALUINST, IMMEDIATE, IMMEDIATE_MASK, VMEMEXEC, ALUEXEC
|
||||
ret:list[ProfileEvent] = []
|
||||
|
|
@ -321,7 +322,7 @@ def sqtt_timeline(data:bytes, lib:bytes) -> list[ProfileEvent]:
|
|||
rows.setdefault(r:=(f"WAVE:{wave}" if wave is not None else f"{p.__class__.__name__}:0 {name}"))
|
||||
key = TracingKey(f"{op_name if op_name is not None else name} OP:{idx}", ret=info.inst.disasm() if info is not None else None)
|
||||
ret.append(ProfileRangeEvent(r, key, Decimal(p._time), Decimal(p._time+width)))
|
||||
for p, info in map_insts(data, lib):
|
||||
for p, info in map_insts(data, lib, target):
|
||||
if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break
|
||||
if isinstance(p, INST):
|
||||
op_name = p.op.name if isinstance(p.op, InstOp) else f"0x{p.op:02x}"
|
||||
|
|
@ -346,7 +347,7 @@ def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent) -> tuple[
|
|||
# * init decoder
|
||||
from extra.sqtt.roc import decode
|
||||
base = unwrap(p.base)
|
||||
addr_table = amd_decode(unwrap(p.lib), device_props[p.device]["gfx_target_version"], )
|
||||
addr_table = amd_decode(unwrap(p.lib), device_props[p.device]["gfx_target_version"])
|
||||
disasm:dict[int, tuple[str, int]] = {addr+base:(inst.disasm(), inst.size()) for addr, inst in addr_table.items()}
|
||||
rctx = decode(data, {p.tag:disasm})
|
||||
cu_events:dict[str, list[ProfileEvent]] = {}
|
||||
|
|
@ -422,16 +423,14 @@ def get_stdout(f: Callable) -> str:
|
|||
return buf.getvalue()
|
||||
|
||||
def amd_readelf(lib:bytes) -> list[dict]:
|
||||
from tinygrad.runtime.autogen import amdgpu_kd
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
import msgpack
|
||||
_, sections, __ = elf_loader(lib)
|
||||
data = next((s for s in sections if s.name.startswith(".note"))).content
|
||||
namesz, descsz, typ = struct.unpack_from(hdr:="<III", data, 0)
|
||||
offset = (struct.calcsize(hdr)+namesz+3) & -4
|
||||
notes = msgpack.unpackb(data[offset:offset+descsz])
|
||||
keys = {".sgpr_count":"SGPRs", ".vgpr_count":"VGPRs", ".max_flat_workgroup_size":"Max WGP size",
|
||||
".group_segment_fixed_size":"LDS size", ".private_segment_fixed_size":"Scratch size"}
|
||||
return [{"label":label, "value":v} for k,label in keys.items() if (v:=notes["amdhsa.kernels"][0][k]) > 0]
|
||||
image, sections, __ = elf_loader(lib)
|
||||
rodata = next((s for s in sections if s.name == ".rodata")).content
|
||||
kd = amdgpu_kd.llvm_amdhsa_kernel_descriptor_t.from_buffer_copy(bytearray(rodata))
|
||||
vgpr_gran = kd.compute_pgm_rsrc1 & amdgpu_kd.COMPUTE_PGM_RSRC1_GRANULATED_WORKITEM_VGPR_COUNT
|
||||
return [{"label":f"{resource} Alloc", "value":val} for resource,val in [("VGPR", (vgpr_gran+1)*8-7), ("LDS",kd.group_segment_fixed_size),
|
||||
("Scratch", kd.private_segment_fixed_size)] if val > 0]
|
||||
|
||||
def amd_decode(lib:bytes, target:int) -> dict[int, Any]: # Any is the Inst class from extra.assembly.amd.dsl
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue