mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
clean up _function.depth properly [PR] (#16663)
This commit is contained in:
parent
d7a1022188
commit
d74f488376
2 changed files with 11 additions and 2 deletions
|
|
@ -20,6 +20,13 @@ class TestFunction(unittest.TestCase):
|
|||
a = Tensor([1,2,3])
|
||||
np.testing.assert_equal(f(a,a).numpy(), [2,4,6])
|
||||
|
||||
def test_depth_restored_on_exception(self):
|
||||
from tinygrad.function import _function
|
||||
@function
|
||||
def f(a:Tensor) -> Tensor: raise ValueError("error")
|
||||
with self.assertRaises(ValueError): f(Tensor([1]))
|
||||
self.assertEqual(_function.depth, 0)
|
||||
|
||||
def test_implicit(self):
|
||||
inp = Tensor([7,8,9])
|
||||
@function(allow_implicit=True)
|
||||
|
|
|
|||
|
|
@ -46,8 +46,10 @@ class _function(Generic[ReturnType]):
|
|||
# run it and do surgery later
|
||||
with Context(ALLOW_DEVICE_USAGE=getenv("DEVICE_IN_FUNCTION_BUG", 0)):
|
||||
_function.depth += 1
|
||||
ret = self.fxn(*args, **kwargs)
|
||||
_function.depth -= 1
|
||||
try:
|
||||
ret = self.fxn(*args, **kwargs)
|
||||
finally:
|
||||
_function.depth -= 1
|
||||
if isinstance(ret, Tensor):
|
||||
uret = ret.uop
|
||||
elif isinstance(ret, tuple) and all(isinstance(x, Tensor) for x in ret):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue