more isel tests

This commit is contained in:
ttomsa 2026-03-06 00:28:36 +00:00
commit 41f2bd8a05
2 changed files with 93 additions and 56 deletions

View file

@ -4,9 +4,18 @@ from tinygrad.uop.ops import UOp, dtypes, graph_rewrite
from tinygrad.renderer.isa.x86 import X86Renderer, X86Ops
from tinygrad.renderer.isa import IselContext
# these tests are to catch changes that don't cause incorrect codegen but cause worse codegen
class TestIselX86(unittest.TestCase):
def isel_rewrite(self, x:UOp): return graph_rewrite(x, X86Renderer().isel_matcher, IselContext(x), bottom_up=True)
def _check_op(self, dt_op, expr):
nargs = expr.__code__.co_argcount
for dt,op in dt_op:
with self.subTest(dtype=dt):
vars = [UOp.variable(str(i), 0, 0, dt) for i in range(nargs)]
n = self.isel_rewrite(expr(*vars))
self.assertIs(n.arg, op)
def test_cmove(self):
a = UOp.variable("a", 0, 0, dtypes.int32)
b = UOp.variable("b", 0, 0, dtypes.int32)
@ -19,16 +28,26 @@ class TestIselX86(unittest.TestCase):
self.assertTrue(n.src[0].src[2] == n.src[1].src[2] and n.src[0].src[2].arg is X86Ops.CMP)
def test_vmax(self):
a = UOp.variable("a", 0, 0, dtypes.float32)
b = UOp.variable("b", 0, 0, dtypes.float32)
n = self.isel_rewrite((a < b).where(b, a))
self.assertTrue(n.arg is X86Ops.VMAXSS)
dt_op = [(dtypes.float32, X86Ops.VMAXSS), (dtypes.float64, X86Ops.VMAXSD),
(dtypes.float32.vec(4), X86Ops.VMAXPS), (dtypes.float64.vec(4), X86Ops.VMAXPD)]
self._check_op(dt_op, lambda a,b: (a < b).where(b, a))
def test_vmin(self):
a = UOp.variable("a", 0, 0, dtypes.float32)
b = UOp.variable("b", 0, 0, dtypes.float32)
n = self.isel_rewrite((a < b).where(a, b))
self.assertTrue(n.arg is X86Ops.VMINSS)
dt_op = [(dtypes.float32, X86Ops.VMINSS), (dtypes.float64, X86Ops.VMINSD),
(dtypes.float32.vec(4), X86Ops.VMINPS), (dtypes.float64.vec(4), X86Ops.VMINPD)]
self._check_op(dt_op, lambda a,b: (a < b).where(a, b))
def test_vfmadd(self):
dt_op = [(dtypes.float32, X86Ops.VFMADD213SS), (dtypes.float64, X86Ops.VFMADD213SD),
(dtypes.float32.vec(4), X86Ops.VFMADD213PS), (dtypes.float64.vec(4), X86Ops.VFMADD213PD)]
self._check_op(dt_op, lambda a,b,c: a * b + c)
# TODO: shouldn't match fmadd if var is used multiple times
@unittest.expectedFailure
def test_vfmadd_fail(self):
dt_op = [(dtypes.float32, X86Ops.VADDSS), (dtypes.float64, X86Ops.VADDSD),
(dtypes.float32.vec(4), X86Ops.VADDPS), (dtypes.float64.vec(4), X86Ops.VADDPD)]
self._check_op(dt_op, lambda a,b: a * b + b)
def test_vpbroadcast(self):
a = UOp.variable("a", 0, 0, dtypes.int32)
@ -40,41 +59,63 @@ class TestIselX86(unittest.TestCase):
n = self.isel_rewrite(load.broadcast(4))
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and len(n.src) == 3)
# lower 2 32 bits must come from the same register and upper 2 32 bits must come from the same register
def test_vbroadcastss(self):
a = UOp.variable("a", 0, 0, dtypes.float32)
valid = [UOp.vectorize(a, a, a, a), UOp.vectorize(a, a, a, a, a, a, a, a)]
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VBROADCASTSS)
def test_vshufps(self):
a = UOp.variable("a", 0, 0, dtypes.float32.vec(4))
b = UOp.variable("b", 0, 0, dtypes.float32.vec(4))
a = UOp.variable("a", 0, 0, dtypes.float32.vec(8))
b = UOp.variable("b", 0, 0, dtypes.float32.vec(8))
c = UOp.variable("c", 0, 0, dtypes.float32)
d = UOp.variable("d", 0, 0, dtypes.float32)
# shuffle between 2 vectors
n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(0), a.gep(1), b.gep(2), b.gep(3))))
self.assertTrue(n.arg is X86Ops.VSHUFPS)
# shuffle between 2 scalars
n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (c, c, d, d)))
self.assertTrue(n.arg is X86Ops.VSHUFPS)
# shuffle between vector and scalar
n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(0), a.gep(1), c, c)))
self.assertTrue(n.arg is X86Ops.VSHUFPS)
# shuffle between 1 vector
n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(1), a.gep(2), a.gep(3), a.gep(0))))
self.assertTrue(n.arg is X86Ops.VSHUFPS and n.src[0] is n.src[1])
# a shuffle between 1 scalar is just a broadcast and matches X86Ops.VBROADCASTSS to allow for load fusion
# this is the fallback slow VECTORIZE, 1 vinsertps per src in VECTORIZE
valid = [UOp.vectorize(c, c, d, d),
UOp.vectorize(a.gep(0), a.gep(1), c, c),
UOp.vectorize(a.gep(0), a.gep(1), b.gep(2), b.gep(3)),
UOp.vectorize(a.gep(1), a.gep(2), a.gep(3), a.gep(0)),
UOp.vectorize(a.gep(3), a.gep(2), a.gep(1), a.gep(0), a.gep(7), a.gep(6), a.gep(5), a.gep(4)),
UOp.vectorize(a.gep(0), a.gep(0), b.gep(1), b.gep(1), a.gep(4), a.gep(4), b.gep(5), b.gep(5))]
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPS)
invalid = [UOp.vectorize(a.gep(0), a.gep(1), b.gep(4), b.gep(5)),
UOp.vectorize(a.gep(0), a.gep(5), b.gep(2), b.gep(3)),
UOp.vectorize(a.gep(0), a.gep(0), a.gep(0), a.gep(0), a.gep(4), a.gep(4), a.gep(4), a.gep(5)),
UOp.vectorize(a.gep(0), a.gep(0), b.gep(0), b.gep(0), a.gep(4), a.gep(4), b.gep(4), a.gep(4))]
for shuf in invalid: self.assertIsNot(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPS)
def test_vshufpd(self):
a = UOp.variable("a", 0, 0, dtypes.float64.vec(4))
b = UOp.variable("b", 0, 0, dtypes.float64.vec(4))
c = UOp.variable("c", 0, 0, dtypes.float64)
d = UOp.variable("d", 0, 0, dtypes.float64)
valid = [UOp.vectorize(c, d),
UOp.vectorize(a.gep(0), c),
UOp.vectorize(a.gep(1), b.gep(1)),
UOp.vectorize(a.gep(0), b.gep(1), a.gep(2), b.gep(3)),
UOp.vectorize(a.gep(1), a.gep(1), a.gep(3), a.gep(3))]
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPD)
invalid = [UOp.vectorize(c, c, c, c),
UOp.vectorize(a.gep(0), a.gep(1), b.gep(2), b.gep(3)),
UOp.vectorize(a.gep(2), b.gep(3), a.gep(2), b.gep(3)),
UOp.vectorize(a.gep(0), b.gep(1), a.gep(0), b.gep(1))]
for shuf in invalid: self.assertIsNot(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPD)
# this is the fallback slow VECTORIZE, 1 vinsertps per src in VECTORIZE
def test_vinsertps(self):
a = UOp.variable("a", 0, 0, dtypes.float32.vec(4))
b = UOp.variable("b", 0, 0, dtypes.float32.vec(4))
c = UOp.variable("c", 0, 0, dtypes.float32.vec(4))
d = UOp.variable("e", 0, 0, dtypes.float32)
# pack 1 from vector and 1 from scalar, moving 0th element to position 0 does nothing so only 1 vinsertps is generated
n = self.isel_rewrite(UOp(Ops.VECTORIZE, dtypes.float32.vec(2), (a.gep(0), d)))
n = self.isel_rewrite(UOp.vectorize(a.gep(0), d))
self.assertTrue(n.arg is X86Ops.VINSERTPS and n.src[0].arg is X86Ops.DEFINE_REG)
# interleaved shuffle between 2 vectors
n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(0), b.gep(1), a.gep(2), b.gep(3))))
self.assertTrue(n.arg is X86Ops.VINSERTPS)
# shuffle between 4 sources
n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(3), b.gep(2), c.gep(1), d)))
self.assertTrue(n.arg is X86Ops.VINSERTPS)
valid = [UOp.vectorize(a.gep(0), b.gep(1), a.gep(2), b.gep(3)), # TODO: this should be vunpck
UOp.vectorize(a.gep(3), b.gep(2), c.gep(1), d)]
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VINSERTPS)
# complex address is [base + index*scale + displacement]
def test_complex_address(self):
@ -86,21 +127,19 @@ class TestIselX86(unittest.TestCase):
# displacement is the constant in "a" scaled to the buffer element size, dtype is int8 when the value fits otherwise int32
self.assertTrue(n.src[2].arg is X86Ops.IMM and n.src[2].dtype is dtypes.int8 and n.src[2].tag == 4)
def test_fuse_load(self):
def test_fold_load(self):
load1 = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load()
load2 = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 1), ptr=True).load()
n = self.isel_rewrite(load1 + load2)
self.assertTrue(len(n.src) == 4)
# don't fuse when used multiple times
def test_dont_fuse_load_diff_users(self):
# don't fold when used multiple times
def test_dont_fold_load(self):
load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load()
add = load + 1
n = self.isel_rewrite(add + load)
# used by multiple users
n = self.isel_rewrite(load + 1 + load)
self.assertTrue(len(n.src) == 2)
def test_dont_fuse_load_same_user(self):
load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load()
# used mutiple times by same user
n = self.isel_rewrite(load * load)
self.assertTrue(len(n.src) == 2)

View file

@ -231,6 +231,8 @@ reg_strs = {"rax": {4:"eax", 2:"ax", 1:"al"}, "rcx": {4:"ecx", 2:"cx", 1:"cl"},
# ***** X86 instruction selection *****
# if the load is used multiple times we don't fold
def is_foldable_load(ctx:IselContext, x:UOp, s:UOp) -> bool: return s.op is Ops.LOAD and len(ctx.uses[s]) == x.src.count(s) == 1
def base(x:UOp, i:int) -> UOp: return s.src[0] if (s:=x.src[i]).op is Ops.GEP else s
def lane(x:UOp, i:int) -> int: return x.src[i].arg[0] if x.src[i].op is Ops.GEP else 0
def to_int(dt:DType): return {dtypes.float16: dtypes.int16, dtypes.float32: dtypes.int32, dtypes.float64: dtypes.int64}[dt]
def def_reg(dt:DType, reg:Register|None=None) -> UOp: return UOp(Ops.INS, arg=X86Ops.DEFINE_REG, dtype=dt, tag=reg)
def imm(dt:DType, v:int) -> UOp: return UOp(Ops.INS, arg=X86Ops.IMM, dtype=dt, tag=truncate[dt](v))
@ -253,31 +255,27 @@ def vcmp(x:UOp) -> UOp:
# for 128 bit xmm2 selects its lower 2 32 bits from xmm0 and its upper 2 32 bits from xmm1 according to imm
# for 256 bit ymm2 repeats the shuffle for its upper 128 bits selecting from the upper 128 bits of ymm0 and ymm1
def vshufps(x:UOp) -> UOp|None:
def _idx(i:int) -> int: return x.src[i].arg[0] if x.src[i].op is Ops.GEP else 0
def _in(i:int) -> UOp: return s.src[0] if (s:=x.src[i]).op is Ops.GEP else s
a, b = _in(0), _in(2)
if not (a is _in(1) and b is _in(3)) or any(_idx(i) > 3 for i in range(4)): return None
if len(x.src) == 8 and (not (a is _in(4) is _in(5) and b is _in(6) is _in(7)) or any(_idx(i+4) != _idx(i)+4 for i in range(4))): return None
return x.ins(X86Ops.VSHUFPS, src=(a, b, imm(dtypes.uint8, sum(_idx(i) << 2*i for i in range(4)))))
a, b = base(x, 0), base(x, 2)
if not (a is base(x, 1) and b is base(x, 3)) or any(lane(x, i) > 3 for i in range(4)): return None
if len(x.src) == 8:
if not (a is base(x, 4) is base(x, 5) and b is base(x, 6) is base(x, 7)) or any(lane(x, i+4) != lane(x, i)+4 for i in range(4)): return None
return x.ins(X86Ops.VSHUFPS, src=(a, b, imm(dtypes.uint8, sum(lane(x, i) << 2*i for i in range(4)))))
# vshufpd xmm2, xmm0, xmm1, imm
# for 128 bit xmm2 selects its lower 64 bits from xmm0 and its upper 64 bits from xmm1 according to imm
# for 256 bit ymm2 additionally selects its upper 128 bits from the upper 128 bits of ymm0 and ymm1 from following the same constraint
# for 256 bit ymm2 also selects its upper 128 bits from the upper 128 bits of ymm0 and ymm1 following the same constraint
def vshufpd(x:UOp) -> UOp:
def _idx(i:int) -> int: return x.src[i].arg[0] if x.src[i].op is Ops.GEP else 0
def _in(i:int) -> UOp: return s.src[0] if (s:=x.src[i]).op is Ops.GEP else s
a, b = _in(0), _in(1)
if _idx(0) > 1 or _idx(1) > 1: return None
if len(x.src) == 4 and not (a is _in(2) and b is _in(3) and _idx(2) > 1 and _idx(3) > 1): return None
return x.ins(X86Ops.VSHUFPD, src=(a, b, imm(dtypes.uint8, sum(_idx(i) << i for i in range(len(x.src))))))
a, b = base(x, 0), base(x, 1)
if lane(x, 0) > 1 or lane(x, 1) > 1: return None
if len(x.src) == 4 and not (a is base(x, 2) and b is base(x, 3) and lane(x, 2) > 1 and lane(x, 3) > 1): return None
return x.ins(X86Ops.VSHUFPD, src=(a, b, imm(dtypes.uint8, sum(lane(x, i) << i for i in range(len(x.src))))))
# vinsertps xmm2, xmm0, xmm1, imm
# inserts any 32 bit element in xmm1 into any position in xmm0 according to immm, result is written to xmm2
# this is the fallback slow case for when you can't match more a powerful shuffle
def vinsertps(x:UOp) -> UOp:
def _insert(ret:UOp, i:int) -> UOp:
s, v = x.src[i], 0
if s.op is Ops.GEP: s, v = s.src[0], s.arg[0]
s, v = base(x, i), lane(x, i)
# moving the 0th element into the 0th position does nothing
return s if i == v == 0 else x.ins(X86Ops.VINSERTPS, src=(ret, s, imm(dtypes.uint8, v << 6 | i << 4)))
return functools.reduce(_insert, range(len(x.src)), def_reg(x.dtype))