import unittest from tinygrad.uop import Ops from tinygrad.uop.ops import UOp, dtypes, graph_rewrite from tinygrad.renderer.isa.x86 import X86Renderer, X86Ops from tinygrad.renderer.isa import IselContext class TestIselX86(unittest.TestCase): def isel_rewrite(self, x:UOp): return graph_rewrite(x, X86Renderer().isel_matcher, IselContext(x), bottom_up=True) def test_cmove(self): a = UOp.variable("a", 0, 0, dtypes.int32) b = UOp.variable("b", 0, 0, dtypes.int32) c = (a < b).where(a, b) d = (a != b).where(a, b) f = c + d n = self.isel_rewrite(f) self.assertTrue(n.src[0].arg is X86Ops.CMOVL and n.src[1].arg is X86Ops.CMOVNE) # both comparisons become the same instruction self.assertTrue(n.src[0].src[2] == n.src[1].src[2] and n.src[0].src[2].arg is X86Ops.CMP) def test_vmax(self): a = UOp.variable("a", 0, 0, dtypes.float32) b = UOp.variable("b", 0, 0, dtypes.float32) n = self.isel_rewrite((a < b).where(b, a)) self.assertTrue(n.arg is X86Ops.VMAXSS) def test_vmin(self): a = UOp.variable("a", 0, 0, dtypes.float32) b = UOp.variable("b", 0, 0, dtypes.float32) n = self.isel_rewrite((a < b).where(a, b)) self.assertTrue(n.arg is X86Ops.VMINSS) def test_vpbroadcast(self): a = UOp.variable("a", 0, 0, dtypes.int32) n = self.isel_rewrite(a.broadcast(4)) # need to move src from gpr to xmm before broadcasting self.assertTrue(n.arg is X86Ops.VPBROADCASTD and n.src[0].arg is X86Ops.VMOVD) # if we can fuse a load we can skip the move and access memory directly load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load() n = self.isel_rewrite(load.broadcast(4)) self.assertTrue(n.arg is X86Ops.VPBROADCASTD and len(n.src) == 3) # lower 2 32 bits must come from the same register and upper 2 32 bits must come from the same register def test_vshufps(self): a = UOp.variable("a", 0, 0, dtypes.float32.vec(4)) b = UOp.variable("b", 0, 0, dtypes.float32.vec(4)) c = UOp.variable("c", 0, 0, dtypes.float32) d = UOp.variable("d", 0, 0, dtypes.float32) # shuffle between 2 vectors n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(0), a.gep(1), b.gep(2), b.gep(3)))) self.assertTrue(n.arg is X86Ops.VSHUFPS) # shuffle between 2 scalars n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (c, c, d, d))) self.assertTrue(n.arg is X86Ops.VSHUFPS) # shuffle between vector and scalar n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(0), a.gep(1), c, c))) self.assertTrue(n.arg is X86Ops.VSHUFPS) # shuffle between 1 vector n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(1), a.gep(2), a.gep(3), a.gep(0)))) self.assertTrue(n.arg is X86Ops.VSHUFPS and n.src[0] is n.src[1]) # a shuffle between 1 scalar is just a broadcast and matches X86Ops.VBROADCASTSS to allow for load fusion # this is the fallback slow VECTORIZE, 1 vinsertps per src in VECTORIZE def test_vinsertps(self): a = UOp.variable("a", 0, 0, dtypes.float32.vec(4)) b = UOp.variable("b", 0, 0, dtypes.float32.vec(4)) c = UOp.variable("c", 0, 0, dtypes.float32.vec(4)) d = UOp.variable("e", 0, 0, dtypes.float32) # pack 1 from vector and 1 from scalar, moving 0th element to position 0 does nothing so only 1 vinsertps is generated n = self.isel_rewrite(UOp(Ops.VECTORIZE, dtypes.float32.vec(2), (a.gep(0), d))) self.assertTrue(n.arg is X86Ops.VINSERTPS and n.src[0].arg is X86Ops.DEFINE_REG) # interleaved shuffle between 2 vectors n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(0), b.gep(1), a.gep(2), b.gep(3)))) self.assertTrue(n.arg is X86Ops.VINSERTPS) # shuffle between 4 sources n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(3), b.gep(2), c.gep(1), d))) self.assertTrue(n.arg is X86Ops.VINSERTPS) # complex address is [base + index*scale + displacement] def test_complex_address(self): a = UOp.variable("a", 0, 0, dtypes.int32) load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(a + 1, ptr=True).load() n = self.isel_rewrite(load) # base is PARAM, index is "a" self.assertTrue(n.src[0].arg is X86Ops.DEFINE_REG and n.src[1].arg is X86Ops.DEFINE_REG) # displacement is the constant in "a" scaled to the buffer element size, dtype is int8 when the value fits otherwise int32 self.assertTrue(n.src[2].arg is X86Ops.IMM and n.src[2].dtype is dtypes.int8 and n.src[2].tag == 4) def test_fuse_load(self): load1 = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load() load2 = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 1), ptr=True).load() n = self.isel_rewrite(load1 + load2) self.assertTrue(len(n.src) == 4) # don't fuse when used multiple times def test_dont_fuse_load_diff_users(self): load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load() add = load + 1 n = self.isel_rewrite(add + load) self.assertTrue(len(n.src) == 2) def test_dont_fuse_load_same_user(self): load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load() n = self.isel_rewrite(load * load) self.assertTrue(len(n.src) == 2) # TODO: might want to check that load isn't part of another range when fusing if __name__ == "__main__": unittest.main()