remove intel and amx support (#16482)

This commit is contained in:
George Hotz 2026-06-02 18:53:05 -07:00 committed by GitHub
commit ffadd7a315
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 16 additions and 393 deletions

View file

@ -126,12 +126,6 @@ jobs:
run: BIG=2 MPS=1 python3.11 test/speed/external_test_speed_v_torch.py
- name: Test tensor cores
run: DEV=METAL python3.11 test/opt/test_tensor_cores.py
- name: Test AMX tensor cores
run: |
DEBUG=2 DEV=CPU AMX=1 python3.11 test/opt/test_tensor_cores.py
DEBUG=2 DEV=CPU:LLVM AMX=1 python3.11 test/opt/test_tensor_cores.py
DEBUG=2 DEV=CPU AMX=1 python3.11 test/opt/test_gen_float4.py TestFloat4.test_float4_multidim_amx TestFloat4.test_float4_multidim_unaligned_load_amx
DEBUG=2 DEV=CPU:LLVM AMX=1 python3.11 test/opt/test_gen_float4.py TestFloat4.test_float4_multidim_amx TestFloat4.test_float4_multidim_unaligned_load_amx
- name: Run Tensor Core GEMM (float)
run: DEBUG=2 SHOULD_USE_TC=1 python3.11 extra/gemm/simple_matmul.py
- name: Run Tensor Core GEMM (half)

View file

@ -183,21 +183,11 @@ jobs:
DEBUG=2 ALLOW_TF32=1 DEV=PYTHON::sm_80 python3 test/backend/test_ops.py TestOps.test_gemm
DEBUG=2 DEV=PYTHON::sm_75 python3 test/backend/test_ops.py TestOps.test_gemm_fp16
ALLOW_TF32=1 DEV=PYTHON::sm_89 python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test emulated INTEL OpenCL tensor cores
run: DEBUG=2 DEV=PYTHON::INTEL HALF=1 N=64 python3 ./extra/gemm/simple_matmul.py
- name: Test emulated AMX tensor cores
env:
DEV: 'PYTHON::AMX'
run: |
DEBUG=2 python3 test/backend/test_ops.py TestOps.test_gemm
python3 -m pytest -nauto test/opt/test_tensor_cores.py
- name: Test device flop counts
run: |
DEBUG=2 DEV=PYTHON::METAL python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
DEBUG=2 DEV=PYTHON::gfx1100 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
DEBUG=2 DEV=PYTHON::sm_80 python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
DEBUG=2 DEV=PYTHON::INTEL python3 ./test/null/test_uops_stats.py TestUOpsStatsMatmulHalf
DEBUG=2 DEV=PYTHON::AMX python3 ./test/null/test_uops_stats.py TestUOpsStats.test_simple_matmul
linter:
name: Linters

View file

@ -62,7 +62,7 @@ A lot of work can still be done here. For example, we never copy the inputs to o
Many accelerators have Tensor Cores / MAC arrays / systolic arrays. The main value of these is that, since they are 2-D, they create an n^2 ratio between the compute and the input data.
GPUs use Tensor Cores instead of MAC arrays to fit better in the GPU warp paradigm. This is because the output of Tensor Cores is O(n) wrt the input, while the output of MAC arrays like the AMX is O(n^2)
GPUs use Tensor Cores instead of MAC arrays to fit better in the GPU warp paradigm. This is because the output of Tensor Cores is O(n) wrt the input, while the output of MAC arrays is O(n^2)
We have a simple framework in tinygrad for adding these ALU blocks and achieving good performance from them.

View file

@ -83,9 +83,5 @@ NV backend supports several interfaces for communicating with devices:
## CPU Arch
The CPU renderers may be additionally configured using the arch component of [the `DEV` environment variable](env_vars.md#dev-variable).
CPU arch should be specified as a comma-separated list of parameters, and must contain at least two values: the architecture family (ie. x86_64, arm64, or riscv64) and the cpu type (as accepted by `clang`'s `-march`).
If native is specified as the cpu type, tinygrad (or delegate compiler) will query the host cpu type. Additional comma-separated values may be specified as follows:
* `AMX`: emit Apple silicon AMX instructions
All other additional values are interpreted as cpu feature flags. When a value is preceded by a `-` character, the corresponding feature flag will be disabled, otherwise the flag will be enabled.
If native is specified as the cpu type, tinygrad (or delegate compiler) will query the host cpu type. Additional comma-separated values are interpreted as cpu feature flags. When a value is preceded by a `-` character, the corresponding feature flag will be disabled, otherwise the flag will be enabled.
Note that enabled feature flags should not be preceded by a `+`.

View file

@ -1,180 +0,0 @@
#!/usr/bin/env python3
import numpy as np
import time
import sys
np.set_printoptions(linewidth=160)
np.set_printoptions(linewidth=1000, threshold=10000000000, suppress=False)
from tinygrad.runtime.ops_llvm import LLVMDevice, LLVMProgram, LLVMCompiler
from llvmlite import ir # type: ignore
from tinygrad.helpers import flat_mv
from tinygrad.device import MallocAllocator
# https://github.com/corsix/amx/blob/main/Instructions.md
# 12 lines for AMX support
from functools import partialmethod
class AMX:
@staticmethod
def nop_op_imm5(op, imm5, builder): builder.asm(ir.FunctionType(ir.VoidType(), []), f".word (0x201000 + ({op} << 5) + {imm5}); amx op {op} imm {imm5}", "", tuple(), True)
@staticmethod
def op_gpr(op, builder, gpr): builder.asm(ir.FunctionType(ir.VoidType(), [ir.IntType(64)]), f".word (0x201000 + ({op} << 5) + 0$0 - ((0$0 >> 4) * 6)); amx op {op} reg $0", "r", (gpr,), True)
set, clr = partialmethod(nop_op_imm5, 17, 0), partialmethod(nop_op_imm5, 17, 1)
ldx, ldy, stx, sty = partialmethod(op_gpr, 0), partialmethod(op_gpr, 1), partialmethod(op_gpr, 2), partialmethod(op_gpr, 3)
ldz, stz, ldzi, stzi = partialmethod(op_gpr, 4), partialmethod(op_gpr, 5), partialmethod(op_gpr, 6), partialmethod(op_gpr, 7)
extrx, extry = partialmethod(op_gpr, 8), partialmethod(op_gpr, 9)
fma64, fms64, fma32, fms32 = partialmethod(op_gpr, 10), partialmethod(op_gpr, 11), partialmethod(op_gpr, 12), partialmethod(op_gpr, 13)
mac16, fma16, fms16 = partialmethod(op_gpr, 14), partialmethod(op_gpr, 15), partialmethod(op_gpr, 16)
vecint, vecfp, matint, matfp, genlut = partialmethod(op_gpr, 18), partialmethod(op_gpr, 19), partialmethod(op_gpr, 20), partialmethod(op_gpr, 21), partialmethod(op_gpr, 22)
def int_const(x): return ir.Constant(ir.IntType(64), x)
N = 4096
# N = 1024
# N = 64
BW = N*N*4
# matrix is 64M, max load bandwidth is 57 GB/s
# cache line looks like 256 bytes (64 floats)
na = np.zeros((256), dtype=np.float32)
# na = np.zeros((N, N), dtype=np.float32)
nb = np.random.randn(N, N).astype(np.float32)
nc = np.random.randn(N, N).astype(np.float32)
ns = nb.reshape(-1, 32).sum(axis=0)
a = MallocAllocator.alloc(na.nbytes)
b = MallocAllocator.alloc(nb.nbytes)
c = MallocAllocator.alloc(nc.nbytes)
MallocAllocator._copyin(b, flat_mv(nb.data))
MallocAllocator._copyin(c, flat_mv(nc.data))
module = ir.Module(name=__file__)
func = ir.Function(module, ir.FunctionType(ir.IntType(64), [ir.FloatType().as_pointer()]*3), name='exec')
# load all
entry = ir.IRBuilder(func.append_basic_block(name="entry"))
zm, xm, ym = [entry.ptrtoint(func.args[i], ir.IntType(64)) for i in range(3)]
loop_1 = ir.IRBuilder(func.append_basic_block(name="loop_y"))
loop_1_exit = ir.IRBuilder(func.append_basic_block(name="loop_y_exit"))
exit = ir.IRBuilder(func.append_basic_block(name="exit"))
y = loop_1.phi(ir.IntType(64), name="y")
y.add_incoming(int_const(0), entry._block)
yp = loop_1_exit.add(y, int_const(32*2))
y.add_incoming(yp, loop_1_exit._block)
prefetch_function = ir.Function(module, ir.FunctionType(ir.VoidType(), [ir.PointerType(ir.FloatType()), ir.IntType(32), ir.IntType(32), ir.IntType(32)]), name="llvm.prefetch")
xptr = y
addr = loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))
#prefetch_ptr = loop_1_exit.inttoptr(loop_1_exit.add(addr, int_const(128)), ir.PointerType(ir.FloatType()))
#loop_1_exit.call(prefetch_function, [prefetch_ptr, ir.IntType(32)(0), ir.IntType(32)(2), ir.IntType(32)(1)])
AMX.ldx(loop_1_exit, loop_1_exit.add(int_const(1<<62), addr))
xptr = loop_1_exit.add(xptr, int_const(32))
AMX.ldy(loop_1_exit, loop_1_exit.add(int_const(1<<62), loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))))
AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28))
AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28 | 1 << 20 | (16*4)<<10))
AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 29))
AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 29 | 1 << 20 | (16*4)))
AMX.set(entry)
AMX.stz(exit, exit.add(zm, int_const(1 << 62 | (0 << 56) | 0)))
AMX.clr(exit)
entry.branch(loop_1._block)
loop_1.branch(loop_1_exit._block)
loop_1_exit.cbranch(loop_1_exit.icmp_unsigned("==", yp, int_const(N*N)), exit._block, loop_1._block)
exit.ret(int_const(0))
device = LLVMDevice("llvm")
prog = LLVMProgram(device, "exec", LLVMCompiler(device).compile(str(module)))
"""
loop_1 = ir.IRBuilder(func.append_basic_block(name="loop_y"))
loop_2 = ir.IRBuilder(func.append_basic_block(name="loop_x"))
loop_3 = ir.IRBuilder(func.append_basic_block(name="loop_k"))
loop_3_exit = ir.IRBuilder(func.append_basic_block(name="loop_k_exit"))
loop_2_exit = ir.IRBuilder(func.append_basic_block(name="loop_x_exit"))
loop_1_exit = ir.IRBuilder(func.append_basic_block(name="loop_y_exit"))
y = loop_1.phi(ir.IntType(64), name="y")
x = loop_2.phi(ir.IntType(64), name="x")
k = loop_3.phi(ir.IntType(64), name="k")
exit = ir.IRBuilder(func.append_basic_block(name="exit"))
AMX.set(loop_2)
# stride
xptr = loop_3_exit.add(x, loop_3_exit.mul(k, int_const(N)))
yptr = loop_3_exit.add(y, loop_3_exit.mul(k, int_const(N)))
# if you are okay with the wrong answer, this is faster
#xptr = loop_3_exit.add(x, loop_3_exit.mul(k, int_const(32)))
#yptr = loop_3_exit.add(y, loop_3_exit.mul(k, int_const(32)))
# double loads load 32 floats
AMX.ldx(loop_3_exit, loop_3_exit.add(int_const(1<<62), loop_3_exit.add(xm, loop_3_exit.mul(int_const(4), xptr))))
AMX.ldy(loop_3_exit, loop_3_exit.add(int_const(1<<62), loop_3_exit.add(ym, loop_3_exit.mul(int_const(4), yptr))))
# <Z row> <X offset> <Y offset>
AMX.fma32(loop_3_exit, int_const(0<<20 | (0*16*4)<<10 | (0*16*4)))
AMX.fma32(loop_3_exit, int_const(1<<20 | (1*16*4)<<10 | (0*16*4)))
AMX.fma32(loop_3_exit, int_const(2<<20 | (0*16*4)<<10 | (1*16*4)))
AMX.fma32(loop_3_exit, int_const(3<<20 | (1*16*4)<<10 | (1*16*4)))
# store
gptr = loop_2_exit.mul(loop_2_exit.add(loop_2.mul(y, int_const(N)), x), int_const(4))
zmp = loop_2_exit.add(zm, gptr)
for j in range(2):
for r in range(16):
z_row = j*2
ptr = ((j*16)+r)*N
AMX.stz(loop_2_exit, loop_2_exit.add(zmp, int_const(1 << 62 | ((r*4+z_row) << 56) | ptr*4)))
AMX.clr(loop_2_exit)
yp = loop_1_exit.add(y, int_const(32))
xp = loop_2_exit.add(x, int_const(32))
kp = loop_3_exit.add(k, int_const(1))
y.add_incoming(int_const(0), entry._block)
x.add_incoming(int_const(0), loop_1._block)
k.add_incoming(int_const(0), loop_2._block)
y.add_incoming(yp, loop_1_exit._block)
x.add_incoming(xp, loop_2_exit._block)
k.add_incoming(kp, loop_3_exit._block)
entry.branch(loop_1._block)
loop_1.branch(loop_2._block)
loop_2.branch(loop_3._block)
loop_3.branch(loop_3_exit._block)
loop_3_exit.cbranch(loop_3_exit.icmp_unsigned("==", kp, int_const(N)), loop_2_exit._block, loop_3._block)
loop_2_exit.cbranch(loop_2_exit.icmp_unsigned("==", xp, int_const(N)), loop_1_exit._block, loop_2._block)
loop_1_exit.cbranch(loop_1_exit.icmp_unsigned("==", yp, int_const(N)), exit._block, loop_1._block)
exit.ret(int_const(0))
device = LLVMDevice("llvm")
prog = LLVMProgram(device, "exec", LLVMCompiler(device).compile(str(module)))
"""
def timeit(fxn):
st = time.perf_counter()
et = fxn()
return time.perf_counter() - st
tm = min([timeit(lambda: prog(a, b, c, N**2)) for _ in range(20)])
MallocAllocator._copyout(flat_mv(na.data), a)
print(f"{N*N:10d} {tm*1e6:9.2f} us, {BW*1e-9/tm:.2f} GB/s")
np.testing.assert_allclose(na[:ns.shape[0]], ns, atol=1e-4, rtol=1e-4)
# comp = (nb.T @ nc).T
# np.testing.assert_allclose(na, comp, atol=1e-4, rtol=1e-5)

View file

@ -1,43 +0,0 @@
#!/usr/bin/env python3
import numpy as np
from tinygrad.runtime.ops_cl import CLProgram, CLCompiler
from tinygrad import Device, dtypes
from tinygrad.device import Buffer
from hexdump import hexdump
# https://github.com/intel/intel-graphics-compiler/blob/master/documentation/visa/instructions/DPAS.md
# https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html
# https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
# https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_split_matrix_multiply_accumulate.html
# https://hc34.hotchips.org/assets/program/conference/day1/GPU%20HPC/Intel_s%20Ponte%20Vecchio%20GPU%20-%20Architecture%20Systems%20and%20Software%20FINAL.pdf
device = Device["CL"]
# NOTE: only the subgroup type 8 ones work
prog = CLProgram(device, "test", CLCompiler(device, "test").compile(f"""
__attribute__((intel_reqd_sub_group_size(8)))
__kernel void test(__global float* data0, const __global int* data1, const __global int8* data2) {{
int lidx0 = get_local_id(0);
int a = data1[lidx0];
int8 b = data2[lidx0];
float out = intel_sub_group_f16_f16_matrix_mad_k16(a, b, 0.0f);
data0[lidx0] = out;
}}
"""))
#with open("/tmp/test.elf", "wb") as f: f.write(prog.lib)
a = Buffer("CL", 8, dtypes.float32).allocate()
b = Buffer("CL", 0x10, dtypes.float16).allocate()
c = Buffer("CL", 8*0x10, dtypes.float16).allocate()
row = np.array([1,2,3,4,5,6,7,8,1,2,3,4,5,6,7,8], np.float16)
mat = np.random.random((8, 0x10)).astype(np.float16)
b.copyin(row.data)
c.copyin(mat.data)
ret = prog(a._buf, b._buf, c._buf, global_size=[1,1,1], local_size=[8,1,1], wait=True)
print(ret)
out = np.frombuffer(a.as_memoryview(), np.float32)
real = row.astype(np.float32)@mat.T.astype(np.float32)
print("out:", out)
print("real", real)

View file

@ -132,7 +132,6 @@ class TestDevVar(unittest.TestCase):
for d, t in [("AMD", Target(device="AMD", renderer="")), ("AMD:LLVM", Target(device="AMD", renderer="LLVM")),
(":LLVM", Target(device="", renderer="LLVM")), ("AMD::gfx1100", Target(device="AMD", arch="gfx1100")),
("AMD:LLVM:gfx1100", Target(device="AMD", renderer="LLVM", arch="gfx1100")), ("::gfx1100", Target(arch="gfx1100")),
("CPU:LLVM:arm64,native,AMX", Target(device="CPU", renderer="LLVM", arch="arm64,native,AMX")),
("USB+", Target(interface="USB")), ("USB+AMD", Target(device="AMD", interface="USB")),
("PCI:0+AMD", Target(device="AMD", interface="PCI", indices="0")), (":0+AMD", Target(device="AMD", indices="0")),
("PCI:0,1+AMD", Target(device="AMD", interface="PCI", indices="0,1")),

View file

@ -4,11 +4,8 @@ from tinygrad.uop.ops import UOp, Ops
from tinygrad.codegen import to_program
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.helpers import DEV
from test.helpers import replace_opts
AMX = "AMX" in DEV.arch
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4")
class TestFloat4(unittest.TestCase):
@staticmethod
@ -32,7 +29,6 @@ class TestFloat4(unittest.TestCase):
assert TestFloat4.count_float4(tuple(program.src[2].src)) == (2, 1)
@unittest.skipIf(Device.DEFAULT in {"CPU"} and AMX, "CPU with AMX upcasts float up to size 16")
def test_float4_multidim(self):
a = Tensor.empty(2, 8).realize()
b = Tensor.empty(2, 8).realize()
@ -43,25 +39,6 @@ class TestFloat4(unittest.TestCase):
renderer=Device[Device.DEFAULT].renderer).src[2].src)
assert TestFloat4.count_float4(uops) == (4, 2)
@unittest.skipUnless(Device.DEFAULT in {"CPU"} and AMX, "Only CPU with AMX upcasts float up to size 16")
def test_float4_multidim_amx(self):
def kernel_for_shape(size, shift):
a = Tensor.empty(2, size).realize()
b = Tensor.empty(2, size).realize()
c = a + b
s = c.schedule_linear().src[0]
return tuple(to_program(replace_opts(s.src[0], [Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=shift)]),
renderer=Device[Device.DEFAULT].renderer).src[2].src)
sizes = [12, 8, 16]
shifts = [3, 2, 4]
expected_upcast_size = [4, 8, 16]
expected_output = [(6,3), (2,1), (2,1)]
for i in range(len(sizes)):
assert TestFloat4.count_float4(kernel_for_shape(sizes[i], shifts[i]), expected_upcast_size[i]) == expected_output[i]
def test_float4_unaligned_load(self):
a = Tensor.empty(9).realize().shrink(((1, 9),))
b = Tensor.empty(9).realize().shrink(((1, 9),))
@ -74,7 +51,6 @@ class TestFloat4(unittest.TestCase):
assert TestFloat4.count_float4(tuple(program.src[2].src)) == (0, 1)
@unittest.skipIf(Device.DEFAULT in {"CPU"} and AMX, "CPU with AMX upcasts float up to size 16")
def test_float4_multidim_unaligned_load(self):
a = Tensor.empty(2, 9).realize().shrink(((0, 2), (1, 9),))
b = Tensor.empty(2, 9).realize().shrink(((0, 2), (1, 9),))
@ -86,25 +62,6 @@ class TestFloat4(unittest.TestCase):
assert TestFloat4.count_float4(uops) == (0, 2)
@unittest.skipUnless(Device.DEFAULT in {"CPU"} and AMX, "Only CPU with AMX upcasts float up to size 16")
def test_float4_multidim_unaligned_load_amx(self):
def kernel_for_shape(size, shift):
a = Tensor.empty(2, size).realize().shrink(((0, 2), (1, size),))
b = Tensor.empty(2, size).realize().shrink(((0, 2), (1, size),))
c = a + b
s = c.schedule_linear().src[0]
return tuple(to_program(replace_opts(s.src[0], [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=shift)]),
renderer=Device[Device.DEFAULT].renderer).src[2].src)
sizes = [13, 9, 17]
shifts = [3, 2, 4]
expected_upcast_size = [4, 8, 16]
expected_output = [(0,3), (0,1), (0,1)]
for i in range(len(sizes)):
assert TestFloat4.count_float4(kernel_for_shape(sizes[i], shifts[i]), expected_upcast_size[i]) == expected_output[i]
def test_float4_sometimes_unaligned(self):
a = Tensor.empty(1, 1, 8).realize()
b = Tensor.empty(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5)))

View file

@ -18,8 +18,6 @@ from test.backend.test_linearizer import helper_realized_ast, helper_linearizer_
# NOTE: to_program always passes in Device[Device.DEFAULT].renderer explicitly for process_replay!!!
AMX = "AMX" in DEV.arch
def run_program(prg:UOp, bufs:list[Buffer]):
buf_uops = [UOp.new_buffer(b.device, b.size, b.dtype) for b in bufs]
for u,b in zip(buf_uops, bufs): buffers[u] = b
@ -69,15 +67,14 @@ class TestTensorCores(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores(self):
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
# for AMX, tc.dims[2] == 1 so reduceop is None thus tensor_cores are not triggered
helper_tc_allclose(tc.dims[0], tc.dims[1], 2 if AMX else tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0)
helper_tc_allclose(tc.dims[0], tc.dims[1], tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0)
@Context(ALLOW_TF32=1)
@unittest.skipIf(Device.DEFAULT == "PYTHON", "not generated on EMULATED device")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores_codegen(self):
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
n, m, k = tc.dims[0], tc.dims[1], 2 if AMX else tc.dims[2]
n, m, k = tc.dims
a, b = Tensor.rand(m, k, dtype=tc.dtype_in), Tensor.rand(k, n, dtype=tc.dtype_in)
r = a.matmul(b, dtype=tc.dtype_out)
prg = to_program(replace_opts(r.schedule_linear().src[-1].src[0],
@ -127,7 +124,7 @@ class TestTensorCores(unittest.TestCase):
# check excessive padding doesn't trigger padded TC in TC_OPT=2
helper_tc_ensure_uops_and_opts_count(tc.dims[0]//4, tc.dims[1], tc.dims[2], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False)
helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1]//4, tc.dims[2], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False)
if not AMX and tc not in amd_cdna_1616128: # AMX tc.dims[2] == 1
if tc not in amd_cdna_1616128:
helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1], tc.dims[2]//8, tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False)
@Context(ALLOW_TF32=1)

View file

@ -174,7 +174,7 @@ class TestSpeed(unittest.TestCase):
def test_permute(self):
for N in [1024, 4096]:
# this is a 64MB tensor, M1 L1 cache is 128kB
# to fit easily in L1, rotations should be 128x128 chunks. 128x128 is also the AMX size
# to fit easily in L1, rotations should be 128x128 chunks.
def f(a, b): return a.permute(1,0).contiguous()
helper_test_generic_square('permute', N, f, f, onearg=True)

View file

@ -171,7 +171,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
lengths = [4]
elif ctx is not None and ctx.supports_float4:
# TODO: a better way to get this than ctx
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if "AMX" in ctx.target.arch else [4,2])
lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else [4,2]
lengths.append(1) # worst case, it's not folded
# filter fold lengths that don't divide

View file

@ -33,8 +33,7 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
good_tc_opt = True
except KernelOptError:
pass
# skip hand-coded TC opts if AMX, upcasting will make kernel slower
if good_tc_opt and "AMX" not in k.ren.target.arch:
if good_tc_opt:
if rngs is not None:
for tc_dim in [1,0]: # attempt to upcast M and N
szs = [sz for sz in [5,4,3,2] if rngs[tc_dim].src[0].divides(sz) is not None]

View file

@ -143,17 +143,3 @@ metal = [TensorCore(dims=(8,8,8), threads=32, elements_per_thread=(2,2,2), dtype
(('l0', 'r0', 'r1', 'l3', 'r2'), ('u0',), ('l1', 'l2', 'l4'))))
for di,do in [(dtypes.float,dtypes.float),(dtypes.half,dtypes.float),
(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float),(dtypes.bfloat16,dtypes.bfloat16)]]
# ***** Apple AMX *****
amx = [TensorCore(dims=(sz,sz,1), threads=1, elements_per_thread=(sz,sz,sz*sz), dtype_in=dt, dtype_out=dt,
swizzle=(((), ('u0', 'u1', 'u2', 'u3', 'u4', 'u5', 'u6', 'u7'), ()),
((), ('u4', 'u5', 'u6', 'u7', 'u0', 'u1', 'u2', 'u3'), ())),
opts=("u0","u0","u0","u0","u1","u1","u1","u1")) for dt,sz in [(dt, 64 // dt.itemsize) for dt in [dtypes.float]]]
# ***** Intel ****
intel = [TensorCore(dims=(8,8,16), threads=8, elements_per_thread=(16,16,8), dtype_in=dtypes.half, dtype_out=dtypes.float,
opts=("l0","l0","l0","u1","u1","u1"),
swizzle=((('r1', 'r2', 'r3'), ('u0', 'u1', 'u2'), ('l0', 'l1', 'l2', 'r0')),
(('l0', 'l1', 'l2'), ('r1', 'r2', 'r3'), ('u0', 'u1', 'u2', 'r0'))))]

View file

@ -258,24 +258,7 @@ class ClangRenderer(CStyleLanguage):
alignment = 2**int(math.log2(dt.itemsize)) if getenv("ALIGNED", 1) and not dtypes.is_bool(dt) else 1
return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({alignment}),ext_vector_type({dt.count})));"
def _render_defines(self, uops) -> list[str]:
prefix = [self.render_vector_prefix(dt) for dt in uops_to_dtypes(uops) if dt.count > 1]
# https://github.com/corsix/amx
for name, (N, M, _), dtype_in, _, _, _, _, _ in wmma_args(uops):
prefix += [
'#define AMX_SET(imm5) __asm("nop\\nnop\\nnop\\n.word (0x201000+(%0<<5)+%1)" : : "i"(17), "i"(imm5) : "memory")',
'#define AMX(op, gpr, btf) __asm(".word (0x201000+(%0 << 5)+0%1-((0%1>>4)*6))" : : "i"(op), "r"((unsigned long long)(gpr)+(btf)) : "memory")',
]
# 'static' in C roughly means that function symbol isn't exported. LLVM puts those symbols at the end of object file which allows Clang JIT
# to just jump at the start of a shellcode without having to deal with symbols or trampolines at all. This is better than having to inline
# wmma function every time it is called or wasting complexity on a symbol parsing and a memory page on trampoline.
out, dt1, dt2 = self.render_dtype(dtype_in.vec(N*N)), self.render_dtype(dtype_in.vec(N)), self.render_dtype(dtype_in.vec(M))
prefix += [f"""static {out} __{name}({dt1} data1, {dt2} data2, {out} data0){{
AMX_SET(0);\n for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(4, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}
AMX(0, (int *)(&data2), 0ull<<62); AMX(1, (int *)(&data1), 0ull<<62); AMX(12, 0, 0ull);
for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(5, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}
AMX_SET(1);\n return data0;\n}}"""]
return prefix
def _render_defines(self, uops) -> list[str]: return [self.render_vector_prefix(dt) for dt in uops_to_dtypes(uops) if dt.count > 1]
def _render_body(self, function_name, kernel, bufs, uops, pref=None) -> str: return super().render_kernel(function_name, kernel, bufs, uops, pref)
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str: return ""
@ -289,8 +272,7 @@ class ClangRenderer(CStyleLanguage):
def __init__(self, target:Target):
super().__init__(target)
from tinygrad.runtime.support.compiler_cpu import ClangCompiler
if "AMX" in target.arch: self.tensor_cores = tc.amx
self.compiler = ClangCompiler([x for x in target.arch.split(",") if x != "AMX"])
self.compiler = ClangCompiler(target.arch.split(","))
class OpenCLRenderer(CStyleLanguage):
has_aux = True
@ -337,23 +319,6 @@ class OpenCLRenderer(CStyleLanguage):
if (d != dtypes.half or "cl_khr_fp16" in self.target.arch) and
(d != dtypes.double or "cl_khr_fp64" in self.target.arch) and d not in dtypes.fp8s}
class IntelRenderer(OpenCLRenderer):
suffix, kernel_typedef = "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel void"
tensor_cores = tc.intel
string_rewrite = PatternMatcher([
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float),)), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x]})"),
(UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16),)), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x]})"),
]) + OpenCLRenderer.string_rewrite
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
prefix = []
for name, _, dtype_in, dtype_out, _, _, _, _ in wmma_args(uops):
dt_in = ("ushort", "bf16") if dtype_in == dtypes.bfloat16 else (dtype_in.name, "f16")
prefix.append(f"""{dtype_out.name}8 __{name}({dt_in[0]}16 a, {dt_in[0]}16 b, {dtype_out.name}8 c) {{
return intel_sub_group_{dt_in[1]}_{dt_in[1]}_matrix_mad_k16(as_int8(a), as_int8(b), c);\n}}""")
return super().render_kernel(function_name, kernel, bufs, uops, prefix or None)
class MetalRenderer(CStyleLanguage):
shared_max = 32768
def __init__(self, target:Target):

View file

@ -34,19 +34,6 @@ def lcast(input_type:DType, output_type:DType):
if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'sext'
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
# https://github.com/corsix/amx
def render_wmma_amx(ctx, wmma: UOp) -> str:
def AMX(op, gpr): return f'call void asm sideeffect ".word (0x201000+($0<<5)+0$1-((0$1>>4)*6))", "i,r,~{{memory}}"(i32 {op}, i64 {gpr}) #0; AMX'
return "\n".join([
*[f' store {ldt(src.dtype)} {ctx[src]}, {ldt(src.dtype.ptr())} {ctx[wmma]}_amx{i}, align {src.dtype.itemsize}' for i,src in enumerate(wmma.src)],
f' call void asm sideeffect "nop\\0Anop\\0Anop\\0A.word ({0x201000 + (17 << 5) + 0})", "~{{memory}}"() #0; AMX set', # set
*[f' {ctx[wmma]}_ld{i} = add i64 {ctx[wmma]}_ptr_amx2, {i*4<<56 | i*64}\n {AMX(4,f"{ctx[wmma]}_ld{i}")} ldz' for i in range(16)], # ldz
f' {AMX(0, f"{ctx[wmma]}_ptr_amx1")} ldx\n {AMX(1, f"{ctx[wmma]}_ptr_amx0")} ldy\n {AMX(12, 0)} fma32', # ldx ldy fma
*[f' {ctx[wmma]}_st{i} = add i64 {ctx[wmma]}_ptr_amx2, {i*4<<56 | i*64}\n {AMX(5,f"{ctx[wmma]}_st{i}")} stz' for i in range(16)], # stz
f' call void asm sideeffect "nop\\0Anop\\0Anop\\0A.word ({0x201000 + (17 << 5) + 1})", "~{{memory}}"() #0; AMX clr', # clr
f' {ctx[wmma]} = load {ldt(wmma.dtype)}, ptr {ctx[wmma]}_amx2, align {wmma.dtype.itemsize}'])
def render_wmma_amd(ctx, wmma: UOp, cdna=False) -> str:
dt_map = {dtypes.half: "f16", dtypes.float: "f32", dtypes.ushort: "bf16.1k" if cdna else "bf16", dtypes.bfloat16: "bf16.1k" if cdna else "bf16",
dtypes.fp8e4m3: ".fp8.fp8", dtypes.fp8e5m2: ".bf8.bf8"}
@ -147,14 +134,6 @@ class LLVMRenderer(Renderer):
vc = -1
local_args: list[str] = []
for u in uops:
if self.tensor_cores == tc.amx and u.op is Ops.WMMA: # prealloc aux buffers as AMX can only load from memory
vc += 1
r[u] = f"%wmma{vc}"
for i, dtype in enumerate(u.arg[2].vec(sz) for sz in [prod(size for _, size in upcast) for upcast in u.arg[6]]):
kernel += [f" {r[u]}_amx{i} = alloca {ldt(dtype)}, align {dtype.itemsize}",
f" {r[u]}_ptr_amx{i} = ptrtoint {ldt(dtype.ptr())} {r[u]}_amx{i} to i64"]
name = "test"
for u in uops:
if u.op in {Ops.NOOP, Ops.GROUP}: continue
@ -197,14 +176,13 @@ class CPULLVMRenderer(LLVMRenderer):
has_threads = bool(getenv("THREADS", 1))
global_max = (CPU_COUNT.value, 0, 0)
abi = 'win64cc' if sys.platform == 'win32' else None
string_rewrite = base_rewrite + PatternMatcher([(UPat(Ops.WMMA, name="wmma"), render_wmma_amx)])
string_rewrite = base_rewrite
def render(self, uops: list[UOp]) -> str: return "\n".join((k:=self._render_kernel(uops))[0] + (k[1], self._render_footer(uops)))
def _render_footer(self, uops: list[UOp]) -> str: return 'attributes #0 = { alwaysinline nounwind "no-builtins" "no-trapping-math"="true" }'
def __init__(self, target:Target):
super().__init__(target)
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler
if "AMX" in target.arch: self.tensor_cores = tc.amx
self.compiler = CPULLVMCompiler([x for x in target.arch.split(",") if x != "AMX"])
self.compiler = CPULLVMCompiler(target.arch.split(","))
# FIXME: fp16 works on non-osx, but only if the cpu supports it
def supported_dtypes(self):

View file

@ -4,7 +4,7 @@ import ctypes, functools, hashlib
from tinygrad.runtime.autogen import opencl as cl
from tinygrad.runtime.support import c
from tinygrad.helpers import to_char_p_p, from_mv, OSX, DEBUG, mv_address, suppress_finalizing, unwrap
from tinygrad.renderer.cstyle import OpenCLRenderer, IntelRenderer
from tinygrad.renderer.cstyle import OpenCLRenderer
from tinygrad.device import BufferSpec, LRUAllocator, Compiled, Compiler, CompileError
from tinygrad.dtype import ImageDType
@ -117,14 +117,13 @@ class CLDevice(Compiled):
ctypes.byref(buf := ctypes.create_string_buffer(exts_len.value)), None),
ctypes.string_at(buf).decode().split())[1]
renderer = IntelRenderer if "cl_intel_subgroup_matrix_multiply_accumulate" in self.device_exts else OpenCLRenderer
self.cl_compiler = CLCompiler(self, f"{hashlib.md5(self.device_name.encode() + self.driver_version.encode()).hexdigest()}")
arch = ",".join(self.device_exts)
if "cl_khr_image2d_from_buffer" in self.device_exts:
check(cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_IMAGE_PITCH_ALIGNMENT, 4, ctypes.byref(ipa := ctypes.c_uint32()), None))
arch += f",IMAGE_PITCH_ALIGNMENT={ipa.value}"
super().__init__(device, CLAllocator(self), [renderer], functools.partial(CLProgram, self), arch=arch)
super().__init__(device, CLAllocator(self), [OpenCLRenderer], functools.partial(CLProgram, self), arch=arch)
def count(self) -> int: return len(unwrap(self.device_ids))

View file

@ -191,18 +191,6 @@ class PythonProgram:
values[u] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
else: raise NotImplementedError(f"unimplemented tensor core {u.arg}")
elif device == "INTEL":
# A (16 elements on 8 threads)
def a_elem(x, k, row, goff): return x[k%2+row*2][goff+k//2]
# B (16 elements on 8 threads)
def b_elem(x, col, k, goff): return x[k][goff+col]
# C, D (8 elements on 8 threads)
def c_map(lane, elem): return (lane, elem)
values[u] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
elif device == "CPU":
def elem(x, col, row, _): return x[col+row][0] # k is always 0
def c_map(lane, elem): return (elem%16, elem//16)
values[u] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
else: raise NotImplementedError(f"unimplemented tensor core {u.arg}")
elif u.op in GroupOp.ALU:
assert all_same([len(x) for x in src_values]), f"{[len(x) for x in src_values]} doesn't match on {u.op}"
@ -224,8 +212,6 @@ class PythonRenderer(Renderer):
{"AMD":"gfx1100", "AMD_RDNA4":"gfx1201", "AMD_MFMA":"gfx950", "CUDA":"sm_80", "CUDA_SM75":"sm_75", "CUDA_SM89":"sm_89"}.get(emu, emu))
target = replace(target, renderer="PYTHON")
if target.arch == "METAL": self.target, self.tensor_cores = replace(target, device="METAL"), tc.metal
elif target.arch == "INTEL": self.target, self.suffix, self.tensor_cores = replace(target, device="INTEL"), "INTEL", tc.intel
elif target.arch == "AMX": self.target, self.tensor_cores = replace(target, device="CPU"), tc.amx
elif target.arch.startswith("gfx"):
self.target = replace(target, device="AMD")
self.tensor_cores = tc.get_amd(target.arch)

View file

@ -69,8 +69,8 @@ class LLVMCompiler(Compiler):
super().__init__(cache_key or f"compile_llvm_{processor}_{feats}{'_jit' if self.jit else ''}{'_opt' if opt else ''}")
def __del__(self):
llvm.LLVMDisposePassBuilderOptions(self.pbo)
llvm.LLVMContextDispose(self.context)
if hasattr(self, 'pbo'): llvm.LLVMDisposePassBuilderOptions(self.pbo)
if hasattr(self, 'context'): llvm.LLVMContextDispose(self.context)
def compile_to_obj(self, src:str) -> bytes:
self.diag_msgs.clear()