mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
6 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ab5b5b865 | ||
|
|
3b123b4d53 | ||
|
|
80f1ff18ad | ||
|
|
09c02868a5 | ||
|
|
9d638a0202 |
||
|
|
b815c6a3df |
10 changed files with 92 additions and 16 deletions
31
test/test_fold.py
Normal file
31
test/test_fold.py
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
import numpy as np
|
||||||
|
import unittest
|
||||||
|
from tinygrad import Tensor, UOp
|
||||||
|
from tinygrad.uop.ops import AxisType
|
||||||
|
|
||||||
|
class TestFold(unittest.TestCase):
|
||||||
|
def test_reduce_add(self):
|
||||||
|
a = Tensor.randn(10, 10).realize()
|
||||||
|
a_red = a.sum(axis=1)
|
||||||
|
np.testing.assert_allclose(a_red.numpy(), a.numpy().sum(axis=1), atol=1e-6)
|
||||||
|
|
||||||
|
def test_fold_add(self):
|
||||||
|
a = Tensor.randn(10, 10).realize()
|
||||||
|
init = Tensor.zeros(10, 1).contiguous()
|
||||||
|
a_red = (init+a).fold(init).reshape(10)
|
||||||
|
np.testing.assert_allclose(a_red.numpy(), a.numpy().sum(axis=1), atol=1e-6)
|
||||||
|
|
||||||
|
#@unittest.skip("no outer fold yet")
|
||||||
|
def test_fold_matmul(self):
|
||||||
|
vec = Tensor.randn(1, 10).realize()
|
||||||
|
mats = Tensor.randn(3, 10, 10).realize()
|
||||||
|
np_mats = mats.numpy()
|
||||||
|
np_ref = ((vec.numpy() @ np_mats[0]) @ np_mats[1]) @ np_mats[2]
|
||||||
|
|
||||||
|
i = UOp.range(3, -1, AxisType.OUTER)
|
||||||
|
out = (vec @ mats[i]).contiguous().fold(vec, i)
|
||||||
|
|
||||||
|
np.testing.assert_allclose(out.numpy(), np_ref, atol=1e-6)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
|
|
@ -308,9 +308,19 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
||||||
if len(reduce_range) == 0: return ret
|
if len(reduce_range) == 0: return ret
|
||||||
return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range)).index(UOp.const(dtypes.int, 0))
|
return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range)).index(UOp.const(dtypes.int, 0))
|
||||||
|
|
||||||
|
def fold_to_store(x:UOp):
|
||||||
|
_, acc, ranges = x.src[0], x.src[1], x.src[2:]
|
||||||
|
assert acc.op is Ops.INDEX
|
||||||
|
buf = acc.src[0]
|
||||||
|
ret = x.substitute({buf: buf.rtag().after(*ranges)}).substitute({buf.rtag(): buf})
|
||||||
|
base, acc, ranges = ret.src[0], ret.src[1], ret.src[2:]
|
||||||
|
return buf.after(acc.store(base).end(*ranges)).index(acc.src[1])
|
||||||
|
|
||||||
pm_reduce = PatternMatcher([
|
pm_reduce = PatternMatcher([
|
||||||
# REDUCE -> DEFINE_ACC+ASSIGN
|
# REDUCE -> DEFINE_ACC+STORE
|
||||||
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
|
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
|
||||||
|
# FOLD -> STORE
|
||||||
|
(UPat(Ops.FOLD, name="x"), fold_to_store),
|
||||||
# tensor core built in accumulate
|
# tensor core built in accumulate
|
||||||
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
|
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
|
||||||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,8 @@ from __future__ import annotations
|
||||||
import math, itertools
|
import math, itertools
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import cast, Final
|
from typing import cast, Final
|
||||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp, axis_letters, axis_colors
|
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp
|
||||||
|
from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos
|
||||||
from tinygrad.device import Buffer
|
from tinygrad.device import Buffer
|
||||||
from tinygrad.dtype import dtypes, ImageDType
|
from tinygrad.dtype import dtypes, ImageDType
|
||||||
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
|
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
|
||||||
|
|
@ -12,10 +13,6 @@ from tinygrad.renderer import Renderer
|
||||||
|
|
||||||
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||||
|
|
||||||
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
|
|
||||||
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
|
|
||||||
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
|
|
||||||
|
|
||||||
class Scheduler:
|
class Scheduler:
|
||||||
def __init__(self, ast:UOp, ren:Renderer):
|
def __init__(self, ast:UOp, ren:Renderer):
|
||||||
self.ast, self.ren = ast, ren
|
self.ast, self.ren = ast, ren
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
||||||
if x.op in {Ops.BUFFERIZE, Ops.INDEX}: return None
|
if x.op in {Ops.BUFFERIZE, Ops.INDEX}: return None
|
||||||
if x.op is Ops.AFTER and x.src[1].op is Ops.KERNEL: return None
|
if x.op is Ops.AFTER and x.src[1].op is Ops.KERNEL: return None
|
||||||
new_srcs = []
|
new_srcs = []
|
||||||
for s in x.src:
|
for i,s in enumerate(x.src):
|
||||||
new_src = s
|
new_src = s
|
||||||
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op is Ops.KERNEL):
|
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op is Ops.KERNEL):
|
||||||
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
|
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
|
||||||
|
|
@ -65,7 +65,9 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
||||||
# None in the device assigns it a number later
|
# None in the device assigns it a number later
|
||||||
opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL)
|
opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL)
|
||||||
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
|
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
|
||||||
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
|
if x in ctx.range_map:
|
||||||
|
# for scan we use the output ranges on the 2nd arg
|
||||||
|
new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][int(x.op is Ops.FOLD and i == 1)]) if i in realized_ranges])
|
||||||
new_srcs.append(new_src)
|
new_srcs.append(new_src)
|
||||||
# NOTE: do we need this?
|
# NOTE: do we need this?
|
||||||
return x.replace(src=tns) if x.src != (tns:=tuple(new_srcs)) else None
|
return x.replace(src=tns) if x.src != (tns:=tuple(new_srcs)) else None
|
||||||
|
|
@ -84,6 +86,13 @@ def convert_reduce_axis_to_reduce_with_ranges(ctx:IndexingContext, x:UOp):
|
||||||
ctx.range_map[ret] = ctx.range_map[x]
|
ctx.range_map[ret] = ctx.range_map[x]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
def add_ranges_to_scan(ctx:IndexingContext, x:UOp):
|
||||||
|
if x not in ctx.range_map: return None
|
||||||
|
new_ranges = [r for r,ar in zip(*ctx.range_map[x]) if r is not ar and r not in x.src]
|
||||||
|
ret = x.replace(src=x.src+tuple(new_ranges))
|
||||||
|
ctx.range_map[ret] = ctx.range_map[x]
|
||||||
|
return ret
|
||||||
|
|
||||||
def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
|
def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
|
||||||
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0]
|
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0]
|
||||||
|
|
||||||
|
|
@ -97,6 +106,8 @@ def add_third_op_to_assign_to_track_shape(ctx:IndexingContext, assign:UOp):
|
||||||
pm_apply_rangeify = PatternMatcher([
|
pm_apply_rangeify = PatternMatcher([
|
||||||
# REDUCE_AXIS -> REDUCE
|
# REDUCE_AXIS -> REDUCE
|
||||||
(UPat(Ops.REDUCE_AXIS, name="x"), convert_reduce_axis_to_reduce_with_ranges),
|
(UPat(Ops.REDUCE_AXIS, name="x"), convert_reduce_axis_to_reduce_with_ranges),
|
||||||
|
# SCAN -> SCAN (with new ranges)
|
||||||
|
(UPat(Ops.FOLD, name="x"), add_ranges_to_scan),
|
||||||
# PAD -> WHERE
|
# PAD -> WHERE
|
||||||
(UPat(Ops.PAD, name="x"), convert_pad_to_where_to_keep_behavior_local),
|
(UPat(Ops.PAD, name="x"), convert_pad_to_where_to_keep_behavior_local),
|
||||||
# add third op to assign
|
# add third op to assign
|
||||||
|
|
@ -244,6 +255,9 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||||
# REDUCE_AXIS creates ranges for the axes it is reducing
|
# REDUCE_AXIS creates ranges for the axes it is reducing
|
||||||
if x.op is Ops.REDUCE_AXIS:
|
if x.op is Ops.REDUCE_AXIS:
|
||||||
rngs = tuple(rctx.new_range(s, axistype=AxisType.REDUCE) if i in x.arg[1] else r for i,(r,s) in enumerate(zip(rngs, x.src[0].shape)))
|
rngs = tuple(rctx.new_range(s, axistype=AxisType.REDUCE) if i in x.arg[1] else r for i,(r,s) in enumerate(zip(rngs, x.src[0].shape)))
|
||||||
|
if x.op is Ops.FOLD:
|
||||||
|
rngs = tuple(rctx.new_range(s, axistype=AxisType.FOLD) if resolve(x.src[1].shape[i] == 1) else r \
|
||||||
|
for i,(r,s) in enumerate(zip(rngs, x.src[0].shape)))
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
realized_ranges = rctx.realize_map.get(x, None)
|
realized_ranges = rctx.realize_map.get(x, None)
|
||||||
|
|
|
||||||
|
|
@ -396,6 +396,7 @@ def handle_after(ctx:LocalAddBufferContext, after:UOp):
|
||||||
return buf
|
return buf
|
||||||
|
|
||||||
def renumber_range(ctx:LocalAddBufferContext, r:UOp):
|
def renumber_range(ctx:LocalAddBufferContext, r:UOp):
|
||||||
|
if r.arg[-1] == AxisType.OUTER: return None
|
||||||
if r.tag != (): return None
|
if r.tag != (): return None
|
||||||
ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=None)
|
ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=None)
|
||||||
ctx.range += 1
|
ctx.range += 1
|
||||||
|
|
@ -469,7 +470,10 @@ pm_add_range_tags = PatternMatcher([
|
||||||
])
|
])
|
||||||
|
|
||||||
def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
||||||
if len(x.ranges): return None
|
if len([r for r in x.ranges if r.arg[-1] != AxisType.OUTER]): return None
|
||||||
|
|
||||||
|
# ends of outer range don't go in kernels
|
||||||
|
if x.op is Ops.END and x.src[1].op is Ops.RANGE and x.src[1].arg[-1] == AxisType.OUTER: return None
|
||||||
|
|
||||||
# local kernel rewrite
|
# local kernel rewrite
|
||||||
lctx = LocalAddBufferContext()
|
lctx = LocalAddBufferContext()
|
||||||
|
|
|
||||||
|
|
@ -1509,6 +1509,9 @@ class Tensor(OpMixin):
|
||||||
|
|
||||||
# ***** reduce ops *****
|
# ***** reduce ops *****
|
||||||
|
|
||||||
|
def fold(self, init:Tensor, *ranges:UOp) -> Tensor:
|
||||||
|
return self._apply_uop(UOp.fold, init, extra_args=ranges)
|
||||||
|
|
||||||
def _reduce(self, op:Ops, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
|
def _reduce(self, op:Ops, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
|
||||||
axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1)))
|
axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1)))
|
||||||
if self.ndim == 0: axis = ()
|
if self.ndim == 0: axis = ()
|
||||||
|
|
|
||||||
|
|
@ -92,6 +92,9 @@ class Ops(FastEnum):
|
||||||
# reduce
|
# reduce
|
||||||
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto()
|
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto()
|
||||||
|
|
||||||
|
# scan
|
||||||
|
FOLD = auto()
|
||||||
|
|
||||||
# errors/placeholders
|
# errors/placeholders
|
||||||
REWRITE_ERROR = auto(); SENTINEL = auto()
|
REWRITE_ERROR = auto(); SENTINEL = auto()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,14 +13,20 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
class AxisType(Enum):
|
class AxisType(Enum):
|
||||||
def __repr__(self): return str(self)
|
def __repr__(self): return str(self)
|
||||||
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
|
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); FOLD = auto() # noqa: E702
|
||||||
THREAD = auto()
|
UPCAST = auto(); UNROLL = auto() # noqa: E702
|
||||||
|
THREAD = auto(); OUTER = auto() # noqa: E702
|
||||||
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
|
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
|
||||||
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
|
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.FOLD: "F", AxisType.UNROLL: "r", AxisType.OUTER: "O"}
|
||||||
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
|
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
|
||||||
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
|
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.FOLD: "red", AxisType.UNROLL: "magenta",
|
||||||
|
AxisType.OUTER: "green"}
|
||||||
|
|
||||||
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1}
|
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
|
||||||
|
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
|
||||||
|
AxisType.GROUP_REDUCE: 2, AxisType.FOLD: 4, AxisType.REDUCE: 5, AxisType.UNROLL: 6, AxisType.OUTER: -2}
|
||||||
|
|
||||||
|
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.FOLD: 2}
|
||||||
|
|
||||||
# https://en.wikipedia.org/wiki/Identity_element
|
# https://en.wikipedia.org/wiki/Identity_element
|
||||||
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
||||||
|
|
@ -213,6 +219,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
case Ops.BUFFERIZE: return tuple([int(r.vmax+1) for r in self.src[1:]])
|
case Ops.BUFFERIZE: return tuple([int(r.vmax+1) for r in self.src[1:]])
|
||||||
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
|
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
|
||||||
|
|
||||||
|
# shape of init
|
||||||
|
case Ops.FOLD:
|
||||||
|
return self.src[1]._shape
|
||||||
|
|
||||||
# passthrough ops
|
# passthrough ops
|
||||||
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
|
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
|
||||||
return self.src[0]._shape
|
return self.src[0]._shape
|
||||||
|
|
@ -437,6 +447,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
assert self.dtype.scalar() is dtypes.index, "Can only call get_valid on index dtype"
|
assert self.dtype.scalar() is dtypes.index, "Can only call get_valid on index dtype"
|
||||||
return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid)
|
return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid)
|
||||||
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
||||||
|
def fold(self, *src:UOp, **kwargs): return UOp(Ops.FOLD, self.dtype, (self,)+src, **kwargs)
|
||||||
|
|
||||||
def is_contiguous(self):
|
def is_contiguous(self):
|
||||||
# TODO: this is is_realized
|
# TODO: this is is_realized
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,9 @@ shared_spec = PatternMatcher([
|
||||||
rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \
|
rng.dtype == x.dtype and isinstance(rng.arg, tuple) and len(rng.arg) >= 2 and \
|
||||||
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
|
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
|
||||||
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
|
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
|
||||||
|
|
||||||
|
# FOLD, 2 ops + ranges
|
||||||
|
(UPat(Ops.FOLD, src=(UPat(), UPat()), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[2:])),
|
||||||
])
|
])
|
||||||
|
|
||||||
# ***** UOp spec in the Tensor graph *****
|
# ***** UOp spec in the Tensor graph *****
|
||||||
|
|
@ -172,7 +175,7 @@ kernel_spec = PatternMatcher([
|
||||||
# bufferize can be on anything
|
# bufferize can be on anything
|
||||||
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: True),
|
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: True),
|
||||||
|
|
||||||
# reduce must be on ranges
|
# reduce/fold must be on ranges
|
||||||
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
|
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:])),
|
||||||
])+movement_ops+shared_codegen_spec+shared_spec
|
])+movement_ops+shared_codegen_spec+shared_spec
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from tinygrad.dtype import dtypes
|
||||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||||
Ops.DEFINE_GLOBAL:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
|
Ops.DEFINE_GLOBAL:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
|
||||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
|
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.FOLD: "#FF7B7B",
|
||||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0",
|
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0",
|
||||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue