mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
4 commits
master
...
only_assig
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
73331f4e55 | ||
|
|
fe25f98f49 | ||
|
|
4f91db06d7 | ||
|
|
1a4826f802 |
1 changed files with 2 additions and 1 deletions
|
|
@ -309,6 +309,7 @@ class Tensor(OpMixin):
|
|||
# but AFTER must be embedded before _apply_uop (so subsequent assigns see it)
|
||||
assign_uop = self.uop.assign(x.uop)
|
||||
base = self.uop.base
|
||||
if base.op not in {Ops.BUFFER, Ops.AFTER, Ops.ASSIGN}: raise RuntimeError(f"can only assign to buffer, not {base.op}")
|
||||
if base.op in {Ops.BUFFER, Ops.AFTER} and not self.uop.has_buffer_identity():
|
||||
_apply_map_to_tensors({base: base.after(assign_uop)}, name="Embed View Assign", walk=True)
|
||||
return self.replace(self._apply_uop(lambda *_: assign_uop, x))
|
||||
|
|
@ -621,7 +622,7 @@ class Tensor(OpMixin):
|
|||
Tensor._device_seeds[device] = Tensor(
|
||||
[int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
|
||||
device=device, dtype=dtypes.uint32, requires_grad=False)
|
||||
Tensor._device_rng_counters[device] = Tensor([num], device=device, dtype=dtypes.uint32, requires_grad=False).contiguous()
|
||||
Tensor._device_rng_counters[device] = Tensor([num], device=device, dtype=dtypes.uint32, requires_grad=False).clone()
|
||||
# increment rng counter for devices
|
||||
else: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue