move cast around expand backward to tensor.py (#11483)

This commit is contained in:
chenyu 2025-08-02 20:03:54 -07:00 committed by GitHub
commit 823f1a01db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 4 additions and 5 deletions

View file

@ -1,6 +1,6 @@
from typing import cast
import math, dataclasses
from tinygrad.dtype import dtypes, sum_acc_dtype
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
from tinygrad.helpers import argsort
@ -38,9 +38,7 @@ pm_gradient = PatternMatcher([
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip(ret.arg),)),
# TODO: this cast can be removed by putting the casts around the EXPAND
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
(ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)).cast(ctx.dtype),)),
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)),)),
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
# there's no gradient for bitcast
(UPat(Ops.BITCAST), lambda ctx: (None,)),

View file

@ -3561,7 +3561,8 @@ class Tensor(MathTrait):
# for each dimension, check either dim is 1, or it does not change
if not all(resolve(s == ns) or resolve(s == 1) for s,ns in zip(shape, new_shape)):
raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
return self.reshape(shape)._apply_uop(UOp.expand, arg=new_shape)
# NOTE: this cast is no-op in forward and uses sum_acc_dtype in the backward sum
return self.reshape(shape).cast(sum_acc_dtype(self.dtype))._apply_uop(UOp.expand, arg=new_shape).cast(self.dtype)
def _broadcasted(self, y:Tensor|ConstType|UOp, reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]:
x: Tensor = self