cleanup tensor BIND + remove outdated comments in tensor.py [pr] (#9712)

* cleanup tensor BIND + remove outdated comments in tensor.py [pr]

* from_blob whitespace

* assert
This commit is contained in:
qazal 2025-04-03 11:21:53 +08:00 committed by GitHub
commit bbd13191f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 5 additions and 6 deletions

View file

@ -478,6 +478,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
var, val = arg.unbind()
return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val)
# otherwise it's just a RESHAPE(BUFFER)
assert op is Ops.EMPTY, f"unkown op {op}"
if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}")
return UOp.new_buffer(device, size, dtype).reshape(shape)
def copy_to_device(self, device:str|tuple[str, ...], clone:bool=False): return UOp(Ops.COPY, self.dtype, (UOp(Ops.DEVICE, arg=device), self), clone)

View file

@ -138,11 +138,10 @@ class Tensor(SimpleMathTrait):
# None (the default) will be updated to True if it's put in an optimizer
self.requires_grad:bool|None = requires_grad
# create a LazyBuffer from the different types of inputs
# create a UOp from the different types of inputs
if isinstance(data, UOp):
assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
# NOTE: this is here because LazyBuffer = UOp
if isinstance(data, UOp) and data.op is Ops.BIND: data = _metaop(Ops.BIND, tuple(), dtype or data.dtype, device, data)
if data.op is Ops.BIND: data = _metaop(Ops.BIND, tuple(), dtype or data.dtype, device, data)
elif data is None: data = _metaop(Ops.EMPTY, (0,), dtype or dtypes.default_float, device)
elif isinstance(data, get_args(ConstType)): data = _metaop(Ops.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
@ -398,7 +397,7 @@ class Tensor(SimpleMathTrait):
@staticmethod
def from_uop(y:UOp, **kwargs) -> Tensor:
if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False) # this is the only UOp allowed in Tensor
if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False)
if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
if y.op is Ops.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
@ -438,8 +437,7 @@ class Tensor(SimpleMathTrait):
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
"""
r = Tensor._metaop(Ops.EMPTY, shape, **kwargs)
r = Tensor.empty(*shape, **kwargs)
r.lazydata.buffer.allocate(external_ptr=ptr)
return r