mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix unit test dtypes
This commit is contained in:
parent
65dcd6dd45
commit
260da2017c
2 changed files with 7 additions and 7 deletions
|
|
@ -303,8 +303,8 @@ class TestRecurse(unittest.TestCase):
|
|||
def test_inf_loop(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
pm = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG)),
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
|
||||
])
|
||||
with self.assertRaises(RuntimeError):
|
||||
graph_rewrite(a, pm)
|
||||
|
|
@ -312,8 +312,8 @@ class TestRecurse(unittest.TestCase):
|
|||
def test_inf_loop_bottom_up(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
pm = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG)),
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
|
||||
])
|
||||
with self.assertRaises(RuntimeError):
|
||||
graph_rewrite(a, pm, bottom_up=True)
|
||||
|
|
|
|||
|
|
@ -124,10 +124,10 @@ class TestViz(BaseTestViz):
|
|||
|
||||
def test_inf_loop(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
b = a.replace(op=Ops.DEFINE_REG)
|
||||
b = a.replace(op=Ops.CONST)
|
||||
pm = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG)),
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
|
||||
])
|
||||
with self.assertRaises(RuntimeError): exec_rewrite(a, [pm])
|
||||
graphs = flatten(x["graph"].values() for x in get_details(tracked_ctxs[0][0]))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue