update jit type annotation post lazy rewrite (#3091)

This commit is contained in:
chenyu 2024-01-11 15:49:30 -05:00 committed by GitHub
commit dcf7ecaaff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -27,7 +27,7 @@ def get_input_replace(jit_cache: List[JitItem], input_rawbuffers:List[Buffer]) -
input_replace[(j,i)] = input_rawbuffers.index(a)
return input_replace
def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[JitItem]) -> List[int]:
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ((ji.prg.global_size and not all_int(tuple(ji.prg.global_size))) or (ji.prg.local_size and not all_int(tuple(ji.prg.local_size))))] # noqa: E501
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ((ji.prg.global_size and not all_int(ji.prg.global_size)) or (ji.prg.local_size and not all_int(ji.prg.local_size)))] # noqa: E501
def get_jc_idxs_with_updatable_var_vals(jit_cache: List[JitItem]) -> List[int]:
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ji.prg.vars]
@ -52,12 +52,11 @@ class TinyJit(Generic[ReturnType]):
def __call__(self, *args, **kwargs) -> ReturnType:
# all inputs (except const) are realized
input_tensors: Dict[Union[int, str], LazyBuffer] = {cast(Union[int, str], k):v.realize().lazydata for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor} # noqa: E501
input_tensors: Dict[Union[int, str], LazyBuffer] = {cast(Union[int, str], k): cast(LazyBuffer, v.realize().lazydata) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)} # noqa: E501
assert all(isinstance(x, LazyBuffer) for x in input_tensors.values()), "multilazybuffer JIT isn't supported"
expected_name_sts_dtype = tuple([(k, v.st.unbind()[0], v.dtype) for k,v in input_tensors.items()])
# get rawbuffers
# TODO: why can .realized have Any type?
input_rawbuffers: List[Buffer] = [v.base.realized for v in input_tensors.values() if v.base.realized is not None]
assert len(set(input_rawbuffers)) == len(input_rawbuffers), "duplicate inputs to JIT"