mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
use assertEqual with new style uops [pr] (#7360)
This commit is contained in:
parent
0beb2d8f84
commit
0af1212164
5 changed files with 21 additions and 29 deletions
|
|
@ -63,11 +63,6 @@ def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)):
|
|||
diff = ocdiff.console_diff(str(s0), str(s1))
|
||||
logging.info(diff)
|
||||
|
||||
def assert_equiv_uops(u1:UOp, u2:UOp) -> None:
|
||||
if u1 is not u2:
|
||||
print_diff(u1, u2)
|
||||
raise AssertionError("uops aren't equal.")
|
||||
|
||||
def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional[ShapeTracker]=None, st_src:Optional[Tuple[UOp]]=None) -> UOp:
|
||||
if st_src is None:
|
||||
st_src = (st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import unittest
|
||||
import time
|
||||
import numpy as np
|
||||
from test.helpers import assert_equiv_uops
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import lower_schedule_item, run_schedule
|
||||
|
|
@ -43,8 +42,8 @@ class TestFusionOp(unittest.TestCase):
|
|||
c = Tensor([1,2,3,4])
|
||||
for _ in range(23): c = c + c
|
||||
sched3 = create_schedule([c.lazydata])
|
||||
assert_equiv_uops(sched1[-1].ast, sched2[-1].ast)
|
||||
with self.assertRaises(AssertionError): assert_equiv_uops(sched1[-1].ast, sched3[-1].ast)
|
||||
self.assertEqual(sched1[-1].ast, sched2[-1].ast)
|
||||
with self.assertRaises(AssertionError): self.assertEqual(sched1[-1].ast, sched3[-1].ast)
|
||||
self.assertLess(time.perf_counter()-st, 2.0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import unittest, pickle, types
|
||||
import numpy as np
|
||||
from test.helpers import assert_equiv_uops
|
||||
from tinygrad import Tensor, TinyJit, Variable, dtypes
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.ops import PatternMatcher, UPat, UOp
|
||||
|
|
@ -83,7 +82,7 @@ class TestPickle(unittest.TestCase):
|
|||
sched = create_schedule([out.lazydata])
|
||||
pk = pickle.dumps(sched)
|
||||
sched_pk = pickle.loads(pk)
|
||||
assert_equiv_uops(sched_pk[-1].ast, sched[-1].ast)
|
||||
self.assertEqual(sched_pk[-1].ast, sched[-1].ast)
|
||||
|
||||
def test_pickle_renderer(self):
|
||||
from tinygrad.device import Device
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from typing import List
|
||||
import unittest, time
|
||||
from test.helpers import assert_equiv_uops
|
||||
from tinygrad import dtypes, Device
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo
|
||||
|
|
@ -262,7 +261,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
# possible
|
||||
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
|
||||
xyzw = tuple(UOp(UOps.GEP, dtypes.float, (val,), (i,)) for i in range(4))
|
||||
assert_equiv_uops(_test_vec(xyzw), val)
|
||||
self.assertEqual(_test_vec(xyzw), val)
|
||||
|
||||
# unaligned
|
||||
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
|
||||
|
|
@ -290,7 +289,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
|
||||
uops = to_uops_list([UOp(UOps.GEP, dtypes.float, (vec,), (i,)) for i in range(vec_size)])
|
||||
for uop, const in zip(uops, consts):
|
||||
assert_equiv_uops(uop, const)
|
||||
self.assertEqual(uop, const)
|
||||
|
||||
def test_wmma_vectorize_fold(self):
|
||||
for i in [2, 4, 8]:
|
||||
|
|
@ -299,7 +298,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[0], acc)
|
||||
self.assertEqual(uops[0], acc)
|
||||
self.assertEqual(len(uops), 1)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
|
|
@ -308,7 +307,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[0], acc)
|
||||
self.assertEqual(uops[0], acc)
|
||||
self.assertEqual(len(uops), 1)
|
||||
|
||||
@unittest.skip("wmma is wrong here, it needs an arg")
|
||||
|
|
@ -321,7 +320,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[-1], wmma)
|
||||
self.assertEqual(uops[-1], wmma)
|
||||
|
||||
for i in [4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
|
|
@ -331,7 +330,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[-1], wmma)
|
||||
self.assertEqual(uops[-1], wmma)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
|
|
@ -340,7 +339,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[-1], wmma)
|
||||
self.assertEqual(uops[-1], wmma)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
|
|
@ -349,7 +348,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[-1], wmma)
|
||||
self.assertEqual(uops[-1], wmma)
|
||||
|
||||
def test_cast_alu_fold(self):
|
||||
d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0)
|
||||
|
|
@ -395,9 +394,9 @@ class TestUOpGraph(unittest.TestCase):
|
|||
uops = to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, idx, ld1+ld0))])
|
||||
ld0, ld1 = uops[-1].src[2].src
|
||||
# ld0 becomes the invalid value
|
||||
assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
|
||||
self.assertEqual(ld1, UOp.const(dtypes.int, 2))
|
||||
# the gate and invalid value are deleted from ld1
|
||||
assert_equiv_uops(ld0, UOp.load(glbl2, idx, dtype=dtypes.int))
|
||||
self.assertEqual(ld0, UOp.load(glbl2, idx, dtype=dtypes.int))
|
||||
|
||||
def test_fold_gated_load_local(self):
|
||||
glbl0 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
|
|
@ -410,9 +409,9 @@ class TestUOpGraph(unittest.TestCase):
|
|||
uops = to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, lidx, ld1+ld0))])
|
||||
ld0, ld1 = uops[-1].src[2].src
|
||||
# ld0 becomes the invalid value
|
||||
assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
|
||||
self.assertEqual(ld1, UOp.const(dtypes.int, 2))
|
||||
# the gate and invalid value are deleted from ld1
|
||||
assert_equiv_uops(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int))
|
||||
self.assertEqual(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int))
|
||||
|
||||
def test_fold_gated_store(self):
|
||||
glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
|
|
@ -424,7 +423,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
uops = to_uops_list([st0, st1])
|
||||
# only the second store happens
|
||||
self.assertEqual(len(uops), 4)
|
||||
assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val))
|
||||
self.assertEqual(uops[-1], UOp.store(glbl, idx1, val))
|
||||
|
||||
@unittest.skip("this is a uop type error")
|
||||
def test_asserts_bad_gate(self):
|
||||
|
|
@ -679,7 +678,7 @@ class TestIFUOps(unittest.TestCase):
|
|||
sink = gate_rewrite(sink)
|
||||
if_uops = [u for u in sink.parents if u.op is UOps.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
assert_equiv_uops(if_uops[0].src[0], gate)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
self.assertEqual(len(st.src), 3)
|
||||
|
||||
|
|
@ -697,7 +696,7 @@ class TestIFUOps(unittest.TestCase):
|
|||
sink = gate_rewrite(sink)
|
||||
if_uops = [u for u in sink.parents if u.op is UOps.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
assert_equiv_uops(if_uops[0].src[0], gate)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
self.assertEqual(len(st.src), 3)
|
||||
|
||||
|
|
@ -713,7 +712,7 @@ class TestIFUOps(unittest.TestCase):
|
|||
sink = gate_rewrite(sink)
|
||||
if_uops = [u for u in sink.parents if u.op is UOps.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
assert_equiv_uops(if_uops[0].src[0], gate)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
self.assertEqual(len(st.src), 3)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from tinygrad.engine.schedule import create_schedule, to_si
|
|||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
|
||||
from test.helpers import is_dtype_supported, assert_equiv_uops
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
def to_uops_list(u:List[UOp], opts=None, skip_check=False) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check)
|
||||
|
||||
|
|
@ -395,7 +395,7 @@ class TestUOpStr(unittest.TestCase):
|
|||
# nice big complicated uop
|
||||
with Context(NOOPT=1):
|
||||
sink = UOp(UOps.SINK, dtypes.void, (get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops[-1],))
|
||||
assert_equiv_uops(sink, eval(str(sink)))
|
||||
self.assertEqual(sink, eval(str(sink)))
|
||||
|
||||
def test_vectorized_str(self):
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.int.vec(4), tuple(UOp.const(dtypes.int, x) for x in range(4)))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue