mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
dsp matchers + bump line count to 11300 (#9130)
This commit is contained in:
parent
638d925e4e
commit
af9d8d39d2
2 changed files with 27 additions and 14 deletions
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
|
|
@ -263,8 +263,8 @@ jobs:
|
|||
run: awk '/```python/{flag=1;next}/```/{flag=0}flag' README.md > README.py && PYTHONPATH=. python README.py
|
||||
- name: Run unit tests
|
||||
run: PYTHONPATH="." python -m pytest -n=auto test/unit/
|
||||
- name: Repo line count < 11200 lines
|
||||
run: MAX_LINE_COUNT=11200 python sz.py
|
||||
- name: Repo line count < 11300 lines
|
||||
run: MAX_LINE_COUNT=11300 python sz.py
|
||||
|
||||
fuzzing:
|
||||
name: Fuzzing
|
||||
|
|
|
|||
|
|
@ -12,9 +12,19 @@ if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-i
|
|||
from tinygrad.ops import PatternMatcher, UPat
|
||||
|
||||
dsp_pm = PatternMatcher([
|
||||
(UPat(Ops.VECTORIZE, src=UPat.var("y"))*UPat.var("x"), lambda x,y: UOp(Ops.CUSTOM, x.dtype, (y,), arg="{0}")*x),
|
||||
(((UPat.var('x').maximum(0) ^ -1).maximum(-256) ^ -1).cast(dtypes.uchar.vec(128)),
|
||||
lambda x: UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=tuple(x.gep(tuple(range(i, i+32))) for i in range(0, 128, 32)),
|
||||
arg="__builtin_HEXAGON_V6_vpackhub_sat_128B(__builtin_HEXAGON_V6_vpackwh_sat_128B({3}, {2}), __builtin_HEXAGON_V6_vpackwh_sat_128B({1}, {0}))")),
|
||||
(UPat(Ops.GEP, name="x"), lambda x: UOp(Ops.CUSTOM, x.dtype, x.src+x.src,
|
||||
"__builtin_shufflevector({0}, {1}, "+','.join([str(y) for y in x.arg])+")") if len(x.arg) > 1 else None),
|
||||
])
|
||||
|
||||
dsp_pm_late = PatternMatcher([
|
||||
(UPat.var("x")+UPat(Ops.VECTORIZE, src=UPat.var("y")), lambda x,y: x+UOp(Ops.CUSTOM, x.dtype, (y,), arg="{0}")),
|
||||
(UPat.var("x")*UPat(Ops.VECTORIZE, src=UPat.var("y")), lambda x,y: x*UOp(Ops.CUSTOM, x.dtype, (y,), arg="{0}")),
|
||||
(UPat.var("x")//UPat(Ops.VECTORIZE, src=UPat.var("y")), lambda x,y: x//UOp(Ops.CUSTOM, x.dtype, (y,), arg="{0}")),
|
||||
(UPat(Ops.DEFINE_ACC, src=(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST, arg=0)),), dtype=dtypes.uchar.vec(128), name="d", allow_any_len=True),
|
||||
lambda d: d.replace(src=(UOp(Ops.CUSTOM, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:]))
|
||||
lambda d: d.replace(src=(UOp(Ops.CUSTOM, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:])),
|
||||
])
|
||||
|
||||
class DSPRenderer(ClangRenderer):
|
||||
|
|
@ -22,7 +32,8 @@ class DSPRenderer(ClangRenderer):
|
|||
supports_float4 = True
|
||||
buffer_suffix = " restrict __attribute__((align_value(128)))"
|
||||
kernel_prefix = "__attribute__((noinline)) "
|
||||
extra_matcher = dsp_pm+ClangRenderer.extra_matcher
|
||||
pre_matcher = dsp_pm
|
||||
extra_matcher = dsp_pm_late+ClangRenderer.extra_matcher
|
||||
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
|
||||
code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})",
|
||||
Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})",
|
||||
|
|
@ -237,20 +248,22 @@ class RPCListener(threading.Thread):
|
|||
|
||||
# ***** mock DSP *****
|
||||
|
||||
mockdsp_boilerplate = '''/* DSP boilerplate */ static long syscall(long r0, long r1, long r2, long r3, long r4, long r5, long r6) {
|
||||
long retval; __asm__ volatile("r0 = %1; r1 = %2; r2 = %3; r3 = %4; r4 = %5; r5 = %6; r6 = %7; trap0(#1); %0 = r0" : "=r" (retval)
|
||||
: "r" (r0), "r" (r1), "r" (r2), "r" (r3), "r" (r4), "r" (r5), "r" (r6) : "r0", "r1", "r2", "r3", "r4", "r5", "r6"); return retval; }
|
||||
static int read(int fd, void* buf, int len) {{ return syscall(fd, (long)buf, len, 0, 0, 0, 63); }}
|
||||
static int write(int fd, void* buf, int len) {{ return syscall(fd, (long)buf, len, 0, 0, 0, 64); }}
|
||||
static int exit(int ret) {{ return syscall(ret, 0, 0, 0, 0, 0, 93); }}
|
||||
static unsigned int inscount(void) {{ unsigned int ret; __asm__ volatile(".word 0x6a15c000; %0 = R0" : "=r" (ret) : : "r0"); return ret; }}
|
||||
static void *mmap2(void *addr, unsigned int length, int prot, int flags, int fd, unsigned long offset) {{
|
||||
return (void*)syscall((long)addr, length, prot, flags, fd, offset, 222); }}'''
|
||||
|
||||
class MockDSPRenderer(DSPRenderer):
|
||||
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
|
||||
ret = ClangRenderer.render_kernel(self, function_name, kernel, bufs, uops, prefix)
|
||||
# https://gpages.juszkiewicz.com.pl/syscalls-table/syscalls.html
|
||||
# control register 21 is HEX_REG_QEMU_INSN_CNT, 0x6a15c000 loads it
|
||||
msrc = ['''/* DSP boilerplate */ static long syscall(long r0, long r1, long r2, long r3, long r4, long r5, long r6) {
|
||||
long retval; __asm__ volatile("r0 = %1; r1 = %2; r2 = %3; r3 = %4; r4 = %5; r5 = %6; r6 = %7; trap0(#1); %0 = r0" : "=r" (retval)
|
||||
: "r" (r0), "r" (r1), "r" (r2), "r" (r3), "r" (r4), "r" (r5), "r" (r6) : "r0", "r1", "r2", "r3", "r4", "r5", "r6"); return retval; }
|
||||
static int read(int fd, void* buf, int len) {{ return syscall(fd, (long)buf, len, 0, 0, 0, 63); }}
|
||||
static int write(int fd, void* buf, int len) {{ return syscall(fd, (long)buf, len, 0, 0, 0, 64); }}
|
||||
static int exit(int ret) {{ return syscall(ret, 0, 0, 0, 0, 0, 93); }}
|
||||
static unsigned int inscount(void) {{ unsigned int ret; __asm__ volatile(".word 0x6a15c000; %0 = R0" : "=r" (ret) : : "r0"); return ret; }}
|
||||
static void *mmap2(void *addr, unsigned int length, int prot, int flags, int fd, unsigned long offset) {{
|
||||
return (void*)syscall((long)addr, length, prot, flags, fd, offset, 222); }}''', 'void _start(void) {']
|
||||
msrc = [mockdsp_boilerplate, 'void _start(void) {']
|
||||
for i,b in enumerate(bufs):
|
||||
if isinstance(b[1][0], PtrDType):
|
||||
sz = b[1][0].size*b[1][0].itemsize
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue