mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
more isel tests
This commit is contained in:
parent
255a788dea
commit
41f2bd8a05
2 changed files with 93 additions and 56 deletions
|
|
@ -4,9 +4,18 @@ from tinygrad.uop.ops import UOp, dtypes, graph_rewrite
|
|||
from tinygrad.renderer.isa.x86 import X86Renderer, X86Ops
|
||||
from tinygrad.renderer.isa import IselContext
|
||||
|
||||
# these tests are to catch changes that don't cause incorrect codegen but cause worse codegen
|
||||
class TestIselX86(unittest.TestCase):
|
||||
def isel_rewrite(self, x:UOp): return graph_rewrite(x, X86Renderer().isel_matcher, IselContext(x), bottom_up=True)
|
||||
|
||||
def _check_op(self, dt_op, expr):
|
||||
nargs = expr.__code__.co_argcount
|
||||
for dt,op in dt_op:
|
||||
with self.subTest(dtype=dt):
|
||||
vars = [UOp.variable(str(i), 0, 0, dt) for i in range(nargs)]
|
||||
n = self.isel_rewrite(expr(*vars))
|
||||
self.assertIs(n.arg, op)
|
||||
|
||||
def test_cmove(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.int32)
|
||||
b = UOp.variable("b", 0, 0, dtypes.int32)
|
||||
|
|
@ -19,16 +28,26 @@ class TestIselX86(unittest.TestCase):
|
|||
self.assertTrue(n.src[0].src[2] == n.src[1].src[2] and n.src[0].src[2].arg is X86Ops.CMP)
|
||||
|
||||
def test_vmax(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.float32)
|
||||
b = UOp.variable("b", 0, 0, dtypes.float32)
|
||||
n = self.isel_rewrite((a < b).where(b, a))
|
||||
self.assertTrue(n.arg is X86Ops.VMAXSS)
|
||||
dt_op = [(dtypes.float32, X86Ops.VMAXSS), (dtypes.float64, X86Ops.VMAXSD),
|
||||
(dtypes.float32.vec(4), X86Ops.VMAXPS), (dtypes.float64.vec(4), X86Ops.VMAXPD)]
|
||||
self._check_op(dt_op, lambda a,b: (a < b).where(b, a))
|
||||
|
||||
def test_vmin(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.float32)
|
||||
b = UOp.variable("b", 0, 0, dtypes.float32)
|
||||
n = self.isel_rewrite((a < b).where(a, b))
|
||||
self.assertTrue(n.arg is X86Ops.VMINSS)
|
||||
dt_op = [(dtypes.float32, X86Ops.VMINSS), (dtypes.float64, X86Ops.VMINSD),
|
||||
(dtypes.float32.vec(4), X86Ops.VMINPS), (dtypes.float64.vec(4), X86Ops.VMINPD)]
|
||||
self._check_op(dt_op, lambda a,b: (a < b).where(a, b))
|
||||
|
||||
def test_vfmadd(self):
|
||||
dt_op = [(dtypes.float32, X86Ops.VFMADD213SS), (dtypes.float64, X86Ops.VFMADD213SD),
|
||||
(dtypes.float32.vec(4), X86Ops.VFMADD213PS), (dtypes.float64.vec(4), X86Ops.VFMADD213PD)]
|
||||
self._check_op(dt_op, lambda a,b,c: a * b + c)
|
||||
|
||||
# TODO: shouldn't match fmadd if var is used multiple times
|
||||
@unittest.expectedFailure
|
||||
def test_vfmadd_fail(self):
|
||||
dt_op = [(dtypes.float32, X86Ops.VADDSS), (dtypes.float64, X86Ops.VADDSD),
|
||||
(dtypes.float32.vec(4), X86Ops.VADDPS), (dtypes.float64.vec(4), X86Ops.VADDPD)]
|
||||
self._check_op(dt_op, lambda a,b: a * b + b)
|
||||
|
||||
def test_vpbroadcast(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.int32)
|
||||
|
|
@ -40,41 +59,63 @@ class TestIselX86(unittest.TestCase):
|
|||
n = self.isel_rewrite(load.broadcast(4))
|
||||
self.assertTrue(n.arg is X86Ops.VPBROADCASTD and len(n.src) == 3)
|
||||
|
||||
# lower 2 32 bits must come from the same register and upper 2 32 bits must come from the same register
|
||||
def test_vbroadcastss(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.float32)
|
||||
valid = [UOp.vectorize(a, a, a, a), UOp.vectorize(a, a, a, a, a, a, a, a)]
|
||||
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VBROADCASTSS)
|
||||
|
||||
def test_vshufps(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.float32.vec(4))
|
||||
b = UOp.variable("b", 0, 0, dtypes.float32.vec(4))
|
||||
a = UOp.variable("a", 0, 0, dtypes.float32.vec(8))
|
||||
b = UOp.variable("b", 0, 0, dtypes.float32.vec(8))
|
||||
c = UOp.variable("c", 0, 0, dtypes.float32)
|
||||
d = UOp.variable("d", 0, 0, dtypes.float32)
|
||||
# shuffle between 2 vectors
|
||||
n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(0), a.gep(1), b.gep(2), b.gep(3))))
|
||||
self.assertTrue(n.arg is X86Ops.VSHUFPS)
|
||||
# shuffle between 2 scalars
|
||||
n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (c, c, d, d)))
|
||||
self.assertTrue(n.arg is X86Ops.VSHUFPS)
|
||||
# shuffle between vector and scalar
|
||||
n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(0), a.gep(1), c, c)))
|
||||
self.assertTrue(n.arg is X86Ops.VSHUFPS)
|
||||
# shuffle between 1 vector
|
||||
n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(1), a.gep(2), a.gep(3), a.gep(0))))
|
||||
self.assertTrue(n.arg is X86Ops.VSHUFPS and n.src[0] is n.src[1])
|
||||
# a shuffle between 1 scalar is just a broadcast and matches X86Ops.VBROADCASTSS to allow for load fusion
|
||||
|
||||
# this is the fallback slow VECTORIZE, 1 vinsertps per src in VECTORIZE
|
||||
valid = [UOp.vectorize(c, c, d, d),
|
||||
UOp.vectorize(a.gep(0), a.gep(1), c, c),
|
||||
UOp.vectorize(a.gep(0), a.gep(1), b.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(1), a.gep(2), a.gep(3), a.gep(0)),
|
||||
UOp.vectorize(a.gep(3), a.gep(2), a.gep(1), a.gep(0), a.gep(7), a.gep(6), a.gep(5), a.gep(4)),
|
||||
UOp.vectorize(a.gep(0), a.gep(0), b.gep(1), b.gep(1), a.gep(4), a.gep(4), b.gep(5), b.gep(5))]
|
||||
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPS)
|
||||
|
||||
invalid = [UOp.vectorize(a.gep(0), a.gep(1), b.gep(4), b.gep(5)),
|
||||
UOp.vectorize(a.gep(0), a.gep(5), b.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(0), a.gep(0), a.gep(0), a.gep(0), a.gep(4), a.gep(4), a.gep(4), a.gep(5)),
|
||||
UOp.vectorize(a.gep(0), a.gep(0), b.gep(0), b.gep(0), a.gep(4), a.gep(4), b.gep(4), a.gep(4))]
|
||||
for shuf in invalid: self.assertIsNot(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPS)
|
||||
|
||||
def test_vshufpd(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.float64.vec(4))
|
||||
b = UOp.variable("b", 0, 0, dtypes.float64.vec(4))
|
||||
c = UOp.variable("c", 0, 0, dtypes.float64)
|
||||
d = UOp.variable("d", 0, 0, dtypes.float64)
|
||||
|
||||
valid = [UOp.vectorize(c, d),
|
||||
UOp.vectorize(a.gep(0), c),
|
||||
UOp.vectorize(a.gep(1), b.gep(1)),
|
||||
UOp.vectorize(a.gep(0), b.gep(1), a.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(1), a.gep(1), a.gep(3), a.gep(3))]
|
||||
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPD)
|
||||
|
||||
invalid = [UOp.vectorize(c, c, c, c),
|
||||
UOp.vectorize(a.gep(0), a.gep(1), b.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(2), b.gep(3), a.gep(2), b.gep(3)),
|
||||
UOp.vectorize(a.gep(0), b.gep(1), a.gep(0), b.gep(1))]
|
||||
for shuf in invalid: self.assertIsNot(self.isel_rewrite(shuf).arg, X86Ops.VSHUFPD)
|
||||
|
||||
# this is the fallback slow VECTORIZE, 1 vinsertps per src in VECTORIZE
|
||||
def test_vinsertps(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.float32.vec(4))
|
||||
b = UOp.variable("b", 0, 0, dtypes.float32.vec(4))
|
||||
c = UOp.variable("c", 0, 0, dtypes.float32.vec(4))
|
||||
d = UOp.variable("e", 0, 0, dtypes.float32)
|
||||
# pack 1 from vector and 1 from scalar, moving 0th element to position 0 does nothing so only 1 vinsertps is generated
|
||||
n = self.isel_rewrite(UOp(Ops.VECTORIZE, dtypes.float32.vec(2), (a.gep(0), d)))
|
||||
n = self.isel_rewrite(UOp.vectorize(a.gep(0), d))
|
||||
self.assertTrue(n.arg is X86Ops.VINSERTPS and n.src[0].arg is X86Ops.DEFINE_REG)
|
||||
# interleaved shuffle between 2 vectors
|
||||
n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(0), b.gep(1), a.gep(2), b.gep(3))))
|
||||
self.assertTrue(n.arg is X86Ops.VINSERTPS)
|
||||
# shuffle between 4 sources
|
||||
n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(3), b.gep(2), c.gep(1), d)))
|
||||
self.assertTrue(n.arg is X86Ops.VINSERTPS)
|
||||
|
||||
valid = [UOp.vectorize(a.gep(0), b.gep(1), a.gep(2), b.gep(3)), # TODO: this should be vunpck
|
||||
UOp.vectorize(a.gep(3), b.gep(2), c.gep(1), d)]
|
||||
for shuf in valid: self.assertIs(self.isel_rewrite(shuf).arg, X86Ops.VINSERTPS)
|
||||
|
||||
# complex address is [base + index*scale + displacement]
|
||||
def test_complex_address(self):
|
||||
|
|
@ -86,21 +127,19 @@ class TestIselX86(unittest.TestCase):
|
|||
# displacement is the constant in "a" scaled to the buffer element size, dtype is int8 when the value fits otherwise int32
|
||||
self.assertTrue(n.src[2].arg is X86Ops.IMM and n.src[2].dtype is dtypes.int8 and n.src[2].tag == 4)
|
||||
|
||||
def test_fuse_load(self):
|
||||
def test_fold_load(self):
|
||||
load1 = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load()
|
||||
load2 = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 1), ptr=True).load()
|
||||
n = self.isel_rewrite(load1 + load2)
|
||||
self.assertTrue(len(n.src) == 4)
|
||||
|
||||
# don't fuse when used multiple times
|
||||
def test_dont_fuse_load_diff_users(self):
|
||||
# don't fold when used multiple times
|
||||
def test_dont_fold_load(self):
|
||||
load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load()
|
||||
add = load + 1
|
||||
n = self.isel_rewrite(add + load)
|
||||
# used by multiple users
|
||||
n = self.isel_rewrite(load + 1 + load)
|
||||
self.assertTrue(len(n.src) == 2)
|
||||
|
||||
def test_dont_fuse_load_same_user(self):
|
||||
load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load()
|
||||
# used mutiple times by same user
|
||||
n = self.isel_rewrite(load * load)
|
||||
self.assertTrue(len(n.src) == 2)
|
||||
|
||||
|
|
|
|||
|
|
@ -231,6 +231,8 @@ reg_strs = {"rax": {4:"eax", 2:"ax", 1:"al"}, "rcx": {4:"ecx", 2:"cx", 1:"cl"},
|
|||
# ***** X86 instruction selection *****
|
||||
# if the load is used multiple times we don't fold
|
||||
def is_foldable_load(ctx:IselContext, x:UOp, s:UOp) -> bool: return s.op is Ops.LOAD and len(ctx.uses[s]) == x.src.count(s) == 1
|
||||
def base(x:UOp, i:int) -> UOp: return s.src[0] if (s:=x.src[i]).op is Ops.GEP else s
|
||||
def lane(x:UOp, i:int) -> int: return x.src[i].arg[0] if x.src[i].op is Ops.GEP else 0
|
||||
def to_int(dt:DType): return {dtypes.float16: dtypes.int16, dtypes.float32: dtypes.int32, dtypes.float64: dtypes.int64}[dt]
|
||||
def def_reg(dt:DType, reg:Register|None=None) -> UOp: return UOp(Ops.INS, arg=X86Ops.DEFINE_REG, dtype=dt, tag=reg)
|
||||
def imm(dt:DType, v:int) -> UOp: return UOp(Ops.INS, arg=X86Ops.IMM, dtype=dt, tag=truncate[dt](v))
|
||||
|
|
@ -253,31 +255,27 @@ def vcmp(x:UOp) -> UOp:
|
|||
# for 128 bit xmm2 selects its lower 2 32 bits from xmm0 and its upper 2 32 bits from xmm1 according to imm
|
||||
# for 256 bit ymm2 repeats the shuffle for its upper 128 bits selecting from the upper 128 bits of ymm0 and ymm1
|
||||
def vshufps(x:UOp) -> UOp|None:
|
||||
def _idx(i:int) -> int: return x.src[i].arg[0] if x.src[i].op is Ops.GEP else 0
|
||||
def _in(i:int) -> UOp: return s.src[0] if (s:=x.src[i]).op is Ops.GEP else s
|
||||
a, b = _in(0), _in(2)
|
||||
if not (a is _in(1) and b is _in(3)) or any(_idx(i) > 3 for i in range(4)): return None
|
||||
if len(x.src) == 8 and (not (a is _in(4) is _in(5) and b is _in(6) is _in(7)) or any(_idx(i+4) != _idx(i)+4 for i in range(4))): return None
|
||||
return x.ins(X86Ops.VSHUFPS, src=(a, b, imm(dtypes.uint8, sum(_idx(i) << 2*i for i in range(4)))))
|
||||
a, b = base(x, 0), base(x, 2)
|
||||
if not (a is base(x, 1) and b is base(x, 3)) or any(lane(x, i) > 3 for i in range(4)): return None
|
||||
if len(x.src) == 8:
|
||||
if not (a is base(x, 4) is base(x, 5) and b is base(x, 6) is base(x, 7)) or any(lane(x, i+4) != lane(x, i)+4 for i in range(4)): return None
|
||||
return x.ins(X86Ops.VSHUFPS, src=(a, b, imm(dtypes.uint8, sum(lane(x, i) << 2*i for i in range(4)))))
|
||||
|
||||
# vshufpd xmm2, xmm0, xmm1, imm
|
||||
# for 128 bit xmm2 selects its lower 64 bits from xmm0 and its upper 64 bits from xmm1 according to imm
|
||||
# for 256 bit ymm2 additionally selects its upper 128 bits from the upper 128 bits of ymm0 and ymm1 from following the same constraint
|
||||
# for 256 bit ymm2 also selects its upper 128 bits from the upper 128 bits of ymm0 and ymm1 following the same constraint
|
||||
def vshufpd(x:UOp) -> UOp:
|
||||
def _idx(i:int) -> int: return x.src[i].arg[0] if x.src[i].op is Ops.GEP else 0
|
||||
def _in(i:int) -> UOp: return s.src[0] if (s:=x.src[i]).op is Ops.GEP else s
|
||||
a, b = _in(0), _in(1)
|
||||
if _idx(0) > 1 or _idx(1) > 1: return None
|
||||
if len(x.src) == 4 and not (a is _in(2) and b is _in(3) and _idx(2) > 1 and _idx(3) > 1): return None
|
||||
return x.ins(X86Ops.VSHUFPD, src=(a, b, imm(dtypes.uint8, sum(_idx(i) << i for i in range(len(x.src))))))
|
||||
a, b = base(x, 0), base(x, 1)
|
||||
if lane(x, 0) > 1 or lane(x, 1) > 1: return None
|
||||
if len(x.src) == 4 and not (a is base(x, 2) and b is base(x, 3) and lane(x, 2) > 1 and lane(x, 3) > 1): return None
|
||||
return x.ins(X86Ops.VSHUFPD, src=(a, b, imm(dtypes.uint8, sum(lane(x, i) << i for i in range(len(x.src))))))
|
||||
|
||||
# vinsertps xmm2, xmm0, xmm1, imm
|
||||
# inserts any 32 bit element in xmm1 into any position in xmm0 according to immm, result is written to xmm2
|
||||
# this is the fallback slow case for when you can't match more a powerful shuffle
|
||||
def vinsertps(x:UOp) -> UOp:
|
||||
def _insert(ret:UOp, i:int) -> UOp:
|
||||
s, v = x.src[i], 0
|
||||
if s.op is Ops.GEP: s, v = s.src[0], s.arg[0]
|
||||
s, v = base(x, i), lane(x, i)
|
||||
# moving the 0th element into the 0th position does nothing
|
||||
return s if i == v == 0 else x.ins(X86Ops.VINSERTPS, src=(ret, s, imm(dtypes.uint8, v << 6 | i << 4)))
|
||||
return functools.reduce(_insert, range(len(x.src)), def_reg(x.dtype))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue