This commit is contained in:
George Hotz 2026-06-02 12:53:00 -07:00
commit e76de41110
2 changed files with 17 additions and 11 deletions

View file

@ -107,7 +107,9 @@ class WGSLRenderer(CStyleLanguage):
def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
def render_load(self, x:str, uop:UOp) -> str: return f"atomicLoad(&{x})" if is_packed(uop) and uop.addrspace != AddrSpace.REG else x
def buf_map(self, dt:DType) -> str: return "atomic<u32>" if dt.itemsize < 4 and dt != dtypes.half else self.type_map[dt.base]
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[UOp,bool]]], uops:list[UOp], prefix=None) -> str:
def arg_dtype(u:UOp) -> DType:
return u.dtype if isinstance(u.dtype, PtrDType) or u.op is not Ops.PARAM else u.dtype.ptr(u.max_numel(), u.addrspace)
local_size = [u.src[0].ssimplify() for u in sorted([u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == 'l'], key=lambda u: u.arg)]
if not local_size: local_size = [1]
bind_it = iter(range(len(bufs)))
@ -117,8 +119,9 @@ class WGSLRenderer(CStyleLanguage):
prg += "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" +
f"{'var<storage,read_write>' if isinstance(u.dtype, PtrDType) else 'var<uniform>'}" +
f"{name}:{f'array<{self.buf_map(u.dtype.base)}>' if isinstance(u.dtype,PtrDType) else self.buf_map(u.dtype)};" for name,(u,_) in bufs])
f"{'var<storage,read_write>' if isinstance(dt, PtrDType) else 'var<uniform>'}" +
f"{name}:{f'array<{self.buf_map(dt.base)}>' if isinstance(dt,PtrDType) else self.buf_map(dt)};"
for name,(u,_) in bufs for dt in (arg_dtype(u),)])
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"

View file

@ -54,18 +54,20 @@ class DSPRenderer(ClangRenderer):
'unsigned long long HAP_perf_get_time_us(void);'] + super()._render_defines(uops)
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str:
def arg_dtype(u:UOp): return u.dtype if isinstance(u.dtype, PtrDType) or u.op is not Ops.PARAM else u.dtype.ptr(u.max_numel(), u.addrspace)
msrc = ['int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};',
'HAP_power_set((void*)handle, (void*)&req);']
msrc += ['if ((sc>>24) != 2) return 0;']
msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0].dtype, PtrDType)]
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(arg_dtype(b[1][0]), PtrDType)]
msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs)
if isinstance(b[1][0].dtype, PtrDType)]
if isinstance(arg_dtype(b[1][0]), PtrDType)]
msrc += ["unsigned long long start = HAP_perf_get_time_us();"]
msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0].dtype, PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
params = [(f'buf_{i}' if isinstance(arg_dtype(b[1][0]), PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)]
msrc += [f"{function_name}({', '.join(params)});"]
msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"]
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0].dtype, PtrDType)]
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(arg_dtype(b[1][0]), PtrDType)]
msrc += ["return 0; }"]
return '\n'.join(msrc)
@ -275,22 +277,23 @@ class MockDSPRenderer(DSPRenderer):
def __init__(self, target:Target): self.target, self.compiler = target, DSPCompiler(mock=True)
def _render_defines(self, uops) -> list[str]: return ClangRenderer._render_defines(self, uops)
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str:
def arg_dtype(u:UOp): return u.dtype if isinstance(u.dtype, PtrDType) or u.op is not Ops.PARAM else u.dtype.ptr(u.max_numel(), u.addrspace)
# https://gpages.juszkiewicz.com.pl/syscalls-table/syscalls.html
# control register 21 is HEX_REG_QEMU_INSN_CNT, 0x6a15c000 loads it
msrc = [mockdsp_boilerplate, 'void _start(void) {']
for i,b in enumerate(bufs):
if isinstance(b[1][0].dtype, PtrDType):
sz = b[1][0].dtype.size*b[1][0].dtype.itemsize
if isinstance(dt:=arg_dtype(b[1][0]), PtrDType):
sz = dt.size*dt.itemsize
# for loop for big reads
msrc.append(f"void *buf{i} = mmap2(0, {sz}, 3, 0x21, -1, 0); for(int rd = 0; rd < {sz}; rd += read(0, buf{i}+rd, {sz}-rd));")
else:
msrc.append(f"unsigned int val{i}; read(0, &val{i}, 4);")
msrc.append("unsigned int st = inscount();")
params = [(f'(void*)buf{i}' if isinstance(b[1][0].dtype, PtrDType) else f'val{i}') for i,b in enumerate(bufs)]
params = [(f'(void*)buf{i}' if isinstance(arg_dtype(b[1][0]), PtrDType) else f'val{i}') for i,b in enumerate(bufs)]
msrc.append(f"{function_name}({', '.join(params)});")
msrc.append("unsigned int et = inscount() - st; write(1, &et, sizeof(et));")
for i,b in enumerate(bufs):
if isinstance(b[1][0].dtype, PtrDType): msrc.append(f"write(1, buf{i}, {b[1][0].dtype.size*b[1][0].dtype.itemsize});")
if isinstance(dt:=arg_dtype(b[1][0]), PtrDType): msrc.append(f"write(1, buf{i}, {dt.size*dt.itemsize});")
msrc.append('exit(0); }')
return '\n'.join(msrc)