mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
5 commits
master
...
python_spe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
82aa943cd4 | ||
|
|
e16782cf9e |
||
|
|
1c47ee729e | ||
|
|
a8f9e69bd9 | ||
|
|
ffff194e93 |
3 changed files with 23 additions and 12 deletions
|
|
@ -517,7 +517,7 @@ class TestUOpStr(unittest.TestCase):
|
||||||
|
|
||||||
class TestUPatHelpers(unittest.TestCase):
|
class TestUPatHelpers(unittest.TestCase):
|
||||||
def test_location(self):
|
def test_location(self):
|
||||||
self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "math.py")
|
self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "symbolic.py")
|
||||||
self.assertEqual(shared_spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "spec.py")
|
self.assertEqual(shared_spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "spec.py")
|
||||||
test_upat = UPat(Ops.CONST, dtypes.bool)
|
test_upat = UPat(Ops.CONST, dtypes.bool)
|
||||||
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])
|
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import unittest, time
|
import unittest, time
|
||||||
|
from tinygrad.helpers import Profiling
|
||||||
from tinygrad.uop.ops import UOp
|
from tinygrad.uop.ops import UOp
|
||||||
from tinygrad.dtype import dtypes
|
from tinygrad.dtype import dtypes
|
||||||
|
|
||||||
|
|
@ -38,6 +39,14 @@ class TestMicrobenchmarks(unittest.TestCase):
|
||||||
a = UOp.const(dtypes.int, 2)
|
a = UOp.const(dtypes.int, 2)
|
||||||
for _ in range(N): (a+a).simplify()
|
for _ in range(N): (a+a).simplify()
|
||||||
|
|
||||||
|
class TestMicroprofile(unittest.TestCase):
|
||||||
|
def test_uop_simplify_complex(self):
|
||||||
|
x = UOp.variable("x", 0, 10)
|
||||||
|
y = UOp.variable("y", 0, 10)
|
||||||
|
expr = (x*2)+5+(x*4)+(y*2)+y
|
||||||
|
with Profiling():
|
||||||
|
for _ in range(1000): expr.simplify()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -866,8 +866,8 @@ def print_uops(uops:list[UOp]):
|
||||||
|
|
||||||
def get_location() -> tuple[str, int]:
|
def get_location() -> tuple[str, int]:
|
||||||
frm = sys._getframe(1)
|
frm = sys._getframe(1)
|
||||||
# skip over ops.py/mathtraits.py (unless there's nothing but ops.py/mathtraits.py)
|
# skip over ops.py and anything in mixin
|
||||||
while pathlib.Path(frm.f_code.co_filename).name in ("ops.py", "mathtraits.py") and frm.f_back is not None and \
|
while ((codepath:=pathlib.Path(frm.f_code.co_filename)).name == "ops.py" or codepath.parent.name == "mixin") and frm.f_back is not None and \
|
||||||
not frm.f_back.f_code.co_filename.startswith("<frozen"):
|
not frm.f_back.f_code.co_filename.startswith("<frozen"):
|
||||||
frm = frm.f_back
|
frm = frm.f_back
|
||||||
return frm.f_code.co_filename, frm.f_lineno
|
return frm.f_code.co_filename, frm.f_lineno
|
||||||
|
|
@ -1077,20 +1077,22 @@ def track_rewrites(name:Callable[..., str|TracingKey]|bool=True, replay:bool=Fal
|
||||||
|
|
||||||
active_rewrites:list[TrackedGraphRewrite] = []
|
active_rewrites:list[TrackedGraphRewrite] = []
|
||||||
def profile_matches(fxn:Callable):
|
def profile_matches(fxn:Callable):
|
||||||
def wrap(*args, **kwargs):
|
def wrap_profile_matches(*args, **kwargs):
|
||||||
|
if TRACK_MATCH_STATS >= 2:
|
||||||
name = str(kwargs.get("name", None) or fxn.__name__)
|
name = str(kwargs.get("name", None) or fxn.__name__)
|
||||||
assert args and isinstance(args[0], UOp), f"invalid match tracing inputs for {name} with {args}"
|
assert args and isinstance(args[0], UOp), f"invalid match tracing inputs for {name} with {args}"
|
||||||
if tracking:=(TRACK_MATCH_STATS >= 2):
|
|
||||||
loc = ((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno)
|
loc = ((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno)
|
||||||
depth = len(active_rewrites)
|
depth = len(active_rewrites)
|
||||||
if not tracked_ctxs: add_trace_group(TracingKey(f"default {fxn.__name__}"))
|
if not tracked_ctxs: add_trace_group(TracingKey(f"default {fxn.__name__}"))
|
||||||
tracked_ctxs[-1].append(ctx:=TrackedGraphRewrite(loc, args[0].trace_num, [], name, depth, kwargs.get("bottom_up", False)))
|
tracked_ctxs[-1].append(ctx:=TrackedGraphRewrite(loc, args[0].trace_num, [], name, depth, kwargs.get("bottom_up", False)))
|
||||||
active_rewrites.append(ctx)
|
active_rewrites.append(ctx)
|
||||||
with cpu_profile(name, "TINY", display=tracking):
|
with cpu_profile(name, "TINY"):
|
||||||
ret = fxn(*args, **kwargs)
|
ret = fxn(*args, **kwargs)
|
||||||
if tracking: active_rewrites.pop()
|
active_rewrites.pop()
|
||||||
return ret
|
return ret
|
||||||
return wrap
|
# without tracking, we just call the function
|
||||||
|
return fxn(*args, **kwargs)
|
||||||
|
return wrap_profile_matches
|
||||||
|
|
||||||
class TrackedPatternMatcher(PatternMatcher):
|
class TrackedPatternMatcher(PatternMatcher):
|
||||||
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
|
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue