mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
cl: image alignment in arch (#16106)
This commit is contained in:
parent
71a8c0da09
commit
105b037c3c
10 changed files with 24 additions and 18 deletions
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
|
|
@ -169,8 +169,8 @@ jobs:
|
|||
run: DEV=PYTHON python3 -m pytest -rA test/backend/test_renderer_failures.py::TestRendererFailures
|
||||
- name: Test IMAGE support
|
||||
run: |
|
||||
IMAGE=1 DEV=PYTHON python3 test/backend/test_ops.py TestOps.test_gemm
|
||||
IMAGE=1 DEV=PYTHON python3 test/backend/test_ops.py TestOps.test_simple_conv2d
|
||||
IMAGE=1 DEV=PYTHON::IMAGE_PITCH_ALIGNMENT=64 python3 test/backend/test_ops.py TestOps.test_gemm
|
||||
IMAGE=1 DEV=PYTHON::IMAGE_PITCH_ALIGNMENT=64 python3 test/backend/test_ops.py TestOps.test_simple_conv2d
|
||||
- name: Test emulated METAL tensor cores
|
||||
run: |
|
||||
DEBUG=2 FORWARD_ONLY=1 DEV=PYTHON::METAL python3 test/backend/test_ops.py TestOps.test_big_gemm
|
||||
|
|
|
|||
|
|
@ -69,7 +69,8 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
|||
sink = graph_rewrite(sink, pm_add_loads, name="** add loads (code)")
|
||||
|
||||
# create image buffers
|
||||
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON"}: sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True)
|
||||
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON"}:
|
||||
sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True, ctx=ren.target.arch)
|
||||
|
||||
# devectorize (TODO: does this need opts?)
|
||||
if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
|
||||
|
|
|
|||
|
|
@ -57,13 +57,13 @@ load_store_indexing = PatternMatcher([
|
|||
|
||||
# ***** load/store grouping *****
|
||||
|
||||
def expand_index(buf:UOp, vec:UOp):
|
||||
def expand_index(ctx, buf:UOp, vec:UOp):
|
||||
# determine optimal image shapes
|
||||
if isinstance(dt:=buf.dtype, ImageDType):
|
||||
x, valid = vec.get_idx().gep(0), vec.get_valid().gep(0)
|
||||
# search for dims that drop the most valid statements
|
||||
best_drop, cands = -1, []
|
||||
for ch, cw in ImageDType.valid_dims(dt):
|
||||
for ch, cw in ImageDType.valid_dims(dt, ctx.target.arch):
|
||||
if (dropped:=len(_drop_valid_stmts(valid, cidx:=uop_given_valid(valid, UOp.vectorize((x//4)%cw, x//(4*cw))), ch, cw))) > best_drop:
|
||||
best_drop, cands = dropped, [(ch, cw, cidx)]
|
||||
elif dropped == best_drop: cands.append((ch, cw, cidx))
|
||||
|
|
@ -366,9 +366,9 @@ pm_imageh_store = PatternMatcher([
|
|||
(UPat(GroupOp.All, name="x"), lambda x: x.cast(dtypes.float))
|
||||
])
|
||||
|
||||
def make_image(ls, buf, off):
|
||||
def make_image(ctx, ls, buf, off):
|
||||
if (vcount:=buf.dtype.vcount) != 1: buf = buf.src[0]
|
||||
if buf.op == Ops.PARAM and not isinstance(dt:=buf.dtype, ImageDType) and (dims:=ImageDType.valid_dims(dt)):
|
||||
if buf.op == Ops.PARAM and not isinstance(dt:=buf.dtype, ImageDType) and (dims:=ImageDType.valid_dims(dt, ctx)):
|
||||
buf = buf.replace(dtype=(dtypes.imageh if dt.base == dtypes.half else dtypes.imagef)((*dims[0], 4)))
|
||||
if vcount != 1: buf = UOp.vectorize(*([buf] * vcount))
|
||||
if ls.op is Ops.LOAD: return ls.replace(src=(buf.index(off, ptr=True),), dtype=dtypes.float.vec(ls.dtype.vcount)).cast(dt.base)
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ def hand_coded_optimizations(k:Scheduler) -> Scheduler:
|
|||
# upcast float4 images, this must be early so we don't accidentally add locals before the upcast
|
||||
if IMAGE:
|
||||
for buf_index,buf in enumerate(k.bufs):
|
||||
if isinstance(buf.src[0].dtype, PtrDType) and ImageDType.valid_dims(buf.src[0].dtype):
|
||||
if isinstance(buf.src[0].dtype, PtrDType) and ImageDType.valid_dims(buf.src[0].dtype, k.ren.target.arch):
|
||||
# part of is_expanded
|
||||
unit_stride_axes_mul_4 = [k.rngs.index(c) for c in k.bufs[buf_index].src[1].get_idx().split_uop(Ops.ADD) if
|
||||
c.op is Ops.RANGE and (c.vmax+1)%4 == 0]
|
||||
|
|
|
|||
|
|
@ -138,11 +138,12 @@ class ImageDType(PtrDType):
|
|||
|
||||
# get list of (height, width) that do not require pitch padding
|
||||
@staticmethod
|
||||
def valid_dims(ptr:PtrDType) -> list[tuple[int,int]]:
|
||||
ALIGN, MAXW, pxls = getenv("IMAGE_PITCH_ALIGN", 256 if OSX else 64), 16384, ptr.size // 4
|
||||
def valid_dims(ptr:PtrDType, arch:str) -> list[tuple[int,int]]:
|
||||
if (ALIGN:=next((int(p.split('=')[1]) for p in arch.split(',') if p.startswith("IMAGE_PITCH_ALIGNMENT=")), 0)) == 0: return []
|
||||
MAXW, pxls = 16384, ptr.size // 4
|
||||
if ptr.base not in (dtypes.half, dtypes.float) or ptr.size > 4*MAXW*MAXW: return []
|
||||
# height=1 images just need to abide by alignment requirements in bytes, not pixels!
|
||||
if ptr.size % (ALIGN * 4) != 0: return [] if ptr.nbytes() % getenv("IMAGE_BASE_ALIGN", 64) != 0 or pxls > MAXW else [(1, pxls)]
|
||||
if ptr.size % (ALIGN * 4) != 0: return [] if ptr.nbytes() % (64 if OSX else ALIGN) != 0 or pxls > MAXW else [(1, pxls)]
|
||||
return [(pxls//ALIGN//k, ALIGN*k) for k in range(ceildiv(pxls//ALIGN, MAXW), min(pxls//ALIGN, MAXW//ALIGN)+1) if (pxls//ALIGN)%k == 0]
|
||||
|
||||
class dtypes:
|
||||
|
|
|
|||
|
|
@ -119,7 +119,12 @@ class CLDevice(Compiled):
|
|||
|
||||
renderer = IntelRenderer if "cl_intel_subgroup_matrix_multiply_accumulate" in self.device_exts else OpenCLRenderer
|
||||
self.cl_compiler = CLCompiler(self, f"{hashlib.md5(self.device_name.encode() + self.driver_version.encode()).hexdigest()}")
|
||||
super().__init__(device, CLAllocator(self), [renderer], functools.partial(CLProgram, self))
|
||||
|
||||
if "cl_khr_image2d_from_buffer" in self.device_exts:
|
||||
check(cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_IMAGE_PITCH_ALIGNMENT, 4, ctypes.byref(ipa := ctypes.c_uint32()), None))
|
||||
arch = f"IMAGE_PITCH_ALIGNMENT={ipa.value}"
|
||||
else: arch = ""
|
||||
super().__init__(device, CLAllocator(self), [renderer], functools.partial(CLProgram, self), arch=arch)
|
||||
|
||||
def count(self) -> int: return len(unwrap(self.device_ids))
|
||||
|
||||
|
|
|
|||
|
|
@ -220,8 +220,7 @@ class PythonRenderer(Renderer):
|
|||
elif target.arch.startswith("sm"):
|
||||
self.target = replace(target, device="CUDA")
|
||||
self.tensor_cores = tc.get_cuda(target.arch)
|
||||
elif target.arch == "": self.target = target
|
||||
else: raise RuntimeError(f"unsupported arch: {target.arch}")
|
||||
else: self.target = target
|
||||
|
||||
def render(self, uops:list[UOp]) -> str:
|
||||
# the value of SPECIAL comes from local/global_size, not form its source
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from tinygrad.runtime.autogen import kgsl, mesa
|
|||
from tinygrad.renderer.cstyle import QCOMCLRenderer
|
||||
from tinygrad.renderer.nir import IR3Renderer
|
||||
from tinygrad.helpers import getenv, mv_address, to_mv, round_up, data64_le, ceildiv, prod, cpu_profile, lo32, suppress_finalizing
|
||||
from tinygrad.helpers import next_power2, flatten, PROFILE
|
||||
from tinygrad.helpers import next_power2, flatten, PROFILE, IMAGE
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.runtime.support.system import System
|
||||
if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import
|
||||
|
|
@ -371,7 +371,7 @@ class QCOMDevice(HCQCompiled):
|
|||
System.write_sysfs("/sys/class/kgsl/kgsl-3d0/idle_timer", value="4000000000", msg="Failed to disable suspend mode", expected="4294967276")
|
||||
|
||||
super().__init__(device, QCOMAllocator(self), [QCOMCLRenderer, IR3Renderer], functools.partial(QCOMProgram, self), QCOMSignal,
|
||||
functools.partial(QCOMComputeQueue, self), arch="a%d%d%d" % self.gpu_id)
|
||||
functools.partial(QCOMComputeQueue, self), arch=("a%d%d%d" + (",IMAGE_PITCH_ALIGNMENT=64" if IMAGE else "")) % self.gpu_id)
|
||||
|
||||
def _gpu_alloc(self, size:int, flags:int=0, uncached=False, fill_zeroes=False) -> HCQBuffer:
|
||||
flags |= flag("KGSL_MEMALIGN", alignment_hint:=12) | kgsl.KGSL_MEMFLAGS_USE_CPU_MAP
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ def disas_adreno(lib:bytes, gpu_id=630):
|
|||
|
||||
class IR3Compiler(Compiler):
|
||||
def __init__(self, arch):
|
||||
assert arch == "a630", "only a630 supported, for now"
|
||||
assert arch.split(',')[0] == "a630", "only a630 supported, for now"
|
||||
self.arch, self.dev_id = arch, mesa.struct_fd_dev_id(630, 0x6030001)
|
||||
self.cc = mesa.ir3_compiler_create(None, self.dev_id, mesa.fd_dev_info(self.dev_id),
|
||||
mesa.struct_ir3_compiler_options(disable_cache=True)).contents
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ def _read_lib(lib, off) -> int: return struct.unpack("I", lib[off:off+4])[0]
|
|||
|
||||
class QCOMCompiler(Compiler):
|
||||
def __init__(self, arch:str):
|
||||
assert arch == "a630", "only a630 supported"
|
||||
assert arch.split(',')[0] == "a630", "only a630 supported"
|
||||
self.arch, self.chip_id, self.llvm_inst = arch, 0x6030001, llvm_qcom.cl_compiler_create_llvm_instance()
|
||||
super().__init__(f"compile_qcomcl_{arch}")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue