mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
amd: EMU DPP support (#15719)
* EMU DPP support from GPT 5.4 * cleanups * simple * nope * fix
This commit is contained in:
parent
2b8d303f75
commit
359b1582d6
12 changed files with 627 additions and 83 deletions
|
|
@ -2,6 +2,7 @@
|
|||
from __future__ import annotations
|
||||
import re
|
||||
from typing import Callable
|
||||
from test.amd.helpers import decode_dpp16
|
||||
from tinygrad.renderer.amd.dsl import Inst, Reg
|
||||
|
||||
# Special register mappings for disassembly
|
||||
|
|
@ -838,22 +839,11 @@ def _disasm_vop1_sdwa(inst) -> str:
|
|||
|
||||
def _decode_dpp(dpp: int) -> str:
|
||||
"""Decode DPP control value to string."""
|
||||
if dpp < 0x100: return f"quad_perm:[{dpp&3},{(dpp>>2)&3},{(dpp>>4)&3},{(dpp>>6)&3}]"
|
||||
if 0x100 <= dpp <= 0x10f: return f"row_shl:{dpp & 0xf}"
|
||||
if 0x110 <= dpp <= 0x11f: return f"row_shr:{dpp & 0xf}"
|
||||
if 0x120 <= dpp <= 0x12f: return f"row_ror:{dpp & 0xf}"
|
||||
if dpp == 0x130: return "wave_shl:1"
|
||||
if dpp == 0x134: return "wave_rol:1"
|
||||
if dpp == 0x138: return "wave_shr:1"
|
||||
if dpp == 0x13c: return "wave_ror:1"
|
||||
if dpp == 0x140: return "row_mirror"
|
||||
if dpp == 0x141: return "row_half_mirror"
|
||||
if dpp == 0x142: return "row_bcast:15"
|
||||
if dpp == 0x143: return "row_bcast:31"
|
||||
if 0x150 <= dpp <= 0x15f: return f"row_newbcast:{dpp & 0xf}"
|
||||
if 0x160 <= dpp <= 0x16f: return f"row_share:{dpp & 0xf}"
|
||||
if 0x170 <= dpp <= 0x17f: return f"row_xmask:{dpp & 0xf}"
|
||||
return f"dpp:{dpp:#x}"
|
||||
op, arg = decode_dpp16(dpp)
|
||||
if op == "quad_perm": return f"quad_perm:[{','.join(str(x) for x in arg)}]"
|
||||
if op in ("row_mirror", "row_half_mirror"): return op
|
||||
if op == "dpp": return f"dpp:{arg:#x}"
|
||||
return f"{op}:{arg}"
|
||||
|
||||
def _disasm_vop1_dpp(inst) -> str:
|
||||
name = inst.op_name.lower().replace('_e32', '')
|
||||
|
|
|
|||
|
|
@ -12,8 +12,19 @@ ARCH_TO_TARGET:dict[str, list[str]] = {
|
|||
|
||||
TARGET_TO_ARCH:dict[str, str] = {t:arch for arch,targets in ARCH_TO_TARGET.items() for t in targets}
|
||||
|
||||
_DPP16_RANGE_OPS = {0x100: "row_shl", 0x110: "row_shr", 0x120: "row_ror", 0x150: "row_newbcast", 0x160: "row_share", 0x170: "row_xmask"}
|
||||
_DPP16_EXACT_OPS = {0x130: ("wave_shl", 1), 0x134: ("wave_rol", 1), 0x138: ("wave_shr", 1), 0x13c: ("wave_ror", 1),
|
||||
0x140: ("row_mirror", 0), 0x141: ("row_half_mirror", 0), 0x142: ("row_bcast", 15), 0x143: ("row_bcast", 31)}
|
||||
|
||||
def get_target(arch:str) -> str: return ARCH_TO_TARGET[arch][0]
|
||||
|
||||
def decode_dpp16(dpp: int) -> tuple[str, int | tuple[int, int, int, int]]:
|
||||
"""Decode a DPP16 control word into a symbolic operation and argument."""
|
||||
if dpp < 0x100: return "quad_perm", tuple((dpp >> shift) & 0x3 for shift in range(0, 8, 2))
|
||||
if dpp in _DPP16_EXACT_OPS: return _DPP16_EXACT_OPS[dpp]
|
||||
if (base := dpp & 0x1f0) in _DPP16_RANGE_OPS: return _DPP16_RANGE_OPS[base], dpp & 0xf
|
||||
return "dpp", dpp
|
||||
|
||||
def get_mattr(arch:str) -> str:
|
||||
return {"rdna3":"+real-true16,+wavefrontsize32", "rdna4":"+real-true16,+wavefrontsize32", "cdna":"+wavefrontsize64"}[arch]
|
||||
|
||||
|
|
|
|||
187
test/amd/hw/test_dpp.py
Normal file
187
test/amd/hw/test_dpp.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
"""Tests for DPP16 source swizzles.
|
||||
|
||||
These instructions trap in the default wave32 hw helper, so this file uses a
|
||||
minimal wave64 lane-store harness and compares emulator vs hardware directly
|
||||
when USE_HW=1.
|
||||
"""
|
||||
import ctypes, unittest
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import *
|
||||
from tinygrad.helpers import flat_mv
|
||||
from test.amd.hw.helpers import USE_HW, assemble
|
||||
from test.mockgpu.amd.emu import run_asm
|
||||
|
||||
WAVE64 = 64
|
||||
|
||||
def _wave64_code(instructions: list, out_reg: int = 1) -> bytes:
|
||||
return assemble([
|
||||
s_mov_b32(s[80], s[0]),
|
||||
s_mov_b32(s[81], s[1]),
|
||||
v_mov_b32_e32(v[255], v[0]),
|
||||
*instructions,
|
||||
s_load_b64(s[92:93], s[80:81], 0, soffset=NULL),
|
||||
s_waitcnt(0),
|
||||
v_lshlrev_b32_e32(v[240], 2, v[255]),
|
||||
global_store_b32(addr=v[240], data=v[out_reg], saddr=s[92:93], offset=0),
|
||||
s_endpgm(),
|
||||
])
|
||||
|
||||
def _run_wave64_emu(instructions: list, out_reg: int = 1) -> list[int]:
|
||||
out_buf = (ctypes.c_uint32 * WAVE64)(*([0] * WAVE64))
|
||||
args = (ctypes.c_uint64 * 1)(ctypes.addressof(out_buf))
|
||||
code = _wave64_code(instructions, out_reg)
|
||||
kernel_buf = (ctypes.c_char * len(code)).from_buffer_copy(code)
|
||||
rsrc2 = 0x19c | (128 << 15)
|
||||
scratch_size = 0x10000
|
||||
result = run_asm(ctypes.addressof(kernel_buf), len(code), 1, 1, 1, WAVE64, 1, 1, ctypes.addressof(args), rsrc2, scratch_size)
|
||||
assert result == 0, f"run_asm failed with {result}"
|
||||
return list(out_buf)
|
||||
|
||||
def _run_wave64_hw(instructions: list, out_reg: int = 1) -> list[int]:
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.runtime.ops_amd import AMDProgram
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
|
||||
dev = Device["AMD"]
|
||||
compiler = HIPCompiler(dev.arch) # type: ignore[attr-defined]
|
||||
code = _wave64_code(instructions, out_reg)
|
||||
byte_str = ', '.join(f'0x{b:02x}' for b in code)
|
||||
asm_src = f""".text
|
||||
.globl test
|
||||
.p2align 8
|
||||
.type test,@function
|
||||
test:
|
||||
.byte {byte_str}
|
||||
|
||||
.rodata
|
||||
.p2align 6
|
||||
.amdhsa_kernel test
|
||||
.amdhsa_next_free_vgpr 256
|
||||
.amdhsa_next_free_sgpr 96
|
||||
.amdhsa_user_sgpr_kernarg_segment_ptr 1
|
||||
.amdhsa_kernarg_size 8
|
||||
.amdhsa_group_segment_fixed_size 65536
|
||||
.amdhsa_private_segment_fixed_size 65536
|
||||
.amdhsa_enable_private_segment 1
|
||||
.end_amdhsa_kernel
|
||||
|
||||
.amdgpu_metadata
|
||||
---
|
||||
amdhsa.version:
|
||||
- 1
|
||||
- 0
|
||||
amdhsa.kernels:
|
||||
- .name: test
|
||||
.symbol: test.kd
|
||||
.kernarg_segment_size: 8
|
||||
.group_segment_fixed_size: 65536
|
||||
.private_segment_fixed_size: 65536
|
||||
.kernarg_segment_align: 8
|
||||
.wavefront_size: 64
|
||||
.sgpr_count: 96
|
||||
.vgpr_count: 256
|
||||
.max_flat_workgroup_size: 1024
|
||||
...
|
||||
.end_amdgpu_metadata
|
||||
"""
|
||||
lib = compiler.compile(asm_src)
|
||||
prg = AMDProgram(dev, "test", lib) # type: ignore[arg-type]
|
||||
out_gpu = dev.allocator.alloc(WAVE64 * 4)
|
||||
prg(out_gpu, global_size=(1, 1, 1), local_size=(WAVE64, 1, 1), wait=True)
|
||||
out = bytearray(WAVE64 * 4)
|
||||
dev.allocator._copyout(flat_mv(memoryview(out)), out_gpu)
|
||||
return [int.from_bytes(out[i*4:(i+1)*4], 'little') for i in range(WAVE64)]
|
||||
|
||||
def run_wave64(instructions: list, out_reg: int = 1) -> list[int]:
|
||||
emu = _run_wave64_emu(instructions, out_reg)
|
||||
if not USE_HW: return emu
|
||||
hw = _run_wave64_hw(instructions, out_reg)
|
||||
if emu != hw:
|
||||
diffs = [f"lane {i}: emu=0x{e:08x} hw=0x{h:08x}" for i, (e, h) in enumerate(zip(emu, hw)) if e != h]
|
||||
raise AssertionError("Emulator vs Hardware mismatch:\n" + '\n'.join(diffs[:16]))
|
||||
return hw
|
||||
|
||||
class TestDPP16(unittest.TestCase):
|
||||
def _run_copy(self, dpp: int, *, row_mask: int = 0xf, bank_mask: int = 0xf, bc: int = 1, dst_seed: int | None = None) -> list[int]:
|
||||
instructions = [
|
||||
v_mul_u32_u24_e32(v[0], 10, v[255]),
|
||||
v_add_nc_u32_e32(v[0], 3, v[0]),
|
||||
]
|
||||
if dst_seed is not None: instructions.append(v_mov_b32_e32(v[1], dst_seed))
|
||||
instructions += [v_mov_b32_e32(v[2], 0), v_or_b32_e32(v[1], DPP, v[2], vsrc0=v[0], dpp=dpp, row_mask=row_mask, bank_mask=bank_mask, bc=bc)]
|
||||
return run_wave64(instructions)
|
||||
|
||||
def test_quad_perm_reverse(self):
|
||||
out = self._run_copy(0x1b)
|
||||
self.assertEqual(out[0], 33)
|
||||
self.assertEqual(out[1], 23)
|
||||
self.assertEqual(out[2], 13)
|
||||
self.assertEqual(out[3], 3)
|
||||
self.assertEqual(out[4], 73)
|
||||
|
||||
def test_row_shl(self):
|
||||
out = self._run_copy(0x101)
|
||||
self.assertEqual(out[0], 13)
|
||||
self.assertEqual(out[7], 83)
|
||||
self.assertEqual(out[14], 153)
|
||||
self.assertEqual(out[15], 0)
|
||||
self.assertEqual(out[16], 173)
|
||||
|
||||
def test_row_shr(self):
|
||||
out = self._run_copy(0x111)
|
||||
self.assertEqual(out[0], 0)
|
||||
self.assertEqual(out[1], 3)
|
||||
self.assertEqual(out[8], 73)
|
||||
self.assertEqual(out[15], 143)
|
||||
self.assertEqual(out[16], 0)
|
||||
self.assertEqual(out[17], 163)
|
||||
|
||||
def test_row_ror(self):
|
||||
out = self._run_copy(0x121)
|
||||
self.assertEqual(out[0], 153)
|
||||
self.assertEqual(out[1], 3)
|
||||
self.assertEqual(out[15], 143)
|
||||
self.assertEqual(out[16], 313)
|
||||
|
||||
def test_row_mirror(self):
|
||||
out = self._run_copy(0x140)
|
||||
self.assertEqual(out[0], 153)
|
||||
self.assertEqual(out[5], 103)
|
||||
self.assertEqual(out[8], 73)
|
||||
self.assertEqual(out[16], 313)
|
||||
|
||||
def test_row_half_mirror(self):
|
||||
out = self._run_copy(0x141)
|
||||
self.assertEqual(out[0], 73)
|
||||
self.assertEqual(out[7], 3)
|
||||
self.assertEqual(out[8], 153)
|
||||
self.assertEqual(out[15], 83)
|
||||
self.assertEqual(out[16], 233)
|
||||
|
||||
def test_row_mask(self):
|
||||
out = self._run_copy(0x101, row_mask=0x5, dst_seed=0xDEADBEEF)
|
||||
self.assertEqual(out[0], 13)
|
||||
self.assertEqual(out[15], 0)
|
||||
self.assertEqual(out[16], 0xDEADBEEF)
|
||||
self.assertEqual(out[32], 333)
|
||||
self.assertEqual(out[47], 0)
|
||||
self.assertEqual(out[48], 0xDEADBEEF)
|
||||
|
||||
def test_bank_mask(self):
|
||||
out = self._run_copy(0x101, bank_mask=0x5, dst_seed=0xDEADBEEF)
|
||||
self.assertEqual(out[0], 13)
|
||||
self.assertEqual(out[3], 43)
|
||||
self.assertEqual(out[4], 0xDEADBEEF)
|
||||
self.assertEqual(out[8], 93)
|
||||
self.assertEqual(out[12], 0xDEADBEEF)
|
||||
|
||||
class TestVOPCDPP16(unittest.TestCase):
|
||||
def test_row_bcast15_materializes_vcc(self):
|
||||
out = run_wave64([
|
||||
v_mov_b32_e32(v[0], v[255]),
|
||||
v_cmp_eq_u32_e32(DPP, v[0], vsrc0=v[0], dpp=0x142, row_mask=0xf, bank_mask=0xf, bc=1),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_mov_b32_e32(v[3], 1),
|
||||
v_cndmask_b32_e32(v[1], v[2], v[3]),
|
||||
])
|
||||
for lane in (0, 16, 32, 48): self.assertEqual(out[lane], 1)
|
||||
for lane in (1, 15, 31, 47, 63): self.assertEqual(out[lane], 0)
|
||||
|
|
@ -833,8 +833,6 @@ class TestDsPermute(unittest.TestCase):
|
|||
src_lane = lane ^ 1
|
||||
expected = src_lane + 100
|
||||
self.assertEqual(st.vgpr[lane][2], expected, f"lane {lane}: expected v[1] from lane {src_lane} = {expected}, got {st.vgpr[lane][2]}")
|
||||
|
||||
|
||||
class TestDSSubDword(unittest.TestCase):
|
||||
"""Tests for sub-dword DS operations (ds_store_b16, ds_store_b16_d16_hi)."""
|
||||
|
||||
|
|
|
|||
129
test/amd/hw/test_rdna4_permlane_var.py
Normal file
129
test/amd/hw/test_rdna4_permlane_var.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
"""RDNA4 V_PERMLANE16_VAR_B32 / V_PERMLANEX16_VAR_B32 coverage.
|
||||
|
||||
Exercises the generated pcode path end-to-end in the emulator and compares against
|
||||
real RDNA4 hardware when USE_HW=1.
|
||||
"""
|
||||
import ctypes, unittest
|
||||
import tinygrad.runtime.autogen.amd.rdna4.ins as r4
|
||||
from tinygrad.helpers import flat_mv
|
||||
from tinygrad.renderer.amd.dsl import NULL
|
||||
from test.amd.hw.helpers import USE_HW, assemble
|
||||
from test.mockgpu.amd.emu import run_asm
|
||||
|
||||
LANES = 32
|
||||
|
||||
def _code(instructions: list, out_reg: int = 2) -> bytes:
|
||||
return assemble([
|
||||
r4.s_mov_b32(r4.s[80], r4.s[0]),
|
||||
r4.s_mov_b32(r4.s[81], r4.s[1]),
|
||||
r4.v_mov_b32_e32(r4.v[255], r4.v[0]),
|
||||
*instructions,
|
||||
r4.s_load_b64(r4.s[92:93], r4.s[80:81], soffset=NULL),
|
||||
r4.s_wait_kmcnt(simm16=0),
|
||||
r4.v_lshlrev_b32_e32(r4.v[240], 2, r4.v[255]),
|
||||
r4.v_mov_b32_e32(r4.v[241], 0),
|
||||
r4.global_store_b32(vaddr=r4.v[240:241], saddr=r4.s[92:93], vsrc=r4.v[out_reg]),
|
||||
r4.s_endpgm(),
|
||||
])
|
||||
|
||||
def _run_emu(instructions: list, out_reg: int = 2) -> list[int]:
|
||||
out_buf = (ctypes.c_uint32 * LANES)(*([0] * LANES))
|
||||
args = (ctypes.c_uint64 * 1)(ctypes.addressof(out_buf))
|
||||
code = _code(instructions, out_reg)
|
||||
kernel_buf = (ctypes.c_char * len(code)).from_buffer_copy(code)
|
||||
result = run_asm(ctypes.addressof(kernel_buf), len(code), 1, 1, 1, LANES, 1, 1, ctypes.addressof(args), arch='rdna4')
|
||||
assert result == 0, f"run_asm failed with {result}"
|
||||
return list(out_buf)
|
||||
|
||||
def _run_hw(instructions: list, out_reg: int = 2) -> list[int]:
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.runtime.ops_amd import AMDProgram
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
|
||||
dev = Device['AMD']
|
||||
if not dev.arch.startswith('gfx12'): raise unittest.SkipTest('requires RDNA4 hardware')
|
||||
compiler = HIPCompiler(dev.arch)
|
||||
code = _code(instructions, out_reg)
|
||||
byte_str = ', '.join(f'0x{b:02x}' for b in code)
|
||||
asm_src = f""".text
|
||||
.globl test
|
||||
.p2align 8
|
||||
.type test,@function
|
||||
test:
|
||||
.byte {byte_str}
|
||||
|
||||
.rodata
|
||||
.p2align 6
|
||||
.amdhsa_kernel test
|
||||
.amdhsa_next_free_vgpr 256
|
||||
.amdhsa_next_free_sgpr 96
|
||||
.amdhsa_wavefront_size32 1
|
||||
.amdhsa_user_sgpr_kernarg_segment_ptr 1
|
||||
.amdhsa_kernarg_size 8
|
||||
.amdhsa_group_segment_fixed_size 65536
|
||||
.amdhsa_private_segment_fixed_size 65536
|
||||
.amdhsa_enable_private_segment 1
|
||||
.end_amdhsa_kernel
|
||||
|
||||
.amdgpu_metadata
|
||||
---
|
||||
amdhsa.version:
|
||||
- 1
|
||||
- 0
|
||||
amdhsa.kernels:
|
||||
- .name: test
|
||||
.symbol: test.kd
|
||||
.kernarg_segment_size: 8
|
||||
.group_segment_fixed_size: 65536
|
||||
.private_segment_fixed_size: 65536
|
||||
.kernarg_segment_align: 8
|
||||
.wavefront_size: 32
|
||||
.sgpr_count: 96
|
||||
.vgpr_count: 256
|
||||
.max_flat_workgroup_size: 1024
|
||||
...
|
||||
.end_amdgpu_metadata
|
||||
"""
|
||||
lib = compiler.compile(asm_src)
|
||||
prg = AMDProgram(dev, 'test', lib)
|
||||
out_gpu = dev.allocator.alloc(LANES * 4)
|
||||
prg(out_gpu, global_size=(1, 1, 1), local_size=(LANES, 1, 1), wait=True)
|
||||
out = bytearray(LANES * 4)
|
||||
dev.allocator._copyout(flat_mv(memoryview(out)), out_gpu)
|
||||
return [int.from_bytes(out[i*4:(i+1)*4], 'little') for i in range(LANES)]
|
||||
|
||||
def run_rdna4(instructions: list, out_reg: int = 2) -> list[int]:
|
||||
emu = _run_emu(instructions, out_reg)
|
||||
if not USE_HW: return emu
|
||||
hw = _run_hw(instructions, out_reg)
|
||||
if emu != hw:
|
||||
diffs = [f"lane {i}: emu=0x{e:08x} hw=0x{h:08x}" for i, (e, h) in enumerate(zip(emu, hw)) if e != h]
|
||||
raise AssertionError("Emulator vs Hardware mismatch:\n" + '\n'.join(diffs[:16]))
|
||||
return hw
|
||||
|
||||
class TestPermlaneVarRDNA4(unittest.TestCase):
|
||||
def test_v_permlane16_var_b32_reverse(self):
|
||||
out = run_rdna4([
|
||||
r4.v_mov_b32_e32(r4.v[0], r4.v[255]),
|
||||
r4.v_xor_b32_e32(r4.v[1], 15, r4.v[255]),
|
||||
r4.v_permlane16_var_b32(r4.v[2], r4.v[0], r4.v[1]),
|
||||
])
|
||||
self.assertEqual(out[0], 15)
|
||||
self.assertEqual(out[5], 10)
|
||||
self.assertEqual(out[15], 0)
|
||||
self.assertEqual(out[16], 31)
|
||||
self.assertEqual(out[21], 26)
|
||||
self.assertEqual(out[31], 16)
|
||||
|
||||
def test_v_permlanex16_var_b32_cross_row(self):
|
||||
out = run_rdna4([
|
||||
r4.v_mov_b32_e32(r4.v[0], r4.v[255]),
|
||||
r4.v_mov_b32_e32(r4.v[1], r4.v[255]),
|
||||
r4.v_permlanex16_var_b32(r4.v[2], r4.v[0], r4.v[1]),
|
||||
])
|
||||
self.assertEqual(out[0], 16)
|
||||
self.assertEqual(out[5], 21)
|
||||
self.assertEqual(out[15], 31)
|
||||
self.assertEqual(out[16], 0)
|
||||
self.assertEqual(out[21], 5)
|
||||
self.assertEqual(out[31], 15)
|
||||
35
test/amd/hw/test_vinterp.py
Normal file
35
test/amd/hw/test_vinterp.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
"""Tests for VINTERP instructions."""
|
||||
import unittest
|
||||
from test.amd.hw.helpers import *
|
||||
|
||||
class TestVInterp(unittest.TestCase):
|
||||
def test_v_interp_p10_f32(self):
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[10], v[255]),
|
||||
v_cvt_f32_u32_e32(v[1], v[10]),
|
||||
s_mov_b32(s[0], f2i(100.0)),
|
||||
v_add_f32_e32(v[1], s[0], v[1]),
|
||||
v_cvt_f32_u32_e32(v[3], v[10]),
|
||||
s_mov_b32(s[1], f2i(10.0)),
|
||||
v_add_f32_e32(v[3], s[1], v[3]),
|
||||
s_mov_b32(s[2], f2i(2.0)),
|
||||
v_interp_p10_f32(v[4], v[1], s[2], v[3]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=8)
|
||||
for lane in range(4): self.assertAlmostEqual(i2f(st.vgpr[lane][4]), 212.0, places=5)
|
||||
for lane in range(4, 8): self.assertAlmostEqual(i2f(st.vgpr[lane][4]), 224.0, places=5)
|
||||
|
||||
def test_v_interp_p10_f16_f32(self):
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[10], v[255]),
|
||||
v_cvt_f32_u32_e32(v[11], v[10]),
|
||||
v_cvt_f16_f32_e32(v[1], v[11]),
|
||||
s_mov_b32(s[0], f2i(10.0)),
|
||||
v_add_f32_e32(v[12], s[0], v[11]),
|
||||
v_cvt_f16_f32_e32(v[3], v[12]),
|
||||
s_mov_b32(s[1], f2i(2.0)),
|
||||
v_interp_p10_f16_f32(v[4], v[1], s[1], v[3]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=8)
|
||||
for lane in range(4): self.assertAlmostEqual(i2f(st.vgpr[lane][4]), 12.0, places=5)
|
||||
for lane in range(4, 8): self.assertAlmostEqual(i2f(st.vgpr[lane][4]), 24.0, places=5)
|
||||
|
|
@ -30,6 +30,17 @@ class TestBasicArithmetic(unittest.TestCase):
|
|||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 8.0, places=5)
|
||||
|
||||
def test_v_add_f32_dpp_row_shl(self):
|
||||
"""V_ADD_F32 DPP row_shl swizzles src0 before the add."""
|
||||
instructions = [
|
||||
v_cvt_f32_u32_e32(v[0], v[255]),
|
||||
v_add_f32_e32(v[1], DPP, v[0], vsrc0=v[0], dpp=0x101, row_mask=0xf, bank_mask=0xf, bc=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=16)
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][1]), 1.0, places=5)
|
||||
self.assertAlmostEqual(i2f(st.vgpr[1][1]), 3.0, places=5)
|
||||
self.assertAlmostEqual(i2f(st.vgpr[14][1]), 29.0, places=5)
|
||||
|
||||
def test_v_fmac_f32(self):
|
||||
"""V_FMAC_F32: d = d + a*b using inline constants."""
|
||||
instructions = [
|
||||
|
|
|
|||
|
|
@ -20,6 +20,28 @@ class TestFMA(unittest.TestCase):
|
|||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][3]), 9.0, places=5)
|
||||
|
||||
def test_v_mullit_f32_basic(self):
|
||||
"""V_MULLIT_F32 multiplies when the guard input is valid."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], f2i(2.0)),
|
||||
v_mov_b32_e32(v[1], f2i(3.0)),
|
||||
v_mov_b32_e32(v[2], f2i(1.0)),
|
||||
v_mullit_f32(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][3]), 6.0, places=5)
|
||||
|
||||
def test_v_mullit_f32_invalid_guard(self):
|
||||
"""V_MULLIT_F32 returns -MAX_FLOAT_F32 when the guard input is non-positive."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], f2i(2.0)),
|
||||
v_mov_b32_e32(v[1], f2i(3.0)),
|
||||
v_mov_b32_e32(v[2], f2i(0.0)),
|
||||
v_mullit_f32(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][3], 0xFF7FFFFF)
|
||||
|
||||
def test_v_fma_f32_negative(self):
|
||||
"""V_FMA_F32 with negative multiplier."""
|
||||
instructions = [
|
||||
|
|
@ -1592,8 +1614,7 @@ class TestModifierInteractions(unittest.TestCase):
|
|||
self.assertEqual(st.vgpr[0][2], 0x80000000, "-|(-0.0)| = -0.0")
|
||||
|
||||
def test_clamp_with_nan(self):
|
||||
"""Clamp with NaN input should still produce NaN."""
|
||||
import math
|
||||
"""Clamp with NaN input saturates to 0 on RDNA3 hardware."""
|
||||
quiet_nan = 0x7fc00000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], quiet_nan),
|
||||
|
|
@ -1601,7 +1622,7 @@ class TestModifierInteractions(unittest.TestCase):
|
|||
VOP3(VOP3Op.V_ADD_F32, vdst=v[1], src0=v[0], src1=0.0, clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertTrue(math.isnan(i2f(st.vgpr[0][1])))
|
||||
self.assertEqual(st.vgpr[0][1], 0)
|
||||
|
||||
def test_omod_ignored(self):
|
||||
"""OMOD field is ignored on RDNA3 hardware."""
|
||||
|
|
@ -3605,32 +3626,30 @@ class TestPermlane(unittest.TestCase):
|
|||
"""V_PERMLANE16_B32 broadcast lane 0 to all lanes in row."""
|
||||
# lanesel = all zeros -> all positions read from lane 0 within row
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0xCAFEBABE), # source data
|
||||
v_mov_b32_e32(v[0], v[255]),
|
||||
s_mov_b32(s[0], 0), # lanesel low = 0 (all read lane 0)
|
||||
s_mov_b32(s[1], 0), # lanesel high = 0
|
||||
v_permlane16_b32(v[1], v[0], s[0], s[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=4)
|
||||
# All lanes read from lane 0 of their row
|
||||
for lane in range(4):
|
||||
self.assertEqual(st.vgpr[lane][1], 0xCAFEBABE)
|
||||
st = run_program(instructions, n_lanes=32)
|
||||
for lane in range(16): self.assertEqual(st.vgpr[lane][1], 0)
|
||||
for lane in range(16, 32): self.assertEqual(st.vgpr[lane][1], 16)
|
||||
|
||||
def test_v_permlanex16_b32_identity(self):
|
||||
"""V_PERMLANEX16_B32 cross-row read with identity selection."""
|
||||
# In wave32: row 0 (lanes 0-15) reads from row 1 (lanes 16-31) and vice versa
|
||||
# With single lane in row 0, it reads from lane 0 of row 1 (lane 16)
|
||||
# But lane 16 doesn't exist in 1-lane test, so use 32 lanes
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0x11111111), # All lanes have this initially
|
||||
v_mov_b32_e32(v[0], v[255]),
|
||||
s_mov_b32(s[0], 0x76543210), # lanesel low
|
||||
s_mov_b32(s[1], 0xFEDCBA98), # lanesel high
|
||||
v_permlanex16_b32(v[1], v[0], s[0], s[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=32)
|
||||
# Lane 0 in row 0 reads from lane 0 of row 1 (lane 16)
|
||||
self.assertEqual(st.vgpr[0][1], 0x11111111)
|
||||
# Lane 16 in row 1 reads from lane 0 of row 0 (lane 0)
|
||||
self.assertEqual(st.vgpr[16][1], 0x11111111)
|
||||
self.assertEqual(st.vgpr[0][1], 16)
|
||||
self.assertEqual(st.vgpr[5][1], 21)
|
||||
self.assertEqual(st.vgpr[15][1], 31)
|
||||
self.assertEqual(st.vgpr[16][1], 0)
|
||||
self.assertEqual(st.vgpr[21][1], 5)
|
||||
self.assertEqual(st.vgpr[31][1], 15)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from tinygrad.uop.ops import UOp, Ops
|
|||
from test.mockgpu.amd.emu import parse_pcode
|
||||
from test.mockgpu.amd.pcode import parse_expr
|
||||
from tinygrad.runtime.autogen.amd.rdna3.str_pcode import PCODE
|
||||
from tinygrad.runtime.autogen.amd.rdna3.enum import VOP1Op, VOP2Op, SOP2Op, DSOp
|
||||
from tinygrad.runtime.autogen.amd.rdna3.enum import VOP1Op, VOP2Op, SOP2Op, DSOp, GLOBALOp
|
||||
|
||||
def _srcs():
|
||||
"""Create minimal source variables for pcode parsing."""
|
||||
|
|
@ -113,6 +113,7 @@ class TestParseExpr(unittest.TestCase):
|
|||
result = parse_expr('cond ? a : b', vrs)
|
||||
self.assertEqual(result.op, Ops.WHERE)
|
||||
|
||||
|
||||
class TestForLoopParsing(unittest.TestCase):
|
||||
"""Test for loop parsing (CLZ/CTZ patterns)."""
|
||||
|
||||
|
|
@ -164,6 +165,20 @@ class TestForLoopParsing(unittest.TestCase):
|
|||
class TestDSPcodePatterns(unittest.TestCase):
|
||||
"""Test DS instruction pcode patterns."""
|
||||
|
||||
def test_global_atomic_add_f32_parsing(self):
|
||||
"""Test GLOBAL_ATOMIC_ADD_F32 keeps memory values in float dtype."""
|
||||
vmem = UOp(Ops.PARAM, dtypes.uint32.ptr(1024), arg=2)
|
||||
srcs = {
|
||||
'ADDR': UOp.const(dtypes.uint64, 0),
|
||||
'DATA': UOp.const(dtypes.uint32, 0x3f800000),
|
||||
'_vmem': vmem,
|
||||
}
|
||||
|
||||
_, assigns = parse_pcode(PCODE[GLOBALOp.GLOBAL_ATOMIC_ADD_F32], srcs)
|
||||
mem_write = next(val for dest, val in assigns if dest == 'MEM[ADDR].f32')
|
||||
self.assertEqual(mem_write[1].op, Ops.ADD) # type: ignore[index]
|
||||
self.assertEqual(mem_write[1].dtype, dtypes.float32) # type: ignore[index]
|
||||
|
||||
def test_ds_load_b32_pcode(self):
|
||||
"""Test DS_LOAD_B32 pcode is parseable."""
|
||||
pcode = PCODE.get(DSOp.DS_LOAD_B32)
|
||||
|
|
@ -285,6 +300,47 @@ class TestConditionalParsing(unittest.TestCase):
|
|||
# Result should be a WHERE (ternary becomes WHERE)
|
||||
self.assertEqual(val.op, Ops.WHERE)
|
||||
|
||||
class TestConcatWidthParsing(unittest.TestCase):
|
||||
"""Test that bit extracts keep the right width for concat/unary ops."""
|
||||
|
||||
def test_permlanex16_altrow_concat(self):
|
||||
for row, expected in [(0, 1), (1, 0), (2, 3), (3, 2)]:
|
||||
parsed = parse_expr('{ row[1], ~row[0] }', {'row': UOp.const(dtypes.uint32, row)})
|
||||
self.assertEqual(parsed.simplify().arg, expected)
|
||||
|
||||
def test_permlane64_altlane_concat(self):
|
||||
for lane, expected in [(0, 32), (1, 33), (31, 63), (32, 0), (63, 31)]:
|
||||
parsed = parse_expr('{ ~lane[5], lane[4:0] }', {'lane': UOp.const(dtypes.uint32, lane)})
|
||||
self.assertEqual(parsed.simplify().arg, expected)
|
||||
|
||||
def test_permlane64_wave64_pcode_indices(self):
|
||||
vgpr = UOp(Ops.PARAM, dtypes.uint32.ptr(256), arg=0)
|
||||
srcs = {
|
||||
'SRC0': UOp.const(dtypes.uint32, 0),
|
||||
'VDST': UOp.const(dtypes.uint32, 1),
|
||||
'EXEC_LO': UOp.const(dtypes.uint32, 0xFFFFFFFF),
|
||||
'EXEC': UOp.const(dtypes.uint64, 0xFFFFFFFFFFFFFFFF),
|
||||
'_vgpr': vgpr,
|
||||
'_wave_size': 64,
|
||||
'S0': UOp.const(dtypes.uint32, 0),
|
||||
'S1': UOp.const(dtypes.uint32, 0),
|
||||
'S2': UOp.const(dtypes.uint32, 0),
|
||||
}
|
||||
|
||||
def load_idx(v: UOp) -> int:
|
||||
simp = v.simplify()
|
||||
self.assertEqual(simp.op, Ops.LOAD)
|
||||
self.assertEqual(simp.src[0].op, Ops.INDEX)
|
||||
idx = simp.src[0].src[1].simplify()
|
||||
self.assertEqual(idx.op, Ops.CONST)
|
||||
return idx.arg
|
||||
|
||||
_, assigns = parse_pcode(PCODE[VOP1Op.V_PERMLANE64_B32_E32], srcs)
|
||||
self.assertEqual(len(assigns), 64)
|
||||
for lane, (dst_idx, src_idx) in {0: (64, 32), 31: (95, 63), 32: (96, 0), 63: (127, 31)}.items():
|
||||
self.assertEqual(assigns[lane][1][0].simplify().arg, dst_idx) # type: ignore[index]
|
||||
self.assertEqual(load_idx(assigns[lane][1][1]), src_idx) # type: ignore[index]
|
||||
|
||||
class TestAllPcode(unittest.TestCase):
|
||||
"""Test that all pcode from all architectures can be parsed."""
|
||||
|
||||
|
|
|
|||
|
|
@ -67,6 +67,7 @@ from tinygrad.runtime.autogen.amd.rdna4 import ins as ir4
|
|||
from tinygrad.runtime.autogen.amd.cdna import ins as irc
|
||||
from tinygrad.renderer.amd.dsl import VCC_LO, EXEC_LO, SCC, ttmp
|
||||
from tinygrad.runtime.autogen.amd.common import Fmt, OpType
|
||||
from test.amd.helpers import decode_dpp16
|
||||
from test.mockgpu.amd.pcode import parse_block, _FUNCS, _set_bits, _val_to_bits
|
||||
|
||||
MASK32 = 0xFFFFFFFF
|
||||
|
|
@ -345,7 +346,7 @@ def parse_pcode(pcode: str, srcs: dict[str, UOp | int] | None = None) -> tuple[d
|
|||
# TODO: pcode.py should tokenize full pcode string instead of line-by-line, then this hack can be removed
|
||||
lines: list[str] = []
|
||||
for l in raw_lines:
|
||||
if lines and lines[-1].endswith('&&'): lines[-1] = lines[-1] + ' ' + l
|
||||
if lines and re.search(r'(&&|\|\||[&|+\-*/^])\s*$', lines[-1]): lines[-1] = lines[-1] + ' ' + l
|
||||
else: lines.append(l)
|
||||
_, final, _ = parse_block(lines, 0, env, assigns=assigns)
|
||||
sliced = set(d.split('[')[0] for d, _ in assigns if '[' in d)
|
||||
|
|
@ -638,11 +639,13 @@ class _Ctx:
|
|||
src0_reg = (src0_off >= _c(256)).where(src0_off - _c(256), _c(0)) # VGPR index or 0
|
||||
src1_off = self.inst_field(type(inst).src1) if hasattr(type(inst), 'src1') else None
|
||||
src2_off = self.inst_field(type(inst).src2) if hasattr(type(inst), 'src2') else None
|
||||
src1_reg = (src1_off >= _c(256)).where(src1_off - _c(256), src1_off) if src1_off is not None else _c(0)
|
||||
src2_reg = (src2_off >= _c(256)).where(src2_off - _c(256), src2_off) if src2_off is not None else _c(0)
|
||||
exec_val = self.rexec()
|
||||
exec_lo = exec_val.cast(dtypes.uint32) if exec_val.dtype == dtypes.uint64 else exec_val
|
||||
srcs = {
|
||||
'SRC0': src0_reg, 'VDST': vdst_off, 'EXEC_LO': exec_lo, 'EXEC': exec_val if exec_val.dtype == dtypes.uint64 else exec_val.cast(dtypes.uint64),
|
||||
'_vgpr': self.vgpr, '_wave_size': self.wave_size,
|
||||
'_vgpr': self.vgpr, '_wave_size': self.wave_size, 'SRC1': src1_reg, 'SRC2': src2_reg,
|
||||
'S0': self.rsrc_dyn(src0_off, _c(0, dtypes.int)) if 'WRITELANE' in op_name else src0_reg,
|
||||
'S1': self.rsrc_dyn(src1_off, _c(0, dtypes.int)) if src1_off is not None else _c(0),
|
||||
'S2': self.rsrc_dyn(src2_off, _c(0, dtypes.int)) if src2_off is not None else _c(0),
|
||||
|
|
@ -662,10 +665,11 @@ class _Ctx:
|
|||
vcc_reg = sdst_reg if sdst_reg is not None else VCC_LO.offset
|
||||
if 'VCC' not in srcs: srcs['VCC'] = self.rmask(_c(vcc_reg))
|
||||
srcs.update({'EXEC': exec_mask, 'SCC': self.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane, 'VDST': vdst_reg,
|
||||
'ROUND_MODE': _c(0), 'ROUND_TOWARD_ZERO': _c(0), 'ROUND_NEAREST_EVEN': _c(0), '_vgpr': self.vgpr, '_wave_size': self.wave_size,
|
||||
# CDNA SDWA byte/word select constants (E32 always uses BYTE0/WORD0 defaults)
|
||||
'SDWA_SRC0_SEL': _c(0), 'BYTE0': _c(0), 'BYTE1': _c(1), 'BYTE2': _c(2), 'BYTE3': _c(3),
|
||||
'WORD0': _c(0), 'WORD1': _c(1)}) # rounding mode and SDWA constants
|
||||
'ROUND_MODE': _c(0), 'ROUND_TOWARD_ZERO': _c(0), 'ROUND_NEAREST_EVEN': _c(0), '_vgpr': self.vgpr, '_wave_size': self.wave_size,
|
||||
'MAX_FLOAT_F32': UOp.const(dtypes.float32, 3.4028234663852886e38),
|
||||
# CDNA SDWA byte/word select constants (E32 always uses BYTE0/WORD0 defaults)
|
||||
'SDWA_SRC0_SEL': _c(0), 'BYTE0': _c(0), 'BYTE1': _c(1), 'BYTE2': _c(2), 'BYTE3': _c(3),
|
||||
'WORD0': _c(0), 'WORD1': _c(1)}) # rounding mode and SDWA constants
|
||||
_, assigns = parse_pcode(pcode, srcs)
|
||||
|
||||
# For integer ops with clamp, compute overflow using wide arithmetic
|
||||
|
|
@ -715,6 +719,10 @@ class _Ctx:
|
|||
new_vcc = _set_lane_bit(old_vcc, lane, val, exec_mask)
|
||||
raw_stores.extend([('vcc', s) for s in self.wmask(_c(VCC_LO.offset), new_vcc)])
|
||||
elif dest.startswith('D0'):
|
||||
dest_suffix = re.match(r'D0\.(\w+)', dest)
|
||||
if dest_suffix is not None:
|
||||
target_dt = {'u16': dtypes.uint16, 'i16': dtypes.int16, 'f16': dtypes.half}.get(dest_suffix.group(1))
|
||||
if target_dt is not None and val.dtype != target_dt: val = val.cast(target_dt)
|
||||
if (slice_match := re.match(r'D0\[(\d+)\s*:\s*(\d+)\]', dest)):
|
||||
d0_hi_bit, d0_lo_bit = int(slice_match.group(1)), int(slice_match.group(2))
|
||||
if d0_hi_bit != 31 or d0_lo_bit != 0:
|
||||
|
|
@ -727,7 +735,8 @@ class _Ctx:
|
|||
# For integer ops with clamp, use pre-computed saturated value; for floats, clamp to [0,1]
|
||||
if int_saturate is not None: val = int_saturate
|
||||
elif clmp and val.dtype in (dtypes.float32, dtypes.half, dtypes.float64):
|
||||
val = val.maximum(UOp.const(val.dtype, 0.0)).minimum(UOp.const(val.dtype, 1.0))
|
||||
clamped = val.maximum(UOp.const(val.dtype, 0.0)).minimum(UOp.const(val.dtype, 1.0))
|
||||
val = _FUNCS['isNAN'](val).where(UOp.const(val.dtype, 0.0), clamped)
|
||||
if val.dtype in (dtypes.uint64, dtypes.int64, dtypes.float64):
|
||||
lo, hi = _split64(val)
|
||||
raw_stores.extend([('vgpr', self.wvgpr_dyn(vdst_reg, lane, lo, exec_mask)),
|
||||
|
|
@ -915,6 +924,45 @@ def _sdwa_write(old: UOp, val: UOp, dst_sel: UOp, dst_unused: UOp) -> UOp:
|
|||
# For PAD and SEXT, unused bits are zero (PAD) or sign-extended (SEXT). For DWORD, just return val.
|
||||
return dst_sel.eq(_c(6)).where(val, dst_unused.eq(_c(2)).where(preserved, placed))
|
||||
|
||||
def _dpp_quad_sel(quad_lane: UOp, sels: tuple[int, int, int, int]) -> UOp:
|
||||
sel = _c(sels[0], dtypes.int)
|
||||
for i, src in enumerate(sels[1:], start=1): sel = quad_lane.eq(_c(i, dtypes.int)).where(_c(src, dtypes.int), sel)
|
||||
return sel
|
||||
|
||||
def _dpp16_ctrl(lane: UOp, dpp: int, row_mask: int, bank_mask: int, wave_size: int) -> tuple[UOp, UOp, UOp]:
|
||||
"""Return (src_lane, row/bank enabled, in-bounds) for a DPP16 swizzle."""
|
||||
lane_i = lane.cast(dtypes.int)
|
||||
row_base, lane_in_row = lane_i & _c(~15, dtypes.int), lane_i & _c(15, dtypes.int)
|
||||
row = lane_i // _c(16, dtypes.int)
|
||||
bank = lane_in_row >> _c(2, dtypes.int)
|
||||
enabled = (((_c(row_mask) >> row.cast(dtypes.uint32)) & _c(1)).ne(_c(0)) &
|
||||
(((_c(bank_mask) >> bank.cast(dtypes.uint32)) & _c(1)).ne(_c(0))))
|
||||
op, arg = decode_dpp16(dpp)
|
||||
src_lane, valid = lane_i, UOp.const(dtypes.bool, True)
|
||||
|
||||
if op == 'quad_perm': src_lane = (lane_i & _c(~3, dtypes.int)) + _dpp_quad_sel(lane_i & _c(3, dtypes.int), arg)
|
||||
elif op == 'row_shl': src_lane, valid = row_base + lane_in_row + _c(arg, dtypes.int), lane_in_row <= _c(15 - arg, dtypes.int)
|
||||
elif op == 'row_shr': src_lane, valid = row_base + lane_in_row - _c(arg, dtypes.int), lane_in_row >= _c(arg, dtypes.int)
|
||||
elif op == 'row_ror': src_lane = row_base + ((lane_in_row - _c(arg, dtypes.int)) & _c(15, dtypes.int))
|
||||
elif op == 'row_mirror': src_lane = row_base + (_c(15, dtypes.int) - lane_in_row)
|
||||
elif op == 'row_half_mirror': src_lane = row_base + ((lane_in_row & _c(8, dtypes.int)) | (_c(7, dtypes.int) - (lane_in_row & _c(7, dtypes.int))))
|
||||
elif op == 'row_bcast': src_lane = row_base
|
||||
elif op == 'wave_shl': src_lane, valid = lane_i + _c(arg, dtypes.int), lane_i < _c(wave_size - arg, dtypes.int)
|
||||
elif op == 'wave_rol': src_lane = (lane_i + _c(arg, dtypes.int)) % _c(wave_size, dtypes.int)
|
||||
elif op == 'wave_shr': src_lane, valid = lane_i - _c(arg, dtypes.int), lane_i >= _c(arg, dtypes.int)
|
||||
elif op == 'wave_ror': src_lane = (lane_i - _c(arg, dtypes.int)) % _c(wave_size, dtypes.int)
|
||||
else: raise NotImplementedError(f"DPP16 control {dpp:#x} ({op}:{arg}) not implemented in emulator")
|
||||
return src_lane, enabled, valid
|
||||
|
||||
def _load_dpp16_src0(ctx: _Ctx, inst, lane: UOp, fallback: UOp) -> UOp:
|
||||
"""Load a DPP16-swizzled src0 value from vsrc0."""
|
||||
src_lane, enabled, valid = _dpp16_ctrl(lane, getattr(inst, 'dpp', 0) or 0, getattr(inst, 'row_mask', 0xf) or 0xf,
|
||||
getattr(inst, 'bank_mask', 0xf) or 0xf, ctx.wave_size)
|
||||
safe_src_lane = (enabled & valid).where(src_lane, _c(0, dtypes.int))
|
||||
swizzled = ctx.rvgpr_dyn(ctx.inst_field(type(inst).vsrc0), safe_src_lane)
|
||||
invalid = UOp.const(fallback.dtype, 0) if getattr(inst, 'bc', 0) else fallback
|
||||
return enabled.where(valid.where(swizzled, invalid), fallback)
|
||||
|
||||
def _compile_sdwa(inst: irc.VOP1_SDWA | irc.VOP2_SDWA | irc.VOP2_SDWA_SDST | irc.VOPC_SDWA_SDST, ctx: _Ctx) -> UOp:
|
||||
"""Compile CDNA SDWA (Sub-Dword Access) VOP1/VOP2/VOPC instructions."""
|
||||
is_vopc = isinstance(inst, irc.VOPC_SDWA_SDST)
|
||||
|
|
@ -998,7 +1046,9 @@ def _compile_sdwa(inst: irc.VOP1_SDWA | irc.VOP2_SDWA | irc.VOP2_SDWA_SDST | irc
|
|||
return UOp.sink(UOp.sink(*stores).end(lane), *ctx.inc_pc())
|
||||
return UOp.sink(*ctx.inc_pc())
|
||||
|
||||
def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VOP1_SDST | ir4.VOP2 | irc.VOP1 | irc.VOP2, ctx: _Ctx) -> UOp:
|
||||
def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP1_DPP16 | ir3.VOP2 | ir3.VOP2_DPP16 |
|
||||
ir4.VOP1 | ir4.VOP1_SDST | ir4.VOP1_DPP16 | ir4.VOP2 | ir4.VOP2_DPP16 |
|
||||
irc.VOP1 | irc.VOP1_DPP16 | irc.VOP2 | irc.VOP2_DPP16, ctx: _Ctx) -> UOp:
|
||||
op_name = _op_name(inst)
|
||||
if op_name in ('V_READFIRSTLANE_B32_E32', 'V_PERMLANE64_B32_E32'): return ctx.compile_lane_pcode(inst.op, inst)
|
||||
# v_accvgpr_mov_b32: ACCVGPR[vdst] = ACCVGPR[src0] (VOP1 encoding, no pcode)
|
||||
|
|
@ -1011,20 +1061,28 @@ def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VO
|
|||
lane, exec_mask, bits = ctx.range(), ctx.rexec(), inst.canonical_op_bits
|
||||
literal = ctx.inst_field(type(inst).literal) if hasattr(type(inst), 'literal') else None # type: ignore[union-attr]
|
||||
is_f64 = 'F64' in op_name and 'B64' not in op_name
|
||||
is_float = any(x in op_name for x in ('F16', 'F32', 'F64'))
|
||||
is_dpp16 = hasattr(type(inst), 'dpp') and hasattr(type(inst), 'vsrc0')
|
||||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||||
write_hi_half = bits['d'] == 16 and (vdst_reg >= _c(128))
|
||||
if isinstance(write_hi_half, UOp): vdst_reg = write_hi_half.where(vdst_reg - _c(128), vdst_reg)
|
||||
elif write_hi_half: vdst_reg -= 128
|
||||
src0_off = None
|
||||
if isinstance(inst, (ir3.VOP1, ir4.VOP1, irc.VOP1)):
|
||||
# Handle VOP1 hi-half source operand (src0 >= v[128] for 16-bit ops)
|
||||
src0_off = ctx.inst_field(type(inst).src0)
|
||||
s0 = ctx.rsrc_dyn(src0_off, lane, bits['s0'], literal, is_f64)
|
||||
if bits['s0'] == 16:
|
||||
d0 = _cond_hi16(write_hi_half, ctx.rvgpr_dyn(vdst_reg, lane))
|
||||
if is_dpp16:
|
||||
s0 = _load_dpp16_src0(ctx, inst, lane, d0)
|
||||
else:
|
||||
src0_off = ctx.inst_field(type(inst).src0)
|
||||
s0 = ctx.rsrc_dyn(src0_off, lane, bits['s0'], literal, is_f64)
|
||||
if bits['s0'] == 16 and not is_dpp16:
|
||||
src0_hi = src0_off >= _c(384)
|
||||
# Only compute hi-half when src0_off >= 384, use guarded index to prevent OOB access
|
||||
src0_reg = src0_hi.where(src0_off - _c(384), _c(0))
|
||||
s0 = src0_hi.where(_hi16(ctx.rvgpr_dyn(src0_reg, lane)), s0)
|
||||
d0 = _cond_hi16(write_hi_half, ctx.rvgpr_dyn(vdst_reg, lane))
|
||||
if is_dpp16 and is_float:
|
||||
s0 = _apply_src_mods(s0, 0, 1 if getattr(inst, 'src0_abs', 0) else 0, 1 if getattr(inst, 'src0_neg', 0) else 0, bits['s0'])
|
||||
srcs:dict[str, UOp | int] = {'S0': s0, 'D0': d0}
|
||||
else:
|
||||
vsrc1_reg = ctx.inst_field(type(inst).vsrc1)
|
||||
|
|
@ -1037,13 +1095,19 @@ def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VO
|
|||
s1 = _cond_hi16(vsrc1_hi, ctx.rvgpr_dyn(vsrc1_actual, lane))
|
||||
d0 = _cond_hi16(write_hi_half, ctx.rvgpr_dyn(vdst_reg, lane)) # FMAC/FMAMK hi-half dest needs hi-half accumulator
|
||||
# Handle VOP2 hi-half src0 operand (src0 >= v[128] for 16-bit ops)
|
||||
src0_off = ctx.inst_field(type(inst).src0)
|
||||
s0 = ctx.rsrc_dyn(src0_off, lane, bits['s0'], literal, is_f64)
|
||||
if bits['s0'] == 16:
|
||||
if is_dpp16:
|
||||
s0 = _load_dpp16_src0(ctx, inst, lane, d0)
|
||||
else:
|
||||
src0_off = ctx.inst_field(type(inst).src0)
|
||||
s0 = ctx.rsrc_dyn(src0_off, lane, bits['s0'], literal, is_f64)
|
||||
if bits['s0'] == 16 and not is_dpp16:
|
||||
src0_hi = src0_off >= _c(384)
|
||||
# Only compute hi-half when src0_off >= 384, use guarded index to prevent OOB access
|
||||
src0_reg = src0_hi.where(src0_off - _c(384), _c(0))
|
||||
s0 = src0_hi.where(_hi16(ctx.rvgpr_dyn(src0_reg, lane)), s0)
|
||||
if is_dpp16 and is_float:
|
||||
s0 = _apply_src_mods(s0, 0, 1 if getattr(inst, 'src0_abs', 0) else 0, 1 if getattr(inst, 'src0_neg', 0) else 0, bits['s0'])
|
||||
s1 = _apply_src_mods(s1, 0, 1 if getattr(inst, 'src1_abs', 0) else 0, 1 if getattr(inst, 'src1_neg', 0) else 0, bits['s1'])
|
||||
srcs = {'S0': s0, 'S1': s1, 'D0': d0}
|
||||
# FMAAK_(DTYPE)_E32 series
|
||||
if 'V_FMAA' in _op_name(inst) or 'V_FMAM' in _op_name(inst):
|
||||
|
|
@ -1051,10 +1115,11 @@ def _compile_vop12(inst: ir3.VOP1 | ir3.VOP1_SDST | ir3.VOP2 | ir4.VOP1 | ir4.VO
|
|||
srcs['SIMM32'] = literal
|
||||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=write_hi_half, src0_off=src0_off)
|
||||
|
||||
def _compile_vopc(inst: ir3.VOPC|ir3.VOP3|ir4.VOPC|ir4.VOP3|irc.VOPC|irc.VOP3, ctx: _Ctx,
|
||||
def _compile_vopc(inst: ir3.VOPC|ir3.VOPC_DPP16|ir3.VOP3|ir4.VOPC|ir4.VOPC_DPP16|ir4.VOP3|irc.VOPC|irc.VOP3, ctx: _Ctx,
|
||||
opsel: int = 0, abs_bits: int = 0, neg_bits: int = 0) -> UOp:
|
||||
exec_mask, op_name, bits = ctx.rexec(), _op_name(inst), inst.canonical_op_bits
|
||||
is_cmpx, is_vopc = 'CMPX' in op_name, hasattr(inst, 'vsrc1') # is_vopc: e32 vs e64
|
||||
is_dpp16 = hasattr(type(inst), 'dpp') and hasattr(type(inst), 'vsrc0')
|
||||
|
||||
# Handle both VOPC (vsrc1) and VOP3 (src1) instruction formats - read operands dynamically
|
||||
if is_vopc:
|
||||
|
|
@ -1077,11 +1142,14 @@ def _compile_vopc(inst: ir3.VOPC|ir3.VOP3|ir4.VOPC|ir4.VOP3|irc.VOPC|irc.VOP3, c
|
|||
is_float, is_f64, pcode = any(x in op_name for x in ('_F32', '_F64', '_F16')), '_F64' in op_name, get_pcode(inst.op)
|
||||
def get_cmp_bit(lane) -> UOp:
|
||||
lc = lane.cast(dtypes.int) if isinstance(lane, UOp) else _c(lane, dtypes.int)
|
||||
s0 = ctx.rsrc_dyn(src0_off, lc, bits['s0'], literal, is_f64)
|
||||
s0 = _load_dpp16_src0(ctx, inst, lc, _c(0)) if is_dpp16 else ctx.rsrc_dyn(src0_off, lc, bits['s0'], literal, is_f64)
|
||||
s1 = _cond_hi16(vsrc1_hi, ctx.rsrc_dyn(src1_off, lc, bits['s1'], literal, is_f64)) if bits['s0'] == 16 \
|
||||
else ctx.rsrc_dyn(src1_off, lc, bits['s1'], literal, is_f64)
|
||||
if bits['s0'] == 16 and opsel: s0, s1 = _apply_opsel(s0, 0, opsel), _apply_opsel(s1, 1, opsel)
|
||||
if is_float:
|
||||
if is_dpp16:
|
||||
s0 = _apply_src_mods(s0, 0, 1 if getattr(inst, 'src0_abs', 0) else 0, 1 if getattr(inst, 'src0_neg', 0) else 0, bits['s0'])
|
||||
s1 = _apply_src_mods(s1, 0, 1 if getattr(inst, 'src1_abs', 0) else 0, 1 if getattr(inst, 'src1_neg', 0) else 0, bits['s1'])
|
||||
s0 = _apply_src_mods(s0, 0, abs_bits, neg_bits, bits['s0'])
|
||||
s1 = _apply_src_mods(s1, 1, abs_bits, neg_bits, bits['s1'])
|
||||
for dest, val in parse_pcode(pcode, {'S0': s0, 'S1': s1, 'laneId': lc, 'D0': UOp.const(dtypes.uint64, 0)})[1]:
|
||||
|
|
@ -1176,6 +1244,19 @@ def _compile_vop3(inst: ir3.VOP3 | ir4.VOP3 | irc.VOP3, ctx: _Ctx) -> UOp:
|
|||
opsel_dst_hi = bool(opsel & 0b1000) and bits['d'] == 16
|
||||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask, opsel_dst_hi=opsel_dst_hi, clmp=getattr(inst, 'clmp', 0))
|
||||
|
||||
def _compile_vinterp(inst: ir3.VINTERP | ir4.VINTERP, ctx: _Ctx) -> UOp:
|
||||
lane, exec_mask = ctx.range(), ctx.rexec()
|
||||
inst_type = type(inst)
|
||||
vdst_reg = ctx.inst_field(inst_type.vdst)
|
||||
src0_off, src1_off, src2_off = ctx.inst_field(inst_type.src0), ctx.inst_field(inst_type.src1), ctx.inst_field(inst_type.src2)
|
||||
src0_reg = (src0_off >= _c(256)).where(src0_off - _c(256), src0_off)
|
||||
src2_reg = (src2_off >= _c(256)).where(src2_off - _c(256), src2_off)
|
||||
srcs = {
|
||||
'SRC0': src0_reg, 'SRC2': src2_reg,
|
||||
'S0': ctx.rsrc_dyn(src0_off, lane), 'S1': ctx.rsrc_dyn(src1_off, lane), 'S2': ctx.rsrc_dyn(src2_off, lane),
|
||||
}
|
||||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask)
|
||||
|
||||
def _compile_vop3sd(inst: ir3.VOP3SD | ir4.VOP3SD | irc.VOP3SD, ctx: _Ctx) -> UOp:
|
||||
exec_mask = ctx.rexec()
|
||||
bits, pcode, ops = inst.canonical_op_bits, get_pcode(inst.op), inst.canonical_operands
|
||||
|
|
@ -1787,7 +1868,7 @@ def _compile_mem_op(inst: ir3.DS|ir3.FLAT|ir3.GLOBAL|ir3.SCRATCH|ir4.DS|ir4.VFLA
|
|||
'DATA2': _u64(ctx.rvgpr_dyn(data1_reg, lane), ctx.rvgpr_dyn(data1_reg + _c(1), lane)) if has_data1 else UOp.const(dtypes.uint64, 0)}
|
||||
# RDNA3 uses ADDR/OFFSET, RDNA4 uses vgpr_a/offset (lowercase) + CalcDsAddr function
|
||||
return {'ADDR': addr, 'ADDR_BASE': addr, 'OFFSET': offset, 'OFFSET0': offset0, 'OFFSET1': offset1, '_lds': mem, 'laneId': lane,
|
||||
'vgpr_a': ctx.rvgpr_dyn(addr_reg, lane), 'offset': offset, **data}
|
||||
'vgpr_a': ctx.rvgpr_dyn(addr_reg, lane), 'offset': offset, 'offset0': offset0, 'offset1': offset1, **data}
|
||||
active = _lane_active(exec_mask, lane)
|
||||
# saddr < 124 means valid SGPR pair, otherwise use 0 (NULL means no saddr contribution)
|
||||
use_saddr = (saddr_reg < _c(124)) if saddr_reg is not None else UOp.const(dtypes.bool, False)
|
||||
|
|
@ -1936,17 +2017,20 @@ def _compile_mubuf(inst: irc.MUBUF, ctx: _Ctx) -> UOp:
|
|||
# Dispatch table: instruction type -> handler function
|
||||
_INST_HANDLERS: dict[type, Callable[..., UOp]] = {
|
||||
ir3.SOPP: _compile_sopp, ir3.SMEM: _compile_smem, ir3.SOP1: _compile_sop, ir3.SOP2: _compile_sop, ir3.SOPC: _compile_sop, ir3.SOPK: _compile_sop,
|
||||
ir3.VOP1: _compile_vop12, ir3.VOP1_SDST: _compile_vop12, ir3.VOP2: _compile_vop12, ir3.VOPC: _compile_vopc, ir3.VOP3: _compile_vop3,
|
||||
ir3.VOP1: _compile_vop12, ir3.VOP1_SDST: _compile_vop12, ir3.VOP1_DPP16: _compile_vop12, ir3.VOP2: _compile_vop12, ir3.VOP2_DPP16: _compile_vop12,
|
||||
ir3.VOPC: _compile_vopc, ir3.VOPC_DPP16: _compile_vopc, ir3.VOP3: _compile_vop3, ir3.VINTERP: _compile_vinterp,
|
||||
ir3.VOP3_SDST: _compile_vop3, ir3.VOP3SD: _compile_vop3sd, ir3.VOP3P: _compile_vop3p, ir3.VOPD: _compile_vopd,
|
||||
ir3.DS: _compile_mem_op, ir3.FLAT: _compile_mem_op, ir3.GLOBAL: _compile_mem_op, ir3.SCRATCH: _compile_mem_op,
|
||||
# RDNA4 instruction classes
|
||||
ir4.SOPP: _compile_sopp, ir4.SMEM: _compile_smem, ir4.SOP1: _compile_sop, ir4.SOP2: _compile_sop, ir4.SOPC: _compile_sop, ir4.SOPK: _compile_sop,
|
||||
ir4.VOP1: _compile_vop12, ir4.VOP1_SDST: _compile_vop12, ir4.VOP2: _compile_vop12, ir4.VOPC: _compile_vopc, ir4.VOP3: _compile_vop3,
|
||||
ir4.VOP1: _compile_vop12, ir4.VOP1_SDST: _compile_vop12, ir4.VOP1_DPP16: _compile_vop12, ir4.VOP2: _compile_vop12, ir4.VOP2_DPP16: _compile_vop12,
|
||||
ir4.VOPC: _compile_vopc, ir4.VOPC_DPP16: _compile_vopc, ir4.VOP3: _compile_vop3, ir4.VINTERP: _compile_vinterp,
|
||||
ir4.VOP3_SDST: _compile_vop3, ir4.VOP3SD: _compile_vop3sd, ir4.VOP3P: _compile_vop3p, ir4.VOPD: _compile_vopd,
|
||||
ir4.DS: _compile_mem_op, ir4.VFLAT: _compile_mem_op, ir4.VGLOBAL: _compile_mem_op, ir4.VSCRATCH: _compile_mem_op,
|
||||
# CDNA instruction classes
|
||||
irc.SOPP: _compile_sopp, irc.SMEM: _compile_smem, irc.SOP1: _compile_sop, irc.SOP2: _compile_sop, irc.SOPC: _compile_sop, irc.SOPK: _compile_sop,
|
||||
irc.VOP1: _compile_vop12, irc.VOP2: _compile_vop12, irc.VOPC: _compile_vopc, irc.VOP3: _compile_vop3,
|
||||
irc.VOP1: _compile_vop12, irc.VOP1_DPP16: _compile_vop12, irc.VOP2: _compile_vop12, irc.VOP2_DPP16: _compile_vop12,
|
||||
irc.VOPC: _compile_vopc, irc.VOP3: _compile_vop3,
|
||||
irc.VOP3_SDST: _compile_vop3, irc.VOP3SD: _compile_vop3sd, irc.VOP3P: _compile_vop3p,
|
||||
irc.VOP1_SDWA: _compile_sdwa, irc.VOP2_SDWA: _compile_sdwa, irc.VOP2_SDWA_SDST: _compile_sdwa, irc.VOPC_SDWA_SDST: _compile_sdwa,
|
||||
irc.DS: _compile_mem_op, irc.FLAT: _compile_mem_op, irc.GLOBAL: _compile_mem_op, irc.SCRATCH: _compile_mem_op,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from tinygrad.dtype import dtypes
|
|||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.uop.decompositions import f2f
|
||||
|
||||
# Type alias for vars dict: stores UOps for variables and tuples for lambda definitions
|
||||
# Type alias for vars dict: stores UOps and tuples for lambda definitions
|
||||
VarVal = UOp | tuple[str, list[str], str]
|
||||
|
||||
def _const(dt, v): return UOp.const(dt, v)
|
||||
|
|
@ -50,6 +50,22 @@ def _extract_bits(val: UOp, hi: int, lo: int) -> UOp:
|
|||
if result.dtype != target_dt: result = result.cast(target_dt)
|
||||
return result
|
||||
|
||||
def _expr_bits(v: UOp) -> int:
|
||||
if v.dtype == dtypes.bool: return 1
|
||||
if v.op in (Ops.AND, Ops.XOR):
|
||||
widths: list[int] = []
|
||||
for src in v.src:
|
||||
if src.op == Ops.CONST and isinstance(src.arg, int) and src.arg > 0 and (src.arg & (src.arg + 1)) == 0:
|
||||
widths.append(src.arg.bit_length())
|
||||
if widths: return max(widths)
|
||||
return v.dtype.bitsize
|
||||
|
||||
def _countbits(v: UOp) -> UOp:
|
||||
dt = dtypes.uint64 if _expr_bits(v) > 32 or v.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32
|
||||
vv, out = v.cast(dt), _u32(0)
|
||||
for i in range(_expr_bits(v)): out = out + ((vv >> _const(dt, i)) & _const(dt, 1)).cast(dtypes.uint32)
|
||||
return out
|
||||
|
||||
def _set_bit(old, pos, val):
|
||||
mask = _u32(1) << pos
|
||||
return (old & (mask ^ _u32(0xFFFFFFFF))) | ((val.cast(dtypes.uint32) & _u32(1)) << pos)
|
||||
|
|
@ -335,6 +351,7 @@ _FUNCS: dict[str, Callable[..., UOp]] = {
|
|||
# System NOPs - these are scheduling hints, no effect on emulation
|
||||
'MIN': lambda a, b: (a < b).where(a, b),
|
||||
's_nop': lambda a: _u32(0),
|
||||
'countbits': _countbits,
|
||||
# Address calculation for memory operations
|
||||
'CalcDsAddr': lambda a, o, *r: a.cast(dtypes.uint32) + o.cast(dtypes.uint32),
|
||||
'CalcGlobalAddr': lambda v, s, *r: v.cast(dtypes.uint64) + s.cast(dtypes.uint64),
|
||||
|
|
@ -389,7 +406,7 @@ def tokenize(s: str) -> list[Token]:
|
|||
if c.isspace():
|
||||
i += 1
|
||||
continue
|
||||
if i + 1 < n and s[i:i+2] in ('+=', '-='):
|
||||
if i + 1 < n and s[i:i+2] in ('+=', '-=', '|=', '&=', '^='):
|
||||
tokens.append(Token('ASSIGN_OP', s[i:i+2]))
|
||||
i += 2
|
||||
continue
|
||||
|
|
@ -503,7 +520,7 @@ class Parser:
|
|||
def unary(self) -> UOp:
|
||||
if self.try_eat_val('~', 'OP'):
|
||||
inner = self.unary()
|
||||
return inner ^ _const(inner.dtype, (1 << (inner.dtype.itemsize * 8)) - 1)
|
||||
return inner ^ _const(inner.dtype, (1 << _expr_bits(inner)) - 1)
|
||||
if self.try_eat_val('!', 'OP'):
|
||||
inner = self.unary()
|
||||
return inner.eq(_const(inner.dtype, 0))
|
||||
|
|
@ -539,7 +556,10 @@ class Parser:
|
|||
self.eat('COMMA')
|
||||
lo = self.parse()
|
||||
self.eat('RBRACE')
|
||||
return (hi.cast(dt:=_BITS_DT.get((s:=lo.dtype.bitsize) * 2, dtypes.uint64)) << _const(dt, s)) | lo.cast(dt)
|
||||
lo_bits, hi_bits = _expr_bits(lo), _expr_bits(hi)
|
||||
total_bits = lo_bits + hi_bits
|
||||
dt = _BITS_DT.get(total_bits, dtypes.uint32 if total_bits <= 32 else dtypes.uint64)
|
||||
return (hi.cast(dt) << _const(dt, lo_bits)) | lo.cast(dt)
|
||||
if self.at('NUM'):
|
||||
num = self.eat('NUM').val
|
||||
if self.try_eat('QUOTE'):
|
||||
|
|
@ -576,8 +596,8 @@ class Parser:
|
|||
if name == 'OVERFLOW_F32': return _const(dtypes.uint32, 0x7F7FFFFF).bitcast(dtypes.float32)
|
||||
if name == 'UNDERFLOW_F64': return _const(dtypes.uint64, 1).bitcast(dtypes.float64)
|
||||
if name == 'OVERFLOW_F64': return _const(dtypes.uint64, 0x7FEFFFFFFFFFFFFF).bitcast(dtypes.float64)
|
||||
if name == 'WAVE32': return _const(dtypes.bool, self.vars.get('_wave_size', 32) <= 32)
|
||||
if name == 'WAVE64': return _const(dtypes.bool, self.vars.get('_wave_size', 32) > 32)
|
||||
if name.lower() == 'wave32': return _const(dtypes.bool, self.vars.get('_wave_size', 32) <= 32)
|
||||
if name.lower() == 'wave64': return _const(dtypes.bool, self.vars.get('_wave_size', 32) > 32)
|
||||
if name == 'WAVE_MODE' and self.try_eat('DOT') and self.try_eat_val('IEEE', 'IDENT'): return _u32(1)
|
||||
if self.try_eat('LBRACE'):
|
||||
idx = self.eat('NUM').val
|
||||
|
|
@ -685,7 +705,7 @@ class Parser:
|
|||
dt = dtypes.uint64 if base.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32
|
||||
base_cast = base.cast(dt) if base.dtype != dt else base
|
||||
result = ((base_cast >> _const(dt, idx)) & _const(dt, 1))
|
||||
return _cast_to(result, dt_suffix) if dt_suffix else result
|
||||
return _cast_to(result, dt_suffix) if dt_suffix else result.cast(dtypes.bool)
|
||||
if var_name:
|
||||
idx_u32 = _to_u32(first)
|
||||
elems = [(i, self.vars[f'{var_name}@{i}']) for i in range(256) if f'{var_name}@{i}' in self.vars]
|
||||
|
|
@ -699,7 +719,7 @@ class Parser:
|
|||
dt = dtypes.uint64 if base.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32
|
||||
base_cast = base.cast(dt) if base.dtype != dt else base
|
||||
result = (base_cast >> first.cast(dt)) & _const(dt, 1)
|
||||
return _cast_to(result, dt_suffix) if dt_suffix else result
|
||||
return _cast_to(result, dt_suffix) if dt_suffix else result.cast(dtypes.bool)
|
||||
|
||||
def _handle_brace_index(self, base) -> UOp:
|
||||
self.eat('LBRACE')
|
||||
|
|
@ -845,7 +865,7 @@ class Parser:
|
|||
hi = mem.index(safe_idx_hi, *gate)
|
||||
combined = val.cast(dtypes.uint64) | (hi.cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32))
|
||||
val = is_unaligned.where((combined >> (byte_off.cast(dtypes.uint64) * UOp.const(dtypes.uint64, 8))).cast(dtypes.uint32), val)
|
||||
return val
|
||||
return _cast_to(val, dt)
|
||||
|
||||
def _coerce_cmp(self, l: UOp, r: UOp) -> tuple[UOp, UOp]:
|
||||
if l.dtype != r.dtype:
|
||||
|
|
@ -1044,14 +1064,12 @@ def parse_block(lines: list[str], start: int, env: dict[str, VarVal], funcs: dic
|
|||
elif j < len(toks) and toks[j].type == 'EQUALS': j += 1
|
||||
rhs = parse_tokens(toks[j:], env, funcs)
|
||||
if compound_op:
|
||||
mem = env.get('_vmem') if '_vmem' in env else env.get('_lds')
|
||||
if isinstance(mem, UOp):
|
||||
adt = dtypes.uint64 if addr.dtype == dtypes.uint64 else dtypes.uint32
|
||||
idx = (addr >> _const(adt, 2)).cast(dtypes.int)
|
||||
old = mem.index(idx)
|
||||
if dt in (dtypes.uint64, dtypes.int64, dtypes.float64):
|
||||
old = old.cast(dtypes.uint64) | (mem.index(((addr + _const(adt, 4)) >> _const(adt, 2)).cast(dtypes.int)).cast(dtypes.uint64) << _u64(32))
|
||||
rhs = (old + rhs) if compound_op == '+=' else (old - rhs)
|
||||
old = Parser([Token('EOF', '')], env, funcs)._handle_mem_load(addr, dt)
|
||||
if compound_op == '+=': rhs = old + rhs
|
||||
elif compound_op == '-=': rhs = old - rhs
|
||||
elif compound_op == '|=': rhs = old | rhs
|
||||
elif compound_op == '&=': rhs = old & rhs
|
||||
elif compound_op == '^=': rhs = old ^ rhs
|
||||
if assigns is not None: assigns.append((f'MEM[{_tok_str(addr_toks)}].{dt_name}', (addr, rhs)))
|
||||
i += 1
|
||||
continue
|
||||
|
|
@ -1188,7 +1206,11 @@ def parse_block(lines: list[str], start: int, env: dict[str, VarVal], funcs: dic
|
|||
old = block_assigns.get(var, env.get(var, _u32(0)))
|
||||
rhs = parse_tokens(toks[assign_op+1:], env, funcs)
|
||||
if rhs.dtype != old.dtype: rhs = rhs.cast(old.dtype)
|
||||
block_assigns[var] = env[var] = (old + rhs) if toks[assign_op].val == '+=' else (old - rhs)
|
||||
if toks[assign_op].val == '+=': block_assigns[var] = env[var] = old + rhs
|
||||
elif toks[assign_op].val == '-=': block_assigns[var] = env[var] = old - rhs
|
||||
elif toks[assign_op].val == '|=': block_assigns[var] = env[var] = old | rhs
|
||||
elif toks[assign_op].val == '&=': block_assigns[var] = env[var] = old & rhs
|
||||
elif toks[assign_op].val == '^=': block_assigns[var] = env[var] = old ^ rhs
|
||||
i += 1
|
||||
continue
|
||||
|
||||
|
|
@ -1335,4 +1357,3 @@ def parse_block(lines: list[str], start: int, env: dict[str, VarVal], funcs: dic
|
|||
|
||||
def parse_expr(expr: str, env: dict[str, VarVal], funcs: dict | None = None) -> UOp:
|
||||
return parse_tokens(tokenize(expr.strip().rstrip(';')), env, funcs)
|
||||
|
||||
|
|
|
|||
|
|
@ -32,11 +32,13 @@ _FORMATS: dict[str, list[type[Inst]]] | None = None
|
|||
def _load_formats() -> dict[str, list[type[Inst]]]:
|
||||
global _FORMATS
|
||||
if _FORMATS is not None: return _FORMATS
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import (VOP1, VOP1_SDST, VOP1_LIT, VOP2, VOP2_LIT, VOP3, VOP3_SDST, VOP3SD, VOP3P, VOPC, VOPD,
|
||||
VINTERP, SOP1, SOP1_LIT, SOP2, SOP2_LIT, SOPC, SOPK, SOPK_LIT, SOPP, SMEM, DS, FLAT, GLOBAL, SCRATCH)
|
||||
from tinygrad.runtime.autogen.amd.rdna4.ins import (VOP1 as R4_VOP1, VOP1_SDST as R4_VOP1_SDST, VOP1_LIT as R4_VOP1_LIT,
|
||||
VOP2 as R4_VOP2, VOP2_LIT as R4_VOP2_LIT, VOP3 as R4_VOP3, VOP3_SDST as R4_VOP3_SDST, VOP3SD as R4_VOP3SD, VOP3P as R4_VOP3P,
|
||||
VOPC as R4_VOPC, VOPD as R4_VOPD, VINTERP as R4_VINTERP, SOP1 as R4_SOP1, SOP1_LIT as R4_SOP1_LIT,
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import (VOP1, VOP1_SDST, VOP1_DPP16, VOP1_LIT, VOP2, VOP2_DPP16, VOP2_LIT, VOP3, VOP3_SDST,
|
||||
VOP3SD, VOP3P, VOPC, VOPC_DPP16, VOPD, VINTERP, SOP1, SOP1_LIT, SOP2, SOP2_LIT, SOPC, SOPK, SOPK_LIT, SOPP, SMEM, DS, FLAT, GLOBAL,
|
||||
SCRATCH)
|
||||
from tinygrad.runtime.autogen.amd.rdna4.ins import (VOP1 as R4_VOP1, VOP1_SDST as R4_VOP1_SDST, VOP1_DPP16 as R4_VOP1_DPP16,
|
||||
VOP1_LIT as R4_VOP1_LIT, VOP2 as R4_VOP2, VOP2_DPP16 as R4_VOP2_DPP16, VOP2_LIT as R4_VOP2_LIT, VOP3 as R4_VOP3,
|
||||
VOP3_SDST as R4_VOP3_SDST, VOP3SD as R4_VOP3SD, VOP3P as R4_VOP3P, VOPC as R4_VOPC, VOPC_DPP16 as R4_VOPC_DPP16,
|
||||
VOPD as R4_VOPD, VINTERP as R4_VINTERP, SOP1 as R4_SOP1, SOP1_LIT as R4_SOP1_LIT,
|
||||
SOP2 as R4_SOP2, SOP2_LIT as R4_SOP2_LIT, SOPC as R4_SOPC, SOPC_LIT as R4_SOPC_LIT,
|
||||
SOPK as R4_SOPK, SOPK_LIT as R4_SOPK_LIT, SOPP as R4_SOPP,
|
||||
SMEM as R4_SMEM, DS as R4_DS, VFLAT as R4_FLAT, VGLOBAL as R4_GLOBAL, VSCRATCH as R4_SCRATCH)
|
||||
|
|
@ -50,10 +52,11 @@ def _load_formats() -> dict[str, list[type[Inst]]]:
|
|||
# Order: base before _LIT (base matches regular ops, _LIT catches lit-only ops excluded from base)
|
||||
_FORMATS = {
|
||||
"rdna3": [VOPD, VOP3P, VINTERP, VOP3SD, VOP3_SDST, VOP3, DS, GLOBAL, SCRATCH, FLAT, SMEM,
|
||||
SOP1, SOP1_LIT, SOP2, SOP2_LIT, SOPC, SOPK, SOPK_LIT, SOPP, VOPC, VOP1_SDST, VOP1, VOP1_LIT, VOP2, VOP2_LIT],
|
||||
SOP1, SOP1_LIT, SOP2, SOP2_LIT, SOPC, SOPK, SOPK_LIT, SOPP, VOPC_DPP16, VOPC, VOP1_SDST, VOP1_DPP16, VOP1, VOP1_LIT,
|
||||
VOP2_DPP16, VOP2, VOP2_LIT],
|
||||
"rdna4": [R4_VOPD, R4_VOP3P, R4_VINTERP, R4_VOP3SD, R4_VOP3_SDST, R4_VOP3, R4_DS, R4_GLOBAL, R4_SCRATCH, R4_FLAT, R4_SMEM,
|
||||
R4_SOP1, R4_SOP1_LIT, R4_SOPC, R4_SOPC_LIT, R4_SOPP, R4_SOPK, R4_SOPK_LIT, R4_VOPC, R4_VOP1_SDST, R4_VOP1, R4_VOP1_LIT,
|
||||
R4_SOP2, R4_SOP2_LIT, R4_VOP2, R4_VOP2_LIT],
|
||||
R4_SOP1, R4_SOP1_LIT, R4_SOPC, R4_SOPC_LIT, R4_SOPP, R4_SOPK, R4_SOPK_LIT, R4_VOPC_DPP16, R4_VOPC, R4_VOP1_SDST,
|
||||
R4_VOP1_DPP16, R4_VOP1, R4_VOP1_LIT, R4_SOP2, R4_SOP2_LIT, R4_VOP2_DPP16, R4_VOP2, R4_VOP2_LIT],
|
||||
"cdna": [C_VOP3PX2, C_VOP3P_MFMA, C_VOP3P, C_VOP3SD, C_VOP3_SDST, C_VOP3, C_DS, C_GLOBAL, C_SCRATCH, C_FLAT, C_MUBUF, C_SMEM,
|
||||
C_SOP1, C_SOPC, C_SOPP, C_SOPK, C_SOPK_LIT, C_VOPC_SDWA_SDST, C_VOPC,
|
||||
C_VOP1_DPP16, C_VOP1_SDWA, C_VOP1, C_VOP2_DPP16, C_VOP2_SDWA, C_SOP2, C_VOP2, C_VOP2_LIT],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue