mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
final const cleanups [PR] (#16688)
This commit is contained in:
parent
9ae0a93d0e
commit
4618d27129
2 changed files with 2 additions and 2 deletions
|
|
@ -222,7 +222,7 @@ class TestCallSchedule(unittest.TestCase):
|
|||
# find the FUNCTION nodes
|
||||
c0 = next(u for u in r0.uop.toposort() if u.op is Ops.FUNCTION)
|
||||
c1 = next(u for u in r1.uop.toposort() if u.op is Ops.FUNCTION)
|
||||
# the function bodies (src[0]) should have identical keys — unique consts must not leak through
|
||||
# the function bodies (src[0]) should have identical keys
|
||||
self.assertEqual(c0.src[0].key, c1.src[0].key)
|
||||
|
||||
def test_precompile_symbolic_2d(self):
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
|
|||
params = {x.arg.slot:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
|
||||
grad_args = ctx.src
|
||||
root_grad = UOp(Ops.TUPLE, src=tuple(UOp(Ops.NOOP) if g.op is Ops.NOOP else
|
||||
g if g.base.op is Ops.CONST and g.device is None else g.param_like(len(args)+i) for i,g in enumerate(grad_args)))
|
||||
g if g.base.op is Ops.CONST else g.param_like(len(args)+i) for i,g in enumerate(grad_args)))
|
||||
grads = compute_gradient(fxn, root_grad, set(params.values()))
|
||||
# for precompiled calls, substitute forward outputs with params so intermediates aren't recomputed
|
||||
fwd_subs = {src: src.param_like(len(args)+len(grad_args)+i) for i, src in enumerate(fxn.src)} if k.arg.precompile else {}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue