final const cleanups [PR] (#16688)

This commit is contained in:
chenyu 2026-06-20 21:38:16 -04:00 committed by GitHub
commit 4618d27129
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View file

@ -222,7 +222,7 @@ class TestCallSchedule(unittest.TestCase):
# find the FUNCTION nodes
c0 = next(u for u in r0.uop.toposort() if u.op is Ops.FUNCTION)
c1 = next(u for u in r1.uop.toposort() if u.op is Ops.FUNCTION)
# the function bodies (src[0]) should have identical keys — unique consts must not leak through
# the function bodies (src[0]) should have identical keys
self.assertEqual(c0.src[0].key, c1.src[0].key)
def test_precompile_symbolic_2d(self):

View file

@ -33,7 +33,7 @@ def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
params = {x.arg.slot:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
grad_args = ctx.src
root_grad = UOp(Ops.TUPLE, src=tuple(UOp(Ops.NOOP) if g.op is Ops.NOOP else
g if g.base.op is Ops.CONST and g.device is None else g.param_like(len(args)+i) for i,g in enumerate(grad_args)))
g if g.base.op is Ops.CONST else g.param_like(len(args)+i) for i,g in enumerate(grad_args)))
grads = compute_gradient(fxn, root_grad, set(params.values()))
# for precompiled calls, substitute forward outputs with params so intermediates aren't recomputed
fwd_subs = {src: src.param_like(len(args)+len(grad_args)+i) for i, src in enumerate(fxn.src)} if k.arg.precompile else {}