Fix z3 rendering of floats in indexing (#11740)

* Fix floating point comparison in indexing

* wrap in noop

* update tests

* improve rules for loading and comparing floats

* add test cast to bool
This commit is contained in:
Sieds Lykles 2025-08-23 05:56:19 +02:00 committed by GitHub
commit 5a6817d5f8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 49 additions and 9 deletions

View file

@ -441,18 +441,16 @@ class TestUOpGraph(unittest.TestCase):
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 20)),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
@unittest.skip("outdated")
def test_in_out_of_bounds_access_gated_store(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), src=(), arg=0)
v = Variable("v", 0, 20)
st0 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v), UOp.const(dtypes.int, 0), v<16))
st0 = UOp(Ops.STORE, dtypes.void, src=(glbl0.index(v), UOp.const(dtypes.int, 0), UOp(Ops.IF, src=(v<16,))))
to_uops_list([st0])
st1 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v), v, v<20))
with self.assertRaises(RuntimeError): to_uops_list([st1])
@unittest.skip("outdated")
def test_in_bounds_access_gated_local(self):
with Context(IGNORE_OOB=0):
# Define buffers
@ -465,7 +463,7 @@ class TestUOpGraph(unittest.TestCase):
gate = (gidx<400) & (lidx<8)
local_store = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx), UOp.const(dtypes.uint, 1), lidx<8))
local_store = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx), UOp.const(dtypes.uint, 1), UOp(Ops.IF, src=(lidx<8,))))
barrier = UOp(Ops.BARRIER, dtypes.void, (local_store,))
if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier))
@ -477,6 +475,34 @@ class TestUOpGraph(unittest.TestCase):
global_store = UOp(Ops.STORE, dtypes.void, (gbuf.index(gidx), local_load))
to_uops_list([global_store])
def test_load_with_float_in_index(self):
with Context(IGNORE_OOB=0):
ridx = UOp.range(dtypes.int, 20, 0)
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16))),))
to_uops_list([ld0])
glblfloat = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(20), (), 0)
ldfloat = UOp(Ops.LOAD, dtypes.float, (glblfloat.index(ridx),))
i = (ldfloat+3.14).cast(dtypes.int)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16))),))
def test_load_cast_to_bool(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
ridx = UOp.range(dtypes.int, 20, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx, ridx.cast(dtypes.bool).logical_not()),))
to_uops_list([ld0])
@unittest.skip("Bool load is not supported yet")
def test_load_mask(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
ridx = UOp.range(dtypes.int, 20, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask),)))
to_uops_list([ld0])
def test_out_of_bounds_off_by_one_access(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)

View file

@ -109,6 +109,9 @@ class GroupOp:
# BinaryOps that satisfy f(x,x)=x see https://en.wikipedia.org/wiki/Idempotence
Idempotent = {Ops.OR, Ops.AND, Ops.MAX}
# These can change the dtype to bool
Comparison = {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}
# do not preserve f(0) = 0
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}

View file

@ -23,14 +23,25 @@ try:
(UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(x.arg, 0, x.src[0].arg-1, ctx[0]))),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0]))),
(UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"ridx{x.arg}", 0, x.src[0].arg-1, ctx[0]))),
# float loads only become a variable when they get cast to int/bool
(UPat(Ops.LOAD, dtypes.ints, name="x"),
lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.vmin, x.vmax, ctx[0]))),
(UPat(Ops.CONST, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx))),
(UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.bool,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
(UPat(Ops.CAST, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"cast{ctx[1].setdefault(x, len(ctx[1]))}", x.vmin, x.vmax, ctx[0]))),
lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0]))),
(UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,), name="x"),
lambda x,ctx: UOp(Ops.NOOP, arg=(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx))),
# z3 can cast from bool to int automatically
(UPat(Ops.CAST, dtype=dtypes.ints, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
(UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=(x.src[0].arg!=0))),
# if the source of the cast is not a noop it means that it is a float and so we create a new variable
(UPat(Ops.CAST, dtype=dtypes.ints, name="x"), lambda x,ctx:
UOp(Ops.NOOP, arg=create_bounded(f"cast{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0]))),
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x,ctx:
UOp(Ops.NOOP, arg=z3.Bool(f"cast{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx))),
(UPat(Ops.XOR, src=UPat(Ops.NOOP), name="x"),
lambda x: UOp(Ops.NOOP, arg=z3.BV2Int(z3_alu[x.op](*(z3.Int2BV(s.arg, x.dtype.itemsize*8) for s in x.src))))),
(UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=z3_alu[x.op](*(s.arg for s in x.src)))),
# A comparison between floats introduces a new bool variable
(UPat(GroupOp.Comparison, src=UPat(dtype=dtypes.floats), name="x"), lambda x,ctx:
UOp(Ops.NOOP, arg=z3.Bool(f"float_cmp{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx))),
])
z3_imported = True