Compare commits

...

2 commits

Author SHA1 Message Date
George Hotz
770dac0e0d broadcast 2026-05-14 17:04:37 -07:00
George Hotz
b827858479 broadcast shape 2026-05-14 17:01:20 -07:00
5 changed files with 26 additions and 19 deletions

View file

@ -14,6 +14,13 @@ def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
return ((mask/broadcast_to_input(count)) * broadcast_to_input(ctx),) return ((mask/broadcast_to_input(count)) * broadcast_to_input(ctx),)
if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],) if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],)
def unbroadcast(ctx:UOp, shape:tuple|None) -> UOp:
if ctx._shape is None or shape is None or ctx.shape == shape: return ctx
if len(shape) > len(ctx.shape): raise RuntimeError(f"can't unbroadcast {ctx.shape} to {shape}")
aligned = (1,)*(len(ctx.shape)-len(shape)) + shape
axis = tuple(i for i,(s,n) in enumerate(zip(aligned, ctx.shape)) if s != n)
return ctx.cast(sum_acc_dtype(ctx.dtype))._rop(Ops.ADD, axis).cast(ctx.dtype).reshape(shape)
def _compact_params(body:UOp, all_args:tuple[UOp, ...]) -> tuple[UOp, tuple[UOp, ...]]: def _compact_params(body:UOp, all_args:tuple[UOp, ...]) -> tuple[UOp, tuple[UOp, ...]]:
"""Remove unused PARAMs from body and return compacted (body, args).""" """Remove unused PARAMs from body and return compacted (body, args)."""
used = sorted({p.arg: p for p in body.toposort() if p.op is Ops.PARAM}.items()) used = sorted({p.arg: p for p in body.toposort() if p.op is Ops.PARAM}.items())
@ -66,9 +73,7 @@ pm_gradient = PatternMatcher([
(UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)), (UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)), (UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)), (UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)),
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (unbroadcast(ctx, ret.src[0]._shape), None)),
(ctx.cast(sum_acc_dtype(ctx.dtype))._rop(Ops.ADD, tuple(i for i,(s,n) in enumerate(zip(ret.src[0].shape, ret.shape)) if s!=n))
.cast(ctx.dtype), None)),
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)), (UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)), (UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)), (UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),
@ -114,6 +119,7 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}" assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
for k,v in zip(t0.src, lgrads): for k,v in zip(t0.src, lgrads):
if v is None: continue if v is None: continue
v = unbroadcast(v, k._shape)
if k in grads and grads[k].op is not Ops.NOOP: if k in grads and grads[k].op is not Ops.NOOP:
if v.op is Ops.TUPLE and grads[k].op is Ops.TUPLE: if v.op is Ops.TUPLE and grads[k].op is Ops.TUPLE:
grads[k] = UOp.maketuple(*(p + n if (p.op is not Ops.NOOP and n.op is not Ops.NOOP) else grads[k] = UOp.maketuple(*(p + n if (p.op is not Ops.NOOP and n.op is not Ops.NOOP) else

View file

@ -5,7 +5,7 @@ from tinygrad.mixin.elementwise import ElementwiseMixin
from tinygrad.mixin.movement import MovementMixin from tinygrad.mixin.movement import MovementMixin
from tinygrad.mixin.reduce import ReduceMixin from tinygrad.mixin.reduce import ReduceMixin
from tinygrad.uop import Ops from tinygrad.uop import Ops
from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element from tinygrad.uop.ops import resolve, smax, smin, identity_element
from tinygrad.dtype import ConstType, DType, DTypeLike, Invalid, InvalidType, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype from tinygrad.dtype import ConstType, DType, DTypeLike, Invalid, InvalidType, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.helpers import all_int, argfix, ceildiv, flatten, flat_to_grouped, make_tuple, prod, resolve_pool_pads, round_up from tinygrad.helpers import all_int, argfix, ceildiv, flatten, flat_to_grouped, make_tuple, prod, resolve_pool_pads, round_up
@ -306,11 +306,6 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
def _broadcasted(self, y, reverse=False) -> tuple[Self, Self]: def _broadcasted(self, y, reverse=False) -> tuple[Self, Self]:
if not isinstance(y, type(self)): y = self.ufix(y) if not isinstance(y, type(self)): y = self.ufix(y)
x, y = (self, y) if not reverse else (y, self) x, y = (self, y) if not reverse else (y, self)
# ValueError: unsized ptr has shape (-1,) which can't broadcast; RuntimeError: shape mismatch
try:
out_shape = _broadcast_shape(x.shape, y.shape)
x, y = x._broadcast_to(out_shape), y._broadcast_to(out_shape)
except (RuntimeError, ValueError): pass
# ptr dtypes aren't in the promo lattice # ptr dtypes aren't in the promo lattice
if x.dtype == y.dtype or any(isinstance(d, PtrDType) for d in (x.dtype, y.dtype)): return x, y if x.dtype == y.dtype or any(isinstance(d, PtrDType) for d in (x.dtype, y.dtype)): return x, y
return x.cast(out_dtype := least_upper_dtype(x.dtype, y.dtype)), y.cast(out_dtype) return x.cast(out_dtype := least_upper_dtype(x.dtype, y.dtype)), y.cast(out_dtype)

View file

@ -144,11 +144,18 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO
case _: raise RuntimeError(f"{op} is not a MovementOp") case _: raise RuntimeError(f"{op} is not a MovementOp")
return rngs return rngs
pm_do_broadcast = PatternMatcher([
(UPat(GroupOp.Broadcastable, name="x"), lambda x: x.replace(src=tuple(y._broadcast_to(x.shape) for y in x.src))),
])
@profile_matches @profile_matches
def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
if debug: print("**************************") if debug: print("**************************")
rctx = IndexingContext() rctx = IndexingContext()
# run broadcasting
tsink = graph_rewrite(tsink, pm_do_broadcast, name="do broadcast")
# get ops to realize # get ops to realize
graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, name="get realize") graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, name="get realize")

View file

@ -118,6 +118,9 @@ class GroupOp:
# TODO: is BITCAST always Elementwise if it's shape changing? # TODO: is BITCAST always Elementwise if it's shape changing?
Elementwise = set.union(ALU, {Ops.CAST, Ops.BITCAST}) Elementwise = set.union(ALU, {Ops.CAST, Ops.BITCAST})
# all ops that support shape broadcasting
Broadcastable = set.union(Elementwise, {Ops.CAST, Ops.GROUP, Ops.STORE})
Defines = {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG} Defines = {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE} Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}

View file

@ -51,7 +51,10 @@ def _align_left(*shapes:tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
max_dim = max(len(s) for s in shapes) max_dim = max(len(s) for s in shapes)
return tuple((1,)*(max_dim-len(s))+s for s in shapes) return tuple((1,)*(max_dim-len(s))+s for s in shapes)
def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]: def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]:
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes))) ret = tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes)))
if not all(resolve(s == ns) or resolve(s == 1) for shape in _align_left(*shapes) for s,ns in zip(shape, ret)):
raise ValueError(f"shape mismatch: objects cannot be broadcast to a single shape {shapes}")
return ret
def ssimplify(uop:sint): return uop.ssimplify() if isinstance(uop, UOp) else uop def ssimplify(uop:sint): return uop.ssimplify() if isinstance(uop, UOp) else uop
def sym_infer(uop: UOp|int, var_vals: dict[str, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop def sym_infer(uop: UOp|int, var_vals: dict[str, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
@ -317,13 +320,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return tuple(1 if i in axis_arg else s for i,s in enumerate(ps)) return tuple(1 if i in axis_arg else s for i,s in enumerate(ps))
# elementwise ops keep the shape the same. all inputs with shape must match # elementwise ops keep the shape the same. all inputs with shape must match
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.GROUP, Ops.STORE}): if self.op in GroupOp.Broadcastable:
input_shapes = [x._shape for x in self.src] input_shapes = [x._shape for x in self.src]
assert len(self.src) > 0 and all(x is not None for x in input_shapes), f"None input shape not supported for {self.op}" assert len(self.src) > 0 and all(x is not None for x in input_shapes), f"None input shape not supported for {self.op}"
# TODO: add broadcasting here return _broadcast_shape(*input_shapes)
if not all_same(input_shapes):
raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes} {[x.op for x in self.src]}")
return input_shapes[0]
# all Ops must be explicitly handled # all Ops must be explicitly handled
raise NotImplementedError(f"no shape handling for {self.op} with {self.dtype}") raise NotImplementedError(f"no shape handling for {self.op} with {self.dtype}")
@ -483,10 +483,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return UOp(Ops.CONTRACT, dtype=self.dtype.vec(prod([x.vmax+1 for x in rngs])), src=(self,), arg=tuple((x.arg[0], x.vmax+1) for x in rngs)) return UOp(Ops.CONTRACT, dtype=self.dtype.vec(prod([x.vmax+1 for x in rngs])), src=(self,), arg=tuple((x.arg[0], x.vmax+1) for x in rngs))
def alu(self, op, *src:UOp, **kwargs): def alu(self, op, *src:UOp, **kwargs):
all_srcs = (self, *src) all_srcs = (self, *src)
# broadcast shaped operands to a common shape (None and () are falsy, so only real shapes participate)
if (shapes := [s for x in all_srcs if (s:=x._shape)]) and not all_same(shapes):
out_shape = _broadcast_shape(*shapes)
all_srcs = tuple(x._broadcast_to(out_shape) if x._shape else x for x in all_srcs)
out_dtype = all_srcs[-1].dtype out_dtype = all_srcs[-1].dtype
if op in {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
return UOp(op, out_dtype, all_srcs, **kwargs) return UOp(op, out_dtype, all_srcs, **kwargs)