mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
remove arch from asm kernel class (#15977)
* rm arch from kernel * update other tests * update abstractions4.py
This commit is contained in:
parent
7a79c2948a
commit
a37b605523
4 changed files with 20 additions and 20 deletions
|
|
@ -105,7 +105,7 @@ def example_3_custom_uop(a:Tensor, correct):
|
|||
def example_5_custom_assembly(a:Tensor, correct):
|
||||
# Kernel class copied from amd_asm_matmul
|
||||
class Kernel:
|
||||
def __init__(self, arch='gfx1100'): self.instructions, self.labels, self.pos, self.arch = [], {}, 0, arch
|
||||
def __init__(self): self.instructions, self.labels, self.pos = [], {}, 0
|
||||
def label(self, name): self.labels[name] = self.pos
|
||||
def emit(self, inst, target=None):
|
||||
self.instructions.append(inst)
|
||||
|
|
|
|||
|
|
@ -167,7 +167,7 @@ PREFETCH_LOADS = [(V_LDS_A_DATA[4+2*i], V_LDS_A_DATA[4+2*i+1], V_GLOBAL_B_ADDR,
|
|||
# =============================================================================
|
||||
|
||||
class Kernel:
|
||||
def __init__(self, arch='gfx1100'): self.instructions, self.labels, self.pos, self.arch = [], {}, 0, arch
|
||||
def __init__(self): self.instructions, self.labels, self.pos = [], {}, 0
|
||||
def label(self, name): self.labels[name] = self.pos
|
||||
|
||||
def emit(self, inst, target=None):
|
||||
|
|
@ -196,10 +196,10 @@ class Kernel:
|
|||
# Kernel builder
|
||||
# =============================================================================
|
||||
|
||||
def build_kernel(N, arch='gfx1100'):
|
||||
def build_kernel(N):
|
||||
assert N % 128 == 0, f"N must be a multiple of 128 (tile size), got {N}"
|
||||
assert N >= 256, f"N must be >= 256 (prefetch pipeline requires at least 2 K-blocks), got {N}"
|
||||
k = Kernel(arch)
|
||||
k = Kernel()
|
||||
|
||||
# ===========================================================================
|
||||
# PROLOGUE: Load kernel arguments, compute tile coordinates and addresses
|
||||
|
|
@ -443,7 +443,7 @@ def test_matmul():
|
|||
dev = Device[Device.DEFAULT]
|
||||
print(f"Device arch: {dev.renderer.target.arch}")
|
||||
|
||||
insts = build_kernel(N, dev.renderer.target.arch)
|
||||
insts = build_kernel(N)
|
||||
|
||||
rng = np.random.default_rng(42)
|
||||
a = Tensor(rng.random((N, N), dtype=np.float32) - 0.5)
|
||||
|
|
|
|||
|
|
@ -99,13 +99,13 @@ def custom_lds_sync(A:UOp, arch:str) -> UOp:
|
|||
sink = UOp.sink(A.base, lds, threads, wg, arg=KernelInfo("custom_lds_sync"))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
||||
|
||||
def custom_handwritten(A:UOp, arch:str) -> UOp:
|
||||
def custom_handwritten(A:UOp) -> UOp:
|
||||
A = A.flatten()
|
||||
threads = UOp.special(128, "lidx0")
|
||||
wg = UOp.special(1, "gidx0")
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=512, addrspace=AddrSpace.LOCAL), (), 'lds') # 128 * 4 bytes
|
||||
pipes = {getenv("PIPE", "")} if getenv("PIPE", "") else {"SALU", "VALU", "TRANSCENDENTAL", "WMMA"}
|
||||
k = Kernel(arch)
|
||||
k = Kernel()
|
||||
# wrap in loop to filter out icache misses
|
||||
LOOP_N, UNROLL_N = 8, 5
|
||||
k.emit(r4.s_mov_b32(s[1], LOOP_N))
|
||||
|
|
@ -145,10 +145,10 @@ def custom_handwritten(A:UOp, arch:str) -> UOp:
|
|||
sink = UOp.sink(A.base, threads, wg, lds, arg=KernelInfo("custom_handwritten"))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
||||
|
||||
def custom_data_deps(A:UOp, arch:str) -> UOp:
|
||||
def custom_data_deps(A:UOp) -> UOp:
|
||||
A = A.flatten()
|
||||
threads = UOp.special(A.numel(), "lidx0")
|
||||
k = Kernel(arch)
|
||||
k = Kernel()
|
||||
k.emit(s_load_b64(s[0:1], s[0:1], soffset=NULL))
|
||||
k.emit(s_waitcnt_lgkmcnt(sdst=NULL, simm16=0))
|
||||
k.emit(v_lshlrev_b32_e32(v[0], 2, v[0]))
|
||||
|
|
@ -198,13 +198,13 @@ class TestCustomKernel(unittest.TestCase):
|
|||
def test_handwritten(self):
|
||||
if self.arch != "rdna4": self.skipTest("only tested on rdna4")
|
||||
a = Tensor.empty(1024, dtype=dtypes.int32).contiguous().realize()
|
||||
a = Tensor.custom_kernel(a, fxn=functools.partial(custom_handwritten, arch=self.arch))[0]
|
||||
a = Tensor.custom_kernel(a, fxn=custom_handwritten)[0]
|
||||
a.realize()
|
||||
|
||||
def test_data_deps(self):
|
||||
if self.arch != "rdna3": self.skipTest("only tested on rdna3")
|
||||
a = Tensor(np.full(32, 5.0, dtype=np.float32)).realize()
|
||||
a = Tensor.custom_kernel(a, fxn=functools.partial(custom_data_deps, arch=self.arch))[0]
|
||||
a = Tensor.custom_kernel(a, fxn=custom_data_deps)[0]
|
||||
a.realize()
|
||||
self.assertTrue((a.numpy() == 6.0).all())
|
||||
|
||||
|
|
|
|||
|
|
@ -728,7 +728,7 @@ class TestCfg(unittest.TestCase):
|
|||
return amdgpu_cfg(prg.src[4].arg, self.arch)
|
||||
|
||||
def test_simple(self):
|
||||
k = Kernel(arch=self.arch)
|
||||
k = Kernel()
|
||||
k.label("entry")
|
||||
k.emit(s_branch(), target="bb1")
|
||||
k.label("bb1")
|
||||
|
|
@ -738,7 +738,7 @@ class TestCfg(unittest.TestCase):
|
|||
self.assertEqual(len(cfg["blocks"]), 2)
|
||||
|
||||
def test_diamond(self):
|
||||
k = Kernel(arch=self.arch)
|
||||
k = Kernel()
|
||||
k.label("entry")
|
||||
k.emit(s_mov_b32(s[0], 0))
|
||||
k.emit(s_mov_b32(s[1], 0))
|
||||
|
|
@ -772,7 +772,7 @@ class TestCfg(unittest.TestCase):
|
|||
assert st.startswith("s_code_end") and st.endswith("x)"), st
|
||||
|
||||
def test_loop(self):
|
||||
k = Kernel(arch=self.arch)
|
||||
k = Kernel()
|
||||
k.label("entry")
|
||||
k.emit(s_mov_b32(s[1], 4))
|
||||
k.label("loop")
|
||||
|
|
@ -784,7 +784,7 @@ class TestCfg(unittest.TestCase):
|
|||
self.get_cfg("simple_loop", k)
|
||||
|
||||
def test_loop_branch(self):
|
||||
k = Kernel(arch=self.arch)
|
||||
k = Kernel()
|
||||
k.label("entry")
|
||||
k.emit(s_mov_b32(s[1], 4))
|
||||
k.label("loop")
|
||||
|
|
@ -802,7 +802,7 @@ class TestCfg(unittest.TestCase):
|
|||
self.get_cfg("loop_if", k)
|
||||
|
||||
def test_loop_break(self):
|
||||
k = Kernel(arch=self.arch)
|
||||
k = Kernel()
|
||||
k.label("entry")
|
||||
k.emit(s_mov_b32(s[1], 8))
|
||||
k.label("loop")
|
||||
|
|
@ -817,7 +817,7 @@ class TestCfg(unittest.TestCase):
|
|||
self.get_cfg("loop_break", k)
|
||||
|
||||
def test_switch(self):
|
||||
k = Kernel(arch=self.arch)
|
||||
k = Kernel()
|
||||
k.label("entry")
|
||||
k.emit(s_cmp_eq_i32(s[0], 0))
|
||||
k.emit(s_cbranch_scc1(), target="case0")
|
||||
|
|
@ -839,7 +839,7 @@ class TestCfg(unittest.TestCase):
|
|||
self.get_cfg("switch_case", k)
|
||||
|
||||
def test_ping_pong(self):
|
||||
k = Kernel(arch=self.arch)
|
||||
k = Kernel()
|
||||
k.label("entry")
|
||||
k.emit(s_cmp_eq_i32(s[0], 0))
|
||||
k.emit(s_cbranch_scc1(), target="ping")
|
||||
|
|
@ -858,7 +858,7 @@ class TestCfg(unittest.TestCase):
|
|||
|
||||
def test_colored_blocks(self):
|
||||
N = 10
|
||||
k = Kernel(arch=self.arch)
|
||||
k = Kernel()
|
||||
k.label("entry")
|
||||
k.emit(s_branch(), target="init0")
|
||||
for i in range(N):
|
||||
|
|
@ -878,7 +878,7 @@ class TestCfg(unittest.TestCase):
|
|||
self.get_cfg("test_colored_blocks", k)
|
||||
|
||||
def test_jump_back_to_end(self):
|
||||
k = Kernel(arch=self.arch)
|
||||
k = Kernel()
|
||||
k.label("entry")
|
||||
k.emit(s_mov_b32(s[1], 2))
|
||||
k.emit(s_cbranch_execz(), target="loop")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue