Compare commits

...

4 commits

Author SHA1 Message Date
George Hotz
37a40bf975 early lower cat 2026-03-07 11:39:20 +08:00
George Hotz
af1db22b25 simpler 2026-03-07 10:11:21 +08:00
George Hotz
be0f9d1055 min 2026-03-07 10:00:29 +08:00
George Hotz
5b9a6c5520 Add Ops.CAT movement op (ai slop) 2026-03-06 18:51:25 +08:00
8 changed files with 45 additions and 8 deletions

View file

@ -199,6 +199,13 @@ class TestMultiTensor(unittest.TestCase):
run_schedule(sched)
np.testing.assert_equal(xt.numpy(), X_np[i*2:i*2+2])
def test_cat_on_non_shard_axis(self):
# cat must be lowered to PAD/ADD before multi_pm runs, otherwise MULTI nodes are not handled
X = Tensor.arange(8).reshape(4, 2).realize().shard_(devices_2, 0)
Y = Tensor.arange(8, 16).reshape(4, 2).realize().shard_(devices_2, 0)
Z = X.cat(Y, dim=1)
np.testing.assert_equal(Z.numpy(), np.concatenate([np.arange(8).reshape(4, 2), np.arange(8, 16).reshape(4, 2)], axis=1))
@given(strat.sampled_from((devices_2, devices_3)),
strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)),
strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)))

View file

@ -1,8 +1,14 @@
from typing import cast
import math, dataclasses
import math, itertools, dataclasses
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
from tinygrad.helpers import argsort
def cat_gradient(ctx:UOp, ret:UOp) -> tuple[UOp, ...]:
axis = ret.arg
dim_acc = list(itertools.accumulate([s.shape[axis] for s in ret.src], initial=0))
return tuple(ctx.shrink(tuple([(dim_acc[i], dim_acc[i+1]) if j==axis else (0, ctx.shape[j])
for j in range(len(ctx.shape))])) for i in range(len(ret.src)))
def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
def broadcast_to_input(x): return x.reshape(x.shape+(1,)*(len(ret.src[0].shape)-len(x.shape))).expand(ret.src[0].shape)
if op == Ops.ADD: return (broadcast_to_input(ctx),)
@ -54,6 +60,7 @@ pm_gradient = PatternMatcher([
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip([i for i,x in enumerate(ret.marg) if x]),)),
(UPat(Ops.CAT, name="ret"), lambda ctx, ret: cat_gradient(ctx, ret)),
(UPat(Ops.COPY, name="ret"), lambda ctx, ret: (ctx.copy_to_device(ret.src[0].device), None)),
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
# NOTE: this is only correct when the KERNEL has a single output

View file

@ -112,7 +112,19 @@ def resolve_call(c:UOp, allow_param_mismatch=True) -> UOp|None:
if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}")
return c.src[0].substitute(dict_map, walk=True)
earliest_rewrites = mop_cleanup+PatternMatcher([
def lower_cat(cat:UOp) -> UOp:
axis = cat.arg
dim_acc = list(itertools.accumulate([s.shape[axis] for s in cat.src], initial=0))
padded = [s.pad(tuple((dim_acc[i], dim_acc[-1]-dim_acc[i+1]) if j==axis else (0,0) for j in range(len(s.shape)))) for i,s in enumerate(cat.src)]
ret = padded[0]
for p in padded[1:]: ret = ret.alu(Ops.ADD, p)
return ret
pm_lower_cat = PatternMatcher([
(UPat(Ops.CAT, name="cat"), lower_cat),
])
earliest_rewrites = mop_cleanup+pm_lower_cat+PatternMatcher([
# early fixup const copy
(UPat(Ops.COPY, src=(UPat.var("s"), UPat.var("d"))),
lambda s,d: s.substitute({UOp(Ops.DEVICE, arg=s.device):d}) if s.base.op is Ops.CONST else None),
@ -537,7 +549,8 @@ split_kernels = PatternMatcher([
@profile_matches
def get_kernel_graph(sink:UOp) -> UOp:
tsink = graph_rewrite(sink, multi_pm, name="multi_pm")
tsink = graph_rewrite(sink, pm_lower_cat, name="lower_cat")
tsink = graph_rewrite(tsink, multi_pm, name="multi_pm")
if OPENPILOT_HACKS: tsink = graph_rewrite(tsink, pm_fold_moved_assign, ctx={}, name="fold moved assigns")
tsink = graph_rewrite(tsink, pm_syntactic_sugar+pm_mops+earliest_rewrites, bottom_up=True, name="earliest rewrites")

View file

@ -1384,10 +1384,8 @@ class Tensor(OpMixin):
"""
dim = self._resolve_dim(dim)
for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim)
tensors = [self, *args]
dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0))
for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)])
return functools.reduce(Tensor.add, tensors)
_ = [t.shape[dim] for t in [self, *args]] # validate dim in bounds (catches scalar cat)
return self._apply_uop(lambda *uops, arg: UOp(Ops.CAT, uops[0].dtype, uops, arg), *args, arg=dim)
def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
"""

View file

@ -103,6 +103,9 @@ class Ops(FastEnum):
# expander ops
UNROLL = auto(); CONTRACT = auto(); VCAT = auto(); PTRCAT = auto()
# CAT is a movement op (placed here to preserve enum ordering of existing ops)
CAT = auto()
class GroupOp:
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIPROCAL, Ops.NEG, Ops.TRUNC}
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ,

View file

@ -263,6 +263,14 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# MULTI marker (axis info in PARAM sources) has no shape
case Ops.MULTI if len(self.src) == 0: return None
case Ops.CAT:
shapes = [s.shape for s in self.src]
axis = self.arg
for s in shapes[1:]:
if len(s) != len(shapes[0]) or not all(a==b for i,(a,b) in enumerate(zip(s, shapes[0])) if i!=axis):
raise ValueError(f"CAT shape mismatch: {shapes}")
return tuple(ssimplify(sum(s[i] for s in shapes)) if i==axis else shapes[0][i] for i in range(len(shapes[0])))
# movement ops change the shape
# NOTE: ssimplify is required because the shape needs to be canonical for broadcasting and same shape checking
if self.op in GroupOp.Movement.union({Ops.MULTI, Ops.REDUCE_AXIS, Ops.WMMA}):

View file

@ -69,6 +69,7 @@ movement_ops = PatternMatcher([
(UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True),
(UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index), UPat(dtype=dtypes.index))), lambda mv,x: True),
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)),
(UPat(Ops.CAT, name="mv"), lambda mv: isinstance(mv.arg, int) and len(mv.src) >= 1),
# inputs to movement ops
(UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True),

View file

@ -47,7 +47,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.CAT:"#C1FFD7",
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.LINEAR: "#7DF4FF", Ops.BINARY: "#404040",
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",