mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Fix z3 rendering of floats in indexing (#11740)
* Fix floating point comparison in indexing * wrap in noop * update tests * improve rules for loading and comparing floats * add test cast to bool
This commit is contained in:
parent
4267c45db3
commit
5a6817d5f8
3 changed files with 49 additions and 9 deletions
|
|
@ -441,18 +441,16 @@ class TestUOpGraph(unittest.TestCase):
|
|||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 20)),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
||||
|
||||
@unittest.skip("outdated")
|
||||
def test_in_out_of_bounds_access_gated_store(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), src=(), arg=0)
|
||||
v = Variable("v", 0, 20)
|
||||
st0 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v), UOp.const(dtypes.int, 0), v<16))
|
||||
st0 = UOp(Ops.STORE, dtypes.void, src=(glbl0.index(v), UOp.const(dtypes.int, 0), UOp(Ops.IF, src=(v<16,))))
|
||||
to_uops_list([st0])
|
||||
|
||||
st1 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v), v, v<20))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([st1])
|
||||
|
||||
@unittest.skip("outdated")
|
||||
def test_in_bounds_access_gated_local(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
# Define buffers
|
||||
|
|
@ -465,7 +463,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
|
||||
gate = (gidx<400) & (lidx<8)
|
||||
|
||||
local_store = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx), UOp.const(dtypes.uint, 1), lidx<8))
|
||||
local_store = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx), UOp.const(dtypes.uint, 1), UOp(Ops.IF, src=(lidx<8,))))
|
||||
|
||||
barrier = UOp(Ops.BARRIER, dtypes.void, (local_store,))
|
||||
if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier))
|
||||
|
|
@ -477,6 +475,34 @@ class TestUOpGraph(unittest.TestCase):
|
|||
global_store = UOp(Ops.STORE, dtypes.void, (gbuf.index(gidx), local_load))
|
||||
to_uops_list([global_store])
|
||||
|
||||
def test_load_with_float_in_index(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
ridx = UOp.range(dtypes.int, 20, 0)
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16))),))
|
||||
to_uops_list([ld0])
|
||||
glblfloat = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(20), (), 0)
|
||||
ldfloat = UOp(Ops.LOAD, dtypes.float, (glblfloat.index(ridx),))
|
||||
i = (ldfloat+3.14).cast(dtypes.int)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16))),))
|
||||
|
||||
def test_load_cast_to_bool(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
|
||||
ridx = UOp.range(dtypes.int, 20, 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx, ridx.cast(dtypes.bool).logical_not()),))
|
||||
to_uops_list([ld0])
|
||||
|
||||
@unittest.skip("Bool load is not supported yet")
|
||||
def test_load_mask(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
|
||||
ridx = UOp.range(dtypes.int, 20, 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask),)))
|
||||
to_uops_list([ld0])
|
||||
|
||||
def test_out_of_bounds_off_by_one_access(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
|
|
|
|||
|
|
@ -109,6 +109,9 @@ class GroupOp:
|
|||
# BinaryOps that satisfy f(x,x)=x see https://en.wikipedia.org/wiki/Idempotence
|
||||
Idempotent = {Ops.OR, Ops.AND, Ops.MAX}
|
||||
|
||||
# These can change the dtype to bool
|
||||
Comparison = {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}
|
||||
|
||||
# do not preserve f(0) = 0
|
||||
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,14 +23,25 @@ try:
|
|||
(UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(x.arg, 0, x.src[0].arg-1, ctx[0]))),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0]))),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"ridx{x.arg}", 0, x.src[0].arg-1, ctx[0]))),
|
||||
# float loads only become a variable when they get cast to int/bool
|
||||
(UPat(Ops.LOAD, dtypes.ints, name="x"),
|
||||
lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.vmin, x.vmax, ctx[0]))),
|
||||
(UPat(Ops.CONST, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx))),
|
||||
(UPat(Ops.CAST, dtype=dtypes.ints+(dtypes.bool,), src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
||||
(UPat(Ops.CAST, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"cast{ctx[1].setdefault(x, len(ctx[1]))}", x.vmin, x.vmax, ctx[0]))),
|
||||
lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0]))),
|
||||
(UPat(Ops.CONST, dtype=dtypes.ints+(dtypes.bool,), name="x"),
|
||||
lambda x,ctx: UOp(Ops.NOOP, arg=(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx))),
|
||||
# z3 can cast from bool to int automatically
|
||||
(UPat(Ops.CAST, dtype=dtypes.ints, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
||||
(UPat(Ops.CAST, dtype=dtypes.bool, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=(x.src[0].arg!=0))),
|
||||
# if the source of the cast is not a noop it means that it is a float and so we create a new variable
|
||||
(UPat(Ops.CAST, dtype=dtypes.ints, name="x"), lambda x,ctx:
|
||||
UOp(Ops.NOOP, arg=create_bounded(f"cast{ctx[1].setdefault(x, len(ctx[1]))}", x.dtype.min, x.dtype.max, ctx[0]))),
|
||||
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x,ctx:
|
||||
UOp(Ops.NOOP, arg=z3.Bool(f"cast{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx))),
|
||||
(UPat(Ops.XOR, src=UPat(Ops.NOOP), name="x"),
|
||||
lambda x: UOp(Ops.NOOP, arg=z3.BV2Int(z3_alu[x.op](*(z3.Int2BV(s.arg, x.dtype.itemsize*8) for s in x.src))))),
|
||||
(UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=z3_alu[x.op](*(s.arg for s in x.src)))),
|
||||
# A comparison between floats introduces a new bool variable
|
||||
(UPat(GroupOp.Comparison, src=UPat(dtype=dtypes.floats), name="x"), lambda x,ctx:
|
||||
UOp(Ops.NOOP, arg=z3.Bool(f"float_cmp{ctx[1].setdefault(x, len(ctx[1]))}",ctx=ctx[0].ctx))),
|
||||
])
|
||||
|
||||
z3_imported = True
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue