Compare commits

...

4 commits

Author SHA1 Message Date
George Hotz
73331f4e55 print 2026-03-04 18:55:37 +08:00
George Hotz
fe25f98f49 clone forces it to be backed by a real buffer 2026-03-04 18:43:09 +08:00
George Hotz
4f91db06d7 work 2026-03-04 18:39:48 +08:00
George Hotz
1a4826f802 only support assign to buffers/after 2026-03-04 18:32:06 +08:00

View file

@ -309,6 +309,7 @@ class Tensor(OpMixin):
# but AFTER must be embedded before _apply_uop (so subsequent assigns see it)
assign_uop = self.uop.assign(x.uop)
base = self.uop.base
if base.op not in {Ops.BUFFER, Ops.AFTER, Ops.ASSIGN}: raise RuntimeError(f"can only assign to buffer, not {base.op}")
if base.op in {Ops.BUFFER, Ops.AFTER} and not self.uop.has_buffer_identity():
_apply_map_to_tensors({base: base.after(assign_uop)}, name="Embed View Assign", walk=True)
return self.replace(self._apply_uop(lambda *_: assign_uop, x))
@ -621,7 +622,7 @@ class Tensor(OpMixin):
Tensor._device_seeds[device] = Tensor(
[int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
device=device, dtype=dtypes.uint32, requires_grad=False)
Tensor._device_rng_counters[device] = Tensor([num], device=device, dtype=dtypes.uint32, requires_grad=False).contiguous()
Tensor._device_rng_counters[device] = Tensor([num], device=device, dtype=dtypes.uint32, requires_grad=False).clone()
# increment rng counter for devices
else: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num)