mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
float16 fix
This commit is contained in:
parent
8d4a48fcd3
commit
587259976d
2 changed files with 8 additions and 7 deletions
|
|
@ -11,7 +11,7 @@ from tinygrad import nn, dtypes, Device, Tensor, Variable
|
|||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat
|
||||
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
|
||||
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp, CPU_X86
|
||||
from tinygrad.schedule.rangeify import Kernel
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
||||
|
||||
|
|
@ -1816,6 +1816,7 @@ class TestSchedule(unittest.TestCase):
|
|||
self.assertEqual(b.tolist(), [False, False])
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Validation error on WebGPU")
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU" and CPU_X86, "seg fault")
|
||||
def test_mnist_val(self):
|
||||
from tinygrad.nn.datasets import mnist
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -292,11 +292,11 @@ isel_matcher = PatternMatcher([
|
|||
(UPat((Ops.DEFINE_REG, Ops.DEFINE_LOCAL), name="x"),
|
||||
lambda ctx,x: x.replace(op=X86Ops.LEA, src=(UOp(X86Ops.DEFINE_REG, x.dtype, arg=RSP), UOp(Ops.NOOP), imm(dtypes.int32, ctx.inc_stack(x.dtype.nbytes()))), arg=None)), # noqa: E501
|
||||
# constants that can't be immediates, move them to registers
|
||||
#(UPat(Ops.CONST, dtypes.float16, name="x"), lambda x: UOp(X86Ops.VMOVD, x.dtype, UOp(X86Ops.MOVi, dtypes.int32, (x.replace(op=X86Ops.IMM))))),
|
||||
(UPat(Ops.CONST, dtypes.float32, name="x"), lambda x: UOp(X86Ops.VMOVD, x.dtype, (UOp(X86Ops.MOVi, dtypes.int32, (x.replace(op=X86Ops.IMM),)),))),
|
||||
(UPat(Ops.CONST, dtypes.float64, name="x"), lambda x: UOp(X86Ops.VMOVQ, x.dtype, (UOp(X86Ops.MOVABS, dtypes.int64, (x.replace(op=X86Ops.IMM),)),))),
|
||||
(UPat(Ops.CONST, dtypes.int64s, name="x"), lambda x: UOp(X86Ops.MOVABS, x.dtype, (x.replace(op=X86Ops.IMM),)) if x.tag is None else None),
|
||||
(UPat(Ops.CONST, dtypes.ints+(dtypes.bool,), name="x"), lambda x: UOp(X86Ops.MOVi, x.dtype, (x.replace(op=X86Ops.IMM),)) if x.tag is None else None),
|
||||
(UPat(Ops.CONST, dtypes.float16, name="x"), lambda x: UOp(X86Ops.VPINSRW, x.dtype, (def_reg(x.dtype), UOp(X86Ops.MOVi, dtypes.int16, (imm(x.dtype, x.arg),)), imm(dtypes.uint8, 0)))),
|
||||
(UPat(Ops.CONST, dtypes.float32, name="x"), lambda x: UOp(X86Ops.VMOVD, x.dtype, (UOp(X86Ops.MOVi, dtypes.int32, (imm(x.dtype, x.arg),)),))),
|
||||
(UPat(Ops.CONST, dtypes.float64, name="x"), lambda x: UOp(X86Ops.VMOVQ, x.dtype, (UOp(X86Ops.MOVABS, dtypes.int64, (imm(x.dtype, x.arg),)),))),
|
||||
(UPat(Ops.CONST, dtypes.int64s, name="x"), lambda x: UOp(X86Ops.MOVABS, x.dtype, (imm(x.dtype, x.arg),)) if x.tag is None else None),
|
||||
(UPat(Ops.CONST, dtypes.ints+(dtypes.bool,), name="x"), lambda x: UOp(X86Ops.MOVi, x.dtype, (imm(x.dtype, x.arg),)) if x.tag is None else None),
|
||||
# LEA, first 2 cases only happen if INDEX is followed by a WHERE preventing the displacement being moved to the LOAD/STORE
|
||||
# if the idx can be less than 0 need to sign extend
|
||||
(UPat(Ops.INDEX, src=(UPat.var("base"), UPat.var("idx") + UPat.cvar("dis")), name="x"), lambda base,idx,dis,x: x.replace(op=X86Ops.LEA, src=(base, idx.cast(dtypes.int64) if idx.vmin < 0 else idx, disp(dis.const_like(dis.arg * base.dtype.itemsize))))),
|
||||
|
|
@ -452,7 +452,7 @@ isel_matcher = PatternMatcher([
|
|||
(UPat(Ops.LOAD, dt_128bit, name="x"), lambda ctx,x: x.replace(op=X86Ops.VMOVUPS, src=fuse_index(ctx, x))),
|
||||
(UPat(Ops.LOAD, dt_64bit, name="x"), lambda ctx,x: x.replace(op=X86Ops.VMOVSD, src=fuse_index(ctx, x))),
|
||||
(UPat(Ops.LOAD, dt_32bit, name="x"), lambda ctx,x: x.replace(op=X86Ops.VMOVSS, src=fuse_index(ctx, x))),
|
||||
(UPat(Ops.LOAD, dt_16bit, name="x"), lambda ctx,x: x.replace(op=X86Ops.VPINSRW, src=(def_reg(x.dtype),) + fuse_index(ctx, x) + (imm(dtypes.uint8, 0),))),
|
||||
(UPat(Ops.LOAD, dt_16bit, name="x"), lambda ctx,x: x.replace(op=X86Ops.VPINSRW, src=(def_reg(x.dtype, x.arg),) + fuse_index(ctx, x) + (imm(dtypes.uint8, 0),))),
|
||||
(UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda ctx,x: x.replace(op=X86Ops.MOV, src=fuse_index(ctx, x))),
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat(), UPat(dtype=dt_128bit)), name="x"), lambda ctx,x: x.replace(op=X86Ops.VMOVUPSm, src=fuse_index(ctx, x) + (x.src[-1],))), # noqa: E501
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat(), UPat(dtype=dt_64bit)), name="x"), lambda ctx,x: x.replace(op=X86Ops.VMOVSDm, src=fuse_index(ctx, x) + (x.src[-1],))), # noqa: E501
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue