mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
hotfix: add JITGRAPH and invert sints
This commit is contained in:
parent
80f53245e8
commit
954a2fef75
3 changed files with 5 additions and 4 deletions
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar, Generic
|
||||
import functools, itertools, operator
|
||||
from tinygrad.helpers import DEBUG, DType, merge_dicts, getenv, all_int
|
||||
from tinygrad.helpers import DEBUG, DType, merge_dicts, getenv, all_int, Context, GRAPH
|
||||
from tinygrad.device import Device, JITRunner, CompiledASTRunner, Buffer
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
|
@ -71,7 +71,8 @@ class TinyJit(Generic[ReturnType]):
|
|||
# jit capture
|
||||
self.expected_vals, self.expected_name_sts_dtype = expected_vals, expected_name_sts_dtype
|
||||
CacheCollector.start(var_vals)
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value)):
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
self.jit_cache = CacheCollector.finish()
|
||||
assert len(self.jit_cache) != 0, "didn't JIT anything!"
|
||||
assert len(set(get_input_replace(self.jit_cache, input_rawbuffers).values())) == len(input_rawbuffers), "some input tensors not found"
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ class ShapeTracker:
|
|||
|
||||
def __add__(self, st:ShapeTracker) -> ShapeTracker: return ShapeTracker(self.views + st.views).simplify()
|
||||
|
||||
def invert(self, out_shape:Tuple[int, ...]) -> Optional[ShapeTracker]:
|
||||
def invert(self, out_shape:Tuple[sint, ...]) -> Optional[ShapeTracker]:
|
||||
ret = tuple(v.invert(s) for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]))
|
||||
return ShapeTracker(cast(Tuple[View, ...], ret)) if all(x is not None for x in ret) else None
|
||||
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ class View:
|
|||
return View.create(new_shape, new_strides, new_offset, new_mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def invert(self, out_shape:Tuple[int, ...]) -> Optional[View]:
|
||||
def invert(self, out_shape:Tuple[sint, ...]) -> Optional[View]:
|
||||
ret = self.shrink(self.mask) if self.mask else self
|
||||
if prod(ret.shape) != prod(out_shape): return None # don't support shrink, expand, or stride != (-1, 1)
|
||||
ret = cast(View, ret.reshape(tuple(s for s in ret.shape if s != 1))) # removing ones will never be an issue
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue