isolate the 134ms kernel in train_gpt2.py (#4773)

133ms on tinybox red with BEAM=2
This commit is contained in:
chenyu 2024-05-29 17:26:24 -04:00 committed by GitHub
commit cde7a7cda7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

10
test/external/external_test_lm_head.py vendored Normal file
View file

@ -0,0 +1,10 @@
from tinygrad import Tensor, nn
if __name__ == "__main__":
vocab_size = 50257
n_embd = 768
lm_head = nn.Linear(n_embd, vocab_size, bias=False)
bs = 4
seq_len = 1024
x = Tensor.rand(bs, seq_len, n_embd)
ret = lm_head(x).realize()