spend 5 lines to bring mnist into the repo (#4122)

This commit is contained in:
George Hotz 2024-04-09 19:24:57 -07:00 committed by GitHub
commit fea774f669
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 15 additions and 8 deletions

View file

@ -2,7 +2,7 @@
from typing import List, Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters
from tinygrad.helpers import getenv, colored
from extra.datasets import fetch_mnist
from tinygrad.features.datasets import mnist
from tqdm import trange
class Model:
@ -19,7 +19,7 @@ class Model:
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
if __name__ == "__main__":
X_train, Y_train, X_test, Y_test = fetch_mnist(tensors=True)
X_train, Y_train, X_test, Y_test = mnist()
model = Model()
opt = nn.optim.Adam(nn.state.get_parameters(model))

View file

@ -1,6 +1,6 @@
from tinygrad.tensor import Tensor # noqa: F401
from tinygrad.engine.jit import TinyJit # noqa: F401
from tinygrad.shape.symbolic import Variable # noqa: F401
from tinygrad.dtype import dtypes # noqa: F401
from tinygrad.helpers import GlobalCounters # noqa: F401
from tinygrad.device import Device # noqa: F401
from tinygrad.tensor import Tensor # noqa: F401
from tinygrad.engine.jit import TinyJit # noqa: F401
from tinygrad.shape.symbolic import Variable # noqa: F401
from tinygrad.dtype import dtypes # noqa: F401
from tinygrad.helpers import GlobalCounters, fetch # noqa: F401
from tinygrad.device import Device # noqa: F401

View file

@ -0,0 +1,7 @@
import gzip
from tinygrad import Tensor, fetch
def _fetch_mnist(file, offset): return Tensor(gzip.open(fetch("https://storage.googleapis.com/cvdf-datasets/mnist/"+file)).read()[offset:])
def mnist():
return _fetch_mnist("train-images-idx3-ubyte.gz", 0x10).reshape(-1, 1, 28, 28), _fetch_mnist("train-labels-idx1-ubyte.gz", 8), \
_fetch_mnist("t10k-images-idx3-ubyte.gz", 0x10).reshape(-1, 1, 28, 28), _fetch_mnist("t10k-labels-idx1-ubyte.gz", 8)