Compare commits

...

2 commits

Author SHA1 Message Date
George Hotz
db8c6d9a04 work 2025-11-10 15:26:16 -08:00
George Hotz
0647f87bf8 outer range runs in the scheduler 2025-11-10 14:49:42 -08:00
5 changed files with 41 additions and 12 deletions

View file

@ -0,0 +1,18 @@
from tinygrad import Tensor, UOp
from tinygrad.uop.ops import Ops, AxisType
import unittest
# this test is only focused on transformers and using range for the layers
class TestOuterworldTransformer(unittest.TestCase):
def test_three_mats(self):
w = Tensor.empty(3, 1024, 1024)
inp = Tensor.empty(1, 1024)
i = UOp.range(3, -1, AxisType.OUTER)
inp_after = Tensor(inp.uop.after(i))
inp_gemm = inp_after@w[i]
inp = inp.uop.after(inp.uop.store(inp_gemm.uop).end(i)).contiguous()
inp = Tensor(inp)
inp.realize()
if __name__ == "__main__":
unittest.main()

View file

@ -2,7 +2,8 @@ from __future__ import annotations
import math, itertools import math, itertools
from collections import defaultdict from collections import defaultdict
from typing import cast, Final from typing import cast, Final
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp, axis_letters, axis_colors from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp
from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos
from tinygrad.device import Buffer from tinygrad.device import Buffer
from tinygrad.dtype import dtypes, ImageDType from tinygrad.dtype import dtypes, ImageDType
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
@ -12,10 +13,6 @@ from tinygrad.renderer import Renderer
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)]) remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
class Scheduler: class Scheduler:
def __init__(self, ast:UOp, ren:Renderer): def __init__(self, ast:UOp, ren:Renderer):
self.ast, self.ren = ast, ren self.ast, self.ren = ast, ren

View file

@ -16,6 +16,9 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
for s in rb.src: for s in rb.src:
if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None if s.base.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
def realize_store(ctx:dict[UOp, None], a:UOp) -> None:
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None: def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
# if it's a kernel, we don't realize it # if it's a kernel, we don't realize it
@ -30,6 +33,8 @@ pm_generate_realize_map = PatternMatcher([
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs), (UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
# realize ASSIGN and input to assign (might be optimized out) # realize ASSIGN and input to assign (might be optimized out)
(UPat(Ops.ASSIGN, name="a"), realize_assign), (UPat(Ops.ASSIGN, name="a"), realize_assign),
# realize STORE
(UPat(Ops.STORE, name="a"), realize_store),
]) ])
@dataclass(frozen=True) @dataclass(frozen=True)
@ -50,13 +55,14 @@ class IndexingContext:
# if a range has a 1 src, it's the same as UOp.const(dtypes.index, 0) # if a range has a 1 src, it's the same as UOp.const(dtypes.index, 0)
return UOp.range(s, next(self.range_idx), axistype) if resolve(s!=1) else UOp.const(dtypes.index, 0) return UOp.range(s, next(self.range_idx), axistype) if resolve(s!=1) else UOp.const(dtypes.index, 0)
ops_allowed_after = (Ops.KERNEL, Ops.RANGE)
def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp): def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
if x.op in {Ops.BUFFERIZE, Ops.INDEX}: return None if x.op in {Ops.BUFFERIZE, Ops.INDEX}: return None
if x.op is Ops.AFTER and x.src[1].op is Ops.KERNEL: return None if x.op is Ops.AFTER and x.src[1].op in ops_allowed_after: return None
new_srcs = [] new_srcs = []
for s in x.src: for s in x.src:
new_src = s new_src = s
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op is Ops.KERNEL): if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op in ops_allowed_after):
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0]) if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
elif s in ctx.realize_map: elif s in ctx.realize_map:
realized_ranges = ctx.realize_map[s] realized_ranges = ctx.realize_map[s]
@ -176,7 +182,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
# mark all ranges as ended # mark all ranges as ended
assert rctx.realize_map[x] is None assert rctx.realize_map[x] is None
rctx.realize_map[x] = list(range(len(x.shape))) rctx.realize_map[x] = list(range(len(x.shape)))
elif x.op in {Ops.MSTACK, Ops.MSELECT}: elif x.op in {Ops.MSTACK, Ops.MSELECT, Ops.END}:
# treat MSTACK/MSELECT like SINK # treat MSTACK/MSELECT like SINK
continue continue
elif len(consumer_rngs) == 0: elif len(consumer_rngs) == 0:

View file

@ -396,6 +396,7 @@ def handle_after(ctx:LocalAddBufferContext, after:UOp):
return buf return buf
def renumber_range(ctx:LocalAddBufferContext, r:UOp): def renumber_range(ctx:LocalAddBufferContext, r:UOp):
if r.arg[-1] == AxisType.OUTER: return None
if r.tag != (): return None if r.tag != (): return None
ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=None) ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=None)
ctx.range += 1 ctx.range += 1
@ -469,7 +470,10 @@ pm_add_range_tags = PatternMatcher([
]) ])
def split_store(ctx:list[UOp], x:UOp) -> UOp|None: def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
if len(x.ranges): return None if len([r for r in x.ranges if r.arg[-1] != AxisType.OUTER]): return None
# ends of outer range don't go in kernels
if x.op is Ops.END and x.src[1].op is Ops.RANGE and x.src[1].arg[-1] == AxisType.OUTER: return None
# local kernel rewrite # local kernel rewrite
lctx = LocalAddBufferContext() lctx = LocalAddBufferContext()

View file

@ -15,11 +15,15 @@ if TYPE_CHECKING:
class AxisType(Enum): class AxisType(Enum):
def __repr__(self): return str(self) def __repr__(self): return str(self)
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702 GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
THREAD = auto() THREAD = auto(); OUTER = auto() # noqa: E702
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u", axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"} AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.OUTER: ("O")}
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE", axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"} AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta", AxisType.OUTER: "green"}
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
axis_to_pos = {AxisType.OUTER: -2, AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2,
AxisType.UPCAST: 3, AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1} range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1}