metal: arch is GPU family (#16223)

This commit is contained in:
Christopher Milan 2026-05-15 18:22:48 -07:00 committed by GitHub
commit 79c0ae5b89
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 6 additions and 15 deletions

View file

@ -59,8 +59,3 @@ kernel void r_5(device int* data0, const device int* data1, uint3 gid [[threadgr
self.assertEqual(curr:=device.sysdevice.currentAllocatedSize(), before+size, msg=f"{curr=} - {before=}")
device.allocator.free(buf, buf.size, BufferSpec(nolru=True))
self.assertEqual(curr:=device.sysdevice.currentAllocatedSize(), before, msg=f"{curr=} - {before=}")
def test_gpu_family(self):
device = Device['METAL']
self.assertGreater(device.gpu_family, 0)
self.assertLessEqual(device.gpu_family, 15)

View file

@ -342,7 +342,7 @@ def is_dtype_supported(dtype:DType, target:Target|None=None) -> bool:
target = target or DEV.target(Device.DEFAULT)
if dtype == dtypes.bfloat16:
match target.device:
case "METAL": return not CI or BENCHMARKS
case "METAL": target.arch.startswith("Apple") and int(target.arch[5:]) >= 6
case "CUDA": return (not CI or BENCHMARKS) and target.renderer != "PTX"
case "NV": return (not CI or BENCHMARKS) and target.renderer not in ("PTX", "NAK")
case "CPU": return (not CI or BENCHMARKS) and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} and target.renderer != "LVP"

View file

@ -343,7 +343,7 @@ class MetalRenderer(CStyleLanguage):
def __init__(self, target:Target):
super().__init__(target)
from tinygrad.runtime.ops_metal import MetalCompiler
self.compiler, self.tensor_cores = MetalCompiler(), tc.metal if target.arch == "arm64" else []
self.compiler, self.tensor_cores = MetalCompiler(), tc.metal if target.arch.startswith("Apple") and int(target.arch[5:]) >= 7 else []
# language options
kernel_typedef = "kernel void"

View file

@ -23,7 +23,7 @@ class MetalGraph(GraphRunner):
self.icb = self.dev.sysdevice.newIndirectCommandBufferWithDescriptor_maxCommandCount_options(icb_descriptor, len(self.calls),
metal.MTLResourceCPUCacheModeDefaultCache)
if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?")
self.needs_icb_fix = int(self.dev.gpu_family < 9) # ICB fix not required on M3+ (Apple9+)
self.needs_icb_fix = int(not self.dev.arch.startswith("Apple") or int(self.dev.arch[5:]) < 9) # ICB fix not required on M3+ (Apple9+)
if len(self.vars): self.int_buf = self.dev.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)

View file

@ -37,12 +37,8 @@ class MetalDevice(Compiled):
self.timeline_signal = self.sysdevice.newSharedEvent()
self.timeline_value = 0
# probe GPU family: Apple9=M3/M4, Apple8=M2, Apple7=M1, etc. values are 1000+N.
self.gpu_family = 0
for i in range(15, 0, -1):
if self.sysdevice.supportsFamily(1000 + i):
self.gpu_family = i
break
# https://developer.apple.com/documentation/metal/mtlgpufamily
def check_family(f): return next(filter(self.sysdevice.supportsFamily, reversed([v for v, nm in metal.enum_MTLGPUFamily.items() if f in nm])), 0)
Compiled.profile_events += [ProfileDeviceEvent(device)]
@ -51,7 +47,7 @@ class MetalDevice(Compiled):
# This can be reproduced locally with any virtualization software (like utm) that can create macOS VMs with apple's own virtualization framework.
super().__init__(device, MetalAllocator(self), [MetalRenderer],
functools.partial(MetalProgram, self), MetalGraph if 'virtual' not in from_ns_str(self.sysdevice.name()).lower() else None,
arch=platform.machine())
arch=metal.enum_MTLGPUFamily[check_family("Apple") or check_family("Mac")][12:])
def synchronize(self):
for cbuf in self.mtl_buffers_in_flight: