dsp multicore 2 (#9644)

* dsp multicore 2

* hmm

* better
This commit is contained in:
nimlgen 2025-03-31 22:56:54 +07:00 committed by GitHub
commit bb7b89475c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 90 additions and 29 deletions

View file

@ -1,12 +1,34 @@
import pickle, sys
from dataclasses import replace
from tinygrad import Device, Context
from tinygrad import Device, Context, Tensor
from tinygrad.device import Buffer
from tinygrad.helpers import getenv, BEAM
from tinygrad.engine.jit import TinyJit
from tinygrad.engine.realize import CompiledRunner
from tinygrad.engine.realize import CompiledRunner, ExecItem
from tinygrad.renderer import ProgramSpec
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
import numpy as np
def move_jit_captured_to_dev(captured, device="DSP"):
captured.expected_st_vars_dtype_device = [x[:3] + (device,) for x in captured.expected_st_vars_dtype_device]
assign = {}
def move_buffer(b):
if b in assign: return assign[b]
if b._base is not None:
newbuf = Buffer(device, b.size, b.dtype, base=move_buffer(b._base), offset=b.offset)
else:
newbuf = Buffer(device, b.size, b.dtype)
if b.is_allocated(): newbuf.ensure_allocated().copyin(b.as_buffer())
assign[b] = newbuf
return assign[b]
for item in captured.jit_cache:
for b in item.bufs:
if b is not None: move_buffer(b)
captured.jit_cache = [ExecItem(item.prg, [assign.get(b,b) for b in item.bufs]) for item in captured.jit_cache]
return captured
if __name__ == "__main__":
with Context(DEBUG=0):
@ -15,6 +37,10 @@ if __name__ == "__main__":
print(f"{f.tell()/1e6:.2f}M loaded")
print(type(fxn))
# Move all buffers to DSP device.
fxn.captured = move_jit_captured_to_dev(fxn.captured, "DSP")
new_jit = []
knum = 1
for ei in fxn.captured.jit_cache:
# skip the copy and the first kernel
@ -22,14 +48,18 @@ if __name__ == "__main__":
if knum == (pknum:=getenv("KNUM", 0)) or pknum == 0:
p: ProgramSpec = ei.prg.p
k = Kernel(p.ast, Device["DSP"].renderer)
dsp_bufs = [Buffer("DSP", 8192+b.size, b.dtype).view(b.size, b.dtype, 4096) for b in ei.bufs]
k.hand_coded_optimizations()
#if knum == 13: k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
p2 = k.to_program()
new_ei = replace(ei, prg=CompiledRunner(p2), bufs=dsp_bufs)
new_ei = replace(ei, prg=CompiledRunner(p2))
if getenv("MULTICORE", 0) == 1:
new_ei.run({p2.vars[0]: 0})
new_ei.run({p2.vars[0]: 1})
new_ei.run()
else:
new_ei.run()
new_jit.append(new_ei)
knum += 1
if getenv("RUN_JIT", 0):
fxn.captured.free_intermediates()
fxn.captured.jit_cache = new_jit
fxn(input=Tensor(np.zeros((1, 3, 224, 224), dtype=np.float32), device="DSP"))

View file

@ -168,8 +168,8 @@ class CapturedJit(Generic[ReturnType]):
update_depends(depends, self.jit_cache)
for b in depends:
if b is not None:
b.deallocate()
if b._base is not None and b._base.allocated_views == 0: b._base.deallocate()
if b.is_allocated(): b.deallocate()
if b._base is not None and b._base.allocated_views == 0 and b._base.is_allocated(): b._base.deallocate()
self.__post_init__() # reset the graph state
def optimize_weights(self):

View file

@ -111,9 +111,9 @@ class ProgramSpec:
if u.op is Ops.SPECIAL:
# NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
if u.arg[0][0] == 'i': self.local_size = None
special_size = self.local_size if u.arg[0][0] == 'l' else self.global_size
assert special_size is not None
special_size[int(u.arg[0][-1])] = u.arg[1]
# special_size = self.local_size if u.arg[0][0] == 'l' else self.global_size
# assert special_size is not None
# special_size[int(u.arg[0][-1])] = u.arg[1]
self.vars = sorted(self.vars, key=lambda v: v.arg)
self.outs = sorted(dedup(self.outs))
self.ins = sorted(dedup(self.ins))

View file

@ -1,6 +1,6 @@
from typing import cast
import itertools
from tinygrad.helpers import dedup, DEBUG, to_function_name
from tinygrad.helpers import dedup, DEBUG, to_function_name, getenv
from tinygrad.engine.jit import GraphRunner, GraphException
from tinygrad.device import Buffer
from tinygrad.engine.realize import ExecItem, CompiledRunner
@ -24,17 +24,18 @@ class CPUGraph(GraphRunner):
if buf in input_rawbuffers: return f"arg{input_rawbuffers.index(buf)}"
return f"({device.renderer.render_dtype(buf.dtype)}*)(cbuf{self.base_bufs.index(buf.base)} + {buf.offset})"
batched = ["void batched("+','.join([f"{device.renderer.render_dtype(x[1][0])} {x[0]}" for x in targs])+") {"]
batched = ["void batched("+','.join([f"{device.renderer.render_dtype(x[1][0])} {x[0]}" for x in targs])+", int gl0, void* sync) {"]
for i, ji in enumerate(jit_cache):
args = [render_arg(buf) for buf in ji.bufs] + [x.expr for x in cast(CompiledRunner, ji.prg).p.vars]
batched.append(f" {to_function_name(cast(CompiledRunner, ji.prg).p.name)}({','.join(args)});")
batched.append(f" {to_function_name(cast(CompiledRunner, ji.prg).p.name)}({','.join(args)}, gl0, 0x0);")
if getenv("MULTICORE", 0) != 0: batched.append(f" qurt_barrier_wait(&(((qurt_barrier_t*)sync)[{i}]));")
batched.append("}")
prep = [device.renderer._render(cast(CompiledRunner, ji.prg).p.uops) for i,ji in enumerate(jit_cache)]
funcs = dedup(device.renderer._render_body(prep[i][0], *prep[i][1:], cast(CompiledRunner, ji.prg).p.uops) for i,ji in enumerate(jit_cache))
defines = '\n'.join(set(itertools.chain.from_iterable(device.renderer._render_defines(cast(CompiledRunner, ji.prg).p.uops) for ji in jit_cache)))
entry = device.renderer._render_entry("batched", targs)
defines = '\n'.join(dedup(itertools.chain.from_iterable(device.renderer._render_defines(cast(CompiledRunner, ji.prg).p.uops) for ji in jit_cache)))
entry = device.renderer._render_entry("batched", targs, sync_cnt=len(jit_cache))
code = defines + '\n' + '\n'.join([''.join(f) for f in funcs]) + '\n'.join(batched) + '\n' + entry
if DEBUG >= 4: print(code)

View file

@ -287,8 +287,8 @@ def vectorize_shuffle(vec:UOp):
def multicore_range(r:UOp):
if getenv("MULTICORE", 0) != 1: return None
if any(x.op is Ops.DEFINE_VAR for x in r.toposort): return None
core = UOp(Ops.DEFINE_VAR, dtypes.int, arg=("core", 0, 1))
if any(x.op is Ops.SPECIAL for x in r.toposort): return None
core = UOp(Ops.SPECIAL, dtypes.int, arg=("g0", 0, 1))
start = (core.eq(0)).where(r.src[0], r.src[1]//2)
end = (core.eq(0)).where(r.src[1]//2, r.src[1])
return r.replace(src=(start,end))
@ -345,32 +345,62 @@ pretty_render = PatternMatcher([
class DSPRenderer(ClangRenderer):
device = "DSP"
supports_float4 = True
global_max = (2, 1, 1)
buffer_suffix = " restrict __attribute__((align_value(128)))"
kernel_prefix = "__attribute__((noinline)) "
pre_matcher = dsp_pm
extra_matcher = dsp_pm_late+ClangRenderer.extra_matcher+pretty_render
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
code_for_op = {k:v for k,v in ClangRenderer.code_for_op.items() if k != Ops.SQRT}
extra_args = ['int global_idx_0', 'void* sync']
code_for_workitem = {"g": lambda x: f"global_idx_{x}"}
def _render_defines(self, uops) -> list[str]:
return ['''/* DSP boilerplate */ struct dcvs_v2_req { int type; int _pad; _Bool dcvs_enable; char dcvs_option; _Bool set_latency; int latency;
_Bool set_dcvs_params; short _pad2; char target_corner; char min_corner; char max_corner; int _pad3[3];};''','int HAP_power_set(void*, void*);',
'typedef union { struct { void *pv; unsigned int len; } buf; struct { int fd; unsigned int offset; } dma; } remote_arg;',
'void* HAP_mmap(void *addr, int len, int prot, int flags, int fd, long offset);', 'int HAP_munmap(void *addr, int len);',
'unsigned long long HAP_perf_get_time_us(void);'] + super()._render_defines(uops)
'unsigned long long HAP_perf_get_time_us(void);', 'typedef unsigned long qurt_thread_t;', 'void qurt_thread_exit(int);',
'typedef struct _qurt_barrier { char padding[64]; } qurt_barrier_t;', 'int qurt_barrier_init(qurt_barrier_t*, unsigned int);',
'int qurt_barrier_wait(qurt_barrier_t*);',
'typedef struct _qurt_thread_attr { char name[16]; unsigned char tcb_partition; unsigned char affinity; unsigned short priority;',
'unsigned char asid; unsigned char bus_priority; unsigned short timetest_id; unsigned int stack_size; void *stack_addr; char padding[96]; } qurt_thread_attr_t;',
'int qurt_thread_join(qurt_thread_t tid, int *status);', 'void* malloc(unsigned int);', 'void free(void*);',
'int qurt_thread_create (qurt_thread_t *thread_id, qurt_thread_attr_t *attr, void (*entrypoint) (void *), void *arg);',
] + super()._render_defines(uops)
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]]) -> str:
msrc = ['int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]], sync_cnt=0x0) -> str:
msrc = ['typedef struct all_args {', *[f'int sz_or_val_{i}; int off{i}; void *buf_{i};' for i in range(len(bufs))], 'void* sync; } all_args_t;']
msrc += ['void threader(all_args_t* args) {']
msrc += [f"{function_name}({', '.join([(f'args->buf_{i}' if isinstance(b[1][0], PtrDType) else f'args->sz_or_val_{i}') for i,b in enumerate(bufs)])}, 1, args->sync);"]
msrc += ['qurt_thread_exit(0); }'
'int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};',
'HAP_power_set((void*)handle, (void*)&req);']
msrc += ['if ((sc>>24) != 2) return 0;']
msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
if sync_cnt > 0:
msrc += [f"qurt_barrier_t* sync = malloc({sync_cnt} * sizeof(qurt_barrier_t));"]
msrc += [f"qurt_barrier_init(&sync[{i}], 2);" for i in range(sync_cnt)]
else: msrc += [f"qurt_barrier_t* sync = 0x0;"]
msrc += ['all_args_t args = { 0 };']
msrc += [f'args.sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
msrc += [f'args.off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
msrc += [f'args.buf_{i} = HAP_mmap(0,args.sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+args.off{i};' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
msrc += ['args.sync = sync;']
msrc += ["qurt_thread_attr_t attr = { 0 };"]
msrc += ["attr.name[0] = 't';", "attr.priority = 255;", "attr.asid = 0;"]
msrc += ["attr.stack_size = (64 << 10);", "attr.stack_addr = malloc(attr.stack_size);"]
msrc += [""]
msrc += ["unsigned long long start = HAP_perf_get_time_us();"]
msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0], PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
if getenv("MULTICORE", 0) != 0:
msrc += ["qurt_thread_t thread_ = 0; qurt_thread_create(&thread_, &attr, (void (*)(void*))threader, (void*)&args);"]
msrc += [f"{function_name}({', '.join([(f'args.buf_{i}' if isinstance(b[1][0], PtrDType) else f'args.sz_or_val_{i}') for i,b in enumerate(bufs)])}, 0, args.sync);"]
if getenv("MULTICORE", 0) != 0:
msrc += ['int status;', f"qurt_thread_join(thread_, &status);"]
msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"]
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
msrc += [f'HAP_munmap(args.buf_{i}, args.sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
msrc += ['free(attr.stack_addr);']
if sync_cnt > 0: msrc += ['free(sync);']
msrc += ["return 0; }"]
return '\n'.join(msrc)
@ -446,10 +476,10 @@ class DSPDevice(Compiled):
try:
self.ion_fd = os.open('/dev/ion', os.O_RDONLY)
# Generate link script to pass into clang. Aligning all used sections to 4k fixes invoke problem.
sections = ['hash', 'text', 'rela.plt', 'got', 'got.plt', 'dynamic', 'dynsym', 'dynstr', 'plt', 'data', 'bss']
sections = ['text', 'rela.plt', 'rela.dyn', 'plt', 'data', 'bss', 'hash', 'dynamic', 'got', 'got.plt', 'dynsym', 'dynstr', 'symtab', 'shstrtab', 'strtab']
sections_link = '\n'.join([f'.{n} : ALIGN(4096) {{ *(.{n}) }}' for n in sections])
with tempfile.NamedTemporaryFile(delete=False) as self.link_ld:
self.link_ld.write(f"SECTIONS {{ . = 0x0; {sections_link}\n /DISCARD/ : {{ *(.note .note.* .gnu.hash .comment) }} }}".encode())
self.link_ld.write(f"SECTIONS {{ . = 0x0;\n{sections_link}\n /DISCARD/ : {{ *(.note .note.* .gnu.hash .comment) }} }}".encode())
self.link_ld.flush()
from tinygrad.runtime.graph.cpu import CPUGraph