mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
revert-118
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f8526d2b2 |
2 changed files with 11 additions and 6 deletions
|
|
@ -1,5 +1,5 @@
|
|||
import math
|
||||
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, ssimplify, AxisType
|
||||
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType
|
||||
from tinygrad.helpers import all_int, partition, flatten, prod, dedup
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.shape.view import get_contraction
|
||||
|
|
@ -68,8 +68,14 @@ def add_gpudims(ctx:Renderer, s:UOp):
|
|||
global_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg[0]%1000 in global_dims])
|
||||
local_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg[0]%1000 in local_dims])
|
||||
|
||||
# define indexes for GPU-like execution
|
||||
idxs = get_grouped_dims("gidx", global_shape, ctx.global_max, reverse=True) + get_grouped_dims("lidx", local_shape, ctx.local_max)
|
||||
# get the idxs
|
||||
ki: KernelInfo = s.arg
|
||||
if ki.dont_use_locals:
|
||||
assert not local_dims, "can't use locals if there's no local dims"
|
||||
idxs = get_grouped_dims("idx", global_shape, ctx.global_max, reverse=True)
|
||||
else:
|
||||
# define indexes for GPU-like execution
|
||||
idxs = get_grouped_dims("gidx", global_shape, ctx.global_max, reverse=True) + get_grouped_dims("lidx", local_shape, ctx.local_max)
|
||||
|
||||
# apply to multiple ranges
|
||||
subs = {}
|
||||
|
|
@ -77,7 +83,7 @@ def add_gpudims(ctx:Renderer, s:UOp):
|
|||
if r.op is not Ops.RANGE: continue
|
||||
try:
|
||||
ii = (global_dims+local_dims).index(r.arg[0]%1000)
|
||||
if r.arg[0] < 2000 and r.arg[1] == AxisType.GROUP_REDUCE: continue
|
||||
if r.arg[0] < 2000 and ki.axis_types[r.arg[0]%1000] == AxisType.GROUP_REDUCE: continue
|
||||
subs[r] = idxs[ii]
|
||||
except ValueError: continue
|
||||
return s.substitute(subs)
|
||||
|
|
|
|||
|
|
@ -18,8 +18,7 @@ def shape_to_idx(s, axis_types, start=0):
|
|||
|
||||
def get_index(ast:UOp) -> IndexContext:
|
||||
axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else ()
|
||||
if len(ast.full_shape) != len(axis_types):
|
||||
axis_types = tuple([AxisType.REDUCE if s is not fs else AxisType.LOOP for s,fs in zip(ast.shape, ast.full_shape)])
|
||||
if len(ast.full_shape) != len(axis_types): axis_types = (AxisType.LOOP,)*len(ast.full_shape)
|
||||
return IndexContext(axis_types, [], 0)
|
||||
|
||||
# ***** lowering (given index) *****
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue