Compare commits

...

2 commits

Author SHA1 Message Date
George Hotz
db8c6d9a04 work 2025-11-10 15:26:16 -08:00
George Hotz
0647f87bf8 outer range runs in the scheduler 2025-11-10 14:49:42 -08:00
5 changed files with 41 additions and 12 deletions

View file

@ -0,0 +1,18 @@
from tinygrad import Tensor, UOp
from tinygrad.uop.ops import Ops, AxisType
import unittest
# this test is only focused on transformers and using range for the layers
class TestOuterworldTransformer(unittest.TestCase):
def test_three_mats(self):
w = Tensor.empty(3, 1024, 1024)
inp = Tensor.empty(1, 1024)
i = UOp.range(3, -1, AxisType.OUTER)
inp_after = Tensor(inp.uop.after(i))
inp_gemm = inp_after@w[i]
inp = inp.uop.after(inp.uop.store(inp_gemm.uop).end(i)).contiguous()
inp = Tensor(inp)
inp.realize()
if __name__ == "__main__":
unittest.main()

View file

@ -2,7 +2,8 @@ from __future__ import annotations
import math, itertools
from collections import defaultdict
from typing import cast, Final
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp, axis_letters, axis_colors
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp
from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos
from tinygrad.device import Buffer
from tinygrad.dtype import dtypes, ImageDType
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
@ -12,10 +13,6 @@ from tinygrad.renderer import Renderer
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
class Scheduler:
def __init__(self, ast:UOp, ren:Renderer):
self.ast, self.ren = ast, ren

View file

@ -16,6 +16,9 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
for s in rb.src:
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
def realize_store(ctx:dict[UOp, None], a:UOp) -> None:
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
# if it's a kernel, we don't realize it
@ -30,6 +33,8 @@ pm_generate_realize_map = PatternMatcher([
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
# realize ASSIGN and input to assign (might be optimized out)
(UPat(Ops.ASSIGN, name="a"), realize_assign),
# realize STORE
(UPat(Ops.STORE, name="a"), realize_store),
])
@dataclass(frozen=True)
@ -50,13 +55,14 @@ class IndexingContext:
# if a range has a 1 src, it's the same as UOp.const(dtypes.index, 0)
return UOp.range(s, next(self.range_idx), axistype) if resolve(s!=1) else UOp.const(dtypes.index, 0)
ops_allowed_after = (Ops.KERNEL, Ops.RANGE)
def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
if x.op in {Ops.BUFFERIZE, Ops.INDEX}: return None
if x.op is Ops.AFTER and x.src[1].op is Ops.KERNEL: return None
if x.op is Ops.AFTER and x.src[1].op in ops_allowed_after: return None
new_srcs = []
for s in x.src:
new_src = s
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op is Ops.KERNEL):
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op in ops_allowed_after):
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
elif s in ctx.realize_map:
realized_ranges = ctx.realize_map[s]
@ -176,7 +182,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
# mark all ranges as ended
assert rctx.realize_map[x] is None
rctx.realize_map[x] = list(range(len(x.shape)))
elif x.op in {Ops.MSTACK, Ops.MSELECT}:
elif x.op in {Ops.MSTACK, Ops.MSELECT, Ops.END}:
# treat MSTACK/MSELECT like SINK
continue
elif len(consumer_rngs) == 0:

View file

@ -396,6 +396,7 @@ def handle_after(ctx:LocalAddBufferContext, after:UOp):
return buf
def renumber_range(ctx:LocalAddBufferContext, r:UOp):
if r.arg[-1] == AxisType.OUTER: return None
if r.tag != (): return None
ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=None)
ctx.range += 1
@ -469,7 +470,10 @@ pm_add_range_tags = PatternMatcher([
])
def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
if len(x.ranges): return None
if len([r for r in x.ranges if r.arg[-1] != AxisType.OUTER]): return None
# ends of outer range don't go in kernels
if x.op is Ops.END and x.src[1].op is Ops.RANGE and x.src[1].arg[-1] == AxisType.OUTER: return None
# local kernel rewrite
lctx = LocalAddBufferContext()

View file

@ -15,11 +15,15 @@ if TYPE_CHECKING:
class AxisType(Enum):
def __repr__(self): return str(self)
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
THREAD = auto()
THREAD = auto(); OUTER = auto() # noqa: E702
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.OUTER: ("O")}
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta", AxisType.OUTER: "green"}
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
axis_to_pos = {AxisType.OUTER: -2, AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2,
AxisType.UPCAST: 3, AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1}