cl: image alignment in arch (#16106)

This commit is contained in:
Christopher Milan 2026-05-08 16:33:33 -07:00 committed by GitHub
commit 105b037c3c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 24 additions and 18 deletions

View file

@ -169,8 +169,8 @@ jobs:
run: DEV=PYTHON python3 -m pytest -rA test/backend/test_renderer_failures.py::TestRendererFailures
- name: Test IMAGE support
run: |
IMAGE=1 DEV=PYTHON python3 test/backend/test_ops.py TestOps.test_gemm
IMAGE=1 DEV=PYTHON python3 test/backend/test_ops.py TestOps.test_simple_conv2d
IMAGE=1 DEV=PYTHON::IMAGE_PITCH_ALIGNMENT=64 python3 test/backend/test_ops.py TestOps.test_gemm
IMAGE=1 DEV=PYTHON::IMAGE_PITCH_ALIGNMENT=64 python3 test/backend/test_ops.py TestOps.test_simple_conv2d
- name: Test emulated METAL tensor cores
run: |
DEBUG=2 FORWARD_ONLY=1 DEV=PYTHON::METAL python3 test/backend/test_ops.py TestOps.test_big_gemm

View file

@ -69,7 +69,8 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
sink = graph_rewrite(sink, pm_add_loads, name="** add loads (code)")
# create image buffers
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON"}: sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True)
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON"}:
sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True, ctx=ren.target.arch)
# devectorize (TODO: does this need opts?)
if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing

View file

@ -57,13 +57,13 @@ load_store_indexing = PatternMatcher([
# ***** load/store grouping *****
def expand_index(buf:UOp, vec:UOp):
def expand_index(ctx, buf:UOp, vec:UOp):
# determine optimal image shapes
if isinstance(dt:=buf.dtype, ImageDType):
x, valid = vec.get_idx().gep(0), vec.get_valid().gep(0)
# search for dims that drop the most valid statements
best_drop, cands = -1, []
for ch, cw in ImageDType.valid_dims(dt):
for ch, cw in ImageDType.valid_dims(dt, ctx.target.arch):
if (dropped:=len(_drop_valid_stmts(valid, cidx:=uop_given_valid(valid, UOp.vectorize((x//4)%cw, x//(4*cw))), ch, cw))) > best_drop:
best_drop, cands = dropped, [(ch, cw, cidx)]
elif dropped == best_drop: cands.append((ch, cw, cidx))
@ -366,9 +366,9 @@ pm_imageh_store = PatternMatcher([
(UPat(GroupOp.All, name="x"), lambda x: x.cast(dtypes.float))
])
def make_image(ls, buf, off):
def make_image(ctx, ls, buf, off):
if (vcount:=buf.dtype.vcount) != 1: buf = buf.src[0]
if buf.op == Ops.PARAM and not isinstance(dt:=buf.dtype, ImageDType) and (dims:=ImageDType.valid_dims(dt)):
if buf.op == Ops.PARAM and not isinstance(dt:=buf.dtype, ImageDType) and (dims:=ImageDType.valid_dims(dt, ctx)):
buf = buf.replace(dtype=(dtypes.imageh if dt.base == dtypes.half else dtypes.imagef)((*dims[0], 4)))
if vcount != 1: buf = UOp.vectorize(*([buf] * vcount))
if ls.op is Ops.LOAD: return ls.replace(src=(buf.index(off, ptr=True),), dtype=dtypes.float.vec(ls.dtype.vcount)).cast(dt.base)

View file

@ -51,7 +51,7 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
# upcast float4 images, this must be early so we don't accidentally add locals before the upcast
if IMAGE:
for buf_index,buf in enumerate(k.bufs):
if isinstance(buf.src[0].dtype, PtrDType) and ImageDType.valid_dims(buf.src[0].dtype):
if isinstance(buf.src[0].dtype, PtrDType) and ImageDType.valid_dims(buf.src[0].dtype, k.ren.target.arch):
# part of is_expanded
unit_stride_axes_mul_4 = [k.rngs.index(c) for c in k.bufs[buf_index].src[1].get_idx().split_uop(Ops.ADD) if
c.op is Ops.RANGE and (c.vmax+1)%4 == 0]

View file

@ -138,11 +138,12 @@ class ImageDType(PtrDType):
# get list of (height, width) that do not require pitch padding
@staticmethod
def valid_dims(ptr:PtrDType) -> list[tuple[int,int]]:
ALIGN, MAXW, pxls = getenv("IMAGE_PITCH_ALIGN", 256 if OSX else 64), 16384, ptr.size // 4
def valid_dims(ptr:PtrDType, arch:str) -> list[tuple[int,int]]:
if (ALIGN:=next((int(p.split('=')[1]) for p in arch.split(',') if p.startswith("IMAGE_PITCH_ALIGNMENT=")), 0)) == 0: return []
MAXW, pxls = 16384, ptr.size // 4
if ptr.base not in (dtypes.half, dtypes.float) or ptr.size > 4*MAXW*MAXW: return []
# height=1 images just need to abide by alignment requirements in bytes, not pixels!
if ptr.size % (ALIGN * 4) != 0: return [] if ptr.nbytes() % getenv("IMAGE_BASE_ALIGN", 64) != 0 or pxls > MAXW else [(1, pxls)]
if ptr.size % (ALIGN * 4) != 0: return [] if ptr.nbytes() % (64 if OSX else ALIGN) != 0 or pxls > MAXW else [(1, pxls)]
return [(pxls//ALIGN//k, ALIGN*k) for k in range(ceildiv(pxls//ALIGN, MAXW), min(pxls//ALIGN, MAXW//ALIGN)+1) if (pxls//ALIGN)%k == 0]
class dtypes:

View file

@ -119,7 +119,12 @@ class CLDevice(Compiled):
renderer = IntelRenderer if "cl_intel_subgroup_matrix_multiply_accumulate" in self.device_exts else OpenCLRenderer
self.cl_compiler = CLCompiler(self, f"{hashlib.md5(self.device_name.encode() + self.driver_version.encode()).hexdigest()}")
super().__init__(device, CLAllocator(self), [renderer], functools.partial(CLProgram, self))
if "cl_khr_image2d_from_buffer" in self.device_exts:
check(cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_IMAGE_PITCH_ALIGNMENT, 4, ctypes.byref(ipa := ctypes.c_uint32()), None))
arch = f"IMAGE_PITCH_ALIGNMENT={ipa.value}"
else: arch = ""
super().__init__(device, CLAllocator(self), [renderer], functools.partial(CLProgram, self), arch=arch)
def count(self) -> int: return len(unwrap(self.device_ids))

View file

@ -220,8 +220,7 @@ class PythonRenderer(Renderer):
elif target.arch.startswith("sm"):
self.target = replace(target, device="CUDA")
self.tensor_cores = tc.get_cuda(target.arch)
elif target.arch == "": self.target = target
else: raise RuntimeError(f"unsupported arch: {target.arch}")
else: self.target = target
def render(self, uops:list[UOp]) -> str:
# the value of SPECIAL comes from local/global_size, not form its source

View file

@ -9,7 +9,7 @@ from tinygrad.runtime.autogen import kgsl, mesa
from tinygrad.renderer.cstyle import QCOMCLRenderer
from tinygrad.renderer.nir import IR3Renderer
from tinygrad.helpers import getenv, mv_address, to_mv, round_up, data64_le, ceildiv, prod, cpu_profile, lo32, suppress_finalizing
from tinygrad.helpers import next_power2, flatten, PROFILE
from tinygrad.helpers import next_power2, flatten, PROFILE, IMAGE
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.runtime.support.system import System
if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import
@ -371,7 +371,7 @@ class QCOMDevice(HCQCompiled):
System.write_sysfs("/sys/class/kgsl/kgsl-3d0/idle_timer", value="4000000000", msg="Failed to disable suspend mode", expected="4294967276")
super().__init__(device, QCOMAllocator(self), [QCOMCLRenderer, IR3Renderer], functools.partial(QCOMProgram, self), QCOMSignal,
functools.partial(QCOMComputeQueue, self), arch="a%d%d%d" % self.gpu_id)
functools.partial(QCOMComputeQueue, self), arch=("a%d%d%d" + (",IMAGE_PITCH_ALIGNMENT=64" if IMAGE else "")) % self.gpu_id)
def _gpu_alloc(self, size:int, flags:int=0, uncached=False, fill_zeroes=False) -> HCQBuffer:
flags |= flag("KGSL_MEMALIGN", alignment_hint:=12) | kgsl.KGSL_MEMFLAGS_USE_CPU_MAP

View file

@ -93,7 +93,7 @@ def disas_adreno(lib:bytes, gpu_id=630):
class IR3Compiler(Compiler):
def __init__(self, arch):
assert arch == "a630", "only a630 supported, for now"
assert arch.split(',')[0] == "a630", "only a630 supported, for now"
self.arch, self.dev_id = arch, mesa.struct_fd_dev_id(630, 0x6030001)
self.cc = mesa.ir3_compiler_create(None, self.dev_id, mesa.fd_dev_info(self.dev_id),
mesa.struct_ir3_compiler_options(disable_cache=True)).contents

View file

@ -9,7 +9,7 @@ def _read_lib(lib, off) -> int: return struct.unpack("I", lib[off:off+4])[0]
class QCOMCompiler(Compiler):
def __init__(self, arch:str):
assert arch == "a630", "only a630 supported"
assert arch.split(',')[0] == "a630", "only a630 supported"
self.arch, self.chip_id, self.llvm_inst = arch, 0x6030001, llvm_qcom.cl_compiler_create_llvm_instance()
super().__init__(f"compile_qcomcl_{arch}")