fix unit test dtypes

This commit is contained in:
George Hotz 2025-08-13 12:15:12 -07:00
commit 260da2017c
2 changed files with 7 additions and 7 deletions

View file

@ -303,8 +303,8 @@ class TestRecurse(unittest.TestCase):
def test_inf_loop(self):
a = UOp.variable('a', 0, 10)
pm = PatternMatcher([
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG)),
(UPat(Ops.DEFINE_REG, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
(UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
])
with self.assertRaises(RuntimeError):
graph_rewrite(a, pm)
@ -312,8 +312,8 @@ class TestRecurse(unittest.TestCase):
def test_inf_loop_bottom_up(self):
a = UOp.variable('a', 0, 10)
pm = PatternMatcher([
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG)),
(UPat(Ops.DEFINE_REG, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
(UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
])
with self.assertRaises(RuntimeError):
graph_rewrite(a, pm, bottom_up=True)

View file

@ -124,10 +124,10 @@ class TestViz(BaseTestViz):
def test_inf_loop(self):
a = UOp.variable('a', 0, 10)
b = a.replace(op=Ops.DEFINE_REG)
b = a.replace(op=Ops.CONST)
pm = PatternMatcher([
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG)),
(UPat(Ops.DEFINE_REG, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
(UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
])
with self.assertRaises(RuntimeError): exec_rewrite(a, [pm])
graphs = flatten(x["graph"].values() for x in get_details(tracked_ctxs[0][0]))