mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
2 commits
master
...
outer_rang
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
db8c6d9a04 | ||
|
|
0647f87bf8 |
5 changed files with 41 additions and 12 deletions
18
test/test_outerworld_transformer.py
Normal file
18
test/test_outerworld_transformer.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from tinygrad import Tensor, UOp
|
||||
from tinygrad.uop.ops import Ops, AxisType
|
||||
import unittest
|
||||
# this test is only focused on transformers and using range for the layers
|
||||
|
||||
class TestOuterworldTransformer(unittest.TestCase):
|
||||
def test_three_mats(self):
|
||||
w = Tensor.empty(3, 1024, 1024)
|
||||
inp = Tensor.empty(1, 1024)
|
||||
i = UOp.range(3, -1, AxisType.OUTER)
|
||||
inp_after = Tensor(inp.uop.after(i))
|
||||
inp_gemm = inp_after@w[i]
|
||||
inp = inp.uop.after(inp.uop.store(inp_gemm.uop).end(i)).contiguous()
|
||||
inp = Tensor(inp)
|
||||
inp.realize()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -2,7 +2,8 @@ from __future__ import annotations
|
|||
import math, itertools
|
||||
from collections import defaultdict
|
||||
from typing import cast, Final
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp, axis_letters, axis_colors
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp
|
||||
from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
|
||||
|
|
@ -12,10 +13,6 @@ from tinygrad.renderer import Renderer
|
|||
|
||||
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
|
||||
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
|
||||
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
|
||||
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
|
||||
|
||||
class Scheduler:
|
||||
def __init__(self, ast:UOp, ren:Renderer):
|
||||
self.ast, self.ren = ast, ren
|
||||
|
|
|
|||
|
|
@ -16,6 +16,9 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
|
|||
for s in rb.src:
|
||||
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
|
||||
|
||||
def realize_store(ctx:dict[UOp, None], a:UOp) -> None:
|
||||
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
|
||||
|
||||
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
|
||||
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
|
||||
# if it's a kernel, we don't realize it
|
||||
|
|
@ -30,6 +33,8 @@ pm_generate_realize_map = PatternMatcher([
|
|||
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
|
||||
# realize ASSIGN and input to assign (might be optimized out)
|
||||
(UPat(Ops.ASSIGN, name="a"), realize_assign),
|
||||
# realize STORE
|
||||
(UPat(Ops.STORE, name="a"), realize_store),
|
||||
])
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -50,13 +55,14 @@ class IndexingContext:
|
|||
# if a range has a 1 src, it's the same as UOp.const(dtypes.index, 0)
|
||||
return UOp.range(s, next(self.range_idx), axistype) if resolve(s!=1) else UOp.const(dtypes.index, 0)
|
||||
|
||||
ops_allowed_after = (Ops.KERNEL, Ops.RANGE)
|
||||
def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
||||
if x.op in {Ops.BUFFERIZE, Ops.INDEX}: return None
|
||||
if x.op is Ops.AFTER and x.src[1].op is Ops.KERNEL: return None
|
||||
if x.op is Ops.AFTER and x.src[1].op in ops_allowed_after: return None
|
||||
new_srcs = []
|
||||
for s in x.src:
|
||||
new_src = s
|
||||
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op is Ops.KERNEL):
|
||||
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op in ops_allowed_after):
|
||||
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
|
||||
elif s in ctx.realize_map:
|
||||
realized_ranges = ctx.realize_map[s]
|
||||
|
|
@ -176,7 +182,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
|||
# mark all ranges as ended
|
||||
assert rctx.realize_map[x] is None
|
||||
rctx.realize_map[x] = list(range(len(x.shape)))
|
||||
elif x.op in {Ops.MSTACK, Ops.MSELECT}:
|
||||
elif x.op in {Ops.MSTACK, Ops.MSELECT, Ops.END}:
|
||||
# treat MSTACK/MSELECT like SINK
|
||||
continue
|
||||
elif len(consumer_rngs) == 0:
|
||||
|
|
|
|||
|
|
@ -396,6 +396,7 @@ def handle_after(ctx:LocalAddBufferContext, after:UOp):
|
|||
return buf
|
||||
|
||||
def renumber_range(ctx:LocalAddBufferContext, r:UOp):
|
||||
if r.arg[-1] == AxisType.OUTER: return None
|
||||
if r.tag != (): return None
|
||||
ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=None)
|
||||
ctx.range += 1
|
||||
|
|
@ -469,7 +470,10 @@ pm_add_range_tags = PatternMatcher([
|
|||
])
|
||||
|
||||
def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
||||
if len(x.ranges): return None
|
||||
if len([r for r in x.ranges if r.arg[-1] != AxisType.OUTER]): return None
|
||||
|
||||
# ends of outer range don't go in kernels
|
||||
if x.op is Ops.END and x.src[1].op is Ops.RANGE and x.src[1].arg[-1] == AxisType.OUTER: return None
|
||||
|
||||
# local kernel rewrite
|
||||
lctx = LocalAddBufferContext()
|
||||
|
|
|
|||
|
|
@ -15,11 +15,15 @@ if TYPE_CHECKING:
|
|||
class AxisType(Enum):
|
||||
def __repr__(self): return str(self)
|
||||
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
|
||||
THREAD = auto()
|
||||
THREAD = auto(); OUTER = auto() # noqa: E702
|
||||
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
|
||||
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
|
||||
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.OUTER: ("O")}
|
||||
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
|
||||
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
|
||||
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta", AxisType.OUTER: "green"}
|
||||
|
||||
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
|
||||
axis_to_pos = {AxisType.OUTER: -2, AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2,
|
||||
AxisType.UPCAST: 3, AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
|
||||
|
||||
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue