safe changes from new dtype branch [pr] (#7397)

* safe changes from new dtype branch [pr]

* only image test on GPU
This commit is contained in:
George Hotz 2024-10-30 16:18:48 +07:00 committed by GitHub
commit 4e2895f8d2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 36 additions and 28 deletions

View file

@ -116,7 +116,7 @@ def run_linearizer(lin: Kernel, rawbufs=None, var_vals=None) -> Tuple[str, Any]:
def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=None, rtol=1e-2, atol=1e-2):
# TODO: for bfloat16 it compiles linearizer, but it does not run because numpy cannot generate bf16 buffer.
has_bf16 = any(b.dtype == dtypes.bfloat16 for b in lin.membufs)
has_bf16 = any(b.dtype.base == dtypes.bfloat16 for b in lin.membufs)
# TODO: raise specific fuzzing errors instead of str, and propagate the error message
try:

View file

@ -297,7 +297,7 @@ class TestBoolDType(TestDType): DTYPE = dtypes.bool
class TestPtrDType(unittest.TestCase):
def test_vec_double(self):
dt1 = dtypes.float.vec(4).ptr(v=4)
dt1 = dtypes.float.vec(4).ptr().vec(4)
dt2 = dtypes.float.vec(4).ptr().vec(4)
self.assertEqual(dt1, dt2)
self.assertEqual(str(dt1), str(dt2))
@ -313,7 +313,7 @@ class TestPtrDType(unittest.TestCase):
self.assertEqual(dt, dtypes.float)
def test_serialize(self):
dt = dtypes.float.vec(4).ptr(v=4)
dt = dtypes.float.vec(4).ptr().vec(4)
self.assertEqual(dt, eval(str(dt)))
def test_vcount(self):
@ -327,18 +327,18 @@ class TestPtrDType(unittest.TestCase):
self.assertEqual(dt.v, 1)
self.assertEqual(dt.count, 4)
dt = dtypes.float.vec(4).ptr(v=4)
dt = dtypes.float.vec(4).ptr().vec(4)
self.assertEqual(dt.vcount, 4)
self.assertEqual(dt.v, 4)
self.assertEqual(dt.count, 4)
class TestImageDType(unittest.TestCase):
def test_image_scalar(self):
assert dtypes.imagef((10,10)).scalar() == dtypes.float32
assert dtypes.imageh((10,10)).scalar() == dtypes.float32
assert dtypes.imagef((10,10)).base.scalar() == dtypes.float32
assert dtypes.imageh((10,10)).base.scalar() == dtypes.float32
def test_image_vec(self):
assert dtypes.imagef((10,10)).vec(4) == dtypes.float32.vec(4)
assert dtypes.imageh((10,10)).vec(4) == dtypes.float32.vec(4)
assert dtypes.imagef((10,10)).base.vec(4) == dtypes.float32.vec(4)
assert dtypes.imageh((10,10)).base.vec(4) == dtypes.float32.vec(4)
class TestEqStrDType(unittest.TestCase):
def test_image_ne(self):

View file

@ -13,8 +13,8 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
def helper_test_lin(lin: Kernel, opts, failed_platforms, rtol=1e-2, atol=1e-2):
if any(b.dtype == dtypes.half for b in lin.membufs) and not is_dtype_supported(dtypes.half): return
if any(b.dtype == dtypes.bfloat16 for b in lin.membufs) and not is_dtype_supported(dtypes.bfloat16): return
if any(b.dtype.base == dtypes.half for b in lin.membufs) and not is_dtype_supported(dtypes.half): return
if any(b.dtype.base == dtypes.bfloat16 for b in lin.membufs) and not is_dtype_supported(dtypes.bfloat16): return
for opt in opts:
try:

View file

@ -1,6 +1,6 @@
# basic self-contained tests of the external functionality of tinygrad
import unittest
from tinygrad import Tensor, Context, Variable, TinyJit
from tinygrad import Tensor, Context, Variable, TinyJit, dtypes, Device
class TestTiny(unittest.TestCase):
@ -18,11 +18,11 @@ class TestTiny(unittest.TestCase):
out = Tensor.cat(Tensor.ones(8).contiguous(), Tensor.ones(8).contiguous())
self.assertListEqual(out.tolist(), [1]*16)
def test_gemm(self):
N = 4
def test_gemm(self, N=4, out_dtype=dtypes.float):
a = Tensor.ones(N,N).contiguous()
b = Tensor.eye(N).contiguous()
self.assertListEqual((a@b).flatten().tolist(), [1.0]*(N*N))
self.assertListEqual((out:=a@b).flatten().tolist(), [1.0]*(N*N))
self.assertEqual(out.dtype, out_dtype)
# *** JIT (for Python speed) ***
@ -57,6 +57,15 @@ class TestTiny(unittest.TestCase):
ret = Tensor.ones(s).contiguous().reshape(i.bind(s)).sum()
self.assertEqual(ret.item(), s)
# *** image ***
@unittest.skipIf(Device.DEFAULT != "GPU", "image only supported on GPU")
def test_image(self):
with Context(IMAGE=2): self.test_gemm(out_dtype=dtypes.imagef((4, 1, 4)))
def test_beam_image(self):
with Context(BEAM=1): self.test_image()
if __name__ == '__main__':
unittest.main()

View file

@ -13,7 +13,7 @@ if TYPE_CHECKING: from tinygrad.renderer import Renderer
# ***** float4/image store handling *****
def fold_expanded(ex, buf):
if buf.dtype != dtypes.float.ptr() and buf.dtype != dtypes.half.ptr() and not isinstance(buf.dtype, ImageDType): return None
if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): return None
new_srcs = dedup(list(ex.src))
old_new_srcs = new_srcs[:]
is_load, is_image = new_srcs[0].op is UOps.LOAD, isinstance(buf.dtype, ImageDType)
@ -32,7 +32,7 @@ def fold_expanded(ex, buf):
offsets_rootsrc[root_src][arg] = i
# then rewrite everything we can
lengths = [4] if is_image else ([8,4,2] if buf.dtype == dtypes.half.ptr() and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]))
lengths = [4] if is_image else ([8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]))
used = set()
for rootsrc, offsets in offsets_rootsrc.items():
for o in offsets:

View file

@ -4,7 +4,7 @@ from collections import defaultdict
from typing import Optional, Dict, Tuple, Any, Iterator
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, contextlib
from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
from tinygrad.dtype import DType, ImageDType
from tinygrad.dtype import DType, ImageDType, PtrDType
from tinygrad.renderer import Renderer
# **************** Device ****************
@ -55,8 +55,8 @@ class BufferOptions:
class Buffer:
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
assert isinstance(dtype, DType)
if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
else: assert isinstance(dtype, DType) and not isinstance(dtype, PtrDType)
self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
if base is None:
assert offset == 0, "base buffers can't have offset"

View file

@ -5,7 +5,7 @@ from dataclasses import replace
from tinygrad.ops import UOp, UOps, Variable, sym_infer
from tinygrad.device import Device, Buffer, Compiler
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
from tinygrad.dtype import ImageDType
from tinygrad.dtype import ImageDType, PtrDType
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
from tinygrad.tensor import Tensor
from tinygrad.engine.realize import CompiledRunner
@ -92,8 +92,10 @@ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> List[Buffer]:
rawbufs: List[Optional[Buffer]] = [None]*len(bufsts)
for k,lx in bufsts.items():
buf_size = prod(dtype.shape) if isinstance(dtype:=lx[0].src[0].dtype, ImageDType) else max(y.st_arg.real_size() for y in lx)
assert isinstance(dtype, (PtrDType, ImageDType))
if buf_size == 0: buf_size = 1 # create a size 1 buffer if no cell is accessed in kernel. # TODO: remove from kernel input in this case.
rawbufs[k] = Buffer(lin.opts.device, buf_size, dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, dtype)
buf_dtype = dtype if isinstance(dtype, ImageDType) else dtype.base
rawbufs[k] = Buffer(lin.opts.device, buf_size, buf_dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, buf_dtype)
assert all(r is not None for r in rawbufs)
return cast(List[Buffer], rawbufs)

View file

@ -74,7 +74,7 @@ class LLVMRenderer(Renderer):
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
# create llvm function
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values()]
func_dtypes = [(dtype_to_llvm_dtype[dtype.base if isinstance(dtype, PtrDType) else dtype],dtype) for dtype in buf_to_dtype.values()]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=name)
for a in func.args:
if a.type.is_pointer: a.add_attribute("noalias")

View file

@ -2012,6 +2012,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
print(t.conv2d(w).numpy())
```
"""
if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, acc_dtype)
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
@ -2097,6 +2098,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
print(a.dot(b).numpy())
```
"""
if IMAGE: return self.image_dot(w, acc_dtype)
n1, n2 = len(self.shape), len(w.shape)
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
if (L:=self.shape[-1]) != (R:=w.shape[-min(n2, 2)]): raise AssertionError(f"shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})")
@ -3438,7 +3440,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
# *** image Tensor function replacements ***
def image_dot(self, w:Tensor, acc_dtype=None):
def image_dot(self, w:Tensor, acc_dtype=None) -> Tensor:
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
n1, n2 = len(self.shape), len(w.shape)
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
@ -3453,7 +3455,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
return cx.image_conv2d(cw, groups=groups, acc_dtype=acc_dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
def image_conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None):
def image_conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None) -> Tensor:
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
@ -3513,11 +3515,6 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
if IMAGE:
# if IMAGE>0 we install these replacement functions in Tensor (hack!)
setattr(Tensor, "conv2d", Tensor.image_conv2d)
setattr(Tensor, "dot", Tensor.image_dot)
def _metadata_wrapper(fn):
def _wrapper(*args, **kwargs):
if _METADATA.get() is not None: return fn(*args, **kwargs)