use assertEqual with new style uops [pr] (#7360)

This commit is contained in:
George Hotz 2024-10-29 17:43:21 +07:00 committed by GitHub
commit 0af1212164
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 21 additions and 29 deletions

View file

@ -63,11 +63,6 @@ def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)):
diff = ocdiff.console_diff(str(s0), str(s1))
logging.info(diff)
def assert_equiv_uops(u1:UOp, u2:UOp) -> None:
if u1 is not u2:
print_diff(u1, u2)
raise AssertionError("uops aren't equal.")
def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional[ShapeTracker]=None, st_src:Optional[Tuple[UOp]]=None) -> UOp:
if st_src is None:
st_src = (st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)

View file

@ -1,7 +1,6 @@
import unittest
import time
import numpy as np
from test.helpers import assert_equiv_uops
from tinygrad import Tensor, dtypes
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import lower_schedule_item, run_schedule
@ -43,8 +42,8 @@ class TestFusionOp(unittest.TestCase):
c = Tensor([1,2,3,4])
for _ in range(23): c = c + c
sched3 = create_schedule([c.lazydata])
assert_equiv_uops(sched1[-1].ast, sched2[-1].ast)
with self.assertRaises(AssertionError): assert_equiv_uops(sched1[-1].ast, sched3[-1].ast)
self.assertEqual(sched1[-1].ast, sched2[-1].ast)
with self.assertRaises(AssertionError): self.assertEqual(sched1[-1].ast, sched3[-1].ast)
self.assertLess(time.perf_counter()-st, 2.0)
if __name__ == '__main__':

View file

@ -1,6 +1,5 @@
import unittest, pickle, types
import numpy as np
from test.helpers import assert_equiv_uops
from tinygrad import Tensor, TinyJit, Variable, dtypes
from tinygrad.engine.schedule import create_schedule
from tinygrad.ops import PatternMatcher, UPat, UOp
@ -83,7 +82,7 @@ class TestPickle(unittest.TestCase):
sched = create_schedule([out.lazydata])
pk = pickle.dumps(sched)
sched_pk = pickle.loads(pk)
assert_equiv_uops(sched_pk[-1].ast, sched[-1].ast)
self.assertEqual(sched_pk[-1].ast, sched[-1].ast)
def test_pickle_renderer(self):
from tinygrad.device import Device

View file

@ -1,6 +1,5 @@
from typing import List
import unittest, time
from test.helpers import assert_equiv_uops
from tinygrad import dtypes, Device
from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo
@ -262,7 +261,7 @@ class TestUOpGraph(unittest.TestCase):
# possible
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
xyzw = tuple(UOp(UOps.GEP, dtypes.float, (val,), (i,)) for i in range(4))
assert_equiv_uops(_test_vec(xyzw), val)
self.assertEqual(_test_vec(xyzw), val)
# unaligned
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
@ -290,7 +289,7 @@ class TestUOpGraph(unittest.TestCase):
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
uops = to_uops_list([UOp(UOps.GEP, dtypes.float, (vec,), (i,)) for i in range(vec_size)])
for uop, const in zip(uops, consts):
assert_equiv_uops(uop, const)
self.assertEqual(uop, const)
def test_wmma_vectorize_fold(self):
for i in [2, 4, 8]:
@ -299,7 +298,7 @@ class TestUOpGraph(unittest.TestCase):
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[0], acc)
self.assertEqual(uops[0], acc)
self.assertEqual(len(uops), 1)
for i in [2, 4, 8]:
@ -308,7 +307,7 @@ class TestUOpGraph(unittest.TestCase):
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[0], acc)
self.assertEqual(uops[0], acc)
self.assertEqual(len(uops), 1)
@unittest.skip("wmma is wrong here, it needs an arg")
@ -321,7 +320,7 @@ class TestUOpGraph(unittest.TestCase):
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[-1], wmma)
self.assertEqual(uops[-1], wmma)
for i in [4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
@ -331,7 +330,7 @@ class TestUOpGraph(unittest.TestCase):
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[-1], wmma)
self.assertEqual(uops[-1], wmma)
for i in [2, 4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
@ -340,7 +339,7 @@ class TestUOpGraph(unittest.TestCase):
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[-1], wmma)
self.assertEqual(uops[-1], wmma)
for i in [2, 4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
@ -349,7 +348,7 @@ class TestUOpGraph(unittest.TestCase):
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[-1], wmma)
self.assertEqual(uops[-1], wmma)
def test_cast_alu_fold(self):
d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0)
@ -395,9 +394,9 @@ class TestUOpGraph(unittest.TestCase):
uops = to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, idx, ld1+ld0))])
ld0, ld1 = uops[-1].src[2].src
# ld0 becomes the invalid value
assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
self.assertEqual(ld1, UOp.const(dtypes.int, 2))
# the gate and invalid value are deleted from ld1
assert_equiv_uops(ld0, UOp.load(glbl2, idx, dtype=dtypes.int))
self.assertEqual(ld0, UOp.load(glbl2, idx, dtype=dtypes.int))
def test_fold_gated_load_local(self):
glbl0 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
@ -410,9 +409,9 @@ class TestUOpGraph(unittest.TestCase):
uops = to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, lidx, ld1+ld0))])
ld0, ld1 = uops[-1].src[2].src
# ld0 becomes the invalid value
assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
self.assertEqual(ld1, UOp.const(dtypes.int, 2))
# the gate and invalid value are deleted from ld1
assert_equiv_uops(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int))
self.assertEqual(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int))
def test_fold_gated_store(self):
glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
@ -424,7 +423,7 @@ class TestUOpGraph(unittest.TestCase):
uops = to_uops_list([st0, st1])
# only the second store happens
self.assertEqual(len(uops), 4)
assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val))
self.assertEqual(uops[-1], UOp.store(glbl, idx1, val))
@unittest.skip("this is a uop type error")
def test_asserts_bad_gate(self):
@ -679,7 +678,7 @@ class TestIFUOps(unittest.TestCase):
sink = gate_rewrite(sink)
if_uops = [u for u in sink.parents if u.op is UOps.IF]
self.assertEqual(len(if_uops), 1)
assert_equiv_uops(if_uops[0].src[0], gate)
self.assertEqual(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 3)
@ -697,7 +696,7 @@ class TestIFUOps(unittest.TestCase):
sink = gate_rewrite(sink)
if_uops = [u for u in sink.parents if u.op is UOps.IF]
self.assertEqual(len(if_uops), 1)
assert_equiv_uops(if_uops[0].src[0], gate)
self.assertEqual(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 3)
@ -713,7 +712,7 @@ class TestIFUOps(unittest.TestCase):
sink = gate_rewrite(sink)
if_uops = [u for u in sink.parents if u.op is UOps.IF]
self.assertEqual(len(if_uops), 1)
assert_equiv_uops(if_uops[0].src[0], gate)
self.assertEqual(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 3)

View file

@ -12,7 +12,7 @@ from tinygrad.engine.schedule import create_schedule, to_si
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
from test.helpers import is_dtype_supported, assert_equiv_uops
from test.helpers import is_dtype_supported
def to_uops_list(u:List[UOp], opts=None, skip_check=False) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check)
@ -395,7 +395,7 @@ class TestUOpStr(unittest.TestCase):
# nice big complicated uop
with Context(NOOPT=1):
sink = UOp(UOps.SINK, dtypes.void, (get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops[-1],))
assert_equiv_uops(sink, eval(str(sink)))
self.assertEqual(sink, eval(str(sink)))
def test_vectorized_str(self):
vec = UOp(UOps.VECTORIZE, dtypes.int.vec(4), tuple(UOp.const(dtypes.int, x) for x in range(4)))