mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fixup isel tests
This commit is contained in:
parent
3fcde08b20
commit
db3ed92ae3
2 changed files with 50 additions and 59 deletions
|
|
@ -7,9 +7,7 @@ from tinygrad.helpers import SPEC
|
|||
|
||||
@unittest.skipIf(SPEC > 1, "x86 spec not supported in full_spec")
|
||||
class TestIselX86(unittest.TestCase):
|
||||
def isel_rewrite(self, x:UOp):
|
||||
x = graph_rewrite(x, X86Renderer().pre_isel_matcher)
|
||||
return graph_rewrite(x, X86Renderer().isel_matcher, IselContext(x), bottom_up=True)
|
||||
def isel_rewrite(self, x:UOp): return graph_rewrite(x, X86Renderer().isel_matcher, IselContext(x), bottom_up=True)
|
||||
|
||||
def test_cmove(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.int32)
|
||||
|
|
@ -35,75 +33,68 @@ class TestIselX86(unittest.TestCase):
|
|||
self.assertTrue(n.src[0].op is X86Ops.CMOVB and n.src[0].src[2].op is X86Ops.VUCOMISS)
|
||||
self.assertTrue(n.src[1].op is X86Ops.VBLENDVPS and n.src[1].src[2].op is X86Ops.VCMPSS and n.src[1].src[2].src[2].arg == 1)
|
||||
|
||||
# the geps become part of the immediate in the instruction
|
||||
def test_vshufps_same_src(self):
|
||||
# lower 2 32 bits must come from the same register and upper 2 32 bits must come from the same register
|
||||
def test_vshufps(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.float32.vec(4))
|
||||
vec = UOp(Ops.VECTORIZE, a.dtype, (a.gep(3), a.gep(2), a.gep(1), a.gep(0)))
|
||||
n = self.isel_rewrite(vec)
|
||||
self.assertTrue(n.op is X86Ops.VSHUFPS and n.src[0] is a and n.src[1] is a and n.src[2].arg == 27)
|
||||
|
||||
def test_vshufps_diff_src(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.float32.vec(4))
|
||||
b = UOp.variable("b", 0, 0, dtypes.float32)
|
||||
vec = UOp(Ops.VECTORIZE, a.dtype, (a.gep(2), a.gep(3), b, b))
|
||||
n = self.isel_rewrite(vec)
|
||||
self.assertTrue(n.op is X86Ops.VSHUFPS and n.src[0] is a and n.src[1] is b and n.src[2].arg == 14)
|
||||
b = UOp.variable("b", 0, 0, dtypes.float32.vec(4))
|
||||
c = UOp.variable("c", 0, 0, dtypes.float32)
|
||||
d = UOp.variable("d", 0, 0, dtypes.float32)
|
||||
# shuffle between 2 vectors
|
||||
n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(0), a.gep(1), b.gep(2), b.gep(3))))
|
||||
self.assertTrue(n.op is X86Ops.VSHUFPS)
|
||||
# shuffle between 2 scalars
|
||||
n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (c, c, d, d)))
|
||||
self.assertTrue(n.op is X86Ops.VSHUFPS)
|
||||
# shuffle between vector and scalar
|
||||
n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(0), a.gep(1), c, c)))
|
||||
self.assertTrue(n.op is X86Ops.VSHUFPS)
|
||||
# shuffle between 1 vector
|
||||
n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(1), a.gep(2), a.gep(3), a.gep(0))))
|
||||
self.assertTrue(n.op is X86Ops.VSHUFPS and n.src[0] is n.src[1])
|
||||
# a shuffle between 1 scalar is just a broadcast and matches X86Ops.VBROADCASTSS to allow for load fusion
|
||||
|
||||
# this is the fallback slow VECTORIZE, 1 vinsertps per src in VECTORIZE
|
||||
def test_vinsertps(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.float32.vec(4))
|
||||
b = UOp.variable("b", 0, 0, dtypes.float32.vec(4))
|
||||
c = UOp.variable("c", 0, 0, dtypes.float32.vec(4))
|
||||
d = UOp.variable("d", 0, 0, dtypes.float32)
|
||||
vec = UOp(Ops.VECTORIZE, dtypes.float32.vec(4), (a.gep(0), b.gep(0), c.gep(0), d))
|
||||
n = self.isel_rewrite(vec)
|
||||
self.assertTrue(n.op is X86Ops.VINSERTPS and len(n.src) == 3)
|
||||
self.assertTrue(n.src[0].op is X86Ops.VINSERTPS and n.src[1] is d and n.src[2].arg == 48)
|
||||
n = n.src[0]
|
||||
self.assertTrue(n.src[0].op is X86Ops.VINSERTPS and n.src[1] is c and n.src[2].arg == 32)
|
||||
n = n.src[0]
|
||||
# first gep is just moving the first element from a reg to another which does nothing
|
||||
self.assertTrue(n.src[0] is a and n.src[1] is b and n.src[2].arg == 16)
|
||||
d = UOp.variable("e", 0, 0, dtypes.float32)
|
||||
# pack 1 from vector and 1 from scalar, moving 0th element to position 0 does nothing so only 1 vinsertps is generated
|
||||
n = self.isel_rewrite(UOp(Ops.VECTORIZE, dtypes.float32.vec(2), (a.gep(0), d)))
|
||||
self.assertTrue(n.op is X86Ops.VINSERTPS and n.src[0].op is X86Ops.DEFINE_REG)
|
||||
# interleaved shuffle between 2 vectors
|
||||
n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(0), b.gep(1), a.gep(2), b.gep(3))))
|
||||
self.assertTrue(n.op is X86Ops.VINSERTPS)
|
||||
# shuffle between 4 sources
|
||||
n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(3), b.gep(2), c.gep(1), d)))
|
||||
self.assertTrue(n.op is X86Ops.VINSERTPS)
|
||||
|
||||
# 8bit displacement should be used when possible
|
||||
def test_load_8bit_disp(self):
|
||||
offset = UOp.variable("a", 0, 0, dtypes.int32) + UOp.const(dtypes.int32, 1)
|
||||
index = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), arg=0).index(offset, ptr=True)
|
||||
load = index.load()
|
||||
# complex address is [base + index*scale + displacement]
|
||||
def test_complex_address(self):
|
||||
a = UOp.variable("a", 0, 0, dtypes.int32)
|
||||
load = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), arg=0).index(a + 1, ptr=True).load()
|
||||
n = self.isel_rewrite(load)
|
||||
self.assertTrue(n.src[2].op is X86Ops.IMM and n.src[2].dtype is dtypes.int8)
|
||||
|
||||
def test_fuse_index(self):
|
||||
var = UOp.variable("a", 0, 0, dtypes.int32)
|
||||
offset = var + UOp.const(dtypes.int32, 1)
|
||||
index = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), arg=0).index(offset, ptr=True)
|
||||
load = index.load()
|
||||
n = self.isel_rewrite(load)
|
||||
self.assertTrue(n.src[1] is var)
|
||||
# base is DEFINE_GLOBAL, index is "a"
|
||||
self.assertTrue(n.src[0].op is X86Ops.DEFINE_REG and n.src[1].op is X86Ops.DEFINE_REG)
|
||||
# displacement is the constant in "a" scaled to the buffer element size, dtype is int8 when the value fits otherwise int32
|
||||
self.assertTrue(n.src[2].op is X86Ops.IMM and n.src[2].dtype is dtypes.int8 and n.src[2].arg == 4)
|
||||
|
||||
def test_fuse_load(self):
|
||||
offset = UOp.variable("a", 0, 0, dtypes.int32) + UOp.const(dtypes.int32, 1)
|
||||
index = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), arg=0).index(offset, ptr=True)
|
||||
load = index.load()
|
||||
add = offset + load
|
||||
n = self.isel_rewrite(add)
|
||||
load1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load()
|
||||
load2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 1), ptr=True).load()
|
||||
n = self.isel_rewrite(load1 + load2)
|
||||
self.assertTrue(len(n.src) == 4)
|
||||
|
||||
# don't fuse when used multiple times
|
||||
def test_dont_fuse_load(self):
|
||||
offset = UOp.variable("a", 0, 0, dtypes.int32) + UOp.const(dtypes.int32, 1)
|
||||
index = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), arg=0).index(offset, ptr=True)
|
||||
load = index.load()
|
||||
add1 = offset + load
|
||||
add2 = add1 + load
|
||||
n = self.isel_rewrite(add2)
|
||||
def test_dont_fuse_load_diff_users(self):
|
||||
load = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load()
|
||||
add = load + 1
|
||||
n = self.isel_rewrite(add + load)
|
||||
self.assertTrue(len(n.src) == 2)
|
||||
|
||||
def test_dont_fuse_load_same_user(self):
|
||||
offset = UOp.variable("a", 0, 0, dtypes.int32) + UOp.const(dtypes.int32, 1)
|
||||
index = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), arg=0).index(offset, ptr=True)
|
||||
load = index.load()
|
||||
add = load + load
|
||||
n = self.isel_rewrite(add)
|
||||
load = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load()
|
||||
n = self.isel_rewrite(load * load)
|
||||
self.assertTrue(len(n.src) == 2)
|
||||
|
||||
# test noop has same reg as src, this is because noops aren't instructions but still need to be part of the graph
|
||||
|
|
|
|||
|
|
@ -217,9 +217,9 @@ def def_reg(dt:DType, reg:Register|None=None) -> UOp: return UOp(X86Ops.DEFINE_R
|
|||
# xmm2 selects its lower 2 32 bits from xmm0 and its upper 2 32 bits from xmm1
|
||||
def vshufps(x:UOp) -> UOp:
|
||||
def _in(i:int) -> UOp: return s.src[0] if (s:=x.src[i]).op is Ops.GEP else s
|
||||
if len(x.src) != 4 or not (_in(0) is _in(1) and _in(2) is _in(3)): return None
|
||||
if len(x.src) != 4 or _in(0) is not _in(1) or _in(2) is not _in(3): return None
|
||||
return UOp(X86Ops.VSHUFPS, x.dtype, (_in(0), _in(2),
|
||||
imm(dtypes.uint8, sum((s.arg[0] if s.op is Ops.GEP else 0) << (2*i) for i,s in enumerate(x.src)))))
|
||||
imm(dtypes.uint8, sum(s.arg[0] << 2*i if s.op is Ops.GEP else 0 for i,s in enumerate(x.src)))))
|
||||
|
||||
# vinsertps xmm2, xmm0, xmm1
|
||||
# inserts any 32 bit element in xmm1 into any position in xmm0, result is written to xmm2
|
||||
|
|
@ -228,7 +228,7 @@ def vinsertps(x:UOp) -> UOp:
|
|||
def _insert(ret:UOp, i:int) -> UOp:
|
||||
s, v = x.src[i], 0
|
||||
if s.op is Ops.GEP: s, v = s.src[0], s.arg[0]
|
||||
# if first src is not a gep or gep[0] it's just moving the 0th element from an xmm reg to another without shuffling which does nothing
|
||||
# moving the 0th element into the 0th position does nothing
|
||||
return s if i == v == 0 else UOp(X86Ops.VINSERTPS, x.dtype, (ret, s, imm(dtypes.uint8, v << 6 | i << 4)))
|
||||
return functools.reduce(_insert, range(len(x.src)), def_reg(x.dtype))
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue