minor function.py cleanups [PR] (#16662)

This commit is contained in:
chenyu 2026-06-18 13:36:48 -04:00 committed by GitHub
commit d7a1022188
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -56,13 +56,9 @@ class _function(Generic[ReturnType]):
raise RuntimeError(f"function return type {type(ret)} not supported")
# replace the known inputs with params (using deduplicated slots)
subs = {}
for i,x in enumerate(call_uops): subs[x] = x.param_like(i)
subs = {x:x.param_like(i) for i,x in enumerate(call_uops)}
uret = uret.substitute(subs)
# add contiguous to call_uops
#call_uops = [x.contiguous() for x in call_uops]
# the BUFFERs that are left are the implicit inputs
num_explicit = len(call_uops)
uret = graph_rewrite(uret, pm_ctx, (call_uops, itertools.count(0)), bottom_up=True, name="get_implicit_inputs")
@ -73,19 +69,11 @@ class _function(Generic[ReturnType]):
buf_strs = '\n '.join(f"{i}: dtype={b.dtype}, size={b.arg}, device={b.device}" for i,b in enumerate(implicit_buffers))
raise RuntimeError(f"function {name} has {len(implicit_buffers)} implicit buffer(s), but allow_implicit=False\n {buf_strs}")
# assign output
#pbuffer = uret.param_like(len(call_uops))
#assigned = pbuffer.assign(uret).sink()
#buffer = UOp.new_buffer(pbuffer.device, pbuffer.size, pbuffer.dtype).reshape(uret.shape)
#call = assigned.call(*call_uops, buffer, name=name)
#ret = buffer.after(call)
fret = uret.call(*call_uops, grad_fxn=self.grad_fxn, name=name, precompile=self.precompile,
precompile_backward=self.precompile_backward)
if DEBUG >= 2:
#signature = [(x._shape, x.dtype, x.device) for x in call_uops]
print(" "*_function.depth+f"function {uret.key.hex()[:8]} in {(time.perf_counter()-st)*1000:8.2f} ms: {name}") # with sig {signature}")
print(" "*_function.depth+f"function {uret.key.hex()[:8]} in {(time.perf_counter()-st)*1000:8.2f} ms: {name}")
if isinstance(ret, tuple):
return cast(ReturnType, tuple(Tensor(fret.gettuple(i)) for i in range(len(ret))))