mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
11 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee5f17bca2 | ||
|
|
1b879967c4 | ||
|
|
d11f6d316d | ||
|
|
b67def38d2 |
||
|
|
f9010fdfc9 | ||
|
|
bf116deb5a | ||
|
|
8179a07477 |
||
|
|
b14da7f9d4 | ||
|
|
dd2ff2ddb9 | ||
|
|
79393bddb4 |
||
|
|
348ab6c30f |
5 changed files with 114 additions and 0 deletions
31
test/unit/test_png.py
Normal file
31
test/unit/test_png.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
#!/usr/bin/env python
|
||||
import io, unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, fetch
|
||||
from tinygrad.nn.state import png_load
|
||||
try:
|
||||
from PIL import Image
|
||||
except ImportError:
|
||||
raise unittest.SkipTest("PIL not installed")
|
||||
|
||||
class TestPNGLoad(unittest.TestCase):
|
||||
def test_real_png(self):
|
||||
# test against a real PNG file (uses only filters 0, 1)
|
||||
fp = fetch('https://upload.wikimedia.org/wikipedia/en/d/d4/Norwegian_Forest_Cat_in_Norway.png')
|
||||
with open(fp, 'rb') as f: png_bytes = f.read()
|
||||
expected = np.array(Image.open(io.BytesIO(png_bytes)))[:, :, :3]
|
||||
result = png_load(Tensor(np.frombuffer(png_bytes, dtype=np.uint8))).numpy()
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
def test_roundtrip_png(self):
|
||||
# horizontal stripes pattern uses only filters 0, 1
|
||||
img_array = np.zeros((32, 32, 3), dtype=np.uint8)
|
||||
img_array[::2] = 255 # white stripes on black
|
||||
buf = io.BytesIO()
|
||||
Image.fromarray(img_array).save(buf, format='PNG')
|
||||
png_bytes = buf.getvalue()
|
||||
result = png_load(Tensor(np.frombuffer(png_bytes, dtype=np.uint8))).numpy()
|
||||
np.testing.assert_array_equal(result, img_array)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
50
tinygrad/apps/resnet.py
Normal file
50
tinygrad/apps/resnet.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
# classification in 50 lines
|
||||
import sys
|
||||
from tinygrad import nn, Tensor
|
||||
|
||||
class Bottleneck:
|
||||
expansion = 4
|
||||
def __init__(self, in_c, mid_c, stride=1):
|
||||
out_c = mid_c * self.expansion
|
||||
self.conv1, self.bn1 = nn.Conv2d(in_c, mid_c, 1, bias=False), nn.BatchNorm2d(mid_c)
|
||||
self.conv2, self.bn2 = nn.Conv2d(mid_c, mid_c, 3, stride, 1, bias=False), nn.BatchNorm2d(mid_c)
|
||||
self.conv3, self.bn3 = nn.Conv2d(mid_c, out_c, 1, bias=False), nn.BatchNorm2d(out_c)
|
||||
self.downsample = (stride != 1 or in_c != out_c) and [nn.Conv2d(in_c, out_c, 1, stride, bias=False), nn.BatchNorm2d(out_c)] or []
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
identity = x.sequential(self.downsample)
|
||||
x = self.bn1(self.conv1(x)).relu()
|
||||
x = self.bn2(self.conv2(x)).relu()
|
||||
x = self.bn3(self.conv3(x))
|
||||
return (x + identity).relu()
|
||||
|
||||
class ResNet50:
|
||||
def __init__(self, num_classes=1000):
|
||||
self.conv1, self.bn1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False), nn.BatchNorm2d(64)
|
||||
self.layer1 = self._make_layer(64, 64, 3, 1)
|
||||
self.layer2 = self._make_layer(256, 128, 4, 2)
|
||||
self.layer3 = self._make_layer(512, 256, 6, 2)
|
||||
self.layer4 = self._make_layer(1024,512, 3, 2)
|
||||
self.fc = nn.Linear(2048, num_classes)
|
||||
|
||||
def _make_layer(self, in_c, mid_c, blocks, stride):
|
||||
layers = [Bottleneck(in_c, mid_c, stride)]
|
||||
for _ in range(1, blocks): layers.append(Bottleneck(mid_c * Bottleneck.expansion, mid_c))
|
||||
return layers
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x = self.bn1(self.conv1(x)).relu()
|
||||
# TODO: max_pool2d return type is Tensor | tuple[Tensor, Tensor], this should be type specialised
|
||||
x = x.max_pool2d() # type: ignore
|
||||
x = x.sequential([*self.layer1, *self.layer2, *self.layer3, *self.layer4])
|
||||
x = x.mean((2, 3))
|
||||
return self.fc(x)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_url = "https://upload.wikimedia.org/wikipedia/en/d/d4/Norwegian_Forest_Cat_in_Norway.png"
|
||||
img = nn.state.png_load(Tensor.from_url(sys.argv[1] if len(sys.argv) > 1 else test_url))
|
||||
model = ResNet50()
|
||||
state_dict = nn.state.safe_load(Tensor.from_url("https://huggingface.co/timm/resnet50.a1_in1k/resolve/main/model.safetensors"))
|
||||
nn.state.load_state_dict(model, state_dict)
|
||||
value = model(img.rearrange("h w c -> 1 c h w").float()/255).argmax().item()
|
||||
print(value, nn.datasets.imagenet_labels()[value])
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
import ast
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.state import tar_extract
|
||||
|
||||
|
|
@ -12,3 +13,8 @@ def cifar(device=None):
|
|||
train = Tensor.cat(*[tt[f"cifar-10-batches-bin/data_batch_{i}.bin"].reshape(-1, 3073).to(device) for i in range(1,6)])
|
||||
test = tt["cifar-10-batches-bin/test_batch.bin"].reshape(-1, 3073).to(device)
|
||||
return train[:, 1:].reshape(-1,3,32,32), train[:, 0], test[:, 1:].reshape(-1,3,32,32), test[:, 0]
|
||||
|
||||
def imagenet_labels():
|
||||
return ast.literal_eval(Tensor.from_url(
|
||||
"https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
|
||||
).tobytes().decode())
|
||||
|
|
|
|||
|
|
@ -383,3 +383,24 @@ def gguf_load(tensor: Tensor) -> tuple[dict, dict[str, Tensor]]:
|
|||
for name, dims, typ, off in t_infos: state_dict[name] = ggml_data_to_tensor(tensor[data_start + off:], prod(dims), typ).reshape(*reversed(dims))
|
||||
|
||||
return kv_data, state_dict
|
||||
|
||||
@accept_filename
|
||||
def png_load(t:Tensor) -> Tensor:
|
||||
f = io.BufferedReader(TensorIO(t))
|
||||
assert f.read(8) == b'\x89PNG\r\n\x1a\n', "not a PNG"
|
||||
idats = []
|
||||
while (slen:=f.read(4)):
|
||||
typ, dat = f.read(4), f.read(struct.unpack(">I", slen)[0])
|
||||
if DEBUG >= 3: print(len(dat), typ)
|
||||
if typ == b'IHDR':
|
||||
width, height, depth, color_type = struct.unpack(">IIBB", dat[:10])
|
||||
assert depth == 8 and color_type in [2, 6], f"only 8-bit RGB/RGBA PNG supported {depth=} {color_type=}"
|
||||
bpp = 3 if color_type == 2 else 4
|
||||
if typ == b'IDAT': idats.append(dat)
|
||||
f.seek(4, 1)
|
||||
data = Tensor(zlib.decompress(b''.join(idats))).reshape(height, width * bpp + 1)
|
||||
filters, pixels = data[:, 0], data[:, 1:].reshape(height, width, bpp)
|
||||
assert filters.max().item() <= 1, f"only PNG filters 0/1 supported, got {set(filters.tolist())}" # type: ignore[arg-type]
|
||||
# Sub filter (type 1): each pixel adds the pixel to its left, which is cumsum along width
|
||||
pixels = (filters == 1).reshape(height, 1, 1).where(pixels.cast(dtypes.int16).cumsum(axis=1).bitwise_and(0xff).cast(dtypes.uint8), pixels)
|
||||
return pixels[:, :, :3]
|
||||
|
|
|
|||
|
|
@ -312,6 +312,12 @@ class Tensor(OpMixin):
|
|||
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
||||
return self._buffer().as_typed_buffer(self.shape)
|
||||
|
||||
def tobytes(self) -> bytes:
|
||||
"""
|
||||
Returns the data of this tensor as bytes, like numpy's `.tobytes()`.
|
||||
"""
|
||||
return bytes(self.data())
|
||||
|
||||
def item(self) -> ConstType:
|
||||
"""
|
||||
Returns the value of this tensor as a standard Python number.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue