tinygrad/test/external/external_test_llama3_layer.py
Christopher Milan bc180a963c
deprecate <dev>=1 in favor of DEV=<dev> (#15467)
* start work on target

* add test

* update actions to use DEV

* update docs

* update readmes

* tests need that too

* update example

* update tests (comments)

* fix that test

* ruff

* mypy

* oops

* remove getenvs

* don't add Target yet

* and the test

* lint

* and docs

* more stuff

* assert

* few more fixes

* test assert
2026-03-26 03:48:03 -04:00

24 lines
1,004 B
Python

#!/usr/bin/env python3
from tinygrad import Tensor, TinyJit, nn, dtypes
from tinygrad.helpers import getenv
from extra.models.llama import TransformerBlock, precompute_freqs_cis
BS = getenv("BS", 1)
SEQLEN = getenv("SEQLEN", 128)
# DEFAULT_FLOAT=bfloat16 SEQLEN=8192 ASM_GEMM=1 HK_FLASH_ATTENTION=1 EMULATE=AMD_CDNA4 DEV=NULL DEBUG=2 VIZ=1 PYTHONPATH="."
# python test/external/external_test_llama3_layer.py
if __name__ == "__main__":
dim, hidden_dim, n_heads, n_kv_heads, norm_eps = 4096, 14336, 32, 8, 1e-5
layer = TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context=0)
for x in nn.state.get_parameters(layer): x.replace(x.cast(dtypes.default_float)).realize()
freqs_cis = precompute_freqs_cis(dim // n_heads, SEQLEN, theta=500000.0).contiguous().requires_grad_(False).realize()
@TinyJit
def run(t): return layer(t, 0, freqs_cis, None)
for i in range(5):
print(f"*** run {i}")
run(Tensor.rand(BS, SEQLEN, dim, dtype=dtypes.default_float).realize())