mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
jit input_buffers cleanup [pr] (#14532)
This commit is contained in:
parent
67f91e897b
commit
024f57ecf5
1 changed files with 1 additions and 2 deletions
|
|
@ -257,8 +257,7 @@ def _prepare_jit_inputs(args, kwargs):
|
|||
input_uops: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
|
||||
if any(u.base.op is Ops.CONST for u in input_uops):
|
||||
raise JitError("JIT inputs cannot be const, create a buffer with .contiguous()")
|
||||
input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b:=u.base.realized, MultiBuffer) else [b]
|
||||
for u in input_uops if u.base.realized is not None])
|
||||
input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for u in input_uops if (b:=u.base.realized) is not None])
|
||||
if len(set(input_buffers)) != len(input_buffers): raise JitError("duplicate inputs to JIT")
|
||||
inputs = [(*(u.substitute({u.base:UOp(Ops.NOOP)}, extra_pm=mop_cleanup).unbind_all()), u.dtype, u.device) for u in input_uops]
|
||||
_var_vals = merge_dicts([x[1] for x in inputs] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue