Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
5f8526d2b2
Revert "remove KernelInfo from gpudims (#11809)"
This reverts commit 846753f343.
2025-08-23 19:37:00 -07:00
2 changed files with 11 additions and 6 deletions

View file

@ -1,5 +1,5 @@
import math
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, ssimplify, AxisType
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType
from tinygrad.helpers import all_int, partition, flatten, prod, dedup
from tinygrad.dtype import dtypes
from tinygrad.shape.view import get_contraction
@ -68,8 +68,14 @@ def add_gpudims(ctx:Renderer, s:UOp):
global_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg[0]%1000 in global_dims])
local_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg[0]%1000 in local_dims])
# define indexes for GPU-like execution
idxs = get_grouped_dims("gidx", global_shape, ctx.global_max, reverse=True) + get_grouped_dims("lidx", local_shape, ctx.local_max)
# get the idxs
ki: KernelInfo = s.arg
if ki.dont_use_locals:
assert not local_dims, "can't use locals if there's no local dims"
idxs = get_grouped_dims("idx", global_shape, ctx.global_max, reverse=True)
else:
# define indexes for GPU-like execution
idxs = get_grouped_dims("gidx", global_shape, ctx.global_max, reverse=True) + get_grouped_dims("lidx", local_shape, ctx.local_max)
# apply to multiple ranges
subs = {}
@ -77,7 +83,7 @@ def add_gpudims(ctx:Renderer, s:UOp):
if r.op is not Ops.RANGE: continue
try:
ii = (global_dims+local_dims).index(r.arg[0]%1000)
if r.arg[0] < 2000 and r.arg[1] == AxisType.GROUP_REDUCE: continue
if r.arg[0] < 2000 and ki.axis_types[r.arg[0]%1000] == AxisType.GROUP_REDUCE: continue
subs[r] = idxs[ii]
except ValueError: continue
return s.substitute(subs)

View file

@ -18,8 +18,7 @@ def shape_to_idx(s, axis_types, start=0):
def get_index(ast:UOp) -> IndexContext:
axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else ()
if len(ast.full_shape) != len(axis_types):
axis_types = tuple([AxisType.REDUCE if s is not fs else AxisType.LOOP for s,fs in zip(ast.shape, ast.full_shape)])
if len(ast.full_shape) != len(axis_types): axis_types = (AxisType.LOOP,)*len(ast.full_shape)
return IndexContext(axis_types, [], 0)
# ***** lowering (given index) *****