# TODO: implement BatchNorm2d and Swish # aka batch_norm, pad, swish, dropout # https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth # a rough copy of # https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py from tinygrad.tensor import Tensor from tinygrad.utils import fetch def swish(x): return x.mul(x.sigmoid()) class BatchNorm2D: def __init__(self, sz, eps=0.001): self.eps = eps self.weight = Tensor.zeros(sz) self.bias = Tensor.zeros(sz) # TODO: need running_mean and running_var self.running_mean = Tensor.zeros(sz) self.running_var = Tensor.zeros(sz) self.num_batches_tracked = Tensor.zeros(0) def __call__(self, x): # this work at inference? x = x.sub(self.running_mean.reshape(shape=[1, -1, 1, 1])) x = x.mul(self.weight.reshape(shape=[1, -1, 1, 1])) x = x.div(self.running_var.add(Tensor([self.eps])).reshape(shape=[1, -1, 1, 1]).sqrt()) x = x.add(self.bias.reshape(shape=[1, -1, 1, 1])) return x class MBConvBlock: def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio): oup = expand_ratio * input_filters if expand_ratio != 1: self._expand_conv = Tensor.zeros(oup, input_filters, 1, 1) self._bn0 = BatchNorm2D(oup) else: self._expand_conv = None self.pad = (kernel_size-1)//2 self.strides = strides self._depthwise_conv = Tensor.zeros(oup, 1, kernel_size, kernel_size) self._bn1 = BatchNorm2D(oup) num_squeezed_channels = max(1, int(input_filters * se_ratio)) self._se_reduce = Tensor.zeros(num_squeezed_channels, oup, 1, 1) self._se_reduce_bias = Tensor.zeros(num_squeezed_channels) self._se_expand = Tensor.zeros(oup, num_squeezed_channels, 1, 1) self._se_expand_bias = Tensor.zeros(oup) self._project_conv = Tensor.zeros(output_filters, oup, 1, 1) self._bn2 = BatchNorm2D(output_filters) def __call__(self, x): if self._expand_conv: x = swish(self._bn0(x.conv2d(self._expand_conv))) x = x.pad2d(padding=(self.pad, self.pad, self.pad, self.pad)) x = x.conv2d(self._depthwise_conv, stride=self.strides, groups=self._depthwise_conv.shape[0]) x = swish(self._bn1(x)) # has_se x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4]) x_squeezed = swish(x_squeezed.conv2d(self._se_reduce).add(self._se_reduce_bias.reshape(shape=[1, -1, 1, 1]))) x_squeezed = x_squeezed.conv2d(self._se_expand).add(self._se_expand_bias.reshape(shape=[1, -1, 1, 1])) x = x.mul(x_squeezed.sigmoid()) x = self._bn2(x.conv2d(self._project_conv)) return swish(x) class EfficientNet: def __init__(self): self._conv_stem = Tensor.zeros(32, 3, 3, 3) self._bn0 = BatchNorm2D(32) blocks_args = [ [1, 3, (1,1), 1, 32, 16, 0.25], [2, 3, (2,2), 6, 16, 24, 0.25], [2, 5, (2,2), 6, 24, 40, 0.25], [3, 3, (2,2), 6, 40, 80, 0.25], [3, 5, (1,1), 6, 80, 112, 0.25], [4, 5, (1,1), 6, 112, 192, 0.25], [1, 3, (1,1), 6, 192, 320, 0.25], ] self._blocks = [] # num_repeats, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio for b in blocks_args: args = b[1:] for n in range(b[0]): self._blocks.append(MBConvBlock(*args)) args[3] = args[4] args[1] = (1,1) self._conv_head = Tensor.zeros(1280, 320, 1, 1) self._bn1 = BatchNorm2D(1280) self._fc = Tensor.zeros(1280, 1000) self._fc_bias = Tensor.zeros(1000) def forward(self, x): x = x.pad2d(padding=(0,1,0,1)) x = swish(self._bn0(x.conv2d(self._conv_stem, stride=2))) for b in self._blocks: print(x.shape) x = b(x) x = swish(self._bn1(x.conv2d(self._conv_head))) x = x.avg_pool2d(kernel_size=x.shape[2:4]).reshape(shape=(-1, 1280)) #x = x.dropout(0.2) return swish(x.dot(self._fc).add(self._fc_bias)) if __name__ == "__main__": import numpy as np np.set_printoptions(suppress=True) # instantiate my net model = EfficientNet() # load b0 import io, torch b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth") b0 = torch.load(io.BytesIO(b0)) for k,v in b0.items(): if '_blocks.' in k: k = "%s[%s].%s" % tuple(k.split(".", 2)) mk = "model."+k #print(k, v.shape) try: mv = eval(mk) except AttributeError: try: mv = eval(mk.replace(".weight", "")) except AttributeError: mv = eval(mk.replace(".bias", "_bias")) mv.data[:] = v.numpy() if k != '_fc.weight' else v.numpy().T # load cat image from PIL import Image img = Image.open(io.BytesIO(fetch("https://c.files.bbci.co.uk/12A9B/production/_111434467_gettyimages-1143489763.jpg"))) img = img.resize((224, 224)) img = np.moveaxis(np.array(img), [2,0,1], [0,1,2]) img = img.astype(np.float32).reshape(1,3,224,224) print(img.shape) #b0 = pickle.loads(b0) out = model.forward(Tensor(img)) print(np.argmax(out.data), np.max(out.data))