tinygrad/extra/models
Yixiang Gao 13e872b53f
add mutigpu support for llama attention (#3064)
* add llama attention test for multigpu

* test fails

* kv cache trying to shrink on sharded axis

* mask None works for scale dot product

* kv cache seems to be working but scale dot product breaks

* scaled dot product works, but the last linear layer failed

* running into the reshape case where it could be wrong for multigpu

* making sure it was the reshape

* adding contiguous doesn't solve

* need to shard more properly

* remove reshape test

* minor adjustment to scale dot product attention test

* weights are sharded wrong

* continue fix new weight sharding

* clean up

* fix attention when start_pos is 0

* remove print

* add TODOs for the best mutigpu interface
2024-01-11 16:31:02 -08:00
..
bert.py move to new cached fetch (#2493) 2023-11-28 17:36:55 -08:00
convnext.py move to new cached fetch (#2493) 2023-11-28 17:36:55 -08:00
efficientnet.py add name support to fetch (#2407) 2023-11-23 14:16:17 -08:00
llama.py add mutigpu support for llama attention (#3064) 2024-01-11 16:31:02 -08:00
mask_rcnn.py move dtypes to dtype.py (#2964) 2024-01-01 14:58:48 -08:00
resnet.py move to new cached fetch (#2493) 2023-11-28 17:36:55 -08:00
retinanet.py move to new cached fetch (#2493) 2023-11-28 17:36:55 -08:00
rnnt.py move to new cached fetch (#2493) 2023-11-28 17:36:55 -08:00
transformer.py fix onehot and jit in examples/transformer (#3073) 2024-01-10 02:22:41 -05:00
unet3d.py move to new cached fetch (#2493) 2023-11-28 17:36:55 -08:00
vit.py move to new cached fetch (#2493) 2023-11-28 17:36:55 -08:00