don't strip sink in to_uops_list [pr] (#14111)

This commit is contained in:
chenyu 2026-01-12 11:19:03 -05:00 committed by GitHub
commit 6b0a9f5ee6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 28 additions and 28 deletions

View file

@ -49,7 +49,7 @@ xfail_broken_const_wraparound = pytest.mark.xfail(reason="const folding does not
class TestModularWraparound(unittest.TestCase):
def _test(self, uop:UOp, expected:int):
results = to_uops_list([uop])
self.assertEqual(len(results), 1)
self.assertEqual(len(results), 2) # +1 for SINK
self.assertEqual(results[0].op, Ops.CONST)
self.assertEqual(results[0].dtype, uop.dtype)
self.assertEqual(results[0].arg, expected)
@ -198,8 +198,8 @@ class TestUOpGraph(unittest.TestCase):
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
out = UOp(Ops.ADD, dtypes.float, (c1, c2))
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
self.assertEqual(len(uops), 2) # +1 for SINK
out = uops[-2]
self.assertEqual(out.op, Ops.CONST)
self.assertEqual(out.arg, 3.0)
@ -210,8 +210,8 @@ class TestUOpGraph(unittest.TestCase):
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
out = UOp(Ops.WHERE, dtypes.float, (vc, c1, c1))
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
self.assertEqual(len(uops), 2) # +1 for SINK
out = uops[-2]
self.assertEqual(out.op, Ops.CONST)
self.assertEqual(out.arg, 1.0)
@ -221,8 +221,8 @@ class TestUOpGraph(unittest.TestCase):
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
out = UOp(Ops.WHERE, dtypes.float, (bf, c1, c2))
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
self.assertEqual(len(uops), 2) # +1 for SINK
out = uops[-2]
self.assertEqual(out.op, Ops.CONST)
self.assertEqual(out.arg, 2.0)
@ -230,8 +230,8 @@ class TestUOpGraph(unittest.TestCase):
bf = UOp(Ops.CONST, dtypes.bool, arg=False)
out = UOp(Ops.CAST, dtypes.int, (bf,))
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
self.assertEqual(len(uops), 2) # +1 for SINK
out = uops[-2]
self.assertEqual(out.op, Ops.CONST)
self.assertEqual(out.arg, 0)
@ -239,8 +239,8 @@ class TestUOpGraph(unittest.TestCase):
bf = UOp(Ops.CONST, dtypes.float, arg=1.0)
out = UOp(Ops.BITCAST, dtypes.uint32, (bf,))
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
self.assertEqual(len(uops), 2) # +1 for SINK
out = uops[-2]
self.assertEqual(out.op, Ops.CONST)
self.assertEqual(out.arg, 0x3F800000)
@ -249,7 +249,7 @@ class TestUOpGraph(unittest.TestCase):
bf = UOp(Ops.CONST, dtypes.uint8, arg=0x3F)
out = UOp(Ops.BITCAST, dtypes.half, (bf,))
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
self.assertEqual(len(uops), 2) # +1 for SINK
@unittest.skip("this test isn't valid uops")
def test_noop_vectorize_fold(self):
@ -276,7 +276,7 @@ class TestUOpGraph(unittest.TestCase):
if DEBUG >= 4:
from tinygrad import Device
print(Device[Device.DEFAULT].renderer.render(uops))
return uops[-1].src[-1]
return uops[-2].src[-1] # -2 to skip SINK
# possible
val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1.index(idx),))
@ -321,7 +321,7 @@ class TestUOpGraph(unittest.TestCase):
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[0], acc)
self.assertEqual(len(uops), 1)
self.assertEqual(len(uops), 2) # +1 for SINK
for i in [2, 4, 8]:
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i))
@ -330,7 +330,7 @@ class TestUOpGraph(unittest.TestCase):
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[0], acc)
self.assertEqual(len(uops), 1)
self.assertEqual(len(uops), 2) # +1 for SINK
@unittest.skip("wmma is wrong here, it needs an arg")
def test_wmma_vectorize_no_fold(self):
@ -342,7 +342,7 @@ class TestUOpGraph(unittest.TestCase):
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[-1], wmma)
self.assertEqual(uops[-2], wmma) # -2 to skip SINK
for i in [4, 8]:
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
@ -352,7 +352,7 @@ class TestUOpGraph(unittest.TestCase):
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[-1], wmma)
self.assertEqual(uops[-2], wmma) # -2 to skip SINK
for i in [2, 4, 8]:
vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i),
@ -361,7 +361,7 @@ class TestUOpGraph(unittest.TestCase):
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[-1], wmma)
self.assertEqual(uops[-2], wmma) # -2 to skip SINK
for i in [2, 4, 8]:
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
@ -370,7 +370,7 @@ class TestUOpGraph(unittest.TestCase):
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[-1], wmma)
self.assertEqual(uops[-2], wmma) # -2 to skip SINK
def test_cast_alu_fold(self):
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0)
@ -399,8 +399,8 @@ class TestUOpGraph(unittest.TestCase):
vc = UOp(Ops.ADD, dtypes.int, (v, c2))
out = UOp(Ops.ADD, dtypes.int, (vc, c4))
uops = to_uops_list([out])
self.assertEqual(len(uops), 3)
out = uops[-1]
self.assertEqual(len(uops), 4) # +1 for SINK
out = uops[-2] # -2 to skip SINK
self.assertEqual(out.op, Ops.ADD)
self.assertEqual(out.src[1].op, Ops.CONST)
self.assertEqual(out.src[1].arg, 6)
@ -415,7 +415,8 @@ class TestUOpGraph(unittest.TestCase):
def test_sub_with_cast_folds(self):
a = Variable("a", 0, 5)
uops = to_uops_list([a.cast(dtypes.int)+(-a).cast(dtypes.int)])
assert uops == [UOp.const(dtypes.int, 0)]
assert uops[0] == UOp.const(dtypes.int, 0)
assert uops[-1].op == Ops.SINK
def test_where_on_gated_load_fold(self):
ridx0 = UOp.range(100, 0)
@ -486,7 +487,7 @@ class TestUOpGraph(unittest.TestCase):
ld0 = glbl1.index(UOp.invalid())
ld1 = glbl2.index(idx.valid(UOp.const(dtypes.bool, True)))
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(idx), ld1+ld0))])
ld0 = uops[-1].src[-1]
ld0 = uops[-2].src[-1] # -2 to skip SINK
# the gate and invalid value are deleted from ld1
self.assertEqual(ld0, UOp.load(glbl2.index(idx, ptr=True), dtype=dtypes.int))
@ -500,7 +501,7 @@ class TestUOpGraph(unittest.TestCase):
ld1 = smem.after(barrier).index((lidx+2).valid(UOp.const(dtypes.bool, True)))
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0.index(lidx), ld1+ld0))])
ld0 = uops[-1].src[-1]
ld0 = uops[-2].src[-1] # -2 to skip SINK
# the gate and invalid value are deleted from ld1
self.assertEqual(ld0.src[0], smem.after(barrier).index(lidx+2, ptr=True))
@ -513,8 +514,8 @@ class TestUOpGraph(unittest.TestCase):
st1 = glbl.index(idx0.valid(UOp.const(dtypes.bool, True)), ptr=True).store(val)
uops = to_uops_list([st0, st1])
# only the second store happens
self.assertEqual(len(uops), 5)
self.assertEqual(uops[-1], glbl.index(idx1, ptr=True).store(val))
self.assertEqual(len(uops), 6) # +1 for SINK
self.assertEqual(uops[-2], glbl.index(idx1, ptr=True).store(val)) # -2 to skip SINK
@unittest.skip("this is a uop type error")
def test_asserts_bad_gate(self):

View file

@ -20,10 +20,9 @@ from dataclasses import replace
def to_uops_list(u:list[UOp], ren=None) -> list[UOp]:
sink = UOp.group(*u)
for r in sink.ranges: sink = sink.end(r)
# we strip the SINK here for legacy reasons
ret = get_uops(sink.sink(arg=KernelInfo(opts_to_apply=())), ren)
assert ret[-1].op is Ops.SINK
return ret[:-1]
return ret
def _uops_to_prg(uops_list):
prg = get_program(UOp.sink(*uops_list), Device[Device.DEFAULT].renderer)