mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
clean renderer migrations [pr] (#16472)
* clean renderer migrations * minor webgpu * use PARAM UOp as API * make linter happy
This commit is contained in:
parent
9897658895
commit
82f1c983d4
6 changed files with 41 additions and 38 deletions
|
|
@ -115,7 +115,7 @@ pm_linearize_cleanups = PatternMatcher([
|
|||
# if statements are not allowed in the graph
|
||||
(UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError, "if not allowed in graph")),
|
||||
# gated STORE becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF
|
||||
(UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX).or_casted(), UPat(), UPat(name="gate", dtype=dtypes.bool))),
|
||||
(UPat(Ops.STORE, name="u", src=(UPat((Ops.INDEX, Ops.SHRINK)).or_casted(), UPat(), UPat(name="gate", dtype=dtypes.bool))),
|
||||
lambda u, gate: ((st:=u.replace(src=u.src[0:2])), [mif:=UOp(Ops.IF, src=(gate, u.src[0])), st, UOp(Ops.ENDIF, src=(mif,))]))
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class Estimates:
|
|||
if buf.op is Ops.PARAM:
|
||||
# u.src[0] is INDEX, cap at buffer size for re-reads (e.g. matmul)
|
||||
accessed = mem.get((buf, u.op), 0) + u.src[0].dtype.base.itemsize * mults
|
||||
mem[(buf, u.op)] = smin(accessed, buf.ptrdtype.nbytes()) if buf.ptrdtype.size != -1 else accessed
|
||||
mem[(buf, u.op)] = smin(accessed, buf.max_numel() * buf.dtype.itemsize)
|
||||
if u.op is Ops.RANGE:
|
||||
mult_stack.append(mults)
|
||||
mults *= cast(sint, u.src[0].ssimplify())
|
||||
|
|
|
|||
|
|
@ -131,12 +131,12 @@ class CStyleLanguage(Renderer):
|
|||
string_rewrite = base_rewrite
|
||||
extra_matcher = extra_pm
|
||||
|
||||
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
|
||||
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[UOp,bool]]], uops:list[UOp], prefix=None) -> str:
|
||||
tmp = ""
|
||||
if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs):
|
||||
if any(isinstance(u.dtype, ImageDType) for _,(u,_) in bufs):
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
|
||||
buftypes = [(name, self.render_dtype(dtype, mutable)+self.buffer_suffix if isinstance(dtype, (ImageDType, PtrDType)) else
|
||||
self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
|
||||
buftypes = [(name, self.render_dtype(u.dtype, mutable)+self.buffer_suffix if isinstance(u.dtype, (ImageDType, PtrDType)) else
|
||||
self.arg_int_prefix if u.dtype == dtypes.int else None) for name,(u,mutable) in bufs]
|
||||
local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]
|
||||
launch_bounds = prod([d.vmax for d in local_dims])
|
||||
prg = ''.join([f"{self.kernel_typedef.format(launch_bounds=launch_bounds)} {function_name}(",] +
|
||||
|
|
@ -156,14 +156,14 @@ class CStyleLanguage(Renderer):
|
|||
return self.type_map.get(scalar:=dt.scalar(), scalar.name)
|
||||
|
||||
def __getitem__(self, key): return self.r[key] # hacky helper
|
||||
def _render(self, uops:list[UOp]) -> tuple[str, list[str], list[tuple[str,tuple[DType,bool]]]]:
|
||||
def _render(self, uops:list[UOp]) -> tuple[str, list[str], list[tuple[str,tuple[UOp,bool]]]]:
|
||||
r: dict[UOp, str] = {}
|
||||
self.r = r
|
||||
|
||||
child_count = Counter(v for ru in uops for v in ru.src)
|
||||
# find which PARAMs are stored to with a single toposort
|
||||
writable_params = {u for u in UOp.sink(*[u.src[0] for u in uops if u.op is Ops.STORE]).toposort(lambda u: u.op != Ops.END) if u.op is Ops.PARAM}
|
||||
bufs: dict[UOp, tuple[str, tuple[DType, bool]]] = {}
|
||||
bufs: dict[UOp, tuple[str, tuple[UOp, bool]]] = {}
|
||||
kernel = []
|
||||
depth = 1
|
||||
c: defaultdict[str, int] = defaultdict(int)
|
||||
|
|
@ -180,7 +180,7 @@ class CStyleLanguage(Renderer):
|
|||
if u.op is not Ops.PARAM: r[u] = u.arg[0]
|
||||
elif isinstance(u.dtype, ImageDType): r[u] = f"data{u.arg.slot}_{u.dtype.shape[0]}x{u.dtype.shape[1]}"
|
||||
else: r[u] = f"data{u.arg.slot}_{sz}" if (sz:=u.max_numel()) > 0 else f"data{u.arg.slot}"
|
||||
bufs[u] = (r[u], (u.dtype, u in writable_params))
|
||||
bufs[u] = (r[u], (u, u in writable_params))
|
||||
continue
|
||||
|
||||
# naming
|
||||
|
|
@ -266,7 +266,7 @@ class ClangRenderer(CStyleLanguage):
|
|||
AMX_SET(1);\n return data0;\n}}"""]
|
||||
return prefix
|
||||
def _render_body(self, function_name, kernel, bufs, uops, pref=None) -> str: return super().render_kernel(function_name, kernel, bufs, uops, pref)
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]]) -> str: return ""
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str: return ""
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
defines = '\n'.join(self._render_defines(uops))
|
||||
|
|
@ -303,12 +303,11 @@ class OpenCLRenderer(CStyleLanguage):
|
|||
lambda ctx,x: f"{(struct.unpack('I', struct.pack('f', float_to_bf16(x.arg)))[0] >> 16)}u"),
|
||||
# load/store image (OpenCL)
|
||||
(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')), lambda ctx,buf,idx_y,idx_x: f"IMAGE<{ctx[buf]}, {ctx[idx_y]}, {ctx[idx_x]}>"),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("var"), UPat.var("gate"))),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float, src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("var"), UPat.var("gate"))),
|
||||
lambda ctx,buf,idx_y,idx_x,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, (int2)({ctx[idx_x]},{ctx[idx_y]})):{ctx[var]})"),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')),)),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float, src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')),)),
|
||||
lambda ctx,buf,idx_y,idx_x: f"read_imagef({ctx[buf]}, smp, (int2)({ctx[idx_x]},{ctx[idx_y]}))"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')),
|
||||
UPat.var("var", dtypes.float.vec(4)))),
|
||||
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("var", dtypes.float))),
|
||||
lambda ctx,buf,idx_y,idx_x,var: f"write_imagef({ctx[buf]}, (int2)({ctx[idx_x]},{ctx[idx_y]}), {ctx[var]});"),
|
||||
]) + base_rewrite
|
||||
|
||||
|
|
|
|||
|
|
@ -135,13 +135,13 @@ class LLVMRenderer(Renderer):
|
|||
code_for_op = {k:lambda:None for v in lop.values() for k in v.keys()}
|
||||
|
||||
extra_matcher = create_non_native_float_pats((dtypes.bfloat16,)) + pm_manual_bf16_cast
|
||||
def _render_fn(self, name:str, args:list[tuple[str,DType]], kernel:list[str], prefix:list[str]|None=None) -> str:
|
||||
def _render_fn(self, name:str, args:list[tuple[str,UOp]], kernel:list[str], prefix:list[str]|None=None) -> str:
|
||||
# NOTE: CPUAllocator promises 0x20 alignment
|
||||
sargs = ", ".join([f"{ldt(dt)}{' noalias align 32' if isinstance(dt, PtrDType) else ''} {name}" for name,dt in args])
|
||||
sargs = ", ".join([f"{ldt(u.dtype)}{' noalias align 32' if isinstance(u.dtype, PtrDType) else ''} {name}" for name,u in args])
|
||||
return "\n".join((prefix or []) + [f"define{' ' + self.abi if self.abi else ''} void @{name}({sargs}) #0", "{"] + kernel + [" ret void\n}"])
|
||||
def _render_kernel(self, uops: list[UOp], prefix:list[str]|None=None) -> tuple[tuple[str, ...], str]:
|
||||
r: dict[UOp, str] = {}
|
||||
args: list[tuple[str, DType]] = []
|
||||
args: list[tuple[str, UOp]] = []
|
||||
kernel: list[str] = []
|
||||
vc = -1
|
||||
|
||||
|
|
@ -165,17 +165,17 @@ class LLVMRenderer(Renderer):
|
|||
continue
|
||||
if u.op in (Ops.PARAM, Ops.DEFINE_VAR):
|
||||
r[u] = f"%data{u.arg.slot}" if u.op is Ops.PARAM else f"%{u.expr}"
|
||||
args.append((r[u], u.dtype))
|
||||
args.append((r[u], u))
|
||||
elif u.op in (Ops.DEFINE_LOCAL, Ops.DEFINE_REG):
|
||||
r[u] = f"%{'local' if u.op is Ops.DEFINE_LOCAL else 'reg'}_{str(u.arg).replace('(', '').replace(')', '').replace(',', '_').replace(' ', '')}"
|
||||
assert isinstance(u.dtype, PtrDType)
|
||||
size = u.max_numel()
|
||||
if u.op is Ops.DEFINE_REG:
|
||||
kernel.append(f" {r[u]} = alloca [{u.dtype.size} x {ldt(u.dtype.base)}]")
|
||||
kernel.append(f" {r[u]} = alloca [{size} x {ldt(u.dtype.base)}]")
|
||||
elif self.has_local:
|
||||
local_args.append(f"@{r[u][1:]} = internal unnamed_addr addrspace(3) global [{u.dtype.size} x {ldt(u.dtype)}] undef, align 16")
|
||||
kernel.append(f" {r[u]} = addrspacecast [{u.dtype.size} x {ldt(u.dtype)}] addrspace(3)* @{r[u][1:]} to [{u.dtype.size} x {ldt(u.dtype)}]*")
|
||||
local_args.append(f"@{r[u][1:]} = internal unnamed_addr addrspace(3) global [{size} x {ldt(u.dtype)}] undef, align 16")
|
||||
kernel.append(f" {r[u]} = addrspacecast [{size} x {ldt(u.dtype)}] addrspace(3)* @{r[u][1:]} to [{size} x {ldt(u.dtype)}]*")
|
||||
else:
|
||||
kernel.append(f" {r[u]} = alloca [{u.dtype.size} x {ldt(u.dtype.base)}], align 16")
|
||||
kernel.append(f" {r[u]} = alloca [{size} x {ldt(u.dtype.base)}], align 16")
|
||||
elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)
|
||||
elif u.op is Ops.CAST and (ldt(u.dtype) == ldt(u.src[0].dtype) or isinstance(u.dtype, PtrDType)):
|
||||
r[u] = r[u.src[0]] # cast from signed to unsigned of the same size is a noop, or pointer cast
|
||||
|
|
|
|||
|
|
@ -29,7 +29,9 @@ def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None, gate:UOp|Non
|
|||
|
||||
def is_packed(dt:DType, odt:DType|None = None) -> bool:
|
||||
if odt is None: odt = dt
|
||||
return dt.itemsize < 4 and dt.base != dtypes.half and (not isinstance(odt, PtrDType) or odt.addrspace != AddrSpace.REG)
|
||||
# registers aren't packed
|
||||
if isinstance(odt, PtrDType) and odt.addrspace == AddrSpace.REG: return False
|
||||
return dt.itemsize < 4 and dt.base != dtypes.half
|
||||
def _packed_size(dt:PtrDType): return dt.size // (4//dt.itemsize) if is_packed(dt) else dt.size
|
||||
|
||||
def is_nan(a):
|
||||
|
|
@ -98,7 +100,7 @@ class WGSLRenderer(CStyleLanguage):
|
|||
def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
|
||||
def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if is_packed(dt) else x
|
||||
def buf_map(self, dt:DType) -> str: return "atomic<u32>" if is_packed(dt) else self.type_map[dt.base]
|
||||
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
|
||||
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[UOp,bool]]], uops:list[UOp], prefix=None) -> str:
|
||||
local_size = [u.src[0].ssimplify() for u in sorted([u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == 'l'], key=lambda u: u.arg)]
|
||||
if not local_size: local_size = [1]
|
||||
bind_it = iter(range(len(bufs)))
|
||||
|
|
@ -108,8 +110,8 @@ class WGSLRenderer(CStyleLanguage):
|
|||
prg += "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
|
||||
prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
|
||||
prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" +
|
||||
f"{'var<storage,read_write>' if isinstance(dtype, PtrDType) else 'var<uniform>'}" +
|
||||
f"{name}:{f'array<{self.buf_map(dtype.base)}>' if isinstance(dtype,PtrDType) else self.buf_map(dtype)};" for name,(dtype,_) in bufs])
|
||||
f"{'var<storage,read_write>' if isinstance(u.dtype, PtrDType) else 'var<uniform>'}" +
|
||||
f"{name}:{f'array<{self.buf_map(u.dtype.base)}>' if isinstance(u.dtype,PtrDType) else self.buf_map(u.dtype)};" for name,(u,_) in bufs])
|
||||
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
|
||||
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys, subprocess, struct
|
||||
assert sys.platform != 'win32'
|
||||
from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.helpers import getenv, round_up, mv_address, to_mv, cpu_objdump, system, DEBUG, suppress_finalizing, Target
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
|
|
@ -53,18 +53,19 @@ class DSPRenderer(ClangRenderer):
|
|||
'void* HAP_mmap(void *addr, int len, int prot, int flags, int fd, long offset);', 'int HAP_munmap(void *addr, int len);',
|
||||
'unsigned long long HAP_perf_get_time_us(void);'] + super()._render_defines(uops)
|
||||
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]]) -> str:
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str:
|
||||
msrc = ['int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
|
||||
'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};',
|
||||
'HAP_power_set((void*)handle, (void*)&req);']
|
||||
msrc += ['if ((sc>>24) != 2) return 0;']
|
||||
msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
|
||||
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0].dtype, PtrDType)]
|
||||
msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs)
|
||||
if isinstance(b[1][0].dtype, PtrDType)]
|
||||
msrc += ["unsigned long long start = HAP_perf_get_time_us();"]
|
||||
msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0], PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
|
||||
msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0].dtype, PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
|
||||
msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"]
|
||||
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0].dtype, PtrDType)]
|
||||
msrc += ["return 0; }"]
|
||||
return '\n'.join(msrc)
|
||||
|
||||
|
|
@ -273,22 +274,23 @@ return (void*)syscall((long)addr, length, prot, flags, fd, offset, 222); }}'''
|
|||
class MockDSPRenderer(DSPRenderer):
|
||||
def __init__(self, target:Target): self.target, self.compiler = target, DSPCompiler(mock=True)
|
||||
def _render_defines(self, uops) -> list[str]: return ClangRenderer._render_defines(self, uops)
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]]) -> str:
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str:
|
||||
# https://gpages.juszkiewicz.com.pl/syscalls-table/syscalls.html
|
||||
# control register 21 is HEX_REG_QEMU_INSN_CNT, 0x6a15c000 loads it
|
||||
msrc = [mockdsp_boilerplate, 'void _start(void) {']
|
||||
for i,b in enumerate(bufs):
|
||||
if isinstance(b[1][0], PtrDType):
|
||||
sz = b[1][0].size*b[1][0].itemsize
|
||||
if isinstance(b[1][0].dtype, PtrDType):
|
||||
sz = b[1][0].dtype.size*b[1][0].dtype.itemsize
|
||||
# for loop for big reads
|
||||
msrc.append(f"void *buf{i} = mmap2(0, {sz}, 3, 0x21, -1, 0); for(int rd = 0; rd < {sz}; rd += read(0, buf{i}+rd, {sz}-rd));")
|
||||
else:
|
||||
msrc.append(f"unsigned int val{i}; read(0, &val{i}, 4);")
|
||||
msrc.append("unsigned int st = inscount();")
|
||||
msrc.append(f"{function_name}({', '.join([(f'(void*)buf{i}' if isinstance(b[1][0], PtrDType) else f'val{i}') for i,b in enumerate(bufs)])});")
|
||||
params = [(f'(void*)buf{i}' if isinstance(b[1][0].dtype, PtrDType) else f'val{i}') for i,b in enumerate(bufs)]
|
||||
msrc.append(f"{function_name}({', '.join(params)});")
|
||||
msrc.append("unsigned int et = inscount() - st; write(1, &et, sizeof(et));")
|
||||
for i,b in enumerate(bufs):
|
||||
if isinstance(b[1][0], PtrDType): msrc.append(f"write(1, buf{i}, {b[1][0].size*b[1][0].itemsize});")
|
||||
if isinstance(b[1][0].dtype, PtrDType): msrc.append(f"write(1, buf{i}, {b[1][0].dtype.size*b[1][0].dtype.itemsize});")
|
||||
msrc.append('exit(0); }')
|
||||
return '\n'.join(msrc)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue