jit input_buffers cleanup [pr] (#14532)

This commit is contained in:
chenyu 2026-02-04 10:14:38 -05:00 committed by GitHub
commit 024f57ecf5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -257,8 +257,7 @@ def _prepare_jit_inputs(args, kwargs):
input_uops: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
if any(u.base.op is Ops.CONST for u in input_uops):
raise JitError("JIT inputs cannot be const, create a buffer with .contiguous()")
input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b:=u.base.realized, MultiBuffer) else [b]
for u in input_uops if u.base.realized is not None])
input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for u in input_uops if (b:=u.base.realized) is not None])
if len(set(input_buffers)) != len(input_buffers): raise JitError("duplicate inputs to JIT")
inputs = [(*(u.substitute({u.base:UOp(Ops.NOOP)}, extra_pm=mop_cleanup).unbind_all()), u.dtype, u.device) for u in input_uops]
_var_vals = merge_dicts([x[1] for x in inputs] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])