mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
8005e6c974
commit
bb7b89475c
5 changed files with 90 additions and 29 deletions
|
|
@ -1,12 +1,34 @@
|
|||
import pickle, sys
|
||||
from dataclasses import replace
|
||||
from tinygrad import Device, Context
|
||||
from tinygrad import Device, Context, Tensor
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.helpers import getenv, BEAM
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||
import numpy as np
|
||||
|
||||
def move_jit_captured_to_dev(captured, device="DSP"):
|
||||
captured.expected_st_vars_dtype_device = [x[:3] + (device,) for x in captured.expected_st_vars_dtype_device]
|
||||
|
||||
assign = {}
|
||||
def move_buffer(b):
|
||||
if b in assign: return assign[b]
|
||||
|
||||
if b._base is not None:
|
||||
newbuf = Buffer(device, b.size, b.dtype, base=move_buffer(b._base), offset=b.offset)
|
||||
else:
|
||||
newbuf = Buffer(device, b.size, b.dtype)
|
||||
if b.is_allocated(): newbuf.ensure_allocated().copyin(b.as_buffer())
|
||||
assign[b] = newbuf
|
||||
return assign[b]
|
||||
|
||||
for item in captured.jit_cache:
|
||||
for b in item.bufs:
|
||||
if b is not None: move_buffer(b)
|
||||
captured.jit_cache = [ExecItem(item.prg, [assign.get(b,b) for b in item.bufs]) for item in captured.jit_cache]
|
||||
return captured
|
||||
|
||||
if __name__ == "__main__":
|
||||
with Context(DEBUG=0):
|
||||
|
|
@ -15,6 +37,10 @@ if __name__ == "__main__":
|
|||
print(f"{f.tell()/1e6:.2f}M loaded")
|
||||
print(type(fxn))
|
||||
|
||||
# Move all buffers to DSP device.
|
||||
fxn.captured = move_jit_captured_to_dev(fxn.captured, "DSP")
|
||||
new_jit = []
|
||||
|
||||
knum = 1
|
||||
for ei in fxn.captured.jit_cache:
|
||||
# skip the copy and the first kernel
|
||||
|
|
@ -22,14 +48,18 @@ if __name__ == "__main__":
|
|||
if knum == (pknum:=getenv("KNUM", 0)) or pknum == 0:
|
||||
p: ProgramSpec = ei.prg.p
|
||||
k = Kernel(p.ast, Device["DSP"].renderer)
|
||||
dsp_bufs = [Buffer("DSP", 8192+b.size, b.dtype).view(b.size, b.dtype, 4096) for b in ei.bufs]
|
||||
k.hand_coded_optimizations()
|
||||
#if knum == 13: k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
|
||||
p2 = k.to_program()
|
||||
new_ei = replace(ei, prg=CompiledRunner(p2), bufs=dsp_bufs)
|
||||
new_ei = replace(ei, prg=CompiledRunner(p2))
|
||||
if getenv("MULTICORE", 0) == 1:
|
||||
new_ei.run({p2.vars[0]: 0})
|
||||
new_ei.run({p2.vars[0]: 1})
|
||||
new_ei.run()
|
||||
else:
|
||||
new_ei.run()
|
||||
new_jit.append(new_ei)
|
||||
knum += 1
|
||||
|
||||
if getenv("RUN_JIT", 0):
|
||||
fxn.captured.free_intermediates()
|
||||
fxn.captured.jit_cache = new_jit
|
||||
fxn(input=Tensor(np.zeros((1, 3, 224, 224), dtype=np.float32), device="DSP"))
|
||||
|
|
|
|||
|
|
@ -168,8 +168,8 @@ class CapturedJit(Generic[ReturnType]):
|
|||
update_depends(depends, self.jit_cache)
|
||||
for b in depends:
|
||||
if b is not None:
|
||||
b.deallocate()
|
||||
if b._base is not None and b._base.allocated_views == 0: b._base.deallocate()
|
||||
if b.is_allocated(): b.deallocate()
|
||||
if b._base is not None and b._base.allocated_views == 0 and b._base.is_allocated(): b._base.deallocate()
|
||||
self.__post_init__() # reset the graph state
|
||||
|
||||
def optimize_weights(self):
|
||||
|
|
|
|||
|
|
@ -111,9 +111,9 @@ class ProgramSpec:
|
|||
if u.op is Ops.SPECIAL:
|
||||
# NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
|
||||
if u.arg[0][0] == 'i': self.local_size = None
|
||||
special_size = self.local_size if u.arg[0][0] == 'l' else self.global_size
|
||||
assert special_size is not None
|
||||
special_size[int(u.arg[0][-1])] = u.arg[1]
|
||||
# special_size = self.local_size if u.arg[0][0] == 'l' else self.global_size
|
||||
# assert special_size is not None
|
||||
# special_size[int(u.arg[0][-1])] = u.arg[1]
|
||||
self.vars = sorted(self.vars, key=lambda v: v.arg)
|
||||
self.outs = sorted(dedup(self.outs))
|
||||
self.ins = sorted(dedup(self.ins))
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import cast
|
||||
import itertools
|
||||
from tinygrad.helpers import dedup, DEBUG, to_function_name
|
||||
from tinygrad.helpers import dedup, DEBUG, to_function_name, getenv
|
||||
from tinygrad.engine.jit import GraphRunner, GraphException
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
|
|
@ -24,17 +24,18 @@ class CPUGraph(GraphRunner):
|
|||
if buf in input_rawbuffers: return f"arg{input_rawbuffers.index(buf)}"
|
||||
return f"({device.renderer.render_dtype(buf.dtype)}*)(cbuf{self.base_bufs.index(buf.base)} + {buf.offset})"
|
||||
|
||||
batched = ["void batched("+','.join([f"{device.renderer.render_dtype(x[1][0])} {x[0]}" for x in targs])+") {"]
|
||||
batched = ["void batched("+','.join([f"{device.renderer.render_dtype(x[1][0])} {x[0]}" for x in targs])+", int gl0, void* sync) {"]
|
||||
for i, ji in enumerate(jit_cache):
|
||||
args = [render_arg(buf) for buf in ji.bufs] + [x.expr for x in cast(CompiledRunner, ji.prg).p.vars]
|
||||
batched.append(f" {to_function_name(cast(CompiledRunner, ji.prg).p.name)}({','.join(args)});")
|
||||
batched.append(f" {to_function_name(cast(CompiledRunner, ji.prg).p.name)}({','.join(args)}, gl0, 0x0);")
|
||||
if getenv("MULTICORE", 0) != 0: batched.append(f" qurt_barrier_wait(&(((qurt_barrier_t*)sync)[{i}]));")
|
||||
batched.append("}")
|
||||
|
||||
prep = [device.renderer._render(cast(CompiledRunner, ji.prg).p.uops) for i,ji in enumerate(jit_cache)]
|
||||
funcs = dedup(device.renderer._render_body(prep[i][0], *prep[i][1:], cast(CompiledRunner, ji.prg).p.uops) for i,ji in enumerate(jit_cache))
|
||||
|
||||
defines = '\n'.join(set(itertools.chain.from_iterable(device.renderer._render_defines(cast(CompiledRunner, ji.prg).p.uops) for ji in jit_cache)))
|
||||
entry = device.renderer._render_entry("batched", targs)
|
||||
defines = '\n'.join(dedup(itertools.chain.from_iterable(device.renderer._render_defines(cast(CompiledRunner, ji.prg).p.uops) for ji in jit_cache)))
|
||||
entry = device.renderer._render_entry("batched", targs, sync_cnt=len(jit_cache))
|
||||
code = defines + '\n' + '\n'.join([''.join(f) for f in funcs]) + '\n'.join(batched) + '\n' + entry
|
||||
|
||||
if DEBUG >= 4: print(code)
|
||||
|
|
|
|||
|
|
@ -287,8 +287,8 @@ def vectorize_shuffle(vec:UOp):
|
|||
|
||||
def multicore_range(r:UOp):
|
||||
if getenv("MULTICORE", 0) != 1: return None
|
||||
if any(x.op is Ops.DEFINE_VAR for x in r.toposort): return None
|
||||
core = UOp(Ops.DEFINE_VAR, dtypes.int, arg=("core", 0, 1))
|
||||
if any(x.op is Ops.SPECIAL for x in r.toposort): return None
|
||||
core = UOp(Ops.SPECIAL, dtypes.int, arg=("g0", 0, 1))
|
||||
start = (core.eq(0)).where(r.src[0], r.src[1]//2)
|
||||
end = (core.eq(0)).where(r.src[1]//2, r.src[1])
|
||||
return r.replace(src=(start,end))
|
||||
|
|
@ -345,32 +345,62 @@ pretty_render = PatternMatcher([
|
|||
class DSPRenderer(ClangRenderer):
|
||||
device = "DSP"
|
||||
supports_float4 = True
|
||||
global_max = (2, 1, 1)
|
||||
buffer_suffix = " restrict __attribute__((align_value(128)))"
|
||||
kernel_prefix = "__attribute__((noinline)) "
|
||||
pre_matcher = dsp_pm
|
||||
extra_matcher = dsp_pm_late+ClangRenderer.extra_matcher+pretty_render
|
||||
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
|
||||
code_for_op = {k:v for k,v in ClangRenderer.code_for_op.items() if k != Ops.SQRT}
|
||||
extra_args = ['int global_idx_0', 'void* sync']
|
||||
code_for_workitem = {"g": lambda x: f"global_idx_{x}"}
|
||||
|
||||
def _render_defines(self, uops) -> list[str]:
|
||||
return ['''/* DSP boilerplate */ struct dcvs_v2_req { int type; int _pad; _Bool dcvs_enable; char dcvs_option; _Bool set_latency; int latency;
|
||||
_Bool set_dcvs_params; short _pad2; char target_corner; char min_corner; char max_corner; int _pad3[3];};''','int HAP_power_set(void*, void*);',
|
||||
'typedef union { struct { void *pv; unsigned int len; } buf; struct { int fd; unsigned int offset; } dma; } remote_arg;',
|
||||
'void* HAP_mmap(void *addr, int len, int prot, int flags, int fd, long offset);', 'int HAP_munmap(void *addr, int len);',
|
||||
'unsigned long long HAP_perf_get_time_us(void);'] + super()._render_defines(uops)
|
||||
'unsigned long long HAP_perf_get_time_us(void);', 'typedef unsigned long qurt_thread_t;', 'void qurt_thread_exit(int);',
|
||||
'typedef struct _qurt_barrier { char padding[64]; } qurt_barrier_t;', 'int qurt_barrier_init(qurt_barrier_t*, unsigned int);',
|
||||
'int qurt_barrier_wait(qurt_barrier_t*);',
|
||||
'typedef struct _qurt_thread_attr { char name[16]; unsigned char tcb_partition; unsigned char affinity; unsigned short priority;',
|
||||
'unsigned char asid; unsigned char bus_priority; unsigned short timetest_id; unsigned int stack_size; void *stack_addr; char padding[96]; } qurt_thread_attr_t;',
|
||||
'int qurt_thread_join(qurt_thread_t tid, int *status);', 'void* malloc(unsigned int);', 'void free(void*);',
|
||||
'int qurt_thread_create (qurt_thread_t *thread_id, qurt_thread_attr_t *attr, void (*entrypoint) (void *), void *arg);',
|
||||
] + super()._render_defines(uops)
|
||||
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]]) -> str:
|
||||
msrc = ['int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]], sync_cnt=0x0) -> str:
|
||||
msrc = ['typedef struct all_args {', *[f'int sz_or_val_{i}; int off{i}; void *buf_{i};' for i in range(len(bufs))], 'void* sync; } all_args_t;']
|
||||
msrc += ['void threader(all_args_t* args) {']
|
||||
msrc += [f"{function_name}({', '.join([(f'args->buf_{i}' if isinstance(b[1][0], PtrDType) else f'args->sz_or_val_{i}') for i,b in enumerate(bufs)])}, 1, args->sync);"]
|
||||
msrc += ['qurt_thread_exit(0); }'
|
||||
'int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
|
||||
'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};',
|
||||
'HAP_power_set((void*)handle, (void*)&req);']
|
||||
msrc += ['if ((sc>>24) != 2) return 0;']
|
||||
msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
|
||||
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
if sync_cnt > 0:
|
||||
msrc += [f"qurt_barrier_t* sync = malloc({sync_cnt} * sizeof(qurt_barrier_t));"]
|
||||
msrc += [f"qurt_barrier_init(&sync[{i}], 2);" for i in range(sync_cnt)]
|
||||
else: msrc += [f"qurt_barrier_t* sync = 0x0;"]
|
||||
msrc += ['all_args_t args = { 0 };']
|
||||
msrc += [f'args.sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
|
||||
msrc += [f'args.off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
msrc += [f'args.buf_{i} = HAP_mmap(0,args.sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+args.off{i};' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
msrc += ['args.sync = sync;']
|
||||
msrc += ["qurt_thread_attr_t attr = { 0 };"]
|
||||
msrc += ["attr.name[0] = 't';", "attr.priority = 255;", "attr.asid = 0;"]
|
||||
msrc += ["attr.stack_size = (64 << 10);", "attr.stack_addr = malloc(attr.stack_size);"]
|
||||
msrc += [""]
|
||||
msrc += ["unsigned long long start = HAP_perf_get_time_us();"]
|
||||
msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0], PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
|
||||
if getenv("MULTICORE", 0) != 0:
|
||||
msrc += ["qurt_thread_t thread_ = 0; qurt_thread_create(&thread_, &attr, (void (*)(void*))threader, (void*)&args);"]
|
||||
msrc += [f"{function_name}({', '.join([(f'args.buf_{i}' if isinstance(b[1][0], PtrDType) else f'args.sz_or_val_{i}') for i,b in enumerate(bufs)])}, 0, args.sync);"]
|
||||
if getenv("MULTICORE", 0) != 0:
|
||||
msrc += ['int status;', f"qurt_thread_join(thread_, &status);"]
|
||||
msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"]
|
||||
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
msrc += [f'HAP_munmap(args.buf_{i}, args.sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
msrc += ['free(attr.stack_addr);']
|
||||
if sync_cnt > 0: msrc += ['free(sync);']
|
||||
msrc += ["return 0; }"]
|
||||
return '\n'.join(msrc)
|
||||
|
||||
|
|
@ -446,10 +476,10 @@ class DSPDevice(Compiled):
|
|||
try:
|
||||
self.ion_fd = os.open('/dev/ion', os.O_RDONLY)
|
||||
# Generate link script to pass into clang. Aligning all used sections to 4k fixes invoke problem.
|
||||
sections = ['hash', 'text', 'rela.plt', 'got', 'got.plt', 'dynamic', 'dynsym', 'dynstr', 'plt', 'data', 'bss']
|
||||
sections = ['text', 'rela.plt', 'rela.dyn', 'plt', 'data', 'bss', 'hash', 'dynamic', 'got', 'got.plt', 'dynsym', 'dynstr', 'symtab', 'shstrtab', 'strtab']
|
||||
sections_link = '\n'.join([f'.{n} : ALIGN(4096) {{ *(.{n}) }}' for n in sections])
|
||||
with tempfile.NamedTemporaryFile(delete=False) as self.link_ld:
|
||||
self.link_ld.write(f"SECTIONS {{ . = 0x0; {sections_link}\n /DISCARD/ : {{ *(.note .note.* .gnu.hash .comment) }} }}".encode())
|
||||
self.link_ld.write(f"SECTIONS {{ . = 0x0;\n{sections_link}\n /DISCARD/ : {{ *(.note .note.* .gnu.hash .comment) }} }}".encode())
|
||||
self.link_ld.flush()
|
||||
|
||||
from tinygrad.runtime.graph.cpu import CPUGraph
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue