mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
674 lines
28 KiB
Python
674 lines
28 KiB
Python
#!/usr/bin/env python3
|
|
"""Tests for the pcode-based instruction selector (isel.py)."""
|
|
import unittest
|
|
from tinygrad.uop.ops import UOp, Ops
|
|
from tinygrad.dtype import dtypes
|
|
from extra.assembly.amd.isel import (rdna3_isel, make_inst, normalize, _count_nodes, _is_direct_alu,
|
|
_parse_pcode_patterns, _pattern_key, _runtime_key, uop_to_upat,
|
|
_SENTINEL, _SENTINEL_SET, _DIRECT_TABLE, _STRUCTURAL_TABLE,
|
|
_ALU_ENUM_TYPES, build_isel_patterns)
|
|
from extra.assembly.amd.autogen.rdna3.str_pcode import PCODE
|
|
from extra.assembly.amd.autogen.rdna3.enum import VOP2Op, VOP1Op, VOP3Op, SOP2Op, VOPCOp, VOP3SDOp
|
|
|
|
# helpers
|
|
def _var(name, dtype=dtypes.float): return UOp(Ops.DEFINE_VAR, dtype, arg=(name, 0, 100))
|
|
def _const(val, dtype=dtypes.float): return UOp(Ops.CONST, dtype, arg=val)
|
|
|
|
class TestMakeInst(unittest.TestCase):
|
|
def test_vop2(self):
|
|
inst = make_inst(VOP2Op.V_ADD_F32_E32)
|
|
assert inst.op == VOP2Op.V_ADD_F32_E32
|
|
|
|
def test_vop1(self):
|
|
inst = make_inst(VOP1Op.V_SQRT_F32_E32)
|
|
assert inst.op == VOP1Op.V_SQRT_F32_E32
|
|
|
|
def test_vop3(self):
|
|
inst = make_inst(VOP3Op.V_ADD_F64)
|
|
assert inst.op == VOP3Op.V_ADD_F64
|
|
|
|
def test_vop3_sdst(self):
|
|
# VOP3SDOp opcodes need the _SDST variant class
|
|
inst = make_inst(VOP3SDOp.V_ADD_CO_CI_U32)
|
|
assert inst.op == VOP3SDOp.V_ADD_CO_CI_U32
|
|
|
|
def test_vopc(self):
|
|
inst = make_inst(VOPCOp.V_CMP_LT_F32_E32)
|
|
assert inst.op == VOPCOp.V_CMP_LT_F32_E32
|
|
|
|
def test_sop2(self):
|
|
inst = make_inst(SOP2Op.S_ADD_I32)
|
|
assert inst.op == SOP2Op.S_ADD_I32
|
|
|
|
def test_invalid_raises(self):
|
|
with self.assertRaises(RuntimeError): make_inst("not_an_opcode")
|
|
|
|
class TestNormalize(unittest.TestCase):
|
|
def test_bitcast_sentinel(self):
|
|
s0 = _SENTINEL['S0']
|
|
bc = UOp(Ops.BITCAST, dtypes.float, (s0,))
|
|
norm = normalize(bc)
|
|
assert norm.op == Ops.DEFINE_VAR
|
|
assert norm.dtype == dtypes.float
|
|
|
|
def test_cast_sentinel(self):
|
|
s0 = _SENTINEL['S0']
|
|
cast = UOp(Ops.CAST, dtypes.int, (s0,))
|
|
norm = normalize(cast)
|
|
assert norm.op == Ops.DEFINE_VAR
|
|
assert norm.dtype == dtypes.int
|
|
|
|
def test_identity_bitcast(self):
|
|
x = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
|
bc = UOp(Ops.BITCAST, dtypes.float, (x,))
|
|
norm = normalize(bc)
|
|
assert norm.op == Ops.CONST
|
|
assert norm.arg == 1.0
|
|
|
|
def test_shift_mask_31(self):
|
|
s0 = _SENTINEL['S0']
|
|
c31 = UOp(Ops.CONST, dtypes.uint, arg=31)
|
|
masked = UOp(Ops.AND, dtypes.uint, (s0, c31))
|
|
norm = normalize(masked)
|
|
assert norm.op == Ops.DEFINE_VAR
|
|
|
|
def test_shift_mask_63(self):
|
|
s0 = _SENTINEL['S0']
|
|
c63 = UOp(Ops.CONST, dtypes.uint, arg=63)
|
|
masked = UOp(Ops.AND, dtypes.uint, (s0, c63))
|
|
norm = normalize(masked)
|
|
assert norm.op == Ops.DEFINE_VAR
|
|
|
|
def test_non_sentinel_bitcast_preserved(self):
|
|
x = _var('x', dtypes.uint)
|
|
bc = UOp(Ops.BITCAST, dtypes.float, (x,))
|
|
norm = normalize(bc)
|
|
assert norm.op == Ops.BITCAST # not a sentinel, so preserved
|
|
|
|
def test_recursive(self):
|
|
s0 = _SENTINEL['S0']
|
|
s1 = _SENTINEL['S1']
|
|
bc0 = UOp(Ops.BITCAST, dtypes.float, (s0,))
|
|
bc1 = UOp(Ops.BITCAST, dtypes.float, (s1,))
|
|
add = UOp(Ops.ADD, dtypes.float, (bc0, bc1))
|
|
norm = normalize(add)
|
|
assert norm.op == Ops.ADD
|
|
assert all(s.dtype == dtypes.float for s in norm.src)
|
|
assert all(s.op == Ops.DEFINE_VAR for s in norm.src)
|
|
|
|
class TestCountNodes(unittest.TestCase):
|
|
def test_leaf(self):
|
|
assert _count_nodes(_var('x')) == 1
|
|
|
|
def test_binary(self):
|
|
x, y = _var('x'), _var('y')
|
|
add = UOp(Ops.ADD, dtypes.float, (x, y))
|
|
assert _count_nodes(add) == 3
|
|
|
|
def test_dag_sharing(self):
|
|
x = _var('x')
|
|
add = UOp(Ops.ADD, dtypes.float, (x, x))
|
|
assert _count_nodes(add) == 2 # x counted once
|
|
|
|
class TestIsDirectAlu(unittest.TestCase):
|
|
def test_add_sentinels(self):
|
|
s0 = _SENTINEL['S0'].replace(dtype=dtypes.float)
|
|
s1 = _SENTINEL['S1'].replace(dtype=dtypes.float)
|
|
add = UOp(Ops.ADD, dtypes.float, (s0, s1))
|
|
assert _is_direct_alu(add)
|
|
|
|
def test_cast_sentinel(self):
|
|
s0 = _SENTINEL['S0'].replace(dtype=dtypes.int)
|
|
cast = UOp(Ops.CAST, dtypes.float, (s0,))
|
|
assert _is_direct_alu(cast)
|
|
|
|
def test_nested_not_direct(self):
|
|
s0 = _SENTINEL['S0'].replace(dtype=dtypes.uint)
|
|
c = _const(0xFFFFFFFF, dtypes.uint)
|
|
xor = UOp(Ops.XOR, dtypes.uint, (s0, c))
|
|
assert not _is_direct_alu(xor) # const child is not DEFINE_VAR
|
|
|
|
class TestPatternKey(unittest.TestCase):
|
|
def test_sentinel_var(self):
|
|
s0 = _SENTINEL['S0'].replace(dtype=dtypes.float)
|
|
key = _pattern_key(s0)
|
|
assert key == 'var(S0,dtypes.float)'
|
|
|
|
def test_const(self):
|
|
c = _const(42, dtypes.uint)
|
|
key = _pattern_key(c)
|
|
assert key == 'const(42,dtypes.uint)'
|
|
|
|
def test_binary_op(self):
|
|
s0 = _SENTINEL['S0'].replace(dtype=dtypes.float)
|
|
s1 = _SENTINEL['S1'].replace(dtype=dtypes.float)
|
|
add = UOp(Ops.ADD, dtypes.float, (s0, s1))
|
|
key = _pattern_key(add)
|
|
assert key == 'Ops.ADD(dtypes.float,var(S0,dtypes.float),var(S1,dtypes.float))'
|
|
|
|
class TestRuntimeKey(unittest.TestCase):
|
|
def test_matches_pattern_key(self):
|
|
# runtime key on a matched UOp should equal pattern key on the pcode template
|
|
x = _var('x', dtypes.uint)
|
|
c = _const(0xFFFFFFFF, dtypes.uint)
|
|
xor = UOp(Ops.XOR, dtypes.uint, (x, c))
|
|
rkey = _runtime_key(xor)
|
|
|
|
s0 = _SENTINEL['S0'].replace(dtype=dtypes.uint)
|
|
xor_template = UOp(Ops.XOR, dtypes.uint, (s0, c))
|
|
pkey = _pattern_key(xor_template)
|
|
assert rkey == pkey
|
|
|
|
class TestUopToUpat(unittest.TestCase):
|
|
def test_sentinel_becomes_var(self):
|
|
s0 = _SENTINEL['S0'].replace(dtype=dtypes.float)
|
|
pat = uop_to_upat(s0)
|
|
assert pat.name == 'S0'
|
|
assert pat.dtype == (dtypes.float,)
|
|
|
|
def test_const_preserved(self):
|
|
c = _const(42, dtypes.uint)
|
|
pat = uop_to_upat(c)
|
|
assert pat.op == (Ops.CONST,)
|
|
assert pat.arg == 42
|
|
|
|
class TestBuildPerformance(unittest.TestCase):
|
|
def test_builds_under_2_seconds(self):
|
|
import time
|
|
t0 = time.time()
|
|
build_isel_patterns(PCODE)
|
|
elapsed = time.time() - t0
|
|
assert elapsed < 2.0, f"build took {elapsed:.2f}s, expected <2s"
|
|
|
|
def test_alu_filter(self):
|
|
# verify only ALU enum types are parsed
|
|
for opcode in PCODE:
|
|
if type(opcode).__name__ not in _ALU_ENUM_TYPES: continue
|
|
# these should parse without hanging
|
|
|
|
class TestDirectPatterns(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.pm = rdna3_isel()
|
|
|
|
def _check(self, uop, expected_name_substr):
|
|
result = self.pm.rewrite(uop)
|
|
self.assertIsNotNone(result, f"no match for {uop.op} {uop.dtype}")
|
|
self.assertEqual(result.op, Ops.INS)
|
|
self.assertIn(expected_name_substr, result.arg.op.name, f"expected {expected_name_substr} in {result.arg.op.name}")
|
|
return result
|
|
|
|
# arithmetic
|
|
def test_add_f32(self): self._check(UOp(Ops.ADD, dtypes.float, (_var('a'), _var('b'))), 'V_ADD_F32')
|
|
def test_add_f64(self): self._check(UOp(Ops.ADD, dtypes.double, (_var('a', dtypes.double), _var('b', dtypes.double))), 'V_ADD_F64')
|
|
def test_add_i32(self): self._check(UOp(Ops.ADD, dtypes.int, (_var('a', dtypes.int), _var('b', dtypes.int))), 'ADD_NC_I32')
|
|
def test_add_u32(self): self._check(UOp(Ops.ADD, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))), 'ADD_NC_U32')
|
|
def test_mul_f32(self): self._check(UOp(Ops.MUL, dtypes.float, (_var('a'), _var('b'))), 'V_MUL_F32')
|
|
def test_mul_f64(self): self._check(UOp(Ops.MUL, dtypes.double, (_var('a', dtypes.double), _var('b', dtypes.double))), 'V_MUL_F64')
|
|
def test_mul_i32(self): self._check(UOp(Ops.MUL, dtypes.int, (_var('a', dtypes.int), _var('b', dtypes.int))), 'MUL_I32')
|
|
def test_mul_u32(self): self._check(UOp(Ops.MUL, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))), 'MUL_U32')
|
|
|
|
# bitwise
|
|
def test_and_u32(self): self._check(UOp(Ops.AND, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))), 'AND_B32')
|
|
def test_or_u32(self): self._check(UOp(Ops.OR, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))), 'OR_B32')
|
|
def test_xor_u32(self): self._check(UOp(Ops.XOR, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))), 'XOR_B32')
|
|
# u64 bitwise ops are SOP-only, skipped in vgpr_only mode
|
|
def test_and_u64_skipped(self):
|
|
result = self.pm.rewrite(UOp(Ops.AND, dtypes.ulong, (_var('a', dtypes.ulong), _var('b', dtypes.ulong))))
|
|
self.assertIsNone(result)
|
|
def test_or_u64_skipped(self):
|
|
result = self.pm.rewrite(UOp(Ops.OR, dtypes.ulong, (_var('a', dtypes.ulong), _var('b', dtypes.ulong))))
|
|
self.assertIsNone(result)
|
|
def test_xor_u64_skipped(self):
|
|
result = self.pm.rewrite(UOp(Ops.XOR, dtypes.ulong, (_var('a', dtypes.ulong), _var('b', dtypes.ulong))))
|
|
self.assertIsNone(result)
|
|
|
|
# shifts
|
|
def test_shl_u32(self): self._check(UOp(Ops.SHL, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))), 'LSH')
|
|
def test_shr_u32(self): self._check(UOp(Ops.SHR, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))), 'LSH')
|
|
|
|
# unary float
|
|
def test_sqrt_f32(self): self._check(UOp(Ops.SQRT, dtypes.float, (_var('a'),)), 'SQRT_F32')
|
|
def test_sqrt_f64(self): self._check(UOp(Ops.SQRT, dtypes.double, (_var('a', dtypes.double),)), 'SQRT_F64')
|
|
def test_trunc_f32(self): self._check(UOp(Ops.TRUNC, dtypes.float, (_var('a'),)), 'TRUNC_F32')
|
|
def test_trunc_f64(self): self._check(UOp(Ops.TRUNC, dtypes.double, (_var('a', dtypes.double),)), 'TRUNC_F64')
|
|
def test_log2_f32(self): self._check(UOp(Ops.LOG2, dtypes.float, (_var('a'),)), 'LOG')
|
|
def test_exp2_f32(self): self._check(UOp(Ops.EXP2, dtypes.float, (_var('a'),)), 'EXP')
|
|
|
|
# conversions
|
|
def test_cast_i32_to_f32(self): self._check(UOp(Ops.CAST, dtypes.float, (_var('a', dtypes.int),)), 'CVT_F32_I32')
|
|
def test_cast_f32_to_f64(self): self._check(UOp(Ops.CAST, dtypes.double, (_var('a'),)), 'CVT_F64_F32')
|
|
def test_cast_f64_to_f32(self): self._check(UOp(Ops.CAST, dtypes.float, (_var('a', dtypes.double),)), 'CVT_F32_F64')
|
|
def test_cast_i32_to_f64(self): self._check(UOp(Ops.CAST, dtypes.double, (_var('a', dtypes.int),)), 'CVT_F64_I32')
|
|
def test_cast_f32_to_f16(self): self._check(UOp(Ops.CAST, dtypes.half, (_var('a'),)), 'CVT_F16_F32')
|
|
|
|
# compares are skipped by ISel (VOPC writes VCC, not VGPRs; LLVM handles natively)
|
|
def test_cmplt_skipped(self):
|
|
result = self.pm.rewrite(UOp(Ops.CMPLT, dtypes.bool, (_var('a', dtypes.int), _var('b', dtypes.int))))
|
|
self.assertIsNone(result)
|
|
def test_cmpne_skipped(self):
|
|
result = self.pm.rewrite(UOp(Ops.CMPNE, dtypes.bool, (_var('a'), _var('b'))))
|
|
self.assertIsNone(result)
|
|
|
|
# check that unmatched types return None
|
|
def test_no_match(self):
|
|
# there's no direct ADD for bools
|
|
result = self.pm.rewrite(UOp(Ops.ADD, dtypes.bool, (_var('a', dtypes.bool), _var('b', dtypes.bool))))
|
|
self.assertIsNone(result)
|
|
|
|
class TestStructuralPatterns(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.pm = rdna3_isel()
|
|
|
|
def _check(self, uop, expected_name_substr, expected_src_count=None):
|
|
result = self.pm.rewrite(uop)
|
|
self.assertIsNotNone(result, f"no match for structural pattern")
|
|
self.assertEqual(result.op, Ops.INS)
|
|
self.assertIn(expected_name_substr, result.arg.op.name, f"expected {expected_name_substr} in {result.arg.op.name}")
|
|
if expected_src_count is not None:
|
|
self.assertEqual(len(result.src), expected_src_count, f"expected {expected_src_count} srcs, got {len(result.src)}")
|
|
return result
|
|
|
|
def test_not_u32(self):
|
|
x = _var('x', dtypes.uint)
|
|
xor = UOp(Ops.XOR, dtypes.uint, (x, _const(0xFFFFFFFF, dtypes.uint)))
|
|
self._check(xor, 'NOT_B32', 1)
|
|
|
|
# u64 NOT is SOP-only (S_NOT_B64), skipped in vgpr_only mode
|
|
def test_not_u64_skipped(self):
|
|
x = _var('x', dtypes.ulong)
|
|
xor = UOp(Ops.XOR, dtypes.ulong, (x, _const(0xFFFFFFFFFFFFFFFF, dtypes.ulong)))
|
|
result = self.pm.rewrite(xor)
|
|
self.assertIsNone(result)
|
|
|
|
def test_sub_u32(self):
|
|
x, y = _var('x', dtypes.uint), _var('y', dtypes.uint)
|
|
neg = UOp(Ops.MUL, dtypes.uint, (y, _const(-1, dtypes.uint)))
|
|
sub = UOp(Ops.ADD, dtypes.uint, (x, neg))
|
|
result = self._check(sub, 'SUB_NC_U32', 2)
|
|
# verify source order: x is first, y is second
|
|
self.assertEqual(result.src[0].arg, ('x', 0, 100))
|
|
self.assertEqual(result.src[1].arg, ('y', 0, 100))
|
|
|
|
def test_rcp_f32(self):
|
|
a = _var('a')
|
|
rcp = UOp(Ops.RECIPROCAL, dtypes.float, (a,))
|
|
mul_rcp = UOp(Ops.MUL, dtypes.float, (_const(1.0), rcp))
|
|
result = self._check(mul_rcp, 'RCP_F32', 1)
|
|
self.assertEqual(result.src[0].arg, ('a', 0, 100))
|
|
|
|
def test_rcp_f64(self):
|
|
a = _var('a', dtypes.double)
|
|
rcp = UOp(Ops.RECIPROCAL, dtypes.double, (a,))
|
|
mul_rcp = UOp(Ops.MUL, dtypes.double, (_const(1.0, dtypes.double), rcp))
|
|
self._check(mul_rcp, 'RCP_F64', 1)
|
|
|
|
def test_cvt_i32_f32(self):
|
|
# CAST(i32, TRUNC(f32, x)) -> V_CVT_I32_F32
|
|
a = _var('a')
|
|
trunc = UOp(Ops.TRUNC, dtypes.float, (a,))
|
|
cast = UOp(Ops.CAST, dtypes.int, (trunc,))
|
|
self._check(cast, 'CVT_I32_F32', 1)
|
|
|
|
def test_mad_u32(self):
|
|
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
|
|
mul = UOp(Ops.MUL, dtypes.uint, (x, y))
|
|
mad = UOp(Ops.ADD, dtypes.uint, (mul, z))
|
|
result = self._check(mad, 'MAD_U32_U24', 3)
|
|
self.assertEqual(result.src[0].arg, ('x', 0, 100))
|
|
self.assertEqual(result.src[1].arg, ('y', 0, 100))
|
|
self.assertEqual(result.src[2].arg, ('z', 0, 100))
|
|
|
|
def test_add3_u32(self):
|
|
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
|
|
add1 = UOp(Ops.ADD, dtypes.uint, (x, y))
|
|
add3 = UOp(Ops.ADD, dtypes.uint, (add1, z))
|
|
result = self._check(add3, 'ADD3_U32', 3)
|
|
|
|
def test_xor3_b32(self):
|
|
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
|
|
xor1 = UOp(Ops.XOR, dtypes.uint, (x, y))
|
|
xor3 = UOp(Ops.XOR, dtypes.uint, (xor1, z))
|
|
self._check(xor3, 'XOR3_B32', 3)
|
|
|
|
def test_and_or_b32(self):
|
|
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
|
|
and_op = UOp(Ops.AND, dtypes.uint, (x, y))
|
|
or_op = UOp(Ops.OR, dtypes.uint, (and_op, z))
|
|
self._check(or_op, 'AND_OR_B32', 3)
|
|
|
|
def test_or3_b32(self):
|
|
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
|
|
or1 = UOp(Ops.OR, dtypes.uint, (x, y))
|
|
or3 = UOp(Ops.OR, dtypes.uint, (or1, z))
|
|
self._check(or3, 'OR3_B32', 3)
|
|
|
|
# NAND/NOR were SOP-only, in vgpr_only mode they decompose to V_XOR_B32(AND/OR, mask)
|
|
def test_nand_decomposes(self):
|
|
x, y = _var('x', dtypes.uint), _var('y', dtypes.uint)
|
|
and_op = UOp(Ops.AND, dtypes.uint, (x, y))
|
|
nand = UOp(Ops.XOR, dtypes.uint, (and_op, _const(0xFFFFFFFF, dtypes.uint)))
|
|
self._check(nand, 'XOR_B32')
|
|
|
|
def test_nor_decomposes(self):
|
|
x, y = _var('x', dtypes.uint), _var('y', dtypes.uint)
|
|
or_op = UOp(Ops.OR, dtypes.uint, (x, y))
|
|
nor = UOp(Ops.XOR, dtypes.uint, (or_op, _const(0xFFFFFFFF, dtypes.uint)))
|
|
self._check(nor, 'XOR_B32')
|
|
|
|
def test_xnor_b32(self):
|
|
x, y = _var('x', dtypes.uint), _var('y', dtypes.uint)
|
|
xor_op = UOp(Ops.XOR, dtypes.uint, (x, y))
|
|
xnor = UOp(Ops.XOR, dtypes.uint, (xor_op, _const(0xFFFFFFFF, dtypes.uint)))
|
|
self._check(xnor, 'XNOR_B32', 2)
|
|
|
|
def test_min_u32(self):
|
|
x, y = _var('x', dtypes.uint), _var('y', dtypes.uint)
|
|
cmp = UOp(Ops.CMPLT, dtypes.bool, (x, y))
|
|
where = UOp(Ops.WHERE, dtypes.uint, (cmp, x, y))
|
|
self._check(where, 'MIN_U32', 2)
|
|
|
|
class TestInstProperties(unittest.TestCase):
|
|
"""Verify that Inst objects produced by isel have correct properties."""
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.pm = rdna3_isel()
|
|
|
|
def test_ins_has_dtype(self):
|
|
result = self.pm.rewrite(UOp(Ops.ADD, dtypes.float, (_var('a'), _var('b'))))
|
|
self.assertEqual(result.dtype, dtypes.float)
|
|
|
|
def test_ins_preserves_sources(self):
|
|
a, b = _var('a'), _var('b')
|
|
result = self.pm.rewrite(UOp(Ops.ADD, dtypes.float, (a, b)))
|
|
self.assertEqual(result.src, (a, b))
|
|
|
|
def test_ins_tag_default_none(self):
|
|
result = self.pm.rewrite(UOp(Ops.ADD, dtypes.float, (_var('a'), _var('b'))))
|
|
# tag should not be set (defaults to None or empty)
|
|
self.assertIsNone(result.tag)
|
|
|
|
class TestTableCoverage(unittest.TestCase):
|
|
"""Verify that the tables have expected coverage."""
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
rdna3_isel() # populate tables
|
|
|
|
def test_direct_table_has_add(self):
|
|
found = any(op == Ops.ADD for (op, _, _) in _DIRECT_TABLE)
|
|
self.assertTrue(found)
|
|
|
|
def test_direct_table_has_cast(self):
|
|
found = any(op == Ops.CAST for (op, _, _) in _DIRECT_TABLE)
|
|
self.assertTrue(found)
|
|
|
|
def test_direct_table_skips_cmplt(self):
|
|
# compares are in the table but skipped at runtime (bool output)
|
|
found = any(op == Ops.CMPLT for (op, _, _) in _DIRECT_TABLE)
|
|
self.assertTrue(found) # entries exist but callbacks skip them
|
|
|
|
def test_structural_table_has_not(self):
|
|
found = any('NOT' in inst.op.name for inst in _STRUCTURAL_TABLE.values())
|
|
self.assertTrue(found)
|
|
|
|
def test_structural_table_has_rcp(self):
|
|
found = any('RCP' in inst.op.name for inst in _STRUCTURAL_TABLE.values())
|
|
self.assertTrue(found)
|
|
|
|
def test_structural_table_has_sub(self):
|
|
found = any('SUB' in inst.op.name for inst in _STRUCTURAL_TABLE.values())
|
|
self.assertTrue(found)
|
|
|
|
def test_direct_count(self):
|
|
self.assertGreaterEqual(len(_DIRECT_TABLE), 25, "expected at least 25 direct patterns")
|
|
|
|
def test_structural_count(self):
|
|
self.assertGreaterEqual(len(_STRUCTURAL_TABLE), 15, "expected at least 15 structural patterns")
|
|
|
|
class TestEmulatorValidation(unittest.TestCase):
|
|
"""Validate isel-produced Inst objects execute correctly in the emulator.
|
|
|
|
For each pattern, we:
|
|
1. Run the UOp through isel to get Ops.INS with arg=Inst
|
|
2. Copy the Inst and assign concrete registers
|
|
3. Set up operand values via MOV instructions
|
|
4. Execute through the emulator
|
|
5. Verify the output matches expected computation
|
|
"""
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.pm = rdna3_isel()
|
|
|
|
def _run(self, instructions, n_lanes=1):
|
|
from extra.assembly.amd.test.hw.helpers import run_program_emu
|
|
return run_program_emu(instructions, n_lanes)
|
|
|
|
def _get_inst(self, uop):
|
|
"""Get isel result, return (Inst, src_count)."""
|
|
result = self.pm.rewrite(uop)
|
|
assert result is not None and result.op == Ops.INS, f"isel failed for {uop.op} {uop.dtype}"
|
|
return result.arg, len(result.src)
|
|
|
|
def _copy_inst(self, inst):
|
|
import copy
|
|
return copy.copy(inst)
|
|
|
|
# ── direct ALU: float arithmetic ──
|
|
|
|
def test_emu_add_f32(self):
|
|
from extra.assembly.amd.test.hw.helpers import i2f, f2i
|
|
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
|
inst, _ = self._get_inst(UOp(Ops.ADD, dtypes.float, (_var('a'), _var('b'))))
|
|
ci = self._copy_inst(inst)
|
|
ci.src0 = v[0]; ci.vsrc1 = v[1]; ci.vdst = v[2]
|
|
st = self._run([v_mov_b32_e32(v[0], 1.5), v_mov_b32_e32(v[1], 2.25), ci])
|
|
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 3.75, places=5)
|
|
|
|
def test_emu_mul_f32(self):
|
|
from extra.assembly.amd.test.hw.helpers import i2f
|
|
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
|
inst, _ = self._get_inst(UOp(Ops.MUL, dtypes.float, (_var('a'), _var('b'))))
|
|
ci = self._copy_inst(inst)
|
|
ci.src0 = v[0]; ci.vsrc1 = v[1]; ci.vdst = v[2]
|
|
st = self._run([v_mov_b32_e32(v[0], 3.0), v_mov_b32_e32(v[1], 4.0), ci])
|
|
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 12.0, places=5)
|
|
|
|
# ── direct ALU: integer arithmetic ──
|
|
|
|
def test_emu_add_u32(self):
|
|
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
|
inst, _ = self._get_inst(UOp(Ops.ADD, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))))
|
|
ci = self._copy_inst(inst)
|
|
ci.src0 = v[0]; ci.vsrc1 = v[1]; ci.vdst = v[2]
|
|
st = self._run([v_mov_b32_e32(v[0], 10), v_mov_b32_e32(v[1], 20), ci])
|
|
self.assertEqual(st.vgpr[0][2], 30)
|
|
|
|
# ── direct ALU: bitwise ──
|
|
|
|
def test_emu_and_u32(self):
|
|
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
|
inst, _ = self._get_inst(UOp(Ops.AND, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))))
|
|
ci = self._copy_inst(inst)
|
|
ci.src0 = v[0]; ci.vsrc1 = v[1]; ci.vdst = v[2]
|
|
st = self._run([v_mov_b32_e32(v[0], 0xFF00), v_mov_b32_e32(v[1], 0x0FF0), ci])
|
|
self.assertEqual(st.vgpr[0][2], 0x0F00)
|
|
|
|
def test_emu_or_u32(self):
|
|
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
|
inst, _ = self._get_inst(UOp(Ops.OR, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))))
|
|
ci = self._copy_inst(inst)
|
|
ci.src0 = v[0]; ci.vsrc1 = v[1]; ci.vdst = v[2]
|
|
st = self._run([v_mov_b32_e32(v[0], 0xFF00), v_mov_b32_e32(v[1], 0x0FF0), ci])
|
|
self.assertEqual(st.vgpr[0][2], 0xFFF0)
|
|
|
|
def test_emu_xor_u32(self):
|
|
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
|
inst, _ = self._get_inst(UOp(Ops.XOR, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))))
|
|
ci = self._copy_inst(inst)
|
|
ci.src0 = v[0]; ci.vsrc1 = v[1]; ci.vdst = v[2]
|
|
st = self._run([v_mov_b32_e32(v[0], 0xFF00), v_mov_b32_e32(v[1], 0x0FF0), ci])
|
|
self.assertEqual(st.vgpr[0][2], 0xF0F0)
|
|
|
|
# ── direct ALU: shifts ──
|
|
|
|
def test_emu_shl_u32(self):
|
|
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
|
inst, _ = self._get_inst(UOp(Ops.SHL, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))))
|
|
ci = self._copy_inst(inst)
|
|
# LSHLREV: vdst = vsrc1 << src0 (reversed operands!)
|
|
ci.src0 = v[1]; ci.vsrc1 = v[0]; ci.vdst = v[2]
|
|
st = self._run([v_mov_b32_e32(v[0], 1), v_mov_b32_e32(v[1], 4), ci])
|
|
self.assertEqual(st.vgpr[0][2], 16) # 1 << 4 = 16
|
|
|
|
def test_emu_shr_u32(self):
|
|
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
|
inst, _ = self._get_inst(UOp(Ops.SHR, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))))
|
|
ci = self._copy_inst(inst)
|
|
# LSHRREV: vdst = vsrc1 >> src0 (reversed operands!)
|
|
ci.src0 = v[1]; ci.vsrc1 = v[0]; ci.vdst = v[2]
|
|
st = self._run([v_mov_b32_e32(v[0], 16), v_mov_b32_e32(v[1], 4), ci])
|
|
self.assertEqual(st.vgpr[0][2], 1) # 16 >> 4 = 1
|
|
|
|
# ── direct ALU: unary float ──
|
|
|
|
def test_emu_sqrt_f32(self):
|
|
from extra.assembly.amd.test.hw.helpers import i2f
|
|
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
|
inst, _ = self._get_inst(UOp(Ops.SQRT, dtypes.float, (_var('a'),)))
|
|
ci = self._copy_inst(inst)
|
|
ci.src0 = v[0]; ci.vdst = v[2]
|
|
st = self._run([v_mov_b32_e32(v[0], 4.0), ci])
|
|
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 2.0, places=4)
|
|
|
|
def test_emu_trunc_f32(self):
|
|
from extra.assembly.amd.test.hw.helpers import i2f
|
|
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
|
inst, _ = self._get_inst(UOp(Ops.TRUNC, dtypes.float, (_var('a'),)))
|
|
ci = self._copy_inst(inst)
|
|
ci.src0 = v[0]; ci.vdst = v[2]
|
|
st = self._run([v_mov_b32_e32(v[0], 3.7), ci])
|
|
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 3.0, places=5)
|
|
|
|
def test_emu_exp2_f32(self):
|
|
from extra.assembly.amd.test.hw.helpers import i2f
|
|
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
|
inst, _ = self._get_inst(UOp(Ops.EXP2, dtypes.float, (_var('a'),)))
|
|
ci = self._copy_inst(inst)
|
|
ci.src0 = v[0]; ci.vdst = v[2]
|
|
st = self._run([v_mov_b32_e32(v[0], 3.0), ci])
|
|
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 8.0, delta=0.01)
|
|
|
|
def test_emu_log2_f32(self):
|
|
from extra.assembly.amd.test.hw.helpers import i2f
|
|
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
|
inst, _ = self._get_inst(UOp(Ops.LOG2, dtypes.float, (_var('a'),)))
|
|
ci = self._copy_inst(inst)
|
|
ci.src0 = v[0]; ci.vdst = v[2]
|
|
st = self._run([v_mov_b32_e32(v[0], 8.0), ci])
|
|
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 3.0, delta=0.01)
|
|
|
|
# ── direct ALU: conversions ──
|
|
|
|
def test_emu_cast_i32_to_f32(self):
|
|
from extra.assembly.amd.test.hw.helpers import i2f
|
|
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
|
inst, _ = self._get_inst(UOp(Ops.CAST, dtypes.float, (_var('a', dtypes.int),)))
|
|
ci = self._copy_inst(inst)
|
|
ci.src0 = v[0]; ci.vdst = v[2]
|
|
st = self._run([v_mov_b32_e32(v[0], 42), ci])
|
|
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 42.0, places=5)
|
|
|
|
def test_emu_cast_f32_to_f16(self):
|
|
from extra.assembly.amd.test.hw.helpers import i2f, f16
|
|
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
|
inst, _ = self._get_inst(UOp(Ops.CAST, dtypes.half, (_var('a'),)))
|
|
ci = self._copy_inst(inst)
|
|
ci.src0 = v[0]; ci.vdst = v[2]
|
|
st = self._run([v_mov_b32_e32(v[0], 1.5), ci])
|
|
# f16 result is in lower 16 bits of v[2]
|
|
self.assertAlmostEqual(f16(st.vgpr[0][2]), 1.5, places=2)
|
|
|
|
# compares are skipped by ISel (VOPC writes VCC; LLVM handles natively)
|
|
|
|
# ── structural: NOT ──
|
|
|
|
def test_emu_not_u32(self):
|
|
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
|
x = _var('x', dtypes.uint)
|
|
xor = UOp(Ops.XOR, dtypes.uint, (x, _const(0xFFFFFFFF, dtypes.uint)))
|
|
inst, _ = self._get_inst(xor)
|
|
ci = self._copy_inst(inst)
|
|
ci.src0 = v[0]; ci.vdst = v[2]
|
|
st = self._run([v_mov_b32_e32(v[0], 0x0000FF00), ci])
|
|
self.assertEqual(st.vgpr[0][2], 0xFFFF00FF)
|
|
|
|
# ── structural: SUB ──
|
|
|
|
def test_emu_sub_u32(self):
|
|
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
|
x, y = _var('x', dtypes.uint), _var('y', dtypes.uint)
|
|
neg = UOp(Ops.MUL, dtypes.uint, (y, _const(-1, dtypes.uint)))
|
|
sub = UOp(Ops.ADD, dtypes.uint, (x, neg))
|
|
inst, nsrc = self._get_inst(sub)
|
|
ci = self._copy_inst(inst)
|
|
ci.src0 = v[0]; ci.vsrc1 = v[1]; ci.vdst = v[2]
|
|
st = self._run([v_mov_b32_e32(v[0], 30), v_mov_b32_e32(v[1], 12), ci])
|
|
self.assertEqual(st.vgpr[0][2], 18)
|
|
|
|
# ── structural: RCP ──
|
|
|
|
def test_emu_rcp_f32(self):
|
|
from extra.assembly.amd.test.hw.helpers import i2f
|
|
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
|
a = _var('a')
|
|
rcp = UOp(Ops.RECIPROCAL, dtypes.float, (a,))
|
|
mul_rcp = UOp(Ops.MUL, dtypes.float, (_const(1.0), rcp))
|
|
inst, _ = self._get_inst(mul_rcp)
|
|
ci = self._copy_inst(inst)
|
|
ci.src0 = v[0]; ci.vdst = v[2]
|
|
st = self._run([v_mov_b32_e32(v[0], 4.0), ci])
|
|
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 0.25, places=4)
|
|
|
|
# ── structural: MAD ──
|
|
|
|
def test_emu_mad_u32(self):
|
|
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
|
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
|
|
mul = UOp(Ops.MUL, dtypes.uint, (x, y))
|
|
mad = UOp(Ops.ADD, dtypes.uint, (mul, z))
|
|
inst, _ = self._get_inst(mad)
|
|
ci = self._copy_inst(inst)
|
|
# VOP3 format: src0, src1, src2, vdst
|
|
ci.src0 = v[0]; ci.src1 = v[1]; ci.src2 = v[2]; ci.vdst = v[3]
|
|
st = self._run([v_mov_b32_e32(v[0], 3), v_mov_b32_e32(v[1], 4), v_mov_b32_e32(v[2], 5), ci])
|
|
self.assertEqual(st.vgpr[0][3], 17) # 3*4 + 5 = 17
|
|
|
|
# ── structural: ADD3 ──
|
|
|
|
def test_emu_add3_u32(self):
|
|
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
|
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
|
|
add1 = UOp(Ops.ADD, dtypes.uint, (x, y))
|
|
add3 = UOp(Ops.ADD, dtypes.uint, (add1, z))
|
|
inst, _ = self._get_inst(add3)
|
|
ci = self._copy_inst(inst)
|
|
ci.src0 = v[0]; ci.src1 = v[1]; ci.src2 = v[2]; ci.vdst = v[3]
|
|
st = self._run([v_mov_b32_e32(v[0], 10), v_mov_b32_e32(v[1], 20), v_mov_b32_e32(v[2], 30), ci])
|
|
self.assertEqual(st.vgpr[0][3], 60) # 10+20+30
|
|
|
|
# ── structural: XOR3 ──
|
|
|
|
def test_emu_xor3_b32(self):
|
|
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
|
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
|
|
xor1 = UOp(Ops.XOR, dtypes.uint, (x, y))
|
|
xor3 = UOp(Ops.XOR, dtypes.uint, (xor1, z))
|
|
inst, _ = self._get_inst(xor3)
|
|
ci = self._copy_inst(inst)
|
|
ci.src0 = v[0]; ci.src1 = v[1]; ci.src2 = v[2]; ci.vdst = v[3]
|
|
st = self._run([v_mov_b32_e32(v[0], 0xFF), v_mov_b32_e32(v[1], 0x0F), v_mov_b32_e32(v[2], 0x33), ci])
|
|
self.assertEqual(st.vgpr[0][3], 0xFF ^ 0x0F ^ 0x33)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|