mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
min_outer_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
54ab6aa247 |
4 changed files with 11 additions and 9 deletions
0
test/test_vmap.py
Normal file
0
test/test_vmap.py
Normal file
|
|
@ -2,7 +2,8 @@ from __future__ import annotations
|
||||||
import math, itertools
|
import math, itertools
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import cast, Final
|
from typing import cast, Final
|
||||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp, axis_letters, axis_colors
|
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp
|
||||||
|
from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos
|
||||||
from tinygrad.device import Buffer
|
from tinygrad.device import Buffer
|
||||||
from tinygrad.dtype import dtypes, ImageDType
|
from tinygrad.dtype import dtypes, ImageDType
|
||||||
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
|
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
|
||||||
|
|
@ -12,10 +13,6 @@ from tinygrad.renderer import Renderer
|
||||||
|
|
||||||
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||||
|
|
||||||
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
|
|
||||||
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
|
|
||||||
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
|
|
||||||
|
|
||||||
class Scheduler:
|
class Scheduler:
|
||||||
def __init__(self, ast:UOp, ren:Renderer):
|
def __init__(self, ast:UOp, ren:Renderer):
|
||||||
self.ast, self.ren = ast, ren
|
self.ast, self.ren = ast, ren
|
||||||
|
|
|
||||||
|
|
@ -469,7 +469,7 @@ pm_add_range_tags = PatternMatcher([
|
||||||
])
|
])
|
||||||
|
|
||||||
def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
||||||
if len(x.ranges): return None
|
if len([r for r in x.ranges if r.arg[-1] != AxisType.OUTER]): return None
|
||||||
|
|
||||||
# local kernel rewrite
|
# local kernel rewrite
|
||||||
lctx = LocalAddBufferContext()
|
lctx = LocalAddBufferContext()
|
||||||
|
|
|
||||||
|
|
@ -14,11 +14,16 @@ if TYPE_CHECKING:
|
||||||
class AxisType(Enum):
|
class AxisType(Enum):
|
||||||
def __repr__(self): return str(self)
|
def __repr__(self): return str(self)
|
||||||
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
|
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
|
||||||
THREAD = auto()
|
THREAD = auto(); OUTER = auto() # noqa: E702
|
||||||
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
|
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
|
||||||
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r"}
|
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.OUTER: "O"}
|
||||||
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
|
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
|
||||||
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
|
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta",
|
||||||
|
AxisType.OUTER: "green"}
|
||||||
|
|
||||||
|
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
|
||||||
|
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
|
||||||
|
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5, AxisType.OUTER: -2}
|
||||||
|
|
||||||
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1}
|
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue