mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
less CONST(DEVICE) (#16452)
* less CONST(DEVICE) no DEVICE for single device in const_like, multi has other issues * maybe * that?
This commit is contained in:
parent
556defa0f7
commit
7e7b481ba7
6 changed files with 20 additions and 12 deletions
|
|
@ -25,10 +25,12 @@ def calculate_storage_offset(x: Tensor) -> int:
|
|||
u_strides = strides_for_shape(u.src[0].shape)
|
||||
for i, (start, _) in enumerate(u.marg): offset += start * u_strides[i]
|
||||
return offset
|
||||
def wrap(x: Tensor) -> torch.Tensor:
|
||||
def wrap(x: Tensor, dev: torch.device|None=None) -> torch.Tensor:
|
||||
x._strides = strides_for_shape(x.shape) # always recalculate
|
||||
if (not hasattr(x, '_storage_offset')) or (not x.uop.is_realized): x._storage_offset = calculate_storage_offset(x)
|
||||
return mod.wrap(x, _to_torch_dtype(x.dtype), _to_torch_device(x.device).index)
|
||||
# a deviceless tinygrad value takes the device from the op context
|
||||
idx = _to_torch_device(x.device).index if x.device is not None else (dev.index if dev is not None else 0)
|
||||
return mod.wrap(x, _to_torch_dtype(x.dtype), idx)
|
||||
def _update_torch_metadata(tensor: torch.Tensor, tiny: Tensor) -> None:
|
||||
tiny._strides = strides_for_shape(tiny.shape)
|
||||
tiny._storage_offset = calculate_storage_offset(tiny)
|
||||
|
|
@ -545,14 +547,17 @@ def wrap_out(f):
|
|||
assigned = f(*args, **kwargs)
|
||||
if getenv("ALLOW_DTYPE_MISMATCH", 1): assigned = assigned.cast(out.dtype)
|
||||
assert out.shape == assigned.shape, f"shape mismatch: {assigned.shape} -> {out.shape}"
|
||||
assert out.device == assigned.device, f"device mismatch: {assigned.device} -> {out.device}"
|
||||
assert out.device == assigned.device or out.device is None or assigned.device is None, f"device mismatch: {assigned.device} -> {out.device}"
|
||||
assert out.dtype == assigned.dtype, f"dtype mismatch: {assigned.dtype} -> {out.dtype}"
|
||||
if out.device is None and assigned.device is not None: out.replace(out.empty_like(device=assigned.device))
|
||||
return out.assign(assigned)
|
||||
return _wrap_out
|
||||
|
||||
def _inplace_op(t, new_value):
|
||||
if not hasattr(t, "_view_base") and not getattr(canonical_base(t), "_views", set()): t.replace(new_value)
|
||||
else: _apply_inplace(t, new_value)
|
||||
else:
|
||||
if (base:=canonical_base(t)).device is None and new_value.device is not None: base.replace(base.empty_like(device=new_value.device))
|
||||
_apply_inplace(t, new_value)
|
||||
return t
|
||||
|
||||
tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
||||
|
|
@ -679,10 +684,11 @@ def wrap_fxn(k,f):
|
|||
if TORCH_DEBUG:
|
||||
print(k, len(args), [x.shape if isinstance(x, torch.Tensor) else x for x in args],
|
||||
{k:v.shape if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()})
|
||||
dev = next((a.device for a in args if isinstance(a, torch.Tensor) and a.device.type == "tiny"), None)
|
||||
args, kwargs = unwrap_args(args, kwargs)
|
||||
out = f(*args, **kwargs)
|
||||
if isinstance(out, Tensor): return wrap(out)
|
||||
elif isinstance(out, tuple): return tuple(wrap(x) for x in out)
|
||||
if isinstance(out, Tensor): return wrap(out, dev)
|
||||
elif isinstance(out, tuple): return tuple(wrap(x, dev) for x in out)
|
||||
else: raise RuntimeError(f"unknown output type {type(out)}")
|
||||
return nf
|
||||
|
||||
|
|
|
|||
|
|
@ -968,7 +968,7 @@ class TestSchedule(unittest.TestCase):
|
|||
|
||||
def test_const_schedule_contig(self):
|
||||
constv = Tensor.empty(2, 2).uop.const_like(10).contiguous()
|
||||
check_schedule(constv, 1)
|
||||
check_schedule(constv, 0)
|
||||
|
||||
def test_advanced_simple_indexing_combined(self):
|
||||
X = Tensor.arange(16).reshape(4, 4)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class TestWinograd(unittest.TestCase):
|
|||
out = Tensor.conv2d(x,w, padding=1)
|
||||
out.mean().backward()
|
||||
backward_schedule = x.grad.schedule_linear(w.grad)
|
||||
self.assertEqual(len(backward_schedule.src), 4)
|
||||
self.assertEqual(len(backward_schedule.src), 2)
|
||||
|
||||
@unittest.skip("this requires optimizations")
|
||||
def test_counters(self):
|
||||
|
|
|
|||
|
|
@ -851,6 +851,7 @@ class Tensor(OpMixin):
|
|||
# clear contexts
|
||||
for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient)):
|
||||
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
||||
if g.device is None and t.device is not None: g = g.clone(device=t.device)
|
||||
if t.grad is None: t.grad = g
|
||||
else: t.grad.assign(t.grad + g.to(t.grad.device))
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -489,8 +489,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return perm.index(*non_slice_args, ptr=True)
|
||||
return self.index(*[UOp.const(dtypes.weakint, x) if isinstance(x, int) else x for x in idx])
|
||||
def const_like(self, b:ConstLike, dtype:DType|None=None):
|
||||
# constants can optionally have a DEVICE source
|
||||
ret = UOp.const(dtype or self.dtype.base, b, device=self.device, shape=self.shard_shape if self.axis is not None else self._shape)
|
||||
# multi constants can optionally have a DEVICE source # TODO: no const with DEVICE
|
||||
dev = self.device if isinstance(self.device, tuple) else None
|
||||
ret = UOp.const(dtype or self.dtype.base, b, device=dev, shape=self.shard_shape if self.axis is not None else self._shape)
|
||||
return ret.multi(self.axis) if self.axis is not None else ret
|
||||
def ufix(self, x):
|
||||
if isinstance(x, UOp): return x
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import cast, Any
|
|||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, AxisType, KernelInfo, ParamArg
|
||||
from tinygrad.uop.render import print_uops, pyrender
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid, ConstFloat
|
||||
from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata, panic, CHECK_OOB
|
||||
from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata, panic, CHECK_OOB, all_same
|
||||
|
||||
# ***** uop helpers *****
|
||||
|
||||
|
|
@ -166,7 +166,7 @@ spec_tensor = PatternMatcher([
|
|||
# MULTI/MSELECT/MSTACK
|
||||
(UPat(Ops.MULTI, name="multi"), lambda multi: all(x.dtype == multi.dtype for x in multi.src) and isinstance(multi.arg, int)),
|
||||
(UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),
|
||||
(UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(x.device, str) for x in x.src)),
|
||||
(UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(s.device, str) for s in x.src) or (all_same(x.src) and x.src[0].device is None)),
|
||||
|
||||
# CONTIGUOUS ensures the source UOp realizes
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS, Ops.CONTIGUOUS_BACKWARD), name="root", src=(UPat.var("x"),), arg=None),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue