mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
small beautiful_mnist update (#11596)
gather is fast now. there's a conv/bw kernel that only gets fast with BEAM, but whole thing runs < 5 seconds now regardless
This commit is contained in:
parent
45baec1aab
commit
7338ffead0
1 changed files with 2 additions and 3 deletions
|
|
@ -1,12 +1,12 @@
|
|||
# model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
|
||||
from typing import List, Callable
|
||||
from typing import Callable
|
||||
from tinygrad import Tensor, TinyJit, nn, GlobalCounters
|
||||
from tinygrad.helpers import getenv, colored, trange
|
||||
from tinygrad.nn.datasets import mnist
|
||||
|
||||
class Model:
|
||||
def __init__(self):
|
||||
self.layers: List[Callable[[Tensor], Tensor]] = [
|
||||
self.layers: list[Callable[[Tensor], Tensor]] = [
|
||||
nn.Conv2d(1, 32, 5), Tensor.relu,
|
||||
nn.Conv2d(32, 32, 5), Tensor.relu,
|
||||
nn.BatchNorm(32), Tensor.max_pool2d,
|
||||
|
|
@ -28,7 +28,6 @@ if __name__ == "__main__":
|
|||
def train_step() -> Tensor:
|
||||
opt.zero_grad()
|
||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
|
||||
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
|
||||
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
|
||||
opt.step()
|
||||
return loss
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue