Symbolic Shape JIT main PR (#1353)

* Symbolic Shape JIT

update tests

2 variables symbolic ops, adding more tests

test passing

cleanup

* more test cases

* single flag

* review update

* jit attention one piece

* realize

* symbolic_jit test for cuda

* old artifact

* works with cuda gpu but failed ci

* CUDACPU
This commit is contained in:
chenyu 2023-08-18 14:39:55 -07:00 committed by GitHub
commit ae39cf84ab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 223 additions and 65 deletions

View file

@ -188,7 +188,7 @@ jobs:
- name: Run JIT test
run: DEBUG=2 METAL=1 python -m pytest -n=auto test/test_jit.py
- name: Run symbolic shapetracker test
run: METAL=1 python -m pytest -n=auto test/test_symbolic_shapetracker.py test/test_symbolic_ops.py
run: METAL=1 python -m pytest -n=auto test/test_symbolic_shapetracker.py test/test_symbolic_ops.py test/test_symbolic_jit.py
- name: Check Device.DEFAULT
run: WEBGPU=1 python -c "from tinygrad.lazy import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT"
#- name: Run webgpu pytest

View file

@ -307,7 +307,7 @@ GlobalCounters.cache = None # disable the cache
# there's one ASTRunner in the cache
assert len(cache_saved) == 1
prg, bufs = cache_saved[0]
prg, bufs, _ = cache_saved[0]
# print the C Program :)
print(prg.prg)

View file

@ -16,6 +16,7 @@ from tinygrad.tensor import Tensor
from tinygrad.nn import Embedding, Linear
from tinygrad.ops import GlobalCounters
from tinygrad.jit import TinyJit
from tinygrad.shape.symbolic import Variable, sym_infer
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
@ -66,36 +67,27 @@ class Attention:
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
def prepare_attention(self, x:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
def __call__(self, x:Tensor, cache_k:Tensor, cache_v:Tensor, start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
return xq, xk, xv
def inner_attention(self, xq:Tensor, xk:Tensor, xv:Tensor, start_pos:int, mask:Optional[Tensor]) -> Tensor:
bsz, seqlen, _, _ = xq.shape
# kv caching!
if start_pos == 0:
keys, values = xk, xv
else:
assert hasattr(self, 'cache_k'), "no cache"
assert start_pos == self.cache_k.shape[1] and start_pos == self.cache_v.shape[1], "cache is wrong shape"
assert cache_k.shape[0] > 0, "no cache"
assert start_pos == sym_infer(cache_k.shape[1], cache_k.lazydata.st.var_vals) == sym_infer(cache_v.shape[1], cache_v.lazydata.st.var_vals), f"cache has wrong shape, not ({start_pos} == {sym_infer(cache_k.shape[1], cache_k.lazydata.st.var_vals)} == {sym_infer(cache_v.shape[1], cache_v.lazydata.st.var_vals)})"
assert seqlen == xk.shape[1] and seqlen == xv.shape[1], "seqlen is wrong shape?!?"
keys, values = self.cache_k.cat(xk, dim=1), self.cache_v.cat(xv, dim=1)
keys, values = cache_k.cat(xk, dim=1), cache_v.cat(xv, dim=1)
# save the cache
self.cache_k, self.cache_v = keys.realize(), values.realize()
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
return Tensor.scaled_dot_product_attention(xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), mask).transpose(1, 2).reshape(bsz, seqlen, -1)
# NOTE: this is not called
def __call__(self, x:Tensor, start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
xq, xk, xv = self.prepare_attention(x, freqs_cis)
output = self.inner_attention(xq, xk, xv, start_pos, mask)
return self.wo(output)
cache_k, cache_v = keys, values
keys, values = repeat_kv(keys, self.n_rep).realize(), repeat_kv(values, self.n_rep).realize()
attn = Tensor.scaled_dot_product_attention(xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), mask).transpose(1, 2).reshape(bsz, seqlen, -1)
return self.wo(attn).realize(), cache_k.realize(), cache_v.realize()
class FeedForward:
def __init__(self, dim, hidden_dim, multiple_of, linear=Linear, ffn_dim_multiplier=None):
@ -118,26 +110,32 @@ class TransformerBlock:
self.feed_forward = FeedForward(dim, 4*dim, multiple_of, linear, ffn_dim_multiplier)
self.attention_norm = RMSNorm(dim, norm_eps)
self.ffn_norm = RMSNorm(dim, norm_eps)
if getenv("JIT"):
self._pre = TinyJit(self.pre)
self._post = TinyJit(self.post)
else:
self._pre, self._post = self.pre, self.post
self.cache_k, self.cache_v = None, None
def pre(self, x:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
xq, xk, xv = self.attention.prepare_attention(self.attention_norm(x), freqs_cis)
return xq.realize(), xk.realize(), xv.realize()
self.jitted_attention_norm = TinyJit(lambda x: self.attention_norm(x).realize())
self.jitted_attn = TinyJit(self.attention.__call__)
self.jitted_norm_output = TinyJit(self.norm_output)
def post(self, x:Tensor, output:Tensor) -> Tensor:
h = x + self.attention.wo(output)
def norm_output(self, x:Tensor, output:Tensor) -> Tensor:
h = x + output
return (h + self.feed_forward(self.ffn_norm(h))).realize()
def __call__(self, x:Tensor, start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor]):
# if mask is not None, x's shape is dymanic based on user input and pre/post can't be jitted
xq, xk, xv = self._pre(x, freqs_cis) if mask is None else self.pre(x, freqs_cis)
# inner_attention can't be jitted because it's dynamic based on start_pos
output = self.attention.inner_attention(xq, xk, xv, start_pos, mask)
return self._post(x, output) if mask is None else self.post(x, output)
bsz, seqlen, _ = x.shape
do_jit = getenv("JIT") and mask is None
if do_jit:
pos = Variable("pos", 1, 1024)
self.cache_k = self.cache_k.reshape(self.cache_k.shape[0], pos, self.cache_k.shape[2], self.cache_k.shape[3])
self.cache_v = self.cache_v.reshape(self.cache_v.shape[0], pos, self.cache_v.shape[2], self.cache_v.shape[3])
output, cache_k, cache_v = self.jitted_attn(self.jitted_attention_norm(x), self.cache_k, self.cache_v, start_pos, freqs_cis, mask)
else:
output, cache_k, cache_v = self.attention(self.attention_norm(x), self.cache_k, self.cache_v, start_pos, freqs_cis, mask)
# save the cache. with symbolic shape, cast it back to int shape so we have int shape in cache
self.cache_k = cache_k.reshape(cache_k.shape[0], start_pos+seqlen, cache_k.shape[2], cache_k.shape[3]).realize()
self.cache_v = cache_v.reshape(cache_v.shape[0], start_pos+seqlen, cache_v.shape[2], cache_v.shape[3]).realize()
return self.jitted_norm_output(x, output) if do_jit else self.norm_output(x, output)
class Transformer:
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None, n_kv_heads=None):
@ -146,16 +144,21 @@ class Transformer:
self.tok_embeddings = Embedding(vocab_size, dim)
self.output = linear(dim, vocab_size, bias=False)
self.freqs_cis = Tensor(precompute_freqs_cis(dim // n_heads, max_seq_len * 2))
self.norm_output = lambda x: self.output(self.norm(x))
self.jitted_tok_embeddings = TinyJit(lambda x: self.tok_embeddings(x).realize())
self.jitted_norm_output = TinyJit(lambda x: self.norm_output(x).realize())
def __call__(self, tokens:Tensor, start_pos:int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
# get only the part we are using. making it contiguous avoids more kernel calls
# get only the part we are using. TODO: removing contiguous resulted in a bug?
freqs_cis = self.freqs_cis[:, start_pos:start_pos+seqlen].contiguous().realize()
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None
do_jit = getenv("JIT") and mask is None
h = self.jitted_tok_embeddings(tokens) if do_jit else self.tok_embeddings(tokens)
h = h.sequential([functools.partial(layer, start_pos=start_pos, freqs_cis=freqs_cis, mask=mask) for layer in self.layers])
return self.output(self.norm(h))
return self.jitted_norm_output(h) if do_jit else self.norm_output(h)
# **** files and arguments ****

8
extra/dist/world.py vendored
View file

@ -9,12 +9,12 @@ from tinygrad.tensor import Tensor, Function
import numpy as np
# fake the function signature of ASTRunner so we can put it in the cache
def __send_rb(args:Tuple[RawBufferCopyInOut, RawShmBuffer, int, Any], jit=False, force_wait=False):
def __send_rb(args:Tuple[RawBufferCopyInOut, RawShmBuffer, int, Any], variables=None, jit=False, force_wait=False):
args[0]._copyout(np.frombuffer(args[1]._buffer(), dtype=args[0].dtype.np))
dist.OOB.send(args[3], args[2])
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} sent {args[0]} to rank {args[2]}")
def __recv_rb(args:Tuple[RawBufferCopyIn, RawShmBuffer, int], jit=False, force_wait=False):
def __recv_rb(args:Tuple[RawBufferCopyIn, RawShmBuffer, int], variables=None, jit=False, force_wait=False):
dist.OOB.recv(args[2])
args[0]._copyin(args[1].toCPU())
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} recv {args[0]} from rank {args[2]}")
@ -35,7 +35,7 @@ def _send_rb(x:RawBufferCopyInOut, target_rank:int, cache_id:Optional[str]=None)
__send_rb((x, rb, target_rank, (shm_name, cache_id)))
# jit support
if GlobalCounters.cache is not None: GlobalCounters.cache.append((__send_rb, [x, rb, target_rank, None]))
if GlobalCounters.cache is not None: GlobalCounters.cache.append((__send_rb, [x, rb, target_rank, None], {}))
setattr(_send_rb, "shared_memory_cache", {})
# receive a rawbuffer from the target rank
@ -52,7 +52,7 @@ def _recv_rb(x:RawBufferCopyIn, target_rank:int):
s.unlink()
# jit support
if GlobalCounters.cache is not None: GlobalCounters.cache.append((__recv_rb, [x, rb, target_rank]))
if GlobalCounters.cache is not None: GlobalCounters.cache.append((__recv_rb, [x, rb, target_rank], {}))
# sends a lazybuffer from our rank to the target rank
def _send_lb(x:LazyBuffer, target_rank:int, cache_id:Optional[str]=None) -> None: _send_rb(x.contiguous().realize().realized, target_rank, cache_id=cache_id)

View file

@ -7,7 +7,8 @@ import json
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
for fxn,args in run.jit_cache:
for fxn,args,var_vals in run.jit_cache:
assert not var_vals, "symbolic shape is not supported"
functions[fxn.name] = fxn.prg # NOTE: this assumes all with the same name are the same
cargs = []
for i,arg in enumerate(args):

View file

@ -69,7 +69,7 @@ def compile(dat, output_fn):
# transform to CL.CACHE
used_ops = 0
cl_cache = []
for prg,args in model_exec.jit_cache:
for prg,args,_ in model_exec.jit_cache:
# pass these to thneed
setattr(prg.clprg, 'op_estimate', prg.op_estimate)
setattr(prg.clprg, 'prg', prg.prg)

View file

@ -11,7 +11,7 @@ class TestCopy(unittest.TestCase):
t = Tensor.randn(i).realize()
GlobalCounters.cache = []
t.assign(t+1).realize()
fxn, args = GlobalCounters.cache[0]
fxn, args, _ = GlobalCounters.cache[0]
GlobalCounters.reset()
def run(): return fxn(args, force_wait=True)
ct = min([run() for _ in range(10)])

View file

@ -31,7 +31,7 @@ def derandomize(x):
if isinstance(x, LazyOp):
if x.op == LoadOps.RAND: x.op = LoadOps.EMPTY
x.src = [derandomize(s) for s in x.src]
else:
elif hasattr(x, "op"):
x.op = derandomize(x.op)
return x

152
test/test_symbolic_jit.py Normal file
View file

@ -0,0 +1,152 @@
import unittest
from tinygrad.jit import TinyJit
from tinygrad.helpers import getenv
from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import Tensor, Device
import numpy as np
@unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported")
@unittest.skipUnless(Device.DEFAULT in ["GPU", "METAL", "CLANG", "CUDA"], f"{Device.DEFAULT} is not supported")
class TestSymbolicJit(unittest.TestCase):
def test_plus1(self):
def f(a): return (a+1).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(3, i)
symbolic = jf(a.reshape(3, vi)).reshape(3, i).cpu().numpy()
expected = f(a).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
def test_add(self):
def f(a, b): return (a+b).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(3, i)
b = Tensor.rand(3, i)
symbolic = jf(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
def test_matmul(self):
def f(a, b): return (a@b).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(3, i)
b = Tensor.rand(i, 5)
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
def test_mixed_with_no_symbol_kernel(self):
def f(a, b):
s = (a@b).realize()
s = (s+s).realize() # this one does not have symbols in input
return s
jf = TinyJit(f)
for i in range(1, 5):
vi = Variable("i", 1, 10)
a = Tensor.rand(3, i)
b = Tensor.rand(i, 5)
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 2
@unittest.skipIf(Device.DEFAULT == "CLANG", "broken on CLANG CI")
def test_attention(self):
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
for i in range(1, 5):
q = Tensor.rand(2, 1, 4, 8)
k = Tensor.rand(2, i, 4, 8)
v = Tensor.rand(2, i, 4, 8)
symbolic = jf(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).cpu().numpy()
expected = f(q, k, v).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 6
def test_cat_dim0(self):
def f(a, b): return a.cat(b, dim=0).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(i, 3)
b = Tensor.rand(2, 3)
symbolic = jf(a.reshape(vi, 3), b).reshape(i+2, 3).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
def test_cat_dim1(self):
def f(a, b): return a.cat(b, dim=1).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(3, i)
b = Tensor.rand(3, 2)
symbolic = jf(a.reshape(3, vi), b).reshape(3, i+2).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
def test_cat_dim0_two_vars(self):
def f(a, b): return a.cat(b, dim=0).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
for i in range(1, 5):
for j in range(1, 5):
a = Tensor.rand(i, 3)
b = Tensor.rand(j, 3)
symbolic = jf(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
def test_cat_dim1_two_vars(self):
def f(a, b): return a.cat(b, dim=1).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
for i in range(1, 5):
for j in range(1, 5):
a = Tensor.rand(3, i)
b = Tensor.rand(3, j)
symbolic = jf(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
def test_two_vars_plus1(self):
def f(a, b): return (a@b+1).realize()
jf = TinyJit(f)
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
for i in range(1, 5):
for j in range(1, 5):
a = Tensor.rand(i, 3)
b = Tensor.rand(3, j)
symbolic = jf(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1
def test_jit_symbolic_shape_mismatch(self):
@TinyJit
def add(a, b): return (a+b).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(3, i).reshape(3, vi)
b = Tensor.rand(3, i).reshape(3, vi)
c = add(a, b)
a = Tensor.rand(3, 7).reshape(3, vi)
bad = Tensor.rand(4, 7).reshape(4, vi)
with self.assertRaises(AssertionError):
add(a, bad)

View file

@ -131,7 +131,7 @@ class GlobalCounters:
kernel_count: ClassVar[int] = 0
mem_used: ClassVar[int] = 0 # NOTE: this is not reset
mem_cached: ClassVar[int] = 0 # NOTE: this is not reset
cache: ClassVar[Optional[List[Tuple[Callable, Any]]]] = None
cache: ClassVar[Optional[List[Tuple[Callable, Any, Dict[Any, int]]]]] = None # List[Tuple[Callable, List[RawBuffer], Dict[Variable, int]]]
@staticmethod
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0.0,0,None

View file

@ -1,10 +1,10 @@
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
import functools, itertools
from tinygrad.helpers import DEBUG, DType
from tinygrad.helpers import DEBUG, DType, merge_dicts
from tinygrad.lazy import Device
from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters, RawBuffer
from tinygrad.shape.symbolic import Variable, Node
JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU"]
@ -12,25 +12,27 @@ class TinyJit:
def __init__(self, fxn:Callable):
self.fxn: Callable = fxn
self.cnt: int = 0
self.jit_cache: List[Tuple[Callable, List[Optional[RawBuffer]]]] = []
self.jit_cache: List[Tuple[Callable, List[Optional[RawBuffer]], Dict[Variable, int]]] = []
self.ret: Any = None
self.input_replace: Dict[Tuple[int, int], Tuple[Union[int, str], int, DType]]= {} # (kernel_number, buffer_number) -> (input_name, expected_size, expected_type)
self.input_replace: Dict[Tuple[int, int], Tuple[Union[int, str], Tuple[Union[Node, int],...], DType]]= {} # (kernel_number, buffer_number) -> (input_name, expected_shape, expected_type)
# add support for instance methods
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
def __call__(self, *args, **kwargs) -> Any:
if Device.DEFAULT not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device
# NOTE: this cast is needed since although we know realize will create a ".realized" DeviceBuffer, the type checker doesn't
input_rawbuffers: Dict[Union[int, str], RawBuffer] = {cast(Union[int, str], k):cast(RawBuffer, v.realize().lazydata.realized) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
# NOTE: this cast is needed since although we know realize will create a ".realized" RawBuffer, the type checker doesn't
input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, Tuple[Union[Node, int],...]]] = {cast(Union[int, str], k):(cast(RawBuffer, v.realize().lazydata.realized), v.shape) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
assert len(input_rawbuffers) != 0, "no inputs to JIT"
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
if self.cnt >= 2:
for (j,i),(input_name, expected_size, expected_type) in self.input_replace.items():
assert input_rawbuffers[input_name].size == expected_size and input_rawbuffers[input_name].dtype == expected_type, f"size or type mismatch in JIT, {input_rawbuffers[input_name]} != <{expected_size}, {expected_type}>"
self.jit_cache[j][1][i] = input_rawbuffers[input_name]
for prg, pargs in self.jit_cache: # type: Callable, List[Optional[RawBuffer]]
prg(pargs, jit=True)
var_vals = dict(sorted(merge_dicts([arg.lazydata.st.var_vals for arg in args if isinstance(arg, Tensor)]).items(), key=lambda kv: kv[0].key))
for (j,i),(input_name, expected_shape, expected_type) in self.input_replace.items():
assert input_rawbuffers[input_name][1] == expected_shape and input_rawbuffers[input_name][0].dtype == expected_type, f"shape or type mismatch in JIT, <{input_rawbuffers[input_name][1]}, {input_rawbuffers[input_name][0].dtype}> != <{expected_shape}, {expected_type}>"
self.jit_cache[j][1][i] = input_rawbuffers[input_name][0]
for prg, pargs, variables in self.jit_cache: # type: Callable, List[Optional[RawBuffer]], Dict[Variable, int]
for v in (var_vals.keys() & variables.keys()): variables[v] = var_vals[v]
prg(pargs, variables, jit=True)
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
elif self.cnt == 1:
GlobalCounters.cache = []
@ -41,10 +43,10 @@ class TinyJit:
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
# get the inputs for replacement
for j_,(_,pargs_) in enumerate(self.jit_cache): # type: Tuple[int, Tuple[Callable, List[Optional[RawBuffer]]]]
for i,a in enumerate(pargs_):
if a in input_rawbuffers.values():
self.input_replace[(j_,i)] = [(k, v.size, v.dtype) for k,v in input_rawbuffers.items() if v == a][0]
for j_,cache in enumerate(self.jit_cache): # type: Tuple[int, Tuple[Callable, List[Optional[RawBuffer]], Dict[Variable, int]]]
for i,a in enumerate(cache[1]):
if a in [v[0] for v in input_rawbuffers.values()]:
self.input_replace[(j_,i)] = [(k, v[1], v[0].dtype) for k,v in input_rawbuffers.items() if v[0] == a][0]
#if prg.local_size is None: prg.local_size = prg.optimize_local_size(args, preserve_output=True) # the JIT can optimize local
assert set([x[0] for x in self.input_replace.values()]) == set(input_rawbuffers.keys()), "some input tensors not found"
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None

View file

@ -134,7 +134,7 @@ class ASTRunner:
def exec(self, bufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False, optimizing=False) -> Optional[float]:
rawbufs = dedup([x.realized for x in bufs if buf_is_kernel_arg(x)])
if GlobalCounters.cache is not None and not optimizing: GlobalCounters.cache.append((self, rawbufs))
if GlobalCounters.cache is not None and not optimizing: GlobalCounters.cache.append((self, rawbufs, var_vals if var_vals is not None else {}))
return self(rawbufs, var_vals, force_wait=force_wait)
def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:

View file

@ -78,7 +78,7 @@ class CUDAProgram:
if wait:
start, end = cuda.Event(), cuda.Event()
start.record()
self.prg(*[x._buf if isinstance(x, RawCUDABuffer) else x for x in args], block=tuple(local_size), grid=tuple(global_size))
self.prg(*[x._buf if isinstance(x, RawCUDABuffer) else np.int32(x) if (isinstance(x, int) and not getenv("CUDACPU")) else x for x in args], block=tuple(local_size), grid=tuple(global_size))
if wait:
end.record()
end.synchronize()