Compare commits

...

5 commits

Author SHA1 Message Date
George Hotz
82aa943cd4 fix that test 2025-11-19 08:48:49 -08:00
George Hotz
e16782cf9e
Merge branch 'master' into python_speed 2025-11-19 08:41:40 -08:00
George Hotz
1c47ee729e fix names of rewrite rules 2025-11-19 08:41:34 -08:00
George Hotz
a8f9e69bd9 work on python speed 2025-11-19 08:34:15 -08:00
George Hotz
ffff194e93 skip process replay by default 2025-11-19 08:14:44 -08:00
3 changed files with 23 additions and 12 deletions

View file

@ -517,7 +517,7 @@ class TestUOpStr(unittest.TestCase):
class TestUPatHelpers(unittest.TestCase): class TestUPatHelpers(unittest.TestCase):
def test_location(self): def test_location(self):
self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "math.py") self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "symbolic.py")
self.assertEqual(shared_spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "spec.py") self.assertEqual(shared_spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "spec.py")
test_upat = UPat(Ops.CONST, dtypes.bool) test_upat = UPat(Ops.CONST, dtypes.bool)
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1]) self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])

View file

@ -1,4 +1,5 @@
import unittest, time import unittest, time
from tinygrad.helpers import Profiling
from tinygrad.uop.ops import UOp from tinygrad.uop.ops import UOp
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
@ -38,6 +39,14 @@ class TestMicrobenchmarks(unittest.TestCase):
a = UOp.const(dtypes.int, 2) a = UOp.const(dtypes.int, 2)
for _ in range(N): (a+a).simplify() for _ in range(N): (a+a).simplify()
class TestMicroprofile(unittest.TestCase):
def test_uop_simplify_complex(self):
x = UOp.variable("x", 0, 10)
y = UOp.variable("y", 0, 10)
expr = (x*2)+5+(x*4)+(y*2)+y
with Profiling():
for _ in range(1000): expr.simplify()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -866,8 +866,8 @@ def print_uops(uops:list[UOp]):
def get_location() -> tuple[str, int]: def get_location() -> tuple[str, int]:
frm = sys._getframe(1) frm = sys._getframe(1)
# skip over ops.py/mathtraits.py (unless there's nothing but ops.py/mathtraits.py) # skip over ops.py and anything in mixin
while pathlib.Path(frm.f_code.co_filename).name in ("ops.py", "mathtraits.py") and frm.f_back is not None and \ while ((codepath:=pathlib.Path(frm.f_code.co_filename)).name == "ops.py" or codepath.parent.name == "mixin") and frm.f_back is not None and \
not frm.f_back.f_code.co_filename.startswith("<frozen"): not frm.f_back.f_code.co_filename.startswith("<frozen"):
frm = frm.f_back frm = frm.f_back
return frm.f_code.co_filename, frm.f_lineno return frm.f_code.co_filename, frm.f_lineno
@ -1077,20 +1077,22 @@ def track_rewrites(name:Callable[..., str|TracingKey]|bool=True, replay:bool=Fal
active_rewrites:list[TrackedGraphRewrite] = [] active_rewrites:list[TrackedGraphRewrite] = []
def profile_matches(fxn:Callable): def profile_matches(fxn:Callable):
def wrap(*args, **kwargs): def wrap_profile_matches(*args, **kwargs):
if TRACK_MATCH_STATS >= 2:
name = str(kwargs.get("name", None) or fxn.__name__) name = str(kwargs.get("name", None) or fxn.__name__)
assert args and isinstance(args[0], UOp), f"invalid match tracing inputs for {name} with {args}" assert args and isinstance(args[0], UOp), f"invalid match tracing inputs for {name} with {args}"
if tracking:=(TRACK_MATCH_STATS >= 2):
loc = ((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno) loc = ((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno)
depth = len(active_rewrites) depth = len(active_rewrites)
if not tracked_ctxs: add_trace_group(TracingKey(f"default {fxn.__name__}")) if not tracked_ctxs: add_trace_group(TracingKey(f"default {fxn.__name__}"))
tracked_ctxs[-1].append(ctx:=TrackedGraphRewrite(loc, args[0].trace_num, [], name, depth, kwargs.get("bottom_up", False))) tracked_ctxs[-1].append(ctx:=TrackedGraphRewrite(loc, args[0].trace_num, [], name, depth, kwargs.get("bottom_up", False)))
active_rewrites.append(ctx) active_rewrites.append(ctx)
with cpu_profile(name, "TINY", display=tracking): with cpu_profile(name, "TINY"):
ret = fxn(*args, **kwargs) ret = fxn(*args, **kwargs)
if tracking: active_rewrites.pop() active_rewrites.pop()
return ret return ret
return wrap # without tracking, we just call the function
return fxn(*args, **kwargs)
return wrap_profile_matches
class TrackedPatternMatcher(PatternMatcher): class TrackedPatternMatcher(PatternMatcher):
def rewrite(self, uop:UOp, ctx=None) -> UOp|None: def rewrite(self, uop:UOp, ctx=None) -> UOp|None: