mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Symbolic Shape JIT main PR (#1353)
* Symbolic Shape JIT update tests 2 variables symbolic ops, adding more tests test passing cleanup * more test cases * single flag * review update * jit attention one piece * realize * symbolic_jit test for cuda * old artifact * works with cuda gpu but failed ci * CUDACPU
This commit is contained in:
parent
84e6693915
commit
ae39cf84ab
13 changed files with 223 additions and 65 deletions
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
|
|
@ -188,7 +188,7 @@ jobs:
|
|||
- name: Run JIT test
|
||||
run: DEBUG=2 METAL=1 python -m pytest -n=auto test/test_jit.py
|
||||
- name: Run symbolic shapetracker test
|
||||
run: METAL=1 python -m pytest -n=auto test/test_symbolic_shapetracker.py test/test_symbolic_ops.py
|
||||
run: METAL=1 python -m pytest -n=auto test/test_symbolic_shapetracker.py test/test_symbolic_ops.py test/test_symbolic_jit.py
|
||||
- name: Check Device.DEFAULT
|
||||
run: WEBGPU=1 python -c "from tinygrad.lazy import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT"
|
||||
#- name: Run webgpu pytest
|
||||
|
|
|
|||
|
|
@ -307,7 +307,7 @@ GlobalCounters.cache = None # disable the cache
|
|||
|
||||
# there's one ASTRunner in the cache
|
||||
assert len(cache_saved) == 1
|
||||
prg, bufs = cache_saved[0]
|
||||
prg, bufs, _ = cache_saved[0]
|
||||
|
||||
# print the C Program :)
|
||||
print(prg.prg)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.nn import Embedding, Linear
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer
|
||||
|
||||
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
||||
|
|
@ -66,36 +67,27 @@ class Attention:
|
|||
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
|
||||
|
||||
def prepare_attention(self, x:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
def __call__(self, x:Tensor, cache_k:Tensor, cache_v:Tensor, start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
bsz, seqlen, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
|
||||
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
|
||||
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
return xq, xk, xv
|
||||
|
||||
def inner_attention(self, xq:Tensor, xk:Tensor, xv:Tensor, start_pos:int, mask:Optional[Tensor]) -> Tensor:
|
||||
bsz, seqlen, _, _ = xq.shape
|
||||
# kv caching!
|
||||
if start_pos == 0:
|
||||
keys, values = xk, xv
|
||||
else:
|
||||
assert hasattr(self, 'cache_k'), "no cache"
|
||||
assert start_pos == self.cache_k.shape[1] and start_pos == self.cache_v.shape[1], "cache is wrong shape"
|
||||
assert cache_k.shape[0] > 0, "no cache"
|
||||
assert start_pos == sym_infer(cache_k.shape[1], cache_k.lazydata.st.var_vals) == sym_infer(cache_v.shape[1], cache_v.lazydata.st.var_vals), f"cache has wrong shape, not ({start_pos} == {sym_infer(cache_k.shape[1], cache_k.lazydata.st.var_vals)} == {sym_infer(cache_v.shape[1], cache_v.lazydata.st.var_vals)})"
|
||||
assert seqlen == xk.shape[1] and seqlen == xv.shape[1], "seqlen is wrong shape?!?"
|
||||
keys, values = self.cache_k.cat(xk, dim=1), self.cache_v.cat(xv, dim=1)
|
||||
keys, values = cache_k.cat(xk, dim=1), cache_v.cat(xv, dim=1)
|
||||
|
||||
# save the cache
|
||||
self.cache_k, self.cache_v = keys.realize(), values.realize()
|
||||
|
||||
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
|
||||
return Tensor.scaled_dot_product_attention(xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), mask).transpose(1, 2).reshape(bsz, seqlen, -1)
|
||||
|
||||
# NOTE: this is not called
|
||||
def __call__(self, x:Tensor, start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
|
||||
xq, xk, xv = self.prepare_attention(x, freqs_cis)
|
||||
output = self.inner_attention(xq, xk, xv, start_pos, mask)
|
||||
return self.wo(output)
|
||||
cache_k, cache_v = keys, values
|
||||
keys, values = repeat_kv(keys, self.n_rep).realize(), repeat_kv(values, self.n_rep).realize()
|
||||
attn = Tensor.scaled_dot_product_attention(xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), mask).transpose(1, 2).reshape(bsz, seqlen, -1)
|
||||
return self.wo(attn).realize(), cache_k.realize(), cache_v.realize()
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim, hidden_dim, multiple_of, linear=Linear, ffn_dim_multiplier=None):
|
||||
|
|
@ -118,26 +110,32 @@ class TransformerBlock:
|
|||
self.feed_forward = FeedForward(dim, 4*dim, multiple_of, linear, ffn_dim_multiplier)
|
||||
self.attention_norm = RMSNorm(dim, norm_eps)
|
||||
self.ffn_norm = RMSNorm(dim, norm_eps)
|
||||
if getenv("JIT"):
|
||||
self._pre = TinyJit(self.pre)
|
||||
self._post = TinyJit(self.post)
|
||||
else:
|
||||
self._pre, self._post = self.pre, self.post
|
||||
self.cache_k, self.cache_v = None, None
|
||||
|
||||
def pre(self, x:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
xq, xk, xv = self.attention.prepare_attention(self.attention_norm(x), freqs_cis)
|
||||
return xq.realize(), xk.realize(), xv.realize()
|
||||
self.jitted_attention_norm = TinyJit(lambda x: self.attention_norm(x).realize())
|
||||
self.jitted_attn = TinyJit(self.attention.__call__)
|
||||
self.jitted_norm_output = TinyJit(self.norm_output)
|
||||
|
||||
def post(self, x:Tensor, output:Tensor) -> Tensor:
|
||||
h = x + self.attention.wo(output)
|
||||
def norm_output(self, x:Tensor, output:Tensor) -> Tensor:
|
||||
h = x + output
|
||||
return (h + self.feed_forward(self.ffn_norm(h))).realize()
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor]):
|
||||
# if mask is not None, x's shape is dymanic based on user input and pre/post can't be jitted
|
||||
xq, xk, xv = self._pre(x, freqs_cis) if mask is None else self.pre(x, freqs_cis)
|
||||
# inner_attention can't be jitted because it's dynamic based on start_pos
|
||||
output = self.attention.inner_attention(xq, xk, xv, start_pos, mask)
|
||||
return self._post(x, output) if mask is None else self.post(x, output)
|
||||
bsz, seqlen, _ = x.shape
|
||||
do_jit = getenv("JIT") and mask is None
|
||||
if do_jit:
|
||||
pos = Variable("pos", 1, 1024)
|
||||
self.cache_k = self.cache_k.reshape(self.cache_k.shape[0], pos, self.cache_k.shape[2], self.cache_k.shape[3])
|
||||
self.cache_v = self.cache_v.reshape(self.cache_v.shape[0], pos, self.cache_v.shape[2], self.cache_v.shape[3])
|
||||
output, cache_k, cache_v = self.jitted_attn(self.jitted_attention_norm(x), self.cache_k, self.cache_v, start_pos, freqs_cis, mask)
|
||||
else:
|
||||
output, cache_k, cache_v = self.attention(self.attention_norm(x), self.cache_k, self.cache_v, start_pos, freqs_cis, mask)
|
||||
|
||||
# save the cache. with symbolic shape, cast it back to int shape so we have int shape in cache
|
||||
self.cache_k = cache_k.reshape(cache_k.shape[0], start_pos+seqlen, cache_k.shape[2], cache_k.shape[3]).realize()
|
||||
self.cache_v = cache_v.reshape(cache_v.shape[0], start_pos+seqlen, cache_v.shape[2], cache_v.shape[3]).realize()
|
||||
|
||||
return self.jitted_norm_output(x, output) if do_jit else self.norm_output(x, output)
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None, n_kv_heads=None):
|
||||
|
|
@ -146,16 +144,21 @@ class Transformer:
|
|||
self.tok_embeddings = Embedding(vocab_size, dim)
|
||||
self.output = linear(dim, vocab_size, bias=False)
|
||||
self.freqs_cis = Tensor(precompute_freqs_cis(dim // n_heads, max_seq_len * 2))
|
||||
self.norm_output = lambda x: self.output(self.norm(x))
|
||||
|
||||
self.jitted_tok_embeddings = TinyJit(lambda x: self.tok_embeddings(x).realize())
|
||||
self.jitted_norm_output = TinyJit(lambda x: self.norm_output(x).realize())
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:int):
|
||||
_bsz, seqlen = tokens.shape
|
||||
h = self.tok_embeddings(tokens)
|
||||
|
||||
# get only the part we are using. making it contiguous avoids more kernel calls
|
||||
# get only the part we are using. TODO: removing contiguous resulted in a bug?
|
||||
freqs_cis = self.freqs_cis[:, start_pos:start_pos+seqlen].contiguous().realize()
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None
|
||||
|
||||
do_jit = getenv("JIT") and mask is None
|
||||
h = self.jitted_tok_embeddings(tokens) if do_jit else self.tok_embeddings(tokens)
|
||||
h = h.sequential([functools.partial(layer, start_pos=start_pos, freqs_cis=freqs_cis, mask=mask) for layer in self.layers])
|
||||
return self.output(self.norm(h))
|
||||
return self.jitted_norm_output(h) if do_jit else self.norm_output(h)
|
||||
|
||||
# **** files and arguments ****
|
||||
|
||||
|
|
|
|||
8
extra/dist/world.py
vendored
8
extra/dist/world.py
vendored
|
|
@ -9,12 +9,12 @@ from tinygrad.tensor import Tensor, Function
|
|||
import numpy as np
|
||||
|
||||
# fake the function signature of ASTRunner so we can put it in the cache
|
||||
def __send_rb(args:Tuple[RawBufferCopyInOut, RawShmBuffer, int, Any], jit=False, force_wait=False):
|
||||
def __send_rb(args:Tuple[RawBufferCopyInOut, RawShmBuffer, int, Any], variables=None, jit=False, force_wait=False):
|
||||
args[0]._copyout(np.frombuffer(args[1]._buffer(), dtype=args[0].dtype.np))
|
||||
dist.OOB.send(args[3], args[2])
|
||||
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} sent {args[0]} to rank {args[2]}")
|
||||
|
||||
def __recv_rb(args:Tuple[RawBufferCopyIn, RawShmBuffer, int], jit=False, force_wait=False):
|
||||
def __recv_rb(args:Tuple[RawBufferCopyIn, RawShmBuffer, int], variables=None, jit=False, force_wait=False):
|
||||
dist.OOB.recv(args[2])
|
||||
args[0]._copyin(args[1].toCPU())
|
||||
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} recv {args[0]} from rank {args[2]}")
|
||||
|
|
@ -35,7 +35,7 @@ def _send_rb(x:RawBufferCopyInOut, target_rank:int, cache_id:Optional[str]=None)
|
|||
__send_rb((x, rb, target_rank, (shm_name, cache_id)))
|
||||
|
||||
# jit support
|
||||
if GlobalCounters.cache is not None: GlobalCounters.cache.append((__send_rb, [x, rb, target_rank, None]))
|
||||
if GlobalCounters.cache is not None: GlobalCounters.cache.append((__send_rb, [x, rb, target_rank, None], {}))
|
||||
setattr(_send_rb, "shared_memory_cache", {})
|
||||
|
||||
# receive a rawbuffer from the target rank
|
||||
|
|
@ -52,7 +52,7 @@ def _recv_rb(x:RawBufferCopyIn, target_rank:int):
|
|||
s.unlink()
|
||||
|
||||
# jit support
|
||||
if GlobalCounters.cache is not None: GlobalCounters.cache.append((__recv_rb, [x, rb, target_rank]))
|
||||
if GlobalCounters.cache is not None: GlobalCounters.cache.append((__recv_rb, [x, rb, target_rank], {}))
|
||||
|
||||
# sends a lazybuffer from our rank to the target rank
|
||||
def _send_lb(x:LazyBuffer, target_rank:int, cache_id:Optional[str]=None) -> None: _send_rb(x.contiguous().realize().realized, target_rank, cache_id=cache_id)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,8 @@ import json
|
|||
|
||||
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
|
||||
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
|
||||
for fxn,args in run.jit_cache:
|
||||
for fxn,args,var_vals in run.jit_cache:
|
||||
assert not var_vals, "symbolic shape is not supported"
|
||||
functions[fxn.name] = fxn.prg # NOTE: this assumes all with the same name are the same
|
||||
cargs = []
|
||||
for i,arg in enumerate(args):
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ def compile(dat, output_fn):
|
|||
# transform to CL.CACHE
|
||||
used_ops = 0
|
||||
cl_cache = []
|
||||
for prg,args in model_exec.jit_cache:
|
||||
for prg,args,_ in model_exec.jit_cache:
|
||||
# pass these to thneed
|
||||
setattr(prg.clprg, 'op_estimate', prg.op_estimate)
|
||||
setattr(prg.clprg, 'prg', prg.prg)
|
||||
|
|
|
|||
2
test/external/external_copy_benchmark.py
vendored
2
test/external/external_copy_benchmark.py
vendored
|
|
@ -11,7 +11,7 @@ class TestCopy(unittest.TestCase):
|
|||
t = Tensor.randn(i).realize()
|
||||
GlobalCounters.cache = []
|
||||
t.assign(t+1).realize()
|
||||
fxn, args = GlobalCounters.cache[0]
|
||||
fxn, args, _ = GlobalCounters.cache[0]
|
||||
GlobalCounters.reset()
|
||||
def run(): return fxn(args, force_wait=True)
|
||||
ct = min([run() for _ in range(10)])
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def derandomize(x):
|
|||
if isinstance(x, LazyOp):
|
||||
if x.op == LoadOps.RAND: x.op = LoadOps.EMPTY
|
||||
x.src = [derandomize(s) for s in x.src]
|
||||
else:
|
||||
elif hasattr(x, "op"):
|
||||
x.op = derandomize(x.op)
|
||||
return x
|
||||
|
||||
|
|
|
|||
152
test/test_symbolic_jit.py
Normal file
152
test/test_symbolic_jit.py
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
import unittest
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
import numpy as np
|
||||
|
||||
@unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported")
|
||||
@unittest.skipUnless(Device.DEFAULT in ["GPU", "METAL", "CLANG", "CUDA"], f"{Device.DEFAULT} is not supported")
|
||||
class TestSymbolicJit(unittest.TestCase):
|
||||
def test_plus1(self):
|
||||
def f(a): return (a+1).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
a = Tensor.rand(3, i)
|
||||
symbolic = jf(a.reshape(3, vi)).reshape(3, i).cpu().numpy()
|
||||
expected = f(a).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
|
||||
def test_add(self):
|
||||
def f(a, b): return (a+b).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, i)
|
||||
symbolic = jf(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
|
||||
def test_matmul(self):
|
||||
def f(a, b): return (a@b).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(i, 5)
|
||||
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
|
||||
def test_mixed_with_no_symbol_kernel(self):
|
||||
def f(a, b):
|
||||
s = (a@b).realize()
|
||||
s = (s+s).realize() # this one does not have symbols in input
|
||||
return s
|
||||
jf = TinyJit(f)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10)
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(i, 5)
|
||||
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 2
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CLANG", "broken on CLANG CI")
|
||||
def test_attention(self):
|
||||
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
q = Tensor.rand(2, 1, 4, 8)
|
||||
k = Tensor.rand(2, i, 4, 8)
|
||||
v = Tensor.rand(2, i, 4, 8)
|
||||
symbolic = jf(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).cpu().numpy()
|
||||
expected = f(q, k, v).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 6
|
||||
|
||||
def test_cat_dim0(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(2, 3)
|
||||
symbolic = jf(a.reshape(vi, 3), b).reshape(i+2, 3).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
|
||||
def test_cat_dim1(self):
|
||||
def f(a, b): return a.cat(b, dim=1).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, 2)
|
||||
symbolic = jf(a.reshape(3, vi), b).reshape(3, i+2).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
|
||||
def test_cat_dim0_two_vars(self):
|
||||
def f(a, b): return a.cat(b, dim=0).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
vj = Variable("j", 1, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(j, 3)
|
||||
symbolic = jf(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
|
||||
def test_cat_dim1_two_vars(self):
|
||||
def f(a, b): return a.cat(b, dim=1).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
vj = Variable("j", 1, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
a = Tensor.rand(3, i)
|
||||
b = Tensor.rand(3, j)
|
||||
symbolic = jf(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
|
||||
def test_two_vars_plus1(self):
|
||||
def f(a, b): return (a@b+1).realize()
|
||||
jf = TinyJit(f)
|
||||
vi = Variable("i", 1, 10)
|
||||
vj = Variable("j", 1, 10)
|
||||
for i in range(1, 5):
|
||||
for j in range(1, 5):
|
||||
a = Tensor.rand(i, 3)
|
||||
b = Tensor.rand(3, j)
|
||||
symbolic = jf(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).cpu().numpy()
|
||||
expected = f(a, b).cpu().numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert len(jf.jit_cache) == 1
|
||||
|
||||
def test_jit_symbolic_shape_mismatch(self):
|
||||
@TinyJit
|
||||
def add(a, b): return (a+b).realize()
|
||||
vi = Variable("i", 1, 10)
|
||||
for i in range(1, 5):
|
||||
a = Tensor.rand(3, i).reshape(3, vi)
|
||||
b = Tensor.rand(3, i).reshape(3, vi)
|
||||
c = add(a, b)
|
||||
a = Tensor.rand(3, 7).reshape(3, vi)
|
||||
bad = Tensor.rand(4, 7).reshape(4, vi)
|
||||
with self.assertRaises(AssertionError):
|
||||
add(a, bad)
|
||||
|
|
@ -131,7 +131,7 @@ class GlobalCounters:
|
|||
kernel_count: ClassVar[int] = 0
|
||||
mem_used: ClassVar[int] = 0 # NOTE: this is not reset
|
||||
mem_cached: ClassVar[int] = 0 # NOTE: this is not reset
|
||||
cache: ClassVar[Optional[List[Tuple[Callable, Any]]]] = None
|
||||
cache: ClassVar[Optional[List[Tuple[Callable, Any, Dict[Any, int]]]]] = None # List[Tuple[Callable, List[RawBuffer], Dict[Variable, int]]]
|
||||
@staticmethod
|
||||
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0.0,0,None
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
|
||||
import functools, itertools
|
||||
from tinygrad.helpers import DEBUG, DType
|
||||
|
||||
from tinygrad.helpers import DEBUG, DType, merge_dicts
|
||||
from tinygrad.lazy import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import GlobalCounters, RawBuffer
|
||||
from tinygrad.shape.symbolic import Variable, Node
|
||||
|
||||
JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU"]
|
||||
|
||||
|
|
@ -12,25 +12,27 @@ class TinyJit:
|
|||
def __init__(self, fxn:Callable):
|
||||
self.fxn: Callable = fxn
|
||||
self.cnt: int = 0
|
||||
self.jit_cache: List[Tuple[Callable, List[Optional[RawBuffer]]]] = []
|
||||
self.jit_cache: List[Tuple[Callable, List[Optional[RawBuffer]], Dict[Variable, int]]] = []
|
||||
self.ret: Any = None
|
||||
self.input_replace: Dict[Tuple[int, int], Tuple[Union[int, str], int, DType]]= {} # (kernel_number, buffer_number) -> (input_name, expected_size, expected_type)
|
||||
self.input_replace: Dict[Tuple[int, int], Tuple[Union[int, str], Tuple[Union[Node, int],...], DType]]= {} # (kernel_number, buffer_number) -> (input_name, expected_shape, expected_type)
|
||||
|
||||
# add support for instance methods
|
||||
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
if Device.DEFAULT not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device
|
||||
# NOTE: this cast is needed since although we know realize will create a ".realized" DeviceBuffer, the type checker doesn't
|
||||
input_rawbuffers: Dict[Union[int, str], RawBuffer] = {cast(Union[int, str], k):cast(RawBuffer, v.realize().lazydata.realized) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
|
||||
# NOTE: this cast is needed since although we know realize will create a ".realized" RawBuffer, the type checker doesn't
|
||||
input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, Tuple[Union[Node, int],...]]] = {cast(Union[int, str], k):(cast(RawBuffer, v.realize().lazydata.realized), v.shape) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
|
||||
assert len(input_rawbuffers) != 0, "no inputs to JIT"
|
||||
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
|
||||
if self.cnt >= 2:
|
||||
for (j,i),(input_name, expected_size, expected_type) in self.input_replace.items():
|
||||
assert input_rawbuffers[input_name].size == expected_size and input_rawbuffers[input_name].dtype == expected_type, f"size or type mismatch in JIT, {input_rawbuffers[input_name]} != <{expected_size}, {expected_type}>"
|
||||
self.jit_cache[j][1][i] = input_rawbuffers[input_name]
|
||||
for prg, pargs in self.jit_cache: # type: Callable, List[Optional[RawBuffer]]
|
||||
prg(pargs, jit=True)
|
||||
var_vals = dict(sorted(merge_dicts([arg.lazydata.st.var_vals for arg in args if isinstance(arg, Tensor)]).items(), key=lambda kv: kv[0].key))
|
||||
for (j,i),(input_name, expected_shape, expected_type) in self.input_replace.items():
|
||||
assert input_rawbuffers[input_name][1] == expected_shape and input_rawbuffers[input_name][0].dtype == expected_type, f"shape or type mismatch in JIT, <{input_rawbuffers[input_name][1]}, {input_rawbuffers[input_name][0].dtype}> != <{expected_shape}, {expected_type}>"
|
||||
self.jit_cache[j][1][i] = input_rawbuffers[input_name][0]
|
||||
for prg, pargs, variables in self.jit_cache: # type: Callable, List[Optional[RawBuffer]], Dict[Variable, int]
|
||||
for v in (var_vals.keys() & variables.keys()): variables[v] = var_vals[v]
|
||||
prg(pargs, variables, jit=True)
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
|
||||
elif self.cnt == 1:
|
||||
GlobalCounters.cache = []
|
||||
|
|
@ -41,10 +43,10 @@ class TinyJit:
|
|||
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
|
||||
|
||||
# get the inputs for replacement
|
||||
for j_,(_,pargs_) in enumerate(self.jit_cache): # type: Tuple[int, Tuple[Callable, List[Optional[RawBuffer]]]]
|
||||
for i,a in enumerate(pargs_):
|
||||
if a in input_rawbuffers.values():
|
||||
self.input_replace[(j_,i)] = [(k, v.size, v.dtype) for k,v in input_rawbuffers.items() if v == a][0]
|
||||
for j_,cache in enumerate(self.jit_cache): # type: Tuple[int, Tuple[Callable, List[Optional[RawBuffer]], Dict[Variable, int]]]
|
||||
for i,a in enumerate(cache[1]):
|
||||
if a in [v[0] for v in input_rawbuffers.values()]:
|
||||
self.input_replace[(j_,i)] = [(k, v[1], v[0].dtype) for k,v in input_rawbuffers.items() if v[0] == a][0]
|
||||
#if prg.local_size is None: prg.local_size = prg.optimize_local_size(args, preserve_output=True) # the JIT can optimize local
|
||||
assert set([x[0] for x in self.input_replace.values()]) == set(input_rawbuffers.keys()), "some input tensors not found"
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ class ASTRunner:
|
|||
|
||||
def exec(self, bufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False, optimizing=False) -> Optional[float]:
|
||||
rawbufs = dedup([x.realized for x in bufs if buf_is_kernel_arg(x)])
|
||||
if GlobalCounters.cache is not None and not optimizing: GlobalCounters.cache.append((self, rawbufs))
|
||||
if GlobalCounters.cache is not None and not optimizing: GlobalCounters.cache.append((self, rawbufs, var_vals if var_vals is not None else {}))
|
||||
return self(rawbufs, var_vals, force_wait=force_wait)
|
||||
|
||||
def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ class CUDAProgram:
|
|||
if wait:
|
||||
start, end = cuda.Event(), cuda.Event()
|
||||
start.record()
|
||||
self.prg(*[x._buf if isinstance(x, RawCUDABuffer) else x for x in args], block=tuple(local_size), grid=tuple(global_size))
|
||||
self.prg(*[x._buf if isinstance(x, RawCUDABuffer) else np.int32(x) if (isinstance(x, int) and not getenv("CUDACPU")) else x for x in args], block=tuple(local_size), grid=tuple(global_size))
|
||||
if wait:
|
||||
end.record()
|
||||
end.synchronize()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue