mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
pass through name function args in track_rewrites (#11572)
This commit is contained in:
parent
1826004ef9
commit
960cc6533a
2 changed files with 4 additions and 4 deletions
|
|
@ -97,9 +97,9 @@ class TestViz(BaseTestViz):
|
|||
|
||||
# name can also come from a function that returns a string
|
||||
def test_dyn_name_fxn(self):
|
||||
@track_rewrites(name=lambda a,ret: a.render())
|
||||
def name_from_fxn(s:UOp): return graph_rewrite(s, PatternMatcher([]))
|
||||
name_from_fxn(UOp.variable("a", 1, 10)+1)
|
||||
@track_rewrites(name=lambda *args,ret,**kwargs: ret.render())
|
||||
def name_from_fxn(s:UOp, arg:list|None=None): return graph_rewrite(s, PatternMatcher([]))
|
||||
name_from_fxn(UOp.variable("a", 1, 10)+1, arg=["test"])
|
||||
lst = get_viz_list()
|
||||
# name gets deduped by the function call counter
|
||||
self.assertEqual(lst[0]["name"], "(a+1) n1")
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from tinygrad.codegen.opt.kernel import Opt
|
|||
|
||||
# **************** Program Creation ****************
|
||||
|
||||
@track_rewrites(name=lambda _ast,_renderer,ret: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret))
|
||||
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret))
|
||||
def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) -> ProgramSpec:
|
||||
"""
|
||||
Transform an AST into a ProgramSpec. May trigger BEAM search.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue