global/locals from AxisType in range

This commit is contained in:
George Hotz 2025-08-23 15:33:39 -07:00
commit 5ed146b520

View file

@ -1,6 +1,6 @@
import math
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType
from tinygrad.helpers import all_int, partition, flatten, prod
from tinygrad.helpers import all_int, partition, flatten, prod, dedup
from tinygrad.dtype import dtypes
from tinygrad.shape.view import get_contraction
from tinygrad.renderer import Renderer
@ -52,20 +52,24 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No
def add_gpudims(ctx:Renderer, s:UOp):
if s.arg is None: return None
ki: KernelInfo = s.arg
global_dims = [i for i,x in enumerate(ki.axis_types) if x is AxisType.GLOBAL]
local_dims = [i for i,x in enumerate(ki.axis_types) if x in (AxisType.LOCAL, AxisType.GROUP_REDUCE)]
if not global_dims and not local_dims: return None
s_topo = list(s.toposort())
if any(x.op is Ops.SPECIAL for x in s_topo): return None
# get global and local shape
# get ranges
all_ranges = {x.arg[0]%1000:x for x in s_topo if x.op is Ops.RANGE}
# extract global/local dims
global_dims = sorted(dedup([x.arg[0]%1000 for x in all_ranges.values() if x.arg[1] is AxisType.GLOBAL]))
local_dims = sorted(dedup([x.arg[0]%1000 for x in all_ranges.values() if x.arg[1] in (AxisType.LOCAL, AxisType.GROUP_REDUCE)]))
if not global_dims and not local_dims: return None
# get global and local shape
ranges = [all_ranges[r] for r in global_dims+local_dims if r in all_ranges]
global_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg[0]%1000 in global_dims])
local_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg[0]%1000 in local_dims])
# get the idxs
ki: KernelInfo = s.arg
if ki.dont_use_locals:
assert not local_dims, "can't use locals if there's no local dims"
idxs = get_grouped_dims("idx", global_shape, ctx.global_max, reverse=True)