remove range split support

This commit is contained in:
George Hotz 2025-03-31 12:43:21 +08:00
commit 273dde69bd
3 changed files with 1 additions and 17 deletions

View file

@ -639,13 +639,8 @@ class Kernel:
# otherwise we just replace the VIEW source
return ret.replace(src=(st_uop,)) if len(op.src) == 1 else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:]))
if op.op is Ops.SINK:
range_split_axis = ()
if self.full_shape[self.first_reduce:self.first_reduce+3] == (3,3,32) and self.full_shape[-1] != 7:
#range_split_axis = (self.first_reduce-2, self.first_reduce-1)
#range_split_axis = (self.first_reduce-1,)
pass
return ret.replace(arg = KernelInfo(to_function_name(self.name) if name_override is None else name_override,
self.local_dims, self.upcasted, self.dont_use_locals, range_split_axis))
self.local_dims, self.upcasted, self.dont_use_locals))
if op.op is Ops.REDUCE_AXIS:
reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2

View file

@ -101,16 +101,6 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
assert isinstance(g, int), "needs to be int to upcast/unroll"
idxs.append(UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))
# range splitting
for i in ki.range_split_axis:
rng = idxs[i]
#rngv = UOp(Ops.VECTORIZE, rng.dtype.vec(2),
# (rng.const_like(rng.src[0]), rng.replace(src=(rng.src[0]+1, rng.src[1]))))
#idxs[i] = UOp(Ops.UNROLL, rng.dtype, (rngv,), ((0, 2),),)
rngv = UOp(Ops.VECTORIZE, rng.dtype.vec(3),
(rng.const_like(rng.src[0]), rng.replace(src=(rng.src[0]+1, rng.src[1]-1)), rng.const_like(rng.src[1]-1)))
idxs[i] = UOp(Ops.UNROLL, rng.dtype, (rngv,), ((rng.arg, 3),),)
# late indexes (group for reduce)
ridxs = idxs[:]
for a in range(first_reduce, first_reduce+group_for_reduces):

View file

@ -661,7 +661,6 @@ class KernelInfo:
local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
upcasted: int = 0 # count that are upcasted (this is remapping RANGE to UNROLL)
dont_use_locals: bool = False # don't use local indexing
range_split_axis: tuple[int, ...] = ()
# ******** ops in python ********