mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
test mnist passes
This commit is contained in:
parent
ffa08e9c94
commit
35116959ea
2 changed files with 23 additions and 14 deletions
|
|
@ -393,16 +393,17 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
|||
tensor_map = graph_rewrite_map(tensor_map[sink], rangeify_fixups, bottom_up=True, input_map=tensor_map, name="* contiguous")
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], pm_children, ctx=ChildrenContext(), bottom_up=True, input_map=tensor_map, name="* children")
|
||||
|
||||
#tensor_map = graph_rewrite_map(tensor_map[sink], pm_rangeify, ctx=RangeifyContext(), bottom_up=True, input_map=tensor_map, name="* rangeify")
|
||||
#tensor_map = graph_rewrite_map(tensor_map[sink], pm_add_buffers, ctx=AddBufferContext(), bottom_up=True, input_map=tensor_map, name="* buffer")
|
||||
#tensor_map = graph_rewrite_map(tensor_map[sink], split_kernels, input_map=tensor_map, name="* split kernels")
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], pm_rangeify, ctx=RangeifyContext(), bottom_up=True, input_map=tensor_map, name="* rangeify")
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], pm_add_buffers, ctx=AddBufferContext(), bottom_up=True, input_map=tensor_map, name="* buffer")
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], split_kernels, input_map=tensor_map, name="* split kernels")
|
||||
return tensor_map
|
||||
|
||||
"""
|
||||
rsink = tensor_map[sink]
|
||||
|
||||
my_ctx = []
|
||||
rsink = graph_rewrite(rsink, pm_rangeify, ctx=RangeifyContext(), bottom_up=True, name="* rangeify")
|
||||
rsink = graph_rewrite(rsink, pm_add_buffers, ctx=AddBufferContext(), bottom_up=True, name="* buffer")
|
||||
rsink = graph_rewrite(rsink, do_debuf, ctx=my_ctx, name="* debuf")
|
||||
rsink = graph_rewrite(rsink, do_debuf, ctx=[], name="* debuf")
|
||||
"""
|
||||
#if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Kernel Graph")
|
||||
#rsink = graph_rewrite(rsink, sym, name="* symbolic")
|
||||
|
||||
|
|
|
|||
|
|
@ -7,15 +7,15 @@ from tinygrad.helpers import argsort, prod, all_same
|
|||
rangeify_fixups = PatternMatcher([
|
||||
# all contiguous on SINK
|
||||
(UPat(Ops.SINK, name="x"),
|
||||
lambda x: x.replace(src=tuple([s.contiguous().index() if s.op not in {Ops.INDEX, Ops.CONST} else s for s in x.src]))),
|
||||
lambda x: x.replace(src=tuple([s.contiguous() if s.op not in {Ops.CONTIGUOUS, Ops.CONST} else s for s in x.src]))),
|
||||
# all contiguous on COPY
|
||||
(UPat(Ops.COPY, name="x"), lambda x: x.replace(tag=1).contiguous() if x.tag is None else None),
|
||||
# double contiguous merge
|
||||
(UPat(Ops.CONTIGUOUS, name="c2", src=(UPat(Ops.CONTIGUOUS, name="c1"))), lambda c1,c2: c1 if c1.arg is None and c2.arg is None else None),
|
||||
# const
|
||||
(UPat(Ops.CONST, name="x"), lambda x:
|
||||
x.replace(src=(x.src[0].src[0],)).reshape((1,)*len(x.shape)).expand(x.shape) if \
|
||||
len(x.src) and x.src[0].op is Ops.VIEW and not any(s == 0 for s in x.shape) else None),
|
||||
#(UPat(Ops.CONST, name="x"), lambda x:
|
||||
# x.replace(src=(x.src[0].src[0],)).reshape((1,)*len(x.shape)).expand(x.shape) if \
|
||||
# len(x.src) and x.src[0].op is Ops.VIEW and not any(s == 0 for s in x.shape) else None),
|
||||
])
|
||||
|
||||
@dataclass
|
||||
|
|
@ -43,6 +43,7 @@ class RangeifyContext:
|
|||
regs: int = 0
|
||||
seen_children: dict[UOp, dict[int, UOp]] = field(default_factory=dict)
|
||||
seen_child: dict[UOp, Any] = field(default_factory=dict)
|
||||
is_sink_contig: tuple[UOp, ...] = ()
|
||||
|
||||
def map_reshape(x:UOp, r:UOp):
|
||||
acc = 1
|
||||
|
|
@ -111,7 +112,7 @@ pm_mops = PatternMatcher([
|
|||
(UPat(Ops.INDEX, src=(UPat(Ops.PAD, name="r"),), allow_any_len=True, name="x"), map_pad),
|
||||
])
|
||||
|
||||
def map_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp):
|
||||
def map_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp|None=None):
|
||||
if x.tag == 1: return None
|
||||
ranges = []
|
||||
new_ranges = []
|
||||
|
|
@ -121,7 +122,7 @@ def map_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp):
|
|||
assert idx is not None, "partial contig requires index"
|
||||
ranges.append(idx.src[1+i])
|
||||
continue
|
||||
if len(idx.src) > 1: passthrough_idx.append(idx.src[1+i])
|
||||
if idx is not None: passthrough_idx.append(idx.src[1+i])
|
||||
if resolve(s!=1):
|
||||
ranges.append(UOp.range(dtypes.int, s, ctx.idx))
|
||||
new_ranges.append(ranges[-1])
|
||||
|
|
@ -185,13 +186,20 @@ def indexed_endrange(er:UOp, idx:UOp):
|
|||
if to_end_axis: return idx.replace(src=(er.src[0].contiguous(arg=tuple(to_end_axis)),)+idx.src[1:])
|
||||
return idx.replace(src=(er.src[0],)+idx.src[1:])
|
||||
|
||||
def get_sink_contig(ctx:RangeifyContext, s:UOp):
|
||||
ctx.is_sink_contig = [x for x in s.src if x.op is Ops.CONTIGUOUS and x.tag is None]
|
||||
|
||||
pm_rangeify = pm_mops+PatternMatcher([
|
||||
# if there are new ended children, tag the SINK
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CHILD, src=(UPat(name="c"), ), name="x"),), allow_any_len=True, name="idx"), index_child),
|
||||
|
||||
# if there's an INDEX it can support partial contig
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CONTIGUOUS, name="x"),), allow_any_len=True, name="idx"), map_contiguous),
|
||||
#(UPat(Ops.CONTIGUOUS, name="x"), map_contiguous),
|
||||
|
||||
# sink contigs to kick it off
|
||||
(UPat(Ops.SINK, name="s"), get_sink_contig),
|
||||
(UPat(Ops.CONTIGUOUS, name="x"), lambda ctx,x: map_contiguous(ctx, x) if x in ctx.is_sink_contig else None),
|
||||
#(UPat(Ops.SINK, name="s"), lambda ctx,s: s.replace(src=tuple([map_contiguous(ctx,x) if x.op is Ops.CONTIGUOUS else x for x in s.src]))),
|
||||
|
||||
# handle ENDRANGE on movement
|
||||
(UPat(Ops.ENDRANGE, src=(UPat(GroupOp.Movement),), allow_any_len=True, name="er"),
|
||||
|
|
@ -246,7 +254,7 @@ def add_load_on_store(ctx:AddBufferContext, x:UOp, st:UOp):
|
|||
pm_add_buffers = pm_mops+PatternMatcher([
|
||||
(UPat(Ops.CONTIGUOUS, name="x"), add_store),
|
||||
(UPat(Ops.ENDRANGE, name="x"), lambda x: x.src[0]),
|
||||
#(UPat(Ops.INDEX, src=(UPat(Ops.BUFFER, name="b"), UPat(name="idx")), name="x"), add_load),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.BUFFER, name="b"), UPat(name="idx")), name="x"), add_load),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.STORE, name="st"),), allow_any_len=True, name="x"), add_load_on_store),
|
||||
(UPat(Ops.BIND, name="b"), lambda b: b.src[0]),
|
||||
# CONST can't have axes. remove srcs when we idx
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue