remove arch from asm kernel class (#15977)

* rm arch from kernel

* update other tests

* update abstractions4.py
This commit is contained in:
qazal 2026-04-29 21:39:52 +03:00 committed by GitHub
commit a37b605523
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 20 additions and 20 deletions

View file

@ -105,7 +105,7 @@ def example_3_custom_uop(a:Tensor, correct):
def example_5_custom_assembly(a:Tensor, correct):
# Kernel class copied from amd_asm_matmul
class Kernel:
def __init__(self, arch='gfx1100'): self.instructions, self.labels, self.pos, self.arch = [], {}, 0, arch
def __init__(self): self.instructions, self.labels, self.pos = [], {}, 0
def label(self, name): self.labels[name] = self.pos
def emit(self, inst, target=None):
self.instructions.append(inst)

View file

@ -167,7 +167,7 @@ PREFETCH_LOADS = [(V_LDS_A_DATA[4+2*i], V_LDS_A_DATA[4+2*i+1], V_GLOBAL_B_ADDR,
# =============================================================================
class Kernel:
def __init__(self, arch='gfx1100'): self.instructions, self.labels, self.pos, self.arch = [], {}, 0, arch
def __init__(self): self.instructions, self.labels, self.pos = [], {}, 0
def label(self, name): self.labels[name] = self.pos
def emit(self, inst, target=None):
@ -196,10 +196,10 @@ class Kernel:
# Kernel builder
# =============================================================================
def build_kernel(N, arch='gfx1100'):
def build_kernel(N):
assert N % 128 == 0, f"N must be a multiple of 128 (tile size), got {N}"
assert N >= 256, f"N must be >= 256 (prefetch pipeline requires at least 2 K-blocks), got {N}"
k = Kernel(arch)
k = Kernel()
# ===========================================================================
# PROLOGUE: Load kernel arguments, compute tile coordinates and addresses
@ -443,7 +443,7 @@ def test_matmul():
dev = Device[Device.DEFAULT]
print(f"Device arch: {dev.renderer.target.arch}")
insts = build_kernel(N, dev.renderer.target.arch)
insts = build_kernel(N)
rng = np.random.default_rng(42)
a = Tensor(rng.random((N, N), dtype=np.float32) - 0.5)

View file

@ -99,13 +99,13 @@ def custom_lds_sync(A:UOp, arch:str) -> UOp:
sink = UOp.sink(A.base, lds, threads, wg, arg=KernelInfo("custom_lds_sync"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
def custom_handwritten(A:UOp, arch:str) -> UOp:
def custom_handwritten(A:UOp) -> UOp:
A = A.flatten()
threads = UOp.special(128, "lidx0")
wg = UOp.special(1, "gidx0")
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=512, addrspace=AddrSpace.LOCAL), (), 'lds') # 128 * 4 bytes
pipes = {getenv("PIPE", "")} if getenv("PIPE", "") else {"SALU", "VALU", "TRANSCENDENTAL", "WMMA"}
k = Kernel(arch)
k = Kernel()
# wrap in loop to filter out icache misses
LOOP_N, UNROLL_N = 8, 5
k.emit(r4.s_mov_b32(s[1], LOOP_N))
@ -145,10 +145,10 @@ def custom_handwritten(A:UOp, arch:str) -> UOp:
sink = UOp.sink(A.base, threads, wg, lds, arg=KernelInfo("custom_handwritten"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
def custom_data_deps(A:UOp, arch:str) -> UOp:
def custom_data_deps(A:UOp) -> UOp:
A = A.flatten()
threads = UOp.special(A.numel(), "lidx0")
k = Kernel(arch)
k = Kernel()
k.emit(s_load_b64(s[0:1], s[0:1], soffset=NULL))
k.emit(s_waitcnt_lgkmcnt(sdst=NULL, simm16=0))
k.emit(v_lshlrev_b32_e32(v[0], 2, v[0]))
@ -198,13 +198,13 @@ class TestCustomKernel(unittest.TestCase):
def test_handwritten(self):
if self.arch != "rdna4": self.skipTest("only tested on rdna4")
a = Tensor.empty(1024, dtype=dtypes.int32).contiguous().realize()
a = Tensor.custom_kernel(a, fxn=functools.partial(custom_handwritten, arch=self.arch))[0]
a = Tensor.custom_kernel(a, fxn=custom_handwritten)[0]
a.realize()
def test_data_deps(self):
if self.arch != "rdna3": self.skipTest("only tested on rdna3")
a = Tensor(np.full(32, 5.0, dtype=np.float32)).realize()
a = Tensor.custom_kernel(a, fxn=functools.partial(custom_data_deps, arch=self.arch))[0]
a = Tensor.custom_kernel(a, fxn=custom_data_deps)[0]
a.realize()
self.assertTrue((a.numpy() == 6.0).all())

View file

@ -728,7 +728,7 @@ class TestCfg(unittest.TestCase):
return amdgpu_cfg(prg.src[4].arg, self.arch)
def test_simple(self):
k = Kernel(arch=self.arch)
k = Kernel()
k.label("entry")
k.emit(s_branch(), target="bb1")
k.label("bb1")
@ -738,7 +738,7 @@ class TestCfg(unittest.TestCase):
self.assertEqual(len(cfg["blocks"]), 2)
def test_diamond(self):
k = Kernel(arch=self.arch)
k = Kernel()
k.label("entry")
k.emit(s_mov_b32(s[0], 0))
k.emit(s_mov_b32(s[1], 0))
@ -772,7 +772,7 @@ class TestCfg(unittest.TestCase):
assert st.startswith("s_code_end") and st.endswith("x)"), st
def test_loop(self):
k = Kernel(arch=self.arch)
k = Kernel()
k.label("entry")
k.emit(s_mov_b32(s[1], 4))
k.label("loop")
@ -784,7 +784,7 @@ class TestCfg(unittest.TestCase):
self.get_cfg("simple_loop", k)
def test_loop_branch(self):
k = Kernel(arch=self.arch)
k = Kernel()
k.label("entry")
k.emit(s_mov_b32(s[1], 4))
k.label("loop")
@ -802,7 +802,7 @@ class TestCfg(unittest.TestCase):
self.get_cfg("loop_if", k)
def test_loop_break(self):
k = Kernel(arch=self.arch)
k = Kernel()
k.label("entry")
k.emit(s_mov_b32(s[1], 8))
k.label("loop")
@ -817,7 +817,7 @@ class TestCfg(unittest.TestCase):
self.get_cfg("loop_break", k)
def test_switch(self):
k = Kernel(arch=self.arch)
k = Kernel()
k.label("entry")
k.emit(s_cmp_eq_i32(s[0], 0))
k.emit(s_cbranch_scc1(), target="case0")
@ -839,7 +839,7 @@ class TestCfg(unittest.TestCase):
self.get_cfg("switch_case", k)
def test_ping_pong(self):
k = Kernel(arch=self.arch)
k = Kernel()
k.label("entry")
k.emit(s_cmp_eq_i32(s[0], 0))
k.emit(s_cbranch_scc1(), target="ping")
@ -858,7 +858,7 @@ class TestCfg(unittest.TestCase):
def test_colored_blocks(self):
N = 10
k = Kernel(arch=self.arch)
k = Kernel()
k.label("entry")
k.emit(s_branch(), target="init0")
for i in range(N):
@ -878,7 +878,7 @@ class TestCfg(unittest.TestCase):
self.get_cfg("test_colored_blocks", k)
def test_jump_back_to_end(self):
k = Kernel(arch=self.arch)
k = Kernel()
k.label("entry")
k.emit(s_mov_b32(s[1], 2))
k.emit(s_cbranch_execz(), target="loop")