mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
spend 5 lines to bring mnist into the repo (#4122)
This commit is contained in:
parent
42edae8935
commit
fea774f669
3 changed files with 15 additions and 8 deletions
|
|
@ -2,7 +2,7 @@
|
|||
from typing import List, Callable
|
||||
from tinygrad import Tensor, TinyJit, nn, GlobalCounters
|
||||
from tinygrad.helpers import getenv, colored
|
||||
from extra.datasets import fetch_mnist
|
||||
from tinygrad.features.datasets import mnist
|
||||
from tqdm import trange
|
||||
|
||||
class Model:
|
||||
|
|
@ -19,7 +19,7 @@ class Model:
|
|||
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
|
||||
|
||||
if __name__ == "__main__":
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist(tensors=True)
|
||||
X_train, Y_train, X_test, Y_test = mnist()
|
||||
|
||||
model = Model()
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(model))
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from tinygrad.tensor import Tensor # noqa: F401
|
||||
from tinygrad.engine.jit import TinyJit # noqa: F401
|
||||
from tinygrad.shape.symbolic import Variable # noqa: F401
|
||||
from tinygrad.dtype import dtypes # noqa: F401
|
||||
from tinygrad.helpers import GlobalCounters # noqa: F401
|
||||
from tinygrad.device import Device # noqa: F401
|
||||
from tinygrad.tensor import Tensor # noqa: F401
|
||||
from tinygrad.engine.jit import TinyJit # noqa: F401
|
||||
from tinygrad.shape.symbolic import Variable # noqa: F401
|
||||
from tinygrad.dtype import dtypes # noqa: F401
|
||||
from tinygrad.helpers import GlobalCounters, fetch # noqa: F401
|
||||
from tinygrad.device import Device # noqa: F401
|
||||
7
tinygrad/features/datasets.py
Normal file
7
tinygrad/features/datasets.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
import gzip
|
||||
from tinygrad import Tensor, fetch
|
||||
|
||||
def _fetch_mnist(file, offset): return Tensor(gzip.open(fetch("https://storage.googleapis.com/cvdf-datasets/mnist/"+file)).read()[offset:])
|
||||
def mnist():
|
||||
return _fetch_mnist("train-images-idx3-ubyte.gz", 0x10).reshape(-1, 1, 28, 28), _fetch_mnist("train-labels-idx1-ubyte.gz", 8), \
|
||||
_fetch_mnist("t10k-images-idx3-ubyte.gz", 0x10).reshape(-1, 1, 28, 28), _fetch_mnist("t10k-labels-idx1-ubyte.gz", 8)
|
||||
Loading…
Add table
Add a link
Reference in a new issue