mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
219 commits
master
...
dsp_search
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
07102624a1 |
||
|
|
ed76dd71eb |
||
|
|
5d6e8bd681 | ||
|
|
f90656c647 | ||
|
|
86727875f9 | ||
|
|
7e4ac744ac |
||
|
|
4496cc6e61 | ||
|
|
13d3bcb6e1 | ||
|
|
efad1ebd0d |
||
|
|
e20eed6208 | ||
|
|
7dc265ef93 | ||
|
|
13dec71ab0 | ||
|
|
bb453613ac | ||
|
|
d6c3ae186b | ||
|
|
c066653428 | ||
|
|
66c6d35fe2 | ||
|
|
64e1ddf2a9 | ||
|
|
d6013a2d50 | ||
|
|
1d36aa8116 |
||
|
|
6251ab3d90 | ||
|
|
dd51728795 | ||
|
|
a59b1ed970 | ||
|
|
147fc0e648 | ||
|
|
17f7b226cb | ||
|
|
9c34d9eb6e | ||
|
|
95261b6193 | ||
|
|
1e2becfeae | ||
|
|
e18cdbcbe2 | ||
|
|
f3cb4c3eef | ||
|
|
6ecaf11224 | ||
|
|
8b24f9cb0d | ||
|
|
797e512c00 | ||
|
|
f600482982 | ||
|
|
da35edbb55 | ||
|
|
661431ee75 | ||
|
|
8340d9c1c2 | ||
|
|
910cddbbca | ||
|
|
e6e0c0ec86 | ||
|
|
d0eedb5a79 | ||
|
|
f69deddbd4 | ||
|
|
be11fbbf78 | ||
|
|
812c391617 | ||
|
|
3306083f42 | ||
|
|
18d7e9d3f1 | ||
|
|
1c3f249ecf | ||
|
|
bb7b89475c |
||
|
|
8005e6c974 | ||
|
|
a3d61a0372 | ||
|
|
c73e35aa24 | ||
|
|
0b4b9f61b9 | ||
|
|
ee3ddfcdc1 | ||
|
|
220d682489 | ||
|
|
9c388c3539 | ||
|
|
4b3a4c8c46 | ||
|
|
eb606d7230 | ||
|
|
49d52a2763 | ||
|
|
a59c3dd09a | ||
|
|
a640292aed | ||
|
|
2f48c12441 |
||
|
|
be3b5efc64 | ||
|
|
996d0ac1d2 | ||
|
|
77e897b3b1 |
||
|
|
273dde69bd | ||
|
|
a64030d8c8 | ||
|
|
9b19129e87 | ||
|
|
48221d9024 | ||
|
|
bcfcd60f55 | ||
|
|
abc90024ac | ||
|
|
f0e6d8394c |
||
|
|
a1c1ecd597 |
||
|
|
489a5e24c4 |
||
|
|
e0fd84dd64 | ||
|
|
1a9d7a1628 | ||
|
|
45646fe102 | ||
|
|
9c928afafe | ||
|
|
d4f1c5049b | ||
|
|
11b478f85d | ||
|
|
0aa7031b5f | ||
|
|
ab67d5ff6e | ||
|
|
cbe23e13c2 | ||
|
|
9bbd12dc65 | ||
|
|
b09142a893 | ||
|
|
1d7faf4777 | ||
|
|
59438be39b | ||
|
|
cc23836a38 | ||
|
|
e4354effa2 |
||
|
|
d180e909a3 | ||
|
|
52364231dc | ||
|
|
d32ad080c3 | ||
|
|
a8bd26d9bc | ||
|
|
6d860389f4 | ||
|
|
5d5286489d | ||
|
|
917e0e925b | ||
|
|
6081f8427e | ||
|
|
23035bf028 | ||
|
|
5e33163ef3 |
||
|
|
cee9fc7540 | ||
|
|
9041072dea |
||
|
|
444d6279ac | ||
|
|
f27f484621 | ||
|
|
38488ec3b0 | ||
|
|
ff96f0adae | ||
|
|
5dd59a6096 | ||
|
|
6bec82b918 | ||
|
|
a436d7542f | ||
|
|
5d98688de6 | ||
|
|
09d877ed8c |
||
|
|
6ff894d674 | ||
|
|
da03b4520a |
||
|
|
013c6e0b10 | ||
|
|
31ffa1607e |
||
|
|
928994c6ea | ||
|
|
e283bec62e | ||
|
|
c4f5db8467 | ||
|
|
bf0d928417 | ||
|
|
f823324eb9 | ||
|
|
6995e0c91b |
||
|
|
b934b5b907 | ||
|
|
290ba9ee37 | ||
|
|
e0d63696d7 | ||
|
|
acafd57f14 |
||
|
|
905f847d10 | ||
|
|
9e19cdfbbe | ||
|
|
f7b38fa94c | ||
|
|
bd03942bd8 | ||
|
|
880b4a5e47 | ||
|
|
2e4cae342b | ||
|
|
8660fecb02 | ||
|
|
e3e43df0c9 | ||
|
|
a47e61b097 | ||
|
|
f1ff18acec | ||
|
|
60cbfe4222 |
||
|
|
311df3ff21 | ||
|
|
f6e64a5e8e | ||
|
|
622ff115a3 | ||
|
|
5a6e8ee268 | ||
|
|
a9f1227625 | ||
|
|
74c2587ef4 | ||
|
|
bce252e0b8 | ||
|
|
66a90a3c92 | ||
|
|
0d76b0d461 | ||
|
|
5e4505d363 | ||
|
|
29920b74d5 | ||
|
|
ccd18a803c | ||
|
|
943bde47ab | ||
|
|
0d10c7ae2f | ||
|
|
3cab6a3d4a | ||
|
|
22a56cbaea | ||
|
|
afd61730b4 | ||
|
|
536556434b | ||
|
|
52bff5f39d | ||
|
|
64d0f14d3d | ||
|
|
1b61cc6ec3 | ||
|
|
6f792e8045 | ||
|
|
b1f8018bf4 | ||
|
|
2eb9241329 | ||
|
|
554a490751 |
||
|
|
651c678edf | ||
|
|
3274bd2d81 | ||
|
|
30f4d64148 | ||
|
|
2634975d5a | ||
|
|
fd73ec2b1b | ||
|
|
e1d2bec4a4 | ||
|
|
1b4e9f5e91 | ||
|
|
25c023bcbe | ||
|
|
07abf9e6bc | ||
|
|
26b02a037c | ||
|
|
5089a601c6 | ||
|
|
6b49a63c48 | ||
|
|
dca95428a5 | ||
|
|
8a477ba4e1 | ||
|
|
264dd91b8a | ||
|
|
bdf716b915 | ||
|
|
cf41c803d0 | ||
|
|
3cf9224df5 | ||
|
|
af94addb3a | ||
|
|
dc1469a188 | ||
|
|
0416b0998d | ||
|
|
c715c25420 |
||
|
|
f66b03f0a6 | ||
|
|
2729a46ca6 | ||
|
|
dbb50e4a00 | ||
|
|
71c7c455a6 | ||
|
|
ff3438be4e | ||
|
|
bc5e23061b | ||
|
|
5ce951fb34 | ||
|
|
4a49d05a3f |
||
|
|
c3c85c64ee | ||
|
|
61c02ca634 | ||
|
|
325044bcaf | ||
|
|
91ac508878 | ||
|
|
2ed30f5366 | ||
|
|
d0b9c7e7ca | ||
|
|
f6ed8f4a27 | ||
|
|
87718170d2 | ||
|
|
b67af4049c | ||
|
|
16e425a4c0 | ||
|
|
c867a48ab4 | ||
|
|
2dc82c0604 | ||
|
|
e7402e6643 | ||
|
|
e5ccd9e846 | ||
|
|
624197f169 | ||
|
|
d42350a401 | ||
|
|
223feb2118 |
||
|
|
8eb9093fb8 | ||
|
|
45f7c08111 | ||
|
|
58fc77fdb3 | ||
|
|
e57258b17b | ||
|
|
31cd00e72f | ||
|
|
b00ccc08c3 | ||
|
|
94d578aec5 | ||
|
|
45010f7eff | ||
|
|
249141026e | ||
|
|
a913c1aab7 | ||
|
|
469ec6b6b4 | ||
|
|
1a84d504b7 |
||
|
|
14c9f14125 | ||
|
|
cc0041cb8c | ||
|
|
e4615e0cd9 |
11 changed files with 640 additions and 54 deletions
|
|
@ -415,6 +415,7 @@ def get_onnx_ops():
|
|||
|
||||
def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
|
||||
kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1):
|
||||
if W.shape[1:] == (1,3,3) and group > 1: group = W.shape[0]
|
||||
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations,
|
||||
padding=_resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad))
|
||||
|
||||
|
|
@ -724,6 +725,19 @@ def get_onnx_ops():
|
|||
ret = _clamp_cast((x / y_scale + 0.4999999 + y_zero_point).int(), out_dtype)
|
||||
else:
|
||||
ret = _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype)
|
||||
# you need both NHWC=1 DONT_GROUP_REDUCES=1 for this to work
|
||||
if getenv("NHWC") and len(ret.shape) == 4:
|
||||
in_chans = ret.shape[1]
|
||||
if ret.shape[1] == 3 or in_chans%32 != 0:
|
||||
return ret.permute(0,2,3,1).contiguous().permute(0,3,1,2)
|
||||
else:
|
||||
if in_chans%32 != 0: ret = ret.pad(((0,0), (0,32-(in_chans%32)), (0,0), (0,0)))
|
||||
ret = ret.reshape(ret.shape[0], ret.shape[1]//32, 32, ret.shape[-2], ret.shape[-1])
|
||||
order = (0, 1, 3, 4, 2)
|
||||
ret = ret.permute(order).contiguous().permute(*argsort(order))
|
||||
ret = ret.reshape(ret.shape[0], -1, ret.shape[-2], ret.shape[-1])
|
||||
if in_chans%32 != 0: ret = ret[:, :in_chans, :, :]
|
||||
return ret
|
||||
return ret.contiguous()
|
||||
|
||||
def DynamicQuantizeLinear(x: Tensor):
|
||||
|
|
@ -735,6 +749,66 @@ def get_onnx_ops():
|
|||
return y, scale, zero_point
|
||||
|
||||
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0):
|
||||
if getenv("NHWC"):
|
||||
# pad channels
|
||||
in_shape = x.shape
|
||||
if len(x.shape) == 4 and x.shape[1:] == (1,3,3) and x.shape[0]%32 != 0:
|
||||
# 3x3 depthwise (C,1,3,3). pad C to 32
|
||||
x = x.pad(((0,32-(x.shape[0]%32)), (0,0), (0,0), (0,0)))
|
||||
elif len(x.shape) == 4 and x.shape[2:] == (1,1) and x.shape[0] != 1:
|
||||
# 1x1 conv (C_out,C_in,1,1), pad C_out and C_in to 32
|
||||
if x.shape[0]%32 != 0: x = x.pad(((0,32-(x.shape[0]%32)), (0,0), (0,0), (0,0)))
|
||||
if x.shape[1]%32 != 0: x = x.pad(((0,0), (0,32-(x.shape[1]%32)), (0,0), (0,0)))
|
||||
elif len(x.shape) == 1 and x.shape[0]%32 != 0 and x.shape[0] != 1000:
|
||||
# bias
|
||||
x = x.pad(((0,32-(x.shape[0]%32)),))
|
||||
|
||||
if in_shape != x.shape:
|
||||
xzp = x_zero_point.item()
|
||||
print(f"{in_shape} -> {x.shape}", xzp)
|
||||
# fix up the zero point in the padded area
|
||||
pp = (Tensor.full(in_shape, -xzp, dtype=dtypes.int).pad(tuple([(0, so-si) for si,so in zip(in_shape, x.shape)])) + xzp).cast(x.dtype)
|
||||
x = (x + pp).contiguous()
|
||||
|
||||
if getenv("NHWC") and len(x.shape) == 4 and x.shape[1:] == (3,3,3):
|
||||
x = x.pad(((0,0), (0,0), (0,0), (0,1)))
|
||||
assert x.shape[0] == 32
|
||||
order = (1,2,0,3)
|
||||
x = x.permute(*order).contiguous().permute(*argsort(order))
|
||||
x = x[:, :, :, :3]
|
||||
|
||||
if getenv("NHWC") and len(x.shape) == 4 and x.shape[1:] == (1,3,3):
|
||||
# 3x3 depthwise (C,1,3,3)
|
||||
# "width multiple of 4 depth multiple of 32 aligned to 128bytes"
|
||||
x = x.pad(((0,0), (0,0), (0,0), (0,1)))
|
||||
if x.shape[0]%32 == 0:
|
||||
# depth/32 is a loop -- lsr(depth, #5)
|
||||
# width/4 is a loop -- lsr(out_width, #2)
|
||||
# height is a loop
|
||||
x = x.reshape(-1, 32, 1, 3, 4)
|
||||
order = (0,3,1,2,4)
|
||||
x = x.permute(*order).contiguous().permute(*argsort(order))
|
||||
x = x.reshape(-1, 1, 3, 4)
|
||||
else:
|
||||
assert False # (doesn't happen anymore)
|
||||
#print("HERE", x.shape)
|
||||
order = (2,0,1,3)
|
||||
x = x.permute(*order).contiguous().permute(*argsort(order))
|
||||
x = x[:, :, :, :3]
|
||||
# we increase the filts to 4-aligned for speed (75% util)
|
||||
WEIGHT_SHIFT = 4
|
||||
if getenv("NHWC") and len(x.shape) == 4 and x.shape[2:] == (1,1) and x.shape[1]%WEIGHT_SHIFT == 0:
|
||||
if x.shape[0]%32 == 0:
|
||||
# DSP swizzle memory (big)
|
||||
x = x.reshape(x.shape[0]//32, 32, x.shape[1]//WEIGHT_SHIFT, WEIGHT_SHIFT).permute(0,2,1,3).contiguous().permute(0,2,1,3).reshape(x.shape)
|
||||
else:
|
||||
# DSP swizzle memory
|
||||
x = x.reshape(x.shape[0], x.shape[1]//WEIGHT_SHIFT, WEIGHT_SHIFT).permute(1,0,2).contiguous().permute(1,0,2).reshape(x.shape)
|
||||
if getenv("NHWC") and x.shape == (1000, 1280):
|
||||
x = x.reshape(-1, 320, 4)
|
||||
order = (1,0,2)
|
||||
x = x.permute(*order).contiguous().permute(*argsort(order))
|
||||
x = x.reshape(-1, 1280)
|
||||
x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size)
|
||||
return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from collections import defaultdict
|
|||
from tinygrad.dtype import dtypes, ImageDType, PtrDType
|
||||
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve
|
||||
from tinygrad.ops import graph_rewrite, GroupOp
|
||||
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
|
||||
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat, commutative
|
||||
from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
from tinygrad.renderer import Renderer
|
||||
|
|
@ -12,10 +12,16 @@ from tinygrad.renderer import Renderer
|
|||
# ***** load/store grouping *****
|
||||
|
||||
def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
||||
if getenv("UNSAFE_DISABLE_MASK", 0): mask = None
|
||||
vectorize_mask = getenv("VECTORIZE_MASK", 0) and buf.arg == 0 and mask is not None
|
||||
|
||||
# generate the individual indexes
|
||||
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]),
|
||||
symbolic_flat+load_store_indexing, name=f"index_buf_{buf.arg}")
|
||||
if vectorize_mask:
|
||||
# no load_store_indexing if we are doing vectorized mask
|
||||
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]),
|
||||
symbolic_flat+commutative, name=f"index_buf_{buf.arg}")
|
||||
else:
|
||||
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]),
|
||||
symbolic_flat+commutative+load_store_indexing, name=f"index_buf_{buf.arg}")
|
||||
# extract all the relevant offsets
|
||||
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
|
||||
for i in range(vec.dtype.count):
|
||||
|
|
@ -24,7 +30,7 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
|||
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
|
||||
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
||||
else: root_src, arg = idx, 0
|
||||
if len(midx.src[i].src) == 3: root_src = (midx.src[i].src[2], root_src)
|
||||
if len(midx.src[i].src) == 3 and not vectorize_mask: root_src = (midx.src[i].src[2], root_src)
|
||||
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
|
||||
|
||||
# the buf.dtype is always a pointer
|
||||
|
|
@ -35,10 +41,26 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
|||
idxs: list[int|None] = [None]*vec.dtype.count
|
||||
global_offset = 0
|
||||
for offsets in offsets_rootsrc.values():
|
||||
if 0 in offsets:
|
||||
match = True
|
||||
for i in range(0, max(offsets.keys()), 4):
|
||||
if i in offsets and i+1 in offsets and i+2 in offsets and i+3 not in offsets: pass
|
||||
else: match = False
|
||||
if match:
|
||||
for i in range(0, max(offsets.keys()), 4):
|
||||
assert i+3 not in offsets
|
||||
offsets[i+3] = {}
|
||||
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
|
||||
for grp in grouped_offsets:
|
||||
# get the index offset for this element. using [0] is okay, because they are the same
|
||||
lidx = midx.src[offsets[grp[0]][0]]
|
||||
|
||||
if vectorize_mask:
|
||||
allgrp = [midx.src[offsets[g][0]] for g in grp]
|
||||
base = [x.src[2].cast(dtypes.uchar) if len(x.src) > 2 else UOp.const(dtypes.uchar, 1) for x in allgrp]
|
||||
vecmask = UOp(Ops.VECTORIZE, dtypes.uchar.vec(len(base)), tuple(base))
|
||||
lidx = lidx.replace(src=lidx.src[0:2]+(vecmask,))
|
||||
|
||||
if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, local=ptrdtype.local))
|
||||
# set the idxs of the output
|
||||
for i,g in enumerate(grp):
|
||||
|
|
@ -167,7 +189,11 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
|||
if global_offset+fold_length > sz: continue
|
||||
oidx = idx.src[1] + global_offset
|
||||
if must_divide and oidx.simplify().divides(fold_length) is None: continue
|
||||
lidx = buf.index(oidx, idx.src[2] if len(idx.src) > 2 else None)
|
||||
if len(idx.src) > 2 and idx.src[2].dtype.count > 1:
|
||||
# vectorized
|
||||
lidx = buf.index(oidx, idx.src[2].gep(tuple(range(global_offset, global_offset+fold_length))))
|
||||
else:
|
||||
lidx = buf.index(oidx, idx.src[2] if len(idx.src) > 2 else None)
|
||||
if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, local=ptrdtype.local))
|
||||
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))
|
||||
else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length)))
|
||||
|
|
@ -271,7 +297,34 @@ pm_render = PatternMatcher([
|
|||
|
||||
# *** uop graph ***
|
||||
|
||||
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.ops import identity_element
|
||||
from tinygrad.helpers import partition
|
||||
|
||||
@dataclass
|
||||
class ReduceContext:
|
||||
acc_num: int = 0
|
||||
|
||||
def reduce_to_acc(ctx:ReduceContext, x:UOp):
|
||||
ret = x.src[0]
|
||||
reduce_range, reduce_expand = partition(x.src, lambda y: y.op is Ops.RANGE)
|
||||
if len(reduce_range) == 0: return ret
|
||||
if all(y not in reduce_range for y in ret.toposort):
|
||||
# TODO: this shouldn't be here
|
||||
return ret*prod([y.src[1] for y in reduce_range]).broadcast(ret.dtype.count)
|
||||
alu_op = x.arg
|
||||
# create acc
|
||||
acc = UOp(Ops.DEFINE_ACC, x.dtype, (x.const_like(identity_element(alu_op, x.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,))
|
||||
ctx.acc_num += 1
|
||||
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [acc]+list(reduce_expand))
|
||||
# create ACC and assign
|
||||
return acc.assign(ret)
|
||||
|
||||
pm_reduce = PatternMatcher([
|
||||
(UPat(Ops.REDUCE, name="x"), reduce_to_acc)
|
||||
])
|
||||
|
||||
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None, is_conv=False) -> UOp:
|
||||
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
|
||||
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
|
||||
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
|
||||
|
|
@ -282,7 +335,15 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
|||
else: sink = graph_rewrite(sink, sym+load_store_folding+correct_load_store+load_store_indexing, ctx=opts)
|
||||
|
||||
# optional pre matcher
|
||||
if opts is not None and opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher)
|
||||
if opts is not None and opts.pre_matcher is not None:
|
||||
if is_conv:
|
||||
from tinygrad.runtime.ops_dsp import conv_pm
|
||||
sink = graph_rewrite(sink, conv_pm+opts.pre_matcher)
|
||||
else:
|
||||
sink = graph_rewrite(sink, opts.pre_matcher)
|
||||
|
||||
# remove reduce
|
||||
sink = graph_rewrite(sink, pm_reduce, ctx=ReduceContext(), name="remove_reduce")
|
||||
|
||||
# final rules for the renderer (without sym)
|
||||
sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
# this converts a lowerer program into a vectorized program
|
||||
|
||||
import functools, itertools, operator
|
||||
from tinygrad.helpers import AMX, dedup, flatten, all_same, prod
|
||||
from tinygrad.helpers import AMX, dedup, flatten, all_same, prod, getenv
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, graph_rewrite
|
||||
from tinygrad.codegen.symbolic import sym
|
||||
|
||||
|
|
@ -123,9 +124,21 @@ pm_store_ignore = PatternMatcher([
|
|||
lambda store,mask: store.replace(src=(store.src[0], UOp(Ops.IGNORE, src=(store.src[1], mask)))) if store.src[1].op is not Ops.IGNORE else None),
|
||||
])
|
||||
|
||||
def debug_ignore(x, y):
|
||||
if getenv("DEBUG_IGNORE"):
|
||||
print("****")
|
||||
print("seen ", x.render())
|
||||
print("ignore", y.render())
|
||||
# this is totally wrong
|
||||
return x.const_like(True)
|
||||
|
||||
pm_move_ignore = PatternMatcher([
|
||||
# IGNORE on SELF is nothing
|
||||
(UPat(Ops.IGNORE, src=(UPat(name="x"), UPat(name="x"))), lambda x: x.const_like(True)),
|
||||
# IGNORE debug
|
||||
(UPat(Ops.IGNORE, src=(UPat(dtype=dtypes.bool, name="x"), UPat(dtype=dtypes.bool, name="y"))), debug_ignore),
|
||||
# IGNORE with AND on SELF is nothing (is this right?)
|
||||
#(UPat(Ops.IGNORE, src=(UPat(name="x"), UPat(name="x") & UPat())), lambda x: x.const_like(True)),
|
||||
# IGNORE on a CONST is nothing
|
||||
(UPat(Ops.IGNORE, src=(UPat((Ops.CONST, Ops.VCONST), name="c"), UPat())), lambda c: c),
|
||||
# move the IGNOREs
|
||||
|
|
|
|||
|
|
@ -437,6 +437,61 @@ class Kernel:
|
|||
return self
|
||||
|
||||
def hand_coded_optimizations(self) -> Kernel:
|
||||
if self.opts.device == "DSP":
|
||||
k = self
|
||||
# special path for DSP
|
||||
if k.full_shape[-3:] == (32,3,3):
|
||||
# 3x3 dwconv
|
||||
# kernel 49 is broken
|
||||
if k.full_shape[-4]%4 != 0: k.apply_opt(Opt(OptOps.PADTO, len(k.full_shape)-4, 4))
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 0))
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 0))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-3, 32))
|
||||
if k.full_shape[len(k.full_shape)-4]%4 == 0:
|
||||
#if k.full_shape[len(k.full_shape)-4] <= 8: k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-4, 0))
|
||||
#else: k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-4, 4))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-4, 4))
|
||||
# if this is small, swap it
|
||||
# NOTE: this is breaking something (should be fixed w/o padto)
|
||||
# kernel 23 is broken with this
|
||||
if k.full_shape[0] <= 6: k.apply_opt(Opt(OptOps.SWAP, 0, 1))
|
||||
elif k.full_shape[-4:] == (32,3,3,3):
|
||||
# 3x3 normal conv
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 2, 0))
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 1, 0))
|
||||
# more UNROLLs aren't working well here, but they should be
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 2, 32))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
|
||||
elif len(k.full_shape) == 3 and k.full_shape[1] == 32 and k.first_reduce == 2:
|
||||
# weight that's exactly 32
|
||||
# NOTE: this pad might be broken
|
||||
if k.full_shape[0]%4 != 0: k.apply_opt(Opt(OptOps.PADTO, 0, 4))
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 8))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 32))
|
||||
if k.full_shape[0]%4 == 0: k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
|
||||
elif len(k.full_shape) == 4 and k.full_shape[2] == 32 and k.first_reduce == 3:
|
||||
# weight that has more than 32
|
||||
# NOTE: this pad is broken on kernel 50
|
||||
if k.full_shape[1]%4 != 0: k.apply_opt(Opt(OptOps.PADTO, 1, 4))
|
||||
# weight with more
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 8))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 2, 32))
|
||||
if k.full_shape[1]%4 == 0: k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
|
||||
# if the more is small, upcast it (kernel 50 is broken with this)
|
||||
if k.full_shape[0] <= 6: k.apply_opt(Opt(OptOps.UPCAST, 0, 0))
|
||||
elif len(k.full_shape) == 2 and k.first_reduce == 1:
|
||||
# unroll to 4 if we can
|
||||
if k.full_shape[k.first_reduce]%4 == 0: k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
|
||||
# always pad to 128
|
||||
# NOTE: this breaks kernel 66
|
||||
if k.full_shape[0]%128 != 0: k.apply_opt(Opt(OptOps.PADTO, 0, 128))
|
||||
if k.full_shape[0]%128 == 0: k.apply_opt(Opt(OptOps.UPCAST, 0, 128))
|
||||
elif len(k.full_shape) == 1 and k.full_shape[0] > 1000:
|
||||
# pad to 128 and run on 128
|
||||
if k.full_shape[0]%128 != 0: k.apply_opt(Opt(OptOps.PADTO, 0, 128))
|
||||
if k.full_shape[0]%128 == 0: k.apply_opt(Opt(OptOps.UPCAST, 0, 128))
|
||||
return self
|
||||
|
||||
self.required_optimizations()
|
||||
|
||||
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
|
||||
|
|
@ -678,7 +733,11 @@ class Kernel:
|
|||
# TODO: sadly modified_ast doesn't pass the shape spec because of how group_for_reduces constructs UOps, there's probably a way to fix this
|
||||
#if __debug__: type_verify(list(modified_ast.toposort), shape_spec)
|
||||
|
||||
self.uops:list[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts))
|
||||
is_conv = (len(self.full_shape) == 6 and self.full_shape[2:4] == (3,3))
|
||||
is_conv = is_conv or (len(self.full_shape) == 6 and self.full_shape[3:5] == (3,3))
|
||||
is_conv = is_conv or (len(self.full_shape) == 7 and self.full_shape[3:5] == (3,3))
|
||||
is_conv = is_conv or self.full_shape[-2:] == (3,3)
|
||||
self.uops:list[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts, is_conv))
|
||||
if DEBUG >= 6: print_uops(self.uops)
|
||||
return self
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@ import functools, itertools, operator, math
|
|||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
from tinygrad.dtype import dtypes, PtrDType, least_upper_dtype
|
||||
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop
|
||||
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, sint_to_uop
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap, QUANTIZE
|
||||
from tinygrad.helpers import all_int, prod, flatten, unwrap, QUANTIZE
|
||||
from tinygrad.codegen.expander import expand_rewrite
|
||||
from tinygrad.codegen.symbolic import symbolic
|
||||
|
||||
|
|
@ -112,21 +112,20 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
|||
|
||||
def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
||||
# NOTE: always using ridxs is fine here
|
||||
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
|
||||
#reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
|
||||
reduce_indexes = [ctx.ridxs[i] for i in x.axis_arg]
|
||||
all_nodes = flatten([x.toposort for x in reduce_indexes])
|
||||
reduce_expand = [x for x in all_nodes if x.op is Ops.UNROLL]
|
||||
reduce_range = [x for x in all_nodes if x.op is Ops.RANGE]
|
||||
assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand} for {x.axis_arg}"
|
||||
alu_op: Ops = x.arg[0]
|
||||
ret = x.src[0]
|
||||
# create acc
|
||||
acc = UOp(Ops.DEFINE_ACC, x.dtype, (x.const_like(identity_element(alu_op, x.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,))
|
||||
ctx.acc_num += 1
|
||||
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
||||
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
|
||||
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [acc]+[ret.gep(i) for i in range(ret.dtype.count)])
|
||||
ret = (functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(ret.dtype.count)]),)
|
||||
else:
|
||||
ret = acc.alu(alu_op, ret)
|
||||
if not len(reduce_range): return ret
|
||||
# create ACC and assign
|
||||
return acc.assign(ret)
|
||||
ret = (ret,)
|
||||
return UOp(Ops.REDUCE, x.dtype, ret+tuple(reduce_range), alu_op)
|
||||
|
||||
def lower_load_store(ctx: IndexContext, x: UOp):
|
||||
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
|
||||
|
|
@ -221,6 +220,13 @@ pm_quant = symbolic+PatternMatcher([
|
|||
lambda v1,v2,c1,r: r.replace(src=(v1*v2,)) + r.replace(src=(c1*v2,))),
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")), name="r"),
|
||||
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,)) + r.replace(src=(c1*c2,))),
|
||||
|
||||
# MUL by 1/0 on LOAD where the masks match
|
||||
(UPat(Ops.WHERE, src=(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v1"),)), UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))) * \
|
||||
UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v2")), name="ld"),
|
||||
lambda ld,v1,v2: ld if v1.arg.to_indexed_uops()[1].simplify() == v2.arg.to_indexed_uops()[1].simplify()
|
||||
# NOTE: this clause is completely false and might break things
|
||||
or True else None),
|
||||
])
|
||||
|
||||
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split
|
|||
quo += q * v
|
||||
|
||||
# if numerator is negative, and it has remainder, don't simplify because C divmod is different from python divmod.
|
||||
if x.vmin < 0 and remainders: return None
|
||||
if x.vmin < -10000000 and remainders: return None
|
||||
if which is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd
|
||||
return rem//(c//gcd)+quo
|
||||
|
||||
|
|
@ -201,7 +201,7 @@ commutative = PatternMatcher([
|
|||
(UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
||||
])
|
||||
|
||||
symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
symbolic = symbolic_simple+PatternMatcher([
|
||||
# ** boolean algebra **
|
||||
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
|
||||
# ** combine terms **
|
||||
|
|
|
|||
|
|
@ -335,9 +335,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
|
||||
def simplify(self):
|
||||
# late import!
|
||||
from tinygrad.codegen.symbolic import symbolic
|
||||
from tinygrad.codegen.symbolic import symbolic_flat, commutative
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
return graph_rewrite(self, symbolic)
|
||||
return graph_rewrite(self, symbolic_flat+commutative)
|
||||
def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
|
||||
def _eval(self, dtype, expected_type:Type[T]) -> T:
|
||||
assert self.dtype in dtype, f"eval with wrong dtype {self}"
|
||||
|
|
|
|||
|
|
@ -78,6 +78,9 @@ class Estimates:
|
|||
elif u.op is Ops.STORE: lds += u.src[1].dtype.itemsize * mults
|
||||
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
|
||||
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
||||
elif u.op in {Ops.CUSTOM, Ops.CUSTOMI} and u not in dont_count:
|
||||
if u.arg.startswith("__builtin_HEXAGON_V6_vrmpy"): flops += 32*mults*(8 if 'acc' in u.arg else 7)
|
||||
if u.arg.startswith("__builtin_HEXAGON_A2_vraddub"): flops += mults*(17 if 'acc' in u.arg else 16)
|
||||
return Estimates(flops, lds, lds) # TODO: properly track memory, lds is always a high estimate
|
||||
|
||||
@dataclass
|
||||
|
|
@ -106,6 +109,9 @@ class ProgramSpec:
|
|||
if u.op is Ops.DEFINE_VAR: self.vars.append(u)
|
||||
if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg)
|
||||
if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
|
||||
# DSP masked store
|
||||
if u.op is Ops.CUSTOM and u.arg.startswith("__builtin_HEXAGON_V6_vS32b"):
|
||||
self.outs.extend([x.arg for x in u.src[1].toposort if x.op is Ops.DEFINE_GLOBAL])
|
||||
if u.op is Ops.LOAD: self.ins.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
|
||||
if u.op is Ops.SPECIAL:
|
||||
# NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import cast
|
||||
import itertools
|
||||
from tinygrad.helpers import dedup, DEBUG, to_function_name
|
||||
from tinygrad.helpers import dedup, DEBUG, to_function_name, getenv
|
||||
from tinygrad.engine.jit import GraphRunner, GraphException
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
|
|
@ -24,17 +24,18 @@ class CPUGraph(GraphRunner):
|
|||
if buf in input_rawbuffers: return f"arg{input_rawbuffers.index(buf)}"
|
||||
return f"({device.renderer.render_dtype(buf.dtype)}*)(cbuf{self.base_bufs.index(buf.base)} + {buf.offset})"
|
||||
|
||||
batched = ["void batched("+','.join([f"{device.renderer.render_dtype(x[1][0])} {x[0]}" for x in targs])+") {"]
|
||||
batched = ["void batched("+','.join([f"{device.renderer.render_dtype(x[1][0])} {x[0]}" for x in targs])+", int gl0, void* sync) {"]
|
||||
for i, ji in enumerate(jit_cache):
|
||||
args = [render_arg(buf) for buf in ji.bufs] + [x.expr for x in cast(CompiledRunner, ji.prg).p.vars]
|
||||
batched.append(f" {to_function_name(cast(CompiledRunner, ji.prg).p.name)}({','.join(args)});")
|
||||
batched.append(f" {to_function_name(cast(CompiledRunner, ji.prg).p.name)}({','.join(args)}, gl0, 0x0);")
|
||||
if getenv("MULTICORE", 0) != 0: batched.append(f" qurt_barrier_wait(&(((qurt_barrier_t*)sync)[{i}]));")
|
||||
batched.append("}")
|
||||
|
||||
prep = [device.renderer._render(cast(CompiledRunner, ji.prg).p.uops) for i,ji in enumerate(jit_cache)]
|
||||
funcs = dedup(device.renderer._render_body(prep[i][0], *prep[i][1:], cast(CompiledRunner, ji.prg).p.uops) for i,ji in enumerate(jit_cache))
|
||||
|
||||
defines = dedup(itertools.chain.from_iterable(device.renderer._render_defines(cast(CompiledRunner, ji.prg).p.uops) for ji in jit_cache))
|
||||
entry = device.renderer._render_entry("batched", targs)
|
||||
entry = device.renderer._render_entry("batched", targs, sync_cnt=len(jit_cache))
|
||||
code = '\n'.join(defines) + '\n' + '\n'.join([''.join(f) for f in funcs]) + '\n'.join(batched) + '\n' + entry
|
||||
|
||||
if DEBUG >= 4: print(code)
|
||||
|
|
|
|||
|
|
@ -3,65 +3,417 @@ import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, context
|
|||
assert sys.platform != 'win32'
|
||||
from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler, MallocAllocator
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType
|
||||
from tinygrad.ops import Ops, UOp
|
||||
from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv, cpu_objdump, DEBUG
|
||||
from tinygrad.ops import Ops, UOp, PatternMatcher, UPat
|
||||
from tinygrad.codegen.symbolic import gep_pushing
|
||||
from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv, cpu_objdump, DEBUG, dedup, all_same
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
from tinygrad.runtime.autogen import libc, qcom_dsp
|
||||
if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
from tinygrad.ops import PatternMatcher, UPat
|
||||
def multi_mul(a0, a1, b0, b1, c0, c1, d0=None, d1=None, acc=None):
|
||||
swizzle = []
|
||||
for i in range(32):
|
||||
swizzle.append(i)
|
||||
swizzle.append(32+i)
|
||||
swizzle.append(64+i)
|
||||
swizzle.append(96+i)
|
||||
swizzle = tuple(swizzle)
|
||||
if a0.op is not Ops.CAST:
|
||||
#print("rejected on a0")
|
||||
return None
|
||||
if a1.op is not Ops.CAST:
|
||||
#print("rejected on a1")
|
||||
return None
|
||||
if d0 is None:
|
||||
d0 = UOp.const(dtypes.uchar.vec(32), 0).cast(dtypes.int.vec(32))
|
||||
if d1 is None:
|
||||
d1 = UOp.const(dtypes.uchar.vec(32), 0).cast(dtypes.int.vec(32))
|
||||
assert a0.op is Ops.CAST
|
||||
assert b0.op is Ops.CAST
|
||||
assert c0.op is Ops.CAST
|
||||
assert d0.op is Ops.CAST
|
||||
assert a1.op is Ops.CAST
|
||||
assert b1.op is Ops.CAST
|
||||
assert c1.op is Ops.CAST
|
||||
assert d1.op is Ops.CAST
|
||||
dt1 = a0.src[0].dtype.scalar().vec(128)
|
||||
dt2 = a1.src[0].dtype.scalar().vec(128)
|
||||
m0 = UOp(Ops.CAT, dt1, src=(a0.src[0],b0.src[0],c0.src[0],d0.src[0])).gep(swizzle)
|
||||
m1 = UOp(Ops.CAT, dt2, src=(a1.src[0],b1.src[0],c1.src[0],d1.src[0])).gep(swizzle)
|
||||
simp_m1 = m1.simplify()
|
||||
if simp_m1.op is Ops.GEP and simp_m1.arg == simp_m1.arg[0:4]*32:
|
||||
# Vx32.w+=vrmpy(Vu32.ub,Rt32.b) -> __builtin_HEXAGON_V6_vrmpybus_acc
|
||||
# Vx32.uw+=vrmpy(Vu32.ub,Rt32.ub) -> __builtin_HEXAGON_V6_vrmpyub_acc
|
||||
scalar_m1 = simp_m1.src[0].gep(simp_m1.arg[0:4])
|
||||
if acc is not None:
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, scalar_m1.bitcast(dtypes.uint)), "__builtin_HEXAGON_V6_vrmpyub_acc_128B({0}, {1}, {2})")
|
||||
else:
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, scalar_m1.bitcast(dtypes.uint)), "__builtin_HEXAGON_V6_vrmpyub_128B({0}, {1})")
|
||||
if acc is not None:
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, m1), "__builtin_HEXAGON_V6_vrmpyubv_acc_128B({0}, {1}, {2})")
|
||||
else:
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, m1), "__builtin_HEXAGON_V6_vrmpyubv_128B({0}, {1})")
|
||||
|
||||
def gep_on_reduce(gep, alu):
|
||||
if gep.dtype.vcount == 1: return None
|
||||
return UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count),
|
||||
tuple(x.gep(gep.arg) if x.op is not Ops.RANGE else x for x in alu.src), alu.arg) if not isinstance(gep.dtype, PtrDType) and \
|
||||
alu.dtype.count >= gep.dtype.count else None
|
||||
|
||||
def multi_add_int32(**aa):
|
||||
if 'acc' in aa:
|
||||
acc = aa['acc']
|
||||
del aa['acc']
|
||||
else:
|
||||
acc = None
|
||||
mask = 0x01010101
|
||||
if 'd0' not in aa:
|
||||
mask = 0x00010101
|
||||
d0 = UOp.const(dtypes.uchar.vec(32), 0).cast(dtypes.int.vec(32))
|
||||
aa['d0'] = d0
|
||||
swizzle = []
|
||||
for i in range(32):
|
||||
swizzle.append(i)
|
||||
swizzle.append(32+i)
|
||||
swizzle.append(64+i)
|
||||
swizzle.append(96+i)
|
||||
for x in aa.values():
|
||||
assert x.src[0].dtype.scalar() is dtypes.uchar
|
||||
assert x.op is Ops.CAST
|
||||
swizzle = tuple(swizzle)
|
||||
m0 = UOp(Ops.CAT, dtypes.uchar.vec(128), src=tuple(aa[k].src[0] for k in sorted(aa.keys()))).gep(swizzle)
|
||||
if acc is not None:
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, UOp.const(dtypes.uint, mask)),
|
||||
"__builtin_HEXAGON_V6_vrmpyub_acc_128B({0}, {1}, {2})")
|
||||
else:
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, UOp.const(dtypes.uint, mask)), "__builtin_HEXAGON_V6_vrmpyub_128B({0}, {1})")
|
||||
|
||||
def multi_add_int2(**aa):
|
||||
if 'acc' in aa:
|
||||
acc = aa['acc']
|
||||
del aa['acc']
|
||||
else:
|
||||
acc = None
|
||||
eles = []
|
||||
for k in sorted(aa.keys()): eles.append(aa[k].src[0].gep(0))
|
||||
for k in sorted(aa.keys()): eles.append(aa[k].src[0].gep(1))
|
||||
r0 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(eles[0:4]+eles[8:12]))
|
||||
r1 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(eles[4:8]+eles[12:16]))
|
||||
|
||||
# TODO: types aren't right here
|
||||
if acc is not None:
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(2), (acc, r0.bitcast(dtypes.int64), r1.bitcast(dtypes.int64)),
|
||||
arg="__builtin_HEXAGON_A2_vraddub_acc({0}, {1}, {2})")
|
||||
else:
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(2), (r0.bitcast(dtypes.int64), r1.bitcast(dtypes.int64)), arg="__builtin_HEXAGON_A2_vraddub({0}, {1})")
|
||||
|
||||
conv_pm = PatternMatcher([
|
||||
# __builtin_HEXAGON_V6_vrmpybus x3
|
||||
(UPat(dtype=dtypes.int.vec(32), name="a0")*UPat(name="a1") + UPat(name="b0")*UPat(name="b1") + \
|
||||
UPat(name="c0")*UPat(name="c1"), multi_mul),
|
||||
(UPat(name="acc") + UPat(dtype=dtypes.int.vec(32), name="a0")*UPat(name="a1") + UPat(name="b0")*UPat(name="b1") + \
|
||||
UPat(name="c0")*UPat(name="c1"), multi_mul),
|
||||
|
||||
# __builtin_HEXAGON_V6_vrmpybus x3
|
||||
(UPat(Ops.CAST, dtype=dtypes.int.vec(32), name="a0") + UPat(Ops.CAST, name="b0") + UPat(Ops.CAST, name="c0"), multi_add_int32),
|
||||
(UPat(name="acc") + UPat(Ops.CAST, dtype=dtypes.int.vec(32), name="a0") + UPat(Ops.CAST, name="b0") + UPat(Ops.CAST, name="c0"), multi_add_int32),
|
||||
])
|
||||
|
||||
dsp_pm = PatternMatcher([
|
||||
# convert load char32 to load char128
|
||||
(UPat(Ops.LOAD, (dtypes.uchar.vec(96), dtypes.uchar.vec(64), dtypes.uchar.vec(32)), src=(UPat.var("buf").cast(),), name="load"),
|
||||
lambda load, buf: load.replace(dtype=dtypes.uchar.vec(128),
|
||||
src=(buf.cast(buf.dtype.base.vec(128).ptr(size=buf.dtype.size, local=buf.dtype.local)),)+load.src[1:]).gep(tuple(range(0, load.dtype.count)))),
|
||||
# GEP on REDUCE
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.REDUCE, name='alu'),), name='gep'), gep_on_reduce),
|
||||
# no swizzle down convert
|
||||
(((UPat.var('x').maximum(0) ^ -1).maximum(-256) ^ -1).cast(dtypes.uchar.vec(128)),
|
||||
lambda x: UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=tuple(x.gep(tuple(range(i, i+32))) for i in range(0, 128, 32)),
|
||||
arg="__builtin_HEXAGON_V6_vpackhub_sat_128B(__builtin_HEXAGON_V6_vpackwh_sat_128B({3}, {2}), __builtin_HEXAGON_V6_vpackwh_sat_128B({1}, {0}))")),
|
||||
(UPat(Ops.GEP, name="x"), lambda x: UOp(Ops.CUSTOM, x.dtype, x.src+x.src,
|
||||
"__builtin_shufflevector({0}, {1}, "+','.join([str(y) for y in x.arg])+")") if len(x.arg) > 1 and x.src[0].dtype.count > 1 else None),
|
||||
])
|
||||
|
||||
# REDUCE int4 -> 2xint2, int8 -> 4xint2
|
||||
(UPat(Ops.REDUCE, dtype=dtypes.int.vec(4), name="r"),
|
||||
lambda r: UOp(Ops.CAT, r.dtype, (gep_on_reduce(r.gep((0,1)), r), gep_on_reduce(r.gep((2,3)), r)))),
|
||||
(UPat(Ops.REDUCE, dtype=dtypes.int.vec(8), name="r"),
|
||||
lambda r: UOp(Ops.CAT, r.dtype, (gep_on_reduce(r.gep((0,1)), r), gep_on_reduce(r.gep((2,3)), r),
|
||||
gep_on_reduce(r.gep((4,5)), r), gep_on_reduce(r.gep((6,7)), r)))),
|
||||
|
||||
# __builtin_HEXAGON_V6_vrmpybus
|
||||
(UPat(dtype=dtypes.int.vec(32), name="a0")*UPat(name="a1") + UPat(name="b0")*UPat(name="b1") + \
|
||||
UPat(name="c0")*UPat(name="c1") + UPat(name="d0")*UPat(name="d1"), multi_mul),
|
||||
(UPat(name="acc") + UPat(dtype=dtypes.int.vec(32), name="a0")*UPat(name="a1") + UPat(name="b0")*UPat(name="b1") + \
|
||||
UPat(name="c0")*UPat(name="c1") + UPat(name="d0")*UPat(name="d1"), multi_mul),
|
||||
|
||||
# build __builtin_HEXAGON_V6_vrmpybus_128B
|
||||
(UPat(Ops.CAST,dtype=dtypes.int.vec(32),name="a0")+UPat(Ops.CAST,name="b0")+UPat(Ops.CAST,name="c0")+UPat(Ops.CAST,name="d0"), multi_add_int32),
|
||||
(UPat(name="acc")+UPat(Ops.CAST,dtype=dtypes.int.vec(32),name="a0")+UPat(Ops.CAST,name="b0")+
|
||||
UPat(Ops.CAST,name="c0")+UPat(Ops.CAST,name="d0"), multi_add_int32),
|
||||
|
||||
# build __builtin_HEXAGON_A2_vraddub
|
||||
(UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a0")+UPat(Ops.CAST,name="a1")+UPat(Ops.CAST,name="a2")+UPat(Ops.CAST,name="a3")+ \
|
||||
UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a4")+UPat(Ops.CAST,name="a5")+UPat(Ops.CAST,name="a6")+UPat(Ops.CAST,name="a7"), multi_add_int2),
|
||||
(UPat(name="acc")+UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a0")+UPat(Ops.CAST,name="a1")+UPat(Ops.CAST,name="a2")+UPat(Ops.CAST,name="a3")+ \
|
||||
UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a4")+UPat(Ops.CAST,name="a5")+UPat(Ops.CAST,name="a6")+UPat(Ops.CAST,name="a7"), multi_add_int2),
|
||||
|
||||
# we upcast 3 as 4
|
||||
(UPat(Ops.REDUCE, name="r", src=(UPat(dtype=dtypes.int.vec(32), name="a0")*UPat(name="a1") +
|
||||
UPat(name="b0")*UPat(name="b1") + UPat(name="c0")*UPat(name="c1"),), allow_any_len=True),
|
||||
lambda r, **kwargs: r.replace(src=(mm,)+r.src[1:]) if (mm:=multi_mul(**kwargs)) else None),
|
||||
(UPat(Ops.REDUCE, name="r", src=(UPat(Ops.CAST,dtype=dtypes.int.vec(32),name="a0")+UPat(Ops.CAST,name="b0")+UPat(Ops.CAST,name="c0"),
|
||||
), allow_any_len=True),
|
||||
lambda r, **kwargs: r.replace(src=(mm,)+r.src[1:]) if (mm:=multi_add_int32(**kwargs)) else None),
|
||||
|
||||
# mul by const on GEP
|
||||
(UPat(Ops.GEP, src=(UPat.var('x'),), name="gep")*UPat.cvar("c", vec=False),
|
||||
lambda x, gep, c: (x.gep(gep.arg[0])*c.arg).broadcast(c.dtype.count) if all_same(gep.arg) and c.dtype.count > 1 else None),
|
||||
])+gep_pushing
|
||||
|
||||
def add_to_mul(c:UOp, x:UOp):
|
||||
if c.arg.startswith("__builtin_HEXAGON_V6_vrmpyub_128B"):
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (x, c.src[0], c.src[1]), "__builtin_HEXAGON_V6_vrmpyub_acc_128B({0}, {1}, {2})")
|
||||
elif c.arg.startswith("__builtin_HEXAGON_V6_vrmpyubv_128B"):
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (x, c.src[0], c.src[1]), "__builtin_HEXAGON_V6_vrmpyubv_acc_128B({0}, {1}, {2})")
|
||||
elif 'acc' in c.arg and x.op is not Ops.CUSTOM:
|
||||
return c.replace(src=(x+c.src[0], c.src[1], c.src[2]))
|
||||
else:
|
||||
return None
|
||||
|
||||
def prefetch_l1(ld:UOp, idx:UOp):
|
||||
if ld.src[-1].op is Ops.CUSTOM: return None
|
||||
ranges = sorted([x for x in ld.src[0].src[0].toposort if x.op is Ops.RANGE], key=lambda x: x.arg)
|
||||
x1 = UOp(Ops.CUSTOM, dtypes.void, src=(idx.src[0], idx.src[1]+UOp.const(dtypes.int, ld.dtype.count*2)),
|
||||
arg="__builtin_HEXAGON_Y2_dcfetch({0}+{1});")
|
||||
x2 = UOp(Ops.CUSTOM, dtypes.void, src=(idx.src[0], idx.src[1].substitute({ranges[-1]: ranges[-1].src[0]})),
|
||||
arg="__builtin_HEXAGON_Y2_dcfetch({0}+{1});")
|
||||
return ld.replace(src=ld.src+(x1, x2))
|
||||
|
||||
def prefetch_l2(ld:UOp, idx:UOp):
|
||||
if not getenv("PREFETCHL2", 1): return None
|
||||
if ld.src[-1].op is Ops.CUSTOM and 'l2fetch' in ld.src[-1].arg: return None
|
||||
ranges = sorted([x for x in ld.src[0].src[0].toposort if x.op is Ops.RANGE], key=lambda x: x.arg)
|
||||
if len(ranges):
|
||||
nidx = idx.src[1]
|
||||
const = 0
|
||||
if nidx.op is Ops.ADD and nidx.src[1].op is Ops.CONST:
|
||||
# NOTE: this causes access alignment issues
|
||||
#const = nidx.src[1].arg
|
||||
nidx = nidx.src[0]
|
||||
zero_ranges = {r:r.const_like(0) for r in ranges[:-1]}
|
||||
nlen_uop = (nidx.substitute({ranges[-1]: ranges[-1].src[1], **zero_ranges}) -
|
||||
nidx.substitute({ranges[-1]: ranges[-1].src[0], **zero_ranges})).simplify()
|
||||
nidx = nidx.substitute({ranges[-1]: ranges[-1].src[0]})
|
||||
buf_lines_total = ((idx.src[0].dtype.size*idx.src[0].dtype.itemsize)+127)//128
|
||||
if buf_lines_total < 8192//128:
|
||||
# if the total buffer size is sub 8k, fetch it all
|
||||
x1 = UOp(Ops.CUSTOM, dtypes.void, src=(idx.src[0], UOp.const(dtypes.int, buf_lines_total)),
|
||||
arg="__builtin_HEXAGON_Y4_l2fetch({0}, 0x808000|{1});")
|
||||
else:
|
||||
fetch_lines = 8
|
||||
if nlen_uop.op is Ops.CONST and nlen_uop.arg <= 8192: fetch_lines = ((nlen_uop.arg+127)//128)*2+1
|
||||
fetch_lines = max(fetch_lines, 8)
|
||||
|
||||
# fetch up to 8192
|
||||
x1 = UOp(Ops.CUSTOM, dtypes.void, src=(idx.src[0], nidx+const, UOp.const(dtypes.int, fetch_lines)),
|
||||
arg="__builtin_HEXAGON_Y4_l2fetch({0}+{1}, 0x808000|{2});")
|
||||
return ld.replace(src=ld.src+(x1,))
|
||||
|
||||
def vectorize_shuffle(vec:UOp):
|
||||
if getenv("DISABLE_VECTORIZED_SHUFFLE", 0): return None
|
||||
if not all(s.op in {Ops.GEP, Ops.CONST} for s in vec.src): return None
|
||||
gepped = dedup([s.src[0] for s in vec.src if s.op is Ops.GEP])
|
||||
if len(gepped) == 0: return None
|
||||
if len(gepped) == 1:
|
||||
# this pattern is broken in DSP clang
|
||||
if gepped[0].dtype.count == 4: return None
|
||||
#return None
|
||||
arg = []
|
||||
for s in vec.src:
|
||||
if s.op is Ops.GEP:
|
||||
arg.append(s.arg[0])
|
||||
else:
|
||||
arg.append(-1)
|
||||
str_arg = ','.join([f'{y:4d}' for y in arg])
|
||||
full_arg = "__builtin_shufflevector({0}, {0}, "+str_arg+")"
|
||||
return UOp(Ops.CUSTOM, vec.dtype, tuple(gepped), full_arg)
|
||||
if not all(x.dtype.scalar() is dtypes.uchar for x in gepped): return None
|
||||
if not all_same([x.dtype.count for x in gepped]) or gepped[0].dtype.count != vec.dtype.count: return None
|
||||
if len(gepped) == 2:
|
||||
arg = []
|
||||
for s in vec.src:
|
||||
if s.op is Ops.GEP:
|
||||
if s.src[0] is gepped[0]:
|
||||
arg.append(s.arg[0])
|
||||
continue
|
||||
if s.src[0] is gepped[1]:
|
||||
arg.append(gepped[0].dtype.count + s.arg[0])
|
||||
continue
|
||||
arg.append(-1)
|
||||
str_arg = ','.join([f'{y:4d}' for y in arg])
|
||||
full_arg = "__builtin_shufflevector({0}, {1}, "+str_arg+")"
|
||||
return UOp(Ops.CUSTOM, vec.dtype, tuple(gepped), full_arg)
|
||||
if len(gepped) != 3: return None
|
||||
arg = []
|
||||
for s in vec.src:
|
||||
if s.op is Ops.GEP:
|
||||
if s.src[0] is gepped[0]:
|
||||
arg.append(s.arg[0])
|
||||
continue
|
||||
if s.src[0] is gepped[1]:
|
||||
arg.append(gepped[0].dtype.count + s.arg[0])
|
||||
continue
|
||||
arg.append(-1)
|
||||
arg2 = []
|
||||
for i,s in enumerate(vec.src):
|
||||
if s.op is Ops.GEP:
|
||||
if s.src[0] is gepped[2]:
|
||||
arg2.append(vec.dtype.count + s.arg[0])
|
||||
continue
|
||||
if s.op is Ops.CONST:
|
||||
arg2.append(-1)
|
||||
continue
|
||||
arg2.append(i)
|
||||
str_arg = ','.join([f'{y:4d}' for y in arg])
|
||||
str_arg2 = ','.join([f'{y:4d}' for y in arg2])
|
||||
full_arg = "__builtin_shufflevector(__builtin_shufflevector({0}, {1}, "+str_arg+"), {2}, "+str_arg2+")"
|
||||
return UOp(Ops.CUSTOM, vec.dtype, tuple(gepped), full_arg)
|
||||
|
||||
def multicore_range(r:UOp):
|
||||
# NOTE: THIS IS BROKEN if this is a reduce range. TODO: check for that
|
||||
if getenv("MULTICORE", 0) != 1: return None
|
||||
if any(x.op is Ops.SPECIAL for x in r.toposort): return None
|
||||
core = UOp(Ops.SPECIAL, dtypes.int, arg=("g0", 2))
|
||||
start = (core.eq(0)).where(r.src[0], r.src[1]//2)
|
||||
end = (core.eq(0)).where(r.src[1]//2, r.src[1])
|
||||
return r.replace(src=(start,end))
|
||||
|
||||
def store_with_mask(buf, idx, val, mask, cast):
|
||||
if val.dtype.count != 128 or mask.dtype.count != 128 or val.dtype.scalar() != dtypes.uchar:
|
||||
print("DROP MASK", val.dtype.count, mask.dtype.count)
|
||||
# NOTE: we are dropping the mask
|
||||
return buf.index(idx).cast(cast.dtype).store(val)
|
||||
|
||||
const_0 = UOp.const(dtypes.uchar.vec(128), 0)
|
||||
cmask = UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=(mask,),arg="{0}")
|
||||
|
||||
# unaligned
|
||||
min_128 = (buf.index(idx).cast(dtypes.uint)&0x7F)
|
||||
cmask_l = UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=(cmask, const_0, min_128), arg="__builtin_HEXAGON_V6_vlalignb_128B({0}, {1}, {2})")
|
||||
cmask_r = UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=(const_0, mask, min_128), arg="__builtin_HEXAGON_V6_vlalignb_128B({0}, {1}, {2})")
|
||||
val_l = UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=(val, const_0, min_128), arg="__builtin_HEXAGON_V6_vlalignb_128B({0}, {1}, {2})")
|
||||
val_r = UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=(const_0, val, min_128), arg="__builtin_HEXAGON_V6_vlalignb_128B({0}, {1}, {2})")
|
||||
store_l = UOp(Ops.CUSTOM, dtypes.void, src=(cmask_l, buf.index(idx).cast(cast.dtype), val_l, const_0),
|
||||
arg='__builtin_HEXAGON_V6_vS32b_nqpred_ai_128B(__builtin_HEXAGON_V6_veqb_128B({0}, {3}), {1}, {2});')
|
||||
store_r = UOp(Ops.CUSTOM, dtypes.void, src=(cmask_r, buf.index(idx+128).cast(cast.dtype), val_r, const_0),
|
||||
arg='__builtin_HEXAGON_V6_vS32b_nqpred_ai_128B(__builtin_HEXAGON_V6_veqb_128B({0}, {3}), {1}, {2});')
|
||||
return UOp(Ops.CUSTOM, src=(store_l,store_r), arg="")
|
||||
|
||||
dsp_pm_late = PatternMatcher([
|
||||
(UPat.var("x")+UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x+UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
|
||||
(UPat.var("x")*UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x*UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
|
||||
(UPat.var("x")//UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x//UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
|
||||
# prefetch L1
|
||||
(UPat(Ops.LOAD, dtype=(dtypes.uchar.vec(4), dtypes.uchar.vec(8)), src=(UPat(Ops.INDEX, name="idx").cast(),), name="ld"), prefetch_l1),
|
||||
|
||||
# prefetch L2
|
||||
(UPat(Ops.LOAD, dtype=(dtypes.uchar.vec(8), dtypes.uchar.vec(128), dtypes.int.vec(32)),
|
||||
src=(UPat(Ops.INDEX, name="idx").cast(),), name="ld", allow_any_len=True), prefetch_l2),
|
||||
|
||||
# use __builtin_shufflevector
|
||||
(UPat(Ops.VECTORIZE, dtypes.uchar.vec(128), name="vec"), vectorize_shuffle),
|
||||
|
||||
# __builtin_HEXAGON_V6_vrmpyub_acc_128B
|
||||
(UPat(Ops.CUSTOMI, dtype=dtypes.int.vec(32), name="c")+UPat.var("x"), add_to_mul),
|
||||
|
||||
# add acc to __builtin_HEXAGON_A2_vraddub (must be after the reduce expansion)
|
||||
(UPat(Ops.CUSTOMI, name="cu", arg="__builtin_HEXAGON_A2_vraddub({0}, {1})") + UPat.var("x"),
|
||||
lambda x,cu: cu.replace(dtype=dtypes.int64, src=(x.bitcast(dtypes.int64), cu.src[0], cu.src[1]),
|
||||
arg="__builtin_HEXAGON_A2_vraddub_acc({0}, {1}, {2})").bitcast(dtypes.int.vec(2))),
|
||||
|
||||
(UPat(Ops.GEP, name="x"), lambda x: UOp(Ops.CUSTOM, x.dtype, x.src,
|
||||
"__builtin_shufflevector({0}, {0}, "+','.join([f'{y:4d}' for y in x.arg])+")") if len(x.arg) > 1 and x.src[0].dtype.count > 4 else None),
|
||||
(UPat.var("x")+UPat(Ops.VECTORIZE,src=UPat.var("y")),
|
||||
lambda x,y: x+UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI or x.arg != "{0}" else None),
|
||||
(UPat.var("x")*UPat(Ops.VECTORIZE,src=UPat.var("y")),
|
||||
lambda x,y: x*UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI or x.arg != "{0}" else None),
|
||||
(UPat.var("x")//UPat(Ops.VECTORIZE,src=UPat.var("y")),
|
||||
lambda x,y: x//UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI or x.arg != "{0}" else None),
|
||||
(UPat(Ops.DEFINE_ACC, src=(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST, arg=0)),), dtype=dtypes.uchar.vec(128), name="d", allow_any_len=True),
|
||||
lambda d: d.replace(src=(UOp(Ops.CUSTOMI, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:])),
|
||||
|
||||
# multicore
|
||||
(UPat(Ops.RANGE, name="r", arg=0), multicore_range),
|
||||
|
||||
# store with mask
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat.var("mask"))).cast().named("cast"), UPat.var("val"))),
|
||||
store_with_mask),
|
||||
])
|
||||
|
||||
# NOTE: this just increases readability of the generated code
|
||||
dsp_string = PatternMatcher([
|
||||
(UPat(Ops.CONST, (dtypes.int8, dtypes.uint8), name="x"), lambda ctx,x: str(x.arg)),
|
||||
pretty_render = PatternMatcher([
|
||||
# makes rendering nicer
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST, dtype=(dtypes.uint8, dtypes.int8)), name="v"),
|
||||
lambda v: UOp(Ops.VECTORIZE, v.dtype, src=tuple(UOp(Ops.CUSTOMI, x.dtype, src=(UOp.const(dtypes.int, x.arg),), arg="{0}") for x in v.src))),
|
||||
])
|
||||
|
||||
class DSPRenderer(ClangRenderer):
|
||||
device = "DSP"
|
||||
supports_float4 = True
|
||||
global_max = (2, 1, 1)
|
||||
buffer_suffix = " restrict __attribute__((align_value(128)))"
|
||||
kernel_prefix = "__attribute__((noinline)) "
|
||||
pre_matcher = dsp_pm
|
||||
extra_matcher = dsp_pm_late+ClangRenderer.extra_matcher
|
||||
string_rewrite = dsp_string+ClangRenderer.string_rewrite
|
||||
extra_matcher = dsp_pm_late+ClangRenderer.extra_matcher+pretty_render
|
||||
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
|
||||
code_for_op = {k:v for k,v in ClangRenderer.code_for_op.items() if k != Ops.SQRT}
|
||||
extra_args = ['int global_idx_0', 'void* sync']
|
||||
code_for_workitem = {"g": lambda x: f"global_idx_{x}"}
|
||||
|
||||
def _render_defines(self, uops) -> list[str]:
|
||||
return ['''/* DSP boilerplate */ struct dcvs_v2_req { int type; int _pad; _Bool dcvs_enable; char dcvs_option; _Bool set_latency; int latency;
|
||||
_Bool set_dcvs_params; short _pad2; char target_corner; char min_corner; char max_corner; int _pad3[3];};''','int HAP_power_set(void*, void*);',
|
||||
'typedef union { struct { void *pv; unsigned int len; } buf; struct { int fd; unsigned int offset; } dma; } remote_arg;',
|
||||
'void* HAP_mmap(void *addr, int len, int prot, int flags, int fd, long offset);', 'int HAP_munmap(void *addr, int len);',
|
||||
'unsigned long long HAP_perf_get_time_us(void);'] + super()._render_defines(uops)
|
||||
'unsigned long long HAP_perf_get_time_us(void);', 'typedef unsigned long qurt_thread_t;', 'void qurt_thread_exit(int);',
|
||||
'typedef struct _qurt_barrier { char padding[64]; } qurt_barrier_t;', 'int qurt_barrier_init(qurt_barrier_t*, unsigned int);',
|
||||
'int qurt_barrier_wait(qurt_barrier_t*);',
|
||||
'typedef struct _qurt_thread_attr { char name[16]; unsigned char tcb_partition; unsigned char affinity; unsigned short priority;',
|
||||
'unsigned char asid; unsigned char bus_priority; unsigned short timetest_id; unsigned int stack_size;'
|
||||
'void *stack_addr; char padding[96]; } qurt_thread_attr_t;',
|
||||
'int qurt_thread_join(qurt_thread_t tid, int *status);', 'void* malloc(unsigned int);', 'void free(void*);',
|
||||
'int qurt_thread_create (qurt_thread_t *thread_id, qurt_thread_attr_t *attr, void (*entrypoint) (void *), void *arg);',
|
||||
] + super()._render_defines(uops)
|
||||
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]]) -> str:
|
||||
msrc = ['int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]], sync_cnt=0x0) -> str:
|
||||
msrc = ['typedef struct all_args {', *[f'int sz_or_val_{i}; int off{i}; void *buf_{i};' for i in range(len(bufs))], 'void* sync; } all_args_t;']
|
||||
msrc += ['void threader(all_args_t* args) {']
|
||||
buf_inputs = ', '.join([(f'args->buf_{i}' if isinstance(b[1][0], PtrDType) else f'args->sz_or_val_{i}') for i,b in enumerate(bufs)])
|
||||
msrc += [f"{function_name}({buf_inputs}, 1, args->sync);"]
|
||||
msrc += ['qurt_thread_exit(0); }'
|
||||
'int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
|
||||
'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};',
|
||||
'HAP_power_set((void*)handle, (void*)&req);']
|
||||
msrc += ['if ((sc>>24) != 2) return 0;']
|
||||
msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
|
||||
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
if sync_cnt > 0:
|
||||
msrc += [f"qurt_barrier_t* sync = malloc({sync_cnt} * sizeof(qurt_barrier_t));"]
|
||||
msrc += [f"qurt_barrier_init(&sync[{i}], 2);" for i in range(sync_cnt)]
|
||||
else: msrc += ["qurt_barrier_t* sync = 0x0;"]
|
||||
msrc += ['all_args_t args = { 0 };']
|
||||
msrc += [f'args.sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
|
||||
msrc += [f'args.off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
msrc += [f'args.buf_{i} = HAP_mmap(0,args.sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+args.off{i};'
|
||||
for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
msrc += ['args.sync = sync;']
|
||||
msrc += ["qurt_thread_attr_t attr = { 0 };"]
|
||||
msrc += ["attr.name[0] = 't';", "attr.priority = 255;", "attr.asid = 0;"]
|
||||
msrc += ["attr.stack_size = (64 << 10);", "attr.stack_addr = malloc(attr.stack_size);"]
|
||||
msrc += [""]
|
||||
msrc += ["unsigned long long start = HAP_perf_get_time_us();"]
|
||||
msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0], PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
|
||||
if getenv("MULTICORE", 0) != 0:
|
||||
msrc += ["qurt_thread_t thread_ = 0; qurt_thread_create(&thread_, &attr, (void (*)(void*))threader, (void*)&args);"]
|
||||
buf_inputs = ', '.join([(f'args.buf_{i}' if isinstance(b[1][0], PtrDType) else f'args.sz_or_val_{i}') for i,b in enumerate(bufs)])
|
||||
msrc += [f"{function_name}({buf_inputs}, 0, args.sync);"]
|
||||
if getenv("MULTICORE", 0) != 0:
|
||||
msrc += ['int status;', "qurt_thread_join(thread_, &status);"]
|
||||
msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"]
|
||||
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
msrc += [f'HAP_munmap(args.buf_{i}, args.sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
msrc += ['free(attr.stack_addr);']
|
||||
if sync_cnt > 0: msrc += ['free(sync);']
|
||||
msrc += ["return 0; }"]
|
||||
return '\n'.join(msrc)
|
||||
|
||||
|
|
@ -141,7 +493,7 @@ class DSPDevice(Compiled):
|
|||
'got', 'got.plt', 'dynsym', 'dynstr', 'symtab', 'shstrtab', 'strtab']
|
||||
sections_link = '\n'.join([f'.{n} : ALIGN(4096) {{ *(.{n}) }}' for n in sections])
|
||||
with tempfile.NamedTemporaryFile(delete=False) as self.link_ld:
|
||||
self.link_ld.write(f"SECTIONS {{ . = 0x0; {sections_link}\n /DISCARD/ : {{ *(.note .note.* .gnu.hash .comment) }} }}".encode())
|
||||
self.link_ld.write(f"SECTIONS {{ . = 0x0;\n{sections_link}\n /DISCARD/ : {{ *(.note .note.* .gnu.hash .comment) }} }}".encode())
|
||||
self.link_ld.flush()
|
||||
|
||||
from tinygrad.runtime.graph.cpu import CPUGraph
|
||||
|
|
@ -283,7 +635,15 @@ class MockDSPRenderer(DSPRenderer):
|
|||
else:
|
||||
msrc.append(f"unsigned int val{i}; read(0, &val{i}, 4);")
|
||||
msrc.append("unsigned int st = inscount();")
|
||||
msrc.append(f"{function_name}({', '.join([(f'(void*)buf{i}' if isinstance(b[1][0], PtrDType) else f'val{i}') for i,b in enumerate(bufs)])});")
|
||||
buf_inputs = ', '.join([(f'(void*)buf{i}' if isinstance(b[1][0], PtrDType) else f'val{i}') for i,b in enumerate(bufs)])
|
||||
if getenv("MULTICORE", 0) != 0:
|
||||
# TODO: get count?
|
||||
# NOTE: we do them in reverse order to reveal bugs
|
||||
msrc.append(f"{function_name}({buf_inputs}, 1, 0);")
|
||||
msrc.append(f"{function_name}({buf_inputs}, 0, 0);")
|
||||
else:
|
||||
# huh, why did this change?
|
||||
msrc.append(f"{function_name}({buf_inputs}, 0, 0);")
|
||||
msrc.append("unsigned int et = inscount() - st; write(1, &et, sizeof(et));")
|
||||
for i,b in enumerate(bufs):
|
||||
if isinstance(b[1][0], PtrDType): msrc.append(f"write(1, buf{i}, {b[1][0].size*b[1][0].itemsize});")
|
||||
|
|
|
|||
|
|
@ -81,6 +81,9 @@ spec = PatternMatcher([
|
|||
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True),
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat(Ops.STORE))), lambda: True),
|
||||
|
||||
# all loads (we add Ops.CUSTOM as fake sources on this for l1fetch)
|
||||
(UPat(Ops.LOAD), lambda: True),
|
||||
|
||||
# early STORE has a <buf, shapetracker, val>
|
||||
(UPat(Ops.STORE, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat())), lambda: True),
|
||||
|
||||
|
|
@ -91,6 +94,9 @@ spec = PatternMatcher([
|
|||
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat()), name="idx"), validate_index),
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(), UPat(dtype=dtypes.bool, name="mask")), name="idx"), validate_index),
|
||||
|
||||
# any mask
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(), UPat(name="mask")), name="idx"), validate_index),
|
||||
|
||||
# LOAD takes a <bufidx, alt?, barrier?>
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue