mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
2 commits
master
...
broadcast_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
770dac0e0d | ||
|
|
b827858479 |
5 changed files with 26 additions and 19 deletions
|
|
@ -14,6 +14,13 @@ def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
|
|||
return ((mask/broadcast_to_input(count)) * broadcast_to_input(ctx),)
|
||||
if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],)
|
||||
|
||||
def unbroadcast(ctx:UOp, shape:tuple|None) -> UOp:
|
||||
if ctx._shape is None or shape is None or ctx.shape == shape: return ctx
|
||||
if len(shape) > len(ctx.shape): raise RuntimeError(f"can't unbroadcast {ctx.shape} to {shape}")
|
||||
aligned = (1,)*(len(ctx.shape)-len(shape)) + shape
|
||||
axis = tuple(i for i,(s,n) in enumerate(zip(aligned, ctx.shape)) if s != n)
|
||||
return ctx.cast(sum_acc_dtype(ctx.dtype))._rop(Ops.ADD, axis).cast(ctx.dtype).reshape(shape)
|
||||
|
||||
def _compact_params(body:UOp, all_args:tuple[UOp, ...]) -> tuple[UOp, tuple[UOp, ...]]:
|
||||
"""Remove unused PARAMs from body and return compacted (body, args)."""
|
||||
used = sorted({p.arg: p for p in body.toposort() if p.op is Ops.PARAM}.items())
|
||||
|
|
@ -66,9 +73,7 @@ pm_gradient = PatternMatcher([
|
|||
(UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
|
||||
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
|
||||
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)),
|
||||
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
|
||||
(ctx.cast(sum_acc_dtype(ctx.dtype))._rop(Ops.ADD, tuple(i for i,(s,n) in enumerate(zip(ret.src[0].shape, ret.shape)) if s!=n))
|
||||
.cast(ctx.dtype), None)),
|
||||
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (unbroadcast(ctx, ret.src[0]._shape), None)),
|
||||
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
||||
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
||||
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),
|
||||
|
|
@ -114,6 +119,7 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
|
|||
assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
|
||||
for k,v in zip(t0.src, lgrads):
|
||||
if v is None: continue
|
||||
v = unbroadcast(v, k._shape)
|
||||
if k in grads and grads[k].op is not Ops.NOOP:
|
||||
if v.op is Ops.TUPLE and grads[k].op is Ops.TUPLE:
|
||||
grads[k] = UOp.maketuple(*(p + n if (p.op is not Ops.NOOP and n.op is not Ops.NOOP) else
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from tinygrad.mixin.elementwise import ElementwiseMixin
|
|||
from tinygrad.mixin.movement import MovementMixin
|
||||
from tinygrad.mixin.reduce import ReduceMixin
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.uop.ops import _broadcast_shape, resolve, smax, smin, identity_element
|
||||
from tinygrad.uop.ops import resolve, smax, smin, identity_element
|
||||
from tinygrad.dtype import ConstType, DType, DTypeLike, Invalid, InvalidType, PtrDType, PyConst, dtypes, least_upper_dtype, sum_acc_dtype, to_dtype
|
||||
from tinygrad.helpers import all_int, argfix, ceildiv, flatten, flat_to_grouped, make_tuple, prod, resolve_pool_pads, round_up
|
||||
|
||||
|
|
@ -306,11 +306,6 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
|||
def _broadcasted(self, y, reverse=False) -> tuple[Self, Self]:
|
||||
if not isinstance(y, type(self)): y = self.ufix(y)
|
||||
x, y = (self, y) if not reverse else (y, self)
|
||||
# ValueError: unsized ptr has shape (-1,) which can't broadcast; RuntimeError: shape mismatch
|
||||
try:
|
||||
out_shape = _broadcast_shape(x.shape, y.shape)
|
||||
x, y = x._broadcast_to(out_shape), y._broadcast_to(out_shape)
|
||||
except (RuntimeError, ValueError): pass
|
||||
# ptr dtypes aren't in the promo lattice
|
||||
if x.dtype == y.dtype or any(isinstance(d, PtrDType) for d in (x.dtype, y.dtype)): return x, y
|
||||
return x.cast(out_dtype := least_upper_dtype(x.dtype, y.dtype)), y.cast(out_dtype)
|
||||
|
|
|
|||
|
|
@ -144,11 +144,18 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO
|
|||
case _: raise RuntimeError(f"{op} is not a MovementOp")
|
||||
return rngs
|
||||
|
||||
pm_do_broadcast = PatternMatcher([
|
||||
(UPat(GroupOp.Broadcastable, name="x"), lambda x: x.replace(src=tuple(y._broadcast_to(x.shape) for y in x.src))),
|
||||
])
|
||||
|
||||
@profile_matches
|
||||
def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
if debug: print("**************************")
|
||||
rctx = IndexingContext()
|
||||
|
||||
# run broadcasting
|
||||
tsink = graph_rewrite(tsink, pm_do_broadcast, name="do broadcast")
|
||||
|
||||
# get ops to realize
|
||||
graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, name="get realize")
|
||||
|
||||
|
|
|
|||
|
|
@ -118,6 +118,9 @@ class GroupOp:
|
|||
# TODO: is BITCAST always Elementwise if it's shape changing?
|
||||
Elementwise = set.union(ALU, {Ops.CAST, Ops.BITCAST})
|
||||
|
||||
# all ops that support shape broadcasting
|
||||
Broadcastable = set.union(Elementwise, {Ops.CAST, Ops.GROUP, Ops.STORE})
|
||||
|
||||
Defines = {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
|
||||
|
||||
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
|
||||
|
|
|
|||
|
|
@ -51,7 +51,10 @@ def _align_left(*shapes:tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
|
|||
max_dim = max(len(s) for s in shapes)
|
||||
return tuple((1,)*(max_dim-len(s))+s for s in shapes)
|
||||
def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]:
|
||||
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes)))
|
||||
ret = tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes)))
|
||||
if not all(resolve(s == ns) or resolve(s == 1) for shape in _align_left(*shapes) for s,ns in zip(shape, ret)):
|
||||
raise ValueError(f"shape mismatch: objects cannot be broadcast to a single shape {shapes}")
|
||||
return ret
|
||||
|
||||
def ssimplify(uop:sint): return uop.ssimplify() if isinstance(uop, UOp) else uop
|
||||
def sym_infer(uop: UOp|int, var_vals: dict[str, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
|
||||
|
|
@ -317,13 +320,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return tuple(1 if i in axis_arg else s for i,s in enumerate(ps))
|
||||
|
||||
# elementwise ops keep the shape the same. all inputs with shape must match
|
||||
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.GROUP, Ops.STORE}):
|
||||
if self.op in GroupOp.Broadcastable:
|
||||
input_shapes = [x._shape for x in self.src]
|
||||
assert len(self.src) > 0 and all(x is not None for x in input_shapes), f"None input shape not supported for {self.op}"
|
||||
# TODO: add broadcasting here
|
||||
if not all_same(input_shapes):
|
||||
raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes} {[x.op for x in self.src]}")
|
||||
return input_shapes[0]
|
||||
return _broadcast_shape(*input_shapes)
|
||||
|
||||
# all Ops must be explicitly handled
|
||||
raise NotImplementedError(f"no shape handling for {self.op} with {self.dtype}")
|
||||
|
|
@ -483,10 +483,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return UOp(Ops.CONTRACT, dtype=self.dtype.vec(prod([x.vmax+1 for x in rngs])), src=(self,), arg=tuple((x.arg[0], x.vmax+1) for x in rngs))
|
||||
def alu(self, op, *src:UOp, **kwargs):
|
||||
all_srcs = (self, *src)
|
||||
# broadcast shaped operands to a common shape (None and () are falsy, so only real shapes participate)
|
||||
if (shapes := [s for x in all_srcs if (s:=x._shape)]) and not all_same(shapes):
|
||||
out_shape = _broadcast_shape(*shapes)
|
||||
all_srcs = tuple(x._broadcast_to(out_shape) if x._shape else x for x in all_srcs)
|
||||
out_dtype = all_srcs[-1].dtype
|
||||
if op in {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
||||
return UOp(op, out_dtype, all_srcs, **kwargs)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue