extend to 128

This commit is contained in:
George Hotz 2025-03-27 10:35:06 +08:00
commit 38488ec3b0
2 changed files with 6 additions and 1 deletions

View file

@ -46,6 +46,10 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
offsets[i+3] = {}
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
for grp in grouped_offsets:
if len(grp) == 32:
for jj in range(grp[0]+32, grp[0]+128):
grp.append(jj)
offsets[jj] = []
# get the index offset for this element. using [0] is okay, because they are the same
lidx = midx.src[offsets[grp[0]][0]]
if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, local=ptrdtype.local))

View file

@ -588,7 +588,8 @@ class Kernel:
range_split_axis = ()
if self.full_shape[self.first_reduce:self.first_reduce+3] == (3,3,32) and self.full_shape[-1] != 7:
#range_split_axis = (self.first_reduce-2, self.first_reduce-1)
range_split_axis = (self.first_reduce-1,)
#range_split_axis = (self.first_reduce-1,)
pass
return ret.replace(arg = KernelInfo(to_function_name(self.name) if name_override is None else name_override,
self.local_dims, self.upcasted, self.dont_use_locals, range_split_axis))
if op.op is Ops.REDUCE_AXIS: