mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
assembly/amd: add gfx12_asm_vflat llvm tests, disasm fixes (#15046)
* add gfx12_asm_vflat.s * work
This commit is contained in:
parent
010d2790ce
commit
ad99b77f6d
2 changed files with 15 additions and 6 deletions
|
|
@ -324,6 +324,12 @@ def _disasm_smem(inst: SMEM) -> str:
|
|||
if name in ('s_memrealtime', 's_memtime'): return f"{name} {_fmt_sdst(inst.sdata, dst_n, cdna)}"
|
||||
return f"{name} {_fmt_sdst(inst.sdata, dst_n, cdna)}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (getattr(inst, 'dlc', 0), " dlc"))
|
||||
|
||||
R4_TH_LOAD = {1: 'TH_LOAD_NT', 2: 'TH_LOAD_HT', 3: 'TH_LOAD_LU', 4: 'TH_LOAD_RT_WB', 5: 'TH_LOAD_NT_WB'}
|
||||
R4_TH_STORE = {1: 'TH_STORE_NT', 2: 'TH_STORE_HT', 3: 'TH_STORE_ST', 4: 'TH_STORE_RT_WB', 5: 'TH_STORE_NT_WB'}
|
||||
R4_TH_ATOMIC = {1: 'TH_ATOMIC_RETURN', 2: 'TH_ATOMIC_NT', 3: 'TH_ATOMIC_RETURN_NT',
|
||||
4: 'TH_ATOMIC_CASCADE_RT', 5: 'TH_ATOMIC_CASCADE_RETURN', 6: 'TH_ATOMIC_CASCADE_NT', 7: 'TH_ATOMIC_CASCADE_RETURN_NT'}
|
||||
R4_SCOPE = {1: 'SCOPE_SE', 2: 'SCOPE_DEV', 3: 'SCOPE_SYS'}
|
||||
|
||||
def _disasm_flat(inst: FLAT) -> str:
|
||||
name, cdna, r4 = inst.op_name.lower(), _is_cdna(inst), _is_r4(inst)
|
||||
acc = getattr(inst, 'acc', 0)
|
||||
|
|
@ -331,9 +337,10 @@ def _disasm_flat(inst: FLAT) -> str:
|
|||
if r4: seg = 'flat' if (cls_name:=inst.__class__.__name__) == 'VFLAT' else ('global' if cls_name == 'VGLOBAL' else 'scratch')
|
||||
else: seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
|
||||
instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}"
|
||||
# Global/scratch uses 13-bit signed offset
|
||||
# Global/scratch uses 13-bit signed offset (RDNA3/CDNA), 24-bit signed offset (RDNA4)
|
||||
offset = inst.ioffset if r4 else inst.offset # type: ignore[attr-defined]
|
||||
if seg != 'flat':
|
||||
if r4: off_val = offset if offset < (1 << 23) else offset - (1 << 24) # sign extend 24-bit
|
||||
elif seg != 'flat':
|
||||
if cdna:
|
||||
# CDNA: bit 12 is sign bit but not in offset field
|
||||
raw = int.from_bytes(inst.to_bytes(), 'little')
|
||||
|
|
@ -348,7 +355,9 @@ def _disasm_flat(inst: FLAT) -> str:
|
|||
w = regs.get('data', regs.get('d', 1)) if 'store' in name or 'atomic' in name else regs.get('d', 1)
|
||||
off_s = f" offset:{off_val}" if off_val else ""
|
||||
if cdna: mods = f"{off_s}{' sc0' if inst.sc0 else ''}{' nt' if inst.nt else ''}{' sc1' if getattr(inst, 'sc1', 0) else ''}" # type: ignore[attr-defined]
|
||||
elif r4: mods = f"{off_s}{' scope' if inst.scope else ''}{' th' if inst.th else ''}" # type: ignore[attr-defined]
|
||||
elif r4:
|
||||
th_names = R4_TH_ATOMIC if 'atomic' in name else (R4_TH_STORE if 'store' in name else R4_TH_LOAD)
|
||||
mods = off_s + (f" th:{th_names[inst.th]}" if inst.th in th_names else "") + (f" scope:{R4_SCOPE[inst.scope]}" if inst.scope in R4_SCOPE else "")
|
||||
else: mods = f"{off_s}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
|
||||
if seg == 'flat': saddr_s = ""
|
||||
elif _unwrap(inst.saddr) in (0x7F, 124): saddr_s = ", off"
|
||||
|
|
@ -357,7 +366,7 @@ def _disasm_flat(inst: FLAT) -> str:
|
|||
saddr_s = f", {(SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS)[_unwrap(inst.saddr)]}"
|
||||
elif t := _ttmp(inst.saddr, 2): saddr_s = f", {t}"
|
||||
else: saddr_s = f", {_sreg(inst.saddr, 2) if _unwrap(inst.saddr) < 106 else decode_src(_unwrap(inst.saddr), cdna)}"
|
||||
if 'addtid' in name: return f"{instr} {reg_fn(inst.data if 'store' in name else inst.vdst)}{saddr_s}{mods}"
|
||||
if 'addtid' in name: return f"{instr} {reg_fn((inst.vsrc if r4 else inst.data) if 'store' in name else inst.vdst)}{saddr_s}{mods}"
|
||||
# RDNA4: vaddr instead of addr, vsrc instead of data
|
||||
addr = inst.vaddr if r4 else inst.addr # type: ignore[attr-defined]
|
||||
data = inst.vsrc if r4 else inst.data # type: ignore[attr-defined]
|
||||
|
|
@ -372,7 +381,7 @@ def _disasm_flat(inst: FLAT) -> str:
|
|||
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(addr, addr_w)
|
||||
data_s, vdst_s = reg_fn(data, w), reg_fn(inst.vdst, w // 2 if 'cmpswap' in name else w)
|
||||
if 'atomic' in name:
|
||||
glc_or_sc0 = inst.sc0 if cdna else inst.glc # type: ignore[attr-defined]
|
||||
glc_or_sc0 = inst.sc0 if cdna else (inst.th & 1 if r4 else inst.glc) # type: ignore[attr-defined]
|
||||
sfx = f"{saddr_s if seg != 'flat' else ''}{mods}"
|
||||
return f"{instr} {vdst_s}, {addr_s}, {data_s}{sfx}" if glc_or_sc0 else f"{instr} {addr_s}, {data_s}{sfx}"
|
||||
if 'store' in name: return f"{instr} {addr_s}, {data_s}{saddr_s}{mods}"
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ RDNA4_FILES = ['gfx12_asm_sop1.s', 'gfx12_asm_sop2.s', 'gfx12_asm_sopp.s', 'gfx1
|
|||
'gfx12_asm_vop1.s', 'gfx12_asm_vop2.s', 'gfx12_asm_vopc.s', 'gfx12_asm_vopcx.s', 'gfx12_asm_vop3.s', 'gfx12_asm_vop3c.s',
|
||||
'gfx12_asm_vop3cx.s', 'gfx12_asm_vop3p.s', 'gfx12_asm_vop3_from_vop1.s', 'gfx12_asm_vop3_from_vop2.s',
|
||||
'gfx12_asm_vop3p_features.s', 'gfx12_asm_vopd.s', 'gfx12_asm_vopd_features.s',
|
||||
'gfx12_asm_ds.s', 'gfx12_asm_smem.s',
|
||||
'gfx12_asm_ds.s', 'gfx12_asm_smem.s', 'gfx12_asm_vflat.s',
|
||||
'gfx12_asm_wmma_w32.s']
|
||||
|
||||
def _parse_llvm_tests(text: str, pattern: str) -> list[tuple[str, bytes]]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue