Merge branch 'master' into shrink_in_render

This commit is contained in:
George Hotz 2026-05-29 13:06:31 -07:00 committed by GitHub
commit 1670dbfacd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 24 additions and 21 deletions

View file

@ -323,7 +323,7 @@ class TestWithGrad(unittest.TestCase):
def test_set_overlapping_backward(self):
z = Tensor.zeros(6)
x = Tensor.ones(4)
x = Tensor.ones(4).contiguous()
y = Tensor.ones(4) * 2
z[:4] = x
z[2:] = y

View file

@ -456,14 +456,14 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 4, [c1.weight, c1.bias, *nn.state.get_parameters(bn)])
def test_fold_conv_batchnorm_optim(self, adam=False):
# 2 is too low?
optim, cnt = (nn.optim.Adam, 16) if adam else (nn.optim.SGD, 2)
optim, cnt = (nn.optim.Adam, 29) if adam else (nn.optim.SGD, 15)
with Tensor.train():
img = Tensor.ones(1,3,4,4)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
_realize_weights([c1, bn])
opt = optim(nn.state.get_parameters([c1, bn]))
Tensor.realize(img, *nn.state.get_parameters(opt))
img_bn = bn(c1(img)).elu().sum()
opt.zero_grad()
img_bn.backward()
@ -477,7 +477,7 @@ class TestSchedule(unittest.TestCase):
fw = bn(x).contiguous_backward().relu().contiguous()
fw.sum().backward()
# TODO: this is too many
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 9)
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 10, nn.state.get_parameters(bn))
def test_fold_conv_relu(self):
c1 = nn.Conv2d(3,16,3)

View file

@ -15,6 +15,10 @@ class TestRealizeIsRealized(unittest.TestCase):
t = Tensor.zeros(10).contiguous().realize()
assert t.uop.is_realized
def test_ones(self):
t = Tensor.ones(4, 4).realize()
assert t.uop.is_realized
def test_bytes(self):
t = Tensor(b'\x01\x02\x03').realize()
assert t.uop.is_realized
@ -51,10 +55,6 @@ class TestRealizeIsRealized(unittest.TestCase):
t = Tensor(3.14).realize()
assert not t.uop.is_realized
def test_ones_not_realized(self):
t = Tensor.ones(4, 4).realize()
assert not t.uop.is_realized
def test_none_not_realized(self):
t = Tensor(None).realize()
assert not t.uop.is_realized

View file

@ -11,7 +11,7 @@ from tinygrad.dtype import ConstType, DType, DTypeLike, Invalid, InvalidType, Pt
from tinygrad.helpers import all_int, argfix, ceildiv, flatten, flat_to_grouped, make_tuple, prod, resolve_pool_pads, round_up
if TYPE_CHECKING:
from tinygrad.uop.ops import sint
from tinygrad.uop.ops import sint, UOp
ReductionStr = Literal["mean", "sum", "none"]
@ -23,12 +23,13 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
def const(dtype, b, device=None): raise NotImplementedError("creation helpers are only supported on Tensor and UOp")
@classmethod
def full(cls, shape:tuple[sint, ...], fill_value:ConstType, **kwargs) -> Self:
def full(cls, shape:tuple[sint, ...], fill_value:ConstType|UOp, dtype:DTypeLike|None=None,
device:str|tuple[str, ...]|None=None, buffer=True) -> Self:
"""
Creates a tensor with the given shape, filled with the given value.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
Pass `buffer=False` to get a broadcast const value instead of a materialized buffer.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.full((2, 3), 42).numpy())
@ -37,11 +38,13 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
print(Tensor.full((2, 3), False).numpy())
```
"""
from tinygrad.uop.ops import UOp
new_shape = argfix(shape)
if not kwargs.pop("buffer", True):
dt = to_dtype(kwargs.pop("dtype", None) or dtypes.from_py(fill_value))
return cls.const(dt, fill_value, canonicalize_device(kwargs.pop("device", None))).reshape((1,)*len(new_shape)).expand(new_shape)
return cls.unique_const(fill_value, **kwargs).reshape((1,)*len(new_shape)).expand(new_shape)
dt = to_dtype(dtype) if dtype is not None else None
if isinstance(fill_value, UOp): val = cls.const(dt or fill_value.dtype, fill_value)
else: val = cls.const(dt or dtypes.from_py(fill_value), fill_value, None if buffer else canonicalize_device(device))
val = val.reshape((1,)*len(new_shape)).expand(new_shape)
return val.clone(device=device) if buffer else val
@classmethod
def invalids(cls, *shape, **kwargs) -> Self:
@ -52,7 +55,8 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
Eventually Tensor.empty will be replaced by this.
"""
return cls.full(argfix(*shape), Invalid, **kwargs)
new_shape = argfix(*shape)
return cls.unique_const(Invalid, **kwargs).reshape((1,)*len(new_shape)).expand(new_shape)
@classmethod
def zeros(cls, *shape, **kwargs) -> Self:

View file

@ -580,7 +580,7 @@ class Tensor(OpMixin):
def _multi_like(self, fxn, *args, **kwargs) -> Tensor:
dtype = kwargs.pop("dtype", self.dtype)
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
if kwargs.pop("device", None) is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
if self.uop.axis is None: return fxn(self.shape, *args, dtype=dtype, **kwargs).shard(self.device)
stacked = UOp.mstack(*[fxn(self.uop.shard_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device])
@ -598,9 +598,8 @@ class Tensor(OpMixin):
print(Tensor.full_like(t, 42).numpy())
```
"""
if device is None: return super().full_like(fill_value, dtype)
if isinstance(self.device, tuple): raise RuntimeError("cannot specify `device` on `full_like` of a multi device tensor")
return Tensor.full(self.shape, fill_value, dtype=dtype or self.dtype, device=device)
if isinstance(self.device, tuple): return self._multi_like(Tensor.full, fill_value, dtype=dtype or self.dtype, device=device)
return Tensor.full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device if device is None else device)
def rand_like(self, **kwargs) -> Tensor:
"""

View file

@ -125,7 +125,7 @@ spec_tensor = PatternMatcher([
# CONST with a UNIQUE or DEVICE
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),)), lambda: True),
(UPat(Ops.CONST, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE))), lambda: True),
(UPat(Ops.CONST, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="c"), lambda c: c.arg is Invalid),
# BUFFER
(UPat(Ops.BUFFER, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="buf"),