mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
move cast around expand backward to tensor.py (#11483)
This commit is contained in:
parent
0ce0f51010
commit
823f1a01db
2 changed files with 4 additions and 5 deletions
|
|
@ -1,6 +1,6 @@
|
|||
from typing import cast
|
||||
import math, dataclasses
|
||||
from tinygrad.dtype import dtypes, sum_acc_dtype
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
|
||||
from tinygrad.helpers import argsort
|
||||
|
||||
|
|
@ -38,9 +38,7 @@ pm_gradient = PatternMatcher([
|
|||
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
|
||||
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
|
||||
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip(ret.arg),)),
|
||||
# TODO: this cast can be removed by putting the casts around the EXPAND
|
||||
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
|
||||
(ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)).cast(ctx.dtype),)),
|
||||
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)),)),
|
||||
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
|
||||
# there's no gradient for bitcast
|
||||
(UPat(Ops.BITCAST), lambda ctx: (None,)),
|
||||
|
|
|
|||
|
|
@ -3561,7 +3561,8 @@ class Tensor(MathTrait):
|
|||
# for each dimension, check either dim is 1, or it does not change
|
||||
if not all(resolve(s == ns) or resolve(s == 1) for s,ns in zip(shape, new_shape)):
|
||||
raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
|
||||
return self.reshape(shape)._apply_uop(UOp.expand, arg=new_shape)
|
||||
# NOTE: this cast is no-op in forward and uses sum_acc_dtype in the backward sum
|
||||
return self.reshape(shape).cast(sum_acc_dtype(self.dtype))._apply_uop(UOp.expand, arg=new_shape).cast(self.dtype)
|
||||
|
||||
def _broadcasted(self, y:Tensor|ConstType|UOp, reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]:
|
||||
x: Tensor = self
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue