mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1cd10a2e3c |
9 changed files with 1439 additions and 6 deletions
674
test/amd/test_isel.py
Normal file
674
test/amd/test_isel.py
Normal file
|
|
@ -0,0 +1,674 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Tests for the pcode-based instruction selector (isel.py)."""
|
||||
import unittest
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.dtype import dtypes
|
||||
from extra.assembly.amd.isel import (rdna3_isel, make_inst, normalize, _count_nodes, _is_direct_alu,
|
||||
_parse_pcode_patterns, _pattern_key, _runtime_key, uop_to_upat,
|
||||
_SENTINEL, _SENTINEL_SET, _DIRECT_TABLE, _STRUCTURAL_TABLE,
|
||||
_ALU_ENUM_TYPES, build_isel_patterns)
|
||||
from extra.assembly.amd.autogen.rdna3.str_pcode import PCODE
|
||||
from extra.assembly.amd.autogen.rdna3.enum import VOP2Op, VOP1Op, VOP3Op, SOP2Op, VOPCOp, VOP3SDOp
|
||||
|
||||
# helpers
|
||||
def _var(name, dtype=dtypes.float): return UOp(Ops.DEFINE_VAR, dtype, arg=(name, 0, 100))
|
||||
def _const(val, dtype=dtypes.float): return UOp(Ops.CONST, dtype, arg=val)
|
||||
|
||||
class TestMakeInst(unittest.TestCase):
|
||||
def test_vop2(self):
|
||||
inst = make_inst(VOP2Op.V_ADD_F32_E32)
|
||||
assert inst.op == VOP2Op.V_ADD_F32_E32
|
||||
|
||||
def test_vop1(self):
|
||||
inst = make_inst(VOP1Op.V_SQRT_F32_E32)
|
||||
assert inst.op == VOP1Op.V_SQRT_F32_E32
|
||||
|
||||
def test_vop3(self):
|
||||
inst = make_inst(VOP3Op.V_ADD_F64)
|
||||
assert inst.op == VOP3Op.V_ADD_F64
|
||||
|
||||
def test_vop3_sdst(self):
|
||||
# VOP3SDOp opcodes need the _SDST variant class
|
||||
inst = make_inst(VOP3SDOp.V_ADD_CO_CI_U32)
|
||||
assert inst.op == VOP3SDOp.V_ADD_CO_CI_U32
|
||||
|
||||
def test_vopc(self):
|
||||
inst = make_inst(VOPCOp.V_CMP_LT_F32_E32)
|
||||
assert inst.op == VOPCOp.V_CMP_LT_F32_E32
|
||||
|
||||
def test_sop2(self):
|
||||
inst = make_inst(SOP2Op.S_ADD_I32)
|
||||
assert inst.op == SOP2Op.S_ADD_I32
|
||||
|
||||
def test_invalid_raises(self):
|
||||
with self.assertRaises(RuntimeError): make_inst("not_an_opcode")
|
||||
|
||||
class TestNormalize(unittest.TestCase):
|
||||
def test_bitcast_sentinel(self):
|
||||
s0 = _SENTINEL['S0']
|
||||
bc = UOp(Ops.BITCAST, dtypes.float, (s0,))
|
||||
norm = normalize(bc)
|
||||
assert norm.op == Ops.DEFINE_VAR
|
||||
assert norm.dtype == dtypes.float
|
||||
|
||||
def test_cast_sentinel(self):
|
||||
s0 = _SENTINEL['S0']
|
||||
cast = UOp(Ops.CAST, dtypes.int, (s0,))
|
||||
norm = normalize(cast)
|
||||
assert norm.op == Ops.DEFINE_VAR
|
||||
assert norm.dtype == dtypes.int
|
||||
|
||||
def test_identity_bitcast(self):
|
||||
x = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
bc = UOp(Ops.BITCAST, dtypes.float, (x,))
|
||||
norm = normalize(bc)
|
||||
assert norm.op == Ops.CONST
|
||||
assert norm.arg == 1.0
|
||||
|
||||
def test_shift_mask_31(self):
|
||||
s0 = _SENTINEL['S0']
|
||||
c31 = UOp(Ops.CONST, dtypes.uint, arg=31)
|
||||
masked = UOp(Ops.AND, dtypes.uint, (s0, c31))
|
||||
norm = normalize(masked)
|
||||
assert norm.op == Ops.DEFINE_VAR
|
||||
|
||||
def test_shift_mask_63(self):
|
||||
s0 = _SENTINEL['S0']
|
||||
c63 = UOp(Ops.CONST, dtypes.uint, arg=63)
|
||||
masked = UOp(Ops.AND, dtypes.uint, (s0, c63))
|
||||
norm = normalize(masked)
|
||||
assert norm.op == Ops.DEFINE_VAR
|
||||
|
||||
def test_non_sentinel_bitcast_preserved(self):
|
||||
x = _var('x', dtypes.uint)
|
||||
bc = UOp(Ops.BITCAST, dtypes.float, (x,))
|
||||
norm = normalize(bc)
|
||||
assert norm.op == Ops.BITCAST # not a sentinel, so preserved
|
||||
|
||||
def test_recursive(self):
|
||||
s0 = _SENTINEL['S0']
|
||||
s1 = _SENTINEL['S1']
|
||||
bc0 = UOp(Ops.BITCAST, dtypes.float, (s0,))
|
||||
bc1 = UOp(Ops.BITCAST, dtypes.float, (s1,))
|
||||
add = UOp(Ops.ADD, dtypes.float, (bc0, bc1))
|
||||
norm = normalize(add)
|
||||
assert norm.op == Ops.ADD
|
||||
assert all(s.dtype == dtypes.float for s in norm.src)
|
||||
assert all(s.op == Ops.DEFINE_VAR for s in norm.src)
|
||||
|
||||
class TestCountNodes(unittest.TestCase):
|
||||
def test_leaf(self):
|
||||
assert _count_nodes(_var('x')) == 1
|
||||
|
||||
def test_binary(self):
|
||||
x, y = _var('x'), _var('y')
|
||||
add = UOp(Ops.ADD, dtypes.float, (x, y))
|
||||
assert _count_nodes(add) == 3
|
||||
|
||||
def test_dag_sharing(self):
|
||||
x = _var('x')
|
||||
add = UOp(Ops.ADD, dtypes.float, (x, x))
|
||||
assert _count_nodes(add) == 2 # x counted once
|
||||
|
||||
class TestIsDirectAlu(unittest.TestCase):
|
||||
def test_add_sentinels(self):
|
||||
s0 = _SENTINEL['S0'].replace(dtype=dtypes.float)
|
||||
s1 = _SENTINEL['S1'].replace(dtype=dtypes.float)
|
||||
add = UOp(Ops.ADD, dtypes.float, (s0, s1))
|
||||
assert _is_direct_alu(add)
|
||||
|
||||
def test_cast_sentinel(self):
|
||||
s0 = _SENTINEL['S0'].replace(dtype=dtypes.int)
|
||||
cast = UOp(Ops.CAST, dtypes.float, (s0,))
|
||||
assert _is_direct_alu(cast)
|
||||
|
||||
def test_nested_not_direct(self):
|
||||
s0 = _SENTINEL['S0'].replace(dtype=dtypes.uint)
|
||||
c = _const(0xFFFFFFFF, dtypes.uint)
|
||||
xor = UOp(Ops.XOR, dtypes.uint, (s0, c))
|
||||
assert not _is_direct_alu(xor) # const child is not DEFINE_VAR
|
||||
|
||||
class TestPatternKey(unittest.TestCase):
|
||||
def test_sentinel_var(self):
|
||||
s0 = _SENTINEL['S0'].replace(dtype=dtypes.float)
|
||||
key = _pattern_key(s0)
|
||||
assert key == 'var(S0,dtypes.float)'
|
||||
|
||||
def test_const(self):
|
||||
c = _const(42, dtypes.uint)
|
||||
key = _pattern_key(c)
|
||||
assert key == 'const(42,dtypes.uint)'
|
||||
|
||||
def test_binary_op(self):
|
||||
s0 = _SENTINEL['S0'].replace(dtype=dtypes.float)
|
||||
s1 = _SENTINEL['S1'].replace(dtype=dtypes.float)
|
||||
add = UOp(Ops.ADD, dtypes.float, (s0, s1))
|
||||
key = _pattern_key(add)
|
||||
assert key == 'Ops.ADD(dtypes.float,var(S0,dtypes.float),var(S1,dtypes.float))'
|
||||
|
||||
class TestRuntimeKey(unittest.TestCase):
|
||||
def test_matches_pattern_key(self):
|
||||
# runtime key on a matched UOp should equal pattern key on the pcode template
|
||||
x = _var('x', dtypes.uint)
|
||||
c = _const(0xFFFFFFFF, dtypes.uint)
|
||||
xor = UOp(Ops.XOR, dtypes.uint, (x, c))
|
||||
rkey = _runtime_key(xor)
|
||||
|
||||
s0 = _SENTINEL['S0'].replace(dtype=dtypes.uint)
|
||||
xor_template = UOp(Ops.XOR, dtypes.uint, (s0, c))
|
||||
pkey = _pattern_key(xor_template)
|
||||
assert rkey == pkey
|
||||
|
||||
class TestUopToUpat(unittest.TestCase):
|
||||
def test_sentinel_becomes_var(self):
|
||||
s0 = _SENTINEL['S0'].replace(dtype=dtypes.float)
|
||||
pat = uop_to_upat(s0)
|
||||
assert pat.name == 'S0'
|
||||
assert pat.dtype == (dtypes.float,)
|
||||
|
||||
def test_const_preserved(self):
|
||||
c = _const(42, dtypes.uint)
|
||||
pat = uop_to_upat(c)
|
||||
assert pat.op == (Ops.CONST,)
|
||||
assert pat.arg == 42
|
||||
|
||||
class TestBuildPerformance(unittest.TestCase):
|
||||
def test_builds_under_2_seconds(self):
|
||||
import time
|
||||
t0 = time.time()
|
||||
build_isel_patterns(PCODE)
|
||||
elapsed = time.time() - t0
|
||||
assert elapsed < 2.0, f"build took {elapsed:.2f}s, expected <2s"
|
||||
|
||||
def test_alu_filter(self):
|
||||
# verify only ALU enum types are parsed
|
||||
for opcode in PCODE:
|
||||
if type(opcode).__name__ not in _ALU_ENUM_TYPES: continue
|
||||
# these should parse without hanging
|
||||
|
||||
class TestDirectPatterns(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.pm = rdna3_isel()
|
||||
|
||||
def _check(self, uop, expected_name_substr):
|
||||
result = self.pm.rewrite(uop)
|
||||
self.assertIsNotNone(result, f"no match for {uop.op} {uop.dtype}")
|
||||
self.assertEqual(result.op, Ops.INS)
|
||||
self.assertIn(expected_name_substr, result.arg.op.name, f"expected {expected_name_substr} in {result.arg.op.name}")
|
||||
return result
|
||||
|
||||
# arithmetic
|
||||
def test_add_f32(self): self._check(UOp(Ops.ADD, dtypes.float, (_var('a'), _var('b'))), 'V_ADD_F32')
|
||||
def test_add_f64(self): self._check(UOp(Ops.ADD, dtypes.double, (_var('a', dtypes.double), _var('b', dtypes.double))), 'V_ADD_F64')
|
||||
def test_add_i32(self): self._check(UOp(Ops.ADD, dtypes.int, (_var('a', dtypes.int), _var('b', dtypes.int))), 'ADD_NC_I32')
|
||||
def test_add_u32(self): self._check(UOp(Ops.ADD, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))), 'ADD_NC_U32')
|
||||
def test_mul_f32(self): self._check(UOp(Ops.MUL, dtypes.float, (_var('a'), _var('b'))), 'V_MUL_F32')
|
||||
def test_mul_f64(self): self._check(UOp(Ops.MUL, dtypes.double, (_var('a', dtypes.double), _var('b', dtypes.double))), 'V_MUL_F64')
|
||||
def test_mul_i32(self): self._check(UOp(Ops.MUL, dtypes.int, (_var('a', dtypes.int), _var('b', dtypes.int))), 'MUL_I32')
|
||||
def test_mul_u32(self): self._check(UOp(Ops.MUL, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))), 'MUL_U32')
|
||||
|
||||
# bitwise
|
||||
def test_and_u32(self): self._check(UOp(Ops.AND, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))), 'AND_B32')
|
||||
def test_or_u32(self): self._check(UOp(Ops.OR, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))), 'OR_B32')
|
||||
def test_xor_u32(self): self._check(UOp(Ops.XOR, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))), 'XOR_B32')
|
||||
# u64 bitwise ops are SOP-only, skipped in vgpr_only mode
|
||||
def test_and_u64_skipped(self):
|
||||
result = self.pm.rewrite(UOp(Ops.AND, dtypes.ulong, (_var('a', dtypes.ulong), _var('b', dtypes.ulong))))
|
||||
self.assertIsNone(result)
|
||||
def test_or_u64_skipped(self):
|
||||
result = self.pm.rewrite(UOp(Ops.OR, dtypes.ulong, (_var('a', dtypes.ulong), _var('b', dtypes.ulong))))
|
||||
self.assertIsNone(result)
|
||||
def test_xor_u64_skipped(self):
|
||||
result = self.pm.rewrite(UOp(Ops.XOR, dtypes.ulong, (_var('a', dtypes.ulong), _var('b', dtypes.ulong))))
|
||||
self.assertIsNone(result)
|
||||
|
||||
# shifts
|
||||
def test_shl_u32(self): self._check(UOp(Ops.SHL, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))), 'LSH')
|
||||
def test_shr_u32(self): self._check(UOp(Ops.SHR, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))), 'LSH')
|
||||
|
||||
# unary float
|
||||
def test_sqrt_f32(self): self._check(UOp(Ops.SQRT, dtypes.float, (_var('a'),)), 'SQRT_F32')
|
||||
def test_sqrt_f64(self): self._check(UOp(Ops.SQRT, dtypes.double, (_var('a', dtypes.double),)), 'SQRT_F64')
|
||||
def test_trunc_f32(self): self._check(UOp(Ops.TRUNC, dtypes.float, (_var('a'),)), 'TRUNC_F32')
|
||||
def test_trunc_f64(self): self._check(UOp(Ops.TRUNC, dtypes.double, (_var('a', dtypes.double),)), 'TRUNC_F64')
|
||||
def test_log2_f32(self): self._check(UOp(Ops.LOG2, dtypes.float, (_var('a'),)), 'LOG')
|
||||
def test_exp2_f32(self): self._check(UOp(Ops.EXP2, dtypes.float, (_var('a'),)), 'EXP')
|
||||
|
||||
# conversions
|
||||
def test_cast_i32_to_f32(self): self._check(UOp(Ops.CAST, dtypes.float, (_var('a', dtypes.int),)), 'CVT_F32_I32')
|
||||
def test_cast_f32_to_f64(self): self._check(UOp(Ops.CAST, dtypes.double, (_var('a'),)), 'CVT_F64_F32')
|
||||
def test_cast_f64_to_f32(self): self._check(UOp(Ops.CAST, dtypes.float, (_var('a', dtypes.double),)), 'CVT_F32_F64')
|
||||
def test_cast_i32_to_f64(self): self._check(UOp(Ops.CAST, dtypes.double, (_var('a', dtypes.int),)), 'CVT_F64_I32')
|
||||
def test_cast_f32_to_f16(self): self._check(UOp(Ops.CAST, dtypes.half, (_var('a'),)), 'CVT_F16_F32')
|
||||
|
||||
# compares are skipped by ISel (VOPC writes VCC, not VGPRs; LLVM handles natively)
|
||||
def test_cmplt_skipped(self):
|
||||
result = self.pm.rewrite(UOp(Ops.CMPLT, dtypes.bool, (_var('a', dtypes.int), _var('b', dtypes.int))))
|
||||
self.assertIsNone(result)
|
||||
def test_cmpne_skipped(self):
|
||||
result = self.pm.rewrite(UOp(Ops.CMPNE, dtypes.bool, (_var('a'), _var('b'))))
|
||||
self.assertIsNone(result)
|
||||
|
||||
# check that unmatched types return None
|
||||
def test_no_match(self):
|
||||
# there's no direct ADD for bools
|
||||
result = self.pm.rewrite(UOp(Ops.ADD, dtypes.bool, (_var('a', dtypes.bool), _var('b', dtypes.bool))))
|
||||
self.assertIsNone(result)
|
||||
|
||||
class TestStructuralPatterns(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.pm = rdna3_isel()
|
||||
|
||||
def _check(self, uop, expected_name_substr, expected_src_count=None):
|
||||
result = self.pm.rewrite(uop)
|
||||
self.assertIsNotNone(result, f"no match for structural pattern")
|
||||
self.assertEqual(result.op, Ops.INS)
|
||||
self.assertIn(expected_name_substr, result.arg.op.name, f"expected {expected_name_substr} in {result.arg.op.name}")
|
||||
if expected_src_count is not None:
|
||||
self.assertEqual(len(result.src), expected_src_count, f"expected {expected_src_count} srcs, got {len(result.src)}")
|
||||
return result
|
||||
|
||||
def test_not_u32(self):
|
||||
x = _var('x', dtypes.uint)
|
||||
xor = UOp(Ops.XOR, dtypes.uint, (x, _const(0xFFFFFFFF, dtypes.uint)))
|
||||
self._check(xor, 'NOT_B32', 1)
|
||||
|
||||
# u64 NOT is SOP-only (S_NOT_B64), skipped in vgpr_only mode
|
||||
def test_not_u64_skipped(self):
|
||||
x = _var('x', dtypes.ulong)
|
||||
xor = UOp(Ops.XOR, dtypes.ulong, (x, _const(0xFFFFFFFFFFFFFFFF, dtypes.ulong)))
|
||||
result = self.pm.rewrite(xor)
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_sub_u32(self):
|
||||
x, y = _var('x', dtypes.uint), _var('y', dtypes.uint)
|
||||
neg = UOp(Ops.MUL, dtypes.uint, (y, _const(-1, dtypes.uint)))
|
||||
sub = UOp(Ops.ADD, dtypes.uint, (x, neg))
|
||||
result = self._check(sub, 'SUB_NC_U32', 2)
|
||||
# verify source order: x is first, y is second
|
||||
self.assertEqual(result.src[0].arg, ('x', 0, 100))
|
||||
self.assertEqual(result.src[1].arg, ('y', 0, 100))
|
||||
|
||||
def test_rcp_f32(self):
|
||||
a = _var('a')
|
||||
rcp = UOp(Ops.RECIPROCAL, dtypes.float, (a,))
|
||||
mul_rcp = UOp(Ops.MUL, dtypes.float, (_const(1.0), rcp))
|
||||
result = self._check(mul_rcp, 'RCP_F32', 1)
|
||||
self.assertEqual(result.src[0].arg, ('a', 0, 100))
|
||||
|
||||
def test_rcp_f64(self):
|
||||
a = _var('a', dtypes.double)
|
||||
rcp = UOp(Ops.RECIPROCAL, dtypes.double, (a,))
|
||||
mul_rcp = UOp(Ops.MUL, dtypes.double, (_const(1.0, dtypes.double), rcp))
|
||||
self._check(mul_rcp, 'RCP_F64', 1)
|
||||
|
||||
def test_cvt_i32_f32(self):
|
||||
# CAST(i32, TRUNC(f32, x)) -> V_CVT_I32_F32
|
||||
a = _var('a')
|
||||
trunc = UOp(Ops.TRUNC, dtypes.float, (a,))
|
||||
cast = UOp(Ops.CAST, dtypes.int, (trunc,))
|
||||
self._check(cast, 'CVT_I32_F32', 1)
|
||||
|
||||
def test_mad_u32(self):
|
||||
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
|
||||
mul = UOp(Ops.MUL, dtypes.uint, (x, y))
|
||||
mad = UOp(Ops.ADD, dtypes.uint, (mul, z))
|
||||
result = self._check(mad, 'MAD_U32_U24', 3)
|
||||
self.assertEqual(result.src[0].arg, ('x', 0, 100))
|
||||
self.assertEqual(result.src[1].arg, ('y', 0, 100))
|
||||
self.assertEqual(result.src[2].arg, ('z', 0, 100))
|
||||
|
||||
def test_add3_u32(self):
|
||||
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
|
||||
add1 = UOp(Ops.ADD, dtypes.uint, (x, y))
|
||||
add3 = UOp(Ops.ADD, dtypes.uint, (add1, z))
|
||||
result = self._check(add3, 'ADD3_U32', 3)
|
||||
|
||||
def test_xor3_b32(self):
|
||||
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
|
||||
xor1 = UOp(Ops.XOR, dtypes.uint, (x, y))
|
||||
xor3 = UOp(Ops.XOR, dtypes.uint, (xor1, z))
|
||||
self._check(xor3, 'XOR3_B32', 3)
|
||||
|
||||
def test_and_or_b32(self):
|
||||
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
|
||||
and_op = UOp(Ops.AND, dtypes.uint, (x, y))
|
||||
or_op = UOp(Ops.OR, dtypes.uint, (and_op, z))
|
||||
self._check(or_op, 'AND_OR_B32', 3)
|
||||
|
||||
def test_or3_b32(self):
|
||||
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
|
||||
or1 = UOp(Ops.OR, dtypes.uint, (x, y))
|
||||
or3 = UOp(Ops.OR, dtypes.uint, (or1, z))
|
||||
self._check(or3, 'OR3_B32', 3)
|
||||
|
||||
# NAND/NOR were SOP-only, in vgpr_only mode they decompose to V_XOR_B32(AND/OR, mask)
|
||||
def test_nand_decomposes(self):
|
||||
x, y = _var('x', dtypes.uint), _var('y', dtypes.uint)
|
||||
and_op = UOp(Ops.AND, dtypes.uint, (x, y))
|
||||
nand = UOp(Ops.XOR, dtypes.uint, (and_op, _const(0xFFFFFFFF, dtypes.uint)))
|
||||
self._check(nand, 'XOR_B32')
|
||||
|
||||
def test_nor_decomposes(self):
|
||||
x, y = _var('x', dtypes.uint), _var('y', dtypes.uint)
|
||||
or_op = UOp(Ops.OR, dtypes.uint, (x, y))
|
||||
nor = UOp(Ops.XOR, dtypes.uint, (or_op, _const(0xFFFFFFFF, dtypes.uint)))
|
||||
self._check(nor, 'XOR_B32')
|
||||
|
||||
def test_xnor_b32(self):
|
||||
x, y = _var('x', dtypes.uint), _var('y', dtypes.uint)
|
||||
xor_op = UOp(Ops.XOR, dtypes.uint, (x, y))
|
||||
xnor = UOp(Ops.XOR, dtypes.uint, (xor_op, _const(0xFFFFFFFF, dtypes.uint)))
|
||||
self._check(xnor, 'XNOR_B32', 2)
|
||||
|
||||
def test_min_u32(self):
|
||||
x, y = _var('x', dtypes.uint), _var('y', dtypes.uint)
|
||||
cmp = UOp(Ops.CMPLT, dtypes.bool, (x, y))
|
||||
where = UOp(Ops.WHERE, dtypes.uint, (cmp, x, y))
|
||||
self._check(where, 'MIN_U32', 2)
|
||||
|
||||
class TestInstProperties(unittest.TestCase):
|
||||
"""Verify that Inst objects produced by isel have correct properties."""
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.pm = rdna3_isel()
|
||||
|
||||
def test_ins_has_dtype(self):
|
||||
result = self.pm.rewrite(UOp(Ops.ADD, dtypes.float, (_var('a'), _var('b'))))
|
||||
self.assertEqual(result.dtype, dtypes.float)
|
||||
|
||||
def test_ins_preserves_sources(self):
|
||||
a, b = _var('a'), _var('b')
|
||||
result = self.pm.rewrite(UOp(Ops.ADD, dtypes.float, (a, b)))
|
||||
self.assertEqual(result.src, (a, b))
|
||||
|
||||
def test_ins_tag_default_none(self):
|
||||
result = self.pm.rewrite(UOp(Ops.ADD, dtypes.float, (_var('a'), _var('b'))))
|
||||
# tag should not be set (defaults to None or empty)
|
||||
self.assertIsNone(result.tag)
|
||||
|
||||
class TestTableCoverage(unittest.TestCase):
|
||||
"""Verify that the tables have expected coverage."""
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
rdna3_isel() # populate tables
|
||||
|
||||
def test_direct_table_has_add(self):
|
||||
found = any(op == Ops.ADD for (op, _, _) in _DIRECT_TABLE)
|
||||
self.assertTrue(found)
|
||||
|
||||
def test_direct_table_has_cast(self):
|
||||
found = any(op == Ops.CAST for (op, _, _) in _DIRECT_TABLE)
|
||||
self.assertTrue(found)
|
||||
|
||||
def test_direct_table_skips_cmplt(self):
|
||||
# compares are in the table but skipped at runtime (bool output)
|
||||
found = any(op == Ops.CMPLT for (op, _, _) in _DIRECT_TABLE)
|
||||
self.assertTrue(found) # entries exist but callbacks skip them
|
||||
|
||||
def test_structural_table_has_not(self):
|
||||
found = any('NOT' in inst.op.name for inst in _STRUCTURAL_TABLE.values())
|
||||
self.assertTrue(found)
|
||||
|
||||
def test_structural_table_has_rcp(self):
|
||||
found = any('RCP' in inst.op.name for inst in _STRUCTURAL_TABLE.values())
|
||||
self.assertTrue(found)
|
||||
|
||||
def test_structural_table_has_sub(self):
|
||||
found = any('SUB' in inst.op.name for inst in _STRUCTURAL_TABLE.values())
|
||||
self.assertTrue(found)
|
||||
|
||||
def test_direct_count(self):
|
||||
self.assertGreaterEqual(len(_DIRECT_TABLE), 25, "expected at least 25 direct patterns")
|
||||
|
||||
def test_structural_count(self):
|
||||
self.assertGreaterEqual(len(_STRUCTURAL_TABLE), 15, "expected at least 15 structural patterns")
|
||||
|
||||
class TestEmulatorValidation(unittest.TestCase):
|
||||
"""Validate isel-produced Inst objects execute correctly in the emulator.
|
||||
|
||||
For each pattern, we:
|
||||
1. Run the UOp through isel to get Ops.INS with arg=Inst
|
||||
2. Copy the Inst and assign concrete registers
|
||||
3. Set up operand values via MOV instructions
|
||||
4. Execute through the emulator
|
||||
5. Verify the output matches expected computation
|
||||
"""
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.pm = rdna3_isel()
|
||||
|
||||
def _run(self, instructions, n_lanes=1):
|
||||
from extra.assembly.amd.test.hw.helpers import run_program_emu
|
||||
return run_program_emu(instructions, n_lanes)
|
||||
|
||||
def _get_inst(self, uop):
|
||||
"""Get isel result, return (Inst, src_count)."""
|
||||
result = self.pm.rewrite(uop)
|
||||
assert result is not None and result.op == Ops.INS, f"isel failed for {uop.op} {uop.dtype}"
|
||||
return result.arg, len(result.src)
|
||||
|
||||
def _copy_inst(self, inst):
|
||||
import copy
|
||||
return copy.copy(inst)
|
||||
|
||||
# ── direct ALU: float arithmetic ──
|
||||
|
||||
def test_emu_add_f32(self):
|
||||
from extra.assembly.amd.test.hw.helpers import i2f, f2i
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
||||
inst, _ = self._get_inst(UOp(Ops.ADD, dtypes.float, (_var('a'), _var('b'))))
|
||||
ci = self._copy_inst(inst)
|
||||
ci.src0 = v[0]; ci.vsrc1 = v[1]; ci.vdst = v[2]
|
||||
st = self._run([v_mov_b32_e32(v[0], 1.5), v_mov_b32_e32(v[1], 2.25), ci])
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 3.75, places=5)
|
||||
|
||||
def test_emu_mul_f32(self):
|
||||
from extra.assembly.amd.test.hw.helpers import i2f
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
||||
inst, _ = self._get_inst(UOp(Ops.MUL, dtypes.float, (_var('a'), _var('b'))))
|
||||
ci = self._copy_inst(inst)
|
||||
ci.src0 = v[0]; ci.vsrc1 = v[1]; ci.vdst = v[2]
|
||||
st = self._run([v_mov_b32_e32(v[0], 3.0), v_mov_b32_e32(v[1], 4.0), ci])
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 12.0, places=5)
|
||||
|
||||
# ── direct ALU: integer arithmetic ──
|
||||
|
||||
def test_emu_add_u32(self):
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
||||
inst, _ = self._get_inst(UOp(Ops.ADD, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))))
|
||||
ci = self._copy_inst(inst)
|
||||
ci.src0 = v[0]; ci.vsrc1 = v[1]; ci.vdst = v[2]
|
||||
st = self._run([v_mov_b32_e32(v[0], 10), v_mov_b32_e32(v[1], 20), ci])
|
||||
self.assertEqual(st.vgpr[0][2], 30)
|
||||
|
||||
# ── direct ALU: bitwise ──
|
||||
|
||||
def test_emu_and_u32(self):
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
||||
inst, _ = self._get_inst(UOp(Ops.AND, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))))
|
||||
ci = self._copy_inst(inst)
|
||||
ci.src0 = v[0]; ci.vsrc1 = v[1]; ci.vdst = v[2]
|
||||
st = self._run([v_mov_b32_e32(v[0], 0xFF00), v_mov_b32_e32(v[1], 0x0FF0), ci])
|
||||
self.assertEqual(st.vgpr[0][2], 0x0F00)
|
||||
|
||||
def test_emu_or_u32(self):
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
||||
inst, _ = self._get_inst(UOp(Ops.OR, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))))
|
||||
ci = self._copy_inst(inst)
|
||||
ci.src0 = v[0]; ci.vsrc1 = v[1]; ci.vdst = v[2]
|
||||
st = self._run([v_mov_b32_e32(v[0], 0xFF00), v_mov_b32_e32(v[1], 0x0FF0), ci])
|
||||
self.assertEqual(st.vgpr[0][2], 0xFFF0)
|
||||
|
||||
def test_emu_xor_u32(self):
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
||||
inst, _ = self._get_inst(UOp(Ops.XOR, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))))
|
||||
ci = self._copy_inst(inst)
|
||||
ci.src0 = v[0]; ci.vsrc1 = v[1]; ci.vdst = v[2]
|
||||
st = self._run([v_mov_b32_e32(v[0], 0xFF00), v_mov_b32_e32(v[1], 0x0FF0), ci])
|
||||
self.assertEqual(st.vgpr[0][2], 0xF0F0)
|
||||
|
||||
# ── direct ALU: shifts ──
|
||||
|
||||
def test_emu_shl_u32(self):
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
||||
inst, _ = self._get_inst(UOp(Ops.SHL, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))))
|
||||
ci = self._copy_inst(inst)
|
||||
# LSHLREV: vdst = vsrc1 << src0 (reversed operands!)
|
||||
ci.src0 = v[1]; ci.vsrc1 = v[0]; ci.vdst = v[2]
|
||||
st = self._run([v_mov_b32_e32(v[0], 1), v_mov_b32_e32(v[1], 4), ci])
|
||||
self.assertEqual(st.vgpr[0][2], 16) # 1 << 4 = 16
|
||||
|
||||
def test_emu_shr_u32(self):
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
||||
inst, _ = self._get_inst(UOp(Ops.SHR, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))))
|
||||
ci = self._copy_inst(inst)
|
||||
# LSHRREV: vdst = vsrc1 >> src0 (reversed operands!)
|
||||
ci.src0 = v[1]; ci.vsrc1 = v[0]; ci.vdst = v[2]
|
||||
st = self._run([v_mov_b32_e32(v[0], 16), v_mov_b32_e32(v[1], 4), ci])
|
||||
self.assertEqual(st.vgpr[0][2], 1) # 16 >> 4 = 1
|
||||
|
||||
# ── direct ALU: unary float ──
|
||||
|
||||
def test_emu_sqrt_f32(self):
|
||||
from extra.assembly.amd.test.hw.helpers import i2f
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
||||
inst, _ = self._get_inst(UOp(Ops.SQRT, dtypes.float, (_var('a'),)))
|
||||
ci = self._copy_inst(inst)
|
||||
ci.src0 = v[0]; ci.vdst = v[2]
|
||||
st = self._run([v_mov_b32_e32(v[0], 4.0), ci])
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 2.0, places=4)
|
||||
|
||||
def test_emu_trunc_f32(self):
|
||||
from extra.assembly.amd.test.hw.helpers import i2f
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
||||
inst, _ = self._get_inst(UOp(Ops.TRUNC, dtypes.float, (_var('a'),)))
|
||||
ci = self._copy_inst(inst)
|
||||
ci.src0 = v[0]; ci.vdst = v[2]
|
||||
st = self._run([v_mov_b32_e32(v[0], 3.7), ci])
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 3.0, places=5)
|
||||
|
||||
def test_emu_exp2_f32(self):
|
||||
from extra.assembly.amd.test.hw.helpers import i2f
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
||||
inst, _ = self._get_inst(UOp(Ops.EXP2, dtypes.float, (_var('a'),)))
|
||||
ci = self._copy_inst(inst)
|
||||
ci.src0 = v[0]; ci.vdst = v[2]
|
||||
st = self._run([v_mov_b32_e32(v[0], 3.0), ci])
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 8.0, delta=0.01)
|
||||
|
||||
def test_emu_log2_f32(self):
|
||||
from extra.assembly.amd.test.hw.helpers import i2f
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
||||
inst, _ = self._get_inst(UOp(Ops.LOG2, dtypes.float, (_var('a'),)))
|
||||
ci = self._copy_inst(inst)
|
||||
ci.src0 = v[0]; ci.vdst = v[2]
|
||||
st = self._run([v_mov_b32_e32(v[0], 8.0), ci])
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 3.0, delta=0.01)
|
||||
|
||||
# ── direct ALU: conversions ──
|
||||
|
||||
def test_emu_cast_i32_to_f32(self):
|
||||
from extra.assembly.amd.test.hw.helpers import i2f
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
||||
inst, _ = self._get_inst(UOp(Ops.CAST, dtypes.float, (_var('a', dtypes.int),)))
|
||||
ci = self._copy_inst(inst)
|
||||
ci.src0 = v[0]; ci.vdst = v[2]
|
||||
st = self._run([v_mov_b32_e32(v[0], 42), ci])
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 42.0, places=5)
|
||||
|
||||
def test_emu_cast_f32_to_f16(self):
|
||||
from extra.assembly.amd.test.hw.helpers import i2f, f16
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
||||
inst, _ = self._get_inst(UOp(Ops.CAST, dtypes.half, (_var('a'),)))
|
||||
ci = self._copy_inst(inst)
|
||||
ci.src0 = v[0]; ci.vdst = v[2]
|
||||
st = self._run([v_mov_b32_e32(v[0], 1.5), ci])
|
||||
# f16 result is in lower 16 bits of v[2]
|
||||
self.assertAlmostEqual(f16(st.vgpr[0][2]), 1.5, places=2)
|
||||
|
||||
# compares are skipped by ISel (VOPC writes VCC; LLVM handles natively)
|
||||
|
||||
# ── structural: NOT ──
|
||||
|
||||
def test_emu_not_u32(self):
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
||||
x = _var('x', dtypes.uint)
|
||||
xor = UOp(Ops.XOR, dtypes.uint, (x, _const(0xFFFFFFFF, dtypes.uint)))
|
||||
inst, _ = self._get_inst(xor)
|
||||
ci = self._copy_inst(inst)
|
||||
ci.src0 = v[0]; ci.vdst = v[2]
|
||||
st = self._run([v_mov_b32_e32(v[0], 0x0000FF00), ci])
|
||||
self.assertEqual(st.vgpr[0][2], 0xFFFF00FF)
|
||||
|
||||
# ── structural: SUB ──
|
||||
|
||||
def test_emu_sub_u32(self):
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
||||
x, y = _var('x', dtypes.uint), _var('y', dtypes.uint)
|
||||
neg = UOp(Ops.MUL, dtypes.uint, (y, _const(-1, dtypes.uint)))
|
||||
sub = UOp(Ops.ADD, dtypes.uint, (x, neg))
|
||||
inst, nsrc = self._get_inst(sub)
|
||||
ci = self._copy_inst(inst)
|
||||
ci.src0 = v[0]; ci.vsrc1 = v[1]; ci.vdst = v[2]
|
||||
st = self._run([v_mov_b32_e32(v[0], 30), v_mov_b32_e32(v[1], 12), ci])
|
||||
self.assertEqual(st.vgpr[0][2], 18)
|
||||
|
||||
# ── structural: RCP ──
|
||||
|
||||
def test_emu_rcp_f32(self):
|
||||
from extra.assembly.amd.test.hw.helpers import i2f
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
||||
a = _var('a')
|
||||
rcp = UOp(Ops.RECIPROCAL, dtypes.float, (a,))
|
||||
mul_rcp = UOp(Ops.MUL, dtypes.float, (_const(1.0), rcp))
|
||||
inst, _ = self._get_inst(mul_rcp)
|
||||
ci = self._copy_inst(inst)
|
||||
ci.src0 = v[0]; ci.vdst = v[2]
|
||||
st = self._run([v_mov_b32_e32(v[0], 4.0), ci])
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 0.25, places=4)
|
||||
|
||||
# ── structural: MAD ──
|
||||
|
||||
def test_emu_mad_u32(self):
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
||||
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
|
||||
mul = UOp(Ops.MUL, dtypes.uint, (x, y))
|
||||
mad = UOp(Ops.ADD, dtypes.uint, (mul, z))
|
||||
inst, _ = self._get_inst(mad)
|
||||
ci = self._copy_inst(inst)
|
||||
# VOP3 format: src0, src1, src2, vdst
|
||||
ci.src0 = v[0]; ci.src1 = v[1]; ci.src2 = v[2]; ci.vdst = v[3]
|
||||
st = self._run([v_mov_b32_e32(v[0], 3), v_mov_b32_e32(v[1], 4), v_mov_b32_e32(v[2], 5), ci])
|
||||
self.assertEqual(st.vgpr[0][3], 17) # 3*4 + 5 = 17
|
||||
|
||||
# ── structural: ADD3 ──
|
||||
|
||||
def test_emu_add3_u32(self):
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
||||
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
|
||||
add1 = UOp(Ops.ADD, dtypes.uint, (x, y))
|
||||
add3 = UOp(Ops.ADD, dtypes.uint, (add1, z))
|
||||
inst, _ = self._get_inst(add3)
|
||||
ci = self._copy_inst(inst)
|
||||
ci.src0 = v[0]; ci.src1 = v[1]; ci.src2 = v[2]; ci.vdst = v[3]
|
||||
st = self._run([v_mov_b32_e32(v[0], 10), v_mov_b32_e32(v[1], 20), v_mov_b32_e32(v[2], 30), ci])
|
||||
self.assertEqual(st.vgpr[0][3], 60) # 10+20+30
|
||||
|
||||
# ── structural: XOR3 ──
|
||||
|
||||
def test_emu_xor3_b32(self):
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
|
||||
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
|
||||
xor1 = UOp(Ops.XOR, dtypes.uint, (x, y))
|
||||
xor3 = UOp(Ops.XOR, dtypes.uint, (xor1, z))
|
||||
inst, _ = self._get_inst(xor3)
|
||||
ci = self._copy_inst(inst)
|
||||
ci.src0 = v[0]; ci.src1 = v[1]; ci.src2 = v[2]; ci.vdst = v[3]
|
||||
st = self._run([v_mov_b32_e32(v[0], 0xFF), v_mov_b32_e32(v[1], 0x0F), v_mov_b32_e32(v[2], 0x33), ci])
|
||||
self.assertEqual(st.vgpr[0][3], 0xFF ^ 0x0F ^ 0x33)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -189,7 +189,7 @@ CPU_CC, CPU_LLVM, CPU_LVP = ContextVar("CPU_CC", ""), ContextVar("CPU_LLVM", 0),
|
|||
NV_CC, NV_PTX, NV_NAK = ContextVar("NV_CC", ""), ContextVar("NV_PTX", 0), ContextVar("NV_NAK", 0)
|
||||
CUDA_CC, CUDA_PTX, CUDA_NVCC = ContextVar("CUDA_CC", ""), ContextVar("CUDA_PTX", 0), ContextVar("CUDA_NVCC", 0)
|
||||
NULL_IR3, NULL_NAK, NULL_ALLOW_COPYOUT = ContextVar("NULL_IR3", 0), ContextVar("NULL_NAK", 0), ContextVar("NULL_ALLOW_COPYOUT", 0)
|
||||
AMD_CC, AMD_LLVM, AMD_HIPCC = ContextVar("AMD_CC", ""), ContextVar("AMD_LLVM", 0), ContextVar("AMD_HIPCC", 0)
|
||||
AMD_CC, AMD_LLVM, AMD_HIPCC, AMD_ISEL, AMD_ASM = ContextVar("AMD_CC", ""), ContextVar("AMD_LLVM", 0), ContextVar("AMD_HIPCC", 0), ContextVar("AMD_ISEL", 0), ContextVar("AMD_ASM", 0)
|
||||
QCOM_CC, QCOM_IR3 = ContextVar("QCOM_CC", ""), ContextVar("QCOM_IR3", 0)
|
||||
# VIZ implies PROFILE, but you can run PROFILE without VIZ
|
||||
VIZ = ContextVar("VIZ", 0)
|
||||
|
|
|
|||
|
|
@ -444,6 +444,9 @@ class Inst:
|
|||
|
||||
def __eq__(self, other): return type(self) is type(other) and self._raw == other._raw
|
||||
def __hash__(self): return hash((type(self), self._raw))
|
||||
def __lt__(self, other):
|
||||
if not isinstance(other, Inst): return NotImplemented
|
||||
return (type(self).__name__, self._raw) < (type(other).__name__, other._raw)
|
||||
|
||||
def __repr__(self):
|
||||
# collect (repr, is_default) pairs, strip trailing defaults so repr roundtrips with eval
|
||||
|
|
|
|||
325
tinygrad/renderer/amd/isel.py
Normal file
325
tinygrad/renderer/amd/isel.py
Normal file
|
|
@ -0,0 +1,325 @@
|
|||
# Instruction selection for AMD GPUs via pcode-derived PatternMatcher
|
||||
# Parses AMD pcode specs into UOp templates, normalizes them, and converts to UPat patterns
|
||||
# that rewrite renderer-level UOps into Ops.INS with arg=Inst objects
|
||||
|
||||
import functools
|
||||
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, GroupOp
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from extra.assembly.amd.emu import parse_pcode
|
||||
from extra.assembly.amd.autogen.rdna3.str_pcode import PCODE
|
||||
from extra.assembly.amd.autogen.rdna3 import ins as rdna3_ins
|
||||
|
||||
# sentinel UOps representing source operands (typed as u32, like real registers)
|
||||
_SENTINEL = {f'S{i}': UOp(Ops.DEFINE_VAR, dtypes.uint32, arg=(f'S{i}', 0, 0xFFFFFFFF)) for i in range(4)}
|
||||
_SENTINEL_SET = set(_SENTINEL.values())
|
||||
|
||||
# only parse ALU-relevant opcode types (memory ops need different sentinels)
|
||||
_ALU_ENUM_TYPES = frozenset({'SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'VOP1Op', 'VOP2Op',
|
||||
'VOP3Op', 'VOP3POp', 'VOP3SDOp', 'VOPCOp', 'VINTERPOp'})
|
||||
# SOP types use SGPRs — skip for LLVM inline asm renderer (VGPR-only)
|
||||
_SOP_ENUM_TYPES = frozenset({'SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp'})
|
||||
|
||||
# opcode enum class name -> list of Inst class suffixes to try
|
||||
_VARIANT_SUFFIXES = ['', '_SDST']
|
||||
|
||||
def make_inst(opcode):
|
||||
"""Create an Inst object with just the opcode set (registers defaulted)."""
|
||||
base_name = type(opcode).__name__[:-2]
|
||||
for suffix in _VARIANT_SUFFIXES:
|
||||
cls = getattr(rdna3_ins, base_name + suffix, None)
|
||||
if cls is None: continue
|
||||
try: return cls(op=opcode)
|
||||
except (RuntimeError, TypeError): continue
|
||||
raise RuntimeError(f"no Inst class found for {opcode}")
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# Normalization: strip register-model artifacts from pcode UOps
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
def normalize(uop, _cache=None):
|
||||
"""Strip register-model artifacts from pcode UOps to match renderer-level UOps.
|
||||
|
||||
Pcode models registers as typeless u32 words and uses BITCAST/CAST to reinterpret.
|
||||
The renderer's UOps are natively typed — we strip these artifacts:
|
||||
- BITCAST(f32, sentinel_u32) -> sentinel typed as f32
|
||||
- CAST(i32, sentinel_u32) -> sentinel typed as i32 (same-size reinterpret)
|
||||
- CAST(u64, sentinel_u32) -> sentinel typed as u64 (widening for 64-bit ops)
|
||||
- BITCAST(T, x) where x.dtype == T -> x (identity bitcast)
|
||||
- AND(x, mask) where mask is shift masking -> x (hardware does this implicitly)
|
||||
"""
|
||||
if _cache is None: _cache = {}
|
||||
if id(uop) in _cache: return _cache[id(uop)]
|
||||
|
||||
# first recurse so children are normalized before we check patterns
|
||||
new_src = tuple(normalize(s, _cache) for s in uop.src)
|
||||
uop = uop if new_src == uop.src else uop.replace(src=new_src)
|
||||
|
||||
# BITCAST or CAST on a sentinel -> sentinel with target dtype
|
||||
if uop.op in (Ops.BITCAST, Ops.CAST) and len(uop.src) == 1 and uop.src[0] in _SENTINEL_SET:
|
||||
result = uop.src[0].replace(dtype=uop.dtype)
|
||||
_cache[id(uop)] = result
|
||||
return result
|
||||
|
||||
# identity BITCAST: BITCAST(T, x) where x already has dtype T
|
||||
if uop.op == Ops.BITCAST and len(uop.src) == 1 and uop.src[0].dtype == uop.dtype:
|
||||
_cache[id(uop)] = uop.src[0]
|
||||
return uop.src[0]
|
||||
|
||||
# shift masking: AND(sentinel, 31) or AND(sentinel, 63) -> sentinel (hardware masks shift amounts)
|
||||
if uop.op == Ops.AND and len(uop.src) == 2:
|
||||
if uop.src[1].op == Ops.CONST and uop.src[1].arg in (31, 63) and uop.src[0].op == Ops.DEFINE_VAR:
|
||||
_cache[id(uop)] = uop.src[0]
|
||||
return uop.src[0]
|
||||
|
||||
_cache[id(uop)] = uop
|
||||
return uop
|
||||
|
||||
def _count_nodes(uop, _seen=None):
|
||||
if _seen is None: _seen = set()
|
||||
if id(uop) in _seen: return 0
|
||||
_seen.add(id(uop))
|
||||
return 1 + sum(_count_nodes(s, _seen) for s in uop.src)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# UOp template -> UPat conversion
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
def uop_to_upat(uop, _seen=None):
|
||||
"""Convert a normalized UOp template into a matchable UPat pattern."""
|
||||
if _seen is None: _seen = {}
|
||||
if id(uop) in _seen: return _seen[id(uop)]
|
||||
if uop.op == Ops.DEFINE_VAR and isinstance(uop.arg, tuple) and uop.arg[0] in _SENTINEL:
|
||||
result = UPat.var(uop.arg[0], dtype=uop.dtype)
|
||||
_seen[id(uop)] = result
|
||||
return result
|
||||
if uop.op in (Ops.CONST, Ops.VCONST):
|
||||
result = UPat(uop.op, uop.dtype, arg=uop.arg)
|
||||
_seen[id(uop)] = result
|
||||
return result
|
||||
src = tuple(uop_to_upat(s, _seen) for s in uop.src) if uop.src else None
|
||||
result = UPat(uop.op, uop.dtype, src=src)
|
||||
_seen[id(uop)] = result
|
||||
return result
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# Pattern classification and selection
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
def _select_best_opcode(opcodes):
|
||||
"""Prefer shorter encodings: VOP1/VOP2 > SOP > VOP3/VOPC."""
|
||||
_PREF = {'VOP1Op': 0, 'VOP2Op': 0, 'SOP1Op': 1, 'SOP2Op': 1, 'SOPCOp': 1, 'VOPCOp': 2, 'VOP3Op': 3, 'VOP3SDOp': 3, 'VOP3POp': 4}
|
||||
return min(opcodes, key=lambda oc: (_PREF.get(type(oc).__name__, 9), oc.value))
|
||||
|
||||
def _is_direct_alu(norm_uop):
|
||||
"""Check if normalized UOp is a direct ALU: op(sentinels...) with no intermediate ops."""
|
||||
if norm_uop.op not in GroupOp.ALU and norm_uop.op not in {Ops.CAST, Ops.BITCAST}: return False
|
||||
return all(s.op == Ops.DEFINE_VAR for s in norm_uop.src)
|
||||
|
||||
def _pattern_key(uop, _seen=None):
|
||||
"""Structural fingerprint for a normalized UOp template (sentinels become var placeholders)."""
|
||||
if _seen is None: _seen = {}
|
||||
if id(uop) in _seen: return _seen[id(uop)]
|
||||
if uop.op == Ops.DEFINE_VAR and isinstance(uop.arg, tuple) and uop.arg[0] in _SENTINEL:
|
||||
result = f'var({uop.arg[0]},{uop.dtype})'
|
||||
elif uop.op in (Ops.CONST, Ops.VCONST):
|
||||
result = f'const({uop.arg},{uop.dtype})'
|
||||
else:
|
||||
children = ','.join(_pattern_key(s, _seen) for s in uop.src)
|
||||
result = f'{uop.op}({uop.dtype},{children})'
|
||||
_seen[id(uop)] = result
|
||||
return result
|
||||
|
||||
def _runtime_key(uop, _var_counter=None, _seen=None):
|
||||
"""Compute a structural key from a matched UOp at runtime (real data, not sentinels).
|
||||
Leaf UOps (non-ALU with no recognized children) are treated as variables."""
|
||||
if _seen is None: _seen = {}
|
||||
if _var_counter is None: _var_counter = [0]
|
||||
uid = id(uop)
|
||||
if uid in _seen: return _seen[uid]
|
||||
_ALU_OPS = GroupOp.ALU | {Ops.CAST, Ops.BITCAST, Ops.WHERE}
|
||||
if uop.op in (Ops.CONST, Ops.VCONST):
|
||||
result = f'const({uop.arg},{uop.dtype})'
|
||||
elif uop.op not in _ALU_OPS:
|
||||
result = f'var(S{_var_counter[0]},{uop.dtype})'
|
||||
_var_counter[0] += 1
|
||||
else:
|
||||
children = ','.join(_runtime_key(s, _var_counter, _seen) for s in uop.src)
|
||||
result = f'{uop.op}({uop.dtype},{children})'
|
||||
_seen[uid] = result
|
||||
return result
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# Build tables from pcode
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
def _parse_pcode_patterns(pcode_dict, vgpr_only=False):
|
||||
"""Parse ALU pcode entries, normalize, and categorize into direct vs structural."""
|
||||
# key: (op, src_dtypes_tuple, dst_dtype) -> [opcodes]
|
||||
direct: dict[tuple, list] = {}
|
||||
structural: list[tuple] = [] # [(opcode, norm_uop)]
|
||||
allowed_types = _ALU_ENUM_TYPES - _SOP_ENUM_TYPES if vgpr_only else _ALU_ENUM_TYPES
|
||||
|
||||
for opcode, pcode_str in pcode_dict.items():
|
||||
if type(opcode).__name__ not in allowed_types: continue
|
||||
try: env, assigns = parse_pcode(pcode_str, dict(_SENTINEL))
|
||||
except Exception: continue
|
||||
d0 = next(((n, u) for n, u in assigns if n.startswith('D0')), None)
|
||||
if d0 is None: continue
|
||||
_, uop = d0
|
||||
if _count_nodes(uop) > 5: continue
|
||||
norm = normalize(uop)
|
||||
if norm.op == Ops.DEFINE_VAR or norm.op in (Ops.CONST, Ops.VCONST): continue
|
||||
|
||||
if _is_direct_alu(norm):
|
||||
src_dtypes = tuple(s.dtype for s in norm.src)
|
||||
key = (norm.op, src_dtypes, norm.dtype)
|
||||
direct.setdefault(key, []).append(opcode)
|
||||
else:
|
||||
structural.append((opcode, norm))
|
||||
|
||||
# pick best opcode for each direct pattern
|
||||
direct_best = {k: _select_best_opcode(v) for k, v in direct.items()}
|
||||
|
||||
# deduplicate structural patterns by shape, pick best
|
||||
seen: dict[str, list] = {}
|
||||
for opcode, norm in structural:
|
||||
key = _pattern_key(norm)
|
||||
seen.setdefault(key, []).append((opcode, norm))
|
||||
structural_best = []
|
||||
for key, group in seen.items():
|
||||
best = _select_best_opcode([oc for oc, _ in group])
|
||||
best_norm = next(n for oc, n in group if oc == best)
|
||||
structural_best.append((best, best_norm, key))
|
||||
|
||||
return direct_best, structural_best
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# Global tables (populated by build_isel_patterns, used by callbacks)
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
# direct: (op, src_dtypes, dst_dtype) -> Inst
|
||||
_DIRECT_TABLE: dict[tuple, object] = {}
|
||||
# structural: pattern_key_string -> Inst
|
||||
_STRUCTURAL_TABLE: dict[str, object] = {}
|
||||
|
||||
# ops that LLVM handles natively (compares write VCC, not VGPRs)
|
||||
_SKIP_OPS = frozenset({Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ})
|
||||
|
||||
def _isel_direct(m):
|
||||
"""Callback for direct ALU: look up Inst by (op, src_dtypes, dtype)."""
|
||||
if m.op in _SKIP_OPS or m.dtype == dtypes.bool: return None
|
||||
src_dtypes = tuple(s.dtype for s in m.src)
|
||||
inst = _DIRECT_TABLE.get((m.op, src_dtypes, m.dtype))
|
||||
if inst is None: return None
|
||||
return UOp(Ops.INS, m.dtype, m.src, arg=inst)
|
||||
|
||||
def _isel_structural(m, **kwargs):
|
||||
"""Callback for structural patterns: compute runtime key, look up Inst."""
|
||||
if m.op in _SKIP_OPS or m.dtype == dtypes.bool: return None
|
||||
key = _runtime_key(m)
|
||||
inst = _STRUCTURAL_TABLE.get(key)
|
||||
if inst is None: return None
|
||||
# collect source vars in order (leaves of the matched tree)
|
||||
srcs = _collect_leaves(m)
|
||||
return UOp(Ops.INS, m.dtype, tuple(srcs), arg=inst)
|
||||
|
||||
def _collect_leaves(uop, _seen=None):
|
||||
"""Collect leaf UOps (non-ALU) from a matched tree in left-to-right order."""
|
||||
if _seen is None: _seen = set()
|
||||
_ALU_OPS = GroupOp.ALU | {Ops.CAST, Ops.BITCAST, Ops.WHERE}
|
||||
uid = id(uop)
|
||||
if uid in _seen: return []
|
||||
_seen.add(uid)
|
||||
if uop.op in (Ops.CONST, Ops.VCONST): return []
|
||||
if uop.op not in _ALU_OPS: return [uop]
|
||||
result = []
|
||||
for s in uop.src:
|
||||
result.extend(_collect_leaves(s, _seen))
|
||||
return result
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# Build the PatternMatcher
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
def build_isel_patterns(pcode_dict=PCODE, vgpr_only=False) -> PatternMatcher:
|
||||
"""Parse pcode and build a PatternMatcher for instruction selection."""
|
||||
direct_best, structural_best = _parse_pcode_patterns(pcode_dict, vgpr_only=vgpr_only)
|
||||
|
||||
# populate direct table
|
||||
_DIRECT_TABLE.clear()
|
||||
for (op, src_dtypes, dtype), opcode in direct_best.items():
|
||||
_DIRECT_TABLE[(op, src_dtypes, dtype)] = make_inst(opcode)
|
||||
|
||||
# populate structural table
|
||||
_STRUCTURAL_TABLE.clear()
|
||||
for opcode, norm, pkey in structural_best:
|
||||
_STRUCTURAL_TABLE[pkey] = make_inst(opcode)
|
||||
|
||||
patterns: list[tuple] = []
|
||||
|
||||
# structural patterns first (more specific, should match before catch-all direct)
|
||||
for opcode, norm, pkey in structural_best:
|
||||
pat = uop_to_upat(norm).named('m')
|
||||
patterns.append((pat, _isel_structural))
|
||||
|
||||
# direct ALU: catch-all patterns that look up by (op, src_dtypes, dtype)
|
||||
patterns.append((UPat(GroupOp.ALU, name='m'), _isel_direct))
|
||||
patterns.append((UPat(Ops.CAST, name='m'), _isel_direct))
|
||||
patterns.append((UPat(Ops.BITCAST, name='m'), _isel_direct))
|
||||
|
||||
return PatternMatcher(patterns)
|
||||
|
||||
@functools.cache
|
||||
def rdna3_isel() -> PatternMatcher:
|
||||
"""Build the default RDNA3 instruction selector (VOP-only for LLVM inline asm)."""
|
||||
return build_isel_patterns(PCODE, vgpr_only=True)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# LLVM inline asm rendering for Ops.INS
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
import re
|
||||
from tinygrad.dtype import PtrDType
|
||||
|
||||
def _ins_mnemonic(inst) -> str:
|
||||
"""Get the assembly mnemonic from an Inst opcode (strip _E32/_E64 suffix)."""
|
||||
return re.sub(r'_e(32|64)$', '', inst.op.name.lower())
|
||||
|
||||
def _ldt(dt):
|
||||
"""LLVM type string for a DType."""
|
||||
if dt.vcount > 1: return f"<{dt.vcount} x {_ldt(dt.scalar())}>"
|
||||
if isinstance(dt, PtrDType): return _ldt(dt.base) + "*"
|
||||
return {dtypes.void: "void", dtypes.bool: "i1", dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
|
||||
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64",
|
||||
dtypes.float16: "half", dtypes.bfloat16: "bfloat", dtypes.float32: "float", dtypes.float64: "double"}[dt]
|
||||
|
||||
def render_ins_llvm(ctx, x):
|
||||
"""Render Ops.INS as LLVM inline assembly call."""
|
||||
inst = x.arg
|
||||
mnem = _ins_mnemonic(inst)
|
||||
n_srcs = len(x.src)
|
||||
# build operand string: $0 = dest, $1..$N = sources
|
||||
ops = ", ".join(f"${i}" for i in range(n_srcs + 1))
|
||||
asm_str = f"{mnem} {ops}"
|
||||
# constraints: =v for output, v for each input (VGPR)
|
||||
constraints = "=v," + ",".join("v" for _ in range(n_srcs))
|
||||
# LLVM types and values
|
||||
ret_type = _ldt(x.dtype)
|
||||
args = ", ".join(f"{_ldt(s.dtype)} {ctx[s]}" for s in x.src)
|
||||
return f" {ctx[x]} = call {ret_type} asm \"{asm_str}\", \"{constraints}\"({args})"
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# AMDISELRenderer: LLVM renderer with pcode-based instruction selection
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
from tinygrad.renderer.llvmir import AMDLLVMRenderer
|
||||
|
||||
class AMDISELRenderer(AMDLLVMRenderer):
|
||||
"""AMD renderer that uses pcode-derived instruction selection for ALU ops."""
|
||||
def __init__(self, arch: str):
|
||||
super().__init__(arch)
|
||||
# add ISel as extra_matcher: rewrites ALU UOps → Ops.INS
|
||||
self.extra_matcher = self.extra_matcher + rdna3_isel()
|
||||
# add Ops.INS rendering to string_rewrite
|
||||
self.string_rewrite = PatternMatcher([(UPat(Ops.INS, name='x'), render_ins_llvm)]) + self.string_rewrite
|
||||
def __reduce__(self): return self.__class__, (self.arch,)
|
||||
427
tinygrad/renderer/amd/renderer.py
Normal file
427
tinygrad/renderer/amd/renderer.py
Normal file
|
|
@ -0,0 +1,427 @@
|
|||
# Direct AMD GPU assembly renderer — emits Inst objects, produces GAS text via disasm()
|
||||
# No LLVM. Uses HIPCompiler (COMGR) to assemble text into ELF.
|
||||
|
||||
import functools, math
|
||||
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, GroupOp
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import AMDHIPRenderer
|
||||
from extra.assembly.amd.dsl import Inst, Reg, s, v, NULL, VCC_LO, EXEC_LO, M0
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (s_load_b64, s_load_b128, s_mov_b32, s_waitcnt, s_endpgm, s_barrier,
|
||||
s_branch, s_cbranch_scc0, s_cbranch_scc1, s_cmp_ge_i32, s_add_i32, s_and_b32, s_lshl_b32,
|
||||
v_mov_b32_e32, v_add_f32_e32, v_add_nc_u32_e32, v_lshlrev_b32_e32, v_lshrrev_b32_e32,
|
||||
v_and_b32_e32, v_mul_lo_u32, v_cmp_lt_i32_e32,
|
||||
global_load_b32, global_load_b64, global_load_b128, global_store_b32, global_store_b64, global_store_b128,
|
||||
ds_load_b32, ds_store_b32)
|
||||
from extra.assembly.amd.test.disasm import disasm
|
||||
from extra.assembly.amd.isel import rdna3_isel, make_inst
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# Register allocator — simple bump allocator
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
class RegFile:
|
||||
"""Simple register allocator: bump-allocates VGPRs and SGPRs."""
|
||||
def __init__(self):
|
||||
self.next_vgpr = 1 # v0 = workitem_id_x (reserved by hardware)
|
||||
self.next_sgpr = 0 # s[0:1] = kernarg_ptr (reserved by ABI)
|
||||
self.max_vgpr = 1
|
||||
self.max_sgpr = 0
|
||||
|
||||
def alloc_vgpr(self, count=1) -> Reg:
|
||||
r = v[self.next_vgpr] if count == 1 else v[self.next_vgpr:self.next_vgpr + count - 1]
|
||||
self.next_vgpr += count
|
||||
self.max_vgpr = max(self.max_vgpr, self.next_vgpr)
|
||||
return r
|
||||
|
||||
def alloc_sgpr(self, count=1) -> Reg:
|
||||
# align to 2 for 64-bit, 4 for 128-bit
|
||||
if count >= 4: self.next_sgpr = (self.next_sgpr + 3) & ~3
|
||||
elif count >= 2: self.next_sgpr = (self.next_sgpr + 1) & ~1
|
||||
r = s[self.next_sgpr] if count == 1 else s[self.next_sgpr:self.next_sgpr + count - 1]
|
||||
self.next_sgpr += count
|
||||
self.max_sgpr = max(self.max_sgpr, self.next_sgpr)
|
||||
return r
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# Instruction emitter (like amd_asm_matmul.Kernel)
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
class AsmKernel:
|
||||
def __init__(self, arch='gfx1100'):
|
||||
self.instructions: list[Inst] = []
|
||||
self.labels: dict[str, int] = {}
|
||||
self.pos = 0
|
||||
self.arch = arch
|
||||
self.regs = RegFile()
|
||||
self.lds_size = 0
|
||||
|
||||
def emit(self, inst, target=None):
|
||||
self.instructions.append(inst)
|
||||
inst._target = target
|
||||
inst._pos = self.pos
|
||||
self.pos += inst.size()
|
||||
return inst
|
||||
|
||||
def label(self, name):
|
||||
self.labels[name] = self.pos
|
||||
|
||||
def waitcnt(self, lgkm=None, vm=None):
|
||||
vmcnt = vm if vm is not None else 63
|
||||
lgkmcnt = lgkm if lgkm is not None else 63
|
||||
expcnt = 7
|
||||
wc = (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10)
|
||||
self.emit(s_waitcnt(simm16=wc))
|
||||
|
||||
def resolve_branches(self):
|
||||
for inst in self.instructions:
|
||||
if hasattr(inst, '_target') and inst._target is not None:
|
||||
offset_dwords = (self.labels[inst._target] - inst._pos - inst.size()) // 4
|
||||
inst.simm16 = offset_dwords
|
||||
|
||||
def to_asm(self, name='kernel', kernarg_size=0, n_params=0) -> str:
|
||||
self.resolve_branches()
|
||||
body = ['\t' + disasm(inst) for inst in self.instructions]
|
||||
|
||||
hsa = [
|
||||
('group_segment_fixed_size', self.lds_size), ('private_segment_fixed_size', 0), ('kernarg_size', kernarg_size),
|
||||
('user_sgpr_count', 2), ('user_sgpr_kernarg_segment_ptr', 1),
|
||||
('wavefront_size32', 1), ('uses_dynamic_stack', 0), ('enable_private_segment', 0),
|
||||
('system_sgpr_workgroup_id_x', 1), ('system_sgpr_workgroup_id_y', 1), ('system_sgpr_workgroup_id_z', 0),
|
||||
('system_vgpr_workitem_id', 0), ('next_free_vgpr', self.regs.max_vgpr),
|
||||
('next_free_sgpr', max(self.regs.max_sgpr, 4)), # minimum 4 SGPRs
|
||||
('float_round_mode_32', 0), ('float_round_mode_16_64', 0),
|
||||
('float_denorm_mode_32', 3), ('float_denorm_mode_16_64', 3),
|
||||
('dx10_clamp', 1), ('ieee_mode', 1), ('fp16_overflow', 0),
|
||||
('workgroup_processor_mode', 0), ('memory_ordered', 1), ('forward_progress', 0), ('shared_vgpr_count', 0)]
|
||||
|
||||
args_meta = '\n'.join(
|
||||
f' - .address_space: global\n .offset: {i*8}\n .size: 8\n .value_kind: global_buffer'
|
||||
for i in range(n_params))
|
||||
|
||||
return '\n'.join([
|
||||
'\t.text', f'\t.amdgcn_target "amdgcn-amd-amdhsa--{self.arch}"',
|
||||
f'\t.protected\t{name}', f'\t.globl\t{name}', '\t.p2align\t8', f'\t.type\t{name},@function', f'{name}:',
|
||||
*body,
|
||||
'\t.section\t.rodata,"a",@progbits', '\t.p2align\t6, 0x0', f'\t.amdhsa_kernel {name}',
|
||||
*[f'\t\t.amdhsa_{k} {v}' for k, v in hsa],
|
||||
f'\t.end_amdhsa_kernel', '\t.text', f'.Lfunc_end0:', f'\t.size\t{name}, .Lfunc_end0-{name}',
|
||||
'\t.amdgpu_metadata', '---', 'amdhsa.kernels:', ' - .args:',
|
||||
args_meta,
|
||||
f' .group_segment_fixed_size: {self.lds_size}', ' .kernarg_segment_align: 8',
|
||||
f' .kernarg_segment_size: {kernarg_size}', ' .max_flat_workgroup_size: 1024',
|
||||
f' .name: {name}', ' .private_segment_fixed_size: 0',
|
||||
f' .sgpr_count: {max(self.regs.max_sgpr, 4)}', f' .symbol: {name}.kd',
|
||||
f' .vgpr_count: {self.regs.max_vgpr}', ' .wavefront_size: 32',
|
||||
f'amdhsa.target: amdgcn-amd-amdhsa--{self.arch}',
|
||||
'amdhsa.version:', ' - 1', ' - 2', '...', '\t.end_amdgpu_metadata'])
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# UOp → Inst rendering
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
# dtype → register count for VGPRs
|
||||
def _dtype_regs(dt: DType) -> int:
|
||||
if isinstance(dt, PtrDType): return 2 # 64-bit pointer
|
||||
return max(1, dt.itemsize // 4) * (dt.vcount if hasattr(dt, 'vcount') and dt.vcount > 1 else 1)
|
||||
|
||||
# dtype → global load instruction
|
||||
def _global_load(vdst, addr, saddr, offset=0, nregs=1):
|
||||
if nregs == 1: return global_load_b32(vdst=vdst, addr=addr, saddr=saddr, offset=offset)
|
||||
if nregs == 2: return global_load_b64(vdst=vdst, addr=addr, saddr=saddr, offset=offset)
|
||||
if nregs == 4: return global_load_b128(vdst=vdst, addr=addr, saddr=saddr, offset=offset)
|
||||
raise RuntimeError(f"unsupported global load size: {nregs} regs")
|
||||
|
||||
def _global_store(addr, data, saddr, offset=0, nregs=1):
|
||||
if nregs == 1: return global_store_b32(addr=addr, data=data, saddr=saddr, offset=offset)
|
||||
if nregs == 2: return global_store_b64(addr=addr, data=data, saddr=saddr, offset=offset)
|
||||
if nregs == 4: return global_store_b128(addr=addr, data=data, saddr=saddr, offset=offset)
|
||||
raise RuntimeError(f"unsupported global store size: {nregs} regs")
|
||||
|
||||
def render_kernel(uops: list[UOp], arch='gfx1100') -> str:
|
||||
"""Render linearized UOps into GAS assembly text."""
|
||||
k = AsmKernel(arch)
|
||||
# r maps UOp → register (Reg)
|
||||
r: dict[UOp, Reg] = {}
|
||||
# s_args: SGPR pairs for kernel argument pointers, loaded from kernarg segment
|
||||
# kernarg_ptr is in s[0:1] (set by HSA ABI)
|
||||
kernarg_base = k.regs.alloc_sgpr(2) # s[0:1] = kernarg segment pointer
|
||||
# system SGPRs for workgroup IDs come after user SGPRs
|
||||
# with user_sgpr_count=2, workgroup_id_x is s[2], workgroup_id_y is s[3]
|
||||
wg_id_x_sgpr = 2
|
||||
wg_id_y_sgpr = 3
|
||||
|
||||
name = 'test'
|
||||
params: list[tuple[int, Reg]] = [] # (param_idx, sgpr_pair)
|
||||
specials: dict[str, Reg] = {}
|
||||
loop_stack: list[tuple[str, str, Reg]] = [] # (label_start, label_end, range_reg)
|
||||
n_params = 0
|
||||
|
||||
# first pass: count params
|
||||
for u in uops:
|
||||
if u.op is Ops.PARAM: n_params = max(n_params, u.arg + 1)
|
||||
if u.op is Ops.SINK and u.arg is not None: name = u.arg.function_name
|
||||
|
||||
kernarg_size = n_params * 8 # each param is 8 bytes (pointer)
|
||||
|
||||
# load all kernel argument pointers
|
||||
param_sgprs: dict[int, Reg] = {}
|
||||
for i in range(n_params):
|
||||
sp = k.regs.alloc_sgpr(2)
|
||||
param_sgprs[i] = sp
|
||||
k.emit(s_load_b64(sdata=sp, sbase=kernarg_base, offset=i * 8, soffset=NULL))
|
||||
k.waitcnt(lgkm=0)
|
||||
|
||||
for u in uops:
|
||||
if u.op is Ops.SINK:
|
||||
continue
|
||||
|
||||
elif u.op is Ops.PARAM:
|
||||
r[u] = param_sgprs[u.arg]
|
||||
|
||||
elif u.op is Ops.CONST:
|
||||
if u.dtype == dtypes.float:
|
||||
vr = k.regs.alloc_vgpr()
|
||||
k.emit(v_mov_b32_e32(vr, u.arg))
|
||||
r[u] = vr
|
||||
elif u.dtype in (dtypes.int, dtypes.int32, dtypes.uint, dtypes.uint32):
|
||||
vr = k.regs.alloc_vgpr()
|
||||
k.emit(v_mov_b32_e32(vr, u.arg if isinstance(u.arg, int) and -16 <= u.arg <= 64 else u.arg))
|
||||
r[u] = vr
|
||||
elif u.dtype == dtypes.bool:
|
||||
# booleans: 1=true, 0=false — stored in VGPR as int
|
||||
vr = k.regs.alloc_vgpr()
|
||||
k.emit(v_mov_b32_e32(vr, 1 if u.arg else 0))
|
||||
r[u] = vr
|
||||
else:
|
||||
raise RuntimeError(f"unsupported CONST dtype {u.dtype}")
|
||||
|
||||
elif u.op is Ops.SPECIAL:
|
||||
kind, idx = u.arg[0], int(u.arg[-1])
|
||||
if kind == 'l':
|
||||
# local thread ID — workitem_id_{x,y,z} already in v0 (only x for 1D)
|
||||
if idx == 0:
|
||||
r[u] = v[0] # workitem_id_x is pre-loaded in v0 by hardware
|
||||
else:
|
||||
raise RuntimeError(f"unsupported local dim {idx}")
|
||||
elif kind == 'g':
|
||||
# workgroup ID — in system SGPRs (after user SGPRs)
|
||||
sgpr_off = wg_id_x_sgpr + idx
|
||||
vr = k.regs.alloc_vgpr()
|
||||
k.emit(v_mov_b32_e32(vr, s[sgpr_off]))
|
||||
r[u] = vr
|
||||
else:
|
||||
raise RuntimeError(f"unsupported SPECIAL kind {kind}")
|
||||
|
||||
elif u.op is Ops.INDEX:
|
||||
# INDEX(ptr, idx) — compute byte address: base_ptr + idx * element_size
|
||||
base = r[u.src[0]]
|
||||
idx_reg = r[u.src[1]]
|
||||
assert isinstance(u.dtype, PtrDType), f"INDEX must produce pointer, got {u.dtype}"
|
||||
elem_size = u.dtype.base.itemsize
|
||||
# compute byte offset: idx * elem_size
|
||||
offset_vr = k.regs.alloc_vgpr()
|
||||
if elem_size == 4:
|
||||
k.emit(v_lshlrev_b32_e32(offset_vr, 2, idx_reg))
|
||||
elif elem_size == 8:
|
||||
k.emit(v_lshlrev_b32_e32(offset_vr, 3, idx_reg))
|
||||
elif elem_size == 16:
|
||||
k.emit(v_lshlrev_b32_e32(offset_vr, 4, idx_reg))
|
||||
elif elem_size == 2:
|
||||
k.emit(v_lshlrev_b32_e32(offset_vr, 1, idx_reg))
|
||||
elif elem_size == 1:
|
||||
k.emit(v_mov_b32_e32(offset_vr, idx_reg))
|
||||
else:
|
||||
k.emit(v_mul_lo_u32(offset_vr, elem_size, idx_reg))
|
||||
# base is an SGPR pair (64-bit pointer), offset is VGPR — use scalar+vector addressing
|
||||
r[u] = offset_vr # store offset VGPR; base SGPR pair stored separately
|
||||
# stash the base pointer for LOAD/STORE to use
|
||||
u._base_sgpr = base
|
||||
|
||||
elif u.op is Ops.LOAD:
|
||||
idx_uop = u.src[0]
|
||||
assert idx_uop.op is Ops.INDEX or (idx_uop.op is Ops.CAST and idx_uop.src[0].op is Ops.INDEX)
|
||||
real_idx = idx_uop.src[0] if idx_uop.op is Ops.CAST else idx_uop
|
||||
base_sgpr = real_idx._base_sgpr
|
||||
offset_vr = r[real_idx]
|
||||
nregs = max(1, u.dtype.itemsize // 4) * (u.dtype.vcount if hasattr(u.dtype, 'vcount') and u.dtype.vcount > 1 else 1)
|
||||
dst = k.regs.alloc_vgpr(nregs)
|
||||
k.emit(_global_load(dst, offset_vr, base_sgpr, nregs=nregs))
|
||||
k.waitcnt(vm=0)
|
||||
r[u] = dst
|
||||
|
||||
elif u.op is Ops.STORE:
|
||||
idx_uop = u.src[0]
|
||||
assert idx_uop.op is Ops.INDEX or (idx_uop.op is Ops.CAST and idx_uop.src[0].op is Ops.INDEX)
|
||||
real_idx = idx_uop.src[0] if idx_uop.op is Ops.CAST else idx_uop
|
||||
base_sgpr = real_idx._base_sgpr
|
||||
offset_vr = r[real_idx]
|
||||
val_reg = r[u.src[1]]
|
||||
nregs = max(1, u.src[1].dtype.itemsize // 4) * (u.src[1].dtype.vcount if hasattr(u.src[1].dtype, 'vcount') and u.src[1].dtype.vcount > 1 else 1)
|
||||
k.emit(_global_store(offset_vr, val_reg, base_sgpr, nregs=nregs))
|
||||
r[u] = offset_vr # stores don't produce values, but map for dependencies
|
||||
|
||||
elif u.op is Ops.ADD:
|
||||
a_reg, b_reg = r[u.src[0]], r[u.src[1]]
|
||||
dst = k.regs.alloc_vgpr()
|
||||
if u.dtype == dtypes.float:
|
||||
k.emit(v_add_f32_e32(dst, a_reg, b_reg))
|
||||
elif u.dtype in (dtypes.int, dtypes.int32, dtypes.uint, dtypes.uint32):
|
||||
k.emit(v_add_nc_u32_e32(dst, a_reg, b_reg))
|
||||
else:
|
||||
raise RuntimeError(f"unsupported ADD dtype {u.dtype}")
|
||||
r[u] = dst
|
||||
|
||||
elif u.op is Ops.MUL:
|
||||
a_reg, b_reg = r[u.src[0]], r[u.src[1]]
|
||||
dst = k.regs.alloc_vgpr()
|
||||
if u.dtype == dtypes.float:
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v_mul_f32_e32
|
||||
k.emit(v_mul_f32_e32(dst, a_reg, b_reg))
|
||||
elif u.dtype in (dtypes.int, dtypes.int32, dtypes.uint, dtypes.uint32):
|
||||
k.emit(v_mul_lo_u32(dst, a_reg, b_reg))
|
||||
else:
|
||||
raise RuntimeError(f"unsupported MUL dtype {u.dtype}")
|
||||
r[u] = dst
|
||||
|
||||
elif u.op is Ops.SHL:
|
||||
# SHL(val, shift) -> v_lshlrev_b32(shift, val) (reversed operands)
|
||||
val_reg, shift_reg = r[u.src[0]], r[u.src[1]]
|
||||
dst = k.regs.alloc_vgpr()
|
||||
k.emit(v_lshlrev_b32_e32(dst, shift_reg, val_reg))
|
||||
r[u] = dst
|
||||
|
||||
elif u.op is Ops.SHR:
|
||||
val_reg, shift_reg = r[u.src[0]], r[u.src[1]]
|
||||
dst = k.regs.alloc_vgpr()
|
||||
k.emit(v_lshrrev_b32_e32(dst, shift_reg, val_reg))
|
||||
r[u] = dst
|
||||
|
||||
elif u.op is Ops.AND:
|
||||
a_reg, b_reg = r[u.src[0]], r[u.src[1]]
|
||||
dst = k.regs.alloc_vgpr()
|
||||
k.emit(v_and_b32_e32(dst, a_reg, b_reg))
|
||||
r[u] = dst
|
||||
|
||||
elif u.op is Ops.CAST:
|
||||
# for now: pointer casts are noops, numeric casts need work
|
||||
if isinstance(u.dtype, PtrDType):
|
||||
r[u] = r[u.src[0]]
|
||||
if hasattr(u.src[0], '_base_sgpr'): u._base_sgpr = u.src[0]._base_sgpr
|
||||
else:
|
||||
raise RuntimeError(f"unsupported CAST {u.src[0].dtype} -> {u.dtype}")
|
||||
|
||||
elif u.op is Ops.VECTORIZE:
|
||||
# VECTORIZE packs scalars into a vector — just allocate contiguous VGPRs
|
||||
count = len(u.src)
|
||||
dst = k.regs.alloc_vgpr(count)
|
||||
for i, src_u in enumerate(u.src):
|
||||
src_reg = r[src_u]
|
||||
target = v[dst.offset - 256 + i] if count > 1 else dst
|
||||
if src_reg.offset != target.offset:
|
||||
k.emit(v_mov_b32_e32(target, src_reg))
|
||||
r[u] = dst
|
||||
|
||||
elif u.op is Ops.GEP:
|
||||
# GEP extracts element from vector — just offset into the VGPR range
|
||||
base_reg = r[u.src[0]]
|
||||
idx = u.arg[0]
|
||||
r[u] = v[base_reg.offset - 256 + idx]
|
||||
|
||||
elif u.op in (Ops.NOOP, Ops.GROUP, Ops.AFTER):
|
||||
if u.src: r[u] = r[u.src[0]]
|
||||
|
||||
elif u.op is Ops.RANGE:
|
||||
# loop: counter starts at 0, increments by 1, bound is src[0]
|
||||
label_start = f'loop_{id(u)}'
|
||||
label_end = f'end_{id(u)}'
|
||||
ctr = k.regs.alloc_vgpr()
|
||||
k.emit(v_mov_b32_e32(ctr, 0))
|
||||
k.label(label_start)
|
||||
r[u] = ctr
|
||||
loop_stack.append((label_start, label_end, ctr))
|
||||
|
||||
elif u.op is Ops.END:
|
||||
label_start, label_end, ctr = loop_stack.pop()
|
||||
# increment counter
|
||||
k.emit(v_add_nc_u32_e32(ctr, 1, ctr))
|
||||
# compare and branch: use SGPR compare since loop bound should be uniform
|
||||
bound_uop = u.src[1] # the RANGE uop's src[0] is the bound
|
||||
# actually END.src = (range_uop, ...), range_uop.src[0] = bound
|
||||
range_uop = u.src[0]
|
||||
bound_reg = r[range_uop.src[0]]
|
||||
k.emit(v_cmp_lt_i32_e32(ctr, bound_reg))
|
||||
k.emit(s_cbranch_scc1(), target=label_start)
|
||||
k.label(label_end)
|
||||
|
||||
elif u.op is Ops.BARRIER:
|
||||
k.emit(s_barrier())
|
||||
|
||||
elif u.op is Ops.DEFINE_LOCAL:
|
||||
# LDS allocation — just track size, address computed at use time
|
||||
r[u] = v[0] # placeholder, LDS addressing handled separately
|
||||
k.lds_size = max(k.lds_size, u.dtype.size * u.dtype.base.itemsize if hasattr(u.dtype, 'size') else 0)
|
||||
|
||||
elif u.op is Ops.DEFINE_REG:
|
||||
# register "spill" region — allocate VGPRs
|
||||
size = u.dtype.size if hasattr(u.dtype, 'size') else 1
|
||||
vr = k.regs.alloc_vgpr(size)
|
||||
r[u] = vr
|
||||
|
||||
elif u.op is Ops.CMPLT:
|
||||
# compare: write result to VCC, then v_cndmask to get bool in VGPR
|
||||
a_reg, b_reg = r[u.src[0]], r[u.src[1]]
|
||||
dst = k.regs.alloc_vgpr()
|
||||
k.emit(v_cmp_lt_i32_e32(a_reg, b_reg))
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v_cndmask_b32_e32
|
||||
k.emit(v_cndmask_b32_e32(dst, 0, 1))
|
||||
r[u] = dst
|
||||
|
||||
elif u.op is Ops.CMPNE:
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v_cmp_ne_u32_e32, v_cndmask_b32_e32
|
||||
a_reg, b_reg = r[u.src[0]], r[u.src[1]]
|
||||
dst = k.regs.alloc_vgpr()
|
||||
k.emit(v_cmp_ne_u32_e32(a_reg, b_reg))
|
||||
k.emit(v_cndmask_b32_e32(dst, 0, 1))
|
||||
r[u] = dst
|
||||
|
||||
elif u.op is Ops.WHERE:
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v_cndmask_b32_e32
|
||||
cond_reg, true_reg, false_reg = r[u.src[0]], r[u.src[1]], r[u.src[2]]
|
||||
dst = k.regs.alloc_vgpr()
|
||||
# set VCC from condition (nonzero = true)
|
||||
k.emit(v_cmp_lt_i32_e32(0, cond_reg)) # VCC = cond_reg != 0
|
||||
k.emit(v_cndmask_b32_e32(dst, false_reg, true_reg))
|
||||
r[u] = dst
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"unsupported UOp: {u.op} dtype={u.dtype}")
|
||||
|
||||
# epilogue
|
||||
k.waitcnt(vm=0, lgkm=0)
|
||||
k.emit(s_endpgm())
|
||||
|
||||
return k.to_asm(name=name, kernarg_size=kernarg_size, n_params=n_params)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# AMDAssemblyRenderer: Renderer subclass for tinygrad integration
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
class AMDAssemblyRenderer(Renderer):
|
||||
device = "AMD"
|
||||
suffix = "s" # GAS assembly
|
||||
supports_float4 = True
|
||||
has_local = True
|
||||
has_shared = True
|
||||
global_max = AMDHIPRenderer.global_max
|
||||
shared_max = AMDHIPRenderer.shared_max
|
||||
|
||||
def __init__(self, arch: str):
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
self.arch = arch
|
||||
self.compiler = HIPCompiler(arch)
|
||||
|
||||
def render(self, uops: list[UOp]) -> str:
|
||||
return render_kernel(uops, arch=self.arch)
|
||||
|
||||
def __reduce__(self): return self.__class__, (self.arch,)
|
||||
|
|
@ -8,9 +8,11 @@ from tinygrad.runtime.support.hcq import MMIOInterface, BumpAllocator, hcq_filte
|
|||
from tinygrad.uop.ops import sint
|
||||
from tinygrad.device import Compiled, DMAFdRef, BufferSpec, CompilerSet
|
||||
from tinygrad.helpers import getenv, round_up, data64_le, DEBUG, PROFILE, ProfileEvent, lo32, hi32, colored, prod, ContextVar
|
||||
from tinygrad.helpers import VIZ, AMD_CC, AMD_LLVM, AMD_HIPCC, ceildiv, unwrap
|
||||
from tinygrad.helpers import VIZ, AMD_CC, AMD_LLVM, AMD_HIPCC, AMD_ISEL, AMD_ASM, ceildiv, unwrap
|
||||
from tinygrad.renderer.cstyle import AMDHIPRenderer, AMDHIPCCRenderer
|
||||
from tinygrad.renderer.llvmir import AMDLLVMRenderer
|
||||
from extra.assembly.amd.isel import AMDISELRenderer
|
||||
from extra.assembly.amd.renderer import AMDAssemblyRenderer
|
||||
from tinygrad.runtime.autogen import kfd, hsa, pci, sqtt, amdgpu_kd, amdgpu_drm
|
||||
from tinygrad.runtime.autogen.am import am
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
|
|
@ -970,7 +972,9 @@ class AMDDevice(HCQCompiled):
|
|||
|
||||
compilers = CompilerSet([(functools.partial(AMDHIPRenderer, self.arch), None),
|
||||
(functools.partial(AMDLLVMRenderer, self.arch), AMD_LLVM),
|
||||
(functools.partial(AMDHIPCCRenderer, self.arch), AMD_HIPCC)], ctrl_var=AMD_CC)
|
||||
(functools.partial(AMDHIPCCRenderer, self.arch), AMD_HIPCC),
|
||||
(functools.partial(AMDISELRenderer, self.arch), AMD_ISEL),
|
||||
(functools.partial(AMDAssemblyRenderer, self.arch), AMD_ASM)], ctrl_var=AMD_CC)
|
||||
|
||||
super().__init__(device, AMDAllocator(self), compilers, functools.partial(AMDProgram, self), AMDSignal,
|
||||
functools.partial(AMDComputeAQLQueue if self.is_aql else AMDComputeQueue, self),
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ class Ops(FastEnum):
|
|||
# CUSTOM/CUSTOMI are used to output strings into codegen. the I makes the string inline
|
||||
CUSTOM = auto(); CUSTOMI = auto()
|
||||
|
||||
# INS is a machine instruction
|
||||
# machine instruction: arg=Inst object, tag=register assignment
|
||||
INS = auto()
|
||||
|
||||
# ** 6 -- ops that don't exist in programs **
|
||||
|
|
|
|||
|
|
@ -289,7 +289,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
if self.op is Ops.ASSIGN: return self.src[1]._shape
|
||||
|
||||
# elementwise ops keep the shape the same. all inputs with shape must match
|
||||
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}):
|
||||
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE, Ops.INS}):
|
||||
input_shapes = [x._shape for x in self.src if x._shape is not None]
|
||||
if len(input_shapes) == 0: return None
|
||||
if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}")
|
||||
|
|
|
|||
|
|
@ -177,7 +177,7 @@ shared_codegen_spec = PatternMatcher([
|
|||
# CUSTOM (inline and non inline)
|
||||
(UPat((Ops.CUSTOMI, Ops.CUSTOM)), lambda: True),
|
||||
|
||||
# assembly instruction
|
||||
# machine instruction (ISel output)
|
||||
(UPat(Ops.INS), lambda: True),
|
||||
|
||||
# INDEX (2-arg and 3-arg with bool gate)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue