Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
30b9c6ed26 global local dims in gpudims [pr] 2025-07-29 10:31:39 -07:00
2 changed files with 8 additions and 10 deletions

View file

@ -1,5 +1,5 @@
import math
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType
from tinygrad.helpers import all_int
from tinygrad.dtype import dtypes
from tinygrad.shape.view import get_contraction
@ -53,16 +53,18 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No
def add_gpudims(ctx:Renderer, s:UOp):
if s.arg is None: return None
ki: KernelInfo = s.arg
if not ki.global_dims and not ki.local_dims: return None
global_dims = [i for i,x in enumerate(ki.axis_types) if x is AxisType.GLOBAL]
local_dims = [i for i,x in enumerate(ki.axis_types) if x in (AxisType.LOCAL, AxisType.GROUP_REDUCE)]
if not global_dims and not local_dims: return None
s_topo = list(s.toposort())
if any(x.op is Ops.SPECIAL for x in s_topo): return None
all_ranges = {x.arg:x for x in s_topo if x.op is Ops.RANGE}
# NOTE: this supports globals/locals in any position
ranges = [all_ranges[r] for r in ki.global_dims+ki.local_dims]
global_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg in ki.global_dims])
local_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg in ki.local_dims])
ranges = [all_ranges[r] for r in global_dims+local_dims]
global_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg in global_dims])
local_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg in local_dims])
if ki.dont_use_locals:
assert not ki.local_dims, "can't use locals if there's no local dims"
assert not local_dims, "can't use locals if there's no local dims"
idxs = get_grouped_dims("idx", global_shape, ctx.global_max, reverse=True)
else:
# define indexes for GPU-like execution

View file

@ -537,10 +537,6 @@ class KernelInfo:
opts_to_apply: tuple|None = None
@property
def function_name(self): return to_function_name(self.name)
@property
def global_dims(self) -> list[int]: return [i for i,x in enumerate(self.axis_types) if x is AxisType.GLOBAL]
@property
def local_dims(self) -> list[int]: return [i for i,x in enumerate(self.axis_types) if x in (AxisType.LOCAL, AxisType.GROUP_REDUCE)]
# ******** ops in python ********