mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Promote Embedding to nn (#798)
* feat: promote Embedding to nn * fix: fix failing test * feat: add test with jit * feat: rewrite embedding to no longer need stacked for loops * clean+fix: don't know how that happened
This commit is contained in:
parent
f4f23dc9a3
commit
0dc333cfab
5 changed files with 51 additions and 34 deletions
|
|
@ -14,7 +14,7 @@ from tinygrad.helpers import getenv, DEBUG
|
|||
from tinygrad.lazy import Device
|
||||
from extra.helpers import Timing
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Linear
|
||||
from tinygrad.nn import Embedding, Linear
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.jit import TinyJit
|
||||
|
||||
|
|
@ -133,13 +133,13 @@ class Transformer:
|
|||
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, max_batch_size=32, max_seq_len=1024):
|
||||
self.layers = [TransformerBlock(dim, multiple_of, n_heads, norm_eps) for _ in range(n_layers)]
|
||||
self.norm = RMSNorm(dim, norm_eps)
|
||||
self.tok_embeddings = {"weight": Tensor.glorot_uniform(vocab_size, dim)}
|
||||
self.tok_embeddings = Embedding(vocab_size, dim)
|
||||
self.output = Linear(dim, vocab_size, bias=False)
|
||||
self.freqs_cis = Tensor(precompute_freqs_cis(dim // n_heads, max_seq_len * 2))
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:int):
|
||||
_bsz, seqlen, _ = tokens.shape
|
||||
h = tokens @ self.tok_embeddings['weight']
|
||||
_bsz, seqlen = tokens.shape
|
||||
h = self.tok_embeddings(tokens)
|
||||
|
||||
# get only the part we are using. making it contiguous avoids more kernel calls
|
||||
freqs_cis = self.freqs_cis[:, start_pos:start_pos+seqlen].contiguous().realize()
|
||||
|
|
@ -174,13 +174,6 @@ WEIGHTS_13B_0_FILENAME = WEIGHTS_DIR / "13B/consolidated.00.pth"
|
|||
WEIGHTS_13B_1_FILENAME = WEIGHTS_DIR / "13B/consolidated.01.pth"
|
||||
|
||||
# **** helper functions ****
|
||||
|
||||
def onehot_encode(toks, vocab_size=VOCAB_SIZE):
|
||||
# this allows the embedding to work in tinygrad
|
||||
onehot = np.zeros((1, len(toks), vocab_size), dtype=np.float32)
|
||||
onehot[0,range(len(toks)),toks] = 1
|
||||
return Tensor(onehot)
|
||||
|
||||
def sample(logits, temperature):
|
||||
if temperature < 1e-6:
|
||||
# so close to 0 we use argmax
|
||||
|
|
@ -365,7 +358,7 @@ After you are done speaking, output [EOS]. You are not Chad.
|
|||
|
||||
print(f"Preparing KV cache for chatbot with personality {args.personality}...")
|
||||
with Timing():
|
||||
model(onehot_encode(toks), 0).realize() # NOTE: output logits are not used
|
||||
model(Tensor([toks]), 0).realize() # NOTE: output logits are not used
|
||||
start_pos = len(toks)
|
||||
else:
|
||||
# non chat bot mode
|
||||
|
|
@ -400,7 +393,7 @@ After you are done speaking, output [EOS]. You are not Chad.
|
|||
if args.timing: print("")
|
||||
st = GlobalCounters.time_sum_s
|
||||
with Timing("ran model in ", on_exit=(lambda et: f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU") if DEBUG else None, enabled=args.timing):
|
||||
logits = model(onehot_encode(toks[start_pos:]), start_pos).realize()
|
||||
logits = model(Tensor([toks[start_pos:]]), start_pos).realize()
|
||||
with Timing("sync in ", enabled=args.timing):
|
||||
tok = sample(logits, args.temperature)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.nn import Linear
|
||||
from tinygrad.nn import Linear, Embedding
|
||||
import numpy as np
|
||||
from extra.utils import download_file
|
||||
from pathlib import Path
|
||||
|
|
@ -171,22 +171,6 @@ class Encoder:
|
|||
return x.transpose(0, 1), x_lens
|
||||
|
||||
|
||||
class Embedding:
|
||||
def __init__(self, vocab_size: int, embed_size: int):
|
||||
self.vocab_size = vocab_size
|
||||
self.vocab_counter = Tensor(np.arange(vocab_size, dtype=np.float32), requires_grad=False)
|
||||
self.weight = Tensor.scaled_uniform(vocab_size, embed_size)
|
||||
|
||||
def __call__(self, idx: Tensor) -> Tensor:
|
||||
oha = []
|
||||
for i in range(idx.shape[0]):
|
||||
ohba = []
|
||||
for j in range(idx.shape[1]):
|
||||
ohba.append((self.vocab_counter == idx[i, j]).realize())
|
||||
oha.append(Tensor.stack(ohba).realize())
|
||||
return Tensor.stack(oha) @ self.weight
|
||||
|
||||
|
||||
class Prediction:
|
||||
def __init__(self, vocab_size, hidden_size, layers, dropout):
|
||||
self.hidden_size = hidden_size
|
||||
|
|
|
|||
6
test/external/external_test_opt.py
vendored
6
test/external/external_test_opt.py
vendored
|
|
@ -81,12 +81,12 @@ class TestInferenceMinKernels(unittest.TestCase):
|
|||
out.realize()
|
||||
|
||||
def test_llama(self):
|
||||
from examples.llama import Transformer, onehot_encode
|
||||
from examples.llama import Transformer
|
||||
args_tiny = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
model = Transformer(**args_tiny)
|
||||
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
|
||||
with CLCache(85):
|
||||
model(onehot_encode([1,2,3,4], vocab_size=args_tiny['vocab_size']), 0).realize()
|
||||
with CLCache(86):
|
||||
model(Tensor([[1,2,3,4]]), 0).realize()
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
class TestOptBinOp(unittest.TestCase):
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.nn import BatchNorm2d, Conv2d, Linear, GroupNorm, LayerNorm, LayerNorm2d
|
||||
from tinygrad.nn import BatchNorm2d, Conv2d, Linear, GroupNorm, LayerNorm, LayerNorm2d, Embedding
|
||||
import torch
|
||||
|
||||
class TestNN(unittest.TestCase):
|
||||
|
|
@ -150,5 +151,35 @@ class TestNN(unittest.TestCase):
|
|||
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
||||
|
||||
def test_embedding(self):
|
||||
B, T, C, VS = 4, 10, 20, 28
|
||||
|
||||
# create in tinygrad
|
||||
layer = Embedding(VS, C)
|
||||
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.Embedding(VS, C).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
|
||||
# test
|
||||
x = Tensor(np.random.randint(0, VS, (B, T)).astype(np.float32))
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.cpu().numpy().astype(np.int32))
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)
|
||||
|
||||
# test with jit enabled
|
||||
@TinyJit
|
||||
def layer_jit(x):
|
||||
return layer(x).realize()
|
||||
|
||||
for _ in range(3):
|
||||
x = Tensor(np.random.randint(0, VS, (B, T)).astype(np.float32))
|
||||
z = layer_jit(x)
|
||||
torch_x = torch.tensor(x.cpu().numpy().astype(np.int32))
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -81,3 +81,12 @@ class LayerNorm:
|
|||
|
||||
class LayerNorm2d(LayerNorm):
|
||||
def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
class Embedding:
|
||||
def __init__(self, vocab_size:int, embed_size:int):
|
||||
self.vocab_size = vocab_size
|
||||
self.weight = Tensor.glorot_uniform(vocab_size, embed_size)
|
||||
|
||||
def __call__(self, idx:Tensor) -> Tensor:
|
||||
vocab_counter = Tensor.arange(self.vocab_size, requires_grad=False).reshape(1, 1, self.vocab_size).expand(*idx.shape, self.vocab_size)
|
||||
return (vocab_counter == idx.unsqueeze(2).expand(*idx.shape, self.vocab_size)) @ self.weight
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue