broadcast

This commit is contained in:
George Hotz 2026-05-14 17:04:37 -07:00
commit 770dac0e0d
2 changed files with 4 additions and 6 deletions

View file

@ -306,11 +306,6 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
def _broadcasted(self, y, reverse=False) -> tuple[Self, Self]:
if not isinstance(y, type(self)): y = self.ufix(y)
x, y = (self, y) if not reverse else (y, self)
# ValueError: unsized ptr has shape (-1,) which can't broadcast; RuntimeError: shape mismatch
#try:
# out_shape = _broadcast_shape(x.shape, y.shape)
# x, y = x._broadcast_to(out_shape), y._broadcast_to(out_shape)
#except (RuntimeError, ValueError): pass
# ptr dtypes aren't in the promo lattice
if x.dtype == y.dtype or any(isinstance(d, PtrDType) for d in (x.dtype, y.dtype)): return x, y
return x.cast(out_dtype := least_upper_dtype(x.dtype, y.dtype)), y.cast(out_dtype)

View file

@ -51,7 +51,10 @@ def _align_left(*shapes:tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
max_dim = max(len(s) for s in shapes)
return tuple((1,)*(max_dim-len(s))+s for s in shapes)
def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]:
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes)))
ret = tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes)))
if not all(resolve(s == ns) or resolve(s == 1) for shape in _align_left(*shapes) for s,ns in zip(shape, ret)):
raise ValueError(f"shape mismatch: objects cannot be broadcast to a single shape {shapes}")
return ret
def ssimplify(uop:sint): return uop.ssimplify() if isinstance(uop, UOp) else uop
def sym_infer(uop: UOp|int, var_vals: dict[str, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop