remove ane, start supporting ops_torch

This commit is contained in:
George Hotz 2021-10-30 17:47:00 -07:00
commit 641b1dbb40
3 changed files with 192 additions and 29 deletions

View file

@ -88,7 +88,7 @@ class Bottleneck:
return out
class ResNet:
def __init__(self, block, num_blocks, num_classes=10, pretrained=False):
def __init__(self, block, num_blocks, num_classes=10):
self.in_planes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3)
@ -145,7 +145,7 @@ def ResNet101(num_classes, pretrained=False):
return model
def ResNet152(num_classes, pretrained=False):
model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes, pretrained=pretrained)
model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes)
if pretrained:
model = load_from_pretrained(model, model_urls['resnet152'])
return model

170
tinygrad/ops_torch.py Normal file
View file

@ -0,0 +1,170 @@
import torch
from .tensor import Function
# ************* unary ops *************
class ReLU(Function):
def forward(ctx, input):
ctx.save_for_backward(input)
return input.relu()
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output * (input >= 0)
class Log(Function):
def forward(ctx, input):
ctx.save_for_backward(input)
return torch.log(input)
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output / input
class Exp(Function):
def forward(ctx, input):
ret = torch.exp(input)
ctx.save_for_backward(ret)
return ret
def backward(ctx, grad_output):
ret, = ctx.saved_tensors
return grad_output * ret
# ************* binary ops *************
def unbroadcast(out, in_sh):
# adjoint operation to broadcast is sum. Need to sum all axis with 1 = in_sh[i] < out.shape[i]
if in_sh == (1,):
return out.sum().reshape((1,))
else:
sum_axis = tuple([i for i in range(len(in_sh)) if in_sh[i]==1 and out.shape[i]>1])
return out.sum(axis=sum_axis).reshape(in_sh)
class Add(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x.shape, y.shape)
return x+y
def backward(ctx, grad_output):
shape_x, shape_y = ctx.saved_tensors
return unbroadcast(grad_output, shape_x), unbroadcast(grad_output, shape_y)
class Sub(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x.shape, y.shape)
return x-y
def backward(ctx, grad_output):
shape_x, shape_y = ctx.saved_tensors
return unbroadcast(grad_output, shape_x), unbroadcast(-grad_output, shape_y)
class Mul(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return x*y
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
return unbroadcast(y*grad_output, x.shape), unbroadcast(x*grad_output, y.shape)
class Pow(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return x ** y
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
return unbroadcast(y * (x**(y-1.0)) * grad_output, x.shape), \
unbroadcast((x**y) * torch.log(x) * grad_output, y.shape)
# ************* reduce ops *************
class Sum(Function):
def forward(ctx, input, axis=None):
ctx.save_for_backward(input, axis)
if axis == None:
return input.sum().reshape((1,))
else:
return input.sum(axis)
def backward(ctx, grad_output):
input, axis = ctx.saved_tensors
if isinstance(axis, int): axis = [axis]
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
return grad_output.reshape(shape) + torch.zeros_like(input)
class Max(Function):
def forward(ctx, inp, axis=None):
if isinstance(axis, int): axis = [axis]
ret = torch.amax(inp, axis=None if axis is None else tuple(axis), keepdims=True)
ctx.save_for_backward(inp, axis, ret)
if axis is not None:
ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis])
return ret
def backward(ctx, grad_output):
input, axis, ret = ctx.saved_tensors
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
ret2 = (input==ret.reshape(shape))
div = ret2.sum(axis=None if axis is None else tuple(axis), keepdims=True)
return ret2*grad_output.reshape(shape)/div
# ************* movement ops *************
class Reshape(Function):
def forward(ctx, x, shape):
ctx.save_for_backward(x.shape)
return x.reshape(shape)
def backward(ctx, grad_output):
in_shape, = ctx.saved_tensors
return grad_output.reshape(in_shape)
class Transpose(Function):
def forward(ctx, x, order):
ctx.save_for_backward(order)
return torch.transpose(x, order)
def backward(ctx, x):
return torch.transpose(x, torch.argsort(ctx.order))
def inner_slice(x, arg):
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
x = torch.nn.functional.pad(x, [item for sublist in padding for item in sublist])
slicee = [(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)]
return x[tuple([slice(x[0], x[1], None) for x in slicee])]
class Slice(Function):
def forward(ctx, x, arg=None):
ctx.save_for_backward(x.shape)
return inner_slice(x, arg)
def backward(ctx, grad_output):
shape, = ctx.saved_tensors
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(ctx.arg)]
return inner_slice(grad_output, narg)
# ************* processing ops *************
class Matmul(Function):
def forward(ctx, input, weight):
ctx.save_for_backward(input, weight)
return input @ weight
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input = grad_output @ torch.swapaxes(weight, -2, -1)
grad_weight = torch.swapaxes(input, -2, -1) @ grad_output
return grad_input, grad_weight
class Conv2D(Function):
def forward(ctx, x, w, stride=1, groups=1):
ctx.save_for_backward(x, w)
return torch.nn.functional.conv2d(x, w, stride=stride, groups=groups)
def backward(ctx, grad_output):
x, w = ctx.saved_tensors
grad_input = torch.nn.grad.conv2d_input(x.shape, w, grad_output)
grad_weight = torch.nn.grad.conv2d_weight(x, w.shape, grad_output)
return grad_input, grad_weight

View file

@ -55,20 +55,11 @@ class GPUBuffer:
def __repr__(self):
return f"<GPUBuffer with shape {self.shape!r}>"
# **** ANE functions ****
ane = None
def require_init_ane():
global ane
if ane is None:
import accel.ane.lib.ane as anelib, accel.ane.tinygrad.ops_ane as ops_ane
ane = anelib.ANE()
# **** start with two base classes, Tensor and Function ****
class Device: CPU, GPU, ANE = 0, 1, 2
class Device: CPU, GPU, TORCH = 0, 1, 2
DEFAULT_DEVICE = Device.CPU if os.environ.get("GPU", 0) != "1" else Device.GPU
DEFAULT_DEVICE = (Device.CPU if os.environ.get("GPU", 0) != "1" else Device.GPU) if os.environ.get("TORCH", 0) != "1" else Device.TORCH
class Tensor:
did_float_warning = False
@ -95,7 +86,10 @@ class Tensor:
@property
def dtype(self):
return self.data.dtype
if self.device == Device.TORCH:
return np.float32
else:
return self.data.dtype
# ***** creation helper functions *****
@ -165,10 +159,8 @@ class Tensor:
with ProfileOp("toCPU", [data]):
cl.enqueue_copy(cl_queue, data, old.cl, is_blocking=True)
elif "ANETensor" in str(type(data)):
if device == Device.ANE: return data
with ProfileOp("toCPU", [data]):
data = data.data().astype(np.float32)
if str(type(data)).startswith("torch"):
data = data.numpy()
if not isinstance(data, np.ndarray):
data = np.array(data, dtype=np.float32)
@ -183,12 +175,11 @@ class Tensor:
with ProfileOp("toGPU", [data]):
return GPUBuffer(data.shape, data)
elif device == Device.ANE:
require_init_ane()
with ProfileOp("toANE", [data]):
ndata = ane.tensor(data.shape)
ndata.data()[:] = data
return ndata
if device == Device.TORCH:
import torch
with ProfileOp("toTORCH", [data]):
return torch.from_numpy(data)
return data
def to_(self, device):
@ -335,7 +326,7 @@ def register(name, fxn, device=Device.CPU):
tt = [arg for arg in x if isinstance(arg, Tensor)][0]
x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
f = Tensor.ops[tt.device][name]
f.cl_ctx, f.cl_queue, f.ane, f.device = cl_ctx, cl_queue, ane, tt.device
f.cl_ctx, f.cl_queue, f.device = cl_ctx, cl_queue, tt.device
return f.apply(f, *x, **kwargs)
setattr(Tensor, name, dispatch)
if name in ['add', 'sub', 'mul', 'pow', 'matmul']:
@ -354,9 +345,6 @@ def _register_ops(namespace, device=Device.CPU):
from tinygrad import ops_cpu
_register_ops(ops_cpu)
if os.getenv("CHERRY", None) is not None:
from accel.cherry.tinygrad import ops_cherry
_register_ops(ops_cherry)
try:
import pyopencl as cl
# TODO: move this import to require_init_gpu?
@ -366,4 +354,9 @@ try:
except ImportError:
# no GPU support
GPU = False
ANE = False
try:
import torch
from tinygrad import ops_torch
_register_ops(ops_torch, device=Device.TORCH)
except ImportError:
pass