mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
30b9c6ed26 |
2 changed files with 8 additions and 10 deletions
|
|
@ -1,5 +1,5 @@
|
|||
import math
|
||||
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify
|
||||
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType
|
||||
from tinygrad.helpers import all_int
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.shape.view import get_contraction
|
||||
|
|
@ -53,16 +53,18 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No
|
|||
def add_gpudims(ctx:Renderer, s:UOp):
|
||||
if s.arg is None: return None
|
||||
ki: KernelInfo = s.arg
|
||||
if not ki.global_dims and not ki.local_dims: return None
|
||||
global_dims = [i for i,x in enumerate(ki.axis_types) if x is AxisType.GLOBAL]
|
||||
local_dims = [i for i,x in enumerate(ki.axis_types) if x in (AxisType.LOCAL, AxisType.GROUP_REDUCE)]
|
||||
if not global_dims and not local_dims: return None
|
||||
s_topo = list(s.toposort())
|
||||
if any(x.op is Ops.SPECIAL for x in s_topo): return None
|
||||
all_ranges = {x.arg:x for x in s_topo if x.op is Ops.RANGE}
|
||||
# NOTE: this supports globals/locals in any position
|
||||
ranges = [all_ranges[r] for r in ki.global_dims+ki.local_dims]
|
||||
global_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg in ki.global_dims])
|
||||
local_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg in ki.local_dims])
|
||||
ranges = [all_ranges[r] for r in global_dims+local_dims]
|
||||
global_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg in global_dims])
|
||||
local_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg in local_dims])
|
||||
if ki.dont_use_locals:
|
||||
assert not ki.local_dims, "can't use locals if there's no local dims"
|
||||
assert not local_dims, "can't use locals if there's no local dims"
|
||||
idxs = get_grouped_dims("idx", global_shape, ctx.global_max, reverse=True)
|
||||
else:
|
||||
# define indexes for GPU-like execution
|
||||
|
|
|
|||
|
|
@ -537,10 +537,6 @@ class KernelInfo:
|
|||
opts_to_apply: tuple|None = None
|
||||
@property
|
||||
def function_name(self): return to_function_name(self.name)
|
||||
@property
|
||||
def global_dims(self) -> list[int]: return [i for i,x in enumerate(self.axis_types) if x is AxisType.GLOBAL]
|
||||
@property
|
||||
def local_dims(self) -> list[int]: return [i for i,x in enumerate(self.axis_types) if x in (AxisType.LOCAL, AxisType.GROUP_REDUCE)]
|
||||
|
||||
# ******** ops in python ********
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue