mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
broadcast
This commit is contained in:
parent
b827858479
commit
770dac0e0d
2 changed files with 4 additions and 6 deletions
|
|
@ -306,11 +306,6 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
|||
def _broadcasted(self, y, reverse=False) -> tuple[Self, Self]:
|
||||
if not isinstance(y, type(self)): y = self.ufix(y)
|
||||
x, y = (self, y) if not reverse else (y, self)
|
||||
# ValueError: unsized ptr has shape (-1,) which can't broadcast; RuntimeError: shape mismatch
|
||||
#try:
|
||||
# out_shape = _broadcast_shape(x.shape, y.shape)
|
||||
# x, y = x._broadcast_to(out_shape), y._broadcast_to(out_shape)
|
||||
#except (RuntimeError, ValueError): pass
|
||||
# ptr dtypes aren't in the promo lattice
|
||||
if x.dtype == y.dtype or any(isinstance(d, PtrDType) for d in (x.dtype, y.dtype)): return x, y
|
||||
return x.cast(out_dtype := least_upper_dtype(x.dtype, y.dtype)), y.cast(out_dtype)
|
||||
|
|
|
|||
|
|
@ -51,7 +51,10 @@ def _align_left(*shapes:tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
|
|||
max_dim = max(len(s) for s in shapes)
|
||||
return tuple((1,)*(max_dim-len(s))+s for s in shapes)
|
||||
def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]:
|
||||
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes)))
|
||||
ret = tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes)))
|
||||
if not all(resolve(s == ns) or resolve(s == 1) for shape in _align_left(*shapes) for s,ns in zip(shape, ret)):
|
||||
raise ValueError(f"shape mismatch: objects cannot be broadcast to a single shape {shapes}")
|
||||
return ret
|
||||
|
||||
def ssimplify(uop:sint): return uop.ssimplify() if isinstance(uop, UOp) else uop
|
||||
def sym_infer(uop: UOp|int, var_vals: dict[str, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue