mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
199 commits
master
...
dsp_search
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dd51728795 | ||
|
|
a59b1ed970 | ||
|
|
147fc0e648 | ||
|
|
17f7b226cb | ||
|
|
9c34d9eb6e | ||
|
|
95261b6193 | ||
|
|
1e2becfeae | ||
|
|
e18cdbcbe2 | ||
|
|
f3cb4c3eef | ||
|
|
6ecaf11224 | ||
|
|
8b24f9cb0d | ||
|
|
797e512c00 | ||
|
|
f600482982 | ||
|
|
da35edbb55 | ||
|
|
661431ee75 | ||
|
|
8340d9c1c2 | ||
|
|
910cddbbca | ||
|
|
e6e0c0ec86 | ||
|
|
d0eedb5a79 | ||
|
|
f69deddbd4 | ||
|
|
be11fbbf78 | ||
|
|
812c391617 | ||
|
|
3306083f42 | ||
|
|
18d7e9d3f1 | ||
|
|
1c3f249ecf | ||
|
|
bb7b89475c |
||
|
|
8005e6c974 | ||
|
|
a3d61a0372 | ||
|
|
c73e35aa24 | ||
|
|
0b4b9f61b9 | ||
|
|
ee3ddfcdc1 | ||
|
|
220d682489 | ||
|
|
9c388c3539 | ||
|
|
4b3a4c8c46 | ||
|
|
eb606d7230 | ||
|
|
49d52a2763 | ||
|
|
a59c3dd09a | ||
|
|
a640292aed | ||
|
|
2f48c12441 |
||
|
|
be3b5efc64 | ||
|
|
996d0ac1d2 | ||
|
|
77e897b3b1 |
||
|
|
273dde69bd | ||
|
|
a64030d8c8 | ||
|
|
9b19129e87 | ||
|
|
48221d9024 | ||
|
|
bcfcd60f55 | ||
|
|
abc90024ac | ||
|
|
f0e6d8394c |
||
|
|
a1c1ecd597 |
||
|
|
489a5e24c4 |
||
|
|
e0fd84dd64 | ||
|
|
1a9d7a1628 | ||
|
|
45646fe102 | ||
|
|
9c928afafe | ||
|
|
d4f1c5049b | ||
|
|
11b478f85d | ||
|
|
0aa7031b5f | ||
|
|
ab67d5ff6e | ||
|
|
cbe23e13c2 | ||
|
|
9bbd12dc65 | ||
|
|
b09142a893 | ||
|
|
1d7faf4777 | ||
|
|
59438be39b | ||
|
|
cc23836a38 | ||
|
|
e4354effa2 |
||
|
|
d180e909a3 | ||
|
|
52364231dc | ||
|
|
d32ad080c3 | ||
|
|
a8bd26d9bc | ||
|
|
6d860389f4 | ||
|
|
5d5286489d | ||
|
|
917e0e925b | ||
|
|
6081f8427e | ||
|
|
23035bf028 | ||
|
|
5e33163ef3 |
||
|
|
cee9fc7540 | ||
|
|
9041072dea |
||
|
|
444d6279ac | ||
|
|
f27f484621 | ||
|
|
38488ec3b0 | ||
|
|
ff96f0adae | ||
|
|
5dd59a6096 | ||
|
|
6bec82b918 | ||
|
|
a436d7542f | ||
|
|
5d98688de6 | ||
|
|
09d877ed8c |
||
|
|
6ff894d674 | ||
|
|
da03b4520a |
||
|
|
013c6e0b10 | ||
|
|
31ffa1607e |
||
|
|
928994c6ea | ||
|
|
e283bec62e | ||
|
|
c4f5db8467 | ||
|
|
bf0d928417 | ||
|
|
f823324eb9 | ||
|
|
6995e0c91b |
||
|
|
b934b5b907 | ||
|
|
290ba9ee37 | ||
|
|
e0d63696d7 | ||
|
|
acafd57f14 |
||
|
|
905f847d10 | ||
|
|
9e19cdfbbe | ||
|
|
f7b38fa94c | ||
|
|
bd03942bd8 | ||
|
|
880b4a5e47 | ||
|
|
2e4cae342b | ||
|
|
8660fecb02 | ||
|
|
e3e43df0c9 | ||
|
|
a47e61b097 | ||
|
|
f1ff18acec | ||
|
|
60cbfe4222 |
||
|
|
311df3ff21 | ||
|
|
f6e64a5e8e | ||
|
|
622ff115a3 | ||
|
|
5a6e8ee268 | ||
|
|
a9f1227625 | ||
|
|
74c2587ef4 | ||
|
|
bce252e0b8 | ||
|
|
66a90a3c92 | ||
|
|
0d76b0d461 | ||
|
|
5e4505d363 | ||
|
|
29920b74d5 | ||
|
|
ccd18a803c | ||
|
|
943bde47ab | ||
|
|
0d10c7ae2f | ||
|
|
3cab6a3d4a | ||
|
|
22a56cbaea | ||
|
|
afd61730b4 | ||
|
|
536556434b | ||
|
|
52bff5f39d | ||
|
|
64d0f14d3d | ||
|
|
1b61cc6ec3 | ||
|
|
6f792e8045 | ||
|
|
b1f8018bf4 | ||
|
|
2eb9241329 | ||
|
|
554a490751 |
||
|
|
651c678edf | ||
|
|
3274bd2d81 | ||
|
|
30f4d64148 | ||
|
|
2634975d5a | ||
|
|
fd73ec2b1b | ||
|
|
e1d2bec4a4 | ||
|
|
1b4e9f5e91 | ||
|
|
25c023bcbe | ||
|
|
07abf9e6bc | ||
|
|
26b02a037c | ||
|
|
5089a601c6 | ||
|
|
6b49a63c48 | ||
|
|
dca95428a5 | ||
|
|
8a477ba4e1 | ||
|
|
264dd91b8a | ||
|
|
bdf716b915 | ||
|
|
cf41c803d0 | ||
|
|
3cf9224df5 | ||
|
|
af94addb3a | ||
|
|
dc1469a188 | ||
|
|
0416b0998d | ||
|
|
c715c25420 |
||
|
|
f66b03f0a6 | ||
|
|
2729a46ca6 | ||
|
|
dbb50e4a00 | ||
|
|
71c7c455a6 | ||
|
|
ff3438be4e | ||
|
|
bc5e23061b | ||
|
|
5ce951fb34 | ||
|
|
4a49d05a3f |
||
|
|
c3c85c64ee | ||
|
|
61c02ca634 | ||
|
|
325044bcaf | ||
|
|
91ac508878 | ||
|
|
2ed30f5366 | ||
|
|
d0b9c7e7ca | ||
|
|
f6ed8f4a27 | ||
|
|
87718170d2 | ||
|
|
b67af4049c | ||
|
|
16e425a4c0 | ||
|
|
c867a48ab4 | ||
|
|
2dc82c0604 | ||
|
|
e7402e6643 | ||
|
|
e5ccd9e846 | ||
|
|
624197f169 | ||
|
|
d42350a401 | ||
|
|
223feb2118 |
||
|
|
8eb9093fb8 | ||
|
|
45f7c08111 | ||
|
|
58fc77fdb3 | ||
|
|
e57258b17b | ||
|
|
31cd00e72f | ||
|
|
b00ccc08c3 | ||
|
|
94d578aec5 | ||
|
|
45010f7eff | ||
|
|
249141026e | ||
|
|
a913c1aab7 | ||
|
|
469ec6b6b4 | ||
|
|
1a84d504b7 |
||
|
|
14c9f14125 | ||
|
|
cc0041cb8c | ||
|
|
e4615e0cd9 |
19 changed files with 913 additions and 90 deletions
|
|
@ -1,4 +1,4 @@
|
|||
import sys, onnx, time
|
||||
import sys, onnx, time, pickle
|
||||
from tinygrad import TinyJit, Device, GlobalCounters, fetch, getenv
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from extra.onnx_helpers import get_example_inputs, validate
|
||||
|
|
@ -33,4 +33,9 @@ if __name__ == "__main__":
|
|||
|
||||
if getenv("ORT"):
|
||||
validate(onnx_file, new_inputs, rtol=1e-3, atol=1e-3)
|
||||
print("model validated")
|
||||
print("model validated")
|
||||
|
||||
if (fn:=getenv("SAVE_PKL", "")) != "":
|
||||
with open(fn, "wb") as f:
|
||||
pickle.dump(run_onnx_jit, f)
|
||||
print(f"pkl saved to {fn}")
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ if __name__ == "__main__":
|
|||
GlobalCounters.reset()
|
||||
p = run_onnx_jit(**{t_name:img})
|
||||
assert p.shape == (1,1000)
|
||||
t = p.argmax().item()
|
||||
t = p.to('cpu').argmax().item()
|
||||
hit += y==t
|
||||
print(f"target: {y:3d} pred: {t:3d} acc: {hit/(i+1)*100:.2f}%")
|
||||
|
||||
|
|
|
|||
19
examples/test_pkl_imagenet.py
Normal file
19
examples/test_pkl_imagenet.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
import sys, pickle
|
||||
from tinygrad import GlobalCounters
|
||||
from tinygrad.helpers import fetch, getenv
|
||||
from examples.test_onnx_imagenet import imagenet_dataloader
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open(fetch(sys.argv[1]), "rb") as f:
|
||||
run_onnx_jit = pickle.load(f)
|
||||
input_name = run_onnx_jit.captured.expected_names[0]
|
||||
device = run_onnx_jit.captured.expected_st_vars_dtype_device[0][-1]
|
||||
print(f"input goes into {input_name=} on {device=}")
|
||||
hit = 0
|
||||
for i,(img,y) in enumerate(imagenet_dataloader(cnt=getenv("CNT", 100))):
|
||||
GlobalCounters.reset()
|
||||
p = run_onnx_jit(**{input_name:img.to(device)})
|
||||
assert p.shape == (1,1000)
|
||||
t = p.argmax().item()
|
||||
hit += y==t
|
||||
print(f"target: {y:3d} pred: {t:3d} acc: {hit/(i+1)*100:.2f}%")
|
||||
|
|
@ -415,6 +415,7 @@ def get_onnx_ops():
|
|||
|
||||
def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
|
||||
kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1):
|
||||
if W.shape[1:] == (1,3,3) and group > 1: group = W.shape[0]
|
||||
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations,
|
||||
padding=_resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad))
|
||||
|
||||
|
|
@ -724,8 +725,20 @@ def get_onnx_ops():
|
|||
ret = _clamp_cast((x / y_scale + 0.4999999 + y_zero_point).int(), out_dtype)
|
||||
else:
|
||||
ret = _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype)
|
||||
|
||||
# you need both NHWC=1 DONT_GROUP_REDUCES=1 for this to work
|
||||
if getenv("NHWC") and len(ret.shape) == 4: return ret.permute(0,2,3,1).contiguous().permute(0,3,1,2)
|
||||
if getenv("NHWC") and len(ret.shape) == 4:
|
||||
in_chans = ret.shape[1]
|
||||
if ret.shape[1] == 3 or in_chans%32 != 0:
|
||||
return ret.permute(0,2,3,1).contiguous().permute(0,3,1,2)
|
||||
else:
|
||||
if in_chans%32 != 0: ret = ret.pad(((0,0), (0,32-(in_chans%32)), (0,0), (0,0)))
|
||||
ret = ret.reshape(ret.shape[0], ret.shape[1]//32, 32, ret.shape[-2], ret.shape[-1])
|
||||
order = (0, 1, 3, 4, 2)
|
||||
ret = ret.permute(order).contiguous().permute(*argsort(order))
|
||||
ret = ret.reshape(ret.shape[0], -1, ret.shape[-2], ret.shape[-1])
|
||||
if in_chans%32 != 0: ret = ret[:, :in_chans, :, :]
|
||||
return ret
|
||||
return ret.contiguous()
|
||||
|
||||
def DynamicQuantizeLinear(x: Tensor):
|
||||
|
|
@ -737,10 +750,66 @@ def get_onnx_ops():
|
|||
return y, scale, zero_point
|
||||
|
||||
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0):
|
||||
if getenv("NHWC"):
|
||||
# pad channels
|
||||
in_shape = x.shape
|
||||
if len(x.shape) == 4 and x.shape[1:] == (1,3,3) and x.shape[0]%32 != 0:
|
||||
# 3x3 depthwise (C,1,3,3). pad C to 32
|
||||
x = x.pad(((0,32-(x.shape[0]%32)), (0,0), (0,0), (0,0)))
|
||||
elif len(x.shape) == 4 and x.shape[2:] == (1,1) and x.shape[0] != 1:
|
||||
# 1x1 conv (C_out,C_in,1,1), pad C_out and C_in to 32
|
||||
if x.shape[0]%32 != 0: x = x.pad(((0,32-(x.shape[0]%32)), (0,0), (0,0), (0,0)))
|
||||
if x.shape[1]%32 != 0: x = x.pad(((0,0), (0,32-(x.shape[1]%32)), (0,0), (0,0)))
|
||||
elif len(x.shape) == 1 and x.shape[0]%32 != 0 and x.shape[0] != 1000:
|
||||
# bias
|
||||
x = x.pad(((0,32-(x.shape[0]%32)),))
|
||||
|
||||
if in_shape != x.shape:
|
||||
xzp = x_zero_point.item()
|
||||
print(f"{in_shape} -> {x.shape}", xzp)
|
||||
# fix up the zero point in the padded area
|
||||
pp = (Tensor.full(in_shape, -xzp, dtype=dtypes.int).pad(tuple([(0, so-si) for si,so in zip(in_shape, x.shape)])) + xzp).cast(x.dtype)
|
||||
x = (x + pp).contiguous()
|
||||
|
||||
if getenv("NHWC") and len(x.shape) == 4 and x.shape[1:] == (3,3,3):
|
||||
x = x.pad(((0,0), (0,0), (0,0), (0,1)))
|
||||
assert x.shape[0] == 32
|
||||
order = (1,2,0,3)
|
||||
x = x.permute(*order).contiguous().permute(*argsort(order))
|
||||
x = x[:, :, :, :3]
|
||||
|
||||
if getenv("NHWC") and len(x.shape) == 4 and x.shape[1:] == (1,3,3):
|
||||
# 3x3 depthwise (C,1,3,3)
|
||||
# "width multiple of 4 depth multiple of 32 aligned to 128bytes"
|
||||
x = x.pad(((0,0), (0,0), (0,0), (0,1)))
|
||||
if x.shape[0]%32 == 0:
|
||||
# depth/32 is a loop -- lsr(depth, #5)
|
||||
# width/4 is a loop -- lsr(out_width, #2)
|
||||
# height is a loop
|
||||
x = x.reshape(-1, 32, 1, 3, 4)
|
||||
order = (0,3,1,2,4)
|
||||
x = x.permute(*order).contiguous().permute(*argsort(order))
|
||||
x = x.reshape(-1, 1, 3, 4)
|
||||
else:
|
||||
assert False # (doesn't happen anymore)
|
||||
#print("HERE", x.shape)
|
||||
order = (2,0,1,3)
|
||||
x = x.permute(*order).contiguous().permute(*argsort(order))
|
||||
x = x[:, :, :, :3]
|
||||
# we increase the filts to 4-aligned for speed (75% util)
|
||||
WEIGHT_SHIFT = 4
|
||||
if getenv("NHWC") and len(x.shape) == 4 and x.shape[2:] == (1,1) and x.shape[1]%WEIGHT_SHIFT == 0:
|
||||
# DSP swizzle memory
|
||||
x = x.reshape(x.shape[0], x.shape[1]//WEIGHT_SHIFT, WEIGHT_SHIFT).permute(1,0,2).contiguous().permute(1,0,2).reshape(x.shape)
|
||||
if x.shape[0]%32 == 0:
|
||||
# DSP swizzle memory (big)
|
||||
x = x.reshape(x.shape[0]//32, 32, x.shape[1]//WEIGHT_SHIFT, WEIGHT_SHIFT).permute(0,2,1,3).contiguous().permute(0,2,1,3).reshape(x.shape)
|
||||
else:
|
||||
# DSP swizzle memory
|
||||
x = x.reshape(x.shape[0], x.shape[1]//WEIGHT_SHIFT, WEIGHT_SHIFT).permute(1,0,2).contiguous().permute(1,0,2).reshape(x.shape)
|
||||
if getenv("NHWC") and x.shape == (1000, 1280):
|
||||
x = x.reshape(-1, 320, 4)
|
||||
order = (1,0,2)
|
||||
x = x.permute(*order).contiguous().permute(*argsort(order))
|
||||
x = x.reshape(-1, 1280)
|
||||
x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size)
|
||||
return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,34 @@
|
|||
import pickle, sys
|
||||
from dataclasses import replace
|
||||
from tinygrad import Device, Context
|
||||
from tinygrad import Device, Context, Tensor, GlobalCounters
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.helpers import getenv, BEAM
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, ScheduleItem, lower_schedule_item
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||
import numpy as np
|
||||
|
||||
def move_jit_captured_to_dev(captured, device="DSP"):
|
||||
captured.expected_st_vars_dtype_device = [x[:3] + (device,) for x in captured.expected_st_vars_dtype_device]
|
||||
|
||||
assign = {}
|
||||
def move_buffer(b):
|
||||
if b in assign: return assign[b]
|
||||
|
||||
if b._base is not None:
|
||||
newbuf = Buffer(device, b.size, b.dtype, base=move_buffer(b._base), offset=b.offset)
|
||||
else:
|
||||
newbuf = Buffer(device, b.size, b.dtype)
|
||||
if b.is_allocated(): newbuf.ensure_allocated().copyin(b.as_buffer())
|
||||
assign[b] = newbuf
|
||||
return assign[b]
|
||||
|
||||
for item in captured.jit_cache:
|
||||
for b in item.bufs:
|
||||
if b is not None: move_buffer(b)
|
||||
captured.jit_cache = [ExecItem(item.prg, [assign.get(b,b) for b in item.bufs]) for item in captured.jit_cache]
|
||||
return captured
|
||||
|
||||
if __name__ == "__main__":
|
||||
with Context(DEBUG=0):
|
||||
|
|
@ -15,6 +37,10 @@ if __name__ == "__main__":
|
|||
print(f"{f.tell()/1e6:.2f}M loaded")
|
||||
print(type(fxn))
|
||||
|
||||
# Move all buffers to DSP device.
|
||||
fxn.captured = move_jit_captured_to_dev(fxn.captured, "DSP")
|
||||
new_jit = []
|
||||
|
||||
knum = 1
|
||||
for ei in fxn.captured.jit_cache:
|
||||
# skip the copy and the first kernel
|
||||
|
|
@ -22,9 +48,47 @@ if __name__ == "__main__":
|
|||
if knum == (pknum:=getenv("KNUM", 0)) or pknum == 0:
|
||||
p: ProgramSpec = ei.prg.p
|
||||
k = Kernel(p.ast, Device["DSP"].renderer)
|
||||
dsp_bufs = [Buffer("DSP", 8192+b.size, b.dtype).view(b.size, b.dtype, 4096) for b in ei.bufs]
|
||||
k.hand_coded_optimizations()
|
||||
|
||||
if getenv("VALIDATE"):
|
||||
with Context(NOOPT=1):
|
||||
lower_schedule_item(ScheduleItem(p.ast, ei.bufs)).run()
|
||||
correct = ei.bufs[0].numpy()
|
||||
ei.bufs[0].copyin(memoryview(bytearray(b'\x00'*ei.bufs[0].nbytes)))
|
||||
GlobalCounters.kernel_count -= 1
|
||||
|
||||
#if knum != 1 and not getenv("NOOPT"): k.hand_coded_optimizations()
|
||||
if not getenv("NOOPT"): k.hand_coded_optimizations()
|
||||
#if knum == 6:
|
||||
# k.apply_opt(Opt(OptOps.UNROLL, 0, 8))
|
||||
# k.apply_opt(Opt(OptOps.UPCAST, 1, 32))
|
||||
# k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
|
||||
|
||||
p2 = k.to_program()
|
||||
new_ei = replace(ei, prg=CompiledRunner(p2), bufs=dsp_bufs)
|
||||
new_ei = replace(ei, prg=CompiledRunner(p2))
|
||||
new_ei.run()
|
||||
new_jit.append(new_ei)
|
||||
test = ei.bufs[0].numpy()
|
||||
|
||||
if getenv("VALIDATE"):
|
||||
import numpy as np
|
||||
"""
|
||||
print("first")
|
||||
print(correct[:150])
|
||||
print(test[:150])
|
||||
print("middle")
|
||||
print(correct[500:600])
|
||||
print(test[500:600])
|
||||
print("last")
|
||||
print(correct[-150:])
|
||||
print(test[-150:])
|
||||
print("skip")
|
||||
print(correct[::32])
|
||||
print(test[::32])
|
||||
"""
|
||||
np.testing.assert_allclose(correct, test, rtol=1e-3, atol=1e-3)
|
||||
knum += 1
|
||||
|
||||
if getenv("RUN_JIT", 0):
|
||||
fxn.captured.free_intermediates()
|
||||
fxn.captured.jit_cache = new_jit
|
||||
fxn(input=Tensor(np.zeros((1, 3, 224, 224), dtype=np.float32), device="DSP"))
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from collections import defaultdict
|
|||
from tinygrad.dtype import dtypes, ImageDType, PtrDType
|
||||
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve
|
||||
from tinygrad.ops import graph_rewrite, GroupOp
|
||||
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
|
||||
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat, commutative
|
||||
from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
from tinygrad.renderer import Renderer
|
||||
|
|
@ -12,10 +12,16 @@ from tinygrad.renderer import Renderer
|
|||
# ***** load/store grouping *****
|
||||
|
||||
def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
||||
if getenv("UNSAFE_DISABLE_MASK", 0): mask = None
|
||||
vectorize_mask = getenv("VECTORIZE_MASK", 0) and buf.arg == 0 and mask is not None
|
||||
|
||||
# generate the individual indexes
|
||||
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]),
|
||||
symbolic_flat+load_store_indexing, name=f"index_buf_{buf.arg}")
|
||||
if vectorize_mask:
|
||||
# no load_store_indexing if we are doing vectorized mask
|
||||
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]),
|
||||
symbolic_flat+commutative, name=f"index_buf_{buf.arg}")
|
||||
else:
|
||||
midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), mask.gep(i) if mask is not None else None) for i in range(vec.dtype.count)]),
|
||||
symbolic_flat+commutative+load_store_indexing, name=f"index_buf_{buf.arg}")
|
||||
# extract all the relevant offsets
|
||||
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
|
||||
for i in range(vec.dtype.count):
|
||||
|
|
@ -24,7 +30,7 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
|||
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
|
||||
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
||||
else: root_src, arg = idx, 0
|
||||
if len(midx.src[i].src) == 3: root_src = (midx.src[i].src[2], root_src)
|
||||
if len(midx.src[i].src) == 3 and not vectorize_mask: root_src = (midx.src[i].src[2], root_src)
|
||||
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
|
||||
|
||||
# the buf.dtype is always a pointer
|
||||
|
|
@ -35,10 +41,26 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
|||
idxs: list[int|None] = [None]*vec.dtype.count
|
||||
global_offset = 0
|
||||
for offsets in offsets_rootsrc.values():
|
||||
if 0 in offsets:
|
||||
match = True
|
||||
for i in range(0, max(offsets.keys()), 4):
|
||||
if i in offsets and i+1 in offsets and i+2 in offsets and i+3 not in offsets: pass
|
||||
else: match = False
|
||||
if match:
|
||||
for i in range(0, max(offsets.keys()), 4):
|
||||
assert i+3 not in offsets
|
||||
offsets[i+3] = {}
|
||||
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
|
||||
for grp in grouped_offsets:
|
||||
# get the index offset for this element. using [0] is okay, because they are the same
|
||||
lidx = midx.src[offsets[grp[0]][0]]
|
||||
|
||||
if vectorize_mask:
|
||||
allgrp = [midx.src[offsets[g][0]] for g in grp]
|
||||
base = [x.src[2].cast(dtypes.uchar) if len(x.src) > 2 else UOp.const(dtypes.uchar, 1) for x in allgrp]
|
||||
vecmask = UOp(Ops.VECTORIZE, dtypes.uchar.vec(len(base)), tuple(base))
|
||||
lidx = lidx.replace(src=lidx.src[0:2]+(vecmask,))
|
||||
|
||||
if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, local=ptrdtype.local))
|
||||
# set the idxs of the output
|
||||
for i,g in enumerate(grp):
|
||||
|
|
@ -88,6 +110,11 @@ load_store_folding = PatternMatcher([
|
|||
|
||||
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
||||
if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.const_like(0)
|
||||
if getenv("DEBUG_SIMPLIFY"):
|
||||
print("****")
|
||||
print("idx in: ", start_idx.render())
|
||||
print("valid: ", valid.render())
|
||||
print("simp: ", idx.render())
|
||||
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid)
|
||||
|
||||
# wait for it to be image indexed before running simplification
|
||||
|
|
@ -150,7 +177,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
|||
must_divide = True
|
||||
if ctx is not None and ctx.device == "DSP":
|
||||
lengths = [128,64,32,16,8,4]
|
||||
if ls.dtype.count < 128: return None # leave these as loads (probably means something is broken)
|
||||
if ls.src[0].dtype.count < 128: return None # leave these as loads (probably means something is broken)
|
||||
must_divide = False
|
||||
elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
|
||||
pass
|
||||
|
|
@ -168,7 +195,11 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
|||
if global_offset+fold_length > sz: continue
|
||||
oidx = idx.src[1] + global_offset
|
||||
if must_divide and oidx.simplify().divides(fold_length) is None: continue
|
||||
lidx = buf.index(oidx, idx.src[2] if len(idx.src) > 2 else None)
|
||||
if len(idx.src) > 2 and idx.src[2].dtype.count > 1:
|
||||
# vectorized
|
||||
lidx = buf.index(oidx, idx.src[2].gep(tuple(range(global_offset, global_offset+fold_length))))
|
||||
else:
|
||||
lidx = buf.index(oidx, idx.src[2] if len(idx.src) > 2 else None)
|
||||
if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, local=ptrdtype.local))
|
||||
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))
|
||||
else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length)))
|
||||
|
|
@ -272,7 +303,34 @@ pm_render = PatternMatcher([
|
|||
|
||||
# *** uop graph ***
|
||||
|
||||
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.ops import identity_element
|
||||
from tinygrad.helpers import partition
|
||||
|
||||
@dataclass
|
||||
class ReduceContext:
|
||||
acc_num: int = 0
|
||||
|
||||
def reduce_to_acc(ctx:ReduceContext, x:UOp):
|
||||
ret = x.src[0]
|
||||
reduce_range, reduce_expand = partition(x.src, lambda y: y.op is Ops.RANGE)
|
||||
if len(reduce_range) == 0: return ret
|
||||
if all(y not in reduce_range for y in ret.toposort):
|
||||
# TODO: this shouldn't be here
|
||||
return ret*prod([y.src[1] for y in reduce_range]).broadcast(ret.dtype.count)
|
||||
alu_op = x.arg
|
||||
# create acc
|
||||
acc = UOp(Ops.DEFINE_ACC, x.dtype, (x.const_like(identity_element(alu_op, x.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,))
|
||||
ctx.acc_num += 1
|
||||
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [acc]+list(reduce_expand))
|
||||
# create ACC and assign
|
||||
return acc.assign(ret)
|
||||
|
||||
pm_reduce = PatternMatcher([
|
||||
(UPat(Ops.REDUCE, name="x"), reduce_to_acc)
|
||||
])
|
||||
|
||||
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None, is_conv=False) -> UOp:
|
||||
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
|
||||
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
|
||||
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
|
||||
|
|
@ -283,8 +341,16 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
|||
else: sink = graph_rewrite(sink, sym+load_store_folding+correct_load_store+load_store_indexing, ctx=opts)
|
||||
|
||||
# optional pre matcher
|
||||
if opts is not None and opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher)
|
||||
if opts is not None and opts.pre_matcher is not None:
|
||||
if is_conv:
|
||||
from tinygrad.runtime.ops_dsp import conv_pm
|
||||
sink = graph_rewrite(sink, conv_pm+opts.pre_matcher)
|
||||
else:
|
||||
sink = graph_rewrite(sink, opts.pre_matcher)
|
||||
|
||||
# remove reduce
|
||||
sink = graph_rewrite(sink, pm_reduce, ctx=ReduceContext(), name="remove_reduce")
|
||||
|
||||
# final rules for the renderer (without sym)
|
||||
sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher)
|
||||
sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+extra_matcher+pm_render)
|
||||
return sink
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
# this converts a lowerer program into a vectorized program
|
||||
|
||||
import functools, itertools, operator
|
||||
from tinygrad.helpers import AMX, dedup, flatten, all_same, prod
|
||||
from tinygrad.helpers import AMX, dedup, flatten, all_same, prod, getenv
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, graph_rewrite
|
||||
from tinygrad.codegen.symbolic import sym
|
||||
|
||||
|
|
@ -116,9 +117,49 @@ migrate_indexing = PatternMatcher([
|
|||
(UPat(Ops.STORE, name="root"), create_gate),
|
||||
])
|
||||
|
||||
# **** IGNORE support ****
|
||||
|
||||
pm_store_ignore = PatternMatcher([
|
||||
(UPat().index(UPat(), UPat(name="mask")).store(UPat()).named("store"),
|
||||
lambda store,mask: store.replace(src=(store.src[0], UOp(Ops.IGNORE, src=(store.src[1], mask)))) if store.src[1].op is not Ops.IGNORE else None),
|
||||
])
|
||||
|
||||
def debug_ignore(x, y):
|
||||
if getenv("DEBUG_IGNORE"):
|
||||
print("****")
|
||||
print("seen ", x.render())
|
||||
print("ignore", y.render())
|
||||
# this is totally wrong
|
||||
return x.const_like(True)
|
||||
|
||||
pm_move_ignore = PatternMatcher([
|
||||
# IGNORE on SELF is nothing
|
||||
(UPat(Ops.IGNORE, src=(UPat(name="x"), UPat(name="x"))), lambda x: x.const_like(True)),
|
||||
# IGNORE debug
|
||||
(UPat(Ops.IGNORE, src=(UPat(dtype=dtypes.bool, name="x"), UPat(dtype=dtypes.bool, name="y"))), debug_ignore),
|
||||
# IGNORE with AND on SELF is nothing (is this right?)
|
||||
#(UPat(Ops.IGNORE, src=(UPat(name="x"), UPat(name="x") & UPat())), lambda x: x.const_like(True)),
|
||||
# IGNORE on a CONST is nothing
|
||||
(UPat(Ops.IGNORE, src=(UPat((Ops.CONST, Ops.VCONST), name="c"), UPat())), lambda c: c),
|
||||
# move the IGNOREs
|
||||
(UPat(Ops.IGNORE, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.VECTORIZE), name="alu"), UPat.var("mask")), name="ig"),
|
||||
lambda ig,alu,mask: alu.replace(src=tuple(UOp(Ops.IGNORE, x.dtype, (x, mask)) for x in alu.src))),
|
||||
])
|
||||
|
||||
pm_delete_ignore = PatternMatcher([
|
||||
# IGNORE on SELF is nothing
|
||||
(UPat(Ops.IGNORE, src=(UPat(name="x"), UPat())), lambda x: x),
|
||||
])
|
||||
|
||||
def expand_rewrite(sink:UOp) -> UOp:
|
||||
# initial symbolic + migrate indexing (remove this)
|
||||
sink = graph_rewrite(sink, sym+migrate_indexing)
|
||||
|
||||
# expand
|
||||
return graph_rewrite(sink, sym+expander)
|
||||
# store IGNORE
|
||||
sink = graph_rewrite(sink, pm_store_ignore, name="store_ignore")
|
||||
|
||||
# move IGNORE
|
||||
sink = graph_rewrite(sink, pm_move_ignore, name="move_ignore")
|
||||
|
||||
# expand + remove surviving ignores
|
||||
return graph_rewrite(sink, pm_delete_ignore+sym+expander)
|
||||
|
|
|
|||
|
|
@ -374,7 +374,7 @@ class Kernel:
|
|||
if opt.op is OptOps.LOCAL: # cyan
|
||||
# NOTE: LLVM/CPU can use locals too, but they are treated the same as globals (still helpful for L1 cache)
|
||||
# it's disabled for now since it makes BEAM slow for little gain
|
||||
check(self.opts.has_local, "target does not support local")
|
||||
#check(self.opts.has_local, "target does not support local")
|
||||
check(axis < self.global_dims, "local is for globals")
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce)
|
||||
self.local_dims += 1
|
||||
|
|
@ -437,6 +437,72 @@ class Kernel:
|
|||
return self
|
||||
|
||||
def hand_coded_optimizations(self) -> Kernel:
|
||||
if self.opts.device == "DSP":
|
||||
k = self
|
||||
# special path for DSP
|
||||
if k.full_shape[-3:] == (32,3,3):
|
||||
# 3x3 dwconv
|
||||
# kernel 49 is broken
|
||||
if k.full_shape[-4]%4 != 0: k.apply_opt(Opt(OptOps.PADTO, len(k.full_shape)-4, 4))
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 0))
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 0))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-3, 32))
|
||||
if k.full_shape[len(k.full_shape)-4]%4 == 0:
|
||||
#if k.full_shape[len(k.full_shape)-4] <= 8: k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-4, 0))
|
||||
#else: k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-4, 4))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, len(k.full_shape)-4, 4))
|
||||
# if this is small, swap it
|
||||
# NOTE: this is breaking something (should be fixed w/o padto)
|
||||
# kernel 23 is broken with this
|
||||
if k.full_shape[0] <= 6: k.apply_opt(Opt(OptOps.SWAP, 0, 1))
|
||||
elif k.full_shape[-4:] == (32,3,3,3):
|
||||
# 3x3 normal conv
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 2, 0))
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 1, 0))
|
||||
# more UNROLLs aren't working well here, but they should be
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 2, 32))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
|
||||
elif len(k.full_shape) == 3 and k.full_shape[1] == 32 and k.first_reduce == 2:
|
||||
# weight that's exactly 32
|
||||
# NOTE: this pad might be broken
|
||||
if k.full_shape[0]%4 != 0: k.apply_opt(Opt(OptOps.PADTO, 0, 4))
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 8))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 32))
|
||||
if k.full_shape[0]%4 == 0: k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
|
||||
elif len(k.full_shape) == 4 and k.full_shape[2] == 32 and k.first_reduce == 3:
|
||||
# weight that has more than 32
|
||||
# NOTE: this pad is broken on kernel 50
|
||||
if k.full_shape[1]%4 != 0: k.apply_opt(Opt(OptOps.PADTO, 1, 4))
|
||||
# weight with more
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 8))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 2, 32))
|
||||
if k.full_shape[1]%4 == 0: k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
|
||||
# if the more is small, upcast it (kernel 50 is broken with this)
|
||||
if k.full_shape[0] <= 6: k.apply_opt(Opt(OptOps.UPCAST, 0, 0))
|
||||
elif len(k.full_shape) == 2 and k.first_reduce == 1:
|
||||
# unroll to 4 if we can
|
||||
if k.full_shape[k.first_reduce]%4 == 0: k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
|
||||
# always pad to 128
|
||||
# NOTE: this breaks kernel 66
|
||||
if k.full_shape[0]%128 != 0: k.apply_opt(Opt(OptOps.PADTO, 0, 128))
|
||||
if k.full_shape[0]%128 == 0: k.apply_opt(Opt(OptOps.UPCAST, 0, 128))
|
||||
elif len(k.full_shape) == 1 and k.full_shape[0] > 1000:
|
||||
# pad to 128 and run on 128
|
||||
if k.full_shape[0]%128 != 0: k.apply_opt(Opt(OptOps.PADTO, 0, 128))
|
||||
if k.full_shape[0]%128 == 0: k.apply_opt(Opt(OptOps.UPCAST, 0, 128))
|
||||
|
||||
# make all non first dimensions local
|
||||
"""
|
||||
if getenv("MULTICORE", 0) == 2 and len(k.full_shape) >= 1 and k.full_shape[0] > 1:
|
||||
if k.full_shape[0]%2 == 1: k.apply_opt(Opt(OptOps.PADTO, 0, 2))
|
||||
if k.full_shape[0] > 2: k.apply_opt(Opt(OptOps.LOCAL, 0, k.full_shape[0]//2))
|
||||
for i in range(1, k.first_reduce-1): k.apply_opt(Opt(OptOps.LOCAL, 1, 0))
|
||||
else:
|
||||
for i in range(1, k.first_reduce): k.apply_opt(Opt(OptOps.LOCAL, 1, 0))
|
||||
"""
|
||||
|
||||
return self
|
||||
|
||||
self.required_optimizations()
|
||||
|
||||
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
|
||||
|
|
@ -678,7 +744,11 @@ class Kernel:
|
|||
# TODO: sadly modified_ast doesn't pass the shape spec because of how group_for_reduces constructs UOps, there's probably a way to fix this
|
||||
#if __debug__: type_verify(list(modified_ast.toposort), shape_spec)
|
||||
|
||||
self.uops:list[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts))
|
||||
is_conv = (len(self.full_shape) == 6 and self.full_shape[2:4] == (3,3))
|
||||
is_conv = is_conv or (len(self.full_shape) == 6 and self.full_shape[3:5] == (3,3))
|
||||
is_conv = is_conv or (len(self.full_shape) == 7 and self.full_shape[3:5] == (3,3))
|
||||
is_conv = is_conv or self.full_shape[-2:] == (3,3)
|
||||
self.uops:list[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts, is_conv))
|
||||
if DEBUG >= 5: print_uops(self.uops)
|
||||
return self
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@ import functools, itertools, operator, math
|
|||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
from tinygrad.dtype import dtypes, PtrDType, least_upper_dtype
|
||||
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop
|
||||
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, sint_to_uop
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap, QUANTIZE
|
||||
from tinygrad.helpers import all_int, prod, flatten, unwrap, QUANTIZE
|
||||
from tinygrad.codegen.expander import expand_rewrite
|
||||
from tinygrad.codegen.symbolic import symbolic
|
||||
|
||||
|
|
@ -112,12 +112,24 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
|||
|
||||
def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
||||
# NOTE: always using ridxs is fine here
|
||||
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
|
||||
#reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
|
||||
reduce_indexes = [ctx.ridxs[i] for i in x.axis_arg]
|
||||
all_nodes = flatten([x.toposort for x in reduce_indexes])
|
||||
reduce_expand = [x for x in all_nodes if x.op is Ops.UNROLL]
|
||||
reduce_range = [x for x in all_nodes if x.op is Ops.RANGE]
|
||||
assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand} for {x.axis_arg}"
|
||||
alu_op: Ops = x.arg[0]
|
||||
ret = x.src[0]
|
||||
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
||||
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
|
||||
ret = (functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(ret.dtype.count)]),)
|
||||
else:
|
||||
ret = (ret,)
|
||||
return UOp(Ops.REDUCE, x.dtype, ret+tuple(reduce_range), alu_op)
|
||||
"""
|
||||
# create acc
|
||||
acc = UOp(Ops.DEFINE_ACC, x.dtype, (x.const_like(identity_element(alu_op, x.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,))
|
||||
if len(x.src) > 1: acc = acc + x.src[1]
|
||||
ctx.acc_num += 1
|
||||
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
||||
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
|
||||
|
|
@ -127,6 +139,7 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
|||
if not len(reduce_range): return ret
|
||||
# create ACC and assign
|
||||
return acc.assign(ret)
|
||||
"""
|
||||
|
||||
def lower_load_store(ctx: IndexContext, x: UOp):
|
||||
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
|
||||
|
|
@ -162,7 +175,28 @@ pm_lowerer = PatternMatcher([
|
|||
|
||||
# **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints ****
|
||||
|
||||
FP = (1 << 16)
|
||||
FP = (1 << 15)
|
||||
FP_DTYPE = dtypes.int32
|
||||
def fixed_point_mul(x, c1, cc, y=None, c2=None):
|
||||
if y is not None:
|
||||
return ((x.cast(FP_DTYPE)*(c1*FP).cast(FP_DTYPE) + y.cast(FP_DTYPE)*(c2*FP).cast(FP_DTYPE) + (cc*FP).cast(FP_DTYPE)) // FP).cast(dtypes.int)
|
||||
else:
|
||||
return ((x.cast(FP_DTYPE)*(c1*FP).cast(FP_DTYPE) + (cc*FP).cast(FP_DTYPE)) // FP).cast(dtypes.int)
|
||||
|
||||
def fixed_const_reduce(r, gate, c1):
|
||||
# TODO: this is doable
|
||||
st = gate.src[0].arg
|
||||
print(st, r.arg)
|
||||
#return c1
|
||||
|
||||
def remove_matching_mask(v1, v2, ld):
|
||||
a1 = v1.arg.to_indexed_uops()[1].simplify()
|
||||
a2 = v2.arg.to_indexed_uops()[1].simplify()
|
||||
#print(a1.render(), a2.render())
|
||||
if a1 == a2: return ld
|
||||
# WRONG!
|
||||
return ld
|
||||
|
||||
pm_quant = symbolic+PatternMatcher([
|
||||
# cast after add/mul
|
||||
(UPat.var("x").cast(dtypes.float32) + UPat.var("y").cast(dtypes.float32),
|
||||
|
|
@ -192,12 +226,13 @@ pm_quant = symbolic+PatternMatcher([
|
|||
UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
|
||||
lambda ld,v,c1: ld*c1),
|
||||
|
||||
# fixed point mult, replace (x.float()*c1+c2).int() with an int expression
|
||||
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("c2")).cast(dtypes.int),
|
||||
lambda x,c1,c2: (x * (c1 * FP).cast(dtypes.int) + (c2 * FP).cast(dtypes.int)) // FP),
|
||||
# fixed point mult, replace (x.float()*c1 + y.float()*c2) with an int expression
|
||||
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("y").cast(dtypes.float)*UPat.var("c2")),
|
||||
lambda x,y,c1,c2: ((x * (c1 * FP).cast(dtypes.int) + y * (c2 * FP).cast(dtypes.int)) // FP).cast(dtypes.float)),
|
||||
# const push through add
|
||||
((UPat.var("x")*UPat.cvar("c1") + UPat.var("y")*UPat.cvar("c2")) * UPat.cvar("c3"), lambda x,y,c1,c2,c3: (x*c1*c3) + (y*c2*c3)),
|
||||
|
||||
# fixed point mult, replace (x.float()*c1 + c2).int() with an int expression
|
||||
# fixed point mult, replace (x.float()*c1 + y.float()*c2 + c3).int() with an int expression
|
||||
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("cc")).cast(dtypes.int), fixed_point_mul),
|
||||
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("y").cast(dtypes.float)*UPat.var("c2")+UPat.var("cc")).cast(dtypes.int),fixed_point_mul),
|
||||
|
||||
# where move
|
||||
(UPat.var("valid").where(UPat.var("yes"), UPat(Ops.CONST, arg=0))*UPat.var("mul"), lambda valid, yes, mul:
|
||||
|
|
@ -211,13 +246,24 @@ pm_quant = symbolic+PatternMatcher([
|
|||
|
||||
# where on two adds
|
||||
(UPat.var("x") + UPat.var("v").where(UPat.var("a0"), UPat.var("a1")) + UPat.var("v").where(UPat.var("b0"), UPat.var("b1")),
|
||||
lambda x,v,a0,a1,b0,b1: x + v.where(a0+a1, b0+b1)),
|
||||
lambda x,v,a0,a1,b0,b1: x + v.where(a0+b0, a1+b1)),
|
||||
|
||||
# split REDUCE into multiple reduces
|
||||
# split REDUCE into multiple reduces (who remembers FOIL?)
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, name="v1")+UPat.var("c1")) * UPat(Ops.CAST, name="v2",), name="r"),
|
||||
lambda v1,v2,c1,r: r.replace(src=(v1*v2,)) + r.replace(src=(c1*v2,))),
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")), name="r"),
|
||||
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,))),
|
||||
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,)) + r.replace(src=(c1*c2,))),
|
||||
|
||||
# hack REDUCE (so wrong that it breaks it)
|
||||
#(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VALID, name='gate').where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)),), name="r"), fixed_const_reduce),
|
||||
|
||||
# MUL by 1/0 on LOAD where the masks match
|
||||
(UPat(Ops.WHERE, src=(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v1"),)), UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))) * \
|
||||
UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v2")), name="ld"), remove_matching_mask),
|
||||
|
||||
#lambda ld,v1,v2: ld if v1.arg.to_indexed_uops()[1].simplify() == v2.arg.to_indexed_uops()[1].simplify() else None),
|
||||
# NOTE: this clause is completely false and might break things
|
||||
# or True else None),
|
||||
])
|
||||
|
||||
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import math, operator, struct, functools
|
|||
from collections import defaultdict
|
||||
from tinygrad.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType
|
||||
from tinygrad.helpers import partition, all_same, prod, getenv, DEBUG, flatten
|
||||
from tinygrad.helpers import partition, all_same, prod, getenv, DEBUG, flatten, get_single_element
|
||||
from tinygrad.codegen.transcendental import xpow
|
||||
|
||||
# ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ********
|
||||
|
|
@ -113,6 +113,8 @@ def canonicalize_simplex(X:UOp) -> UOp|None:
|
|||
return functools.reduce(operator.add, ret) if changed else None
|
||||
|
||||
def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split_rem: bool=False) -> UOp|None:
|
||||
if x.vmin < -10000000: return None
|
||||
|
||||
# simplify x // y or x % y, None means no change
|
||||
# simple cancel div/mod case
|
||||
if y.vmin != 0 != y.vmax and (q:=x.vmin//y.vmin) == x.vmin//y.vmax == x.vmax//y.vmin == x.vmax//y.vmax:
|
||||
|
|
@ -173,7 +175,7 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split
|
|||
gep_pushing = PatternMatcher([
|
||||
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
|
||||
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
|
||||
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(len(g1.arg))))),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
|
||||
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
|
||||
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
|
||||
|
|
@ -182,11 +184,28 @@ gep_pushing = PatternMatcher([
|
|||
(UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
|
||||
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
|
||||
if not isinstance(gep.dtype, PtrDType) else None),
|
||||
# CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)
|
||||
(UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \
|
||||
if not isinstance(x.dtype, PtrDType) else None),
|
||||
# VECTORIZE on same GEP
|
||||
(UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))),
|
||||
# CAST on multi GEP
|
||||
#(UPat(Ops.CAST, src=(UPat(Ops.GEP, name="g"),), name="c"),
|
||||
# lambda c,g: g.src[0].gep(g.arg[0]).cast(c.dtype.scalar()).broadcast(len(g.arg)) if len(g.arg) > 1 and all_same(g.arg) else None),
|
||||
# VECTORIZE/CONST
|
||||
#(UPat(Ops.VECTORIZE, src=UPat.var("x"))+UPat.cvar("c", vec=False), lambda x,c: (x+c.arg).broadcast(c.dtype.count)),
|
||||
])
|
||||
|
||||
commutative = PatternMatcher([
|
||||
# ** COMMUTATIVE flipping (only for ints) **
|
||||
# NOTE: this can break merging vector math by only flipping some of them
|
||||
(UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
||||
])
|
||||
|
||||
symbolic = symbolic_simple+PatternMatcher([
|
||||
# ** COMMUTATIVE flipping (only for ints) **
|
||||
(UPat(GroupOp.Commutative, dtype=dtypes.ints, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
||||
# NOTE: this can break merging vector math by only flipping some of them
|
||||
#(UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
||||
# ** boolean algebra **
|
||||
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
|
||||
# ** combine terms **
|
||||
|
|
@ -418,9 +437,6 @@ sym = symbolic_flat+PatternMatcher([
|
|||
(UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)),
|
||||
# push some GEPs through WMMAs
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
|
||||
# CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)
|
||||
(UPat(Ops.CAT, name="x"), lambda x: UOp(Ops.VECTORIZE, x.dtype, tuple(y.gep(i) for y in x.src for i in range(y.dtype.count))) \
|
||||
if not isinstance(x.dtype, PtrDType) else None),
|
||||
# tensor core with a 0 input is acc
|
||||
(UPat(Ops.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc),
|
||||
(UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc),
|
||||
|
|
|
|||
|
|
@ -168,8 +168,8 @@ class CapturedJit(Generic[ReturnType]):
|
|||
update_depends(depends, self.jit_cache)
|
||||
for b in depends:
|
||||
if b is not None:
|
||||
b.deallocate()
|
||||
if b._base is not None and b._base.allocated_views == 0: b._base.deallocate()
|
||||
if b.is_allocated(): b.deallocate()
|
||||
if b._base is not None and b._base.allocated_views == 0 and b._base.is_allocated(): b._base.deallocate()
|
||||
self.__post_init__() # reset the graph state
|
||||
|
||||
def optimize_weights(self):
|
||||
|
|
@ -314,6 +314,7 @@ class TinyJit(Generic[ReturnType]):
|
|||
|
||||
# set this for next run
|
||||
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device)
|
||||
self.captured.optimize_weights()
|
||||
elif self.cnt >= 2:
|
||||
# jit exec
|
||||
assert self.captured is not None
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Optional, cast, Generator
|
|||
import time, pprint
|
||||
from dataclasses import dataclass, replace
|
||||
from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
|
||||
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU
|
||||
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, Context
|
||||
from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
|
||||
|
|
@ -178,7 +178,8 @@ def run_schedule(schedule:list[ScheduleItem], var_vals:Optional[dict[Variable, i
|
|||
ei.run(var_vals, do_update_stats=do_update_stats)
|
||||
|
||||
# validate the output buffers match (NOTE: this is assuming the output is buffer 0)
|
||||
lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata)).run(var_vals, do_update_stats=do_update_stats)
|
||||
with Context(NOOPT=1):
|
||||
lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata)).run(var_vals, do_update_stats=do_update_stats)
|
||||
import numpy as np
|
||||
np.testing.assert_allclose(nb[0].numpy(), si.bufs[0].numpy(), rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -402,7 +402,7 @@ if CAPTURE_PROCESS_REPLAY:
|
|||
class ScheduleItem:
|
||||
ast: UOp
|
||||
bufs: tuple[Buffer, ...]
|
||||
metadata: tuple[Metadata, ...]
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
|
||||
@track_rewrites(name_fxn=lambda r: f"Schedule {pluralize('Kernel', len(r[0]))}"+(f" (with_{pluralize('Var', len(r[1]))})" if len(r[1]) != 0 else ""))
|
||||
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
||||
|
|
|
|||
|
|
@ -258,6 +258,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
src:tuple[UOp, ...] = tuple()
|
||||
arg:Any = None
|
||||
children:set[weakref.ref[UOp]] = field(default_factory=set)
|
||||
def __post_init__(self):
|
||||
if self.op is Ops.MUL: assert all_same([x.dtype for x in self.src])
|
||||
assert all(isinstance(x, UOp) for x in self.src)
|
||||
def __del__(self):
|
||||
if self.op is Ops.BUFFER and (buffer:=buffers.get(self)) is not None: buffer.ref(-1)
|
||||
if (ref:=UOpMetaClass.ucache.get(k:=(self.op, self.dtype, self.src, self.arg))) is not None:
|
||||
|
|
@ -320,10 +323,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
@functools.cached_property
|
||||
def full_shape(self) -> tuple[sint, ...]:
|
||||
if self.op is Ops.VIEW: return self.shape
|
||||
# TODO: this exists because wmma creates consts without ShapeTracker in the AST, there's probably a way to fix this
|
||||
parent_shapes = [x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL} and not (x.op is Ops.CONST and x.st is None)]
|
||||
# TODO: this should check if st is None, it cannot because local reduce has implicit movement ops
|
||||
return tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL} \
|
||||
# TODO: this exists because wmma creates consts without ShapeTracker in the AST, there's probably a way to fix this
|
||||
and not (x.op is Ops.CONST and x.st is None)]))
|
||||
return tuple(smax(x) for x in zip(*[x for x in parent_shapes if x != ()]))
|
||||
@property
|
||||
def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
|
||||
@property
|
||||
|
|
@ -333,9 +336,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
|
||||
def simplify(self):
|
||||
# late import!
|
||||
from tinygrad.codegen.symbolic import symbolic
|
||||
from tinygrad.codegen.symbolic import symbolic_flat, commutative
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
return graph_rewrite(self, symbolic)
|
||||
return graph_rewrite(self, symbolic_flat+commutative)
|
||||
def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
|
||||
def _eval(self, dtype, expected_type:Type[T]) -> T:
|
||||
assert self.dtype in dtype, f"eval with wrong dtype {self}"
|
||||
|
|
@ -347,6 +350,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
def __int__(self): return self._eval(dtypes.ints, int)
|
||||
def __float__(self): return self._eval(dtypes.floats, float)
|
||||
def substitute(self, dvars:dict[UOp, UOp]):
|
||||
if len(dvars) == 0: return self
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
return graph_rewrite(self, _substitute, dvars, bottom_up=True)
|
||||
|
||||
|
|
@ -384,7 +388,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
if self.op is Ops.VCONST: return UOp.const(self.dtype.scalar(), self.arg[i])
|
||||
if self.op is Ops.CONST: return UOp.const(self.dtype.scalar(), self.arg)
|
||||
i = (i,)
|
||||
if (self.dtype.vcount == len(i) and i == tuple(range(len(i)))) or self.dtype == dtypes.void: return self
|
||||
#if self.dtype.vcount == 1 and i == (0,): return self
|
||||
if (self.dtype.count == len(i) and i == tuple(range(len(i)))) or self.dtype == dtypes.void: return self
|
||||
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
|
||||
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, src=(self,)+src, **kwargs)
|
||||
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
||||
|
|
@ -974,6 +979,8 @@ renderer = PatternMatcher([
|
|||
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg}")),
|
||||
(UPat((Ops.CONST, Ops.VCONST), name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
|
||||
(UPat(Ops.UNROLL, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UNROLL({x.src[0].arg}, {x.arg})")),
|
||||
(UPat(Ops.CAST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"({str(x.dtype)[7:]})({x.src[0].arg})")),
|
||||
(UPat(Ops.LOAD), lambda: UOp(Ops.NOOP, arg="load")),
|
||||
(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
||||
(UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
|
||||
(UPat(Ops.MAX, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
|
||||
|
|
|
|||
|
|
@ -70,12 +70,16 @@ class Estimates:
|
|||
if u.op is Ops.RANGE:
|
||||
mult_stack.append(mults)
|
||||
mults *= (u.src[1] - u.src[0]).ssimplify()
|
||||
mults = mults.substitute({x:x.const_like(0) for x in mults.toposort if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults
|
||||
elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1)
|
||||
elif u.op is Ops.SPECIAL: mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
|
||||
elif u.op is Ops.LOAD: lds += u.dtype.itemsize * mults
|
||||
elif u.op is Ops.STORE: lds += u.src[1].dtype.itemsize * mults
|
||||
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
|
||||
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
||||
elif u.op in {Ops.CUSTOM, Ops.CUSTOMI} and u not in dont_count:
|
||||
if u.arg.startswith("__builtin_HEXAGON_V6_vrmpy"): flops += 32*mults*(8 if 'acc' in u.arg else 7)
|
||||
if u.arg.startswith("__builtin_HEXAGON_A2_vraddub"): flops += mults*(17 if 'acc' in u.arg else 16)
|
||||
return Estimates(flops, lds, lds) # TODO: properly track memory, lds is always a high estimate
|
||||
|
||||
@dataclass
|
||||
|
|
@ -104,13 +108,16 @@ class ProgramSpec:
|
|||
if u.op is Ops.DEFINE_VAR: self.vars.append(u)
|
||||
if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg)
|
||||
if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
|
||||
# DSP masked store
|
||||
if u.op is Ops.CUSTOM and u.arg.startswith("__builtin_HEXAGON_V6_vS32b"):
|
||||
self.outs.extend([x.arg for x in u.src[1].toposort if x.op is Ops.DEFINE_GLOBAL])
|
||||
if u.op is Ops.LOAD: self.ins.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
|
||||
if u.op is Ops.SPECIAL:
|
||||
# NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
|
||||
if u.arg[0][0] == 'i': self.local_size = None
|
||||
special_size = self.local_size if u.arg[0][0] == 'l' else self.global_size
|
||||
assert special_size is not None
|
||||
special_size[int(u.arg[0][-1])] = u.arg[1]
|
||||
# special_size = self.local_size if u.arg[0][0] == 'l' else self.global_size
|
||||
# assert special_size is not None
|
||||
# special_size[int(u.arg[0][-1])] = u.arg[1]
|
||||
self.vars = sorted(self.vars, key=lambda v: v.arg)
|
||||
self.outs = sorted(dedup(self.outs))
|
||||
self.ins = sorted(dedup(self.ins))
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ base_rewrite = PatternMatcher([
|
|||
extra_pm = PatternMatcher([
|
||||
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
|
||||
(UPat(Ops.BITCAST, name="x"),
|
||||
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
|
||||
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op not in {Ops.NOOP, Ops.LOAD, Ops.CUSTOM} else None),
|
||||
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
|
||||
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
||||
# devectorize any bools
|
||||
|
|
@ -160,7 +160,7 @@ class CStyleLanguage(Renderer):
|
|||
|
||||
if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1
|
||||
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \
|
||||
(u.op in {Ops.VECTORIZE, *GroupOp.ALU, Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
|
||||
(u.op in {Ops.VECTORIZE, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
|
||||
r[u] = l
|
||||
else:
|
||||
if u.op in {Ops.RANGE, Ops.ASSIGN, Ops.DEFINE_LOCAL} or u.dtype == dtypes.void:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import cast
|
||||
import itertools
|
||||
from tinygrad.helpers import dedup, DEBUG, to_function_name
|
||||
from tinygrad.helpers import dedup, DEBUG, to_function_name, getenv
|
||||
from tinygrad.engine.jit import GraphRunner, GraphException
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
|
|
@ -24,17 +24,18 @@ class CPUGraph(GraphRunner):
|
|||
if buf in input_rawbuffers: return f"arg{input_rawbuffers.index(buf)}"
|
||||
return f"({device.renderer.render_dtype(buf.dtype)}*)(cbuf{self.base_bufs.index(buf.base)} + {buf.offset})"
|
||||
|
||||
batched = ["void batched("+','.join([f"{device.renderer.render_dtype(x[1][0])} {x[0]}" for x in targs])+") {"]
|
||||
batched = ["void batched("+','.join([f"{device.renderer.render_dtype(x[1][0])} {x[0]}" for x in targs])+", int gl0, void* sync) {"]
|
||||
for i, ji in enumerate(jit_cache):
|
||||
args = [render_arg(buf) for buf in ji.bufs] + [x.expr for x in cast(CompiledRunner, ji.prg).p.vars]
|
||||
batched.append(f" {to_function_name(cast(CompiledRunner, ji.prg).p.name)}({','.join(args)});")
|
||||
batched.append(f" {to_function_name(cast(CompiledRunner, ji.prg).p.name)}({','.join(args)}, gl0, 0x0);")
|
||||
if getenv("MULTICORE", 0) != 0: batched.append(f" qurt_barrier_wait(&(((qurt_barrier_t*)sync)[{i}]));")
|
||||
batched.append("}")
|
||||
|
||||
prep = [device.renderer._render(cast(CompiledRunner, ji.prg).p.uops) for i,ji in enumerate(jit_cache)]
|
||||
funcs = dedup(device.renderer._render_body(prep[i][0], *prep[i][1:], cast(CompiledRunner, ji.prg).p.uops) for i,ji in enumerate(jit_cache))
|
||||
|
||||
defines = '\n'.join(set(itertools.chain.from_iterable(device.renderer._render_defines(cast(CompiledRunner, ji.prg).p.uops) for ji in jit_cache)))
|
||||
entry = device.renderer._render_entry("batched", targs)
|
||||
defines = '\n'.join(dedup(itertools.chain.from_iterable(device.renderer._render_defines(cast(CompiledRunner, ji.prg).p.uops) for ji in jit_cache)))
|
||||
entry = device.renderer._render_entry("batched", targs, sync_cnt=len(jit_cache))
|
||||
code = defines + '\n' + '\n'.join([''.join(f) for f in funcs]) + '\n'.join(batched) + '\n' + entry
|
||||
|
||||
if DEBUG >= 4: print(code)
|
||||
|
|
|
|||
|
|
@ -4,64 +4,461 @@ assert sys.platform != 'win32'
|
|||
from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler, MallocAllocator
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType
|
||||
from tinygrad.ops import Ops, UOp
|
||||
from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv, cpu_objdump, DEBUG
|
||||
from tinygrad.codegen.symbolic import gep_pushing
|
||||
from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv, cpu_objdump, DEBUG, dedup, all_same
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
from tinygrad.runtime.autogen import libc, qcom_dsp
|
||||
if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
from tinygrad.ops import PatternMatcher, UPat
|
||||
|
||||
def multi_mul(a0, a1, b0, b1, c0, c1, d0=None, d1=None, acc=None):
|
||||
swizzle = []
|
||||
for i in range(32):
|
||||
swizzle.append(i)
|
||||
swizzle.append(32+i)
|
||||
swizzle.append(64+i)
|
||||
swizzle.append(96+i)
|
||||
swizzle = tuple(swizzle)
|
||||
if a0.op is not Ops.CAST:
|
||||
#print("rejected on a0")
|
||||
return None
|
||||
if a1.op is not Ops.CAST:
|
||||
#print("rejected on a1")
|
||||
return None
|
||||
if d0 is None:
|
||||
if c0.src[0].op is Ops.GEP and c0.src[0].arg == (2, 6, 10, 14, 18, 22, 26, 30, 34, 38, 42, 46, 50, 54, 58, 62,
|
||||
66, 70, 74, 78, 82, 86, 90, 94, 98, 102, 106, 110, 114, 118, 122, 126):
|
||||
d0 = c0.src[0].src[0].gep(tuple(i+1 for i in c0.src[0].arg)).cast(dtypes.int.vec(32))
|
||||
else:
|
||||
d0 = UOp.const(dtypes.uchar.vec(32), 0).cast(dtypes.int.vec(32))
|
||||
if d1 is None:
|
||||
#if c1.src[0].op is Ops.GEP: print("here1", c1.src[0].arg)
|
||||
d1 = UOp.const(dtypes.uchar.vec(32), 0).cast(dtypes.int.vec(32))
|
||||
assert a0.op is Ops.CAST
|
||||
assert b0.op is Ops.CAST
|
||||
assert c0.op is Ops.CAST
|
||||
assert d0.op is Ops.CAST
|
||||
assert a1.op is Ops.CAST
|
||||
assert b1.op is Ops.CAST
|
||||
assert c1.op is Ops.CAST
|
||||
assert d1.op is Ops.CAST
|
||||
dt1 = a0.src[0].dtype.scalar().vec(128)
|
||||
dt2 = a1.src[0].dtype.scalar().vec(128)
|
||||
m0 = UOp(Ops.CAT, dt1, src=(a0.src[0],b0.src[0],c0.src[0],d0.src[0])).gep(swizzle)
|
||||
m1 = UOp(Ops.CAT, dt2, src=(a1.src[0],b1.src[0],c1.src[0],d1.src[0])).gep(swizzle)
|
||||
simp_m1 = m1.simplify()
|
||||
if simp_m1.op is Ops.GEP and simp_m1.arg == simp_m1.arg[0:4]*32:
|
||||
# Vx32.w+=vrmpy(Vu32.ub,Rt32.b) -> __builtin_HEXAGON_V6_vrmpybus_acc
|
||||
# Vx32.uw+=vrmpy(Vu32.ub,Rt32.ub) -> __builtin_HEXAGON_V6_vrmpyub_acc
|
||||
scalar_m1 = simp_m1.src[0].gep(simp_m1.arg[0:4])
|
||||
if acc is not None:
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, scalar_m1.bitcast(dtypes.uint)), "__builtin_HEXAGON_V6_vrmpyub_acc_128B({0}, {1}, {2})")
|
||||
else:
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, scalar_m1.bitcast(dtypes.uint)), "__builtin_HEXAGON_V6_vrmpyub_128B({0}, {1})")
|
||||
if acc is not None:
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, m1), "__builtin_HEXAGON_V6_vrmpyubv_acc_128B({0}, {1}, {2})")
|
||||
else:
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, m1), "__builtin_HEXAGON_V6_vrmpyubv_128B({0}, {1})")
|
||||
|
||||
# __builtin_HEXAGON_A2_vraddub_acc
|
||||
|
||||
def gep_on_reduce(gep, alu):
|
||||
if gep.dtype.vcount == 1: return None
|
||||
return UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count),
|
||||
tuple(x.gep(gep.arg) if x.op is not Ops.RANGE else x for x in alu.src), alu.arg) if not isinstance(gep.dtype, PtrDType) and \
|
||||
alu.dtype.count >= gep.dtype.count else None
|
||||
|
||||
def multi_add_int32(**aa):
|
||||
if 'acc' in aa:
|
||||
acc = aa['acc']
|
||||
del aa['acc']
|
||||
else:
|
||||
acc = None
|
||||
mask = 0x01010101
|
||||
if 'd0' not in aa:
|
||||
mask = 0x00010101
|
||||
c0 = aa['c0']
|
||||
if c0.src[0].op is Ops.GEP and c0.src[0].arg == (2, 6, 10, 14, 18, 22, 26, 30, 34, 38, 42, 46, 50, 54, 58, 62,
|
||||
66, 70, 74, 78, 82, 86, 90, 94, 98, 102, 106, 110, 114, 118, 122, 126):
|
||||
d0 = c0.src[0].src[0].gep(tuple(i+1 for i in c0.src[0].arg)).cast(dtypes.int.vec(32))
|
||||
else:
|
||||
d0 = UOp.const(dtypes.uchar.vec(32), 0).cast(dtypes.int.vec(32))
|
||||
aa['d0'] = d0
|
||||
swizzle = []
|
||||
for i in range(32):
|
||||
swizzle.append(i)
|
||||
swizzle.append(32+i)
|
||||
swizzle.append(64+i)
|
||||
swizzle.append(96+i)
|
||||
for x in aa.values():
|
||||
assert x.src[0].dtype.scalar() is dtypes.uchar
|
||||
assert x.op is Ops.CAST
|
||||
swizzle = tuple(swizzle)
|
||||
m0 = UOp(Ops.CAT, dtypes.uchar.vec(128), src=tuple(aa[k].src[0] for k in sorted(aa.keys()))).gep(swizzle)
|
||||
if acc is not None:
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (acc, m0, UOp.const(dtypes.uint, mask)),
|
||||
"__builtin_HEXAGON_V6_vrmpyub_acc_128B({0}, {1}, {2})")
|
||||
else:
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (m0, UOp.const(dtypes.uint, mask)), "__builtin_HEXAGON_V6_vrmpyub_128B({0}, {1})")
|
||||
|
||||
def multi_add_int2(**aa):
|
||||
if 'acc' in aa:
|
||||
acc = aa['acc']
|
||||
del aa['acc']
|
||||
else:
|
||||
acc = None
|
||||
eles = []
|
||||
for k in sorted(aa.keys()): eles.append(aa[k].src[0].gep(0))
|
||||
for k in sorted(aa.keys()): eles.append(aa[k].src[0].gep(1))
|
||||
r0 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(eles[0:4]+eles[8:12]))
|
||||
r1 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(eles[4:8]+eles[12:16]))
|
||||
|
||||
"""
|
||||
unsigned long long precast0 = __builtin_HEXAGON_A2_vraddub((*((unsigned long long*)&val1)), (*((unsigned long long*)&val2)));
|
||||
acc4 = (acc4+(*((int2*)&precast0)));
|
||||
|
||||
int2 cast2 = __builtin_convertvector((unsigned_char2){val1[0],val2[0]}, int2);
|
||||
int2 cast4 = __builtin_convertvector((unsigned_char2){val1[1],val2[1]}, int2);
|
||||
int2 cast6 = __builtin_convertvector((unsigned_char2){val1[2],val2[2]}, int2);
|
||||
int2 cast8 = __builtin_convertvector((unsigned_char2){val1[3],val2[3]}, int2);
|
||||
int2 cast10 = __builtin_convertvector((unsigned_char2){val1[4],val2[4]}, int2);
|
||||
int2 cast12 = __builtin_convertvector((unsigned_char2){val1[5],val2[5]}, int2);
|
||||
int2 cast14 = __builtin_convertvector((unsigned_char2){val1[6],val2[6]}, int2);
|
||||
int2 cast16 = __builtin_convertvector((unsigned_char2){val1[7],val2[7]}, int2);
|
||||
acc4 = (acc4+cast2+cast4+cast6+cast8+cast10+cast12+cast14+cast16);
|
||||
"""
|
||||
|
||||
#r0 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(x.src[0].gep(0) for x in aa.values()))
|
||||
#r1 = UOp(Ops.VECTORIZE, dtypes.uchar.vec(8), tuple(x.src[0].gep(1) for x in aa.values()))
|
||||
if acc is not None:
|
||||
#return UOp(Ops.CUSTOMI, dtypes.uint64, (acc.bitcast(dtypes.uint64), r0.bitcast(dtypes.uint64), r1.bitcast(dtypes.uint64)),
|
||||
# arg="__builtin_HEXAGON_A2_vraddub_acc({0}, {1}, {2})").bitcast(dtypes.int.vec(2))
|
||||
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(2), (acc, r0.bitcast(dtypes.int64), r1.bitcast(dtypes.int64)),
|
||||
arg="__builtin_HEXAGON_A2_vraddub_acc({0}, {1}, {2})")
|
||||
else:
|
||||
#return UOp(Ops.CUSTOMI, dtypes.uint64,
|
||||
# (r0.bitcast(dtypes.uint64), r1.bitcast(dtypes.uint64)), arg="__builtin_HEXAGON_A2_vraddub({0}, {1})").bitcast(dtypes.int.vec(2))
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(2), (r0.bitcast(dtypes.int64), r1.bitcast(dtypes.int64)), arg="__builtin_HEXAGON_A2_vraddub({0}, {1})")
|
||||
|
||||
conv_pm = PatternMatcher([
|
||||
# __builtin_HEXAGON_V6_vrmpybus x3
|
||||
(UPat(dtype=dtypes.int.vec(32), name="a0")*UPat(name="a1") + UPat(name="b0")*UPat(name="b1") + \
|
||||
UPat(name="c0")*UPat(name="c1"), multi_mul),
|
||||
(UPat(name="acc") + UPat(dtype=dtypes.int.vec(32), name="a0")*UPat(name="a1") + UPat(name="b0")*UPat(name="b1") + \
|
||||
UPat(name="c0")*UPat(name="c1"), multi_mul),
|
||||
|
||||
# __builtin_HEXAGON_V6_vrmpybus x3
|
||||
(UPat(Ops.CAST, dtype=dtypes.int.vec(32), name="a0") + UPat(Ops.CAST, name="b0") + UPat(Ops.CAST, name="c0"), multi_add_int32),
|
||||
(UPat(name="acc") + UPat(Ops.CAST, dtype=dtypes.int.vec(32), name="a0") + UPat(Ops.CAST, name="b0") + UPat(Ops.CAST, name="c0"), multi_add_int32),
|
||||
])
|
||||
|
||||
dsp_pm = PatternMatcher([
|
||||
# convert load char32 to load char128
|
||||
(UPat(Ops.LOAD, (dtypes.uchar.vec(96), dtypes.uchar.vec(64), dtypes.uchar.vec(32)), src=(UPat.var("buf").cast(),), name="load"),
|
||||
lambda load, buf: load.replace(dtype=dtypes.uchar.vec(128),
|
||||
src=(buf.cast(buf.dtype.base.vec(128).ptr(size=buf.dtype.size, local=buf.dtype.local)),)+load.src[1:]).gep(tuple(range(0, load.dtype.count)))),
|
||||
# GEP on REDUCE
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.REDUCE, name='alu'),), name='gep'), gep_on_reduce),
|
||||
# no swizzle down convert
|
||||
(((UPat.var('x').maximum(0) ^ -1).maximum(-256) ^ -1).cast(dtypes.uchar.vec(128)),
|
||||
lambda x: UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=tuple(x.gep(tuple(range(i, i+32))) for i in range(0, 128, 32)),
|
||||
arg="__builtin_HEXAGON_V6_vpackhub_sat_128B(__builtin_HEXAGON_V6_vpackwh_sat_128B({3}, {2}), __builtin_HEXAGON_V6_vpackwh_sat_128B({1}, {0}))")),
|
||||
(UPat(Ops.GEP, name="x"), lambda x: UOp(Ops.CUSTOM, x.dtype, x.src+x.src,
|
||||
"__builtin_shufflevector({0}, {1}, "+','.join([str(y) for y in x.arg])+")") if len(x.arg) > 1 and x.src[0].dtype.count > 1 else None),
|
||||
])
|
||||
|
||||
# REDUCE int4 -> 2xint2, int8 -> 4xint2
|
||||
(UPat(Ops.REDUCE, dtype=dtypes.int.vec(4), name="r"),
|
||||
lambda r: UOp(Ops.CAT, r.dtype, (gep_on_reduce(r.gep((0,1)), r), gep_on_reduce(r.gep((2,3)), r)))),
|
||||
(UPat(Ops.REDUCE, dtype=dtypes.int.vec(8), name="r"),
|
||||
lambda r: UOp(Ops.CAT, r.dtype, (gep_on_reduce(r.gep((0,1)), r), gep_on_reduce(r.gep((2,3)), r),
|
||||
gep_on_reduce(r.gep((4,5)), r), gep_on_reduce(r.gep((6,7)), r)))),
|
||||
|
||||
# __builtin_HEXAGON_V6_vrmpybus
|
||||
(UPat(dtype=dtypes.int.vec(32), name="a0")*UPat(name="a1") + UPat(name="b0")*UPat(name="b1") + \
|
||||
UPat(name="c0")*UPat(name="c1") + UPat(name="d0")*UPat(name="d1"), multi_mul),
|
||||
(UPat(name="acc") + UPat(dtype=dtypes.int.vec(32), name="a0")*UPat(name="a1") + UPat(name="b0")*UPat(name="b1") + \
|
||||
UPat(name="c0")*UPat(name="c1") + UPat(name="d0")*UPat(name="d1"), multi_mul),
|
||||
|
||||
# build __builtin_HEXAGON_V6_vrmpybus_128B
|
||||
(UPat(Ops.CAST,dtype=dtypes.int.vec(32),name="a0")+UPat(Ops.CAST,name="b0")+UPat(Ops.CAST,name="c0")+UPat(Ops.CAST,name="d0"), multi_add_int32),
|
||||
(UPat(name="acc")+UPat(Ops.CAST,dtype=dtypes.int.vec(32),name="a0")+UPat(Ops.CAST,name="b0")+
|
||||
UPat(Ops.CAST,name="c0")+UPat(Ops.CAST,name="d0"), multi_add_int32),
|
||||
|
||||
# build __builtin_HEXAGON_A2_vraddub
|
||||
(UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a0")+UPat(Ops.CAST,name="a1")+UPat(Ops.CAST,name="a2")+UPat(Ops.CAST,name="a3")+ \
|
||||
UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a4")+UPat(Ops.CAST,name="a5")+UPat(Ops.CAST,name="a6")+UPat(Ops.CAST,name="a7"), multi_add_int2),
|
||||
(UPat(name="acc")+UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a0")+UPat(Ops.CAST,name="a1")+UPat(Ops.CAST,name="a2")+UPat(Ops.CAST,name="a3")+ \
|
||||
UPat(Ops.CAST,dtype=dtypes.int.vec(2),name="a4")+UPat(Ops.CAST,name="a5")+UPat(Ops.CAST,name="a6")+UPat(Ops.CAST,name="a7"), multi_add_int2),
|
||||
|
||||
# we upcast 3 as 4
|
||||
(UPat(Ops.REDUCE, name="r", src=(UPat(dtype=dtypes.int.vec(32), name="a0")*UPat(name="a1") +
|
||||
UPat(name="b0")*UPat(name="b1") + UPat(name="c0")*UPat(name="c1"),), allow_any_len=True),
|
||||
lambda r, **kwargs: r.replace(src=(mm,)+r.src[1:]) if (mm:=multi_mul(**kwargs)) else None),
|
||||
(UPat(Ops.REDUCE, name="r", src=(UPat(Ops.CAST,dtype=dtypes.int.vec(32),name="a0")+UPat(Ops.CAST,name="b0")+UPat(Ops.CAST,name="c0"),
|
||||
), allow_any_len=True),
|
||||
lambda r, **kwargs: r.replace(src=(mm,)+r.src[1:]) if (mm:=multi_add_int32(**kwargs)) else None),
|
||||
|
||||
# mul by const on GEP
|
||||
(UPat(Ops.GEP, src=(UPat.var('x'),), name="gep")*UPat.cvar("c", vec=False),
|
||||
lambda x, gep, c: (x.gep(gep.arg[0])*c.arg).broadcast(c.dtype.count) if all_same(gep.arg) and c.dtype.count > 1 else None),
|
||||
])+gep_pushing
|
||||
|
||||
def add_to_mul(c:UOp, x:UOp):
|
||||
if c.arg.startswith("__builtin_HEXAGON_V6_vrmpyub_128B"):
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (x, c.src[0], c.src[1]), "__builtin_HEXAGON_V6_vrmpyub_acc_128B({0}, {1}, {2})")
|
||||
elif c.arg.startswith("__builtin_HEXAGON_V6_vrmpyubv_128B"):
|
||||
return UOp(Ops.CUSTOMI, dtypes.int.vec(32), (x, c.src[0], c.src[1]), "__builtin_HEXAGON_V6_vrmpyubv_acc_128B({0}, {1}, {2})")
|
||||
elif 'acc' in c.arg and x.op is not Ops.CUSTOM:
|
||||
return c.replace(src=(x+c.src[0], c.src[1], c.src[2]))
|
||||
else:
|
||||
return None
|
||||
|
||||
def prefetch_l1(ld:UOp, idx:UOp):
|
||||
if ld.src[-1].op is Ops.CUSTOM: return None
|
||||
ranges = sorted([x for x in ld.src[0].src[0].toposort if x.op is Ops.RANGE], key=lambda x: x.arg)
|
||||
x1 = UOp(Ops.CUSTOM, dtypes.void, src=(idx.src[0], idx.src[1]+UOp.const(dtypes.int, ld.dtype.count*2)),
|
||||
arg="__builtin_HEXAGON_Y2_dcfetch({0}+{1});")
|
||||
x2 = UOp(Ops.CUSTOM, dtypes.void, src=(idx.src[0], idx.src[1].substitute({ranges[-1]: ranges[-1].src[0]})),
|
||||
arg="__builtin_HEXAGON_Y2_dcfetch({0}+{1});")
|
||||
return ld.replace(src=ld.src+(x1, x2))
|
||||
|
||||
def prefetch_l2(ld:UOp, idx:UOp):
|
||||
if not getenv("PREFETCHL2", 1): return None
|
||||
if ld.src[-1].op is Ops.CUSTOM and 'l2fetch' in ld.src[-1].arg: return None
|
||||
ranges = sorted([x for x in ld.src[0].src[0].toposort if x.op is Ops.RANGE], key=lambda x: x.arg)
|
||||
if len(ranges):
|
||||
nidx = idx.src[1]
|
||||
const = 0
|
||||
if nidx.op is Ops.ADD and nidx.src[1].op is Ops.CONST:
|
||||
# NOTE: this causes access alignment issues
|
||||
#const = nidx.src[1].arg
|
||||
nidx = nidx.src[0]
|
||||
zero_ranges = {r:r.const_like(0) for r in ranges[:-1]}
|
||||
nlen_uop = (nidx.substitute({ranges[-1]: ranges[-1].src[1], **zero_ranges}) -
|
||||
nidx.substitute({ranges[-1]: ranges[-1].src[0], **zero_ranges})).simplify()
|
||||
nidx = nidx.substitute({ranges[-1]: ranges[-1].src[0]})
|
||||
buf_lines_total = ((idx.src[0].dtype.size*idx.src[0].dtype.itemsize)+127)//128
|
||||
if buf_lines_total < 8192//128:
|
||||
# if the total buffer size is sub 8k, fetch it all
|
||||
x1 = UOp(Ops.CUSTOM, dtypes.void, src=(idx.src[0], UOp.const(dtypes.int, buf_lines_total)),
|
||||
arg="__builtin_HEXAGON_Y4_l2fetch({0}, 0x808000|{1});")
|
||||
else:
|
||||
fetch_lines = 8
|
||||
if nlen_uop.op is Ops.CONST and nlen_uop.arg <= 8192: fetch_lines = ((nlen_uop.arg+127)//128)*2+1
|
||||
fetch_lines = max(fetch_lines, 8)
|
||||
|
||||
# fetch up to 8192
|
||||
x1 = UOp(Ops.CUSTOM, dtypes.void, src=(idx.src[0], nidx+const, UOp.const(dtypes.int, fetch_lines)),
|
||||
arg="__builtin_HEXAGON_Y4_l2fetch({0}+{1}, 0x808000|{2});")
|
||||
return ld.replace(src=ld.src+(x1,))
|
||||
|
||||
def vectorize_shuffle(vec:UOp):
|
||||
if getenv("DISABLE_VECTORIZED_SHUFFLE", 0): return None
|
||||
if not all(s.op in {Ops.GEP, Ops.CONST} for s in vec.src): return None
|
||||
gepped = dedup([s.src[0] for s in vec.src if s.op is Ops.GEP])
|
||||
if len(gepped) == 0: return None
|
||||
if len(gepped) == 1:
|
||||
# this pattern is broken in DSP clang
|
||||
if gepped[0].dtype.count == 4: return None
|
||||
#return None
|
||||
arg = []
|
||||
for s in vec.src:
|
||||
if s.op is Ops.GEP:
|
||||
arg.append(s.arg[0])
|
||||
else:
|
||||
arg.append(-1)
|
||||
str_arg = ','.join([f'{y:4d}' for y in arg])
|
||||
full_arg = "__builtin_shufflevector({0}, {0}, "+str_arg+")"
|
||||
return UOp(Ops.CUSTOM, vec.dtype, tuple(gepped), full_arg)
|
||||
if not all(x.dtype.scalar() is dtypes.uchar for x in gepped): return None
|
||||
if not all_same([x.dtype.count for x in gepped]) or gepped[0].dtype.count != vec.dtype.count: return None
|
||||
if len(gepped) == 2:
|
||||
arg = []
|
||||
for s in vec.src:
|
||||
if s.op is Ops.GEP:
|
||||
if s.src[0] is gepped[0]:
|
||||
arg.append(s.arg[0])
|
||||
continue
|
||||
if s.src[0] is gepped[1]:
|
||||
arg.append(gepped[0].dtype.count + s.arg[0])
|
||||
continue
|
||||
arg.append(-1)
|
||||
str_arg = ','.join([f'{y:4d}' for y in arg])
|
||||
full_arg = "__builtin_shufflevector({0}, {1}, "+str_arg+")"
|
||||
return UOp(Ops.CUSTOM, vec.dtype, tuple(gepped), full_arg)
|
||||
if len(gepped) != 3: return None
|
||||
arg = []
|
||||
for s in vec.src:
|
||||
if s.op is Ops.GEP:
|
||||
if s.src[0] is gepped[0]:
|
||||
arg.append(s.arg[0])
|
||||
continue
|
||||
if s.src[0] is gepped[1]:
|
||||
arg.append(gepped[0].dtype.count + s.arg[0])
|
||||
continue
|
||||
arg.append(-1)
|
||||
arg2 = []
|
||||
for i,s in enumerate(vec.src):
|
||||
if s.op is Ops.GEP:
|
||||
if s.src[0] is gepped[2]:
|
||||
arg2.append(vec.dtype.count + s.arg[0])
|
||||
continue
|
||||
if s.op is Ops.CONST:
|
||||
arg2.append(-1)
|
||||
continue
|
||||
arg2.append(i)
|
||||
str_arg = ','.join([f'{y:4d}' for y in arg])
|
||||
str_arg2 = ','.join([f'{y:4d}' for y in arg2])
|
||||
full_arg = "__builtin_shufflevector(__builtin_shufflevector({0}, {1}, "+str_arg+"), {2}, "+str_arg2+")"
|
||||
return UOp(Ops.CUSTOM, vec.dtype, tuple(gepped), full_arg)
|
||||
|
||||
#vv = []
|
||||
#for i in range(64):
|
||||
#src = "__builtin_shufflevector({0}, {1})"
|
||||
return None
|
||||
|
||||
def multicore_range(r:UOp):
|
||||
# NOTE: THIS IS BROKEN if this is a reduce range. TODO: check for that
|
||||
if getenv("MULTICORE", 0) != 1: return None
|
||||
if any(x.op is Ops.SPECIAL for x in r.toposort): return None
|
||||
core = UOp(Ops.SPECIAL, dtypes.int, arg=("g0", 2))
|
||||
start = (core.eq(0)).where(r.src[0], r.src[1]//2)
|
||||
end = (core.eq(0)).where(r.src[1]//2, r.src[1])
|
||||
return r.replace(src=(start,end))
|
||||
|
||||
def store_with_mask(buf, idx, val, mask, cast):
|
||||
if val.dtype.count != 128 or mask.dtype.count != 128 or val.dtype.scalar() != dtypes.uchar:
|
||||
print("DROP MASK", val.dtype.count, mask.dtype.count)
|
||||
# NOTE: we are dropping the mask
|
||||
return buf.index(idx).cast(cast.dtype).store(val)
|
||||
|
||||
const_0 = UOp.const(dtypes.uchar.vec(128), 0)
|
||||
cmask = UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=(mask,),arg="{0}")
|
||||
|
||||
# unaligned
|
||||
min_128 = (buf.index(idx).cast(dtypes.uint)&0x7F)
|
||||
cmask_l = UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=(cmask, const_0, min_128), arg="__builtin_HEXAGON_V6_vlalignb_128B({0}, {1}, {2})")
|
||||
cmask_r = UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=(const_0, mask, min_128), arg="__builtin_HEXAGON_V6_vlalignb_128B({0}, {1}, {2})")
|
||||
val_l = UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=(val, const_0, min_128), arg="__builtin_HEXAGON_V6_vlalignb_128B({0}, {1}, {2})")
|
||||
val_r = UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=(const_0, val, min_128), arg="__builtin_HEXAGON_V6_vlalignb_128B({0}, {1}, {2})")
|
||||
store_l = UOp(Ops.CUSTOM, dtypes.void, src=(cmask_l, buf.index(idx).cast(cast.dtype), val_l, const_0),
|
||||
arg='__builtin_HEXAGON_V6_vS32b_nqpred_ai_128B(__builtin_HEXAGON_V6_veqb_128B({0}, {3}), {1}, {2});')
|
||||
store_r = UOp(Ops.CUSTOM, dtypes.void, src=(cmask_r, buf.index(idx+128).cast(cast.dtype), val_r, const_0),
|
||||
arg='__builtin_HEXAGON_V6_vS32b_nqpred_ai_128B(__builtin_HEXAGON_V6_veqb_128B({0}, {3}), {1}, {2});')
|
||||
return UOp(Ops.CUSTOM, src=(store_l,store_r), arg="")
|
||||
|
||||
dsp_pm_late = PatternMatcher([
|
||||
(UPat.var("x")+UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x+UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
|
||||
(UPat.var("x")*UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x*UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
|
||||
(UPat.var("x")//UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x//UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
|
||||
# prefetch L1
|
||||
(UPat(Ops.LOAD, dtype=(dtypes.uchar.vec(4), dtypes.uchar.vec(8)), src=(UPat(Ops.INDEX, name="idx").cast(),), name="ld"), prefetch_l1),
|
||||
|
||||
# prefetch L2
|
||||
(UPat(Ops.LOAD, dtype=(dtypes.uchar.vec(8), dtypes.uchar.vec(128), dtypes.int.vec(32)),
|
||||
src=(UPat(Ops.INDEX, name="idx").cast(),), name="ld", allow_any_len=True), prefetch_l2),
|
||||
|
||||
# 64 -> 128
|
||||
#(UPat(Ops.LOAD, dtype=dtypes.uchar.vec(64), src=(UPat(Ops.CAST, src=(UPat(Ops.INDEX, name="idx"),)),)),
|
||||
# lambda idx: idx.cast(dtypes.uchar.vec(128).ptr(idx.dtype.size)).load(dtype=dtypes.uchar.vec(128)).gep(tuple(range(0,64)))),
|
||||
|
||||
# unaligned load
|
||||
#(UPat(Ops.LOAD, src=(UPat(Ops.CAST, src=(UPat(Ops.INDEX, src=(UPat(), UPat()+UPat.cvar("c"))),), name="ptr"),), dtype=dtypes.uchar.vec(128)),
|
||||
# lambda c,ptr: UOp(Ops.CUSTOM, dtype=dtypes.uchar.vec(128), src=(ptr,), arg='vmemu({0})')),
|
||||
|
||||
(UPat(Ops.VECTORIZE, dtypes.uchar.vec(128), name="vec"), vectorize_shuffle),
|
||||
|
||||
# __builtin_HEXAGON_V6_vrmpyub_acc_128B
|
||||
(UPat(Ops.CUSTOMI, dtype=dtypes.int.vec(32), name="c")+UPat.var("x"), add_to_mul),
|
||||
|
||||
# add acc to __builtin_HEXAGON_A2_vraddub (must be after the reduce expansion)
|
||||
(UPat(Ops.CUSTOMI, name="cu", arg="__builtin_HEXAGON_A2_vraddub({0}, {1})") + UPat.var("x"),
|
||||
lambda x,cu: cu.replace(dtype=dtypes.int64, src=(x.bitcast(dtypes.int64), cu.src[0], cu.src[1]),
|
||||
arg="__builtin_HEXAGON_A2_vraddub_acc({0}, {1}, {2})").bitcast(dtypes.int.vec(2))),
|
||||
|
||||
#(UPat(Ops.BITCAST, src=(UPat(Ops.LOAD, name="ld"),), name="bc"),
|
||||
# lambda ld, bc: ld.src[0].src[0].cast(bc.dtype.ptr(ld.src[0].dtype.size)).load(dtype=bc.dtype)),
|
||||
(UPat(Ops.GEP, name="x"), lambda x: UOp(Ops.CUSTOM, x.dtype, x.src,
|
||||
"__builtin_shufflevector({0}, {0}, "+','.join([f'{y:4d}' for y in x.arg])+")") if len(x.arg) > 1 and x.src[0].dtype.count > 4 else None),
|
||||
(UPat.var("x")+UPat(Ops.VECTORIZE,src=UPat.var("y")),
|
||||
lambda x,y: x+UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI or x.arg != "{0}" else None),
|
||||
(UPat.var("x")*UPat(Ops.VECTORIZE,src=UPat.var("y")),
|
||||
lambda x,y: x*UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI or x.arg != "{0}" else None),
|
||||
(UPat.var("x")//UPat(Ops.VECTORIZE,src=UPat.var("y")),
|
||||
lambda x,y: x//UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI or x.arg != "{0}" else None),
|
||||
(UPat(Ops.DEFINE_ACC, src=(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST, arg=0)),), dtype=dtypes.uchar.vec(128), name="d", allow_any_len=True),
|
||||
lambda d: d.replace(src=(UOp(Ops.CUSTOMI, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:])),
|
||||
|
||||
# multicore
|
||||
(UPat(Ops.RANGE, name="r", arg=0), multicore_range),
|
||||
|
||||
# store with mask
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat.var("mask"))).cast().named("cast"), UPat.var("val"))),
|
||||
store_with_mask),
|
||||
])
|
||||
|
||||
# NOTE: this just increases readability of the generated code
|
||||
dsp_string = PatternMatcher([
|
||||
(UPat(Ops.CONST, (dtypes.int8, dtypes.uint8), name="x"), lambda ctx,x: str(x.arg)),
|
||||
pretty_render = PatternMatcher([
|
||||
# makes rendering nicer
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST, dtype=(dtypes.uint8, dtypes.int8)), name="v"),
|
||||
lambda v: UOp(Ops.VECTORIZE, v.dtype, src=tuple(UOp(Ops.CUSTOMI, x.dtype, src=(UOp.const(dtypes.int, x.arg),), arg="{0}") for x in v.src))),
|
||||
])
|
||||
|
||||
class DSPRenderer(ClangRenderer):
|
||||
device = "DSP"
|
||||
supports_float4 = True
|
||||
global_max = (2, 1, 1)
|
||||
buffer_suffix = " restrict __attribute__((align_value(128)))"
|
||||
kernel_prefix = "__attribute__((noinline)) "
|
||||
pre_matcher = dsp_pm
|
||||
extra_matcher = dsp_pm_late+ClangRenderer.extra_matcher
|
||||
string_rewrite = dsp_string+ClangRenderer.string_rewrite
|
||||
extra_matcher = dsp_pm_late+ClangRenderer.extra_matcher+pretty_render
|
||||
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
|
||||
code_for_op = {k:v for k,v in ClangRenderer.code_for_op.items() if k != Ops.SQRT}
|
||||
extra_args = ['int global_idx_0', 'void* sync']
|
||||
code_for_workitem = {"g": lambda x: f"global_idx_{x}"}
|
||||
|
||||
def _render_defines(self, uops) -> list[str]:
|
||||
return ['''/* DSP boilerplate */ struct dcvs_v2_req { int type; int _pad; _Bool dcvs_enable; char dcvs_option; _Bool set_latency; int latency;
|
||||
_Bool set_dcvs_params; short _pad2; char target_corner; char min_corner; char max_corner; int _pad3[3];};''','int HAP_power_set(void*, void*);',
|
||||
'typedef union { struct { void *pv; unsigned int len; } buf; struct { int fd; unsigned int offset; } dma; } remote_arg;',
|
||||
'void* HAP_mmap(void *addr, int len, int prot, int flags, int fd, long offset);', 'int HAP_munmap(void *addr, int len);',
|
||||
'unsigned long long HAP_perf_get_time_us(void);'] + super()._render_defines(uops)
|
||||
'unsigned long long HAP_perf_get_time_us(void);', 'typedef unsigned long qurt_thread_t;', 'void qurt_thread_exit(int);',
|
||||
'typedef struct _qurt_barrier { char padding[64]; } qurt_barrier_t;', 'int qurt_barrier_init(qurt_barrier_t*, unsigned int);',
|
||||
'int qurt_barrier_wait(qurt_barrier_t*);',
|
||||
'typedef struct _qurt_thread_attr { char name[16]; unsigned char tcb_partition; unsigned char affinity; unsigned short priority;',
|
||||
'unsigned char asid; unsigned char bus_priority; unsigned short timetest_id; unsigned int stack_size; void *stack_addr; char padding[96]; } qurt_thread_attr_t;',
|
||||
'int qurt_thread_join(qurt_thread_t tid, int *status);', 'void* malloc(unsigned int);', 'void free(void*);',
|
||||
'int qurt_thread_create (qurt_thread_t *thread_id, qurt_thread_attr_t *attr, void (*entrypoint) (void *), void *arg);',
|
||||
] + super()._render_defines(uops)
|
||||
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]]) -> str:
|
||||
msrc = ['int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
|
||||
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]], sync_cnt=0x0) -> str:
|
||||
msrc = ['typedef struct all_args {', *[f'int sz_or_val_{i}; int off{i}; void *buf_{i};' for i in range(len(bufs))], 'void* sync; } all_args_t;']
|
||||
msrc += ['void threader(all_args_t* args) {']
|
||||
msrc += [f"{function_name}({', '.join([(f'args->buf_{i}' if isinstance(b[1][0], PtrDType) else f'args->sz_or_val_{i}') for i,b in enumerate(bufs)])}, 1, args->sync);"]
|
||||
msrc += ['qurt_thread_exit(0); }'
|
||||
'int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
|
||||
'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};',
|
||||
'HAP_power_set((void*)handle, (void*)&req);']
|
||||
msrc += ['if ((sc>>24) != 2) return 0;']
|
||||
msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
|
||||
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
if sync_cnt > 0:
|
||||
msrc += [f"qurt_barrier_t* sync = malloc({sync_cnt} * sizeof(qurt_barrier_t));"]
|
||||
msrc += [f"qurt_barrier_init(&sync[{i}], 2);" for i in range(sync_cnt)]
|
||||
else: msrc += [f"qurt_barrier_t* sync = 0x0;"]
|
||||
msrc += ['all_args_t args = { 0 };']
|
||||
msrc += [f'args.sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
|
||||
msrc += [f'args.off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
msrc += [f'args.buf_{i} = HAP_mmap(0,args.sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+args.off{i};' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
msrc += ['args.sync = sync;']
|
||||
msrc += ["qurt_thread_attr_t attr = { 0 };"]
|
||||
msrc += ["attr.name[0] = 't';", "attr.priority = 255;", "attr.asid = 0;"]
|
||||
msrc += ["attr.stack_size = (64 << 10);", "attr.stack_addr = malloc(attr.stack_size);"]
|
||||
msrc += [""]
|
||||
msrc += ["unsigned long long start = HAP_perf_get_time_us();"]
|
||||
msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0], PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
|
||||
if getenv("MULTICORE", 0) != 0:
|
||||
msrc += ["qurt_thread_t thread_ = 0; qurt_thread_create(&thread_, &attr, (void (*)(void*))threader, (void*)&args);"]
|
||||
msrc += [f"{function_name}({', '.join([(f'args.buf_{i}' if isinstance(b[1][0], PtrDType) else f'args.sz_or_val_{i}') for i,b in enumerate(bufs)])}, 0, args.sync);"]
|
||||
if getenv("MULTICORE", 0) != 0:
|
||||
msrc += ['int status;', f"qurt_thread_join(thread_, &status);"]
|
||||
msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"]
|
||||
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
msrc += [f'HAP_munmap(args.buf_{i}, args.sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
||||
msrc += ['free(attr.stack_addr);']
|
||||
if sync_cnt > 0: msrc += ['free(sync);']
|
||||
msrc += ["return 0; }"]
|
||||
return '\n'.join(msrc)
|
||||
|
||||
|
|
@ -137,10 +534,10 @@ class DSPDevice(Compiled):
|
|||
try:
|
||||
self.ion_fd = os.open('/dev/ion', os.O_RDONLY)
|
||||
# Generate link script to pass into clang. Aligning all used sections to 4k fixes invoke problem.
|
||||
sections = ['hash', 'text', 'rela.plt', 'got', 'got.plt', 'dynamic', 'dynsym', 'dynstr', 'plt', 'data', 'bss']
|
||||
sections = ['text', 'rela.plt', 'rela.dyn', 'plt', 'data', 'bss', 'hash', 'dynamic', 'got', 'got.plt', 'dynsym', 'dynstr', 'symtab', 'shstrtab', 'strtab']
|
||||
sections_link = '\n'.join([f'.{n} : ALIGN(4096) {{ *(.{n}) }}' for n in sections])
|
||||
with tempfile.NamedTemporaryFile(delete=False) as self.link_ld:
|
||||
self.link_ld.write(f"SECTIONS {{ . = 0x0; {sections_link}\n /DISCARD/ : {{ *(.note .note.* .gnu.hash .comment) }} }}".encode())
|
||||
self.link_ld.write(f"SECTIONS {{ . = 0x0;\n{sections_link}\n /DISCARD/ : {{ *(.note .note.* .gnu.hash .comment) }} }}".encode())
|
||||
self.link_ld.flush()
|
||||
|
||||
from tinygrad.runtime.graph.cpu import CPUGraph
|
||||
|
|
@ -282,7 +679,14 @@ class MockDSPRenderer(DSPRenderer):
|
|||
else:
|
||||
msrc.append(f"unsigned int val{i}; read(0, &val{i}, 4);")
|
||||
msrc.append("unsigned int st = inscount();")
|
||||
msrc.append(f"{function_name}({', '.join([(f'(void*)buf{i}' if isinstance(b[1][0], PtrDType) else f'val{i}') for i,b in enumerate(bufs)])});")
|
||||
if getenv("MULTICORE", 0) != 0:
|
||||
# TODO: get count?
|
||||
# NOTE: we do them in reverse order to reveal bugs
|
||||
msrc.append(f"{function_name}({', '.join([(f'(void*)buf{i}' if isinstance(b[1][0], PtrDType) else f'val{i}') for i,b in enumerate(bufs)])}, 1, 0);")
|
||||
msrc.append(f"{function_name}({', '.join([(f'(void*)buf{i}' if isinstance(b[1][0], PtrDType) else f'val{i}') for i,b in enumerate(bufs)])}, 0, 0);")
|
||||
else:
|
||||
# huh, why did this change?
|
||||
msrc.append(f"{function_name}({', '.join([(f'(void*)buf{i}' if isinstance(b[1][0], PtrDType) else f'val{i}') for i,b in enumerate(bufs)])}, 0, 0);")
|
||||
msrc.append("unsigned int et = inscount() - st; write(1, &et, sizeof(et));")
|
||||
for i,b in enumerate(bufs):
|
||||
if isinstance(b[1][0], PtrDType): msrc.append(f"write(1, buf{i}, {b[1][0].size*b[1][0].itemsize});")
|
||||
|
|
|
|||
|
|
@ -81,6 +81,9 @@ spec = PatternMatcher([
|
|||
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True),
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat(Ops.STORE))), lambda: True),
|
||||
|
||||
# all loads (we add Ops.CUSTOM as fake sources on this for l1fetch)
|
||||
(UPat(Ops.LOAD), lambda: True),
|
||||
|
||||
# early STORE has a <buf, shapetracker, val>
|
||||
(UPat(Ops.STORE, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat())), lambda: True),
|
||||
|
||||
|
|
@ -91,6 +94,9 @@ spec = PatternMatcher([
|
|||
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat()), name="idx"), validate_index),
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(), UPat(dtype=dtypes.bool, name="mask")), name="idx"), validate_index),
|
||||
|
||||
# any mask
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(), UPat(name="mask")), name="idx"), validate_index),
|
||||
|
||||
# LOAD takes a <bufidx, alt?, barrier?>
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue