mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
update jit type annotation post lazy rewrite (#3091)
This commit is contained in:
parent
0fe6904351
commit
dcf7ecaaff
1 changed files with 2 additions and 3 deletions
|
|
@ -27,7 +27,7 @@ def get_input_replace(jit_cache: List[JitItem], input_rawbuffers:List[Buffer]) -
|
|||
input_replace[(j,i)] = input_rawbuffers.index(a)
|
||||
return input_replace
|
||||
def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[JitItem]) -> List[int]:
|
||||
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ((ji.prg.global_size and not all_int(tuple(ji.prg.global_size))) or (ji.prg.local_size and not all_int(tuple(ji.prg.local_size))))] # noqa: E501
|
||||
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ((ji.prg.global_size and not all_int(ji.prg.global_size)) or (ji.prg.local_size and not all_int(ji.prg.local_size)))] # noqa: E501
|
||||
def get_jc_idxs_with_updatable_var_vals(jit_cache: List[JitItem]) -> List[int]:
|
||||
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ji.prg.vars]
|
||||
|
||||
|
|
@ -52,12 +52,11 @@ class TinyJit(Generic[ReturnType]):
|
|||
|
||||
def __call__(self, *args, **kwargs) -> ReturnType:
|
||||
# all inputs (except const) are realized
|
||||
input_tensors: Dict[Union[int, str], LazyBuffer] = {cast(Union[int, str], k):v.realize().lazydata for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor} # noqa: E501
|
||||
input_tensors: Dict[Union[int, str], LazyBuffer] = {cast(Union[int, str], k): cast(LazyBuffer, v.realize().lazydata) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)} # noqa: E501
|
||||
assert all(isinstance(x, LazyBuffer) for x in input_tensors.values()), "multilazybuffer JIT isn't supported"
|
||||
expected_name_sts_dtype = tuple([(k, v.st.unbind()[0], v.dtype) for k,v in input_tensors.items()])
|
||||
|
||||
# get rawbuffers
|
||||
# TODO: why can .realized have Any type?
|
||||
input_rawbuffers: List[Buffer] = [v.base.realized for v in input_tensors.values() if v.base.realized is not None]
|
||||
assert len(set(input_rawbuffers)) == len(input_rawbuffers), "duplicate inputs to JIT"
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue