clean renderer migrations [pr] (#16472)

* clean renderer migrations

* minor webgpu

* use PARAM UOp as API

* make linter happy
This commit is contained in:
George Hotz 2026-06-02 11:19:00 -07:00 committed by GitHub
commit 82f1c983d4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 41 additions and 38 deletions

View file

@ -115,7 +115,7 @@ pm_linearize_cleanups = PatternMatcher([
# if statements are not allowed in the graph
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError, "if not allowed in graph")),
# gated STORE becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX).or_casted(), UPat(), UPat(name="gate", dtype=dtypes.bool))),
(UPat(Ops.STORE, name="u", src=(UPat((Ops.INDEX, Ops.SHRINK)).or_casted(), UPat(), UPat(name="gate", dtype=dtypes.bool))),
lambda u, gate: ((st:=u.replace(src=u.src[0:2])), [mif:=UOp(Ops.IF, src=(gate, u.src[0])), st, UOp(Ops.ENDIF, src=(mif,))]))
])

View file

@ -42,7 +42,7 @@ class Estimates:
if buf.op is Ops.PARAM:
# u.src[0] is INDEX, cap at buffer size for re-reads (e.g. matmul)
accessed = mem.get((buf, u.op), 0) + u.src[0].dtype.base.itemsize * mults
mem[(buf, u.op)] = smin(accessed, buf.ptrdtype.nbytes()) if buf.ptrdtype.size != -1 else accessed
mem[(buf, u.op)] = smin(accessed, buf.max_numel() * buf.dtype.itemsize)
if u.op is Ops.RANGE:
mult_stack.append(mults)
mults *= cast(sint, u.src[0].ssimplify())

View file

@ -131,12 +131,12 @@ class CStyleLanguage(Renderer):
string_rewrite = base_rewrite
extra_matcher = extra_pm
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[UOp,bool]]], uops:list[UOp], prefix=None) -> str:
tmp = ""
if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs):
if any(isinstance(u.dtype, ImageDType) for _,(u,_) in bufs):
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
buftypes = [(name, self.render_dtype(dtype, mutable)+self.buffer_suffix if isinstance(dtype, (ImageDType, PtrDType)) else
self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
buftypes = [(name, self.render_dtype(u.dtype, mutable)+self.buffer_suffix if isinstance(u.dtype, (ImageDType, PtrDType)) else
self.arg_int_prefix if u.dtype == dtypes.int else None) for name,(u,mutable) in bufs]
local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]
launch_bounds = prod([d.vmax for d in local_dims])
prg = ''.join([f"{self.kernel_typedef.format(launch_bounds=launch_bounds)} {function_name}(",] +
@ -156,14 +156,14 @@ class CStyleLanguage(Renderer):
return self.type_map.get(scalar:=dt.scalar(), scalar.name)
def __getitem__(self, key): return self.r[key] # hacky helper
def _render(self, uops:list[UOp]) -> tuple[str, list[str], list[tuple[str,tuple[DType,bool]]]]:
def _render(self, uops:list[UOp]) -> tuple[str, list[str], list[tuple[str,tuple[UOp,bool]]]]:
r: dict[UOp, str] = {}
self.r = r
child_count = Counter(v for ru in uops for v in ru.src)
# find which PARAMs are stored to with a single toposort
writable_params = {u for u in UOp.sink(*[u.src[0] for u in uops if u.op is Ops.STORE]).toposort(lambda u: u.op != Ops.END) if u.op is Ops.PARAM}
bufs: dict[UOp, tuple[str, tuple[DType, bool]]] = {}
bufs: dict[UOp, tuple[str, tuple[UOp, bool]]] = {}
kernel = []
depth = 1
c: defaultdict[str, int] = defaultdict(int)
@ -180,7 +180,7 @@ class CStyleLanguage(Renderer):
if u.op is not Ops.PARAM: r[u] = u.arg[0]
elif isinstance(u.dtype, ImageDType): r[u] = f"data{u.arg.slot}_{u.dtype.shape[0]}x{u.dtype.shape[1]}"
else: r[u] = f"data{u.arg.slot}_{sz}" if (sz:=u.max_numel()) > 0 else f"data{u.arg.slot}"
bufs[u] = (r[u], (u.dtype, u in writable_params))
bufs[u] = (r[u], (u, u in writable_params))
continue
# naming
@ -266,7 +266,7 @@ class ClangRenderer(CStyleLanguage):
AMX_SET(1);\n return data0;\n}}"""]
return prefix
def _render_body(self, function_name, kernel, bufs, uops, pref=None) -> str: return super().render_kernel(function_name, kernel, bufs, uops, pref)
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]]) -> str: return ""
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str: return ""
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
defines = '\n'.join(self._render_defines(uops))
@ -303,12 +303,11 @@ class OpenCLRenderer(CStyleLanguage):
lambda ctx,x: f"{(struct.unpack('I', struct.pack('f', float_to_bf16(x.arg)))[0] >> 16)}u"),
# load/store image (OpenCL)
(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')), lambda ctx,buf,idx_y,idx_x: f"IMAGE<{ctx[buf]}, {ctx[idx_y]}, {ctx[idx_x]}>"),
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("var"), UPat.var("gate"))),
(UPat(Ops.LOAD, dtype=dtypes.float, src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("var"), UPat.var("gate"))),
lambda ctx,buf,idx_y,idx_x,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, (int2)({ctx[idx_x]},{ctx[idx_y]})):{ctx[var]})"),
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')),)),
(UPat(Ops.LOAD, dtype=dtypes.float, src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')),)),
lambda ctx,buf,idx_y,idx_x: f"read_imagef({ctx[buf]}, smp, (int2)({ctx[idx_x]},{ctx[idx_y]}))"),
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')),
UPat.var("var", dtypes.float.vec(4)))),
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("var", dtypes.float))),
lambda ctx,buf,idx_y,idx_x,var: f"write_imagef({ctx[buf]}, (int2)({ctx[idx_x]},{ctx[idx_y]}), {ctx[var]});"),
]) + base_rewrite

View file

@ -135,13 +135,13 @@ class LLVMRenderer(Renderer):
code_for_op = {k:lambda:None for v in lop.values() for k in v.keys()}
extra_matcher = create_non_native_float_pats((dtypes.bfloat16,)) + pm_manual_bf16_cast
def _render_fn(self, name:str, args:list[tuple[str,DType]], kernel:list[str], prefix:list[str]|None=None) -> str:
def _render_fn(self, name:str, args:list[tuple[str,UOp]], kernel:list[str], prefix:list[str]|None=None) -> str:
# NOTE: CPUAllocator promises 0x20 alignment
sargs = ", ".join([f"{ldt(dt)}{' noalias align 32' if isinstance(dt, PtrDType) else ''} {name}" for name,dt in args])
sargs = ", ".join([f"{ldt(u.dtype)}{' noalias align 32' if isinstance(u.dtype, PtrDType) else ''} {name}" for name,u in args])
return "\n".join((prefix or []) + [f"define{' ' + self.abi if self.abi else ''} void @{name}({sargs}) #0", "{"] + kernel + [" ret void\n}"])
def _render_kernel(self, uops: list[UOp], prefix:list[str]|None=None) -> tuple[tuple[str, ...], str]:
r: dict[UOp, str] = {}
args: list[tuple[str, DType]] = []
args: list[tuple[str, UOp]] = []
kernel: list[str] = []
vc = -1
@ -165,17 +165,17 @@ class LLVMRenderer(Renderer):
continue
if u.op in (Ops.PARAM, Ops.DEFINE_VAR):
r[u] = f"%data{u.arg.slot}" if u.op is Ops.PARAM else f"%{u.expr}"
args.append((r[u], u.dtype))
args.append((r[u], u))
elif u.op in (Ops.DEFINE_LOCAL, Ops.DEFINE_REG):
r[u] = f"%{'local' if u.op is Ops.DEFINE_LOCAL else 'reg'}_{str(u.arg).replace('(', '').replace(')', '').replace(',', '_').replace(' ', '')}"
assert isinstance(u.dtype, PtrDType)
size = u.max_numel()
if u.op is Ops.DEFINE_REG:
kernel.append(f" {r[u]} = alloca [{u.dtype.size} x {ldt(u.dtype.base)}]")
kernel.append(f" {r[u]} = alloca [{size} x {ldt(u.dtype.base)}]")
elif self.has_local:
local_args.append(f"@{r[u][1:]} = internal unnamed_addr addrspace(3) global [{u.dtype.size} x {ldt(u.dtype)}] undef, align 16")
kernel.append(f" {r[u]} = addrspacecast [{u.dtype.size} x {ldt(u.dtype)}] addrspace(3)* @{r[u][1:]} to [{u.dtype.size} x {ldt(u.dtype)}]*")
local_args.append(f"@{r[u][1:]} = internal unnamed_addr addrspace(3) global [{size} x {ldt(u.dtype)}] undef, align 16")
kernel.append(f" {r[u]} = addrspacecast [{size} x {ldt(u.dtype)}] addrspace(3)* @{r[u][1:]} to [{size} x {ldt(u.dtype)}]*")
else:
kernel.append(f" {r[u]} = alloca [{u.dtype.size} x {ldt(u.dtype.base)}], align 16")
kernel.append(f" {r[u]} = alloca [{size} x {ldt(u.dtype.base)}], align 16")
elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)
elif u.op is Ops.CAST and (ldt(u.dtype) == ldt(u.src[0].dtype) or isinstance(u.dtype, PtrDType)):
r[u] = r[u.src[0]] # cast from signed to unsigned of the same size is a noop, or pointer cast

View file

@ -29,7 +29,9 @@ def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None, gate:UOp|Non
def is_packed(dt:DType, odt:DType|None = None) -> bool:
if odt is None: odt = dt
return dt.itemsize < 4 and dt.base != dtypes.half and (not isinstance(odt, PtrDType) or odt.addrspace != AddrSpace.REG)
# registers aren't packed
if isinstance(odt, PtrDType) and odt.addrspace == AddrSpace.REG: return False
return dt.itemsize < 4 and dt.base != dtypes.half
def _packed_size(dt:PtrDType): return dt.size // (4//dt.itemsize) if is_packed(dt) else dt.size
def is_nan(a):
@ -98,7 +100,7 @@ class WGSLRenderer(CStyleLanguage):
def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if is_packed(dt) else x
def buf_map(self, dt:DType) -> str: return "atomic<u32>" if is_packed(dt) else self.type_map[dt.base]
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[UOp,bool]]], uops:list[UOp], prefix=None) -> str:
local_size = [u.src[0].ssimplify() for u in sorted([u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == 'l'], key=lambda u: u.arg)]
if not local_size: local_size = [1]
bind_it = iter(range(len(bufs)))
@ -108,8 +110,8 @@ class WGSLRenderer(CStyleLanguage):
prg += "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" +
f"{'var<storage,read_write>' if isinstance(dtype, PtrDType) else 'var<uniform>'}" +
f"{name}:{f'array<{self.buf_map(dtype.base)}>' if isinstance(dtype,PtrDType) else self.buf_map(dtype)};" for name,(dtype,_) in bufs])
f"{'var<storage,read_write>' if isinstance(u.dtype, PtrDType) else 'var<uniform>'}" +
f"{name}:{f'array<{self.buf_map(u.dtype.base)}>' if isinstance(u.dtype,PtrDType) else self.buf_map(u.dtype)};" for name,(u,_) in bufs])
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys, subprocess, struct
assert sys.platform != 'win32'
from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler
from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.uop.ops import Ops, UOp
from tinygrad.helpers import getenv, round_up, mv_address, to_mv, cpu_objdump, system, DEBUG, suppress_finalizing, Target
from tinygrad.renderer.cstyle import ClangRenderer
@ -53,18 +53,19 @@ class DSPRenderer(ClangRenderer):
'void* HAP_mmap(void *addr, int len, int prot, int flags, int fd, long offset);', 'int HAP_munmap(void *addr, int len);',
'unsigned long long HAP_perf_get_time_us(void);'] + super()._render_defines(uops)
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]]) -> str:
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str:
msrc = ['int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};',
'HAP_power_set((void*)handle, (void*)&req);']
msrc += ['if ((sc>>24) != 2) return 0;']
msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0].dtype, PtrDType)]
msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs)
if isinstance(b[1][0].dtype, PtrDType)]
msrc += ["unsigned long long start = HAP_perf_get_time_us();"]
msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0], PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0].dtype, PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"]
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0].dtype, PtrDType)]
msrc += ["return 0; }"]
return '\n'.join(msrc)
@ -273,22 +274,23 @@ return (void*)syscall((long)addr, length, prot, flags, fd, offset, 222); }}'''
class MockDSPRenderer(DSPRenderer):
def __init__(self, target:Target): self.target, self.compiler = target, DSPCompiler(mock=True)
def _render_defines(self, uops) -> list[str]: return ClangRenderer._render_defines(self, uops)
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]]) -> str:
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str:
# https://gpages.juszkiewicz.com.pl/syscalls-table/syscalls.html
# control register 21 is HEX_REG_QEMU_INSN_CNT, 0x6a15c000 loads it
msrc = [mockdsp_boilerplate, 'void _start(void) {']
for i,b in enumerate(bufs):
if isinstance(b[1][0], PtrDType):
sz = b[1][0].size*b[1][0].itemsize
if isinstance(b[1][0].dtype, PtrDType):
sz = b[1][0].dtype.size*b[1][0].dtype.itemsize
# for loop for big reads
msrc.append(f"void *buf{i} = mmap2(0, {sz}, 3, 0x21, -1, 0); for(int rd = 0; rd < {sz}; rd += read(0, buf{i}+rd, {sz}-rd));")
else:
msrc.append(f"unsigned int val{i}; read(0, &val{i}, 4);")
msrc.append("unsigned int st = inscount();")
msrc.append(f"{function_name}({', '.join([(f'(void*)buf{i}' if isinstance(b[1][0], PtrDType) else f'val{i}') for i,b in enumerate(bufs)])});")
params = [(f'(void*)buf{i}' if isinstance(b[1][0].dtype, PtrDType) else f'val{i}') for i,b in enumerate(bufs)]
msrc.append(f"{function_name}({', '.join(params)});")
msrc.append("unsigned int et = inscount() - st; write(1, &et, sizeof(et));")
for i,b in enumerate(bufs):
if isinstance(b[1][0], PtrDType): msrc.append(f"write(1, buf{i}, {b[1][0].size*b[1][0].itemsize});")
if isinstance(b[1][0].dtype, PtrDType): msrc.append(f"write(1, buf{i}, {b[1][0].dtype.size*b[1][0].dtype.itemsize});")
msrc.append('exit(0); }')
return '\n'.join(msrc)