mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
global/locals from AxisType in range
This commit is contained in:
parent
a75da49951
commit
5ed146b520
1 changed files with 10 additions and 6 deletions
|
|
@ -1,6 +1,6 @@
|
|||
import math
|
||||
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType
|
||||
from tinygrad.helpers import all_int, partition, flatten, prod
|
||||
from tinygrad.helpers import all_int, partition, flatten, prod, dedup
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.shape.view import get_contraction
|
||||
from tinygrad.renderer import Renderer
|
||||
|
|
@ -52,20 +52,24 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No
|
|||
|
||||
def add_gpudims(ctx:Renderer, s:UOp):
|
||||
if s.arg is None: return None
|
||||
ki: KernelInfo = s.arg
|
||||
global_dims = [i for i,x in enumerate(ki.axis_types) if x is AxisType.GLOBAL]
|
||||
local_dims = [i for i,x in enumerate(ki.axis_types) if x in (AxisType.LOCAL, AxisType.GROUP_REDUCE)]
|
||||
if not global_dims and not local_dims: return None
|
||||
s_topo = list(s.toposort())
|
||||
if any(x.op is Ops.SPECIAL for x in s_topo): return None
|
||||
|
||||
# get global and local shape
|
||||
# get ranges
|
||||
all_ranges = {x.arg[0]%1000:x for x in s_topo if x.op is Ops.RANGE}
|
||||
|
||||
# extract global/local dims
|
||||
global_dims = sorted(dedup([x.arg[0]%1000 for x in all_ranges.values() if x.arg[1] is AxisType.GLOBAL]))
|
||||
local_dims = sorted(dedup([x.arg[0]%1000 for x in all_ranges.values() if x.arg[1] in (AxisType.LOCAL, AxisType.GROUP_REDUCE)]))
|
||||
if not global_dims and not local_dims: return None
|
||||
|
||||
# get global and local shape
|
||||
ranges = [all_ranges[r] for r in global_dims+local_dims if r in all_ranges]
|
||||
global_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg[0]%1000 in global_dims])
|
||||
local_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg[0]%1000 in local_dims])
|
||||
|
||||
# get the idxs
|
||||
ki: KernelInfo = s.arg
|
||||
if ki.dont_use_locals:
|
||||
assert not local_dims, "can't use locals if there's no local dims"
|
||||
idxs = get_grouped_dims("idx", global_shape, ctx.global_max, reverse=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue