Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
1cd10a2e3c amd isel renderer 2026-02-13 17:50:39 +08:00
9 changed files with 1439 additions and 6 deletions

674
test/amd/test_isel.py Normal file
View file

@ -0,0 +1,674 @@
#!/usr/bin/env python3
"""Tests for the pcode-based instruction selector (isel.py)."""
import unittest
from tinygrad.uop.ops import UOp, Ops
from tinygrad.dtype import dtypes
from extra.assembly.amd.isel import (rdna3_isel, make_inst, normalize, _count_nodes, _is_direct_alu,
_parse_pcode_patterns, _pattern_key, _runtime_key, uop_to_upat,
_SENTINEL, _SENTINEL_SET, _DIRECT_TABLE, _STRUCTURAL_TABLE,
_ALU_ENUM_TYPES, build_isel_patterns)
from extra.assembly.amd.autogen.rdna3.str_pcode import PCODE
from extra.assembly.amd.autogen.rdna3.enum import VOP2Op, VOP1Op, VOP3Op, SOP2Op, VOPCOp, VOP3SDOp
# helpers
def _var(name, dtype=dtypes.float): return UOp(Ops.DEFINE_VAR, dtype, arg=(name, 0, 100))
def _const(val, dtype=dtypes.float): return UOp(Ops.CONST, dtype, arg=val)
class TestMakeInst(unittest.TestCase):
def test_vop2(self):
inst = make_inst(VOP2Op.V_ADD_F32_E32)
assert inst.op == VOP2Op.V_ADD_F32_E32
def test_vop1(self):
inst = make_inst(VOP1Op.V_SQRT_F32_E32)
assert inst.op == VOP1Op.V_SQRT_F32_E32
def test_vop3(self):
inst = make_inst(VOP3Op.V_ADD_F64)
assert inst.op == VOP3Op.V_ADD_F64
def test_vop3_sdst(self):
# VOP3SDOp opcodes need the _SDST variant class
inst = make_inst(VOP3SDOp.V_ADD_CO_CI_U32)
assert inst.op == VOP3SDOp.V_ADD_CO_CI_U32
def test_vopc(self):
inst = make_inst(VOPCOp.V_CMP_LT_F32_E32)
assert inst.op == VOPCOp.V_CMP_LT_F32_E32
def test_sop2(self):
inst = make_inst(SOP2Op.S_ADD_I32)
assert inst.op == SOP2Op.S_ADD_I32
def test_invalid_raises(self):
with self.assertRaises(RuntimeError): make_inst("not_an_opcode")
class TestNormalize(unittest.TestCase):
def test_bitcast_sentinel(self):
s0 = _SENTINEL['S0']
bc = UOp(Ops.BITCAST, dtypes.float, (s0,))
norm = normalize(bc)
assert norm.op == Ops.DEFINE_VAR
assert norm.dtype == dtypes.float
def test_cast_sentinel(self):
s0 = _SENTINEL['S0']
cast = UOp(Ops.CAST, dtypes.int, (s0,))
norm = normalize(cast)
assert norm.op == Ops.DEFINE_VAR
assert norm.dtype == dtypes.int
def test_identity_bitcast(self):
x = UOp(Ops.CONST, dtypes.float, arg=1.0)
bc = UOp(Ops.BITCAST, dtypes.float, (x,))
norm = normalize(bc)
assert norm.op == Ops.CONST
assert norm.arg == 1.0
def test_shift_mask_31(self):
s0 = _SENTINEL['S0']
c31 = UOp(Ops.CONST, dtypes.uint, arg=31)
masked = UOp(Ops.AND, dtypes.uint, (s0, c31))
norm = normalize(masked)
assert norm.op == Ops.DEFINE_VAR
def test_shift_mask_63(self):
s0 = _SENTINEL['S0']
c63 = UOp(Ops.CONST, dtypes.uint, arg=63)
masked = UOp(Ops.AND, dtypes.uint, (s0, c63))
norm = normalize(masked)
assert norm.op == Ops.DEFINE_VAR
def test_non_sentinel_bitcast_preserved(self):
x = _var('x', dtypes.uint)
bc = UOp(Ops.BITCAST, dtypes.float, (x,))
norm = normalize(bc)
assert norm.op == Ops.BITCAST # not a sentinel, so preserved
def test_recursive(self):
s0 = _SENTINEL['S0']
s1 = _SENTINEL['S1']
bc0 = UOp(Ops.BITCAST, dtypes.float, (s0,))
bc1 = UOp(Ops.BITCAST, dtypes.float, (s1,))
add = UOp(Ops.ADD, dtypes.float, (bc0, bc1))
norm = normalize(add)
assert norm.op == Ops.ADD
assert all(s.dtype == dtypes.float for s in norm.src)
assert all(s.op == Ops.DEFINE_VAR for s in norm.src)
class TestCountNodes(unittest.TestCase):
def test_leaf(self):
assert _count_nodes(_var('x')) == 1
def test_binary(self):
x, y = _var('x'), _var('y')
add = UOp(Ops.ADD, dtypes.float, (x, y))
assert _count_nodes(add) == 3
def test_dag_sharing(self):
x = _var('x')
add = UOp(Ops.ADD, dtypes.float, (x, x))
assert _count_nodes(add) == 2 # x counted once
class TestIsDirectAlu(unittest.TestCase):
def test_add_sentinels(self):
s0 = _SENTINEL['S0'].replace(dtype=dtypes.float)
s1 = _SENTINEL['S1'].replace(dtype=dtypes.float)
add = UOp(Ops.ADD, dtypes.float, (s0, s1))
assert _is_direct_alu(add)
def test_cast_sentinel(self):
s0 = _SENTINEL['S0'].replace(dtype=dtypes.int)
cast = UOp(Ops.CAST, dtypes.float, (s0,))
assert _is_direct_alu(cast)
def test_nested_not_direct(self):
s0 = _SENTINEL['S0'].replace(dtype=dtypes.uint)
c = _const(0xFFFFFFFF, dtypes.uint)
xor = UOp(Ops.XOR, dtypes.uint, (s0, c))
assert not _is_direct_alu(xor) # const child is not DEFINE_VAR
class TestPatternKey(unittest.TestCase):
def test_sentinel_var(self):
s0 = _SENTINEL['S0'].replace(dtype=dtypes.float)
key = _pattern_key(s0)
assert key == 'var(S0,dtypes.float)'
def test_const(self):
c = _const(42, dtypes.uint)
key = _pattern_key(c)
assert key == 'const(42,dtypes.uint)'
def test_binary_op(self):
s0 = _SENTINEL['S0'].replace(dtype=dtypes.float)
s1 = _SENTINEL['S1'].replace(dtype=dtypes.float)
add = UOp(Ops.ADD, dtypes.float, (s0, s1))
key = _pattern_key(add)
assert key == 'Ops.ADD(dtypes.float,var(S0,dtypes.float),var(S1,dtypes.float))'
class TestRuntimeKey(unittest.TestCase):
def test_matches_pattern_key(self):
# runtime key on a matched UOp should equal pattern key on the pcode template
x = _var('x', dtypes.uint)
c = _const(0xFFFFFFFF, dtypes.uint)
xor = UOp(Ops.XOR, dtypes.uint, (x, c))
rkey = _runtime_key(xor)
s0 = _SENTINEL['S0'].replace(dtype=dtypes.uint)
xor_template = UOp(Ops.XOR, dtypes.uint, (s0, c))
pkey = _pattern_key(xor_template)
assert rkey == pkey
class TestUopToUpat(unittest.TestCase):
def test_sentinel_becomes_var(self):
s0 = _SENTINEL['S0'].replace(dtype=dtypes.float)
pat = uop_to_upat(s0)
assert pat.name == 'S0'
assert pat.dtype == (dtypes.float,)
def test_const_preserved(self):
c = _const(42, dtypes.uint)
pat = uop_to_upat(c)
assert pat.op == (Ops.CONST,)
assert pat.arg == 42
class TestBuildPerformance(unittest.TestCase):
def test_builds_under_2_seconds(self):
import time
t0 = time.time()
build_isel_patterns(PCODE)
elapsed = time.time() - t0
assert elapsed < 2.0, f"build took {elapsed:.2f}s, expected <2s"
def test_alu_filter(self):
# verify only ALU enum types are parsed
for opcode in PCODE:
if type(opcode).__name__ not in _ALU_ENUM_TYPES: continue
# these should parse without hanging
class TestDirectPatterns(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.pm = rdna3_isel()
def _check(self, uop, expected_name_substr):
result = self.pm.rewrite(uop)
self.assertIsNotNone(result, f"no match for {uop.op} {uop.dtype}")
self.assertEqual(result.op, Ops.INS)
self.assertIn(expected_name_substr, result.arg.op.name, f"expected {expected_name_substr} in {result.arg.op.name}")
return result
# arithmetic
def test_add_f32(self): self._check(UOp(Ops.ADD, dtypes.float, (_var('a'), _var('b'))), 'V_ADD_F32')
def test_add_f64(self): self._check(UOp(Ops.ADD, dtypes.double, (_var('a', dtypes.double), _var('b', dtypes.double))), 'V_ADD_F64')
def test_add_i32(self): self._check(UOp(Ops.ADD, dtypes.int, (_var('a', dtypes.int), _var('b', dtypes.int))), 'ADD_NC_I32')
def test_add_u32(self): self._check(UOp(Ops.ADD, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))), 'ADD_NC_U32')
def test_mul_f32(self): self._check(UOp(Ops.MUL, dtypes.float, (_var('a'), _var('b'))), 'V_MUL_F32')
def test_mul_f64(self): self._check(UOp(Ops.MUL, dtypes.double, (_var('a', dtypes.double), _var('b', dtypes.double))), 'V_MUL_F64')
def test_mul_i32(self): self._check(UOp(Ops.MUL, dtypes.int, (_var('a', dtypes.int), _var('b', dtypes.int))), 'MUL_I32')
def test_mul_u32(self): self._check(UOp(Ops.MUL, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))), 'MUL_U32')
# bitwise
def test_and_u32(self): self._check(UOp(Ops.AND, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))), 'AND_B32')
def test_or_u32(self): self._check(UOp(Ops.OR, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))), 'OR_B32')
def test_xor_u32(self): self._check(UOp(Ops.XOR, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))), 'XOR_B32')
# u64 bitwise ops are SOP-only, skipped in vgpr_only mode
def test_and_u64_skipped(self):
result = self.pm.rewrite(UOp(Ops.AND, dtypes.ulong, (_var('a', dtypes.ulong), _var('b', dtypes.ulong))))
self.assertIsNone(result)
def test_or_u64_skipped(self):
result = self.pm.rewrite(UOp(Ops.OR, dtypes.ulong, (_var('a', dtypes.ulong), _var('b', dtypes.ulong))))
self.assertIsNone(result)
def test_xor_u64_skipped(self):
result = self.pm.rewrite(UOp(Ops.XOR, dtypes.ulong, (_var('a', dtypes.ulong), _var('b', dtypes.ulong))))
self.assertIsNone(result)
# shifts
def test_shl_u32(self): self._check(UOp(Ops.SHL, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))), 'LSH')
def test_shr_u32(self): self._check(UOp(Ops.SHR, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))), 'LSH')
# unary float
def test_sqrt_f32(self): self._check(UOp(Ops.SQRT, dtypes.float, (_var('a'),)), 'SQRT_F32')
def test_sqrt_f64(self): self._check(UOp(Ops.SQRT, dtypes.double, (_var('a', dtypes.double),)), 'SQRT_F64')
def test_trunc_f32(self): self._check(UOp(Ops.TRUNC, dtypes.float, (_var('a'),)), 'TRUNC_F32')
def test_trunc_f64(self): self._check(UOp(Ops.TRUNC, dtypes.double, (_var('a', dtypes.double),)), 'TRUNC_F64')
def test_log2_f32(self): self._check(UOp(Ops.LOG2, dtypes.float, (_var('a'),)), 'LOG')
def test_exp2_f32(self): self._check(UOp(Ops.EXP2, dtypes.float, (_var('a'),)), 'EXP')
# conversions
def test_cast_i32_to_f32(self): self._check(UOp(Ops.CAST, dtypes.float, (_var('a', dtypes.int),)), 'CVT_F32_I32')
def test_cast_f32_to_f64(self): self._check(UOp(Ops.CAST, dtypes.double, (_var('a'),)), 'CVT_F64_F32')
def test_cast_f64_to_f32(self): self._check(UOp(Ops.CAST, dtypes.float, (_var('a', dtypes.double),)), 'CVT_F32_F64')
def test_cast_i32_to_f64(self): self._check(UOp(Ops.CAST, dtypes.double, (_var('a', dtypes.int),)), 'CVT_F64_I32')
def test_cast_f32_to_f16(self): self._check(UOp(Ops.CAST, dtypes.half, (_var('a'),)), 'CVT_F16_F32')
# compares are skipped by ISel (VOPC writes VCC, not VGPRs; LLVM handles natively)
def test_cmplt_skipped(self):
result = self.pm.rewrite(UOp(Ops.CMPLT, dtypes.bool, (_var('a', dtypes.int), _var('b', dtypes.int))))
self.assertIsNone(result)
def test_cmpne_skipped(self):
result = self.pm.rewrite(UOp(Ops.CMPNE, dtypes.bool, (_var('a'), _var('b'))))
self.assertIsNone(result)
# check that unmatched types return None
def test_no_match(self):
# there's no direct ADD for bools
result = self.pm.rewrite(UOp(Ops.ADD, dtypes.bool, (_var('a', dtypes.bool), _var('b', dtypes.bool))))
self.assertIsNone(result)
class TestStructuralPatterns(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.pm = rdna3_isel()
def _check(self, uop, expected_name_substr, expected_src_count=None):
result = self.pm.rewrite(uop)
self.assertIsNotNone(result, f"no match for structural pattern")
self.assertEqual(result.op, Ops.INS)
self.assertIn(expected_name_substr, result.arg.op.name, f"expected {expected_name_substr} in {result.arg.op.name}")
if expected_src_count is not None:
self.assertEqual(len(result.src), expected_src_count, f"expected {expected_src_count} srcs, got {len(result.src)}")
return result
def test_not_u32(self):
x = _var('x', dtypes.uint)
xor = UOp(Ops.XOR, dtypes.uint, (x, _const(0xFFFFFFFF, dtypes.uint)))
self._check(xor, 'NOT_B32', 1)
# u64 NOT is SOP-only (S_NOT_B64), skipped in vgpr_only mode
def test_not_u64_skipped(self):
x = _var('x', dtypes.ulong)
xor = UOp(Ops.XOR, dtypes.ulong, (x, _const(0xFFFFFFFFFFFFFFFF, dtypes.ulong)))
result = self.pm.rewrite(xor)
self.assertIsNone(result)
def test_sub_u32(self):
x, y = _var('x', dtypes.uint), _var('y', dtypes.uint)
neg = UOp(Ops.MUL, dtypes.uint, (y, _const(-1, dtypes.uint)))
sub = UOp(Ops.ADD, dtypes.uint, (x, neg))
result = self._check(sub, 'SUB_NC_U32', 2)
# verify source order: x is first, y is second
self.assertEqual(result.src[0].arg, ('x', 0, 100))
self.assertEqual(result.src[1].arg, ('y', 0, 100))
def test_rcp_f32(self):
a = _var('a')
rcp = UOp(Ops.RECIPROCAL, dtypes.float, (a,))
mul_rcp = UOp(Ops.MUL, dtypes.float, (_const(1.0), rcp))
result = self._check(mul_rcp, 'RCP_F32', 1)
self.assertEqual(result.src[0].arg, ('a', 0, 100))
def test_rcp_f64(self):
a = _var('a', dtypes.double)
rcp = UOp(Ops.RECIPROCAL, dtypes.double, (a,))
mul_rcp = UOp(Ops.MUL, dtypes.double, (_const(1.0, dtypes.double), rcp))
self._check(mul_rcp, 'RCP_F64', 1)
def test_cvt_i32_f32(self):
# CAST(i32, TRUNC(f32, x)) -> V_CVT_I32_F32
a = _var('a')
trunc = UOp(Ops.TRUNC, dtypes.float, (a,))
cast = UOp(Ops.CAST, dtypes.int, (trunc,))
self._check(cast, 'CVT_I32_F32', 1)
def test_mad_u32(self):
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
mul = UOp(Ops.MUL, dtypes.uint, (x, y))
mad = UOp(Ops.ADD, dtypes.uint, (mul, z))
result = self._check(mad, 'MAD_U32_U24', 3)
self.assertEqual(result.src[0].arg, ('x', 0, 100))
self.assertEqual(result.src[1].arg, ('y', 0, 100))
self.assertEqual(result.src[2].arg, ('z', 0, 100))
def test_add3_u32(self):
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
add1 = UOp(Ops.ADD, dtypes.uint, (x, y))
add3 = UOp(Ops.ADD, dtypes.uint, (add1, z))
result = self._check(add3, 'ADD3_U32', 3)
def test_xor3_b32(self):
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
xor1 = UOp(Ops.XOR, dtypes.uint, (x, y))
xor3 = UOp(Ops.XOR, dtypes.uint, (xor1, z))
self._check(xor3, 'XOR3_B32', 3)
def test_and_or_b32(self):
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
and_op = UOp(Ops.AND, dtypes.uint, (x, y))
or_op = UOp(Ops.OR, dtypes.uint, (and_op, z))
self._check(or_op, 'AND_OR_B32', 3)
def test_or3_b32(self):
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
or1 = UOp(Ops.OR, dtypes.uint, (x, y))
or3 = UOp(Ops.OR, dtypes.uint, (or1, z))
self._check(or3, 'OR3_B32', 3)
# NAND/NOR were SOP-only, in vgpr_only mode they decompose to V_XOR_B32(AND/OR, mask)
def test_nand_decomposes(self):
x, y = _var('x', dtypes.uint), _var('y', dtypes.uint)
and_op = UOp(Ops.AND, dtypes.uint, (x, y))
nand = UOp(Ops.XOR, dtypes.uint, (and_op, _const(0xFFFFFFFF, dtypes.uint)))
self._check(nand, 'XOR_B32')
def test_nor_decomposes(self):
x, y = _var('x', dtypes.uint), _var('y', dtypes.uint)
or_op = UOp(Ops.OR, dtypes.uint, (x, y))
nor = UOp(Ops.XOR, dtypes.uint, (or_op, _const(0xFFFFFFFF, dtypes.uint)))
self._check(nor, 'XOR_B32')
def test_xnor_b32(self):
x, y = _var('x', dtypes.uint), _var('y', dtypes.uint)
xor_op = UOp(Ops.XOR, dtypes.uint, (x, y))
xnor = UOp(Ops.XOR, dtypes.uint, (xor_op, _const(0xFFFFFFFF, dtypes.uint)))
self._check(xnor, 'XNOR_B32', 2)
def test_min_u32(self):
x, y = _var('x', dtypes.uint), _var('y', dtypes.uint)
cmp = UOp(Ops.CMPLT, dtypes.bool, (x, y))
where = UOp(Ops.WHERE, dtypes.uint, (cmp, x, y))
self._check(where, 'MIN_U32', 2)
class TestInstProperties(unittest.TestCase):
"""Verify that Inst objects produced by isel have correct properties."""
@classmethod
def setUpClass(cls):
cls.pm = rdna3_isel()
def test_ins_has_dtype(self):
result = self.pm.rewrite(UOp(Ops.ADD, dtypes.float, (_var('a'), _var('b'))))
self.assertEqual(result.dtype, dtypes.float)
def test_ins_preserves_sources(self):
a, b = _var('a'), _var('b')
result = self.pm.rewrite(UOp(Ops.ADD, dtypes.float, (a, b)))
self.assertEqual(result.src, (a, b))
def test_ins_tag_default_none(self):
result = self.pm.rewrite(UOp(Ops.ADD, dtypes.float, (_var('a'), _var('b'))))
# tag should not be set (defaults to None or empty)
self.assertIsNone(result.tag)
class TestTableCoverage(unittest.TestCase):
"""Verify that the tables have expected coverage."""
@classmethod
def setUpClass(cls):
rdna3_isel() # populate tables
def test_direct_table_has_add(self):
found = any(op == Ops.ADD for (op, _, _) in _DIRECT_TABLE)
self.assertTrue(found)
def test_direct_table_has_cast(self):
found = any(op == Ops.CAST for (op, _, _) in _DIRECT_TABLE)
self.assertTrue(found)
def test_direct_table_skips_cmplt(self):
# compares are in the table but skipped at runtime (bool output)
found = any(op == Ops.CMPLT for (op, _, _) in _DIRECT_TABLE)
self.assertTrue(found) # entries exist but callbacks skip them
def test_structural_table_has_not(self):
found = any('NOT' in inst.op.name for inst in _STRUCTURAL_TABLE.values())
self.assertTrue(found)
def test_structural_table_has_rcp(self):
found = any('RCP' in inst.op.name for inst in _STRUCTURAL_TABLE.values())
self.assertTrue(found)
def test_structural_table_has_sub(self):
found = any('SUB' in inst.op.name for inst in _STRUCTURAL_TABLE.values())
self.assertTrue(found)
def test_direct_count(self):
self.assertGreaterEqual(len(_DIRECT_TABLE), 25, "expected at least 25 direct patterns")
def test_structural_count(self):
self.assertGreaterEqual(len(_STRUCTURAL_TABLE), 15, "expected at least 15 structural patterns")
class TestEmulatorValidation(unittest.TestCase):
"""Validate isel-produced Inst objects execute correctly in the emulator.
For each pattern, we:
1. Run the UOp through isel to get Ops.INS with arg=Inst
2. Copy the Inst and assign concrete registers
3. Set up operand values via MOV instructions
4. Execute through the emulator
5. Verify the output matches expected computation
"""
@classmethod
def setUpClass(cls):
cls.pm = rdna3_isel()
def _run(self, instructions, n_lanes=1):
from extra.assembly.amd.test.hw.helpers import run_program_emu
return run_program_emu(instructions, n_lanes)
def _get_inst(self, uop):
"""Get isel result, return (Inst, src_count)."""
result = self.pm.rewrite(uop)
assert result is not None and result.op == Ops.INS, f"isel failed for {uop.op} {uop.dtype}"
return result.arg, len(result.src)
def _copy_inst(self, inst):
import copy
return copy.copy(inst)
# ── direct ALU: float arithmetic ──
def test_emu_add_f32(self):
from extra.assembly.amd.test.hw.helpers import i2f, f2i
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
inst, _ = self._get_inst(UOp(Ops.ADD, dtypes.float, (_var('a'), _var('b'))))
ci = self._copy_inst(inst)
ci.src0 = v[0]; ci.vsrc1 = v[1]; ci.vdst = v[2]
st = self._run([v_mov_b32_e32(v[0], 1.5), v_mov_b32_e32(v[1], 2.25), ci])
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 3.75, places=5)
def test_emu_mul_f32(self):
from extra.assembly.amd.test.hw.helpers import i2f
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
inst, _ = self._get_inst(UOp(Ops.MUL, dtypes.float, (_var('a'), _var('b'))))
ci = self._copy_inst(inst)
ci.src0 = v[0]; ci.vsrc1 = v[1]; ci.vdst = v[2]
st = self._run([v_mov_b32_e32(v[0], 3.0), v_mov_b32_e32(v[1], 4.0), ci])
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 12.0, places=5)
# ── direct ALU: integer arithmetic ──
def test_emu_add_u32(self):
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
inst, _ = self._get_inst(UOp(Ops.ADD, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))))
ci = self._copy_inst(inst)
ci.src0 = v[0]; ci.vsrc1 = v[1]; ci.vdst = v[2]
st = self._run([v_mov_b32_e32(v[0], 10), v_mov_b32_e32(v[1], 20), ci])
self.assertEqual(st.vgpr[0][2], 30)
# ── direct ALU: bitwise ──
def test_emu_and_u32(self):
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
inst, _ = self._get_inst(UOp(Ops.AND, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))))
ci = self._copy_inst(inst)
ci.src0 = v[0]; ci.vsrc1 = v[1]; ci.vdst = v[2]
st = self._run([v_mov_b32_e32(v[0], 0xFF00), v_mov_b32_e32(v[1], 0x0FF0), ci])
self.assertEqual(st.vgpr[0][2], 0x0F00)
def test_emu_or_u32(self):
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
inst, _ = self._get_inst(UOp(Ops.OR, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))))
ci = self._copy_inst(inst)
ci.src0 = v[0]; ci.vsrc1 = v[1]; ci.vdst = v[2]
st = self._run([v_mov_b32_e32(v[0], 0xFF00), v_mov_b32_e32(v[1], 0x0FF0), ci])
self.assertEqual(st.vgpr[0][2], 0xFFF0)
def test_emu_xor_u32(self):
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
inst, _ = self._get_inst(UOp(Ops.XOR, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))))
ci = self._copy_inst(inst)
ci.src0 = v[0]; ci.vsrc1 = v[1]; ci.vdst = v[2]
st = self._run([v_mov_b32_e32(v[0], 0xFF00), v_mov_b32_e32(v[1], 0x0FF0), ci])
self.assertEqual(st.vgpr[0][2], 0xF0F0)
# ── direct ALU: shifts ──
def test_emu_shl_u32(self):
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
inst, _ = self._get_inst(UOp(Ops.SHL, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))))
ci = self._copy_inst(inst)
# LSHLREV: vdst = vsrc1 << src0 (reversed operands!)
ci.src0 = v[1]; ci.vsrc1 = v[0]; ci.vdst = v[2]
st = self._run([v_mov_b32_e32(v[0], 1), v_mov_b32_e32(v[1], 4), ci])
self.assertEqual(st.vgpr[0][2], 16) # 1 << 4 = 16
def test_emu_shr_u32(self):
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
inst, _ = self._get_inst(UOp(Ops.SHR, dtypes.uint, (_var('a', dtypes.uint), _var('b', dtypes.uint))))
ci = self._copy_inst(inst)
# LSHRREV: vdst = vsrc1 >> src0 (reversed operands!)
ci.src0 = v[1]; ci.vsrc1 = v[0]; ci.vdst = v[2]
st = self._run([v_mov_b32_e32(v[0], 16), v_mov_b32_e32(v[1], 4), ci])
self.assertEqual(st.vgpr[0][2], 1) # 16 >> 4 = 1
# ── direct ALU: unary float ──
def test_emu_sqrt_f32(self):
from extra.assembly.amd.test.hw.helpers import i2f
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
inst, _ = self._get_inst(UOp(Ops.SQRT, dtypes.float, (_var('a'),)))
ci = self._copy_inst(inst)
ci.src0 = v[0]; ci.vdst = v[2]
st = self._run([v_mov_b32_e32(v[0], 4.0), ci])
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 2.0, places=4)
def test_emu_trunc_f32(self):
from extra.assembly.amd.test.hw.helpers import i2f
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
inst, _ = self._get_inst(UOp(Ops.TRUNC, dtypes.float, (_var('a'),)))
ci = self._copy_inst(inst)
ci.src0 = v[0]; ci.vdst = v[2]
st = self._run([v_mov_b32_e32(v[0], 3.7), ci])
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 3.0, places=5)
def test_emu_exp2_f32(self):
from extra.assembly.amd.test.hw.helpers import i2f
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
inst, _ = self._get_inst(UOp(Ops.EXP2, dtypes.float, (_var('a'),)))
ci = self._copy_inst(inst)
ci.src0 = v[0]; ci.vdst = v[2]
st = self._run([v_mov_b32_e32(v[0], 3.0), ci])
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 8.0, delta=0.01)
def test_emu_log2_f32(self):
from extra.assembly.amd.test.hw.helpers import i2f
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
inst, _ = self._get_inst(UOp(Ops.LOG2, dtypes.float, (_var('a'),)))
ci = self._copy_inst(inst)
ci.src0 = v[0]; ci.vdst = v[2]
st = self._run([v_mov_b32_e32(v[0], 8.0), ci])
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 3.0, delta=0.01)
# ── direct ALU: conversions ──
def test_emu_cast_i32_to_f32(self):
from extra.assembly.amd.test.hw.helpers import i2f
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
inst, _ = self._get_inst(UOp(Ops.CAST, dtypes.float, (_var('a', dtypes.int),)))
ci = self._copy_inst(inst)
ci.src0 = v[0]; ci.vdst = v[2]
st = self._run([v_mov_b32_e32(v[0], 42), ci])
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 42.0, places=5)
def test_emu_cast_f32_to_f16(self):
from extra.assembly.amd.test.hw.helpers import i2f, f16
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
inst, _ = self._get_inst(UOp(Ops.CAST, dtypes.half, (_var('a'),)))
ci = self._copy_inst(inst)
ci.src0 = v[0]; ci.vdst = v[2]
st = self._run([v_mov_b32_e32(v[0], 1.5), ci])
# f16 result is in lower 16 bits of v[2]
self.assertAlmostEqual(f16(st.vgpr[0][2]), 1.5, places=2)
# compares are skipped by ISel (VOPC writes VCC; LLVM handles natively)
# ── structural: NOT ──
def test_emu_not_u32(self):
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
x = _var('x', dtypes.uint)
xor = UOp(Ops.XOR, dtypes.uint, (x, _const(0xFFFFFFFF, dtypes.uint)))
inst, _ = self._get_inst(xor)
ci = self._copy_inst(inst)
ci.src0 = v[0]; ci.vdst = v[2]
st = self._run([v_mov_b32_e32(v[0], 0x0000FF00), ci])
self.assertEqual(st.vgpr[0][2], 0xFFFF00FF)
# ── structural: SUB ──
def test_emu_sub_u32(self):
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
x, y = _var('x', dtypes.uint), _var('y', dtypes.uint)
neg = UOp(Ops.MUL, dtypes.uint, (y, _const(-1, dtypes.uint)))
sub = UOp(Ops.ADD, dtypes.uint, (x, neg))
inst, nsrc = self._get_inst(sub)
ci = self._copy_inst(inst)
ci.src0 = v[0]; ci.vsrc1 = v[1]; ci.vdst = v[2]
st = self._run([v_mov_b32_e32(v[0], 30), v_mov_b32_e32(v[1], 12), ci])
self.assertEqual(st.vgpr[0][2], 18)
# ── structural: RCP ──
def test_emu_rcp_f32(self):
from extra.assembly.amd.test.hw.helpers import i2f
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
a = _var('a')
rcp = UOp(Ops.RECIPROCAL, dtypes.float, (a,))
mul_rcp = UOp(Ops.MUL, dtypes.float, (_const(1.0), rcp))
inst, _ = self._get_inst(mul_rcp)
ci = self._copy_inst(inst)
ci.src0 = v[0]; ci.vdst = v[2]
st = self._run([v_mov_b32_e32(v[0], 4.0), ci])
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 0.25, places=4)
# ── structural: MAD ──
def test_emu_mad_u32(self):
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
mul = UOp(Ops.MUL, dtypes.uint, (x, y))
mad = UOp(Ops.ADD, dtypes.uint, (mul, z))
inst, _ = self._get_inst(mad)
ci = self._copy_inst(inst)
# VOP3 format: src0, src1, src2, vdst
ci.src0 = v[0]; ci.src1 = v[1]; ci.src2 = v[2]; ci.vdst = v[3]
st = self._run([v_mov_b32_e32(v[0], 3), v_mov_b32_e32(v[1], 4), v_mov_b32_e32(v[2], 5), ci])
self.assertEqual(st.vgpr[0][3], 17) # 3*4 + 5 = 17
# ── structural: ADD3 ──
def test_emu_add3_u32(self):
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
add1 = UOp(Ops.ADD, dtypes.uint, (x, y))
add3 = UOp(Ops.ADD, dtypes.uint, (add1, z))
inst, _ = self._get_inst(add3)
ci = self._copy_inst(inst)
ci.src0 = v[0]; ci.src1 = v[1]; ci.src2 = v[2]; ci.vdst = v[3]
st = self._run([v_mov_b32_e32(v[0], 10), v_mov_b32_e32(v[1], 20), v_mov_b32_e32(v[2], 30), ci])
self.assertEqual(st.vgpr[0][3], 60) # 10+20+30
# ── structural: XOR3 ──
def test_emu_xor3_b32(self):
from extra.assembly.amd.autogen.rdna3.ins import v, v_mov_b32_e32
x, y, z = _var('x', dtypes.uint), _var('y', dtypes.uint), _var('z', dtypes.uint)
xor1 = UOp(Ops.XOR, dtypes.uint, (x, y))
xor3 = UOp(Ops.XOR, dtypes.uint, (xor1, z))
inst, _ = self._get_inst(xor3)
ci = self._copy_inst(inst)
ci.src0 = v[0]; ci.src1 = v[1]; ci.src2 = v[2]; ci.vdst = v[3]
st = self._run([v_mov_b32_e32(v[0], 0xFF), v_mov_b32_e32(v[1], 0x0F), v_mov_b32_e32(v[2], 0x33), ci])
self.assertEqual(st.vgpr[0][3], 0xFF ^ 0x0F ^ 0x33)
if __name__ == '__main__':
unittest.main()

View file

@ -189,7 +189,7 @@ CPU_CC, CPU_LLVM, CPU_LVP = ContextVar("CPU_CC", ""), ContextVar("CPU_LLVM", 0),
NV_CC, NV_PTX, NV_NAK = ContextVar("NV_CC", ""), ContextVar("NV_PTX", 0), ContextVar("NV_NAK", 0)
CUDA_CC, CUDA_PTX, CUDA_NVCC = ContextVar("CUDA_CC", ""), ContextVar("CUDA_PTX", 0), ContextVar("CUDA_NVCC", 0)
NULL_IR3, NULL_NAK, NULL_ALLOW_COPYOUT = ContextVar("NULL_IR3", 0), ContextVar("NULL_NAK", 0), ContextVar("NULL_ALLOW_COPYOUT", 0)
AMD_CC, AMD_LLVM, AMD_HIPCC = ContextVar("AMD_CC", ""), ContextVar("AMD_LLVM", 0), ContextVar("AMD_HIPCC", 0)
AMD_CC, AMD_LLVM, AMD_HIPCC, AMD_ISEL, AMD_ASM = ContextVar("AMD_CC", ""), ContextVar("AMD_LLVM", 0), ContextVar("AMD_HIPCC", 0), ContextVar("AMD_ISEL", 0), ContextVar("AMD_ASM", 0)
QCOM_CC, QCOM_IR3 = ContextVar("QCOM_CC", ""), ContextVar("QCOM_IR3", 0)
# VIZ implies PROFILE, but you can run PROFILE without VIZ
VIZ = ContextVar("VIZ", 0)

View file

@ -444,6 +444,9 @@ class Inst:
def __eq__(self, other): return type(self) is type(other) and self._raw == other._raw
def __hash__(self): return hash((type(self), self._raw))
def __lt__(self, other):
if not isinstance(other, Inst): return NotImplemented
return (type(self).__name__, self._raw) < (type(other).__name__, other._raw)
def __repr__(self):
# collect (repr, is_default) pairs, strip trailing defaults so repr roundtrips with eval

View file

@ -0,0 +1,325 @@
# Instruction selection for AMD GPUs via pcode-derived PatternMatcher
# Parses AMD pcode specs into UOp templates, normalizes them, and converts to UPat patterns
# that rewrite renderer-level UOps into Ops.INS with arg=Inst objects
import functools
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, GroupOp
from tinygrad.dtype import dtypes, DType
from extra.assembly.amd.emu import parse_pcode
from extra.assembly.amd.autogen.rdna3.str_pcode import PCODE
from extra.assembly.amd.autogen.rdna3 import ins as rdna3_ins
# sentinel UOps representing source operands (typed as u32, like real registers)
_SENTINEL = {f'S{i}': UOp(Ops.DEFINE_VAR, dtypes.uint32, arg=(f'S{i}', 0, 0xFFFFFFFF)) for i in range(4)}
_SENTINEL_SET = set(_SENTINEL.values())
# only parse ALU-relevant opcode types (memory ops need different sentinels)
_ALU_ENUM_TYPES = frozenset({'SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'VOP1Op', 'VOP2Op',
'VOP3Op', 'VOP3POp', 'VOP3SDOp', 'VOPCOp', 'VINTERPOp'})
# SOP types use SGPRs — skip for LLVM inline asm renderer (VGPR-only)
_SOP_ENUM_TYPES = frozenset({'SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp'})
# opcode enum class name -> list of Inst class suffixes to try
_VARIANT_SUFFIXES = ['', '_SDST']
def make_inst(opcode):
"""Create an Inst object with just the opcode set (registers defaulted)."""
base_name = type(opcode).__name__[:-2]
for suffix in _VARIANT_SUFFIXES:
cls = getattr(rdna3_ins, base_name + suffix, None)
if cls is None: continue
try: return cls(op=opcode)
except (RuntimeError, TypeError): continue
raise RuntimeError(f"no Inst class found for {opcode}")
# ═══════════════════════════════════════════════════════════════
# Normalization: strip register-model artifacts from pcode UOps
# ═══════════════════════════════════════════════════════════════
def normalize(uop, _cache=None):
"""Strip register-model artifacts from pcode UOps to match renderer-level UOps.
Pcode models registers as typeless u32 words and uses BITCAST/CAST to reinterpret.
The renderer's UOps are natively typed — we strip these artifacts:
- BITCAST(f32, sentinel_u32) -> sentinel typed as f32
- CAST(i32, sentinel_u32) -> sentinel typed as i32 (same-size reinterpret)
- CAST(u64, sentinel_u32) -> sentinel typed as u64 (widening for 64-bit ops)
- BITCAST(T, x) where x.dtype == T -> x (identity bitcast)
- AND(x, mask) where mask is shift masking -> x (hardware does this implicitly)
"""
if _cache is None: _cache = {}
if id(uop) in _cache: return _cache[id(uop)]
# first recurse so children are normalized before we check patterns
new_src = tuple(normalize(s, _cache) for s in uop.src)
uop = uop if new_src == uop.src else uop.replace(src=new_src)
# BITCAST or CAST on a sentinel -> sentinel with target dtype
if uop.op in (Ops.BITCAST, Ops.CAST) and len(uop.src) == 1 and uop.src[0] in _SENTINEL_SET:
result = uop.src[0].replace(dtype=uop.dtype)
_cache[id(uop)] = result
return result
# identity BITCAST: BITCAST(T, x) where x already has dtype T
if uop.op == Ops.BITCAST and len(uop.src) == 1 and uop.src[0].dtype == uop.dtype:
_cache[id(uop)] = uop.src[0]
return uop.src[0]
# shift masking: AND(sentinel, 31) or AND(sentinel, 63) -> sentinel (hardware masks shift amounts)
if uop.op == Ops.AND and len(uop.src) == 2:
if uop.src[1].op == Ops.CONST and uop.src[1].arg in (31, 63) and uop.src[0].op == Ops.DEFINE_VAR:
_cache[id(uop)] = uop.src[0]
return uop.src[0]
_cache[id(uop)] = uop
return uop
def _count_nodes(uop, _seen=None):
if _seen is None: _seen = set()
if id(uop) in _seen: return 0
_seen.add(id(uop))
return 1 + sum(_count_nodes(s, _seen) for s in uop.src)
# ═══════════════════════════════════════════════════════════════
# UOp template -> UPat conversion
# ═══════════════════════════════════════════════════════════════
def uop_to_upat(uop, _seen=None):
"""Convert a normalized UOp template into a matchable UPat pattern."""
if _seen is None: _seen = {}
if id(uop) in _seen: return _seen[id(uop)]
if uop.op == Ops.DEFINE_VAR and isinstance(uop.arg, tuple) and uop.arg[0] in _SENTINEL:
result = UPat.var(uop.arg[0], dtype=uop.dtype)
_seen[id(uop)] = result
return result
if uop.op in (Ops.CONST, Ops.VCONST):
result = UPat(uop.op, uop.dtype, arg=uop.arg)
_seen[id(uop)] = result
return result
src = tuple(uop_to_upat(s, _seen) for s in uop.src) if uop.src else None
result = UPat(uop.op, uop.dtype, src=src)
_seen[id(uop)] = result
return result
# ═══════════════════════════════════════════════════════════════
# Pattern classification and selection
# ═══════════════════════════════════════════════════════════════
def _select_best_opcode(opcodes):
"""Prefer shorter encodings: VOP1/VOP2 > SOP > VOP3/VOPC."""
_PREF = {'VOP1Op': 0, 'VOP2Op': 0, 'SOP1Op': 1, 'SOP2Op': 1, 'SOPCOp': 1, 'VOPCOp': 2, 'VOP3Op': 3, 'VOP3SDOp': 3, 'VOP3POp': 4}
return min(opcodes, key=lambda oc: (_PREF.get(type(oc).__name__, 9), oc.value))
def _is_direct_alu(norm_uop):
"""Check if normalized UOp is a direct ALU: op(sentinels...) with no intermediate ops."""
if norm_uop.op not in GroupOp.ALU and norm_uop.op not in {Ops.CAST, Ops.BITCAST}: return False
return all(s.op == Ops.DEFINE_VAR for s in norm_uop.src)
def _pattern_key(uop, _seen=None):
"""Structural fingerprint for a normalized UOp template (sentinels become var placeholders)."""
if _seen is None: _seen = {}
if id(uop) in _seen: return _seen[id(uop)]
if uop.op == Ops.DEFINE_VAR and isinstance(uop.arg, tuple) and uop.arg[0] in _SENTINEL:
result = f'var({uop.arg[0]},{uop.dtype})'
elif uop.op in (Ops.CONST, Ops.VCONST):
result = f'const({uop.arg},{uop.dtype})'
else:
children = ','.join(_pattern_key(s, _seen) for s in uop.src)
result = f'{uop.op}({uop.dtype},{children})'
_seen[id(uop)] = result
return result
def _runtime_key(uop, _var_counter=None, _seen=None):
"""Compute a structural key from a matched UOp at runtime (real data, not sentinels).
Leaf UOps (non-ALU with no recognized children) are treated as variables."""
if _seen is None: _seen = {}
if _var_counter is None: _var_counter = [0]
uid = id(uop)
if uid in _seen: return _seen[uid]
_ALU_OPS = GroupOp.ALU | {Ops.CAST, Ops.BITCAST, Ops.WHERE}
if uop.op in (Ops.CONST, Ops.VCONST):
result = f'const({uop.arg},{uop.dtype})'
elif uop.op not in _ALU_OPS:
result = f'var(S{_var_counter[0]},{uop.dtype})'
_var_counter[0] += 1
else:
children = ','.join(_runtime_key(s, _var_counter, _seen) for s in uop.src)
result = f'{uop.op}({uop.dtype},{children})'
_seen[uid] = result
return result
# ═══════════════════════════════════════════════════════════════
# Build tables from pcode
# ═══════════════════════════════════════════════════════════════
def _parse_pcode_patterns(pcode_dict, vgpr_only=False):
"""Parse ALU pcode entries, normalize, and categorize into direct vs structural."""
# key: (op, src_dtypes_tuple, dst_dtype) -> [opcodes]
direct: dict[tuple, list] = {}
structural: list[tuple] = [] # [(opcode, norm_uop)]
allowed_types = _ALU_ENUM_TYPES - _SOP_ENUM_TYPES if vgpr_only else _ALU_ENUM_TYPES
for opcode, pcode_str in pcode_dict.items():
if type(opcode).__name__ not in allowed_types: continue
try: env, assigns = parse_pcode(pcode_str, dict(_SENTINEL))
except Exception: continue
d0 = next(((n, u) for n, u in assigns if n.startswith('D0')), None)
if d0 is None: continue
_, uop = d0
if _count_nodes(uop) > 5: continue
norm = normalize(uop)
if norm.op == Ops.DEFINE_VAR or norm.op in (Ops.CONST, Ops.VCONST): continue
if _is_direct_alu(norm):
src_dtypes = tuple(s.dtype for s in norm.src)
key = (norm.op, src_dtypes, norm.dtype)
direct.setdefault(key, []).append(opcode)
else:
structural.append((opcode, norm))
# pick best opcode for each direct pattern
direct_best = {k: _select_best_opcode(v) for k, v in direct.items()}
# deduplicate structural patterns by shape, pick best
seen: dict[str, list] = {}
for opcode, norm in structural:
key = _pattern_key(norm)
seen.setdefault(key, []).append((opcode, norm))
structural_best = []
for key, group in seen.items():
best = _select_best_opcode([oc for oc, _ in group])
best_norm = next(n for oc, n in group if oc == best)
structural_best.append((best, best_norm, key))
return direct_best, structural_best
# ═══════════════════════════════════════════════════════════════
# Global tables (populated by build_isel_patterns, used by callbacks)
# ═══════════════════════════════════════════════════════════════
# direct: (op, src_dtypes, dst_dtype) -> Inst
_DIRECT_TABLE: dict[tuple, object] = {}
# structural: pattern_key_string -> Inst
_STRUCTURAL_TABLE: dict[str, object] = {}
# ops that LLVM handles natively (compares write VCC, not VGPRs)
_SKIP_OPS = frozenset({Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ})
def _isel_direct(m):
"""Callback for direct ALU: look up Inst by (op, src_dtypes, dtype)."""
if m.op in _SKIP_OPS or m.dtype == dtypes.bool: return None
src_dtypes = tuple(s.dtype for s in m.src)
inst = _DIRECT_TABLE.get((m.op, src_dtypes, m.dtype))
if inst is None: return None
return UOp(Ops.INS, m.dtype, m.src, arg=inst)
def _isel_structural(m, **kwargs):
"""Callback for structural patterns: compute runtime key, look up Inst."""
if m.op in _SKIP_OPS or m.dtype == dtypes.bool: return None
key = _runtime_key(m)
inst = _STRUCTURAL_TABLE.get(key)
if inst is None: return None
# collect source vars in order (leaves of the matched tree)
srcs = _collect_leaves(m)
return UOp(Ops.INS, m.dtype, tuple(srcs), arg=inst)
def _collect_leaves(uop, _seen=None):
"""Collect leaf UOps (non-ALU) from a matched tree in left-to-right order."""
if _seen is None: _seen = set()
_ALU_OPS = GroupOp.ALU | {Ops.CAST, Ops.BITCAST, Ops.WHERE}
uid = id(uop)
if uid in _seen: return []
_seen.add(uid)
if uop.op in (Ops.CONST, Ops.VCONST): return []
if uop.op not in _ALU_OPS: return [uop]
result = []
for s in uop.src:
result.extend(_collect_leaves(s, _seen))
return result
# ═══════════════════════════════════════════════════════════════
# Build the PatternMatcher
# ═══════════════════════════════════════════════════════════════
def build_isel_patterns(pcode_dict=PCODE, vgpr_only=False) -> PatternMatcher:
"""Parse pcode and build a PatternMatcher for instruction selection."""
direct_best, structural_best = _parse_pcode_patterns(pcode_dict, vgpr_only=vgpr_only)
# populate direct table
_DIRECT_TABLE.clear()
for (op, src_dtypes, dtype), opcode in direct_best.items():
_DIRECT_TABLE[(op, src_dtypes, dtype)] = make_inst(opcode)
# populate structural table
_STRUCTURAL_TABLE.clear()
for opcode, norm, pkey in structural_best:
_STRUCTURAL_TABLE[pkey] = make_inst(opcode)
patterns: list[tuple] = []
# structural patterns first (more specific, should match before catch-all direct)
for opcode, norm, pkey in structural_best:
pat = uop_to_upat(norm).named('m')
patterns.append((pat, _isel_structural))
# direct ALU: catch-all patterns that look up by (op, src_dtypes, dtype)
patterns.append((UPat(GroupOp.ALU, name='m'), _isel_direct))
patterns.append((UPat(Ops.CAST, name='m'), _isel_direct))
patterns.append((UPat(Ops.BITCAST, name='m'), _isel_direct))
return PatternMatcher(patterns)
@functools.cache
def rdna3_isel() -> PatternMatcher:
"""Build the default RDNA3 instruction selector (VOP-only for LLVM inline asm)."""
return build_isel_patterns(PCODE, vgpr_only=True)
# ═══════════════════════════════════════════════════════════════
# LLVM inline asm rendering for Ops.INS
# ═══════════════════════════════════════════════════════════════
import re
from tinygrad.dtype import PtrDType
def _ins_mnemonic(inst) -> str:
"""Get the assembly mnemonic from an Inst opcode (strip _E32/_E64 suffix)."""
return re.sub(r'_e(32|64)$', '', inst.op.name.lower())
def _ldt(dt):
"""LLVM type string for a DType."""
if dt.vcount > 1: return f"<{dt.vcount} x {_ldt(dt.scalar())}>"
if isinstance(dt, PtrDType): return _ldt(dt.base) + "*"
return {dtypes.void: "void", dtypes.bool: "i1", dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64",
dtypes.float16: "half", dtypes.bfloat16: "bfloat", dtypes.float32: "float", dtypes.float64: "double"}[dt]
def render_ins_llvm(ctx, x):
"""Render Ops.INS as LLVM inline assembly call."""
inst = x.arg
mnem = _ins_mnemonic(inst)
n_srcs = len(x.src)
# build operand string: $0 = dest, $1..$N = sources
ops = ", ".join(f"${i}" for i in range(n_srcs + 1))
asm_str = f"{mnem} {ops}"
# constraints: =v for output, v for each input (VGPR)
constraints = "=v," + ",".join("v" for _ in range(n_srcs))
# LLVM types and values
ret_type = _ldt(x.dtype)
args = ", ".join(f"{_ldt(s.dtype)} {ctx[s]}" for s in x.src)
return f" {ctx[x]} = call {ret_type} asm \"{asm_str}\", \"{constraints}\"({args})"
# ═══════════════════════════════════════════════════════════════
# AMDISELRenderer: LLVM renderer with pcode-based instruction selection
# ═══════════════════════════════════════════════════════════════
from tinygrad.renderer.llvmir import AMDLLVMRenderer
class AMDISELRenderer(AMDLLVMRenderer):
"""AMD renderer that uses pcode-derived instruction selection for ALU ops."""
def __init__(self, arch: str):
super().__init__(arch)
# add ISel as extra_matcher: rewrites ALU UOps → Ops.INS
self.extra_matcher = self.extra_matcher + rdna3_isel()
# add Ops.INS rendering to string_rewrite
self.string_rewrite = PatternMatcher([(UPat(Ops.INS, name='x'), render_ins_llvm)]) + self.string_rewrite
def __reduce__(self): return self.__class__, (self.arch,)

View file

@ -0,0 +1,427 @@
# Direct AMD GPU assembly renderer — emits Inst objects, produces GAS text via disasm()
# No LLVM. Uses HIPCompiler (COMGR) to assemble text into ELF.
import functools, math
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, GroupOp
from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import AMDHIPRenderer
from extra.assembly.amd.dsl import Inst, Reg, s, v, NULL, VCC_LO, EXEC_LO, M0
from extra.assembly.amd.autogen.rdna3.ins import (s_load_b64, s_load_b128, s_mov_b32, s_waitcnt, s_endpgm, s_barrier,
s_branch, s_cbranch_scc0, s_cbranch_scc1, s_cmp_ge_i32, s_add_i32, s_and_b32, s_lshl_b32,
v_mov_b32_e32, v_add_f32_e32, v_add_nc_u32_e32, v_lshlrev_b32_e32, v_lshrrev_b32_e32,
v_and_b32_e32, v_mul_lo_u32, v_cmp_lt_i32_e32,
global_load_b32, global_load_b64, global_load_b128, global_store_b32, global_store_b64, global_store_b128,
ds_load_b32, ds_store_b32)
from extra.assembly.amd.test.disasm import disasm
from extra.assembly.amd.isel import rdna3_isel, make_inst
# ═══════════════════════════════════════════════════════════════
# Register allocator — simple bump allocator
# ═══════════════════════════════════════════════════════════════
class RegFile:
"""Simple register allocator: bump-allocates VGPRs and SGPRs."""
def __init__(self):
self.next_vgpr = 1 # v0 = workitem_id_x (reserved by hardware)
self.next_sgpr = 0 # s[0:1] = kernarg_ptr (reserved by ABI)
self.max_vgpr = 1
self.max_sgpr = 0
def alloc_vgpr(self, count=1) -> Reg:
r = v[self.next_vgpr] if count == 1 else v[self.next_vgpr:self.next_vgpr + count - 1]
self.next_vgpr += count
self.max_vgpr = max(self.max_vgpr, self.next_vgpr)
return r
def alloc_sgpr(self, count=1) -> Reg:
# align to 2 for 64-bit, 4 for 128-bit
if count >= 4: self.next_sgpr = (self.next_sgpr + 3) & ~3
elif count >= 2: self.next_sgpr = (self.next_sgpr + 1) & ~1
r = s[self.next_sgpr] if count == 1 else s[self.next_sgpr:self.next_sgpr + count - 1]
self.next_sgpr += count
self.max_sgpr = max(self.max_sgpr, self.next_sgpr)
return r
# ═══════════════════════════════════════════════════════════════
# Instruction emitter (like amd_asm_matmul.Kernel)
# ═══════════════════════════════════════════════════════════════
class AsmKernel:
def __init__(self, arch='gfx1100'):
self.instructions: list[Inst] = []
self.labels: dict[str, int] = {}
self.pos = 0
self.arch = arch
self.regs = RegFile()
self.lds_size = 0
def emit(self, inst, target=None):
self.instructions.append(inst)
inst._target = target
inst._pos = self.pos
self.pos += inst.size()
return inst
def label(self, name):
self.labels[name] = self.pos
def waitcnt(self, lgkm=None, vm=None):
vmcnt = vm if vm is not None else 63
lgkmcnt = lgkm if lgkm is not None else 63
expcnt = 7
wc = (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10)
self.emit(s_waitcnt(simm16=wc))
def resolve_branches(self):
for inst in self.instructions:
if hasattr(inst, '_target') and inst._target is not None:
offset_dwords = (self.labels[inst._target] - inst._pos - inst.size()) // 4
inst.simm16 = offset_dwords
def to_asm(self, name='kernel', kernarg_size=0, n_params=0) -> str:
self.resolve_branches()
body = ['\t' + disasm(inst) for inst in self.instructions]
hsa = [
('group_segment_fixed_size', self.lds_size), ('private_segment_fixed_size', 0), ('kernarg_size', kernarg_size),
('user_sgpr_count', 2), ('user_sgpr_kernarg_segment_ptr', 1),
('wavefront_size32', 1), ('uses_dynamic_stack', 0), ('enable_private_segment', 0),
('system_sgpr_workgroup_id_x', 1), ('system_sgpr_workgroup_id_y', 1), ('system_sgpr_workgroup_id_z', 0),
('system_vgpr_workitem_id', 0), ('next_free_vgpr', self.regs.max_vgpr),
('next_free_sgpr', max(self.regs.max_sgpr, 4)), # minimum 4 SGPRs
('float_round_mode_32', 0), ('float_round_mode_16_64', 0),
('float_denorm_mode_32', 3), ('float_denorm_mode_16_64', 3),
('dx10_clamp', 1), ('ieee_mode', 1), ('fp16_overflow', 0),
('workgroup_processor_mode', 0), ('memory_ordered', 1), ('forward_progress', 0), ('shared_vgpr_count', 0)]
args_meta = '\n'.join(
f' - .address_space: global\n .offset: {i*8}\n .size: 8\n .value_kind: global_buffer'
for i in range(n_params))
return '\n'.join([
'\t.text', f'\t.amdgcn_target "amdgcn-amd-amdhsa--{self.arch}"',
f'\t.protected\t{name}', f'\t.globl\t{name}', '\t.p2align\t8', f'\t.type\t{name},@function', f'{name}:',
*body,
'\t.section\t.rodata,"a",@progbits', '\t.p2align\t6, 0x0', f'\t.amdhsa_kernel {name}',
*[f'\t\t.amdhsa_{k} {v}' for k, v in hsa],
f'\t.end_amdhsa_kernel', '\t.text', f'.Lfunc_end0:', f'\t.size\t{name}, .Lfunc_end0-{name}',
'\t.amdgpu_metadata', '---', 'amdhsa.kernels:', ' - .args:',
args_meta,
f' .group_segment_fixed_size: {self.lds_size}', ' .kernarg_segment_align: 8',
f' .kernarg_segment_size: {kernarg_size}', ' .max_flat_workgroup_size: 1024',
f' .name: {name}', ' .private_segment_fixed_size: 0',
f' .sgpr_count: {max(self.regs.max_sgpr, 4)}', f' .symbol: {name}.kd',
f' .vgpr_count: {self.regs.max_vgpr}', ' .wavefront_size: 32',
f'amdhsa.target: amdgcn-amd-amdhsa--{self.arch}',
'amdhsa.version:', ' - 1', ' - 2', '...', '\t.end_amdgpu_metadata'])
# ═══════════════════════════════════════════════════════════════
# UOp → Inst rendering
# ═══════════════════════════════════════════════════════════════
# dtype → register count for VGPRs
def _dtype_regs(dt: DType) -> int:
if isinstance(dt, PtrDType): return 2 # 64-bit pointer
return max(1, dt.itemsize // 4) * (dt.vcount if hasattr(dt, 'vcount') and dt.vcount > 1 else 1)
# dtype → global load instruction
def _global_load(vdst, addr, saddr, offset=0, nregs=1):
if nregs == 1: return global_load_b32(vdst=vdst, addr=addr, saddr=saddr, offset=offset)
if nregs == 2: return global_load_b64(vdst=vdst, addr=addr, saddr=saddr, offset=offset)
if nregs == 4: return global_load_b128(vdst=vdst, addr=addr, saddr=saddr, offset=offset)
raise RuntimeError(f"unsupported global load size: {nregs} regs")
def _global_store(addr, data, saddr, offset=0, nregs=1):
if nregs == 1: return global_store_b32(addr=addr, data=data, saddr=saddr, offset=offset)
if nregs == 2: return global_store_b64(addr=addr, data=data, saddr=saddr, offset=offset)
if nregs == 4: return global_store_b128(addr=addr, data=data, saddr=saddr, offset=offset)
raise RuntimeError(f"unsupported global store size: {nregs} regs")
def render_kernel(uops: list[UOp], arch='gfx1100') -> str:
"""Render linearized UOps into GAS assembly text."""
k = AsmKernel(arch)
# r maps UOp → register (Reg)
r: dict[UOp, Reg] = {}
# s_args: SGPR pairs for kernel argument pointers, loaded from kernarg segment
# kernarg_ptr is in s[0:1] (set by HSA ABI)
kernarg_base = k.regs.alloc_sgpr(2) # s[0:1] = kernarg segment pointer
# system SGPRs for workgroup IDs come after user SGPRs
# with user_sgpr_count=2, workgroup_id_x is s[2], workgroup_id_y is s[3]
wg_id_x_sgpr = 2
wg_id_y_sgpr = 3
name = 'test'
params: list[tuple[int, Reg]] = [] # (param_idx, sgpr_pair)
specials: dict[str, Reg] = {}
loop_stack: list[tuple[str, str, Reg]] = [] # (label_start, label_end, range_reg)
n_params = 0
# first pass: count params
for u in uops:
if u.op is Ops.PARAM: n_params = max(n_params, u.arg + 1)
if u.op is Ops.SINK and u.arg is not None: name = u.arg.function_name
kernarg_size = n_params * 8 # each param is 8 bytes (pointer)
# load all kernel argument pointers
param_sgprs: dict[int, Reg] = {}
for i in range(n_params):
sp = k.regs.alloc_sgpr(2)
param_sgprs[i] = sp
k.emit(s_load_b64(sdata=sp, sbase=kernarg_base, offset=i * 8, soffset=NULL))
k.waitcnt(lgkm=0)
for u in uops:
if u.op is Ops.SINK:
continue
elif u.op is Ops.PARAM:
r[u] = param_sgprs[u.arg]
elif u.op is Ops.CONST:
if u.dtype == dtypes.float:
vr = k.regs.alloc_vgpr()
k.emit(v_mov_b32_e32(vr, u.arg))
r[u] = vr
elif u.dtype in (dtypes.int, dtypes.int32, dtypes.uint, dtypes.uint32):
vr = k.regs.alloc_vgpr()
k.emit(v_mov_b32_e32(vr, u.arg if isinstance(u.arg, int) and -16 <= u.arg <= 64 else u.arg))
r[u] = vr
elif u.dtype == dtypes.bool:
# booleans: 1=true, 0=false — stored in VGPR as int
vr = k.regs.alloc_vgpr()
k.emit(v_mov_b32_e32(vr, 1 if u.arg else 0))
r[u] = vr
else:
raise RuntimeError(f"unsupported CONST dtype {u.dtype}")
elif u.op is Ops.SPECIAL:
kind, idx = u.arg[0], int(u.arg[-1])
if kind == 'l':
# local thread ID — workitem_id_{x,y,z} already in v0 (only x for 1D)
if idx == 0:
r[u] = v[0] # workitem_id_x is pre-loaded in v0 by hardware
else:
raise RuntimeError(f"unsupported local dim {idx}")
elif kind == 'g':
# workgroup ID — in system SGPRs (after user SGPRs)
sgpr_off = wg_id_x_sgpr + idx
vr = k.regs.alloc_vgpr()
k.emit(v_mov_b32_e32(vr, s[sgpr_off]))
r[u] = vr
else:
raise RuntimeError(f"unsupported SPECIAL kind {kind}")
elif u.op is Ops.INDEX:
# INDEX(ptr, idx) — compute byte address: base_ptr + idx * element_size
base = r[u.src[0]]
idx_reg = r[u.src[1]]
assert isinstance(u.dtype, PtrDType), f"INDEX must produce pointer, got {u.dtype}"
elem_size = u.dtype.base.itemsize
# compute byte offset: idx * elem_size
offset_vr = k.regs.alloc_vgpr()
if elem_size == 4:
k.emit(v_lshlrev_b32_e32(offset_vr, 2, idx_reg))
elif elem_size == 8:
k.emit(v_lshlrev_b32_e32(offset_vr, 3, idx_reg))
elif elem_size == 16:
k.emit(v_lshlrev_b32_e32(offset_vr, 4, idx_reg))
elif elem_size == 2:
k.emit(v_lshlrev_b32_e32(offset_vr, 1, idx_reg))
elif elem_size == 1:
k.emit(v_mov_b32_e32(offset_vr, idx_reg))
else:
k.emit(v_mul_lo_u32(offset_vr, elem_size, idx_reg))
# base is an SGPR pair (64-bit pointer), offset is VGPR — use scalar+vector addressing
r[u] = offset_vr # store offset VGPR; base SGPR pair stored separately
# stash the base pointer for LOAD/STORE to use
u._base_sgpr = base
elif u.op is Ops.LOAD:
idx_uop = u.src[0]
assert idx_uop.op is Ops.INDEX or (idx_uop.op is Ops.CAST and idx_uop.src[0].op is Ops.INDEX)
real_idx = idx_uop.src[0] if idx_uop.op is Ops.CAST else idx_uop
base_sgpr = real_idx._base_sgpr
offset_vr = r[real_idx]
nregs = max(1, u.dtype.itemsize // 4) * (u.dtype.vcount if hasattr(u.dtype, 'vcount') and u.dtype.vcount > 1 else 1)
dst = k.regs.alloc_vgpr(nregs)
k.emit(_global_load(dst, offset_vr, base_sgpr, nregs=nregs))
k.waitcnt(vm=0)
r[u] = dst
elif u.op is Ops.STORE:
idx_uop = u.src[0]
assert idx_uop.op is Ops.INDEX or (idx_uop.op is Ops.CAST and idx_uop.src[0].op is Ops.INDEX)
real_idx = idx_uop.src[0] if idx_uop.op is Ops.CAST else idx_uop
base_sgpr = real_idx._base_sgpr
offset_vr = r[real_idx]
val_reg = r[u.src[1]]
nregs = max(1, u.src[1].dtype.itemsize // 4) * (u.src[1].dtype.vcount if hasattr(u.src[1].dtype, 'vcount') and u.src[1].dtype.vcount > 1 else 1)
k.emit(_global_store(offset_vr, val_reg, base_sgpr, nregs=nregs))
r[u] = offset_vr # stores don't produce values, but map for dependencies
elif u.op is Ops.ADD:
a_reg, b_reg = r[u.src[0]], r[u.src[1]]
dst = k.regs.alloc_vgpr()
if u.dtype == dtypes.float:
k.emit(v_add_f32_e32(dst, a_reg, b_reg))
elif u.dtype in (dtypes.int, dtypes.int32, dtypes.uint, dtypes.uint32):
k.emit(v_add_nc_u32_e32(dst, a_reg, b_reg))
else:
raise RuntimeError(f"unsupported ADD dtype {u.dtype}")
r[u] = dst
elif u.op is Ops.MUL:
a_reg, b_reg = r[u.src[0]], r[u.src[1]]
dst = k.regs.alloc_vgpr()
if u.dtype == dtypes.float:
from extra.assembly.amd.autogen.rdna3.ins import v_mul_f32_e32
k.emit(v_mul_f32_e32(dst, a_reg, b_reg))
elif u.dtype in (dtypes.int, dtypes.int32, dtypes.uint, dtypes.uint32):
k.emit(v_mul_lo_u32(dst, a_reg, b_reg))
else:
raise RuntimeError(f"unsupported MUL dtype {u.dtype}")
r[u] = dst
elif u.op is Ops.SHL:
# SHL(val, shift) -> v_lshlrev_b32(shift, val) (reversed operands)
val_reg, shift_reg = r[u.src[0]], r[u.src[1]]
dst = k.regs.alloc_vgpr()
k.emit(v_lshlrev_b32_e32(dst, shift_reg, val_reg))
r[u] = dst
elif u.op is Ops.SHR:
val_reg, shift_reg = r[u.src[0]], r[u.src[1]]
dst = k.regs.alloc_vgpr()
k.emit(v_lshrrev_b32_e32(dst, shift_reg, val_reg))
r[u] = dst
elif u.op is Ops.AND:
a_reg, b_reg = r[u.src[0]], r[u.src[1]]
dst = k.regs.alloc_vgpr()
k.emit(v_and_b32_e32(dst, a_reg, b_reg))
r[u] = dst
elif u.op is Ops.CAST:
# for now: pointer casts are noops, numeric casts need work
if isinstance(u.dtype, PtrDType):
r[u] = r[u.src[0]]
if hasattr(u.src[0], '_base_sgpr'): u._base_sgpr = u.src[0]._base_sgpr
else:
raise RuntimeError(f"unsupported CAST {u.src[0].dtype} -> {u.dtype}")
elif u.op is Ops.VECTORIZE:
# VECTORIZE packs scalars into a vector — just allocate contiguous VGPRs
count = len(u.src)
dst = k.regs.alloc_vgpr(count)
for i, src_u in enumerate(u.src):
src_reg = r[src_u]
target = v[dst.offset - 256 + i] if count > 1 else dst
if src_reg.offset != target.offset:
k.emit(v_mov_b32_e32(target, src_reg))
r[u] = dst
elif u.op is Ops.GEP:
# GEP extracts element from vector — just offset into the VGPR range
base_reg = r[u.src[0]]
idx = u.arg[0]
r[u] = v[base_reg.offset - 256 + idx]
elif u.op in (Ops.NOOP, Ops.GROUP, Ops.AFTER):
if u.src: r[u] = r[u.src[0]]
elif u.op is Ops.RANGE:
# loop: counter starts at 0, increments by 1, bound is src[0]
label_start = f'loop_{id(u)}'
label_end = f'end_{id(u)}'
ctr = k.regs.alloc_vgpr()
k.emit(v_mov_b32_e32(ctr, 0))
k.label(label_start)
r[u] = ctr
loop_stack.append((label_start, label_end, ctr))
elif u.op is Ops.END:
label_start, label_end, ctr = loop_stack.pop()
# increment counter
k.emit(v_add_nc_u32_e32(ctr, 1, ctr))
# compare and branch: use SGPR compare since loop bound should be uniform
bound_uop = u.src[1] # the RANGE uop's src[0] is the bound
# actually END.src = (range_uop, ...), range_uop.src[0] = bound
range_uop = u.src[0]
bound_reg = r[range_uop.src[0]]
k.emit(v_cmp_lt_i32_e32(ctr, bound_reg))
k.emit(s_cbranch_scc1(), target=label_start)
k.label(label_end)
elif u.op is Ops.BARRIER:
k.emit(s_barrier())
elif u.op is Ops.DEFINE_LOCAL:
# LDS allocation — just track size, address computed at use time
r[u] = v[0] # placeholder, LDS addressing handled separately
k.lds_size = max(k.lds_size, u.dtype.size * u.dtype.base.itemsize if hasattr(u.dtype, 'size') else 0)
elif u.op is Ops.DEFINE_REG:
# register "spill" region — allocate VGPRs
size = u.dtype.size if hasattr(u.dtype, 'size') else 1
vr = k.regs.alloc_vgpr(size)
r[u] = vr
elif u.op is Ops.CMPLT:
# compare: write result to VCC, then v_cndmask to get bool in VGPR
a_reg, b_reg = r[u.src[0]], r[u.src[1]]
dst = k.regs.alloc_vgpr()
k.emit(v_cmp_lt_i32_e32(a_reg, b_reg))
from extra.assembly.amd.autogen.rdna3.ins import v_cndmask_b32_e32
k.emit(v_cndmask_b32_e32(dst, 0, 1))
r[u] = dst
elif u.op is Ops.CMPNE:
from extra.assembly.amd.autogen.rdna3.ins import v_cmp_ne_u32_e32, v_cndmask_b32_e32
a_reg, b_reg = r[u.src[0]], r[u.src[1]]
dst = k.regs.alloc_vgpr()
k.emit(v_cmp_ne_u32_e32(a_reg, b_reg))
k.emit(v_cndmask_b32_e32(dst, 0, 1))
r[u] = dst
elif u.op is Ops.WHERE:
from extra.assembly.amd.autogen.rdna3.ins import v_cndmask_b32_e32
cond_reg, true_reg, false_reg = r[u.src[0]], r[u.src[1]], r[u.src[2]]
dst = k.regs.alloc_vgpr()
# set VCC from condition (nonzero = true)
k.emit(v_cmp_lt_i32_e32(0, cond_reg)) # VCC = cond_reg != 0
k.emit(v_cndmask_b32_e32(dst, false_reg, true_reg))
r[u] = dst
else:
raise RuntimeError(f"unsupported UOp: {u.op} dtype={u.dtype}")
# epilogue
k.waitcnt(vm=0, lgkm=0)
k.emit(s_endpgm())
return k.to_asm(name=name, kernarg_size=kernarg_size, n_params=n_params)
# ═══════════════════════════════════════════════════════════════
# AMDAssemblyRenderer: Renderer subclass for tinygrad integration
# ═══════════════════════════════════════════════════════════════
class AMDAssemblyRenderer(Renderer):
device = "AMD"
suffix = "s" # GAS assembly
supports_float4 = True
has_local = True
has_shared = True
global_max = AMDHIPRenderer.global_max
shared_max = AMDHIPRenderer.shared_max
def __init__(self, arch: str):
from tinygrad.runtime.support.compiler_amd import HIPCompiler
self.arch = arch
self.compiler = HIPCompiler(arch)
def render(self, uops: list[UOp]) -> str:
return render_kernel(uops, arch=self.arch)
def __reduce__(self): return self.__class__, (self.arch,)

View file

@ -8,9 +8,11 @@ from tinygrad.runtime.support.hcq import MMIOInterface, BumpAllocator, hcq_filte
from tinygrad.uop.ops import sint
from tinygrad.device import Compiled, DMAFdRef, BufferSpec, CompilerSet
from tinygrad.helpers import getenv, round_up, data64_le, DEBUG, PROFILE, ProfileEvent, lo32, hi32, colored, prod, ContextVar
from tinygrad.helpers import VIZ, AMD_CC, AMD_LLVM, AMD_HIPCC, ceildiv, unwrap
from tinygrad.helpers import VIZ, AMD_CC, AMD_LLVM, AMD_HIPCC, AMD_ISEL, AMD_ASM, ceildiv, unwrap
from tinygrad.renderer.cstyle import AMDHIPRenderer, AMDHIPCCRenderer
from tinygrad.renderer.llvmir import AMDLLVMRenderer
from extra.assembly.amd.isel import AMDISELRenderer
from extra.assembly.amd.renderer import AMDAssemblyRenderer
from tinygrad.runtime.autogen import kfd, hsa, pci, sqtt, amdgpu_kd, amdgpu_drm
from tinygrad.runtime.autogen.am import am
from tinygrad.runtime.support.elf import elf_loader
@ -970,7 +972,9 @@ class AMDDevice(HCQCompiled):
compilers = CompilerSet([(functools.partial(AMDHIPRenderer, self.arch), None),
(functools.partial(AMDLLVMRenderer, self.arch), AMD_LLVM),
(functools.partial(AMDHIPCCRenderer, self.arch), AMD_HIPCC)], ctrl_var=AMD_CC)
(functools.partial(AMDHIPCCRenderer, self.arch), AMD_HIPCC),
(functools.partial(AMDISELRenderer, self.arch), AMD_ISEL),
(functools.partial(AMDAssemblyRenderer, self.arch), AMD_ASM)], ctrl_var=AMD_CC)
super().__init__(device, AMDAllocator(self), compilers, functools.partial(AMDProgram, self), AMDSignal,
functools.partial(AMDComputeAQLQueue if self.is_aql else AMDComputeQueue, self),

View file

@ -76,7 +76,7 @@ class Ops(FastEnum):
# CUSTOM/CUSTOMI are used to output strings into codegen. the I makes the string inline
CUSTOM = auto(); CUSTOMI = auto()
# INS is a machine instruction
# machine instruction: arg=Inst object, tag=register assignment
INS = auto()
# ** 6 -- ops that don't exist in programs **

View file

@ -289,7 +289,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.op is Ops.ASSIGN: return self.src[1]._shape
# elementwise ops keep the shape the same. all inputs with shape must match
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}):
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE, Ops.INS}):
input_shapes = [x._shape for x in self.src if x._shape is not None]
if len(input_shapes) == 0: return None
if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}")

View file

@ -177,7 +177,7 @@ shared_codegen_spec = PatternMatcher([
# CUSTOM (inline and non inline)
(UPat((Ops.CUSTOMI, Ops.CUSTOM)), lambda: True),
# assembly instruction
# machine instruction (ISel output)
(UPat(Ops.INS), lambda: True),
# INDEX (2-arg and 3-arg with bool gate)