tinygrad/examples
hooved 69857d0ab0
Stable Diffusion mlperf training (#11304)
* entrypoint for sd mlperf train development

* match sd-v2 mlperf reference unet

* implement dataloader from mlperf ref

* update dataloader reference

* implement LambdaLR scheduler from mlperf ref

* match tokenizer from mlperf reference

* sample latent

* add noise to latent

* complete training epoch

* run full training step

* jit training loop

* replicate mlperf ref. losses over 11 train steps

* save tinygrad loss checkpoints properly

* match out.2.bias.grad to reference

* match weights to ref after 1 step

* compare out.2.bias to ref over three train steps

* implement attn_mask; cleanup closeness testing

* correct mse loss

* update dev_run / dependencies

* setup validation config/checkpointing

* implement validation sampling

* test closeness of eval denoise step to mlperf ref

* test closeness of decoder to mlperf ref

* confirm inception matches mlperf ref

* resize w/ bicubic interpolation, test closeness

* confirm closeness of clip preprocess to mlperf ref

* confirm clip score matches mlperf ref

* confirm fid/clip scores match mlperf ref

* cleanup

* cleanup

* zero-init some unet params as in mlperf reference

* revert jit change

* uncomment dependencies

* move to tinybox red

* implement GradScaler from torch but jittable

* simplify lr_scheduler, ensure jittability

* instantiate GradScaler

* only check if grads are finite with fp16

* implement fp16 training loop

* refactor UNet: norm, gelu, mixed precision

* refactor clip_tokenizer to enable versioning

* make fp16 attention closer to torch

* remove comparisons to torch fp16 attention

* add globvars.py for reference

* confirm closeness of fp16 unet forward to mlperf

* test norm closeness to torch with precast

* remeasure e2e with master attention

* more detailed softmax upcast comparison to torch

* parameterize softmax upcast in attention and unet

* use fp32 weights with autocast to fp16

* cleanup

* add data/checkpoint download script

* debug kernel timeout on AMD

* fix finite grads check; start multigpu

* pass numpy arrays from dataloader

* include text encoder in jit train step

* use int32 for tokens instead of int64

* prevent multi bug in reshape within clip

* corealize more, del refs before

* add more logging and wandb

* use erf gelu in clip encoder

* minor changes to train step and logging

* save checkpoints for eval or resuming

* add eval-only logic to training script

* multigpu eval

* remove PARALLEL=0

* cleanup

* pad eval batches of size < EVAL_BS

* workaround silent multigpu bug in jit

* cleanup

* tokenize captions

* verify correctness of multigpu eval

* cleanup

* verify correctness of grads in train step

* verify correctness of training (20 steps)

* don't shard in the training jit

* training settings

* minor cleanup

* overfit train w/ eval on 6 samples

* offload to enable combined train and eval

* download to raid; use local rclone

* misc changes for mi300x / logging

* refactor eval for larger BS, verify correctness

* cleanup

* ckpt resuming and remove eval cats

* eval BEAM config on mi300x and red

* resume eval after crash

* confirm eval correctness (one iteration, 6 samples)

* verify eval correctness at full scale

* cleanup correctness testing

* training correctness (20 steps, BS=248 uniform)

* cleanup

* remove eval cache at end of run

* switch f16 for bf16, del grad scaler

* confirm bf16 training correctness

* timestamps, new jits

* merge jits in training

* realize loss/lr on CPU

* training correctness

* post-bf16 train/eval

* implement grad_acc with timing/logging

* beam offline; debug gradacc; use float32

* fix gradacc in jit, correctness test

* prepare f32 BS=512 gradacc=4 run

* workaround jit problem in diffusion eval

* scale lr by BS

* revert gradacc, prepare bf16 BS=336 lr*=BS train

* make checkpointing faster

* resume bf16 BS=336 base_lr=1.25e-7 run

* jit ckpt at beginning

* don't alloc more gpu mem in ckpt

* cleanup

* move script to mi300x dir

* cleanup

* cleanup unneeded files

* revert beam search to master

* minor changes

* fix regression: realize before assign in eval

* cleanup mlperf SD data/ckpt downloads

* workaround BEAM failure

* workaround bug in Tensor.stack

* minor changes

* revert gradscaler

* cleanup

* cleanup/validate dataloader

* ensure checksum of laion data

* simplify config

* load training state to jitted bufs

* simplify lr scheduler

* simplify train script

* cleanup comments

* refactor stable diffusion/unet init

* more refactoring of stable diffusion init

* fix import errors in tests

* refactor: separate train/eval

* fix import errors

* eval checkpoints in reverse chron. order

* save/load cycle in sd init

* refactor and verify eval

* verify training correctness

* prepare repro train run

* cleanup

* integrate beam retry, train, eval

* simplify wandb

* kill orphaned processes

* better logging

* train to 10 ckpts instead of 7

* remove optimizer/scheduler checkpointing/resume

* cleanup

* BEAM=2 7 ckpts

* add test to compare with torch softmax in amp

* cleanup

* stop eval early if checkpoint converged

* add test for lr scheduler

* add proper test method

* add test for training

* use venv name that is ignored by .gitignore

* linting

* add simple f32 softmax fxn

* revert change to scaled_dot_product_attention

* refactor gelu_erf init

* simplify mixed precision in unet

* add norm autocasting to fp32

* rm extra test

* test eval with NULL backend

* fix venv name

* simplify norm autocast

* use temp dir for training test

* actually add eval test

* remove parallel env variable from tests

* update clip with tests

* reorg init functions

* use np for testing

* remove unused var

* factor out GPUS

* add sd model init tests

* more unet tests

* match master

* rerun CI due to linux (remote) hang

* explain UNET_CKPTDIR

* rerun CI due to linux (remote) timeout

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2025-10-05 07:56:05 -04:00
..
conversation_data Whisper + LLAMA + VITS (#2332) 2023-12-02 15:03:46 -08:00
llm.c clean up unused imports in examples and update CI linting (#11024) 2025-06-30 08:21:27 -07:00
mlperf Stable Diffusion mlperf training (#11304) 2025-10-05 07:56:05 -04:00
openpilot ops_gpu -> ops_cl (#12103) 2025-09-10 15:15:48 -04:00
other_mnist clean up unused imports in examples and update CI linting (#11024) 2025-06-30 08:21:27 -07:00
rl more beautiful_cartpole with exposed hparams 2024-01-07 17:41:09 -08:00
sovits_helpers combine pad2d with pad (#7677) 2024-11-14 17:56:02 +08:00
tinychat rename lazydata to uop (#10698) 2025-06-08 08:42:22 -07:00
vgg7_helpers leakyrelu to leaky_relu (#9270) 2025-02-26 13:22:08 -05:00
webgpu rename lazydata to uop (#10698) 2025-06-08 08:42:22 -07:00
__init__.py failing llama test 2023-03-11 16:28:10 -08:00
beautiful_cartpole.py remove Tensor.no_grad, it's meaningless now [pr] (#10556) 2025-05-28 22:20:02 -07:00
beautiful_cifar.py remove np from beautiful_cifar (#10988) 2025-08-29 19:34:16 -04:00
beautiful_mnist.py test backward in test_tiny (#11697) 2025-08-16 20:29:39 -07:00
beautiful_mnist_multigpu.py Fix mypy examples/beautiful_*.py (#6978) 2024-10-10 11:34:29 -04:00
benchmark_onnx.py OnnxRunner file as input (#10789) 2025-07-12 14:27:46 -04:00
coder.py clean up unused imports in examples and update CI linting (#11024) 2025-06-30 08:21:27 -07:00
compile_efficientnet.py CLANG -> CPU (#9189) 2025-02-20 18:03:09 -05:00
compile_tensorflow.py hcq: move cpu to hcq (#11262) 2025-07-21 15:10:38 +03:00
conversation.py remove Tensor.no_grad, it's meaningless now [pr] (#10556) 2025-05-28 22:20:02 -07:00
efficientnet.py remove clang program header (#4422) 2024-05-04 08:38:01 -07:00
flux1.py flux set model path in args (#7660) 2024-11-12 22:11:40 -05:00
flux1_seed0.png Flux.1 (#6334) 2024-09-24 10:08:04 +08:00
gpt2.py push copy to disk (#12348) 2025-09-29 21:55:05 -07:00
hlb_cifar10.py fix cifar training in RANGEIFY (#12355) 2025-09-30 15:59:19 +08:00
llama.py replace hardcoded GPU in llama debug msg (#12102) 2025-09-10 13:56:40 -04:00
llama3.py replace hardcoded GPU in llama debug msg (#12102) 2025-09-10 13:56:40 -04:00
mamba.py added top k sampling to examples/mamba (#12061) 2025-09-14 15:27:34 -04:00
mask_rcnn.py change Tensor.stack to method (#4719) 2024-05-24 17:04:19 -04:00
minrf.py remove Tensor.no_grad, it's meaningless now [pr] (#10556) 2025-05-28 22:20:02 -07:00
mixtral.py Subtract 1 from Variable upper bound (#10715) 2025-06-09 09:25:53 -07:00
mnist_gan.py leakyrelu to leaky_relu (#9270) 2025-02-26 13:22:08 -05:00
olmoe.py remove .float calls in olmoe (#11610) 2025-08-10 20:33:22 -04:00
openelm.py nn.RMSNorm (#5272) 2024-07-02 21:39:01 -04:00
qwq.py replace hardcoded GPU in llama debug msg (#12102) 2025-09-10 13:56:40 -04:00
sdv2.py remove Tensor.no_grad, it's meaningless now [pr] (#10556) 2025-05-28 22:20:02 -07:00
sdxl.py don't validate output in sdxl with fakeweights (#12160) 2025-09-13 21:47:51 -04:00
sdxl_seed0.png fix failed threefry (#10646) 2025-06-05 17:17:42 -07:00
self_tokenize.py make self_tokenize output more like a python file (#8411) 2024-12-25 14:16:30 -05:00
serious_mnist.py combine pad2d with pad (#7677) 2024-11-14 17:56:02 +08:00
simple_conv_bn.py fix various examples (#4691) 2024-05-22 20:43:21 -04:00
so_vits_svc.py small fix replacing download_file with fetch (#10877) 2025-06-19 12:12:09 -04:00
stable_diffusion.py Stable Diffusion model init for mlperf (#12314) 2025-10-02 02:28:41 -04:00
stable_diffusion_seed0.png default threefry (#6116) 2024-09-25 17:45:13 +08:00
stunning_mnist.py clean up unused imports in examples and update CI linting (#11024) 2025-06-30 08:21:27 -07:00
test_onnx_imagenet.py add MobileNetV2 benchmark to comma CI (#10250) 2025-05-19 18:22:50 +03:00
test_pkl_imagenet.py more stuff from DSP (#9689) 2025-04-02 15:27:48 +08:00
torch_cuda_kernel.py clean up unused imports in examples and update CI linting (#11024) 2025-06-30 08:21:27 -07:00
train_efficientnet.py tinytqdm.set_description and tinytrange (#5101) 2024-06-22 14:45:06 -04:00
train_resnet.py move things, clean up extra (#2292) 2023-11-13 20:18:40 -08:00
transformer.py fixing transformer training bug (#9877) 2025-04-13 19:34:20 -04:00
vgg7.py clean up unused imports in examples and update CI linting (#11024) 2025-06-30 08:21:27 -07:00
vit.py move to new cached fetch (#2493) 2023-11-28 17:36:55 -08:00
vits.py update vits vctk model to use download from huggingface (#10688) 2025-06-07 20:47:28 -04:00
whisper.py Tensor.pad_to and Tensor.shrink_to (#12210) 2025-09-16 12:24:55 -04:00
yolov3.py fix bugs at examples/yolov3.py (#11614) 2025-08-11 21:14:47 -04:00
yolov8-onnx.py OnnxRunner file as input (#10789) 2025-07-12 14:27:46 -04:00
yolov8.py clean up unused imports in examples and update CI linting (#11024) 2025-06-30 08:21:27 -07:00