hotfix: add JITGRAPH and invert sints

This commit is contained in:
George Hotz 2023-12-18 16:33:22 -08:00
commit 954a2fef75
3 changed files with 5 additions and 4 deletions

View file

@ -1,7 +1,7 @@
from __future__ import annotations
from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar, Generic
import functools, itertools, operator
from tinygrad.helpers import DEBUG, DType, merge_dicts, getenv, all_int
from tinygrad.helpers import DEBUG, DType, merge_dicts, getenv, all_int, Context, GRAPH
from tinygrad.device import Device, JITRunner, CompiledASTRunner, Buffer
from tinygrad.tensor import Tensor
from tinygrad.shape.shapetracker import ShapeTracker
@ -71,7 +71,8 @@ class TinyJit(Generic[ReturnType]):
# jit capture
self.expected_vals, self.expected_name_sts_dtype = expected_vals, expected_name_sts_dtype
CacheCollector.start(var_vals)
self.ret = self.fxn(*args, **kwargs)
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value)):
self.ret = self.fxn(*args, **kwargs)
self.jit_cache = CacheCollector.finish()
assert len(self.jit_cache) != 0, "didn't JIT anything!"
assert len(set(get_input_replace(self.jit_cache, input_rawbuffers).values())) == len(input_rawbuffers), "some input tensors not found"

View file

@ -58,7 +58,7 @@ class ShapeTracker:
def __add__(self, st:ShapeTracker) -> ShapeTracker: return ShapeTracker(self.views + st.views).simplify()
def invert(self, out_shape:Tuple[int, ...]) -> Optional[ShapeTracker]:
def invert(self, out_shape:Tuple[sint, ...]) -> Optional[ShapeTracker]:
ret = tuple(v.invert(s) for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]))
return ShapeTracker(cast(Tuple[View, ...], ret)) if all(x is not None for x in ret) else None

View file

@ -99,7 +99,7 @@ class View:
return View.create(new_shape, new_strides, new_offset, new_mask)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def invert(self, out_shape:Tuple[int, ...]) -> Optional[View]:
def invert(self, out_shape:Tuple[sint, ...]) -> Optional[View]:
ret = self.shrink(self.mask) if self.mask else self
if prod(ret.shape) != prod(out_shape): return None # don't support shrink, expand, or stride != (-1, 1)
ret = cast(View, ret.reshape(tuple(s for s in ret.shape if s != 1))) # removing ones will never be an issue