new endrange solution

This commit is contained in:
George Hotz 2025-08-15 15:44:31 -07:00
commit 4a6c2e5d68

View file

@ -84,10 +84,10 @@ class RangeifyContext:
seen_child: dict[UOp, Any] = field(default_factory=dict)
progress_children: dict[UOp, int] = field(default_factory=dict)
def map_reshape(x:UOp, r:UOp):
def map_reshape(idx:UOp, r:UOp):
acc = 1
to_sum = []
for s,src in list(zip(x.shape, x.src[1:]))[::-1]:
for s,src in list(zip(idx.shape, idx.src[1:]))[::-1]:
to_sum.append(acc*src)
acc *= s
mish = sum(to_sum)
@ -100,10 +100,10 @@ def map_reshape(x:UOp, r:UOp):
else:
ret.append(UOp.const(dtypes.int, 0))
ret = UOp.sink(*ret).simplify().src[::-1] if len(ret) else ()
return r.src[0].index(*ret, dtype=x.dtype)
return r.src[0].index(*ret, dtype=idx.dtype, arg=idx.arg)
def map_pad(x:UOp, r:UOp):
ret = list(x.src[1:])
def map_pad(idx:UOp, r:UOp):
ret = list(idx.src[1:])
bigwhere = UOp.const(dtypes.bool, True)
for i,(sh,(s,e)) in enumerate(zip(r.shape, r.arg)):
if s == 0 and e == 0: continue
@ -114,13 +114,13 @@ def map_pad(x:UOp, r:UOp):
# this is safe but dumb
ret[i] = (ret[i] - s).maximum(0).minimum(r.src[0].shape[i]-1)
# PAD is with 0
return bigwhere.simplify().where(UOp(Ops.INDEX, r.dtype, src=(r.src[0],)+tuple(ret)), UOp.const(r.dtype, 0))
return bigwhere.simplify().where(r.src[0].index(*ret, dtype=idx.dtype, arg=idx.arg), UOp.const(r.dtype, 0))
def map_expand(r:UOp, x:UOp):
def map_expand(r:UOp, idx:UOp):
new_rngs = []
ending_ranges = []
non_ending_ranges = []
for a,x,y in zip(x.src[1:], r.src[0].shape, r.shape):
for a,x,y in zip(idx.src[1:], r.src[0].shape, r.shape):
axis_to_range = [u for u in a.toposort() if u.op is Ops.RANGE]
if resolve(x!=y, False):
ending_ranges.extend(axis_to_range)
@ -128,25 +128,24 @@ def map_expand(r:UOp, x:UOp):
else:
non_ending_ranges.extend(axis_to_range)
new_rngs.append(a)
ending_ranges = [x for x in ending_ranges if x not in non_ending_ranges]
ret = r.src[0]
ret = UOp(Ops.ENDRANGE, dtype=ret.dtype, src=(ret,)+tuple(ending_ranges)) if len(ending_ranges) else ret
return ret.index(*new_rngs)
ending_ranges = [x.arg for x in ending_ranges if x not in non_ending_ranges]
if idx.arg is not None: ending_ranges.append(idx.arg)
return r.src[0].index(*new_rngs, arg=min([x for x in ending_ranges]) if ending_ranges else None)
pm_mops = PatternMatcher([
# this is like the definitions of these
(UPat(Ops.INDEX, src=(UPat(Ops.SHRINK, name="r"),), allow_any_len=True, name="x"),
lambda r,x: r.src[0].index(*[a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(x.src[1:], r.arg)], dtype=x.dtype)),
(UPat(Ops.INDEX, src=(UPat(Ops.PERMUTE, name="r"),), allow_any_len=True, name="x"),
lambda r,x: r.src[0].index(*[x.src[1+p] for p in argsort(x.src[0].arg)])),
(UPat(Ops.INDEX, src=(UPat(Ops.FLIP, name="r"),), allow_any_len=True, name="x"),
lambda r,x: r.src[0].index(*[((s-1)-a) if f else a for a,s,f in zip(x.src[1:], r.shape, r.arg)])),
(UPat(Ops.INDEX, src=(UPat(Ops.SHRINK, name="r"),), allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*[a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(idx.src[1:], r.arg)], dtype=idx.dtype, arg=idx.arg)),
(UPat(Ops.INDEX, src=(UPat(Ops.PERMUTE, name="r"),), allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*[idx.src[1+p] for p in argsort(idx.src[0].arg)], dtype=idx.dtype, arg=idx.arg)),
(UPat(Ops.INDEX, src=(UPat(Ops.FLIP, name="r"),), allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*[((s-1)-a) if f else a for a,s,f in zip(idx.src[1:], r.shape, r.arg)], dtype=idx.dtype, arg=idx.arg)),
# expand needs to end ranges
(UPat(Ops.INDEX, src=(UPat(Ops.EXPAND, name="r"),), allow_any_len=True, name="x"), map_expand),
(UPat(Ops.INDEX, src=(UPat(Ops.EXPAND, name="r"),), allow_any_len=True, name="idx"), map_expand),
# reshape does a lot of symbolic stuff
(UPat(Ops.INDEX, src=(UPat(Ops.RESHAPE, name="r"),), allow_any_len=True, name="x"), map_reshape),
(UPat(Ops.INDEX, src=(UPat(Ops.RESHAPE, name="r"),), allow_any_len=True, name="idx"), map_reshape),
# pad adds min and max
(UPat(Ops.INDEX, src=(UPat(Ops.PAD, name="r"),), allow_any_len=True, name="x"), map_pad),
(UPat(Ops.INDEX, src=(UPat(Ops.PAD, name="r"),), allow_any_len=True, name="idx"), map_pad),
])
def map_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp|None=None):
@ -213,15 +212,14 @@ def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
if len(idx_ranges) == 0: return c.index(*out_rngs)
return c.index(*out_rngs).bufferize(*end_ranges, arg=x.device).index(*[idx.src[1+i] for i in idx_ranges])
def indexed_endrange(er:UOp, idx:UOp):
ended = er.src[1:]
earliest_ending_axis = min([x.arg for x in ended])
def might_end_axis(idx:UOp):
if idx.arg is None: return None
to_end_axis = []
for i,a in enumerate(idx.src[1:]):
if any(x.arg > earliest_ending_axis for x in a.toposort() if x.op is Ops.RANGE):
if any(x.arg > idx.arg for x in a.toposort() if x.op is Ops.RANGE):
to_end_axis.append(i)
if to_end_axis: return idx.replace(src=(er.src[0].contiguous(arg=tuple(to_end_axis)),)+idx.src[1:])
return idx.replace(src=(er.src[0],)+idx.src[1:])
if to_end_axis: return idx.replace(src=(idx.src[0].contiguous(arg=tuple(to_end_axis)),)+idx.src[1:], arg=None)
return idx.replace(arg=None)
pm_rangeify = pm_mops+PatternMatcher([
# if there are new ended children, tag the SINK
@ -236,17 +234,8 @@ pm_rangeify = pm_mops+PatternMatcher([
# sink contigs to kick it off
(UPat(Ops.CONTIGUOUS, src=(UPat(),), name="x"), lambda ctx,x: map_contiguous(ctx, x)),
# handle ENDRANGE on ENDRANGE
(UPat(Ops.ENDRANGE, src=(UPat(Ops.ENDRANGE, name="e1"),), allow_any_len=True, name="e2"), lambda e1,e2: e1.replace(src=e1.src+e2.src[1:])),
# handle ENDRANGE on movement
(UPat(Ops.ENDRANGE, src=(UPat(GroupOp.Movement),), allow_any_len=True, name="er"),
lambda er: er.src[0].replace(src=(UOp(Ops.ENDRANGE, dtype=er.dtype, src=(er.src[0].src[0],)+er.src[1:]),))),
# handle ENDRANGE on BUFFER
# and CHILD: python3 test/test_schedule.py TestSchedule.test_cache_reduce_parent
(UPat(Ops.ENDRANGE, src=(UPat({Ops.BUFFER, Ops.CONST, Ops.BUFFERIZE, Ops.CHILD}),), allow_any_len=True, name="er"), lambda er: er.src[0]),
# handle INDEXed ENDRANGE
(UPat(Ops.INDEX, src=(UPat(Ops.ENDRANGE, src=(UPat(GroupOp.Elementwise.union({Ops.REDUCE_AXIS})),), allow_any_len=True, name="er"),),
allow_any_len=True, name="idx"), indexed_endrange),
# handle arg on any op with weight. old endrange stuff
(UPat(Ops.INDEX,src=(UPat(GroupOp.Elementwise.union({Ops.REDUCE_AXIS})),), allow_any_len=True, name="idx"), might_end_axis),
# move MAP through elementwise ALU / reduce. these are the items with cost
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.STORE, Ops.ASSIGN, Ops.COPY, Ops.DEVICE})),), allow_any_len=True, name="x"),
@ -379,6 +368,7 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
from tinygrad.renderer.cstyle import CStyleLanguage
src = CStyleLanguage().render(rsink.arg.lst)
print(src)
return {sink:sink}
tensor_map = graph_rewrite_map(tensor_map[sink], split_kernels, input_map=tensor_map, name="split kernels")
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Kernel Graph")