mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
remove ane, start supporting ops_torch
This commit is contained in:
parent
7d12482d80
commit
641b1dbb40
3 changed files with 192 additions and 29 deletions
|
|
@ -88,7 +88,7 @@ class Bottleneck:
|
|||
return out
|
||||
|
||||
class ResNet:
|
||||
def __init__(self, block, num_blocks, num_classes=10, pretrained=False):
|
||||
def __init__(self, block, num_blocks, num_classes=10):
|
||||
self.in_planes = 64
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3)
|
||||
|
|
@ -145,7 +145,7 @@ def ResNet101(num_classes, pretrained=False):
|
|||
return model
|
||||
|
||||
def ResNet152(num_classes, pretrained=False):
|
||||
model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes, pretrained=pretrained)
|
||||
model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes)
|
||||
if pretrained:
|
||||
model = load_from_pretrained(model, model_urls['resnet152'])
|
||||
return model
|
||||
|
|
|
|||
170
tinygrad/ops_torch.py
Normal file
170
tinygrad/ops_torch.py
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
import torch
|
||||
from .tensor import Function
|
||||
|
||||
# ************* unary ops *************
|
||||
|
||||
class ReLU(Function):
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return input.relu()
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return grad_output * (input >= 0)
|
||||
|
||||
class Log(Function):
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return torch.log(input)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return grad_output / input
|
||||
|
||||
class Exp(Function):
|
||||
def forward(ctx, input):
|
||||
ret = torch.exp(input)
|
||||
ctx.save_for_backward(ret)
|
||||
return ret
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
ret, = ctx.saved_tensors
|
||||
return grad_output * ret
|
||||
|
||||
# ************* binary ops *************
|
||||
|
||||
def unbroadcast(out, in_sh):
|
||||
# adjoint operation to broadcast is sum. Need to sum all axis with 1 = in_sh[i] < out.shape[i]
|
||||
if in_sh == (1,):
|
||||
return out.sum().reshape((1,))
|
||||
else:
|
||||
sum_axis = tuple([i for i in range(len(in_sh)) if in_sh[i]==1 and out.shape[i]>1])
|
||||
return out.sum(axis=sum_axis).reshape(in_sh)
|
||||
|
||||
class Add(Function):
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x.shape, y.shape)
|
||||
return x+y
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
shape_x, shape_y = ctx.saved_tensors
|
||||
return unbroadcast(grad_output, shape_x), unbroadcast(grad_output, shape_y)
|
||||
|
||||
class Sub(Function):
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x.shape, y.shape)
|
||||
return x-y
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
shape_x, shape_y = ctx.saved_tensors
|
||||
return unbroadcast(grad_output, shape_x), unbroadcast(-grad_output, shape_y)
|
||||
|
||||
class Mul(Function):
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x, y)
|
||||
return x*y
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
x,y = ctx.saved_tensors
|
||||
return unbroadcast(y*grad_output, x.shape), unbroadcast(x*grad_output, y.shape)
|
||||
|
||||
class Pow(Function):
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x, y)
|
||||
return x ** y
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
x,y = ctx.saved_tensors
|
||||
return unbroadcast(y * (x**(y-1.0)) * grad_output, x.shape), \
|
||||
unbroadcast((x**y) * torch.log(x) * grad_output, y.shape)
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
def forward(ctx, input, axis=None):
|
||||
ctx.save_for_backward(input, axis)
|
||||
if axis == None:
|
||||
return input.sum().reshape((1,))
|
||||
else:
|
||||
return input.sum(axis)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, axis = ctx.saved_tensors
|
||||
if isinstance(axis, int): axis = [axis]
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
return grad_output.reshape(shape) + torch.zeros_like(input)
|
||||
|
||||
class Max(Function):
|
||||
def forward(ctx, inp, axis=None):
|
||||
if isinstance(axis, int): axis = [axis]
|
||||
ret = torch.amax(inp, axis=None if axis is None else tuple(axis), keepdims=True)
|
||||
ctx.save_for_backward(inp, axis, ret)
|
||||
if axis is not None:
|
||||
ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis])
|
||||
return ret
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, axis, ret = ctx.saved_tensors
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
ret2 = (input==ret.reshape(shape))
|
||||
div = ret2.sum(axis=None if axis is None else tuple(axis), keepdims=True)
|
||||
return ret2*grad_output.reshape(shape)/div
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
class Reshape(Function):
|
||||
def forward(ctx, x, shape):
|
||||
ctx.save_for_backward(x.shape)
|
||||
return x.reshape(shape)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
in_shape, = ctx.saved_tensors
|
||||
return grad_output.reshape(in_shape)
|
||||
|
||||
class Transpose(Function):
|
||||
def forward(ctx, x, order):
|
||||
ctx.save_for_backward(order)
|
||||
return torch.transpose(x, order)
|
||||
|
||||
def backward(ctx, x):
|
||||
return torch.transpose(x, torch.argsort(ctx.order))
|
||||
|
||||
def inner_slice(x, arg):
|
||||
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
|
||||
x = torch.nn.functional.pad(x, [item for sublist in padding for item in sublist])
|
||||
slicee = [(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)]
|
||||
return x[tuple([slice(x[0], x[1], None) for x in slicee])]
|
||||
|
||||
class Slice(Function):
|
||||
def forward(ctx, x, arg=None):
|
||||
ctx.save_for_backward(x.shape)
|
||||
return inner_slice(x, arg)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
shape, = ctx.saved_tensors
|
||||
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(ctx.arg)]
|
||||
return inner_slice(grad_output, narg)
|
||||
|
||||
# ************* processing ops *************
|
||||
|
||||
class Matmul(Function):
|
||||
def forward(ctx, input, weight):
|
||||
ctx.save_for_backward(input, weight)
|
||||
return input @ weight
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
grad_input = grad_output @ torch.swapaxes(weight, -2, -1)
|
||||
grad_weight = torch.swapaxes(input, -2, -1) @ grad_output
|
||||
return grad_input, grad_weight
|
||||
|
||||
class Conv2D(Function):
|
||||
def forward(ctx, x, w, stride=1, groups=1):
|
||||
ctx.save_for_backward(x, w)
|
||||
return torch.nn.functional.conv2d(x, w, stride=stride, groups=groups)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
x, w = ctx.saved_tensors
|
||||
grad_input = torch.nn.grad.conv2d_input(x.shape, w, grad_output)
|
||||
grad_weight = torch.nn.grad.conv2d_weight(x, w.shape, grad_output)
|
||||
return grad_input, grad_weight
|
||||
|
|
@ -55,20 +55,11 @@ class GPUBuffer:
|
|||
def __repr__(self):
|
||||
return f"<GPUBuffer with shape {self.shape!r}>"
|
||||
|
||||
# **** ANE functions ****
|
||||
|
||||
ane = None
|
||||
def require_init_ane():
|
||||
global ane
|
||||
if ane is None:
|
||||
import accel.ane.lib.ane as anelib, accel.ane.tinygrad.ops_ane as ops_ane
|
||||
ane = anelib.ANE()
|
||||
|
||||
# **** start with two base classes, Tensor and Function ****
|
||||
|
||||
class Device: CPU, GPU, ANE = 0, 1, 2
|
||||
class Device: CPU, GPU, TORCH = 0, 1, 2
|
||||
|
||||
DEFAULT_DEVICE = Device.CPU if os.environ.get("GPU", 0) != "1" else Device.GPU
|
||||
DEFAULT_DEVICE = (Device.CPU if os.environ.get("GPU", 0) != "1" else Device.GPU) if os.environ.get("TORCH", 0) != "1" else Device.TORCH
|
||||
|
||||
class Tensor:
|
||||
did_float_warning = False
|
||||
|
|
@ -95,7 +86,10 @@ class Tensor:
|
|||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.data.dtype
|
||||
if self.device == Device.TORCH:
|
||||
return np.float32
|
||||
else:
|
||||
return self.data.dtype
|
||||
|
||||
# ***** creation helper functions *****
|
||||
|
||||
|
|
@ -165,10 +159,8 @@ class Tensor:
|
|||
with ProfileOp("toCPU", [data]):
|
||||
cl.enqueue_copy(cl_queue, data, old.cl, is_blocking=True)
|
||||
|
||||
elif "ANETensor" in str(type(data)):
|
||||
if device == Device.ANE: return data
|
||||
with ProfileOp("toCPU", [data]):
|
||||
data = data.data().astype(np.float32)
|
||||
if str(type(data)).startswith("torch"):
|
||||
data = data.numpy()
|
||||
|
||||
if not isinstance(data, np.ndarray):
|
||||
data = np.array(data, dtype=np.float32)
|
||||
|
|
@ -183,12 +175,11 @@ class Tensor:
|
|||
with ProfileOp("toGPU", [data]):
|
||||
return GPUBuffer(data.shape, data)
|
||||
|
||||
elif device == Device.ANE:
|
||||
require_init_ane()
|
||||
with ProfileOp("toANE", [data]):
|
||||
ndata = ane.tensor(data.shape)
|
||||
ndata.data()[:] = data
|
||||
return ndata
|
||||
if device == Device.TORCH:
|
||||
import torch
|
||||
with ProfileOp("toTORCH", [data]):
|
||||
return torch.from_numpy(data)
|
||||
|
||||
return data
|
||||
|
||||
def to_(self, device):
|
||||
|
|
@ -335,7 +326,7 @@ def register(name, fxn, device=Device.CPU):
|
|||
tt = [arg for arg in x if isinstance(arg, Tensor)][0]
|
||||
x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
|
||||
f = Tensor.ops[tt.device][name]
|
||||
f.cl_ctx, f.cl_queue, f.ane, f.device = cl_ctx, cl_queue, ane, tt.device
|
||||
f.cl_ctx, f.cl_queue, f.device = cl_ctx, cl_queue, tt.device
|
||||
return f.apply(f, *x, **kwargs)
|
||||
setattr(Tensor, name, dispatch)
|
||||
if name in ['add', 'sub', 'mul', 'pow', 'matmul']:
|
||||
|
|
@ -354,9 +345,6 @@ def _register_ops(namespace, device=Device.CPU):
|
|||
|
||||
from tinygrad import ops_cpu
|
||||
_register_ops(ops_cpu)
|
||||
if os.getenv("CHERRY", None) is not None:
|
||||
from accel.cherry.tinygrad import ops_cherry
|
||||
_register_ops(ops_cherry)
|
||||
try:
|
||||
import pyopencl as cl
|
||||
# TODO: move this import to require_init_gpu?
|
||||
|
|
@ -366,4 +354,9 @@ try:
|
|||
except ImportError:
|
||||
# no GPU support
|
||||
GPU = False
|
||||
ANE = False
|
||||
try:
|
||||
import torch
|
||||
from tinygrad import ops_torch
|
||||
_register_ops(ops_torch, device=Device.TORCH)
|
||||
except ImportError:
|
||||
pass
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue