float16 fix

This commit is contained in:
ttomsa 2026-01-01 18:55:28 +00:00
commit 587259976d
2 changed files with 8 additions and 7 deletions

View file

@ -11,7 +11,7 @@ from tinygrad import nn, dtypes, Device, Tensor, Variable
from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType, ImageDType
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp, CPU_X86
from tinygrad.schedule.rangeify import Kernel
from tinygrad.engine.realize import CompiledRunner, run_schedule
@ -1816,6 +1816,7 @@ class TestSchedule(unittest.TestCase):
self.assertEqual(b.tolist(), [False, False])
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Validation error on WebGPU")
@unittest.skipIf(Device.DEFAULT == "CPU" and CPU_X86, "seg fault")
def test_mnist_val(self):
from tinygrad.nn.datasets import mnist
import torch

View file

@ -292,11 +292,11 @@ isel_matcher = PatternMatcher([
(UPat((Ops.DEFINE_REG, Ops.DEFINE_LOCAL), name="x"),
lambda ctx,x: x.replace(op=X86Ops.LEA, src=(UOp(X86Ops.DEFINE_REG, x.dtype, arg=RSP), UOp(Ops.NOOP), imm(dtypes.int32, ctx.inc_stack(x.dtype.nbytes()))), arg=None)), # noqa: E501
# constants that can't be immediates, move them to registers
#(UPat(Ops.CONST, dtypes.float16, name="x"), lambda x: UOp(X86Ops.VMOVD, x.dtype, UOp(X86Ops.MOVi, dtypes.int32, (x.replace(op=X86Ops.IMM))))),
(UPat(Ops.CONST, dtypes.float32, name="x"), lambda x: UOp(X86Ops.VMOVD, x.dtype, (UOp(X86Ops.MOVi, dtypes.int32, (x.replace(op=X86Ops.IMM),)),))),
(UPat(Ops.CONST, dtypes.float64, name="x"), lambda x: UOp(X86Ops.VMOVQ, x.dtype, (UOp(X86Ops.MOVABS, dtypes.int64, (x.replace(op=X86Ops.IMM),)),))),
(UPat(Ops.CONST, dtypes.int64s, name="x"), lambda x: UOp(X86Ops.MOVABS, x.dtype, (x.replace(op=X86Ops.IMM),)) if x.tag is None else None),
(UPat(Ops.CONST, dtypes.ints+(dtypes.bool,), name="x"), lambda x: UOp(X86Ops.MOVi, x.dtype, (x.replace(op=X86Ops.IMM),)) if x.tag is None else None),
(UPat(Ops.CONST, dtypes.float16, name="x"), lambda x: UOp(X86Ops.VPINSRW, x.dtype, (def_reg(x.dtype), UOp(X86Ops.MOVi, dtypes.int16, (imm(x.dtype, x.arg),)), imm(dtypes.uint8, 0)))),
(UPat(Ops.CONST, dtypes.float32, name="x"), lambda x: UOp(X86Ops.VMOVD, x.dtype, (UOp(X86Ops.MOVi, dtypes.int32, (imm(x.dtype, x.arg),)),))),
(UPat(Ops.CONST, dtypes.float64, name="x"), lambda x: UOp(X86Ops.VMOVQ, x.dtype, (UOp(X86Ops.MOVABS, dtypes.int64, (imm(x.dtype, x.arg),)),))),
(UPat(Ops.CONST, dtypes.int64s, name="x"), lambda x: UOp(X86Ops.MOVABS, x.dtype, (imm(x.dtype, x.arg),)) if x.tag is None else None),
(UPat(Ops.CONST, dtypes.ints+(dtypes.bool,), name="x"), lambda x: UOp(X86Ops.MOVi, x.dtype, (imm(x.dtype, x.arg),)) if x.tag is None else None),
# LEA, first 2 cases only happen if INDEX is followed by a WHERE preventing the displacement being moved to the LOAD/STORE
# if the idx can be less than 0 need to sign extend
(UPat(Ops.INDEX, src=(UPat.var("base"), UPat.var("idx") + UPat.cvar("dis")), name="x"), lambda base,idx,dis,x: x.replace(op=X86Ops.LEA, src=(base, idx.cast(dtypes.int64) if idx.vmin < 0 else idx, disp(dis.const_like(dis.arg * base.dtype.itemsize))))),
@ -452,7 +452,7 @@ isel_matcher = PatternMatcher([
(UPat(Ops.LOAD, dt_128bit, name="x"), lambda ctx,x: x.replace(op=X86Ops.VMOVUPS, src=fuse_index(ctx, x))),
(UPat(Ops.LOAD, dt_64bit, name="x"), lambda ctx,x: x.replace(op=X86Ops.VMOVSD, src=fuse_index(ctx, x))),
(UPat(Ops.LOAD, dt_32bit, name="x"), lambda ctx,x: x.replace(op=X86Ops.VMOVSS, src=fuse_index(ctx, x))),
(UPat(Ops.LOAD, dt_16bit, name="x"), lambda ctx,x: x.replace(op=X86Ops.VPINSRW, src=(def_reg(x.dtype),) + fuse_index(ctx, x) + (imm(dtypes.uint8, 0),))),
(UPat(Ops.LOAD, dt_16bit, name="x"), lambda ctx,x: x.replace(op=X86Ops.VPINSRW, src=(def_reg(x.dtype, x.arg),) + fuse_index(ctx, x) + (imm(dtypes.uint8, 0),))),
(UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda ctx,x: x.replace(op=X86Ops.MOV, src=fuse_index(ctx, x))),
(UPat(Ops.STORE, src=(UPat(), UPat(), UPat(dtype=dt_128bit)), name="x"), lambda ctx,x: x.replace(op=X86Ops.VMOVUPSm, src=fuse_index(ctx, x) + (x.src[-1],))), # noqa: E501
(UPat(Ops.STORE, src=(UPat(), UPat(), UPat(dtype=dt_64bit)), name="x"), lambda ctx,x: x.replace(op=X86Ops.VMOVSDm, src=fuse_index(ctx, x) + (x.src[-1],))), # noqa: E501