mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fixes
This commit is contained in:
parent
5768042e3f
commit
e76de41110
2 changed files with 17 additions and 11 deletions
|
|
@ -107,7 +107,9 @@ class WGSLRenderer(CStyleLanguage):
|
|||
def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
|
||||
def render_load(self, x:str, uop:UOp) -> str: return f"atomicLoad(&{x})" if is_packed(uop) and uop.addrspace != AddrSpace.REG else x
|
||||
def buf_map(self, dt:DType) -> str: return "atomic<u32>" if dt.itemsize < 4 and dt != dtypes.half else self.type_map[dt.base]
|
||||
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
|
||||
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[UOp,bool]]], uops:list[UOp], prefix=None) -> str:
|
||||
def arg_dtype(u:UOp) -> DType:
|
||||
return u.dtype if isinstance(u.dtype, PtrDType) or u.op is not Ops.PARAM else u.dtype.ptr(u.max_numel(), u.addrspace)
|
||||
local_size = [u.src[0].ssimplify() for u in sorted([u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == 'l'], key=lambda u: u.arg)]
|
||||
if not local_size: local_size = [1]
|
||||
bind_it = iter(range(len(bufs)))
|
||||
|
|
@ -117,8 +119,9 @@ class WGSLRenderer(CStyleLanguage):
|
|||
prg += "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
|
||||
prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
|
||||
prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" +
|
||||
f"{'var<storage,read_write>' if isinstance(u.dtype, PtrDType) else 'var<uniform>'}" +
|
||||
f"{name}:{f'array<{self.buf_map(u.dtype.base)}>' if isinstance(u.dtype,PtrDType) else self.buf_map(u.dtype)};" for name,(u,_) in bufs])
|
||||
f"{'var<storage,read_write>' if isinstance(dt, PtrDType) else 'var<uniform>'}" +
|
||||
f"{name}:{f'array<{self.buf_map(dt.base)}>' if isinstance(dt,PtrDType) else self.buf_map(dt)};"
|
||||
for name,(u,_) in bufs for dt in (arg_dtype(u),)])
|
||||
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
|
||||
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"
|
||||
|
||||
|
|
|
|||
|
|
@ -54,18 +54,20 @@ class DSPRenderer(ClangRenderer):
|
|||
'unsigned long long HAP_perf_get_time_us(void);'] + super()._render_defines(uops)
|
||||
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str:
|
||||
def arg_dtype(u:UOp): return u.dtype if isinstance(u.dtype, PtrDType) or u.op is not Ops.PARAM else u.dtype.ptr(u.max_numel(), u.addrspace)
|
||||
msrc = ['int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
|
||||
'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};',
|
||||
'HAP_power_set((void*)handle, (void*)&req);']
|
||||
msrc += ['if ((sc>>24) != 2) return 0;']
|
||||
msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
|
||||
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0].dtype, PtrDType)]
|
||||
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(arg_dtype(b[1][0]), PtrDType)]
|
||||
msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs)
|
||||
if isinstance(b[1][0].dtype, PtrDType)]
|
||||
if isinstance(arg_dtype(b[1][0]), PtrDType)]
|
||||
msrc += ["unsigned long long start = HAP_perf_get_time_us();"]
|
||||
msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0].dtype, PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
|
||||
params = [(f'buf_{i}' if isinstance(arg_dtype(b[1][0]), PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)]
|
||||
msrc += [f"{function_name}({', '.join(params)});"]
|
||||
msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"]
|
||||
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0].dtype, PtrDType)]
|
||||
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(arg_dtype(b[1][0]), PtrDType)]
|
||||
msrc += ["return 0; }"]
|
||||
return '\n'.join(msrc)
|
||||
|
||||
|
|
@ -275,22 +277,23 @@ class MockDSPRenderer(DSPRenderer):
|
|||
def __init__(self, target:Target): self.target, self.compiler = target, DSPCompiler(mock=True)
|
||||
def _render_defines(self, uops) -> list[str]: return ClangRenderer._render_defines(self, uops)
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[UOp,bool]]]) -> str:
|
||||
def arg_dtype(u:UOp): return u.dtype if isinstance(u.dtype, PtrDType) or u.op is not Ops.PARAM else u.dtype.ptr(u.max_numel(), u.addrspace)
|
||||
# https://gpages.juszkiewicz.com.pl/syscalls-table/syscalls.html
|
||||
# control register 21 is HEX_REG_QEMU_INSN_CNT, 0x6a15c000 loads it
|
||||
msrc = [mockdsp_boilerplate, 'void _start(void) {']
|
||||
for i,b in enumerate(bufs):
|
||||
if isinstance(b[1][0].dtype, PtrDType):
|
||||
sz = b[1][0].dtype.size*b[1][0].dtype.itemsize
|
||||
if isinstance(dt:=arg_dtype(b[1][0]), PtrDType):
|
||||
sz = dt.size*dt.itemsize
|
||||
# for loop for big reads
|
||||
msrc.append(f"void *buf{i} = mmap2(0, {sz}, 3, 0x21, -1, 0); for(int rd = 0; rd < {sz}; rd += read(0, buf{i}+rd, {sz}-rd));")
|
||||
else:
|
||||
msrc.append(f"unsigned int val{i}; read(0, &val{i}, 4);")
|
||||
msrc.append("unsigned int st = inscount();")
|
||||
params = [(f'(void*)buf{i}' if isinstance(b[1][0].dtype, PtrDType) else f'val{i}') for i,b in enumerate(bufs)]
|
||||
params = [(f'(void*)buf{i}' if isinstance(arg_dtype(b[1][0]), PtrDType) else f'val{i}') for i,b in enumerate(bufs)]
|
||||
msrc.append(f"{function_name}({', '.join(params)});")
|
||||
msrc.append("unsigned int et = inscount() - st; write(1, &et, sizeof(et));")
|
||||
for i,b in enumerate(bufs):
|
||||
if isinstance(b[1][0].dtype, PtrDType): msrc.append(f"write(1, buf{i}, {b[1][0].dtype.size*b[1][0].dtype.itemsize});")
|
||||
if isinstance(dt:=arg_dtype(b[1][0]), PtrDType): msrc.append(f"write(1, buf{i}, {dt.size*dt.itemsize});")
|
||||
msrc.append('exit(0); }')
|
||||
return '\n'.join(msrc)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue