test mnist passes

This commit is contained in:
George Hotz 2025-08-14 16:24:32 -07:00
commit 35116959ea
2 changed files with 23 additions and 14 deletions

View file

@ -393,16 +393,17 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
tensor_map = graph_rewrite_map(tensor_map[sink], rangeify_fixups, bottom_up=True, input_map=tensor_map, name="* contiguous")
tensor_map = graph_rewrite_map(tensor_map[sink], pm_children, ctx=ChildrenContext(), bottom_up=True, input_map=tensor_map, name="* children")
#tensor_map = graph_rewrite_map(tensor_map[sink], pm_rangeify, ctx=RangeifyContext(), bottom_up=True, input_map=tensor_map, name="* rangeify")
#tensor_map = graph_rewrite_map(tensor_map[sink], pm_add_buffers, ctx=AddBufferContext(), bottom_up=True, input_map=tensor_map, name="* buffer")
#tensor_map = graph_rewrite_map(tensor_map[sink], split_kernels, input_map=tensor_map, name="* split kernels")
tensor_map = graph_rewrite_map(tensor_map[sink], pm_rangeify, ctx=RangeifyContext(), bottom_up=True, input_map=tensor_map, name="* rangeify")
tensor_map = graph_rewrite_map(tensor_map[sink], pm_add_buffers, ctx=AddBufferContext(), bottom_up=True, input_map=tensor_map, name="* buffer")
tensor_map = graph_rewrite_map(tensor_map[sink], split_kernels, input_map=tensor_map, name="* split kernels")
return tensor_map
"""
rsink = tensor_map[sink]
my_ctx = []
rsink = graph_rewrite(rsink, pm_rangeify, ctx=RangeifyContext(), bottom_up=True, name="* rangeify")
rsink = graph_rewrite(rsink, pm_add_buffers, ctx=AddBufferContext(), bottom_up=True, name="* buffer")
rsink = graph_rewrite(rsink, do_debuf, ctx=my_ctx, name="* debuf")
rsink = graph_rewrite(rsink, do_debuf, ctx=[], name="* debuf")
"""
#if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Kernel Graph")
#rsink = graph_rewrite(rsink, sym, name="* symbolic")

View file

@ -7,15 +7,15 @@ from tinygrad.helpers import argsort, prod, all_same
rangeify_fixups = PatternMatcher([
# all contiguous on SINK
(UPat(Ops.SINK, name="x"),
lambda x: x.replace(src=tuple([s.contiguous().index() if s.op not in {Ops.INDEX, Ops.CONST} else s for s in x.src]))),
lambda x: x.replace(src=tuple([s.contiguous() if s.op not in {Ops.CONTIGUOUS, Ops.CONST} else s for s in x.src]))),
# all contiguous on COPY
(UPat(Ops.COPY, name="x"), lambda x: x.replace(tag=1).contiguous() if x.tag is None else None),
# double contiguous merge
(UPat(Ops.CONTIGUOUS, name="c2", src=(UPat(Ops.CONTIGUOUS, name="c1"))), lambda c1,c2: c1 if c1.arg is None and c2.arg is None else None),
# const
(UPat(Ops.CONST, name="x"), lambda x:
x.replace(src=(x.src[0].src[0],)).reshape((1,)*len(x.shape)).expand(x.shape) if \
len(x.src) and x.src[0].op is Ops.VIEW and not any(s == 0 for s in x.shape) else None),
#(UPat(Ops.CONST, name="x"), lambda x:
# x.replace(src=(x.src[0].src[0],)).reshape((1,)*len(x.shape)).expand(x.shape) if \
# len(x.src) and x.src[0].op is Ops.VIEW and not any(s == 0 for s in x.shape) else None),
])
@dataclass
@ -43,6 +43,7 @@ class RangeifyContext:
regs: int = 0
seen_children: dict[UOp, dict[int, UOp]] = field(default_factory=dict)
seen_child: dict[UOp, Any] = field(default_factory=dict)
is_sink_contig: tuple[UOp, ...] = ()
def map_reshape(x:UOp, r:UOp):
acc = 1
@ -111,7 +112,7 @@ pm_mops = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat(Ops.PAD, name="r"),), allow_any_len=True, name="x"), map_pad),
])
def map_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp):
def map_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp|None=None):
if x.tag == 1: return None
ranges = []
new_ranges = []
@ -121,7 +122,7 @@ def map_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp):
assert idx is not None, "partial contig requires index"
ranges.append(idx.src[1+i])
continue
if len(idx.src) > 1: passthrough_idx.append(idx.src[1+i])
if idx is not None: passthrough_idx.append(idx.src[1+i])
if resolve(s!=1):
ranges.append(UOp.range(dtypes.int, s, ctx.idx))
new_ranges.append(ranges[-1])
@ -185,13 +186,20 @@ def indexed_endrange(er:UOp, idx:UOp):
if to_end_axis: return idx.replace(src=(er.src[0].contiguous(arg=tuple(to_end_axis)),)+idx.src[1:])
return idx.replace(src=(er.src[0],)+idx.src[1:])
def get_sink_contig(ctx:RangeifyContext, s:UOp):
ctx.is_sink_contig = [x for x in s.src if x.op is Ops.CONTIGUOUS and x.tag is None]
pm_rangeify = pm_mops+PatternMatcher([
# if there are new ended children, tag the SINK
(UPat(Ops.INDEX, src=(UPat(Ops.CHILD, src=(UPat(name="c"), ), name="x"),), allow_any_len=True, name="idx"), index_child),
# if there's an INDEX it can support partial contig
(UPat(Ops.INDEX, src=(UPat(Ops.CONTIGUOUS, name="x"),), allow_any_len=True, name="idx"), map_contiguous),
#(UPat(Ops.CONTIGUOUS, name="x"), map_contiguous),
# sink contigs to kick it off
(UPat(Ops.SINK, name="s"), get_sink_contig),
(UPat(Ops.CONTIGUOUS, name="x"), lambda ctx,x: map_contiguous(ctx, x) if x in ctx.is_sink_contig else None),
#(UPat(Ops.SINK, name="s"), lambda ctx,s: s.replace(src=tuple([map_contiguous(ctx,x) if x.op is Ops.CONTIGUOUS else x for x in s.src]))),
# handle ENDRANGE on movement
(UPat(Ops.ENDRANGE, src=(UPat(GroupOp.Movement),), allow_any_len=True, name="er"),
@ -246,7 +254,7 @@ def add_load_on_store(ctx:AddBufferContext, x:UOp, st:UOp):
pm_add_buffers = pm_mops+PatternMatcher([
(UPat(Ops.CONTIGUOUS, name="x"), add_store),
(UPat(Ops.ENDRANGE, name="x"), lambda x: x.src[0]),
#(UPat(Ops.INDEX, src=(UPat(Ops.BUFFER, name="b"), UPat(name="idx")), name="x"), add_load),
(UPat(Ops.INDEX, src=(UPat(Ops.BUFFER, name="b"), UPat(name="idx")), name="x"), add_load),
(UPat(Ops.INDEX, src=(UPat(Ops.STORE, name="st"),), allow_any_len=True, name="x"), add_load_on_store),
(UPat(Ops.BIND, name="b"), lambda b: b.src[0]),
# CONST can't have axes. remove srcs when we idx