mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
* draft * cleanup test_encodings * cleanup test_isel * model flag state and support rematerialization * woops * add vbroadcastss instruction * don't fuse load if used multiple times in src * add movabs instruction and fix idiv * fixes * add x86 backend to tests * float16 fix * rm TwoAddress2nd * add BARRIER * test windows ci * yup isel fixes the mask stuff too and its beautiful * add cmoves to the spec * support storing imms * no TUPLE_ORDER, breaks tests * fix remaining seg faults * add float max * always fuse index * minor * fix DEFINE_VAR/SPECIAL and enable multithreading * linter * more linter * more * more * more * let's try this * perhaps * start new scheduler * more scheduling info * cleaner shuffle functions * fixup isel tests * skip bounds check when NOOPs exist * skip inf rewrite tests * fix const tag hack and add x86ops to _shape * fix * skip a few tests * func arg order independent from op value * x86 goes in own linearize * switch to PARAM * more * add min x86op and neg in decomps * do mulacc in isel * use def_reg in test_encodings * enable emulated int64 tests * how much does this fix * Ops becomes OpType * fix * rm noqa * rm machine scheduler stuff * and this * allow for extending enums and move X86Ops out of uop * fix imports * rm X86GroupOp from ops.py * spacing * tell mypy to shut up * more linter * add x86op test * allow set[X86Ops] in upat * move NOOPs to pre_isel_matcher and rm NOOP from spec * more asserts * also this * cleanup encode * simplify live range * fix idiv * add Ops.INS to x86 * more changes * more changes * more changes * fix * fix * fix * fix * print formatted assembly * fix 8bit idiv? * oops * enable float16 and unaligned vector load/store * actually no * move x86 tests * no more bool cast * fix * linter * linter * move X86Ops to x86.py * fix vpbroadcast * cleanups * linter * print correct reg names * canonical max * move max/min and add test * support float16 vector load/store * rm bad rewrite * vpsrldq can't access memory * regalloc takes renderer * enable vector load/store on all dtypes * more isel tests * rm this for now * a lot better * fix * fix * fix * deal with flags correctly * fix * enable gep noop rule * fix * fix * fix * add callee saved registers * use Ops.CONST instead of X86Ops.IMM * fix * enable TUPLE_ORDER * fix * rm x86 code in linearizer * fix * fix * fix * move isa rewrites to codegen * fix * fix * skip test_linearizer.py * skip more tests * fix * fix for idiv/mod changes * fix * don't use fmadd if it duplicates fused op * hacky * fix * cleanups * cleanups * fix --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
149 lines
No EOL
7.8 KiB
Python
149 lines
No EOL
7.8 KiB
Python
import unittest
|
|
from tinygrad import Device
|
|
from tinygrad.uop.ops import UOp, Ops
|
|
from tinygrad.dtype import dtypes
|
|
from tinygrad.renderer.isa.x86 import X86Ops, X86Renderer, RBP, RDI, RSP, RSI, RAX, RDX, XMM, GPR, imm, def_reg
|
|
|
|
def ins(op, dt, src, tag=None): return UOp(Ops.INS, arg=op, dtype=dt, src=src, tag=tag)
|
|
|
|
@unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, X86Renderer), "only on x86")
|
|
class TestEncodingsX86(unittest.TestCase):
|
|
# NOTE: x86 supports a single displacement as memory address and index without base memory address
|
|
# these have no use cases so they aren't supported
|
|
def encode(self, u:UOp): return Device[Device.DEFAULT].renderer.render([u])
|
|
|
|
# displacement of 0 isn't emitted
|
|
def test_base_address(self):
|
|
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RDI)
|
|
# mov edi, dword ptr [rdi]
|
|
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 3F"))
|
|
|
|
# rsp/r12 require a sib byte when used as base memory address
|
|
def test_rsp_base_address(self):
|
|
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RSP), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RSP)
|
|
# mov esp, dword ptr [rsp]
|
|
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 24 24"))
|
|
|
|
# rbp/r13 require a displacement when used as base memory address
|
|
def test_rbp_base_address(self):
|
|
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RBP), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RBP)
|
|
# mov ebp, dword ptr [rbp + 0]
|
|
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 6D 00"))
|
|
|
|
# test [base + index*scale]
|
|
def test_base_index_address(self):
|
|
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, RDX), imm(dtypes.int8, 0)), RAX)
|
|
# mov eax, dword ptr [rax + rdx*4]
|
|
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 04 90"))
|
|
|
|
# rsp as index means no index
|
|
def test_rsp_index_address(self):
|
|
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, RSP), imm(dtypes.int8, 0)), RAX)
|
|
# mov eax, dword ptr [rax]
|
|
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 00"))
|
|
|
|
# however r12 is a valid index
|
|
def test_r12_index_address(self):
|
|
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, GPR[12]), imm(dtypes.int8, 0)), RAX)
|
|
# mov eax, dword ptr [rax + r12*4]
|
|
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("42 8B 04 A0"))
|
|
|
|
# test [base + index*scale + 8bit disp]
|
|
def test_complex_address_8bit_disp(self):
|
|
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)), RDI)
|
|
# mov edi, dword ptr [rdi + rsi*4 + 0xa]
|
|
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 7C B7 0A"))
|
|
|
|
# test [base + index*scale + 32bit disp]
|
|
def test_complex_address_32bit_disp(self):
|
|
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10000)), RDI)
|
|
# mov edi, dword ptr [rdi + rsi*4 + 0x2710]
|
|
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B BC B7 10 27 00 00"))
|
|
|
|
# 8bit variants of legacy instructions subtract 1 from opcode
|
|
def test_8bit_legacy_encoding(self):
|
|
cast = ins(X86Ops.MOVSX, dtypes.int32, (def_reg(dtypes.int8, RDX),), RAX)
|
|
# movsx eax, dl
|
|
self.assertEqual(bytes.fromhex(self.encode(cast)), bytes.fromhex("0F BE C2"))
|
|
|
|
# accessing lower 8 bits of rsp, rbp, rsi, rdi requires rex prefix
|
|
def test_lower_8bits_reg(self):
|
|
cast = ins(X86Ops.MOVSX, dtypes.int32, (def_reg(dtypes.int8, RDI),), RAX)
|
|
# movsx eax, dil
|
|
self.assertEqual(bytes.fromhex(self.encode(cast)), bytes.fromhex("40 0F BE C7"))
|
|
|
|
# test 16 bit variant of legacy instruction
|
|
def test_16bit_legacy_encoding(self):
|
|
cast = ins(X86Ops.MOVSX, dtypes.int16, (def_reg(dtypes.int8, RDX),), RAX)
|
|
# movsx ax, dl
|
|
self.assertEqual(bytes.fromhex(self.encode(cast)), bytes.fromhex("66 0F BE C2"))
|
|
|
|
# test 64 bit variant of legacy instruction
|
|
def test_64bit_legacy_encoding(self):
|
|
cast = ins(X86Ops.MOVSX, dtypes.int64, (def_reg(dtypes.int8, RDX),), RAX)
|
|
# movsx rax, dl
|
|
self.assertEqual(bytes.fromhex(self.encode(cast)), bytes.fromhex("48 0F BE C2"))
|
|
|
|
# test compact vex encoding
|
|
def test_compact_vex_encoding(self):
|
|
xmm0, xmm1 = def_reg(dtypes.float32, XMM[0]), def_reg(dtypes.float32, XMM[1])
|
|
add = ins(X86Ops.VADDSS, dtypes.float32, (xmm0, xmm1), XMM[0])
|
|
# vaddss xmm0, xmm0, xmm1
|
|
self.assertEqual(bytes.fromhex(self.encode(add)), bytes.fromhex("C5 FA 58 C1"))
|
|
|
|
# test long vex encoding
|
|
def test_long_vex_encoding(self):
|
|
xmm0, xmm8 = def_reg(dtypes.float32, XMM[0]), def_reg(dtypes.float32, XMM[8])
|
|
add = ins(X86Ops.VADDSS, dtypes.float32, (xmm0, xmm8), XMM[0])
|
|
# vaddss xmm0, xmm0, xmm8
|
|
self.assertEqual(bytes.fromhex(self.encode(add)), bytes.fromhex("C4 C1 7A 58 C0"))
|
|
|
|
# test ymm encoding
|
|
def test_ymm_encoding(self):
|
|
xmm0, xmm1 = def_reg(dtypes.float32.vec(8), XMM[0]), def_reg(dtypes.float32.vec(8), XMM[1])
|
|
add = ins(X86Ops.VADDPS, dtypes.float32.vec(8), (xmm0, xmm1), XMM[0])
|
|
# vaddps ymm0, ymm0, ymm1
|
|
self.assertEqual(bytes.fromhex(self.encode(add)), bytes.fromhex("C5 FC 58 C1"))
|
|
|
|
# test encoding where register is in the immediate field
|
|
def test_reg_in_imm_field(self):
|
|
xmm0, xmm1, xmm2 = def_reg(dtypes.float32, XMM[0]), def_reg(dtypes.float32, XMM[1]), def_reg(dtypes.float32, XMM[2])
|
|
blend = ins(X86Ops.VBLENDVPS, dtypes.float32, (xmm0, xmm1, xmm2), XMM[0])
|
|
# vblendvps xmm0, xmm0, xmm1, xmm2
|
|
self.assertEqual(bytes.fromhex(self.encode(blend)), bytes.fromhex("C4 E3 79 4A C1 20"))
|
|
|
|
# when writting to mem the uop takes the store form where dtype is void and there's no definition
|
|
def test_write_mem(self):
|
|
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)
|
|
xmm0 = def_reg(dtypes.float32, XMM[0])
|
|
extr = ins(X86Ops.VPEXTRD, dtypes.void, (base, index, disp, xmm0, imm(dtypes.uint8, 0)))
|
|
# vpextrd dword ptr [rdi + rsi*4 + 0xa], xmm0, 0
|
|
self.assertEqual(bytes.fromhex(self.encode(extr)), bytes.fromhex("C4 E3 79 16 44 B7 0A 00"))
|
|
|
|
# test two address instruction with fused load works
|
|
def test_two_address_load(self):
|
|
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)
|
|
cmove = ins(X86Ops.CMOVE, dtypes.int32, (base, index, disp), RAX)
|
|
# cmove eax, dword ptr [rdi + rsi*4 + 0xa]
|
|
self.assertEqual(bytes.fromhex(self.encode(cmove)), bytes.fromhex("0F 44 44 B7 0A"))
|
|
|
|
# test instruction where displacement and imm have the same value
|
|
def test_disp_imm_same_value(self):
|
|
base, index, disp = def_reg(dtypes.int8.ptr(), RDI), def_reg(dtypes.int8, RSI), imm(dtypes.int8, 10)
|
|
mov = ins(X86Ops.MOVi, dtypes.void, (base, index, disp, disp))
|
|
# mov byte ptr [rdi + rsi + 0xa], 0xa
|
|
self.assertEqual(bytes.fromhex(self.encode(mov)), bytes.fromhex("40 C6 44 37 0A 0A"))
|
|
|
|
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10)
|
|
imul = ins(X86Ops.IMULi, dtypes.int32, (base, index, disp) + (imm(dtypes.int32, 10),), RDI)
|
|
# imul edi, dword ptr [rdi + rsi*4 + 0xa], 0xa
|
|
self.assertEqual(bytes.fromhex(self.encode(imul)), bytes.fromhex("69 BC B7 0A 00 00 00 0A 00 00 00"))
|
|
|
|
# cmoves have the cmp as the last src even though it is not explicitly used, the cmp doesn't define a reg and is ignored in the encoding
|
|
def test_cmove_ignore_cmp(self):
|
|
cmove = ins(X86Ops.CMOVE, dtypes.int32, (def_reg(dtypes.int32, RAX), UOp(Ops.INS, arg=X86Ops.CMP)), RDX)
|
|
# cmove edx, eax
|
|
self.assertEqual(bytes.fromhex(self.encode(cmove)), bytes.fromhex("0F 44 D0"))
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main() |