deviceless const cleanups (#16341)

This commit is contained in:
chenyu 2026-05-22 20:11:01 -04:00 committed by GitHub
commit 149a87dac2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 5 additions and 5 deletions

View file

@ -10,7 +10,7 @@ class TestTensorVariable(unittest.TestCase):
def test_inner_tvar_node(self):
vv = Variable("w", 0, 10).bind(2)
ret = Tensor.from_uop(vv * 4).item()
ret = Tensor(vv * 4).item()
assert ret == 8
def test_inner_tvar_mul(self):

View file

@ -140,7 +140,7 @@ class TestAssign(unittest.TestCase):
def test_assign_changes_realized_alt(self): return self.test_assign_changes_alt(realize=True)
def test_assign_changes_buffer_alt(self):
a, b = [Tensor(Tensor(0).contiguous().realize().uop.buf_uop) for _ in range(2)]
a, b = [Tensor(Tensor([0]).realize().uop.buf_uop) for _ in range(2)]
Tensor.realize(a.contiguous().assign(1), b.contiguous().assign(2))
self.assertEqual((a + b).item(), 3)

View file

@ -398,7 +398,7 @@ class Transformer:
v_start_pos = UOp.variable("start_pos", 0, self.max_context-1)
v_toks = UOp.variable("toks", 1, chunk_size)
# TODO: use UOp.variable for temperature once float variables are supported
temp = Tensor(temperature).contiguous()
temp = Tensor([temperature])
# assign all input tokens once, then slice from start_pos for the model call
t = Tensor(tokens + [0] * (self.max_context - len(tokens)), dtype="int32").reshape(1, self.max_context)
# recompute start_pos from what's currently valid in the caches

View file

@ -1270,7 +1270,7 @@ class Tensor(OpMixin):
def ufix(self, x) -> Tensor:
# TODO: x:ConstType|UOp does not work because mixin only accepts Self | ConstType
assert isinstance(x, (*get_args(ConstType), UOp)), f"{type(x)=}, {x=}"
return Tensor(x, self.device, self.dtype if self._ufix_keep_dtype(x) else None)
return Tensor(self.uop.ufix(x))
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor:
"""
@ -1292,7 +1292,7 @@ class Tensor(OpMixin):
"""
if isinstance(x, Tensor): x, y = x._broadcasted(y)
elif isinstance(y, Tensor): y, x = y._broadcasted(x)
else: x, y = Tensor(x, self.device)._broadcasted(y)
else: x, y = self.ufix(x)._broadcasted(y)
out_shape = _broadcast_shape(self.shape, x.shape)
return self.cast(dtypes.bool)._broadcast_to(out_shape)._apply_uop(UOp.where, x._broadcast_to(out_shape), y._broadcast_to(out_shape))