tinygrad/test/backend/test_encodings.py
ttomsa aa1e59ab97
X86 with Ops.INS (#14873)
* draft

* cleanup test_encodings

* cleanup test_isel

* model flag state and support rematerialization

* woops

* add vbroadcastss instruction

* don't fuse load if used multiple times in src

* add movabs instruction and fix idiv

* fixes

* add x86 backend to tests

* float16 fix

* rm TwoAddress2nd

* add BARRIER

* test windows ci

* yup isel fixes the mask stuff too and its beautiful

* add cmoves to the spec

* support storing imms

* no TUPLE_ORDER, breaks tests

* fix remaining seg faults

* add float max

* always fuse index

* minor

* fix DEFINE_VAR/SPECIAL and enable multithreading

* linter

* more linter

* more

* more

* more

* let's try this

* perhaps

* start new scheduler

* more scheduling info

* cleaner shuffle functions

* fixup isel tests

* skip bounds check when NOOPs exist

* skip inf rewrite tests

* fix const tag hack and add x86ops to _shape

* fix

* skip a few tests

* func arg order independent from op value

* x86 goes in own linearize

* switch to PARAM

* more

* add min x86op and neg in decomps

* do mulacc in isel

* use def_reg in test_encodings

* enable emulated int64 tests

* how much does this fix

* Ops becomes OpType

* fix

* rm noqa

* rm machine scheduler stuff

* and this

* allow for extending enums and move X86Ops out of uop

* fix imports

* rm X86GroupOp from ops.py

* spacing

* tell mypy to shut up

* more linter

* add x86op test

* allow set[X86Ops] in upat

* move NOOPs to pre_isel_matcher and rm NOOP from spec

* more asserts

* also this

* cleanup encode

* simplify live range

* fix idiv

* add Ops.INS to x86

* more changes

* more changes

* more changes

* fix

* fix

* fix

* fix

* print formatted assembly

* fix 8bit idiv?

* oops

* enable float16  and unaligned vector load/store

* actually no

* move x86 tests

* no more bool cast

* fix

* linter

* linter

* move X86Ops to x86.py

* fix vpbroadcast

* cleanups

* linter

* print correct reg names

* canonical max

* move max/min and add test

* support float16 vector load/store

* rm bad rewrite

* vpsrldq can't access memory

* regalloc takes renderer

* enable vector load/store on all dtypes

* more isel tests

* rm this for now

* a lot better

* fix

* fix

* fix

* deal with flags correctly

* fix

* enable gep noop rule

* fix

* fix

* fix

* add callee saved registers

* use Ops.CONST instead of X86Ops.IMM

* fix

* enable TUPLE_ORDER

* fix

* rm x86 code in linearizer

* fix

* fix

* fix

* move isa rewrites to codegen

* fix

* fix

* skip test_linearizer.py

* skip more tests

* fix

* fix for idiv/mod changes

* fix

* don't use fmadd if it duplicates fused op

* hacky

* fix

* cleanups

* cleanups

* fix

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2026-05-19 12:42:54 -07:00

149 lines
No EOL
7.8 KiB
Python

import unittest
from tinygrad import Device
from tinygrad.uop.ops import UOp, Ops
from tinygrad.dtype import dtypes
from tinygrad.renderer.isa.x86 import X86Ops, X86Renderer, RBP, RDI, RSP, RSI, RAX, RDX, XMM, GPR, imm, def_reg
def ins(op, dt, src, tag=None): return UOp(Ops.INS, arg=op, dtype=dt, src=src, tag=tag)
@unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, X86Renderer), "only on x86")
class TestEncodingsX86(unittest.TestCase):
# NOTE: x86 supports a single displacement as memory address and index without base memory address
# these have no use cases so they aren't supported
def encode(self, u:UOp): return Device[Device.DEFAULT].renderer.render([u])
# displacement of 0 isn't emitted
def test_base_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RDI)
# mov edi, dword ptr [rdi]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 3F"))
# rsp/r12 require a sib byte when used as base memory address
def test_rsp_base_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RSP), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RSP)
# mov esp, dword ptr [rsp]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 24 24"))
# rbp/r13 require a displacement when used as base memory address
def test_rbp_base_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RBP), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RBP)
# mov ebp, dword ptr [rbp + 0]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 6D 00"))
# test [base + index*scale]
def test_base_index_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, RDX), imm(dtypes.int8, 0)), RAX)
# mov eax, dword ptr [rax + rdx*4]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 04 90"))
# rsp as index means no index
def test_rsp_index_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, RSP), imm(dtypes.int8, 0)), RAX)
# mov eax, dword ptr [rax]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 00"))
# however r12 is a valid index
def test_r12_index_address(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, GPR[12]), imm(dtypes.int8, 0)), RAX)
# mov eax, dword ptr [rax + r12*4]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("42 8B 04 A0"))
# test [base + index*scale + 8bit disp]
def test_complex_address_8bit_disp(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)), RDI)
# mov edi, dword ptr [rdi + rsi*4 + 0xa]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 7C B7 0A"))
# test [base + index*scale + 32bit disp]
def test_complex_address_32bit_disp(self):
load = ins(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10000)), RDI)
# mov edi, dword ptr [rdi + rsi*4 + 0x2710]
self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B BC B7 10 27 00 00"))
# 8bit variants of legacy instructions subtract 1 from opcode
def test_8bit_legacy_encoding(self):
cast = ins(X86Ops.MOVSX, dtypes.int32, (def_reg(dtypes.int8, RDX),), RAX)
# movsx eax, dl
self.assertEqual(bytes.fromhex(self.encode(cast)), bytes.fromhex("0F BE C2"))
# accessing lower 8 bits of rsp, rbp, rsi, rdi requires rex prefix
def test_lower_8bits_reg(self):
cast = ins(X86Ops.MOVSX, dtypes.int32, (def_reg(dtypes.int8, RDI),), RAX)
# movsx eax, dil
self.assertEqual(bytes.fromhex(self.encode(cast)), bytes.fromhex("40 0F BE C7"))
# test 16 bit variant of legacy instruction
def test_16bit_legacy_encoding(self):
cast = ins(X86Ops.MOVSX, dtypes.int16, (def_reg(dtypes.int8, RDX),), RAX)
# movsx ax, dl
self.assertEqual(bytes.fromhex(self.encode(cast)), bytes.fromhex("66 0F BE C2"))
# test 64 bit variant of legacy instruction
def test_64bit_legacy_encoding(self):
cast = ins(X86Ops.MOVSX, dtypes.int64, (def_reg(dtypes.int8, RDX),), RAX)
# movsx rax, dl
self.assertEqual(bytes.fromhex(self.encode(cast)), bytes.fromhex("48 0F BE C2"))
# test compact vex encoding
def test_compact_vex_encoding(self):
xmm0, xmm1 = def_reg(dtypes.float32, XMM[0]), def_reg(dtypes.float32, XMM[1])
add = ins(X86Ops.VADDSS, dtypes.float32, (xmm0, xmm1), XMM[0])
# vaddss xmm0, xmm0, xmm1
self.assertEqual(bytes.fromhex(self.encode(add)), bytes.fromhex("C5 FA 58 C1"))
# test long vex encoding
def test_long_vex_encoding(self):
xmm0, xmm8 = def_reg(dtypes.float32, XMM[0]), def_reg(dtypes.float32, XMM[8])
add = ins(X86Ops.VADDSS, dtypes.float32, (xmm0, xmm8), XMM[0])
# vaddss xmm0, xmm0, xmm8
self.assertEqual(bytes.fromhex(self.encode(add)), bytes.fromhex("C4 C1 7A 58 C0"))
# test ymm encoding
def test_ymm_encoding(self):
xmm0, xmm1 = def_reg(dtypes.float32.vec(8), XMM[0]), def_reg(dtypes.float32.vec(8), XMM[1])
add = ins(X86Ops.VADDPS, dtypes.float32.vec(8), (xmm0, xmm1), XMM[0])
# vaddps ymm0, ymm0, ymm1
self.assertEqual(bytes.fromhex(self.encode(add)), bytes.fromhex("C5 FC 58 C1"))
# test encoding where register is in the immediate field
def test_reg_in_imm_field(self):
xmm0, xmm1, xmm2 = def_reg(dtypes.float32, XMM[0]), def_reg(dtypes.float32, XMM[1]), def_reg(dtypes.float32, XMM[2])
blend = ins(X86Ops.VBLENDVPS, dtypes.float32, (xmm0, xmm1, xmm2), XMM[0])
# vblendvps xmm0, xmm0, xmm1, xmm2
self.assertEqual(bytes.fromhex(self.encode(blend)), bytes.fromhex("C4 E3 79 4A C1 20"))
# when writting to mem the uop takes the store form where dtype is void and there's no definition
def test_write_mem(self):
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)
xmm0 = def_reg(dtypes.float32, XMM[0])
extr = ins(X86Ops.VPEXTRD, dtypes.void, (base, index, disp, xmm0, imm(dtypes.uint8, 0)))
# vpextrd dword ptr [rdi + rsi*4 + 0xa], xmm0, 0
self.assertEqual(bytes.fromhex(self.encode(extr)), bytes.fromhex("C4 E3 79 16 44 B7 0A 00"))
# test two address instruction with fused load works
def test_two_address_load(self):
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)
cmove = ins(X86Ops.CMOVE, dtypes.int32, (base, index, disp), RAX)
# cmove eax, dword ptr [rdi + rsi*4 + 0xa]
self.assertEqual(bytes.fromhex(self.encode(cmove)), bytes.fromhex("0F 44 44 B7 0A"))
# test instruction where displacement and imm have the same value
def test_disp_imm_same_value(self):
base, index, disp = def_reg(dtypes.int8.ptr(), RDI), def_reg(dtypes.int8, RSI), imm(dtypes.int8, 10)
mov = ins(X86Ops.MOVi, dtypes.void, (base, index, disp, disp))
# mov byte ptr [rdi + rsi + 0xa], 0xa
self.assertEqual(bytes.fromhex(self.encode(mov)), bytes.fromhex("40 C6 44 37 0A 0A"))
base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10)
imul = ins(X86Ops.IMULi, dtypes.int32, (base, index, disp) + (imm(dtypes.int32, 10),), RDI)
# imul edi, dword ptr [rdi + rsi*4 + 0xa], 0xa
self.assertEqual(bytes.fromhex(self.encode(imul)), bytes.fromhex("69 BC B7 0A 00 00 00 0A 00 00 00"))
# cmoves have the cmp as the last src even though it is not explicitly used, the cmp doesn't define a reg and is ignored in the encoding
def test_cmove_ignore_cmp(self):
cmove = ins(X86Ops.CMOVE, dtypes.int32, (def_reg(dtypes.int32, RAX), UOp(Ops.INS, arg=X86Ops.CMP)), RDX)
# cmove edx, eax
self.assertEqual(bytes.fromhex(self.encode(cmove)), bytes.fromhex("0F 44 D0"))
if __name__ == "__main__":
unittest.main()