mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
don't strip sink in to_uops_list [pr] (#14111)
This commit is contained in:
parent
cad7feec02
commit
6b0a9f5ee6
2 changed files with 28 additions and 28 deletions
|
|
@ -49,7 +49,7 @@ xfail_broken_const_wraparound = pytest.mark.xfail(reason="const folding does not
|
|||
class TestModularWraparound(unittest.TestCase):
|
||||
def _test(self, uop:UOp, expected:int):
|
||||
results = to_uops_list([uop])
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(len(results), 2) # +1 for SINK
|
||||
self.assertEqual(results[0].op, Ops.CONST)
|
||||
self.assertEqual(results[0].dtype, uop.dtype)
|
||||
self.assertEqual(results[0].arg, expected)
|
||||
|
|
@ -198,8 +198,8 @@ class TestUOpGraph(unittest.TestCase):
|
|||
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
|
||||
out = UOp(Ops.ADD, dtypes.float, (c1, c2))
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 1)
|
||||
out = uops[-1]
|
||||
self.assertEqual(len(uops), 2) # +1 for SINK
|
||||
out = uops[-2]
|
||||
self.assertEqual(out.op, Ops.CONST)
|
||||
self.assertEqual(out.arg, 3.0)
|
||||
|
||||
|
|
@ -210,8 +210,8 @@ class TestUOpGraph(unittest.TestCase):
|
|||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
out = UOp(Ops.WHERE, dtypes.float, (vc, c1, c1))
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 1)
|
||||
out = uops[-1]
|
||||
self.assertEqual(len(uops), 2) # +1 for SINK
|
||||
out = uops[-2]
|
||||
self.assertEqual(out.op, Ops.CONST)
|
||||
self.assertEqual(out.arg, 1.0)
|
||||
|
||||
|
|
@ -221,8 +221,8 @@ class TestUOpGraph(unittest.TestCase):
|
|||
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
|
||||
out = UOp(Ops.WHERE, dtypes.float, (bf, c1, c2))
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 1)
|
||||
out = uops[-1]
|
||||
self.assertEqual(len(uops), 2) # +1 for SINK
|
||||
out = uops[-2]
|
||||
self.assertEqual(out.op, Ops.CONST)
|
||||
self.assertEqual(out.arg, 2.0)
|
||||
|
||||
|
|
@ -230,8 +230,8 @@ class TestUOpGraph(unittest.TestCase):
|
|||
bf = UOp(Ops.CONST, dtypes.bool, arg=False)
|
||||
out = UOp(Ops.CAST, dtypes.int, (bf,))
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 1)
|
||||
out = uops[-1]
|
||||
self.assertEqual(len(uops), 2) # +1 for SINK
|
||||
out = uops[-2]
|
||||
self.assertEqual(out.op, Ops.CONST)
|
||||
self.assertEqual(out.arg, 0)
|
||||
|
||||
|
|
@ -239,8 +239,8 @@ class TestUOpGraph(unittest.TestCase):
|
|||
bf = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
out = UOp(Ops.BITCAST, dtypes.uint32, (bf,))
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 1)
|
||||
out = uops[-1]
|
||||
self.assertEqual(len(uops), 2) # +1 for SINK
|
||||
out = uops[-2]
|
||||
self.assertEqual(out.op, Ops.CONST)
|
||||
self.assertEqual(out.arg, 0x3F800000)
|
||||
|
||||
|
|
@ -249,7 +249,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
bf = UOp(Ops.CONST, dtypes.uint8, arg=0x3F)
|
||||
out = UOp(Ops.BITCAST, dtypes.half, (bf,))
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 1)
|
||||
self.assertEqual(len(uops), 2) # +1 for SINK
|
||||
|
||||
@unittest.skip("this test isn't valid uops")
|
||||
def test_noop_vectorize_fold(self):
|
||||
|
|
@ -276,7 +276,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
if DEBUG >= 4:
|
||||
from tinygrad import Device
|
||||
print(Device[Device.DEFAULT].renderer.render(uops))
|
||||
return uops[-1].src[-1]
|
||||
return uops[-2].src[-1] # -2 to skip SINK
|
||||
|
||||
# possible
|
||||
val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1.index(idx),))
|
||||
|
|
@ -321,7 +321,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assertEqual(uops[0], acc)
|
||||
self.assertEqual(len(uops), 1)
|
||||
self.assertEqual(len(uops), 2) # +1 for SINK
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i))
|
||||
|
|
@ -330,7 +330,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assertEqual(uops[0], acc)
|
||||
self.assertEqual(len(uops), 1)
|
||||
self.assertEqual(len(uops), 2) # +1 for SINK
|
||||
|
||||
@unittest.skip("wmma is wrong here, it needs an arg")
|
||||
def test_wmma_vectorize_no_fold(self):
|
||||
|
|
@ -342,7 +342,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assertEqual(uops[-1], wmma)
|
||||
self.assertEqual(uops[-2], wmma) # -2 to skip SINK
|
||||
|
||||
for i in [4, 8]:
|
||||
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
|
|
@ -352,7 +352,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assertEqual(uops[-1], wmma)
|
||||
self.assertEqual(uops[-2], wmma) # -2 to skip SINK
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i),
|
||||
|
|
@ -361,7 +361,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assertEqual(uops[-1], wmma)
|
||||
self.assertEqual(uops[-2], wmma) # -2 to skip SINK
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
|
|
@ -370,7 +370,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assertEqual(uops[-1], wmma)
|
||||
self.assertEqual(uops[-2], wmma) # -2 to skip SINK
|
||||
|
||||
def test_cast_alu_fold(self):
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0)
|
||||
|
|
@ -399,8 +399,8 @@ class TestUOpGraph(unittest.TestCase):
|
|||
vc = UOp(Ops.ADD, dtypes.int, (v, c2))
|
||||
out = UOp(Ops.ADD, dtypes.int, (vc, c4))
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 3)
|
||||
out = uops[-1]
|
||||
self.assertEqual(len(uops), 4) # +1 for SINK
|
||||
out = uops[-2] # -2 to skip SINK
|
||||
self.assertEqual(out.op, Ops.ADD)
|
||||
self.assertEqual(out.src[1].op, Ops.CONST)
|
||||
self.assertEqual(out.src[1].arg, 6)
|
||||
|
|
@ -415,7 +415,8 @@ class TestUOpGraph(unittest.TestCase):
|
|||
def test_sub_with_cast_folds(self):
|
||||
a = Variable("a", 0, 5)
|
||||
uops = to_uops_list([a.cast(dtypes.int)+(-a).cast(dtypes.int)])
|
||||
assert uops == [UOp.const(dtypes.int, 0)]
|
||||
assert uops[0] == UOp.const(dtypes.int, 0)
|
||||
assert uops[-1].op == Ops.SINK
|
||||
|
||||
def test_where_on_gated_load_fold(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
|
|
@ -486,7 +487,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
ld0 = glbl1.index(UOp.invalid())
|
||||
ld1 = glbl2.index(idx.valid(UOp.const(dtypes.bool, True)))
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
|
||||
ld0 = uops[-1].src[-1]
|
||||
ld0 = uops[-2].src[-1] # -2 to skip SINK
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0, UOp.load(glbl2.index(idx, ptr=True), dtype=dtypes.int))
|
||||
|
||||
|
|
@ -500,7 +501,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
ld1 = smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True)))
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))])
|
||||
|
||||
ld0 = uops[-1].src[-1]
|
||||
ld0 = uops[-2].src[-1] # -2 to skip SINK
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2, ptr=True))
|
||||
|
||||
|
|
@ -513,8 +514,8 @@ class TestUOpGraph(unittest.TestCase):
|
|||
st1 = glbl.index(idx0.valid(UOp.const(dtypes.bool, True)), ptr=True).store(val)
|
||||
uops = to_uops_list([st0, st1])
|
||||
# only the second store happens
|
||||
self.assertEqual(len(uops), 5)
|
||||
self.assertEqual(uops[-1], glbl.index(idx1, ptr=True).store(val))
|
||||
self.assertEqual(len(uops), 6) # +1 for SINK
|
||||
self.assertEqual(uops[-2], glbl.index(idx1, ptr=True).store(val)) # -2 to skip SINK
|
||||
|
||||
@unittest.skip("this is a uop type error")
|
||||
def test_asserts_bad_gate(self):
|
||||
|
|
|
|||
|
|
@ -20,10 +20,9 @@ from dataclasses import replace
|
|||
def to_uops_list(u:list[UOp], ren=None) -> list[UOp]:
|
||||
sink = UOp.group(*u)
|
||||
for r in sink.ranges: sink = sink.end(r)
|
||||
# we strip the SINK here for legacy reasons
|
||||
ret = get_uops(sink.sink(arg=KernelInfo(opts_to_apply=())), ren)
|
||||
assert ret[-1].op is Ops.SINK
|
||||
return ret[:-1]
|
||||
return ret
|
||||
|
||||
def _uops_to_prg(uops_list):
|
||||
prg = get_program(UOp.sink(*uops_list), Device[Device.DEFAULT].renderer)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue