mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
deviceless const cleanups (#16341)
This commit is contained in:
parent
35461d4d8f
commit
149a87dac2
4 changed files with 5 additions and 5 deletions
|
|
@ -10,7 +10,7 @@ class TestTensorVariable(unittest.TestCase):
|
|||
|
||||
def test_inner_tvar_node(self):
|
||||
vv = Variable("w", 0, 10).bind(2)
|
||||
ret = Tensor.from_uop(vv * 4).item()
|
||||
ret = Tensor(vv * 4).item()
|
||||
assert ret == 8
|
||||
|
||||
def test_inner_tvar_mul(self):
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ class TestAssign(unittest.TestCase):
|
|||
def test_assign_changes_realized_alt(self): return self.test_assign_changes_alt(realize=True)
|
||||
|
||||
def test_assign_changes_buffer_alt(self):
|
||||
a, b = [Tensor(Tensor(0).contiguous().realize().uop.buf_uop) for _ in range(2)]
|
||||
a, b = [Tensor(Tensor([0]).realize().uop.buf_uop) for _ in range(2)]
|
||||
Tensor.realize(a.contiguous().assign(1), b.contiguous().assign(2))
|
||||
self.assertEqual((a + b).item(), 3)
|
||||
|
||||
|
|
|
|||
|
|
@ -398,7 +398,7 @@ class Transformer:
|
|||
v_start_pos = UOp.variable("start_pos", 0, self.max_context-1)
|
||||
v_toks = UOp.variable("toks", 1, chunk_size)
|
||||
# TODO: use UOp.variable for temperature once float variables are supported
|
||||
temp = Tensor(temperature).contiguous()
|
||||
temp = Tensor([temperature])
|
||||
# assign all input tokens once, then slice from start_pos for the model call
|
||||
t = Tensor(tokens + [0] * (self.max_context - len(tokens)), dtype="int32").reshape(1, self.max_context)
|
||||
# recompute start_pos from what's currently valid in the caches
|
||||
|
|
|
|||
|
|
@ -1270,7 +1270,7 @@ class Tensor(OpMixin):
|
|||
def ufix(self, x) -> Tensor:
|
||||
# TODO: x:ConstType|UOp does not work because mixin only accepts Self | ConstType
|
||||
assert isinstance(x, (*get_args(ConstType), UOp)), f"{type(x)=}, {x=}"
|
||||
return Tensor(x, self.device, self.dtype if self._ufix_keep_dtype(x) else None)
|
||||
return Tensor(self.uop.ufix(x))
|
||||
|
||||
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor:
|
||||
"""
|
||||
|
|
@ -1292,7 +1292,7 @@ class Tensor(OpMixin):
|
|||
"""
|
||||
if isinstance(x, Tensor): x, y = x._broadcasted(y)
|
||||
elif isinstance(y, Tensor): y, x = y._broadcasted(x)
|
||||
else: x, y = Tensor(x, self.device)._broadcasted(y)
|
||||
else: x, y = self.ufix(x)._broadcasted(y)
|
||||
out_shape = _broadcast_shape(self.shape, x.shape)
|
||||
return self.cast(dtypes.bool)._broadcast_to(out_shape)._apply_uop(UOp.where, x._broadcast_to(out_shape), y._broadcast_to(out_shape))
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue