mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5ed146b520 |
1 changed files with 10 additions and 6 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
import math
|
import math
|
||||||
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType
|
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType
|
||||||
from tinygrad.helpers import all_int, partition, flatten, prod
|
from tinygrad.helpers import all_int, partition, flatten, prod, dedup
|
||||||
from tinygrad.dtype import dtypes
|
from tinygrad.dtype import dtypes
|
||||||
from tinygrad.shape.view import get_contraction
|
from tinygrad.shape.view import get_contraction
|
||||||
from tinygrad.renderer import Renderer
|
from tinygrad.renderer import Renderer
|
||||||
|
|
@ -52,20 +52,24 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No
|
||||||
|
|
||||||
def add_gpudims(ctx:Renderer, s:UOp):
|
def add_gpudims(ctx:Renderer, s:UOp):
|
||||||
if s.arg is None: return None
|
if s.arg is None: return None
|
||||||
ki: KernelInfo = s.arg
|
|
||||||
global_dims = [i for i,x in enumerate(ki.axis_types) if x is AxisType.GLOBAL]
|
|
||||||
local_dims = [i for i,x in enumerate(ki.axis_types) if x in (AxisType.LOCAL, AxisType.GROUP_REDUCE)]
|
|
||||||
if not global_dims and not local_dims: return None
|
|
||||||
s_topo = list(s.toposort())
|
s_topo = list(s.toposort())
|
||||||
if any(x.op is Ops.SPECIAL for x in s_topo): return None
|
if any(x.op is Ops.SPECIAL for x in s_topo): return None
|
||||||
|
|
||||||
# get global and local shape
|
# get ranges
|
||||||
all_ranges = {x.arg[0]%1000:x for x in s_topo if x.op is Ops.RANGE}
|
all_ranges = {x.arg[0]%1000:x for x in s_topo if x.op is Ops.RANGE}
|
||||||
|
|
||||||
|
# extract global/local dims
|
||||||
|
global_dims = sorted(dedup([x.arg[0]%1000 for x in all_ranges.values() if x.arg[1] is AxisType.GLOBAL]))
|
||||||
|
local_dims = sorted(dedup([x.arg[0]%1000 for x in all_ranges.values() if x.arg[1] in (AxisType.LOCAL, AxisType.GROUP_REDUCE)]))
|
||||||
|
if not global_dims and not local_dims: return None
|
||||||
|
|
||||||
|
# get global and local shape
|
||||||
ranges = [all_ranges[r] for r in global_dims+local_dims if r in all_ranges]
|
ranges = [all_ranges[r] for r in global_dims+local_dims if r in all_ranges]
|
||||||
global_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg[0]%1000 in global_dims])
|
global_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg[0]%1000 in global_dims])
|
||||||
local_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg[0]%1000 in local_dims])
|
local_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg[0]%1000 in local_dims])
|
||||||
|
|
||||||
# get the idxs
|
# get the idxs
|
||||||
|
ki: KernelInfo = s.arg
|
||||||
if ki.dont_use_locals:
|
if ki.dont_use_locals:
|
||||||
assert not local_dims, "can't use locals if there's no local dims"
|
assert not local_dims, "can't use locals if there's no local dims"
|
||||||
idxs = get_grouped_dims("idx", global_shape, ctx.global_max, reverse=True)
|
idxs = get_grouped_dims("idx", global_shape, ctx.global_max, reverse=True)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue