mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
4 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
37a40bf975 | ||
|
|
af1db22b25 | ||
|
|
be0f9d1055 | ||
|
|
5b9a6c5520 |
8 changed files with 45 additions and 8 deletions
|
|
@ -199,6 +199,13 @@ class TestMultiTensor(unittest.TestCase):
|
||||||
run_schedule(sched)
|
run_schedule(sched)
|
||||||
np.testing.assert_equal(xt.numpy(), X_np[i*2:i*2+2])
|
np.testing.assert_equal(xt.numpy(), X_np[i*2:i*2+2])
|
||||||
|
|
||||||
|
def test_cat_on_non_shard_axis(self):
|
||||||
|
# cat must be lowered to PAD/ADD before multi_pm runs, otherwise MULTI nodes are not handled
|
||||||
|
X = Tensor.arange(8).reshape(4, 2).realize().shard_(devices_2, 0)
|
||||||
|
Y = Tensor.arange(8, 16).reshape(4, 2).realize().shard_(devices_2, 0)
|
||||||
|
Z = X.cat(Y, dim=1)
|
||||||
|
np.testing.assert_equal(Z.numpy(), np.concatenate([np.arange(8).reshape(4, 2), np.arange(8, 16).reshape(4, 2)], axis=1))
|
||||||
|
|
||||||
@given(strat.sampled_from((devices_2, devices_3)),
|
@given(strat.sampled_from((devices_2, devices_3)),
|
||||||
strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)),
|
strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)),
|
||||||
strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)))
|
strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)))
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,14 @@
|
||||||
from typing import cast
|
from typing import cast
|
||||||
import math, dataclasses
|
import math, itertools, dataclasses
|
||||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
|
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
|
||||||
from tinygrad.helpers import argsort
|
from tinygrad.helpers import argsort
|
||||||
|
|
||||||
|
def cat_gradient(ctx:UOp, ret:UOp) -> tuple[UOp, ...]:
|
||||||
|
axis = ret.arg
|
||||||
|
dim_acc = list(itertools.accumulate([s.shape[axis] for s in ret.src], initial=0))
|
||||||
|
return tuple(ctx.shrink(tuple([(dim_acc[i], dim_acc[i+1]) if j==axis else (0, ctx.shape[j])
|
||||||
|
for j in range(len(ctx.shape))])) for i in range(len(ret.src)))
|
||||||
|
|
||||||
def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
|
def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
|
||||||
def broadcast_to_input(x): return x.reshape(x.shape+(1,)*(len(ret.src[0].shape)-len(x.shape))).expand(ret.src[0].shape)
|
def broadcast_to_input(x): return x.reshape(x.shape+(1,)*(len(ret.src[0].shape)-len(x.shape))).expand(ret.src[0].shape)
|
||||||
if op == Ops.ADD: return (broadcast_to_input(ctx),)
|
if op == Ops.ADD: return (broadcast_to_input(ctx),)
|
||||||
|
|
@ -54,6 +60,7 @@ pm_gradient = PatternMatcher([
|
||||||
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
||||||
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),
|
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),
|
||||||
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip([i for i,x in enumerate(ret.marg) if x]),)),
|
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip([i for i,x in enumerate(ret.marg) if x]),)),
|
||||||
|
(UPat(Ops.CAT, name="ret"), lambda ctx, ret: cat_gradient(ctx, ret)),
|
||||||
(UPat(Ops.COPY, name="ret"), lambda ctx, ret: (ctx.copy_to_device(ret.src[0].device), None)),
|
(UPat(Ops.COPY, name="ret"), lambda ctx, ret: (ctx.copy_to_device(ret.src[0].device), None)),
|
||||||
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
|
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
|
||||||
# NOTE: this is only correct when the KERNEL has a single output
|
# NOTE: this is only correct when the KERNEL has a single output
|
||||||
|
|
|
||||||
|
|
@ -112,7 +112,19 @@ def resolve_call(c:UOp, allow_param_mismatch=True) -> UOp|None:
|
||||||
if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}")
|
if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}")
|
||||||
return c.src[0].substitute(dict_map, walk=True)
|
return c.src[0].substitute(dict_map, walk=True)
|
||||||
|
|
||||||
earliest_rewrites = mop_cleanup+PatternMatcher([
|
def lower_cat(cat:UOp) -> UOp:
|
||||||
|
axis = cat.arg
|
||||||
|
dim_acc = list(itertools.accumulate([s.shape[axis] for s in cat.src], initial=0))
|
||||||
|
padded = [s.pad(tuple((dim_acc[i], dim_acc[-1]-dim_acc[i+1]) if j==axis else (0,0) for j in range(len(s.shape)))) for i,s in enumerate(cat.src)]
|
||||||
|
ret = padded[0]
|
||||||
|
for p in padded[1:]: ret = ret.alu(Ops.ADD, p)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
pm_lower_cat = PatternMatcher([
|
||||||
|
(UPat(Ops.CAT, name="cat"), lower_cat),
|
||||||
|
])
|
||||||
|
|
||||||
|
earliest_rewrites = mop_cleanup+pm_lower_cat+PatternMatcher([
|
||||||
# early fixup const copy
|
# early fixup const copy
|
||||||
(UPat(Ops.COPY, src=(UPat.var("s"), UPat.var("d"))),
|
(UPat(Ops.COPY, src=(UPat.var("s"), UPat.var("d"))),
|
||||||
lambda s,d: s.substitute({UOp(Ops.DEVICE, arg=s.device):d}) if s.base.op is Ops.CONST else None),
|
lambda s,d: s.substitute({UOp(Ops.DEVICE, arg=s.device):d}) if s.base.op is Ops.CONST else None),
|
||||||
|
|
@ -537,7 +549,8 @@ split_kernels = PatternMatcher([
|
||||||
|
|
||||||
@profile_matches
|
@profile_matches
|
||||||
def get_kernel_graph(sink:UOp) -> UOp:
|
def get_kernel_graph(sink:UOp) -> UOp:
|
||||||
tsink = graph_rewrite(sink, multi_pm, name="multi_pm")
|
tsink = graph_rewrite(sink, pm_lower_cat, name="lower_cat")
|
||||||
|
tsink = graph_rewrite(tsink, multi_pm, name="multi_pm")
|
||||||
if OPENPILOT_HACKS: tsink = graph_rewrite(tsink, pm_fold_moved_assign, ctx={}, name="fold moved assigns")
|
if OPENPILOT_HACKS: tsink = graph_rewrite(tsink, pm_fold_moved_assign, ctx={}, name="fold moved assigns")
|
||||||
tsink = graph_rewrite(tsink, pm_syntactic_sugar+pm_mops+earliest_rewrites, bottom_up=True, name="earliest rewrites")
|
tsink = graph_rewrite(tsink, pm_syntactic_sugar+pm_mops+earliest_rewrites, bottom_up=True, name="earliest rewrites")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1384,10 +1384,8 @@ class Tensor(OpMixin):
|
||||||
"""
|
"""
|
||||||
dim = self._resolve_dim(dim)
|
dim = self._resolve_dim(dim)
|
||||||
for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim)
|
for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim)
|
||||||
tensors = [self, *args]
|
_ = [t.shape[dim] for t in [self, *args]] # validate dim in bounds (catches scalar cat)
|
||||||
dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0))
|
return self._apply_uop(lambda *uops, arg: UOp(Ops.CAT, uops[0].dtype, uops, arg), *args, arg=dim)
|
||||||
for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)])
|
|
||||||
return functools.reduce(Tensor.add, tensors)
|
|
||||||
|
|
||||||
def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -103,6 +103,9 @@ class Ops(FastEnum):
|
||||||
# expander ops
|
# expander ops
|
||||||
UNROLL = auto(); CONTRACT = auto(); VCAT = auto(); PTRCAT = auto()
|
UNROLL = auto(); CONTRACT = auto(); VCAT = auto(); PTRCAT = auto()
|
||||||
|
|
||||||
|
# CAT is a movement op (placed here to preserve enum ordering of existing ops)
|
||||||
|
CAT = auto()
|
||||||
|
|
||||||
class GroupOp:
|
class GroupOp:
|
||||||
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIPROCAL, Ops.NEG, Ops.TRUNC}
|
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIPROCAL, Ops.NEG, Ops.TRUNC}
|
||||||
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ,
|
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ,
|
||||||
|
|
|
||||||
|
|
@ -263,6 +263,14 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
# MULTI marker (axis info in PARAM sources) has no shape
|
# MULTI marker (axis info in PARAM sources) has no shape
|
||||||
case Ops.MULTI if len(self.src) == 0: return None
|
case Ops.MULTI if len(self.src) == 0: return None
|
||||||
|
|
||||||
|
case Ops.CAT:
|
||||||
|
shapes = [s.shape for s in self.src]
|
||||||
|
axis = self.arg
|
||||||
|
for s in shapes[1:]:
|
||||||
|
if len(s) != len(shapes[0]) or not all(a==b for i,(a,b) in enumerate(zip(s, shapes[0])) if i!=axis):
|
||||||
|
raise ValueError(f"CAT shape mismatch: {shapes}")
|
||||||
|
return tuple(ssimplify(sum(s[i] for s in shapes)) if i==axis else shapes[0][i] for i in range(len(shapes[0])))
|
||||||
|
|
||||||
# movement ops change the shape
|
# movement ops change the shape
|
||||||
# NOTE: ssimplify is required because the shape needs to be canonical for broadcasting and same shape checking
|
# NOTE: ssimplify is required because the shape needs to be canonical for broadcasting and same shape checking
|
||||||
if self.op in GroupOp.Movement.union({Ops.MULTI, Ops.REDUCE_AXIS, Ops.WMMA}):
|
if self.op in GroupOp.Movement.union({Ops.MULTI, Ops.REDUCE_AXIS, Ops.WMMA}):
|
||||||
|
|
|
||||||
|
|
@ -69,6 +69,7 @@ movement_ops = PatternMatcher([
|
||||||
(UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
(UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
||||||
(UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
(UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
||||||
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)),
|
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)),
|
||||||
|
(UPat(Ops.CAT, name="mv"), lambda mv: isinstance(mv.arg, int) and len(mv.src) >= 1),
|
||||||
|
|
||||||
# inputs to movement ops
|
# inputs to movement ops
|
||||||
(UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True),
|
(UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True),
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
||||||
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
|
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
|
||||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
||||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.CAT:"#C1FFD7",
|
||||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
|
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
|
||||||
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.LINEAR: "#7DF4FF", Ops.BINARY: "#404040",
|
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.LINEAR: "#7DF4FF", Ops.BINARY: "#404040",
|
||||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue