mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
clean up ops_torch and ops_cpu (#2819)
This commit is contained in:
parent
f409b57854
commit
959d9cfed4
2 changed files with 30 additions and 29 deletions
|
|
@ -1,44 +1,45 @@
|
|||
import numpy as np
|
||||
from typing import Callable, Dict, Tuple
|
||||
from tinygrad.helpers import dtypes, flat_mv
|
||||
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op
|
||||
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, Op
|
||||
from tinygrad.device import Interpreted, Allocator
|
||||
|
||||
def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
|
||||
return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b)
|
||||
def reduce_axis(in_shape:Tuple[int, ...], out_shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
assert len(in_shape) == len(out_shape), "reduce shapes must have same dimensions"
|
||||
return tuple(i for i,(a,b) in enumerate(zip(in_shape, out_shape)) if a != b)
|
||||
|
||||
# TODO: this should be global infrastructure
|
||||
def output_type(x, y): return x.dtype if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority else y.dtype
|
||||
|
||||
def einsum_mulacc(einsum, get_strides, expand):
|
||||
def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
|
||||
def axes_slice(strides): return [i for i,s in enumerate(strides) if s != 0], tuple([slice(None) if s != 0 else 0 for i,s in enumerate(strides)])
|
||||
def axes_slice(strides): return tuple(i for i,s in enumerate(strides) if s != 0), tuple(slice(None) if s != 0 else 0 for s in strides)
|
||||
def mulacc(a, b, new_shape):
|
||||
(a_axes, a_slices), (b_axes, b_slices) = axes_slice(get_strides(a)), axes_slice(get_strides(b))
|
||||
out = [i for i in range(len(new_shape)) if a.shape[i] == new_shape[i] and (i in a_axes or i in b_axes)]
|
||||
ret = einsum(f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out)}", a[a_slices], b[b_slices])
|
||||
return expand(ret.reshape([(1 if i not in a_axes and i not in b_axes else s) for i,s in enumerate(new_shape)]), new_shape)
|
||||
return expand(ret.reshape(tuple(1 if i not in a_axes and i not in b_axes else s for i,s in enumerate(new_shape))), new_shape)
|
||||
return mulacc
|
||||
|
||||
def as_strided(x, arg):
|
||||
return np.ndarray(shape=arg[0], dtype=x.dtype, buffer=np.require(x, requirements='C'),
|
||||
offset=arg[2]*x.dtype.itemsize, strides=tuple(y*x.dtype.itemsize for y in arg[1]))
|
||||
shape, stride, offset = arg
|
||||
return np.ndarray(shape, x.dtype, buffer=np.require(x, requirements='C'), offset=offset*x.dtype.itemsize,
|
||||
strides=tuple(y*x.dtype.itemsize for y in stride))
|
||||
|
||||
numpy_fxn_for_op: Dict[Op, Callable] = {
|
||||
BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np),
|
||||
UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin,
|
||||
UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.SIN: np.sin, UnaryOps.SQRT: np.sqrt,
|
||||
UnaryOps.CAST: lambda x,y: x.view(y[0].np) if y[1] else x.astype(y[0].np, copy=False),
|
||||
UnaryOps.NEG: lambda x: np.logical_not(x) if x.dtype == np.bool_ else np.negative(x),
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).astype(output_type(x,y)),
|
||||
BinaryOps.ADD: np.add, BinaryOps.SUB: np.subtract, BinaryOps.MUL: np.multiply,
|
||||
BinaryOps.DIV: lambda x, y: np.divide(x, y).astype(output_type(x, y), copy=False),
|
||||
BinaryOps.XOR: np.bitwise_xor, UnaryOps.SQRT: np.sqrt,
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,
|
||||
ReduceOps.MAX: lambda x, new_shape: x.max(shape_to_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
|
||||
MovementOps.AS_STRIDED: as_strided, MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
|
||||
BinaryOps.XOR: np.bitwise_xor,
|
||||
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy(), optimize=True), lambda x: x.strides, np.broadcast_to),
|
||||
TernaryOps.WHERE: np.where,
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(reduce_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,
|
||||
ReduceOps.MAX: lambda x, new_shape: x.max(reduce_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
|
||||
MovementOps.AS_STRIDED: as_strided, MovementOps.EXPAND: np.broadcast_to, MovementOps.PAD: np.pad
|
||||
}
|
||||
|
||||
class NumpyAllocator(Allocator):
|
||||
|
|
|
|||
|
|
@ -1,47 +1,47 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from typing import Dict, Callable
|
||||
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, TernaryOps, ReduceOps, Op
|
||||
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, Op
|
||||
from tinygrad.device import Interpreted, Allocator
|
||||
from tinygrad.helpers import getenv, dtypes
|
||||
from tinygrad.runtime.ops_cpu import einsum_mulacc, shape_to_axis
|
||||
from tinygrad.helpers import getenv, dtypes, flatten
|
||||
from tinygrad.runtime.ops_cpu import einsum_mulacc, reduce_axis
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
||||
type_map = {torch.bool: dtypes.bool, torch.int8: dtypes.int8, torch.uint8: dtypes.uint8, torch.int16: dtypes.int16, torch.int32: dtypes.int32,
|
||||
torch.int64: dtypes.int64, torch.float16: dtypes.float16, torch.bfloat16: dtypes.bfloat16, torch.float32: dtypes.float32,
|
||||
torch.float64: dtypes.float64}
|
||||
inverse_type_map = {v: k for k,v in type_map.items()}
|
||||
# TODO: should unsupported types fail instead of implicit conversion?
|
||||
inverse_type_map.update({dtypes.uint16: torch.int16, dtypes.uint32: torch.int32, dtypes.uint64: torch.int64})
|
||||
def np_type_cvt(t): return {np.uint32: np.int32}.get(t, t)
|
||||
def np_type_cvt(t): return {np.uint32: np.int32, np.uint64: np.int64}.get(t, t)
|
||||
|
||||
def output_type(x, y): return x.dtype if type_map[x.dtype].priority > type_map[y.dtype].priority else y.dtype
|
||||
|
||||
def as_strided(x, arg):
|
||||
if any(i < 0 for i in arg[1]):
|
||||
return torch.as_strided(x.contiguous(), arg[0], tuple(abs(i) for i in arg[1]),
|
||||
arg[2] + sum((s-1)*a if a < 0 else 0 for (s,a) in zip(arg[0], arg[1]))).flip([i for i,a in enumerate(arg[1]) if a < 0])
|
||||
return torch.as_strided(x.contiguous(), arg[0], arg[1], arg[2])
|
||||
shape, stride, offset = arg
|
||||
if any(i < 0 for i in stride):
|
||||
return torch.as_strided(x.contiguous(), shape, tuple(abs(i) for i in stride),
|
||||
offset + sum((s-1)*a if a < 0 else 0 for (s,a) in zip(shape, stride))).flip([i for i,a in enumerate(stride) if a < 0])
|
||||
return torch.as_strided(x.contiguous(), shape, stride, offset)
|
||||
|
||||
torch_fxn_for_op: Dict[Op, Callable] = {
|
||||
# TODO: torch.tensor should work here. it doesn't due to "overflow" in uint8
|
||||
#BufferOps.CONST: lambda val, dtype: torch.tensor(val, device=device, dtype=inverse_type_map[dtype]),
|
||||
BufferOps.CONST: lambda val, dtype: torch.from_numpy(np.array(val, dtype=np_type_cvt(dtype.np))).to(device),
|
||||
UnaryOps.SQRT: lambda x: x.sqrt(), UnaryOps.EXP2: lambda x: x.exp2(), UnaryOps.LOG2: lambda x: x.log2(), UnaryOps.SIN: torch.sin,
|
||||
UnaryOps.EXP2: torch.exp2, UnaryOps.LOG2: torch.log2, UnaryOps.SIN: torch.sin, UnaryOps.SQRT: torch.sqrt,
|
||||
UnaryOps.CAST: lambda x,y: (x.view if y[1] else x.type)(inverse_type_map[y[0]]),
|
||||
UnaryOps.NEG: lambda x: torch.logical_not(x) if x.dtype is torch.bool else torch.neg(x),
|
||||
BinaryOps.MAX: torch.maximum, BinaryOps.CMPLT: lambda x,y: (x<y).type(torch.promote_types(x.dtype, y.dtype)),
|
||||
BinaryOps.ADD: lambda x,y: torch.add(x, y).type(output_type(x,y)),
|
||||
BinaryOps.SUB: lambda x,y: torch.sub(x, y).type(output_type(x,y)),
|
||||
BinaryOps.MUL: lambda x,y: torch.mul(x, y).type(output_type(x,y)),
|
||||
BinaryOps.ADD: torch.add, BinaryOps.SUB: torch.sub, BinaryOps.MUL: torch.mul,
|
||||
BinaryOps.DIV: lambda x,y: torch.div(x, y).type(torch.promote_types(x.dtype, y.dtype)),
|
||||
BinaryOps.XOR: torch.bitwise_xor,
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,
|
||||
ReduceOps.MAX: lambda x, new_shape: x.amax(shape_to_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
|
||||
MovementOps.AS_STRIDED: as_strided, MovementOps.EXPAND: lambda x, arg: x.expand(arg),
|
||||
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]), # pylint: disable=E1102
|
||||
TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: torch.einsum(s, a.float(), b.float()).type(output_type(a,b)),
|
||||
lambda x: x.stride(), lambda x,s: x.expand(s)),
|
||||
TernaryOps.WHERE: lambda x, y, z: torch.where(x != 0, y, z),
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(reduce_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True) if x.shape != new_shape else x,
|
||||
ReduceOps.MAX: lambda x, new_shape: x.amax(reduce_axis(x.shape, new_shape), keepdims=True) if x.shape != new_shape else x,
|
||||
MovementOps.AS_STRIDED: as_strided, MovementOps.EXPAND: lambda x, arg: x.expand(arg),
|
||||
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, flatten(padding[::-1])),
|
||||
}
|
||||
|
||||
class TorchAllocator(Allocator):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue