Compare commits

...

5 commits

Author SHA1 Message Date
George Hotz
82aa943cd4 fix that test 2025-11-19 08:48:49 -08:00
George Hotz
e16782cf9e
Merge branch 'master' into python_speed 2025-11-19 08:41:40 -08:00
George Hotz
1c47ee729e fix names of rewrite rules 2025-11-19 08:41:34 -08:00
George Hotz
a8f9e69bd9 work on python speed 2025-11-19 08:34:15 -08:00
George Hotz
ffff194e93 skip process replay by default 2025-11-19 08:14:44 -08:00
3 changed files with 23 additions and 12 deletions

View file

@ -517,7 +517,7 @@ class TestUOpStr(unittest.TestCase):
class TestUPatHelpers(unittest.TestCase):
def test_location(self):
self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "math.py")
self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "symbolic.py")
self.assertEqual(shared_spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "spec.py")
test_upat = UPat(Ops.CONST, dtypes.bool)
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])

View file

@ -1,4 +1,5 @@
import unittest, time
from tinygrad.helpers import Profiling
from tinygrad.uop.ops import UOp
from tinygrad.dtype import dtypes
@ -38,6 +39,14 @@ class TestMicrobenchmarks(unittest.TestCase):
a = UOp.const(dtypes.int, 2)
for _ in range(N): (a+a).simplify()
class TestMicroprofile(unittest.TestCase):
def test_uop_simplify_complex(self):
x = UOp.variable("x", 0, 10)
y = UOp.variable("y", 0, 10)
expr = (x*2)+5+(x*4)+(y*2)+y
with Profiling():
for _ in range(1000): expr.simplify()
if __name__ == '__main__':
unittest.main()

View file

@ -866,8 +866,8 @@ def print_uops(uops:list[UOp]):
def get_location() -> tuple[str, int]:
frm = sys._getframe(1)
# skip over ops.py/mathtraits.py (unless there's nothing but ops.py/mathtraits.py)
while pathlib.Path(frm.f_code.co_filename).name in ("ops.py", "mathtraits.py") and frm.f_back is not None and \
# skip over ops.py and anything in mixin
while ((codepath:=pathlib.Path(frm.f_code.co_filename)).name == "ops.py" or codepath.parent.name == "mixin") and frm.f_back is not None and \
not frm.f_back.f_code.co_filename.startswith("<frozen"):
frm = frm.f_back
return frm.f_code.co_filename, frm.f_lineno
@ -1077,20 +1077,22 @@ def track_rewrites(name:Callable[..., str|TracingKey]|bool=True, replay:bool=Fal
active_rewrites:list[TrackedGraphRewrite] = []
def profile_matches(fxn:Callable):
def wrap(*args, **kwargs):
name = str(kwargs.get("name", None) or fxn.__name__)
assert args and isinstance(args[0], UOp), f"invalid match tracing inputs for {name} with {args}"
if tracking:=(TRACK_MATCH_STATS >= 2):
def wrap_profile_matches(*args, **kwargs):
if TRACK_MATCH_STATS >= 2:
name = str(kwargs.get("name", None) or fxn.__name__)
assert args and isinstance(args[0], UOp), f"invalid match tracing inputs for {name} with {args}"
loc = ((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno)
depth = len(active_rewrites)
if not tracked_ctxs: add_trace_group(TracingKey(f"default {fxn.__name__}"))
tracked_ctxs[-1].append(ctx:=TrackedGraphRewrite(loc, args[0].trace_num, [], name, depth, kwargs.get("bottom_up", False)))
active_rewrites.append(ctx)
with cpu_profile(name, "TINY", display=tracking):
ret = fxn(*args, **kwargs)
if tracking: active_rewrites.pop()
return ret
return wrap
with cpu_profile(name, "TINY"):
ret = fxn(*args, **kwargs)
active_rewrites.pop()
return ret
# without tracking, we just call the function
return fxn(*args, **kwargs)
return wrap_profile_matches
class TrackedPatternMatcher(PatternMatcher):
def rewrite(self, uop:UOp, ctx=None) -> UOp|None: