mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
minor function.py cleanups [PR] (#16662)
This commit is contained in:
parent
924bece1d5
commit
d7a1022188
1 changed files with 2 additions and 14 deletions
|
|
@ -56,13 +56,9 @@ class _function(Generic[ReturnType]):
|
|||
raise RuntimeError(f"function return type {type(ret)} not supported")
|
||||
|
||||
# replace the known inputs with params (using deduplicated slots)
|
||||
subs = {}
|
||||
for i,x in enumerate(call_uops): subs[x] = x.param_like(i)
|
||||
subs = {x:x.param_like(i) for i,x in enumerate(call_uops)}
|
||||
uret = uret.substitute(subs)
|
||||
|
||||
# add contiguous to call_uops
|
||||
#call_uops = [x.contiguous() for x in call_uops]
|
||||
|
||||
# the BUFFERs that are left are the implicit inputs
|
||||
num_explicit = len(call_uops)
|
||||
uret = graph_rewrite(uret, pm_ctx, (call_uops, itertools.count(0)), bottom_up=True, name="get_implicit_inputs")
|
||||
|
|
@ -73,19 +69,11 @@ class _function(Generic[ReturnType]):
|
|||
buf_strs = '\n '.join(f"{i}: dtype={b.dtype}, size={b.arg}, device={b.device}" for i,b in enumerate(implicit_buffers))
|
||||
raise RuntimeError(f"function {name} has {len(implicit_buffers)} implicit buffer(s), but allow_implicit=False\n {buf_strs}")
|
||||
|
||||
# assign output
|
||||
#pbuffer = uret.param_like(len(call_uops))
|
||||
#assigned = pbuffer.assign(uret).sink()
|
||||
#buffer = UOp.new_buffer(pbuffer.device, pbuffer.size, pbuffer.dtype).reshape(uret.shape)
|
||||
#call = assigned.call(*call_uops, buffer, name=name)
|
||||
#ret = buffer.after(call)
|
||||
|
||||
fret = uret.call(*call_uops, grad_fxn=self.grad_fxn, name=name, precompile=self.precompile,
|
||||
precompile_backward=self.precompile_backward)
|
||||
|
||||
if DEBUG >= 2:
|
||||
#signature = [(x._shape, x.dtype, x.device) for x in call_uops]
|
||||
print(" "*_function.depth+f"function {uret.key.hex()[:8]} in {(time.perf_counter()-st)*1000:8.2f} ms: {name}") # with sig {signature}")
|
||||
print(" "*_function.depth+f"function {uret.key.hex()[:8]} in {(time.perf_counter()-st)*1000:8.2f} ms: {name}")
|
||||
|
||||
if isinstance(ret, tuple):
|
||||
return cast(ReturnType, tuple(Tensor(fret.gettuple(i)) for i in range(len(ret))))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue