mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
new endrange solution
This commit is contained in:
parent
801712880e
commit
4a6c2e5d68
1 changed files with 28 additions and 38 deletions
|
|
@ -84,10 +84,10 @@ class RangeifyContext:
|
|||
seen_child: dict[UOp, Any] = field(default_factory=dict)
|
||||
progress_children: dict[UOp, int] = field(default_factory=dict)
|
||||
|
||||
def map_reshape(x:UOp, r:UOp):
|
||||
def map_reshape(idx:UOp, r:UOp):
|
||||
acc = 1
|
||||
to_sum = []
|
||||
for s,src in list(zip(x.shape, x.src[1:]))[::-1]:
|
||||
for s,src in list(zip(idx.shape, idx.src[1:]))[::-1]:
|
||||
to_sum.append(acc*src)
|
||||
acc *= s
|
||||
mish = sum(to_sum)
|
||||
|
|
@ -100,10 +100,10 @@ def map_reshape(x:UOp, r:UOp):
|
|||
else:
|
||||
ret.append(UOp.const(dtypes.int, 0))
|
||||
ret = UOp.sink(*ret).simplify().src[::-1] if len(ret) else ()
|
||||
return r.src[0].index(*ret, dtype=x.dtype)
|
||||
return r.src[0].index(*ret, dtype=idx.dtype, arg=idx.arg)
|
||||
|
||||
def map_pad(x:UOp, r:UOp):
|
||||
ret = list(x.src[1:])
|
||||
def map_pad(idx:UOp, r:UOp):
|
||||
ret = list(idx.src[1:])
|
||||
bigwhere = UOp.const(dtypes.bool, True)
|
||||
for i,(sh,(s,e)) in enumerate(zip(r.shape, r.arg)):
|
||||
if s == 0 and e == 0: continue
|
||||
|
|
@ -114,13 +114,13 @@ def map_pad(x:UOp, r:UOp):
|
|||
# this is safe but dumb
|
||||
ret[i] = (ret[i] - s).maximum(0).minimum(r.src[0].shape[i]-1)
|
||||
# PAD is with 0
|
||||
return bigwhere.simplify().where(UOp(Ops.INDEX, r.dtype, src=(r.src[0],)+tuple(ret)), UOp.const(r.dtype, 0))
|
||||
return bigwhere.simplify().where(r.src[0].index(*ret, dtype=idx.dtype, arg=idx.arg), UOp.const(r.dtype, 0))
|
||||
|
||||
def map_expand(r:UOp, x:UOp):
|
||||
def map_expand(r:UOp, idx:UOp):
|
||||
new_rngs = []
|
||||
ending_ranges = []
|
||||
non_ending_ranges = []
|
||||
for a,x,y in zip(x.src[1:], r.src[0].shape, r.shape):
|
||||
for a,x,y in zip(idx.src[1:], r.src[0].shape, r.shape):
|
||||
axis_to_range = [u for u in a.toposort() if u.op is Ops.RANGE]
|
||||
if resolve(x!=y, False):
|
||||
ending_ranges.extend(axis_to_range)
|
||||
|
|
@ -128,25 +128,24 @@ def map_expand(r:UOp, x:UOp):
|
|||
else:
|
||||
non_ending_ranges.extend(axis_to_range)
|
||||
new_rngs.append(a)
|
||||
ending_ranges = [x for x in ending_ranges if x not in non_ending_ranges]
|
||||
ret = r.src[0]
|
||||
ret = UOp(Ops.ENDRANGE, dtype=ret.dtype, src=(ret,)+tuple(ending_ranges)) if len(ending_ranges) else ret
|
||||
return ret.index(*new_rngs)
|
||||
ending_ranges = [x.arg for x in ending_ranges if x not in non_ending_ranges]
|
||||
if idx.arg is not None: ending_ranges.append(idx.arg)
|
||||
return r.src[0].index(*new_rngs, arg=min([x for x in ending_ranges]) if ending_ranges else None)
|
||||
|
||||
pm_mops = PatternMatcher([
|
||||
# this is like the definitions of these
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.SHRINK, name="r"),), allow_any_len=True, name="x"),
|
||||
lambda r,x: r.src[0].index(*[a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(x.src[1:], r.arg)], dtype=x.dtype)),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.PERMUTE, name="r"),), allow_any_len=True, name="x"),
|
||||
lambda r,x: r.src[0].index(*[x.src[1+p] for p in argsort(x.src[0].arg)])),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.FLIP, name="r"),), allow_any_len=True, name="x"),
|
||||
lambda r,x: r.src[0].index(*[((s-1)-a) if f else a for a,s,f in zip(x.src[1:], r.shape, r.arg)])),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.SHRINK, name="r"),), allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*[a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(idx.src[1:], r.arg)], dtype=idx.dtype, arg=idx.arg)),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.PERMUTE, name="r"),), allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*[idx.src[1+p] for p in argsort(idx.src[0].arg)], dtype=idx.dtype, arg=idx.arg)),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.FLIP, name="r"),), allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*[((s-1)-a) if f else a for a,s,f in zip(idx.src[1:], r.shape, r.arg)], dtype=idx.dtype, arg=idx.arg)),
|
||||
# expand needs to end ranges
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.EXPAND, name="r"),), allow_any_len=True, name="x"), map_expand),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.EXPAND, name="r"),), allow_any_len=True, name="idx"), map_expand),
|
||||
# reshape does a lot of symbolic stuff
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.RESHAPE, name="r"),), allow_any_len=True, name="x"), map_reshape),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.RESHAPE, name="r"),), allow_any_len=True, name="idx"), map_reshape),
|
||||
# pad adds min and max
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.PAD, name="r"),), allow_any_len=True, name="x"), map_pad),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.PAD, name="r"),), allow_any_len=True, name="idx"), map_pad),
|
||||
])
|
||||
|
||||
def map_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp|None=None):
|
||||
|
|
@ -213,15 +212,14 @@ def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
|
|||
if len(idx_ranges) == 0: return c.index(*out_rngs)
|
||||
return c.index(*out_rngs).bufferize(*end_ranges, arg=x.device).index(*[idx.src[1+i] for i in idx_ranges])
|
||||
|
||||
def indexed_endrange(er:UOp, idx:UOp):
|
||||
ended = er.src[1:]
|
||||
earliest_ending_axis = min([x.arg for x in ended])
|
||||
def might_end_axis(idx:UOp):
|
||||
if idx.arg is None: return None
|
||||
to_end_axis = []
|
||||
for i,a in enumerate(idx.src[1:]):
|
||||
if any(x.arg > earliest_ending_axis for x in a.toposort() if x.op is Ops.RANGE):
|
||||
if any(x.arg > idx.arg for x in a.toposort() if x.op is Ops.RANGE):
|
||||
to_end_axis.append(i)
|
||||
if to_end_axis: return idx.replace(src=(er.src[0].contiguous(arg=tuple(to_end_axis)),)+idx.src[1:])
|
||||
return idx.replace(src=(er.src[0],)+idx.src[1:])
|
||||
if to_end_axis: return idx.replace(src=(idx.src[0].contiguous(arg=tuple(to_end_axis)),)+idx.src[1:], arg=None)
|
||||
return idx.replace(arg=None)
|
||||
|
||||
pm_rangeify = pm_mops+PatternMatcher([
|
||||
# if there are new ended children, tag the SINK
|
||||
|
|
@ -236,17 +234,8 @@ pm_rangeify = pm_mops+PatternMatcher([
|
|||
# sink contigs to kick it off
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat(),), name="x"), lambda ctx,x: map_contiguous(ctx, x)),
|
||||
|
||||
# handle ENDRANGE on ENDRANGE
|
||||
(UPat(Ops.ENDRANGE, src=(UPat(Ops.ENDRANGE, name="e1"),), allow_any_len=True, name="e2"), lambda e1,e2: e1.replace(src=e1.src+e2.src[1:])),
|
||||
# handle ENDRANGE on movement
|
||||
(UPat(Ops.ENDRANGE, src=(UPat(GroupOp.Movement),), allow_any_len=True, name="er"),
|
||||
lambda er: er.src[0].replace(src=(UOp(Ops.ENDRANGE, dtype=er.dtype, src=(er.src[0].src[0],)+er.src[1:]),))),
|
||||
# handle ENDRANGE on BUFFER
|
||||
# and CHILD: python3 test/test_schedule.py TestSchedule.test_cache_reduce_parent
|
||||
(UPat(Ops.ENDRANGE, src=(UPat({Ops.BUFFER, Ops.CONST, Ops.BUFFERIZE, Ops.CHILD}),), allow_any_len=True, name="er"), lambda er: er.src[0]),
|
||||
# handle INDEXed ENDRANGE
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.ENDRANGE, src=(UPat(GroupOp.Elementwise.union({Ops.REDUCE_AXIS})),), allow_any_len=True, name="er"),),
|
||||
allow_any_len=True, name="idx"), indexed_endrange),
|
||||
# handle arg on any op with weight. old endrange stuff
|
||||
(UPat(Ops.INDEX,src=(UPat(GroupOp.Elementwise.union({Ops.REDUCE_AXIS})),), allow_any_len=True, name="idx"), might_end_axis),
|
||||
|
||||
# move MAP through elementwise ALU / reduce. these are the items with cost
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.STORE, Ops.ASSIGN, Ops.COPY, Ops.DEVICE})),), allow_any_len=True, name="x"),
|
||||
|
|
@ -379,6 +368,7 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
|||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
src = CStyleLanguage().render(rsink.arg.lst)
|
||||
print(src)
|
||||
return {sink:sink}
|
||||
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], split_kernels, input_map=tensor_map, name="split kernels")
|
||||
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Kernel Graph")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue