Promote Embedding to nn (#798)

* feat: promote Embedding to nn

* fix: fix failing test

* feat: add test with jit

* feat: rewrite embedding to no longer need stacked for loops

* clean+fix: don't know how that happened
This commit is contained in:
wozeparrot 2023-05-25 21:39:45 -04:00 committed by GitHub
commit 0dc333cfab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 51 additions and 34 deletions

View file

@ -14,7 +14,7 @@ from tinygrad.helpers import getenv, DEBUG
from tinygrad.lazy import Device
from extra.helpers import Timing
from tinygrad.tensor import Tensor
from tinygrad.nn import Linear
from tinygrad.nn import Embedding, Linear
from tinygrad.ops import GlobalCounters
from tinygrad.jit import TinyJit
@ -133,13 +133,13 @@ class Transformer:
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, max_batch_size=32, max_seq_len=1024):
self.layers = [TransformerBlock(dim, multiple_of, n_heads, norm_eps) for _ in range(n_layers)]
self.norm = RMSNorm(dim, norm_eps)
self.tok_embeddings = {"weight": Tensor.glorot_uniform(vocab_size, dim)}
self.tok_embeddings = Embedding(vocab_size, dim)
self.output = Linear(dim, vocab_size, bias=False)
self.freqs_cis = Tensor(precompute_freqs_cis(dim // n_heads, max_seq_len * 2))
def __call__(self, tokens:Tensor, start_pos:int):
_bsz, seqlen, _ = tokens.shape
h = tokens @ self.tok_embeddings['weight']
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
# get only the part we are using. making it contiguous avoids more kernel calls
freqs_cis = self.freqs_cis[:, start_pos:start_pos+seqlen].contiguous().realize()
@ -174,13 +174,6 @@ WEIGHTS_13B_0_FILENAME = WEIGHTS_DIR / "13B/consolidated.00.pth"
WEIGHTS_13B_1_FILENAME = WEIGHTS_DIR / "13B/consolidated.01.pth"
# **** helper functions ****
def onehot_encode(toks, vocab_size=VOCAB_SIZE):
# this allows the embedding to work in tinygrad
onehot = np.zeros((1, len(toks), vocab_size), dtype=np.float32)
onehot[0,range(len(toks)),toks] = 1
return Tensor(onehot)
def sample(logits, temperature):
if temperature < 1e-6:
# so close to 0 we use argmax
@ -365,7 +358,7 @@ After you are done speaking, output [EOS]. You are not Chad.
print(f"Preparing KV cache for chatbot with personality {args.personality}...")
with Timing():
model(onehot_encode(toks), 0).realize() # NOTE: output logits are not used
model(Tensor([toks]), 0).realize() # NOTE: output logits are not used
start_pos = len(toks)
else:
# non chat bot mode
@ -400,7 +393,7 @@ After you are done speaking, output [EOS]. You are not Chad.
if args.timing: print("")
st = GlobalCounters.time_sum_s
with Timing("ran model in ", on_exit=(lambda et: f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU") if DEBUG else None, enabled=args.timing):
logits = model(onehot_encode(toks[start_pos:]), start_pos).realize()
logits = model(Tensor([toks[start_pos:]]), start_pos).realize()
with Timing("sync in ", enabled=args.timing):
tok = sample(logits, args.temperature)

View file

@ -1,6 +1,6 @@
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from tinygrad.nn import Linear
from tinygrad.nn import Linear, Embedding
import numpy as np
from extra.utils import download_file
from pathlib import Path
@ -171,22 +171,6 @@ class Encoder:
return x.transpose(0, 1), x_lens
class Embedding:
def __init__(self, vocab_size: int, embed_size: int):
self.vocab_size = vocab_size
self.vocab_counter = Tensor(np.arange(vocab_size, dtype=np.float32), requires_grad=False)
self.weight = Tensor.scaled_uniform(vocab_size, embed_size)
def __call__(self, idx: Tensor) -> Tensor:
oha = []
for i in range(idx.shape[0]):
ohba = []
for j in range(idx.shape[1]):
ohba.append((self.vocab_counter == idx[i, j]).realize())
oha.append(Tensor.stack(ohba).realize())
return Tensor.stack(oha) @ self.weight
class Prediction:
def __init__(self, vocab_size, hidden_size, layers, dropout):
self.hidden_size = hidden_size

View file

@ -81,12 +81,12 @@ class TestInferenceMinKernels(unittest.TestCase):
out.realize()
def test_llama(self):
from examples.llama import Transformer, onehot_encode
from examples.llama import Transformer
args_tiny = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
model = Transformer(**args_tiny)
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
with CLCache(85):
model(onehot_encode([1,2,3,4], vocab_size=args_tiny['vocab_size']), 0).realize()
with CLCache(86):
model(Tensor([[1,2,3,4]]), 0).realize()
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
class TestOptBinOp(unittest.TestCase):

View file

@ -1,8 +1,9 @@
#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.jit import TinyJit
from tinygrad.tensor import Tensor, Device
from tinygrad.nn import BatchNorm2d, Conv2d, Linear, GroupNorm, LayerNorm, LayerNorm2d
from tinygrad.nn import BatchNorm2d, Conv2d, Linear, GroupNorm, LayerNorm, LayerNorm2d, Embedding
import torch
class TestNN(unittest.TestCase):
@ -150,5 +151,35 @@ class TestNN(unittest.TestCase):
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
def test_embedding(self):
B, T, C, VS = 4, 10, 20, 28
# create in tinygrad
layer = Embedding(VS, C)
with torch.no_grad():
torch_layer = torch.nn.Embedding(VS, C).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
# test
x = Tensor(np.random.randint(0, VS, (B, T)).astype(np.float32))
z = layer(x)
torch_x = torch.tensor(x.cpu().numpy().astype(np.int32))
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)
# test with jit enabled
@TinyJit
def layer_jit(x):
return layer(x).realize()
for _ in range(3):
x = Tensor(np.random.randint(0, VS, (B, T)).astype(np.float32))
z = layer_jit(x)
torch_x = torch.tensor(x.cpu().numpy().astype(np.int32))
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)
if __name__ == '__main__':
unittest.main()

View file

@ -81,3 +81,12 @@ class LayerNorm:
class LayerNorm2d(LayerNorm):
def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
class Embedding:
def __init__(self, vocab_size:int, embed_size:int):
self.vocab_size = vocab_size
self.weight = Tensor.glorot_uniform(vocab_size, embed_size)
def __call__(self, idx:Tensor) -> Tensor:
vocab_counter = Tensor.arange(self.vocab_size, requires_grad=False).reshape(1, 1, self.vocab_size).expand(*idx.shape, self.vocab_size)
return (vocab_counter == idx.unsqueeze(2).expand(*idx.shape, self.vocab_size)) @ self.weight