mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
remove range split support
This commit is contained in:
parent
a64030d8c8
commit
273dde69bd
3 changed files with 1 additions and 17 deletions
|
|
@ -639,13 +639,8 @@ class Kernel:
|
|||
# otherwise we just replace the VIEW source
|
||||
return ret.replace(src=(st_uop,)) if len(op.src) == 1 else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:]))
|
||||
if op.op is Ops.SINK:
|
||||
range_split_axis = ()
|
||||
if self.full_shape[self.first_reduce:self.first_reduce+3] == (3,3,32) and self.full_shape[-1] != 7:
|
||||
#range_split_axis = (self.first_reduce-2, self.first_reduce-1)
|
||||
#range_split_axis = (self.first_reduce-1,)
|
||||
pass
|
||||
return ret.replace(arg = KernelInfo(to_function_name(self.name) if name_override is None else name_override,
|
||||
self.local_dims, self.upcasted, self.dont_use_locals, range_split_axis))
|
||||
self.local_dims, self.upcasted, self.dont_use_locals))
|
||||
if op.op is Ops.REDUCE_AXIS:
|
||||
reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
|
||||
|
||||
|
|
|
|||
|
|
@ -101,16 +101,6 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
|||
assert isinstance(g, int), "needs to be int to upcast/unroll"
|
||||
idxs.append(UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))
|
||||
|
||||
# range splitting
|
||||
for i in ki.range_split_axis:
|
||||
rng = idxs[i]
|
||||
#rngv = UOp(Ops.VECTORIZE, rng.dtype.vec(2),
|
||||
# (rng.const_like(rng.src[0]), rng.replace(src=(rng.src[0]+1, rng.src[1]))))
|
||||
#idxs[i] = UOp(Ops.UNROLL, rng.dtype, (rngv,), ((0, 2),),)
|
||||
rngv = UOp(Ops.VECTORIZE, rng.dtype.vec(3),
|
||||
(rng.const_like(rng.src[0]), rng.replace(src=(rng.src[0]+1, rng.src[1]-1)), rng.const_like(rng.src[1]-1)))
|
||||
idxs[i] = UOp(Ops.UNROLL, rng.dtype, (rngv,), ((rng.arg, 3),),)
|
||||
|
||||
# late indexes (group for reduce)
|
||||
ridxs = idxs[:]
|
||||
for a in range(first_reduce, first_reduce+group_for_reduces):
|
||||
|
|
|
|||
|
|
@ -661,7 +661,6 @@ class KernelInfo:
|
|||
local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
|
||||
upcasted: int = 0 # count that are upcasted (this is remapping RANGE to UNROLL)
|
||||
dont_use_locals: bool = False # don't use local indexing
|
||||
range_split_axis: tuple[int, ...] = ()
|
||||
|
||||
# ******** ops in python ********
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue