more const cleanups [PR] (#16682)

This commit is contained in:
chenyu 2026-06-20 20:41:43 -04:00 committed by GitHub
commit 9ae0a93d0e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 1 additions and 5 deletions

View file

@ -231,8 +231,7 @@ def _prepare_jit_inputs(args, kwargs):
it = x if isinstance(x, (tuple,list)) else x.values() if isinstance(x, dict) else []
tensors += [t for t in it if t.__class__ is Tensor and not any(t is y for y in tensors)]
def get_input_uops() -> list[UOp]: return flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
# TODO: drop the CONST branch once all CONST are deviceless
if any(u.device is None or u.base.op is Ops.CONST for u in get_input_uops()): raise JitError("JIT inputs must be real buffers; use .clone()")
if any(u.device is None for u in get_input_uops()): raise JitError("JIT inputs must be real buffers; use .clone()")
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
input_uops = get_input_uops()
# collect buffer UOps (including MultiBuffer)

View file

@ -543,9 +543,6 @@ to_define_global = PatternMatcher([
# remove device from local BUFFERIZE
(UPat(Ops.STAGE, name="b"), lambda b: b.replace(arg=replace(b.arg, device=None))),
# remove UNIQUE/DEVICE to dedup CONST
(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
# renumber the ranges starting with 0 so that kernel deduping works
(UPat(Ops.RANGE, name="r"), renumber_range),
])