Merge branch 'master' into mac_pytest

This commit is contained in:
George Hotz 2026-02-02 23:40:03 +08:00 committed by GitHub
commit 1019a3d8f8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
70 changed files with 13004 additions and 2502 deletions

View file

@ -704,6 +704,8 @@ jobs:
# TODO: run all once emulator is faster
- name: Run RDNA3 ops tests
run: SKIP_SLOW_TEST=1 AMD_LLVM=0 pytest -n=auto test/test_ops.py -k "test_sparse_categorical_crossentropy or test_tril or test_nonzero or test_softmax_argmax" --durations 20
- name: Run RDNA4 emulator tests
run: MOCKGPU_ARCH=rdna4 python -m pytest test/test_tiny.py -v --durations 20
testnvidia:
strategy:

View file

@ -10,6 +10,7 @@ export DEBUG=${DEBUG:-0}
export FLASH_ATTENTION=${FLASH_ATTENTION:-1}
export ALL2ALL=${ALL2ALL:-1}
export USE_ATOMICS=${USE_ATOMICS:-1}
export ASM_GEMM=${ASM_GEMM:-1}
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16"
export DP=${DP:-8} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-2}

View file

@ -49,10 +49,11 @@ from tinygrad.helpers import Context, DEBUG, colored
from tinygrad.engine.realize import get_runner
from extra.assembly.amd import decode_inst
from extra.assembly.amd.autogen.rdna3.str_pcode import PCODE
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP1_SDST, VOP2, VOP3, VOP3_SDST, VOP3SD, VOP3P, VOPC,
DS, FLAT, GLOBAL, SCRATCH, VOPD, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOPDOp)
from extra.assembly.amd.dsl import VCC_LO, EXEC_LO, SCC
from extra.assembly.amd.autogen.rdna3.str_pcode import PCODE as PCODE_RDNA3
from extra.assembly.amd.autogen.rdna4.str_pcode import PCODE as PCODE_RDNA4
from extra.assembly.amd.autogen.rdna3 import ins as ir3
from extra.assembly.amd.autogen.rdna4 import ins as ir4
from extra.assembly.amd.dsl import VCC_LO, EXEC_LO, SCC, ttmp
from extra.assembly.amd.autogen.common import Fmt, OpType
from extra.assembly.amd.pcode import parse_block, _FUNCS
@ -79,15 +80,23 @@ def _apply_src_mods(val: UOp, mod_bit: int, abs_bits: int, neg_bits: int, bits:
if neg_bits & (1 << mod_bit): fv = fv.neg()
return fv.bitcast(ut).cast(dtypes.uint32) if bits == 16 else fv.bitcast(ut)
# Map VOPD ops to VOP2 ops for pcode lookup
# Map VOPD ops to VOP2 ops for pcode lookup (both RDNA3 and RDNA4)
VOPD_TO_VOP2 = {
VOPDOp.V_DUAL_FMAC_F32: VOP2Op.V_FMAC_F32_E32, VOPDOp.V_DUAL_MUL_F32: VOP2Op.V_MUL_F32_E32,
VOPDOp.V_DUAL_ADD_F32: VOP2Op.V_ADD_F32_E32, VOPDOp.V_DUAL_SUB_F32: VOP2Op.V_SUB_F32_E32,
VOPDOp.V_DUAL_SUBREV_F32: VOP2Op.V_SUBREV_F32_E32, VOPDOp.V_DUAL_MAX_F32: VOP2Op.V_MAX_F32_E32,
VOPDOp.V_DUAL_MIN_F32: VOP2Op.V_MIN_F32_E32, VOPDOp.V_DUAL_ADD_NC_U32: VOP2Op.V_ADD_NC_U32_E32,
VOPDOp.V_DUAL_LSHLREV_B32: VOP2Op.V_LSHLREV_B32_E32, VOPDOp.V_DUAL_AND_B32: VOP2Op.V_AND_B32_E32,
VOPDOp.V_DUAL_MOV_B32: VOP1Op.V_MOV_B32_E32, VOPDOp.V_DUAL_CNDMASK_B32: VOP2Op.V_CNDMASK_B32_E32,
VOPDOp.V_DUAL_FMAAK_F32: VOP2Op.V_FMAAK_F32_E32, VOPDOp.V_DUAL_FMAMK_F32: VOP2Op.V_FMAMK_F32_E32,
ir3.VOPDOp.V_DUAL_FMAC_F32: ir3.VOP2Op.V_FMAC_F32_E32, ir3.VOPDOp.V_DUAL_MUL_F32: ir3.VOP2Op.V_MUL_F32_E32,
ir3.VOPDOp.V_DUAL_ADD_F32: ir3.VOP2Op.V_ADD_F32_E32, ir3.VOPDOp.V_DUAL_SUB_F32: ir3.VOP2Op.V_SUB_F32_E32,
ir3.VOPDOp.V_DUAL_SUBREV_F32: ir3.VOP2Op.V_SUBREV_F32_E32, ir3.VOPDOp.V_DUAL_MAX_F32: ir3.VOP2Op.V_MAX_F32_E32,
ir3.VOPDOp.V_DUAL_MIN_F32: ir3.VOP2Op.V_MIN_F32_E32, ir3.VOPDOp.V_DUAL_ADD_NC_U32: ir3.VOP2Op.V_ADD_NC_U32_E32,
ir3.VOPDOp.V_DUAL_LSHLREV_B32: ir3.VOP2Op.V_LSHLREV_B32_E32, ir3.VOPDOp.V_DUAL_AND_B32: ir3.VOP2Op.V_AND_B32_E32,
ir3.VOPDOp.V_DUAL_MOV_B32: ir3.VOP1Op.V_MOV_B32_E32, ir3.VOPDOp.V_DUAL_CNDMASK_B32: ir3.VOP2Op.V_CNDMASK_B32_E32,
ir3.VOPDOp.V_DUAL_FMAAK_F32: ir3.VOP2Op.V_FMAAK_F32_E32, ir3.VOPDOp.V_DUAL_FMAMK_F32: ir3.VOP2Op.V_FMAMK_F32_E32,
# RDNA4 mappings (same VOP1/VOP2 targets, RDNA4 uses _NUM_ suffix for min/max)
ir4.VOPDOp.V_DUAL_FMAC_F32: ir3.VOP2Op.V_FMAC_F32_E32, ir4.VOPDOp.V_DUAL_MUL_F32: ir3.VOP2Op.V_MUL_F32_E32,
ir4.VOPDOp.V_DUAL_ADD_F32: ir3.VOP2Op.V_ADD_F32_E32, ir4.VOPDOp.V_DUAL_SUB_F32: ir3.VOP2Op.V_SUB_F32_E32,
ir4.VOPDOp.V_DUAL_SUBREV_F32: ir3.VOP2Op.V_SUBREV_F32_E32, ir4.VOPDOp.V_DUAL_MAX_NUM_F32: ir3.VOP2Op.V_MAX_F32_E32,
ir4.VOPDOp.V_DUAL_MIN_NUM_F32: ir3.VOP2Op.V_MIN_F32_E32, ir4.VOPDOp.V_DUAL_ADD_NC_U32: ir3.VOP2Op.V_ADD_NC_U32_E32,
ir4.VOPDOp.V_DUAL_LSHLREV_B32: ir3.VOP2Op.V_LSHLREV_B32_E32, ir4.VOPDOp.V_DUAL_AND_B32: ir3.VOP2Op.V_AND_B32_E32,
ir4.VOPDOp.V_DUAL_MOV_B32: ir3.VOP1Op.V_MOV_B32_E32, ir4.VOPDOp.V_DUAL_CNDMASK_B32: ir3.VOP2Op.V_CNDMASK_B32_E32,
ir4.VOPDOp.V_DUAL_FMAAK_F32: ir3.VOP2Op.V_FMAAK_F32_E32, ir4.VOPDOp.V_DUAL_FMAMK_F32: ir3.VOP2Op.V_FMAMK_F32_E32,
}
WAVE_SIZE = 32
# Special registers stored after inline constants (256-259)
@ -146,11 +155,15 @@ _pcode_fixes = {
'V_TRIG_PREOP_F64': ("result = 64'F((1201'B(2.0 / PI)[1200 : 0] << shift.u32) & 1201'0x1fffffffffffff)", "result = trig_preop_result(shift)"),
}
def _get_pcode_dict(op) -> dict:
"""Return the PCODE dictionary for the given opcode based on its architecture."""
return PCODE_RDNA4 if 'rdna4' in type(op).__module__ else PCODE_RDNA3
# Pcode parser
@functools.cache
def get_pcode(op) -> str:
op_name = op.name
pcode = PCODE[op]
pcode = _get_pcode_dict(op)[op]
if op_name in _pcode_fixes: pcode = pcode.replace(*_pcode_fixes[op_name])
if 'V_DIV_SCALE' in op_name:
dt, exp_lim, ldexp_val = ('f32', '23', '64') if 'F32' in op_name else ('f64', '52', '128')
@ -174,7 +187,12 @@ def get_pcode(op) -> str:
def parse_pcode(pcode: str, srcs: dict[str, UOp] | None = None) -> tuple[dict, list[tuple[str, UOp]]]:
vars: dict = srcs.copy() if srcs else {}
assigns: list[tuple[str, UOp]] = []
lines = [l.strip().rstrip(';') for l in pcode.split('\n') if l.strip() and not l.strip().startswith('//')]
raw_lines = [l.strip().rstrip(';') for l in pcode.split('\n') if l.strip() and not l.strip().startswith('//')]
# TODO: pcode.py should tokenize full pcode string instead of line-by-line, then this hack can be removed
lines: list[str] = []
for l in raw_lines:
if lines and lines[-1].endswith('&&'): lines[-1] = lines[-1] + ' ' + l
else: lines.append(l)
_, final, _ = parse_block(lines, 0, vars, assigns=assigns)
sliced = set(d.split('[')[0] for d, _ in assigns if '[' in d)
for var, val in final.items():
@ -317,9 +335,9 @@ class _Ctx:
return base, mask, size
# Dynamic register access (takes UOp index instead of int)
def rsgpr_dyn(self, reg: UOp) -> UOp:
def rsgpr_dyn(self, reg: UOp, valid: UOp | None = None) -> UOp:
"""Read SGPR with dynamic register index."""
return self.sgpr.index(reg.cast(dtypes.int), ptr=True).load()
return self.sgpr.index(reg.cast(dtypes.int), valid, ptr=True).load() if valid is not None else self.sgpr.index(reg.cast(dtypes.int), ptr=True).load()
def wsgpr_dyn(self, reg: UOp, val: UOp) -> UOp:
"""Write SGPR with dynamic register index. Writes to NULL (124) are discarded."""
@ -341,15 +359,18 @@ class _Ctx:
If lane is None, only scalar access is supported (off must be < 256).
is_f64: True for F64 operations where 64-bit literals go in high 32 bits."""
is_float_const = (off >= _c(240)) & (off <= _c(248))
sgpr_lo = self.rsgpr_dyn(off)
is_vgpr = off >= _c(256)
is_sgpr = is_vgpr.ne(True)
sgpr_lo = self.rsgpr_dyn(off, is_sgpr)
if lane is not None:
is_vgpr, vgpr_reg = off >= _c(256), off - _c(256)
vgpr_reg = off - _c(256)
vgpr_lo = self.rvgpr_dyn(vgpr_reg, lane, is_vgpr)
vgpr_val = _u64(vgpr_lo, self.rvgpr_dyn(vgpr_reg + _c(1), lane, is_vgpr)) if bits == 64 else vgpr_lo
if bits == 64:
sgpr_val = _u64(sgpr_lo, self.rsgpr_dyn(off + _c(1)))
sgpr_hi = self.rsgpr_dyn(off + _c(1), is_sgpr)
sgpr_val = _u64(sgpr_lo, sgpr_hi)
# Integer inline constants: sign-extend 32-bit value from buffer to 64-bit
# Float constants: cast F32 to F64
int_inline = sgpr_lo.cast(dtypes.int32).cast(dtypes.int64)
@ -402,17 +423,19 @@ class _Ctx:
return UOp.sink(*self.scalar_stores(assigns, sdst_reg, sdst_size), *self.inc_pc())
def compile_lane_pcode(self, op, inst) -> UOp:
"""Compile READLANE/READFIRSTLANE/WRITELANE using pcode parser."""
"""Compile cross-lane ops (READLANE/WRITELANE/PERMLANE) using pcode parser."""
pcode = get_pcode(op)
op_name = op.name if hasattr(op, 'name') else str(op)
src0_off, vdst_off = self.inst_field(type(inst).src0), self.inst_field(type(inst).vdst)
src0_reg = (src0_off >= _c(256)).where(src0_off - _c(256), _c(0)) # VGPR index or 0
src1_off = self.inst_field(type(inst).src1) if hasattr(type(inst), 'src1') else None
src2_off = self.inst_field(type(inst).src2) if hasattr(type(inst), 'src2') else None
exec_lo = self.rsgpr_dyn(_c(EXEC_LO.offset))
srcs = {
'SRC0': src0_reg, 'VDST': vdst_off, 'EXEC_LO': exec_lo, 'EXEC': exec_lo.cast(dtypes.uint64), '_vgpr': self.vgpr,
'S0': self.rsrc_dyn(src0_off, _c(0, dtypes.int)) if 'WRITELANE' in op_name else src0_reg,
'S1': self.rsrc_dyn(src1_off, _c(0, dtypes.int)) if src1_off is not None else _c(0),
'S2': self.rsrc_dyn(src2_off, _c(0, dtypes.int)) if src2_off is not None else _c(0),
}
_, assigns = parse_pcode(pcode, srcs)
stores = []
@ -480,13 +503,14 @@ class _Ctx:
# INSTRUCTION HANDLERS
# ═══════════════════════════════════════════════════════════════════════════════
def _compile_sopp(inst: SOPP, ctx: _Ctx) -> UOp:
simm16 = ctx.inst_field_signed(SOPP.simm16).cast(dtypes.int16)
if inst.op == SOPPOp.S_ENDPGM:
def _compile_sopp(inst: ir3.SOPP | ir4.SOPP, ctx: _Ctx) -> UOp:
simm16 = ctx.inst_field_signed(type(inst).simm16).cast(dtypes.int16)
if inst.op in (ir3.SOPPOp.S_ENDPGM, ir4.SOPPOp.S_ENDPGM):
return UOp.sink(ctx.wsgpr_dyn(_c(PC_LO_IDX), UOp.const(dtypes.uint32, 0xFFFFFFFF)),
ctx.wsgpr_dyn(_c(PC_HI_IDX), UOp.const(dtypes.uint32, 0xFFFFFFFF)))
if inst.op in (ir3.SOPPOp.S_NOP, ir4.SOPPOp.S_NOP): return UOp.sink(*ctx.inc_pc()) # S_NOP is a no-op
# NOTE: we ignore SOPPs without PCODE
if inst.op in PCODE:
if inst.op in _get_pcode_dict(inst.op):
pcode = get_pcode(inst.op)
pc_bytes = ctx.rpc() # PC is already 64-bit byte address
vcc, exec_lo = ctx.rsgpr_dyn(_c(VCC_LO.offset)), ctx.rsgpr_dyn(_c(EXEC_LO.offset))
@ -498,50 +522,57 @@ def _compile_sopp(inst: SOPP, ctx: _Ctx) -> UOp:
return UOp.sink(ctx.wsgpr_dyn(_c(PC_LO_IDX), lo), ctx.wsgpr_dyn(_c(PC_HI_IDX), hi))
return UOp.sink(*ctx.inc_pc())
def _compile_smem(inst: SMEM, ctx: _Ctx) -> UOp:
def _compile_smem(inst: ir3.SMEM | ir4.SMEM, ctx: _Ctx) -> UOp:
# Cache invalidation instructions are no-ops in the emulator (we don't model caches)
if inst.op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV):
cache_inv_ops = [ir3.SMEMOp.S_GL1_INV, ir3.SMEMOp.S_DCACHE_INV, ir4.SMEMOp.S_DCACHE_INV]
if hasattr(ir4.SMEMOp, 'S_GL1_INV'): cache_inv_ops.append(ir4.SMEMOp.S_GL1_INV)
if inst.op in cache_inv_ops:
return UOp.sink(*ctx.inc_pc())
# Dynamic sbase field (bits 5:0) - SGPR pair, field value * 2 = register offset
sbase = ctx.inst_field(SMEM.sbase) * _c(2)
sbase = ctx.inst_field(type(inst).sbase) * _c(2)
# Dynamic sdata field (bits 12:6) - destination SGPR
sdata_reg = ctx.inst_field(SMEM.sdata)
offset = ctx.inst_field_signed(SMEM.offset) # 21-bit signed immediate
# Dynamic soffset field (bits 63:57) - SGPR for additional offset (NULL=124 reads as 0)
soffset = ctx.inst_field(SMEM.soffset)
sdata_reg = ctx.inst_field(type(inst).sdata)
# RDNA4 uses 'ioffset', RDNA3 uses 'offset' - use type(inst) to get correct field
offset_field = type(inst).ioffset if hasattr(type(inst), 'ioffset') else type(inst).offset
offset = ctx.inst_field_signed(offset_field) # signed immediate
# Dynamic soffset field - SGPR for additional offset (NULL=124 reads as 0)
soffset = ctx.inst_field(type(inst).soffset)
addr = _u64(ctx.rsgpr_dyn(sbase), ctx.rsgpr_dyn(sbase + _c(1))) + offset.cast(dtypes.uint64) + ctx.rsgpr_dyn(soffset).cast(dtypes.uint64)
ndwords = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}.get(inst.op, 1)
_SMEM_NDWORDS = {ir3.SMEMOp.S_LOAD_B32: 1, ir3.SMEMOp.S_LOAD_B64: 2, ir3.SMEMOp.S_LOAD_B128: 4,
ir3.SMEMOp.S_LOAD_B256: 8, ir3.SMEMOp.S_LOAD_B512: 16, ir4.SMEMOp.S_LOAD_B32: 1, ir4.SMEMOp.S_LOAD_B64: 2,
ir4.SMEMOp.S_LOAD_B96: 3, ir4.SMEMOp.S_LOAD_B128: 4, ir4.SMEMOp.S_LOAD_B256: 8, ir4.SMEMOp.S_LOAD_B512: 16}
ndwords = _SMEM_NDWORDS[inst.op]
stores = [ctx.wsgpr_dyn(sdata_reg + _c(i), ctx.vmem.index((addr + UOp.const(dtypes.uint64, i * 4) >> UOp.const(dtypes.uint64, 2)).cast(dtypes.int)))
for i in range(ndwords)]
return UOp.sink(*stores, *ctx.inc_pc())
def _compile_sop(inst: SOP1 | SOP2 | SOPC | SOPK, ctx: _Ctx) -> UOp:
def _compile_sop(inst: ir3.SOP1 | ir3.SOP2 | ir3.SOPC | ir3.SOPK | ir4.SOP1 | ir4.SOP2 | ir4.SOPC | ir4.SOPK, ctx: _Ctx) -> UOp:
bits = inst.canonical_op_bits
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
if isinstance(inst, SOPK):
sdst_off = ctx.inst_field(SOPK.sdst)
simm16 = ctx.inst_field(SOPK.simm16)
if isinstance(inst, (ir3.SOPK, ir4.SOPK)):
sdst_off = ctx.inst_field(type(inst).sdst)
simm16 = ctx.inst_field(type(inst).simm16)
# Sign-extend simm16
simm16_sext = simm16.cast(dtypes.int16).cast(dtypes.int32)
srcs = {'S0': ctx.rsgpr_dyn(sdst_off), 'SIMM16': simm16_sext, 'D0': ctx.rsgpr_dyn(sdst_off)}
dst_off, dst_size = sdst_off, 1
elif isinstance(inst, SOP1):
sdst_off = ctx.inst_field(SOP1.sdst)
ssrc0_off = ctx.inst_field(SOP1.ssrc0)
elif isinstance(inst, (ir3.SOP1, ir4.SOP1)):
sdst_off = ctx.inst_field(type(inst).sdst)
ssrc0_off = ctx.inst_field(type(inst).ssrc0)
srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal)}
dst_off, dst_size = sdst_off, bits['d'] // 32
elif isinstance(inst, SOP2):
sdst_off = ctx.inst_field(SOP2.sdst)
ssrc0_off = ctx.inst_field(SOP2.ssrc0)
ssrc1_off = ctx.inst_field(SOP2.ssrc1)
elif isinstance(inst, (ir3.SOP2, ir4.SOP2)):
sdst_off = ctx.inst_field(type(inst).sdst)
ssrc0_off = ctx.inst_field(type(inst).ssrc0)
ssrc1_off = ctx.inst_field(type(inst).ssrc1)
srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal),
'S1': ctx.rsrc_dyn(ssrc1_off, None, bits['s1'], literal)}
if literal is not None: srcs['SIMM32'] = literal
dst_off, dst_size = sdst_off, bits['d'] // 32
elif isinstance(inst, SOPC):
ssrc0_off = ctx.inst_field(SOPC.ssrc0)
ssrc1_off = ctx.inst_field(SOPC.ssrc1)
elif isinstance(inst, (ir3.SOPC, ir4.SOPC)):
ssrc0_off = ctx.inst_field(type(inst).ssrc0)
ssrc1_off = ctx.inst_field(type(inst).ssrc1)
srcs = {'S0': ctx.rsrc_dyn(ssrc0_off, None, bits['s0'], literal),
'S1': ctx.rsrc_dyn(ssrc1_off, None, bits['s1'], literal)}
dst_off, dst_size = _c(0), 0 # SOPC writes to SCC, not sdst
@ -550,18 +581,18 @@ def _compile_sop(inst: SOP1 | SOP2 | SOPC | SOPK, ctx: _Ctx) -> UOp:
return ctx.compile_sop_pcode(inst.op, srcs, dst_off, dst_size)
def _compile_vop12(inst: VOP1 | VOP1_SDST | VOP2, ctx: _Ctx) -> UOp:
def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VOP1_SDST | ir4.VOP2, ctx: _Ctx) -> UOp:
op_name = _op_name(inst)
if op_name == 'V_READFIRSTLANE_B32_E32': return ctx.compile_lane_pcode(inst.op, inst)
if op_name in ('V_READFIRSTLANE_B32_E32', 'V_PERMLANE64_B32_E32'): return ctx.compile_lane_pcode(inst.op, inst)
lane, exec_mask, bits = ctx.range(), ctx.rsgpr_dyn(_c(EXEC_LO.offset)), inst.canonical_op_bits
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
vdst_reg = ctx.inst_field(VOP1.vdst)
vdst_reg = ctx.inst_field(type(inst).vdst)
write_hi_half = bits['d'] == 16 and (vdst_reg >= _c(128))
if isinstance(write_hi_half, UOp): vdst_reg = write_hi_half.where(vdst_reg - _c(128), vdst_reg)
elif write_hi_half: vdst_reg -= 128
if isinstance(inst, VOP1):
if isinstance(inst, (ir3.VOP1, ir4.VOP1)):
# Handle VOP1 hi-half source operand (src0 >= v[128] for 16-bit ops)
src0_off = ctx.inst_field(VOP1.src0)
src0_off = ctx.inst_field(type(inst).src0)
s0 = ctx.rsrc_dyn(src0_off, lane, bits['s0'], literal)
if bits['s0'] == 16:
src0_hi = src0_off >= _c(384)
@ -570,13 +601,13 @@ def _compile_vop12(inst: VOP1 | VOP1_SDST | VOP2, ctx: _Ctx) -> UOp:
s0 = src0_hi.where(_hi16(ctx.rvgpr_dyn(src0_reg, lane)), s0)
srcs = {'S0': s0}
else:
vsrc1_reg = ctx.inst_field(VOP2.vsrc1)
vsrc1_reg = ctx.inst_field(type(inst).vsrc1)
vsrc1_hi = bits['s0'] == 16 and (vsrc1_reg >= _c(128))
vsrc1_actual = _cond(vsrc1_hi, vsrc1_reg - _c(128), vsrc1_reg)
s1 = _cond_hi16(vsrc1_hi, ctx.rvgpr_dyn(vsrc1_actual, lane))
d0 = _cond_hi16(write_hi_half, ctx.rvgpr_dyn(vdst_reg, lane)) # FMAC/FMAMK hi-half dest needs hi-half accumulator
# Handle VOP2 hi-half src0 operand (src0 >= v[128] for 16-bit ops)
src0_off = ctx.inst_field(VOP2.src0)
src0_off = ctx.inst_field(type(inst).src0)
s0 = ctx.rsrc_dyn(src0_off, lane, bits['s0'], literal)
if bits['s0'] == 16:
src0_hi = src0_off >= _c(384)
@ -584,19 +615,20 @@ def _compile_vop12(inst: VOP1 | VOP1_SDST | VOP2, ctx: _Ctx) -> UOp:
src0_reg = src0_hi.where(src0_off - _c(384), _c(0))
s0 = src0_hi.where(_hi16(ctx.rvgpr_dyn(src0_reg, lane)), s0)
srcs = {'S0': s0, 'S1': s1, 'D0': d0}
if inst.op in (VOP2Op.V_FMAAK_F32_E32, VOP2Op.V_FMAMK_F32_E32, VOP2Op.V_FMAAK_F16_E32, VOP2Op.V_FMAMK_F16_E32):
if inst.op in (ir3.VOP2Op.V_FMAAK_F32_E32, ir3.VOP2Op.V_FMAMK_F32_E32, ir3.VOP2Op.V_FMAAK_F16_E32,
ir3.VOP2Op.V_FMAMK_F16_E32):
assert literal is not None
srcs['SIMM32'] = literal
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=write_hi_half)
def _compile_vopc(inst: VOPC | VOP3, ctx: _Ctx, opsel: int = 0, abs_bits: int = 0, neg_bits: int = 0) -> UOp:
def _compile_vopc(inst: ir3.VOPC | ir3.VOP3 | ir4.VOPC | ir4.VOP3, ctx: _Ctx, opsel: int = 0, abs_bits: int = 0, neg_bits: int = 0) -> UOp:
exec_mask, op_name, bits = ctx.rsgpr_dyn(_c(EXEC_LO.offset)), _op_name(inst), inst.canonical_op_bits
is_cmpx, is_vopc = 'CMPX' in op_name, hasattr(inst, 'vsrc1') # is_vopc: e32 vs e64
# Handle both VOPC (vsrc1) and VOP3 (src1) instruction formats - read operands dynamically
if is_vopc:
src0_off = ctx.inst_field(VOPC.src0)
vsrc1_off = ctx.inst_field(VOPC.vsrc1)
src0_off = ctx.inst_field(type(inst).src0)
vsrc1_off = ctx.inst_field(type(inst).vsrc1)
# For 16-bit ops, vsrc1 >= 128 means hi-half of v[vsrc1-128]
if bits['s0'] == 16:
vsrc1_hi = vsrc1_off >= _c(128)
@ -605,9 +637,9 @@ def _compile_vopc(inst: VOPC | VOP3, ctx: _Ctx, opsel: int = 0, abs_bits: int =
vsrc1_hi = False
src1_off = _c(256) + vsrc1_off
else:
src0_off = ctx.inst_field(VOP3.src0)
src1_off = ctx.inst_field(VOP3.src1)
dst_off = ctx.inst_field(VOP3.vdst)
src0_off = ctx.inst_field(type(inst).src0)
src1_off = ctx.inst_field(type(inst).src1)
dst_off = ctx.inst_field(type(inst).vdst)
vsrc1_hi = False
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
@ -636,7 +668,7 @@ def _compile_vopc(inst: VOPC | VOP3, ctx: _Ctx, opsel: int = 0, abs_bits: int =
stores = [ctx.wsgpr_dyn(dst_off, new_result)] if not is_vopc else [ctx.wsgpr_dyn(_c(VCC_LO.offset), new_result)]
return UOp.sink(*stores, *ctx.inc_pc())
def _compile_vop3(inst: VOP3, ctx: _Ctx) -> UOp:
def _compile_vop3(inst: ir3.VOP3 | ir4.VOP3, ctx: _Ctx) -> UOp:
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
bits = inst.canonical_op_bits
opsel, op_name = getattr(inst, 'opsel', 0) or 0, _op_name(inst)
@ -645,18 +677,22 @@ def _compile_vop3(inst: VOP3, ctx: _Ctx) -> UOp:
if op_name in ('V_READLANE_B32', 'V_READFIRSTLANE_B32', 'V_READFIRSTLANE_B32_E64', 'V_WRITELANE_B32'):
return ctx.compile_lane_pcode(inst.op, inst)
# V_PERMLANE16_B32 / V_PERMLANEX16_B32: cross-lane swizzle via pcode
if 'PERMLANE16' in op_name or 'PERMLANEX16' in op_name:
return ctx.compile_lane_pcode(inst.op, inst)
# VOP3 VOPC (v_cmp_*_e64) - delegate to unified VOPC handler
if 'V_CMP' in op_name or 'V_CMPX' in op_name:
return _compile_vopc(inst, ctx, opsel=opsel, abs_bits=getattr(inst, 'abs', 0) or 0, neg_bits=getattr(inst, 'neg', 0) or 0)
# Regular VOP3 - read operands dynamically
lane = ctx.range()
vdst_reg = ctx.inst_field(VOP3.vdst)
vdst_reg = ctx.inst_field(type(inst).vdst)
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
ops = inst.canonical_operands
src0 = ctx.rsrc_dyn(ctx.inst_field(VOP3.src0), lane, bits['s0'], literal, 's0' in ops and ops['s0'][0] == Fmt.FMT_NUM_F64)
src1 = ctx.rsrc_dyn(ctx.inst_field(VOP3.src1), lane, bits['s1'], literal, 's1' in ops and ops['s1'][0] == Fmt.FMT_NUM_F64)
src2 = ctx.rsrc_dyn(ctx.inst_field(VOP3.src2), lane, bits['s2'], literal, 's2' in ops and ops['s2'][0] == Fmt.FMT_NUM_F64)
src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, bits['s0'], literal, 's0' in ops and ops['s0'][0] == Fmt.FMT_NUM_F64)
src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, bits['s1'], literal, 's1' in ops and ops['s1'][0] == Fmt.FMT_NUM_F64)
src2 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src2), lane, bits['s2'], literal, 's2' in ops and ops['s2'][0] == Fmt.FMT_NUM_F64)
if bits['s0'] == 16:
src0 = _apply_opsel(src0, 0, opsel)
src1 = _apply_opsel(src1, 1, opsel)
@ -666,19 +702,19 @@ def _compile_vop3(inst: VOP3, ctx: _Ctx) -> UOp:
src1 = _apply_src_mods(src1, 1, abs_bits, neg_bits, bits['s1'])
src2 = _apply_src_mods(src2, 2, abs_bits, neg_bits, bits['s2'])
srcs = {'S0': src0, 'S1': src1, 'S2': src2}
if inst.op in (VOP3Op.V_CNDMASK_B32_E64, VOP3Op.V_CNDMASK_B16) and src2 is not None: srcs['VCC'] = src2
if inst.op in (ir3.VOP3Op.V_CNDMASK_B32_E64, ir3.VOP3Op.V_CNDMASK_B16) and src2 is not None: srcs['VCC'] = src2
# FMAC instructions need D0 (accumulator) from destination register
if 'FMAC' in op_name: srcs['D0'] = ctx.rvgpr_dyn(vdst_reg, lane)
opsel_dst_hi = bool(opsel & 0b1000) and bits['d'] == 16
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=opsel_dst_hi, clmp=getattr(inst, 'clmp', 0))
def _compile_vop3sd(inst: VOP3SD, ctx: _Ctx) -> UOp:
def _compile_vop3sd(inst: ir3.VOP3SD | ir4.VOP3SD, ctx: _Ctx) -> UOp:
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
bits, pcode, ops = inst.canonical_op_bits, get_pcode(inst.op), inst.canonical_operands
# Read operands dynamically from instruction encoding
vdst_reg, sdst_off = ctx.inst_field(VOP3SD.vdst), ctx.inst_field(VOP3SD.sdst)
src0_off, src1_off, src2_off = ctx.inst_field(VOP3SD.src0), ctx.inst_field(VOP3SD.src1), ctx.inst_field(VOP3SD.src2)
vdst_reg, sdst_off = ctx.inst_field(type(inst).vdst), ctx.inst_field(type(inst).sdst)
src0_off, src1_off, src2_off = ctx.inst_field(type(inst).src0), ctx.inst_field(type(inst).src1), ctx.inst_field(type(inst).src2)
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
has_carry_in = 's2' in ops and ops['s2'][2] == OpType.OPR_SREG
@ -724,13 +760,13 @@ def _compile_vop3sd(inst: VOP3SD, ctx: _Ctx) -> UOp:
else:
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, sdst_reg=inst.sdst.offset)
def _compile_wmma(inst: VOP3P, ctx: _Ctx) -> UOp:
def _compile_wmma(inst: ir3.VOP3P | ir4.VOP3P, ctx: _Ctx) -> UOp:
op_name = _op_name(inst)
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
vdst_reg = ctx.inst_field(VOP3P.vdst)
src0_r = ctx.inst_field(VOP3P.src0) - _c(256)
src1_r = ctx.inst_field(VOP3P.src1) - _c(256)
src2_r = ctx.inst_field(VOP3P.src2) - _c(256)
vdst_reg = ctx.inst_field(type(inst).vdst)
src0_r = ctx.inst_field(type(inst).src0) - _c(256)
src1_r = ctx.inst_field(type(inst).src1) - _c(256)
src2_r = ctx.inst_field(type(inst).src2) - _c(256)
is_f16_output = 'F16_16X16X16_F16' in op_name or 'BF16_16X16X16_BF16' in op_name # F16/BF16 output vs F32 output
is_bf16 = 'BF16' in op_name
cvt = _FUNCS['bf16_to_f32'] if is_bf16 else _FUNCS['f16_to_f32']
@ -757,16 +793,16 @@ def _compile_wmma(inst: VOP3P, ctx: _Ctx) -> UOp:
stores = [ctx.wvgpr_dyn(vdst_reg + _c(i // 32), UOp.const(dtypes.int, i % 32), mat_d[i].bitcast(dtypes.uint32), exec_mask) for i in range(256)]
return UOp.sink(*stores, *ctx.inc_pc())
def _compile_vop3p(inst: VOP3P, ctx: _Ctx) -> UOp:
def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P, ctx: _Ctx) -> UOp:
op_name = _op_name(inst)
if 'WMMA' in op_name and ('16X16X16_F16' in op_name or '16X16X16_BF16' in op_name): return _compile_wmma(inst, ctx)
lane = ctx.range()
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
vdst_reg = ctx.inst_field(VOP3P.vdst)
src0 = ctx.rsrc_dyn(ctx.inst_field(VOP3P.src0), lane, 16)
src1 = ctx.rsrc_dyn(ctx.inst_field(VOP3P.src1), lane, 16)
src2 = ctx.rsrc_dyn(ctx.inst_field(VOP3P.src2), lane, 16)
vdst_reg = ctx.inst_field(type(inst).vdst)
src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, 16)
src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, 16)
src2 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src2), lane, 16)
opsel, opsel_hi = getattr(inst, 'opsel', 0) or 0, getattr(inst, 'opsel_hi', 3) if getattr(inst, 'opsel_hi', 3) is not None else 3
opsel_hi2 = getattr(inst, 'opsel_hi2', 1) if getattr(inst, 'opsel_hi2', 1) is not None else 1
neg, neg_hi = getattr(inst, 'neg', 0) or 0, getattr(inst, 'neg_hi', 0) or 0
@ -788,7 +824,7 @@ def _compile_vop3p(inst: VOP3P, ctx: _Ctx) -> UOp:
s0_mod = apply_neg_mix(apply_abs(src0, 1, 1, 1), 1, 1, 1)
s1_mod = apply_neg_mix(apply_abs(src1, 2, 2, 2), 2, 2, 2)
s2_mod = apply_neg_mix(apply_abs(src2, 4, 4, 4), 4, 4, 4)
srcs = {'S0': s0_mod, 'S1': s1_mod, 'S2': s2_mod,
srcs = {'S@0': s0_mod, 'S@1': s1_mod, 'S@2': s2_mod,
'OPSEL_HI': UOp.const(dtypes.uint32, combined_opsel_hi), 'OPSEL': UOp.const(dtypes.uint32, opsel)}
else:
def get_half_bits(val: UOp, use_hi: bool, apply_neg: bool = False) -> UOp:
@ -806,18 +842,19 @@ def _compile_vop3p(inst: VOP3P, ctx: _Ctx) -> UOp:
if is_dot_iu: srcs['NEG'] = UOp.const(dtypes.uint32, neg)
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask)
def _compile_vopd(inst: VOPD, ctx: _Ctx) -> UOp:
def _compile_vopd(inst: ir3.VOPD | ir4.VOPD, ctx: _Ctx) -> UOp:
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
# Read operands dynamically
vdstx_reg = ctx.inst_field(VOPD.vdstx)
# Read operands dynamically - use type(inst) to get correct field descriptors
inst_type = type(inst)
vdstx_reg = ctx.inst_field(inst_type.vdstx)
# vdsty has complex encoding: actual = (raw << 1) | ((vdstx & 1) ^ 1)
vdsty_raw = ctx.inst_field(VOPD.vdsty)
vdsty_raw = ctx.inst_field(inst_type.vdsty)
vdsty_reg = (vdsty_raw << _c(1)) | ((vdstx_reg & _c(1)) ^ _c(1))
srcx0_off = ctx.inst_field(VOPD.srcx0)
srcy0_off = ctx.inst_field(VOPD.srcy0)
vsrcx1_reg = ctx.inst_field(VOPD.vsrcx1)
vsrcy1_reg = ctx.inst_field(VOPD.vsrcy1)
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None
srcx0_off = ctx.inst_field(inst_type.srcx0)
srcy0_off = ctx.inst_field(inst_type.srcy0)
vsrcx1_reg = ctx.inst_field(inst_type.vsrcx1)
vsrcy1_reg = ctx.inst_field(inst_type.vsrcy1)
literal = ctx.inst_field(inst_type.literal) if hasattr(inst_type, 'literal') else None
lane = ctx.range()
srcy0, srcy1 = ctx.rsrc_dyn(srcy0_off, lane, literal=literal), ctx.rvgpr_dyn(vsrcy1_reg, lane)
@ -828,49 +865,64 @@ def _compile_vopd(inst: VOPD, ctx: _Ctx) -> UOp:
assert vop is not None, f"no VOP mapping for VOPD {label}: {op}"
if label == 'Y': srcs = {'S0': srcy0, 'S1': srcy1, 'D0': ctx.rvgpr_dyn(vdst_reg, lane)}
else: srcs = {'S0': ctx.rsrc_dyn(src0_off, lane, literal=literal), 'S1': ctx.rvgpr_dyn(vsrc1_reg, lane), 'D0': ctx.rvgpr_dyn(vdst_reg, lane)}
if op in (VOPDOp.V_DUAL_FMAAK_F32, VOPDOp.V_DUAL_FMAMK_F32):
if op in (ir3.VOPDOp.V_DUAL_FMAAK_F32, ir3.VOPDOp.V_DUAL_FMAMK_F32, ir4.VOPDOp.V_DUAL_FMAAK_F32, ir4.VOPDOp.V_DUAL_FMAMK_F32):
assert literal is not None
srcs['SIMM32'] = literal
if op == VOPDOp.V_DUAL_CNDMASK_B32: srcs['VCC'] = ctx.rsgpr_dyn(_c(VCC_LO.offset))
if op in (ir3.VOPDOp.V_DUAL_CNDMASK_B32, ir4.VOPDOp.V_DUAL_CNDMASK_B32): srcs['VCC'] = ctx.rsgpr_dyn(_c(VCC_LO.offset))
pcode = get_pcode(vop)
srcs.update({'VCC': ctx.rsgpr_dyn(_c(VCC_LO.offset)), 'EXEC': exec_mask, 'SCC': ctx.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane})
for dest, val in parse_pcode(pcode, srcs)[1]:
if dest.startswith('D0'): all_stores.append(ctx.wvgpr_dyn(vdst_reg, lane, _val_to_u32(val), exec_mask, after=srcy1))
return UOp.sink(UOp.group(*all_stores).end(lane), *ctx.inc_pc())
def _compile_mem_op(inst: DS | FLAT | GLOBAL | SCRATCH, ctx: _Ctx) -> UOp:
def _compile_mem_op(inst: ir3.DS | ir3.FLAT | ir3.GLOBAL | ir3.SCRATCH | ir4.DS | ir4.VFLAT | ir4.VGLOBAL | ir4.VSCRATCH, ctx: _Ctx) -> UOp:
"""Unified memory operation compiler for DS, FLAT, GLOBAL, SCRATCH."""
exec_mask, op_name = ctx.rsgpr_dyn(_c(EXEC_LO.offset)), _op_name(inst)
pcode = get_pcode(inst.op)
is_lds = isinstance(inst, DS)
is_scratch = isinstance(inst, SCRATCH)
is_lds = isinstance(inst, (ir3.DS, ir4.DS))
is_scratch = isinstance(inst, (ir3.SCRATCH, ir4.VSCRATCH))
mem = ctx.lds if is_lds else ctx.scratch if is_scratch else ctx.vmem
addr_shift = UOp.const(dtypes.uint32 if is_lds else dtypes.uint64, 2)
# Extract register info - all dynamic for deduplication
if is_lds:
addr_reg = ctx.inst_field(DS.addr)
vdata_reg = ctx.inst_field(DS.data0)
vdst_reg = ctx.inst_field(DS.vdst)
offset0 = ctx.inst_field(DS.offset0)
offset1 = ctx.inst_field(DS.offset1)
addr_reg = ctx.inst_field(type(inst).addr)
vdata_reg = ctx.inst_field(type(inst).data0)
vdst_reg = ctx.inst_field(type(inst).vdst)
offset0 = ctx.inst_field(type(inst).offset0)
offset1 = ctx.inst_field(type(inst).offset1)
offset = offset0 # DS uses offset0 as primary offset
saddr_reg = None
else:
elif isinstance(inst, (ir4.VGLOBAL, ir4.VSCRATCH, ir4.VFLAT)): # RDNA4: vaddr, vsrc, ioffset
addr_reg = ctx.inst_field(type(inst).vaddr)
vdata_reg = ctx.inst_field(type(inst).vsrc)
vdst_reg = ctx.inst_field(type(inst).vdst)
offset = ctx.inst_field_signed(type(inst).ioffset)
offset0, offset1 = _c(0), _c(0)
saddr_reg = ctx.inst_field(type(inst).saddr) if hasattr(type(inst), 'saddr') else None
else: # RDNA3: addr, data, offset
addr_reg = ctx.inst_field(type(inst).addr)
vdata_reg = ctx.inst_field(type(inst).data)
vdst_reg = ctx.inst_field(type(inst).vdst)
offset = ctx.inst_field_signed(type(inst).offset)
offset0, offset1 = _c(0), _c(0)
# Dynamic saddr - read field, NULL (124) or >= 128 means no saddr
saddr_reg = ctx.inst_field(type(inst).saddr) if hasattr(inst, 'saddr') else None
saddr_reg = ctx.inst_field(type(inst).saddr) if hasattr(type(inst), 'saddr') else None
# Data width from canonical_op_bits (32/64/96/128), default to 32 for untyped ops
data_bits_mem = inst.canonical_op_bits.get('data', 32)
is_atomic, glc = 'ATOMIC' in op_name, getattr(inst, 'glc', 0)
has_data1 = is_lds and hasattr(inst, 'data1') and inst.data1 is not None
data1_reg = ctx.inst_field(DS.data1) if is_lds else _c(0)
data1_reg = ctx.inst_field(type(inst).data1) if is_lds else _c(0)
# DS_PERMUTE/DS_BPERMUTE: cross-lane VGPR access via pcode
if is_lds and 'PERMUTE' in op_name:
pcode = get_pcode(inst.op)
srcs = {'ADDR': addr_reg, 'DATA0': vdata_reg, 'VDST': vdst_reg, 'OFFSET': offset,
'EXEC': exec_mask.cast(dtypes.uint64), '_vgpr': ctx.vgpr}
_, assigns = parse_pcode(pcode, srcs)
stores = [ctx.vgpr.index(val[0].cast(dtypes.int)).store(val[1].cast(dtypes.uint32)) for dest, val in assigns if dest.startswith('VGPR[')]
return UOp.sink(*stores, *ctx.inc_pc())
def make_addr(lane: UOp) -> UOp:
if is_lds: return ctx.rvgpr_dyn(addr_reg, lane)
@ -912,14 +964,26 @@ def _compile_mem_op(inst: DS | FLAT | GLOBAL | SCRATCH, ctx: _Ctx) -> UOp:
else:
data = {'DATA': _u64(ctx.rvgpr_dyn(vdata_reg, lane), ctx.rvgpr_dyn(vdata_reg + _c(1), lane)),
'DATA2': _u64(ctx.rvgpr_dyn(data1_reg, lane), ctx.rvgpr_dyn(data1_reg + _c(1), lane)) if has_data1 else UOp.const(dtypes.uint64, 0)}
return {'ADDR': addr, 'ADDR_BASE': addr, 'OFFSET': offset, 'OFFSET0': offset0, 'OFFSET1': offset1, '_lds': mem, 'laneId': lane, **data}
# RDNA3 uses ADDR/OFFSET, RDNA4 uses vgpr_a/offset (lowercase) + CalcDsAddr function
return {'ADDR': addr, 'ADDR_BASE': addr, 'OFFSET': offset, 'OFFSET0': offset0, 'OFFSET1': offset1, '_lds': mem, 'laneId': lane,
'vgpr_a': ctx.rvgpr_dyn(addr_reg, lane), 'offset': offset, **data}
active = _lane_active(exec_mask, lane)
# saddr < 124 means valid SGPR pair, otherwise use 0 (NULL means no saddr contribution)
use_saddr = (saddr_reg < _c(124)) if saddr_reg is not None else UOp.const(dtypes.bool, False)
saddr_raw = _u64(ctx.rsgpr_dyn(saddr_reg), ctx.rsgpr_dyn(saddr_reg + _c(1))) if saddr_reg is not None else UOp.const(dtypes.uint64, 0)
saddr_base = use_saddr.where(saddr_raw, UOp.const(dtypes.uint64, 0))
# Sign-extend offset to 64-bit for the final address calculation
ioffset64 = offset.cast(dtypes.int64).cast(dtypes.uint64)
# v_addr for CalcGlobalAddr: when saddr valid, use low 32 bits as offset; otherwise full 64-bit address. Include ioffset.
vaddr_full = _u64(ctx.rvgpr_dyn(addr_reg, lane), ctx.rvgpr_dyn(addr_reg + _c(1), lane))
vaddr_lo = ctx.rvgpr_dyn(addr_reg, lane).cast(dtypes.uint64)
vaddr_base = use_saddr.where(vaddr_lo + ioffset64, vaddr_full + ioffset64)
if is_atomic:
return {'ADDR': addr, 'DATA': _u64(ctx.rvgpr_dyn(vdata_reg, lane), ctx.rvgpr_dyn(vdata_reg + _c(1), lane)) if data_bits_mem == 64 else ctx.rvgpr_dyn(vdata_reg, lane),
'_vmem': mem, '_active': active, 'laneId': lane}
'_vmem': mem, '_active': active, 'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base}
vdata = ctx.rvgpr_dyn(vdata_reg, lane).cast(dtypes.uint64) if 'STORE' in op_name else ctx.rvgpr_dyn(vdst_reg, lane) if 'D16' in op_name else UOp.const(dtypes.uint32, 0)
if 'STORE' in op_name and data_bits_mem >= 64: vdata = vdata | (ctx.rvgpr_dyn(vdata_reg + _c(1), lane).cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32))
srcs = {'ADDR': addr, 'VDATA': vdata, '_vmem': mem, '_active': active, 'laneId': lane}
srcs = {'ADDR': addr, 'VDATA': vdata, '_vmem': mem, '_active': active, 'laneId': lane, 'v_addr': vaddr_base, 's_saddr': saddr_base}
for i in range(data_bits_mem // 32): srcs[f'VDATA{i}'] = ctx.rvgpr_dyn(vdata_reg + _c(i), lane) if 'STORE' in op_name else UOp.const(dtypes.uint32, 0)
return srcs
@ -970,10 +1034,15 @@ def _compile_mem_op(inst: DS | FLAT | GLOBAL | SCRATCH, ctx: _Ctx) -> UOp:
# Dispatch table: instruction type -> handler function
_INST_HANDLERS: dict[type, Callable[..., UOp]] = {
SOPP: _compile_sopp, SMEM: _compile_smem, SOP1: _compile_sop, SOP2: _compile_sop, SOPC: _compile_sop, SOPK: _compile_sop,
VOP1: _compile_vop12, VOP1_SDST: _compile_vop12, VOP2: _compile_vop12, VOPC: _compile_vopc, VOP3: _compile_vop3, VOP3_SDST: _compile_vop3,
VOP3SD: _compile_vop3sd, VOP3P: _compile_vop3p, VOPD: _compile_vopd,
DS: _compile_mem_op, FLAT: _compile_mem_op, GLOBAL: _compile_mem_op, SCRATCH: _compile_mem_op,
ir3.SOPP: _compile_sopp, ir3.SMEM: _compile_smem, ir3.SOP1: _compile_sop, ir3.SOP2: _compile_sop, ir3.SOPC: _compile_sop, ir3.SOPK: _compile_sop,
ir3.VOP1: _compile_vop12, ir3.VOP1_SDST: _compile_vop12, ir3.VOP2: _compile_vop12, ir3.VOPC: _compile_vopc, ir3.VOP3: _compile_vop3,
ir3.VOP3_SDST: _compile_vop3, ir3.VOP3SD: _compile_vop3sd, ir3.VOP3P: _compile_vop3p, ir3.VOPD: _compile_vopd,
ir3.DS: _compile_mem_op, ir3.FLAT: _compile_mem_op, ir3.GLOBAL: _compile_mem_op, ir3.SCRATCH: _compile_mem_op,
# RDNA4 instruction classes
ir4.SOPP: _compile_sopp, ir4.SMEM: _compile_smem, ir4.SOP1: _compile_sop, ir4.SOP2: _compile_sop, ir4.SOPC: _compile_sop, ir4.SOPK: _compile_sop,
ir4.VOP1: _compile_vop12, ir4.VOP1_SDST: _compile_vop12, ir4.VOP2: _compile_vop12, ir4.VOPC: _compile_vopc, ir4.VOP3: _compile_vop3,
ir4.VOP3_SDST: _compile_vop3, ir4.VOP3SD: _compile_vop3sd, ir4.VOP3P: _compile_vop3p, ir4.VOPD: _compile_vopd,
ir4.DS: _compile_mem_op, ir4.VFLAT: _compile_mem_op, ir4.VGLOBAL: _compile_mem_op, ir4.VSCRATCH: _compile_mem_op,
}
# ═══════════════════════════════════════════════════════════════════════════════
@ -983,9 +1052,9 @@ _INST_HANDLERS: dict[type, Callable[..., UOp]] = {
_canonical_runner_cache: list[tuple[int, int, int, object]] = [] # [(base, mask, size, runner), ...]
@functools.cache
def _get_runner(inst_bytes: bytes):
def _get_runner(inst_bytes: bytes, arch: str = "rdna3"):
"""Build and compile instruction to CompiledRunner. Cached by instruction bytes, with canonical dedup."""
inst = decode_inst(inst_bytes)
inst = decode_inst(inst_bytes, arch)
inst_size = inst.size()
inst_int = int.from_bytes(inst_bytes[:inst_size], 'little')
@ -1014,15 +1083,15 @@ def _get_runner(inst_bytes: bytes):
return runner, True
@functools.cache
def decode_program(data: bytes) -> dict[int, tuple[str, Callable, list[int], Any]]:
def decode_program(data: bytes, arch: str = "rdna3") -> dict[int, tuple[str, Callable, list[int], Any]]:
"""Decode program to {pc: (name, fxn, globals, runner)}."""
result: dict[int, tuple[str, Callable, list[int], Any]] = {}
i = 0
while i < len(data):
inst = decode_inst(data[i:])
if isinstance(inst, SOPP) and inst.op == SOPPOp.S_CODE_END: break
inst = decode_inst(data[i:], arch)
if hasattr(inst, 'op') and inst.op in (ir3.SOPPOp.S_CODE_END, ir4.SOPPOp.S_CODE_END): break
try:
runner, is_new = _get_runner(bytes(data[i:i + inst.size() + 4]))
runner, is_new = _get_runner(bytes(data[i:i + inst.size() + 4]), arch)
if DEBUG >= 3:
try: inst_str = repr(inst)
except Exception: inst_str = f"<{type(inst).__name__} at PC={i}>"
@ -1081,9 +1150,9 @@ class WaveState:
# ═══════════════════════════════════════════════════════════════════════════════
def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int = 0x19c,
scratch_size: int = 0) -> int:
scratch_size: int = 0, arch: str = "rdna3") -> int:
"""Execute AMD assembly program. scratch_size is private_segment_fixed_size from kernel descriptor (per-lane)."""
program_raw = decode_program(bytes((ctypes.c_char * lib_sz).from_address(lib).raw))
program_raw = decode_program(bytes((ctypes.c_char * lib_sz).from_address(lib).raw), arch)
program = {lib + offset: val for offset, val in program_raw.items()} # Remap to actual addresses
lds_size = ((rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE_SHIFT) * 512
total_threads = lx * ly * lz
@ -1111,6 +1180,12 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int,
(hsa.AMD_COMPUTE_PGM_RSRC_TWO_ENABLE_SGPR_WORKGROUP_ID_Z, gidz)]:
if rsrc2 & enabled: st._write_sgpr(sgpr_idx, gid); sgpr_idx += 1
# RDNA4 uses TTMP registers for workgroup IDs: ttmp[9]=gidx, ttmp[10]=gidy, ttmp[11]=gidz
if arch == "rdna4":
st._write_sgpr(ttmp[9].offset, gidx)
st._write_sgpr(ttmp[10].offset, gidy)
st._write_sgpr(ttmp[11].offset, gidz)
# v0 = packed workitem IDs, scratch stride in secret SGPR
for lane in range(n_lanes):
tid = wave_start + lane
@ -1127,7 +1202,7 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int,
assert fxn is not None, f"[emu] No fxn for {name} at PC={pc}"
assert 4 not in globals_list or scratch_buf, f"SCRATCH instruction {name} but scratch_size=0"
if DEBUG >= 6:
inst = decode_inst(bytes((ctypes.c_char * 12).from_address(pc).raw))
inst = decode_inst(bytes((ctypes.c_char * 12).from_address(pc).raw), arch)
print(f"[emu] exec PC={pc:X}: {inst!r}")
fxn(*[c_bufs[g] for g in globals_list])
else: raise RuntimeError("exceeded 1M instructions, likely infinite loop")

View file

@ -200,6 +200,17 @@ def _ff1(val: UOp, bits: int) -> UOp:
result = cond.where(_const(dtypes.int, i), result)
return result
def _sad_u8(a: UOp, b: UOp, acc: UOp, masked: bool = False) -> UOp:
"""Sum of absolute differences of 4 unsigned bytes + accumulator. If masked, skips bytes where a == 0."""
a, b, acc = a.cast(dtypes.uint32), b.cast(dtypes.uint32), acc.cast(dtypes.uint32)
result = acc
for i in range(4):
a_byte = (a >> _u32(i * 8)) & _u32(0xFF)
b_byte = (b >> _u32(i * 8)) & _u32(0xFF)
diff = (a_byte > b_byte).where(a_byte - b_byte, b_byte - a_byte)
result = result + (a_byte.ne(_u32(0)).where(diff, _u32(0)) if masked else diff)
return result
_FUNCS: dict[str, Callable[..., UOp]] = {
'sqrt': lambda a: UOp(Ops.SQRT, a.dtype, (a,)), 'trunc': lambda a: UOp(Ops.TRUNC, a.dtype, (a,)),
'log2': lambda a: UOp(Ops.LOG2, a.dtype, (a,)), 'sin': lambda a: _trig_reduce(a),
@ -254,6 +265,15 @@ _FUNCS: dict[str, Callable[..., UOp]] = {
# Float to int16 conversions
'v_cvt_i16_f32': lambda a: UOp(Ops.TRUNC, dtypes.float32, (a.bitcast(dtypes.float32),)).cast(dtypes.int16),
'v_cvt_u16_f32': lambda a: _f_to_u(a.bitcast(dtypes.float32), dtypes.uint16),
# SAD (Sum of Absolute Differences) - sum |a_i - b_i| for 4 bytes + accumulator
'v_sad_u8': lambda a, b, c: _sad_u8(a, b, c),
'v_msad_u8': lambda a, b, c: _sad_u8(a, b, c, masked=True),
# System NOPs - these are scheduling hints, no effect on emulation
'MIN': lambda a, b: (a < b).where(a, b),
's_nop': lambda a: _u32(0),
# Address calculation for memory operations
'CalcDsAddr': lambda a, o, *r: a.cast(dtypes.uint32) + o.cast(dtypes.uint32),
'CalcGlobalAddr': lambda v, s, *r: v.cast(dtypes.uint64) + s.cast(dtypes.uint64),
}
for is_max, name in [(False, 'min'), (True, 'max')]:
for dt, sfx in [(dtypes.float32, 'f32'), (dtypes.int, 'i32'), (dtypes.uint32, 'u32'), (dtypes.int16, 'i16'), (dtypes.uint16, 'u16')]:
@ -435,7 +455,7 @@ class Parser:
self.eat('DOT')
dt_name = self.eat('IDENT').val
return self._handle_mem_load(addr, DTYPES.get(dt_name, dtypes.uint32))
if name == 'VGPR':
if name == 'VGPR' and self.at('LBRACKET'):
self.eat('LBRACKET')
lane = self.parse()
self.eat('RBRACKET')
@ -462,7 +482,21 @@ class Parser:
if self.try_eat('LBRACE'):
idx = self.eat('NUM').val
self.eat('RBRACE')
elem = self.vars.get(f'{name}{idx}', _u32(0))
# Handle VGPR{lane}[reg] - 2D array access after loop unrolling
if name == 'VGPR' and self.at('LBRACKET'):
self.eat('LBRACKET')
reg = self.parse()
self.eat('RBRACKET')
vgpr = self.vars.get('_vgpr')
if vgpr is None: return _u32(0)
return vgpr.index(_to_u32(reg) * _u32(32) + _u32(int(idx)), ptr=True).load()
elem = self.vars.get(f'{name}@{idx}', self.vars.get(f'{name}{idx}'))
if elem is None:
# Extract bit idx from base variable (like var[idx])
base = self.vars.get(name)
assert isinstance(base, UOp), f"unknown variable: {name}{idx}"
dt = dtypes.uint64 if base.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32
elem = (base.cast(dt) >> _const(dt, int(idx))) & _const(dt, 1)
if self.try_eat('DOT'):
dt_name = self.eat('IDENT').val
return _cast_to(elem, DTYPES.get(dt_name, dtypes.uint32))
@ -475,15 +509,13 @@ class Parser:
return self._handle_bracket_rest(first, _u32(0), name)
if name in self.vars:
v = self.vars[name]
return v if isinstance(v, UOp) else _u32(0) if isinstance(v, dict) else _u32(0)
assert isinstance(v, UOp), f"expected UOp for {name}, got {type(v)}"
return v
raise RuntimeError(f"unknown variable: {name}")
raise RuntimeError(f"unexpected token in primary: {self.peek()}")
def _handle_dot(self, base, field: str) -> UOp:
if isinstance(base, str): return _u32(0)
if not isinstance(base, UOp):
if isinstance(base, dict): return base.get(field, _u32(0))
return _u32(0)
assert isinstance(base, UOp), f"expected UOp for dot access, got {type(base)}"
if field == 'u64' and self.at('LBRACKET') and self.peek(1).type == 'IDENT' and self.peek(1).val == 'laneId':
self.eat('LBRACKET')
self.eat_val('laneId', 'IDENT')
@ -541,9 +573,11 @@ class Parser:
var_name = self._find_var_name(base)
if first.op == Ops.CONST:
idx = int(first.arg)
# Check for array element (var@idx)
if var_name and f'{var_name}@{idx}' in self.vars:
v = self.vars[f'{var_name}@{idx}']
return _cast_to(v, dt_suffix) if dt_suffix else v
# Bit extraction
dt = dtypes.uint64 if base.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32
base_cast = base.cast(dt) if base.dtype != dt else base
result = ((base_cast >> _const(dt, idx)) & _const(dt, 1))
@ -631,13 +665,14 @@ class Parser:
raise RuntimeError(f"unexpected token after {bits}': {self.peek()}")
def _parse_number(self, num: str) -> UOp:
if num.startswith('0x') or num.startswith('0X'): return _const(dtypes.uint64, int(num.rstrip('ULul'), 16))
suffix, num = _strip_suffix(num)
if '.' in num or suffix in ('F', 'f'):
return _const(dtypes.float32 if suffix in ('F', 'f') else dtypes.float64, float(num))
val = int(num)
if 'ULL' in suffix: return _const(dtypes.uint64, val)
if 'LL' in suffix or 'L' in suffix: return _const(dtypes.uint64, val)
if num.startswith('0x') or num.startswith('0X'):
is_u64 = num.upper().endswith('ULL') or num.upper().endswith('LL') or num.upper().endswith('UL')
return _const(dtypes.uint64 if is_u64 else dtypes.uint32, int(num.rstrip('ULul'), 16))
suffix, num_str = _strip_suffix(num)
if '.' in num_str or suffix in ('F', 'f'):
return _const(dtypes.float32 if suffix in ('F', 'f') else dtypes.float64, float(num_str))
val = int(num_str)
if 'ULL' in suffix or 'LL' in suffix or 'L' in suffix: return _const(dtypes.uint64, val)
if 'U' in suffix: return _const(dtypes.uint32, val)
return _const(dtypes.int if val < 0 else dtypes.uint32, val)
@ -655,7 +690,8 @@ class Parser:
if ';' in body or '\n' in body or 'return' in body.lower():
lines = [l.strip() for l in body.replace(';', '\n').split('\n') if l.strip() and not l.strip().startswith('//')]
_, _, result = parse_block(lines, 0, lv, self.funcs)
return result if result is not None else _u32(0)
assert result is not None, f"lambda {name} must return a value"
return result
return parse_expr(body, lv, self.funcs)
if name in self.funcs:
return self.funcs[name](*args)
@ -663,7 +699,7 @@ class Parser:
def _handle_mem_load(self, addr: UOp, dt) -> UOp:
mem = self.vars.get('_vmem') if '_vmem' in self.vars else self.vars.get('_lds')
if mem is None: return _const(dt, 0)
assert mem is not None, "memory load requires _vmem or _lds"
adt = dtypes.uint64 if addr.dtype == dtypes.uint64 else dtypes.uint32
active = self.vars.get('_active')
gate = (active,) if active is not None else ()
@ -725,29 +761,9 @@ def parse_tokens(toks: list[Token], vars: dict[str, VarVal], funcs: dict | None
# Unified block parser for pcode
def _subst_loop_var(line: str, loop_var: str, val: int) -> str:
"""Substitute loop variable and evaluate bracket expressions.
Converts var[loop_var] to var{val} for array element access (like the old regex parser)."""
"""Substitute loop variable with its value."""
toks = tokenize(line)
# First pass: convert var[loop_var] to var{loop_var} to mark for array element assignment
result_toks: list[Token] = []
j = 0
while j < len(toks):
t = toks[j]
# Check for pattern: IDENT[loop_var] where it's not preceded by a dot (not .type[...])
if t.type == 'IDENT' and j+3 < len(toks) and toks[j+1].type == 'LBRACKET' and toks[j+2].type == 'IDENT' and toks[j+2].val == loop_var and toks[j+3].type == 'RBRACKET':
# Check that it's not .type[loop_var]
if not result_toks or result_toks[-1].type != 'DOT':
result_toks.append(t)
result_toks.append(Token('LBRACE', '{'))
result_toks.append(Token('NUM', str(val)))
result_toks.append(Token('RBRACE', '}'))
j += 4
continue
result_toks.append(t)
j += 1
# Second pass: substitute loop variable in remaining positions
subst_parts = [str(val) if t.type == 'IDENT' and t.val == loop_var else t.val for t in result_toks if t.type != 'EOF']
return ' '.join(subst_parts)
return ' '.join(str(val) if t.type == 'IDENT' and t.val == loop_var else t.val for t in toks if t.type != 'EOF')
def _set_bits(old: UOp, val: UOp, width: int, offset: int) -> UOp:
"""Set bits [offset:offset+width) in old to val, masking and shifting appropriately."""
@ -797,8 +813,9 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
def parse_bound():
if p.at('NUM') and p.peek(1).type == 'QUOTE': p.eat('NUM'); p.eat('QUOTE')
if p.at('NUM'): return int(p.eat('NUM').val.rstrip('UuLl'))
expr = p.parse()
return int(expr.arg) if expr.op == Ops.CONST else 0
expr = p.parse().simplify()
assert expr.op == Ops.CONST, f"loop bound must be constant, got {expr}"
return int(expr.arg)
start_val = parse_bound()
p.eat('COLON')
end_val = parse_bound()
@ -844,7 +861,9 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
# declare
if first == 'declare':
if '[' not in line and len(toks) >= 2 and toks[1].type == 'IDENT': vars[toks[1].val] = _u32(0)
# Initialize scalar declarations (skip arrays and vars already passed as srcs)
if '[' not in line and len(toks) >= 2 and toks[1].type == 'IDENT':
vars.setdefault(toks[1].val, _u32(0))
i += 1; continue
# lambda definition
@ -902,6 +921,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
j, lane_toks = _match_bracket(toks, 1)
if j < len(toks) and toks[j].type == 'LBRACKET':
j, reg_toks = _match_bracket(toks, j)
if j < len(toks) and toks[j].type == 'DOT': j += 2 # skip .type suffix
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
ln, rg, val = parse_tokens(lane_toks, vars, funcs), parse_tokens(reg_toks, vars, funcs), parse_tokens(toks[j:], vars, funcs)
if assigns is not None: assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}]', (_to_u32(rg) * _u32(32) + _to_u32(ln), val)))
@ -965,19 +985,32 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
block_assigns[var] = vars[var] = _set_bit(existing, _to_u32(parse_tokens(bit_toks, vars, funcs)), parse_tokens(toks[j+1:], vars, funcs))
i += 1; continue
# Array element: var{idx} = value
if len(toks) >= 5 and toks[0].type == 'IDENT' and toks[1].type == 'LBRACE' and toks[2].type == 'NUM':
var, idx = toks[0].val, int(toks[2].val)
j = 4
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
if j < len(toks):
val = parse_tokens(toks[j+1:], vars, funcs)
existing = block_assigns.get(var, vars.get(var))
if existing is not None and isinstance(existing, UOp):
block_assigns[var] = vars[var] = _set_bit(existing, _u32(idx), val)
else:
block_assigns[f'{var}@{idx}'] = vars[f'{var}@{idx}'] = val
i += 1; continue
# Array element: var[idx] = value (static index) or var[expr] = value (dynamic)
if len(toks) >= 4 and toks[0].type == 'IDENT' and toks[1].type == 'LBRACKET':
var = toks[0].val
j, idx_toks = _match_bracket(toks, 1)
if j < len(toks) and toks[j].type == 'EQUALS':
# Static index: var[NUM] = value
if len(idx_toks) == 1 and idx_toks[0].type == 'NUM':
idx = int(idx_toks[0].val.rstrip('UuLl'))
val = parse_tokens(toks[j+1:], vars, funcs)
existing = block_assigns.get(var, vars.get(var))
if existing is not None and isinstance(existing, UOp):
block_assigns[var] = vars[var] = _set_bit(existing, _u32(idx), val)
else:
block_assigns[f'{var}@{idx}'] = vars[f'{var}@{idx}'] = val
i += 1; continue
# Dynamic index: var[expr] = value where var has @-elements
elems = [(k.split('@')[1], v) for k, v in {**vars, **block_assigns}.items() if k.startswith(f'{var}@') and isinstance(v, UOp)]
if elems:
idx_expr = parse_tokens(idx_toks, vars, funcs)
val = parse_tokens(toks[j+1:], vars, funcs)
for elem_idx_str, old_elem in elems:
elem_idx = int(elem_idx_str)
cond = _to_u32(idx_expr).eq(_u32(elem_idx))
new_val = cond.where(val.cast(old_elem.dtype) if val.dtype != old_elem.dtype else val, old_elem)
block_assigns[f'{var}@{elem_idx}'] = vars[f'{var}@{elem_idx}'] = new_val
i += 1; continue
# Compound assignment: var += or var -=
assign_op = next((j for j, t in enumerate(toks) if t.type == 'ASSIGN_OP'), None)
@ -1024,13 +1057,14 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
def parse_cond(s, kw):
ll = s.lower()
return _to_bool(parse_expr(s[ll.find(kw) + len(kw):ll.rfind('then')].strip(), vars, funcs))
def not_static_false(c): return c.op != Ops.CONST or c.arg is not False
def is_const(c, v): return c.op == Ops.CONST and c.arg is v
cond = parse_cond(line, 'if')
conditions: list[tuple[UOp, UOp | dict[str, VarVal] | None]] = [(cond, None)] if not_static_false(cond) else []
conditions: list[tuple[UOp, UOp | dict[str, VarVal] | None]] = [(cond, None)] if not is_const(cond, False) else []
else_branch: tuple[UOp | None, dict[str, VarVal]] = (None, {})
vars_snap = dict(vars)
static_true = is_const(cond, True) # track if any condition is statically true
i += 1
i, branch, ret = parse_block(lines, i, vars, funcs, assigns)
i, branch, ret = parse_block(lines, i, vars, funcs, assigns if not is_const(cond, False) else None)
if conditions: conditions[0] = (cond, ret if ret is not None else branch)
vars.clear(); vars.update(vars_snap)
while i < len(lines):
@ -1039,12 +1073,16 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
lf = ltoks[0].val.lower()
if lf == 'elsif':
c = parse_cond(lines[i], 'elsif')
i += 1; i, branch, ret = parse_block(lines, i, vars, funcs, assigns)
if not_static_false(c): conditions.append((c, ret if ret is not None else branch))
take = not static_true and not is_const(c, False)
i += 1; i, branch, ret = parse_block(lines, i, vars, funcs, assigns if take else None)
if take:
conditions.append((c, ret if ret is not None else branch))
if is_const(c, True): static_true = True
vars.clear(); vars.update(vars_snap)
elif lf == 'else':
i += 1; i, branch, ret = parse_block(lines, i, vars, funcs, assigns)
else_branch = (ret, branch)
i += 1
i, branch, ret = parse_block(lines, i, vars, funcs, assigns if not static_true else None)
if not static_true: else_branch = (ret, branch)
vars.clear(); vars.update(vars_snap)
elif lf == 'endif': i += 1; break
else: break
@ -1056,17 +1094,21 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
if rv.dtype != result.dtype and rv.dtype.itemsize == result.dtype.itemsize: result = result.cast(rv.dtype)
result = c.where(rv, result)
return i, block_assigns, result
# Main style: merge variable assignments with WHERE
else_assigns = else_branch[1]
all_vars = set().union(*[ba.keys() for _, ba in conditions if isinstance(ba, dict)], else_assigns.keys())
for var in all_vars:
res: Any = else_assigns.get(var, block_assigns.get(var, vars.get(var, _u32(0))))
for cond, ba in reversed(conditions):
if isinstance(ba, dict) and var in ba:
tv = ba[var]
if isinstance(tv, UOp) and isinstance(res, UOp):
res = cond.where(tv, res.cast(tv.dtype) if tv.dtype != res.dtype and tv.dtype.itemsize == res.dtype.itemsize else res)
block_assigns[var] = vars[var] = res
# If statically true, use that branch directly; otherwise merge with WHERE
if static_true:
ba = next((b for c, b in conditions if is_const(c, True) and isinstance(b, dict)), {})
block_assigns.update(ba); vars.update(ba)
else:
else_assigns = else_branch[1]
all_vars = set().union(*[ba.keys() for _, ba in conditions if isinstance(ba, dict)], else_assigns.keys())
for var in all_vars:
res: Any = else_assigns.get(var, block_assigns.get(var, vars.get(var, _u32(0))))
for cond, ba in reversed(conditions):
if isinstance(ba, dict) and var in ba:
tv = ba[var]
if isinstance(tv, UOp) and isinstance(res, UOp):
res = cond.where(tv, res.cast(tv.dtype) if tv.dtype != res.dtype and tv.dtype.itemsize == res.dtype.itemsize else res)
block_assigns[var] = vars[var] = res
continue
# Regular assignment: var = value

View file

@ -2,11 +2,8 @@
from dataclasses import dataclass
from typing import Iterator
from tinygrad.runtime.support.elf import elf_loader
from extra.assembly.amd.sqtt import decode, print_packets, INST, VALUINST, IMMEDIATE, WAVESTART, WAVEEND, InstOp, PacketType, IMMEDIATE_MASK
from extra.assembly.amd.dsl import Inst
from extra.assembly.amd import decode_inst
from extra.assembly.amd.autogen.rdna3.ins import SOPP, s_endpgm
from extra.assembly.amd.autogen.rdna3.enum import SOPPOp
@ -16,19 +13,11 @@ class InstructionInfo:
wave: int
inst: Inst
def map_insts(data:bytes, lib:bytes) -> Iterator[tuple[PacketType, InstructionInfo|None]]:
def map_insts(data:bytes, lib:bytes, target:int) -> Iterator[tuple[PacketType, InstructionInfo|None]]:
"""maps SQTT packets to instructions, yields (packet, instruction_info or None)"""
# map pcs to insts
pc_map:dict[int, Inst] = {}
image, sections, _ = elf_loader(lib)
text = next((sh for sh in sections if sh.name == ".text"), None)
assert text is not None, "no .text section found"
text_off, text_size = text.header.sh_addr, text.header.sh_size
offset = text_off
while offset < text_off + text_size:
inst = decode_inst(image[offset:])
pc_map[offset-text_off] = inst
offset += inst.size()
from tinygrad.viz.serve import amd_decode
pc_map = amd_decode(lib, target)
wave_pc:dict[int, int] = {}
# only processing packets on one [CU, SIMD] unit
@ -37,7 +26,7 @@ def map_insts(data:bytes, lib:bytes) -> Iterator[tuple[PacketType, InstructionIn
if not simd_select(p): continue
if isinstance(p, WAVESTART):
assert p.wave not in wave_pc, "only one inflight wave per unit"
wave_pc[p.wave] = 0
wave_pc[p.wave] = next(iter(pc_map))
continue
if isinstance(p, WAVEEND):
pc = wave_pc.pop(p.wave)
@ -80,22 +69,22 @@ def map_insts(data:bytes, lib:bytes) -> Iterator[tuple[PacketType, InstructionIn
# test to compare every packet with the rocprof decoder
def test_rocprof_inst_traces_match(sqtt, prg, target):
from tinygrad.viz.serve import llvm_disasm
from tinygrad.viz.serve import amd_decode
from extra.sqtt.roc import decode as roc_decode, InstExec
disasm = {addr+prg.base:inst_disasm for addr, inst_disasm in llvm_disasm(target, prg.lib).items()}
rctx = roc_decode([sqtt], {prg.name:disasm})
rwaves = rctx.inst_execs[(sqtt.kern, sqtt.exec_tag)]
addr_table = amd_decode(prg.lib, target)
disasm = {addr+prg.base:(inst.disasm(), inst.size()) for addr,inst in addr_table.items()}
rctx = roc_decode([sqtt], {prg.tag:disasm})
rwaves = rctx.inst_execs.get((sqtt.kern, sqtt.exec_tag), [])
rwaves_iter:dict[int, list[Iterator[InstExec]]] = {} # wave unit (0-15) -> list of inst trace iterators for all executions on that unit
for w in rwaves: rwaves_iter.setdefault(w.wave_id, []).append(w.unpack_insts())
rwaves_base = next(iter(disasm)) # base program counter
passed_insts = 0
for pkt, info in map_insts(sqtt.blob, prg.lib):
for pkt, info in map_insts(sqtt.blob, prg.lib, target):
if DEBUG >= 2: print_packets([pkt])
if info is None: continue
if DEBUG >= 2: print(f"{' '*29}{info.inst.disasm()}")
rocprof_inst = next(rwaves_iter[info.wave][0])
ref_pc = rocprof_inst.pc-rwaves_base
ref_pc = rocprof_inst.pc-prg.base
# always check pc matches
assert ref_pc == info.pc, f"pc mismatch {ref_pc}:{disasm[rocprof_inst.pc][0]} != {info.pc}:{info.inst.disasm()}"
# special handling for s_endpgm, it marks the wave completion.
@ -110,7 +99,8 @@ def test_rocprof_inst_traces_match(sqtt, prg, target):
for k,v in rwaves_iter.items():
assert len(v) == 0, f"incomplete wave {k}"
print(f"passed for {passed_insts} instructions across {len(rwaves)} waves scheduled on {len(rwaves_iter)} wave units")
if len(rwaves):
print(f"passed for {passed_insts} instructions across {len(rwaves)} waves scheduled on {len(rwaves_iter)} wave units")
if __name__ == "__main__":
import argparse, pickle, pathlib
@ -123,7 +113,7 @@ if __name__ == "__main__":
with open(args.profile, "rb") as f:
data = pickle.load(f)
sqtt_events = [e for e in data if type(e).__name__ == "ProfileSQTTEvent"]
kern_events = {e.name:e for e in data if type(e).__name__ == "ProfileProgramEvent"}
kern_events = {e.tag:e for e in data if type(e).__name__ == "ProfileProgramEvent"}
target = next((e for e in data if type(e).__name__ == "ProfileDeviceEvent" and e.device.startswith("AMD"))).props["gfx_target_version"]
for e in sqtt_events:
if args.kernel is not None and args.kernel != e.kern: continue

View file

@ -43,6 +43,23 @@ VCC = VCC_LO # For VOP3SD sdst field (VCC_LO is exported from dsl)
USE_HW = os.environ.get("USE_HW", "0") == "1"
FLOAT_TOLERANCE = 1e-5
def get_gpu_target() -> tuple[int, int, int]:
"""Get the GPU target as (major, minor, stepping) tuple."""
if not USE_HW: return (0, 0, 0)
from tinygrad.device import Device
return Device["AMD"].target
def skip_unless_gfx(min_major: int, min_minor: int = 0, reason: str = ""):
"""Skip test if GPU target is below the minimum required version."""
import unittest
def decorator(test_func):
if not USE_HW: return test_func
target = get_gpu_target()
if target[0] < min_major or (target[0] == min_major and target[1] < min_minor):
return unittest.skip(reason or f"requires gfx{min_major}{min_minor}0+")(test_func)
return test_func
return decorator
# Output buffer layout: vgpr[16][32], sgpr[16], vcc, scc, exec
N_VGPRS, N_SGPRS, WAVE_SIZE = 16, 16, 32
VGPR_BYTES = N_VGPRS * WAVE_SIZE * 4 # 16 regs * 32 lanes * 4 bytes = 2048
@ -212,8 +229,12 @@ amdhsa.kernels:
return parse_output(bytes(out_buf), n_lanes)
def compare_wave_states(emu_st: WaveState, hw_st: WaveState, n_lanes: int, n_vgprs: int = N_VGPRS) -> list[str]:
"""Compare two WaveStates and return list of differences."""
def compare_wave_states(emu_st: WaveState, hw_st: WaveState, n_lanes: int, n_vgprs: int = N_VGPRS, ulp_tolerance: int = 0) -> list[str]:
"""Compare two WaveStates and return list of differences.
Args:
ulp_tolerance: Allow up to this many ULPs difference for float comparisons (0 = exact match required)
"""
import math
diffs = []
for i in range(n_vgprs):
@ -224,6 +245,11 @@ def compare_wave_states(emu_st: WaveState, hw_st: WaveState, n_lanes: int, n_vgp
emu_f, hw_f = _f32(emu_val), _f32(hw_val)
if math.isnan(emu_f) and math.isnan(hw_f):
continue
# Check ULP difference for floats (only for same-sign values)
if ulp_tolerance > 0 and (emu_val < 0x80000000) == (hw_val < 0x80000000):
ulp_diff = abs(int(emu_val) - int(hw_val))
if ulp_diff <= ulp_tolerance:
continue
diffs.append(f"v[{i}] lane {lane}: emu=0x{emu_val:08x} ({emu_f:.6g}) hw=0x{hw_val:08x} ({hw_f:.6g})")
for i in range(N_SGPRS):
emu_val = emu_st.sgpr[i]
@ -236,16 +262,19 @@ def compare_wave_states(emu_st: WaveState, hw_st: WaveState, n_lanes: int, n_vgp
diffs.append(f"scc: emu={emu_st.scc} hw={hw_st.scc}")
return diffs
def run_program(instructions: list, n_lanes: int = 1) -> WaveState:
def run_program(instructions: list, n_lanes: int = 1, ulp_tolerance: int = 0) -> WaveState:
"""Run instructions and return WaveState.
If USE_HW=1, runs on both emulator and hardware, compares results, and raises if they differ.
Otherwise, runs only on emulator.
Args:
ulp_tolerance: Allow up to this many ULPs difference for float comparisons (0 = exact match required)
"""
emu_st = run_program_emu(instructions, n_lanes)
if USE_HW:
hw_st = run_program_hw(instructions, n_lanes)
diffs = compare_wave_states(emu_st, hw_st, n_lanes)
diffs = compare_wave_states(emu_st, hw_st, n_lanes, ulp_tolerance=ulp_tolerance)
if diffs:
raise AssertionError(f"Emulator vs Hardware mismatch:\n" + "\n".join(diffs))
return hw_st

View file

@ -719,5 +719,47 @@ class TestAtomicOrdering(unittest.TestCase):
self.assertEqual(st.vgpr[0][4], 150, "Final value should be 150")
class TestDsPermute(unittest.TestCase):
"""Tests for DS_PERMUTE_B32 and DS_BPERMUTE_B32 instructions."""
def test_ds_permute_b32_identity(self):
"""DS_PERMUTE_B32 with identity permutation (lane 0 sends to lane 0)."""
# For simplicity, test with single lane
instructions = [
v_mov_b32_e32(v[0], 0), # addr = 0 (lane 0)
v_mov_b32_e32(v[1], 0xDEADBEEF), # data
ds_permute_b32(v[2], v[0], v[1]),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
# Lane 0 sends to lane 0, so lane 0 gets 0xDEADBEEF
self.assertEqual(st.vgpr[0][2], 0xDEADBEEF)
def test_ds_bpermute_b32_identity(self):
"""DS_BPERMUTE_B32 with identity permutation (each lane reads from itself)."""
instructions = [
v_mov_b32_e32(v[0], 0), # addr = 0 (read from lane 0)
v_mov_b32_e32(v[1], 0xCAFEBABE), # data in lane 0
ds_bpermute_b32(v[2], v[0], v[1]),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
# Lane 0 reads from lane 0's v[1]
self.assertEqual(st.vgpr[0][2], 0xCAFEBABE)
def test_ds_permute_b32_broadcast(self):
"""DS_PERMUTE_B32 broadcast - all lanes send to lane 0."""
# With 4 lanes, all sending to lane 0, highest lane wins
instructions = [
v_mov_b32_e32(v[0], 0), # All lanes send to addr 0 (lane 0)
v_mov_b32_e32(v[1], 0x11111111), # All lanes send same data
ds_permute_b32(v[2], v[0], v[1]),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=4)
# Lane 0 receives data (highest numbered active lane wins)
self.assertEqual(st.vgpr[0][2], 0x11111111)
if __name__ == '__main__':
unittest.main()

View file

@ -62,6 +62,7 @@ class TestBasicScalar(unittest.TestCase):
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[1], 0x80000000)
@skip_unless_gfx(11, 5, "SALU FP ops require gfx1150+")
def test_s_fmamk_f32(self):
"""S_FMAMK_F32: D = S0 * literal + S1."""
# 2.0 * 3.0 + 1.0 = 7.0
@ -73,6 +74,7 @@ class TestBasicScalar(unittest.TestCase):
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[2], f2i(7.0))
@skip_unless_gfx(11, 5, "SALU FP ops require gfx1150+")
def test_s_fmamk_f32_negative(self):
"""S_FMAMK_F32 with negative values."""
# -2.0 * 4.0 + 10.0 = 2.0

View file

@ -1561,5 +1561,24 @@ class TestCvtNormF16(unittest.TestCase):
self.assertAlmostEqual(result, 32768, delta=1)
class TestPermlane64(unittest.TestCase):
"""Tests for V_PERMLANE64_B32 instruction (wave64 cross-half swap)."""
def test_v_permlane64_b32_is_nop_in_wave32(self):
"""V_PERMLANE64_B32 is a NOP in wave32 mode.
Per AMD pcode: "if WAVE32 then s_nop(...) else ... endif"
The emulator runs in wave32 mode, so this instruction should not modify registers.
"""
instructions = [
v_mov_b32_e32(v[0], 0xCAFEBABE), # source
v_mov_b32_e32(v[1], 0x12345678), # dest (should be preserved)
v_permlane64_b32_e32(v[1], v[0]), # NOP in wave32
]
st = run_program(instructions, n_lanes=1)
# Dest register should be unchanged (NOP behavior in wave32)
self.assertEqual(st.vgpr[0][1], 0x12345678)
if __name__ == '__main__':
unittest.main()

View file

@ -3262,5 +3262,81 @@ class TestMinMaxF16Vop3(unittest.TestCase):
self.assertAlmostEqual(result, 4.0, delta=0.01)
class TestSadHi(unittest.TestCase):
"""Tests for V_SAD_HI_U8 instruction."""
def test_v_sad_hi_u8_basic(self):
"""V_SAD_HI_U8: (sad << 16) + acc."""
# |1-5| + |2-6| + |3-7| + |4-8| = 16, << 16 = 0x100000, + 100 = 0x100064
instructions = [
v_mov_b32_e32(v[0], 0x04030201),
v_mov_b32_e32(v[1], 0x08070605),
v_mov_b32_e32(v[2], 100),
v_sad_hi_u8(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][3], (16 << 16) + 100)
def test_v_sad_hi_u8_zero_diff(self):
"""V_SAD_HI_U8: identical inputs gives acc only."""
instructions = [
v_mov_b32_e32(v[0], 0x12345678),
v_mov_b32_e32(v[2], 50),
v_sad_hi_u8(v[3], v[0], v[0], v[2]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][3], 50)
class TestPermlane(unittest.TestCase):
"""Tests for V_PERMLANE16_B32 and V_PERMLANEX16_B32 instructions."""
def test_v_permlane16_b32_identity(self):
"""V_PERMLANE16_B32 with identity permutation (lane i reads from lane i within row)."""
# lanesel encodes 4 bits per position: position i gets lanesel[i*4+3:i*4]
# Identity: position 0->0, 1->1, ..., 15->15
# lanesel = 0xFEDCBA9876543210 (positions 15-0 in nibbles)
instructions = [
v_mov_b32_e32(v[0], 0xDEADBEEF), # source data
s_mov_b32(s[0], 0x76543210), # lanesel low (positions 0-7)
s_mov_b32(s[1], 0xFEDCBA98), # lanesel high (positions 8-15)
v_permlane16_b32(v[1], v[0], s[0], s[1]),
]
st = run_program(instructions, n_lanes=1)
# Lane 0 reads from lane 0 (position 0 -> lanesel[3:0] = 0)
self.assertEqual(st.vgpr[0][1], 0xDEADBEEF)
def test_v_permlane16_b32_broadcast(self):
"""V_PERMLANE16_B32 broadcast lane 0 to all lanes in row."""
# lanesel = all zeros -> all positions read from lane 0 within row
instructions = [
v_mov_b32_e32(v[0], 0xCAFEBABE), # source data
s_mov_b32(s[0], 0), # lanesel low = 0 (all read lane 0)
s_mov_b32(s[1], 0), # lanesel high = 0
v_permlane16_b32(v[1], v[0], s[0], s[1]),
]
st = run_program(instructions, n_lanes=4)
# All lanes read from lane 0 of their row
for lane in range(4):
self.assertEqual(st.vgpr[lane][1], 0xCAFEBABE)
def test_v_permlanex16_b32_identity(self):
"""V_PERMLANEX16_B32 cross-row read with identity selection."""
# In wave32: row 0 (lanes 0-15) reads from row 1 (lanes 16-31) and vice versa
# With single lane in row 0, it reads from lane 0 of row 1 (lane 16)
# But lane 16 doesn't exist in 1-lane test, so use 32 lanes
instructions = [
v_mov_b32_e32(v[0], 0x11111111), # All lanes have this initially
s_mov_b32(s[0], 0x76543210), # lanesel low
s_mov_b32(s[1], 0xFEDCBA98), # lanesel high
v_permlanex16_b32(v[1], v[0], s[0], s[1]),
]
st = run_program(instructions, n_lanes=32)
# Lane 0 in row 0 reads from lane 0 of row 1 (lane 16)
self.assertEqual(st.vgpr[0][1], 0x11111111)
# Lane 16 in row 1 reads from lane 0 of row 0 (lane 0)
self.assertEqual(st.vgpr[16][1], 0x11111111)
if __name__ == '__main__':
unittest.main()

View file

@ -767,6 +767,7 @@ class TestDot2F32F16(unittest.TestCase):
"""V_DOT2_F32_F16 with negative f16 values."""
# src0 = {hi=-2.0, lo=3.0}, src1 = {hi=1.0, lo=2.0}
# result = 3.0*2.0 + (-2.0)*1.0 + 0 = 6 - 2 = 4.0
# NOTE: Hardware DOT2 may have up to 1 ULP difference due to internal implementation
src0 = (f32_to_f16(-2.0) << 16) | f32_to_f16(3.0)
src1 = (f32_to_f16(1.0) << 16) | f32_to_f16(2.0)
instructions = [
@ -777,7 +778,7 @@ class TestDot2F32F16(unittest.TestCase):
v_mov_b32_e32(v[2], 0),
v_dot2_f32_f16(v[3], v[0], v[1], v[2], opsel_hi=3, opsel_hi2=1),
]
st = run_program(instructions, n_lanes=1)
st = run_program(instructions, n_lanes=1, ulp_tolerance=1)
result = i2f(st.vgpr[0][3])
self.assertAlmostEqual(result, 4.0, places=2)

View file

@ -0,0 +1,98 @@
import unittest, ctypes
from extra.assembly.amd.autogen.rdna4 import ins as ir4
from extra.assembly.amd.dsl import v, s
from extra.assembly.amd.emu import WaveState, decode_program
from tinygrad.device import Buffer, BufferSpec
from tinygrad.dtype import dtypes
class TestRDNA4Emu(unittest.TestCase):
def _run(self, insts: list, sgprs: dict[int, int] = None, vgprs: dict[tuple[int, int], int] = None) -> WaveState:
"""Run instructions and return final WaveState."""
# Add S_ENDPGM if not present
if not any(isinstance(i, ir4.SOPP) and i.op == ir4.SOPPOp.S_ENDPGM for i in insts):
insts = list(insts) + [ir4.SOPP(ir4.SOPPOp.S_ENDPGM, simm=0)]
# Assemble and decode
code = b''.join(i.to_bytes() for i in insts)
code_buf = (ctypes.c_uint8 * len(code)).from_buffer_copy(code)
code_addr = ctypes.addressof(code_buf)
program_raw = decode_program(code, "rdna4")
program = {code_addr + offset: val for offset, val in program_raw.items()}
# Setup wave state
st = WaveState(n_lanes=1)
st.pc = code_addr
if sgprs:
for idx, val in sgprs.items(): st._write_sgpr(idx, val)
if vgprs:
for (reg, lane), val in vgprs.items(): st._write_vgpr(reg, lane, val)
# Setup vmem buffer with external_ptr=0 (maps to address 0, allows any pointer access)
vmem_buf = Buffer('CPU', 1 << 40, dtypes.uint32, options=BufferSpec(external_ptr=0)).ensure_allocated()
# Execute
c_bufs = [ctypes.c_uint64(st.sgpr_buf._buf.va_addr), ctypes.c_uint64(st.vgpr_buf._buf.va_addr),
ctypes.c_uint64(vmem_buf._buf.va_addr), ctypes.c_uint64(0), ctypes.c_uint64(0)]
for _ in range(100):
if (pc := st.pc) == 0xFFFFFFFFFFFFFFFF or pc not in program: break
_, fxn, globals_list, _ = program[pc]
fxn(*[c_bufs[g] for g in globals_list])
return st
def test_vopd_dual_mov(self):
"""Test VOPD with two V_DUAL_MOV_B32 operations: v[1]=s[1], v[2]=s[2]."""
insts = [ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
vdstx=v[1], vdsty=v[2], srcx0=s[1], srcy0=s[2], vsrcx1=v[0], vsrcy1=v[0])]
st = self._run(insts, sgprs={1: 0x40e00000, 2: 0x41100000}) # 7.0f, 9.0f
self.assertEqual(st._read_vgpr(1, 0), 0x40e00000) # v[1] = 7.0
self.assertEqual(st._read_vgpr(2, 0), 0x41100000) # v[2] = 9.0
def test_vopd_dual_mov_after_other_vopd(self):
"""Test VOPD reuse: first VOPD(v[3]=0, v[0]=?), then VOPD(v[1]=s[1], v[2]=s[2])."""
# This matches the BEAM kernel sequence that fails
insts = [
ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
vdstx=v[3], vdsty=v[0], srcx0=0, srcy0=s[0], vsrcx1=v[0], vsrcy1=v[0]), # v[3]=0, v[0]=s[0]
ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
vdstx=v[1], vdsty=v[2], srcx0=s[1], srcy0=s[2], vsrcx1=v[0], vsrcy1=v[0]), # v[1]=s[1], v[2]=s[2]
]
st = self._run(insts, sgprs={0: 0x40a00000, 1: 0x40e00000, 2: 0x41100000}) # 5.0f, 7.0f, 9.0f
self.assertEqual(st._read_vgpr(1, 0), 0x40e00000) # v[1] = 7.0
self.assertEqual(st._read_vgpr(2, 0), 0x41100000) # v[2] = 9.0
def test_vopd_with_s_add_f32_sequence(self):
"""Test full BEAM kernel sequence: s_add_f32 then VOPD."""
# This is the exact sequence from the failing BEAM kernel
insts = [
ir4.SOP2(ir4.SOP2Op.S_ADD_F32, sdst=s[0], ssrc0=s[0], ssrc1=s[8]), # s[0] = s[0] + s[8]
ir4.SOP2(ir4.SOP2Op.S_ADD_F32, sdst=s[1], ssrc0=s[1], ssrc1=s[9]), # s[1] = s[1] + s[9]
ir4.SOP2(ir4.SOP2Op.S_ADD_F32, sdst=s[2], ssrc0=s[2], ssrc1=s[10]), # s[2] = s[2] + s[10]
ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
vdstx=v[3], vdsty=v[0], srcx0=0, srcy0=s[0], vsrcx1=v[0], vsrcy1=v[0]),
ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
vdstx=v[1], vdsty=v[2], srcx0=s[1], srcy0=s[2], vsrcx1=v[0], vsrcy1=v[0]),
]
# Input: s[0:2] = [1,2,3], s[8:10] = [4,5,6]
# After s_add_f32: s[0:2] = [5,7,9]
st = self._run(insts, sgprs={0: 0x3f800000, 1: 0x40000000, 2: 0x40400000, # 1.0, 2.0, 3.0
8: 0x40800000, 9: 0x40a00000, 10: 0x40c00000}) # 4.0, 5.0, 6.0
self.assertEqual(st._read_vgpr(1, 0), 0x40e00000) # v[1] = 7.0
self.assertEqual(st._read_vgpr(2, 0), 0x41100000) # v[2] = 9.0
def test_s_mov_b32_then_vopd(self):
"""Test s_mov_b32 followed by VOPD - simulates BEAM kernel sequence."""
# Use s_mov_b32 with SGPR source (copy from pre-initialized SGPRs)
# s[10:12] will have values set by test harness, copy to s[0:2], then VOPD to VGPRs
insts = [
ir4.SOP1(ir4.SOP1Op.S_MOV_B32, sdst=s[0], ssrc0=s[10]), # s[0] = s[10]
ir4.SOP1(ir4.SOP1Op.S_MOV_B32, sdst=s[1], ssrc0=s[11]), # s[1] = s[11]
ir4.SOP1(ir4.SOP1Op.S_MOV_B32, sdst=s[2], ssrc0=s[12]), # s[2] = s[12]
ir4.VOPD(ir4.VOPDOp.V_DUAL_MOV_B32, ir4.VOPDOp.V_DUAL_MOV_B32,
vdstx=v[1], vdsty=v[2], srcx0=s[1], srcy0=s[2], vsrcx1=v[0], vsrcy1=v[0]),
]
st = self._run(insts, sgprs={10: 0x40a00000, 11: 0x40e00000, 12: 0x41100000}) # 5.0, 7.0, 9.0
self.assertEqual(st._read_vgpr(1, 0), 0x40e00000) # v[1] = 7.0
self.assertEqual(st._read_vgpr(2, 0), 0x41100000) # v[2] = 9.0
if __name__ == '__main__':
unittest.main()

View file

@ -203,12 +203,12 @@ class SQTTExamplesTestBase(unittest.TestCase):
class TestSQTTExamplesRDNA3(SQTTExamplesTestBase):
target = "gfx1100"
expected = {
"profile_empty_run_0": [1803, 1908, 1928, 1979, 2006, 1912],
"profile_empty_run_1": [1803, 1908, 1928, 1979, 2006, 1912],
"profile_gemm_run_0": [2531, 1844, 1864, 1915, 1942, 1848, 3074, 1919, 1939, 1990, 2017, 1923, 19026, 1919, 1939, 1990, 2017, 1929],
"profile_gemm_run_1": [2554, 1844, 1864, 1915, 1942, 1848, 3084, 1919, 1939, 1990, 2017, 1923, 19010, 1919, 1939, 1990, 2017, 1923],
"profile_plus_run_0": [1900, 1908, 1928, 1979, 2006, 1912],
"profile_plus_run_1": [1856, 1908, 1928, 1979, 2006, 1912],
"profile_empty_run_0": [1844, 1885, 1905, 1956, 1983, 1889],
"profile_empty_run_1": [1780, 1885, 1905, 1956, 1983, 1889],
"profile_gemm_run_0": [2656, 2025, 2045, 2096, 2123, 2029, 3183, 2019, 2039, 2090, 2117, 2023, 19119, 2013, 2033, 2084, 2111, 2017],
"profile_gemm_run_1": [2662, 2025, 2045, 2096, 2123, 2029, 3179, 2019, 2039, 2090, 2117, 2023, 19113, 2071, 2091, 2142, 2169, 2075],
"profile_plus_run_0": [1886, 2013, 2033, 2084, 2111, 2017],
"profile_plus_run_1": [1988, 2071, 2091, 2142, 2169, 2075],
}
class TestSQTTExamplesRDNA4(SQTTExamplesTestBase): target = "gfx1200"

View file

@ -471,7 +471,7 @@ THREADS = 128
def test_matmul():
dev = Device[Device.DEFAULT]
print(f"Device arch: {dev.arch}")
print(f"Device arch: {dev.renderer.arch}")
if getenv("STOCK", 0):
# Load the stock kernel from amd_seb/kernel8_batched_gmem.s
@ -479,7 +479,7 @@ def test_matmul():
asm = stock_path.read_text()
print(f"Loaded stock kernel from {stock_path}")
else:
asm = build_kernel(dev.arch)
asm = build_kernel(dev.renderer.arch)
binary = dev.compiler.compile(asm)
print(f"Compiled! Binary size: {len(binary)} bytes")

11517
extra/gemm/asm/cdna/asm.py Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,95 @@
import atexit, functools
from tinygrad.runtime.support.compiler_amd import HIPCompiler
from tinygrad import Tensor, Device, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
from tinygrad.renderer import Estimates
from tinygrad.helpers import getenv, all_same, dedup
from extra.gemm.asm.cdna.asm import build_kernel, GEMM_ARGS
# ** CDNA4 assembly gemm
WORKGROUP_SIZE = 256
def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str, arch:str, wg:int) -> UOp:
batch, M, K = A.shape
K2, N = B.shape[(1 if B.ndim == 3 else 0):]
assert K == K2
lidx = UOp.special(WORKGROUP_SIZE, "lidx0")
gidx = UOp.special(wg, "gidx0")
k = build_kernel(batch, M, N, K, A.dtype.base)
sink = UOp.sink(C.base, A.base, B.base, lidx, gidx,
arg=KernelInfo(name=k.name, estimates=Estimates(ops=2*batch*M*N*K, mem=(batch*M*K + K*N + batch*M*N)*2)))
binary = HIPCompiler(arch).compile(k.to_asm())
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=k.to_text()), UOp(Ops.BINARY, arg=binary)))
counters = {"used":0, "todos":[]}
def todo(msg:str) -> bool: counters["todos"].append(msg); return False
atexit.register(lambda: print(f'asm_gemm: {counters["used"]} used, {len(counters["todos"])} not used'))
def can_use_asm_gemm(a:Tensor, b:Tensor) -> bool:
if a.dtype != b.dtype: return todo(f"dtypes must match {a.dtype} != {b.dtype}")
if a.dtype not in {dtypes.bfloat16, dtypes.float16}: return todo(f"only bfloat16/float16, got {a.dtype}")
# only sharding on the batch is tested, others might work too
if isinstance(a.device, tuple) and not (a.ndim == 3 and a.uop.axis == 0 and b.uop.axis is None):
return todo(f"sharding mismatch a.ndim={a.ndim} a.uop.axis={a.uop.axis} b.uop.axis={b.uop.axis}")
batch, M, K = (1, *a.shape) if a.ndim == 2 else a.shape
N = b.shape[1]
if isinstance(a.device, tuple): batch //= len(a.device)
if batch not in {1, 2}: return todo(f"GEMM batch size {batch}")
if (key:=(M, N, K)) not in GEMM_ARGS: return todo(f"GEMM shape not supported {key}")
return True
# ** UOp gemm to test Tensor.custom_kernel multi and backward correctness on non cdna4
# note: this can be removed after we have GEMM on mixins
def custom_uop_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
M, K = A.shape[0]*A.shape[1], A.shape[2]
K2, N = B.shape[(1 if B.ndim == 3 else 0):]
assert K == K2
m = UOp.range(M, 1, AxisType.LOOP)
n = UOp.range(N, 2, AxisType.LOOP)
k = UOp.range(K, 0, AxisType.REDUCE)
mul = (A.index((m*UOp.const(dtypes.index, K)+k))*B.index((k*UOp.const(dtypes.index, N)+n))).cast(dtypes.float32)
red = mul.reduce(k, arg=Ops.ADD, dtype=dtypes.float32).cast(C.dtype.base)
store = C.index((m*UOp.const(dtypes.index, N)+n), ptr=True).store(red).end(m, n)
return store.sink(arg=KernelInfo(name=f'uop_gemm_{M}_{N}_{K}'))
# ** backward gemm, might use the asm gemm
def custom_gemm_bw(gradient:UOp, kernel:UOp):
out, a, b = kernel.src
assert all_same([gradient.device, a.device, b.device, out.device])
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
grad_a = (g_t @ b_t.T).uop
a_T = a_t.transpose(-2, -1)
a_T = a_T.reshape(*a_T.shape[:-1], 1, a_T.shape[-1])
g_r = g_t.reshape(*g_t.shape[:-2], 1, *g_t.shape[-2:]).transpose(-1, -2)
grad_b = (a_T * g_r).sum((-1, 0)).uop
return (None, grad_a, grad_b)
# ** main gemm function
def asm_gemm(a:Tensor, b:Tensor) -> Tensor:
assert can_use_asm_gemm(a, b), f"{counters['todos'][-1]}"
counters["used"] += 1
squeeze = a.ndim == 2
if squeeze: a = a.unsqueeze(0)
batch, M, K = a.shape
N = b.shape[1]
is_multi = isinstance(a.device, tuple)
if is_multi:
out = Tensor(Tensor.empty(batch//len(a.device), M, N, dtype=a.dtype, device=a.device).uop.multi(0), device=a.device)
else:
out = Tensor.empty(batch, M, N, dtype=a.dtype, device=a.device)
dname = a.device[0] if is_multi else a.device
arch = getattr(Device[dname].renderer, "arch", None)
if arch.startswith("gfx950") and getenv("USE_ASM", 1):
numWG = GEMM_ARGS[(M, N, K)][0]
out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_asm_gemm, dname=dname, wg=numWG, arch=arch), grad_fxn=custom_gemm_bw)[0]
else:
out = Tensor.custom_kernel(out, a, b, fxn=custom_uop_gemm, grad_fxn=custom_gemm_bw)[0]
return out.squeeze(0) if squeeze else out

File diff suppressed because it is too large Load diff

View file

@ -1,78 +0,0 @@
.text
.section .text.
.global gemm
.p2align 8
.type gemm,@function
gemm:
INSTRUCTIONS
.section .rodata,"a",@progbits
.p2align 6, 0x0
.amdhsa_kernel gemm
# basic memory requirements
.amdhsa_group_segment_fixed_size 133120
.amdhsa_private_segment_fixed_size 0
.amdhsa_kernarg_size 28
# register usage (RSRC1)
.amdhsa_next_free_vgpr 504
.amdhsa_next_free_sgpr 96
# workgroup / workitem IDs (RSRC2)
.amdhsa_system_sgpr_workgroup_id_x 1
.amdhsa_system_sgpr_workgroup_id_y 1
.amdhsa_system_sgpr_workgroup_id_z 1
# user SGPRs, we only specify the kernel args ptr in s[0:1]
.amdhsa_user_sgpr_kernarg_segment_ptr 1
.amdhsa_user_sgpr_count 2
.amdhsa_user_sgpr_kernarg_preload_length 0
.amdhsa_user_sgpr_kernarg_preload_offset 0
# gfx90a / gfx940 specifics (RSRC3)
.amdhsa_accum_offset 248
.amdhsa_uses_dynamic_stack 0
.amdhsa_tg_split 0
.end_amdhsa_kernel
.amdgpu_metadata
---
amdhsa.kernels:
- .name: gemm
.symbol: gemm.kd
.args:
- .name: C
.address_space: global
.offset: 0
.size: 8
.value_kind: global_buffer
.value_type: bf16
- .name: B
.address_space: global
.offset: 8
.size: 8
.value_kind: global_buffer
.value_type: bf16
- .name: A
.address_space: global
.offset: 16
.size: 8
.value_kind: global_buffer
.value_type: bf16
- .name: sz
.offset: 24
.size: 4
.value_kind: by_value
.value_type: u32
.group_segment_fixed_size: 133120
.private_segment_fixed_size: 0
.kernarg_segment_align: 8
.kernarg_segment_size: 28
.max_flat_workgroup_size: 256
.sgpr_count: 88
.sgpr_spill_count: 0
.vgpr_count: 248
.vgpr_spill_count: 0
.wavefront_size: 64
amdhsa.version:
- 1
- 0
...
.end_amdgpu_metadata

View file

@ -1,73 +0,0 @@
# Run assembly on the AMD runtime and check correctness
# VIZ=2 to profile
import pathlib
from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.engine.realize import Estimates
from tinygrad.helpers import getenv
fp = pathlib.Path(__file__).parent/"gemm.s"
N = getenv("N", 8192)
THREADS_PER_WG = 256
NUM_WG = N//THREADS_PER_WG * N//THREADS_PER_WG
assert N % THREADS_PER_WG == 0, "N must be divisible by THREADS_PER_WG"
# ** generate inputs on CPU
scale = 10.0
import torch
torch.manual_seed(0)
A = (torch.randn(N, N, dtype=torch.float32, device="cpu") / scale).to(torch.bfloat16).contiguous()
B = (torch.randn(N, N, dtype=torch.float32, device="cpu") / scale).to(torch.bfloat16).contiguous()
Bt = B.t().contiguous() # transpose B for the asm gemm
C_torch = A@B
# ** copy buffers to AMD
# input creation and validation run on the copy engine for simpler tracing
def from_torch(t:torch.Tensor) -> Tensor:
return Tensor.from_blob(t.data_ptr(), t.shape, dtype=dtypes.bfloat16, device="cpu").to(Device.DEFAULT).realize()
C_tiny = from_torch(A) @ from_torch(B)
C_asm = Tensor.empty_like(C_tiny)
# ** assembly custom kernel
def custom_asm_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
lidx = UOp.special(THREADS_PER_WG, "lidx0")
gidx = UOp.special(NUM_WG, "gidx0")
src = (pathlib.Path(__file__).parent/"template.s").read_text().replace("INSTRUCTIONS", fp.read_text())
sz = UOp.variable("SZ", 256, 8192)
sink = UOp.sink(C.base, A.base, B.base, sz, lidx, gidx, arg=KernelInfo(name="gemm", estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
C_asm = Tensor.custom_kernel(C_asm, from_torch(A), from_torch(Bt), fxn=custom_asm_gemm)[0]
# ** run gemms
sched = Tensor.schedule(C_tiny, C_asm)
eis = [si.lower() for si in sched]
with Context(DEBUG=2):
for ei in eis:
et = ei.run({"SZ":N}, wait=True)
print(f"{(N*N*N*2 / et)*1e-12:.2f} REAL TFLOPS")
# ** correctness
import ctypes
def torch_bf16(t:Tensor) -> torch.tensor:
asm_out = t.to("cpu").realize().uop.buffer._buf
buf = (ctypes.c_uint16*C_asm.uop.size).from_address(asm_out.va_addr)
return torch.frombuffer(buf, dtype=torch.bfloat16, count=C_asm.uop.size).reshape(C_asm.shape)
assert torch.allclose(torch_bf16(C_asm), C_torch, rtol=1e-2, atol=1e-3)
assert torch.allclose(torch_bf16(C_tiny), C_torch, rtol=1e-2, atol=1e-3)

View file

@ -0,0 +1,46 @@
import unittest
from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.helpers import getenv
from extra.gemm.asm.cdna.gemm import asm_gemm
def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.bfloat16, multi=False) -> None:
Tensor.manual_seed(0)
a_rand = Tensor.randn((batch, M, K), dtype=dtypes.float).sub(0.5).cast(dtype)
b_rand = Tensor.randn((K, N), dtype=dtypes.float).sub(0.5).cast(dtype)
with Context(DEBUG=0):
Tensor.realize(a_rand, b_rand)
devs = tuple(f"{Device.DEFAULT}:{i}" for i in range(8)) if multi else None
a, b = Tensor(a_rand.numpy(), requires_grad=True).cast(dtype), Tensor(b_rand.numpy(), requires_grad=True).cast(dtype)
if multi: a, b = a.shard(devs, axis=0), b.shard(devs, axis=None)
tst = asm_gemm(a, b)
tst.sum().backward()
Tensor.realize(tst, a.grad, b.grad)
a_ref, b_ref = Tensor(a_rand.numpy(), requires_grad=True).cast(dtype), Tensor(b_rand.numpy(), requires_grad=True).cast(dtype)
if multi: a_ref, b_ref = a_ref.shard(devs, axis=0), b_ref.shard(devs, axis=None)
with Context(ASM_GEMM=0): ref = a_ref @ b_ref
ref.sum().backward()
Tensor.realize(ref, a_ref.grad, b_ref.grad)
with Context(DEBUG=0):
assert (tst - ref).square().max().float().item() < 1e-6, "forward mismatch"
assert (a.grad - a_ref.grad).square().max().float().item() < 1e-3, "grad_a mismatch"
assert (b.grad - b_ref.grad).square().max().float().item() < 1e-3, "grad_b mismatch"
class TestGemm(unittest.TestCase):
def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 4096), N, N, dtype=dtypes.half)
def test_gemm1(self): verify_asm_gemm(8, 8192, 4096, 14336, multi=True)
def test_gemm2(self): verify_asm_gemm(8, 8192, 128256, 4096, multi=True)
def test_gemm3(self): verify_asm_gemm(8, 8192, 14336, 4096, multi=True)
def test_gemm4(self): verify_asm_gemm(8, 4096, 14336, 4096, multi=True)
def test_gemm5(self): verify_asm_gemm(8, 4096, 4096, 14336, multi=True)
def test_gemm6(self): verify_asm_gemm(16, 4096, 4096, 14336, multi=True)
def test_gemm_unsupported(self):
with self.assertRaisesRegex(AssertionError, "shape not supported"):
verify_asm_gemm(8, 8192, 1024, 4096, multi=True)
if __name__ == "__main__":
unittest.main()

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3
import sys, os, zlib, struct, hashlib
from tinygrad.helpers import DEBUG, getenv, fetch
import os, zlib, struct, hashlib
from tinygrad.helpers import getenv
from tinygrad.runtime.support.usb import USB3
SUPPORTED_CONTROLLERS = [
@ -50,7 +50,7 @@ patched_fw = patch(file_path, file_hash, patches)
dev = None
for vendor, device in SUPPORTED_CONTROLLERS:
try:
dev = USB3(vendor, device, 0x81, 0x83, 0x02, 0x04)
dev = USB3(vendor, device, 0x81, 0x83, 0x02, 0x04, use_bot=True)
break
except RuntimeError: pass
if dev is None:

View file

@ -4,7 +4,7 @@ import tinygrad.runtime.autogen.am.am as am
import tinygrad.runtime.autogen.amdgpu_drm as amdgpu_drm
from tinygrad.helpers import from_mv
from test.mockgpu.driver import VirtDriver, VirtFileDesc, TextFileDesc, DirFileDesc, VirtFile
from test.mockgpu.amd.amdgpu import AMDGPU, gpu_props
from test.mockgpu.amd.amdgpu import AMDGPU, gpu_props, GFX_TARGET_VERSION, MOCKGPU_ARCH
libc = ctypes.CDLL(ctypes.util.find_library("c"))
libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_long]
@ -90,35 +90,30 @@ class AMDDriver(VirtDriver):
def _prepare_gpu(self, gpu_id):
self.doorbells[gpu_id] = memoryview(bytearray(0x2000))
self.gpus[gpu_id] = AMDGPU(gpu_id)
# IP versions: rdna3 = GC 11.0.0, NBIF 4.3.0; rdna4 = GC 12.0.0, NBIF 6.3.1
ip_versions = {"rdna3": {"gc": (11, 0, 0), "sdma": (6, 0, 0), "nbif": (4, 3, 0)},
"rdna4": {"gc": (12, 0, 0), "sdma": (6, 0, 0), "nbif": (6, 3, 1)}}[MOCKGPU_ARCH]
def ip_discovery_files(hwid, ver, base_addr):
p = f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{hwid}/0'
return [VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{hwid}', functools.partial(DirFileDesc, child_names=['0'])),
VirtFile(f'{p}/major', functools.partial(TextFileDesc, text=str(ver[0]))),
VirtFile(f'{p}/minor', functools.partial(TextFileDesc, text=str(ver[1]))),
VirtFile(f'{p}/revision', functools.partial(TextFileDesc, text=str(ver[2]))),
VirtFile(f'{p}/base_addr', functools.partial(TextFileDesc, text=base_addr))]
self.tracked_files += [
VirtFile('/sys/module/amdgpu', functools.partial(TextFileDesc, text="1")),
VirtFile('/sys/module/amdgpu/parameters/ppfeaturemask', functools.partial(TextFileDesc, text="0xffff3fff")),
VirtFile(f'/sys/devices/virtual/kfd/kfd/topology/nodes/{gpu_id}', functools.partial(DirFileDesc, child_names=['gpu_id', 'properties'])),
VirtFile(f'/sys/devices/virtual/kfd/kfd/topology/nodes/{gpu_id}/gpu_id', functools.partial(TextFileDesc, text=f"{gpu_id}")),
VirtFile(f'/sys/devices/virtual/kfd/kfd/topology/nodes/{gpu_id}/properties',
functools.partial(TextFileDesc, text=gpu_props.format(drm_render_minor=gpu_id))),
functools.partial(TextFileDesc, text=gpu_props.format(drm_render_minor=gpu_id, gfx_target_version=GFX_TARGET_VERSION))),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/power_dpm_force_performance_level',
functools.partial(TextFileDesc, text='profile_standard\n')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0',
functools.partial(DirFileDesc, child_names=[str(am.GC_HWID), str(am.SDMA0_HWID), str(am.NBIF_HWID)])),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.GC_HWID}', functools.partial(DirFileDesc, child_names=['0'])),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.GC_HWID}/0/major', functools.partial(TextFileDesc, text='11')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.GC_HWID}/0/minor', functools.partial(TextFileDesc, text='0')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.GC_HWID}/0/revision', functools.partial(TextFileDesc, text='0')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.GC_HWID}/0/base_addr',
functools.partial(TextFileDesc, text='0x00001260\n0x0000A000\n0x0001C000\n0x02402C00')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.SDMA0_HWID}', functools.partial(DirFileDesc, child_names=['0'])),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.SDMA0_HWID}/0/major', functools.partial(TextFileDesc, text='6')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.SDMA0_HWID}/0/minor', functools.partial(TextFileDesc, text='0')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.SDMA0_HWID}/0/revision', functools.partial(TextFileDesc, text='0')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.SDMA0_HWID}/0/base_addr',
functools.partial(TextFileDesc, text='0x00001260\n0x0000A000\n0x0001C000\n0x02402C00')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.NBIF_HWID}', functools.partial(DirFileDesc, child_names=['0'])),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.NBIF_HWID}/0/major', functools.partial(TextFileDesc, text='4')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.NBIF_HWID}/0/minor', functools.partial(TextFileDesc, text='3')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.NBIF_HWID}/0/revision', functools.partial(TextFileDesc, text='0')),
VirtFile(f'/sys/class/drm/renderD{gpu_id}/device/ip_discovery/die/0/{am.NBIF_HWID}/0/base_addr',
functools.partial(TextFileDesc, text='0x00000000\n0x00000014\n0x00000D20\n0x00010400\n0x0241B000\n0x04040000')),
*ip_discovery_files(am.GC_HWID, ip_versions["gc"], '0x00001260\n0x0000A000\n0x0001C000\n0x02402C00'),
*ip_discovery_files(am.SDMA0_HWID, ip_versions["sdma"], '0x00001260\n0x0000A000\n0x0001C000\n0x02402C00'),
*ip_discovery_files(am.NBIF_HWID, ip_versions["nbif"], '0x00000000\n0x00000014\n0x00000D20\n0x00010400\n0x0241B000\n0x04040000'),
VirtFile(f'/dev/dri/renderD{gpu_id}', functools.partial(DRMFileDesc, driver=self, gpu=f"{self.gpus[gpu_id]}")),
]

View file

@ -1,8 +1,11 @@
import ctypes, time
from test.mockgpu.gpu import VirtGPU
from test.mockgpu.helpers import _try_dlopen_remu
from tinygrad.helpers import getbits, to_mv
from tinygrad.helpers import getbits, to_mv, getenv
from tinygrad.runtime.support import c
MOCKGPU_ARCH = getenv("MOCKGPU_ARCH", "rdna3")
GFX_TARGET_VERSION = {"rdna3": 110000, "rdna4": 120000}[MOCKGPU_ARCH]
import tinygrad.runtime.autogen.amd_gpu as amd_gpu, tinygrad.runtime.autogen.am.pm4_nv as pm4
SDMA_MAX_COPY_SIZE = 0x400000
@ -194,10 +197,11 @@ class PM4Executor(AMDQueue):
scratch_size = wavesize * 4 # This gives the scratch size per thread (lane)
assert prg_sz > 0, "Invalid prg ptr (not found in mapped ranges)"
# Pass valid memory ranges, rsrc2, and scratch_size to Python emulator
# Pass valid memory ranges, rsrc2, scratch_size and arch to Python emulator
if hasattr(remu, 'valid_mem_ranges'): remu.valid_mem_ranges = self.gpu.mapped_ranges
if hasattr(remu, 'rsrc2'): remu.rsrc2 = rsrc2
if hasattr(remu, 'scratch_size'): remu.scratch_size = scratch_size
if hasattr(remu, 'arch'): remu.arch = self.gpu.arch
err = remu.run_asm(prg_addr, prg_sz, *gl, *lc, args_addr)
if err != 0: raise RuntimeError("remu does not support the new instruction introduced in this kernel")
@ -314,6 +318,7 @@ class AMDGPU(VirtGPU):
self.regs = AMDGPURegisters()
self.mapped_ranges = set()
self.queues = []
self.arch = MOCKGPU_ARCH
def map_range(self, vaddr, size): self.mapped_ranges.add((vaddr, size))
def unmap_range(self, vaddr, size): self.mapped_ranges.remove((vaddr, size))
@ -342,7 +347,7 @@ simd_arrays_per_engine 2
cu_per_simd_array 8
simd_per_cu 2
max_slots_scratch_cu 32
gfx_target_version 110000
gfx_target_version {gfx_target_version}
vendor_id 4098
device_id 29772
location_id 34304

View file

@ -16,14 +16,15 @@ def _try_dlopen_gpuocelot():
return None
class PythonRemu:
"""Python RDNA3 emulator wrapper that matches the libremu.so interface."""
"""Python RDNA3/RDNA4 emulator wrapper that matches the libremu.so interface."""
valid_mem_ranges: set[tuple[int, int]] = set()
rsrc2: int = 0x19c # Default: USER_SGPR_COUNT=14, enable X and Y workgroup IDs
scratch_size: int = 0 # private_segment_fixed_size from kernel descriptor
arch: str = "rdna3" # Architecture: rdna3 or rdna4
def run_asm(self, lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int) -> int:
from extra.assembly.amd.emu import run_asm
return run_asm(lib, lib_sz, gx, gy, gz, lx, ly, lz, args_ptr, self.rsrc2, self.scratch_size)
return run_asm(lib, lib_sz, gx, gy, gz, lx, ly, lz, args_ptr, self.rsrc2, self.scratch_size, self.arch)
def _try_dlopen_remu():
# Use Python emulator only if PYTHON_REMU=1

View file

@ -195,8 +195,10 @@ class TestAssignIssues(unittest.TestCase):
t.shrink(((1, 3), (1, 3))).assign(Tensor.ones(2, 2))
np.testing.assert_allclose(t.numpy(), torch_tensor.numpy())
@unittest.expectedFailure
def test_assign_broadcast(self):
# broadcasting during assign should behave like PyTorch
# NOTE: we don't want implicit dtype casting (int64 -> float32 loses precision), so this fails
torch_tensor = torch.zeros(3, 5)
torch_tensor[:] = torch.arange(5)
t = Tensor.zeros(3, 5)

View file

@ -256,6 +256,18 @@ class TestMultiTensor(unittest.TestCase):
a,b = _test_allreduce(Tensor.rand(256, 256))
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
def test_multiple_to_single_device_naive(self):
with Context(RING=0):
t = Tensor.arange(32).shard(devices_4, 0).to(Device.DEFAULT).realize()
self.assertEqual(t.device, Device.DEFAULT)
np.testing.assert_equal(t.numpy(), np.arange(32))
def test_multiple_to_single_device_ring(self):
with Context(RING=2):
t = Tensor.arange(32).shard(devices_4, 0).to(Device.DEFAULT).realize()
self.assertEqual(t.device, Device.DEFAULT)
np.testing.assert_equal(t.numpy(), np.arange(32))
def test_allreduce_all2all(self):
with Context(ALL2ALL=2):
a,b = _test_allreduce(Tensor.rand(256, 256))
@ -1273,19 +1285,20 @@ class TestMultiRamUsage(unittest.TestCase):
_ = Tensor.zeros(self.N, self.N).contiguous().shard(devices_2, axis=0).contiguous().realize()
self.assertUsed(self.N*self.N*4) # sharding should not increase total ram usage
def _test_matmul_half(self, devs):
def _test_matmul_half(self, dev_count:int):
N = 32
total_mem = {}
devs = tuple(f"NULL:{i}" for i in range(dev_count))
for dtype in {dtypes.float, dtypes.half}:
GlobalCounters.reset()
a = Tensor.empty((N, N), dtype=dtype).shard(devs, axis=0)
b = Tensor.empty((N, N), dtype=dtype).shard(devs, axis=None)
a = Tensor.empty((N, N), dtype=dtype, device=devs[0]).shard(devs, axis=0)
b = Tensor.empty((N, N), dtype=dtype, device=devs[0]).shard(devs, axis=None)
(a @ b).realize()
total_mem[dtype] = GlobalCounters.global_mem
self.assertEqual(total_mem[dtypes.half], total_mem[dtypes.float] // 2)
def test_matmul_half(self): self._test_matmul_half(devices_2)
def test_matmul_half_alt(self): self._test_matmul_half(devices_4)
def test_matmul_half(self): self._test_matmul_half(dev_count=2)
def test_matmul_half_alt(self): self._test_matmul_half(dev_count=4)
@unittest.skipIf(not_support_multi_device(), "need multi")
class TestMultiFromUnrenderable(unittest.TestCase):

View file

@ -2,7 +2,7 @@
# allow define from star imports
import unittest
import textwrap, functools
import functools
from tinygrad import Device, Tensor
from tinygrad.uop.ops import UOp, Ops, KernelInfo
@ -11,69 +11,27 @@ from tinygrad.runtime.support.compiler_amd import HIPCompiler
from tinygrad.viz.serve import amdgpu_cfg
from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.dsl import Inst
from extra.assembly.amd.dsl import s
template = """.text
.globl fn_name
.p2align 8
.type fn_name,@function
fn_name:
INSTRUCTION
.rodata
.p2align 6
.amdhsa_kernel fn_name
.amdhsa_kernarg_size 8
.amdhsa_user_sgpr_kernarg_segment_ptr 1
.amdhsa_next_free_vgpr .amdgcn.next_free_vgpr
.amdhsa_next_free_sgpr .amdgcn.next_free_sgpr
.amdhsa_wavefront_size32 1
.end_amdhsa_kernel
.amdgpu_metadata
---
amdhsa.version:
- 1
- 0
amdhsa.kernels:
- .name: fn_name
.symbol: fn_name.kd
.group_segment_fixed_size: 0
.private_segment_fixed_size: 0
.wavefront_size: 32
.sgpr_count: 8
.vgpr_count: 8
.max_flat_workgroup_size: 1024
.kernarg_segment_align: 8
.kernarg_segment_size: 8
.args:
- .address_space: global
.name: a
.offset: 0
.size: 8
.type_name: 'float*'
.value_kind: global_buffer
...
.end_amdgpu_metadata
"""
# TODO: this belongs to the dsl infrastructure
from extra.gemm.amd_asm_matmul import Kernel
# TODO: shouldn't need compiler once we can output ELF
# outputs a text disassembly for humans and a machine readable binary
def assemble(name:str, insts:list[str|Inst], compiler:Compiler) -> tuple[str, bytes]:
asm = "\n".join([inst if isinstance(inst, str) else inst.disasm() for inst in insts])
src = template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(asm))
def assemble(name:str, k:Kernel, compiler:Compiler) -> tuple[str, bytes]:
src = k.to_asm()
return (src, compiler.compile(src))
def asm_kernel(out:UOp, insts:list[str|Inst], name:str, device:str, compiler:Compiler, n_threads:int=1, n_workgroups:int=1) -> UOp:
def asm_kernel(out:UOp, k:Kernel, name:str, device:str, compiler:Compiler, n_threads:int=1, n_workgroups:int=1) -> UOp:
lidx = UOp.special(n_threads, "lidx0")
gidx = UOp.special(n_workgroups, "gidx0")
sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name=name))
src, lib = assemble(name, insts, compiler)
src, lib = assemble(name, k, compiler)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
def run_asm(name:str, insts:list) -> None:
fxn = functools.partial(asm_kernel, insts=insts, name=name, device=Device.DEFAULT, compiler=HIPCompiler(Device[Device.DEFAULT].renderer.arch))
def run_asm(name:str, k:Kernel) -> None:
fxn = functools.partial(asm_kernel, k=k, name=name, device=Device.DEFAULT, compiler=HIPCompiler(Device[Device.DEFAULT].renderer.arch))
out = Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0]
out.realize()
@ -85,32 +43,32 @@ class TestCfg(unittest.TestCase):
self.skipTest(f"tests written for RDNA, got arch {arch}")
def test_simple(self):
run_asm("simple", [
"entry:",
"s_branch bb1",
"bb1:",
s_endpgm(),
s_code_end(),
])
k = Kernel(arch=Device["AMD"].arch)
k.label("entry")
k.emit(s_branch(), target="bb1")
k.label("bb1")
k.emit(s_endpgm())
k.emit(s_code_end())
run_asm("simple", k)
def test_diamond(self):
run_asm("diamond", insts:=[
"entry:",
s_mov_b32(s[0], 0),
s_mov_b32(s[1], 0),
s_cmp_eq_u64(s[0:1], 0),
"s_cbranch_scc1 if",
"s_branch else",
"if:",
s_nop(1),
"s_branch end",
"else:",
s_nop(0),
"end:",
s_endpgm(),
s_code_end(),
])
_, lib = assemble("diamond", insts, HIPCompiler(Device[Device.DEFAULT].arch))
k = Kernel(arch=Device["AMD"].arch)
k.label("entry")
k.emit(s_mov_b32(s[0], 0))
k.emit(s_mov_b32(s[1], 0))
k.emit(s_cmp_eq_u64(s[0:1], 0))
k.emit(s_cbranch_scc1(), target="if")
k.emit(s_branch(), target="else")
k.label("if")
k.emit(s_nop(1))
k.emit(s_branch(), target="end")
k.label("else")
k.emit(s_nop(0))
k.label("end")
k.emit(s_endpgm())
k.emit(s_code_end())
run_asm("diamond", k)
_, lib = assemble("diamond", k, HIPCompiler(Device[Device.DEFAULT].arch))
cfg = amdgpu_cfg(lib, Device[Device.DEFAULT].device_props()["gfx_target_version"])["data"]
self.assertEqual(len(cfg["blocks"]), 5)
edge_count = sum(len(v) for v in cfg["paths"].values())
@ -124,133 +82,138 @@ class TestCfg(unittest.TestCase):
self.assertEqual(insts, ['s_mov_b32', 's_cmp_eq_u64'])
def test_loop(self):
run_asm("simple_loop", [
"entry:",
s_mov_b32(s[1], 4),
"loop:",
s_add_u32(s[1], s[1], -1),
s_cmp_eq_i32(s[1], 0),
"s_cbranch_scc0 loop",
s_endpgm(),
s_code_end(),
])
k = Kernel(arch=Device["AMD"].arch)
k.label("entry")
k.emit(s_mov_b32(s[1], 4))
k.label("loop")
k.emit(s_add_u32(s[1], s[1], -1))
k.emit(s_cmp_eq_i32(s[1], 0))
k.emit(s_cbranch_scc0(), target="loop")
k.emit(s_endpgm())
k.emit(s_code_end())
run_asm("simple_loop", k)
def test_loop_branch(self):
run_asm("loop_if", [
"entry:",
s_mov_b32(s[1], 4),
"loop:",
s_add_u32(s[1], s[1], -1),
s_cmp_eq_i32(s[1], 2),
"s_cbranch_scc1 cond",
"s_branch cont",
"cond:",
s_add_u32(s[1], s[1], -2),
"cont:",
s_cmp_eq_i32(s[1], 0),
"s_cbranch_scc0 loop",
s_endpgm(),
s_code_end(),
])
k = Kernel(arch=Device["AMD"].arch)
k.label("entry")
k.emit(s_mov_b32(s[1], 4))
k.label("loop")
k.emit(s_add_u32(s[1], s[1], -1))
k.emit(s_cmp_eq_i32(s[1], 2))
k.emit(s_cbranch_scc1(), target="cond")
k.emit(s_branch(), target="cont")
k.label("cond")
k.emit(s_add_u32(s[1], s[1], -2))
k.label("cont")
k.emit(s_cmp_eq_i32(s[1], 0))
k.emit(s_cbranch_scc0(), target="loop")
k.emit(s_endpgm())
k.emit(s_code_end())
run_asm("loop_if", k)
def test_loop_break(self):
run_asm("loop_break", [
"entry:",
s_mov_b32(s[1], 8),
"loop:",
s_add_u32(s[1], s[1], -1),
s_cmp_eq_i32(s[1], 5),
"s_cbranch_scc1 break",
s_cmp_eq_i32(s[1], 0),
"s_cbranch_scc0 loop",
"break:",
s_endpgm(),
s_code_end(),
])
k = Kernel(arch=Device["AMD"].arch)
k.label("entry")
k.emit(s_mov_b32(s[1], 8))
k.label("loop")
k.emit(s_add_u32(s[1], s[1], -1))
k.emit(s_cmp_eq_i32(s[1], 5))
k.emit(s_cbranch_scc1(), target="break")
k.emit(s_cmp_eq_i32(s[1], 0))
k.emit(s_cbranch_scc0(), target="loop")
k.label("break")
k.emit(s_endpgm())
k.emit(s_code_end())
run_asm("loop_break", k)
def test_switch(self):
run_asm("switch_case", [
"entry:",
s_cmp_eq_i32(s[0], 0),
"s_cbranch_scc1 case0",
s_cmp_eq_i32(s[0], 1),
"s_cbranch_scc1 case1",
"s_branch case2",
"case0:",
s_nop(0),
"s_branch join",
"case1:",
s_nop(1),
"s_branch join",
"case2:",
s_nop(2),
"s_branch join",
"join:",
s_endpgm(),
s_code_end(),
])
k = Kernel(arch=Device["AMD"].arch)
k.label("entry")
k.emit(s_cmp_eq_i32(s[0], 0))
k.emit(s_cbranch_scc1(), target="case0")
k.emit(s_cmp_eq_i32(s[0], 1))
k.emit(s_cbranch_scc1(), target="case1")
k.emit(s_branch(), target="case2")
k.label("case0")
k.emit(s_nop(0))
k.emit(s_branch(), target="join")
k.label("case1")
k.emit(s_nop(1))
k.emit(s_branch(), target="join")
k.label("case2")
k.emit(s_nop(2))
k.emit(s_branch(), target="join")
k.label("join")
k.emit(s_endpgm())
k.emit(s_code_end())
run_asm("switch_case", k)
def test_ping_pong(self):
run_asm("ping_pong", [
"entry:",
s_cmp_eq_i32(s[0], 0),
"s_cbranch_scc1 ping",
"s_branch pong",
"ping:",
s_cmp_eq_i32(s[1], 0),
"s_cbranch_scc1 pong",
"s_branch end",
"pong:",
s_cmp_eq_i32(s[2], 0),
"s_cbranch_scc1 ping",
"end:",
s_endpgm(),
s_code_end(),
])
k = Kernel(arch=Device["AMD"].arch)
k.label("entry")
k.emit(s_cmp_eq_i32(s[0], 0))
k.emit(s_cbranch_scc1(), target="ping")
k.emit(s_branch(), target="pong")
k.label("ping")
k.emit(s_cmp_eq_i32(s[1], 0))
k.emit(s_cbranch_scc1(), target="pong")
k.emit(s_branch(), target="end")
k.label("pong")
k.emit(s_cmp_eq_i32(s[2], 0))
k.emit(s_cbranch_scc1(), target="ping")
k.label("end")
k.emit(s_endpgm())
k.emit(s_code_end())
run_asm("ping_pong", k)
def test_colored_blocks(self):
N = 10
asm = ["entry:", "s_branch init0"]
k = Kernel(arch=Device["AMD"].arch)
k.label("entry")
k.emit(s_branch(), target="init0")
for i in range(N):
asm += [f"init{i}:", s_mov_b32(s[1], i + 1), f"s_branch {(loop:=f'loop{i}')}"]
asm += [
f"{loop}:",
s_nop(i & 7),
s_add_u32(s[1], s[1], -1),
s_cmp_eq_i32(s[1], 0),
f"s_cbranch_scc0 {loop}",
f"s_branch {'init' + str(i+1) if i + 1 < N else 'end'}",
]
asm += ["end:", s_endpgm(), s_code_end()]
run_asm("test_colored_blocks", asm)
loop = f"loop{i}"
k.label(f"init{i}")
k.emit(s_mov_b32(s[1], i + 1))
k.emit(s_branch(), target=loop)
k.label(loop)
k.emit(s_nop(i & 7))
k.emit(s_add_u32(s[1], s[1], -1))
k.emit(s_cmp_eq_i32(s[1], 0))
k.emit(s_cbranch_scc0(), target=loop)
k.emit(s_branch(), target=f"init{i+1}" if i + 1 < N else "end")
k.label("end")
k.emit(s_endpgm())
k.emit(s_code_end())
run_asm("test_colored_blocks", k)
def test_jump_back_to_end(self):
run_asm("jump_back_to_end", [
"entry:",
s_mov_b32(s[1], 2),
"s_cbranch_execz loop",
"end:",
s_endpgm(),
"loop:",
s_add_u32(s[1], s[1], -1),
s_cmp_eq_i32(s[1], 0),
"s_branch end",
s_code_end(),
])
k = Kernel(arch=Device["AMD"].arch)
k.label("entry")
k.emit(s_mov_b32(s[1], 2))
k.emit(s_cbranch_execz(), target="loop")
k.label("end")
k.emit(s_endpgm())
k.label("loop")
k.emit(s_add_u32(s[1], s[1], -1))
k.emit(s_cmp_eq_i32(s[1], 0))
k.emit(s_branch(), target="end")
k.emit(s_code_end())
run_asm("jump_back_to_end", k)
def test_hit_count(self):
run_asm("test_hit_count", [
"entry:",
s_mov_b32(s[1], 1),
"s_branch alt",
"continue:",
s_mov_b32(s[2], 2),
s_add_u32(s[1], s[1], s[2]),
"alt:",
s_add_u32(s[1], s[1], -1),
s_endpgm(),
s_code_end(),
])
k = Kernel(arch=Device["AMD"].arch)
k.label("entry")
k.emit(s_mov_b32(s[1], 1))
k.emit(s_branch(), target="alt")
k.label("continue")
k.emit(s_mov_b32(s[2], 2))
k.emit(s_add_u32(s[1], s[1], s[2]))
k.label("alt")
k.emit(s_add_u32(s[1], s[1], -1))
k.emit(s_endpgm())
k.emit(s_code_end())
run_asm("test_hit_count", k)
if __name__ == "__main__":
unittest.main()

View file

@ -461,6 +461,27 @@ class TestAssign(unittest.TestCase):
a[2:5] = [1, 2, 3]
np.testing.assert_allclose(a.numpy(), [0., 0., 1., 2., 3., 0., 0., 0.])
def test_assign_bitcast(self):
# assign to a bitcast view should modify the underlying buffer
a = Tensor([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32).realize()
# IEEE 754: 1.0f = 0x3f800000, 2.0f = 0x40000000, 3.0f = 0x40400000, 4.0f = 0x40800000
a.bitcast(dtypes.uint32).assign(Tensor([0x40800000, 0x40400000, 0x40000000, 0x3f800000], dtype=dtypes.uint32)).realize()
np.testing.assert_allclose(a.numpy(), [4.0, 3.0, 2.0, 1.0])
# double bitcast
b = Tensor([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32).realize()
b.bitcast(dtypes.uint32).bitcast(dtypes.int32).assign(Tensor([0x40800000, 0x40400000, 0x40000000, 0x3f800000], dtype=dtypes.int32)).realize()
np.testing.assert_allclose(b.numpy(), [4.0, 3.0, 2.0, 1.0])
# shrink then bitcast
c = Tensor([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32).realize()
c[0:2].bitcast(dtypes.uint32).assign(Tensor([0x40800000, 0x40400000], dtype=dtypes.uint32)).realize()
np.testing.assert_allclose(c.numpy(), [4.0, 3.0, 3.0, 4.0])
def test_assign_bitcast_different_size(self):
# different-size bitcast creates a new tensor, not a view, so assign doesn't modify the original
a = Tensor([0]*8, dtype=dtypes.uint8).realize()
a.bitcast(dtypes.int64).assign(Tensor([12345], dtype=dtypes.int64)).realize()
np.testing.assert_equal(a.numpy(), [0]*8)
@unittest.skip("don't use output buffer, and mismatch dtype no longer supported")
def test_cast_assignment(self):
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
@ -472,6 +493,38 @@ class TestAssign(unittest.TestCase):
assert oba1 is None and oba2 is None
np.testing.assert_allclose(a.numpy(), np.arange(N*N,dtype=np.int32).reshape((N,N)))
def test_assign_dtype_mismatch(self):
# assign should not implicitly cast dtypes - this can lose precision
a = Tensor.zeros(4, dtype=dtypes.float32).contiguous().realize()
b = Tensor([1, 2, 3, 4], dtype=dtypes.int32)
with self.assertRaisesRegex(RuntimeError, "assign dtype mismatch"):
a.assign(b)
def test_assign_dtype_mismatch_int64_to_float32(self):
# int64 -> float32 loses precision for large values, should not be implicit
a = Tensor.zeros(1, dtype=dtypes.float32).contiguous().realize()
b = Tensor([16777217], dtype=dtypes.int64) # 2^24 + 1, not exactly representable in float32
with self.assertRaisesRegex(RuntimeError, "assign dtype mismatch"):
a.assign(b)
def test_assign_shape_broadcast(self):
# shape broadcasting should work when dtypes match
a = Tensor.zeros(3, 5, dtype=dtypes.float32).contiguous().realize()
b = Tensor([1., 2., 3., 4., 5.], dtype=dtypes.float32)
a.assign(b)
a.realize()
expected = np.array([[1., 2., 3., 4., 5.]] * 3)
np.testing.assert_allclose(a.numpy(), expected)
def test_assign_shape_broadcast_2d(self):
# broadcast (1, 5) to (3, 5)
a = Tensor.zeros(3, 5, dtype=dtypes.float32).contiguous().realize()
b = Tensor([[1., 2., 3., 4., 5.]], dtype=dtypes.float32)
a.assign(b)
a.realize()
expected = np.array([[1., 2., 3., 4., 5.]] * 3)
np.testing.assert_allclose(a.numpy(), expected)
def test_disk_assignment(self):
a = Tensor.empty(5, device=f"disk:{temp('disk_assignment')}").assign(Tensor.ones(5)).numpy()
np.testing.assert_equal(a, np.ones(5))
@ -566,12 +619,12 @@ class TestAssignOrdering(unittest.TestCase):
def test_slice_write_then_full_read(self):
"""Write to slice, then read full buffer."""
# without .realize(): orphan slice assign not triggered by .numpy()
buf = Tensor.zeros(4).contiguous().realize()
buf = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
buf[1:3].assign(Tensor([5, 6]))
np.testing.assert_equal(buf.numpy(), [0, 0, 0, 0]) # TODO: wrong! should be [0, 5, 6, 0]
# with .realize(): assign executes
buf = Tensor.zeros(4).contiguous().realize()
buf = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
buf[1:3].assign(Tensor([5, 6])).realize()
np.testing.assert_equal(buf.numpy(), [0, 5, 6, 0])
@ -653,7 +706,7 @@ class TestAssignOrdering(unittest.TestCase):
def test_three_buffer_chain(self):
"""Chain: A depends on B, B depends on C - ordering matters."""
a = Tensor.zeros(4).contiguous().realize()
a = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
b = Tensor([1, 2, 3, 4]).contiguous().realize()
c = Tensor([10, 10, 10, 10]).contiguous().realize()
# b reads from c, a reads from b
@ -665,8 +718,8 @@ class TestAssignOrdering(unittest.TestCase):
def test_interleaved_assign_read_patterns(self):
"""Complex interleaved pattern: write A, read A into B, write B, read B."""
a = Tensor.zeros(4).contiguous().realize()
b = Tensor.zeros(4).contiguous().realize()
a = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
b = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
a.assign(Tensor([1, 2, 3, 4]))
b.assign(a.contiguous()) # b should get [1,2,3,4]

View file

@ -5,24 +5,13 @@ from tinygrad.runtime.support.c import DLL, record, init_records
from tinygrad.runtime.support import c
from tinygrad.runtime.support.autogen import gen
class TestAutogen(unittest.TestCase):
@unittest.skipIf(WIN, "doesn't compile on windows")
class TestC(unittest.TestCase):
def compile(self, src):
with tempfile.NamedTemporaryFile(suffix=".so") as f:
subprocess.check_output(('clang', '-x', 'c', '-fPIC', '-shared', '-', '-o', f.name), input=src.encode())
return DLL("test", f.name)
def run_gen(self, contents):
with tempfile.NamedTemporaryFile(mode='w', suffix='.h') as f:
f.write(contents)
f.flush()
generated_code = gen(name="test_header", dll=None, files=[f.name])
namespace = {}
exec(generated_code, namespace)
return namespace
@unittest.skipIf(WIN, "doesn't compile on windows")
def test_packed_struct(self):
@record
class Baz:
@ -45,7 +34,6 @@ class TestAutogen(unittest.TestCase):
assert b.c == 1
assert b.d == 0
@unittest.skipIf(WIN, "doesn't compile on windows")
def test_packed_struct_interop(self):
@record
class Baz:
@ -75,7 +63,6 @@ class TestAutogen(unittest.TestCase):
self.assertEqual(test(b), b.a + b.b + b.c + b.d)
# https://github.com/python/cpython/issues/90914
@unittest.skipIf(WIN, "doesn't compile on windows")
def test_bitfield_interop(self):
@record
class Baz:
@ -103,7 +90,6 @@ class TestAutogen(unittest.TestCase):
def test(x:Baz) -> ctypes.c_int: ...
for i in range(8): self.assertEqual(test(Baz(*(j==i for j in range(8)))), i==2)
@unittest.skipIf(WIN, "doesn't compile on windows")
def test_struct_interop(self):
@record
class Baz:
@ -131,7 +117,6 @@ class TestAutogen(unittest.TestCase):
def test(x:Baz) -> Baz: ...
self.assertEqual(bytes(test(Baz(*range(8)))), struct.pack("8i", *range(7, -1, -1)))
@unittest.skipIf(WIN, "doesn't compile on windows")
def test_aos_interop(self):
@record
class Item:
@ -151,7 +136,6 @@ class TestAutogen(unittest.TestCase):
def test(arr:(Item * 3)) -> ctypes.c_int: ...
self.assertEqual(test((Item * 3)(Item(10), Item(20), Item(30))), 60)
@unittest.skipIf(WIN, "doesn't compile on windows")
def test_soa_interop(self):
@record
class Row:
@ -173,7 +157,6 @@ class TestAutogen(unittest.TestCase):
self.assertEqual(r.data[1], 20)
self.assertEqual(r.data[2], 10)
@unittest.skipIf(WIN, "doesn't compile on windows")
def test_soa_ptr_interop(self):
@record
class Row:
@ -191,7 +174,6 @@ class TestAutogen(unittest.TestCase):
def test(x:Row) -> ctypes.c_int: ...
assert test(Row((ctypes.c_int * 3)(10, 20, 30))) == 60
@unittest.skipIf(WIN, "doesn't compile on windows")
def test_nested_struct_interop(self):
@record
class Inner:
@ -217,7 +199,6 @@ class TestAutogen(unittest.TestCase):
self.assertEqual(o.inner.a, 20)
self.assertEqual(o.b, 10)
@unittest.skipIf(WIN, "doesn't compile on windows")
def test_struct_pointer_interop(self):
@record
class Foo:
@ -242,7 +223,88 @@ class TestAutogen(unittest.TestCase):
self.assertEqual(out.contents.a, 20)
self.assertEqual(out.contents.b, 10)
@unittest.skipIf(WIN, "doesn't compile on windows")
def test_pointer_field_roundtrip(self):
# This tests storing a pointer in a record struct field and passing it to C
# Mimics how mesa.struct_lp_build_tgsi_params.mask is used
from tinygrad.runtime.support.c import POINTER
@record
class Inner:
SIZE = 8
value: Annotated[ctypes.c_int, 0]
flag: Annotated[ctypes.c_int, 4]
@record
class Outer:
SIZE = 16
x: Annotated[ctypes.c_int, 0]
inner_ptr: Annotated[POINTER[Inner], 8]
init_records()
src = """
struct inner { int value; int flag; };
struct outer { int x; struct inner *inner_ptr; };
int test(struct inner *p) {
return p->value + p->flag;
}
"""
dll = self.compile(src)
@dll.bind
def test(p:POINTER[Inner]) -> ctypes.c_int: ...
inner = Inner(value=42, flag=10)
outer = Outer(x=1, inner_ptr=ctypes.pointer(inner))
# Retrieve pointer from struct field and pass to C
self.assertEqual(test(outer.inner_ptr), 52)
def test_pointer_field_loses_reference(self):
# BUG: When a pointer is stored in a record struct field, only the address bytes are saved.
# The pointer's _objects dict (which prevents GC of the pointed-to object) is lost.
# This causes the pointed-to object to be garbage collected, leading to use-after-free.
from tinygrad.runtime.support.c import POINTER
@record
class MaskContext:
SIZE = 16
value: Annotated[ctypes.c_int, 0]
initialized: Annotated[ctypes.c_int, 4]
ptr: Annotated[ctypes.c_void_p, 8]
@record
class Params:
SIZE = 16
x: Annotated[ctypes.c_int, 0]
mask: Annotated[POINTER[MaskContext], 8]
init_records()
src = """
struct mask_ctx { int value; int initialized; void *ptr; };
void mask_begin(struct mask_ctx *m, int val) { m->value = val; m->initialized = 1; }
int mask_end(struct mask_ctx *m) { return m->value + m->initialized; }
"""
dll = self.compile(src)
@dll.bind
def mask_begin(m:POINTER[MaskContext], val:ctypes.c_int) -> None: ...
@dll.bind
def mask_end(m:POINTER[MaskContext]) -> ctypes.c_int: ...
# When MaskContext() is created inline, it gets garbage collected after the pointer
# is stored because only the address bytes are saved, not the _objects reference.
params = Params(x=1, mask=ctypes.pointer(MaskContext()))
mask_begin(params.mask, 42)
result = mask_end(params.mask)
self.assertEqual(result, 43) # 42 + 1
@unittest.skipIf(OSX and ('MTLCompiler' in DLL._loaded_ or 'llvm' in DLL._loaded_), "libclang can't be loaded after MTLCompiler or llvm on OSX")
@unittest.skipIf(WIN, "doesn't compile on windows")
class TestAutogen(unittest.TestCase):
def run_gen(self, contents):
with tempfile.NamedTemporaryFile(mode='w', suffix='.h') as f:
f.write(contents)
f.flush()
generated_code = gen(name="test_header", dll=None, files=[f.name])
namespace = {}
exec(generated_code, namespace)
return namespace
def test_packed_structs(self):
ns = self.run_gen("""
typedef unsigned NvU32;
@ -292,47 +354,6 @@ typedef struct
assert frts_cmd.readVbiosDesc.__class__ is FWSECLIC_READ_VBIOS_DESC
assert frts_cmd.frtsRegionDesc.__class__ is FWSECLIC_FRTS_REGION_DESC
@unittest.skipIf(WIN, "doesn't compile on windows")
@unittest.skipIf(OSX, "can't find stdint?")
def test_packed_fields(self):
ns = self.run_gen("""#include <stdint.h>
typedef struct die_info
{
uint16_t die_id;
uint16_t die_offset; /* Points to the corresponding die_header structure */
} die_info;
typedef struct ip_discovery_header
{
uint32_t signature; /* Table Signature */
uint16_t version; /* Table Version */
uint16_t size; /* Table Size */
uint32_t id; /* Table ID */
uint16_t num_dies; /* Number of Dies */
die_info die_info[16]; /* list die information for up to 16 dies */
union {
uint16_t padding[1]; /* version <= 3 */
struct { /* version == 4 */
uint8_t base_addr_64_bit : 1; /* ip structures are using 64 bit base address */
uint8_t reserved : 7;
uint8_t reserved2;
};
};
} ip_discovery_header;
""")
ip_discovery_header = ns['ip_discovery_header']
hdr = b'IPDS\x04\x00|\x1d\x80\x1a\xffd\x01\x00\x00\x00\x8c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00' # noqa: E501
ihdr = ip_discovery_header.from_buffer_copy(hdr)
assert ctypes.sizeof(ihdr) == 80
assert ihdr.signature == 0x53445049
assert ihdr.version == 0x0004
assert ihdr.num_dies == 1
assert ihdr.base_addr_64_bit == 1
@unittest.skipIf(WIN, "doesn't compile on windows")
def test_gen_from_header(self):
namespace = self.run_gen("""
typedef struct {
@ -378,7 +399,6 @@ typedef struct ip_discovery_header
self.assertTrue(hasattr(rect, 'height'))
self.assertTrue(hasattr(rect, 'color'))
@unittest.skipIf(WIN, "doesn't compile on windows")
def test_struct_ordering(self):
namespace = self.run_gen("""
struct A;
@ -408,77 +428,6 @@ typedef struct ip_discovery_header
self.assertTrue(hasattr(b, 'c_ptr'))
self.assertTrue(hasattr(c, 'a_ptr'))
@unittest.skipIf(WIN, "doesn't compile on windows")
def test_pointer_field_roundtrip(self):
# This tests storing a pointer in a record struct field and passing it to C
# Mimics how mesa.struct_lp_build_tgsi_params.mask is used
from tinygrad.runtime.support.c import POINTER
@record
class Inner:
SIZE = 8
value: Annotated[ctypes.c_int, 0]
flag: Annotated[ctypes.c_int, 4]
@record
class Outer:
SIZE = 16
x: Annotated[ctypes.c_int, 0]
inner_ptr: Annotated[POINTER[Inner], 8]
init_records()
src = """
struct inner { int value; int flag; };
struct outer { int x; struct inner *inner_ptr; };
int test(struct inner *p) {
return p->value + p->flag;
}
"""
dll = self.compile(src)
@dll.bind
def test(p:POINTER[Inner]) -> ctypes.c_int: ...
inner = Inner(value=42, flag=10)
outer = Outer(x=1, inner_ptr=ctypes.pointer(inner))
# Retrieve pointer from struct field and pass to C
self.assertEqual(test(outer.inner_ptr), 52)
@unittest.skipIf(WIN, "doesn't compile on windows")
def test_pointer_field_loses_reference(self):
# BUG: When a pointer is stored in a record struct field, only the address bytes are saved.
# The pointer's _objects dict (which prevents GC of the pointed-to object) is lost.
# This causes the pointed-to object to be garbage collected, leading to use-after-free.
from tinygrad.runtime.support.c import POINTER
@record
class MaskContext:
SIZE = 16
value: Annotated[ctypes.c_int, 0]
initialized: Annotated[ctypes.c_int, 4]
ptr: Annotated[ctypes.c_void_p, 8]
@record
class Params:
SIZE = 16
x: Annotated[ctypes.c_int, 0]
mask: Annotated[POINTER[MaskContext], 8]
init_records()
src = """
struct mask_ctx { int value; int initialized; void *ptr; };
void mask_begin(struct mask_ctx *m, int val) { m->value = val; m->initialized = 1; }
int mask_end(struct mask_ctx *m) { return m->value + m->initialized; }
"""
dll = self.compile(src)
@dll.bind
def mask_begin(m:POINTER[MaskContext], val:ctypes.c_int) -> None: ...
@dll.bind
def mask_end(m:POINTER[MaskContext]) -> ctypes.c_int: ...
# When MaskContext() is created inline, it gets garbage collected after the pointer
# is stored because only the address bytes are saved, not the _objects reference.
params = Params(x=1, mask=ctypes.pointer(MaskContext()))
mask_begin(params.mask, 42)
result = mask_end(params.mask)
self.assertEqual(result, 43) # 42 + 1
@unittest.skipIf(WIN, "doesn't compile on windows")
def test_anonymous_children(self):
namespace = self.run_gen("""
struct foo {
@ -491,7 +440,6 @@ typedef struct ip_discovery_header
self.assertIn('struct_foo', namespace)
self.assertIn('struct_foo_bar', namespace)
@unittest.skipIf(WIN, "doesn't compile on windows")
def test_enums(self):
namespace = self.run_gen("""
enum Foo { A, B, C };
@ -511,4 +459,43 @@ typedef struct ip_discovery_header
assert namespace["enum_Bar"].get(1) == "Y"
assert namespace["enum_Bar"].get(2) == "Z"
@unittest.skipIf(OSX, "can't find stdint?")
def test_packed_fields(self):
ns = self.run_gen("""#include <stdint.h>
typedef struct die_info
{
uint16_t die_id;
uint16_t die_offset; /* Points to the corresponding die_header structure */
} die_info;
typedef struct ip_discovery_header
{
uint32_t signature; /* Table Signature */
uint16_t version; /* Table Version */
uint16_t size; /* Table Size */
uint32_t id; /* Table ID */
uint16_t num_dies; /* Number of Dies */
die_info die_info[16]; /* list die information for up to 16 dies */
union {
uint16_t padding[1]; /* version <= 3 */
struct { /* version == 4 */
uint8_t base_addr_64_bit : 1; /* ip structures are using 64 bit base address */
uint8_t reserved : 7;
uint8_t reserved2;
};
};
} ip_discovery_header;
""")
ip_discovery_header = ns['ip_discovery_header']
hdr = b'IPDS\x04\x00|\x1d\x80\x1a\xffd\x01\x00\x00\x00\x8c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00' # noqa: E501
ihdr = ip_discovery_header.from_buffer_copy(hdr)
assert ctypes.sizeof(ihdr) == 80
assert ihdr.signature == 0x53445049
assert ihdr.version == 0x0004
assert ihdr.num_dies == 1
assert ihdr.base_addr_64_bit == 1
if __name__ == "__main__": unittest.main()

View file

@ -42,7 +42,7 @@ class TestCall(unittest.TestCase):
b = Tensor.randn(K, N)
Tensor.realize(a, b)
c = Tensor.call(a, b, fxn=a.as_param(0) @ b.as_param(1))
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5)
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6)
@unittest.skip("needs GEMM on mixins")
def test_call_gemm_uop(self):
@ -56,7 +56,7 @@ class TestCall(unittest.TestCase):
y = UOp.param(1, dtypes.float, shape=(K, N))
c = Tensor.call(a, b, fxn=x@y)
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5)
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6)
if __name__ == '__main__':
unittest.main()

View file

@ -59,13 +59,13 @@ class TestRawDiskBuffer(unittest.TestCase):
_test_bitcasted(t, dtypes.float32, 0.0)
_test_bitcasted(t, dtypes.uint32, 0)
# pi in float16 stored via int16
t.bitcast(dtypes.uint16).assign(Tensor.full((128, 64), 0x4248, dtype=dtypes.uint16)).realize()
t.assign(Tensor.full((128, 64), 0x4248, dtype=dtypes.uint16).bitcast(dtypes.uint8)).realize()
_test_bitcasted(t, dtypes.float16, 3.140625)
_test_bitcasted(t, dtypes.float32, 50.064727)
_test_bitcasted(t, dtypes.uint16, 0x4248)
_test_bitcasted(t, dtypes.uint32, 0x42484248)
# pi in float32 stored via float32
t.bitcast(dtypes.float32).assign(Tensor.full((128, 32), 3.1415927, dtype=dtypes.float32)).realize()
t.assign(Tensor.full((128, 32), 3.1415927, dtype=dtypes.float32).bitcast(dtypes.uint8)).realize()
_test_bitcasted(t, dtypes.float32, 3.1415927)
_test_bitcasted(t, dtypes.uint32, 0x40490FDB)
# doesn't suport normal cast
@ -348,13 +348,25 @@ class TestDiskTensor(unittest.TestCase):
def test_assign_with_bitcast(self):
# bitcast assign is used in safe_save for writing header length
# this tests the synchronous disk assign hack handles bitcast correctly
# bitcast on source side works, bitcast on target side raises
pathlib.Path(temp(fn:="dt_assign_bitcast")).unlink(missing_ok=True)
t = Tensor.empty(16, device=f"disk:{temp(fn)}", dtype=dtypes.uint8)
t[0:8].bitcast(dtypes.int64).assign([12345])
# verify the data was written correctly
# correct way: bitcast the source to match target dtype
t[0:8].assign(Tensor([12345], dtype=dtypes.int64, device="CPU").bitcast(dtypes.uint8))
val = int.from_bytes(t[0:8].data(), 'little')
self.assertEqual(val, 12345)
# bitcast on target with non-broadcastable dtype raises
with self.assertRaises(RuntimeError):
t[0:4].bitcast(dtypes.int32).assign(Tensor([12345], dtype=dtypes.int64))
def test_assign_to_bitcast_view(self):
# assign float values to a float32 view of a uint8 disk buffer (used by safe_save)
pathlib.Path(temp(fn:="dt_bitcast_view_assign")).unlink(missing_ok=True)
t = Tensor.empty(32, device=f"disk:{temp(fn)}", dtype=dtypes.uint8)
# create float32 view of bytes 8-24 (4 floats)
float_view = t[8:24].bitcast(dtypes.float32)
float_view.assign(Tensor([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32, device="CPU"))
np.testing.assert_array_equal(float_view.numpy(), [1.0, 2.0, 3.0, 4.0])
def test_assign_cross_device(self):
# disk assign allows cross-device (source on GPU/CPU, target on disk)

View file

@ -1,5 +1,5 @@
import unittest
from tinygrad import dtypes, Device
from tinygrad import dtypes
from tinygrad.device import Buffer
from tinygrad.engine.memory import _internal_memory_planner
@ -7,7 +7,7 @@ global_map = {}
def b(i, base=None, offset=0, pin=False, size=16):
global global_map
if i in global_map: return global_map[i]
global_map[i] = Buffer(Device.DEFAULT, size, dtypes.int8, base=global_map[base] if base is not None else None, offset=offset)
global_map[i] = Buffer("NULL", size, dtypes.int8, base=global_map[base] if base is not None else None, offset=offset)
if pin: global_map[i].ref(1)
return global_map[i]

View file

@ -35,7 +35,6 @@ class TestScheduleCache(unittest.TestCase):
_, var_vals = t.schedule_with_vars()
self.assertEqual(var_vals, {'pos': 42})
@Context(SPEC=0)
def test_custom_kernel(self):
for i in range(4):
a = Tensor.empty(1)
@ -43,7 +42,6 @@ class TestScheduleCache(unittest.TestCase):
a.realize()
self.assertEqual(a.item(), i)
@Context(SPEC=0)
def test_same_custom_function_reuses_cache(self):
schedule_cache.clear()
fxn = functools.partial(custom_set0_kernel, num=10)

View file

@ -470,5 +470,19 @@ class TestUnfoldableImageChannelSelection(unittest.TestCase):
load = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((10, 10, 4)), arg=0).index(x, ptr=True), UOp.const(dtypes.float, 0)))
self.assertEqual(self._count_nans(load), 1)
class TestDropTrueGate(unittest.TestCase):
def test_drop_true_gate_on_index(self):
# test that INDEX with a constant True gate gets simplified to drop the gate
from tinygrad.codegen.late.devectorizer import load_store_indexing
from tinygrad.uop.ops import graph_rewrite
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
idx = UOp.const(dtypes.index, 0)
true_gate = UOp.const(dtypes.bool, True)
index_with_gate = UOp(Ops.INDEX, dtypes.int.ptr(), (buf, idx, true_gate))
# apply the optimization
result = graph_rewrite(index_with_gate, load_store_indexing)
# the True gate should be dropped (INDEX should only have 2 sources)
self.assertEqual(len(result.src), 2, "True gate should be dropped from INDEX")
if __name__ == '__main__':
unittest.main()

View file

@ -479,6 +479,26 @@ class TestUOpGraph(unittest.TestCase):
for u in uops:
self.assertNotEqual(u.dtype, dtypes.long)
def test_load_idx_no_math_on_loaded(self):
# test the (x+y)<c pattern where x has loads - we shouldn't do math on loaded indices
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(128000), arg=0, src=())
c1 = UOp.range(UOp.const(dtypes.index, 512), 1, AxisType.LOOP)
c2 = UOp.range(UOp.const(dtypes.index, 250), 2, AxisType.LOOP)
c3 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(512), arg=1, src=())
c4 = c3.index(c1) # c4 is a load
c5 = UOp.range(UOp.const(dtypes.index, 240), 0, AxisType.REDUCE)
c6 = ((c2*UOp.const(dtypes.index, 240))+c5)
c7 = UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(60000), arg=2, src=())
c8 = c7.index(c6)
# (loaded + range) < const pattern - loaded value shouldn't be promoted to long
loaded_idx = c4.cast(dtypes.index)
comparison = (loaded_idx + c5) < UOp.const(dtypes.index, 60000)
c9 = comparison.where(c8.cast(dtypes.uint).cast(dtypes.uchar), 0).reduce(c5, arg=Ops.ADD)
c10 = c0.index(((c1*UOp.const(dtypes.index, 250))+c2)).store(c9).end(c1, c2)
uops = to_uops_list([c10])
for u in uops:
self.assertNotEqual(u.dtype, dtypes.long)
def test_fold_gated_load(self):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
@ -686,6 +706,97 @@ class TestExpander(unittest.TestCase):
sink = expander_rewrite(sink)
print(sink)
class TestReduceCollapse(unittest.TestCase):
def test_multi_range_reduce_add(self):
"""Test that (x + y).reduce(r1, r2) distributes over multiple ranges"""
from tinygrad.codegen.simplify import pm_reduce_collapse
# Create two ranges
r1 = UOp.range(3, 0)
r2 = UOp.range(4, 1)
# Create x + y where x and y depend on different ranges
x = r1.cast(dtypes.float)
y = r2.cast(dtypes.float)
# (x + y).reduce(r1, r2) should be rewritten
red = (x + y).reduce(r1, r2, arg=Ops.ADD)
self.assertEqual(len(red.src), 3) # value + 2 ranges
result = graph_rewrite(red, pm_reduce_collapse, name='test')
# Should become add of two separate reduces
self.assertEqual(result.op, Ops.ADD)
class TestLoadStoreFolding(unittest.TestCase):
def test_gated_load_gep_preserves_alt(self):
"""Test that LOAD(GEP, alt) preserves alt value after rewrite"""
from tinygrad.codegen.late.devectorizer import load_store_folding
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.vec(4).ptr(), (), 0)
idx = UOp.const(dtypes.int, 0)
gate = UOp.const(dtypes.bool, True)
gated_index = buf.index(idx, gate)
gep = gated_index.gep(0)
alt = UOp.const(dtypes.float, 42.0)
gated_load = gep.load(alt)
self.assertEqual(len(gated_load.src), 2) # GEP + alt
result = graph_rewrite(gated_load, load_store_folding, name='test')
# After rewrite, should still have alt value preserved
self.assertEqual(result.op, Ops.GEP)
inner_load = result.src[0]
self.assertEqual(inner_load.op, Ops.LOAD)
self.assertEqual(len(inner_load.src), 2) # INDEX + alt
def test_gated_load_ptrcat_preserves_alt(self):
"""Test that LOAD(PTRCAT, alt) preserves alt value after rewrite"""
from tinygrad.codegen.late.devectorizer import load_store_folding
buf1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
buf2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
idx = UOp.const(dtypes.int, 0)
idx1 = buf1.index(idx)
idx2 = buf2.index(idx)
ptrcat = UOp(Ops.PTRCAT, dtypes.float.ptr().vec(2), (idx1, idx2))
alt = UOp.const(dtypes.float.vec(2), 42.0)
gated_load = ptrcat.load(alt)
self.assertEqual(len(gated_load.src), 2) # PTRCAT + alt
result = graph_rewrite(gated_load, load_store_folding, name='test')
# After rewrite, should be CAT of LOADs, each preserving alt
self.assertEqual(result.op, Ops.CAT)
for inner_load in result.src:
self.assertEqual(inner_load.op, Ops.LOAD)
self.assertEqual(len(inner_load.src), 2) # INDEX + alt
self.assertEqual(inner_load.src[1].arg, 42.0) # alt value preserved
class TestConstBufferize(unittest.TestCase):
def test_const_bufferize_with_ranges(self):
"""Test that CONST.BUFFERIZE with ranges is folded correctly.
BUFFERIZE can have ranges as additional sources beyond the value.
The pattern at rangeify.py uses allow_any_len=True because
CONST doesn't depend on ranges (constant is same value everywhere).
"""
from tinygrad.schedule.rangeify import pm_const_buffer_folding, BufferizeOpts
c = UOp.const(dtypes.float, 42.0)
r1 = UOp.range(3, 0)
bufferize_with_range = UOp(Ops.BUFFERIZE, dtypes.float, (c, r1), arg=BufferizeOpts(device="CPU"))
self.assertEqual(len(bufferize_with_range.src), 2) # const + 1 range
result = graph_rewrite(bufferize_with_range, pm_const_buffer_folding, name='test')
# BUFFERIZE should be removed, result is const broadcast to shape
self.assertNotEqual(result.op, Ops.BUFFERIZE)
const_vals = [u.arg for u in result.toposort() if u.op is Ops.CONST and u.dtype == dtypes.float]
self.assertIn(42.0, const_vals)
def test_const_bufferize_with_multiple_ranges(self):
"""Test CONST.BUFFERIZE with multiple ranges is also folded."""
from tinygrad.schedule.rangeify import pm_const_buffer_folding, BufferizeOpts
c = UOp.const(dtypes.float, 3.14)
r1 = UOp.range(3, 0)
r2 = UOp.range(4, 1)
bufferize_with_ranges = UOp(Ops.BUFFERIZE, dtypes.float, (c, r1, r2), arg=BufferizeOpts(device="CPU"))
self.assertEqual(len(bufferize_with_ranges.src), 3) # const + 2 ranges
result = graph_rewrite(bufferize_with_ranges, pm_const_buffer_folding, name='test')
# BUFFERIZE should be removed
self.assertNotEqual(result.op, Ops.BUFFERIZE)
const_vals = [u.arg for u in result.toposort() if u.op is Ops.CONST and u.dtype == dtypes.float]
self.assertIn(3.14, const_vals)
class TestUOpTags(unittest.TestCase):
def test_inc_by_one(self):
g = UOp.const(dtypes.int, 1) + UOp.const(dtypes.int, 1)

View file

@ -1074,6 +1074,24 @@ class TestGatedUopGivenValid(unittest.TestCase):
expected_vec = UOp(Ops.VECTORIZE, dtypes.index.vec(2), (uconst(0), r0))
self.assertEqual(idx, (r0 < 3).where(expected_vec, UOp.invalid()))
class TestRangeSplitting(unittest.TestCase):
def test_range_split_on_mod(self):
# test that mark_range_mod splits RANGE(8) into RANGE(4)*2 + RANGE(2) when used with %2
from tinygrad.codegen.simplify import pm_split_ranges, pm_flatten_range
r0 = UOp.range(uconst(8), 0)
# create a simple expression using the range with mod: store range%2 to a buffer
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
val = (r0 % uconst(2)).cast(dtypes.int)
store = UOp(Ops.STORE, dtypes.void, (buf.index(uconst(0)), val))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.END, dtypes.void, (store, r0)),))
# count RANGEs before
ranges_before = len([u for u in sink.toposort() if u.op is Ops.RANGE])
# apply the range splitting optimization
sink_after = graph_rewrite(sink, pm_split_ranges+pm_flatten_range, ctx={}, name="test split ranges")
# count RANGEs after - should have more due to splitting
ranges_after = len([u for u in sink_after.toposort() if u.op is Ops.RANGE])
self.assertGreater(ranges_after, ranges_before, "RANGE should be split when used with mod of divisible constant")
class TestBounds(unittest.TestCase):
def test_unrolled_arange(self):
# #include <metal_stdlib>

View file

@ -115,8 +115,8 @@ pm_linearize_cleanups = PatternMatcher([
# if statements are not allowed in the graph
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError("if not allowed in graph"))),
# gated INDEX becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat(name="gate", dtype=dtypes.bool))).or_casted(), UPat()),
allow_any_len=True), lambda u, gate: (u, [mif:=UOp(Ops.IF, src=(gate, u.src[0])), u, UOp(Ops.ENDIF, src=(mif,))]))
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat(name="gate", dtype=dtypes.bool))).or_casted(), UPat())),
lambda u, gate: (u, [mif:=UOp(Ops.IF, src=(gate, u.src[0])), u, UOp(Ops.ENDIF, src=(mif,))]))
])
# requires lst be toposorted. like graph rewrite, but for lines

View file

@ -123,12 +123,12 @@ load_store_folding = PatternMatcher([
(UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True),
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
# GEP on data of STORE
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st")), allow_any_len=True, name="sto"), gep_on_store),
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st")), name="sto"), gep_on_store),
# put PTRCAT after LOAD
(UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
lambda cat,ld: UOp(Ops.CAT, cat.dtype.base.vec(cat.dtype.vcount), tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
# put PTRCAT after STORE
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data")), allow_any_len=True, name="sto"), cat_after_store),
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data")), name="sto"), cat_after_store),
])
# *** correct load/store ***
@ -319,8 +319,7 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in ended_ranges])
identity = red.const(red.dtype, identity_element(red.arg, red.dtype.scalar()))
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=ctx.acc_num)
acc_init = acc.after(*input_ranges).index(UOp.const(dtypes.int, 0)).store(identity) if len(input_ranges) else \
acc.index(UOp.const(dtypes.int, 0)).store(identity)
acc_init = acc.after(*input_ranges).index(UOp.const(dtypes.int, 0)).store(identity)
lst = [acc.after(acc_init, *reduce_range).index(UOp.const(dtypes.int, 0))] + lst # put acc as the first element
ctx.acc_num += 1
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
@ -342,6 +341,6 @@ pm_add_loads = PatternMatcher([
(UPat(Ops.INDEX, name="idx"), lambda idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else
idx.replace(dtype=idx.src[0].dtype).load(dtype=idx.dtype.base)),
# remove loads from stores
(UPat(Ops.STORE, src=(UPat(Ops.LOAD),), allow_any_len=True, name="s"), lambda s: s.replace(src=(s.src[0].src[0],)+s.src[1:])),
(UPat(Ops.STORE, src=(UPat(Ops.LOAD), UPat(name="val")), name="s"), lambda s,val: s.replace(src=(s.src[0].src[0], val))),
])

View file

@ -206,6 +206,8 @@ ALLOW_TF32 = ContextVar("ALLOW_TF32", 0)
SCACHE = ContextVar("SCACHE", 1)
# allow use of atomics for embedding backward
USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
# allow use of assembly for gemm
ASM_GEMM = ContextVar("ASM_GEMM", 0)
@dataclass(frozen=True)
class Metadata:

View file

@ -78,7 +78,7 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:dict[str, Any]|None=No
j += "\x20"*(round_up(len(j),8)-len(j))
pathlib.Path(fn).unlink(missing_ok=True)
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
t[0:8].bitcast(dtypes.int64).assign([len(j)])
t[0:8].assign(Tensor([len(j)], dtype=dtypes.int64, device="CPU").bitcast(dtypes.uint8))
t[8:8+len(j)].assign(list(j.encode('utf-8')))
for k,v in safe_load(t).items(): v.assign(tensors[k])

View file

@ -46,10 +46,10 @@ base_rewrite = PatternMatcher([
# new load/store
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True),
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("gate"))).or_casted("bidx"), UPat.var("var")), allow_any_len=True),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("gate"))).or_casted("bidx"), UPat.var("var"))),
lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
(UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"(*{ctx[bidx]})"),
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
(UPat(Ops.LOAD, src=(UPat.var('bidx'),)), lambda ctx,bidx: f"(*{ctx[bidx]})"),
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var"))), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
# alu/gep
# TODO: look for left-associative
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](

View file

@ -4,12 +4,13 @@ import tinygrad.runtime.support.objc as objc
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator, ProfileDeviceEvent, CompilerSet, CompilerPair
from tinygrad.renderer.cstyle import MetalRenderer
from tinygrad.runtime.autogen import metal
from tinygrad.runtime.support.c import DLL
# 13 is requestType that metal uses to compile source code into MTLB, there aren't any docs or symbols.
REQUEST_TYPE_COMPILE = 13
# Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc
ctypes.CDLL("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics")
DLL("CoreGraphics", "CoreGraphics")
# FIXME: these need autogen to support objc categories
# https://developer.apple.com/library/archive/documentation/Cocoa/Conceptual/ObjectiveC/Chapters/ocCategories.html
@ -67,7 +68,7 @@ class MetalCompiler(Compiler):
# doesn't seem to be anything we can do.
with contextlib.suppress(FileNotFoundError, ModuleNotFoundError):
import tinygrad.runtime.autogen.llvm # noqa: F401
support = ctypes.CDLL("/System/Library/PrivateFrameworks/MTLCompiler.framework/MTLCompiler")
support = DLL("MTLCompiler", "MTLCompiler")
support.MTLCodeGenServiceCreate.restype = ctypes.c_void_p
def __init__(self):

View file

@ -88,8 +88,8 @@ def import_asic_regs(prefix:str, version:tuple[int, ...], cls=AMDReg) -> dict[st
return x
def _download_file(ver, suff) -> str:
dir_prefix = {"osssys": "oss"}.get(prefix, prefix)
fetch_name, file_name = f"{prefix}_{'_'.join(map(str, ver))}_{suff}.h", f"{prefix}_{'_'.join(map(str, version))}_{suff}.h"
return header_download(f"include/asic_reg/{dir_prefix}/{fetch_name}", name=file_name, subdir="asic_regs")
fetch_name = f"{prefix}_{'_'.join(map(str, ver))}_{suff}.h"
return header_download(f"include/asic_reg/{dir_prefix}/{fetch_name}", name=fetch_name, subdir="asic_regs")
for ver in fixup_ip_version(prefix, version):
try: offs, sh_masks = _extract_regs(_download_file(ver, "offset")), _extract_regs(_download_file(ver, "sh_mask"))

View file

@ -131,13 +131,15 @@ def init_c_struct_t(sz:int, fields: tuple[tuple, ...]):
def init_c_var(ty, creat_cb): return (creat_cb(v:=del_an(ty)()), v)[1]
class DLL(ctypes.CDLL):
_loaded_: set[str] = set()
@staticmethod
def findlib(nm:str, paths:list[str], extra_paths=[]):
if nm == 'libc' and OSX: return '/usr/lib/libc.dylib'
if pathlib.Path(path:=getenv(nm.replace('-', '_').upper()+"_PATH", '')).is_file(): return path
for p in paths:
libpaths = {"posix": ["/usr/lib64", "/usr/lib", "/usr/local/lib"], "nt": os.environ['PATH'].split(os.pathsep),
"darwin": ["/opt/homebrew/lib", f"/System/Library/Frameworks/{p}.framework"],
"darwin": ["/opt/homebrew/lib", f"/System/Library/Frameworks/{p}.framework", f"/System/Library/PrivateFrameworks/{p}.framework"],
'linux': ['/lib', '/lib64', f"/lib/{sysconfig.get_config_var('MULTIARCH')}", "/usr/lib/wsl/lib/"]}
if (pth:=pathlib.Path(p)).is_absolute():
if pth.is_file(): return p
@ -154,12 +156,12 @@ class DLL(ctypes.CDLL):
if f.read(4) == b'\x7FELF': return str(l)
def __init__(self, nm:str, paths:str|list[str], extra_paths=[], emsg="", **kwargs):
self.nm, self.emsg, self.loaded = nm, emsg, False
self.nm, self.emsg = nm, emsg
if (path:= DLL.findlib(nm, paths if isinstance(paths, list) else [paths], extra_paths if isinstance(extra_paths, list) else [extra_paths])):
if DEBUG >= 3: print(f"loading {nm} from {path}")
try:
super().__init__(path, **kwargs)
self.loaded = True
self._loaded_.add(self.nm)
except OSError as e:
self.emsg = str(e)
if DEBUG >= 3: print(f"loading {nm} failed: {e}")
@ -175,5 +177,6 @@ class DLL(ctypes.CDLL):
return wrapper
def __getattr__(self, nm):
if not self.loaded: raise AttributeError(f"failed to load library {self.nm}: " + (self.emsg or f"try setting {self.nm.upper()+'_PATH'}?"))
if self.nm not in self._loaded_:
raise AttributeError(f"failed to load library {self.nm}: " + (self.emsg or f"try setting {self.nm.upper()+'_PATH'}?"))
return super().__getattr__(nm)

View file

@ -5,10 +5,10 @@ from tinygrad.helpers import DEBUG, to_mv, round_up, OSX
from tinygrad.runtime.support.hcq import MMIOInterface
class USB3:
def __init__(self, vendor:int, dev:int, ep_data_in:int, ep_stat_in:int, ep_data_out:int, ep_cmd_out:int, max_streams:int=31):
def __init__(self, vendor:int, dev:int, ep_data_in:int, ep_stat_in:int, ep_data_out:int, ep_cmd_out:int, max_streams:int=31, use_bot=False):
self.vendor, self.dev = vendor, dev
self.ep_data_in, self.ep_stat_in, self.ep_data_out, self.ep_cmd_out = ep_data_in, ep_stat_in, ep_data_out, ep_cmd_out
self.max_streams = max_streams
self.max_streams, self.use_bot = max_streams, use_bot
self.ctx = ctypes.POINTER(libusb.struct_libusb_context)()
if libusb.libusb_init(ctypes.byref(self.ctx)): raise RuntimeError("libusb_init failed")
@ -25,30 +25,34 @@ class USB3:
# Set configuration and claim interface
if libusb.libusb_set_configuration(self.handle, 1): raise RuntimeError("set_configuration failed")
if libusb.libusb_claim_interface(self.handle, 0): raise RuntimeError("claim_interface failed. sudo required?")
if libusb.libusb_set_interface_alt_setting(self.handle, 0, 1): raise RuntimeError("alt_setting failed")
# Clear any stalled endpoints
all_eps = (self.ep_data_out, self.ep_data_in, self.ep_stat_in, self.ep_cmd_out)
for ep in all_eps: libusb.libusb_clear_halt(self.handle, ep)
if use_bot:
self._tag = 0
else:
if libusb.libusb_set_interface_alt_setting(self.handle, 0, 1): raise RuntimeError("alt_setting failed")
# Allocate streams
stream_eps = (ctypes.c_uint8 * 3)(self.ep_data_out, self.ep_data_in, self.ep_stat_in)
if (rc:=libusb.libusb_alloc_streams(self.handle, self.max_streams * len(stream_eps), stream_eps, len(stream_eps))) < 0:
raise RuntimeError(f"alloc_streams failed: {rc}")
# Clear any stalled endpoints
all_eps = (self.ep_data_out, self.ep_data_in, self.ep_stat_in, self.ep_cmd_out)
for ep in all_eps: libusb.libusb_clear_halt(self.handle, ep)
# Base cmd
cmd_template = bytes([0x01, 0x00, 0x00, 0x01, *([0] * 12), 0xE4, 0x24, 0x00, 0xB2, 0x1A, 0x00, 0x00, 0x00, *([0] * 8)])
# Allocate streams
stream_eps = (ctypes.c_uint8 * 3)(self.ep_data_out, self.ep_data_in, self.ep_stat_in)
if (rc:=libusb.libusb_alloc_streams(self.handle, self.max_streams * len(stream_eps), stream_eps, len(stream_eps))) < 0:
raise RuntimeError(f"alloc_streams failed: {rc}")
# Init pools
self.tr = {ep: [libusb.libusb_alloc_transfer(0) for _ in range(self.max_streams)] for ep in all_eps}
# Base cmd
cmd_template = bytes([0x01, 0x00, 0x00, 0x01, *([0] * 12), 0xE4, 0x24, 0x00, 0xB2, 0x1A, 0x00, 0x00, 0x00, *([0] * 8)])
self.buf_cmd = [(ctypes.c_uint8 * len(cmd_template))(*cmd_template) for _ in range(self.max_streams)]
self.buf_stat = [(ctypes.c_uint8 * 64)() for _ in range(self.max_streams)]
self.buf_data_in = [(ctypes.c_uint8 * 0x1000)() for _ in range(self.max_streams)]
self.buf_data_out = [(ctypes.c_uint8 * 0x80000)() for _ in range(self.max_streams)]
self.buf_data_out_mvs = [to_mv(ctypes.addressof(self.buf_data_out[i]), 0x80000) for i in range(self.max_streams)]
# Init pools
self.tr = {ep: [libusb.libusb_alloc_transfer(0) for _ in range(self.max_streams)] for ep in all_eps}
for slot in range(self.max_streams): struct.pack_into(">B", self.buf_cmd[slot], 3, slot + 1)
self.buf_cmd = [(ctypes.c_uint8 * len(cmd_template))(*cmd_template) for _ in range(self.max_streams)]
self.buf_stat = [(ctypes.c_uint8 * 64)() for _ in range(self.max_streams)]
self.buf_data_in = [(ctypes.c_uint8 * 0x1000)() for _ in range(self.max_streams)]
self.buf_data_out = [(ctypes.c_uint8 * 0x80000)() for _ in range(self.max_streams)]
self.buf_data_out_mvs = [to_mv(ctypes.addressof(self.buf_data_out[i]), 0x80000) for i in range(self.max_streams)]
for slot in range(self.max_streams): struct.pack_into(">B", self.buf_cmd[slot], 3, slot + 1)
def _prep_transfer(self, tr, ep, stream_id, buf, length):
tr.contents.dev_handle, tr.contents.endpoint, tr.contents.length, tr.contents.buffer = self.handle, ep, length, buf
@ -68,38 +72,90 @@ class USB3:
if tr.contents.status == libusb.LIBUSB_TRANSFER_COMPLETED: running -= 1
elif tr.contents.status != 0xFF: raise RuntimeError(f"EP 0x{tr.contents.endpoint:02X} error: {tr.contents.status}")
def _bulk_out(self, ep: int, payload: bytes, timeout: int = 1000):
transferred = ctypes.c_int(0)
rc = libusb.libusb_bulk_transfer(
self.handle,
ep,
(ctypes.c_ubyte * len(payload))(*payload),
len(payload),
ctypes.byref(transferred),
timeout,
)
assert rc == 0, f"bulk OUT 0x{ep:02X} failed: {rc}"
assert transferred.value == len(payload), f"bulk OUT short write on 0x{ep:02X}: {transferred.value}/{len(payload)} bytes"
def _bulk_in(self, ep: int, length: int, timeout: int = 1000) -> bytes:
buf, transferred = (ctypes.c_ubyte * length)(), ctypes.c_int(0)
rc = libusb.libusb_bulk_transfer(
self.handle,
ep,
buf,
length,
ctypes.byref(transferred),
timeout,
)
assert rc == 0, f"bulk IN 0x{ep:02X} failed: {rc}"
return bytes(buf[:transferred.value])
def send_batch(self, cdbs:list[bytes], idata:list[int]|None=None, odata:list[bytes|None]|None=None) -> list[bytes|None]:
idata, odata = idata or [0] * len(cdbs), odata or [None] * len(cdbs)
results, tr_window, op_window = [], [], []
results:list[bytes|None] = []
tr_window, op_window = [], []
for idx, (cdb, rlen, send_data) in enumerate(zip(cdbs, idata, odata)):
# allocate slot and stream. stream is 1-based
slot, stream = idx % self.max_streams, (idx % self.max_streams) + 1
if self.use_bot:
dir_in = rlen > 0
data_len = rlen if dir_in else (len(send_data) if send_data is not None else 0)
assert (data_len == 0) if dir_in else (rlen == 0), "BOT mode only supports either read or write per command"
# build cmd packet
self.buf_cmd[slot][16:16+len(cdb)] = list(cdb)
# CBW
self._tag += 1
flags = 0x80 if dir_in else 0x00
cbw = struct.pack("<IIIBBB", 0x43425355, self._tag, data_len, flags, 0, len(cdb)) + cdb + b"\x00" * (16 - len(cdb))
self._bulk_out(self.ep_data_out, cbw)
# cmd + stat transfers
tr_window.append(self._prep_transfer(self.tr[self.ep_cmd_out][slot], self.ep_cmd_out, None, self.buf_cmd[slot], len(self.buf_cmd[slot])))
tr_window.append(self._prep_transfer(self.tr[self.ep_stat_in][slot], self.ep_stat_in, stream, self.buf_stat[slot], 64))
# DAT
if dir_in:
results.append(self._bulk_in(self.ep_data_in, rlen))
else:
if send_data is not None:
self._bulk_out(self.ep_data_out, send_data)
results.append(None)
if rlen:
if rlen > len(self.buf_data_in[slot]): self.buf_data_in[slot] = (ctypes.c_uint8 * round_up(rlen, 0x1000))()
tr_window.append(self._prep_transfer(self.tr[self.ep_data_in][slot], self.ep_data_in, stream, self.buf_data_in[slot], rlen))
# CSW
sig, rtag, residue, status = struct.unpack("<IIIB", self._bulk_in(self.ep_data_in, 13, timeout=2000))
assert sig == 0x53425355, f"Bad CSW signature 0x{sig:08X}, expected 0x53425355"
assert rtag == self._tag, f"CSW tag mismatch: got {rtag}, expected {self._tag}"
assert status == 0, f"SCSI command failed, CSW status=0x{status:02X}, residue={residue}"
else:
# allocate slot and stream. stream is 1-based
slot, stream = idx % self.max_streams, (idx % self.max_streams) + 1
if send_data is not None:
if len(send_data) > len(self.buf_data_out[slot]):
self.buf_data_out[slot] = (ctypes.c_uint8 * len(send_data))()
self.buf_data_out_mvs[slot] = to_mv(ctypes.addressof(self.buf_data_out[slot]), len(send_data))
# build cmd packet
self.buf_cmd[slot][16:16+len(cdb)] = list(cdb)
self.buf_data_out_mvs[slot][:len(send_data)] = bytes(send_data)
tr_window.append(self._prep_transfer(self.tr[self.ep_data_out][slot], self.ep_data_out, stream, self.buf_data_out[slot], len(send_data)))
# cmd + stat transfers
tr_window.append(self._prep_transfer(self.tr[self.ep_cmd_out][slot], self.ep_cmd_out, None, self.buf_cmd[slot], len(self.buf_cmd[slot])))
tr_window.append(self._prep_transfer(self.tr[self.ep_stat_in][slot], self.ep_stat_in, stream, self.buf_stat[slot], 64))
op_window.append((idx, slot, rlen))
if (idx + 1 == len(cdbs)) or len(op_window) >= self.max_streams:
self._submit_and_wait(tr_window)
for idx, slot, rlen in op_window: results.append(bytes(self.buf_data_in[slot][:rlen]) if rlen else None)
tr_window = []
if rlen:
if rlen > len(self.buf_data_in[slot]): self.buf_data_in[slot] = (ctypes.c_uint8 * round_up(rlen, 0x1000))()
tr_window.append(self._prep_transfer(self.tr[self.ep_data_in][slot], self.ep_data_in, stream, self.buf_data_in[slot], rlen))
if send_data is not None:
if len(send_data) > len(self.buf_data_out[slot]):
self.buf_data_out[slot] = (ctypes.c_uint8 * len(send_data))()
self.buf_data_out_mvs[slot] = to_mv(ctypes.addressof(self.buf_data_out[slot]), len(send_data))
self.buf_data_out_mvs[slot][:len(send_data)] = bytes(send_data)
tr_window.append(self._prep_transfer(self.tr[self.ep_data_out][slot], self.ep_data_out, stream, self.buf_data_out[slot], len(send_data)))
op_window.append((idx, slot, rlen))
if (idx + 1 == len(cdbs)) or len(op_window) >= self.max_streams:
self._submit_and_wait(tr_window)
for idx, slot, rlen in op_window: results.append(bytes(self.buf_data_in[slot][:rlen]) if rlen else None)
tr_window = []
return results

View file

@ -71,7 +71,8 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
# allgather
copied_chunks = []
for i,rc in enumerate(reduced_chunks):
if use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(n_lbs))))
if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg))
elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(n_lbs))))
else:
this_chunk: list[UOp|None] = [None] * n_lbs
this_chunk[(i+n_lbs-1)%n_lbs] = rc
@ -93,12 +94,7 @@ def mstack_early_shrink(ms:UOp, shrink:UOp):
return s.shrink(tuple(new_arg))
for i, x in enumerate(ms.src):
if x.op is Ops.COPY:
# if src device doesn't have a renderer, we have to view after the copy
# TODO: a way to understand this
if x.src[0].device in {"DISK", "NPY"}:
ret.append(apply_shrink(x, i))
else:
ret.append(apply_shrink(x.src[0], i).copy_to_device(x.device))
ret.append(apply_shrink(x.src[0], i).copy_to_device(x.device))
else:
ret.append(apply_shrink(x, i).contiguous())
return ms.replace(src=tuple(ret))

View file

@ -60,7 +60,7 @@ def split_reduceop(reduce:UOp, x:UOp):
mop_cleanup = PatternMatcher([
# merge adjacent RESHAPES, safe because they are not tagged
(UPat(Ops.RESHAPE, name="x2").f(Ops.RESHAPE, allow_any_len=True, name="x"),
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE, name="x2"), UPat()), name="x"),
lambda x,x2: x.replace(src=(x2.src[0], x.src[1])) if x.tag is None and x2.tag is None else None),
])
@ -90,7 +90,7 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
(UPat(Ops.CUSTOM_KERNEL, name="ck"), resolve_custom_kernel),
# remove CONTIGUOUS if the BUFFER is already contiguous
(UPat(Ops.BUFFER).f(Ops.RESHAPE, allow_any_len=True, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
(UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER), UPat()), name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
# split_reduceop
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
@ -126,6 +126,10 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
# ** assign rules **
# move bitcast from assign target to source: a.bitcast(X).assign(src) -> a.assign(src.bitcast(a.dtype))
(UPat(Ops.ASSIGN, src=(UPat(Ops.BITCAST, src=(UPat(name="target"),)), UPat(name="src")), name="assign"),
lambda target, src, assign: target.assign(src.bitcast(target.dtype)).replace(tag=assign.tag)),
# assign only to buffer, otherwise make it a CONTIGUOUS
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.BUFFER}, name="target"), UPat(name="x")), name="assign"),
lambda x,target,assign: x.f(Ops.CONTIGUOUS, tag=assign.tag) if ((t:=target.base).op is not Ops.BUFFER and \
@ -247,7 +251,7 @@ pm_const_buffer_folding = pm_mops+PatternMatcher([
# dont bufferize an arange
(UPat.any((r:=UPat(dtype=dtypes.index).cast()).named("src"), r.eq(UPat()).named("src")).f(Ops.BUFFERIZE,
allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize),
# no buffers for const
# no buffers for const (ranges don't matter for const - it's the same value everywhere)
(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg).rtag(b.tag)),
# indexing a const is a const
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),),), lambda c: c),
@ -271,20 +275,20 @@ def late_buffer_view(t:UOp, b:UOp):
size = prod(shape)
# walk up for the INDEX
# NOTE: even though we allow RESHAPE and SHRINK, they can combine to form non-contiguous access patterns (e.g. t[::2])
x = t
while not any(u.op is Ops.INDEX for u in x.src):
assert x.op not in GroupOp.Elementwise, "can't buffer view elementwise"
while x.op is not Ops.INDEX:
assert x.op in {Ops.BITCAST, Ops.CONTIGUOUS, Ops.SHRINK, Ops.RESHAPE}, f"unexpected op {x.op} in buffer view walk"
x = x.src[0]
x = next(u for u in x.src if u.op is Ops.INDEX)
if len(shape) == 0: offset = x.src[1].arg
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
else: offset = sum(idx.vmin for idx in x.src[1:])
if offset < 0: raise RuntimeError(f"negative offset {offset} in buffer view")
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset), tag=t.tag),) + b.src[1:])
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset), tag=t.tag), b.src[1]))
to_bufferview = PatternMatcher([
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t").f(Ops.BUFFERIZE, allow_any_len=True, name="b"), late_buffer_view),
(UPat((Ops.BITCAST, Ops.CONTIGUOUS)).f(Ops.BUFFER_VIEW, name="b"), lambda b: b.replace(src=b.src[0].src)),
(UPat(Ops.BUFFERIZE, src=(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t"), UPat()), name="b"), late_buffer_view),
])
DEVICE_MAX_BUFS = {"METAL": 31, "WEBGPU": 8} # TODO: get from device?
@ -451,9 +455,6 @@ to_define_global = PatternMatcher([
# this is only needed if you are using symbolic
(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
# remove RANGE with 0 size
(UPat(Ops.RANGE, name="r"), lambda r: UOp.const(dtypes.index, 0) if r.vmax == 0 else None),
# renumber the ranges starting with 0 so that kernel deduping works
(UPat(Ops.RANGE, name="r"), renumber_range),
])
@ -469,9 +470,6 @@ rangeify_codegen = PatternMatcher([
# TODO: this can be moved into codegen?
(UPat(Ops.NOOP, name="x"), lambda x: x.src[0]),
# strip the arg from store
(UPat(Ops.STORE, name="x"), lambda x: x.replace(arg=None) if x.arg is not None else None),
# add loads to non ptr indexes
# TODO: this can be moved into codegen?
#(UPat.any(UPat(Ops.DEFINE_GLOBAL, name="dg"), UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True, name="dg"))

View file

@ -7,7 +7,7 @@ if TYPE_CHECKING: import numpy
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ASM_GEMM, ceildiv, fetch, polyN, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import suppress_finalizing, disable_gc
from tinygrad.gradient import compute_gradient
from tinygrad.mixin import OpMixin
@ -286,20 +286,20 @@ class Tensor(OpMixin):
return self
def assign(self, x:Tensor|PyConst|list|tuple) -> Tensor:
is_disk = isinstance(self.device, str) and self.device.startswith("DISK")
if not isinstance(x, Tensor): x = Tensor(x, device="CPU" if is_disk else self.device, dtype=self.dtype)
if self.uop is x.uop: return self # a self assign is a NOOP
# broadcast x (shape only, dtype must match)
if self.shape != x.shape: x = x._broadcast_to(self.shape)
if self.shape != x.shape: raise RuntimeError(f"assign shape mismatch {self.shape} != {x.shape}")
if not is_disk and self.device != x.device: raise RuntimeError(f"assign device mismatch {self.device} != {x.device}")
if self.dtype != x.dtype: raise RuntimeError(f"assign dtype mismatch {self.dtype} != {x.dtype}")
if isinstance(self.device, tuple) and self.uop.axis != x.uop.axis: raise RuntimeError(f"multi axis mismatch {self.uop.axis} != {x.uop.axis}")
# TODO: this is a hack for writing to DISK. remove with working assign
if isinstance(self.device, str) and self.device.startswith("DISK"):
if not isinstance(x, Tensor): x = Tensor(x, device="CPU", dtype=self.dtype)
if is_disk:
self._buffer().copyin(x._data())
return self
if not isinstance(x, Tensor): x = Tensor(x, device=self.device, dtype=self.dtype)
if self.uop is x.uop: return self # a self assign is a NOOP
# NOTE: we allow cross device assign
# broadcast x
if least_upper_dtype(self.dtype, x.dtype) == self.dtype: x = x._broadcast_to(self.shape).cast(self.dtype)
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
assert not isinstance(self.device, tuple) or self.uop.axis == x.uop.axis, f"multi assign axis mismatch {self.uop.axis} != {x.uop.axis}"
return self.replace(self._apply_uop(UOp.assign, x))
def detach(self) -> Tensor:
@ -2431,6 +2431,9 @@ class Tensor(OpMixin):
```
"""
if IMAGE: return self.image_dot(w, dtype)
if ASM_GEMM:
from extra.gemm.asm.cdna.gemm import can_use_asm_gemm, asm_gemm
if can_use_asm_gemm(self, w): return asm_gemm(self, w)
x, dx, dw = self, self.ndim, w.ndim
if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D")
if x.shape[-1] != w.shape[axis_w:=-min(w.ndim,2)]: raise RuntimeError(f"cannot dot {x.shape} and {w.shape}")
@ -3865,7 +3868,10 @@ class Tensor(OpMixin):
def bitcast(self, dtype:DTypeLike) -> Tensor:
"""
Bitcasts `self` to the given `dtype` of the same itemsize.
Bitcasts `self` to the given `dtype`.
When the target dtype has the same itemsize, this is a view of the same memory.
When itemsizes differ, the last dimension is adjusted and a new Tensor is created.
`self` must not require a gradient.

View file

@ -14,7 +14,7 @@ class Ops(FastEnum):
# ** 1 -- defines/special **
# define GLOBAL/VAR are ptrs to outside the Kernel
DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); BIND = auto()
DEFINE_VAR = auto(); BIND = auto()
# this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly
SPECIAL = auto()
@ -28,6 +28,9 @@ class Ops(FastEnum):
NOOP = auto(); REWRITE_ERROR = auto()
PARAM = auto(); CALL = auto()
# TODO: remove this alias, DEFINE_GLOBAL is PARAM now
DEFINE_GLOBAL = PARAM
# renderer
# LINEAR is a list of UOps, SOURCE has a str arg that's human readable, BINARY has bytes arg that's compiled
PROGRAM = auto(); LINEAR = auto(); SOURCE = auto(); BINARY = auto()

View file

@ -206,7 +206,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | Ops.SINK | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY:
return None
@ -224,8 +224,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
case Ops.BUFFER_VIEW: return (self.arg[0],)
case Ops.ENCDEC: return self.arg[0]
case Ops.BUFFERIZE: return tuple([int(r.vmax+1) for r in self.src[1:]])
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
case Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
case Ops.PARAM:
if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size,)
# NOTE: copied from marg
if len(self.src) >= 1: return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count))
return None
@ -413,10 +414,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
def store(self, src:UOp|ConstType, **kwargs):
return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self, UOp.const(self.dtype, src) if not isinstance(src, UOp) else src), **kwargs)
def end(self, *src:UOp):
if len(src) == 0: return self
return UOp(Ops.END, src=(self,)+src)
def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs)
def end(self, *src:UOp): return UOp(Ops.END, src=(self,)+src) if len(src) else self
def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs) if len(src) else self
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
def contract(self, *rngs:UOp):
@ -1436,6 +1435,7 @@ def pyrender(ast:UOp) -> str:
for s in u.src: to_render.add(s)
if u.op is Ops.STORE: to_render.add(u.src[1])
if u.op in {Ops.REDUCE, Ops.REDUCE_AXIS}: to_render.add(u.src[0])
if u.op in {Ops.CUSTOM_KERNEL, Ops.CALL}: raise NotImplementedError("custom_kernel / call can't be pyrendered")
if u.op in not_rendered: continue
# checking the consumers is not enough, you have to make sure it's not used twice by the one consumer
if len(cmap[u]) == 1 and len([x for x in list(cmap[u].keys())[0].src if x is u]) == 1 and u.op not in always_rendered: continue
@ -1444,7 +1444,13 @@ def pyrender(ast:UOp) -> str:
kernels: dict[UOp, tuple[str, str]] = {}
r: dict[UOp, str] = {}
ret: dict[str, str] = {}
depth: dict[UOp, int] = {}
for i,u in enumerate(lst):
# limit inline depth to avoid "too many nested parentheses" in Python parser
op_depth = 1 + max([depth[s] for s in u.src], default=0)
if op_depth > 100: to_render.add(u)
depth[u] = 0 if u in to_render else op_depth
# do the rendering
if u.op is Ops.KERNEL:
if u.arg.ast not in kernels:
kernels[u.arg.ast] = (f"k{len(kernels)}", f"def k{len(kernels)}():\n " + pyrender(u.arg.ast).replace('\n', '\n ') + "\n return ast\n\n")

View file

@ -84,14 +84,11 @@ _tensor_spec = PatternMatcher([
(UPat(Ops.LUNIQUE, dtypes.void, ()), lambda: True),
(UPat(Ops.DEVICE, dtypes.void, (), name="d"), lambda d:
isinstance(d.arg, str) or (isinstance(d.arg, tuple) and all(isinstance(s, str) for s in d.arg))),
(UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE)), allow_any_len=True, name="buf"),
(UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE)), name="buf"),
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"),
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)),
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.MSTACK, src=UPat(Ops.BUFFER)),)), lambda: True),
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND, Ops.CONTIGUOUS))), lambda: True),
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
# ASSIGN has a target and a value. It can also optionally depend on other assigns
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
@ -316,7 +313,8 @@ def eval_pyrender(code:str) -> UOp:
return lcls['ast']
def test_pyrender(test_ast:UOp, assert_parents=True):
code = pyrender(test_ast)
try: code = pyrender(test_ast)
except NotImplementedError: return None # this is okay, not all ops can be pyrendered
ast:UOp = eval_pyrender(code)
if ast is not test_ast:
if assert_parents:

View file

@ -306,11 +306,12 @@ def load_counters(profile:list[ProfileEvent]) -> None:
# to decode a SQTT trace, we need the raw stream, program binary and device properties
if (sqtt:=v.get(ProfileSQTTEvent)):
for e in sqtt:
if e.itrace: steps.append(create_step(f"PKTS SE:{e.se}", (f"/prg-pkts-{e.se}", len(ctxs), len(steps)), data=(e.blob, prg_events[k].lib)))
if e.itrace: steps.append(create_step(f"PKTS SE:{e.se}", (f"/prg-pkts-{e.se}", len(ctxs), len(steps)),
data=(e.blob, prg_events[k].lib, device_props[e.device]["gfx_target_version"])))
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k])))
ctxs.append({"name":f"Exec {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
def sqtt_timeline(data:bytes, lib:bytes) -> list[ProfileEvent]:
def sqtt_timeline(data:bytes, lib:bytes, target:int) -> list[ProfileEvent]:
from extra.assembly.amd.sqttmap import map_insts, InstructionInfo
from extra.assembly.amd.sqtt import PacketType, INST, InstOp, VALUINST, IMMEDIATE, IMMEDIATE_MASK, VMEMEXEC, ALUEXEC
ret:list[ProfileEvent] = []
@ -321,7 +322,7 @@ def sqtt_timeline(data:bytes, lib:bytes) -> list[ProfileEvent]:
rows.setdefault(r:=(f"WAVE:{wave}" if wave is not None else f"{p.__class__.__name__}:0 {name}"))
key = TracingKey(f"{op_name if op_name is not None else name} OP:{idx}", ret=info.inst.disasm() if info is not None else None)
ret.append(ProfileRangeEvent(r, key, Decimal(p._time), Decimal(p._time+width)))
for p, info in map_insts(data, lib):
for p, info in map_insts(data, lib, target):
if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break
if isinstance(p, INST):
op_name = p.op.name if isinstance(p.op, InstOp) else f"0x{p.op:02x}"
@ -346,7 +347,7 @@ def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent) -> tuple[
# * init decoder
from extra.sqtt.roc import decode
base = unwrap(p.base)
addr_table = amd_decode(unwrap(p.lib), device_props[p.device]["gfx_target_version"], )
addr_table = amd_decode(unwrap(p.lib), device_props[p.device]["gfx_target_version"])
disasm:dict[int, tuple[str, int]] = {addr+base:(inst.disasm(), inst.size()) for addr, inst in addr_table.items()}
rctx = decode(data, {p.tag:disasm})
cu_events:dict[str, list[ProfileEvent]] = {}
@ -422,16 +423,14 @@ def get_stdout(f: Callable) -> str:
return buf.getvalue()
def amd_readelf(lib:bytes) -> list[dict]:
from tinygrad.runtime.autogen import amdgpu_kd
from tinygrad.runtime.support.elf import elf_loader
import msgpack
_, sections, __ = elf_loader(lib)
data = next((s for s in sections if s.name.startswith(".note"))).content
namesz, descsz, typ = struct.unpack_from(hdr:="<III", data, 0)
offset = (struct.calcsize(hdr)+namesz+3) & -4
notes = msgpack.unpackb(data[offset:offset+descsz])
keys = {".sgpr_count":"SGPRs", ".vgpr_count":"VGPRs", ".max_flat_workgroup_size":"Max WGP size",
".group_segment_fixed_size":"LDS size", ".private_segment_fixed_size":"Scratch size"}
return [{"label":label, "value":v} for k,label in keys.items() if (v:=notes["amdhsa.kernels"][0][k]) > 0]
image, sections, __ = elf_loader(lib)
rodata = next((s for s in sections if s.name == ".rodata")).content
kd = amdgpu_kd.llvm_amdhsa_kernel_descriptor_t.from_buffer_copy(bytearray(rodata))
vgpr_gran = kd.compute_pgm_rsrc1 & amdgpu_kd.COMPUTE_PGM_RSRC1_GRANULATED_WORKITEM_VGPR_COUNT
return [{"label":f"{resource} Alloc", "value":val} for resource,val in [("VGPR", (vgpr_gran+1)*8-7), ("LDS",kd.group_segment_fixed_size),
("Scratch", kd.private_segment_fixed_size)] if val > 0]
def amd_decode(lib:bytes, target:int) -> dict[int, Any]: # Any is the Inst class from extra.assembly.amd.dsl
from tinygrad.runtime.support.elf import elf_loader
@ -558,7 +557,7 @@ def get_render(query:str) -> dict:
rows:dict[int, dict] = {}
for pc, (inst,_) in pc_to_inst.items():
if start_pc is None: start_pc = pc
rows[pc] = {"pc":pc-start_pc, "inst":inst, "hit_count":0, "dur":0, "stall":0, "hits":{"cols":inst_columns, "rows":[]}, "type":""}
rows[pc] = {"pc":pc-start_pc, "inst":inst, "hit_count":0, "dur":0, "stall":0, "type":"", "hits":{"cols":inst_columns, "rows":[]}}
for e in w.unpack_insts():
if not (inst:=rows[e.pc]).get("type"): inst["type"] = str(e.typ).split("_")[-1]
inst["hit_count"] += 1