mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
4 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
37a40bf975 | ||
|
|
af1db22b25 | ||
|
|
be0f9d1055 | ||
|
|
5b9a6c5520 |
8 changed files with 45 additions and 8 deletions
|
|
@ -199,6 +199,13 @@ class TestMultiTensor(unittest.TestCase):
|
|||
run_schedule(sched)
|
||||
np.testing.assert_equal(xt.numpy(), X_np[i*2:i*2+2])
|
||||
|
||||
def test_cat_on_non_shard_axis(self):
|
||||
# cat must be lowered to PAD/ADD before multi_pm runs, otherwise MULTI nodes are not handled
|
||||
X = Tensor.arange(8).reshape(4, 2).realize().shard_(devices_2, 0)
|
||||
Y = Tensor.arange(8, 16).reshape(4, 2).realize().shard_(devices_2, 0)
|
||||
Z = X.cat(Y, dim=1)
|
||||
np.testing.assert_equal(Z.numpy(), np.concatenate([np.arange(8).reshape(4, 2), np.arange(8, 16).reshape(4, 2)], axis=1))
|
||||
|
||||
@given(strat.sampled_from((devices_2, devices_3)),
|
||||
strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)),
|
||||
strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)))
|
||||
|
|
|
|||
|
|
@ -1,8 +1,14 @@
|
|||
from typing import cast
|
||||
import math, dataclasses
|
||||
import math, itertools, dataclasses
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
|
||||
from tinygrad.helpers import argsort
|
||||
|
||||
def cat_gradient(ctx:UOp, ret:UOp) -> tuple[UOp, ...]:
|
||||
axis = ret.arg
|
||||
dim_acc = list(itertools.accumulate([s.shape[axis] for s in ret.src], initial=0))
|
||||
return tuple(ctx.shrink(tuple([(dim_acc[i], dim_acc[i+1]) if j==axis else (0, ctx.shape[j])
|
||||
for j in range(len(ctx.shape))])) for i in range(len(ret.src)))
|
||||
|
||||
def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
|
||||
def broadcast_to_input(x): return x.reshape(x.shape+(1,)*(len(ret.src[0].shape)-len(x.shape))).expand(ret.src[0].shape)
|
||||
if op == Ops.ADD: return (broadcast_to_input(ctx),)
|
||||
|
|
@ -54,6 +60,7 @@ pm_gradient = PatternMatcher([
|
|||
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
||||
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),
|
||||
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip([i for i,x in enumerate(ret.marg) if x]),)),
|
||||
(UPat(Ops.CAT, name="ret"), lambda ctx, ret: cat_gradient(ctx, ret)),
|
||||
(UPat(Ops.COPY, name="ret"), lambda ctx, ret: (ctx.copy_to_device(ret.src[0].device), None)),
|
||||
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
|
||||
# NOTE: this is only correct when the KERNEL has a single output
|
||||
|
|
|
|||
|
|
@ -112,7 +112,19 @@ def resolve_call(c:UOp, allow_param_mismatch=True) -> UOp|None:
|
|||
if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}")
|
||||
return c.src[0].substitute(dict_map, walk=True)
|
||||
|
||||
earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||
def lower_cat(cat:UOp) -> UOp:
|
||||
axis = cat.arg
|
||||
dim_acc = list(itertools.accumulate([s.shape[axis] for s in cat.src], initial=0))
|
||||
padded = [s.pad(tuple((dim_acc[i], dim_acc[-1]-dim_acc[i+1]) if j==axis else (0,0) for j in range(len(s.shape)))) for i,s in enumerate(cat.src)]
|
||||
ret = padded[0]
|
||||
for p in padded[1:]: ret = ret.alu(Ops.ADD, p)
|
||||
return ret
|
||||
|
||||
pm_lower_cat = PatternMatcher([
|
||||
(UPat(Ops.CAT, name="cat"), lower_cat),
|
||||
])
|
||||
|
||||
earliest_rewrites = mop_cleanup+pm_lower_cat+PatternMatcher([
|
||||
# early fixup const copy
|
||||
(UPat(Ops.COPY, src=(UPat.var("s"), UPat.var("d"))),
|
||||
lambda s,d: s.substitute({UOp(Ops.DEVICE, arg=s.device):d}) if s.base.op is Ops.CONST else None),
|
||||
|
|
@ -537,7 +549,8 @@ split_kernels = PatternMatcher([
|
|||
|
||||
@profile_matches
|
||||
def get_kernel_graph(sink:UOp) -> UOp:
|
||||
tsink = graph_rewrite(sink, multi_pm, name="multi_pm")
|
||||
tsink = graph_rewrite(sink, pm_lower_cat, name="lower_cat")
|
||||
tsink = graph_rewrite(tsink, multi_pm, name="multi_pm")
|
||||
if OPENPILOT_HACKS: tsink = graph_rewrite(tsink, pm_fold_moved_assign, ctx={}, name="fold moved assigns")
|
||||
tsink = graph_rewrite(tsink, pm_syntactic_sugar+pm_mops+earliest_rewrites, bottom_up=True, name="earliest rewrites")
|
||||
|
||||
|
|
|
|||
|
|
@ -1384,10 +1384,8 @@ class Tensor(OpMixin):
|
|||
"""
|
||||
dim = self._resolve_dim(dim)
|
||||
for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim)
|
||||
tensors = [self, *args]
|
||||
dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0))
|
||||
for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)])
|
||||
return functools.reduce(Tensor.add, tensors)
|
||||
_ = [t.shape[dim] for t in [self, *args]] # validate dim in bounds (catches scalar cat)
|
||||
return self._apply_uop(lambda *uops, arg: UOp(Ops.CAT, uops[0].dtype, uops, arg), *args, arg=dim)
|
||||
|
||||
def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -103,6 +103,9 @@ class Ops(FastEnum):
|
|||
# expander ops
|
||||
UNROLL = auto(); CONTRACT = auto(); VCAT = auto(); PTRCAT = auto()
|
||||
|
||||
# CAT is a movement op (placed here to preserve enum ordering of existing ops)
|
||||
CAT = auto()
|
||||
|
||||
class GroupOp:
|
||||
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIPROCAL, Ops.NEG, Ops.TRUNC}
|
||||
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ,
|
||||
|
|
|
|||
|
|
@ -263,6 +263,14 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
# MULTI marker (axis info in PARAM sources) has no shape
|
||||
case Ops.MULTI if len(self.src) == 0: return None
|
||||
|
||||
case Ops.CAT:
|
||||
shapes = [s.shape for s in self.src]
|
||||
axis = self.arg
|
||||
for s in shapes[1:]:
|
||||
if len(s) != len(shapes[0]) or not all(a==b for i,(a,b) in enumerate(zip(s, shapes[0])) if i!=axis):
|
||||
raise ValueError(f"CAT shape mismatch: {shapes}")
|
||||
return tuple(ssimplify(sum(s[i] for s in shapes)) if i==axis else shapes[0][i] for i in range(len(shapes[0])))
|
||||
|
||||
# movement ops change the shape
|
||||
# NOTE: ssimplify is required because the shape needs to be canonical for broadcasting and same shape checking
|
||||
if self.op in GroupOp.Movement.union({Ops.MULTI, Ops.REDUCE_AXIS, Ops.WMMA}):
|
||||
|
|
|
|||
|
|
@ -69,6 +69,7 @@ movement_ops = PatternMatcher([
|
|||
(UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
||||
(UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
||||
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)),
|
||||
(UPat(Ops.CAT, name="mv"), lambda mv: isinstance(mv.arg, int) and len(mv.src) >= 1),
|
||||
|
||||
# inputs to movement ops
|
||||
(UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True),
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
|||
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.CAT:"#C1FFD7",
|
||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
|
||||
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.LINEAR: "#7DF4FF", Ops.BINARY: "#404040",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue