clean up _function.depth properly [PR] (#16663)

This commit is contained in:
chenyu 2026-06-18 14:10:22 -04:00 committed by GitHub
commit d74f488376
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 11 additions and 2 deletions

View file

@ -20,6 +20,13 @@ class TestFunction(unittest.TestCase):
a = Tensor([1,2,3])
np.testing.assert_equal(f(a,a).numpy(), [2,4,6])
def test_depth_restored_on_exception(self):
from tinygrad.function import _function
@function
def f(a:Tensor) -> Tensor: raise ValueError("error")
with self.assertRaises(ValueError): f(Tensor([1]))
self.assertEqual(_function.depth, 0)
def test_implicit(self):
inp = Tensor([7,8,9])
@function(allow_implicit=True)

View file

@ -46,8 +46,10 @@ class _function(Generic[ReturnType]):
# run it and do surgery later
with Context(ALLOW_DEVICE_USAGE=getenv("DEVICE_IN_FUNCTION_BUG", 0)):
_function.depth += 1
ret = self.fxn(*args, **kwargs)
_function.depth -= 1
try:
ret = self.fxn(*args, **kwargs)
finally:
_function.depth -= 1
if isinstance(ret, Tensor):
uret = ret.uop
elif isinstance(ret, tuple) and all(isinstance(x, Tensor) for x in ret):