less CONST(DEVICE) (#16452)

* less CONST(DEVICE)

no DEVICE for single device in const_like, multi has other issues

* maybe

* that?
This commit is contained in:
chenyu 2026-06-01 15:55:12 -04:00 committed by GitHub
commit 7e7b481ba7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 20 additions and 12 deletions

View file

@ -25,10 +25,12 @@ def calculate_storage_offset(x: Tensor) -> int:
u_strides = strides_for_shape(u.src[0].shape)
for i, (start, _) in enumerate(u.marg): offset += start * u_strides[i]
return offset
def wrap(x: Tensor) -> torch.Tensor:
def wrap(x: Tensor, dev: torch.device|None=None) -> torch.Tensor:
x._strides = strides_for_shape(x.shape) # always recalculate
if (not hasattr(x, '_storage_offset')) or (not x.uop.is_realized): x._storage_offset = calculate_storage_offset(x)
return mod.wrap(x, _to_torch_dtype(x.dtype), _to_torch_device(x.device).index)
# a deviceless tinygrad value takes the device from the op context
idx = _to_torch_device(x.device).index if x.device is not None else (dev.index if dev is not None else 0)
return mod.wrap(x, _to_torch_dtype(x.dtype), idx)
def _update_torch_metadata(tensor: torch.Tensor, tiny: Tensor) -> None:
tiny._strides = strides_for_shape(tiny.shape)
tiny._storage_offset = calculate_storage_offset(tiny)
@ -545,14 +547,17 @@ def wrap_out(f):
assigned = f(*args, **kwargs)
if getenv("ALLOW_DTYPE_MISMATCH", 1): assigned = assigned.cast(out.dtype)
assert out.shape == assigned.shape, f"shape mismatch: {assigned.shape} -> {out.shape}"
assert out.device == assigned.device, f"device mismatch: {assigned.device} -> {out.device}"
assert out.device == assigned.device or out.device is None or assigned.device is None, f"device mismatch: {assigned.device} -> {out.device}"
assert out.dtype == assigned.dtype, f"dtype mismatch: {assigned.dtype} -> {out.dtype}"
if out.device is None and assigned.device is not None: out.replace(out.empty_like(device=assigned.device))
return out.assign(assigned)
return _wrap_out
def _inplace_op(t, new_value):
if not hasattr(t, "_view_base") and not getattr(canonical_base(t), "_views", set()): t.replace(new_value)
else: _apply_inplace(t, new_value)
else:
if (base:=canonical_base(t)).device is None and new_value.device is not None: base.replace(base.empty_like(device=new_value.device))
_apply_inplace(t, new_value)
return t
tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
@ -679,10 +684,11 @@ def wrap_fxn(k,f):
if TORCH_DEBUG:
print(k, len(args), [x.shape if isinstance(x, torch.Tensor) else x for x in args],
{k:v.shape if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()})
dev = next((a.device for a in args if isinstance(a, torch.Tensor) and a.device.type == "tiny"), None)
args, kwargs = unwrap_args(args, kwargs)
out = f(*args, **kwargs)
if isinstance(out, Tensor): return wrap(out)
elif isinstance(out, tuple): return tuple(wrap(x) for x in out)
if isinstance(out, Tensor): return wrap(out, dev)
elif isinstance(out, tuple): return tuple(wrap(x, dev) for x in out)
else: raise RuntimeError(f"unknown output type {type(out)}")
return nf

View file

@ -968,7 +968,7 @@ class TestSchedule(unittest.TestCase):
def test_const_schedule_contig(self):
constv = Tensor.empty(2, 2).uop.const_like(10).contiguous()
check_schedule(constv, 1)
check_schedule(constv, 0)
def test_advanced_simple_indexing_combined(self):
X = Tensor.arange(16).reshape(4, 4)

View file

@ -20,7 +20,7 @@ class TestWinograd(unittest.TestCase):
out = Tensor.conv2d(x,w, padding=1)
out.mean().backward()
backward_schedule = x.grad.schedule_linear(w.grad)
self.assertEqual(len(backward_schedule.src), 4)
self.assertEqual(len(backward_schedule.src), 2)
@unittest.skip("this requires optimizations")
def test_counters(self):

View file

@ -851,6 +851,7 @@ class Tensor(OpMixin):
# clear contexts
for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient)):
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
if g.device is None and t.device is not None: g = g.clone(device=t.device)
if t.grad is None: t.grad = g
else: t.grad.assign(t.grad + g.to(t.grad.device))
return self

View file

@ -489,8 +489,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return perm.index(*non_slice_args, ptr=True)
return self.index(*[UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in idx])
def const_like(self, b:ConstLike, dtype:DType|None=None):
# constants can optionally have a DEVICE source
ret = UOp.const(dtype or self.dtype.base, b, device=self.device, shape=self.shard_shape if self.axis is not None else self._shape)
# multi constants can optionally have a DEVICE source # TODO: no const with DEVICE
dev = self.device if isinstance(self.device, tuple) else None
ret = UOp.const(dtype or self.dtype.base, b, device=dev, shape=self.shard_shape if self.axis is not None else self._shape)
return ret.multi(self.axis) if self.axis is not None else ret
def ufix(self, x):
if isinstance(x, UOp): return x

View file

@ -3,7 +3,7 @@ from typing import cast, Any
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, AxisType, KernelInfo, ParamArg
from tinygrad.uop.render import print_uops, pyrender
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid, ConstFloat
from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata, panic, CHECK_OOB
from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata, panic, CHECK_OOB, all_same
# ***** uop helpers *****
@ -166,7 +166,7 @@ spec_tensor = PatternMatcher([
# MULTI/MSELECT/MSTACK
(UPat(Ops.MULTI, name="multi"), lambda multi: all(x.dtype == multi.dtype for x in multi.src) and isinstance(multi.arg, int)),
(UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),
(UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(x.device, str) for x in x.src)),
(UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(s.device, str) for s in x.src) or (all_same(x.src) and x.src[0].device is None)),
# CONTIGUOUS ensures the source UOp realizes
(UPat((Ops.DETACH, Ops.CONTIGUOUS, Ops.CONTIGUOUS_BACKWARD), name="root", src=(UPat.var("x"),), arg=None),